forked from jiuyuan/InfiniTensor
impl sqrt on CUDA (#109)
* impl sqrt on CUDA fix parser of Gather and ReduceMean * fix test_gather * fix test_cuda_gather * impl sqrt cpu and add sqrt to test_cuda_unary * cuda_unary supports arbitary shapes * fix SplitOp with dim=-1 * fix SplitOp with dim=-1
This commit is contained in:
parent
ef672894d0
commit
48847958d0
4
Makefile
4
Makefile
|
@ -1,6 +1,6 @@
|
||||||
.PHONY : build clean format install-python test-cpp test-onnx
|
.PHONY : build clean format install-python test-cpp test-onnx
|
||||||
|
|
||||||
TYPE ?= release
|
TYPE ?= Release
|
||||||
CUDA ?= OFF
|
CUDA ?= OFF
|
||||||
BANG ?= OFF
|
BANG ?= OFF
|
||||||
INTELCPU ?= off
|
INTELCPU ?= off
|
||||||
|
@ -30,7 +30,7 @@ format:
|
||||||
|
|
||||||
install-python: build
|
install-python: build
|
||||||
cp build/$(TYPE)/backend*.so pyinfinitensor/src/pyinfinitensor
|
cp build/$(TYPE)/backend*.so pyinfinitensor/src/pyinfinitensor
|
||||||
pip install pyinfinitensor/
|
pip install -e pyinfinitensor/
|
||||||
|
|
||||||
test-cpp:
|
test-cpp:
|
||||||
@echo
|
@echo
|
||||||
|
|
|
@ -47,6 +47,7 @@ class GraphHandlerObj {
|
||||||
Tensor tanh(Tensor x, Tensor y);
|
Tensor tanh(Tensor x, Tensor y);
|
||||||
Tensor softmax(Tensor x, Tensor y, int axis);
|
Tensor softmax(Tensor x, Tensor y, int axis);
|
||||||
Tensor abs(Tensor x, Tensor y);
|
Tensor abs(Tensor x, Tensor y);
|
||||||
|
Tensor sqrt(Tensor x, Tensor y);
|
||||||
Tensor shape(Tensor x, Tensor y);
|
Tensor shape(Tensor x, Tensor y);
|
||||||
Tensor identity(Tensor x, Tensor y);
|
Tensor identity(Tensor x, Tensor y);
|
||||||
Tensor flatten(Tensor s, Tensor y, int axis);
|
Tensor flatten(Tensor s, Tensor y, int axis);
|
||||||
|
|
|
@ -3,29 +3,32 @@
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
// TODO(constroy): num should be size_t.
|
||||||
void softmax_kernel(float *input, float *output, int num);
|
void softmax_kernel(float *input, float *output, int num);
|
||||||
void relu_kernel(float *input, float *output, int num);
|
void relu_kernel(float *input, float *output, int num);
|
||||||
void sigmoid_kernel(float *input, float *output, int num);
|
void sigmoid_kernel(float *input, float *output, int num);
|
||||||
void tanh_kernel(float *input, float *output, int num);
|
void tanh_kernel(float *input, float *output, int num);
|
||||||
void abs_kernel(float *input, float *output, int num);
|
void abs_kernel(float *input, float *output, int num);
|
||||||
|
void sqrt_kernel(float *input, float *output, int num);
|
||||||
|
|
||||||
void unary_kernel(const Operator &_op) {
|
void unary_kernel(const Operator &_op) {
|
||||||
auto op = as<UnaryObj>(_op);
|
auto op = as<UnaryObj>(_op);
|
||||||
float *const inputData = (op->getInputs(0)->getRawDataPtr<float *>());
|
float *const inputData = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||||
float *const outputData = (op->getOutput()->getRawDataPtr<float *>());
|
float *const outputData = (op->getOutput()->getRawDataPtr<float *>());
|
||||||
|
|
||||||
auto dim = op->getInputs(0)->getDims();
|
size_t num = op->getOutput()->size();
|
||||||
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
|
|
||||||
if (op->getOpType() == OpType::Softmax)
|
if (op->getOpType() == OpType::Softmax)
|
||||||
softmax_kernel(inputData, outputData, n * c * h * w);
|
softmax_kernel(inputData, outputData, num);
|
||||||
else if (op->getOpType() == OpType::Relu)
|
else if (op->getOpType() == OpType::Relu)
|
||||||
relu_kernel(inputData, outputData, n * c * h * w);
|
relu_kernel(inputData, outputData, num);
|
||||||
else if (op->getOpType() == OpType::Sigmoid)
|
else if (op->getOpType() == OpType::Sigmoid)
|
||||||
sigmoid_kernel(inputData, outputData, n * c * h * w);
|
sigmoid_kernel(inputData, outputData, num);
|
||||||
else if (op->getOpType() == OpType::Tanh)
|
else if (op->getOpType() == OpType::Tanh)
|
||||||
tanh_kernel(inputData, outputData, n * c * h * w);
|
tanh_kernel(inputData, outputData, num);
|
||||||
else if (op->getOpType() == OpType::Abs)
|
else if (op->getOpType() == OpType::Abs)
|
||||||
abs_kernel(inputData, outputData, n * c * h * w);
|
abs_kernel(inputData, outputData, num);
|
||||||
|
else if (op->getOpType() == OpType::Sqrt)
|
||||||
|
sqrt_kernel(inputData, outputData, num);
|
||||||
else
|
else
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
}
|
}
|
||||||
|
|
|
@ -377,6 +377,11 @@ class OnnxStub:
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
|
elif node.op_type == "Sqrt":
|
||||||
|
tensors[node.output[0]] = self.handler.sqrt(
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
elif node.op_type == "Shape":
|
elif node.op_type == "Shape":
|
||||||
tensors[node.output[0]] = self.handler.shape(
|
tensors[node.output[0]] = self.handler.shape(
|
||||||
tensors[node.input[0]],
|
tensors[node.input[0]],
|
||||||
|
@ -500,7 +505,7 @@ class OnnxStub:
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
next(
|
next(
|
||||||
(attr.i for attr in node.attribute if attr.name == "axis")
|
(attr.i for attr in node.attribute if attr.name == "axis"), 0
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
elif node.op_type == "ReduceMean":
|
elif node.op_type == "ReduceMean":
|
||||||
|
@ -521,7 +526,8 @@ class OnnxStub:
|
||||||
attr.i
|
attr.i
|
||||||
for attr in node.attribute
|
for attr in node.attribute
|
||||||
if attr.name == "keepdims"
|
if attr.name == "keepdims"
|
||||||
)
|
),
|
||||||
|
1
|
||||||
)
|
)
|
||||||
!= 0,
|
!= 0,
|
||||||
)
|
)
|
||||||
|
|
|
@ -151,6 +151,7 @@ DEFINE_UNARY_METHOD(relu, Relu)
|
||||||
DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
|
DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
|
||||||
DEFINE_UNARY_METHOD(tanh, Tanh)
|
DEFINE_UNARY_METHOD(tanh, Tanh)
|
||||||
DEFINE_UNARY_METHOD(abs, Abs)
|
DEFINE_UNARY_METHOD(abs, Abs)
|
||||||
|
DEFINE_UNARY_METHOD(sqrt, Sqrt)
|
||||||
DEFINE_UNARY_METHOD(shape, Shape)
|
DEFINE_UNARY_METHOD(shape, Shape)
|
||||||
|
|
||||||
// see operators/reshape.h
|
// see operators/reshape.h
|
||||||
|
|
|
@ -344,6 +344,7 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("tanh", &Handler::tanh, policy::move)
|
.def("tanh", &Handler::tanh, policy::move)
|
||||||
.def("softmax", &Handler::softmax, policy::move)
|
.def("softmax", &Handler::softmax, policy::move)
|
||||||
.def("abs", &Handler::abs, policy::move)
|
.def("abs", &Handler::abs, policy::move)
|
||||||
|
.def("sqrt", &Handler::sqrt, policy::move)
|
||||||
.def("shape", &Handler::shape, policy::move)
|
.def("shape", &Handler::shape, policy::move)
|
||||||
.def("identity", &Handler::identity, policy::move)
|
.def("identity", &Handler::identity, policy::move)
|
||||||
.def("flatten", &Handler::flatten, policy::move)
|
.def("flatten", &Handler::flatten, policy::move)
|
||||||
|
|
|
@ -56,6 +56,10 @@ template <typename T> class NaiveAbs : public NativeUnary<T> {
|
||||||
T doCompute(T val) const override { return val < 0 ? -val : val; }
|
T doCompute(T val) const override { return val < 0 ? -val : val; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T> class NaiveSqrt : public NativeUnary<T> {
|
||||||
|
T doCompute(T val) const override { return std::sqrt(val); }
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T> class Clip : public CpuKernelWithoutConfig {
|
template <typename T> class Clip : public CpuKernelWithoutConfig {
|
||||||
void compute(const Operator &_op,
|
void compute(const Operator &_op,
|
||||||
const RuntimeObj *context) const override {
|
const RuntimeObj *context) const override {
|
||||||
|
@ -91,6 +95,8 @@ REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::UInt32, NaiveAbs<uint32_t>,
|
||||||
"absNaive_CPU_uint32");
|
"absNaive_CPU_uint32");
|
||||||
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs<float>,
|
REGISTER_KERNEL(Device::CPU, OpType::Abs, DataType::Float32, NaiveAbs<float>,
|
||||||
"absNaive_CPU_float32");
|
"absNaive_CPU_float32");
|
||||||
|
REGISTER_KERNEL(Device::CPU, OpType::Sqrt, DataType::Float32, NaiveSqrt<float>,
|
||||||
|
"sqrtNaive_CPU_float32");
|
||||||
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
|
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::UInt32,
|
||||||
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
|
NaiveSoftmax<uint32_t>, "softmaxNaive_CPU_uint32");
|
||||||
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,
|
REGISTER_KERNEL(Device::CPU, OpType::Softmax, DataType::Float32,
|
||||||
|
|
|
@ -132,6 +132,8 @@ REGISTER_KERNEL(Device::CUDA, OpType::Tanh, DataType::Float32, TanhCudnn,
|
||||||
"Tanh_CUDA_Float32");
|
"Tanh_CUDA_Float32");
|
||||||
REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
|
REGISTER_KERNEL(Device::CUDA, OpType::Abs, DataType::Float32, UnaryCuda,
|
||||||
"Abs_CUDA_Float32");
|
"Abs_CUDA_Float32");
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::Sqrt, DataType::Float32, UnaryCuda,
|
||||||
|
"Sqrt_CUDA_Float32");
|
||||||
|
|
||||||
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda,
|
// REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, UnaryCuda,
|
||||||
// "Softmax_CUDA_Float32");
|
// "Softmax_CUDA_Float32");
|
||||||
|
|
|
@ -58,6 +58,14 @@ __global__ void _abs_kernel(float *input, float *output, int n) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__global__ void _sqrt_kernel(float *input, float *output, int n) {
|
||||||
|
int index = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
int stride = blockDim.x * gridDim.x;
|
||||||
|
for (int i = index; i < n; i += stride) {
|
||||||
|
output[i] = sqrt(input[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
void softmax_kernel(float *input, float *output, int num) {
|
void softmax_kernel(float *input, float *output, int num) {
|
||||||
|
|
||||||
|
@ -90,5 +98,10 @@ void abs_kernel(float *input, float *output, int num) {
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
_abs_kernel<<<blocksize, gridsize>>>(input, output, num);
|
_abs_kernel<<<blocksize, gridsize>>>(input, output, num);
|
||||||
}
|
}
|
||||||
|
void sqrt_kernel(float *input, float *output, int num) {
|
||||||
|
|
||||||
|
int blocksize = block_work_size();
|
||||||
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
|
_sqrt_kernel<<<blocksize, gridsize>>>(input, output, num);
|
||||||
|
}
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -25,7 +25,7 @@ optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) const {
|
||||||
vector<DataType> GatherObj::inferDataType(const TensorVec &inputs) const {
|
vector<DataType> GatherObj::inferDataType(const TensorVec &inputs) const {
|
||||||
IT_ASSERT(inputs.size() == 2);
|
IT_ASSERT(inputs.size() == 2);
|
||||||
auto index = inputs[1];
|
auto index = inputs[1];
|
||||||
IT_ASSERT(index->getDType() == DataType::UInt32);
|
IT_ASSERT(index->getDType() == DataType::Int32);
|
||||||
return {inputs[0]->getDType()};
|
return {inputs[0]->getDType()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,10 +7,8 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
|
||||||
std::optional<TensorVec> outputs, int dim, int num)
|
std::optional<TensorVec> outputs, int dim, int num)
|
||||||
: OperatorObj(OpType::Split, {input},
|
: OperatorObj(OpType::Split, {input},
|
||||||
((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))),
|
((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))),
|
||||||
dim(dim), num(num), ratio({}) {
|
dim(get_real_axis(dim, input->getRank())), num(num), ratio({}) {
|
||||||
int rank = input->getRank();
|
int dimSize = input->getDims().at(this->dim);
|
||||||
dim = get_real_axis(dim, rank);
|
|
||||||
int dimSize = input->getDims().at(dim);
|
|
||||||
int pieceSize = dimSize / num;
|
int pieceSize = dimSize / num;
|
||||||
int lastSize = dimSize - pieceSize * num;
|
int lastSize = dimSize - pieceSize * num;
|
||||||
|
|
||||||
|
@ -28,9 +26,7 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
|
||||||
const vector<int> &ratio)
|
const vector<int> &ratio)
|
||||||
: OperatorObj(OpType::Split, {input},
|
: OperatorObj(OpType::Split, {input},
|
||||||
((!outputs) ? TensorVec{nullptr} : (*outputs))),
|
((!outputs) ? TensorVec{nullptr} : (*outputs))),
|
||||||
dim(dim), num(-1), ratio(ratio) {
|
dim(get_real_axis(dim, input->getRank())), num(-1), ratio(ratio) {
|
||||||
int rank = input->getRank();
|
|
||||||
dim = get_real_axis(dim, rank);
|
|
||||||
num = ratio.size();
|
num = ratio.size();
|
||||||
if (!outputs) {
|
if (!outputs) {
|
||||||
TensorVec tmp(num, nullptr);
|
TensorVec tmp(num, nullptr);
|
||||||
|
|
|
@ -179,10 +179,10 @@ TEST(Gather, Cuda) {
|
||||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
auto input = gCpu->addTensor({3, 2}, DataType::Float32);
|
auto input = gCpu->addTensor({3, 2}, DataType::Float32);
|
||||||
auto index = gCpu->addTensor({2, 2}, DataType::UInt32);
|
auto index = gCpu->addTensor({2, 2}, DataType::Int32);
|
||||||
gCpu->dataMalloc();
|
gCpu->dataMalloc();
|
||||||
input->copyin(vector<float>{1, 2, 3, 4, 5, 6});
|
input->copyin(vector<float>{1, 2, 3, 4, 5, 6});
|
||||||
index->copyin(vector<uint32_t>{0, 1, 1, 2});
|
index->copyin(vector<int>{0, 1, 1, 2});
|
||||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
@ -191,7 +191,7 @@ TEST(Gather, Cuda) {
|
||||||
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 0);
|
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 0);
|
||||||
gCuda->dataMalloc();
|
gCuda->dataMalloc();
|
||||||
inputCuda->copyin(vector<float>{1, 2, 3, 4, 5, 6});
|
inputCuda->copyin(vector<float>{1, 2, 3, 4, 5, 6});
|
||||||
indexCuda->copyin(vector<uint32_t>{0, 1, 1, 2});
|
indexCuda->copyin(vector<int>{0, 1, 1, 2});
|
||||||
cudaRuntime->run(gCuda);
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
// cudaPrintTensor(op->getOutput());
|
// cudaPrintTensor(op->getOutput());
|
||||||
|
@ -203,10 +203,10 @@ TEST(Gather, Cuda) {
|
||||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
auto input = gCpu->addTensor({3, 3}, DataType::Float32);
|
auto input = gCpu->addTensor({3, 3}, DataType::Float32);
|
||||||
auto index = gCpu->addTensor({1, 2}, DataType::UInt32);
|
auto index = gCpu->addTensor({1, 2}, DataType::Int32);
|
||||||
gCpu->dataMalloc();
|
gCpu->dataMalloc();
|
||||||
input->setData(IncrementalGenerator());
|
input->setData(IncrementalGenerator());
|
||||||
index->copyin(vector<uint32_t>{0, 2});
|
index->copyin(vector<int>{0, 2});
|
||||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
@ -215,7 +215,7 @@ TEST(Gather, Cuda) {
|
||||||
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 1);
|
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 1);
|
||||||
gCuda->dataMalloc();
|
gCuda->dataMalloc();
|
||||||
inputCuda->setData(IncrementalGenerator());
|
inputCuda->setData(IncrementalGenerator());
|
||||||
indexCuda->copyin(vector<uint32_t>{0, 2});
|
indexCuda->copyin(vector<int>{0, 2});
|
||||||
cudaRuntime->run(gCuda);
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
// cudaPrintTensor(op->getOutput());
|
// cudaPrintTensor(op->getOutput());
|
||||||
|
@ -227,10 +227,10 @@ TEST(Gather, Cuda) {
|
||||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||||
auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32);
|
auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32);
|
||||||
auto index = gCpu->addTensor({3, 1}, DataType::UInt32);
|
auto index = gCpu->addTensor({3, 1}, DataType::Int32);
|
||||||
gCpu->dataMalloc();
|
gCpu->dataMalloc();
|
||||||
input->setData(IncrementalGenerator());
|
input->setData(IncrementalGenerator());
|
||||||
index->copyin(vector<uint32_t>{0, 3, 1});
|
index->copyin(vector<int>{0, 3, 1});
|
||||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
Graph gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
|
||||||
|
@ -239,7 +239,7 @@ TEST(Gather, Cuda) {
|
||||||
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 1);
|
auto op = gCuda->addOp<GatherObj>(inputCuda, indexCuda, nullptr, 1);
|
||||||
gCuda->dataMalloc();
|
gCuda->dataMalloc();
|
||||||
inputCuda->setData(IncrementalGenerator());
|
inputCuda->setData(IncrementalGenerator());
|
||||||
indexCuda->copyin(vector<uint32_t>{0, 3, 1});
|
indexCuda->copyin(vector<int>{0, 3, 1});
|
||||||
cudaRuntime->run(gCuda);
|
cudaRuntime->run(gCuda);
|
||||||
|
|
||||||
// cudaPrintTensor(op->getOutput());
|
// cudaPrintTensor(op->getOutput());
|
||||||
|
|
|
@ -45,6 +45,11 @@ TEST(cuDNN_Unary, run) {
|
||||||
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
|
testUnary<SqrtObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||||
|
// more shapes
|
||||||
|
testUnary<SqrtObj>(IncrementalGenerator(), Shape{13});
|
||||||
|
testUnary<SqrtObj>(IncrementalGenerator(), Shape{4, 3});
|
||||||
|
testUnary<SqrtObj>(IncrementalGenerator(), Shape{2, 3, 4, 5, 6});
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -11,8 +11,8 @@ TEST(Gather, ShapeInference) {
|
||||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
|
||||||
Graph g = make_ref<GraphObj>(runtime);
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
Tensor i = g->addTensor({1, 3, 4, 4}, DataType::UInt32);
|
Tensor i = g->addTensor({1, 3, 4, 4}, DataType::Int32);
|
||||||
Tensor index = g->addTensor({2, 1, 2}, DataType::UInt32);
|
Tensor index = g->addTensor({2, 1, 2}, DataType::Int32);
|
||||||
auto op = g->addOp<GatherObj>(i, index, nullptr, 1);
|
auto op = g->addOp<GatherObj>(i, index, nullptr, 1);
|
||||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
|
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,20 @@ TEST(Split, ShapeInfer) {
|
||||||
EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
|
EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32);
|
||||||
|
|
||||||
|
auto op = g->addOp<SplitObj>(input, std::nullopt, -1, 4);
|
||||||
|
EXPECT_EQ(op->numOutputs(), 4);
|
||||||
|
EXPECT_EQ(op->getOutputs().size(), (size_t)4);
|
||||||
|
EXPECT_EQ(op->getOutput(0)->getDims(), (Shape{1, 3, 2, 3}));
|
||||||
|
EXPECT_EQ(op->getOutput(1)->getDims(), (Shape{1, 3, 2, 3}));
|
||||||
|
EXPECT_EQ(op->getOutput(2)->getDims(), (Shape{1, 3, 2, 3}));
|
||||||
|
EXPECT_EQ(op->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||||
Graph g = make_ref<GraphObj>(runtime);
|
Graph g = make_ref<GraphObj>(runtime);
|
||||||
|
|
Loading…
Reference in New Issue