forked from jiuyuan/InfiniTensor
Fix: OpType and print device tensors
This commit is contained in:
parent
01fc19795d
commit
2812900ea2
|
@ -121,6 +121,7 @@ class OpRegistry {
|
|||
FOP(ConvBackwardData);
|
||||
FOP(Matmul);
|
||||
FOP(ConvTrans);
|
||||
FOP(ConvTransNHWC);
|
||||
FOP(G2BMM);
|
||||
FOP(GBMM);
|
||||
FOP(Pad);
|
||||
|
@ -209,7 +210,8 @@ class OpRegistry {
|
|||
//
|
||||
FOP(MemBound);
|
||||
default:
|
||||
IT_ASSERT(false);
|
||||
IT_ASSERT(false, "Unknown OpType " +
|
||||
std::to_string(enum_to_underlying(opType)));
|
||||
break;
|
||||
}
|
||||
#undef FOP
|
||||
|
|
|
@ -109,13 +109,13 @@ class TensorObj : public TensorBaseObj {
|
|||
size_t getOffsetByBroadcastOffset(size_t bcOffset, Shape bcShape) const;
|
||||
|
||||
private:
|
||||
template <class T> string dataToString() const {
|
||||
template <class T> string dataToString(void *rawPtr) const {
|
||||
std::stringstream builder;
|
||||
builder << "Tensor: " << guid << std::endl;
|
||||
|
||||
auto numDims = shape.size();
|
||||
auto dimSzVec = vector<int>(numDims, 1);
|
||||
auto ptr = data->getPtr<T *>();
|
||||
T *ptr = (T *)rawPtr;
|
||||
dimSzVec[numDims - 1] = shape[numDims - 1];
|
||||
|
||||
for (int i = numDims - 1; i != 0; --i)
|
||||
|
|
|
@ -304,5 +304,4 @@ DEFINE_UNARY_OBJ(Sqrt, OpType::Sqrt)
|
|||
DEFINE_UNARY_OBJ(Rsqrt, OpType::Rsqrt)
|
||||
DEFINE_UNARY_OBJ(Round, OpType::Round)
|
||||
DEFINE_UNARY_OBJ(Square, OpType::Square)
|
||||
DEFINE_UNARY_OBJ(PRelu, OpType::PRelu)
|
||||
}; // namespace infini
|
||||
|
|
|
@ -67,12 +67,19 @@ vector<size_t> TensorObj::getStride() const {
|
|||
|
||||
void TensorObj::printData() const {
|
||||
IT_ASSERT(data != nullptr);
|
||||
if (!runtime->isCpu())
|
||||
IT_TODO_HALT();
|
||||
void *ptr = nullptr;
|
||||
Blob buffer;
|
||||
if (!runtime->isCpu()) { // copy data to main memory
|
||||
buffer = NativeCpuRuntimeObj::getInstance()->allocBlob(getBytes());
|
||||
runtime->copyBlobToCPU(buffer->getPtr<void *>(),
|
||||
getRawDataPtr<void *>(), getBytes());
|
||||
ptr = buffer->getPtr<void *>();
|
||||
} else
|
||||
ptr = data->getPtr<float *>();
|
||||
|
||||
#define TRY_PRINT(N) \
|
||||
if (dtype == DataType(N)) \
|
||||
std::cout << dataToString<DT<N>::t>() << std::endl;
|
||||
std::cout << dataToString<DT<N>::t>(ptr) << std::endl;
|
||||
|
||||
TRY_PRINT(0) // fmt: new line
|
||||
else TRY_PRINT(1) //
|
||||
|
|
Loading…
Reference in New Issue