forked from jiuyuan/InfiniTensor
add cast operation
This commit is contained in:
parent
5329e66d0f
commit
0079d1271b
|
@ -6,8 +6,9 @@ class DataType {
|
||||||
public:
|
public:
|
||||||
static const DataType Float32;
|
static const DataType Float32;
|
||||||
static const DataType UInt32;
|
static const DataType UInt32;
|
||||||
static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t)};
|
static const DataType Int32;
|
||||||
static constexpr std::string_view names[]{"Float32", "UInt32"};
|
static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t), sizeof(int32_t)};
|
||||||
|
static constexpr std::string_view names[]{"Float32", "UInt32", "Int32"};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int index;
|
int index;
|
||||||
|
@ -29,9 +30,11 @@ class DataType {
|
||||||
|
|
||||||
inline const DataType DataType::Float32(0);
|
inline const DataType DataType::Float32(0);
|
||||||
inline const DataType DataType::UInt32(1);
|
inline const DataType DataType::UInt32(1);
|
||||||
|
inline const DataType DataType::Int32(2);
|
||||||
// Method definitions are out of the declaration due to GCC bug:
|
// Method definitions are out of the declaration due to GCC bug:
|
||||||
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
|
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
|
||||||
template <> inline DataType DataType::get<float>() { return Float32; }
|
template <> inline DataType DataType::get<float>() { return Float32; }
|
||||||
template <> inline DataType DataType::get<uint32_t>() { return UInt32; }
|
template <> inline DataType DataType::get<uint32_t>() { return UInt32; }
|
||||||
|
template <> inline DataType DataType::get<int32_t>() { return Int32; }
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -81,6 +81,7 @@ enum class OpType {
|
||||||
Transform,
|
Transform,
|
||||||
AddN,
|
AddN,
|
||||||
MulN,
|
MulN,
|
||||||
|
Cast,
|
||||||
//
|
//
|
||||||
MemBound = 300,
|
MemBound = 300,
|
||||||
};
|
};
|
||||||
|
@ -170,6 +171,7 @@ class OpRegistry {
|
||||||
FOP(Transform);
|
FOP(Transform);
|
||||||
FOP(AddN);
|
FOP(AddN);
|
||||||
FOP(MulN);
|
FOP(MulN);
|
||||||
|
FOP(Cast);
|
||||||
//
|
//
|
||||||
FOP(MemBound);
|
FOP(MemBound);
|
||||||
default:
|
default:
|
||||||
|
@ -251,6 +253,7 @@ class OperatorObj : public Object {
|
||||||
* function.
|
* function.
|
||||||
*/
|
*/
|
||||||
bool checkValid(GraphObj *graph);
|
bool checkValid(GraphObj *graph);
|
||||||
|
bool checkValid(GraphObj *graph, DataType type);
|
||||||
OpPerfKey getOpPerfKey() const;
|
OpPerfKey getOpPerfKey() const;
|
||||||
/**
|
/**
|
||||||
* @brief Hash operator attributes. Input and output shapes are not
|
* @brief Hash operator attributes. Input and output shapes are not
|
||||||
|
|
|
@ -72,6 +72,7 @@ class TensorObj : public TensorBaseObj {
|
||||||
private:
|
private:
|
||||||
void printDataFloat() const;
|
void printDataFloat() const;
|
||||||
void printDataUint32_t() const;
|
void printDataUint32_t() const;
|
||||||
|
void printDataInt32_t() const;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool equalDataImpl(const T *a, const T *b, size_t size) const {
|
bool equalDataImpl(const T *a, const T *b, size_t size) const {
|
||||||
|
|
|
@ -80,6 +80,34 @@ class TransformObj : public OperatorObj {
|
||||||
vector<int> getOpAttrVector() const override;
|
vector<int> getOpAttrVector() const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class CastObj : public OperatorObj {
|
||||||
|
public:
|
||||||
|
enum CastType { Float2Half = 0, Float2HalfIEEE754, Float2Double, Float2Int64, Float2Int32, Float2Int16, Float2Int8, Float2Bool,
|
||||||
|
Half2Float, Half2Int32, Half2Int64, Half2Int16, Half2Int8, Half2Uint8, Half2Bool, Half2FloatInf,
|
||||||
|
Int322Float, Int322Half, Int322Int8, Int322Int16,
|
||||||
|
Int162Float, Int162Half, Int162Int32,
|
||||||
|
Int82Float, Int82Half, Int82Int16, Int82Int32,
|
||||||
|
Uint82Float, Uint82Half, Uint82Int32, Uint82Int64,
|
||||||
|
Bool2Float, Bool2Half, Bool2Int32,
|
||||||
|
Int322Int64, Int322Bool,
|
||||||
|
Int642Int32, Int642Uint32, Int642Float, Int642Half,
|
||||||
|
Uint642Uint32,
|
||||||
|
Uint322Int64, Uint322Uint64,
|
||||||
|
Double2Float};
|
||||||
|
CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type);
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
CastType getType() const { return castType; }
|
||||||
|
int numInputs() const override { return 1; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
CastType castType;
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
|
||||||
#define DEFINE_UNARY_OBJ(prefix, type) \
|
#define DEFINE_UNARY_OBJ(prefix, type) \
|
||||||
class prefix##Obj : public UnaryObj { \
|
class prefix##Obj : public UnaryObj { \
|
||||||
public: \
|
public: \
|
||||||
|
|
|
@ -61,4 +61,4 @@ OpVec GraphObj::getComputeOps() const {
|
||||||
return opList;
|
return opList;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -82,6 +82,29 @@ bool OperatorObj::checkValid(GraphObj *graph) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool OperatorObj::checkValid(GraphObj *graph, DataType type) {
|
||||||
|
auto optShapes = inferShape();
|
||||||
|
if (!optShapes) // shape inference failed
|
||||||
|
return false;
|
||||||
|
|
||||||
|
const vector<Shape> &shapes = *optShapes;
|
||||||
|
if (shapes.size() != outputs.size())
|
||||||
|
return false;
|
||||||
|
if (graph) { // if graph != nullptr, outputs should be created
|
||||||
|
auto dataTypes = vector(numOutputs(), type);;
|
||||||
|
for (size_t i = 0; i < outputs.size(); i++) {
|
||||||
|
IT_ASSERT(!outputs[i]);
|
||||||
|
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
|
||||||
|
}
|
||||||
|
} else { // if outputs have been created, check their shapes
|
||||||
|
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||||
|
if (shapes[i] != outputs[i]->getDims())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
optional<vector<Shape>> OperatorObj::inferShape() const {
|
optional<vector<Shape>> OperatorObj::inferShape() const {
|
||||||
return inferShape(inputs);
|
return inferShape(inputs);
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,6 +69,8 @@ void TensorObj::printData() const {
|
||||||
printDataFloat();
|
printDataFloat();
|
||||||
else if (dtype == DataType::UInt32)
|
else if (dtype == DataType::UInt32)
|
||||||
printDataUint32_t();
|
printDataUint32_t();
|
||||||
|
else if (dtype == DataType::Int32)
|
||||||
|
printDataInt32_t();
|
||||||
else
|
else
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
}
|
}
|
||||||
|
@ -128,6 +130,34 @@ void TensorObj::printDataUint32_t() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TensorObj::printDataInt32_t() const {
|
||||||
|
IT_ASSERT(data != nullptr);
|
||||||
|
std::cout << "Tensor: " << guid << std::endl;
|
||||||
|
auto numDims = shape.size();
|
||||||
|
auto dimSzVec = std::vector<int>(numDims, 1);
|
||||||
|
auto ptr = data->getPtr<int32_t *>();
|
||||||
|
dimSzVec[numDims - 1] = shape[numDims - 1];
|
||||||
|
for (int i = numDims - 1; i != 0; --i)
|
||||||
|
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
|
||||||
|
for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
|
||||||
|
for (size_t j = 0; j < numDims; ++j) {
|
||||||
|
if (i % dimSzVec[j] == 0) {
|
||||||
|
std::cout << "[";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::cout << ptr[i];
|
||||||
|
for (size_t j = 0; j < numDims; ++j) {
|
||||||
|
if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
|
||||||
|
std::cout << "]";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (i != size() - 1)
|
||||||
|
std::cout << ", ";
|
||||||
|
if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] - 1)
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool TensorObj::equalData(const Tensor &rhs) const {
|
bool TensorObj::equalData(const Tensor &rhs) const {
|
||||||
IT_ASSERT(data != nullptr);
|
IT_ASSERT(data != nullptr);
|
||||||
IT_ASSERT(rhs->data != nullptr);
|
IT_ASSERT(rhs->data != nullptr);
|
||||||
|
@ -142,6 +172,9 @@ bool TensorObj::equalData(const Tensor &rhs) const {
|
||||||
else if (getDType() == DataType::Float32)
|
else if (getDType() == DataType::Float32)
|
||||||
return equalDataImpl(getRawDataPtr<float *>(),
|
return equalDataImpl(getRawDataPtr<float *>(),
|
||||||
rhs->getRawDataPtr<float *>(), size());
|
rhs->getRawDataPtr<float *>(), size());
|
||||||
|
else if (getDType() == DataType::Int32)
|
||||||
|
return equalDataImpl(getRawDataPtr<int32_t *>(),
|
||||||
|
rhs->getRawDataPtr<int32_t *>(), size());
|
||||||
else
|
else
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
}
|
}
|
||||||
|
@ -155,6 +188,8 @@ void TensorObj::dataMalloc() {
|
||||||
bytesPerElement = sizeof(float);
|
bytesPerElement = sizeof(float);
|
||||||
else if (getDType() == DataType::UInt32)
|
else if (getDType() == DataType::UInt32)
|
||||||
bytesPerElement = sizeof(uint32_t);
|
bytesPerElement = sizeof(uint32_t);
|
||||||
|
else if (getDType() == DataType::Int32)
|
||||||
|
bytesPerElement = sizeof(int32_t);
|
||||||
data = runtime->allocBlob(size() * bytesPerElement);
|
data = runtime->allocBlob(size() * bytesPerElement);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
#include "bang/bang_kernel_without_config.h"
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "operators/unary.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
class CastCnnl : public BangKernelWithoutConfig {
|
||||||
|
void compute(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<CastObj>(_op);
|
||||||
|
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||||
|
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||||
|
|
||||||
|
cnnlTensorDescriptor_t aDesc, cDesc;
|
||||||
|
auto dim = op->getInputs(0)->getDims();
|
||||||
|
if (dim.size() != 4)
|
||||||
|
IT_TODO_HALT();
|
||||||
|
|
||||||
|
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
|
||||||
|
// get inputs
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||||
|
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||||
|
cnnlCastDataType_t NlCastType;
|
||||||
|
CastObj::CastType type = op->getType();
|
||||||
|
switch(type){
|
||||||
|
case CastObj::Float2Half:
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_HALF, 4, dim_array));
|
||||||
|
NlCastType = CNNL_CAST_FLOAT_TO_HALF;
|
||||||
|
break;
|
||||||
|
case CastObj::Float2HalfIEEE754:
|
||||||
|
case CastObj::Float2Double:
|
||||||
|
case CastObj::Float2Int64:
|
||||||
|
case CastObj::Float2Int32:
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||||
|
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||||
|
NlCastType = CNNL_CAST_FLOAT_TO_INT32;
|
||||||
|
case CastObj::Float2Int16:
|
||||||
|
case CastObj::Float2Int8:
|
||||||
|
case CastObj::Float2Bool:
|
||||||
|
//Todo
|
||||||
|
break;
|
||||||
|
case CastObj::Half2Float:
|
||||||
|
case CastObj::Half2Int32:
|
||||||
|
case CastObj::Half2Int64:
|
||||||
|
case CastObj::Half2Int16:
|
||||||
|
case CastObj::Half2Int8:
|
||||||
|
case CastObj::Half2Uint8:
|
||||||
|
case CastObj::Half2Bool:
|
||||||
|
case CastObj::Half2FloatInf:
|
||||||
|
//todo
|
||||||
|
break;
|
||||||
|
case CastObj::Int322Float:
|
||||||
|
case CastObj::Int322Half:
|
||||||
|
case CastObj::Int322Int8:
|
||||||
|
case CastObj::Int322Int16:
|
||||||
|
//todo
|
||||||
|
break;
|
||||||
|
case CastObj::Int162Float:
|
||||||
|
case CastObj::Int162Half:
|
||||||
|
case CastObj::Int162Int32:
|
||||||
|
//todo
|
||||||
|
break;
|
||||||
|
case CastObj::Int82Float:
|
||||||
|
case CastObj::Int82Half:
|
||||||
|
case CastObj::Int82Int16:
|
||||||
|
case CastObj::Int82Int32:
|
||||||
|
//todo
|
||||||
|
break;
|
||||||
|
case CastObj::Uint82Float:
|
||||||
|
case CastObj::Uint82Half:
|
||||||
|
case CastObj::Uint82Int32:
|
||||||
|
case CastObj::Uint82Int64:
|
||||||
|
//todo
|
||||||
|
break;
|
||||||
|
case CastObj::Bool2Float:
|
||||||
|
case CastObj::Bool2Half:
|
||||||
|
case CastObj::Bool2Int32:
|
||||||
|
//todo
|
||||||
|
break;
|
||||||
|
case CastObj::Int322Int64:
|
||||||
|
case CastObj::Int322Bool:
|
||||||
|
//todo
|
||||||
|
break;
|
||||||
|
case CastObj::Int642Int32:
|
||||||
|
case CastObj::Int642Uint32:
|
||||||
|
case CastObj::Int642Float:
|
||||||
|
case CastObj::Int642Half:
|
||||||
|
//todo
|
||||||
|
break;
|
||||||
|
case CastObj::Uint642Uint32:
|
||||||
|
case CastObj::Uint322Int64:
|
||||||
|
case CastObj::Uint322Uint64:
|
||||||
|
//todo
|
||||||
|
break;
|
||||||
|
case CastObj::Double2Float:
|
||||||
|
//todo
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
cnnlStatus_t stat =
|
||||||
|
cnnlCastDataType(context->cnnlHandle(), aDesc, aData, NlCastType, cDesc, cData);
|
||||||
|
if (stat != CNNL_STATUS_SUCCESS)
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Destories in BANG does not require sync. But cnnl does not state
|
||||||
|
// whether sync is required before destories.
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||||
|
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::BANG, OpType::Cast, DataType::Float32, CastCnnl,
|
||||||
|
"Cast_cnnl_BANG_Float32");
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -150,4 +150,33 @@ vector<int> TransformObj::getOpAttrVector() const {
|
||||||
return {enum_to_underlying(type)};
|
return {enum_to_underlying(type)};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CastObj::CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type)
|
||||||
|
: OperatorObj(OpType::Cast, {input}, {output}), castType(type) {
|
||||||
|
IT_ASSERT(checkValid(graph, DataType::Int32));
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> CastObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
const auto A = inputs[0];
|
||||||
|
return {{A->getDims()}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string CastObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
|
||||||
|
os << "(";
|
||||||
|
os << "output=" << outputs[0]->getGuid() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> CastObj::getWorkloadVector() const {
|
||||||
|
vector<int> ret{enum_to_underlying(type)};
|
||||||
|
const Shape shape = outputs[0]->getDims();
|
||||||
|
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> CastObj::getOpAttrVector() const {
|
||||||
|
return {enum_to_underlying(type)};
|
||||||
|
}
|
||||||
|
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
#include "bang/bang_runtime.h"
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "operators/unary.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
void testCast(const std::function<void(void *, size_t, DataType)> &generator,
|
||||||
|
const Shape &shape) {
|
||||||
|
// Runtime
|
||||||
|
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||||
|
|
||||||
|
// Build input data on CPU
|
||||||
|
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||||
|
inputCpu->dataMalloc();
|
||||||
|
inputCpu->setData(generator);
|
||||||
|
|
||||||
|
// GPU
|
||||||
|
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||||
|
auto inputGpu = bangGraph->cloneTensor(inputCpu);
|
||||||
|
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, CastObj::Float2Int32);
|
||||||
|
auto outputGpu = gpuOp->getOutput();
|
||||||
|
bangGraph->dataMalloc();
|
||||||
|
bangRuntime->run(bangGraph);
|
||||||
|
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||||
|
inputCpu->printData();
|
||||||
|
outputGpu2Cpu->printData();
|
||||||
|
EXPECT_TRUE(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(cnnl_Cast, run) {
|
||||||
|
testCast<CastObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue