forked from jiuyuan/InfiniTensor
fix: fix castType mlu (#117)
This commit is contained in:
parent
9cf6c30e1c
commit
1438f14a25
|
@ -22,142 +22,142 @@ class CastCnnl : public BangKernelWithoutConfig {
|
||||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||||
cnnlCastDataType_t NlCastType;
|
cnnlCastDataType_t NlCastType;
|
||||||
CastObj::CastType type = op->getType();
|
CastType type = op->getType();
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case CastObj::Float2Int64:
|
case CastType::Float2Int64:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_FLOAT_TO_INT64;
|
NlCastType = CNNL_CAST_FLOAT_TO_INT64;
|
||||||
break;
|
break;
|
||||||
case CastObj::Float2Int32:
|
case CastType::Float2Int32:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_FLOAT_TO_INT32;
|
NlCastType = CNNL_CAST_FLOAT_TO_INT32;
|
||||||
break;
|
break;
|
||||||
case CastObj::Float2Int16:
|
case CastType::Float2Int16:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_FLOAT_TO_INT16;
|
NlCastType = CNNL_CAST_FLOAT_TO_INT16;
|
||||||
break;
|
break;
|
||||||
case CastObj::Float2Int8:
|
case CastType::Float2Int8:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_FLOAT_TO_INT8;
|
NlCastType = CNNL_CAST_FLOAT_TO_INT8;
|
||||||
break;
|
break;
|
||||||
case CastObj::Int322Float:
|
case CastType::Int322Float:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_INT32_TO_FLOAT;
|
NlCastType = CNNL_CAST_INT32_TO_FLOAT;
|
||||||
break;
|
break;
|
||||||
case CastObj::Int322Int8:
|
case CastType::Int322Int8:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_INT32_TO_INT8;
|
NlCastType = CNNL_CAST_INT32_TO_INT8;
|
||||||
break;
|
break;
|
||||||
case CastObj::Int322Int16:
|
case CastType::Int322Int16:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_INT32_TO_INT16;
|
NlCastType = CNNL_CAST_INT32_TO_INT16;
|
||||||
break;
|
break;
|
||||||
case CastObj::Int162Float:
|
case CastType::Int162Float:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_INT16_TO_FLOAT;
|
NlCastType = CNNL_CAST_INT16_TO_FLOAT;
|
||||||
break;
|
break;
|
||||||
case CastObj::Int162Int32:
|
case CastType::Int162Int32:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_INT16_TO_INT32;
|
NlCastType = CNNL_CAST_INT16_TO_INT32;
|
||||||
break;
|
break;
|
||||||
case CastObj::Int82Float:
|
case CastType::Int82Float:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_INT8_TO_FLOAT;
|
NlCastType = CNNL_CAST_INT8_TO_FLOAT;
|
||||||
break;
|
break;
|
||||||
case CastObj::Int82Int16:
|
case CastType::Int82Int16:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_INT8_TO_INT16;
|
NlCastType = CNNL_CAST_INT8_TO_INT16;
|
||||||
break;
|
break;
|
||||||
case CastObj::Int82Int32:
|
case CastType::Int82Int32:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_INT8_TO_INT32;
|
NlCastType = CNNL_CAST_INT8_TO_INT32;
|
||||||
break;
|
break;
|
||||||
case CastObj::Uint82Float:
|
case CastType::Uint82Float:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_UINT8_TO_FLOAT;
|
NlCastType = CNNL_CAST_UINT8_TO_FLOAT;
|
||||||
break;
|
break;
|
||||||
case CastObj::Uint82Int32:
|
case CastType::Uint82Int32:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_UINT8_TO_INT32;
|
NlCastType = CNNL_CAST_UINT8_TO_INT32;
|
||||||
break;
|
break;
|
||||||
case CastObj::Uint82Int64:
|
case CastType::Uint82Int64:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_UINT8_TO_INT64;
|
NlCastType = CNNL_CAST_UINT8_TO_INT64;
|
||||||
break;
|
break;
|
||||||
case CastObj::Int322Int64:
|
case CastType::Int322Int64:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_INT32_TO_INT64;
|
NlCastType = CNNL_CAST_INT32_TO_INT64;
|
||||||
break;
|
break;
|
||||||
case CastObj::Int642Int32:
|
case CastType::Int642Int32:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_INT64_TO_INT32;
|
NlCastType = CNNL_CAST_INT64_TO_INT32;
|
||||||
break;
|
break;
|
||||||
case CastObj::Int642Uint32:
|
case CastType::Int642Uint32:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT32, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT32, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_INT64_TO_UINT32;
|
NlCastType = CNNL_CAST_INT64_TO_UINT32;
|
||||||
break;
|
break;
|
||||||
case CastObj::Int642Float:
|
case CastType::Int642Float:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||||
NlCastType = CNNL_CAST_INT64_TO_FLOAT;
|
NlCastType = CNNL_CAST_INT64_TO_FLOAT;
|
||||||
break;
|
break;
|
||||||
case CastObj::Uint322Int64:
|
case CastType::Uint322Int64:
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT32, 4, dim_array));
|
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT32, 4, dim_array));
|
||||||
checkCnnlError(cnnlSetTensorDescriptor(
|
checkCnnlError(cnnlSetTensorDescriptor(
|
||||||
|
|
|
@ -23,7 +23,7 @@ void testCast(const std::function<void(void *, size_t, DataType)> &generator,
|
||||||
// GPU
|
// GPU
|
||||||
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||||
auto inputGpu = bangGraph->cloneTensor(inputCpu);
|
auto inputGpu = bangGraph->cloneTensor(inputCpu);
|
||||||
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, CastObj::Float2Int32);
|
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, CastType::Float2Int32);
|
||||||
auto outputGpu = gpuOp->getOutput();
|
auto outputGpu = gpuOp->getOutput();
|
||||||
bangGraph->dataMalloc();
|
bangGraph->dataMalloc();
|
||||||
bangRuntime->run(bangGraph);
|
bangRuntime->run(bangGraph);
|
||||||
|
|
Loading…
Reference in New Issue