impl sqrt on CUDA (#109)

* impl sqrt on CUDA
fix parser of Gather and ReduceMean

* fix test_gather

* fix test_cuda_gather

* impl sqrt cpu and add sqrt to test_cuda_unary

* cuda_unary supports arbitary shapes

* fix SplitOp with dim=-1

* fix SplitOp with dim=-1
This commit is contained in:
constroy Li 2023-08-18 12:17:47 +08:00 committed by GitHub
parent ef672894d0
commit 48847958d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 78 additions and 30 deletions

View File

@ -1,6 +1,6 @@
.PHONY : build clean format install-python test-cpp test-onnx .PHONY : build clean format install-python test-cpp test-onnx
TYPE ?= release TYPE ?= Release
CUDA ?= OFF CUDA ?= OFF
BANG ?= OFF BANG ?= OFF
INTELCPU ?= off INTELCPU ?= off
@ -30,7 +30,7 @@ format:
install-python: build install-python: build
cp build/$(TYPE)/backend*.so pyinfinitensor/src/pyinfinitensor cp build/$(TYPE)/backend*.so pyinfinitensor/src/pyinfinitensor
pip install pyinfinitensor/ pip install -e pyinfinitensor/
test-cpp: test-cpp:
@echo @echo

View File

@ -47,6 +47,7 @@ class GraphHandlerObj {
Tensor tanh(Tensor x, Tensor y); Tensor tanh(Tensor x, Tensor y);
Tensor softmax(Tensor x, Tensor y, int axis); Tensor softmax(Tensor x, Tensor y, int axis);
Tensor abs(Tensor x, Tensor y); Tensor abs(Tensor x, Tensor y);
Tensor sqrt(Tensor x, Tensor y);
Tensor shape(Tensor x, Tensor y); Tensor shape(Tensor x, Tensor y);
Tensor identity(Tensor x, Tensor y); Tensor identity(Tensor x, Tensor y);
Tensor flatten(Tensor s, Tensor y, int axis); Tensor flatten(Tensor s, Tensor y, int axis);

View File

@ -3,29 +3,32 @@
#include "operators/unary.h" #include "operators/unary.h"
namespace infini { namespace infini {
// TODO(constroy): num should be size_t.
void softmax_kernel(float *input, float *output, int num); void softmax_kernel(float *input, float *output, int num);
void relu_kernel(float *input, float *output, int num); void relu_kernel(float *input, float *output, int num);
void sigmoid_kernel(float *input, float *output, int num); void sigmoid_kernel(float *input, float *output, int num);
void tanh_kernel(float *input, float *output, int num); void tanh_kernel(float *input, float *output, int num);
void abs_kernel(float *input, float *output, int num); void abs_kernel(float *input, float *output, int num);
void sqrt_kernel(float *input, float *output, int num);
void unary_kernel(const Operator &_op) { void unary_kernel(const Operator &_op) {
auto op = as<UnaryObj>(_op); auto op = as<UnaryObj>(_op);
float *const inputData = (op->getInputs(0)->getRawDataPtr<float *>()); float *const inputData = (op->getInputs(0)->getRawDataPtr<float *>());
float *const outputData = (op->getOutput()->getRawDataPtr<float *>()); float *const outputData = (op->getOutput()->getRawDataPtr<float *>());
auto dim = op->getInputs(0)->getDims(); size_t num = op->getOutput()->size();
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
if (op->getOpType() == OpType::Softmax) if (op->getOpType() == OpType::Softmax)
softmax_kernel(inputData, outputData, n * c * h * w); softmax_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Relu) else if (op->getOpType() == OpType::Relu)
relu_kernel(inputData, outputData, n * c * h * w); relu_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Sigmoid) else if (op->getOpType() == OpType::Sigmoid)
sigmoid_kernel(inputData, outputData, n * c * h * w); sigmoid_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Tanh) else if (op->getOpType() == OpType::Tanh)
tanh_kernel(inputData, outputData, n * c * h * w); tanh_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Abs) else if (op->getOpType() == OpType::Abs)
abs_kernel(inputData, outputData, n * c * h * w); abs_kernel(inputData, outputData, num);
else if (op->getOpType() == OpType::Sqrt)
sqrt_kernel(inputData, outputData, num);
else else
IT_TODO_HALT(); IT_TODO_HALT();
} }

View File

