This commit is contained in:
wanghailu 2023-01-05 14:23:46 +08:00
parent 156a40806d
commit 2b8bca17e2
60 changed files with 517 additions and 425 deletions

View File

@ -7,7 +7,8 @@ class DataType {
static const DataType Float32; static const DataType Float32;
static const DataType UInt32; static const DataType UInt32;
static const DataType Int32; static const DataType Int32;
static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t), sizeof(int32_t)}; static constexpr size_t sizePerElement[]{sizeof(float), sizeof(uint32_t),
sizeof(int32_t)};
static constexpr std::string_view names[]{"Float32", "UInt32", "Int32"}; static constexpr std::string_view names[]{"Float32", "UInt32", "Int32"};
private: private:

View File

@ -4,7 +4,8 @@
namespace infini { namespace infini {
class ActivationBackwardObj : public OperatorObj { class ActivationBackwardObj : public OperatorObj {
public: public:
ActivationBackwardObj(OpType type, GraphObj *graph, Tensor y, Tensor diff_y, Tensor x, Tensor diff_x); ActivationBackwardObj(OpType type, GraphObj *graph, Tensor y, Tensor diff_y,
Tensor x, Tensor diff_x);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override; std::string toString() const override;
@ -16,11 +17,12 @@ class ActivationBackwardObj : public OperatorObj {
vector<int> getOpAttrVector() const override; vector<int> getOpAttrVector() const override;
}; };
#define DEFINE_ACTIVATION_BACKWARD_OBJ(prefix, type) \ #define DEFINE_ACTIVATION_BACKWARD_OBJ(prefix, type) \
class prefix##Obj : public ActivationBackwardObj { \ class prefix##Obj : public ActivationBackwardObj { \
public: \ public: \
prefix##Obj(GraphObj *graph, Tensor y, Tensor diff_y, Tensor x, Tensor diff_x) \ prefix##Obj(GraphObj *graph, Tensor y, Tensor diff_y, Tensor x, \
: ActivationBackwardObj(type, graph, y, diff_y, x, diff_x) {} \ Tensor diff_x) \
: ActivationBackwardObj(type, graph, y, diff_y, x, diff_x) {} \
}; };
DEFINE_ACTIVATION_BACKWARD_OBJ(ReluBackward, OpType::ReluBackward) DEFINE_ACTIVATION_BACKWARD_OBJ(ReluBackward, OpType::ReluBackward)

View File

@ -91,14 +91,15 @@ class ConvBackwardFilterObj : public ConvBaseObj {
ActType act; ActType act;
public: public:
ConvBackwardFilterObj(GraphObj *graph, Tensor inputX, Tensor diffY, Tensor diffW, int ph, ConvBackwardFilterObj(GraphObj *graph, Tensor inputX, Tensor diffY,
int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1, Tensor diffW, int ph, int pw, int sh = 1, int sw = 1,
Tensor bias = nullptr, ActType act = ActType::None); int dh = 1, int dw = 1, Tensor bias = nullptr,
ActType act = ActType::None);
// Constructors for setting padding mode // Constructors for setting padding mode
ConvBackwardFilterObj(GraphObj *graph, Tensor inputX, Tensor diffY, Tensor diffW, ConvBackwardFilterObj(GraphObj *graph, Tensor inputX, Tensor diffY,
PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1, Tensor diffW, PaddingMode mode = PaddingMode::Same,
int dh = 1, int dw = 1, Tensor bias = nullptr, int sh = 1, int sw = 1, int dh = 1, int dw = 1,
ActType act = ActType::None); Tensor bias = nullptr, ActType act = ActType::None);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
ActType getAct() const { return act; } ActType getAct() const { return act; }

View File

@ -20,7 +20,8 @@ class ElementWiseObj : public OperatorObj {
class MSELossObj : public OperatorObj { class MSELossObj : public OperatorObj {
public: public:
enum Reduction { None = 0, Sum, Mean }; enum Reduction { None = 0, Sum, Mean };
MSELossObj(GraphObj *graph, Tensor input0, Tensor input1, Reduction reduction, Tensor output); MSELossObj(GraphObj *graph, Tensor input0, Tensor input1,
Reduction reduction, Tensor output);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
Reduction getReduction() const { return reductionMode; } Reduction getReduction() const { return reductionMode; }

View File

@ -28,7 +28,7 @@ class ClipObj : public OperatorObj {
int numOutputs() const override { return 1; } int numOutputs() const override { return 1; }
private: private:
float minValue,maxValue; float minValue, maxValue;
vector<int> getWorkloadVector() const override; vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override; vector<int> getOpAttrVector() const override;
}; };
@ -65,7 +65,8 @@ class L2LossObj : public OperatorObj {
class TransformObj : public OperatorObj { class TransformObj : public OperatorObj {
public: public:
TransformObj(GraphObj *graph, Tensor input, Tensor output, float alpha, float beta); TransformObj(GraphObj *graph, Tensor input, Tensor output, float alpha,
float beta);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override; std::string toString() const override;
@ -82,18 +83,52 @@ class TransformObj : public OperatorObj {
class CastObj : public OperatorObj { class CastObj : public OperatorObj {
public: public:
enum CastType { Float2Half = 0, Float2HalfIEEE754, Float2Double, Float2Int64, Float2Int32, Float2Int16, Float2Int8, Float2Bool, enum CastType {
Half2Float, Half2Int32, Half2Int64, Half2Int16, Half2Int8, Half2Uint8, Half2Bool, Half2FloatInf, Float2Half = 0,
Int322Float, Int322Half, Int322Int8, Int322Int16, Float2HalfIEEE754,
Int162Float, Int162Half, Int162Int32, Float2Double,
Int82Float, Int82Half, Int82Int16, Int82Int32, Float2Int64,
Uint82Float, Uint82Half, Uint82Int32, Uint82Int64, Float2Int32,
Bool2Float, Bool2Half, Bool2Int32, Float2Int16,
Int322Int64, Int322Bool, Float2Int8,
Int642Int32, Int642Uint32, Int642Float, Int642Half, Float2Bool,
Uint642Uint32, Half2Float,
Uint322Int64, Uint322Uint64, Half2Int32,
Double2Float}; Half2Int64,
Half2Int16,
Half2Int8,
Half2Uint8,
Half2Bool,
Half2FloatInf,
Int322Float,
Int322Half,
Int322Int8,
Int322Int16,
Int162Float,
Int162Half,
Int162Int32,
Int82Float,
Int82Half,
Int82Int16,
Int82Int32,
Uint82Float,
Uint82Half,
Uint82Int32,
Uint82Int64,
Bool2Float,
Bool2Half,
Bool2Int32,
Int322Int64,
Int322Bool,
Int642Int32,
Int642Uint32,
Int642Float,
Int642Half,
Uint642Uint32,
Uint322Int64,
Uint322Uint64,
Double2Float
};
CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type); CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
@ -110,7 +145,8 @@ class CastObj : public OperatorObj {
class CumsumObj : public OperatorObj { class CumsumObj : public OperatorObj {
public: public:
CumsumObj(GraphObj *graph, Tensor input, Tensor output, int axis, bool exclusive, bool reverse); CumsumObj(GraphObj *graph, Tensor input, Tensor output, int axis,
bool exclusive, bool reverse);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override; std::string toString() const override;
@ -129,8 +165,9 @@ class CumsumObj : public OperatorObj {
// class CumprodObj : public OperatorObj { // class CumprodObj : public OperatorObj {
// public: // public:
// CumprodObj(GraphObj *graph, Tensor input, Tensor output, int axis, bool exclusive, bool reverse); // CumprodObj(GraphObj *graph, Tensor input, Tensor output, int axis, bool
// optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; // exclusive, bool reverse); optional<vector<Shape>> inferShape(const
// TensorVec &inputs) const override;
// //
// std::string toString() const override; // std::string toString() const override;
// int getAxis() const { return axisValue; } // int getAxis() const { return axisValue; }

View File

@ -10,9 +10,7 @@ OperatorObj::OperatorObj(OpType opType, TensorVec inputs, TensorVec outputs)
IT_ASSERT(t != nullptr); IT_ASSERT(t != nullptr);
} }
OperatorObj::OperatorObj(OpType opType) OperatorObj::OperatorObj(OpType opType) : type(opType) {}
: type(opType){
}
bool OperatorObj::isLinearOp() const { bool OperatorObj::isLinearOp() const {
return enum_to_underlying(type) >= 100 && enum_to_underlying(type) < 200; return enum_to_underlying(type) >= 100 && enum_to_underlying(type) < 200;
@ -91,7 +89,8 @@ bool OperatorObj::checkValid(GraphObj *graph, DataType type) {
if (shapes.size() != outputs.size()) if (shapes.size() != outputs.size())
return false; return false;
if (graph) { // if graph != nullptr, outputs should be created if (graph) { // if graph != nullptr, outputs should be created
auto dataTypes = vector(numOutputs(), type);; auto dataTypes = vector(numOutputs(), type);
;
for (size_t i = 0; i < outputs.size(); i++) { for (size_t i = 0; i < outputs.size(); i++) {
IT_ASSERT(!outputs[i]); IT_ASSERT(!outputs[i]);
outputs[i] = graph->addTensor(shapes[i], dataTypes[i]); outputs[i] = graph->addTensor(shapes[i], dataTypes[i]);

View File

@ -1,6 +1,6 @@
#include "operators/activation_backward.h"
#include "bang/bang_kernel_without_config.h" #include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h" #include "bang/bang_runtime.h"
#include "operators/activation_backward.h"
namespace infini { namespace infini {
class ActivationBackwardCnnl : public BangKernelWithoutConfig { class ActivationBackwardCnnl : public BangKernelWithoutConfig {
@ -47,11 +47,9 @@ class ActivationBackwardCnnl : public BangKernelWithoutConfig {
opDesc, getOpType(), CNNL_NOT_PROPAGATE_NAN, getCoef())); opDesc, getOpType(), CNNL_NOT_PROPAGATE_NAN, getCoef()));
auto [alpha, beta] = getAlphBeta(); auto [alpha, beta] = getAlphBeta();
cnnlStatus_t stat = cnnlStatus_t stat = cnnlActivationBackward(
cnnlActivationBackward(context->cnnlHandle(), opDesc, &alpha, yDesc, yData, context->cnnlHandle(), opDesc, &alpha, yDesc, yData, diffYDesc,
diffYDesc, diffYData, diffYData, xDesc, xData, &beta, diffXDesc, diffXData);
xDesc, xData,
&beta, diffXDesc, diffXData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -86,11 +84,11 @@ class TanhBackwardCnnl : public ActivationBackwardCnnl {
float getCoef() const override { return 0.0; } float getCoef() const override { return 0.0; }
}; };
REGISTER_KERNEL(Device::BANG, OpType::ReluBackward, DataType::Float32, ReluBackwardCnnl, REGISTER_KERNEL(Device::BANG, OpType::ReluBackward, DataType::Float32,
"ReluBackward_cnnl_BANG_Float32"); ReluBackwardCnnl, "ReluBackward_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::SigmoidBackward, DataType::Float32, SigmoidBackwardCnnl, REGISTER_KERNEL(Device::BANG, OpType::SigmoidBackward, DataType::Float32,
"SigmoidBackward_cnnl_BANG_Float32"); SigmoidBackwardCnnl, "SigmoidBackward_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::TanhBackward, DataType::Float32, TanhBackwardCnnl, REGISTER_KERNEL(Device::BANG, OpType::TanhBackward, DataType::Float32,
"TanhBackward_cnnl_BANG_Float32"); TanhBackwardCnnl, "TanhBackward_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -1,6 +1,6 @@
#include "operators/element_wise.h"
#include "bang/bang_kernel_without_config.h" #include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h" #include "bang/bang_runtime.h"
#include "operators/element_wise.h"
namespace infini { namespace infini {
class AddNCnnl : public BangKernelWithoutConfig { class AddNCnnl : public BangKernelWithoutConfig {
@ -10,8 +10,8 @@ class AddNCnnl : public BangKernelWithoutConfig {
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
int num = op->numInputs(); int num = op->numInputs();
void *argv[num]; void *argv[num];
for(int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
argv[i] = op->getInputs(i)->getRawDataPtr<void *>(); argv[i] = op->getInputs(i)->getRawDataPtr<void *>();
} }
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); void *const cData = (op->getOutput()->getRawDataPtr<void *>());
@ -25,19 +25,21 @@ class AddNCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlSetTensorDescriptor(desc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(desc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array)); CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlTensorDescriptor_t descArray[num]; cnnlTensorDescriptor_t descArray[num];
for(int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i])); checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i]));
checkCnnlError(cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW, checkCnnlError(
CNNL_DTYPE_FLOAT, 4, dim_array)); cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
} }
cnnlStatus_t stat = cnnlAddN(context->cnnlHandle(), descArray, argv, num, desc, cData); cnnlStatus_t stat =
cnnlAddN(context->cnnlHandle(), descArray, argv, num, desc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
// Destories in BANG does not require sync. But cnnl does not state // Destories in BANG does not require sync. But cnnl does not state
// whether sync is required before destories. // whether sync is required before destories.
for(int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
checkCnnlError(cnnlDestroyTensorDescriptor(descArray[i])); checkCnnlError(cnnlDestroyTensorDescriptor(descArray[i]));
} }
checkCnnlError(cnnlDestroyTensorDescriptor(desc)); checkCnnlError(cnnlDestroyTensorDescriptor(desc));

View File

@ -23,83 +23,87 @@ class CastCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
cnnlCastDataType_t NlCastType; cnnlCastDataType_t NlCastType;
CastObj::CastType type = op->getType(); CastObj::CastType type = op->getType();
switch(type){ switch (type) {
case CastObj::Float2Half: case CastObj::Float2Half:
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array)); checkCnnlError(cnnlSetTensorDescriptor(
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_HALF, 4, dim_array)); aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
NlCastType = CNNL_CAST_FLOAT_TO_HALF; checkCnnlError(cnnlSetTensorDescriptor(
break; cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_HALF, 4, dim_array));
case CastObj::Float2HalfIEEE754: NlCastType = CNNL_CAST_FLOAT_TO_HALF;
case CastObj::Float2Double: break;
case CastObj::Float2Int64: case CastObj::Float2HalfIEEE754:
case CastObj::Float2Int32: case CastObj::Float2Double:
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array)); case CastObj::Float2Int64:
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array)); case CastObj::Float2Int32:
NlCastType = CNNL_CAST_FLOAT_TO_INT32; checkCnnlError(cnnlSetTensorDescriptor(
case CastObj::Float2Int16: aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
case CastObj::Float2Int8: checkCnnlError(cnnlSetTensorDescriptor(
case CastObj::Float2Bool: cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_INT32, 4, dim_array));
//Todo NlCastType = CNNL_CAST_FLOAT_TO_INT32;
break; case CastObj::Float2Int16:
case CastObj::Half2Float: case CastObj::Float2Int8:
case CastObj::Half2Int32: case CastObj::Float2Bool:
case CastObj::Half2Int64: // Todo
case CastObj::Half2Int16: break;
case CastObj::Half2Int8: case CastObj::Half2Float:
case CastObj::Half2Uint8: case CastObj::Half2Int32:
case CastObj::Half2Bool: case CastObj::Half2Int64:
case CastObj::Half2FloatInf: case CastObj::Half2Int16:
//todo case CastObj::Half2Int8:
break; case CastObj::Half2Uint8:
case CastObj::Int322Float: case CastObj::Half2Bool:
case CastObj::Int322Half: case CastObj::Half2FloatInf:
case CastObj::Int322Int8: // todo
case CastObj::Int322Int16: break;
//todo case CastObj::Int322Float:
break; case CastObj::Int322Half:
case CastObj::Int162Float: case CastObj::Int322Int8:
case CastObj::Int162Half: case CastObj::Int322Int16:
case CastObj::Int162Int32: // todo
//todo break;
break; case CastObj::Int162Float:
case CastObj::Int82Float: case CastObj::Int162Half:
case CastObj::Int82Half: case CastObj::Int162Int32:
case CastObj::Int82Int16: // todo
case CastObj::Int82Int32: break;
//todo case CastObj::Int82Float:
break; case CastObj::Int82Half:
case CastObj::Uint82Float: case CastObj::Int82Int16:
case CastObj::Uint82Half: case CastObj::Int82Int32:
case CastObj::Uint82Int32: // todo
case CastObj::Uint82Int64: break;
//todo case CastObj::Uint82Float:
break; case CastObj::Uint82Half:
case CastObj::Bool2Float: case CastObj::Uint82Int32:
case CastObj::Bool2Half: case CastObj::Uint82Int64:
case CastObj::Bool2Int32: // todo
//todo break;
break; case CastObj::Bool2Float:
case CastObj::Int322Int64: case CastObj::Bool2Half:
case CastObj::Int322Bool: case CastObj::Bool2Int32:
//todo // todo
break; break;
case CastObj::Int642Int32: case CastObj::Int322Int64:
case CastObj::Int642Uint32: case CastObj::Int322Bool:
case CastObj::Int642Float: // todo
case CastObj::Int642Half: break;
//todo case CastObj::Int642Int32:
break; case CastObj::Int642Uint32:
case CastObj::Uint642Uint32: case CastObj::Int642Float:
case CastObj::Uint322Int64: case CastObj::Int642Half:
case CastObj::Uint322Uint64: // todo
//todo break;
break; case CastObj::Uint642Uint32:
case CastObj::Double2Float: case CastObj::Uint322Int64:
//todo case CastObj::Uint322Uint64:
break; // todo
break;
case CastObj::Double2Float:
// todo
break;
} }
cnnlStatus_t stat = cnnlStatus_t stat = cnnlCastDataType(context->cnnlHandle(), aDesc,
cnnlCastDataType(context->cnnlHandle(), aDesc, aData, NlCastType, cDesc, cData); aData, NlCastType, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -1,6 +1,6 @@
#include "operators/conv.h"
#include "bang/bang_kernel_without_config.h" #include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h" #include "bang/bang_runtime.h"
#include "operators/conv.h"
namespace infini { namespace infini {
class ConvBackwardFilterCnnl : public BangKernelWithoutConfig { class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
@ -27,7 +27,8 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>()); void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); void *const cData = (op->getOutput()->getRawDataPtr<void *>());
cnnlTensorDescriptor_t aDesc, bDesc, cDesc, aDescTrans, bDescTrans, cDescTrans; cnnlTensorDescriptor_t aDesc, bDesc, cDesc, aDescTrans, bDescTrans,
cDescTrans;
auto dimInputs0 = op->getInputs(0)->getDims(); auto dimInputs0 = op->getInputs(0)->getDims();
auto dimInputs1 = op->getInputs(1)->getDims(); auto dimInputs1 = op->getInputs(1)->getDims();
auto dimOutput = op->getOutput()->getDims(); auto dimOutput = op->getOutput()->getDims();
@ -47,17 +48,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
dimOutput[3]}; dimOutput[3]};
int inputs0ArrayTrans[4] = {dimInputs0[0], dimInputs0[2], dimInputs0[3], int inputs0ArrayTrans[4] = {dimInputs0[0], dimInputs0[2], dimInputs0[3],
dimInputs0[1]}; dimInputs0[1]};
int inputs1ArrayTrans[4] = {dimInputs1[0], dimInputs1[2], dimInputs1[3], int inputs1ArrayTrans[4] = {dimInputs1[0], dimInputs1[2], dimInputs1[3],
dimInputs1[1]}; dimInputs1[1]};
int outputArrayTrans[4] = {dimOutput[0], dimOutput[2], dimOutput[3], int outputArrayTrans[4] = {dimOutput[0], dimOutput[2], dimOutput[3],
dimOutput[1]}; dimOutput[1]};
int transMode[4] = {0, 2, 3, 1}; int transMode[4] = {0, 2, 3, 1};
cnnlTransposeDescriptor_t transDesc; cnnlTransposeDescriptor_t transDesc;
checkCnnlError(cnnlCreateTransposeDescriptor(&transDesc)); checkCnnlError(cnnlCreateTransposeDescriptor(&transDesc));
checkCnnlError(cnnlSetTransposeDescriptor( checkCnnlError(cnnlSetTransposeDescriptor(transDesc, 4, transMode));
transDesc, 4, transMode));
// get inputs // get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
@ -65,13 +65,17 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inputs0Array)); aDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inputs0Array));
checkCnnlError(cnnlCreateTensorDescriptor(&aDescTrans)); checkCnnlError(cnnlCreateTensorDescriptor(&aDescTrans));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(aDescTrans, CNNL_LAYOUT_NHWC,
aDescTrans, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, inputs0ArrayTrans)); CNNL_DTYPE_FLOAT, 4,
inputs0ArrayTrans));
size_t wsTrans1Size = dimInputs0[0] * dimInputs0[1] * dimInputs0[2] * dimInputs0[3] * sizeof(float); size_t wsTrans1Size = dimInputs0[0] * dimInputs0[1] * dimInputs0[2] *
dimInputs0[3] * sizeof(float);
BangPtr wsTrans1Data = context->getWorkspace(wsTrans1Size); BangPtr wsTrans1Data = context->getWorkspace(wsTrans1Size);
cnnlStatus_t stat = cnnlTranspose(context->cnnlHandle(), transDesc, aDesc, aData, aDescTrans, wsTrans1Data); cnnlStatus_t stat =
cnnlTranspose(context->cnnlHandle(), transDesc, aDesc, aData,
aDescTrans, wsTrans1Data);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -80,13 +84,16 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inputs1Array)); bDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, inputs1Array));
checkCnnlError(cnnlCreateTensorDescriptor(&bDescTrans)); checkCnnlError(cnnlCreateTensorDescriptor(&bDescTrans));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(bDescTrans, CNNL_LAYOUT_NHWC,
bDescTrans, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, inputs1ArrayTrans)); CNNL_DTYPE_FLOAT, 4,
inputs1ArrayTrans));
size_t wsTrans2Size = dimInputs1[0] * dimInputs1[1] * dimInputs1[2] * dimInputs1[3] * sizeof(float); size_t wsTrans2Size = dimInputs1[0] * dimInputs1[1] * dimInputs1[2] *
dimInputs1[3] * sizeof(float);
BangPtr wsTrans2Data = context->getWorkspace(wsTrans2Size); BangPtr wsTrans2Data = context->getWorkspace(wsTrans2Size);
stat = cnnlTranspose(context->cnnlHandle(), transDesc, bDesc, bData, bDescTrans, wsTrans2Data); stat = cnnlTranspose(context->cnnlHandle(), transDesc, bDesc, bData,
bDescTrans, wsTrans2Data);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -96,36 +103,40 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, outputArray)); cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, outputArray));
checkCnnlError(cnnlCreateTensorDescriptor(&cDescTrans)); checkCnnlError(cnnlCreateTensorDescriptor(&cDescTrans));
checkCnnlError(cnnlSetTensorDescriptor( checkCnnlError(cnnlSetTensorDescriptor(cDescTrans, CNNL_LAYOUT_NHWC,
cDescTrans, CNNL_LAYOUT_NHWC, CNNL_DTYPE_FLOAT, 4, outputArrayTrans)); CNNL_DTYPE_FLOAT, 4,
outputArrayTrans));
size_t wsTrans3Size = dimOutput[0] * dimOutput[1] * dimOutput[2] * dimOutput[3] * sizeof(float); size_t wsTrans3Size = dimOutput[0] * dimOutput[1] * dimOutput[2] *
dimOutput[3] * sizeof(float);
BangPtr wsTrans3Data = context->getWorkspace(wsTrans3Size); BangPtr wsTrans3Data = context->getWorkspace(wsTrans3Size);
cnnlConvolutionBwdFilterAlgo_t algo; cnnlConvolutionBwdFilterAlgo_t algo;
cnnlGetConvolutionBackwardFilterAlgorithm(context->cnnlHandle(), convDesc, cnnlGetConvolutionBackwardFilterAlgorithm(
aDescTrans, bDescTrans, cDescTrans, context->cnnlHandle(), convDesc, aDescTrans, bDescTrans, cDescTrans,
CNNL_CONVOLUTION_BWD_FILTER_FASTEST, &algo); CNNL_CONVOLUTION_BWD_FILTER_FASTEST, &algo);
size_t wsSize; size_t wsSize;
cnnlGetConvolutionBackwardFilterWorkspaceSize(context->cnnlHandle(), aDescTrans, cnnlGetConvolutionBackwardFilterWorkspaceSize(
bDescTrans, cDescTrans, convDesc, context->cnnlHandle(), aDescTrans, bDescTrans, cDescTrans, convDesc,
algo, &wsSize); algo, &wsSize);
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
stat = cnnlConvolutionBackwardFilter( stat = cnnlConvolutionBackwardFilter(
context->cnnlHandle(), NULL, aDescTrans, wsTrans1Data, bDescTrans, context->cnnlHandle(), NULL, aDescTrans, wsTrans1Data, bDescTrans,
wsTrans2Data, convDesc, algo, wsData, wsSize, NULL, cDescTrans, wsTrans3Data); wsTrans2Data, convDesc, algo, wsData, wsSize, NULL, cDescTrans,
wsTrans3Data);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
int transMode2[4] = {0, 3, 1, 2}; int transMode2[4] = {0, 3, 1, 2};
cnnlTransposeDescriptor_t transOutputDesc; cnnlTransposeDescriptor_t transOutputDesc;
checkCnnlError(cnnlCreateTransposeDescriptor(&transOutputDesc)); checkCnnlError(cnnlCreateTransposeDescriptor(&transOutputDesc));
checkCnnlError(cnnlSetTransposeDescriptor( checkCnnlError(
transOutputDesc, 4, transMode2)); cnnlSetTransposeDescriptor(transOutputDesc, 4, transMode2));
stat = cnnlTranspose(context->cnnlHandle(), transOutputDesc, cDescTrans, wsTrans3Data, cDesc, cData); stat = cnnlTranspose(context->cnnlHandle(), transOutputDesc, cDescTrans,
wsTrans3Data, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -143,6 +154,6 @@ class ConvBackwardFilterCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::ConvBackwardFilter, DataType::Float32, ConvBackwardFilterCnnl, REGISTER_KERNEL(Device::BANG, OpType::ConvBackwardFilter, DataType::Float32,
"ConvBackwardFilter_cnnl_BANG_Float32"); ConvBackwardFilterCnnl, "ConvBackwardFilter_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -32,7 +32,8 @@ class CumsumCnnl : public BangKernelWithoutConfig {
CNNL_DTYPE_FLOAT, 4, dim_array)); CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlCumsum(context->cnnlHandle(), aDesc, aData, axis, exclusive, reverse, CNNL_NOT_PROPAGATE_NAN, cDesc, cData); cnnlCumsum(context->cnnlHandle(), aDesc, aData, axis, exclusive,
reverse, CNNL_NOT_PROPAGATE_NAN, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -1,6 +1,6 @@
#include "operators/det.h"
#include "bang/bang_kernel_without_config.h" #include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h" #include "bang/bang_runtime.h"
#include "operators/det.h"
namespace infini { namespace infini {
class DetCnnl : public BangKernelWithoutConfig { class DetCnnl : public BangKernelWithoutConfig {
@ -13,7 +13,7 @@ class DetCnnl : public BangKernelWithoutConfig {
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); void *const cData = (op->getOutput()->getRawDataPtr<void *>());
DetObj::Mode mode = op->getMode(); DetObj::Mode mode = op->getMode();
cnnlDetMode_t nlMode; cnnlDetMode_t nlMode;
if(mode == DetObj::LogDet) { if (mode == DetObj::LogDet) {
nlMode = CNNL_DET_MODE_LOGDET; nlMode = CNNL_DET_MODE_LOGDET;
} else { } else {
nlMode = CNNL_DET_MODE_DET; nlMode = CNNL_DET_MODE_DET;
@ -28,15 +28,16 @@ class DetCnnl : public BangKernelWithoutConfig {
int dimout_array[2] = {dimout[0], dimout[1]}; int dimout_array[2] = {dimout[0], dimout[1]};
// get inputs // get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, dimin_array)); aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, 4, dimin_array));
// get outputs // get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 2, dimout_array)); cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, 2, dimout_array));
cnnlStatus_t stat = cnnlDet(context->cnnlHandle(), nlMode, aDesc, aData, cDesc, cData); cnnlStatus_t stat =
cnnlDet(context->cnnlHandle(), nlMode, aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -98,13 +98,13 @@ class DivCnnl : public BangKernelWithoutConfig {
size_t wsSize; size_t wsSize;
cnnlGetDivWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, cnnlGetDivWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
&wsSize); &wsSize);
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
cnnlStatus_t stat = cnnlDiv_v2(context->cnnlHandle(), cnnlStatus_t stat = cnnlDiv_v2(
CNNL_COMPUTATION_HIGH_PRECISION, context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, aDesc,
aDesc, aData, bDesc, bData, wsData, wsSize, cDesc, cData); aData, bDesc, bData, wsData, wsSize, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -152,9 +152,9 @@ class DivNoNanCnnl : public BangKernelWithoutConfig {
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
cnnlStatus_t stat = cnnlDivNoNan_v2(context->cnnlHandle(), cnnlStatus_t stat = cnnlDivNoNan_v2(
CNNL_COMPUTATION_HIGH_PRECISION, context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, aDesc,
aDesc, aData, bDesc, bData, wsData, wsSize, cDesc, cData); aData, bDesc, bData, wsData, wsSize, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -201,8 +201,9 @@ class MaximumCnnl : public BangKernelWithoutConfig {
cnnlGetMaximumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize); cnnlGetMaximumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize);
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
cnnlStatus_t stat = cnnlMaximum(context->cnnlHandle(), aDesc, aData, bDesc, bData, cnnlStatus_t stat =
cDesc, cData, wsData, wsSize); cnnlMaximum(context->cnnlHandle(), aDesc, aData, bDesc, bData,
cDesc, cData, wsData, wsSize);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -249,8 +250,9 @@ class MinimumCnnl : public BangKernelWithoutConfig {
cnnlGetMinimumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize); cnnlGetMinimumWorkspaceSize(context->cnnlHandle(), cDesc, &wsSize);
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
cnnlStatus_t stat = cnnlMinimum(context->cnnlHandle(), aDesc, aData, bDesc, bData, cnnlStatus_t stat =
cDesc, cData, wsData, wsSize); cnnlMinimum(context->cnnlHandle(), aDesc, aData, bDesc, bData,
cDesc, cData, wsData, wsSize);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -278,7 +280,7 @@ class MSELossCnnl : public BangKernelWithoutConfig {
IT_TODO_HALT(); IT_TODO_HALT();
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]}; int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
int dim_out[4] ={1,1,1,1}; int dim_out[4] = {1, 1, 1, 1};
// get inputs // get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
@ -290,23 +292,23 @@ class MSELossCnnl : public BangKernelWithoutConfig {
// get outputs // get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
if ( reduction == MSELossObj::None ) { if (reduction == MSELossObj::None) {
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, dim_array)); cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_array));
} else { } else {
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, dim_out)); cDesc, CNNL_LAYOUT_NCHW, CNNL_DTYPE_FLOAT, 4, dim_out));
} }
cnnlStatus_t stat; cnnlStatus_t stat;
if( reduction == MSELossObj::None ) { if (reduction == MSELossObj::None) {
stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_NONE, aDesc, aData, bDesc, bData, stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_NONE, aDesc,
cDesc, cData); aData, bDesc, bData, cDesc, cData);
} else if (reduction == MSELossObj::Sum) { } else if (reduction == MSELossObj::Sum) {
stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_SUM, aDesc, aData, bDesc, bData, stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_SUM, aDesc,
cDesc, cData); aData, bDesc, bData, cDesc, cData);
} else { } else {
stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_MEAN, aDesc, aData, bDesc, bData, stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_MEAN, aDesc,
cDesc, cData); aData, bDesc, bData, cDesc, cData);
} }
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
@ -352,11 +354,13 @@ class PowerCnnl : public BangKernelWithoutConfig {
// get op descriptor // get op descriptor
size_t wsSize; size_t wsSize;
cnnlGetPowWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, &wsSize); cnnlGetPowWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc,
&wsSize);
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
cnnlStatus_t stat = cnnlPow(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, cnnlStatus_t stat =
aDesc, aData, bDesc, bData, wsData, wsSize, cDesc, cData); cnnlPow(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,
aDesc, aData, bDesc, bData, wsData, wsSize, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -404,9 +408,9 @@ class FloorDivCnnl : public BangKernelWithoutConfig {
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
cnnlStatus_t stat = cnnlFloorDiv_v2(context->cnnlHandle(), cnnlStatus_t stat = cnnlFloorDiv_v2(
CNNL_COMPUTATION_HIGH_PRECISION, context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, aDesc,
aDesc, aData, bDesc, bData, cDesc, cData, wsData, wsSize); aData, bDesc, bData, cDesc, cData, wsData, wsSize);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -449,14 +453,14 @@ class FloorDivTruncCnnl : public BangKernelWithoutConfig {
CNNL_DTYPE_FLOAT, 4, dim_array)); CNNL_DTYPE_FLOAT, 4, dim_array));
size_t wsSize; size_t wsSize;
cnnlGetFloorDivTruncWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, cnnlGetFloorDivTruncWorkspaceSize(context->cnnlHandle(), aDesc, bDesc,
&wsSize); cDesc, &wsSize);
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
cnnlStatus_t stat = cnnlFloorDivTrunc(context->cnnlHandle(), cnnlStatus_t stat = cnnlFloorDivTrunc(
CNNL_COMPUTATION_HIGH_PRECISION, context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, aDesc,
aDesc, aData, bDesc, bData, cDesc, cData, wsData, wsSize); aData, bDesc, bData, cDesc, cData, wsData, wsSize);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -504,8 +508,9 @@ class FloorModCnnl : public BangKernelWithoutConfig {
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
cnnlStatus_t stat = cnnlFloorMod(context->cnnlHandle(), cnnlStatus_t stat =
aDesc, aData, bDesc, bData, cDesc, cData, wsData, wsSize); cnnlFloorMod(context->cnnlHandle(), aDesc, aData, bDesc, bData,
cDesc, cData, wsData, wsSize);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -536,25 +541,30 @@ class FloorModCnnl : public BangKernelWithoutConfig {
// // get inputs // // get inputs
// checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); // checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
// checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW, // checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
// CNNL_DTYPE_FLOAT, 4, dim_array)); // CNNL_DTYPE_FLOAT, 4,
// dim_array));
// //
// checkCnnlError(cnnlCreateTensorDescriptor(&bDesc)); // checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
// checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW, // checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
// CNNL_DTYPE_FLOAT, 4, dim_array)); // CNNL_DTYPE_FLOAT, 4,
// dim_array));
// //
// // get outputs // // get outputs
// checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); // checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
// checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, // checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
// CNNL_DTYPE_FLOAT, 4, dim_array)); // CNNL_DTYPE_FLOAT, 4,
// dim_array));
// //
// size_t wsSize; // size_t wsSize;
// cnnlGetFloorModTruncWorkspaceSize(context->cnnlHandle(), aDesc, bDesc, cDesc, // cnnlGetFloorModTruncWorkspaceSize(context->cnnlHandle(), aDesc,
// bDesc, cDesc,
// &wsSize); // &wsSize);
// //
// BangPtr wsData = context->getWorkspace(wsSize); // BangPtr wsData = context->getWorkspace(wsSize);
// //
// cnnlStatus_t stat = cnnlFloorModTrunc(context->cnnlHandle(), // cnnlStatus_t stat = cnnlFloorModTrunc(context->cnnlHandle(),
// aDesc, aData, bDesc, bData, cDesc, cData, wsData, wsSize); // aDesc, aData, bDesc, bData, cDesc,
// cData, wsData, wsSize);
// if (stat != CNNL_STATUS_SUCCESS) // if (stat != CNNL_STATUS_SUCCESS)
// return; // return;
// //
@ -595,8 +605,8 @@ REGISTER_KERNEL(Device::BANG, OpType::Sub, DataType::Float32, SubCnnl,
REGISTER_KERNEL(Device::BANG, OpType::Mul, DataType::Float32, MulCnnl, REGISTER_KERNEL(Device::BANG, OpType::Mul, DataType::Float32, MulCnnl,
"Mul_cnnl_BANG_Float32"); "Mul_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::DivDemo, DataType::Float32, ElementWiseBang, REGISTER_KERNEL(Device::BANG, OpType::DivDemo, DataType::Float32,
"DivDemo_Bang_Float32"); ElementWiseBang, "DivDemo_Bang_Float32");
REGISTER_KERNEL(Device::BANG, OpType::Div, DataType::Float32, DivCnnl, REGISTER_KERNEL(Device::BANG, OpType::Div, DataType::Float32, DivCnnl,
"Div_cnnl_Float32"); "Div_cnnl_Float32");
REGISTER_KERNEL(Device::BANG, OpType::DivNoNan, DataType::Float32, DivNoNanCnnl, REGISTER_KERNEL(Device::BANG, OpType::DivNoNan, DataType::Float32, DivNoNanCnnl,
@ -611,11 +621,12 @@ REGISTER_KERNEL(Device::BANG, OpType::Power, DataType::Float32, PowerCnnl,
"Power_cnnl_BANG_Float32"); "Power_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, DataType::Float32, FloorDivCnnl, REGISTER_KERNEL(Device::BANG, OpType::FloorDiv, DataType::Float32, FloorDivCnnl,
"FloorDiv_cnnl_BANG_Float32"); "FloorDiv_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::FloorDivTrunc, DataType::Float32, FloorDivTruncCnnl, REGISTER_KERNEL(Device::BANG, OpType::FloorDivTrunc, DataType::Float32,
"FloorDivTrunc_cnnl_BANG_Float32"); FloorDivTruncCnnl, "FloorDivTrunc_cnnl_BANG_Float32");
REGISTER_KERNEL(Device::BANG, OpType::FloorMod, DataType::Float32, FloorModCnnl, REGISTER_KERNEL(Device::BANG, OpType::FloorMod, DataType::Float32, FloorModCnnl,
"FloorMod_cnnl_BANG_Float32"); "FloorMod_cnnl_BANG_Float32");
// REGISTER_KERNEL(Device::BANG, OpType::FloorModTrunc, DataType::Float32, FloorModTruncCnnl, // REGISTER_KERNEL(Device::BANG, OpType::FloorModTrunc, DataType::Float32,
// FloorModTruncCnnl,
// "FloorModTrunc_cnnl_BANG_Float32"); // "FloorModTrunc_cnnl_BANG_Float32");
// REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32, // REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32,
// ElementWiseBang, // ElementWiseBang,

View File

@ -29,7 +29,8 @@ class ErfCnnl : public BangKernelWithoutConfig {
CNNL_DTYPE_FLOAT, 4, dim_array)); CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlErf_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, aDesc, aData, cDesc, cData); cnnlErf_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,
aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -29,7 +29,8 @@ class ExpCnnl : public BangKernelWithoutConfig {
CNNL_DTYPE_FLOAT, 4, dim_array)); CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlExp_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, aDesc, aData, cDesc, cData); cnnlExp_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,
aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -31,7 +31,7 @@ class LogCnnl : public BangKernelWithoutConfig {
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlLog_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, cnnlLog_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,
getOpType(), aDesc, aData, cDesc, cData); getOpType(), aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -1,6 +1,6 @@
#include "operators/element_wise.h"
#include "bang/bang_kernel_without_config.h" #include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h" #include "bang/bang_runtime.h"
#include "operators/element_wise.h"
namespace infini { namespace infini {
class MulNCnnl : public BangKernelWithoutConfig { class MulNCnnl : public BangKernelWithoutConfig {
@ -10,8 +10,8 @@ class MulNCnnl : public BangKernelWithoutConfig {
auto context = dynamic_cast<const BangRuntimeObj *>(_context); auto context = dynamic_cast<const BangRuntimeObj *>(_context);
int num = op->numInputs(); int num = op->numInputs();
void *argv[num]; void *argv[num];
for(int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
argv[i] = op->getInputs(i)->getRawDataPtr<void *>(); argv[i] = op->getInputs(i)->getRawDataPtr<void *>();
} }
void *const cData = (op->getOutput()->getRawDataPtr<void *>()); void *const cData = (op->getOutput()->getRawDataPtr<void *>());
@ -25,19 +25,21 @@ class MulNCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlSetTensorDescriptor(desc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(desc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array)); CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlTensorDescriptor_t descArray[num]; cnnlTensorDescriptor_t descArray[num];
for(int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i])); checkCnnlError(cnnlCreateTensorDescriptor(&descArray[i]));
checkCnnlError(cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW, checkCnnlError(
CNNL_DTYPE_FLOAT, 4, dim_array)); cnnlSetTensorDescriptor(descArray[i], CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
} }
cnnlStatus_t stat = cnnlMulN(context->cnnlHandle(), descArray, argv, num, desc, cData); cnnlStatus_t stat =
cnnlMulN(context->cnnlHandle(), descArray, argv, num, desc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
// Destories in BANG does not require sync. But cnnl does not state // Destories in BANG does not require sync. But cnnl does not state
// whether sync is required before destories. // whether sync is required before destories.
for(int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
checkCnnlError(cnnlDestroyTensorDescriptor(descArray[i])); checkCnnlError(cnnlDestroyTensorDescriptor(descArray[i]));
} }
checkCnnlError(cnnlDestroyTensorDescriptor(desc)); checkCnnlError(cnnlDestroyTensorDescriptor(desc));

View File

@ -40,7 +40,7 @@ class NegTensorCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::NegTensor, DataType::Float32, NegTensorCnnl, REGISTER_KERNEL(Device::BANG, OpType::NegTensor, DataType::Float32,
"NegTensor_cnnl_BANG_Float32"); NegTensorCnnl, "NegTensor_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -1,6 +1,6 @@
#include "operators/pad.h"
#include "bang/bang_kernel_without_config.h" #include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h" #include "bang/bang_runtime.h"
#include "operators/pad.h"
namespace infini { namespace infini {
class PadCnnl : public BangKernelWithoutConfig { class PadCnnl : public BangKernelWithoutConfig {
@ -24,30 +24,31 @@ class PadCnnl : public BangKernelWithoutConfig {
if (pads.size() == 2 && dim_size != 1) { if (pads.size() == 2 && dim_size != 1) {
for (int i = 0; i < dim_size * 2; i += 2) { for (int i = 0; i < dim_size * 2; i += 2) {
paddings[i] = pads[0]; paddings[i] = pads[0];
paddings[i+1] = pads[1]; paddings[i + 1] = pads[1];
} }
} else { } else {
for (int i = 0; i < dim_size * 2; i += 2) { for (int i = 0; i < dim_size * 2; i += 2) {
paddings[i] = pads[i/2]; paddings[i] = pads[i / 2];
paddings[i+1] = pads[i/2 + dim_size]; paddings[i + 1] = pads[i / 2 + dim_size];
} }
} }
int dimout_array[dim_size]; int dimout_array[dim_size];
for (int i = 0; i < dim_size; ++i) { for (int i = 0; i < dim_size; ++i) {
dimout_array[i] = dim[i] + paddings[2*i] + paddings[2*i+1]; dimout_array[i] = dim[i] + paddings[2 * i] + paddings[2 * i + 1];
} }
float paddingValue = 0.0; float paddingValue = 0.0;
// input // input
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, dim_size, dim_array)); aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, dim_size, dim_array));
// output // output
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, dim_size, dimout_array)); CNNL_DTYPE_FLOAT, dim_size,
dimout_array));
cnnlStatus_t stat = cnnlStatus_t stat = cnnlPad(context->cnnlHandle(), aDesc, aData,
cnnlPad(context->cnnlHandle(), aDesc, aData, paddings, &paddingValue, cDesc, cData); paddings, &paddingValue, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -40,7 +40,7 @@ class ReciprocalCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Reciprocal, DataType::Float32, ReciprocalCnnl, REGISTER_KERNEL(Device::BANG, OpType::Reciprocal, DataType::Float32,
"Reciprocal_cnnl_BANG_Float32"); ReciprocalCnnl, "Reciprocal_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -29,7 +29,8 @@ class RsqrtCnnl : public BangKernelWithoutConfig {
CNNL_DTYPE_FLOAT, 4, dim_array)); CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlRsqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, aDesc, aData, cDesc, cData); cnnlRsqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,
aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -29,7 +29,8 @@ class SqrtCnnl : public BangKernelWithoutConfig {
CNNL_DTYPE_FLOAT, 4, dim_array)); CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlStatus_t stat = cnnlStatus_t stat =
cnnlSqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION, aDesc, aData, cDesc, cData); cnnlSqrt_v2(context->cnnlHandle(), CNNL_COMPUTATION_HIGH_PRECISION,
aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -25,8 +25,8 @@ class TransformCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW, checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array)); CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlStatus_t stat = cnnlStatus_t stat = cnnlTransform(context->cnnlHandle(), &alpha, cDesc,
cnnlTransform(context->cnnlHandle(), &alpha, cDesc, aData, &beta, cData); aData, &beta, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -36,7 +36,7 @@ class TransformCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Transform, DataType::Float32, TransformCnnl, REGISTER_KERNEL(Device::BANG, OpType::Transform, DataType::Float32,
"Transform_cnnl_BANG_Float32"); TransformCnnl, "Transform_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -1,6 +1,6 @@
#include "operators/transpose.h"
#include "bang/bang_kernel_without_config.h" #include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h" #include "bang/bang_runtime.h"
#include "operators/transpose.h"
namespace infini { namespace infini {
class TransposeCnnl : public BangKernelWithoutConfig { class TransposeCnnl : public BangKernelWithoutConfig {
@ -22,13 +22,13 @@ class TransposeCnnl : public BangKernelWithoutConfig {
int dimout_array[4] = {dimout[0], dimout[1], dimout[2], dimout[3]}; int dimout_array[4] = {dimout[0], dimout[1], dimout[2], dimout[3]};
// get inputs // get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, dimin_array)); aDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, 4, dimin_array));
// get outputs // get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc)); checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY, checkCnnlError(cnnlSetTensorDescriptor(
CNNL_DTYPE_FLOAT, 4, dimout_array)); cDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, 4, dimout_array));
// get op descriptor // get op descriptor
auto permute = op->getPermute(); auto permute = op->getPermute();
@ -37,12 +37,13 @@ class TransposeCnnl : public BangKernelWithoutConfig {
checkCnnlError(cnnlSetTransposeDescriptor(opDesc, 4, permute)); checkCnnlError(cnnlSetTransposeDescriptor(opDesc, 4, permute));
size_t wsSize; size_t wsSize;
cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), aDesc, opDesc, &wsSize); cnnlGetTransposeWorkspaceSize(context->cnnlHandle(), aDesc, opDesc,
&wsSize);
BangPtr wsData = context->getWorkspace(wsSize); BangPtr wsData = context->getWorkspace(wsSize);
cnnlStatus_t stat = cnnlTranspose_v2(context->cnnlHandle(), opDesc, cnnlStatus_t stat =
aDesc, aData, cDesc, cData, cnnlTranspose_v2(context->cnnlHandle(), opDesc, aDesc, aData, cDesc,
wsData, wsSize); cData, wsData, wsSize);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;
@ -54,6 +55,6 @@ class TransposeCnnl : public BangKernelWithoutConfig {
} }
}; };
REGISTER_KERNEL(Device::BANG, OpType::Transpose, DataType::Float32, TransposeCnnl, REGISTER_KERNEL(Device::BANG, OpType::Transpose, DataType::Float32,
"Transpose_cnnl_BANG_Float32"); TransposeCnnl, "Transpose_cnnl_BANG_Float32");
}; // namespace infini }; // namespace infini

