fix: fix castType mlu (#117)

This commit is contained in:
zhangyunze 2023-08-22 14:54:32 +08:00 committed by GitHub
parent 9cf6c30e1c
commit 1438f14a25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 22 deletions

View File

@ -22,142 +22,142 @@ class CastCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
cnnlCastDataType_t NlCastType;
CastObj::CastType type = op->getType();
CastType type = op->getType();
switch (type) {
case CastObj::Float2Int64:
case CastType::Float2Int64:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
NlCastType = CNNL_CAST_FLOAT_TO_INT64;
break;
case CastObj::Float2Int32:
case CastType::Float2Int32:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
NlCastType = CNNL_CAST_FLOAT_TO_INT32;
break;
case CastObj::Float2Int16:
case CastType::Float2Int16:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
NlCastType = CNNL_CAST_FLOAT_TO_INT16;
break;
case CastObj::Float2Int8:
case CastType::Float2Int8:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
NlCastType = CNNL_CAST_FLOAT_TO_INT8;
break;
case CastObj::Int322Float:
case CastType::Int322Float:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
NlCastType = CNNL_CAST_INT32_TO_FLOAT;
break;
case CastObj::Int322Int8:
case CastType::Int322Int8:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
NlCastType = CNNL_CAST_INT32_TO_INT8;
break;
case CastObj::Int322Int16:
case CastType::Int322Int16:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
NlCastType = CNNL_CAST_INT32_TO_INT16;
break;
case CastObj::Int162Float:
case CastType::Int162Float:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
NlCastType = CNNL_CAST_INT16_TO_FLOAT;
break;
case CastObj::Int162Int32:
case CastType::Int162Int32:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
NlCastType = CNNL_CAST_INT16_TO_INT32;
break;
case CastObj::Int82Float:
case CastType::Int82Float:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
NlCastType = CNNL_CAST_INT8_TO_FLOAT;
break;
case CastObj::Int82Int16:
case CastType::Int82Int16:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
NlCastType = CNNL_CAST_INT8_TO_INT16;
break;
case CastObj::Int82Int32:
case CastType::Int82Int32:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
NlCastType = CNNL_CAST_INT8_TO_INT32;
break;
case CastObj::Uint82Float:
case CastType::Uint82Float:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
NlCastType = CNNL_CAST_UINT8_TO_FLOAT;
break;
case CastObj::Uint82Int32:
case CastType::Uint82Int32:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
NlCastType = CNNL_CAST_UINT8_TO_INT32;
break;
case CastObj::Uint82Int64:
case CastType::Uint82Int64:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
NlCastType = CNNL_CAST_UINT8_TO_INT64;
break;
case CastObj::Int322Int64:
case CastType::Int322Int64:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
NlCastType = CNNL_CAST_INT32_TO_INT64;
break;
case CastObj::Int642Int32:
case CastType::Int642Int32:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
NlCastType = CNNL_CAST_INT64_TO_INT32;
break;
case CastObj::Int642Uint32:
case CastType::Int642Uint32:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT32, 4, dim_array));
NlCastType = CNNL_CAST_INT64_TO_UINT32;
break;
case CastObj::Int642Float:
case CastType::Int642Float:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
NlCastType = CNNL_CAST_INT64_TO_FLOAT;
break;
case CastObj::Uint322Int64:
case CastType::Uint322Int64:
checkCnnlError(cnnlSetTensorDescriptor(
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT32, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(

View File

@ -23,7 +23,7 @@ void testCast(const std::function<void(void *, size_t, DataType)> &generator,
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu = bangGraph->cloneTensor(inputCpu);
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, CastObj::Float2Int32);
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, CastType::Float2Int32);
auto outputGpu = gpuOp->getOutput();
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);