forked from jiuyuan/InfiniTensor
ADD add/mul/sub/div/pow operators and CPU/CUDA kernels (#26)
Fix some remove useless code. add div/pow kernel Add add/mul/sub operators. fix cpu kernel. add element wise kenerl for cuda. ADD element wise operator.
This commit is contained in:
parent
0409eafb5f
commit
13b7a2604b
|
@ -71,6 +71,8 @@ class TensorObj : public TensorBaseObj {
|
|||
return equalDataImpl(getRawDataPtr<T *>(), dataVector.data(), size());
|
||||
}
|
||||
|
||||
size_t getOffsetByBroadcastOffset(size_t bcOffset, Shape bcShape) const;
|
||||
|
||||
private:
|
||||
void printDataFloat() const;
|
||||
void printDataUint32_t() const;
|
||||
|
@ -92,6 +94,10 @@ class TensorObj : public TensorBaseObj {
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Shape getPosByOffset(size_t offset, Shape dim) const;
|
||||
size_t getOffsetByPos(Shape pos, Shape dim) const;
|
||||
|
||||
// void setDims(const Dim &dms) { dims = dms; }
|
||||
|
||||
// bool dataRand(int seed = 0) {
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
#pragma once
|
||||
|
||||
#include "operators/element_wise.h"
|
||||
|
||||
namespace infini {
|
||||
void div_kernel(float *a, float *b, float *c, int num);
|
||||
void pow_kernel(float *a, float *b, float *c, int num);
|
||||
|
||||
void element_wise_kernel(const Operator &_op) {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
float *const aData = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||
float *const bData = (op->getInputs(1)->getRawDataPtr<float *>());
|
||||
float *const cData = (op->getOutput()->getRawDataPtr<float *>());
|
||||
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
|
||||
if (op->getOpType() == OpType::Div)
|
||||
div_kernel(aData, bData, cData, n * c * h * w);
|
||||
else if (op->getOpType() == OpType::Pow)
|
||||
pow_kernel(aData, bData, cData, n * c * h * w);
|
||||
else
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,33 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
class ElementWiseObj : public OperatorObj {
|
||||
public:
|
||||
ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, Tensor input1,
|
||||
Tensor output);
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
||||
#define DEFINE_ELEMENT_WISE_OBJ(prefix, type) \
|
||||
class prefix##Obj : public ElementWiseObj { \
|
||||
public: \
|
||||
prefix##Obj(GraphObj *graph, Tensor input0, Tensor input1, \
|
||||
Tensor output) \
|
||||
: ElementWiseObj(type, graph, input0, input1, output) {} \
|
||||
};
|
||||
|
||||
DEFINE_ELEMENT_WISE_OBJ(Add, OpType::Add)
|
||||
DEFINE_ELEMENT_WISE_OBJ(Sub, OpType::Sub)
|
||||
DEFINE_ELEMENT_WISE_OBJ(Mul, OpType::Mul)
|
||||
DEFINE_ELEMENT_WISE_OBJ(Div, OpType::Div)
|
||||
DEFINE_ELEMENT_WISE_OBJ(Pow, OpType::Pow)
|
||||
}; // namespace infini
|
|
@ -141,4 +141,35 @@ void TensorObj::copyData(const TensorObj *src) {
|
|||
runtime->copyBlob(this, src);
|
||||
}
|
||||
|
||||
Shape TensorObj::getPosByOffset(size_t offset, Shape dim) const {
|
||||
Shape pos = dim;
|
||||
for (int i = dim.size() - 1; i >= 0; i--) {
|
||||
pos[i] = offset % dim.at(i);
|
||||
offset = (offset - pos[i]) / dim.at(i);
|
||||
}
|
||||
return pos;
|
||||
}
|
||||
|
||||
size_t TensorObj::getOffsetByPos(Shape pos, Shape dim) const {
|
||||
int n = dim.size();
|
||||
size_t offset = pos.at(0);
|
||||
for (auto i = 1; i < n; i++) {
|
||||
offset = offset * dim.at(i) + pos.at(i);
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
|
||||
size_t TensorObj::getOffsetByBroadcastOffset(size_t bcOffset,
|
||||
Shape bcDim) const {
|
||||
Shape bcPos = getPosByOffset(bcOffset, bcDim);
|
||||
|
||||
Shape pos = bcPos;
|
||||
int n = shape.size();
|
||||
for (auto i = 0; i < n; i++) {
|
||||
if (shape.at(i) == 1)
|
||||
pos[i] = 0;
|
||||
}
|
||||
return getOffsetByPos(pos, shape);
|
||||
}
|
||||
|
||||
}; // namespace infini
|
|
@ -17,4 +17,5 @@ void cudaPrintFloat(float *x, int len) {
|
|||
cudaPrintFloatImpl<<<1, 1>>>(x, len);
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,68 @@
|
|||
#include "operators/element_wise.h"
|
||||
#include "core/kernel.h"
|
||||
|
||||
namespace infini {
|
||||
template <typename T> class NativeElementWise : public Kernel {
|
||||
virtual T doCompute(T val0, T val1) const = 0;
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
T *inptr0 = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *inptr1 = op->getInputs(1)->getRawDataPtr<T *>();
|
||||
T *outptr = op->getOutput()->getRawDataPtr<T *>();
|
||||
|
||||
auto outDim = op->getOutput()->getDims();
|
||||
auto n = op->getOutput()->size();
|
||||
for (size_t offset = 0; offset < n; offset++) {
|
||||
// For now,we only process the same dims here, broardcast will be
|
||||
// considered in the opt layer.
|
||||
/*auto offset0 =
|
||||
op->getInputs(0)->getOffsetByBroadcastOffset(offset, outDim);
|
||||
auto offset1 =
|
||||
op->getInputs(1)->getOffsetByBroadcastOffset(offset, outDim);
|
||||
outptr[offset] = doCompute(inptr0[offset0], inptr1[offset1]);*/
|
||||
outptr[offset] = doCompute(inptr0[offset], inptr1[offset]);
|
||||
}
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
compute(op, {}, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *context) const override {
|
||||
PerfRecord perfrcd(timeit([&]() { compute(op, context); }));
|
||||
return perfrcd;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> class NaiveAdd : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return val0 + val1; }
|
||||
};
|
||||
template <typename T> class NaiveSub : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return val0 - val1; }
|
||||
};
|
||||
template <typename T> class NaiveMul : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return val0 * val1; }
|
||||
};
|
||||
template <typename T> class NaiveDiv : public NativeElementWise<T> {
|
||||
T doCompute(T val0, T val1) const override { return (T)(val0 / val1); }
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::UInt32, NaiveAdd<uint32_t>,
|
||||
"addNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Add, DataType::Float32, NaiveAdd<float>,
|
||||
"addNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::UInt32, NaiveSub<uint32_t>,
|
||||
"subNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Sub, DataType::Float32, NaiveSub<float>,
|
||||
"subNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::UInt32, NaiveMul<uint32_t>,
|
||||
"mulNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Mul, DataType::Float32, NaiveMul<float>,
|
||||
"mulNaive_CPU_float32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::UInt32, NaiveDiv<uint32_t>,
|
||||
"divNaive_CPU_uint32");
|
||||
REGISTER_KERNEL(Device::CPU, OpType::Div, DataType::Float32, NaiveDiv<float>,
|
||||
"divNaive_CPU_float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,124 @@
|
|||
#include "operators/element_wise.h"
|
||||
#include "core/kernel.h"
|
||||
#include "cuda/cuda_element_wise.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class ElementWiseCudnn : public Kernel {
|
||||
virtual cudnnOpTensorOp_t getOpType() const = 0;
|
||||
virtual tuple<float, float, float> getAlphBeta() const {
|
||||
return {1.f, 1.f, 0.f};
|
||||
}
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cudnnTensorDescriptor_t aDesc, bDesc, cDesc;
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
|
||||
|
||||
// get inputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&aDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
aDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
|
||||
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&bDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
bDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
|
||||
|
||||
// get outputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&cDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
cDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
|
||||
|
||||
// get op descriptor
|
||||
cudnnOpTensorDescriptor_t opDesc;
|
||||
checkCudnnError(cudnnCreateOpTensorDescriptor(&opDesc));
|
||||
checkCudnnError(cudnnSetOpTensorDescriptor(
|
||||
opDesc, getOpType(), CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN));
|
||||
|
||||
auto [aAlpha, bAlpha, beta] = getAlphBeta();
|
||||
cudnnStatus_t stat =
|
||||
cudnnOpTensor(context->cudnnHandle(), opDesc, &aAlpha, aDesc, aData,
|
||||
&bAlpha, bDesc, bData, &beta, cDesc, cData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
// Destories in CUDA does not require sync. But cuDNN does not state
|
||||
// whether sync is required before destories.
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(aDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(bDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(cDesc));
|
||||
checkCudnnError(cudnnDestroyOpTensorDescriptor(opDesc));
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
compute(_op, {}, _context);
|
||||
}
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord ret;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
ret.time = timeit([&]() { compute(_op, _context); },
|
||||
[&]() { context->sync(); });
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
class AddCudnn : public ElementWiseCudnn {
|
||||
cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_ADD; }
|
||||
};
|
||||
|
||||
class SubCudnn : public ElementWiseCudnn {
|
||||
cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_ADD; }
|
||||
tuple<float, float, float> getAlphBeta() const override {
|
||||
return {1.f, -1.f, 0.f};
|
||||
}
|
||||
};
|
||||
|
||||
class MulCudnn : public ElementWiseCudnn {
|
||||
cudnnOpTensorOp_t getOpType() const override { return CUDNN_OP_TENSOR_MUL; }
|
||||
};
|
||||
|
||||
class ElementWiseCuda : public Kernel {
|
||||
void compute(const Operator &_op, const PerfRecord &record,
|
||||
const RuntimeObj *_context) const override {
|
||||
element_wise_kernel(_op);
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
compute(_op, {}, _context);
|
||||
}
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
PerfRecord ret;
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
ret.time = timeit([&]() { compute(_op, _context); },
|
||||
[&]() { context->sync(); });
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Add, DataType::Float32, AddCudnn,
|
||||
"Add_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Sub, DataType::Float32, SubCudnn,
|
||||
"Sub_cuDNN_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Mul, DataType::Float32, MulCudnn,
|
||||
"Mul_cuDNN_CUDA_Float32");
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Div, DataType::Float32, ElementWiseCuda,
|
||||
"Div_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Pow, DataType::Float32, ElementWiseCuda,
|
||||
"Pow__CUDA_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,38 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include <math.h>
|
||||
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
__global__ void _div_kernel(float *x, float *y, float *z, int n) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < n; i += stride) {
|
||||
z[i] = x[i] / y[i];
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void _pow_kernel(float *x, float *y, float *z, int n) {
|
||||
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int stride = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < n; i += stride) {
|
||||
z[i] = pow(x[i], y[i]);
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void div_kernel(float *a, float *b, float *c, int num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_div_kernel<<<blocksize, gridsize>>>(a, b, c, num);
|
||||
}
|
||||
void pow_kernel(float *a, float *b, float *c, int num) {
|
||||
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||
_pow_kernel<<<blocksize, gridsize>>>(a, b, c, num);
|
||||
}
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,57 @@
|
|||
#include "operators/element_wise.h"
|
||||
|
||||
namespace infini {
|
||||
ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0,
|
||||
Tensor input1, Tensor output)
|
||||
: OperatorObj(type, {input0, input1}, {output}) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>>
|
||||
ElementWiseObj::inferShape(const TensorVec &inputs) const {
|
||||
// For now,we only process the same dims here, broardcast will be considered
|
||||
// in the opt layer.
|
||||
const auto A = inputs[0], B = inputs[1];
|
||||
if (A->getDims().size() != B->getDims().size() ||
|
||||
A->getDims() != B->getDims())
|
||||
return {};
|
||||
|
||||
return {{A->getDims()}};
|
||||
/*
|
||||
int n = A->getDims().size();
|
||||
Shape shape;
|
||||
for (int i = 0; i < n; i++) {
|
||||
auto dimA = A->getDims().at(i);
|
||||
auto dimB = B->getDims().at(i);
|
||||
if (!(dimA == dimB || dimA == 1 || dimB == 1))
|
||||
return {};
|
||||
auto dimI = dimA > dimB ? dimA : dimB;
|
||||
shape.emplace_back(dimI);
|
||||
}
|
||||
return {{shape}};*/
|
||||
}
|
||||
|
||||
std::string ElementWiseObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << vecToString(inputs[1]->getDims()) << ",";
|
||||
os << "input0=" << inputs[0]->getGuid() << ",";
|
||||
os << "input1=" << inputs[1]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
// use output dim or inputs dim?
|
||||
vector<int> ElementWiseObj::getWorkloadVector() const {
|
||||
vector<int> ret = outputs[0]->getDims();
|
||||
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> ElementWiseObj::getOpAttrVector() const {
|
||||
return {enum_to_underlying(type)};
|
||||
}
|
||||
|
||||
}; // namespace infini
|
|
@ -28,7 +28,7 @@ optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) const {
|
|||
|
||||
std::string PoolingObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Maxpool[" << getGuid() << "]";
|
||||
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << "k=[" << kh << "," << kw << "],";
|
||||
os << "p=[" << ph << "," << pw << "],";
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/element_wise.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
using ExpectOutput = vector<float>;
|
||||
TEST(ElementWise, ShapeInference) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({2, 3, 3, 4}, DataType::UInt32);
|
||||
Tensor i1 = g->addTensor({2, 3, 3, 4}, DataType::UInt32);
|
||||
auto op = g->addOp<AddObj>(i0, i1, nullptr);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
|
||||
}
|
||||
}
|
||||
/*
|
||||
template <typename T>
|
||||
void test_element_wise(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const vector<uint32_t> &ans) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 3, 2, 2}, DataType::UInt32);
|
||||
Tensor i1 = g->addTensor({2, 3, 1, 2}, DataType::UInt32);
|
||||
auto op = g->addOp<T>(i0, i1, nullptr);
|
||||
|
||||
g->dataMalloc();
|
||||
i0->setData(generator);
|
||||
i1->setData(generator);
|
||||
runtime->run(g, true, true);
|
||||
// check answer
|
||||
EXPECT_TRUE(op->getOutput()->equalData(ans));
|
||||
}
|
||||
|
||||
TEST(ElementWise, NaiveCPU) {
|
||||
test_element_wise<AddObj>(IncrementalGenerator(),
|
||||
vector<uint32_t>{0, 2, 2, 4, 6, 8, 8, 10,
|
||||
12, 14, 14, 16, 6, 8, 8, 10,
|
||||
12, 14, 14, 16, 18, 20, 20, 22});
|
||||
test_element_wise<SubObj>(
|
||||
IncrementalGenerator(),
|
||||
vector<uint32_t>{0, 0, 2, 2,
|
||||
2, 2, 4, 4,
|
||||
4, 4, 6, 6,
|
||||
4294967290, 4294967290, 4294967292, 4294967292,
|
||||
4294967292, 4294967292, 4294967294, 4294967294,
|
||||
4294967294, 4294967294, 0, 0});
|
||||
test_element_wise<MulObj>(
|
||||
IncrementalGenerator(),
|
||||
vector<uint32_t>{0, 1, 0, 3, 8, 15, 12, 21, 32, 45, 40, 55,
|
||||
0, 7, 12, 21, 32, 45, 48, 63, 80, 99, 100, 121});
|
||||
test_element_wise<DivObj>(OneGenerator(),
|
||||
vector<uint32_t>{
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
});
|
||||
}
|
||||
*/
|
||||
|
||||
template <class T>
|
||||
void testElementWiseCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const ExpectOutput &ansVec) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor acpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
acpu->dataMalloc();
|
||||
acpu->setData(generator);
|
||||
|
||||
Tensor bcpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
bcpu->dataMalloc();
|
||||
bcpu->setData(generator);
|
||||
|
||||
// Build CUDA graph
|
||||
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||
auto a = g->cloneTensor(acpu);
|
||||
auto b = g->cloneTensor(bcpu);
|
||||
auto op = g->addOp<T>(a, b, nullptr);
|
||||
|
||||
// allocate CUDA memory
|
||||
g->dataMalloc();
|
||||
|
||||
// Execute on CUDA
|
||||
cudaRuntime->run(g);
|
||||
|
||||
// clone CUDA output to CPU
|
||||
auto c = op->getOutput();
|
||||
auto ccpu = c->clone(cpuRuntime);
|
||||
// cudaPrintTensor(c);
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(ccpu->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(ElementWise, CuDNN) {
|
||||
testElementWiseCudnn<AddObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22});
|
||||
testElementWiseCudnn<SubObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
testElementWiseCudnn<MulObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
|
||||
|
||||
testElementWiseCudnn<DivObj>(
|
||||
OneGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
|
||||
testElementWiseCudnn<PowObj>(IncrementalGenerator(), Shape{1, 2, 2, 1},
|
||||
ExpectOutput{1, 1, 4, 27});
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue