add operation fill

This commit is contained in:
wanghailu 2022-12-19 02:59:33 +00:00
parent 82f510672d
commit 084063a68f
5 changed files with 127 additions and 0 deletions

View File

@ -64,6 +64,7 @@ enum class OpType {
Clip,
Erf,
Exp,
Fill,
//
MemBound = 300,
};
@ -136,6 +137,7 @@ class OpRegistry {
FOP(Clip);
FOP(Erf);
FOP(Exp);
FOP(Fill);
//
FOP(MemBound);
default:

View File

@ -33,6 +33,22 @@ class ClipObj : public OperatorObj {
vector<int> getOpAttrVector() const override;
};
class FillObj : public OperatorObj {
public:
FillObj(GraphObj *graph, Tensor input, Tensor output, float value);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;
float getValue() const { return setValue; };
int numInputs() const override { return 1; }
int numOutputs() const override { return 1; }
private:
float setValue;
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
#define DEFINE_UNARY_OBJ(prefix, type) \
class prefix##Obj : public UnaryObj { \
public: \

40
src/kernels/bang/fill.cc Normal file
View File

@ -0,0 +1,40 @@
#include "bang/bang_kernel_without_config.h"
#include "bang/bang_runtime.h"
#include "operators/unary.h"
namespace infini {
class FillCnnl : public BangKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<FillObj>(_op);
auto context = dynamic_cast<const BangRuntimeObj *>(_context);
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
float value = op->getValue();
cnnlTensorDescriptor_t cDesc;
auto dim = op->getOutput()->getDims();
if (dim.size() != 4)
IT_TODO_HALT();
int dim_array[4] = {dim[0], dim[1], dim[2], dim[3]};
// get outputs
checkCnnlError(cnnlCreateTensorDescriptor(&cDesc));
checkCnnlError(cnnlSetTensorDescriptor(cDesc, CNNL_LAYOUT_NCHW,
CNNL_DTYPE_FLOAT, 4, dim_array));
cnnlStatus_t stat =
cnnlFill(context->cnnlHandle(), value, cDesc, cData);
if (stat != CNNL_STATUS_SUCCESS)
return;
// Destories in BANG does not require sync. But cnnl does not state
// whether sync is required before destories.
checkCnnlError(cnnlDestroyTensorDescriptor(cDesc));
}
};
REGISTER_KERNEL(Device::BANG, OpType::Fill, DataType::Float32, FillCnnl,
"Fill_cnnl_BANG_Float32");
}; // namespace infini

View File

@ -63,4 +63,33 @@ vector<int> ClipObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
FillObj::FillObj(GraphObj *graph, Tensor input, Tensor output, float value)
: OperatorObj(OpType::Fill, {input}, {output}), setValue(value) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> FillObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0];
return {{A->getDims()}};
}
std::string FillObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
os << "(";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> FillObj::getWorkloadVector() const {
vector<int> ret{enum_to_underlying(type)};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> FillObj::getOpAttrVector() const {
return {enum_to_underlying(type)};
}
}; // namespace infini

View File

@ -0,0 +1,40 @@
#include "bang/bang_runtime.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
template <class T>
void testFill(const std::function<void(void *, size_t, DataType)> &generator,
const Shape &shape) {
// Runtime
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
auto bangRuntime = make_ref<BangRuntimeObj>();
// Build input data on CPU
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
inputCpu->dataMalloc();
inputCpu->setData(generator);
// GPU
Graph bangGraph = make_ref<GraphObj>(bangRuntime);
auto inputGpu = bangGraph->cloneTensor(inputCpu);
float value = 1.0;
auto gpuOp = bangGraph->addOp<T>(inputGpu, nullptr, value);
auto outputGpu = gpuOp->getOutput();
bangGraph->dataMalloc();
bangRuntime->run(bangGraph);
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
outputGpu2Cpu->printData();
EXPECT_TRUE(1);
}
TEST(cnnl_Fill, run) {
testFill<FillObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
}
} // namespace infini