Fix: OpType and print device tensors

This commit is contained in:
Liyan Zheng 2023-04-18 20:28:08 +08:00
parent 01fc19795d
commit 2812900ea2
4 changed files with 15 additions and 7 deletions

View File

@ -121,6 +121,7 @@ class OpRegistry {
FOP(ConvBackwardData);
FOP(Matmul);
FOP(ConvTrans);
FOP(ConvTransNHWC);
FOP(G2BMM);
FOP(GBMM);
FOP(Pad);
@ -209,7 +210,8 @@ class OpRegistry {
//
FOP(MemBound);
default:
IT_ASSERT(false);
IT_ASSERT(false, "Unknown OpType " +
std::to_string(enum_to_underlying(opType)));
break;
}
#undef FOP

View File

@ -109,13 +109,13 @@ class TensorObj : public TensorBaseObj {
size_t getOffsetByBroadcastOffset(size_t bcOffset, Shape bcShape) const;
private:
template <class T> string dataToString() const {
template <class T> string dataToString(void *rawPtr) const {
std::stringstream builder;
builder << "Tensor: " << guid << std::endl;
auto numDims = shape.size();
auto dimSzVec = vector<int>(numDims, 1);
auto ptr = data->getPtr<T *>();
T *ptr = (T *)rawPtr;
dimSzVec[numDims - 1] = shape[numDims - 1];
for (int i = numDims - 1; i != 0; --i)

View File

@ -304,5 +304,4 @@ DEFINE_UNARY_OBJ(Sqrt, OpType::Sqrt)
DEFINE_UNARY_OBJ(Rsqrt, OpType::Rsqrt)
DEFINE_UNARY_OBJ(Round, OpType::Round)
DEFINE_UNARY_OBJ(Square, OpType::Square)
DEFINE_UNARY_OBJ(PRelu, OpType::PRelu)
}; // namespace infini

View File

@ -67,12 +67,19 @@ vector<size_t> TensorObj::getStride() const {
void TensorObj::printData() const {
IT_ASSERT(data != nullptr);
if (!runtime->isCpu())
IT_TODO_HALT();
void *ptr = nullptr;
Blob buffer;
if (!runtime->isCpu()) { // copy data to main memory
buffer = NativeCpuRuntimeObj::getInstance()->allocBlob(getBytes());
runtime->copyBlobToCPU(buffer->getPtr<void *>(),
getRawDataPtr<void *>(), getBytes());
ptr = buffer->getPtr<void *>();
} else
ptr = data->getPtr<float *>();
#define TRY_PRINT(N) \
if (dtype == DataType(N)) \
std::cout << dataToString<DT<N>::t>() << std::endl;
std::cout << dataToString<DT<N>::t>(ptr) << std::endl;
TRY_PRINT(0) // fmt: new line
else TRY_PRINT(1) //