add cast operation

This commit is contained in:
wanghailu 2022-12-28 08:57:52 +00:00
parent 5329e66d0f
commit 0079d1271b
10 changed files with 282 additions and 4 deletions

View File

@ -6,8 +6,9 @@ class DataType {
public:
static const DataType Float32;
static const DataType UInt32;
static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t)};
static constexpr std::string_view names[]{"Float32", "UInt32"};
static const DataType Int32;
static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t), sizeof(int32_t)};
static constexpr std::string_view names[]{"Float32", "UInt32", "Int32"};
private:
int index;
@ -29,9 +30,11 @@ class DataType {
inline const DataType DataType::Float32(0);
inline const DataType DataType::UInt32(1);
inline const DataType DataType::Int32(2);
// Method definitions are out of the declaration due to GCC bug:
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
template <> inline DataType DataType::get<float>() { return Float32; }
template <> inline DataType DataType::get<uint32_t>() { return UInt32; }
template <> inline DataType DataType::get<int32_t>() { return Int32; }
} // namespace infini
} // namespace infini

View File

@ -81,6 +81,7 @@ enum class OpType {
Transform,
AddN,
MulN,
Cast,
//
MemBound = 300,
};
@ -170,6 +171,7 @@ class OpRegistry {
FOP(Transform);
FOP(AddN);
FOP(MulN);
FOP(Cast);
//
FOP(MemBound);
default:
@ -251,6 +253,7 @@ class OperatorObj : public Object {
* function.
*/
bool checkValid(GraphObj *graph);
bool checkValid(GraphObj *graph, DataType type);
OpPerfKey getOpPerfKey() const;
/**
* @brief Hash operator attributes. Input and output shapes are not

View File

@ -72,6 +72,7 @@ class TensorObj : public TensorBaseObj {
private:
void printDataFloat() const;
void printDataUint32_t() const;
void printDataInt32_t() const;
template <typename T>
bool equalDataImpl(const T *a, const T *b, size_t size) const {

View File

@ -80,6 +80,34 @@ class TransformObj : public OperatorObj {
vector<int> getOpAttrVector() const override;
};
class CastObj : public OperatorObj {
public:
enum CastType { Float2Half = 0, Float2HalfIEEE754, Float2Double, Float2Int64, Float2Int32, Float2Int16, Float2Int8, Float2Bool,
Half2Float, Half2Int32, Half2Int64, Half2Int16, Half2Int8, Half2Uint8, Half2Bool, Half2FloatInf,
Int322Float, Int322Half, Int322Int8, Int322Int16,
Int162Float, Int162Half, Int162Int32,
Int82Float, Int82Half, Int82Int16, Int82Int32,
Uint82Float, Uint82Half, Uint82Int32, Uint82Int64,
Bool2Float, Bool2Half, Bool2Int32,
Int322Int64, Int322Bool,
Int642Int32, Int642Uint32, Int642Float, Int642Half,
Uint642Uint32,
Uint322Int64, Uint322Uint64,
Double2Float};
CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
CastType getType() const { return castType; }
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
private:
CastType castType;
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
#define DEFINE_UNARY_OBJ(prefix, type) \
class prefix##Obj : public UnaryObj { \
public: \

View File

@ -61,4 +61,4 @@ OpVec GraphObj::getComputeOps() const {
return opList;
};
} // namespace infini
} // namespace infini

View File

@ -82,6 +82,29 @@ bool OperatorObj::checkValid(GraphObj *graph) {
return true;
}
bool OperatorObj::checkValid(GraphObj *graph, DataType type) {
auto optShapes = inferShape();
if (!optShapes) // shape inference failed
return false;
const vector<Shape> &shapes = *optShapes;
if (shapes.size() != outputs.size())
return false;
if (graph) { // if graph != nullptr, outputs should be created
auto dataTypes = vector(numOutputs(), type);;
for (size_t i = 0; i < outputs.size(); i++) {
IT_ASSERT(!outputs[i]);
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
}
} else { // if outputs have been created, check their shapes
for (size_t i = 0; i < shapes.size(); ++i) {
if (shapes[i] != outputs[i]->getDims())
return false;
}
}
return true;
}
optional<vector<Shape>> OperatorObj::inferShape() const {
return inferShape(inputs);
}

View File

@ -69,6 +69,8 @@ void TensorObj::printData() const {
printDataFloat();
else if (dtype == DataType::UInt32)
printDataUint32_t();
else if (dtype == DataType::Int32)
printDataInt32_t();
else
IT_TODO_HALT();
}
@ -128,6 +130,34 @@ void TensorObj::printDataUint32_t() const {
}
}
void TensorObj::printDataInt32_t() const {
IT_ASSERT(data != nullptr);
std::cout << "Tensor: " << guid << std::endl;
auto numDims = shape.size();
auto dimSzVec = std::vector<int>(numDims, 1);
auto ptr = data->getPtr<int32_t *>();
dimSzVec[numDims - 1] = shape[numDims - 1];
for (int i = numDims - 1; i != 0; --i)
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
for (size_t j = 0; j < numDims; ++j) {
if (i % dimSzVec[j] == 0) {
std::cout << "[";
}
}
std::cout << ptr[i];
for (size_t j = 0; j < numDims; ++j) {
if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
std::cout << "]";
}
}
if (i != size() - 1)
std::cout << ", ";
if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] - 1)
std::cout << std::endl;
}
}
bool TensorObj::equalData(const Tensor &rhs) const {
IT_ASSERT(data != nullptr);
IT_ASSERT(rhs->data != nullptr);
@ -142,6 +172,9 @@ bool TensorObj::equalData(const Tensor &rhs) const {
else if (getDType() == DataType::Float32)
return equalDataImpl(getRawDataPtr<float *>(),
rhs->getRawDataPtr<float *>(), size());
else if (getDType() == DataType::Int32)
return equalDataImpl(getRawDataPtr<int32_t *>(),
rhs->getRawDataPtr<int32_t *>(), size());
else
IT_TODO_HALT();
}
@ -155,6 +188,8 @@ void TensorObj::dataMalloc() {
bytesPerElement = sizeof(float);
else if (getDType() == DataType::UInt32)
bytesPerElement = sizeof(uint32_t);
else if (getDType() == DataType::Int32)
bytesPerElement = sizeof(int32_t);
data = runtime->allocBlob(size() * bytesPerElement);
}

116
src/kernels/bang/cast.cc Normal file
View File

@ -0,0 +1,116 @@
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
#include "operators/unary.h"
namespace infini {
class CastCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<CastObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
cnnlTensorDescriptor_t aDesc, cDesc;
auto dim = op->getInputs(0)->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
// get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
cnnlCastDataType_t NlCastType;
CastObj::CastType type = op->getType();
switch(type){
case CastObj::Float2Half:
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_HALF, 4, dim_array));
NlCastType = CNNL_CAST_FLOAT_TO_HALF;
break;
case CastObj::Float2HalfIEEE754:
case CastObj::Float2Double:
case CastObj::Float2Int64:
case CastObj::Float2Int32:
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
NlCastType = CNNL_CAST_FLOAT_TO_INT32;
case CastObj::Float2Int16:
case CastObj::Float2Int8:
case CastObj::Float2Bool:
//Todo
break;
case CastObj::Half2Float:
case CastObj::Half2Int32:
case CastObj::Half2Int64:
case CastObj::Half2Int16:
case CastObj::Half2Int8:
case CastObj::Half2Uint8:
case CastObj::Half2Bool:
case CastObj::Half2FloatInf:
//todo
break;
case CastObj::Int322Float:
case CastObj::Int322Half:
case CastObj::Int322Int8:
case CastObj::Int322Int16:
//todo
break;
case CastObj::Int162Float:
case CastObj::Int162Half:
case CastObj::Int162Int32:
//todo
break;
case CastObj::Int82Float:
case CastObj::Int82Half:
case CastObj::Int82Int16:
case CastObj::Int82Int32:
//todo
break;
case CastObj::Uint82Float:
case CastObj::Uint82Half:
case CastObj::Uint82Int32:
case CastObj::Uint82Int64:
//todo
break;
case CastObj::Bool2Float:
case CastObj::Bool2Half:
case CastObj::Bool2Int32:
//todo
break;
case CastObj::Int322Int64:
case CastObj::Int322Bool:
//todo
break;
case CastObj::Int642Int32:
case CastObj::Int642Uint32:
case CastObj::Int642Float:
case CastObj::Int642Half:
//todo
break;
case CastObj::Uint642Uint32:
case CastObj::Uint322Int64:
case CastObj::Uint322Uint64:
//todo
break;
case CastObj::Double2Float:
//todo
break;
}
cnnlStatus_t stat =
cnnlCastDataType(context->cnnlHandle(), aDesc, aData, NlCastType, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;
// Destories in BANG does not require sync. But cnnl does not state
// whether sync is required before destories.
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
}
};
REGISTER_KERNEL(Device::BANG, OpType::Cast, DataType::Float32, CastCnnl,
"Cast_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -150,4 +150,33 @@ vector<int> TransformObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
CastObj::CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type)
: OperatorObj(OpType::Cast, {input}, {output}), castType(type) {
IT_ASSERT(checkValid(graph, DataType::Int32));
}
optional<vector<Shape>> CastObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0];
return {{A->getDims()}};
}
std::string CastObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << "(";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> CastObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> CastObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
}; // namespace infini

View File

@ -0,0 +1,40 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
template <class T>
void testCast(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu->dataMalloc();
inputCpu->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu = bangGraph->cloneTensor(inputCpu);
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, CastObj::Float2Int32);
auto outputGpu = gpuOp->getOutput();
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
inputCpu->printData();
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_Cast, run) {
testCast<CastObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini