Compare commits

...

5 Commits

Author SHA1 Message Date
Pairshoe f88aefb2ca Add: pytest for import_onnx 2022-10-26 23:57:29 +08:00
Pairshoe 0ef74a2145 Update: Rename GraphFactory -> GraphBuilder && Remove unnecessary outputs 2022-10-26 21:06:21 +08:00
Pairshoe d50fecc132 Add: python interfaced for importing onnx 2022-10-16 15:17:10 +08:00
Pairshoe 4836decc69 Add: test for class GraphFactoryObj 2022-10-16 15:17:10 +08:00
Pairshoe fe535e72a7 Add: class GraphFactory and pybind11 interfaces 2022-10-16 15:17:10 +08:00
9 changed files with 1650 additions and 1 deletions

View File

@ -0,0 +1,166 @@
#pragma once
#include "core/common.h"
#include "core/graph.h"
#include "core/operator.h"
#include "core/tensor.h"
#include "operators/G2BMM.h"
#include "operators/GBMM.h"
#include "operators/concat.h"
#include "operators/conv.h"
#include "operators/element_wise.h"
#include "operators/extend.h"
#include "operators/gather.h"
#include "operators/matmul.h"
#include "operators/membound.h"
#include "operators/pad.h"
#include "operators/pooling.h"
#include "operators/reshape.h"
#include "operators/slice.h"
#include "operators/split.h"
#include "operators/unary.h"
namespace infini {
class GraphBuilderObj {
private:
Graph g;
public:
GraphBuilderObj(Runtime runtime) : g(make_ref<GraphObj>(runtime)) {}
// tensors
Tensor tensor(Shape dim, const std::string &dtype);
// operators
// conv op
Operator conv(Tensor input, Tensor weight, Tensor output, int ph, int pw,
int sh = 1, int sw = 1, int dh = 1, int dw = 1,
Tensor bias = nullptr);
Operator conv(Tensor input, Tensor weight, int ph, int pw, int sh = 1,
int sw = 1, int dh = 1, int dw = 1, Tensor bias = nullptr);
Operator conv(Tensor input, Tensor weight, Tensor output,
ConvBaseObj::PaddingMode pm, int sh = 1, int sw = 1,
int dh = 1, int dw = 1, Tensor bias = nullptr);
Operator conv(Tensor input, Tensor weight, ConvBaseObj::PaddingMode pm,
int sh = 1, int sw = 1, int dh = 1, int dw = 1,
Tensor bias = nullptr);
// matmul op
Operator matmul(Tensor A, Tensor B, Tensor C, bool transA = false,
bool transB = false, Tensor bias = nullptr,
ActType act = ActType::None);
Operator matmul(Tensor A, Tensor B, bool transA = false,
bool transB = false, Tensor bias = nullptr,
ActType act = ActType::None);
// conv trans op
Operator convTrans(Tensor input, Tensor weight, Tensor output, int ph,
int pw, int sh = 1, int sw = 1, int dh = 1, int dw = 1,
int oph = 0, int opw = 0, int group = 1,
Tensor bias = nullptr, ActType act = ActType::None);
Operator convTrans(Tensor input, Tensor weight, int ph, int pw, int sh = 1,
int sw = 1, int dh = 1, int dw = 1, int oph = 0,
int opw = 0, int group = 1, Tensor bias = nullptr,
ActType act = ActType::None);
Operator convTrans(Tensor input, Tensor weight, Tensor output,
ConvBaseObj::PaddingMode pm, int sh = 1, int sw = 1,
int dh = 1, int dw = 1, int oph = 0, int opw = 0,
int group = 1, Tensor bias = nullptr,
ActType act = ActType::None);
Operator convTrans(Tensor input, Tensor weight, ConvBaseObj::PaddingMode pm,
int sh = 1, int sw = 1, int dh = 1, int dw = 1,
int oph = 0, int opw = 0, int group = 1,
Tensor bias = nullptr, ActType act = ActType::None);
// g2bmm op
Operator g2bmm(Tensor A, Tensor B, Tensor C, const int width,
const int dilation, Tensor bias = nullptr,
ActType act = ActType::None);
Operator g2bmm(Tensor A, Tensor B, const int width, const int dilation,
Tensor bias = nullptr, ActType act = ActType::None);
// gbmm-like op
Operator gbmml(Tensor A, Tensor B, Tensor C, const int dilation,
Tensor bias = nullptr, ActType act = ActType::None);
Operator gbmml(Tensor A, Tensor B, const int dilation,
Tensor bias = nullptr, ActType act = ActType::None);
// pad op
Operator pad(Tensor input, Tensor output, const vector<int> &pads,
const optional<const vector<int>> &axis);
Operator pad(Tensor input, const vector<int> &pads,
const optional<const vector<int>> &axis);
// slice op
Operator slice(Tensor input, Tensor output, const vector<int> &starts,
const vector<int> &ends,
const optional<const vector<int>> &axis,
const optional<const vector<int>> &steps);
Operator slice(Tensor input, const vector<int> &starts,
const vector<int> &ends,
const optional<const vector<int>> &axis,
const optional<const vector<int>> &steps);
// concat op
Operator concat(TensorVec inputs, Tensor output, int dim);
Operator concat(TensorVec inputs, int dim);
// split op
Operator split(Tensor input, std::optional<TensorVec> outputs, int dim,
int num);
Operator split(Tensor input, int dim, int num);
Operator split(Tensor input, std::optional<TensorVec> outputs, int dim,
const vector<int> &ratio);
Operator split(Tensor input, int dim, const vector<int> &ratio);
// transpose op
// TODO
// extend op
Operator extend(Tensor input, Tensor output, int dim, int num);
Operator extend(Tensor input, int dim, int num);
// max pool op
Operator maxpool(Tensor input, Tensor output, int kh, int kw, int dh,
int dw, int ph, int pw, int sh, int sw);
Operator maxpool(Tensor input, int kh, int kw, int dh, int dw, int ph,
int pw, int sh, int sw);
// average pool op
Operator avgpool(Tensor input, Tensor output, int kh, int kw, int dh,
int dw, int ph, int pw, int sh, int sw);
Operator avgpool(Tensor input, int kh, int kw, int dh, int dw, int ph,
int pw, int sh, int sw);
// element wise op
Operator add(Tensor input0, Tensor input1, Tensor output);
Operator add(Tensor input0, Tensor input1);
Operator sub(Tensor input0, Tensor input1, Tensor output);
Operator sub(Tensor input0, Tensor input1);
Operator mul(Tensor input0, Tensor input1, Tensor output);
Operator mul(Tensor input0, Tensor input1);
Operator div(Tensor input0, Tensor input1, Tensor output);
Operator div(Tensor input0, Tensor input1);
Operator pow(Tensor input0, Tensor input1, Tensor output);
Operator pow(Tensor input0, Tensor input1);
// gather op
Operator gather(Tensor input, Tensor index, Tensor output, int axis);
Operator gather(Tensor input, Tensor index, int axis);
// reduce mean op
// TODO
// reshape op
Operator reshape(Tensor input, Tensor output, const Shape &dims);
Operator reshape(Tensor input, const Shape &dims);
Operator flatten(Tensor input, Tensor output);
Operator flatten(Tensor input);
Operator identity(Tensor input, Tensor output);
Operator identity(Tensor input);
// unary op
// TODO: batch norm
Operator softmax(Tensor input, Tensor output);
Operator softmax(Tensor input);
// TODO: activation
Operator relu(Tensor input, Tensor output);
Operator relu(Tensor input);
Operator sigmoid(Tensor input, Tensor output);
Operator sigmoid(Tensor input);
Operator tanh(Tensor input, Tensor output);
Operator tanh(Tensor input);
Operator abs(Tensor input, Tensor output);
Operator abs(Tensor input);
// resize op
// TODO
// membound op
Operator memBound(const TensorVec &inputs, const TensorVec &outputs,
const std::vector<nnet::Tensor> &nnetInputs,
nnet::Expr expr, double exec_time, std::string hint = {});
};
} // namespace infini