@ -377,6 +377,11 @@ class OnnxStub:
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
) )
elif node.op_type == "Sqrt":
tensors[node.output[0]] = self.handler.sqrt(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Shape": elif node.op_type == "Shape":
tensors[node.output[0]] = self.handler.shape( tensors[node.output[0]] = self.handler.shape(
tensors[node.input[0]], tensors[node.input[0]],
@ -500,7 +505,7 @@ class OnnxStub:
tensors[node.input[1]], tensors[node.input[1]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
next( next(
(attr.i for attr in node.attribute if attr.name == "axis") (attr.i for attr in node.attribute if attr.name == "axis"), 0
), ),
) )
elif node.op_type == "ReduceMean": elif node.op_type == "ReduceMean":
@ -521,7 +526,8 @@ class OnnxStub:
attr.i attr.i
for attr in node.attribute for attr in node.attribute
if attr.name == "keepdims" if attr.name == "keepdims"
) ),
1
) )
!= 0, != 0,
) )

View File

@ -151,6 +151,7 @@ DEFINE_UNARY_METHOD(relu, Relu)
DEFINE_UNARY_METHOD(sigmoid, Sigmoid) DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
DEFINE_UNARY_METHOD(tanh, Tanh) DEFINE_UNARY_METHOD(tanh, Tanh)
DEFINE_UNARY_METHOD(abs, Abs) DEFINE_UNARY_METHOD(abs, Abs)
DEFINE_UNARY_METHOD(sqrt, Sqrt)
DEFINE_UNARY_METHOD(shape, Shape) DEFINE_UNARY_METHOD(shape, Shape)
// see operators/reshape.h // see operators/reshape.h

View File

@ -344,6 +344,7 @@ void init_graph_builder(py::module &m) {
.def("tanh", &Handler::tanh, policy::move) .def("tanh", &Handler::tanh, policy::move)
.def("softmax", &Handler::softmax, policy::move) .def("softmax", &Handler::softmax, policy::move)
.def("abs", &Handler::abs, policy::move) .def("abs", &Handler::abs, policy::move)
.def("sqrt", &Handler::sqrt, policy::move)
.def("shape", &Handler::shape, policy::move) .def("shape", &Handler::shape, policy::move)
.def("identity", &Handler::identity, policy::move) .def("identity", &Handler::identity, policy::move)
.def("flatten", &Handler::flatten, policy::move) .def("flatten", &Handler::flatten, policy::move)

View File

@ -56,6 +56,10 @@ template <typename T> class NaiveAbs : public NativeUnary<T> {
T doCompute(T val) const override { return val < 0 ? -val : val; } T doCompute(T val) const override { return val < 0 ? -val : val; }
}; };
template <typename T> class NaiveSqrt : public NativeUnary<T> {
T doCompute(T val) const override { return std::sqrt(val); }
};
template <typename T> class Clip : public CpuKernelWithoutConfig { template <typename T> class Clip : public CpuKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *context) const override { const RuntimeObj *context) const override {
@ -91,6 +95,8 @@ REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::UInt32, NaiveAbs<uint32_t>,
"absNaive_CPU_uint32"); "absNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs<float>, REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs<float>,
"absNaive_CPU_float32"); "absNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt<float>,
"sqrtNaive_CPU_float32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32, REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32"); NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32, REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,

View File

@ -132,6 +132,8 @@ REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, TanhCudnn,
"Tanh_CUDA_Float32"); "Tanh_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda, REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
"Abs_CUDA_Float32"); "Abs_CUDA_Float32");
REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda,
"Sqrt_CUDA_Float32");
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda, // REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda,
// "Softmax_CUDA_Float32"); // "Softmax_CUDA_Float32");

View File

