diff --git a/src/kernels/bang/cast.cc b/src/kernels/bang/cast.cc index 35da0190..a3d56654 100644 --- a/src/kernels/bang/cast.cc +++ b/src/kernels/bang/cast.cc @@ -22,142 +22,142 @@ class CastCnnl : public BangKernelWithoutConfig { checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); cnnlCastDataType_t NlCastType; - CastObj::CastType type = op->getType(); + CastType type = op->getType(); switch (type) { - case CastObj::Float2Int64: + case CastType::Float2Int64: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array)); NlCastType = CNNL_CAST_FLOAT_TO_INT64; break; - case CastObj::Float2Int32: + case CastType::Float2Int32: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array)); NlCastType = CNNL_CAST_FLOAT_TO_INT32; break; - case CastObj::Float2Int16: + case CastType::Float2Int16: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array)); NlCastType = CNNL_CAST_FLOAT_TO_INT16; break; - case CastObj::Float2Int8: + case CastType::Float2Int8: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array)); NlCastType = CNNL_CAST_FLOAT_TO_INT8; break; - case CastObj::Int322Float: + case CastType::Int322Float: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array)); NlCastType = CNNL_CAST_INT32_TO_FLOAT; break; - case CastObj::Int322Int8: + case CastType::Int322Int8: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array)); NlCastType = CNNL_CAST_INT32_TO_INT8; break; - case CastObj::Int322Int16: + case CastType::Int322Int16: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array)); NlCastType = CNNL_CAST_INT32_TO_INT16; break; - case CastObj::Int162Float: + case CastType::Int162Float: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array)); NlCastType = CNNL_CAST_INT16_TO_FLOAT; break; - case CastObj::Int162Int32: + case CastType::Int162Int32: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array)); NlCastType = CNNL_CAST_INT16_TO_INT32; break; - case CastObj::Int82Float: + case CastType::Int82Float: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array)); NlCastType = CNNL_CAST_INT8_TO_FLOAT; break; - case CastObj::Int82Int16: + case CastType::Int82Int16: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT16, 4, dim_array)); NlCastType = CNNL_CAST_INT8_TO_INT16; break; - case CastObj::Int82Int32: + case CastType::Int82Int32: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT8, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array)); NlCastType = CNNL_CAST_INT8_TO_INT32; break; - case CastObj::Uint82Float: + case CastType::Uint82Float: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array)); NlCastType = CNNL_CAST_UINT8_TO_FLOAT; break; - case CastObj::Uint82Int32: + case CastType::Uint82Int32: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array)); NlCastType = CNNL_CAST_UINT8_TO_INT32; break; - case CastObj::Uint82Int64: + case CastType::Uint82Int64: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT8, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array)); NlCastType = CNNL_CAST_UINT8_TO_INT64; break; - case CastObj::Int322Int64: + case CastType::Int322Int64: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array)); NlCastType = CNNL_CAST_INT32_TO_INT64; break; - case CastObj::Int642Int32: + case CastType::Int642Int32: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array)); NlCastType = CNNL_CAST_INT64_TO_INT32; break; - case CastObj::Int642Uint32: + case CastType::Int642Uint32: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT32, 4, dim_array)); NlCastType = CNNL_CAST_INT64_TO_UINT32; break; - case CastObj::Int642Float: + case CastType::Int642Float: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT64, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array)); NlCastType = CNNL_CAST_INT64_TO_FLOAT; break; - case CastObj::Uint322Int64: + case CastType::Uint322Int64: checkCnnlError(cnnlSetTensorDescriptor( aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_UINT32, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor( diff --git a/test/kernels/bang/test_bang_cast.cc b/test/kernels/bang/test_bang_cast.cc index 7bcc44ea..57bea058 100644 --- a/test/kernels/bang/test_bang_cast.cc +++ b/test/kernels/bang/test_bang_cast.cc @@ -23,7 +23,7 @@ void testCast(const std::function &generator, // GPU Graph bangGraph = make_ref(bangRuntime); auto inputGpu = bangGraph->cloneTensor(inputCpu); - auto gpuOp = bangGraph->addOp(inputGpu, nullptr, CastObj::Float2Int32); + auto gpuOp = bangGraph->addOp(inputGpu, nullptr, CastType::Float2Int32); auto outputGpu = gpuOp->getOutput(); bangGraph->dataMalloc(); bangRuntime->run(bangGraph);