From ef672894d06e217317094f56613d4dac0de50623 Mon Sep 17 00:00:00 2001 From: zhangyunze <93699316+bitzyz@users.noreply.github.com> Date: Wed, 16 Aug 2023 21:49:43 +0800 Subject: [PATCH] support mixed dtype (#102) * feat: support mixed dtype * feat: support cast op * test: add test for cast op * feat: support datatype BFloat16 * feat: support data convert fp32 <-> bfp16 * fix: fix all op's infershape func * fix as review comment --- include/core/data_type.h | 18 +- include/core/graph_handler.h | 1 + include/core/tensor.h | 3 +- include/operators/transpose.h | 2 +- include/operators/unary.h | 50 +- include/utils/data_convert.h | 2 + include/utils/operator_utils.h | 15 + pyinfinitensor/src/pyinfinitensor/onnx.py | 916 ++++++++++++---------- pyinfinitensor/tests/test_onnx.py | 27 +- src/core/graph_handler.cc | 73 ++ src/core/tensor.cc | 2 + src/ffi/ffi_infinitensor.cc | 14 +- src/kernels/cpu/matmul.cc | 1 - src/kernels/cuda/gather.cc | 6 +- src/kernels/cuda/matmul.cc | 4 +- src/kernels/cuda/pad_slice.cc | 2 +- src/kernels/cuda/reduce_mean.cc | 2 +- src/kernels/cuda/resize.cc | 2 +- src/kernels/cuda/split_concat.cc | 10 +- src/kernels/intelcpu/batch_norm.cc | 4 +- src/kernels/intelcpu/element_wise.cc | 4 +- src/kernels/intelcpu/gather.cc | 6 +- src/kernels/intelcpu/pad.cc | 2 +- src/kernels/intelcpu/pooling.cc | 2 +- src/kernels/intelcpu/reduce.cc | 10 +- src/kernels/intelcpu/reshape.cc | 2 +- src/kernels/intelcpu/resize.cc | 4 +- src/kernels/intelcpu/slice.cc | 2 +- src/kernels/intelcpu/softmax.cc | 2 +- src/kernels/intelcpu/split.cc | 2 +- src/operators/G2BMM.cc | 15 +- src/operators/GBMM.cc | 12 +- src/operators/batch_norm.cc | 7 +- src/operators/concat.cc | 13 +- src/operators/conv.cc | 9 +- src/operators/det.cc | 4 +- src/operators/element_wise.cc | 31 +- src/operators/extend.cc | 4 +- src/operators/gather.cc | 9 +- src/operators/matmul.cc | 77 +- src/operators/pad.cc | 14 +- src/operators/pooling.cc | 8 +- src/operators/reduce_mean.cc | 10 +- src/operators/reshape.cc | 21 +- src/operators/resize.cc | 65 +- src/operators/softmax.cc | 9 +- src/operators/split.cc | 11 +- src/operators/transpose.cc | 28 +- src/operators/unary.cc | 50 +- src/utils/data_convert.cc | 13 + src/utils/dataloader.cc | 2 +- src/utils/operator_utils.cc | 44 ++ test/core/test_hash.cc | 2 +- test/operators/test_matmul.cc | 24 + test/operators/test_transpose.cc | 32 + 55 files changed, 992 insertions(+), 712 deletions(-) create mode 100644 include/utils/operator_utils.h create mode 100644 src/utils/operator_utils.cc create mode 100644 test/operators/test_transpose.cc diff --git a/include/core/data_type.h b/include/core/data_type.h index 2fb05a07..eb6a6a8d 100644 --- a/include/core/data_type.h +++ b/include/core/data_type.h @@ -19,6 +19,7 @@ class DataType { static const DataType Double; static const DataType UInt32; static const DataType UInt64; + static const DataType BFloat16; // "sizePerElement" show the DType to cpu_type // DataType::Bool -> int8_t DataType::Float16 -> uint16_t static constexpr size_t sizePerElement[]{0, @@ -34,14 +35,19 @@ class DataType { sizeof(uint16_t), sizeof(double), sizeof(uint32_t), - sizeof(uint64_t)}; + sizeof(uint64_t), + 0, + 0, + sizeof(uint16_t)}; static constexpr std::string_view names[]{ - "Undefine", "Float32", "UInt8", "Int8", "UInt16", - "Int16", "Int32", "Int64", "String", "Bool", - "Float16", "Double", "UInt32", "UInt64"}; + "Undefine", "Float32", "UInt8", "Int8", "UInt16", + "Int16", "Int32", "Int64", "String", "Bool", + "Float16", "Double", "UInt32", "UInt64", "PlaceHolder", + "PlaceHolder", "BFloat16"}; - static constexpr int cpuType[]{-1, 0, 2, 3, 4, 5, 6, 7, -1, 3, 4, 9, 1, 8}; + static constexpr int cpuType[]{-1, 0, 2, 3, 4, 5, 6, 7, -1, + 3, 4, 9, 1, 8, -1, -1, 4}; private: int index; @@ -79,6 +85,7 @@ inline const DataType DataType::Float16(10); inline const DataType DataType::Double(11); inline const DataType DataType::UInt32(12); inline const DataType DataType::UInt64(13); +inline const DataType DataType::BFloat16(16); // Method definitions are out of the declaration due to GCC bug: // https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc template <> inline int DataType::get() { return 0; } @@ -107,5 +114,6 @@ template <> struct DT<10> { using t = uint16_t; }; template <> struct DT<11> { using t = double; }; template <> struct DT<12> { using t = uint32_t; }; template <> struct DT<13> { using t = uint64_t; }; +template <> struct DT<16> { using t = uint16_t; }; } // namespace infini diff --git a/include/core/graph_handler.h b/include/core/graph_handler.h index b8b7bc9d..48c79a6e 100644 --- a/include/core/graph_handler.h +++ b/include/core/graph_handler.h @@ -66,6 +66,7 @@ class GraphHandlerObj { const optional> &steps); Tensor pad(Tensor input, Tensor output, const vector &pads, const optional> &axes); + Tensor cast(Tensor input, Tensor output, int to); //------ modifiers diff --git a/include/core/tensor.h b/include/core/tensor.h index 6dadd0d9..89f7d9be 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -36,6 +36,7 @@ class TensorObj : public TensorBaseObj { size_t getBytes() const { return _size * dtype.getSize(); } Shape getDims() const { return shape; } + size_t getRank() const { return shape.size(); } vector getStride() const; size_t getOffset(const vector &ds) const; void dataMalloc(); @@ -330,7 +331,7 @@ class TensorObj : public TensorBaseObj { // } // void initSplittingPoints() { - // splittingPoints.resize(getDims().size()); } + // splittingPoints.resize(getRank()); } // void printShape(); }; diff --git a/include/operators/transpose.h b/include/operators/transpose.h index b26ed49a..61dc8e5a 100644 --- a/include/operators/transpose.h +++ b/include/operators/transpose.h @@ -15,7 +15,7 @@ class TransposeObj : public OperatorObj { std::vector getPermute() const { return transposePermute; } private: - vector transposePermute = {1, 1, 1, 1}; + vector transposePermute; vector getWorkloadVector() const override; vector getOpAttrVector() const override; }; diff --git a/include/operators/unary.h b/include/operators/unary.h index 3e94a548..8a3d9704 100644 --- a/include/operators/unary.h +++ b/include/operators/unary.h @@ -134,31 +134,35 @@ class TransformObj : public OperatorObj { vector getOpAttrVector() const override; }; +enum class CastType { + Float2Float16 = 0, + Float2Int64, + Float2Int32, + Float2Int16, + Float2Int8, + Float2BFloat16, + Int322Float, + Int322Int8, + Int322Int16, + Int322Int64, + Int162Float, + Int162Int32, + Int82Float, + Int82Int16, + Int82Int32, + Uint82Float, + Uint82Int32, + Uint82Int64, + Int642Int32, + Int642Uint32, + Int642Float, + Uint322Int64, + Float162Float, + BFloat162Float, +}; + class CastObj : public OperatorObj { public: - enum CastType { - Float2Half = 0, - Float2Int64, - Float2Int32, - Float2Int16, - Float2Int8, - Int322Float, - Int322Int8, - Int322Int16, - Int162Float, - Int162Int32, - Int82Float, - Int82Int16, - Int82Int32, - Uint82Float, - Uint82Int32, - Uint82Int64, - Int322Int64, - Int642Int32, - Int642Uint32, - Int642Float, - Uint322Int64, - }; CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type); OP_CLONE(CastObj); optional> inferShape(const TensorVec &inputs) const override; diff --git a/include/utils/data_convert.h b/include/utils/data_convert.h index 51dd8501..2cbcdbac 100644 --- a/include/utils/data_convert.h +++ b/include/utils/data_convert.h @@ -8,4 +8,6 @@ union Uf32 { }; uint16_t float_to_fp16(const float x); float fp16_to_float(const uint16_t x); +uint16_t float_to_bfp16(const float x); +float bfp16_to_float(const uint16_t x); } // namespace infini diff --git a/include/utils/operator_utils.h b/include/utils/operator_utils.h new file mode 100644 index 00000000..01703252 --- /dev/null +++ b/include/utils/operator_utils.h @@ -0,0 +1,15 @@ +#pragma once +#ifndef OPERATOR_UTIL_H +#define OPERATOR_UTIL_H + +#include "core/tensor.h" + +namespace infini { + +// Launch a broadcast shape based on the shape of input A and B +Shape infer_broadcast(const Shape &A, const Shape &B); +// Launch the real axis based on rank and current axis +int get_real_axis(const int &axis, const int &rank); +} // namespace infini + +#endif diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 1a5186a7..bc057d9b 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -62,462 +62,520 @@ class OnnxStub: tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type) data[initializer.name] = initializer + node_name = [] + new_node_name = [] for node in model.graph.node: - if node.op_type == "Conv": - attributes = _parse_attribute( - node, - { - "dilations": [1, 1], - "pads": [0, 0, 0, 0], - "strides": [1, 1], - }, - ) - (d, p, s) = ( - attributes[name] for name in ["dilations", "pads", "strides"] - ) - if p[0] != p[2] or p[1] != p[3]: - adapt = "{}-adapt".format(node.output[0]) - tensors[adapt] = self.handler.pad( - tensors[node.input[0]], None, p, [-2, -1] + node_name.append(node.name) + node_list = model.graph.node + while len(node_list) != 0: + for node in model.graph.node: + if node.name not in node_list: + continue + if _analyse_node(node, tensors): + continue + if node.op_type == "Conv": + attributes = _parse_attribute( + node, + { + "dilations": [1, 1], + "pads": [0, 0, 0, 0], + "strides": [1, 1], + }, ) - p = [0, 0, 0, 0] - else: - adapt = node.input[0] + (d, p, s) = ( + attributes[name] for name in ["dilations", "pads", "strides"] + ) + if p[0] != p[2] or p[1] != p[3]: + adapt = "{}-adapt".format(node.output[0]) + tensors[adapt] = self.handler.pad( + tensors[node.input[0]], None, p, [-2, -1] + ) + p = [0, 0, 0, 0] + else: + adapt = node.input[0] - if len(node.input) > 2: - bias = "{}-bias".format(node.output[0]) - reshape = "{}-reshape".format(node.output[0]) - tensors[bias] = self.handler.conv( - tensors[adapt], + if len(node.input) > 2: + bias = "{}-bias".format(node.output[0]) + reshape = "{}-reshape".format(node.output[0]) + tensors[bias] = self.handler.conv( + tensors[adapt], + tensors[node.input[1]], + None, + p[0], + p[1], + s[0], + s[1], + d[0], + d[1], + ) + tensors[reshape] = self.handler.reshape( + tensors[node.input[2]], + None, + [ + 1, + reduce( + lambda acc, x: acc * x, + _search_shape(model, node.input[2]), + ), + 1, + 1, + ], + ) + tensors[node.output[0]] = self.handler.add( + tensors[bias], + tensors[reshape], + tensors.get(node.output[0]), + ) + else: + tensors[node.output[0]] = self.handler.conv( + tensors[adapt], + tensors[node.input[1]], + tensors.get(node.output[0]), + p[0], + p[1], + s[0], + s[1], + d[0], + d[1], + ) + elif node.op_type == "ConvTranspose": + attributes = _parse_attribute( + node, + { + "dilations": [1, 1], + "pads": [0, 0], + "strides": [1, 1], + "output_padding": [0, 0], + }, + ) + (d, p, s, op) = ( + attributes[name] + for name in ["dilations", "pads", "strides", "output_padding"] + ) + tensors[node.output[0]] = self.handler.convTransposed2d( + tensors[node.input[0]], tensors[node.input[1]], - None, + tensors.get(node.output[0]), p[0], p[1], s[0], s[1], d[0], d[1], + op[0], + op[1], ) - tensors[reshape] = self.handler.reshape( - tensors[node.input[2]], + elif node.op_type == "MatMul": + tensors[node.output[0]] = self.handler.matmul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + False, + False, None, - [ - 1, - reduce( - lambda acc, x: acc * x, - _search_shape(model, node.input[2]), - ), - 1, - 1, - ], + backend.ActType.Linear, ) + elif node.op_type == "Gemm": + attributes = _parse_attribute( + node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0} + ) + (alpha, beta, transA, transB) = ( + attributes[name] + for name in ["alpha", "beta", "transA", "transB"] + ) + # FIXME unsupport attributes: `alpha` `beta` + assert alpha == 1.0 + assert beta == 1.0 + tensors[node.output[0]] = self.handler.matmul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + transA == 1, + transB == 1, + tensors[node.input[2]] if len(node.input) > 2 else None, + backend.ActType.Linear, + ) + elif node.op_type == "BatchNormalization": + (input, mean, var, scale, bias) = ( + tensors[node.input[i]] for i in [0, 3, 4, 1, 2] + ) + output = tensors.get(node.output[0]) + attributes = _parse_attribute( + node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0} + ) + (momentum, eps, training) = ( + attributes[name] + for name in ["momentum", "epsilon", "training_mode"] + ) + tensors[node.output[0]] = self.handler.batchNormalization( + input, + output, + mean, + var, + scale, + bias, + momentum, + eps, + training != 0, + ) + elif node.op_type == "MaxPool": + attributes = _parse_attribute( + node, + { + "kernel_shape": None, + "dilations": [1, 1], + "pads": [0, 0, 0, 0], + "strides": [1, 1], + }, + ) + (k, d, p, s) = ( + attributes[name] + for name in ["kernel_shape", "dilations", "pads", "strides"] + ) + if p[0] != p[2] or p[1] != p[3]: + adapt = "{}-adapt".format(node.output[0]) + tensors[adapt] = self.handler.pad( + tensors.get(node.input[0]), None, p, [-2, -1] + ) + tensors[node.output[0]] = self.handler.maxPool( + tensors[adapt], + tensors.get(node.output[0]), + k[0], + k[1], + d[0], + d[1], + 0, + 0, + s[0], + s[1], + ) + else: + tensors[node.output[0]] = self.handler.maxPool( + tensors[node.input[0]], + tensors.get(node.output[0]), + k[0], + k[1], + d[0], + d[1], + p[0], + p[1], + s[0], + s[1], + ) + elif node.op_type == "AveragePool": + attributes = _parse_attribute( + node, + { + "kernel_shape": None, + "pads": [0, 0, 0, 0], + "strides": [1, 1], + }, + ) + (k, p, s) = ( + attributes[name] for name in ["kernel_shape", "pads", "strides"] + ) + if p[0] != p[2] or p[1] != p[3]: + adapt = "{}-adapt".format(node.output[0]) + tensors[adapt] = self.handler.pad( + tensors.get(node.input[0]), None, p, [-2, -1] + ) + tensors[node.output[0]] = self.handler.avgPool( + tensors[adapt], + tensors.get(node.output[0]), + k[0], + k[1], + 1, + 1, + 0, + 0, + s[0], + s[1], + ) + else: + tensors[node.output[0]] = self.handler.avgPool( + tensors[node.input[0]], + tensors.get(node.output[0]), + k[0], + k[1], + 1, + 1, + p[0], + p[1], + s[0], + s[1], + ) + elif node.op_type == "GlobalAveragePool": + [_, _, h, w] = _search_shape(model, node.input[0]) + tensors[node.output[0]] = self.handler.avgPool( + tensors[node.input[0]], + tensors.get(node.output[0]), + h, + w, + 1, + 1, + 0, + 0, + 1, + 1, + ) + elif node.op_type == "Add": tensors[node.output[0]] = self.handler.add( - tensors[bias], - tensors[reshape], - tensors.get(node.output[0]), - ) - else: - tensors[node.output[0]] = self.handler.conv( - tensors[adapt], + tensors[node.input[0]], tensors[node.input[1]], tensors.get(node.output[0]), - p[0], - p[1], - s[0], - s[1], - d[0], - d[1], ) - elif node.op_type == "ConvTranspose": - attributes = _parse_attribute( - node, - { - "dilations": [1, 1], - "pads": [0, 0], - "strides": [1, 1], - "output_padding": [0, 0], - }, - ) - (d, p, s, op) = ( - attributes[name] - for name in ["dilations", "pads", "strides", "output_padding"] - ) - tensors[node.output[0]] = self.handler.convTransposed2d( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - p[0], - p[1], - s[0], - s[1], - d[0], - d[1], - op[0], - op[1], - ) - elif node.op_type == "MatMul": - tensors[node.output[0]] = self.handler.matmul( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - False, - False, - None, - backend.ActType.Linear, - ) - elif node.op_type == "Gemm": - attributes = _parse_attribute( - node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0} - ) - (alpha, beta, transA, transB) = ( - attributes[name] for name in ["alpha", "beta", "transA", "transB"] - ) - # FIXME unsupport attributes: `alpha` `beta` - assert alpha == 1.0 - assert beta == 1.0 - tensors[node.output[0]] = self.handler.matmul( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - transA == 1, - transB == 1, - tensors[node.input[2]] if len(node.input) > 2 else None, - backend.ActType.Linear, - ) - elif node.op_type == "BatchNormalization": - (input, mean, var, scale, bias) = ( - tensors[node.input[i]] for i in [0, 3, 4, 1, 2] - ) - output = tensors.get(node.output[0]) - attributes = _parse_attribute( - node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0} - ) - (momentum, eps, training) = ( - attributes[name] - for name in ["momentum", "epsilon", "training_mode"] - ) - tensors[node.output[0]] = self.handler.batchNormalization( - input, output, mean, var, scale, bias, momentum, eps, training != 0 - ) - elif node.op_type == "MaxPool": - attributes = _parse_attribute( - node, - { - "kernel_shape": None, - "dilations": [1, 1], - "pads": [0, 0, 0, 0], - "strides": [1, 1], - }, - ) - (k, d, p, s) = ( - attributes[name] - for name in ["kernel_shape", "dilations", "pads", "strides"] - ) - if p[0] != p[2] or p[1] != p[3]: - adapt = "{}-adapt".format(node.output[0]) - tensors[adapt] = self.handler.pad( - tensors.get(node.input[0]), None, p, [-2, -1] - ) - tensors[node.output[0]] = self.handler.maxPool( - tensors[adapt], + elif node.op_type == "Sub": + tensors[node.output[0]] = self.handler.sub( + tensors[node.input[0]], + tensors[node.input[1]], tensors.get(node.output[0]), - k[0], - k[1], - d[0], - d[1], - 0, - 0, - s[0], - s[1], ) - else: - tensors[node.output[0]] = self.handler.maxPool( + elif node.op_type == "Mul": + tensors[node.output[0]] = self.handler.mul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Div": + tensors[node.output[0]] = self.handler.div( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Pow": + tensors[node.output[0]] = self.handler.pow( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Relu": + tensors[node.output[0]] = self.handler.relu( tensors[node.input[0]], tensors.get(node.output[0]), - k[0], - k[1], - d[0], - d[1], - p[0], - p[1], - s[0], - s[1], ) - elif node.op_type == "AveragePool": - attributes = _parse_attribute( - node, - { - "kernel_shape": None, - "pads": [0, 0, 0, 0], - "strides": [1, 1], - }, - ) - (k, p, s) = ( - attributes[name] for name in ["kernel_shape", "pads", "strides"] - ) - if p[0] != p[2] or p[1] != p[3]: - adapt = "{}-adapt".format(node.output[0]) - tensors[adapt] = self.handler.pad( - tensors.get(node.input[0]), None, p, [-2, -1] - ) - tensors[node.output[0]] = self.handler.avgPool( - tensors[adapt], - tensors.get(node.output[0]), - k[0], - k[1], - 1, - 1, - 0, - 0, - s[0], - s[1], - ) - else: - tensors[node.output[0]] = self.handler.avgPool( + elif node.op_type == "Sigmoid": + tensors[node.output[0]] = self.handler.sigmoid( tensors[node.input[0]], tensors.get(node.output[0]), - k[0], - k[1], - 1, - 1, - p[0], - p[1], - s[0], - s[1], ) - elif node.op_type == "GlobalAveragePool": - [_, _, h, w] = _search_shape(model, node.input[0]) - tensors[node.output[0]] = self.handler.avgPool( - tensors[node.input[0]], - tensors.get(node.output[0]), - h, - w, - 1, - 1, - 0, - 0, - 1, - 1, - ) - elif node.op_type == "Add": - tensors[node.output[0]] = self.handler.add( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Sub": - tensors[node.output[0]] = self.handler.sub( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Mul": - tensors[node.output[0]] = self.handler.mul( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Div": - tensors[node.output[0]] = self.handler.div( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Pow": - tensors[node.output[0]] = self.handler.pow( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Relu": - tensors[node.output[0]] = self.handler.relu( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Sigmoid": - tensors[node.output[0]] = self.handler.sigmoid( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Tanh": - tensors[node.output[0]] = self.handler.tanh( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Softmax": - tensors[node.output[0]] = self.handler.softmax( - tensors[node.input[0]], - tensors.get(node.output[0]), - next( - (attr.i for attr in node.attribute if attr.name == "axis"), -1 - ), - ) - elif node.op_type == "Abs": - tensors[node.output[0]] = self.handler.abs( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Shape": - tensors[node.output[0]] = self.handler.shape( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Identity": - tensors[node.output[0]] = self.handler.identity( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Flatten": - tensors[node.output[0]] = self.handler.flatten( - tensors[node.input[0]], - tensors.get(node.output[0]), - next((attr.i for attr in node.attribute if attr.name == "axis")), - ) - elif node.op_type == "PRelu": - tensors[node.output[0]] = self.handler.pRelu( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Clip": - tensors[node.output[0]] = self.handler.clip( - tensors[node.input[0]], - tensors.get(node.output[0]), - next(_parse_data(data[node.input[1]]).__iter__(), None) - if len(node.input) > 1 - else None, - next(_parse_data(data[node.input[2]]).__iter__(), None) - if len(node.input) > 2 - else None, - ) - elif node.op_type == "Transpose": - perm = next( - (attr.ints for attr in node.attribute if attr.name == "perm"), None - ) - tensors[node.output[0]] = self.handler.transpose( - tensors[node.input[0]], - tensors.get(node.output[0]), - perm, - ) - elif node.op_type == "Reshape": - dims = _search_shape(model, node.input[0]) - size = reduce(lambda acc, x: acc * x, dims) - input_shape = _parse_data(data[node.input[1]]) - for i, x in enumerate(input_shape): - if x == 0: - input_shape[i] = dims[i] - temp = reduce(lambda acc, x: acc * x, input_shape, 1) - if temp < 0: - input_shape[input_shape.index(-1)] = size // -temp - tensors[node.output[0]] = self.handler.reshape( - tensors[node.input[0]], - tensors.get(node.output[0]), - input_shape, - ) - elif node.op_type == "Squeeze": - input_shape = _search_shape(model, node.input[0]) - axes = set( - [int(i) for i in data[node.input[1]].int64_data] - if len(node.input) > 1 - else _parse_attribute(node, {"axes": None})["axes"] - ) - assert all(input_shape[d] == 1 for d in axes) - output_shape = [] - for i, x in enumerate(input_shape): - if i not in axes: - output_shape.append(x) - tensors[node.output[0]] = self.handler.reshape( - tensors[node.input[0]], - tensors.get(node.output[0]), - output_shape, - ) - elif node.op_type == "Unsqueeze": - input_shape = _search_shape(model, node.input[0]) - axes = ( - [int(i) for i in data[node.input[1]].int64_data] - if len(node.input) > 1 - else _parse_attribute(node, {"axes": None})["axes"] - ) - for i in axes: - input_shape.insert(i, 1) - tensors[node.output[0]] = self.handler.reshape( - tensors[node.input[0]], - tensors.get(node.output[0]), - input_shape, - ) - elif node.op_type == "Concat": - tensors[node.output[0]] = self.handler.concat( - [tensors[name] for name in node.input], - tensors.get(node.output[0]), - next((attr.i for attr in node.attribute if attr.name == "axis")), - ) - elif node.op_type == "Split": - for name, tensor in zip( - node.output, - self.handler.split( + elif node.op_type == "Tanh": + tensors[node.output[0]] = self.handler.tanh( tensors[node.input[0]], - None, + tensors.get(node.output[0]), + ) + elif node.op_type == "Softmax": + tensors[node.output[0]] = self.handler.softmax( + tensors[node.input[0]], + tensors.get(node.output[0]), next( (attr.i for attr in node.attribute if attr.name == "axis"), - 0, + -1, ), - len(node.output), - ), - ): - tensors[name] = tensor - elif node.op_type == "Gather": - tensors[node.output[0]] = self.handler.gather( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - next((attr.i for attr in node.attribute if attr.name == "axis")), - ) - elif node.op_type == "ReduceMean": - tensors[node.output[0]] = self.handler.reduce_mean( - tensors[node.input[0]], - tensors.get(node.output[0]), - # NOTE(constroy): `axes` is an attribute until opset version 13. - next( - (attr.ints for attr in node.attribute if attr.name == "axes"), - None, - ), - next((attr.i for attr in node.attribute if attr.name == "keepdims")) - != 0, - ) - elif node.op_type == "Slice": - tensors[node.output[0]] = self.handler.slice( - tensors[node.input[0]], - tensors.get(node.output[0]), - _parse_data(data[node.input[1]]), - _parse_data(data[node.input[2]]), - _parse_data(data[node.input[3]]) if len(node.input) > 3 else None, - _parse_data(data[node.input[4]]) if len(node.input) > 4 else None, - ) - elif node.op_type == "Pad": - tensors[node.output[0]] = self.handler.pad( - tensors[node.input[0]], - tensors.get(node.output[0]), - _parse_data(data[node.input[1]]), - _parse_data(data[node.input[3]]) if len(node.input) > 3 else None, - ) - elif node.op_type == "Dropout": - for name, tensor in zip( - node.output, - self.handler.dropout( + ) + elif node.op_type == "Abs": + tensors[node.output[0]] = self.handler.abs( tensors[node.input[0]], tensors.get(node.output[0]), - tensors.get(node.output[1]) if len(node.output) > 1 else None, - _parse_data(data[node.input[1]])[0] + ) + elif node.op_type == "Shape": + tensors[node.output[0]] = self.handler.shape( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Identity": + tensors[node.output[0]] = self.handler.identity( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Flatten": + tensors[node.output[0]] = self.handler.flatten( + tensors[node.input[0]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "axis") + ), + ) + elif node.op_type == "PRelu": + tensors[node.output[0]] = self.handler.pRelu( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Clip": + tensors[node.output[0]] = self.handler.clip( + tensors[node.input[0]], + tensors.get(node.output[0]), + next(_parse_data(data[node.input[1]]).__iter__(), None) if len(node.input) > 1 - else 0.5, - _parse_data(data[node.input[2]])[0] + else None, + next(_parse_data(data[node.input[2]]).__iter__(), None) if len(node.input) > 2 - else False, - ), - ): - tensors[name] = tensor - else: - raise Exception('Unsupported operator "{}"'.format(node.op_type)) + else None, + ) + elif node.op_type == "Transpose": + perm = next( + (attr.ints for attr in node.attribute if attr.name == "perm"), + None, + ) + tensors[node.output[0]] = self.handler.transpose( + tensors[node.input[0]], + tensors.get(node.output[0]), + perm, + ) + elif node.op_type == "Reshape": + dims = _search_shape(model, node.input[0]) + size = reduce(lambda acc, x: acc * x, dims) + input_shape = _parse_data(data[node.input[1]]) + for i, x in enumerate(input_shape): + if x == 0: + input_shape[i] = dims[i] + temp = reduce(lambda acc, x: acc * x, input_shape, 1) + if temp < 0: + input_shape[input_shape.index(-1)] = size // -temp + tensors[node.output[0]] = self.handler.reshape( + tensors[node.input[0]], + tensors.get(node.output[0]), + input_shape, + ) + elif node.op_type == "Squeeze": + input_shape = _search_shape(model, node.input[0]) + axes = set( + [int(i) for i in data[node.input[1]].int64_data] + if len(node.input) > 1 + else _parse_attribute(node, {"axes": None})["axes"] + ) + assert all(input_shape[d] == 1 for d in axes) + output_shape = [] + for i, x in enumerate(input_shape): + if i not in axes: + output_shape.append(x) + tensors[node.output[0]] = self.handler.reshape( + tensors[node.input[0]], + tensors.get(node.output[0]), + output_shape, + ) + elif node.op_type == "Unsqueeze": + input_shape = _search_shape(model, node.input[0]) + axes = ( + [int(i) for i in data[node.input[1]].int64_data] + if len(node.input) > 1 + else _parse_attribute(node, {"axes": None})["axes"] + ) + for i in axes: + input_shape.insert(i, 1) + tensors[node.output[0]] = self.handler.reshape( + tensors[node.input[0]], + tensors.get(node.output[0]), + input_shape, + ) + elif node.op_type == "Concat": + tensors[node.output[0]] = self.handler.concat( + [tensors[name] for name in node.input], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "axis") + ), + ) + elif node.op_type == "Split": + for name, tensor in zip( + node.output, + self.handler.split( + tensors[node.input[0]], + None, + next( + ( + attr.i + for attr in node.attribute + if attr.name == "axis" + ), + 0, + ), + len(node.output), + ), + ): + tensors[name] = tensor + elif node.op_type == "Gather": + tensors[node.output[0]] = self.handler.gather( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + next( + (attr.i for attr in node.attribute if attr.name == "axis") + ), + ) + elif node.op_type == "ReduceMean": + tensors[node.output[0]] = self.handler.reduce_mean( + tensors[node.input[0]], + tensors.get(node.output[0]), + # NOTE(constroy): `axes` is an attribute until opset version 13. + next( + ( + attr.ints + for attr in node.attribute + if attr.name == "axes" + ), + None, + ), + next( + ( + attr.i + for attr in node.attribute + if attr.name == "keepdims" + ) + ) + != 0, + ) + elif node.op_type == "Slice": + tensors[node.output[0]] = self.handler.slice( + tensors[node.input[0]], + tensors.get(node.output[0]), + _parse_data(data[node.input[1]]), + _parse_data(data[node.input[2]]), + _parse_data(data[node.input[3]]) + if len(node.input) > 3 + else None, + _parse_data(data[node.input[4]]) + if len(node.input) > 4 + else None, + ) + elif node.op_type == "Pad": + tensors[node.output[0]] = self.handler.pad( + tensors[node.input[0]], + tensors.get(node.output[0]), + _parse_data(data[node.input[1]]), + _parse_data(data[node.input[3]]) + if len(node.input) > 3 + else None, + ) + elif node.op_type == "Dropout": + for name, tensor in zip( + node.output, + self.handler.dropout( + tensors[node.input[0]], + tensors.get(node.output[0]), + tensors.get(node.output[1]) + if len(node.output) > 1 + else None, + _parse_data(data[node.input[1]])[0] + if len(node.input) > 1 + else 0.5, + _parse_data(data[node.input[2]])[0] + if len(node.input) > 2 + else False, + ), + ): + tensors[name] = tensor + elif node.op_type == "Cast": + tensors[node.output[0]] = self.handler.cast( + tensors[node.input[0]], + tensors.get(node.output[0]), + next((attr.i for attr in node.attribute if attr.name == "to")), + ) + else: + raise Exception('Unsupported operator "{}"'.format(node.op_type)) + new_node_name.append(node.name) + # update the node_list + node_list = list(set(node_name) - set(new_node_name)) self.handler.data_malloc() @@ -540,6 +598,8 @@ class OnnxStub: obj.copyin_float16(_parse_data_fp16(tensor)) elif tensor.data_type == TensorProto.INT8: obj.copyin_uint8(_parse_data(tensor)) + elif tensor.data_type == TensorProto.BFLOAT16: + obj.copyin_float16(_parse_data_fp16(tensor)) else: assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) @@ -823,6 +883,9 @@ class OnnxStub: ctx.push_data_input(name, "max", TensorProto.FLOAT, [], []) ) ctx.push_node(make_node(ty.name, inputs, outputs, name)) + elif ty == backend.OpTypeId.Cast: + to = backend.cast_to_of(op) + ctx.push_node(make_node(ty.name, inputs, outputs, name, to=to)) else: raise Exception("Unsupported OpType", ty) @@ -922,3 +985,10 @@ def _parse_data_fp16(tensor: TensorProto): def _take_shape_dim(shape: TensorShapeProto) -> List[int]: return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim] + + +def _analyse_node(node: NodeProto, tensors) -> bool: + for i in node.input: + if i not in tensors: + return True + return False diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 4283a981..497bae9b 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -79,6 +79,21 @@ class TestStringMethods(unittest.TestCase): ) make_and_import_model(make_graph([conv], "conv_fp16", [i, w], [o])) + def test_conv_bfp16(self): + i = make_tensor_value_info("i", TensorProto.BFLOAT16, [1, 3, 4, 4]) + w = make_tensor_value_info("w", TensorProto.BFLOAT16, [2, 3, 3, 3]) + o = make_tensor_value_info("o", TensorProto.BFLOAT16, [1, 2, 2, 2]) + conv = make_node( + "Conv", + ["i", "w"], + ["o"], + "conv", + pads=[1, 1, 1, 1], + strides=[2, 1], + dilations=[1, 2], + ) + make_and_import_model(make_graph([conv], "conv_bfp16", [i, w], [o])) + def test_matmul(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4]) @@ -226,9 +241,7 @@ class TestStringMethods(unittest.TestCase): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1 * 3, 5 * 7]) flatten = make_node("Flatten", ["x"], ["y"], axis=2, name="flatten") - make_and_import_model( - make_graph([flatten], "flatten", [x], [y]) - ) + make_and_import_model(make_graph([flatten], "flatten", [x], [y])) def test_reshape(self): data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5]) @@ -331,6 +344,14 @@ class TestStringMethods(unittest.TestCase): y = handler.tensor([3, 2, 1], 12) handler.reshape(x, y, [3, 2, 1]) + def test_cast(self): + input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4]) + output = make_tensor_value_info("output", TensorProto.FLOAT16, [1, 3, 2, 4]) + cast = make_node( + "Cast", ["input1"], ["output"], to=TensorProto.FLOAT16, name="cast" + ) + make_and_import_model(make_graph([cast], "cast", [input1], [output])) + if __name__ == "__main__": unittest.main() diff --git a/src/core/graph_handler.cc b/src/core/graph_handler.cc index 17074ef7..2f59fd98 100644 --- a/src/core/graph_handler.cc +++ b/src/core/graph_handler.cc @@ -18,6 +18,7 @@ namespace infini { static DataType dtype_repr_convert(int); +static CastType inferCastType(Tensor input, int to); Tensor GraphHandlerObj::tensor(Shape dims, int dtype) { return g->addTensor(std::move(dims), dtype_repr_convert(dtype)); @@ -293,6 +294,76 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output, } } +Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) { + if (output) { + g->addOpWithOutputs(std::move(input), output, + inferCastType(input, to)); + return output; + } else { + return g + ->addOp(std::move(input), output, inferCastType(input, to)) + ->getOutput(); + } +} + +static CastType inferCastType(Tensor input, int to) { + auto iType = input->getDType(); + auto oType = DataType(to); + if (iType == DataType::Float32 && oType == DataType::Float16) { + return CastType::Float2Float16; + } else if (iType == DataType::Float32 && oType == DataType::Int64) { + return CastType::Float2Int64; + } else if (iType == DataType::Float32 && oType == DataType::Int32) { + return CastType::Float2Int32; + } else if (iType == DataType::Float32 && oType == DataType::Int16) { + return CastType::Float2Int16; + } else if (iType == DataType::Float32 && oType == DataType::Int8) { + return CastType::Float2Int8; + } else if (iType == DataType::Float32 && oType == DataType::BFloat16) { + return CastType::Float2BFloat16; + } else if (iType == DataType::Int32 && oType == DataType::Float32) { + return CastType::Int322Float; + } else if (iType == DataType::Int32 && oType == DataType::Int8) { + return CastType::Int322Int8; + } else if (iType == DataType::Int32 && oType == DataType::Int16) { + return CastType::Int322Int16; + } else if (iType == DataType::Int32 && oType == DataType::Int64) { + return CastType::Int322Int64; + } else if (iType == DataType::Int16 && oType == DataType::Int32) { + return CastType::Int162Int32; + } else if (iType == DataType::Int16 && oType == DataType::Float32) { + return CastType::Int162Float; + } else if (iType == DataType::Int8 && oType == DataType::Float32) { + return CastType::Int82Float; + } else if (iType == DataType::Int8 && oType == DataType::Int16) { + return CastType::Int82Int16; + } else if (iType == DataType::Int8 && oType == DataType::Int32) { + return CastType::Int82Int32; + } else if (iType == DataType::UInt8 && oType == DataType::Int32) { + return CastType::Uint82Int32; + } else if (iType == DataType::UInt8 && oType == DataType::Float32) { + return CastType::Uint82Float; + } else if (iType == DataType::UInt8 && oType == DataType::Int64) { + return CastType::Uint82Int64; + } else if (iType == DataType::Int64 && oType == DataType::Float32) { + return CastType::Int642Float; + } else if (iType == DataType::Int64 && oType == DataType::UInt32) { + return CastType::Int642Uint32; + } else if (iType == DataType::Int64 && oType == DataType::Int32) { + return CastType::Int642Int32; + } else if (iType == DataType::UInt32 && oType == DataType::Int64) { + return CastType::Uint322Int64; + } else if (iType == DataType::Float16 && oType == DataType::Float32) { + return CastType::Float162Float; + } else if (iType == DataType::BFloat16 && oType == DataType::Float32) { + return CastType::BFloat162Float; + } else { + IT_TODO_HALT_MSG("Unsupported CastType : input_type is " + + iType.toString() + " output_type is " + + oType.toString()); + } +} + static DataType dtype_repr_convert(int dtype) { switch (dtype) { case 0: @@ -323,6 +394,8 @@ static DataType dtype_repr_convert(int dtype) { return DataType::UInt32; case 13: return DataType::UInt64; + case 16: + return DataType::BFloat16; default: IT_ASSERT(false, "Unsupported data type"); } diff --git a/src/core/tensor.cc b/src/core/tensor.cc index f7c35d78..77b3b49b 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -85,6 +85,7 @@ void TensorObj::printData() const { else TRY_PRINT(11) // else TRY_PRINT(12) // else TRY_PRINT(13) // + else TRY_PRINT(16) // else IT_TODO_HALT(); #undef TRY_PRINT @@ -118,6 +119,7 @@ bool TensorObj::equalData(const Tensor &rhs, double relativeError) const { else TEST_EQUAL(11) // else TEST_EQUAL(12) // else TEST_EQUAL(13) // + else TEST_EQUAL(16) // else IT_TODO_HALT(); #undef TEST_EQUAL diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 30237a90..f5315e63 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -95,6 +95,7 @@ void export_values(py::module &m) { .VALUE(OpType, Abs) .VALUE(OpType, Resize) .VALUE(OpType, Dropout) + .VALUE(OpType, Cast) .export_values(); #undef VALUE @@ -129,6 +130,8 @@ static int tensor_dtype(Tensor t) { return 12; if (t->getDType() == DataType::UInt64) return 13; + if (t->getDType() == DataType::BFloat16) + return 16; IT_ASSERT(false, "Unsupported data type"); } @@ -242,6 +245,13 @@ static int flatten_axis_of(Operator op) { return dynamic_cast(op.get())->getAxis(); } +static int cast_to_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::Cast); + auto castOutputDtype = + dynamic_cast(op.get())->getOutputDataType(); + return castOutputDtype.getIndex(); +} + void export_functions(py::module &m) { #define FUNCTION(NAME) def(#NAME, &NAME) m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance) @@ -271,7 +281,8 @@ void export_functions(py::module &m) { .FUNCTION(concat_axis_of) .FUNCTION(split_axis_of) .FUNCTION(gather_axis_of) - .FUNCTION(flatten_axis_of); + .FUNCTION(flatten_axis_of) + .FUNCTION(cast_to_of); #undef FUNCTION } @@ -346,6 +357,7 @@ void init_graph_builder(py::module &m) { .def("reduce_mean", &Handler::reduceMean, policy::move) .def("slice", &Handler::slice, policy::move) .def("pad", &Handler::pad, policy::move) + .def("cast", &Handler::cast, policy::move) .def("topo_sort", &Handler::topo_sort, policy::automatic) .def("optimize", &Handler::optimize, policy::automatic) .def("operators", &Handler::operators, policy::move) diff --git a/src/kernels/cpu/matmul.cc b/src/kernels/cpu/matmul.cc index 01dcefa6..248cb60b 100644 --- a/src/kernels/cpu/matmul.cc +++ b/src/kernels/cpu/matmul.cc @@ -13,7 +13,6 @@ template class NaiveMatmul : public CpuKernelWithoutConfig { T *C = op->getOutput()->getRawDataPtr(); IT_ASSERT(op->getTransA() == false && op->getTransB() == false); IT_ASSERT(op->getAct() == ActType::None); - IT_ASSERT(op->getB() == 1); const int M = op->getM(), N = op->getN(), K = op->getK(); for (int i = 0; i < M; i++) { for (int j = 0; j < N; j++) { diff --git a/src/kernels/cuda/gather.cc b/src/kernels/cuda/gather.cc index 63725524..d769440e 100644 --- a/src/kernels/cuda/gather.cc +++ b/src/kernels/cuda/gather.cc @@ -14,9 +14,9 @@ class GatherCuda : public CudaKernelWithoutConfig { auto out = op->getOutput(); metaData.indexValue = index->getRawDataPtr(); metaData.axis = op->getAxis(); - metaData.inNDim = in->getDims().size(); - metaData.outNDim = out->getDims().size(); - metaData.idxNDim = index->getDims().size(); + metaData.inNDim = in->getRank(); + metaData.outNDim = out->getRank(); + metaData.idxNDim = index->getRank(); for (int i = 0; i < metaData.outNDim; ++i) metaData.outDim[i] = out->getDims()[i]; for (int i = 0; i < metaData.idxNDim; ++i) { diff --git a/src/kernels/cuda/matmul.cc b/src/kernels/cuda/matmul.cc index e238ef49..a2b55e04 100644 --- a/src/kernels/cuda/matmul.cc +++ b/src/kernels/cuda/matmul.cc @@ -51,8 +51,8 @@ class matmulCublas : public Kernel { cublasStatus_t stat; if (b > 1) { // Support batch broadcast with zero stride - int dimA = op->getInputs(0)->getDims().size(); - int dimB = op->getInputs(1)->getDims().size(); + int dimA = op->getInputs(0)->getRank(); + int dimB = op->getInputs(1)->getRank(); long long strideA = (dimA == 2 || (dimA == 3 && op->getInputs(0)->getDims()[0] == 1)) diff --git a/src/kernels/cuda/pad_slice.cc b/src/kernels/cuda/pad_slice.cc index 561a0923..2e7e3931 100644 --- a/src/kernels/cuda/pad_slice.cc +++ b/src/kernels/cuda/pad_slice.cc @@ -7,7 +7,7 @@ class PadSliceCudaCompute { public: void do_compute(Tensor partTensor, Tensor wholeTensor, const Shape &begNos, bool isPad) const { - int nDims = partTensor->getDims().size(); + int nDims = partTensor->getRank(); IT_ASSERT(MAX_DIM >= nDims); TransMetaData metadata; for (int i = 0; i < nDims; i++) { diff --git a/src/kernels/cuda/reduce_mean.cc b/src/kernels/cuda/reduce_mean.cc index e61f0019..6ae357c8 100644 --- a/src/kernels/cuda/reduce_mean.cc +++ b/src/kernels/cuda/reduce_mean.cc @@ -14,7 +14,7 @@ class ReduceMeanCudnn : public CudaKernelWithoutConfig { // Each dimension of the output tensor C must match the corresponding // dimension of the input tensor A or must be equal to 1. The dimensions // equal to 1 indicate the dimensions of A to be reduced. - int nInDims = input->getDims().size(); + int nInDims = input->getRank(); IT_ASSERT(CUDNN_DIM_MAX >= nInDims); int inDimArray[CUDNN_DIM_MAX], outDimArray[CUDNN_DIM_MAX], inStrideArray[CUDNN_DIM_MAX], outStrideArray[CUDNN_DIM_MAX]; diff --git a/src/kernels/cuda/resize.cc b/src/kernels/cuda/resize.cc index 3ee1765a..5becb913 100644 --- a/src/kernels/cuda/resize.cc +++ b/src/kernels/cuda/resize.cc @@ -9,7 +9,7 @@ class ResizeCuda : public CudaKernelWithoutConfig { auto in = op->getInputs(0); auto out = op->getOutputs()[0]; - int nDims = in->getDims().size(); + int nDims = in->getRank(); if (nDims > 4) IT_TODO_HALT(); diff --git a/src/kernels/cuda/split_concat.cc b/src/kernels/cuda/split_concat.cc index 5792753c..dbe2a7ac 100644 --- a/src/kernels/cuda/split_concat.cc +++ b/src/kernels/cuda/split_concat.cc @@ -9,7 +9,7 @@ namespace infini { class CudaCompute { void initComposedTensorMetadata(ComposedTensorMetadata &metadata, Tensor tensor) const { - int nDims = tensor->getDims().size(); + int nDims = tensor->getRank(); auto strides = tensor->getStride(); IT_ASSERT(strides.size() == (size_t)nDims); for (int i = 0; i < nDims; ++i) { @@ -60,8 +60,8 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { do_compute(_op->getOutput(), _op->getInputs(), - as(_op)->getDim(), - _op->getOutput()->getDims().size(), false); + as(_op)->getDim(), _op->getOutput()->getRank(), + false); } }; @@ -69,8 +69,8 @@ class SplitCuda : private CudaCompute, public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { do_compute(_op->getInputs(0), _op->getOutputs(), - as(_op)->getDim(), - _op->getInputs(0)->getDims().size(), true); + as(_op)->getDim(), _op->getInputs(0)->getRank(), + true); } }; diff --git a/src/kernels/intelcpu/batch_norm.cc b/src/kernels/intelcpu/batch_norm.cc index 9410bcb8..4583c013 100644 --- a/src/kernels/intelcpu/batch_norm.cc +++ b/src/kernels/intelcpu/batch_norm.cc @@ -14,7 +14,7 @@ class MklBatchNorm : public MklKernelWithoutConfig { // create user memory that describes data layout in the buffers std::vector dims; - for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) dims.push_back(op->getInputs(0)->getDims()[i]); auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, @@ -25,7 +25,7 @@ class MklBatchNorm : public MklKernelWithoutConfig { getUserFormatTag(dims.size())); auto output = dnnl::memory(dstMd, context->getEngine(), dstData); - std::vector meanDims(op->getInputs(0)->getDims().size(), 1); + std::vector meanDims(op->getInputs(0)->getRank(), 1); meanDims[1] = op->getInputs(0)->getDims()[1]; auto meanMd = dnnl::memory::desc(meanDims, dnnl::memory::data_type::f32, getUserFormatTag(meanDims.size())); diff --git a/src/kernels/intelcpu/element_wise.cc b/src/kernels/intelcpu/element_wise.cc index 2bccc819..0a27c31e 100644 --- a/src/kernels/intelcpu/element_wise.cc +++ b/src/kernels/intelcpu/element_wise.cc @@ -34,7 +34,7 @@ class MklBinary : public MklKernelWithoutConfig { // create user memory that describes data layout in the buffers std::vector dims; - for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) dims.push_back(op->getInputs(0)->getDims()[i]); auto srcMd1 = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, @@ -89,7 +89,7 @@ class MklUnary : public MklKernelWithoutConfig { // create user memory that describes data layout in the buffers std::vector dims; - for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) dims.push_back(op->getInputs(0)->getDims()[i]); auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, diff --git a/src/kernels/intelcpu/gather.cc b/src/kernels/intelcpu/gather.cc index a95ece4e..61549ccb 100644 --- a/src/kernels/intelcpu/gather.cc +++ b/src/kernels/intelcpu/gather.cc @@ -17,9 +17,9 @@ class MklGather : public MklKernelWithoutConfig { int oSize = out->size(); int idxSize = index->size(); - int inNDim = in->getDims().size(); - int oNDim = out->getDims().size(); - int idxNDim = index->getDims().size(); + int inNDim = in->getRank(); + int oNDim = out->getRank(); + int idxNDim = index->getRank(); int axis = op->getAxis(); int outDim[4] = {0}; diff --git a/src/kernels/intelcpu/pad.cc b/src/kernels/intelcpu/pad.cc index 02dc4143..8f52e7f6 100644 --- a/src/kernels/intelcpu/pad.cc +++ b/src/kernels/intelcpu/pad.cc @@ -10,7 +10,7 @@ class MklPad : public MklKernelWithoutConfig { auto context = dynamic_cast(_context); std::vector dims; - for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) { + for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) { dims.push_back(op->getInputs(0)->getDims()[i]); } auto paddedMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, diff --git a/src/kernels/intelcpu/pooling.cc b/src/kernels/intelcpu/pooling.cc index cfe8364f..d3c9e44d 100644 --- a/src/kernels/intelcpu/pooling.cc +++ b/src/kernels/intelcpu/pooling.cc @@ -17,7 +17,7 @@ class MklPooling : public MklKernelWithoutConfig { // create user memory that describes data layout in the buffers auto [n, c, h, w, r, s] = op->getNCHWRS(); auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); - auto nDim = op->getOutput()->getDims().size(); + auto nDim = op->getOutput()->getRank(); auto oh = op->getOutput()->getDims()[nDim - 2]; auto ow = op->getOutput()->getDims()[nDim - 1]; diff --git a/src/kernels/intelcpu/reduce.cc b/src/kernels/intelcpu/reduce.cc index 23202fec..6670229e 100644 --- a/src/kernels/intelcpu/reduce.cc +++ b/src/kernels/intelcpu/reduce.cc @@ -18,16 +18,16 @@ class MklReduce : public MklKernelWithoutConfig { // create user memory that describes data layout in the buffers std::vector inDims, inStrides; - for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) { + for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) { inDims.push_back(op->getInputs(0)->getDims()[i]); inStrides.push_back(op->getInputs(0)->getStride()[i]); } - std::vector oDims(op->getInputs(0)->getDims().size(), 0), - oStrides(op->getInputs(0)->getDims().size(), 1); + std::vector oDims(op->getInputs(0)->getRank(), 0), + oStrides(op->getInputs(0)->getRank(), 1); if (!op->getKeepDims()) { oDims = inDims; - for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) { + for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) { if (op->isReduced(i)) { oDims[i] = 1; } @@ -38,7 +38,7 @@ class MklReduce : public MklKernelWithoutConfig { stride *= oDims[i]; } } else { - for (size_t i = 0; i < op->getOutput(0)->getDims().size(); ++i) { + for (size_t i = 0; i < op->getOutput(0)->getRank(); ++i) { oDims[i] = op->getOutput(0)->getDims()[i]; oStrides[i] = op->getOutput(0)->getStride()[i]; } diff --git a/src/kernels/intelcpu/reshape.cc b/src/kernels/intelcpu/reshape.cc index bddef40f..2a17b881 100644 --- a/src/kernels/intelcpu/reshape.cc +++ b/src/kernels/intelcpu/reshape.cc @@ -10,7 +10,7 @@ class MklReshape : public MklKernelWithoutConfig { auto context = dynamic_cast(_context); std::vector dims; - for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) dims.push_back(op->getInputs(0)->getDims()[i]); // create src md and src memory diff --git a/src/kernels/intelcpu/resize.cc b/src/kernels/intelcpu/resize.cc index e7b3eea4..f9a85634 100644 --- a/src/kernels/intelcpu/resize.cc +++ b/src/kernels/intelcpu/resize.cc @@ -30,7 +30,7 @@ class MklResize : public MklKernelWithoutConfig { enum_to_underlying(ResizeObj::ECoordinateTransMode::halfPixel)) IT_TODO_HALT(); - int nDim = op->getInputs(0)->getDims().size(); + int nDim = op->getInputs(0)->getRank(); IT_ASSERT(nDim == 3 || nDim == 4 || nDim == 5 && (op->getInputs(0)->getDims()[0] == 1 && @@ -44,7 +44,7 @@ class MklResize : public MklKernelWithoutConfig { // create user memory that describes data layout in the buffers std::vector idims, odims; - for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) { + for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) { idims.push_back(op->getInputs(0)->getDims()[i]); odims.push_back(op->getOutput(0)->getDims()[i]); } diff --git a/src/kernels/intelcpu/slice.cc b/src/kernels/intelcpu/slice.cc index 663897cc..a5715ced 100644 --- a/src/kernels/intelcpu/slice.cc +++ b/src/kernels/intelcpu/slice.cc @@ -10,7 +10,7 @@ class MklSlice : public MklKernelWithoutConfig { auto context = dynamic_cast(_context); std::vector dims; - for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) dims.push_back(op->getInputs(0)->getDims()[i]); // create src md diff --git a/src/kernels/intelcpu/softmax.cc b/src/kernels/intelcpu/softmax.cc index f8ce568c..32c58a94 100644 --- a/src/kernels/intelcpu/softmax.cc +++ b/src/kernels/intelcpu/softmax.cc @@ -14,7 +14,7 @@ class MklSoftmax : public MklKernelWithoutConfig { // create user memory that describes data layout in the buffers std::vector dims; - for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) dims.push_back(op->getInputs(0)->getDims()[i]); auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, diff --git a/src/kernels/intelcpu/split.cc b/src/kernels/intelcpu/split.cc index 654cf9a8..df859083 100644 --- a/src/kernels/intelcpu/split.cc +++ b/src/kernels/intelcpu/split.cc @@ -10,7 +10,7 @@ class MklSplit : public MklKernelWithoutConfig { auto context = dynamic_cast(_context); std::vector dims; - for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) + for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) dims.push_back(op->getInputs(0)->getDims()[i]); // create src md diff --git a/src/operators/G2BMM.cc b/src/operators/G2BMM.cc index aafc829e..499c1f77 100644 --- a/src/operators/G2BMM.cc +++ b/src/operators/G2BMM.cc @@ -23,16 +23,11 @@ string G2BMMObj::toString() const { optional> G2BMMObj::inferShape(const TensorVec &inputs) const { auto A = inputs[0], B = inputs[1]; - if (!(A->getDims().size() == 3 && B->getDims().size() == 3)) - return {}; - if (!(A->getDims()[0] == B->getDims()[0])) - return {}; - if (!(A->getDims()[1] == B->getDims()[1])) - return {}; - if (!(A->getDims()[2] == B->getDims()[2])) - return {}; - if (width < 0) - return {}; + IT_ASSERT(A->getRank() == 3 && B->getRank() == 3); + IT_ASSERT(A->getDims()[0] == B->getDims()[0]); + IT_ASSERT(A->getDims()[1] == B->getDims()[1]); + IT_ASSERT(A->getDims()[2] == B->getDims()[2]); + IT_ASSERT(width >= 0); int b(A->getDims()[0]), m(A->getDims()[1]), n(2 * width + 1); return {{{b, m, n}}}; } diff --git a/src/operators/GBMM.cc b/src/operators/GBMM.cc index ab034472..d51128fa 100644 --- a/src/operators/GBMM.cc +++ b/src/operators/GBMM.cc @@ -24,14 +24,10 @@ string GBMMObj::toString() const { optional> GBMMObj::inferShape(const TensorVec &inputs) const { auto A = inputs[0], B = inputs[1]; - if (!(A->getDims().size() == 3 && B->getDims().size() == 3)) - return {}; - if (!(A->getDims()[0] == B->getDims()[0])) - return {}; - if (!(A->getDims()[1] == B->getDims()[1])) - return {}; - if (A->getDims()[2] % 2 == 0) - return {}; + IT_ASSERT(A->getRank() == 3 && B->getRank() == 3); + IT_ASSERT(A->getDims()[0] == B->getDims()[0]); + IT_ASSERT(A->getDims()[1] == B->getDims()[1]); + IT_ASSERT(A->getDims()[2] % 2 != 0); int b(A->getDims()[0]), m(A->getDims()[1]), k(B->getDims()[2]); return {{{b, m, k}}}; } diff --git a/src/operators/batch_norm.cc b/src/operators/batch_norm.cc index 69271377..ba68cbfd 100644 --- a/src/operators/batch_norm.cc +++ b/src/operators/batch_norm.cc @@ -21,9 +21,10 @@ BatchNormObj::inferShape(const TensorVec &inputs) const { auto scale = inputs[3]; auto bias = inputs[4]; auto c = std::vector{input->getDims()[1]}; - if (mean->getDims() != c || var->getDims() != c || scale->getDims() != c || - bias->getDims() != c) - return {}; + IT_ASSERT(mean->getRank() == 1 && mean->getDims() == c); + IT_ASSERT(var->getRank() == 1 && var->getDims() == c); + IT_ASSERT(scale->getRank() == 1 && scale->getDims() == c); + IT_ASSERT(bias->getRank() == 1 && bias->getDims() == c); return {{input->getDims()}}; } diff --git a/src/operators/concat.cc b/src/operators/concat.cc index 8f8a9f7b..78e30dad 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -1,28 +1,29 @@ #include "operators/concat.h" +#include "utils/operator_utils.h" namespace infini { ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim) : OperatorObj(OpType::Concat, inputs, {output}), dim(dim) { + int rank = inputs[0]->getRank(); + dim = get_real_axis(dim, rank); IT_ASSERT(checkValid(graph)); } optional> ConcatObj::inferShape(const TensorVec &inputs) const { IT_ASSERT(inputs.size() > 1); Shape dims = inputs[0]->getDims(); + auto rank = inputs[0]->getRank(); ShapeElem n = dims.at(dim); for (auto itr = inputs.begin() + 1; itr != inputs.end(); ++itr) { auto input = *itr; auto iDims = input->getDims(); - if (dims.size() != iDims.size()) - return {}; - int nDims = dims.size(); - for (auto i = 0; i < nDims; i++) { + IT_ASSERT(rank == input->getRank()); + for (auto i = 0; i < (int)rank; i++) { if (i == dim) { n += iDims.at(i); continue; } - if (iDims.at(i) != dims.at(i)) - return {}; + IT_ASSERT(iDims.at(i) == dims.at(i)); } } dims[dim] = n; diff --git a/src/operators/conv.cc b/src/operators/conv.cc index f7ab1944..8c3eafb4 100644 --- a/src/operators/conv.cc +++ b/src/operators/conv.cc @@ -93,8 +93,7 @@ optional> ConvObj::inferShape(const TensorVec &inputs) const { int on = n, oc = f; int oh = 0, ow = 0; // For NCHW+FCRS layout, C of input is divisable by C of weight - if (input->getDims()[1] % weight->getDims()[1] != 0) - return {}; + IT_ASSERT(input->getDims()[1] % weight->getDims()[1] == 0); // Set padding size if (padding == PaddingMode::Other) { oh = (h - (r - sh) * dh + ph * 2) / sh; @@ -151,8 +150,7 @@ ConvTransposed2dObj::inferShape(const TensorVec &inputs) const { auto c = weight->getDims()[1]; auto r = weight->getDims()[2]; auto s = weight->getDims()[3]; - if (f != weight->getDims()[0]) - return {}; + IT_ASSERT(f == weight->getDims()[0]); int on = n, oc = c * group; int oh = 0, ow = 0; @@ -232,8 +230,7 @@ ConvBackwardFilterObj::inferShape(const TensorVec &inputs) const { int on = n, oc = f; int oh = 0, ow = 0; // For NCHW+FCRS layout, C of input is divisable by C of weight - if (inputX->getDims()[1] % diffY->getDims()[1] != 0) - return {}; + IT_ASSERT(inputX->getDims()[1] % diffY->getDims()[1] == 0); // Set padding size if (padding == PaddingMode::Other) { oh = (h - (r - sh) * dh + ph * 2) / sh; diff --git a/src/operators/det.cc b/src/operators/det.cc index 0c84b5b7..473982cd 100644 --- a/src/operators/det.cc +++ b/src/operators/det.cc @@ -9,8 +9,8 @@ DetObj::DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode) optional> DetObj::inferShape(const TensorVec &inputs) const { const auto A = inputs[0]; auto input = A->getDims(); - int length = input.size(); - if (length == 2) { + int rank = A->getRank(); + if (rank == 2) { std::vector output = {1}; return {{output}}; } else { diff --git a/src/operators/element_wise.cc b/src/operators/element_wise.cc index 008c6872..d86ccccf 100644 --- a/src/operators/element_wise.cc +++ b/src/operators/element_wise.cc @@ -1,4 +1,5 @@ #include "operators/element_wise.h" +#include "utils/operator_utils.h" namespace infini { ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, @@ -9,31 +10,8 @@ ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, optional> ElementWiseObj::inferShape(const TensorVec &inputs) const { - // For now,we only process the same dims here, broardcast will be considered - // in the opt layer. const auto A = inputs[0], B = inputs[1]; - int max_len = std::max(A->getDims().size(), B->getDims().size()); - std::vector A_(max_len, 1); - std::vector B_(max_len, 1); - std::vector res(max_len, 1); - memcpy(A_.data() + max_len - A->getDims().size(), A->getDims().data(), - A->getDims().size() * sizeof(int)); - memcpy(B_.data() + max_len - B->getDims().size(), B->getDims().data(), - B->getDims().size() * sizeof(int)); - // std::copy(A->getDims().begin(), A->getDims().end(), A_.begin() + (max_len - // - A->getDims().size())); std::copy(B->getDims().begin(), - // B->getDims().end(), B_.begin() + (max_len - B->getDims().size())); - // std::copy(A->getDims().rbegin(), A->getDims().rend(), A_.rbegin()); - // std::copy(B->getDims().rbegin(), B->getDims().rend(), B_.rbegin()); - - for (int i = 0; i < max_len; ++i) { - if (A_[i] == B_[i] || (A_[i] == 1 || B_[i] == 1)) { - res[i] = std::max(A_[i], B_[i]); - } else { - return {}; - } - } - + auto res = infer_broadcast(A->getDims(), B->getDims()); return {{res}}; } @@ -69,9 +47,8 @@ MSELossObj::MSELossObj(GraphObj *graph, Tensor input0, Tensor input1, optional> MSELossObj::inferShape(const TensorVec &inputs) const { const auto A = inputs[0], B = inputs[1]; - if (A->getDims().size() != B->getDims().size() || - A->getDims() != B->getDims()) - return {}; + IT_ASSERT(A->getRank() == B->getRank()); + IT_ASSERT(A->getDims() == B->getDims()); if (reductionMode == None) { return {{A->getDims()}}; diff --git a/src/operators/extend.cc b/src/operators/extend.cc index 13efcfcf..e8587dbb 100644 --- a/src/operators/extend.cc +++ b/src/operators/extend.cc @@ -1,16 +1,18 @@ #include "operators/extend.h" +#include "utils/operator_utils.h" namespace infini { ExtendObj::ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim, int num) : OperatorObj(OpType::Extend, {input}, {output}), dim(dim), num(num) { + int rank = input->getRank(); + dim = get_real_axis(dim, rank); IT_ASSERT(checkValid(graph)); } optional> ExtendObj::inferShape(const TensorVec &inputs) const { auto ret = inputs[0]->getDims(); - IT_ASSERT((size_t)dim < ret.size()); ret[dim] = ret[dim] * (num + 1); return {{ret}}; } diff --git a/src/operators/gather.cc b/src/operators/gather.cc index 7b54701d..0441b6ba 100644 --- a/src/operators/gather.cc +++ b/src/operators/gather.cc @@ -1,9 +1,12 @@ #include "operators/gather.h" +#include "utils/operator_utils.h" namespace infini { GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices, Tensor output, int axis) : OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) { + int rank = input->getRank(); + axis = get_real_axis(axis, rank); IT_ASSERT(checkValid(graph)); } @@ -11,12 +14,6 @@ optional> GatherObj::inferShape(const TensorVec &inputs) const { auto dims0 = inputs[0]->getDims(); auto dims1 = inputs[1]->getDims(); - if (axis < 0) - IT_TODO_HALT(); - - if ((size_t)axis >= dims0.size()) - return {}; - IT_ASSERT(CheckIndexValid()); Shape dim = dims0; diff --git a/src/operators/matmul.cc b/src/operators/matmul.cc index b26b15ee..963dd591 100644 --- a/src/operators/matmul.cc +++ b/src/operators/matmul.cc @@ -1,4 +1,6 @@ #include "operators/matmul.h" +#include "utils/operator_utils.h" +#include namespace infini { @@ -9,25 +11,23 @@ MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA, transA(transA), transB(transB), act(act), b(1) { auto shape_a = A->getDims(); auto shape_b = B->getDims(); - int dimA = shape_a.size(), dimB = shape_b.size(); - IT_ASSERT(dimA >= 2 && dimB >= 2); - - b = 1; - if (dimA <= 3 && dimB <= 3) { - int b1 = dimA == 2 ? 1 : A->getDims()[0]; - int b2 = dimB == 2 ? 1 : B->getDims()[0]; - - b = std::max(b1, b2); + int rankA = A->getRank(); + int rankB = B->getRank(); + IT_ASSERT(rankA >= 2 && rankB >= 2); + Shape shape_a1(shape_a.begin(), shape_a.begin() + (rankA - 2)); + Shape shape_b1(shape_b.begin(), shape_b.begin() + (rankB - 2)); + auto ret = infer_broadcast(shape_a1, shape_b1); + if (ret.empty()) { + b = 1; } else { - IT_ASSERT_TODO(dimA == dimB); - for (size_t i = 0; i < shape_a.size() - 2; ++i) { - IT_ASSERT_TODO(shape_a[i] == shape_b[i]); - b *= shape_a[i]; - } + b = std::accumulate(ret.begin(), ret.end(), 1); } + auto kA = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin()); + auto kB = *(transB ? shape_b.rbegin() : shape_b.rbegin() + 1); + IT_ASSERT(kA == kB); m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1); n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin()); - k = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin()); + k = kA; IT_ASSERT(checkValid(graph)); } @@ -42,43 +42,16 @@ string MatmulObj::toString() const { optional> MatmulObj::inferShape(const TensorVec &inputs) const { auto A = inputs[0], B = inputs[1]; - int dimA = A->getDims().size(), dimB = B->getDims().size(); - - if (dimA > 3 || dimB > 3) { - // no broadcast - auto shape_a = inputs[0]->getDims(); - auto it = shape_a.rbegin(); - *it++ = n; - *it++ = m; - return {{std::move(shape_a)}}; - } - - int b1 = dimA == 2 ? 1 : A->getDims()[0]; - int b2 = dimB == 2 ? 1 : B->getDims()[0]; - - int b = std::max(b1, b2); - int m = transA ? A->getDims()[dimA - 1] : A->getDims()[dimA - 2]; - int n = transB ? B->getDims()[dimB - 2] : B->getDims()[dimB - 1]; - int kA = transA ? A->getDims()[dimA - 2] : A->getDims()[dimA - 1]; - int kB = transB ? B->getDims()[dimB - 1] : B->getDims()[dimB - 2]; - - if ((dimA != 2 && dimA != 3) || (dimB != 2 && dimB != 3)) { - printf("Bad input dim: dimA = %d, dimB = %d\n", dimA, dimB); - return {}; - } - if (b1 != 1 && b2 != 1 && b1 != b2) { - printf("Bad batch size b1 = %d, b2 = %d\n", b1, b2); - return {}; - } - if (kA != kB) { - printf("Bad K: kA = %d, kB = %d\n", kA, kB); - return {}; - } - if (dimA == 2 && dimB == 2) { - return {{{m, n}}}; - } else { - return {{{b, m, n}}}; - } + auto shapeA = A->getDims(); + auto shapeB = B->getDims(); + int rankA = A->getRank(); + int rankB = B->getRank(); + Shape shapeA1(shapeA.begin(), shapeA.begin() + (rankA - 2)); + Shape shapeB1(shapeB.begin(), shapeB.begin() + (rankB - 2)); + Shape ret = infer_broadcast(shapeA1, shapeB1); + ret.emplace_back(m); + ret.emplace_back(n); + return {{ret}}; } vector MatmulObj::getWorkloadVector() const { diff --git a/src/operators/pad.cc b/src/operators/pad.cc index 3e0ce94c..b870e449 100644 --- a/src/operators/pad.cc +++ b/src/operators/pad.cc @@ -9,7 +9,7 @@ PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output, else { auto nAxis = (*axes).size(); IT_ASSERT(_pads.size() == nAxis * 2); - auto nDims = input->getDims().size(); + auto nDims = input->getRank(); pads = vector(nDims * 2, 0); for (size_t i = 0; i < nAxis; ++i) { @@ -24,13 +24,11 @@ PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output, optional> PadObj::inferShape(const TensorVec &inputs) const { auto dims = inputs[0]->getDims(); - int nDims = dims.size(); - if (nDims * 2 != (int)pads.size()) - return {}; - for (int i = 0; i < nDims; ++i) { - if (pads[i] < 0 || pads[i + nDims] < 0) - return {}; - dims[i] += pads[i] + pads[i + nDims]; + int rank = inputs[0]->getRank(); + IT_ASSERT(rank * 2 == (int)pads.size()); + for (int i = 0; i < rank; ++i) { + IT_ASSERT(pads[i] >= 0 && pads[i + rank] >= 0); + dims[i] += pads[i] + pads[i + rank]; } return {{dims}}; diff --git a/src/operators/pooling.cc b/src/operators/pooling.cc index 0061bf6f..d7153699 100644 --- a/src/operators/pooling.cc +++ b/src/operators/pooling.cc @@ -16,13 +16,13 @@ PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input, optional> PoolingObj::inferShape(const TensorVec &inputs) const { const auto &input = inputs[0]; - auto h = input->getDims()[input->getDims().size() - 2], - w = input->getDims()[input->getDims().size() - 1]; + auto h = input->getDims()[input->getRank() - 2], + w = input->getDims()[input->getRank() - 1]; int oh = (h - (kh - sh) + ph * 2) / sh; int ow = (w - (kw - sw) + pw * 2) / sw; auto ret = input->getDims(); - ret[input->getDims().size() - 2] = oh; - ret[input->getDims().size() - 1] = ow; + ret[input->getRank() - 2] = oh; + ret[input->getRank() - 1] = ow; return {{ret}}; } diff --git a/src/operators/reduce_mean.cc b/src/operators/reduce_mean.cc index 633e6b86..e3a5ec97 100644 --- a/src/operators/reduce_mean.cc +++ b/src/operators/reduce_mean.cc @@ -1,15 +1,14 @@ #include "operators/reduce_mean.h" +#include "utils/operator_utils.h" namespace infini { ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, const optional> &_axes, bool keepDims) : OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) { - const auto size = input->getDims().size(); + const auto size = input->getRank(); if (_axes) { for (auto idx : *_axes) { - if (idx < 0) - IT_TODO_HALT(); - IT_ASSERT((size_t)idx < size); + idx = get_real_axis(idx, size); axes.emplace(idx); } } else @@ -25,6 +24,7 @@ bool ReduceMeanObj::isReduced(int idx) const { optional> ReduceMeanObj::inferShape(const TensorVec &inputs) const { auto dims = inputs[0]->getDims(); + auto rank = inputs[0]->getRank(); if (keepDims) { Shape ret = dims; @@ -33,7 +33,7 @@ ReduceMeanObj::inferShape(const TensorVec &inputs) const { return {{ret}}; } else { Shape ret; - for (size_t i = 0; i < dims.size(); ++i) { + for (size_t i = 0; i < rank; ++i) { if (!isReduced(i)) ret.emplace_back(dims[i]); } diff --git a/src/operators/reshape.cc b/src/operators/reshape.cc index 7110ab90..df216601 100644 --- a/src/operators/reshape.cc +++ b/src/operators/reshape.cc @@ -1,4 +1,5 @@ #include "operators/reshape.h" +#include "utils/operator_utils.h" namespace infini { ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims) @@ -8,10 +9,10 @@ ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims) optional> ReshapeObj::inferShape(const TensorVec &inputs) const { size_t size = 1; - for (size_t i = 0; i < dims.size(); ++i) + for (size_t i = 0; i < dims.size(); ++i) { size *= dims.at(i); - if (size != inputs[0]->size()) - return {}; + } + IT_ASSERT(size == inputs[0]->size()); return {{dims}}; } @@ -41,22 +42,18 @@ vector ReshapeObj::getOpAttrVector() const { FlattenObj::FlattenObj(GraphObj *graph, Tensor input, Tensor output, int _axis) : OperatorObj(OpType::Flatten, {input}, {output}) { - if (_axis >= 0 && (size_t)_axis < input->getDims().size()) - axis = _axis; - else if (_axis <= -1 && (size_t)_axis >= -input->getDims().size()) - axis = _axis + input->getDims().size(); - else - IT_ASSERT(0); + int rank = input->getRank(); + axis = get_real_axis(_axis, rank); IT_ASSERT(checkValid(graph)); } optional> FlattenObj::inferShape(const TensorVec &inputs) const { int sizeB = 1, sizeE = 1; auto dims = getInputs(0)->getDims(); - int ndim = dims.size(); - for (int i = 0; i < ndim; ++i) + int rank = getInputs(0)->getRank(); + for (int i = 0; i < rank; ++i) { ((i < axis) ? sizeB : sizeE) *= dims.at(i); - + } return {{{sizeB, sizeE}}}; } diff --git a/src/operators/resize.cc b/src/operators/resize.cc index 63998de0..11933414 100644 --- a/src/operators/resize.cc +++ b/src/operators/resize.cc @@ -45,11 +45,11 @@ void ResizeObj::init(const Tensor &input, const Tensor &sizes, if (ECoordinateTransMode::tfCropAndResize == coMode) { IT_ASSERT(nullptr != roi); inputs.push_back(roi); - IT_ASSERT(roi->getDims().size() == 1); + IT_ASSERT(roi->getRank() == 1); IT_ASSERT((size_t)roi->getDims()[0] == this->axes.size() * 2); // init roi_start = 0;roi_end =1 - size_t nDims = input->getDims().size(); + size_t nDims = input->getRank(); for (size_t i = 0; i < nDims; ++i) { this->roi.emplace_back(0); } @@ -75,24 +75,26 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes, const std::optional> &axes) { IT_ASSERT(sizes != nullptr); size_t size = sizes->getDims()[0]; - IT_ASSERT(size == input->getDims().size() || + IT_ASSERT(size == input->getRank() || (axes != std::nullopt && size == (*axes).size())); - if (axes == std::nullopt) - for (size_t i = 0; i < input->getDims().size(); ++i) + if (axes == std::nullopt) { + for (size_t i = 0; i < input->getRank(); ++i) { this->axes.emplace_back(i); - else + } + } else { // check axes for (size_t i = 0; i < (*axes).size(); ++i) { auto val = (*axes)[i]; - if (val < 0) + if (val < 0) { IT_TODO_HALT(); - IT_ASSERT((size_t)val < inputs[0]->getDims().size()); + } + IT_ASSERT((size_t)val < inputs[0]->getRank()); this->axes.emplace_back(val); } - + } // init this->scales - for (size_t i = 0; i < input->getDims().size(); ++i) { + for (size_t i = 0; i < input->getRank(); ++i) { this->scales.emplace_back(1); } @@ -109,9 +111,10 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes, int n = this->axes.size(); switch (ratioPolicy) { case EKeepAspectRatioPolicy::stretch: - for (int i = 0; i < n; ++i) + for (int i = 0; i < n; ++i) { scales[this->axes[i]] = (float)data[i] / (float)inDims[this->axes[i]]; + } break; case EKeepAspectRatioPolicy::notLarger: { float scale = (float)data[0] / (float)inDims[this->axes[0]]; @@ -119,8 +122,9 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes, auto tmp = (float)data[i] / (float)inDims[this->axes[i]]; scale = scale < tmp ? scale : tmp; } - for (int i = 0; i < n; ++i) + for (int i = 0; i < n; ++i) { scales[this->axes[i]] = scale; + } break; } case EKeepAspectRatioPolicy::notSmaller: { @@ -129,8 +133,9 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes, auto tmp = (float)data[i] / (float)inDims[this->axes[i]]; scale = scale > tmp ? scale : tmp; } - for (int i = 0; i < n; ++i) + for (int i = 0; i < n; ++i) { scales[this->axes[i]] = scale; + } break; } default: @@ -142,7 +147,7 @@ void ResizeObj::InitByScales(Tensor input, Tensor scales, const std::optional> &axes) { IT_ASSERT(scales != nullptr); size_t size = scales->getDims()[0]; - IT_ASSERT(size == input->getDims().size() || + IT_ASSERT(size == input->getRank() || (axes != std::nullopt && size == (*axes).size())); // copy scales data to host. @@ -155,27 +160,29 @@ void ResizeObj::InitByScales(Tensor input, Tensor scales, (void *)data, scales->getRawDataPtr(), scales->getBytes()); // init this->scales - for (size_t i = 0; i < input->getDims().size(); ++i) { + for (size_t i = 0; i < input->getRank(); ++i) { this->scales.emplace_back(1); } - if (axes == std::nullopt) - for (size_t i = 0; i < input->getDims().size(); ++i) { + if (axes == std::nullopt) { + for (size_t i = 0; i < input->getRank(); ++i) { this->axes.emplace_back(i); IT_ASSERT(data[i] > 0); this->scales[i] = data[i]; } - else + } else { // check axes for (size_t i = 0; i < (*axes).size(); ++i) { auto val = (*axes)[i]; - if (val < 0) + if (val < 0) { IT_TODO_HALT(); - IT_ASSERT((size_t)val < inputs[0]->getDims().size()); + } + IT_ASSERT((size_t)val < inputs[0]->getRank()); this->axes.emplace_back(val); IT_ASSERT(data[i] > 0); this->scales[val] = data[i]; } + } } vector ResizeObj::inferDataType(const TensorVec &inputs) const { @@ -202,8 +209,8 @@ float ResizeObj::round_int(float x) const { optional> ResizeObj::inferShape(const TensorVec &inputs) const { auto inDims = inputs[0]->getDims(); Shape ret = inDims; - int nDim = inDims.size(); - for (int i = 0; i < nDim; ++i) { + int rank = inputs[0]->getRank(); + for (int i = 0; i < rank; ++i) { int size = round_int(scales[i] * inDims[i]); ret[i] = size; } @@ -217,12 +224,14 @@ std::string ResizeObj::toString() const { << "[" << getGuid() << "]"; os << "("; os << vecToString(inputs[0]->getDims()) << ","; - if (inputs.size() == 3) + if (inputs.size() == 3) { os << "roi=" << vecToString(inputs[2]->getDims()) << ","; - if (isResizeBySizes()) + } + if (isResizeBySizes()) { os << "sizes=" << vecToString(inputs[1]->getDims()) << ","; - else + } else { os << "scales=" << vecToString(inputs[1]->getDims()) << ","; + } os << "axes=" << vecToString(axes) << ","; os << "coMode=" << enum_to_underlying(coMode) << ","; os << "nearestMode=" << enum_to_underlying(nearestMode) << ","; @@ -230,16 +239,18 @@ std::string ResizeObj::toString() const { os << "input=" << inputs[0]->getGuid() << ","; os << inputs[1]->getGuid() << ","; - if (inputs.size() == 3) + if (inputs.size() == 3) { os << inputs[2]->getGuid() << ","; + } os << "output=" << outputs[0]->getGuid() << ")"; return os.str(); } vector ResizeObj::getWorkloadVector() const { vector ret = inputs[0]->getDims(); - for (size_t i = 0; i < outputs[0]->getDims().size(); ++i) + for (size_t i = 0; i < outputs[0]->getRank(); ++i) { ret.emplace_back(outputs[0]->getDims()[i]); + } // ratioPolicy only effects output shape, so did not need // here. ret.emplace_back(enum_to_underlying(coMode)); diff --git a/src/operators/softmax.cc b/src/operators/softmax.cc index 2fa2ccf6..f9dde777 100644 --- a/src/operators/softmax.cc +++ b/src/operators/softmax.cc @@ -1,15 +1,12 @@ #include "operators/softmax.h" +#include "utils/operator_utils.h" namespace infini { SoftmaxObj::SoftmaxObj(GraphObj *graph, Tensor input, Tensor output, int _axis) : OperatorObj(OpType::Softmax, {input}, {output}) { - if (_axis >= 0 && (size_t)_axis < input->getDims().size()) - axis = _axis; - else if (_axis <= -1 && (size_t)_axis >= -input->getDims().size()) - axis = _axis + input->getDims().size(); - else - IT_ASSERT(0); + int rank = input->getRank(); + axis = get_real_axis(_axis, rank); IT_ASSERT(checkValid(graph)); } diff --git a/src/operators/split.cc b/src/operators/split.cc index 52c6a61d..45eb1804 100644 --- a/src/operators/split.cc +++ b/src/operators/split.cc @@ -1,4 +1,5 @@ #include "operators/split.h" +#include "utils/operator_utils.h" #include namespace infini { @@ -7,6 +8,8 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input, : OperatorObj(OpType::Split, {input}, ((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))), dim(dim), num(num), ratio({}) { + int rank = input->getRank(); + dim = get_real_axis(dim, rank); int dimSize = input->getDims().at(dim); int pieceSize = dimSize / num; int lastSize = dimSize - pieceSize * num; @@ -26,6 +29,8 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input, : OperatorObj(OpType::Split, {input}, ((!outputs) ? TensorVec{nullptr} : (*outputs))), dim(dim), num(-1), ratio(ratio) { + int rank = input->getRank(); + dim = get_real_axis(dim, rank); num = ratio.size(); if (!outputs) { TensorVec tmp(num, nullptr); @@ -35,13 +40,11 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input, } optional> SplitObj::inferShape(const TensorVec &inputs) const { - if (num == -1 || ratio.size() == 0) - return {}; + IT_ASSERT(num != -1 && ratio.size() != 0); auto inputDims = inputs[0]->getDims(); int totalSize = inputDims.at(dim); int ratioSum = std::accumulate(ratio.begin(), ratio.end(), 0); - if (totalSize % ratioSum != 0) - return {}; + IT_ASSERT(totalSize % ratioSum == 0); int pieceSize = totalSize / ratioSum; diff --git a/src/operators/transpose.cc b/src/operators/transpose.cc index 490dc9e0..9a457647 100644 --- a/src/operators/transpose.cc +++ b/src/operators/transpose.cc @@ -4,26 +4,32 @@ namespace infini { TransposeObj::TransposeObj(GraphObj *graph, Tensor input, Tensor output, vector permute) : OperatorObj(OpType::Transpose, {input}, {output}) { - if (permute.size() != 4) { - IT_TODO_HALT(); + auto rank = input->getRank(); + if (permute.empty()) { + for (size_t i = 0; i < rank; ++i) { + transposePermute[i] = i; + } + } else { + IT_ASSERT(rank == permute.size()); + transposePermute = std::move(permute); } - transposePermute[0] = permute[0]; - transposePermute[1] = permute[1]; - transposePermute[2] = permute[2]; - transposePermute[3] = permute[3]; IT_ASSERT(checkValid(graph)); } optional> TransposeObj::inferShape(const TensorVec &inputs) const { const auto A = inputs[0]; - auto input = A->getDims(); - auto output = input; + auto input_dim = A->getDims(); + auto output_dim = input_dim; + int rank = A->getRank(); - for (int i = 0; i < 4; ++i) { - output[i] = input[transposePermute[i]]; + for (auto index : transposePermute) { + IT_ASSERT(index < rank); } - return {{output}}; + for (int i = 0; i < rank; ++i) { + output_dim[i] = input_dim[transposePermute[i]]; + } + return {{output_dim}}; } std::string TransposeObj::toString() const { diff --git a/src/operators/unary.cc b/src/operators/unary.cc index 7436ac9f..7f98940a 100644 --- a/src/operators/unary.cc +++ b/src/operators/unary.cc @@ -183,46 +183,54 @@ vector CastObj::getOpAttrVector() const { return {type.underlying()}; } DataType CastObj::getOutputDataType() const { switch (castType) { - case CastObj::Float2Int64: + case CastType::Float2Float16: + return DataType::Float16; + case CastType::Float2Int64: return DataType::Int64; - case CastObj::Float2Int32: + case CastType::Float2Int32: return DataType::Int32; - case CastObj::Float2Int16: + case CastType::Float2Int16: return DataType::Int16; - case CastObj::Float2Int8: + case CastType::Float2Int8: return DataType::Int8; - case CastObj::Int322Float: + case CastType::Int322Float: return DataType::Float32; - case CastObj::Int322Int8: + case CastType::Int322Int8: return DataType::Int8; - case CastObj::Int322Int16: + case CastType::Int322Int16: return DataType::Int16; - case CastObj::Int162Float: + case CastType::Int162Float: return DataType::Float32; - case CastObj::Int162Int32: + case CastType::Int162Int32: return DataType::Int32; - case CastObj::Int82Float: + case CastType::Int82Float: return DataType::Float32; - case CastObj::Int82Int16: + case CastType::Int82Int16: return DataType::Int16; - case CastObj::Int82Int32: + case CastType::Int82Int32: return DataType::Int32; - case CastObj::Uint82Float: + case CastType::Uint82Float: return DataType::Float32; - case CastObj::Uint82Int32: + case CastType::Uint82Int32: return DataType::Int32; - case CastObj::Uint82Int64: + case CastType::Uint82Int64: return DataType::Int64; - case CastObj::Int322Int64: + case CastType::Int322Int64: return DataType::Int64; - case CastObj::Int642Int32: + case CastType::Int642Int32: return DataType::Int32; - case CastObj::Int642Uint32: + case CastType::Int642Uint32: return DataType::UInt32; - case CastObj::Int642Float: + case CastType::Int642Float: return DataType::Float32; - case CastObj::Uint322Int64: + case CastType::Uint322Int64: return DataType::Int64; + case CastType::Float162Float: + return DataType::Float32; + case CastType::BFloat162Float: + return DataType::Float32; + case CastType::Float2BFloat16: + return DataType::BFloat16; default: IT_TODO_HALT(); } @@ -234,7 +242,7 @@ ShapeObj::ShapeObj(GraphObj *graph, Tensor input, Tensor output) } optional> ShapeObj::inferShape(const TensorVec &inputs) const { - return {{{static_cast(inputs[0]->getDims().size())}}}; + return {{{static_cast(inputs[0]->getRank())}}}; } std::string ShapeObj::toString() const { diff --git a/src/utils/data_convert.cc b/src/utils/data_convert.cc index 28b0c923..3dee5f1b 100644 --- a/src/utils/data_convert.cc +++ b/src/utils/data_convert.cc @@ -27,4 +27,17 @@ float fp16_to_float(const uint16_t x) { u.u32 = r; return u.f32; } + +uint16_t float_to_bfp16(const float x) { + Uf32 u; + u.f32 = x; + return u.u32 >> 16; +} + +float bfp16_to_fp32(const uint16_t x) { + Uf32 u; + u.u32 = x << 16; + return u.f32; +} + } // namespace infini diff --git a/src/utils/dataloader.cc b/src/utils/dataloader.cc index 73ce34fc..de4e04ac 100644 --- a/src/utils/dataloader.cc +++ b/src/utils/dataloader.cc @@ -12,7 +12,7 @@ void saveTensorData(TensorObj *tensor, std::string file_path) { #ifdef TENSOR_PROTOBUF data::Tensor temp; temp.set_id("tensor_id"); - for (size_t i = 0; i < tensor->getDims().size(); ++i) { + for (size_t i = 0; i < tensor->getRank(); ++i) { temp.add_shape(tensor->getDims()[i]); } temp.set_layout(data::LAYOUT_NHWC); diff --git a/src/utils/operator_utils.cc b/src/utils/operator_utils.cc new file mode 100644 index 00000000..2d26fee5 --- /dev/null +++ b/src/utils/operator_utils.cc @@ -0,0 +1,44 @@ +#include "utils/operator_utils.h" + +namespace infini { + +Shape infer_broadcast(const Shape &A, const Shape &B) { + if (A.empty() && B.empty()) { + return {}; + } + auto A_ = A; + auto B_ = B; + int rankA = A.size(); + int rankB = B.size(); + int rank = std::max(rankA, rankB); + if (rankA < rank) { + for (int i = 0; i < rank - rankA; ++i) { + A_.insert(A_.begin(), 1); + } + } + if (rankB < rank) { + for (int i = 0; i < rank - rankB; ++i) { + B_.insert(B_.begin(), 1); + } + } + Shape ret; + for (int i = 0; i < rank; ++i) { + IT_ASSERT(A_[i] == B_[i] || A_[i] == 1 || B_[i] == 1); + auto shapeEle = std::max(A_[i], B_[i]); + ret.emplace_back(shapeEle); + } + return ret; +} + +int get_real_axis(const int &axis, const int &rank) { + IT_ASSERT(rank >= 1); + IT_ASSERT(axis >= -rank && axis <= (rank - 1)); + int newAxis; + if (axis < 0) { + newAxis = rank + axis; + } else { + newAxis = axis; + } + return newAxis; +} +} // namespace infini diff --git a/test/core/test_hash.cc b/test/core/test_hash.cc index c6098aab..2ec4aedd 100644 --- a/test/core/test_hash.cc +++ b/test/core/test_hash.cc @@ -28,4 +28,4 @@ TEST(Hash, OperatorHash) { EXPECT_NE(key1.hash, key2.hash); } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/test/operators/test_matmul.cc b/test/operators/test_matmul.cc index 22d07a1a..6d3e655d 100644 --- a/test/operators/test_matmul.cc +++ b/test/operators/test_matmul.cc @@ -27,6 +27,30 @@ TEST(Matmul, ShapeInference) { auto C = matmul->getOutputs()[0]; EXPECT_EQ(C->getDims(), (Shape{3, 4, 2})); } + { + Graph g = make_ref(runtime); + auto A = g->addTensor(Shape{1, 2, 3, 5}); + auto B = g->addTensor(Shape{1, 1, 5, 2}); + auto matmul = g->addOp(A, B, nullptr); + auto C = matmul->getOutputs()[0]; + EXPECT_EQ(C->getDims(), (Shape{1, 2, 3, 2})); + } + { + Graph g = make_ref(runtime); + auto A = g->addTensor(Shape{2, 3, 5, 4}); + auto B = g->addTensor(Shape{1, 3, 5, 2}); + auto matmul = g->addOp(A, B, nullptr, true, false); + auto C = matmul->getOutputs()[0]; + EXPECT_EQ(C->getDims(), (Shape{2, 3, 4, 2})); + } + { + Graph g = make_ref(runtime); + auto A = g->addTensor(Shape{2, 3, 5, 4}); + auto B = g->addTensor(Shape{1, 3, 2, 5}); + auto matmul = g->addOp(A, B, nullptr, true, true); + auto C = matmul->getOutputs()[0]; + EXPECT_EQ(C->getDims(), (Shape{2, 3, 4, 2})); + } } }; // namespace infini diff --git a/test/operators/test_transpose.cc b/test/operators/test_transpose.cc new file mode 100644 index 00000000..1c12b79f --- /dev/null +++ b/test/operators/test_transpose.cc @@ -0,0 +1,32 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "operators/transpose.h" + +#include "test.h" + +namespace infini { + +TEST(Transpose, ShapeInference) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({1, 2, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, Shape{0, 1, 2, 3}); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 3, 4})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({1, 2, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, Shape{0, 2, 1, 3}); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4})); + } + { + Graph g = make_ref(runtime); + Tensor i = g->addTensor({2, 3, 4}, DataType::Float32); + auto op = g->addOp(i, nullptr, Shape{0, 2, 1}); + EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 4, 3})); + } +} + +} // namespace infini