support mixed dtype (#102)

* feat: support mixed dtype

* feat: support cast op

* test: add test for cast op

* feat: support datatype BFloat16

* feat: support data convert fp32 <-> bfp16

* fix: fix all op's infershape func

* fix as review comment
This commit is contained in:
zhangyunze 2023-08-16 21:49:43 +08:00 committed by GitHub
parent 0dc5347089
commit ef672894d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 992 additions and 712 deletions

View File

@ -19,6 +19,7 @@ class DataType {
static const DataType Double; static const DataType Double;
static const DataType UInt32; static const DataType UInt32;
static const DataType UInt64; static const DataType UInt64;
static const DataType BFloat16;
// "sizePerElement" show the DType to cpu_type // "sizePerElement" show the DType to cpu_type
// DataType::Bool -> int8_t DataType::Float16 -> uint16_t // DataType::Bool -> int8_t DataType::Float16 -> uint16_t
static constexpr size_t sizePerElement[]{0, static constexpr size_t sizePerElement[]{0,
@ -34,14 +35,19 @@ class DataType {
sizeof(uint16_t), sizeof(uint16_t),
sizeof(double), sizeof(double),
sizeof(uint32_t), sizeof(uint32_t),
sizeof(uint64_t)}; sizeof(uint64_t),
0,
0,
sizeof(uint16_t)};
static constexpr std::string_view names[]{ static constexpr std::string_view names[]{
"Undefine", "Float32", "UInt8", "Int8", "UInt16", "Undefine", "Float32", "UInt8", "Int8", "UInt16",
"Int16", "Int32", "Int64", "String", "Bool", "Int16", "Int32", "Int64", "String", "Bool",
"Float16", "Double", "UInt32", "UInt64"}; "Float16", "Double", "UInt32", "UInt64", "PlaceHolder",
"PlaceHolder", "BFloat16"};
static constexpr int cpuType[]{-1, 0, 2, 3, 4, 5, 6, 7, -1, 3, 4, 9, 1, 8}; static constexpr int cpuType[]{-1, 0, 2, 3, 4, 5, 6, 7, -1,
3, 4, 9, 1, 8, -1, -1, 4};
private: private:
int index; int index;
@ -79,6 +85,7 @@ inline const DataType DataType::Float16(10);
inline const DataType DataType::Double(11); inline const DataType DataType::Double(11);
inline const DataType DataType::UInt32(12); inline const DataType DataType::UInt32(12);
inline const DataType DataType::UInt64(13); inline const DataType DataType::UInt64(13);
inline const DataType DataType::BFloat16(16);
// Method definitions are out of the declaration due to GCC bug: // Method definitions are out of the declaration due to GCC bug:
// https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc // https://stackoverflow.com/questions/49707184/explicit-specialization-in-non-namespace-scope-does-not-compile-in-gcc
template <> inline int DataType::get<float>() { return 0; } template <> inline int DataType::get<float>() { return 0; }
@ -107,5 +114,6 @@ template <> struct DT<10> { using t = uint16_t; };
template <> struct DT<11> { using t = double; }; template <> struct DT<11> { using t = double; };
template <> struct DT<12> { using t = uint32_t; }; template <> struct DT<12> { using t = uint32_t; };
template <> struct DT<13> { using t = uint64_t; }; template <> struct DT<13> { using t = uint64_t; };
template <> struct DT<16> { using t = uint16_t; };
} // namespace infini } // namespace infini

View File

@ -66,6 +66,7 @@ class GraphHandlerObj {
const optional<vector<int>> &steps); const optional<vector<int>> &steps);
Tensor pad(Tensor input, Tensor output, const vector<int> &pads, Tensor pad(Tensor input, Tensor output, const vector<int> &pads,
const optional<vector<int>> &axes); const optional<vector<int>> &axes);
Tensor cast(Tensor input, Tensor output, int to);
//------ modifiers //------ modifiers

View File

@ -36,6 +36,7 @@ class TensorObj : public TensorBaseObj {
size_t getBytes() const { return _size * dtype.getSize(); } size_t getBytes() const { return _size * dtype.getSize(); }
Shape getDims() const { return shape; } Shape getDims() const { return shape; }
size_t getRank() const { return shape.size(); }
vector<size_t> getStride() const; vector<size_t> getStride() const;
size_t getOffset(const vector<int> &ds) const; size_t getOffset(const vector<int> &ds) const;
void dataMalloc(); void dataMalloc();
@ -330,7 +331,7 @@ class TensorObj : public TensorBaseObj {
// } // }
// void initSplittingPoints() { // void initSplittingPoints() {
// splittingPoints.resize(getDims().size()); } // splittingPoints.resize(getRank()); }
// void printShape(); // void printShape();
}; };

View File

@ -15,7 +15,7 @@ class TransposeObj : public OperatorObj {
std::vector<int> getPermute() const { return transposePermute; } std::vector<int> getPermute() const { return transposePermute; }
private: private:
vector<int> transposePermute = {1, 1, 1, 1}; vector<int> transposePermute;
vector<int> getWorkloadVector() const override; vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override; vector<int> getOpAttrVector() const override;
}; };

View File

@ -134,31 +134,35 @@ class TransformObj : public OperatorObj {
vector<int> getOpAttrVector() const override; vector<int> getOpAttrVector() const override;
}; };
enum class CastType {
Float2Float16 = 0,
Float2Int64,
Float2Int32,
Float2Int16,
Float2Int8,
Float2BFloat16,
Int322Float,
Int322Int8,
Int322Int16,
Int322Int64,
Int162Float,
Int162Int32,
Int82Float,
Int82Int16,
Int82Int32,
Uint82Float,
Uint82Int32,
Uint82Int64,
Int642Int32,
Int642Uint32,
Int642Float,
Uint322Int64,
Float162Float,
BFloat162Float,
};
class CastObj : public OperatorObj { class CastObj : public OperatorObj {
public: public:
enum CastType {
Float2Half = 0,
Float2Int64,
Float2Int32,
Float2Int16,
Float2Int8,
Int322Float,
Int322Int8,
Int322Int16,
Int162Float,
Int162Int32,
Int82Float,
Int82Int16,
Int82Int32,
Uint82Float,
Uint82Int32,
Uint82Int64,
Int322Int64,
Int642Int32,
Int642Uint32,
Int642Float,
Uint322Int64,
};
CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type); CastObj(GraphObj *graph, Tensor input, Tensor output, CastType type);
OP_CLONE(CastObj); OP_CLONE(CastObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override; optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -8,4 +8,6 @@ union Uf32 {
}; };
uint16_t float_to_fp16(const float x); uint16_t float_to_fp16(const float x);
float fp16_to_float(const uint16_t x); float fp16_to_float(const uint16_t x);
uint16_t float_to_bfp16(const float x);
float bfp16_to_float(const uint16_t x);
} // namespace infini } // namespace infini

View File

@ -0,0 +1,15 @@
#pragma once
#ifndef OPERATOR_UTIL_H
#define OPERATOR_UTIL_H
#include "core/tensor.h"
namespace infini {
// Launch a broadcast shape based on the shape of input A and B
Shape infer_broadcast(const Shape &A, const Shape &B);
// Launch the real axis based on rank and current axis
int get_real_axis(const int &axis, const int &rank);
} // namespace infini
#endif

View File

@ -62,462 +62,520 @@ class OnnxStub:
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type) tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
data[initializer.name] = initializer data[initializer.name] = initializer
node_name = []
new_node_name = []
for node in model.graph.node: for node in model.graph.node:
if node.op_type == "Conv": node_name.append(node.name)
attributes = _parse_attribute( node_list = model.graph.node
node, while len(node_list) != 0:
{ for node in model.graph.node:
"dilations": [1, 1], if node.name not in node_list:
"pads": [0, 0, 0, 0], continue
"strides": [1, 1], if _analyse_node(node, tensors):
}, continue
) if node.op_type == "Conv":
(d, p, s) = ( attributes = _parse_attribute(
attributes[name] for name in ["dilations", "pads", "strides"] node,
) {
if p[0] != p[2] or p[1] != p[3]: "dilations": [1, 1],
adapt = "{}-adapt".format(node.output[0]) "pads": [0, 0, 0, 0],
tensors[adapt] = self.handler.pad( "strides": [1, 1],
tensors[node.input[0]], None, p, [-2, -1] },
) )
p = [0, 0, 0, 0] (d, p, s) = (
else: attributes[name] for name in ["dilations", "pads", "strides"]
adapt = node.input[0] )
if p[0] != p[2] or p[1] != p[3]:
adapt = "{}-adapt".format(node.output[0])
tensors[adapt] = self.handler.pad(
tensors[node.input[0]], None, p, [-2, -1]
)
p = [0, 0, 0, 0]
else:
adapt = node.input[0]
if len(node.input) > 2: if len(node.input) > 2:
bias = "{}-bias".format(node.output[0]) bias = "{}-bias".format(node.output[0])
reshape = "{}-reshape".format(node.output[0]) reshape = "{}-reshape".format(node.output[0])
tensors[bias] = self.handler.conv( tensors[bias] = self.handler.conv(
tensors[adapt], tensors[adapt],
tensors[node.input[1]],
None,
p[0],
p[1],
s[0],
s[1],
d[0],
d[1],
)
tensors[reshape] = self.handler.reshape(
tensors[node.input[2]],
None,
[
1,
reduce(
lambda acc, x: acc * x,
_search_shape(model, node.input[2]),
),
1,
1,
],
)
tensors[node.output[0]] = self.handler.add(
tensors[bias],
tensors[reshape],
tensors.get(node.output[0]),
)
else:
tensors[node.output[0]] = self.handler.conv(
tensors[adapt],
tensors[node.input[1]],
tensors.get(node.output[0]),
p[0],
p[1],
s[0],
s[1],
d[0],
d[1],
)
elif node.op_type == "ConvTranspose":
attributes = _parse_attribute(
node,
{
"dilations": [1, 1],
"pads": [0, 0],
"strides": [1, 1],
"output_padding": [0, 0],
},
)
(d, p, s, op) = (
attributes[name]
for name in ["dilations", "pads", "strides", "output_padding"]
)
tensors[node.output[0]] = self.handler.convTransposed2d(
tensors[node.input[0]],
tensors[node.input[1]], tensors[node.input[1]],
None, tensors.get(node.output[0]),
p[0], p[0],
p[1], p[1],
s[0], s[0],
s[1], s[1],
d[0], d[0],
d[1], d[1],
op[0],
op[1],
) )
tensors[reshape] = self.handler.reshape( elif node.op_type == "MatMul":
tensors[node.input[2]], tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
False,
None, None,
[ backend.ActType.Linear,
1,
reduce(
lambda acc, x: acc * x,
_search_shape(model, node.input[2]),
),
1,
1,
],
) )
elif node.op_type == "Gemm":
attributes = _parse_attribute(
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
)
(alpha, beta, transA, transB) = (
attributes[name]
for name in ["alpha", "beta", "transA", "transB"]
)
# FIXME unsupport attributes: `alpha` `beta`
assert alpha == 1.0
assert beta == 1.0
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
transA == 1,
transB == 1,
tensors[node.input[2]] if len(node.input) > 2 else None,
backend.ActType.Linear,
)
elif node.op_type == "BatchNormalization":
(input, mean, var, scale, bias) = (
tensors[node.input[i]] for i in [0, 3, 4, 1, 2]
)
output = tensors.get(node.output[0])
attributes = _parse_attribute(
node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0}
)
(momentum, eps, training) = (
attributes[name]
for name in ["momentum", "epsilon", "training_mode"]
)
tensors[node.output[0]] = self.handler.batchNormalization(
input,
output,
mean,
var,
scale,
bias,
momentum,
eps,
training != 0,
)
elif node.op_type == "MaxPool":
attributes = _parse_attribute(
node,
{
"kernel_shape": None,
"dilations": [1, 1],
"pads": [0, 0, 0, 0],
"strides": [1, 1],
},
)
(k, d, p, s) = (
attributes[name]
for name in ["kernel_shape", "dilations", "pads", "strides"]
)
if p[0] != p[2] or p[1] != p[3]:
adapt = "{}-adapt".format(node.output[0])
tensors[adapt] = self.handler.pad(
tensors.get(node.input[0]), None, p, [-2, -1]
)
tensors[node.output[0]] = self.handler.maxPool(
tensors[adapt],
tensors.get(node.output[0]),
k[0],
k[1],
d[0],
d[1],
0,
0,
s[0],
s[1],
)
else:
tensors[node.output[0]] = self.handler.maxPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
k[0],
k[1],
d[0],
d[1],
p[0],
p[1],
s[0],
s[1],
)
elif node.op_type == "AveragePool":
attributes = _parse_attribute(
node,
{
"kernel_shape": None,
"pads": [0, 0, 0, 0],
"strides": [1, 1],
},
)
(k, p, s) = (
attributes[name] for name in ["kernel_shape", "pads", "strides"]
)
if p[0] != p[2] or p[1] != p[3]:
adapt = "{}-adapt".format(node.output[0])
tensors[adapt] = self.handler.pad(
tensors.get(node.input[0]), None, p, [-2, -1]
)
tensors[node.output[0]] = self.handler.avgPool(
tensors[adapt],
tensors.get(node.output[0]),
k[0],
k[1],
1,
1,
0,
0,
s[0],
s[1],
)
else:
tensors[node.output[0]] = self.handler.avgPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
k[0],
k[1],
1,
1,
p[0],
p[1],
s[0],
s[1],
)
elif node.op_type == "GlobalAveragePool":
[_, _, h, w] = _search_shape(model, node.input[0])
tensors[node.output[0]] = self.handler.avgPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
h,
w,
1,
1,
0,
0,
1,
1,
)
elif node.op_type == "Add":
tensors[node.output[0]] = self.handler.add( tensors[node.output[0]] = self.handler.add(
tensors[bias], tensors[node.input[0]],
tensors[reshape],
tensors.get(node.output[0]),
)
else:
tensors[node.output[0]] = self.handler.conv(
tensors[adapt],
tensors[node.input[1]], tensors[node.input[1]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
p[0],
p[1],
s[0],
s[1],
d[0],
d[1],
) )
elif node.op_type == "ConvTranspose": elif node.op_type == "Sub":
attributes = _parse_attribute( tensors[node.output[0]] = self.handler.sub(
node, tensors[node.input[0]],
{ tensors[node.input[1]],
"dilations": [1, 1],
"pads": [0, 0],
"strides": [1, 1],
"output_padding": [0, 0],
},
)
(d, p, s, op) = (
attributes[name]
for name in ["dilations", "pads", "strides", "output_padding"]
)
tensors[node.output[0]] = self.handler.convTransposed2d(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
p[0],
p[1],
s[0],
s[1],
d[0],
d[1],
op[0],
op[1],
)
elif node.op_type == "MatMul":
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
False,
None,
backend.ActType.Linear,
)
elif node.op_type == "Gemm":
attributes = _parse_attribute(
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
)
(alpha, beta, transA, transB) = (
attributes[name] for name in ["alpha", "beta", "transA", "transB"]
)
# FIXME unsupport attributes: `alpha` `beta`
assert alpha == 1.0
assert beta == 1.0
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
transA == 1,
transB == 1,
tensors[node.input[2]] if len(node.input) > 2 else None,
backend.ActType.Linear,
)
elif node.op_type == "BatchNormalization":
(input, mean, var, scale, bias) = (
tensors[node.input[i]] for i in [0, 3, 4, 1, 2]
)
output = tensors.get(node.output[0])
attributes = _parse_attribute(
node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0}
)
(momentum, eps, training) = (
attributes[name]
for name in ["momentum", "epsilon", "training_mode"]
)
tensors[node.output[0]] = self.handler.batchNormalization(
input, output, mean, var, scale, bias, momentum, eps, training != 0
)
elif node.op_type == "MaxPool":
attributes = _parse_attribute(
node,
{
"kernel_shape": None,
"dilations": [1, 1],
"pads": [0, 0, 0, 0],
"strides": [1, 1],
},
)
(k, d, p, s) = (
attributes[name]
for name in ["kernel_shape", "dilations", "pads", "strides"]
)
if p[0] != p[2] or p[1] != p[3]:
adapt = "{}-adapt".format(node.output[0])
tensors[adapt] = self.handler.pad(
tensors.get(node.input[0]), None, p, [-2, -1]
)
tensors[node.output[0]] = self.handler.maxPool(
tensors[adapt],
tensors.get(node.output[0]), tensors.get(node.output[0]),
k[0],
k[1],
d[0],
d[1],
0,
0,
s[0],
s[1],
) )
else: elif node.op_type == "Mul":
tensors[node.output[0]] = self.handler.maxPool( tensors[node.output[0]] = self.handler.mul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Div":
tensors[node.output[0]] = self.handler.div(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Pow":
tensors[node.output[0]] = self.handler.pow(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Relu":
tensors[node.output[0]] = self.handler.relu(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
k[0],
k[1],
d[0],
d[1],
p[0],
p[1],
s[0],
s[1],
) )
elif node.op_type == "AveragePool": elif node.op_type == "Sigmoid":
attributes = _parse_attribute( tensors[node.output[0]] = self.handler.sigmoid(
node,
{
"kernel_shape": None,
"pads": [0, 0, 0, 0],
"strides": [1, 1],
},
)
(k, p, s) = (
attributes[name] for name in ["kernel_shape", "pads", "strides"]
)
if p[0] != p[2] or p[1] != p[3]:
adapt = "{}-adapt".format(node.output[0])
tensors[adapt] = self.handler.pad(
tensors.get(node.input[0]), None, p, [-2, -1]
)
tensors[node.output[0]] = self.handler.avgPool(
tensors[adapt],
tensors.get(node.output[0]),
k[0],
k[1],
1,
1,
0,
0,
s[0],
s[1],
)
else:
tensors[node.output[0]] = self.handler.avgPool(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
k[0],
k[1],
1,
1,
p[0],
p[1],
s[0],
s[1],
) )
elif node.op_type == "GlobalAveragePool": elif node.op_type == "Tanh":
[_, _, h, w] = _search_shape(model, node.input[0]) tensors[node.output[0]] = self.handler.tanh(
tensors[node.output[0]] = self.handler.avgPool(
tensors[node.input[0]],
tensors.get(node.output[0]),
h,
w,
1,
1,
0,
0,
1,
1,
)
elif node.op_type == "Add":
tensors[node.output[0]] = self.handler.add(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Sub":
tensors[node.output[0]] = self.handler.sub(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Mul":
tensors[node.output[0]] = self.handler.mul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Div":
tensors[node.output[0]] = self.handler.div(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Pow":
tensors[node.output[0]] = self.handler.pow(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Relu":
tensors[node.output[0]] = self.handler.relu(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Sigmoid":
tensors[node.output[0]] = self.handler.sigmoid(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Tanh":
tensors[node.output[0]] = self.handler.tanh(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Softmax":
tensors[node.output[0]] = self.handler.softmax(
tensors[node.input[0]],
tensors.get(node.output[0]),
next(
(attr.i for attr in node.attribute if attr.name == "axis"), -1
),
)
elif node.op_type == "Abs":
tensors[node.output[0]] = self.handler.abs(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Shape":
tensors[node.output[0]] = self.handler.shape(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Identity":
tensors[node.output[0]] = self.handler.identity(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Flatten":
tensors[node.output[0]] = self.handler.flatten(
tensors[node.input[0]],
tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "axis")),
)
elif node.op_type == "PRelu":
tensors[node.output[0]] = self.handler.pRelu(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Clip":
tensors[node.output[0]] = self.handler.clip(
tensors[node.input[0]],
tensors.get(node.output[0]),
next(_parse_data(data[node.input[1]]).__iter__(), None)
if len(node.input) > 1
else None,
next(_parse_data(data[node.input[2]]).__iter__(), None)
if len(node.input) > 2
else None,
)
elif node.op_type == "Transpose":
perm = next(
(attr.ints for attr in node.attribute if attr.name == "perm"), None
)
tensors[node.output[0]] = self.handler.transpose(
tensors[node.input[0]],
tensors.get(node.output[0]),
perm,
)
elif node.op_type == "Reshape":
dims = _search_shape(model, node.input[0])
size = reduce(lambda acc, x: acc * x, dims)
input_shape = _parse_data(data[node.input[1]])
for i, x in enumerate(input_shape):
if x == 0:
input_shape[i] = dims[i]
temp = reduce(lambda acc, x: acc * x, input_shape, 1)
if temp < 0:
input_shape[input_shape.index(-1)] = size // -temp
tensors[node.output[0]] = self.handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
input_shape,
)
elif node.op_type == "Squeeze":
input_shape = _search_shape(model, node.input[0])
axes = set(
[int(i) for i in data[node.input[1]].int64_data]
if len(node.input) > 1
else _parse_attribute(node, {"axes": None})["axes"]
)
assert all(input_shape[d] == 1 for d in axes)
output_shape = []
for i, x in enumerate(input_shape):
if i not in axes:
output_shape.append(x)
tensors[node.output[0]] = self.handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
output_shape,
)
elif node.op_type == "Unsqueeze":
input_shape = _search_shape(model, node.input[0])
axes = (
[int(i) for i in data[node.input[1]].int64_data]
if len(node.input) > 1
else _parse_attribute(node, {"axes": None})["axes"]
)
for i in axes:
input_shape.insert(i, 1)
tensors[node.output[0]] = self.handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
input_shape,
)
elif node.op_type == "Concat":
tensors[node.output[0]] = self.handler.concat(
[tensors[name] for name in node.input],
tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "axis")),
)
elif node.op_type == "Split":
for name, tensor in zip(
node.output,
self.handler.split(
tensors[node.input[0]], tensors[node.input[0]],
None, tensors.get(node.output[0]),
)
elif node.op_type == "Softmax":
tensors[node.output[0]] = self.handler.softmax(
tensors[node.input[0]],
tensors.get(node.output[0]),
next( next(
(attr.i for attr in node.attribute if attr.name == "axis"), (attr.i for attr in node.attribute if attr.name == "axis"),
0, -1,
), ),
len(node.output), )
), elif node.op_type == "Abs":
): tensors[node.output[0]] = self.handler.abs(
tensors[name] = tensor
elif node.op_type == "Gather":
tensors[node.output[0]] = self.handler.gather(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "axis")),
)
elif node.op_type == "ReduceMean":
tensors[node.output[0]] = self.handler.reduce_mean(
tensors[node.input[0]],
tensors.get(node.output[0]),
# NOTE(constroy): `axes` is an attribute until opset version 13.
next(
(attr.ints for attr in node.attribute if attr.name == "axes"),
None,
),
next((attr.i for attr in node.attribute if attr.name == "keepdims"))
!= 0,
)
elif node.op_type == "Slice":
tensors[node.output[0]] = self.handler.slice(
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
_parse_data(data[node.input[2]]),
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
_parse_data(data[node.input[4]]) if len(node.input) > 4 else None,
)
elif node.op_type == "Pad":
tensors[node.output[0]] = self.handler.pad(
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
_parse_data(data[node.input[3]]) if len(node.input) > 3 else None,
)
elif node.op_type == "Dropout":
for name, tensor in zip(
node.output,
self.handler.dropout(
tensors[node.input[0]], tensors[node.input[0]],
tensors.get(node.output[0]), tensors.get(node.output[0]),
tensors.get(node.output[1]) if len(node.output) > 1 else None, )
_parse_data(data[node.input[1]])[0] elif node.op_type == "Shape":
tensors[node.output[0]] = self.handler.shape(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Identity":
tensors[node.output[0]] = self.handler.identity(
tensors[node.input[0]],
tensors.get(node.output[0]),
)
elif node.op_type == "Flatten":
tensors[node.output[0]] = self.handler.flatten(
tensors[node.input[0]],
tensors.get(node.output[0]),
next(
(attr.i for attr in node.attribute if attr.name == "axis")
),
)
elif node.op_type == "PRelu":
tensors[node.output[0]] = self.handler.pRelu(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "Clip":
tensors[node.output[0]] = self.handler.clip(
tensors[node.input[0]],
tensors.get(node.output[0]),
next(_parse_data(data[node.input[1]]).__iter__(), None)
if len(node.input) > 1 if len(node.input) > 1
else 0.5, else None,
_parse_data(data[node.input[2]])[0] next(_parse_data(data[node.input[2]]).__iter__(), None)
if len(node.input) > 2 if len(node.input) > 2
else False, else None,
), )
): elif node.op_type == "Transpose":
tensors[name] = tensor perm = next(
else: (attr.ints for attr in node.attribute if attr.name == "perm"),
raise Exception('Unsupported operator "{}"'.format(node.op_type)) None,
)
tensors[node.output[0]] = self.handler.transpose(
tensors[node.input[0]],
tensors.get(node.output[0]),
perm,
)
elif node.op_type == "Reshape":
dims = _search_shape(model, node.input[0])
size = reduce(lambda acc, x: acc * x, dims)
input_shape = _parse_data(data[node.input[1]])
for i, x in enumerate(input_shape):
if x == 0:
input_shape[i] = dims[i]
temp = reduce(lambda acc, x: acc * x, input_shape, 1)
if temp < 0:
input_shape[input_shape.index(-1)] = size // -temp
tensors[node.output[0]] = self.handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
input_shape,
)
elif node.op_type == "Squeeze":
input_shape = _search_shape(model, node.input[0])
axes = set(
[int(i) for i in data[node.input[1]].int64_data]
if len(node.input) > 1
else _parse_attribute(node, {"axes": None})["axes"]
)
assert all(input_shape[d] == 1 for d in axes)
output_shape = []
for i, x in enumerate(input_shape):
if i not in axes:
output_shape.append(x)
tensors[node.output[0]] = self.handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
output_shape,
)
elif node.op_type == "Unsqueeze":
input_shape = _search_shape(model, node.input[0])
axes = (
[int(i) for i in data[node.input[1]].int64_data]
if len(node.input) > 1
else _parse_attribute(node, {"axes": None})["axes"]
)
for i in axes:
input_shape.insert(i, 1)
tensors[node.output[0]] = self.handler.reshape(
tensors[node.input[0]],
tensors.get(node.output[0]),
input_shape,
)
elif node.op_type == "Concat":
tensors[node.output[0]] = self.handler.concat(
[tensors[name] for name in node.input],
tensors.get(node.output[0]),
next(
(attr.i for attr in node.attribute if attr.name == "axis")
),
)
elif node.op_type == "Split":
for name, tensor in zip(
node.output,
self.handler.split(
tensors[node.input[0]],
None,
next(
(
attr.i
for attr in node.attribute
if attr.name == "axis"
),
0,
),
len(node.output),
),
):
tensors[name] = tensor
elif node.op_type == "Gather":
tensors[node.output[0]] = self.handler.gather(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
next(
(attr.i for attr in node.attribute if attr.name == "axis")
),
)
elif node.op_type == "ReduceMean":
tensors[node.output[0]] = self.handler.reduce_mean(
tensors[node.input[0]],
tensors.get(node.output[0]),
# NOTE(constroy): `axes` is an attribute until opset version 13.
next(
(
attr.ints
for attr in node.attribute
if attr.name == "axes"
),
None,
),
next(
(
attr.i
for attr in node.attribute
if attr.name == "keepdims"
)
)
!= 0,
)
elif node.op_type == "Slice":
tensors[node.output[0]] = self.handler.slice(
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
_parse_data(data[node.input[2]]),
_parse_data(data[node.input[3]])
if len(node.input) > 3
else None,
_parse_data(data[node.input[4]])
if len(node.input) > 4
else None,
)
elif node.op_type == "Pad":
tensors[node.output[0]] = self.handler.pad(
tensors[node.input[0]],
tensors.get(node.output[0]),
_parse_data(data[node.input[1]]),
_parse_data(data[node.input[3]])
if len(node.input) > 3
else None,
)
elif node.op_type == "Dropout":
for name, tensor in zip(
node.output,
self.handler.dropout(
tensors[node.input[0]],
tensors.get(node.output[0]),
tensors.get(node.output[1])
if len(node.output) > 1
else None,
_parse_data(data[node.input[1]])[0]
if len(node.input) > 1
else 0.5,
_parse_data(data[node.input[2]])[0]
if len(node.input) > 2
else False,
),
):
tensors[name] = tensor
elif node.op_type == "Cast":
tensors[node.output[0]] = self.handler.cast(
tensors[node.input[0]],
tensors.get(node.output[0]),
next((attr.i for attr in node.attribute if attr.name == "to")),
)
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
new_node_name.append(node.name)
# update the node_list
node_list = list(set(node_name) - set(new_node_name))
self.handler.data_malloc() self.handler.data_malloc()
@ -540,6 +598,8 @@ class OnnxStub:
obj.copyin_float16(_parse_data_fp16(tensor)) obj.copyin_float16(_parse_data_fp16(tensor))
elif tensor.data_type == TensorProto.INT8: elif tensor.data_type == TensorProto.INT8:
obj.copyin_uint8(_parse_data(tensor)) obj.copyin_uint8(_parse_data(tensor))
elif tensor.data_type == TensorProto.BFLOAT16:
obj.copyin_float16(_parse_data_fp16(tensor))
else: else:
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
@ -823,6 +883,9 @@ class OnnxStub:
ctx.push_data_input(name, "max", TensorProto.FLOAT, [], []) ctx.push_data_input(name, "max", TensorProto.FLOAT, [], [])
) )
ctx.push_node(make_node(ty.name, inputs, outputs, name)) ctx.push_node(make_node(ty.name, inputs, outputs, name))
elif ty == backend.OpTypeId.Cast:
to = backend.cast_to_of(op)
ctx.push_node(make_node(ty.name, inputs, outputs, name, to=to))
else: else:
raise Exception("Unsupported OpType", ty) raise Exception("Unsupported OpType", ty)
@ -922,3 +985,10 @@ def _parse_data_fp16(tensor: TensorProto):
def _take_shape_dim(shape: TensorShapeProto) -> List[int]: def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim] return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
def _analyse_node(node: NodeProto, tensors) -> bool:
for i in node.input:
if i not in tensors:
return True
return False

View File

@ -79,6 +79,21 @@ class TestStringMethods(unittest.TestCase):
) )
make_and_import_model(make_graph([conv], "conv_fp16", [i, w], [o])) make_and_import_model(make_graph([conv], "conv_fp16", [i, w], [o]))
def test_conv_bfp16(self):
i = make_tensor_value_info("i", TensorProto.BFLOAT16, [1, 3, 4, 4])
w = make_tensor_value_info("w", TensorProto.BFLOAT16, [2, 3, 3, 3])
o = make_tensor_value_info("o", TensorProto.BFLOAT16, [1, 2, 2, 2])
conv = make_node(
"Conv",
["i", "w"],
["o"],
"conv",
pads=[1, 1, 1, 1],
strides=[2, 1],
dilations=[1, 2],
)
make_and_import_model(make_graph([conv], "conv_bfp16", [i, w], [o]))
def test_matmul(self): def test_matmul(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4]) a = make_tensor_value_info("a", TensorProto.FLOAT, [1, 3, 4])
@ -226,9 +241,7 @@ class TestStringMethods(unittest.TestCase):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7]) x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
y = make_tensor_value_info("y", TensorProto.FLOAT, [1 * 3, 5 * 7]) y = make_tensor_value_info("y", TensorProto.FLOAT, [1 * 3, 5 * 7])
flatten = make_node("Flatten", ["x"], ["y"], axis=2, name="flatten") flatten = make_node("Flatten", ["x"], ["y"], axis=2, name="flatten")
make_and_import_model( make_and_import_model(make_graph([flatten], "flatten", [x], [y]))
make_graph([flatten], "flatten", [x], [y])
)
def test_reshape(self): def test_reshape(self):
data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5]) data = make_tensor_value_info("data", TensorProto.FLOAT, [2, 3, 4, 5])
@ -331,6 +344,14 @@ class TestStringMethods(unittest.TestCase):
y = handler.tensor([3, 2, 1], 12) y = handler.tensor([3, 2, 1], 12)
handler.reshape(x, y, [3, 2, 1]) handler.reshape(x, y, [3, 2, 1])
def test_cast(self):
input1 = make_tensor_value_info("input1", TensorProto.FLOAT, [1, 3, 2, 4])
output = make_tensor_value_info("output", TensorProto.FLOAT16, [1, 3, 2, 4])
cast = make_node(
"Cast", ["input1"], ["output"], to=TensorProto.FLOAT16, name="cast"
)
make_and_import_model(make_graph([cast], "cast", [input1], [output]))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -18,6 +18,7 @@
namespace infini { namespace infini {
static DataType dtype_repr_convert(int); static DataType dtype_repr_convert(int);
static CastType inferCastType(Tensor input, int to);
Tensor GraphHandlerObj::tensor(Shape dims, int dtype) { Tensor GraphHandlerObj::tensor(Shape dims, int dtype) {
return g->addTensor(std::move(dims), dtype_repr_convert(dtype)); return g->addTensor(std::move(dims), dtype_repr_convert(dtype));
@ -293,6 +294,76 @@ Tensor GraphHandlerObj::pad(Tensor input, Tensor output,
} }
} }
Tensor GraphHandlerObj::cast(Tensor input, Tensor output, int to) {
if (output) {
g->addOpWithOutputs<CastObj>(std::move(input), output,
inferCastType(input, to));
return output;
} else {
return g
->addOp<CastObj>(std::move(input), output, inferCastType(input, to))
->getOutput();
}
}
static CastType inferCastType(Tensor input, int to) {
auto iType = input->getDType();
auto oType = DataType(to);
if (iType == DataType::Float32 && oType == DataType::Float16) {
return CastType::Float2Float16;
} else if (iType == DataType::Float32 && oType == DataType::Int64) {
return CastType::Float2Int64;
} else if (iType == DataType::Float32 && oType == DataType::Int32) {
return CastType::Float2Int32;
} else if (iType == DataType::Float32 && oType == DataType::Int16) {
return CastType::Float2Int16;
} else if (iType == DataType::Float32 && oType == DataType::Int8) {
return CastType::Float2Int8;
} else if (iType == DataType::Float32 && oType == DataType::BFloat16) {
return CastType::Float2BFloat16;
} else if (iType == DataType::Int32 && oType == DataType::Float32) {
return CastType::Int322Float;
} else if (iType == DataType::Int32 && oType == DataType::Int8) {
return CastType::Int322Int8;
} else if (iType == DataType::Int32 && oType == DataType::Int16) {
return CastType::Int322Int16;
} else if (iType == DataType::Int32 && oType == DataType::Int64) {
return CastType::Int322Int64;
} else if (iType == DataType::Int16 && oType == DataType::Int32) {
return CastType::Int162Int32;
} else if (iType == DataType::Int16 && oType == DataType::Float32) {
return CastType::Int162Float;
} else if (iType == DataType::Int8 && oType == DataType::Float32) {
return CastType::Int82Float;
} else if (iType == DataType::Int8 && oType == DataType::Int16) {
return CastType::Int82Int16;
} else if (iType == DataType::Int8 && oType == DataType::Int32) {
return CastType::Int82Int32;
} else if (iType == DataType::UInt8 && oType == DataType::Int32) {
return CastType::Uint82Int32;
} else if (iType == DataType::UInt8 && oType == DataType::Float32) {
return CastType::Uint82Float;
} else if (iType == DataType::UInt8 && oType == DataType::Int64) {
return CastType::Uint82Int64;
} else if (iType == DataType::Int64 && oType == DataType::Float32) {
return CastType::Int642Float;
} else if (iType == DataType::Int64 && oType == DataType::UInt32) {
return CastType::Int642Uint32;
} else if (iType == DataType::Int64 && oType == DataType::Int32) {
return CastType::Int642Int32;
} else if (iType == DataType::UInt32 && oType == DataType::Int64) {
return CastType::Uint322Int64;
} else if (iType == DataType::Float16 && oType == DataType::Float32) {
return CastType::Float162Float;
} else if (iType == DataType::BFloat16 && oType == DataType::Float32) {
return CastType::BFloat162Float;
} else {
IT_TODO_HALT_MSG("Unsupported CastType : input_type is " +
iType.toString() + " output_type is " +
oType.toString());
}
}
static DataType dtype_repr_convert(int dtype) { static DataType dtype_repr_convert(int dtype) {
switch (dtype) { switch (dtype) {
case 0: case 0:
@ -323,6 +394,8 @@ static DataType dtype_repr_convert(int dtype) {
return DataType::UInt32; return DataType::UInt32;
case 13: case 13:
return DataType::UInt64; return DataType::UInt64;
case 16:
return DataType::BFloat16;
default: default:
IT_ASSERT(false, "Unsupported data type"); IT_ASSERT(false, "Unsupported data type");
} }

View File

@ -85,6 +85,7 @@ void TensorObj::printData() const {
else TRY_PRINT(11) // else TRY_PRINT(11) //
else TRY_PRINT(12) // else TRY_PRINT(12) //
else TRY_PRINT(13) // else TRY_PRINT(13) //
else TRY_PRINT(16) //
else IT_TODO_HALT(); else IT_TODO_HALT();
#undef TRY_PRINT #undef TRY_PRINT
@ -118,6 +119,7 @@ bool TensorObj::equalData(const Tensor &rhs, double relativeError) const {
else TEST_EQUAL(11) // else TEST_EQUAL(11) //
else TEST_EQUAL(12) // else TEST_EQUAL(12) //
else TEST_EQUAL(13) // else TEST_EQUAL(13) //
else TEST_EQUAL(16) //
else IT_TODO_HALT(); else IT_TODO_HALT();
#undef TEST_EQUAL #undef TEST_EQUAL

View File

@ -95,6 +95,7 @@ void export_values(py::module &m) {
.VALUE(OpType, Abs) .VALUE(OpType, Abs)
.VALUE(OpType, Resize) .VALUE(OpType, Resize)
.VALUE(OpType, Dropout) .VALUE(OpType, Dropout)
.VALUE(OpType, Cast)
.export_values(); .export_values();
#undef VALUE #undef VALUE
@ -129,6 +130,8 @@ static int tensor_dtype(Tensor t) {
return 12; return 12;
if (t->getDType() == DataType::UInt64) if (t->getDType() == DataType::UInt64)
return 13; return 13;
if (t->getDType() == DataType::BFloat16)
return 16;
IT_ASSERT(false, "Unsupported data type"); IT_ASSERT(false, "Unsupported data type");
} }
@ -242,6 +245,13 @@ static int flatten_axis_of(Operator op) {
return dynamic_cast<const FlattenObj *>(op.get())->getAxis(); return dynamic_cast<const FlattenObj *>(op.get())->getAxis();
} }
static int cast_to_of(Operator op) {
IT_ASSERT(op->getOpType() == OpType::Cast);
auto castOutputDtype =
dynamic_cast<const CastObj *>(op.get())->getOutputDataType();
return castOutputDtype.getIndex();
}
void export_functions(py::module &m) { void export_functions(py::module &m) {
#define FUNCTION(NAME) def(#NAME, &NAME) #define FUNCTION(NAME) def(#NAME, &NAME)
m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance) m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance)
@ -271,7 +281,8 @@ void export_functions(py::module &m) {
.FUNCTION(concat_axis_of) .FUNCTION(concat_axis_of)
.FUNCTION(split_axis_of) .FUNCTION(split_axis_of)
.FUNCTION(gather_axis_of) .FUNCTION(gather_axis_of)
.FUNCTION(flatten_axis_of); .FUNCTION(flatten_axis_of)
.FUNCTION(cast_to_of);
#undef FUNCTION #undef FUNCTION
} }
@ -346,6 +357,7 @@ void init_graph_builder(py::module &m) {
.def("reduce_mean", &Handler::reduceMean, policy::move) .def("reduce_mean", &Handler::reduceMean, policy::move)
.def("slice", &Handler::slice, policy::move) .def("slice", &Handler::slice, policy::move)
.def("pad", &Handler::pad, policy::move) .def("pad", &Handler::pad, policy::move)
.def("cast", &Handler::cast, policy::move)
.def("topo_sort", &Handler::topo_sort, policy::automatic) .def("topo_sort", &Handler::topo_sort, policy::automatic)
.def("optimize", &Handler::optimize, policy::automatic) .def("optimize", &Handler::optimize, policy::automatic)
.def("operators", &Handler::operators, policy::move) .def("operators", &Handler::operators, policy::move)

View File

@ -13,7 +13,6 @@ template <typename T> class NaiveMatmul : public CpuKernelWithoutConfig {
T *C = op->getOutput()->getRawDataPtr<T *>(); T *C = op->getOutput()->getRawDataPtr<T *>();
IT_ASSERT(op->getTransA() == false && op->getTransB() == false); IT_ASSERT(op->getTransA() == false && op->getTransB() == false);
IT_ASSERT(op->getAct() == ActType::None); IT_ASSERT(op->getAct() == ActType::None);
IT_ASSERT(op->getB() == 1);
const int M = op->getM(), N = op->getN(), K = op->getK(); const int M = op->getM(), N = op->getN(), K = op->getK();
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
for (int j = 0; j < N; j++) { for (int j = 0; j < N; j++) {

View File

@ -14,9 +14,9 @@ class GatherCuda : public CudaKernelWithoutConfig {
auto out = op->getOutput(); auto out = op->getOutput();
metaData.indexValue = index->getRawDataPtr<int *>(); metaData.indexValue = index->getRawDataPtr<int *>();
metaData.axis = op->getAxis(); metaData.axis = op->getAxis();
metaData.inNDim = in->getDims().size(); metaData.inNDim = in->getRank();
metaData.outNDim = out->getDims().size(); metaData.outNDim = out->getRank();
metaData.idxNDim = index->getDims().size(); metaData.idxNDim = index->getRank();
for (int i = 0; i < metaData.outNDim; ++i) for (int i = 0; i < metaData.outNDim; ++i)
metaData.outDim[i] = out->getDims()[i]; metaData.outDim[i] = out->getDims()[i];
for (int i = 0; i < metaData.idxNDim; ++i) { for (int i = 0; i < metaData.idxNDim; ++i) {

View File

@ -51,8 +51,8 @@ class matmulCublas : public Kernel {
cublasStatus_t stat; cublasStatus_t stat;
if (b > 1) { if (b > 1) {
// Support batch broadcast with zero stride // Support batch broadcast with zero stride
int dimA = op->getInputs(0)->getDims().size(); int dimA = op->getInputs(0)->getRank();
int dimB = op->getInputs(1)->getDims().size(); int dimB = op->getInputs(1)->getRank();
long long strideA = long long strideA =
(dimA == 2 || (dimA == 2 ||
(dimA == 3 && op->getInputs(0)->getDims()[0] == 1)) (dimA == 3 && op->getInputs(0)->getDims()[0] == 1))

View File

@ -7,7 +7,7 @@ class PadSliceCudaCompute {
public: public:
void do_compute(Tensor partTensor, Tensor wholeTensor, const Shape &begNos, void do_compute(Tensor partTensor, Tensor wholeTensor, const Shape &begNos,
bool isPad) const { bool isPad) const {
int nDims = partTensor->getDims().size(); int nDims = partTensor->getRank();
IT_ASSERT(MAX_DIM >= nDims); IT_ASSERT(MAX_DIM >= nDims);
TransMetaData metadata; TransMetaData metadata;
for (int i = 0; i < nDims; i++) { for (int i = 0; i < nDims; i++) {

View File

@ -14,7 +14,7 @@ class ReduceMeanCudnn : public CudaKernelWithoutConfig {
// Each dimension of the output tensor C must match the corresponding // Each dimension of the output tensor C must match the corresponding
// dimension of the input tensor A or must be equal to 1. The dimensions // dimension of the input tensor A or must be equal to 1. The dimensions
// equal to 1 indicate the dimensions of A to be reduced. // equal to 1 indicate the dimensions of A to be reduced.
int nInDims = input->getDims().size(); int nInDims = input->getRank();
IT_ASSERT(CUDNN_DIM_MAX >= nInDims); IT_ASSERT(CUDNN_DIM_MAX >= nInDims);
int inDimArray[CUDNN_DIM_MAX], outDimArray[CUDNN_DIM_MAX], int inDimArray[CUDNN_DIM_MAX], outDimArray[CUDNN_DIM_MAX],
inStrideArray[CUDNN_DIM_MAX], outStrideArray[CUDNN_DIM_MAX]; inStrideArray[CUDNN_DIM_MAX], outStrideArray[CUDNN_DIM_MAX];

View File

@ -9,7 +9,7 @@ class ResizeCuda : public CudaKernelWithoutConfig {
auto in = op->getInputs(0); auto in = op->getInputs(0);
auto out = op->getOutputs()[0]; auto out = op->getOutputs()[0];
int nDims = in->getDims().size(); int nDims = in->getRank();
if (nDims > 4) if (nDims > 4)
IT_TODO_HALT(); IT_TODO_HALT();

View File

@ -9,7 +9,7 @@ namespace infini {
class CudaCompute { class CudaCompute {
void initComposedTensorMetadata(ComposedTensorMetadata &metadata, void initComposedTensorMetadata(ComposedTensorMetadata &metadata,
Tensor tensor) const { Tensor tensor) const {
int nDims = tensor->getDims().size(); int nDims = tensor->getRank();
auto strides = tensor->getStride(); auto strides = tensor->getStride();
IT_ASSERT(strides.size() == (size_t)nDims); IT_ASSERT(strides.size() == (size_t)nDims);
for (int i = 0; i < nDims; ++i) { for (int i = 0; i < nDims; ++i) {
@ -60,8 +60,8 @@ class ConcatCuda : private CudaCompute, public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
do_compute(_op->getOutput(), _op->getInputs(), do_compute(_op->getOutput(), _op->getInputs(),
as<ConcatObj>(_op)->getDim(), as<ConcatObj>(_op)->getDim(), _op->getOutput()->getRank(),
_op->getOutput()->getDims().size(), false); false);
} }
}; };
@ -69,8 +69,8 @@ class SplitCuda : private CudaCompute, public CudaKernelWithoutConfig {
void compute(const Operator &_op, void compute(const Operator &_op,
const RuntimeObj *_context) const override { const RuntimeObj *_context) const override {
do_compute(_op->getInputs(0), _op->getOutputs(), do_compute(_op->getInputs(0), _op->getOutputs(),
as<SplitObj>(_op)->getDim(), as<SplitObj>(_op)->getDim(), _op->getInputs(0)->getRank(),
_op->getInputs(0)->getDims().size(), true); true);
} }
}; };

View File

@ -14,7 +14,7 @@ class MklBatchNorm : public MklKernelWithoutConfig {
// create user memory that describes data layout in the buffers // create user memory that describes data layout in the buffers
std::vector<dnnl_dim_t> dims; std::vector<dnnl_dim_t> dims;
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i)
dims.push_back(op->getInputs(0)->getDims()[i]); dims.push_back(op->getInputs(0)->getDims()[i]);
auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
@ -25,7 +25,7 @@ class MklBatchNorm : public MklKernelWithoutConfig {
getUserFormatTag(dims.size())); getUserFormatTag(dims.size()));
auto output = dnnl::memory(dstMd, context->getEngine(), dstData); auto output = dnnl::memory(dstMd, context->getEngine(), dstData);
std::vector<dnnl_dim_t> meanDims(op->getInputs(0)->getDims().size(), 1); std::vector<dnnl_dim_t> meanDims(op->getInputs(0)->getRank(), 1);
meanDims[1] = op->getInputs(0)->getDims()[1]; meanDims[1] = op->getInputs(0)->getDims()[1];
auto meanMd = dnnl::memory::desc(meanDims, dnnl::memory::data_type::f32, auto meanMd = dnnl::memory::desc(meanDims, dnnl::memory::data_type::f32,
getUserFormatTag(meanDims.size())); getUserFormatTag(meanDims.size()));

View File

@ -34,7 +34,7 @@ class MklBinary : public MklKernelWithoutConfig {
// create user memory that describes data layout in the buffers // create user memory that describes data layout in the buffers
std::vector<dnnl_dim_t> dims; std::vector<dnnl_dim_t> dims;
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i)
dims.push_back(op->getInputs(0)->getDims()[i]); dims.push_back(op->getInputs(0)->getDims()[i]);
auto srcMd1 = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, auto srcMd1 = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
@ -89,7 +89,7 @@ class MklUnary : public MklKernelWithoutConfig {
// create user memory that describes data layout in the buffers // create user memory that describes data layout in the buffers
std::vector<dnnl_dim_t> dims; std::vector<dnnl_dim_t> dims;
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i)
dims.push_back(op->getInputs(0)->getDims()[i]); dims.push_back(op->getInputs(0)->getDims()[i]);
auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,

View File

@ -17,9 +17,9 @@ class MklGather : public MklKernelWithoutConfig {
int oSize = out->size(); int oSize = out->size();
int idxSize = index->size(); int idxSize = index->size();
int inNDim = in->getDims().size(); int inNDim = in->getRank();
int oNDim = out->getDims().size(); int oNDim = out->getRank();
int idxNDim = index->getDims().size(); int idxNDim = index->getRank();
int axis = op->getAxis(); int axis = op->getAxis();
int outDim[4] = {0}; int outDim[4] = {0};

View File

@ -10,7 +10,7 @@ class MklPad : public MklKernelWithoutConfig {
auto context = dynamic_cast<const MklRuntimeObj *>(_context); auto context = dynamic_cast<const MklRuntimeObj *>(_context);
std::vector<dnnl_dim_t> dims; std::vector<dnnl_dim_t> dims;
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) { for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) {
dims.push_back(op->getInputs(0)->getDims()[i]); dims.push_back(op->getInputs(0)->getDims()[i]);
} }
auto paddedMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, auto paddedMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,

View File

@ -17,7 +17,7 @@ class MklPooling : public MklKernelWithoutConfig {
// create user memory that describes data layout in the buffers // create user memory that describes data layout in the buffers
auto [n, c, h, w, r, s] = op->getNCHWRS(); auto [n, c, h, w, r, s] = op->getNCHWRS();
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation(); auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
auto nDim = op->getOutput()->getDims().size(); auto nDim = op->getOutput()->getRank();
auto oh = op->getOutput()->getDims()[nDim - 2]; auto oh = op->getOutput()->getDims()[nDim - 2];
auto ow = op->getOutput()->getDims()[nDim - 1]; auto ow = op->getOutput()->getDims()[nDim - 1];

View File

@ -18,16 +18,16 @@ class MklReduce : public MklKernelWithoutConfig {
// create user memory that describes data layout in the buffers // create user memory that describes data layout in the buffers
std::vector<dnnl_dim_t> inDims, inStrides; std::vector<dnnl_dim_t> inDims, inStrides;
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) { for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) {
inDims.push_back(op->getInputs(0)->getDims()[i]); inDims.push_back(op->getInputs(0)->getDims()[i]);
inStrides.push_back(op->getInputs(0)->getStride()[i]); inStrides.push_back(op->getInputs(0)->getStride()[i]);
} }
std::vector<dnnl_dim_t> oDims(op->getInputs(0)->getDims().size(), 0), std::vector<dnnl_dim_t> oDims(op->getInputs(0)->getRank(), 0),
oStrides(op->getInputs(0)->getDims().size(), 1); oStrides(op->getInputs(0)->getRank(), 1);
if (!op->getKeepDims()) { if (!op->getKeepDims()) {
oDims = inDims; oDims = inDims;
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) { for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) {
if (op->isReduced(i)) { if (op->isReduced(i)) {
oDims[i] = 1; oDims[i] = 1;
} }
@ -38,7 +38,7 @@ class MklReduce : public MklKernelWithoutConfig {
stride *= oDims[i]; stride *= oDims[i];
} }
} else { } else {
for (size_t i = 0; i < op->getOutput(0)->getDims().size(); ++i) { for (size_t i = 0; i < op->getOutput(0)->getRank(); ++i) {
oDims[i] = op->getOutput(0)->getDims()[i]; oDims[i] = op->getOutput(0)->getDims()[i];
oStrides[i] = op->getOutput(0)->getStride()[i]; oStrides[i] = op->getOutput(0)->getStride()[i];
} }

View File

@ -10,7 +10,7 @@ class MklReshape : public MklKernelWithoutConfig {
auto context = dynamic_cast<const MklRuntimeObj *>(_context); auto context = dynamic_cast<const MklRuntimeObj *>(_context);
std::vector<dnnl_dim_t> dims; std::vector<dnnl_dim_t> dims;
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i)
dims.push_back(op->getInputs(0)->getDims()[i]); dims.push_back(op->getInputs(0)->getDims()[i]);
// create src md and src memory // create src md and src memory

View File

@ -30,7 +30,7 @@ class MklResize : public MklKernelWithoutConfig {
enum_to_underlying(ResizeObj::ECoordinateTransMode::halfPixel)) enum_to_underlying(ResizeObj::ECoordinateTransMode::halfPixel))
IT_TODO_HALT(); IT_TODO_HALT();
int nDim = op->getInputs(0)->getDims().size(); int nDim = op->getInputs(0)->getRank();
IT_ASSERT(nDim == 3 || nDim == 4 || IT_ASSERT(nDim == 3 || nDim == 4 ||
nDim == 5 && nDim == 5 &&
(op->getInputs(0)->getDims()[0] == 1 && (op->getInputs(0)->getDims()[0] == 1 &&
@ -44,7 +44,7 @@ class MklResize : public MklKernelWithoutConfig {
// create user memory that describes data layout in the buffers // create user memory that describes data layout in the buffers
std::vector<dnnl_dim_t> idims, odims; std::vector<dnnl_dim_t> idims, odims;
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) { for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i) {
idims.push_back(op->getInputs(0)->getDims()[i]); idims.push_back(op->getInputs(0)->getDims()[i]);
odims.push_back(op->getOutput(0)->getDims()[i]); odims.push_back(op->getOutput(0)->getDims()[i]);
} }

View File

@ -10,7 +10,7 @@ class MklSlice : public MklKernelWithoutConfig {
auto context = dynamic_cast<const MklRuntimeObj *>(_context); auto context = dynamic_cast<const MklRuntimeObj *>(_context);
std::vector<dnnl_dim_t> dims; std::vector<dnnl_dim_t> dims;
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i)
dims.push_back(op->getInputs(0)->getDims()[i]); dims.push_back(op->getInputs(0)->getDims()[i]);
// create src md // create src md

View File

@ -14,7 +14,7 @@ class MklSoftmax : public MklKernelWithoutConfig {
// create user memory that describes data layout in the buffers // create user memory that describes data layout in the buffers
std::vector<dnnl_dim_t> dims; std::vector<dnnl_dim_t> dims;
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i)
dims.push_back(op->getInputs(0)->getDims()[i]); dims.push_back(op->getInputs(0)->getDims()[i]);
auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32, auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,

View File

@ -10,7 +10,7 @@ class MklSplit : public MklKernelWithoutConfig {
auto context = dynamic_cast<const MklRuntimeObj *>(_context); auto context = dynamic_cast<const MklRuntimeObj *>(_context);
std::vector<dnnl_dim_t> dims; std::vector<dnnl_dim_t> dims;
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) for (size_t i = 0; i < op->getInputs(0)->getRank(); ++i)
dims.push_back(op->getInputs(0)->getDims()[i]); dims.push_back(op->getInputs(0)->getDims()[i]);
// create src md // create src md

View File

@ -23,16 +23,11 @@ string G2BMMObj::toString() const {
optional<vector<Shape>> G2BMMObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> G2BMMObj::inferShape(const TensorVec &inputs) const {
auto A = inputs[0], B = inputs[1]; auto A = inputs[0], B = inputs[1];
if (!(A->getDims().size() == 3 && B->getDims().size() == 3)) IT_ASSERT(A->getRank() == 3 && B->getRank() == 3);
return {}; IT_ASSERT(A->getDims()[0] == B->getDims()[0]);
if (!(A->getDims()[0] == B->getDims()[0])) IT_ASSERT(A->getDims()[1] == B->getDims()[1]);
return {}; IT_ASSERT(A->getDims()[2] == B->getDims()[2]);
if (!(A->getDims()[1] == B->getDims()[1])) IT_ASSERT(width >= 0);
return {};
if (!(A->getDims()[2] == B->getDims()[2]))
return {};
if (width < 0)
return {};
int b(A->getDims()[0]), m(A->getDims()[1]), n(2 * width + 1); int b(A->getDims()[0]), m(A->getDims()[1]), n(2 * width + 1);
return {{{b, m, n}}}; return {{{b, m, n}}};
} }

View File

@ -24,14 +24,10 @@ string GBMMObj::toString() const {
optional<vector<Shape>> GBMMObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> GBMMObj::inferShape(const TensorVec &inputs) const {
auto A = inputs[0], B = inputs[1]; auto A = inputs[0], B = inputs[1];
if (!(A->getDims().size() == 3 && B->getDims().size() == 3)) IT_ASSERT(A->getRank() == 3 && B->getRank() == 3);
return {}; IT_ASSERT(A->getDims()[0] == B->getDims()[0]);
if (!(A->getDims()[0] == B->getDims()[0])) IT_ASSERT(A->getDims()[1] == B->getDims()[1]);
return {}; IT_ASSERT(A->getDims()[2] % 2 != 0);
if (!(A->getDims()[1] == B->getDims()[1]))
return {};
if (A->getDims()[2] % 2 == 0)
return {};
int b(A->getDims()[0]), m(A->getDims()[1]), k(B->getDims()[2]); int b(A->getDims()[0]), m(A->getDims()[1]), k(B->getDims()[2]);
return {{{b, m, k}}}; return {{{b, m, k}}};
} }

View File

@ -21,9 +21,10 @@ BatchNormObj::inferShape(const TensorVec &inputs) const {
auto scale = inputs[3]; auto scale = inputs[3];
auto bias = inputs[4]; auto bias = inputs[4];
auto c = std::vector<int>{input->getDims()[1]}; auto c = std::vector<int>{input->getDims()[1]};
if (mean->getDims() != c || var->getDims() != c || scale->getDims() != c || IT_ASSERT(mean->getRank() == 1 && mean->getDims() == c);
bias->getDims() != c) IT_ASSERT(var->getRank() == 1 && var->getDims() == c);
return {}; IT_ASSERT(scale->getRank() == 1 && scale->getDims() == c);
IT_ASSERT(bias->getRank() == 1 && bias->getDims() == c);
return {{input->getDims()}}; return {{input->getDims()}};
} }

View File

@ -1,28 +1,29 @@
#include "operators/concat.h" #include "operators/concat.h"
#include "utils/operator_utils.h"
namespace infini { namespace infini {
ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim) ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim)
: OperatorObj(OpType::Concat, inputs, {output}), dim(dim) { : OperatorObj(OpType::Concat, inputs, {output}), dim(dim) {
int rank = inputs[0]->getRank();
dim = get_real_axis(dim, rank);
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ConcatObj::inferShape(const TensorVec &inputs) const {
IT_ASSERT(inputs.size() > 1); IT_ASSERT(inputs.size() > 1);
Shape dims = inputs[0]->getDims(); Shape dims = inputs[0]->getDims();
auto rank = inputs[0]->getRank();
ShapeElem n = dims.at(dim); ShapeElem n = dims.at(dim);
for (auto itr = inputs.begin() + 1; itr != inputs.end(); ++itr) { for (auto itr = inputs.begin() + 1; itr != inputs.end(); ++itr) {
auto input = *itr; auto input = *itr;
auto iDims = input->getDims(); auto iDims = input->getDims();
if (dims.size() != iDims.size()) IT_ASSERT(rank == input->getRank());
return {}; for (auto i = 0; i < (int)rank; i++) {
int nDims = dims.size();
for (auto i = 0; i < nDims; i++) {
if (i == dim) { if (i == dim) {
n += iDims.at(i); n += iDims.at(i);
continue; continue;
} }
if (iDims.at(i) != dims.at(i)) IT_ASSERT(iDims.at(i) == dims.at(i));
return {};
} }
} }
dims[dim] = n; dims[dim] = n;

View File

@ -93,8 +93,7 @@ optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
int on = n, oc = f; int on = n, oc = f;
int oh = 0, ow = 0; int oh = 0, ow = 0;
// For NCHW+FCRS layout, C of input is divisable by C of weight // For NCHW+FCRS layout, C of input is divisable by C of weight
if (input->getDims()[1] % weight->getDims()[1] != 0) IT_ASSERT(input->getDims()[1] % weight->getDims()[1] == 0);
return {};
// Set padding size // Set padding size
if (padding == PaddingMode::Other) { if (padding == PaddingMode::Other) {
oh = (h - (r - sh) * dh + ph * 2) / sh; oh = (h - (r - sh) * dh + ph * 2) / sh;
@ -151,8 +150,7 @@ ConvTransposed2dObj::inferShape(const TensorVec &inputs) const {
auto c = weight->getDims()[1]; auto c = weight->getDims()[1];
auto r = weight->getDims()[2]; auto r = weight->getDims()[2];
auto s = weight->getDims()[3]; auto s = weight->getDims()[3];
if (f != weight->getDims()[0]) IT_ASSERT(f == weight->getDims()[0]);
return {};
int on = n, oc = c * group; int on = n, oc = c * group;
int oh = 0, ow = 0; int oh = 0, ow = 0;
@ -232,8 +230,7 @@ ConvBackwardFilterObj::inferShape(const TensorVec &inputs) const {
int on = n, oc = f; int on = n, oc = f;
int oh = 0, ow = 0; int oh = 0, ow = 0;
// For NCHW+FCRS layout, C of input is divisable by C of weight // For NCHW+FCRS layout, C of input is divisable by C of weight
if (inputX->getDims()[1] % diffY->getDims()[1] != 0) IT_ASSERT(inputX->getDims()[1] % diffY->getDims()[1] == 0);
return {};
// Set padding size // Set padding size
if (padding == PaddingMode::Other) { if (padding == PaddingMode::Other) {
oh = (h - (r - sh) * dh + ph * 2) / sh; oh = (h - (r - sh) * dh + ph * 2) / sh;

View File

@ -9,8 +9,8 @@ DetObj::DetObj(GraphObj *graph, Tensor input, Tensor output, Mode mode)
optional<vector<Shape>> DetObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> DetObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0]; const auto A = inputs[0];
auto input = A->getDims(); auto input = A->getDims();
int length = input.size(); int rank = A->getRank();
if (length == 2) { if (rank == 2) {
std::vector<int> output = {1}; std::vector<int> output = {1};
return {{output}}; return {{output}};
} else { } else {

View File

@ -1,4 +1,5 @@
#include "operators/element_wise.h" #include "operators/element_wise.h"
#include "utils/operator_utils.h"
namespace infini { namespace infini {
ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0, ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0,
@ -9,31 +10,8 @@ ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0,
optional<vector<Shape>> optional<vector<Shape>>
ElementWiseObj::inferShape(const TensorVec &inputs) const { ElementWiseObj::inferShape(const TensorVec &inputs) const {
// For now,we only process the same dims here, broardcast will be considered
// in the opt layer.
const auto A = inputs[0], B = inputs[1]; const auto A = inputs[0], B = inputs[1];
int max_len = std::max(A->getDims().size(), B->getDims().size()); auto res = infer_broadcast(A->getDims(), B->getDims());
std::vector<int> A_(max_len, 1);
std::vector<int> B_(max_len, 1);
std::vector<int> res(max_len, 1);
memcpy(A_.data() + max_len - A->getDims().size(), A->getDims().data(),
A->getDims().size() * sizeof(int));
memcpy(B_.data() + max_len - B->getDims().size(), B->getDims().data(),
B->getDims().size() * sizeof(int));
// std::copy(A->getDims().begin(), A->getDims().end(), A_.begin() + (max_len
// - A->getDims().size())); std::copy(B->getDims().begin(),
// B->getDims().end(), B_.begin() + (max_len - B->getDims().size()));
// std::copy(A->getDims().rbegin(), A->getDims().rend(), A_.rbegin());
// std::copy(B->getDims().rbegin(), B->getDims().rend(), B_.rbegin());
for (int i = 0; i < max_len; ++i) {
if (A_[i] == B_[i] || (A_[i] == 1 || B_[i] == 1)) {
res[i] = std::max(A_[i], B_[i]);
} else {
return {};
}
}
return {{res}}; return {{res}};
} }
@ -69,9 +47,8 @@ MSELossObj::MSELossObj(GraphObj *graph, Tensor input0, Tensor input1,
optional<vector<Shape>> MSELossObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> MSELossObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0], B = inputs[1]; const auto A = inputs[0], B = inputs[1];
if (A->getDims().size() != B->getDims().size() || IT_ASSERT(A->getRank() == B->getRank());
A->getDims() != B->getDims()) IT_ASSERT(A->getDims() == B->getDims());
return {};
if (reductionMode == None) { if (reductionMode == None) {
return {{A->getDims()}}; return {{A->getDims()}};

View File

@ -1,16 +1,18 @@
#include "operators/extend.h" #include "operators/extend.h"
#include "utils/operator_utils.h"
namespace infini { namespace infini {
ExtendObj::ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim, ExtendObj::ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
int num) int num)
: OperatorObj(OpType::Extend, {input}, {output}), dim(dim), num(num) { : OperatorObj(OpType::Extend, {input}, {output}), dim(dim), num(num) {
int rank = input->getRank();
dim = get_real_axis(dim, rank);
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> ExtendObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ExtendObj::inferShape(const TensorVec &inputs) const {
auto ret = inputs[0]->getDims(); auto ret = inputs[0]->getDims();
IT_ASSERT((size_t)dim < ret.size());
ret[dim] = ret[dim] * (num + 1); ret[dim] = ret[dim] * (num + 1);
return {{ret}}; return {{ret}};
} }

View File

@ -1,9 +1,12 @@
#include "operators/gather.h" #include "operators/gather.h"
#include "utils/operator_utils.h"
namespace infini { namespace infini {
GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices, GatherObj::GatherObj(GraphObj *graph, Tensor input, Tensor indices,
Tensor output, int axis) Tensor output, int axis)
: OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) { : OperatorObj(OpType::Gather, {input, indices}, {output}), axis(axis) {
int rank = input->getRank();
axis = get_real_axis(axis, rank);
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
@ -11,12 +14,6 @@ optional<vector<Shape>> GatherObj::inferShape(const TensorVec &inputs) const {
auto dims0 = inputs[0]->getDims(); auto dims0 = inputs[0]->getDims();
auto dims1 = inputs[1]->getDims(); auto dims1 = inputs[1]->getDims();
if (axis < 0)
IT_TODO_HALT();
if ((size_t)axis >= dims0.size())
return {};
IT_ASSERT(CheckIndexValid()); IT_ASSERT(CheckIndexValid());
Shape dim = dims0; Shape dim = dims0;

View File

@ -1,4 +1,6 @@
#include "operators/matmul.h" #include "operators/matmul.h"
#include "utils/operator_utils.h"
#include <numeric>
namespace infini { namespace infini {
@ -9,25 +11,23 @@ MatmulObj::MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, bool transA,
transA(transA), transB(transB), act(act), b(1) { transA(transA), transB(transB), act(act), b(1) {
auto shape_a = A->getDims(); auto shape_a = A->getDims();
auto shape_b = B->getDims(); auto shape_b = B->getDims();
int dimA = shape_a.size(), dimB = shape_b.size(); int rankA = A->getRank();
IT_ASSERT(dimA >= 2 && dimB >= 2); int rankB = B->getRank();
IT_ASSERT(rankA >= 2 && rankB >= 2);
b = 1; Shape shape_a1(shape_a.begin(), shape_a.begin() + (rankA - 2));
if (dimA <= 3 && dimB <= 3) { Shape shape_b1(shape_b.begin(), shape_b.begin() + (rankB - 2));
int b1 = dimA == 2 ? 1 : A->getDims()[0]; auto ret = infer_broadcast(shape_a1, shape_b1);
int b2 = dimB == 2 ? 1 : B->getDims()[0]; if (ret.empty()) {
b = 1;
b = std::max(b1, b2);
} else { } else {
IT_ASSERT_TODO(dimA == dimB); b = std::accumulate(ret.begin(), ret.end(), 1);
for (size_t i = 0; i < shape_a.size() - 2; ++i) {
IT_ASSERT_TODO(shape_a[i] == shape_b[i]);
b *= shape_a[i];
}
} }
auto kA = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin());
auto kB = *(transB ? shape_b.rbegin() : shape_b.rbegin() + 1);
IT_ASSERT(kA == kB);
m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1); m = *(transA ? shape_a.rbegin() : shape_a.rbegin() + 1);
n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin()); n = *(transB ? shape_b.rbegin() + 1 : shape_b.rbegin());
k = *(transA ? shape_a.rbegin() + 1 : shape_a.rbegin()); k = kA;
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
@ -42,43 +42,16 @@ string MatmulObj::toString() const {
optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> MatmulObj::inferShape(const TensorVec &inputs) const {
auto A = inputs[0], B = inputs[1]; auto A = inputs[0], B = inputs[1];
int dimA = A->getDims().size(), dimB = B->getDims().size(); auto shapeA = A->getDims();
auto shapeB = B->getDims();
if (dimA > 3 || dimB > 3) { int rankA = A->getRank();
// no broadcast int rankB = B->getRank();
auto shape_a = inputs[0]->getDims(); Shape shapeA1(shapeA.begin(), shapeA.begin() + (rankA - 2));
auto it = shape_a.rbegin(); Shape shapeB1(shapeB.begin(), shapeB.begin() + (rankB - 2));
*it++ = n; Shape ret = infer_broadcast(shapeA1, shapeB1);
*it++ = m; ret.emplace_back(m);
return {{std::move(shape_a)}}; ret.emplace_back(n);
} return {{ret}};
int b1 = dimA == 2 ? 1 : A->getDims()[0];
int b2 = dimB == 2 ? 1 : B->getDims()[0];
int b = std::max(b1, b2);
int m = transA ? A->getDims()[dimA - 1] : A->getDims()[dimA - 2];
int n = transB ? B->getDims()[dimB - 2] : B->getDims()[dimB - 1];
int kA = transA ? A->getDims()[dimA - 2] : A->getDims()[dimA - 1];
int kB = transB ? B->getDims()[dimB - 1] : B->getDims()[dimB - 2];
if ((dimA != 2 && dimA != 3) || (dimB != 2 && dimB != 3)) {
printf("Bad input dim: dimA = %d, dimB = %d\n", dimA, dimB);
return {};
}
if (b1 != 1 && b2 != 1 && b1 != b2) {
printf("Bad batch size b1 = %d, b2 = %d\n", b1, b2);
return {};
}
if (kA != kB) {
printf("Bad K: kA = %d, kB = %d\n", kA, kB);
return {};
}
if (dimA == 2 && dimB == 2) {
return {{{m, n}}};
} else {
return {{{b, m, n}}};
}
} }
vector<int> MatmulObj::getWorkloadVector() const { vector<int> MatmulObj::getWorkloadVector() const {

View File

@ -9,7 +9,7 @@ PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output,
else { else {
auto nAxis = (*axes).size(); auto nAxis = (*axes).size();
IT_ASSERT(_pads.size() == nAxis * 2); IT_ASSERT(_pads.size() == nAxis * 2);
auto nDims = input->getDims().size(); auto nDims = input->getRank();
pads = vector<int>(nDims * 2, 0); pads = vector<int>(nDims * 2, 0);
for (size_t i = 0; i < nAxis; ++i) { for (size_t i = 0; i < nAxis; ++i) {
@ -24,13 +24,11 @@ PadObj::PadObj(GraphObj *graph, Tensor input, Tensor output,
optional<vector<Shape>> PadObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> PadObj::inferShape(const TensorVec &inputs) const {
auto dims = inputs[0]->getDims(); auto dims = inputs[0]->getDims();
int nDims = dims.size(); int rank = inputs[0]->getRank();
if (nDims * 2 != (int)pads.size()) IT_ASSERT(rank * 2 == (int)pads.size());
return {}; for (int i = 0; i < rank; ++i) {
for (int i = 0; i < nDims; ++i) { IT_ASSERT(pads[i] >= 0 && pads[i + rank] >= 0);
if (pads[i] < 0 || pads[i + nDims] < 0) dims[i] += pads[i] + pads[i + rank];
return {};
dims[i] += pads[i] + pads[i + nDims];
} }
return {{dims}}; return {{dims}};

View File

@ -16,13 +16,13 @@ PoolingObj::PoolingObj(GraphObj *graph, OpType optype, Tensor input,
optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> PoolingObj::inferShape(const TensorVec &inputs) const {
const auto &input = inputs[0]; const auto &input = inputs[0];
auto h = input->getDims()[input->getDims().size() - 2], auto h = input->getDims()[input->getRank() - 2],
w = input->getDims()[input->getDims().size() - 1]; w = input->getDims()[input->getRank() - 1];
int oh = (h - (kh - sh) + ph * 2) / sh; int oh = (h - (kh - sh) + ph * 2) / sh;
int ow = (w - (kw - sw) + pw * 2) / sw; int ow = (w - (kw - sw) + pw * 2) / sw;
auto ret = input->getDims(); auto ret = input->getDims();
ret[input->getDims().size() - 2] = oh; ret[input->getRank() - 2] = oh;
ret[input->getDims().size() - 1] = ow; ret[input->getRank() - 1] = ow;
return {{ret}}; return {{ret}};
} }

View File

@ -1,15 +1,14 @@
#include "operators/reduce_mean.h" #include "operators/reduce_mean.h"
#include "utils/operator_utils.h"
namespace infini { namespace infini {
ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output, ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<vector<int>> &_axes, bool keepDims) const optional<vector<int>> &_axes, bool keepDims)
: OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) { : OperatorObj(OpType::ReduceMean, {input}, {output}), keepDims(keepDims) {
const auto size = input->getDims().size(); const auto size = input->getRank();
if (_axes) { if (_axes) {
for (auto idx : *_axes) { for (auto idx : *_axes) {
if (idx < 0) idx = get_real_axis(idx, size);
IT_TODO_HALT();
IT_ASSERT((size_t)idx < size);
axes.emplace(idx); axes.emplace(idx);
} }
} else } else
@ -25,6 +24,7 @@ bool ReduceMeanObj::isReduced(int idx) const {
optional<vector<Shape>> optional<vector<Shape>>
ReduceMeanObj::inferShape(const TensorVec &inputs) const { ReduceMeanObj::inferShape(const TensorVec &inputs) const {
auto dims = inputs[0]->getDims(); auto dims = inputs[0]->getDims();
auto rank = inputs[0]->getRank();
if (keepDims) { if (keepDims) {
Shape ret = dims; Shape ret = dims;
@ -33,7 +33,7 @@ ReduceMeanObj::inferShape(const TensorVec &inputs) const {
return {{ret}}; return {{ret}};
} else { } else {
Shape ret; Shape ret;
for (size_t i = 0; i < dims.size(); ++i) { for (size_t i = 0; i < rank; ++i) {
if (!isReduced(i)) if (!isReduced(i))
ret.emplace_back(dims[i]); ret.emplace_back(dims[i]);
} }

View File

@ -1,4 +1,5 @@
#include "operators/reshape.h" #include "operators/reshape.h"
#include "utils/operator_utils.h"
namespace infini { namespace infini {
ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims) ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
@ -8,10 +9,10 @@ ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
optional<vector<Shape>> ReshapeObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ReshapeObj::inferShape(const TensorVec &inputs) const {
size_t size = 1; size_t size = 1;
for (size_t i = 0; i < dims.size(); ++i) for (size_t i = 0; i < dims.size(); ++i) {
size *= dims.at(i); size *= dims.at(i);
if (size != inputs[0]->size()) }
return {}; IT_ASSERT(size == inputs[0]->size());
return {{dims}}; return {{dims}};
} }
@ -41,22 +42,18 @@ vector<int> ReshapeObj::getOpAttrVector() const {
FlattenObj::FlattenObj(GraphObj *graph, Tensor input, Tensor output, int _axis) FlattenObj::FlattenObj(GraphObj *graph, Tensor input, Tensor output, int _axis)
: OperatorObj(OpType::Flatten, {input}, {output}) { : OperatorObj(OpType::Flatten, {input}, {output}) {
if (_axis >= 0 && (size_t)_axis < input->getDims().size()) int rank = input->getRank();
axis = _axis; axis = get_real_axis(_axis, rank);
else if (_axis <= -1 && (size_t)_axis >= -input->getDims().size())
axis = _axis + input->getDims().size();
else
IT_ASSERT(0);
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> FlattenObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> FlattenObj::inferShape(const TensorVec &inputs) const {
int sizeB = 1, sizeE = 1; int sizeB = 1, sizeE = 1;
auto dims = getInputs(0)->getDims(); auto dims = getInputs(0)->getDims();
int ndim = dims.size(); int rank = getInputs(0)->getRank();
for (int i = 0; i < ndim; ++i) for (int i = 0; i < rank; ++i) {
((i < axis) ? sizeB : sizeE) *= dims.at(i); ((i < axis) ? sizeB : sizeE) *= dims.at(i);
}
return {{{sizeB, sizeE}}}; return {{{sizeB, sizeE}}};
} }

View File

@ -45,11 +45,11 @@ void ResizeObj::init(const Tensor &input, const Tensor &sizes,
if (ECoordinateTransMode::tfCropAndResize == coMode) { if (ECoordinateTransMode::tfCropAndResize == coMode) {
IT_ASSERT(nullptr != roi); IT_ASSERT(nullptr != roi);
inputs.push_back(roi); inputs.push_back(roi);
IT_ASSERT(roi->getDims().size() == 1); IT_ASSERT(roi->getRank() == 1);
IT_ASSERT((size_t)roi->getDims()[0] == this->axes.size() * 2); IT_ASSERT((size_t)roi->getDims()[0] == this->axes.size() * 2);
// init roi_start = 0;roi_end =1 // init roi_start = 0;roi_end =1
size_t nDims = input->getDims().size(); size_t nDims = input->getRank();
for (size_t i = 0; i < nDims; ++i) { for (size_t i = 0; i < nDims; ++i) {
this->roi.emplace_back(0); this->roi.emplace_back(0);
} }
@ -75,24 +75,26 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes,
const std::optional<vector<int>> &axes) { const std::optional<vector<int>> &axes) {
IT_ASSERT(sizes != nullptr); IT_ASSERT(sizes != nullptr);
size_t size = sizes->getDims()[0]; size_t size = sizes->getDims()[0];
IT_ASSERT(size == input->getDims().size() || IT_ASSERT(size == input->getRank() ||
(axes != std::nullopt && size == (*axes).size())); (axes != std::nullopt && size == (*axes).size()));
if (axes == std::nullopt) if (axes == std::nullopt) {
for (size_t i = 0; i < input->getDims().size(); ++i) for (size_t i = 0; i < input->getRank(); ++i) {
this->axes.emplace_back(i); this->axes.emplace_back(i);
else }
} else {
// check axes // check axes
for (size_t i = 0; i < (*axes).size(); ++i) { for (size_t i = 0; i < (*axes).size(); ++i) {
auto val = (*axes)[i]; auto val = (*axes)[i];
if (val < 0) if (val < 0) {
IT_TODO_HALT(); IT_TODO_HALT();
IT_ASSERT((size_t)val < inputs[0]->getDims().size()); }
IT_ASSERT((size_t)val < inputs[0]->getRank());
this->axes.emplace_back(val); this->axes.emplace_back(val);
} }
}
// init this->scales // init this->scales
for (size_t i = 0; i < input->getDims().size(); ++i) { for (size_t i = 0; i < input->getRank(); ++i) {
this->scales.emplace_back(1); this->scales.emplace_back(1);
} }
@ -109,9 +111,10 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes,
int n = this->axes.size(); int n = this->axes.size();
switch (ratioPolicy) { switch (ratioPolicy) {
case EKeepAspectRatioPolicy::stretch: case EKeepAspectRatioPolicy::stretch:
for (int i = 0; i < n; ++i) for (int i = 0; i < n; ++i) {
scales[this->axes[i]] = scales[this->axes[i]] =
(float)data[i] / (float)inDims[this->axes[i]]; (float)data[i] / (float)inDims[this->axes[i]];
}
break; break;
case EKeepAspectRatioPolicy::notLarger: { case EKeepAspectRatioPolicy::notLarger: {
float scale = (float)data[0] / (float)inDims[this->axes[0]]; float scale = (float)data[0] / (float)inDims[this->axes[0]];
@ -119,8 +122,9 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes,
auto tmp = (float)data[i] / (float)inDims[this->axes[i]]; auto tmp = (float)data[i] / (float)inDims[this->axes[i]];
scale = scale < tmp ? scale : tmp; scale = scale < tmp ? scale : tmp;
} }
for (int i = 0; i < n; ++i) for (int i = 0; i < n; ++i) {
scales[this->axes[i]] = scale; scales[this->axes[i]] = scale;
}
break; break;
} }
case EKeepAspectRatioPolicy::notSmaller: { case EKeepAspectRatioPolicy::notSmaller: {
@ -129,8 +133,9 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes,
auto tmp = (float)data[i] / (float)inDims[this->axes[i]]; auto tmp = (float)data[i] / (float)inDims[this->axes[i]];
scale = scale > tmp ? scale : tmp; scale = scale > tmp ? scale : tmp;
} }
for (int i = 0; i < n; ++i) for (int i = 0; i < n; ++i) {
scales[this->axes[i]] = scale; scales[this->axes[i]] = scale;
}
break; break;
} }
default: default:
@ -142,7 +147,7 @@ void ResizeObj::InitByScales(Tensor input, Tensor scales,
const std::optional<vector<int>> &axes) { const std::optional<vector<int>> &axes) {
IT_ASSERT(scales != nullptr); IT_ASSERT(scales != nullptr);
size_t size = scales->getDims()[0]; size_t size = scales->getDims()[0];
IT_ASSERT(size == input->getDims().size() || IT_ASSERT(size == input->getRank() ||
(axes != std::nullopt && size == (*axes).size())); (axes != std::nullopt && size == (*axes).size()));
// copy scales data to host. // copy scales data to host.
@ -155,27 +160,29 @@ void ResizeObj::InitByScales(Tensor input, Tensor scales,
(void *)data, scales->getRawDataPtr<void *>(), scales->getBytes()); (void *)data, scales->getRawDataPtr<void *>(), scales->getBytes());
// init this->scales // init this->scales
for (size_t i = 0; i < input->getDims().size(); ++i) { for (size_t i = 0; i < input->getRank(); ++i) {
this->scales.emplace_back(1); this->scales.emplace_back(1);
} }
if (axes == std::nullopt) if (axes == std::nullopt) {
for (size_t i = 0; i < input->getDims().size(); ++i) { for (size_t i = 0; i < input->getRank(); ++i) {
this->axes.emplace_back(i); this->axes.emplace_back(i);
IT_ASSERT(data[i] > 0); IT_ASSERT(data[i] > 0);
this->scales[i] = data[i]; this->scales[i] = data[i];
} }
else } else {
// check axes // check axes
for (size_t i = 0; i < (*axes).size(); ++i) { for (size_t i = 0; i < (*axes).size(); ++i) {
auto val = (*axes)[i]; auto val = (*axes)[i];
if (val < 0) if (val < 0) {
IT_TODO_HALT(); IT_TODO_HALT();
IT_ASSERT((size_t)val < inputs[0]->getDims().size()); }
IT_ASSERT((size_t)val < inputs[0]->getRank());
this->axes.emplace_back(val); this->axes.emplace_back(val);
IT_ASSERT(data[i] > 0); IT_ASSERT(data[i] > 0);
this->scales[val] = data[i]; this->scales[val] = data[i];
} }
}
} }
vector<DataType> ResizeObj::inferDataType(const TensorVec &inputs) const { vector<DataType> ResizeObj::inferDataType(const TensorVec &inputs) const {
@ -202,8 +209,8 @@ float ResizeObj::round_int(float x) const {
optional<vector<Shape>> ResizeObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ResizeObj::inferShape(const TensorVec &inputs) const {
auto inDims = inputs[0]->getDims(); auto inDims = inputs[0]->getDims();
Shape ret = inDims; Shape ret = inDims;
int nDim = inDims.size(); int rank = inputs[0]->getRank();
for (int i = 0; i < nDim; ++i) { for (int i = 0; i < rank; ++i) {
int size = round_int(scales[i] * inDims[i]); int size = round_int(scales[i] * inDims[i]);
ret[i] = size; ret[i] = size;
} }
@ -217,12 +224,14 @@ std::string ResizeObj::toString() const {
<< "[" << getGuid() << "]"; << "[" << getGuid() << "]";
os << "("; os << "(";
os << vecToString(inputs[0]->getDims()) << ","; os << vecToString(inputs[0]->getDims()) << ",";
if (inputs.size() == 3) if (inputs.size() == 3) {
os << "roi=" << vecToString(inputs[2]->getDims()) << ","; os << "roi=" << vecToString(inputs[2]->getDims()) << ",";
if (isResizeBySizes()) }
if (isResizeBySizes()) {
os << "sizes=" << vecToString(inputs[1]->getDims()) << ","; os << "sizes=" << vecToString(inputs[1]->getDims()) << ",";
else } else {
os << "scales=" << vecToString(inputs[1]->getDims()) << ","; os << "scales=" << vecToString(inputs[1]->getDims()) << ",";
}
os << "axes=" << vecToString(axes) << ","; os << "axes=" << vecToString(axes) << ",";
os << "coMode=" << enum_to_underlying(coMode) << ","; os << "coMode=" << enum_to_underlying(coMode) << ",";
os << "nearestMode=" << enum_to_underlying(nearestMode) << ","; os << "nearestMode=" << enum_to_underlying(nearestMode) << ",";
@ -230,16 +239,18 @@ std::string ResizeObj::toString() const {
os << "input=" << inputs[0]->getGuid() << ","; os << "input=" << inputs[0]->getGuid() << ",";
os << inputs[1]->getGuid() << ","; os << inputs[1]->getGuid() << ",";
if (inputs.size() == 3) if (inputs.size() == 3) {
os << inputs[2]->getGuid() << ","; os << inputs[2]->getGuid() << ",";
}
os << "output=" << outputs[0]->getGuid() << ")"; os << "output=" << outputs[0]->getGuid() << ")";
return os.str(); return os.str();
} }
vector<int> ResizeObj::getWorkloadVector() const { vector<int> ResizeObj::getWorkloadVector() const {
vector<int> ret = inputs[0]->getDims(); vector<int> ret = inputs[0]->getDims();
for (size_t i = 0; i < outputs[0]->getDims().size(); ++i) for (size_t i = 0; i < outputs[0]->getRank(); ++i) {
ret.emplace_back(outputs[0]->getDims()[i]); ret.emplace_back(outputs[0]->getDims()[i]);
}
// ratioPolicy only effects output shape, so did not need // ratioPolicy only effects output shape, so did not need
// here. // here.
ret.emplace_back(enum_to_underlying(coMode)); ret.emplace_back(enum_to_underlying(coMode));

View File

@ -1,15 +1,12 @@
#include "operators/softmax.h" #include "operators/softmax.h"
#include "utils/operator_utils.h"
namespace infini { namespace infini {
SoftmaxObj::SoftmaxObj(GraphObj *graph, Tensor input, Tensor output, int _axis) SoftmaxObj::SoftmaxObj(GraphObj *graph, Tensor input, Tensor output, int _axis)
: OperatorObj(OpType::Softmax, {input}, {output}) { : OperatorObj(OpType::Softmax, {input}, {output}) {
if (_axis >= 0 && (size_t)_axis < input->getDims().size()) int rank = input->getRank();
axis = _axis; axis = get_real_axis(_axis, rank);
else if (_axis <= -1 && (size_t)_axis >= -input->getDims().size())
axis = _axis + input->getDims().size();
else
IT_ASSERT(0);
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }

View File

@ -1,4 +1,5 @@
#include "operators/split.h" #include "operators/split.h"
#include "utils/operator_utils.h"
#include <numeric> #include <numeric>
namespace infini { namespace infini {
@ -7,6 +8,8 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
: OperatorObj(OpType::Split, {input}, : OperatorObj(OpType::Split, {input},
((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))), ((!outputs) ? TensorVec(num, nullptr) : std::move(*outputs))),
dim(dim), num(num), ratio({}) { dim(dim), num(num), ratio({}) {
int rank = input->getRank();
dim = get_real_axis(dim, rank);
int dimSize = input->getDims().at(dim); int dimSize = input->getDims().at(dim);
int pieceSize = dimSize / num; int pieceSize = dimSize / num;
int lastSize = dimSize - pieceSize * num; int lastSize = dimSize - pieceSize * num;
@ -26,6 +29,8 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
: OperatorObj(OpType::Split, {input}, : OperatorObj(OpType::Split, {input},
((!outputs) ? TensorVec{nullptr} : (*outputs))), ((!outputs) ? TensorVec{nullptr} : (*outputs))),
dim(dim), num(-1), ratio(ratio) { dim(dim), num(-1), ratio(ratio) {
int rank = input->getRank();
dim = get_real_axis(dim, rank);
num = ratio.size(); num = ratio.size();
if (!outputs) { if (!outputs) {
TensorVec tmp(num, nullptr); TensorVec tmp(num, nullptr);
@ -35,13 +40,11 @@ SplitObj::SplitObj(GraphObj *graph, Tensor input,
} }
optional<vector<Shape>> SplitObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> SplitObj::inferShape(const TensorVec &inputs) const {
if (num == -1 || ratio.size() == 0) IT_ASSERT(num != -1 && ratio.size() != 0);
return {};
auto inputDims = inputs[0]->getDims(); auto inputDims = inputs[0]->getDims();
int totalSize = inputDims.at(dim); int totalSize = inputDims.at(dim);
int ratioSum = std::accumulate(ratio.begin(), ratio.end(), 0); int ratioSum = std::accumulate(ratio.begin(), ratio.end(), 0);
if (totalSize % ratioSum != 0) IT_ASSERT(totalSize % ratioSum == 0);
return {};
int pieceSize = totalSize / ratioSum; int pieceSize = totalSize / ratioSum;

View File

@ -4,26 +4,32 @@ namespace infini {
TransposeObj::TransposeObj(GraphObj *graph, Tensor input, Tensor output, TransposeObj::TransposeObj(GraphObj *graph, Tensor input, Tensor output,
vector<int> permute) vector<int> permute)
: OperatorObj(OpType::Transpose, {input}, {output}) { : OperatorObj(OpType::Transpose, {input}, {output}) {
if (permute.size() != 4) { auto rank = input->getRank();
IT_TODO_HALT(); if (permute.empty()) {
for (size_t i = 0; i < rank; ++i) {
transposePermute[i] = i;
}
} else {
IT_ASSERT(rank == permute.size());
transposePermute = std::move(permute);
} }
transposePermute[0] = permute[0];
transposePermute[1] = permute[1];
transposePermute[2] = permute[2];
transposePermute[3] = permute[3];
IT_ASSERT(checkValid(graph)); IT_ASSERT(checkValid(graph));
} }
optional<vector<Shape>> optional<vector<Shape>>
TransposeObj::inferShape(const TensorVec &inputs) const { TransposeObj::inferShape(const TensorVec &inputs) const {
const auto A = inputs[0]; const auto A = inputs[0];
auto input = A->getDims(); auto input_dim = A->getDims();
auto output = input; auto output_dim = input_dim;
int rank = A->getRank();
for (int i = 0; i < 4; ++i) { for (auto index : transposePermute) {
output[i] = input[transposePermute[i]]; IT_ASSERT(index < rank);
} }
return {{output}}; for (int i = 0; i < rank; ++i) {
output_dim[i] = input_dim[transposePermute[i]];
}
return {{output_dim}};
} }
std::string TransposeObj::toString() const { std::string TransposeObj::toString() const {

View File

@ -183,46 +183,54 @@ vector<int> CastObj::getOpAttrVector() const { return {type.underlying()}; }
DataType CastObj::getOutputDataType() const { DataType CastObj::getOutputDataType() const {
switch (castType) { switch (castType) {
case CastObj::Float2Int64: case CastType::Float2Float16:
return DataType::Float16;
case CastType::Float2Int64:
return DataType::Int64; return DataType::Int64;
case CastObj::Float2Int32: case CastType::Float2Int32:
return DataType::Int32; return DataType::Int32;
case CastObj::Float2Int16: case CastType::Float2Int16:
return DataType::Int16; return DataType::Int16;
case CastObj::Float2Int8: case CastType::Float2Int8:
return DataType::Int8; return DataType::Int8;
case CastObj::Int322Float: case CastType::Int322Float:
return DataType::Float32; return DataType::Float32;
case CastObj::Int322Int8: case CastType::Int322Int8:
return DataType::Int8; return DataType::Int8;
case CastObj::Int322Int16: case CastType::Int322Int16:
return DataType::Int16; return DataType::Int16;
case CastObj::Int162Float: case CastType::Int162Float:
return DataType::Float32; return DataType::Float32;
case CastObj::Int162Int32: case CastType::Int162Int32:
return DataType::Int32; return DataType::Int32;
case CastObj::Int82Float: case CastType::Int82Float:
return DataType::Float32; return DataType::Float32;
case CastObj::Int82Int16: case CastType::Int82Int16:
return DataType::Int16; return DataType::Int16;
case CastObj::Int82Int32: case CastType::Int82Int32:
return DataType::Int32; return DataType::Int32;
case CastObj::Uint82Float: case CastType::Uint82Float:
return DataType::Float32; return DataType::Float32;
case CastObj::Uint82Int32: case CastType::Uint82Int32:
return DataType::Int32; return DataType::Int32;
case CastObj::Uint82Int64: case CastType::Uint82Int64:
return DataType::Int64; return DataType::Int64;
case CastObj::Int322Int64: case CastType::Int322Int64:
return DataType::Int64; return DataType::Int64;
case CastObj::Int642Int32: case CastType::Int642Int32:
return DataType::Int32; return DataType::Int32;
case CastObj::Int642Uint32: case CastType::Int642Uint32:
return DataType::UInt32; return DataType::UInt32;
case CastObj::Int642Float: case CastType::Int642Float:
return DataType::Float32; return DataType::Float32;
case CastObj::Uint322Int64: case CastType::Uint322Int64:
return DataType::Int64; return DataType::Int64;
case CastType::Float162Float:
return DataType::Float32;
case CastType::BFloat162Float:
return DataType::Float32;
case CastType::Float2BFloat16:
return DataType::BFloat16;
default: default:
IT_TODO_HALT(); IT_TODO_HALT();
} }
@ -234,7 +242,7 @@ ShapeObj::ShapeObj(GraphObj *graph, Tensor input, Tensor output)
} }
optional<vector<Shape>> ShapeObj::inferShape(const TensorVec &inputs) const { optional<vector<Shape>> ShapeObj::inferShape(const TensorVec &inputs) const {
return {{{static_cast<int>(inputs[0]->getDims().size())}}}; return {{{static_cast<int>(inputs[0]->getRank())}}};
} }
std::string ShapeObj::toString() const { std::string ShapeObj::toString() const {

View File

@ -27,4 +27,17 @@ float fp16_to_float(const uint16_t x) {
u.u32 = r; u.u32 = r;
return u.f32; return u.f32;
} }
uint16_t float_to_bfp16(const float x) {
Uf32 u;
u.f32 = x;
return u.u32 >> 16;
}
float bfp16_to_fp32(const uint16_t x) {
Uf32 u;
u.u32 = x << 16;
return u.f32;
}
} // namespace infini } // namespace infini

View File

@ -12,7 +12,7 @@ void saveTensorData(TensorObj *tensor, std::string file_path) {
#ifdef TENSOR_PROTOBUF #ifdef TENSOR_PROTOBUF
data::Tensor temp; data::Tensor temp;
temp.set_id("tensor_id"); temp.set_id("tensor_id");
for (size_t i = 0; i < tensor->getDims().size(); ++i) { for (size_t i = 0; i < tensor->getRank(); ++i) {
temp.add_shape(tensor->getDims()[i]); temp.add_shape(tensor->getDims()[i]);
} }
temp.set_layout(data::LAYOUT_NHWC); temp.set_layout(data::LAYOUT_NHWC);

View File

@ -0,0 +1,44 @@
#include "utils/operator_utils.h"
namespace infini {
Shape infer_broadcast(const Shape &A, const Shape &B) {
if (A.empty() && B.empty()) {
return {};
}
auto A_ = A;
auto B_ = B;
int rankA = A.size();
int rankB = B.size();
int rank = std::max(rankA, rankB);
if (rankA < rank) {
for (int i = 0; i < rank - rankA; ++i) {
A_.insert(A_.begin(), 1);
}
}
if (rankB < rank) {
for (int i = 0; i < rank - rankB; ++i) {
B_.insert(B_.begin(), 1);
}
}
Shape ret;
for (int i = 0; i < rank; ++i) {
IT_ASSERT(A_[i] == B_[i] || A_[i] == 1 || B_[i] == 1);
auto shapeEle = std::max(A_[i], B_[i]);
ret.emplace_back(shapeEle);
}
return ret;
}
int get_real_axis(const int &axis, const int &rank) {
IT_ASSERT(rank >= 1);
IT_ASSERT(axis >= -rank && axis <= (rank - 1));
int newAxis;
if (axis < 0) {
newAxis = rank + axis;
} else {
newAxis = axis;
}
return newAxis;
}
} // namespace infini

View File

@ -28,4 +28,4 @@ TEST(Hash, OperatorHash) {
EXPECT_NE(key1.hash, key2.hash); EXPECT_NE(key1.hash, key2.hash);
} }
} // namespace infini } // namespace infini

View File

@ -27,6 +27,30 @@ TEST(Matmul, ShapeInference) {
auto C = matmul->getOutputs()[0]; auto C = matmul->getOutputs()[0];
EXPECT_EQ(C->getDims(), (Shape{3, 4, 2})); EXPECT_EQ(C->getDims(), (Shape{3, 4, 2}));
} }
{
Graph g = make_ref<GraphObj>(runtime);
auto A = g->addTensor(Shape{1, 2, 3, 5});
auto B = g->addTensor(Shape{1, 1, 5, 2});
auto matmul = g->addOp<MatmulObj>(A, B, nullptr);
auto C = matmul->getOutputs()[0];
EXPECT_EQ(C->getDims(), (Shape{1, 2, 3, 2}));
}
{
Graph g = make_ref<GraphObj>(runtime);
auto A = g->addTensor(Shape{2, 3, 5, 4});
auto B = g->addTensor(Shape{1, 3, 5, 2});
auto matmul = g->addOp<MatmulObj>(A, B, nullptr, true, false);
auto C = matmul->getOutputs()[0];
EXPECT_EQ(C->getDims(), (Shape{2, 3, 4, 2}));
}
{
Graph g = make_ref<GraphObj>(runtime);
auto A = g->addTensor(Shape{2, 3, 5, 4});
auto B = g->addTensor(Shape{1, 3, 2, 5});
auto matmul = g->addOp<MatmulObj>(A, B, nullptr, true, true);
auto C = matmul->getOutputs()[0];
EXPECT_EQ(C->getDims(), (Shape{2, 3, 4, 2}));
}
} }
}; // namespace infini }; // namespace infini

View File

@ -0,0 +1,32 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "operators/transpose.h"
#include "test.h"
namespace infini {
TEST(Transpose, ShapeInference) {
Runtime runtime = NativeCpuRuntimeObj::getInstance();
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({1, 2, 3, 4}, DataType::Float32);
auto op = g->addOp<TransposeObj>(i, nullptr, Shape{0, 1, 2, 3});
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 2, 3, 4}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({1, 2, 3, 4}, DataType::Float32);
auto op = g->addOp<TransposeObj>(i, nullptr, Shape{0, 2, 1, 3});
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 3, 2, 4}));
}
{
Graph g = make_ref<GraphObj>(runtime);
Tensor i = g->addTensor({2, 3, 4}, DataType::Float32);
auto op = g->addOp<TransposeObj>(i, nullptr, Shape{0, 2, 1});
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 4, 3}));
}
}
} // namespace infini