View File

@ -10,6 +10,7 @@ class TensorBaseObj;
class TensorObj;
class OperatorObj;
class GraphObj;
class GraphBuilderObj;
class RuntimeObj;
class BlobObj;
@ -17,6 +18,7 @@ using TensorBase = Ref<TensorBaseObj>;
using Tensor = Ref<TensorObj>;
using Operator = Ref<OperatorObj>;
using Graph = Ref<GraphObj>;
using GraphBuilder = Ref<GraphBuilderObj>;
using Runtime = Ref<RuntimeObj>;
using Blob = Ref<BlobObj>;
enum class OpType;

View File

@ -0,0 +1 @@
from .import_onnx import import_onnx

View File

@ -0,0 +1,515 @@
from pyinfinitensor import *
import functools
import numpy as np
import onnx
import onnx.checker
import onnx.numpy_helper
import onnx.shape_inference
def _add_value_info_for_constants(model : onnx.ModelProto):
"""
Currently onnx.shape_inference doesn't use the shape of initializers, so add
that info explicitly as ValueInfoProtos.
Mutates the model.
Args:
model: The ModelProto to update.
"""
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
if model.ir_version < 4:
return
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
inputs = {i.name for i in graph.input}
existing_info = {vi.name: vi for vi in graph.value_info}
for init in graph.initializer:
# Check it really is a constant, not an input
if init.name in inputs:
continue
# The details we want to add
elem_type = init.data_type
shape = init.dims
# Get existing or create new value info for this constant
vi = existing_info.get(init.name)
if vi is None:
vi = graph.value_info.add()
vi.name = init.name
# Even though it would be weird, we will not overwrite info even if it doesn't match
tt = vi.type.tensor_type
if tt.elem_type == onnx.TensorProto.UNDEFINED:
tt.elem_type = elem_type
if not tt.HasField("shape"):
# Ensure we set an empty list if the const is scalar (zero dims)
tt.shape.dim.extend([])
for dim in shape:
tt.shape.dim.add().dim_value = dim
# Handle subgraphs
for node in graph.node:
for attr in node.attribute:
# Ref attrs refer to other attrs, so we don't need to do anything
if attr.ref_attr_name != "":
continue
if attr.type == onnx.AttributeProto.GRAPH:
add_const_value_infos_to_graph(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_const_value_infos_to_graph(g)
return add_const_value_infos_to_graph(model.graph)
def _parse_attribute(attributes, defaults=dict()):
atts = defaults
for att in attributes:
if att.type == onnx.AttributeProto.INT:
atts[att.name] = att.i
elif att.type == onnx.AttributeProto.INTS:
atts[att.name] = att.ints
elif att.type == onnx.AttributeProto.FLOAT:
atts[att.name] = att.f
elif att.type == onnx.AttributeProto.STRING:
atts[att.name] = att.s
elif att.type == onnx.AttributeProto.TENSOR:
atts[att.name] = att.t
else:
assert False, "Unsupported Attribute Type: {}".format(att.type)
return atts
def _onnx_datatype_tostring(dtype):
if dtype == 0:
return 'UNDEFINED'
elif dtype == 1:
return 'FLOAT'
elif dtype == 2:
return 'UINT8'
elif dtype == 3:
return 'INT8'
elif dtype == 4:
return 'UINT16'
elif dtype == 5:
return 'INT16'
elif dtype == 6:
return 'INT32'
elif dtype == 7:
return 'INT64'
elif dtype == 8:
return 'STRING'
elif dtype == 9:
return 'BOOL'
elif dtype == 10:
return 'FLOAT16'
elif dtype == 11:
return 'DOUBLE'
elif dtype == 12:
return 'UINT32'
elif dtype == 13:
return 'UINT64'
elif dtype == 14:
return 'COMPLEX64'
elif dtype == 15:
return 'COMPLEX128'
elif dtype == 16:
return 'BFLOAT16'
else:
assert False, 'Unknown onnx datatype'
def import_onnx(gf: GraphBuilder, net: str):
ts, ds, ops, consts = dict(), dict(), dict(), dict() # (key, value) = (name, class)
model = onnx.load(net)
# Tensor_input
for input in model.graph.input:
if input.name not in ts:
dims = [d.dim_value for d in input.type.tensor_type.shape.dim]
ts[input.name] = gf.tensor(dims, _onnx_datatype_tostring(input.type.tensor_type.elem_type))
ds[input.name] = dims
# Tensor_weight
for weight in model.graph.initializer:
if weight.name not in ts:
ts[weight.name] = gf.tensor(weight.dims, _onnx_datatype_tostring(weight.data_type))
ds[weight.name] = weight.dims
# Tensor_inference
_add_value_info_for_constants(model)
infered_model = onnx.shape_inference.infer_shapes(model)
for v in infered_model.graph.value_info:
if v.name not in ts:
dims = [d.dim_value for d in v.type.tensor_type.shape.dim]
ts[v.name] = gf.tensor(dims, _onnx_datatype_tostring(v.type.tensor_type.elem_type))
ds[v.name] = dims
# Tensor_output
for output in model.graph.output:
if output.name not in ts:
dims = [d.dim_value for d in output.type.tensor_type.shape.dim]
ts[output.name] = gf.tensor(dims, _onnx_datatype_tostring(output.type.tensor_type.elem_type))
ds[output.name] = dims
# Op
for node in model.graph.node:
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
if node.op_type == 'Conv':
attrs = _parse_attribute(node.attribute, {
"auto_pad": "NOTSET",
"dilations": [1, 1],
"pads": [0, 0, 0, 0],
"strides": [1, 1]})
assert len(node.input) == 2 # bias is not implemented yet
assert len(node.output) == 1
assert attrs["auto_pad"] == "NOTSET"
assert len(attrs["pads"]) == 4
assert len(attrs["strides"]) == 2
assert len(attrs["dilations"]) == 2
assert attrs["pads"][0] == attrs["pads"][2]
assert attrs["pads"][1] == attrs["pads"][3]
gf.conv(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]],
attrs["pads"][0], attrs["pads"][1],
attrs["strides"][0], attrs["strides"][1],
attrs["dilations"][0], attrs["dilations"][1],
None if len(node.input) == 2 else ts[node.input[2]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul
# elif node.op_type == 'MatMul':
# assert len(node.input) == 2
# assert len(node.output) == 1
# dimA = list(ds[node.input[0]])
# dimB = list(ds[node.input[1]])
# dimO = list(ds[node.output[0]])
# if len(dimA) == 2 and len(dimB) == 2:
# tmpI0 = g.tensor([1] + list(ds[node.input[0]]), "FLOAT")
# tmpI1 = g.tensor([1] + list(ds[node.input[1]]), "FLOAT")
# tmpO = g.tensor([1] + list(ds[node.output[0]]), "FLOAT")
# g.transpose(ts[node.input[0]], tmpI0, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
# g.transpose(ts[node.input[1]], tmpI1, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
# g.matmul(tmpI0, tmpI1, tmpO, False, False, None)
# g.transpose(tmpO, ts[node.output[0]], -1, Perm([PermItem([0, 1]), PermItem(2)]), 0)
# else:
# assert len(dimO) >= 3
# batch = functools.reduce(lambda x, y: x * y, dimO[:-2])
# if len(dimA) == 3:
# tmpI0 = ts[node.input[0]]
# else:
# tmpI0 = g.tensor([batch, dimA[-2], dimA[-1]], "FLOAT")
# g.reshape(ts[node.input[0]], tmpI0)
# if len(dimB) == 3:
# tmpI1 = ts[node.input[1]]
# else:
# tmpI1 = g.tensor([batch, dimB[-2], dimB[-1]], "FLOAT")
# g.reshape(ts[node.input[1]], tmpI1)
# if len(dimO) == 3:
# tmpO = ts[node.output[0]]
# g.matmul(tmpI0, tmpI1, tmpO, False, False, None)
# else:
# tmpO = g.tensor([batch, dimO[-2], dimO[-1]], "FLOAT")
# g.matmul(tmpI0, tmpI1, tmpO, False, False, None)
# g.reshape(tmpO, ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConvTranspose
elif node.op_type == 'ConvTranspose':
attrs = _parse_attribute(node.attribute, {
"auto_pad": "NOTSET",
"dilations": [1, 1],
"pads": [0, 0, 0, 0],
"strides": [1, 1],
"group": 1})
assert len(node.input) == 2 or len(node.input) == 3
assert len(node.output) == 1
assert attrs["auto_pad"] == "NOTSET"
assert len(attrs["pads"]) == 4
assert len(attrs["strides"]) == 2
assert len(attrs["dilations"]) == 2
assert attrs["pads"][0] == attrs["pads"][2]
assert attrs["pads"][1] == attrs["pads"][3]
oph, opw = 0, 0
if "output_padding" in attrs:
oph, opw = attrs["output_padding"][0], attrs["output_padding"][1]
assert attrs["output_padding"][0] == attrs["output_padding"][1]
gf.convTrans(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]],
attrs["pads"][0], attrs["pads"][1],
attrs["strides"][0], attrs["strides"][1],
attrs["dilations"][0], attrs["dilations"][1],
oph, opw, group,
None if len(node.input) == 2 else ts[node.input[2]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad
elif node.op_type == 'Pad':
attrs = _parse_attribute(node.attribute, {'mode': b'constant'})
assert attrs["mode"].decode("ascii") == "constant"
assert len(attrs["pads"]) % 2 == 0
if "constant_value" in attrs:
assert attrs["constant_value"] == 0
gf.pad(ts[node.input[0]], ts[node.output[0]], attrs["pads"],
attrs["axes"] if axes in attrs else None)
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Slice
elif node.op_type == "Slice":
assert 3 <= len(node.input) <= 5
assert len(node.output) == 1
gf.slice(ts[node.input[0]], ts[node.output[0]],
ts[node.input[1]], ts[node.input[2]],
ts[node.input[3]] if len(node.input) == 4 else None,
ts[node.input[4]] if len(node.input) == 5 else None)
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
elif node.op_type == 'Concat':
attrs = _parse_attribute(node.attribute, {})
assert len(node.output) == 1
gf.concat([ts[item] for item in node.input], ts[node.output[0]], attrs["axis"])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
elif node.op_type == "Split":
attrs = _parse_attribute(node.attribute, {'axis': 0})
assert len(node.input) == 1
assert len(node.output) > 1
dim = attrs['axis']
num = attrs['num_outputs']
gf.split(ts[node.input[0]], [ts[t] for t in node.output], dim, num)
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
elif node.op_type == 'MaxPool':
attrs = _parse_attribute(node.attribute, {
"auto_pad": "NOTSET",
"dilations": [1, 1],
"pads": [0, 0, 0, 0],
"strides": [1, 1]})
assert len(node.input) == 1
assert len(node.output) == 1
assert len(attrs["kernel_shape"]) == 2
assert len(attrs["pads"]) == 4
assert len(attrs["strides"]) == 2
assert len(attrs["dilations"]) == 2
assert attrs["pads"][0] == attrs["pads"][2]
assert attrs["pads"][1] == attrs["pads"][3]
gf.maxpool(ts[node.input[0]], ts[node.output[0]],
attrs["kernel_shape"][0], attrs["kernel_shape"][1],
attrs["dilations"][0], attrs["dilations"][1],
attrs["pads"][0], attrs["pads"][1],
attrs["strides"][0], attrs["strides"][1])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
# No dilation in ONNX
elif node.op_type == 'AveragePool':
attrs = _parse_attribute(node.attribute, {
"auto_pad": "NOTSET",
"count_include_pad": 0,
"pads": [0, 0, 0, 0],
"strides": [1, 1]})
assert len(node.input) == 1
assert len(node.output) == 1
assert attrs["count_include_pad"] == 0 # To be consistent with operator.cc
assert len(attrs["kernel_shape"]) == 2
assert len(attrs["pads"]) == 4
assert len(attrs["strides"]) == 2
assert attrs["pads"][0] == attrs["pads"][2]
assert attrs["pads"][1] == attrs["pads"][3]
dh, dw = 1, 1
gf.avgpool(ts[node.input[0]], ts[node.output[0]],
attrs["kernel_shape"][0], attrs["kernel_shape"][1],
dw, dh,
attrs["pads"][0], attrs["pads"][1],
attrs["strides"][0], attrs["strides"][1])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Add
elif node.op_type == 'Add':
assert len(node.input) == 2
assert len(node.output) == 1
gf.add(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sub
elif node.op_type == 'Sub':
assert len(node.input) == 2
assert len(node.output) == 1
gf.sub(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Mul
elif node.op_type == 'Mul':
assert len(node.input) == 2
assert len(node.output) == 1
gf.mul(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Div
elif node.op_type == 'Div':
assert len(node.input) == 2
assert len(node.output) == 1
gf.div(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pow
elif node.op_type == 'Pow':
assert len(node.input) == 2
assert len(node.output) == 1
gf.pow(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather
elif node.op_type == 'Gather':
attrs = _parse_attribute(node.attribute, {"axis": 0})
assert len(node.input) == 2
assert len(node.output) == 1
gf.gather(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]], attrs["axis"])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape
elif node.op_type == 'Reshape':
attrs = _parse_attribute(node.attribute, {"allowzero": 0})
assert len(node.input) == 2
assert len(node.output) == 1
gf.reshape(ts[node.input[0]], ts[node.output[0]], ts[node.input[1]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Flatten
# Output is 2D in ONNX
elif node.op_type == 'Flatten':
assert len(node.input) == 1
assert len(node.output) == 1
gf.flatten(ts[node.input[0]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity
elif node.op_type == 'Identity':
assert len(node.input) == 1
assert len(node.output) == 1
gf.identity(ts[node.input[0]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Softmax
elif node.op_type == 'Softmax':
attrs = _parse_attribute(node.attribute, {"axis": -1})
assert len(node.input) == 1
assert len(node.output) == 1
gf.softmax(ts[node.input[0]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Relu
elif node.op_type == 'Relu':
assert len(node.input) == 1
assert len(node.output) == 1
gf.relu(ts[node.input[0]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sigmoid
elif node.op_type == 'Sigmoid':
assert len(node.input) == 1
assert len(node.output) == 1
gf.sigmoid(ts[node.input[0]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Tanh
elif node.op_type == 'Tanh':
assert len(node.input) == 1
assert len(node.output) == 1
gf.tanh(ts[node.input[0]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Abs
elif node.op_type == 'Abs':
assert len(node.input) == 1
assert len(node.output) == 1
gf.abs(ts[node.input[0]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
# Ignore for now (TODO)
elif node.op_type == 'Cast':
assert len(node.input) == 1
assert len(node.output) == 1
gf.identity(ts[node.input[0]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
elif node.op_type == 'Constant':
attrs = _parse_attribute(node.attribute, {})
assert len(node.output) == 1
c = onnx.numpy_helper.to_array(attrs["value"])
if c.ndim == 0:
c = c[()]
consts[node.output[0]] = c
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm
# elif node.op_type == 'Gemm':
# attrs = _parse_attribute(node.attribute, {
# "alpha": 1.0,
# "beta": 1.0,
# "transA": 0,
# "transB": 0})
# assert len(node.input) == 2 or len(node.input) == 3
# assert len(node.output) == 1
# assert attrs["alpha"] == 1.0
# assert attrs["beta"] == 1.0 or len(node.input) == 2
# tmpI0 = g.tensor([1] + list(ds[node.input[0]]), "FLOAT")
# tmpI1 = g.tensor([1] + list(ds[node.input[1]]), "FLOAT")
# tmpO = g.tensor([1] + list(ds[node.output[0]]), "FLOAT")
# g.transpose(ts[node.input[0]], tmpI0, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
# g.transpose(ts[node.input[1]], tmpI1, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
# g.matmul(tmpI0, tmpI1, tmpO,
# attrs["transA"], attrs["transB"],
# None if len(node.input) == 2 else ts[node.input[2]])
# g.transpose(tmpO, ts[node.output[0]], -1, Perm([PermItem([0, 1]), PermItem(2)]), 0)
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalAveragePool
# elif node.op_type == 'GlobalAveragePool':
# assert len(node.input) == 1
# assert len(node.output) == 1
# dims = ds[node.input[0]]
# if len(dims) > 0:
# g.avgpool(ts[node.input[0]], ts[node.output[0]], dims[2], dims[3], 0, 0, 1, 1)
# else:
# g.avgpool(ts[node.input[0]], ts[node.output[0]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMean
# elif node.op_type == 'ReduceMean':
# attrs = _parse_attribute(node.attribute, {'keepdims': 1})
# assert len(node.input) == 1
# assert len(node.output) == 1
# assert len(attrs["axes"]) == 1
# axis = attrs["axes"][0]
# if axis < 0:
# axis = len(ds[node.input[0]]) - axis
# gf.reduceMean(ts[node.input[0]], ts[node.output[0]], axis)
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
# Ignore for now, and no need to output anything (TODO)
# elif node.op_type == 'Shape':
# pass
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Transpose
# elif node.op_type == 'Transpose':
# attrs = _parse_attribute(node.attribute, {})
# assert len(node.input) == 1
# assert len(node.output) == 1
# assert "perm" in attrs
# gf.transpose(ts[node.input[0]], ts[node.output[0]], -1,
# Perm([PermItem(x) for x in attrs["perm"]]), 0)
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Unsqueeze
elif node.op_type == 'Unsqueeze':
assert len(node.input) == 2
assert len(node.output) == 1
gf.reshape(ts[node.input[0]], ts[node.output[0]], ts[node.input[1]])
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#BatchNormalization
# elif node.op_type == "BatchNormalization":
# attrs = _parse_attribute(node.attribute, {})
# assert len(node.input) == 5
# assert len(node.output) == 1
# epsilon = attrs['epsilon'] if 'epsilon' in attrs else 1e-5
# momentum = attrs['momentum'] if 'momentum' in attrs else 0.9
# g.batchnorm(ts[node.input[0]], ts[node.input[1]],
# ts[node.input[2]], ts[node.input[3]],
# ts[node.input[4]], ts[node.output[0]],
# epsilon, momentum)
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize
# elif node.op_type == "Resize":
# attrs = _parse_attribute(node.attribute, {})
# assert len(node.input) == 4
# assert len(node.output) == 1
# roi = ts[node.input[1]] if node.input[1] != '' else g.tensor(
# [], 'FLOAT')
# g.resize(ts[node.input[0]], roi, ts[node.output[0]])
else:
assert False, "Unsupported op: " + node.op_type

5
python/test/run_test.py Normal file
View File

@ -0,0 +1,5 @@
import pytest
if __name__ == "__main__":
retcode = pytest.main()

View File

@ -0,0 +1,9 @@
from pyinfinitensor import *
from infinitensor import import_onnx
class Test_ImportOnnx:
def test_Netname(self):
runtime = CpuRuntimeObj.getInstance()
graphBuilder = GraphBuilderObj(runtime)
import_onnx(graphBuilder, '/path/to/net')

519
src/core/graph_builder.cc Normal file
View File

@ -0,0 +1,519 @@
#include "core/graph_builder.h"
namespace infini {
Tensor GraphBuilderObj::tensor(Shape dim, const std::string &dtype) {
if (dtype == "FLOAT") {
return g->addTensor(dim, DataType::Float32);
}
if (dtype == "INT32") {
return g->addTensor(dim, DataType::UInt32);
}
IT_TODO_HALT_MSG("Unsupported data type");
}
Operator GraphBuilderObj::conv(Tensor input, Tensor weight, Tensor output,
int ph, int pw, int sh, int sw, int dh, int dw,
Tensor bias) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op =
g->addOpWithOutputs<ConvObj>(i0, w0, o0, ph, ph, sh, sw, dh, dw, bias);
return op;
}
Operator GraphBuilderObj::conv(Tensor input, Tensor weight, int ph, int pw,
int sh, int sw, int dh, int dw, Tensor bias) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
auto op = g->addOp<ConvObj>(i0, w0, nullptr, ph, ph, sh, sw, dh, dw, bias);
return op;
}
Operator GraphBuilderObj::conv(Tensor input, Tensor weight, Tensor output,
ConvBaseObj::PaddingMode pm, int sh, int sw,
int dh, int dw, Tensor bias) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op =
g->addOpWithOutputs<ConvObj>(i0, w0, o0, pm, sh, sw, dh, dw, bias);
return op;
}
Operator GraphBuilderObj::conv(Tensor input, Tensor weight,
ConvBaseObj::PaddingMode pm, int sh, int sw,
int dh, int dw, Tensor bias) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
auto op = g->addOp<ConvObj>(i0, w0, nullptr, pm, sh, sw, dh, dw, bias);
return op;
}
Operator GraphBuilderObj::matmul(Tensor A, Tensor B, Tensor C, bool transA,
bool transB, Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
Tensor o0 = g->addTensor(C->getDims(), C->getDType());
auto op =
g->addOpWithOutputs<MatmulObj>(i0, i1, o0, transA, transB, bias, act);
return op;
}
Operator GraphBuilderObj::matmul(Tensor A, Tensor B, bool transA, bool transB,
Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
auto op = g->addOp<MatmulObj>(i0, i1, nullptr, transA, transB, bias, act);
return op;
}
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, Tensor output,
int ph, int pw, int sh, int sw, int dh,
int dw, int oph, int opw, int group,
Tensor bias, ActType act) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<ConvTransposed2dObj>(
i0, w0, o0, ph, pw, sh, sw, dh, dw, oph, opw, group, bias, act);
return op;
}
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, int ph, int pw,
int sh, int sw, int dh, int dw, int oph,
int opw, int group, Tensor bias,
ActType act) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
auto op = g->addOp<ConvTransposed2dObj>(i0, w0, nullptr, ph, pw, sh, sw, dh,
dw, oph, opw, group, bias, act);
return op;
}
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, Tensor output,
ConvBaseObj::PaddingMode pm, int sh, int sw,
int dh, int dw, int oph, int opw, int group,
Tensor bias, ActType act) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<ConvTransposed2dObj>(
i0, w0, o0, pm, sh, sw, dh, dw, oph, opw, group, bias, act);
return op;
}
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight,
ConvBaseObj::PaddingMode pm, int sh, int sw,
int dh, int dw, int oph, int opw, int group,
Tensor bias, ActType act) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
auto op = g->addOp<ConvTransposed2dObj>(i0, w0, nullptr, pm, sh, sw, dh, dw,
oph, opw, group, bias, act);
return op;
}
Operator GraphBuilderObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width,
const int dilation, Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
Tensor o0 = g->addTensor(C->getDims(), C->getDType());
auto op =
g->addOpWithOutputs<G2BMMObj>(i0, i1, o0, width, dilation, bias, act);
return op;
}
Operator GraphBuilderObj::g2bmm(Tensor A, Tensor B, const int width,
const int dilation, Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
auto op = g->addOp<G2BMMObj>(i0, i1, nullptr, width, dilation, bias, act);
return op;
}
Operator GraphBuilderObj::gbmml(Tensor A, Tensor B, Tensor C,
const int dilation, Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
Tensor o0 = g->addTensor(C->getDims(), C->getDType());
auto op = g->addOpWithOutputs<GBMMObj>(i0, i1, o0, dilation, bias, act);
return op;
}
Operator GraphBuilderObj::gbmml(Tensor A, Tensor B, const int dilation,
Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
auto op = g->addOp<GBMMObj>(i0, i1, nullptr, dilation, bias, act);
return op;
}
Operator GraphBuilderObj::pad(Tensor input, Tensor output,
const vector<int> &pads,
const optional<const vector<int>> &axis) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<PadObj>(i0, o0, pads, axis);
return op;
}
Operator GraphBuilderObj::pad(Tensor input, const vector<int> &pads,
const optional<const vector<int>> &axis) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<PadObj>(i0, nullptr, pads, axis);
return op;
}
Operator GraphBuilderObj::slice(Tensor input, Tensor output,
const vector<int> &starts,
const vector<int> &ends,
const optional<const vector<int>> &axis,
const optional<const vector<int>> &steps) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<SliceObj>(i0, o0, starts, ends, axis, steps);
return op;
}
Operator GraphBuilderObj::slice(Tensor input, const vector<int> &starts,
const vector<int> &ends,
const optional<const vector<int>> &axis,
const optional<const vector<int>> &steps) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SliceObj>(i0, nullptr, starts, ends, axis, steps);
return op;
}
Operator GraphBuilderObj::concat(TensorVec inputs, Tensor output, int dim) {
TensorVec is;
for (auto input : inputs) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
is.push_back(i);
}
Tensor o = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<ConcatObj>(is, o, dim);
return op;
}
Operator GraphBuilderObj::concat(TensorVec inputs, int dim) {
TensorVec is;
for (auto input : inputs) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
is.push_back(i);
}
auto op = g->addOp<ConcatObj>(is, nullptr, dim);
return op;
}
Operator GraphBuilderObj::split(Tensor input, std::optional<TensorVec> outputs,
int dim, int num) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
if (outputs.has_value()) {
TensorVec os;
for (auto output : outputs.value()) {
Tensor o = g->addTensor(output->getDims(), output->getDType());
os.push_back(o);
}
auto op = g->addOpWithOutputs<SplitObj>(i, os, dim, num);
return op;
} else {
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, num);
return op;
}
}
Operator GraphBuilderObj::split(Tensor input, int dim, int num) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, num);
return op;
}
Operator GraphBuilderObj::split(Tensor input, std::optional<TensorVec> outputs,
int dim, const vector<int> &ratio) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
if (outputs.has_value()) {
TensorVec os;
for (auto output : outputs.value()) {
Tensor o = g->addTensor(output->getDims(), output->getDType());
os.push_back(o);
}
auto op = g->addOpWithOutputs<SplitObj>(i, os, dim, ratio);
return op;
} else {
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, ratio);
return op;
}
}
Operator GraphBuilderObj::split(Tensor input, int dim,
const vector<int> &ratio) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, ratio);
return op;
}
Operator GraphBuilderObj::extend(Tensor input, Tensor output, int dim,
int num) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<ExtendObj>(i0, o0, dim, num);
return op;
}
Operator GraphBuilderObj::extend(Tensor input, int dim, int num) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<ExtendObj>(i0, nullptr, dim, num);
return op;
}
Operator GraphBuilderObj::maxpool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh,
int sw) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op =
g->addOpWithOutputs<MaxPoolObj>(i0, o0, kh, kw, dh, dw, ph, pw, sh, sw);
return op;
}
Operator GraphBuilderObj::maxpool(Tensor input, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<MaxPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
return op;
}
Operator GraphBuilderObj::avgpool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh,
int sw) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(input->getDims(), input->getDType());
auto op =
g->addOpWithOutputs<AvgPoolObj>(i0, o0, kh, kw, dh, dw, ph, pw, sh, sw);
return op;
}
Operator GraphBuilderObj::avgpool(Tensor input, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<AvgPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
return op;
}
Operator GraphBuilderObj::add(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<AddObj>(i0, i1, o0);
return op;
}
Operator GraphBuilderObj::add(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<AddObj>(i0, i1, nullptr);
return op;
}
Operator GraphBuilderObj::sub(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<SubObj>(i0, i1, o0);
return op;
}
Operator GraphBuilderObj::sub(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<SubObj>(i0, i1, nullptr);
return op;
}
Operator GraphBuilderObj::mul(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<MulObj>(i0, i1, o0);
return op;
}
Operator GraphBuilderObj::mul(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<SubObj>(i0, i1, nullptr);
return op;
}
Operator GraphBuilderObj::div(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<DivObj>(i0, i1, o0);
return op;
}
Operator GraphBuilderObj::div(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<DivObj>(i0, i1, nullptr);
return op;
}
Operator GraphBuilderObj::pow(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<PowObj>(i0, i1, o0);
return op;
}
Operator GraphBuilderObj::pow(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<PowObj>(i0, i1, nullptr);
return op;
}
Operator GraphBuilderObj::gather(Tensor input, Tensor index, Tensor output,
int axis) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<GatherObj>(i0, index, o0, axis);
return op;
}
Operator GraphBuilderObj::gather(Tensor input, Tensor index, int axis) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<GatherObj>(i0, index, nullptr, axis);
return op;
}
Operator GraphBuilderObj::reshape(Tensor input, Tensor output,
const Shape &dims) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<ReshapeObj>(i0, o0, dims);
return op;
}
Operator GraphBuilderObj::reshape(Tensor input, const Shape &dims) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<ReshapeObj>(i0, nullptr, dims);
return op;
}
Operator GraphBuilderObj::flatten(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<FlattenObj>(i0, o0);
return op;
}
Operator GraphBuilderObj::flatten(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<FlattenObj>(i0, nullptr);
return op;
}
Operator GraphBuilderObj::identity(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<IdentityObj>(i0, o0);
return op;
}
Operator GraphBuilderObj::identity(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<IdentityObj>(i0, nullptr);
return op;
}
Operator GraphBuilderObj::softmax(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<SoftmaxObj>(i0, o0);
return op;
}
Operator GraphBuilderObj::softmax(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SoftmaxObj>(i0, nullptr);
return op;
}
Operator GraphBuilderObj::relu(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<ReluObj>(i0, o0);
return op;
}
Operator GraphBuilderObj::relu(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<ReluObj>(i0, nullptr);
return op;
}
Operator GraphBuilderObj::sigmoid(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<SigmoidObj>(i0, o0);
return op;
}
Operator GraphBuilderObj::sigmoid(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SigmoidObj>(i0, nullptr);
return op;
}
Operator GraphBuilderObj::tanh(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<TanhObj>(i0, o0);
return op;
}
Operator GraphBuilderObj::tanh(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<TanhObj>(i0, nullptr);
return op;
}
Operator GraphBuilderObj::abs(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<AbsObj>(i0, o0);
return op;
}
Operator GraphBuilderObj::abs(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<AbsObj>(i0, nullptr);
return op;
}
Operator GraphBuilderObj::memBound(const TensorVec &inputs,
const TensorVec &outputs,
const std::vector<nnet::Tensor> &nnetInputs,
nnet::Expr expr, double exec_time,
std::string hint) {
TensorVec is;
for (auto input : inputs) {
auto i = g->addTensor(input->getDims(), input->getDType());
is.push_back(i);
}
TensorVec os;
for (auto output : outputs) {
auto o = g->addTensor(output->getDims(), output->getDType());
os.push_back(o);
}
auto op = g->addOpWithOutputs<MemBoundObj>(is, os, nnetInputs, expr,
exec_time, hint);
return op;
}
} // namespace infini

View File

@ -2,6 +2,7 @@
#ifdef USE_CUDA
#include "cuda/operator_timer.h"
#endif
#include "core/graph_builder.h"
namespace py = pybind11;
namespace infini {
@ -18,6 +19,172 @@ void register_operator_timer(py::module &m) {
#endif
}
void init_graph_builder(py::module &m) {
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntimeObj")
.def(py::init<>())
.def("getInstance", py::overload_cast<>(&CpuRuntimeObj::getInstance),
policy::reference_internal);
py::class_<Shape>(m, "Shape");
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "TensorObj");
py::class_<Tensor>(m, "Tensor");
py::class_<TensorVec>(m, "TensorVec");
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "OperatorObj");
py::class_<Operator>(m, "Operator");
py::class_<ActType>(m, "ActType");
py::class_<ConvObj, std::shared_ptr<ConvObj>, OperatorObj>(m, "ConvObj");
py::class_<MatmulObj, std::shared_ptr<MatmulObj>, OperatorObj>(m,
"MatmulObj");
py::class_<ConvTransposed2dObj, std::shared_ptr<ConvTransposed2dObj>,
OperatorObj>(m, "ConvTransposed2dObj");
py::class_<G2BMMObj, std::shared_ptr<G2BMMObj>, OperatorObj>(m, "G2BMMObj");
py::class_<GBMMObj, std::shared_ptr<GBMMObj>, OperatorObj>(m, "GBMMObj");
py::class_<PadObj, std::shared_ptr<PadObj>, OperatorObj>(m, "PadObj");
py::class_<SliceObj, std::shared_ptr<SliceObj>, OperatorObj>(m, "SliceObj");
py::class_<ConcatObj, std::shared_ptr<ConcatObj>, OperatorObj>(m,
"ConcatObj");
py::class_<SplitObj, std::shared_ptr<SplitObj>, OperatorObj>(m, "SplitObj");
py::class_<ExtendObj, std::shared_ptr<ExtendObj>, OperatorObj>(m,
"ExtendObj");
py::class_<MaxPoolObj, std::shared_ptr<MaxPoolObj>, OperatorObj>(
m, "MaxPoolObj");
py::class_<AvgPoolObj, std::shared_ptr<AvgPoolObj>, OperatorObj>(
m, "AvgPoolObj");
py::class_<AddObj, std::shared_ptr<AddObj>, OperatorObj>(m, "AddObj");
py::class_<SubObj, std::shared_ptr<SubObj>, OperatorObj>(m, "SubObj");
py::class_<MulObj, std::shared_ptr<MulObj>, OperatorObj>(m, "MulObj");
py::class_<DivObj, std::shared_ptr<DivObj>, OperatorObj>(m, "DivObj");
py::class_<PowObj, std::shared_ptr<PowObj>, OperatorObj>(m, "PowObj");
py::class_<GatherObj, std::shared_ptr<GatherObj>, OperatorObj>(m,
"GatherObj");
py::class_<ReshapeObj, std::shared_ptr<ReshapeObj>, OperatorObj>(
m, "ReshapeObj");
py::class_<FlattenObj, std::shared_ptr<FlattenObj>, OperatorObj>(
m, "FlattenObj");
py::class_<IdentityObj, std::shared_ptr<IdentityObj>, OperatorObj>(
m, "IdentityObj");
py::class_<SoftmaxObj, std::shared_ptr<SoftmaxObj>, OperatorObj>(
m, "SoftmaxObj");
py::class_<ReluObj, std::shared_ptr<ReluObj>, OperatorObj>(m, "ReluObj");
py::class_<SigmoidObj, std::shared_ptr<SigmoidObj>, OperatorObj>(
m, "SigmoidObj");
py::class_<TanhObj, std::shared_ptr<TanhObj>, OperatorObj>(m, "TanhObj");
py::class_<AbsObj, std::shared_ptr<AbsObj>, OperatorObj>(m, "AbsObj");
py::class_<MemBoundObj, std::shared_ptr<MemBoundObj>, OperatorObj>(
m, "MemBoundObj");
py::class_<GraphBuilder>(m, "GraphBuilder");
py::class_<GraphBuilderObj>(m, "GraphBuilderObj")
.def(py::init<Runtime>())
.def("tensor",
py::overload_cast<Shape, const std::string &>(
&GraphBuilderObj::tensor),
policy::reference_internal)
.def("conv",
py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int,
int, Tensor>(&GraphBuilderObj::conv),
policy::reference_internal)
.def("matmul",
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
ActType>(&GraphBuilderObj::matmul),
policy::reference_internal)
.def("convTrans",
py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int,
int, int, int, int, Tensor, ActType>(
&GraphBuilderObj::convTrans),
policy::reference_internal)
.def("g2bmm",
py::overload_cast<Tensor, Tensor, Tensor, const int, const int,
Tensor, ActType>(&GraphBuilderObj::g2bmm),
policy::reference_internal)
.def("gbmml",
py::overload_cast<Tensor, Tensor, Tensor, const int, Tensor,
ActType>(&GraphBuilderObj::gbmml),
policy::reference_internal)
.def("pad",
py::overload_cast<Tensor, Tensor, const vector<int> &,
const optional<const vector<int>> &>(
&GraphBuilderObj::pad),
policy::reference_internal)
.def("slice",
py::overload_cast<Tensor, Tensor, const vector<int> &,
const vector<int> &,
const optional<const vector<int>> &,
const optional<const vector<int>> &>(
&GraphBuilderObj::slice),
policy::reference_internal)
.def(
"concat",
py::overload_cast<TensorVec, Tensor, int>(&GraphBuilderObj::concat),
policy::reference_internal)
.def("split",
py::overload_cast<Tensor, std::optional<TensorVec>, int, int>(
&GraphBuilderObj::split),
policy::reference_internal)
.def("extend",
py::overload_cast<Tensor, Tensor, int, int>(
&GraphBuilderObj::extend),
policy::reference_internal)
.def("maxpool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&GraphBuilderObj::maxpool),
policy::reference_internal)
.def("avgpool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&GraphBuilderObj::avgpool),
policy::reference_internal)
.def("add",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::add),
policy::reference_internal)
.def("sub",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::sub),
policy::reference_internal)
.def("mul",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::mul),
policy::reference_internal)
.def("div",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::div),
policy::reference_internal)
.def("pow",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::pow),
policy::reference_internal)
.def("gather",
py::overload_cast<Tensor, Tensor, Tensor, int>(
&GraphBuilderObj::gather),
policy::reference_internal)
.def("reshape",
py::overload_cast<Tensor, Tensor, const Shape &>(
&GraphBuilderObj::reshape),
policy::reference_internal)
.def("flatten",
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::flatten),
policy::reference_internal)
.def("identity",
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::identity),
policy::reference_internal)
.def("softmax",
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::softmax),
policy::reference_internal)
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::relu),
policy::reference_internal)
.def("sigmoid",
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::sigmoid),
policy::reference_internal)
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::tanh),
policy::reference_internal)
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::abs),
policy::reference_internal)
.def("memBound",
py::overload_cast<const TensorVec &, const TensorVec &,
const std::vector<nnet::Tensor> &, nnet::Expr,
double, std::string>(&GraphBuilderObj::memBound),
policy::reference_internal);
}
} // namespace infini
PYBIND11_MODULE(pyinfinitensor, m) { infini::register_operator_timer(m); }
PYBIND11_MODULE(pyinfinitensor, m) {
infini::register_operator_timer(m);
infini::init_graph_builder(m);
}

View File

@ -0,0 +1,265 @@
#include "core/graph_builder.h"
#include "test.h"
namespace infini {
TEST(GraphBuilder, ops) {
Runtime runtime = CpuRuntimeObj::getInstance();
{ // conv without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
auto weight =
make_ref<TensorObj>(Shape{2, 3, 3, 3}, DataType::UInt32, runtime);
auto conv = gf->conv(input, weight, 1, 1);
EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 4, 4}));
}
{ // conv with output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
auto weight =
make_ref<TensorObj>(Shape{2, 3, 3, 3}, DataType::UInt32, runtime);
auto output =
make_ref<TensorObj>(Shape{1, 2, 4, 4}, DataType::UInt32, runtime);
auto conv = gf->conv(input, weight, output, 1, 1);
}
{ // matmul without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime);
auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime);
auto matmul = gf->matmul(A, B);
EXPECT_EQ(matmul->getOutput()->getDims(), (Shape{1, 3, 2}));
}
{ // matmul with output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime);
auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime);
auto C = make_ref<TensorObj>(Shape{1, 3, 2}, DataType::UInt32, runtime);
auto matmul = gf->matmul(A, B, C);
}
{ // convtrans without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime);
auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2},
DataType::UInt32, runtime);
auto convtrans = gf->convTrans(input, weight, 0, 0);
EXPECT_EQ(convtrans->getOutput()->getDims(), (Shape{1, 448, 2, 2}));
}
{ // convtrans with output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime);
auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2},
DataType::UInt32, runtime);
auto output =
make_ref<TensorObj>(Shape{1, 448, 2, 2}, DataType::UInt32, runtime);
auto convtrans = gf->convTrans(input, weight, 0, 0);
}
{ // pad without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
DataType::UInt32, runtime);
vector<int> pads = {2, 10, 1, 5, 0, 10, 1, 5};
auto pad = gf->pad(input, pads, std::nullopt);
EXPECT_EQ(pad->getOutput()->getDims(), (Shape{3, 84, 164, 172}));
}
{ // pad with output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
DataType::UInt32, runtime);
auto output = make_ref<TensorObj>(Shape{3, 84, 164, 172},
DataType::UInt32, runtime);
vector<int> pads = {2, 10, 1, 5, 0, 10, 1, 5};
auto pad = gf->pad(input, output, pads, std::nullopt);
}
{ // slice without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162},
DataType::UInt32, runtime);
vector<int> starts = {2, 10, 1, 5};
vector<int> ends = {3, 10, 100, 100};
auto slice = gf->slice(input, starts, ends, std::nullopt, std::nullopt);
EXPECT_EQ(slice->getOutput()->getDims(), (Shape{2, 1, 100, 96}));
}
{ // slice with output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162},
DataType::UInt32, runtime);
auto output = make_ref<TensorObj>(Shape{2, 1, 100, 96},
DataType::UInt32, runtime);
vector<int> starts = {2, 10, 1, 5};
vector<int> ends = {3, 10, 100, 100};
auto slice =
gf->slice(input, output, starts, ends, std::nullopt, std::nullopt);
}
{ // concat without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto t1 =
make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime);
auto t2 =
make_ref<TensorObj>(Shape{1, 3, 2, 5}, DataType::Float32, runtime);
auto concat = gf->concat(TensorVec{t1, t2}, 3);
EXPECT_EQ(concat->getOutput()->getDims(), (Shape{1, 3, 2, 9}));
}
{ // concat with output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto t1 =
make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime);
auto t2 =
make_ref<TensorObj>(Shape{1, 3, 2, 5}, DataType::Float32, runtime);
auto o0 =
make_ref<TensorObj>(Shape{1, 3, 2, 9}, DataType::Float32, runtime);
auto concat = gf->concat(TensorVec{t1, t2}, o0, 3);
}
{ // split without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime);
auto split = gf->split(input, 3, 4);
EXPECT_EQ(split->numOutputs(), 4);
EXPECT_EQ(split->getOutputs().size(), (size_t)4);
EXPECT_EQ(split->getOutput(0)->getDims(), (Shape{1, 3, 2, 3}));
EXPECT_EQ(split->getOutput(1)->getDims(), (Shape{1, 3, 2, 3}));
EXPECT_EQ(split->getOutput(2)->getDims(), (Shape{1, 3, 2, 3}));
EXPECT_EQ(split->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
}
{ // split with output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime);
auto output0 =
make_ref<TensorObj>(Shape{1, 3, 2, 3}, DataType::Float32, runtime);
auto output1 =
make_ref<TensorObj>(Shape{1, 3, 2, 3}, DataType::Float32, runtime);
auto output2 =
make_ref<TensorObj>(Shape{1, 3, 2, 3}, DataType::Float32, runtime);
auto output3 =
make_ref<TensorObj>(Shape{1, 3, 2, 6}, DataType::Float32, runtime);
auto split = gf->split(
input, TensorVec{output0, output1, output2, output3}, 3, 4);
}
{ // extend without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto extend = gf->extend(input, 2, 1);
EXPECT_EQ(extend->getOutput()->getDims(), (Shape{2, 3, 6, 4}));
}
{ // extend with output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto output =
make_ref<TensorObj>(Shape{2, 3, 6, 4}, DataType::UInt32, runtime);
auto extend = gf->extend(input, output, 2, 1);
}
{ // maxpool without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
DataType::UInt32, runtime);
const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2,
sw = 2;
auto maxpool = gf->maxpool(input, kh, kw, dh, dw, ph, pw, sh, sw);
EXPECT_EQ(maxpool->getOutput()->getDims(), (Shape{1, 64, 80, 80}));
}
{ // maxpool with output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
DataType::UInt32, runtime);
auto output = make_ref<TensorObj>(Shape{1, 64, 80, 80},
DataType::UInt32, runtime);
const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2,
sw = 2;
auto maxpool =
gf->maxpool(input, output, kh, kw, dh, dw, ph, pw, sh, sw);
}
{ // add without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input0 =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto input1 =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto add = gf->add(input0, input1);
EXPECT_EQ(add->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
}
{ // add with output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input0 =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto input1 =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto output =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto add = gf->add(input0, input1, output);
}
{ // gather without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
auto index =
make_ref<TensorObj>(Shape{2, 1, 2}, DataType::UInt32, runtime);
auto gather = gf->gather(input, index, 1);
EXPECT_EQ(gather->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
}
{ // gather with output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
auto index =
make_ref<TensorObj>(Shape{2, 1, 2}, DataType::UInt32, runtime);
auto output = make_ref<TensorObj>(Shape{1, 2, 1, 2, 4, 4},
DataType::UInt32, runtime);
auto gather = gf->gather(input, index, output, 1);
}
{ // reshape without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
vector<int> dims = {3, 2, 4, 3};
auto reshape = gf->reshape(input, dims);
EXPECT_EQ(reshape->getOutput()->getDims(), (Shape{3, 2, 4, 3}));
}
{ // reshape with output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
vector<int> dims = {3, 2, 4, 3};
auto output =
make_ref<TensorObj>(Shape{3, 2, 4, 3}, DataType::Float32, runtime);
auto reshape = gf->reshape(input, output, dims);
}
{ // flatten without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto flatten = gf->flatten(input);
EXPECT_EQ(flatten->getOutput()->getDims(), (Shape{72}));
}
{ // flatten without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto output =
make_ref<TensorObj>(Shape{72}, DataType::Float32, runtime);
auto flatten = gf->flatten(input, output);
}
{ // identity without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto identity = gf->identity(input);
EXPECT_EQ(identity->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
}
{ // identity without output
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto output =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto identity = gf->identity(input, output);
}
}
} // namespace infini