@ -58,6 +58,14 @@ __global__ void _abs_kernel(float *input, float *output, int n) {
} }
} }
__global__ void _sqrt_kernel(float *input, float *output, int n) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (int i = index; i < n; i += stride) {
output[i] = sqrt(input[i]);
}
}
namespace infini { namespace infini {
void softmax_kernel(float *input, float *output, int num) { void softmax_kernel(float *input, float *output, int num) {
@ -90,5 +98,10 @@ void abs_kernel(float *input, float *output, int num) {
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
_abs_kernel<<<blocksize, gridsize>>>(input, output, num); _abs_kernel<<<blocksize, gridsize>>>(input, output, num);
} }
void sqrt_kernel(float *input, float *output, int num) {
int blocksize = block_work_size();
int gridsize = (num + block_work_size() - 1) / block_work_size();
_sqrt_kernel<<<blocksize, gridsize>>>(input, output, num);
}
}; // namespace infini }; // namespace infini

View File

@ -25,7 +25,7 @@ optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) const {
vector<DataType> GatherObj::inferDataType(const TensorVec &inputs) const { vector<DataType> GatherObj::inferDataType(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() == 2); IT_ASSERT(inputs.size() == 2);
auto index = inputs[1]; auto index = inputs[1];
IT_ASSERT(index->getDType() == DataType::UInt32); IT_ASSERT(index->getDType() == DataType::Int32);
return {inputs[0]->getDType()}; return {inputs[0]->getDType()};
} }

View File

@ -7,10 +7,8 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
std::optional<TensorVec> outputs, int dim, int num) std::optional<TensorVec> outputs, int dim, int num)
: OperatorObj(OpType::Split, {input}, : OperatorObj(OpType::Split, {input},
((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))), ((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))),
dim(dim), num(num), ratio({}) { dim(get_real_axis(dim, input->getRank())), num(num), ratio({}) {
int rank = input->getRank(); int dimSize = input->getDims().at(this->dim);
dim = get_real_axis(dim, rank);
int dimSize = input->getDims().at(dim);
int pieceSize = dimSize / num; int pieceSize = dimSize / num;
int lastSize = dimSize - pieceSize * num; int lastSize = dimSize - pieceSize * num;
@ -28,9 +26,7 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
const vector<int> &ratio) const vector<int> &ratio)
: OperatorObj(OpType::Split, {input}, : OperatorObj(OpType::Split, {input},
((!outputs) ? TensorVec{nullptr} : (*outputs))), ((!outputs) ? TensorVec{nullptr} : (*outputs))),
dim(dim), num(-1), ratio(ratio) { dim(get_real_axis(dim, input->getRank())), num(-1), ratio(ratio) {
int rank = input->getRank();
dim = get_real_axis(dim, rank);
num = ratio.size(); num = ratio.size();
if (!outputs) { if (!outputs) {
TensorVec tmp(num, nullptr); TensorVec tmp(num, nullptr);

View File

@ -179,10 +179,10 @@ TEST(Gather, Cuda) {
Runtime runtime = NativeCpuRuntimeObj::getInstance(); Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime); Graph gCpu = make_ref<GraphObj>(runtime);
auto input = gCpu->addTensor({3, 2}, DataType::Float32); auto input = gCpu->addTensor({3, 2}, DataType::Float32);
auto index = gCpu->addTensor({2, 2}, DataType::UInt32); auto index = gCpu->addTensor({2, 2}, DataType::Int32);
gCpu->dataMalloc(); gCpu->dataMalloc();
input->copyin(vector<float>{1, 2, 3, 4, 5, 6}); input->copyin(vector<float>{1, 2, 3, 4, 5, 6});
index->copyin(vector<uint32_t>{0, 1, 1, 2}); index->copyin(vector<int>{0, 1, 1, 2});
auto cudaRuntime = make_ref<CudaRuntimeObj>(); auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime); Graph gCuda = make_ref<GraphObj>(cudaRuntime);
@ -191,7 +191,7 @@ TEST(Gather, Cuda) {
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 0); auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 0);
gCuda->dataMalloc(); gCuda->dataMalloc();
inputCuda->copyin(vector<float>{1, 2, 3, 4, 5, 6}); inputCuda->copyin(vector<float>{1, 2, 3, 4, 5, 6});
indexCuda->copyin(vector<uint32_t>{0, 1, 1, 2}); indexCuda->copyin(vector<int>{0, 1, 1, 2});
cudaRuntime->run(gCuda); cudaRuntime->run(gCuda);
// cudaPrintTensor(op->getOutput()); // cudaPrintTensor(op->getOutput());
@ -203,10 +203,10 @@ TEST(Gather, Cuda) {
Runtime runtime = NativeCpuRuntimeObj::getInstance(); Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime); Graph gCpu = make_ref<GraphObj>(runtime);
auto input = gCpu->addTensor({3, 3}, DataType::Float32); auto input = gCpu->addTensor({3, 3}, DataType::Float32);
auto index = gCpu->addTensor({1, 2}, DataType::UInt32); auto index = gCpu->addTensor({1, 2}, DataType::Int32);
gCpu->dataMalloc(); gCpu->dataMalloc();
input->setData(IncrementalGenerator()); input->setData(IncrementalGenerator());
index->copyin(vector<uint32_t>{0, 2}); index->copyin(vector<int>{0, 2});
auto cudaRuntime = make_ref<CudaRuntimeObj>(); auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime); Graph gCuda = make_ref<GraphObj>(cudaRuntime);
@ -215,7 +215,7 @@ TEST(Gather, Cuda) {
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 1); auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 1);
gCuda->dataMalloc(); gCuda->dataMalloc();
inputCuda->setData(IncrementalGenerator()); inputCuda->setData(IncrementalGenerator());
indexCuda->copyin(vector<uint32_t>{0, 2}); indexCuda->copyin(vector<int>{0, 2});
cudaRuntime->run(gCuda); cudaRuntime->run(gCuda);
// cudaPrintTensor(op->getOutput()); // cudaPrintTensor(op->getOutput());
@ -227,10 +227,10 @@ TEST(Gather, Cuda) {
Runtime runtime = NativeCpuRuntimeObj::getInstance(); Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(runtime); Graph gCpu = make_ref<GraphObj>(runtime);
auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32); auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32);
auto index = gCpu->addTensor({3, 1}, DataType::UInt32); auto index = gCpu->addTensor({3, 1}, DataType::Int32);
gCpu->dataMalloc(); gCpu->dataMalloc();
input->setData(IncrementalGenerator()); input->setData(IncrementalGenerator());
index->copyin(vector<uint32_t>{0, 3, 1}); index->copyin(vector<int>{0, 3, 1});
auto cudaRuntime = make_ref<CudaRuntimeObj>(); auto cudaRuntime = make_ref<CudaRuntimeObj>();
Graph gCuda = make_ref<GraphObj>(cudaRuntime); Graph gCuda = make_ref<GraphObj>(cudaRuntime);
@ -239,7 +239,7 @@ TEST(Gather, Cuda) {
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 1); auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 1);
gCuda->dataMalloc(); gCuda->dataMalloc();
inputCuda->setData(IncrementalGenerator()); inputCuda->setData(IncrementalGenerator());
indexCuda->copyin(vector<uint32_t>{0, 3, 1}); indexCuda->copyin(vector<int>{0, 3, 1});
cudaRuntime->run(gCuda); cudaRuntime->run(gCuda);
// cudaPrintTensor(op->getOutput()); // cudaPrintTensor(op->getOutput());

