forked from jiuyuan/InfiniTensor
add mseloss operation
This commit is contained in:
parent
4ad648fa36
commit
0707fb6aff
|
@ -72,6 +72,7 @@ enum class OpType {
|
|||
L2Loss,
|
||||
Maximum,
|
||||
Minimum,
|
||||
MSELoss,
|
||||
//
|
||||
MemBound = 300,
|
||||
};
|
||||
|
@ -152,6 +153,7 @@ class OpRegistry {
|
|||
FOP(L2Loss);
|
||||
FOP(Maximum);
|
||||
FOP(Minimum);
|
||||
FOP(MSELoss);
|
||||
//
|
||||
FOP(MemBound);
|
||||
default:
|
||||
|
|
|
@ -17,6 +17,23 @@ class ElementWiseObj : public OperatorObj {
|
|||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
class MSELossObj : public OperatorObj {
|
||||
public:
|
||||
enum Reduction { None = 0, Sum, Mean };
|
||||
MSELossObj(GraphObj *graph, Tensor input0, Tensor input1, Reduction reduction, Tensor output);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
Reduction getReduction() const { return reductionMode; }
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
private:
|
||||
Reduction reductionMode;
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
#define DEFINE_ELEMENT_WISE_OBJ(prefix, type) \
|
||||
class prefix##Obj : public ElementWiseObj { \
|
||||
public: \
|
||||
|
|
|
@ -262,6 +262,64 @@ class MinimumCnnl : public BangKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
class MSELossCnnl : public BangKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<MSELossObj>(_op);
|
||||
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
MSELossObj::Reduction reduction = op->getReduction();
|
||||
cnnlTensorDescriptor_t aDesc, bDesc, cDesc;
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
|
||||
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
|
||||
int dim_out[4] ={1,1,1,1};
|
||||
// get inputs
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&bDesc));
|
||||
checkCnnlError(cnnlSetTensorDescriptor(bDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
|
||||
// get outputs
|
||||
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
|
||||
if ( reduction == MSELossObj::None ) {
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, dim_array));
|
||||
} else {
|
||||
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
|
||||
CNNL_DTYPE_FLOAT, 4, dim_out));
|
||||
}
|
||||
cnnlStatus_t stat;
|
||||
if( reduction == MSELossObj::None ) {
|
||||
stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_NONE, aDesc, aData, bDesc, bData,
|
||||
cDesc, cData);
|
||||
} else if (reduction == MSELossObj::Sum) {
|
||||
stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_SUM, aDesc, aData, bDesc, bData,
|
||||
cDesc, cData);
|
||||
} else {
|
||||
stat = cnnlMSELoss(context->cnnlHandle(), CNNL_MSE_LOSS_MEAN, aDesc, aData, bDesc, bData,
|
||||
cDesc, cData);
|
||||
}
|
||||
|
||||
if (stat != CNNL_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
// Destories in BANG does not require sync. But cnnl does not state
|
||||
// whether sync is required before destories.
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(bDesc));
|
||||
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
|
||||
}
|
||||
};
|
||||
|
||||
class AddCnnl : public ElementWiseCnnl {
|
||||
cnnlOpTensorDesc_t getOpType() const override { return CNNL_OP_TENSOR_ADD; }
|
||||
};
|
||||
|
@ -301,6 +359,8 @@ REGISTER_KERNEL(Device::BANG, OpType::Maximum, DataType::Float32, MaximumCnnl,
|
|||
"Maximum_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::Minimum, DataType::Float32, MinimumCnnl,
|
||||
"Minimum_cnnl_BANG_Float32");
|
||||
REGISTER_KERNEL(Device::BANG, OpType::MSELoss, DataType::Float32, MSELossCnnl,
|
||||
"MSELoss_cnnl_BANG_Float32");
|
||||
// REGISTER_KERNEL(Device::BANG, OpType::Pow, DataType::Float32,
|
||||
// ElementWiseBang,
|
||||
// "Pow_Bang_Float32");
|
||||
|
|
|
@ -54,4 +54,48 @@ vector<int> ElementWiseObj::getOpAttrVector() const {
|
|||
return {enum_to_underlying(type)};
|
||||
}
|
||||
|
||||
|
||||
MSELossObj::MSELossObj(GraphObj *graph, Tensor input0, Tensor input1, Reduction reduction, Tensor output)
|
||||
: OperatorObj(OpType::MSELoss, {input0, input1}, {output}), reductionMode(reduction) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
MSELossObj::inferShape(const TensorVec &inputs) const {
|
||||
const auto A = inputs[0], B = inputs[1];
|
||||
if (A->getDims().size() != B->getDims().size() ||
|
||||
A->getDims() != B->getDims())
|
||||
return {};
|
||||
|
||||
if (reductionMode == None) {
|
||||
return {{A->getDims()}};
|
||||
} else {
|
||||
Shape temp = { 1 };
|
||||
return {{temp}};
|
||||
}
|
||||
}
|
||||
|
||||
std::string MSELossObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << vecToString(inputs[1]->getDims()) << ",";
|
||||
os << "input0=" << inputs[0]->getGuid() << ",";
|
||||
os << "input1=" << inputs[1]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
// use output dim or inputs dim?
|
||||
vector<int> MSELossObj::getWorkloadVector() const {
|
||||
vector<int> ret = outputs[0]->getDims();
|
||||
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> MSELossObj::getOpAttrVector() const {
|
||||
return {enum_to_underlying(type)};
|
||||
}
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
#include "bang/bang_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/element_wise.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testMSELoss(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu1 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu1->dataMalloc();
|
||||
inputCpu1->setData(generator);
|
||||
Tensor inputCpu2 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu2->dataMalloc();
|
||||
inputCpu2->setData(generator);
|
||||
|
||||
// GPU
|
||||
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
|
||||
auto inputGpu1 = bangGraph->cloneTensor(inputCpu1);
|
||||
auto inputGpu2 = bangGraph->cloneTensor(inputCpu2);
|
||||
auto gpuOp1 = bangGraph->addOp<T>(inputGpu1, inputGpu2, MSELossObj::None, nullptr);
|
||||
auto gpuOp2 = bangGraph->addOp<T>(inputGpu1, inputGpu2, MSELossObj::Sum, nullptr);
|
||||
auto gpuOp3 = bangGraph->addOp<T>(inputGpu1, inputGpu2, MSELossObj::Mean, nullptr);
|
||||
bangGraph->dataMalloc();
|
||||
bangRuntime->run(bangGraph);
|
||||
auto outputGpu1 = gpuOp1->getOutput();
|
||||
auto outputGpu2 = gpuOp2->getOutput();
|
||||
auto outputGpu3 = gpuOp3->getOutput();
|
||||
auto outputGpu2Cpu1 = outputGpu1->clone(cpuRuntime);
|
||||
auto outputGpu2Cpu2 = outputGpu2->clone(cpuRuntime);
|
||||
auto outputGpu2Cpu3 = outputGpu3->clone(cpuRuntime);
|
||||
// Check
|
||||
outputGpu2Cpu1->printData();
|
||||
outputGpu2Cpu2->printData();
|
||||
outputGpu2Cpu3->printData();
|
||||
EXPECT_TRUE(1);
|
||||
}
|
||||
|
||||
TEST(cnnl_MSELoss, run) {
|
||||
testMSELoss<MSELossObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue