forked from jiuyuan/InfiniTensor
fix: fix castType mlu (#117)
This commit is contained in:
parent
9cf6c30e1c
commit
1438f14a25
|
@ -22,142 +22,142 @@ class CastCnnl : public BangKernelWithoutConfig {
|
|||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
cnnlCastDataType_t NlCastType;
|
||||
CastObj::CastType type = op->getType();
|
||||
CastType type = op->getType();
|
||||
switch (type) {
|
||||
case CastObj::Float2Int64:
|
||||
case CastType::Float2Int64:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_FLOAT_TO_INT64;
|
||||
break;
|
||||
case CastObj::Float2Int32:
|
||||
case CastType::Float2Int32:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_FLOAT_TO_INT32;
|
||||
break;
|
||||
case CastObj::Float2Int16:
|
||||
case CastType::Float2Int16:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_FLOAT_TO_INT16;
|
||||
break;
|
||||
case CastObj::Float2Int8:
|
||||
case CastType::Float2Int8:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_FLOAT_TO_INT8;
|
||||
break;
|
||||
case CastObj::Int322Float:
|
||||
case CastType::Int322Float:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_INT32_TO_FLOAT;
|
||||
break;
|
||||
case CastObj::Int322Int8:
|
||||
case CastType::Int322Int8:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_INT32_TO_INT8;
|
||||
break;
|
||||
case CastObj::Int322Int16:
|
||||
case CastType::Int322Int16:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_INT32_TO_INT16;
|
||||
break;
|
||||
case CastObj::Int162Float:
|
||||
case CastType::Int162Float:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_INT16_TO_FLOAT;
|
||||
break;
|
||||
case CastObj::Int162Int32:
|
||||
case CastType::Int162Int32:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_INT16_TO_INT32;
|
||||
break;
|
||||
case CastObj::Int82Float:
|
||||
case CastType::Int82Float:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_INT8_TO_FLOAT;
|
||||
break;
|
||||
case CastObj::Int82Int16:
|
||||
case CastType::Int82Int16:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_INT8_TO_INT16;
|
||||
break;
|
||||
case CastObj::Int82Int32:
|
||||
case CastType::Int82Int32:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_INT8_TO_INT32;
|
||||
break;
|
||||
case CastObj::Uint82Float:
|
||||
case CastType::Uint82Float:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_UINT8_TO_FLOAT;
|
||||
break;
|
||||
case CastObj::Uint82Int32:
|
||||
case CastType::Uint82Int32:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_UINT8_TO_INT32;
|
||||
break;
|
||||
case CastObj::Uint82Int64:
|
||||
case CastType::Uint82Int64:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_UINT8_TO_INT64;
|
||||
break;
|
||||
case CastObj::Int322Int64:
|
||||
case CastType::Int322Int64:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_INT32_TO_INT64;
|
||||
break;
|
||||
case CastObj::Int642Int32:
|
||||
case CastType::Int642Int32:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_INT64_TO_INT32;
|
||||
break;
|
||||
case CastObj::Int642Uint32:
|
||||
case CastType::Int642Uint32:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT32, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_INT64_TO_UINT32;
|
||||
break;
|
||||
case CastObj::Int642Float:
|
||||
case CastType::Int642Float:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_INT64_TO_FLOAT;
|
||||
break;
|
||||
case CastObj::Uint322Int64:
|
||||
case CastType::Uint322Int64:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT32, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(
|
||||
|
|
|
@ -23,7 +23,7 @@ void testCast(const std::function<void(void *, size_t, DataType)> &generator,
|
|||
// GPU
|
||||
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||
auto inputGpu = bangGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, CastObj::Float2Int32);
|
||||
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, CastType::Float2Int32);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
bangGraph->dataMalloc();
|
||||
bangRuntime->run(bangGraph);
|
||||
|
|
Loading…
Reference in New Issue