View File

@ -45,6 +45,11 @@ TEST(cuDNN_Unary, run) {
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3}); testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
// more shapes
testUnary<SqrtObj>(IncrementalGenerator(), Shape{13});
testUnary<SqrtObj>(IncrementalGenerator(), Shape{4, 3});
testUnary<SqrtObj>(IncrementalGenerator(), Shape{2, 3, 4, 5, 6});
} }
} // namespace infini } // namespace infini

View File

@ -11,8 +11,8 @@ TEST(Gather, ShapeInference) {
Runtime runtime = NativeCpuRuntimeObj::getInstance(); Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime); Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({1, 3, 4, 4}, DataType::UInt32); Tensor i = g->addTensor({1, 3, 4, 4}, DataType::Int32);
Tensor index = g->addTensor({2, 1, 2}, DataType::UInt32); Tensor index = g->addTensor({2, 1, 2}, DataType::Int32);
auto op = g->addOp<GatherObj>(i, index, nullptr, 1); auto op = g->addOp<GatherObj>(i, index, nullptr, 1);
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4})); EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
} }

View File

@ -20,6 +20,20 @@ TEST(Split, ShapeInfer) {
EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6})); EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
} }
{
Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32);
auto op = g->addOp<SplitObj>(input, std::nullopt, -1, 4);
EXPECT_EQ(op->numOutputs(), 4);
EXPECT_EQ(op->getOutputs().size(), (size_t)4);
EXPECT_EQ(op->getOutput(0)->getDims(), (Shape{1, 3, 2, 3}));
EXPECT_EQ(op->getOutput(1)->getDims(), (Shape{1, 3, 2, 3}));
EXPECT_EQ(op->getOutput(2)->getDims(), (Shape{1, 3, 2, 3}));
EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
}
{ {
Runtime runtime = NativeCpuRuntimeObj::getInstance(); Runtime runtime = NativeCpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime); Graph g = make_ref<GraphObj>(runtime);