From 3e6ef305f17ebaf3c8cd6ffecfc082ad3cd1eae4 Mon Sep 17 00:00:00 2001 From: zhangyunze <93699316+bitzyz@users.noreply.github.com> Date: Tue, 29 Aug 2023 16:06:52 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A1=86=E6=9E=B6=E6=94=AF=E6=8C=81bert/gpt2?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=9E=84=E5=9B=BE=20(#94)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: support to sqrt op * feat: support to erf op * feat: support to expand op * feat: support to where op * fix: gather op index can be int64_t(hard coding) * fix: some wrong use * style: fix the format style * test: add test for change op * fix: rebase to master * fix: fix matmul b compute wrong * add expand and where kernel * Add int64 support for cuda gather kernel * add test_where.cc * add "expand.(cu/cc,test,cuda),modified where.cu" * Separate initialization of datatypes to avoid compile error * modify where.(cu/cc/h,test), expand and clip * Format fix * Format fix --------- Co-authored-by: xgqdut2016 Co-authored-by: panzezhong Co-authored-by: Haojie Wang --- include/core/data_type.h | 18 +---- include/core/graph_handler.h | 3 + include/cuda/cuda_expand.h | 9 +++ include/cuda/cuda_where.h | 13 ++++ include/cuda/gather.h | 14 ++-- include/operators/expand.h | 36 ++++++++++ include/operators/where.h | 36 ++++++++++ pyinfinitensor/src/pyinfinitensor/onnx.py | 36 +++++++++- pyinfinitensor/tests/test_onnx.py | 30 +++++++++ src/core/data_type.cc | 23 +++++++ src/core/graph_handler.cc | 28 ++++++++ src/ffi/ffi_infinitensor.cc | 18 +++++ src/kernels/cuda/expand.cc | 36 ++++++++++ src/kernels/cuda/expand.cu | 49 ++++++++++++++ src/kernels/cuda/gather.cc | 3 +- src/kernels/cuda/gather.cu | 33 +++++---- src/kernels/cuda/where.cc | 41 ++++++++++++ src/kernels/cuda/where.cu | 82 +++++++++++++++++++++++ src/operators/expand.cc | 41 ++++++++++++ src/operators/gather.cc | 34 +++++++--- src/operators/matmul.cc | 2 +- src/operators/where.cc | 42 ++++++++++++ test/kernels/cuda/test_cuda_expand.cc | 41 ++++++++++++ test/kernels/cuda/test_cuda_gather.cc | 27 +++++++- test/kernels/cuda/test_cuda_where.cc | 63 +++++++++++++++++ test/operators/test_expand.cc | 26 +++++++ test/operators/test_gather.cc | 20 ++++-- test/operators/test_reduce_mean.cc | 13 ++++ test/operators/test_where.cc | 46 +++++++++++++ 29 files changed, 804 insertions(+), 59 deletions(-) create mode 100644 include/cuda/cuda_expand.h create mode 100644 include/cuda/cuda_where.h create mode 100644 include/operators/expand.h create mode 100644 include/operators/where.h create mode 100644 src/core/data_type.cc create mode 100644 src/kernels/cuda/expand.cc create mode 100644 src/kernels/cuda/expand.cu create mode 100644 src/kernels/cuda/where.cc create mode 100644 src/kernels/cuda/where.cu create mode 100644 src/operators/expand.cc create mode 100644 src/operators/where.cc create mode 100644 test/kernels/cuda/test_cuda_expand.cc create mode 100644 test/kernels/cuda/test_cuda_where.cc create mode 100644 test/operators/test_expand.cc create mode 100644 test/operators/test_where.cc diff --git a/include/core/data_type.h b/include/core/data_type.h index eb6a6a8d..0b7c1fa0 100644 --- a/include/core/data_type.h +++ b/include/core/data_type.h @@ -1,3 +1,4 @@ +#pragma once #include "core/common.h" namespace infini { @@ -69,23 +70,6 @@ class DataType { int getIndex() const { return index; } }; -// to be consistent with onnx -// https://github.com/onnx/onnx/blob/aeb21329122b96df1d3ef33b500a35ca140b1431/onnx/onnx.proto#L484 -inline const DataType DataType::Undefine(0); -inline const DataType DataType::Float32(1); -inline const DataType DataType::UInt8(2); -inline const DataType DataType::Int8(3); -inline const DataType DataType::UInt16(4); -inline const DataType DataType::Int16(5); -inline const DataType DataType::Int32(6); -inline const DataType DataType::Int64(7); -inline const DataType DataType::String(8); -inline const DataType DataType::Bool(9); -inline const DataType DataType::Float16(10); -inline const DataType DataType::Double(11); -inline const DataType DataType::UInt32(12); -inline const DataType DataType::UInt64(13); -inline const DataType DataType::BFloat16(16); // Method definitions are out of the declaration due to GCC bug: // https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc template <> inline int DataType::get() { return 0; } diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index db29d1c6..49fa2347 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -47,6 +47,7 @@ class GraphHandlerObj { Tensor relu(Tensor x, Tensor y); Tensor sigmoid(Tensor x, Tensor y); Tensor tanh(Tensor x, Tensor y); + Tensor erf(Tensor x, Tensor y); Tensor softmax(Tensor x, Tensor y, int axis); Tensor abs(Tensor x, Tensor y); Tensor sqrt(Tensor x, Tensor y); @@ -70,6 +71,8 @@ class GraphHandlerObj { Tensor pad(Tensor input, Tensor output, const vector &pads, const optional> &axes); Tensor cast(Tensor input, Tensor output, int to); + Tensor expand(Tensor input, Tensor output, Shape dims); + Tensor where(Tensor inputX, Tensor inputY, Tensor condition, Tensor output); //------ modifiers diff --git a/include/cuda/cuda_expand.h b/include/cuda/cuda_expand.h new file mode 100644 index 00000000..b53c4ce4 --- /dev/null +++ b/include/cuda/cuda_expand.h @@ -0,0 +1,9 @@ +#pragma once + +#include "operators/unary.h" +#include "utils/small_array.h" +namespace infini { +void expand_kernel(float *input, float *output, int nDims, int outputsize, + SmallArray inputShape, SmallArray outputShape); + +}; // namespace infini diff --git a/include/cuda/cuda_where.h b/include/cuda/cuda_where.h new file mode 100644 index 00000000..14d9bc73 --- /dev/null +++ b/include/cuda/cuda_where.h @@ -0,0 +1,13 @@ +#pragma once +#include "operators/unary.h" +#include "utils/small_array.h" + +namespace infini { +void where_kernel(const float *inputx, const float *inputy, + const float *condition, float *output, int nDims, + infini::SmallArray inputxShape, + infini::SmallArray inputyShape, + infini::SmallArray conditionShape, + infini::SmallArray outputShape); + +}; // namespace infini diff --git a/include/cuda/gather.h b/include/cuda/gather.h index 0cf45142..f3e0956a 100644 --- a/include/cuda/gather.h +++ b/include/cuda/gather.h @@ -1,7 +1,10 @@ #pragma once +#include "core/data_type.h" -typedef struct { - int *indexValue; +namespace infini { +struct GatherMetaData { + void *indexValue; + DataType indexType; int axis; int inNDim; int outNDim; @@ -10,8 +13,7 @@ typedef struct { int idxDim[4]; int idxStride[4]; int inStride[4]; -} GatherMetaData; +}; -namespace infini { -void gather_kernel(float *in, float *out, GatherMetaData metaData, int num); -} +void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num); +} // namespace infini diff --git a/include/operators/expand.h b/include/operators/expand.h new file mode 100644 index 00000000..8a3558ca --- /dev/null +++ b/include/operators/expand.h @@ -0,0 +1,36 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +/** + * @brief Broadcast the input tensor following the given shape and the + * broadcast rule. + * + */ +class ExpandObj : public OperatorObj { + Shape dims; + + public: + /** + * @brief Construct a new Expand object. + * @param graph The computation graph that this operator belongs to. + * @param input The input tensor. + * @param output The output tensor. + * @param dims The shape you want to expand to, following the broadcast + * rule. + */ + ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims); + OP_CLONE(ExpandObj); + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return 1; } + int numOutputs() const override { return 1; } + Shape getShape() const { return dims; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +} // namespace infini diff --git a/include/operators/where.h b/include/operators/where.h new file mode 100644 index 00000000..6422fe34 --- /dev/null +++ b/include/operators/where.h @@ -0,0 +1,36 @@ +#pragma once +#include "core/operator.h" + +namespace infini { +/** + * @brief Return elements, either from X or Y, depending on condition. + * + */ +class WhereObj : public OperatorObj { + + public: + /** + * @brief Construct a new Where object. + * + * @param graph The computation graph that this operator belongs to. + * @param inputX The input tensor X. + * @param inputY The input tensor Y. + * @param output The output tensor. + * @param condition The condition tensor. + */ + WhereObj(GraphObj *graph, Tensor inputX, Tensor inputY, Tensor condition, + Tensor output); + OP_CLONE(WhereObj); + + optional> inferShape(const TensorVec &inputs) const override; + + std::string toString() const override; + int numInputs() const override { return inputs.size(); } + int numOutputs() const override { return 1; } + + private: + vector getWorkloadVector() const override; + vector getOpAttrVector() const override; +}; + +} // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index ced54ae7..17cdb8fe 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -409,7 +409,8 @@ class OnnxStub: tensors[node.input[0]], tensors.get(node.output[0]), next( - (attr.i for attr in node.attribute if attr.name == "axis") + (attr.i for attr in node.attribute if attr.name == "axis"), + 1, ), ) elif node.op_type == "PRelu": @@ -517,7 +518,8 @@ class OnnxStub: tensors[node.input[1]], tensors.get(node.output[0]), next( - (attr.i for attr in node.attribute if attr.name == "axis"), 0 + (attr.i for attr in node.attribute if attr.name == "axis"), + 0, ), ) elif node.op_type == "ReduceMean": @@ -539,7 +541,7 @@ class OnnxStub: for attr in node.attribute if attr.name == "keepdims" ), - 1 + 1, ) != 0, ) @@ -589,6 +591,25 @@ class OnnxStub: tensors.get(node.output[0]), next((attr.i for attr in node.attribute if attr.name == "to")), ) + elif node.op_type == "Expand": + shape = _parse_data(data[node.input[1]]) + tensors[node.output[0]] = self.handler.expand( + tensors[node.input[0]], + tensors.get(node.output[0]), + shape, + ) + elif node.op_type == "Erf": + tensors[node.output[0]] = self.handler.erf( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Where": + tensors[node.output[0]] = self.handler.where( + tensors[node.input[1]], + tensors[node.input[2]], + tensors[node.input[0]], + tensors.get(node.output[0]), + ) else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) new_node_name.append(node.name) @@ -814,6 +835,8 @@ class OnnxStub: backend.OpTypeId.Abs, backend.OpTypeId.Identity, backend.OpTypeId.PRelu, + backend.OpTypeId.Sqrt, + backend.OpTypeId.Erf, ]: ctx.push_node(make_node(ty.name, inputs, outputs, name)) elif ty == backend.OpTypeId.Flatten: @@ -904,6 +927,13 @@ class OnnxStub: elif ty == backend.OpTypeId.Cast: to = backend.cast_to_of(op) ctx.push_node(make_node(ty.name, inputs, outputs, name, to=to)) + elif ty == backend.OpTypeId.Where: + assert len(inputs) == 3, "Check Where Op must have three inputs." + new_inputs = [inputs[2], inputs[0], inputs[1]] + ctx.push_node(make_node(ty.name, new_inputs, outputs, name)) + elif ty == backend.OpTypeId.Expand: + shape = backend.expand_shape_of(op) + ctx.push_node(make_node(ty.name, inputs, outputs, name, shape=shape)) else: raise Exception("Unsupported OpType", ty) diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 497bae9b..28134dc4 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -207,6 +207,18 @@ class TestStringMethods(unittest.TestCase): relu = make_node("Relu", ["x"], ["y"], name="relu") make_and_import_model(make_graph([relu], "relu", [x], [y])) + def test_erf(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + erf = make_node("Erf", ["x"], ["y"], name="erf") + make_and_import_model(make_graph([erf], "erf", [x], [y])) + + def test_sqrt(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + sqrt = make_node("Sqrt", ["x"], ["y"], name="sqrt") + make_and_import_model(make_graph([sqrt], "sqrt", [x], [y])) + def test_sigmoid(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) @@ -352,6 +364,24 @@ class TestStringMethods(unittest.TestCase): ) make_and_import_model(make_graph([cast], "cast", [input1], [output])) + def test_expand(self): + data = make_tensor_value_info("data", TensorProto.FLOAT, [3, 1]) + dim = make_tensor_value_info("dim", TensorProto.INT64, [3]) + dim_data = make_tensor("dim", TensorProto.INT64, [3], [2, 1, 6]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [2, 3, 6]) + expand = make_node("Expand", ["data", "dim"], ["output"], name="expand") + make_and_import_model( + make_graph([expand], "expand", [data, dim], [output], [dim_data]) + ) + + def test_where(self): + x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) + y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7]) + con = make_tensor_value_info("con", TensorProto.BOOL, [1, 3, 5, 7]) + output = make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 5, 7]) + where = make_node("Where", ["x", "y", "con"], ["output"], name="where") + make_and_import_model(make_graph([where], "where", [x, y, con], [output])) + if __name__ == "__main__": unittest.main() diff --git a/src/core/data_type.cc b/src/core/data_type.cc new file mode 100644 index 00000000..7b4d2aa4 --- /dev/null +++ b/src/core/data_type.cc @@ -0,0 +1,23 @@ +#include "core/data_type.h" + +namespace infini { +// Move implementation here to avoid compile time error on some platform +// to be consistent with onnx +// https://github.com/onnx/onnx/blob/aeb21329122b96df1d3ef33b500a35ca140b1431/onnx/onnx.proto#L484 +const DataType DataType::Undefine(0); +const DataType DataType::Float32(1); +const DataType DataType::UInt8(2); +const DataType DataType::Int8(3); +const DataType DataType::UInt16(4); +const DataType DataType::Int16(5); +const DataType DataType::Int32(6); +const DataType DataType::Int64(7); +const DataType DataType::String(8); +const DataType DataType::Bool(9); +const DataType DataType::Float16(10); +const DataType DataType::Double(11); +const DataType DataType::UInt32(12); +const DataType DataType::UInt64(13); +// TODO: Reserved for complex data type. +const DataType DataType::BFloat16(16); +} // namespace infini diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 40fb42df..87ed6f46 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -3,6 +3,7 @@ #include "operators/concat.h" #include "operators/conv.h" #include "operators/element_wise.h" +#include "operators/expand.h" #include "operators/gather.h" #include "operators/matmul.h" #include "operators/pad.h" @@ -14,6 +15,7 @@ #include "operators/split.h" #include "operators/transpose.h" #include "operators/unary.h" +#include "operators/where.h" namespace infini { @@ -155,6 +157,7 @@ DEFINE_UNARY_METHOD(tanh, Tanh) DEFINE_UNARY_METHOD(abs, Abs) DEFINE_UNARY_METHOD(sqrt, Sqrt) DEFINE_UNARY_METHOD(shape, Shape) +DEFINE_UNARY_METHOD(erf, Erf) // see operators/reshape.h DEFINE_UNARY_METHOD(identity, Identity) @@ -309,6 +312,31 @@ Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) { } } +Tensor GraphHandlerObj::expand(Tensor input, Tensor output, Shape dims) { + if (output) { + g->addOpWithOutputs(std::move(input), output, + std::move(dims)); + return output; + } else { + return g->addOp(std::move(input), output, std::move(dims)) + ->getOutput(); + } +} + +Tensor GraphHandlerObj::where(Tensor inputX, Tensor inputY, Tensor condition, + Tensor output) { + if (output) { + g->addOpWithOutputs(std::move(inputX), std::move(inputY), + std::move(condition), output); + return output; + } else { + return g + ->addOp(std::move(inputX), std::move(inputY), + std::move(condition), output) + ->getOutput(); + } +} + static CastType inferCastType(Tensor input, int to) { auto iType = input->getDType(); auto oType = DataType(to); diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 9289829f..d62e57f6 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -2,6 +2,7 @@ #include "operators/batch_norm.h" #include "operators/concat.h" #include "operators/conv.h" +#include "operators/expand.h" #include "operators/gather.h" #include "operators/matmul.h" #include "operators/pad.h" @@ -96,6 +97,10 @@ void export_values(py::module &m) { .VALUE(OpType, Resize) .VALUE(OpType, Dropout) .VALUE(OpType, Cast) + .VALUE(OpType, Sqrt) + .VALUE(OpType, Expand) + .VALUE(OpType, Erf) + .VALUE(OpType, Where) .export_values(); #undef VALUE @@ -226,6 +231,15 @@ static vector reshape_shape_of(Operator op) { return ans; } +static vector expand_shape_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Expand); + auto shape = dynamic_cast(op.get())->getShape(); + vector ans(shape.size()); + std::transform(shape.begin(), shape.end(), ans.begin(), + [](auto x) { return static_cast(x); }); + return ans; +} + static vector pad_pads_of(Operator op) { IT_ASSERT(op->getOpType() == OpType::Pad); auto shape = dynamic_cast(op.get())->getPads(); @@ -276,6 +290,7 @@ void export_functions(py::module &m) { .FUNCTION(reduce_mean_attrs_of) .FUNCTION(tensor_dtype) .FUNCTION(reshape_shape_of) + .FUNCTION(expand_shape_of) .FUNCTION(pad_pads_of) .FUNCTION(transpose_permute_of) .FUNCTION(concat_axis_of) @@ -359,6 +374,9 @@ void init_graph_builder(py::module &m) { .def("slice", &Handler::slice, policy::move) .def("pad", &Handler::pad, policy::move) .def("cast", &Handler::cast, policy::move) + .def("expand", &Handler::expand, policy::move) + .def("erf", &Handler::erf, policy::move) + .def("where", &Handler::where, policy::move) .def("topo_sort", &Handler::topo_sort, policy::automatic) .def("optimize", &Handler::optimize, policy::automatic) .def("operators", &Handler::operators, policy::move) diff --git a/src/kernels/cuda/expand.cc b/src/kernels/cuda/expand.cc new file mode 100644 index 00000000..b8154d49 --- /dev/null +++ b/src/kernels/cuda/expand.cc @@ -0,0 +1,36 @@ +#include "operators/expand.h" +#include "cuda/cuda_expand.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" + +namespace infini { + +class ExpandCuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + + void *const inputData = (op->getInputs(0)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + const auto &in_Shape = op->getInputs(0)->getDims(); // input shape + const auto &out_Shape = op->getShape(); // output shape + + SmallArray inputShape, outputShape; + int nDims = op->getInputs(0)->getDims().size(); + + IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); + int outputsize = 1; // the length of the output vector after flatten + for (int i = 0; i < nDims; ++i) { + outputShape.data[i] = out_Shape[i]; + inputShape.data[i] = in_Shape[i]; + outputsize *= out_Shape[i]; + } + expand_kernel((float *)inputData, (float *)outputData, nDims, + outputsize, inputShape, outputShape); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Expand, DataType::Float32, ExpandCuda, + "Expand_CUDA_Float32"); + +}; // namespace infini diff --git a/src/kernels/cuda/expand.cu b/src/kernels/cuda/expand.cu new file mode 100644 index 00000000..e1649b81 --- /dev/null +++ b/src/kernels/cuda/expand.cu @@ -0,0 +1,49 @@ +#include "core/common.h" +#include "cuda/cuda_common.h" +#include "utils/small_array.h" + +constexpr unsigned int num_threads() { return 32 * 4; } +constexpr int thread_work_size() { return 4; } +constexpr int block_work_size() { return thread_work_size() * num_threads(); } + +__global__ void _expand_kernel(float *input, float *output, int nDims, + int outputsize, infini::SmallArray inputShape, + infini::SmallArray outputShape) { + + int outputIdx = + blockIdx.x * blockDim.x + threadIdx.x; // i(JKS) + j(KS) + k(S) + s + if (outputIdx < outputsize) { + int inputIdx = 0; // record input index + int temp = 1; // stored S, KS, JKS, in order + int tmp = 1; // stored s,k,j,i in order + int v = outputIdx; // v = i(JKS) + j(KS) + k(S) + s + for (int i = nDims - 1; i >= 0; --i) { + if (i == 0) { + tmp = v; // i = outputIdx/(JKS) + } else { + tmp = v % outputShape.data[i]; // store s,k,j in order + } + if (inputShape.data[i] == + 1) { // if input shape = 1, the index only equal 0 + inputIdx += 0; + } else { + inputIdx += + tmp * temp; // otherwise +i(JKS) or j(KS) or k(S) or s + } + temp *= inputShape.data[i]; + v = v / outputShape.data[i]; + } + output[outputIdx] = input[inputIdx]; + } +} + +namespace infini { +void expand_kernel(float *input, float *output, int nDims, int outputsize, + SmallArray inputShape, SmallArray outputShape) { + int blocksize = block_work_size(); + int gridsize = (outputsize + block_work_size() - 1) / block_work_size(); + _expand_kernel<<>>(input, output, nDims, outputsize, + inputShape, outputShape); +} + +} // namespace infini diff --git a/src/kernels/cuda/gather.cc b/src/kernels/cuda/gather.cc index d769440e..e438db99 100644 --- a/src/kernels/cuda/gather.cc +++ b/src/kernels/cuda/gather.cc @@ -12,7 +12,8 @@ class GatherCuda : public CudaKernelWithoutConfig { auto in = op->getInputs(0); auto index = op->getInputs(1); auto out = op->getOutput(); - metaData.indexValue = index->getRawDataPtr(); + metaData.indexValue = index->getRawDataPtr(); + metaData.indexType = index->getDType(); metaData.axis = op->getAxis(); metaData.inNDim = in->getRank(); metaData.outNDim = out->getRank(); diff --git a/src/kernels/cuda/gather.cu b/src/kernels/cuda/gather.cu index 0ae69085..8ffeeac9 100644 --- a/src/kernels/cuda/gather.cu +++ b/src/kernels/cuda/gather.cu @@ -1,19 +1,21 @@ #include "cuda/cuda_common.h" #include "cuda/gather.h" -__device__ int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) { - int offset = 0; +template +__device__ T gatheredOffset2Offset(int gOffset, + infini::GatherMetaData metaData) { + T offset = 0; for (int i = metaData.inNDim - 1, k = metaData.outNDim - 1; i >= 0; --i) { - int idx = 0; + T idx = 0; if (i == metaData.axis) { - int idxOffset = 0; + T idxOffset = 0; for (int j = metaData.idxNDim - 1; j >= 0; --j) { - int p = gOffset % metaData.idxDim[j]; + T p = gOffset % metaData.idxDim[j]; gOffset = gOffset / metaData.idxDim[j]; idxOffset += p * metaData.idxStride[j]; } - idx = metaData.indexValue[idxOffset]; + idx = static_cast(metaData.indexValue)[idxOffset]; k = k - metaData.idxNDim; } else { @@ -26,22 +28,27 @@ __device__ int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) { return offset; } -__global__ void _gather_kernel(float *in, float *out, GatherMetaData metaData, - int num) { - int tid = threadIdx.x + blockIdx.x * blockDim.x; +template +__global__ void _gather_kernel(float *in, float *out, + infini::GatherMetaData metaData, size_t num) { + T tid = threadIdx.x + blockIdx.x * blockDim.x; int stride = blockDim.x * gridDim.x; while (tid < num) { - int offset = gatheredOffset2Offset(tid, metaData); + T offset = gatheredOffset2Offset(tid, metaData); out[tid] = in[offset]; tid += stride; } } namespace infini { -void gather_kernel(float *in, float *out, GatherMetaData metaData, int num) { +void gather_kernel(float *in, float *out, GatherMetaData metaData, size_t num) { int blockSize = 32 * 16; int gridSize = (num + blockSize - 1) / blockSize; - - _gather_kernel<<>>(in, out, metaData, num); + if (metaData.indexType == DataType::Int64) { + _gather_kernel + <<>>(in, out, metaData, num); + } else { + _gather_kernel<<>>(in, out, metaData, num); + } } } // namespace infini diff --git a/src/kernels/cuda/where.cc b/src/kernels/cuda/where.cc new file mode 100644 index 00000000..4769fea0 --- /dev/null +++ b/src/kernels/cuda/where.cc @@ -0,0 +1,41 @@ +#include "operators/where.h" +#include "cuda/cuda_kernel_wihtout_config.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_where.h" + +namespace infini { + +class WhereCuda : public CudaKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + + void *const inputxData = (op->getInputs(0)->getRawDataPtr()); + void *const inputyData = (op->getInputs(1)->getRawDataPtr()); + void *const conditionData = (op->getInputs(2)->getRawDataPtr()); + void *const outputData = (op->getOutput()->getRawDataPtr()); + const auto &inputx_Shape = op->getInputs(0)->getDims(); + const auto &inputy_Shape = op->getInputs(1)->getDims(); + const auto &condition_Shape = op->getInputs(2)->getDims(); + const auto &output_Shape = op->getOutput()->getDims(); + + int nDims = op->getInputs(0)->getDims().size(); + IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); + + SmallArray inputxShape, inputyShape, conditionShape, outputShape; + for (int i = 0; i < nDims; ++i) { + inputxShape.data[i] = inputx_Shape[i]; + inputyShape.data[i] = inputy_Shape[i]; + conditionShape.data[i] = condition_Shape[i]; + outputShape.data[i] = output_Shape[i]; + } + where_kernel((float *)inputxData, (float *)inputyData, + (float *)conditionData, (float *)outputData, nDims, + inputxShape, inputyShape, conditionShape, outputShape); + } +}; + +REGISTER_KERNEL(Device::CUDA, OpType::Where, DataType::Float32, WhereCuda, + "Where_CUDA_Float32"); + +}; // namespace infini diff --git a/src/kernels/cuda/where.cu b/src/kernels/cuda/where.cu new file mode 100644 index 00000000..7d34098c --- /dev/null +++ b/src/kernels/cuda/where.cu @@ -0,0 +1,82 @@ +#include "cuda/cuda_common.h" +#include "utils/small_array.h" + +__global__ void _where_kernel(const float *inputx, const float *inputy, + const float *condition, float *output, int nDims, + int outputsize, infini::SmallArray inputxShape, + infini::SmallArray inputyShape, + infini::SmallArray conditionShape, + infini::SmallArray outputShape) { + + int outputIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (outputIdx < outputsize) { + int inputxIdx = 0; + int temp_inputx = 1; + + int inputyIdx = 0; + int temp_inputy = 1; + + int conditionIdx = 0; + int temp_condition = 1; + + int tmp = 1; // stored s,k,j,i in order + int v = outputIdx; // v = i(JKS) + j(KS) + k(S) + s + for (int i = nDims - 1; i >= 0; --i) { + if (i == 0) { + tmp = v; // i = outputIdx/(JKS) + } else { + tmp = v % outputShape.data[i]; // store s,k,j in order + } + if (inputxShape.data[i] == 1) { + inputxIdx += 0; + } else { + inputxIdx += + tmp * + temp_inputx; // otherwise +i(JKS) or j(KS) or k(S) or s + } + temp_inputx *= inputxShape.data[i]; + //---------------------------- + if (inputyShape.data[i] == 1) { + inputyIdx += 0; + } else { + inputyIdx += + tmp * + temp_inputy; // otherwise +i(JKS) or j(KS) or k(S) or s + } + temp_inputy *= inputyShape.data[i]; + //-------------------------- + if (conditionShape.data[i] == 1) { + conditionIdx += 0; + } else { + conditionIdx += + tmp * + temp_condition; // otherwise +i(JKS) or j(KS) or k(S) or s + } + temp_condition *= conditionShape.data[i]; + //------------------------- + v = v / outputShape.data[i]; + } + output[outputIdx] = + condition[conditionIdx] ? inputx[inputxIdx] : inputy[inputyIdx]; + } +} + +namespace infini { +void where_kernel(const float *inputx, const float *inputy, + const float *condition, float *output, int nDims, + infini::SmallArray inputxShape, + infini::SmallArray inputyShape, + infini::SmallArray conditionShape, + infini::SmallArray outputShape) { + int outputsize = 1; + + for (int i = 0; i < nDims; i++) { + outputsize *= outputShape.data[i]; + } + int blocksize = 32 * 16; + int gridsize = (outputsize + blocksize - 1) / blocksize; + _where_kernel<<>>( + inputx, inputy, condition, output, nDims, outputsize, inputxShape, + inputyShape, conditionShape, outputShape); +} +} // namespace infini diff --git a/src/operators/expand.cc b/src/operators/expand.cc new file mode 100644 index 00000000..faebb34a --- /dev/null +++ b/src/operators/expand.cc @@ -0,0 +1,41 @@ +#include "operators/expand.h" +#include "utils/operator_utils.h" + +namespace infini { + +ExpandObj::ExpandObj(GraphObj *graph, Tensor input, Tensor output, Shape dims) + : OperatorObj(OpType::Expand, {input}, {output}), dims(std::move(dims)) { + IT_ASSERT(checkValid(graph)); +} + +optional> ExpandObj::inferShape(const TensorVec &inputs) const { + auto shape_input = inputs[0]->getDims(); + Shape ret = infer_broadcast(shape_input, dims); + return {{ret}}; +} + +std::string ExpandObj::toString() const { + std::ostringstream os; + os << "Expand[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[0]->getDims()) << ","; + os << "dims=" << vecToString(dims) << ","; + os << "input=" << inputs[0]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector ExpandObj::getWorkloadVector() const { + vector ret = inputs[0]->getDims(); + ret.insert(ret.end(), dims.begin(), dims.end()); + ret.emplace(ret.begin(), type.underlying()); + return ret; +} + +vector ExpandObj::getOpAttrVector() const { + vector ret = dims; + ret.emplace(ret.begin(), type.underlying()); + return ret; +} + +} // namespace infini diff --git a/src/operators/gather.cc b/src/operators/gather.cc index aa9ef79d..f615faf7 100644 --- a/src/operators/gather.cc +++ b/src/operators/gather.cc @@ -24,8 +24,8 @@ optional> GatherObj::inferShape(const TensorVec &inputs) const { vector GatherObj::inferDataType(const TensorVec &inputs) const { IT_ASSERT(inputs.size() == 2); - auto index = inputs[1]; - IT_ASSERT(index->getDType() == DataType::Int32); + auto index_dtype = inputs[1]->getDType(); + IT_ASSERT(index_dtype == DataType::Int32 || index_dtype == DataType::Int64) return {inputs[0]->getDType()}; } @@ -36,19 +36,31 @@ bool GatherObj::CheckIndexValid() const { return true; Runtime runtime = NativeCpuRuntimeObj::getInstance(); - int *data = (int *)runtime->alloc(index->getBytes()); - index->getRuntime()->copyBlobToCPU( - (void *)data, index->getRawDataPtr(), index->getBytes()); - bool ret = true; auto value = inputs[0]->getDims()[axis]; - for (size_t i = 0; i < index->size(); ++i) { - if (data[i] < 0 || data[i] >= value) { - ret = false; - break; + if (index->getDType() == DataType::Int32) { + int *data = (int *)runtime->alloc(index->getBytes()); + index->getRuntime()->copyBlobToCPU( + (void *)data, index->getRawDataPtr(), index->getBytes()); + for (size_t i = 0; i < index->size(); ++i) { + if (data[i] < 0 || data[i] >= value) { + ret = false; + break; + } } + runtime->dealloc(data); + } else { + int64_t *data = (int64_t *)runtime->alloc(index->getBytes()); + index->getRuntime()->copyBlobToCPU( + (void *)data, index->getRawDataPtr(), index->getBytes()); + for (size_t i = 0; i < index->size(); ++i) { + if (data[i] < 0 || data[i] >= value) { + ret = false; + break; + } + } + runtime->dealloc(data); } - runtime->dealloc(data); return ret; } diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index 963dd591..00207e77 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -20,7 +20,7 @@ MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, if (ret.empty()) { b = 1; } else { - b = std::accumulate(ret.begin(), ret.end(), 1); + b = std::accumulate(ret.begin(), ret.end(), 1, std::multiplies()); } auto kA = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin()); auto kB = *(transB ? shape_b.rbegin() : shape_b.rbegin() + 1); diff --git a/src/operators/where.cc b/src/operators/where.cc new file mode 100644 index 00000000..290ca7c6 --- /dev/null +++ b/src/operators/where.cc @@ -0,0 +1,42 @@ +#include "operators/where.h" +#include "utils/operator_utils.h" + +namespace infini { + +WhereObj::WhereObj(GraphObj *graph, Tensor inputX, Tensor inputY, + Tensor condition, Tensor output) + : OperatorObj(OpType::Where, TensorVec{inputX, inputY, condition}, + {output}) { + IT_ASSERT(checkValid(graph)); +} + +optional> WhereObj::inferShape(const TensorVec &inputs) const { + auto shapeX = inputs[0]->getDims(); + auto shapeY = inputs[1]->getDims(); + auto shapeCon = inputs[2]->getDims(); + auto retXY = infer_broadcast(shapeX, shapeY); + auto ret = infer_broadcast(retXY, shapeCon); + return {{ret}}; +} + +std::string WhereObj::toString() const { + std::ostringstream os; + os << "Where[" << getGuid() << "]"; + os << "("; + os << vecToString(inputs[2]->getDims()) << ","; + os << "inputX=" << inputs[0]->getGuid() << ","; + os << "inputY=" << inputs[1]->getGuid() << ","; + os << "condition=" << inputs[2]->getGuid() << ","; + os << "output=" << outputs[0]->getGuid() << ")"; + return os.str(); +} + +vector WhereObj::getWorkloadVector() const { + vector ret = getOutput()->getDims(); + ret.emplace(ret.begin(), type.underlying()); + return ret; +} + +vector WhereObj::getOpAttrVector() const { return {type.underlying()}; } + +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_expand.cc b/test/kernels/cuda/test_cuda_expand.cc new file mode 100644 index 00000000..fb9b350a --- /dev/null +++ b/test/kernels/cuda/test_cuda_expand.cc @@ -0,0 +1,41 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/expand.h" + +#include "test.h" + +namespace infini { + +TEST(Expand, Cuda) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto t1 = gCpu->addTensor({2, 1, 2, 1}, DataType::Float32); + + gCpu->dataMalloc(); + t1->setData(IncrementalGenerator()); + t1->printData(); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto t1Gpu = gCuda->cloneTensor(t1); + + auto op = gCuda->addOp(t1Gpu, nullptr, Shape{2, 2, 2, 3}); + gCuda->dataMalloc(); + t1Gpu->setData(IncrementalGenerator()); + + cudaRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + oCpu->printData(); + EXPECT_TRUE( + oCpu->equalData(vector{0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, + 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3})); +} + +} // namespace infini diff --git a/test/kernels/cuda/test_cuda_gather.cc b/test/kernels/cuda/test_cuda_gather.cc index 33863406..d1262260 100644 --- a/test/kernels/cuda/test_cuda_gather.cc +++ b/test/kernels/cuda/test_cuda_gather.cc @@ -77,7 +77,7 @@ int gatheredOffset2Offset(int gOffset, GatherMetaData metaData) { idxOffset += p * metaData.idxStride[j]; } - idx = metaData.indexValue[idxOffset]; + idx = static_cast(metaData.indexValue)[idxOffset]; k = k - metaData.idxNDim; } else { @@ -242,6 +242,31 @@ TEST(Gather, Cuda) { indexCuda->copyin(vector{0, 3, 1}); cudaRuntime->run(gCuda); + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData( + vector{0, 1, 6, 7, 2, 3, 8, 9, 14, 15, 10, 11})); + } + { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32); + auto index = gCpu->addTensor({3, 1}, DataType::Int64); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + index->copyin(vector{0, 3, 1}); + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputCuda = gCuda->cloneTensor(input); + auto indexCuda = gCuda->cloneTensor(index); + auto op = gCuda->addOp(inputCuda, indexCuda, nullptr, 1); + gCuda->dataMalloc(); + inputCuda->setData(IncrementalGenerator()); + indexCuda->copyin(vector{0, 3, 1}); + cudaRuntime->run(gCuda); + // cudaPrintTensor(op->getOutput()); // copy output from CUDA to CPU auto oCpu = gCpu->cloneTensor(op->getOutput()); diff --git a/test/kernels/cuda/test_cuda_where.cc b/test/kernels/cuda/test_cuda_where.cc new file mode 100644 index 00000000..61515445 --- /dev/null +++ b/test/kernels/cuda/test_cuda_where.cc @@ -0,0 +1,63 @@ +#include "core/graph.h" +#include "core/runtime.h" +#include "cuda/cuda_runtime.h" +#include "cuda/cuda_utility.h" +#include "operators/where.h" + +#include "test.h" + +namespace infini { + +void test_where(const Shape &inputxshape, const vector &inputxdata, + const Shape &inputyshape, const vector &inputydata, + const Shape &conditionshape, const vector &conditiondata, + const vector &ExpectData) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + auto condition = gCpu->addTensor(conditionshape, DataType::Int32); + auto inputx = gCpu->addTensor(inputxshape, DataType::Float32); + auto inputy = gCpu->addTensor(inputyshape, DataType::Float32); + + gCpu->dataMalloc(); + condition->copyin(conditiondata); // + inputx->copyin(inputxdata); + inputy->copyin(inputydata); // + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto conditionGpu = gCuda->cloneTensor(condition); + auto inputxGpu = gCuda->cloneTensor(inputx); + auto inputyGpu = gCuda->cloneTensor(inputy); + + auto op = gCuda->addOp(inputxGpu, inputyGpu, conditionGpu, + nullptr); // WhereObj + gCuda->dataMalloc(); + conditionGpu->copyin(conditiondata); + inputxGpu->copyin(inputxdata); + inputyGpu->copyin(inputydata); + cudaRuntime->run(gCuda); + + auto oCpu = gCpu->cloneTensor(op->getOutput()); // move data from gpu to cpu + oCpu->printData(); //->printData + EXPECT_TRUE(oCpu->equalData(ExpectData)); +} + +TEST(CUDA_Where, run) { + test_where( + Shape{2, 2, 3, 1}, vector{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + Shape{2, 2, 3, 1}, vector{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + Shape{2, 2, 3, 1}, vector{0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1}, + vector{0., 1., 2., 0., 0., 0., 6., 7., 0., 9., 10., 11.}); + + test_where(Shape{2, 1, 1, 3}, // inputx + vector{0, 1, 2, 3, 4, 5}, Shape{1, 2, 1, 1}, // inputy + vector{1, 1}, Shape{2, 1, 3, 1}, // condition + vector{0, 1, 1, 0, 0, 0}, + vector{1., 1., 1., 0., 1., 2., 0., 1., 2., 1., 1., 1., + 0., 1., 2., 0., 1., 2., 1., 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + +} // python output + +} // namespace infini diff --git a/test/operators/test_expand.cc b/test/operators/test_expand.cc new file mode 100644 index 00000000..608a36d5 --- /dev/null +++ b/test/operators/test_expand.cc @@ -0,0 +1,26 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/expand.h" + +#include "test.h" + +namespace infini { + +TEST(Expand, ShapeInference) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({3, 1}, DataType::Float32); + auto op = g->addOp(i, nullptr, Shape{2, 1, 6}); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3, 6})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({3, 1}, DataType::Float32); + auto op = g->addOp(i, nullptr, Shape{3, 4}); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{3, 4})); + } +} + +} // namespace infini diff --git a/test/operators/test_gather.cc b/test/operators/test_gather.cc index f3b9190c..8ddb2110 100644 --- a/test/operators/test_gather.cc +++ b/test/operators/test_gather.cc @@ -9,11 +9,19 @@ namespace infini { TEST(Gather, ShapeInference) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); - - Graph g = make_ref(runtime); - Tensor i = g->addTensor({1, 3, 4, 4}, DataType::Int32); - Tensor index = g->addTensor({2, 1, 2}, DataType::Int32); - auto op = g->addOp(i, index, nullptr, 1); - EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4})); + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({1, 3, 4, 4}, DataType::Int32); + Tensor index = g->addTensor({2, 1, 2}, DataType::Int32); + auto op = g->addOp(i, index, nullptr, 1); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({1, 3, 4, 4}, DataType::Int32); + Tensor index = g->addTensor({2, 1, 2}, DataType::Int64); + auto op = g->addOp(i, index, nullptr, 1); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4})); + } } } // namespace infini diff --git a/test/operators/test_reduce_mean.cc b/test/operators/test_reduce_mean.cc index 8c3d477e..336d4018 100644 --- a/test/operators/test_reduce_mean.cc +++ b/test/operators/test_reduce_mean.cc @@ -21,6 +21,12 @@ TEST(ReduceMean, ShapeInference) { auto op = g->addOp(i, nullptr, vector{1, 3}, true); EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1})); } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, vector{-3, 3}, true); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 3, 1})); + } { Graph g = make_ref(runtime); Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); @@ -33,6 +39,13 @@ TEST(ReduceMean, ShapeInference) { auto op = g->addOp(i, nullptr, vector{1, 3}, false); EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3})); } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32); + auto op = + g->addOp(i, nullptr, vector{-3, 3}, false); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 3})); + } } } // namespace infini diff --git a/test/operators/test_where.cc b/test/operators/test_where.cc new file mode 100644 index 00000000..c32e2d81 --- /dev/null +++ b/test/operators/test_where.cc @@ -0,0 +1,46 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/where.h" + +#include "test.h" + +namespace infini { + +TEST(Where, ShapeInference) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + { + Graph g = make_ref(runtime); + Tensor x = g->addTensor({2, 2}, DataType::Float32); + Tensor y = g->addTensor({2, 2}, DataType::Float32); + Tensor con = g->addTensor({2, 2}, DataType::Bool); + auto op = g->addOp(x, y, con, nullptr); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 2})); + } + { + Graph g = make_ref(runtime); + Tensor x = g->addTensor({1, 12, 224, 224}, DataType::Float32); + Tensor y = g->addTensor({1, 1, 224, 224}, DataType::Float32); + Tensor con = g->addTensor({1, 224, 1}, DataType::Bool); + auto op = g->addOp(x, y, con, nullptr); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 12, 224, 224})); + } + { + Graph g = make_ref(runtime); + Tensor x = g->addTensor({12, 224, 224}, DataType::Float32); + Tensor y = g->addTensor({1, 1, 224, 224}, DataType::Float32); + Tensor con = g->addTensor({1, 224}, DataType::Bool); + auto op = g->addOp(x, y, con, nullptr); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 12, 224, 224})); + } + { + Graph g = make_ref(runtime); + Tensor x = g->addTensor({12, 224, 224}, DataType::Float32); + Tensor y = g->addTensor({1, 1, 224, 224}, DataType::Float32); + Tensor con = g->addTensor({2, 1, 1, 1, 224}, DataType::Bool); + auto op = g->addOp(x, y, con, nullptr); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 1, 12, 224, 224})); + } +} + +} // namespace infini