View File

@ -33,12 +33,10 @@ class TrigonCnnl : public BangKernelWithoutConfig {
// get op descriptor // get op descriptor
cnnlTrigonDescriptor_t opDesc; cnnlTrigonDescriptor_t opDesc;
checkCnnlError(cnnlCreateTrigonDescriptor(&opDesc)); checkCnnlError(cnnlCreateTrigonDescriptor(&opDesc));
checkCnnlError(cnnlSetTrigonDescriptor( checkCnnlError(cnnlSetTrigonDescriptor(opDesc, getOpType()));
opDesc, getOpType()));
cnnlStatus_t stat = cnnlStatus_t stat = cnnlTrigonForward(context->cnnlHandle(), opDesc,
cnnlTrigonForward(context->cnnlHandle(), opDesc, aDesc, aDesc, aData, cDesc, cData);
aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS) if (stat != CNNL_STATUS_SUCCESS)
return; return;

View File

@ -1,17 +1,15 @@
#include "operators/activation_backward.h" #include "operators/activation_backward.h"
namespace infini { namespace infini {
ActivationBackwardObj::ActivationBackwardObj(OpType type, ActivationBackwardObj::ActivationBackwardObj(OpType type, GraphObj *graph,
GraphObj *graph, Tensor y, Tensor diff_y, Tensor x,
Tensor y,
Tensor diff_y,
Tensor x,
Tensor diff_x) Tensor diff_x)
: OperatorObj(type, {y, diff_y, x}, {diff_x}) { : OperatorObj(type, {y, diff_y, x}, {diff_x}) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> ActivationBackwardObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>>
ActivationBackwardObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0]; const auto A = inputs[0];
return {{A->getDims()}}; return {{A->getDims()}};
} }

View File

@ -183,9 +183,9 @@ void ConvTransposed2dObj::setAuxilaryAttributes(PaddingMode mode) {
void ConvBackwardFilterObj::setAuxilaryAttributes(PaddingMode mode) { void ConvBackwardFilterObj::setAuxilaryAttributes(PaddingMode mode) {
const Tensor &inputX = inputs[0]; const Tensor &inputX = inputs[0];
const Tensor &diffY = inputs[1]; const Tensor &diffY = inputs[1];
n = inputX->getDims()[0], c = inputX->getDims()[1], h = inputX->getDims()[2], n = inputX->getDims()[0], c = inputX->getDims()[1],
w = inputX->getDims()[3], f = diffY->getDims()[0], r = diffY->getDims()[2], h = inputX->getDims()[2], w = inputX->getDims()[3], f = diffY->getDims()[0],
s = diffY->getDims()[3]; r = diffY->getDims()[2], s = diffY->getDims()[3];
if (mode == PaddingMode::Same) { if (mode == PaddingMode::Same) {
int oh = h / sh; int oh = h / sh;
int ow = w / sw; int ow = w / sw;
@ -196,9 +196,10 @@ void ConvBackwardFilterObj::setAuxilaryAttributes(PaddingMode mode) {
} }
} }
ConvBackwardFilterObj::ConvBackwardFilterObj(GraphObj *graph, Tensor inputX, Tensor diffY, Tensor diffW, ConvBackwardFilterObj::ConvBackwardFilterObj(GraphObj *graph, Tensor inputX,
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias, Tensor diffY, Tensor diffW, int ph,
ActType act) int pw, int sh, int sw, int dh,
int dw, Tensor bias, ActType act)
: ConvBaseObj(OpType::Conv, {inputX, diffY}, diffW, ph, pw, sh, sw, dh, dw, : ConvBaseObj(OpType::Conv, {inputX, diffY}, diffW, ph, pw, sh, sw, dh, dw,
inputX, diffY), inputX, diffY),
act(act) { act(act) {
@ -208,9 +209,11 @@ ConvBackwardFilterObj::ConvBackwardFilterObj(GraphObj *graph, Tensor inputX, Ten
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
ConvBackwardFilterObj::ConvBackwardFilterObj(GraphObj *graph, Tensor inputX, Tensor diffY, Tensor diffW, ConvBackwardFilterObj::ConvBackwardFilterObj(GraphObj *graph, Tensor inputX,
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias, Tensor diffY, Tensor diffW,
ActType act) PaddingMode mode, int sh, int sw,
int dh, int dw, Tensor bias,
ActType act)
: ConvBaseObj(OpType::Conv, {inputX, diffY}, diffW, mode, sh, sw, dh, dw, : ConvBaseObj(OpType::Conv, {inputX, diffY}, diffW, mode, sh, sw, dh, dw,
inputX, diffY), inputX, diffY),
act(act) { act(act) {
@ -220,7 +223,8 @@ ConvBackwardFilterObj::ConvBackwardFilterObj(GraphObj *graph, Tensor inputX, Ten
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> ConvBackwardFilterObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>>
ConvBackwardFilterObj::inferShape(const TensorVec &inputs) const {
const auto &inputX = inputs[0], &diffY = inputs[1]; const auto &inputX = inputs[0], &diffY = inputs[1];
auto n = inputX->getDims()[0]; auto n = inputX->getDims()[0];
auto h = inputX->getDims()[2]; auto h = inputX->getDims()[2];
@ -251,5 +255,4 @@ optional<vector<Shape>> ConvBackwardFilterObj::inferShape(const TensorVec &input
return {{{on, oc, oh, ow}}}; return {{{on, oc, oh, ow}}};
} }
} // namespace infini } // namespace infini

View File

@ -11,7 +11,7 @@ optional<vector<Shape>> DetObj::inferShape(const TensorVec &inputs) const {
auto input = A->getDims(); auto input = A->getDims();
int length = input.size(); int length = input.size();
if (length == 2) { if (length == 2) {
std::vector<int> output ={1}; std::vector<int> output = {1};
return {{output}}; return {{output}};
} else { } else {
std::vector<int> output(input.begin(), input.end() - 2); std::vector<int> output(input.begin(), input.end() - 2);

View File

@ -54,24 +54,24 @@ vector<int> ElementWiseObj::getOpAttrVector() const {
return {enum_to_underlying(type)}; return {enum_to_underlying(type)};
} }
MSELossObj::MSELossObj(GraphObj *graph, Tensor input0, Tensor input1,
MSELossObj::MSELossObj(GraphObj *graph, Tensor input0, Tensor input1, Reduction reduction, Tensor output) Reduction reduction, Tensor output)
: OperatorObj(OpType::MSELoss, {input0, input1}, {output}), reductionMode(reduction) { : OperatorObj(OpType::MSELoss, {input0, input1}, {output}),
reductionMode(reduction) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> optional<vector<Shape>> MSELossObj::inferShape(const TensorVec &inputs) const {
MSELossObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0], B = inputs[1]; const auto A = inputs[0], B = inputs[1];
if (A->getDims().size() != B->getDims().size() || if (A->getDims().size() != B->getDims().size() ||
A->getDims() != B->getDims()) A->getDims() != B->getDims())
return {}; return {};
if (reductionMode == None) { if (reductionMode == None) {
return {{A->getDims()}}; return {{A->getDims()}};
} else { } else {
Shape temp = { 1 }; Shape temp = {1};
return {{temp}}; return {{temp}};
} }
} }
@ -100,20 +100,19 @@ vector<int> MSELossObj::getOpAttrVector() const {
AddNObj::AddNObj(GraphObj *graph, int tensorNum, Tensor output, ...) AddNObj::AddNObj(GraphObj *graph, int tensorNum, Tensor output, ...)
: OperatorObj(OpType::AddN), num(tensorNum) { : OperatorObj(OpType::AddN), num(tensorNum) {
TensorVec temp; TensorVec temp;
Tensor *start = &output; Tensor *start = &output;
++start; ++start;
for(int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
temp.push_back(*start); temp.push_back(*start);
start++; start++;
} }
setOutputs({output}); setOutputs({output});
setInputs(temp); setInputs(temp);
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> optional<vector<Shape>> AddNObj::inferShape(const TensorVec &inputs) const {
AddNObj::inferShape(const TensorVec &inputs) const {
// For now,we only process the same dims here, broardcast will be considered // For now,we only process the same dims here, broardcast will be considered
// in the opt layer. // in the opt layer.
const auto A = inputs[0]; const auto A = inputs[0];
@ -144,20 +143,19 @@ vector<int> AddNObj::getOpAttrVector() const {
MulNObj::MulNObj(GraphObj *graph, int tensorNum, Tensor output, ...) MulNObj::MulNObj(GraphObj *graph, int tensorNum, Tensor output, ...)
: OperatorObj(OpType::MulN), num(tensorNum) { : OperatorObj(OpType::MulN), num(tensorNum) {
TensorVec temp; TensorVec temp;
Tensor *start = &output; Tensor *start = &output;
++start; ++start;
for(int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
temp.push_back(*start); temp.push_back(*start);
start++; start++;
} }
setOutputs({output}); setOutputs({output});
setInputs(temp); setInputs(temp);
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> optional<vector<Shape>> MulNObj::inferShape(const TensorVec &inputs) const {
MulNObj::inferShape(const TensorVec &inputs) const {
// For now,we only process the same dims here, broardcast will be considered // For now,we only process the same dims here, broardcast will be considered
// in the opt layer. // in the opt layer.
const auto A = inputs[0]; const auto A = inputs[0];

View File

@ -1,21 +1,23 @@
#include "operators/transpose.h" #include "operators/transpose.h"
namespace infini { namespace infini {
TransposeObj::TransposeObj(GraphObj *graph, Tensor input, Tensor output, int permute[4]) TransposeObj::TransposeObj(GraphObj *graph, Tensor input, Tensor output,
int permute[4])
: OperatorObj(OpType::Transpose, {input}, {output}) { : OperatorObj(OpType::Transpose, {input}, {output}) {
transposePermute[0] = permute[0]; transposePermute[0] = permute[0];
transposePermute[1] = permute[1]; transposePermute[1] = permute[1];
transposePermute[2] = permute[2]; transposePermute[2] = permute[2];
transposePermute[3] = permute[3]; transposePermute[3] = permute[3];
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> TransposeObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>>
TransposeObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0]; const auto A = inputs[0];
auto input = A->getDims(); auto input = A->getDims();
auto output = input; auto output = input;
for(int i = 0; i < 4; ++i){ for (int i = 0; i < 4; ++i) {
output[i] = input[transposePermute[i]]; output[i] = input[transposePermute[i]];
} }
return {{output}}; return {{output}};

View File

@ -32,8 +32,10 @@ vector<int> UnaryObj::getOpAttrVector() const {
return {enum_to_underlying(type)}; return {enum_to_underlying(type)};
} }
ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output, float min, float max) ClipObj::ClipObj(GraphObj *graph, Tensor input, Tensor output, float min,
: OperatorObj(OpType::Clip, {input}, {output}), minValue(min), maxValue(max) { float max)
: OperatorObj(OpType::Clip, {input}, {output}), minValue(min),
maxValue(max) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
@ -64,7 +66,7 @@ vector<int> ClipObj::getOpAttrVector() const {
} }
FillObj::FillObj(GraphObj *graph, Tensor input, Tensor output, float value) FillObj::FillObj(GraphObj *graph, Tensor input, Tensor output, float value)
: OperatorObj(OpType::Fill, {input}, {output}), setValue(value) { : OperatorObj(OpType::Fill, {input}, {output}), setValue(value) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
@ -98,7 +100,7 @@ L2LossObj::L2LossObj(GraphObj *graph, Tensor input, Tensor output)
} }
optional<vector<Shape>> L2LossObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> L2LossObj::inferShape(const TensorVec &inputs) const {
Shape temp = { 1 }; Shape temp = {1};
return {{temp}}; return {{temp}};
} }
@ -121,12 +123,15 @@ vector<int> L2LossObj::getOpAttrVector() const {
return {enum_to_underlying(type)}; return {enum_to_underlying(type)};
} }
TransformObj::TransformObj(GraphObj *graph, Tensor input, Tensor output, float alpha, float beta) TransformObj::TransformObj(GraphObj *graph, Tensor input, Tensor output,
: OperatorObj(OpType::Transform, {input}, {output}), alphaValue(alpha), betaValue(beta) { float alpha, float beta)
: OperatorObj(OpType::Transform, {input}, {output}), alphaValue(alpha),
betaValue(beta) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> TransformObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>>
TransformObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0]; const auto A = inputs[0];
return {{A->getDims()}}; return {{A->getDims()}};
} }
@ -179,8 +184,10 @@ vector<int> CastObj::getOpAttrVector() const {
return {enum_to_underlying(type)}; return {enum_to_underlying(type)};
} }
CumsumObj::CumsumObj(GraphObj *graph, Tensor input, Tensor output, int axis, bool exclusive, bool reverse) CumsumObj::CumsumObj(GraphObj *graph, Tensor input, Tensor output, int axis,
: OperatorObj(OpType::Cumsum, {input}, {output}), axisValue(axis), exclusiveValue(exclusive), reverseValue(reverse) { bool exclusive, bool reverse)
: OperatorObj(OpType::Cumsum, {input}, {output}), axisValue(axis),
exclusiveValue(exclusive), reverseValue(reverse) {
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
@ -208,12 +215,15 @@ vector<int> CumsumObj::getOpAttrVector() const {
return {enum_to_underlying(type)}; return {enum_to_underlying(type)};
} }
// CumprodObj::CumprodObj(GraphObj *graph, Tensor input, Tensor output, int axis, bool exclusive, bool reverse) // CumprodObj::CumprodObj(GraphObj *graph, Tensor input, Tensor output, int
// : OperatorObj(OpType::Cumprod, {input}, {output}), axisValue(axis), exclusiveValue(exclusive), reverseValue(reverse) { // axis, bool exclusive, bool reverse)
// : OperatorObj(OpType::Cumprod, {input}, {output}), axisValue(axis),
// exclusiveValue(exclusive), reverseValue(reverse) {
// IT_ASSERT(checkValid(graph)); // IT_ASSERT(checkValid(graph));
// } // }
// //
// optional<vector<Shape>> CumprodObj::inferShape(const TensorVec &inputs) const { // optional<vector<Shape>> CumprodObj::inferShape(const TensorVec &inputs) const
// {
// const auto A = inputs[0]; // const auto A = inputs[0];
// return {{A->getDims()}}; // return {{A->getDims()}};
// } // }

View File

@ -3,16 +3,17 @@
#include "core/kernel.h" #include "core/kernel.h"
#include "core/runtime.h" #include "core/runtime.h"
#include "operators/activation_backward.h" #include "operators/activation_backward.h"
#include "operators/unary.h"
#include "operators/element_wise.h" #include "operators/element_wise.h"
#include "operators/unary.h"
#include "test.h" #include "test.h"
namespace infini { namespace infini {
template <class T, class D> template <class T, class D>
void testActivationBackward(const std::function<void(void *, size_t, DataType)> &generator, void testActivationBackward(
const Shape &shape) { const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();
@ -44,9 +45,12 @@ void testActivationBackward(const std::function<void(void *, size_t, DataType)>
} }
TEST(cnnl_ActivationBackward, run) { TEST(cnnl_ActivationBackward, run) {
testActivationBackward<ReluBackwardObj, ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testActivationBackward<ReluBackwardObj, ReluObj>(IncrementalGenerator(),
testActivationBackward<SigmoidBackwardObj, SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); Shape{1, 2, 2, 3});
testActivationBackward<TanhBackwardObj, TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testActivationBackward<SigmoidBackwardObj, SigmoidObj>(
IncrementalGenerator(), Shape{1, 2, 2, 3});
testActivationBackward<TanhBackwardObj, TanhObj>(IncrementalGenerator(),
Shape{1, 2, 2, 3});
} }
} // namespace infini } // namespace infini

View File

@ -9,9 +9,8 @@
namespace infini { namespace infini {
template <class T> template <class T>
void testaddN( void testaddN(const std::function<void(void *, size_t, DataType)> &generator,
const std::function<void(void *, size_t, DataType)> &generator, const Shape &shape) {
const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testCast(const std::function<void(void *, size_t, DataType)> &generator, void testCast(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testCeil(const std::function<void(void *, size_t, DataType)> &generator, void testCeil(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testClip(const std::function<void(void *, size_t, DataType)> &generator, void testClip(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testCopy(const std::function<void(void *, size_t, DataType)> &generator, void testCopy(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -9,8 +9,8 @@
namespace infini { namespace infini {
template <class T> template <class T>
void testCumsum(const std::function<void(void *, size_t, DataType)> &generator, int axis, void testCumsum(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { int axis, const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testDet(const std::function<void(void *, size_t, DataType)> &generator, void testDet(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -9,9 +9,8 @@
namespace infini { namespace infini {
template <class T> template <class T>
void testDivDemo( void testDivDemo(const std::function<void(void *, size_t, DataType)> &generator,
const std::function<void(void *, size_t, DataType)> &generator, const Shape &shape) {
const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testErf(const std::function<void(void *, size_t, DataType)> &generator, void testErf(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testExp(const std::function<void(void *, size_t, DataType)> &generator, void testExp(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testFill(const std::function<void(void *, size_t, DataType)> &generator, void testFill(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -43,7 +43,8 @@ void testFloorDivTrunc(
} }
TEST(cnnl_FloorDivTrunc, run) { TEST(cnnl_FloorDivTrunc, run) {
testFloorDivTrunc<FloorDivTruncObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testFloorDivTrunc<FloorDivTruncObj>(IncrementalGenerator(),
Shape{1, 2, 2, 3});
} }
} // namespace infini } // namespace infini

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testL2Loss(const std::function<void(void *, size_t, DataType)> &generator, void testL2Loss(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testLog(const std::function<void(void *, size_t, DataType)> &generator, void testLog(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -9,9 +9,8 @@
namespace infini { namespace infini {
template <class T> template <class T>
void testMaximum( void testMaximum(const std::function<void(void *, size_t, DataType)> &generator,
const std::function<void(void *, size_t, DataType)> &generator, const Shape &shape) {
const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -9,9 +9,8 @@
namespace infini { namespace infini {
template <class T> template <class T>
void testMinimum( void testMinimum(const std::function<void(void *, size_t, DataType)> &generator,
const std::function<void(void *, size_t, DataType)> &generator, const Shape &shape) {
const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -9,9 +9,8 @@
namespace infini { namespace infini {
template <class T> template <class T>
void testMSELoss( void testMSELoss(const std::function<void(void *, size_t, DataType)> &generator,
const std::function<void(void *, size_t, DataType)> &generator, const Shape &shape) {
const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();
@ -30,9 +29,12 @@ void testMSELoss(
Graph bangGraph = make_ref<GraphObj>(bangRuntime); Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu1 = bangGraph->cloneTensor(inputCpu1); auto inputGpu1 = bangGraph->cloneTensor(inputCpu1);
auto inputGpu2 = bangGraph->cloneTensor(inputCpu2); auto inputGpu2 = bangGraph->cloneTensor(inputCpu2);
auto gpuOp1 = bangGraph->addOp<T>(inputGpu1, inputGpu2, MSELossObj::None, nullptr); auto gpuOp1 =
auto gpuOp2 = bangGraph->addOp<T>(inputGpu1, inputGpu2, MSELossObj::Sum, nullptr); bangGraph->addOp<T>(inputGpu1, inputGpu2, MSELossObj::None, nullptr);
auto gpuOp3 = bangGraph->addOp<T>(inputGpu1, inputGpu2, MSELossObj::Mean, nullptr); auto gpuOp2 =
bangGraph->addOp<T>(inputGpu1, inputGpu2, MSELossObj::Sum, nullptr);
auto gpuOp3 =
bangGraph->addOp<T>(inputGpu1, inputGpu2, MSELossObj::Mean, nullptr);
bangGraph->dataMalloc(); bangGraph->dataMalloc();
bangRuntime->run(bangGraph); bangRuntime->run(bangGraph);
auto outputGpu1 = gpuOp1->getOutput(); auto outputGpu1 = gpuOp1->getOutput();

View File

@ -9,9 +9,8 @@
namespace infini { namespace infini {
template <class T> template <class T>
void testmulN( void testmulN(const std::function<void(void *, size_t, DataType)> &generator,
const std::function<void(void *, size_t, DataType)> &generator, const Shape &shape) {
const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -9,8 +9,9 @@
namespace infini { namespace infini {
template <class T> template <class T>
void testNegTensor(const std::function<void(void *, size_t, DataType)> &generator, void testNegTensor(
const Shape &shape) { const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testPad(const std::function<void(void *, size_t, DataType)> &generator, void testPad(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();
@ -23,7 +23,8 @@ void testPad(const std::function<void(void *, size_t, DataType)> &generator,
// GPU // GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime); Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu = bangGraph->cloneTensor(inputCpu); auto inputGpu = bangGraph->cloneTensor(inputCpu);
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, vector<int>{1,1,1,1}, vector<int>{0,3}); auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, vector<int>{1, 1, 1, 1},
vector<int>{0, 3});
bangGraph->dataMalloc(); bangGraph->dataMalloc();
bangRuntime->run(bangGraph); bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput(); auto outputGpu = gpuOp->getOutput();

View File

@ -9,9 +9,8 @@
namespace infini { namespace infini {
template <class T> template <class T>
void testPow( void testPow(const std::function<void(void *, size_t, DataType)> &generator,
const std::function<void(void *, size_t, DataType)> &generator, const Shape &shape) {
const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -9,8 +9,9 @@
namespace infini { namespace infini {
template <class T> template <class T>
void testReciprocal(const std::function<void(void *, size_t, DataType)> &generator, void testReciprocal(
const Shape &shape) { const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testSqrt(const std::function<void(void *, size_t, DataType)> &generator, void testSqrt(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -9,8 +9,9 @@
namespace infini { namespace infini {
template <class T> template <class T>
void testTransform(const std::function<void(void *, size_t, DataType)> &generator, void testTransform(
const Shape &shape) { const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -9,8 +9,9 @@
namespace infini { namespace infini {
template <class T> template <class T>
void testTranspose(const std::function<void(void *, size_t, DataType)> &generator, void testTranspose(
const Shape &shape) { const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();

View File

@ -10,7 +10,7 @@ namespace infini {
template <class T> template <class T>
void testTrigon(const std::function<void(void *, size_t, DataType)> &generator, void testTrigon(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) { const Shape &shape) {
// Runtime // Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance(); Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>(); auto bangRuntime = make_ref<BangRuntimeObj>();