add det operation

This commit is contained in:
wanghailu 2023-01-04 09:24:52 +00:00
parent 68f4630dac
commit c19d6e6bb0
5 changed files with 159 additions and 0 deletions

View File

@ -88,6 +88,7 @@ enum class OpType {
FloorModTrunc,
Cumsum,
Cumprod,
Det,
//
MemBound = 300,
};
@ -184,6 +185,7 @@ class OpRegistry {
FOP(FloorModTrunc);
FOP(Cumsum);
FOP(Cumprod);
FOP(Det);
//
FOP(MemBound);
default:

21
include/operators/det.h Normal file
View File

@ -0,0 +1,21 @@
#pragma once
#include "core/operator.h"
namespace infini {
class DetObj : public OperatorObj {
public:
enum Mode { NormalDet = 0, LogDet };
DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
Mode getMode() const { return modeValue; }
private:
Mode modeValue;
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
}; // namespace infini

52
src/kernels/bang/det.cc Normal file
View File

@ -0,0 +1,52 @@
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
#include "operators/det.h"
namespace infini {
class DetCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<DetObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
DetObj::Mode mode = op->getMode();
cnnlDetMode_t nlMode;
if(mode == DetObj::LogDet) {
nlMode = CNNL_DET_MODE_LOGDET;
} else {
nlMode = CNNL_DET_MODE_DET;
}
cnnlTensorDescriptor_t aDesc, cDesc;
auto dimin = op->getInputs(0)->getDims();
auto dimout = op->getOutput()->getDims();
if (dimin.size() != 4 || dimout.size() != 2)
IT_TODO_HALT();
int dimin_array[4] = {dimin[0], dimin[1], dimin[2], dimin[3]};
int dimout_array[2] = {dimout[0], dimout[1]};
// get inputs
checkCnnlError(cnnlCreateTensorDescriptor(&aDesc));
checkCnnlError(cnnlSetTensorDescriptor(aDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, 4, dimin_array));
// get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, 2, dimout_array));
cnnlStatus_t stat = cnnlDet(context->cnnlHandle(), nlMode, aDesc, aData, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;
// Destories in BANG does not require sync. But cnnl does not state
// whether sync is required before destories.
checkCnnlError(cnnlDestroyTensorDescriptor(aDesc));
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
}
};
REGISTER_KERNEL(Device::BANG, OpType::Det, DataType::Float32, DetCnnl,
"Det_cnnl_BANG_Float32");
}; // namespace infini

43
src/operators/det.cc Normal file
View File

@ -0,0 +1,43 @@
#include "operators/det.h"
namespace infini {
DetObj::DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode)
: OperatorObj(OpType::Det, {input}, {output}), modeValue(mode) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> DetObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0];
auto input = A->getDims();
int length = input.size();
if (length == 2) {
std::vector<int> output ={1};
return {{output}};
} else {
std::vector<int> output(input.begin(), input.end() - 2);
return {{output}};
}
}
std::string DetObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> DetObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> DetObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
}; // namespace infini

View File

@ -0,0 +1,41 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/det.h"
#include "test.h"
namespace infini {
template <class T>
void testDet(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu->dataMalloc();
inputCpu->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu = bangGraph->cloneTensor(inputCpu);
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, DetObj::NormalDet);
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu = gpuOp->getOutput();
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
// Check
inputCpu->printData();
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_Det, run) {
testDet<DetObj>(IncrementalGenerator(), Shape{1, 1, 2, 2});
}
} // namespace infini