forked from jiuyuan/InfiniTensor
add cast operation
This commit is contained in:
parent
5329e66d0f
commit
0079d1271b
|
@ -6,8 +6,9 @@ class DataType {
|
|||
public:
|
||||
static const DataType Float32;
|
||||
static const DataType UInt32;
|
||||
static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t)};
|
||||
static constexpr std::string_view names[]{"Float32", "UInt32"};
|
||||
static const DataType Int32;
|
||||
static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t), sizeof(int32_t)};
|
||||
static constexpr std::string_view names[]{"Float32", "UInt32", "Int32"};
|
||||
|
||||
private:
|
||||
int index;
|
||||
|
@ -29,9 +30,11 @@ class DataType {
|
|||
|
||||
inline const DataType DataType::Float32(0);
|
||||
inline const DataType DataType::UInt32(1);
|
||||
inline const DataType DataType::Int32(2);
|
||||
// Method definitions are out of the declaration due to GCC bug:
|
||||
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
|
||||
template <> inline DataType DataType::get<float>() { return Float32; }
|
||||
template <> inline DataType DataType::get<uint32_t>() { return UInt32; }
|
||||
template <> inline DataType DataType::get<int32_t>() { return Int32; }
|
||||
|
||||
} // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -81,6 +81,7 @@ enum class OpType {
|
|||
Transform,
|
||||
AddN,
|
||||
MulN,
|
||||
Cast,
|
||||
//
|
||||
MemBound = 300,
|
||||
};
|
||||
|
@ -170,6 +171,7 @@ class OpRegistry {
|
|||
FOP(Transform);
|
||||
FOP(AddN);
|
||||
FOP(MulN);
|
||||
FOP(Cast);
|
||||
//
|
||||
FOP(MemBound);
|
||||
default:
|
||||
|
@ -251,6 +253,7 @@ class OperatorObj : public Object {
|
|||
* function.
|
||||
*/
|
||||
bool checkValid(GraphObj *graph);
|
||||
bool checkValid(GraphObj *graph, DataType type);
|
||||
OpPerfKey getOpPerfKey() const;
|
||||
/**
|
||||
* @brief Hash operator attributes. Input and output shapes are not
|
||||
|
|
|
@ -72,6 +72,7 @@ class TensorObj : public TensorBaseObj {
|
|||
private:
|
||||
void printDataFloat() const;
|
||||
void printDataUint32_t() const;
|
||||
void printDataInt32_t() const;
|
||||
|
||||
template <typename T>
|
||||
bool equalDataImpl(const T *a, const T *b, size_t size) const {
|
||||
|
|
|
@ -80,6 +80,34 @@ class TransformObj : public OperatorObj {
|
|||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
class CastObj : public OperatorObj {
|
||||
public:
|
||||
enum CastType { Float2Half = 0, Float2HalfIEEE754, Float2Double, Float2Int64, Float2Int32, Float2Int16, Float2Int8, Float2Bool,
|
||||
Half2Float, Half2Int32, Half2Int64, Half2Int16, Half2Int8, Half2Uint8, Half2Bool, Half2FloatInf,
|
||||
Int322Float, Int322Half, Int322Int8, Int322Int16,
|
||||
Int162Float, Int162Half, Int162Int32,
|
||||
Int82Float, Int82Half, Int82Int16, Int82Int32,
|
||||
Uint82Float, Uint82Half, Uint82Int32, Uint82Int64,
|
||||
Bool2Float, Bool2Half, Bool2Int32,
|
||||
Int322Int64, Int322Bool,
|
||||
Int642Int32, Int642Uint32, Int642Float, Int642Half,
|
||||
Uint642Uint32,
|
||||
Uint322Int64, Uint322Uint64,
|
||||
Double2Float};
|
||||
CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
std::string toString() const override;
|
||||
CastType getType() const { return castType; }
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
private:
|
||||
CastType castType;
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
#define DEFINE_UNARY_OBJ(prefix, type) \
|
||||
class prefix##Obj : public UnaryObj { \
|
||||
public: \
|
||||
|
|
|
@ -61,4 +61,4 @@ OpVec GraphObj::getComputeOps() const {
|
|||
return opList;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -82,6 +82,29 @@ bool OperatorObj::checkValid(GraphObj *graph) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool OperatorObj::checkValid(GraphObj *graph, DataType type) {
|
||||
auto optShapes = inferShape();
|
||||
if (!optShapes) // shape inference failed
|
||||
return false;
|
||||
|
||||
const vector<Shape> &shapes = *optShapes;
|
||||
if (shapes.size() != outputs.size())
|
||||
return false;
|
||||
if (graph) { // if graph != nullptr, outputs should be created
|
||||
auto dataTypes = vector(numOutputs(), type);;
|
||||
for (size_t i = 0; i < outputs.size(); i++) {
|
||||
IT_ASSERT(!outputs[i]);
|
||||
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);
|
||||
}
|
||||
} else { // if outputs have been created, check their shapes
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
if (shapes[i] != outputs[i]->getDims())
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
optional<vector<Shape>> OperatorObj::inferShape() const {
|
||||
return inferShape(inputs);
|
||||
}
|
||||
|
|
|
@ -69,6 +69,8 @@ void TensorObj::printData() const {
|
|||
printDataFloat();
|
||||
else if (dtype == DataType::UInt32)
|
||||
printDataUint32_t();
|
||||
else if (dtype == DataType::Int32)
|
||||
printDataInt32_t();
|
||||
else
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
@ -128,6 +130,34 @@ void TensorObj::printDataUint32_t() const {
|
|||
}
|
||||
}
|
||||
|
||||
void TensorObj::printDataInt32_t() const {
|
||||
IT_ASSERT(data != nullptr);
|
||||
std::cout << "Tensor: " << guid << std::endl;
|
||||
auto numDims = shape.size();
|
||||
auto dimSzVec = std::vector<int>(numDims, 1);
|
||||
auto ptr = data->getPtr<int32_t *>();
|
||||
dimSzVec[numDims - 1] = shape[numDims - 1];
|
||||
for (int i = numDims - 1; i != 0; --i)
|
||||
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
|
||||
for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
|
||||
for (size_t j = 0; j < numDims; ++j) {
|
||||
if (i % dimSzVec[j] == 0) {
|
||||
std::cout << "[";
|
||||
}
|
||||
}
|
||||
std::cout << ptr[i];
|
||||
for (size_t j = 0; j < numDims; ++j) {
|
||||
if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
|
||||
std::cout << "]";
|
||||
}
|
||||
}
|
||||
if (i != size() - 1)
|
||||
std::cout << ", ";
|
||||
if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] - 1)
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
bool TensorObj::equalData(const Tensor &rhs) const {
|
||||
IT_ASSERT(data != nullptr);
|
||||
IT_ASSERT(rhs->data != nullptr);
|
||||
|
@ -142,6 +172,9 @@ bool TensorObj::equalData(const Tensor &rhs) const {
|
|||
else if (getDType() == DataType::Float32)
|
||||
return equalDataImpl(getRawDataPtr<float *>(),
|
||||
rhs->getRawDataPtr<float *>(), size());
|
||||
else if (getDType() == DataType::Int32)
|
||||
return equalDataImpl(getRawDataPtr<int32_t *>(),
|
||||
rhs->getRawDataPtr<int32_t *>(), size());
|
||||
else
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
@ -155,6 +188,8 @@ void TensorObj::dataMalloc() {
|
|||
bytesPerElement = sizeof(float);
|
||||
else if (getDType() == DataType::UInt32)
|
||||
bytesPerElement = sizeof(uint32_t);
|
||||
else if (getDType() == DataType::Int32)
|
||||
bytesPerElement = sizeof(int32_t);
|
||||
data = runtime->allocBlob(size() * bytesPerElement);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
#include "bang/bang_kernel_without_config.h"
|
||||
#include "bang/bang_runtime.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
class CastCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<CastObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cnnlTensorDescriptor_t aDesc, cDesc;
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
|
||||
// get inputs
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
cnnlCastDataType_t NlCastType;
|
||||
CastObj::CastType type = op->getType();
|
||||
switch(type){
|
||||
case CastObj::Float2Half:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_HALF, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_FLOAT_TO_HALF;
|
||||
break;
|
||||
case CastObj::Float2HalfIEEE754:
|
||||
case CastObj::Float2Double:
|
||||
case CastObj::Float2Int64:
|
||||
case CastObj::Float2Int32:
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
|
||||
NlCastType = CNNL_CAST_FLOAT_TO_INT32;
|
||||
case CastObj::Float2Int16:
|
||||
case CastObj::Float2Int8:
|
||||
case CastObj::Float2Bool:
|
||||
//Todo
|
||||
break;
|
||||
case CastObj::Half2Float:
|
||||
case CastObj::Half2Int32:
|
||||
case CastObj::Half2Int64:
|
||||
case CastObj::Half2Int16:
|
||||
case CastObj::Half2Int8:
|
||||
case CastObj::Half2Uint8:
|
||||
case CastObj::Half2Bool:
|
||||
case CastObj::Half2FloatInf:
|
||||
//todo
|
||||
break;
|
||||
case CastObj::Int322Float:
|
||||
case CastObj::Int322Half:
|
||||
case CastObj::Int322Int8:
|
||||
case CastObj::Int322Int16:
|
||||
//todo
|
||||
break;
|
||||
case CastObj::Int162Float:
|
||||
case CastObj::Int162Half:
|
||||
case CastObj::Int162Int32:
|
||||
//todo
|
||||
break;
|
||||
case CastObj::Int82Float:
|
||||
case CastObj::Int82Half:
|
||||
case CastObj::Int82Int16:
|
||||
case CastObj::Int82Int32:
|
||||
//todo
|
||||
break;
|
||||
case CastObj::Uint82Float:
|
||||
case CastObj::Uint82Half:
|
||||
case CastObj::Uint82Int32:
|
||||
case CastObj::Uint82Int64:
|
||||
//todo
|
||||
break;
|
||||
case CastObj::Bool2Float:
|
||||
case CastObj::Bool2Half:
|
||||
case CastObj::Bool2Int32:
|
||||
//todo
|
||||
break;
|
||||
case CastObj::Int322Int64:
|
||||
case CastObj::Int322Bool:
|
||||
//todo
|
||||
break;
|
||||
case CastObj::Int642Int32:
|
||||
case CastObj::Int642Uint32:
|
||||
case CastObj::Int642Float:
|
||||
case CastObj::Int642Half:
|
||||
//todo
|
||||
break;
|
||||
case CastObj::Uint642Uint32:
|
||||
case CastObj::Uint322Int64:
|
||||
case CastObj::Uint322Uint64:
|
||||
//todo
|
||||
break;
|
||||
case CastObj::Double2Float:
|
||||
//todo
|
||||
break;
|
||||
}
|
||||
cnnlStatus_t stat =
|
||||
cnnlCastDataType(context->cnnlHandle(), aDesc, aData, NlCastType, cDesc, cData);
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
// Destories in BANG does not require sync. But cnnl does not state
|
||||
// whether sync is required before destories.
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Cast, DataType::Float32, CastCnnl,
|
||||
"Cast_cnnl_BANG_Float32");
|
||||
|
||||
}; // namespace infini
|
|
@ -150,4 +150,33 @@ vector<int> TransformObj::getOpAttrVector() const {
|
|||
return {enum_to_underlying(type)};
|
||||
}
|
||||
|
||||
CastObj::CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type)
|
||||
: OperatorObj(OpType::Cast, {input}, {output}), castType(type) {
|
||||
IT_ASSERT(checkValid(graph, DataType::Int32));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> CastObj::inferShape(const TensorVec &inputs) const {
|
||||
const auto A = inputs[0];
|
||||
return {{A->getDims()}};
|
||||
}
|
||||
|
||||
std::string CastObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> CastObj::getWorkloadVector() const {
|
||||
vector<int> ret{enum_to_underlying(type)};
|
||||
const Shape shape = outputs[0]->getDims();
|
||||
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> CastObj::getOpAttrVector() const {
|
||||
return {enum_to_underlying(type)};
|
||||
}
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
#include "bang/bang_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testCast(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
|
||||
// GPU
|
||||
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||
auto inputGpu = bangGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, CastObj::Float2Int32);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
bangGraph->dataMalloc();
|
||||
bangRuntime->run(bangGraph);
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
inputCpu->printData();
|
||||
outputGpu2Cpu->printData();
|
||||
EXPECT_TRUE(1);
|
||||
}
|
||||
|
||||
TEST(cnnl_Cast, run) {
|
||||
testCast<CastObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue