diff --git a/include/core/operator.h b/include/core/operator.h index cca67297..5ca1e918 100644 --- a/include/core/operator.h +++ b/include/core/operator.h @@ -121,6 +121,7 @@ class OpRegistry { FOP(ConvBackwardData); FOP(Matmul); FOP(ConvTrans); + FOP(ConvTransNHWC); FOP(G2BMM); FOP(GBMM); FOP(Pad); @@ -209,7 +210,8 @@ class OpRegistry { // FOP(MemBound); default: - IT_ASSERT(false); + IT_ASSERT(false, "Unknown OpType " + + std::to_string(enum_to_underlying(opType))); break; } #undef FOP diff --git a/include/core/tensor.h b/include/core/tensor.h index 8417a2b2..61ba2bac 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -109,13 +109,13 @@ class TensorObj : public TensorBaseObj { size_t getOffsetByBroadcastOffset(size_t bcOffset, Shape bcShape) const; private: - template string dataToString() const { + template string dataToString(void *rawPtr) const { std::stringstream builder; builder << "Tensor: " << guid << std::endl; auto numDims = shape.size(); auto dimSzVec = vector(numDims, 1); - auto ptr = data->getPtr(); + T *ptr = (T *)rawPtr; dimSzVec[numDims - 1] = shape[numDims - 1]; for (int i = numDims - 1; i != 0; --i) diff --git a/include/operators/unary.h b/include/operators/unary.h index 7d096269..1df2b4a7 100644 --- a/include/operators/unary.h +++ b/include/operators/unary.h @@ -304,5 +304,4 @@ DEFINE_UNARY_OBJ(Sqrt, OpType::Sqrt) DEFINE_UNARY_OBJ(Rsqrt, OpType::Rsqrt) DEFINE_UNARY_OBJ(Round, OpType::Round) DEFINE_UNARY_OBJ(Square, OpType::Square) -DEFINE_UNARY_OBJ(PRelu, OpType::PRelu) }; // namespace infini diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 00ee1b7d..bc173216 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -67,12 +67,19 @@ vector TensorObj::getStride() const { void TensorObj::printData() const { IT_ASSERT(data != nullptr); - if (!runtime->isCpu()) - IT_TODO_HALT(); + void *ptr = nullptr; + Blob buffer; + if (!runtime->isCpu()) { // copy data to main memory + buffer = NativeCpuRuntimeObj::getInstance()->allocBlob(getBytes()); + runtime->copyBlobToCPU(buffer->getPtr(), + getRawDataPtr(), getBytes()); + ptr = buffer->getPtr(); + } else + ptr = data->getPtr(); #define TRY_PRINT(N) \ if (dtype == DataType(N)) \ - std::cout << dataToString::t>() << std::endl; + std::cout << dataToString::t>(ptr) << std::endl; TRY_PRINT(0) // fmt: new line else TRY_PRINT(1) //