forked from jiuyuan/InfiniTensor
add: graph build for pf.
This commit is contained in:
parent
f88aefb2ca
commit
a0e07199ff
|
@ -14,6 +14,7 @@
|
|||
#include "operators/membound.h"
|
||||
#include "operators/pad.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/slice.h"
|
||||
#include "operators/split.h"
|
||||
|
@ -155,6 +156,7 @@ class GraphBuilderObj {
|
|||
Operator tanh(Tensor input);
|
||||
Operator abs(Tensor input, Tensor output);
|
||||
Operator abs(Tensor input);
|
||||
Operator reduceMean(Tensor input, Tensor Output, int axis);
|
||||
// resize op
|
||||
// TODO
|
||||
// membound op
|
||||
|
|
|
@ -9,7 +9,8 @@ import onnx.checker
|
|||
import onnx.numpy_helper
|
||||
import onnx.shape_inference
|
||||
|
||||
def _add_value_info_for_constants(model : onnx.ModelProto):
|
||||
|
||||
def _add_value_info_for_constants(model: onnx.ModelProto):
|
||||
"""
|
||||
Currently onnx.shape_inference doesn't use the shape of initializers, so add
|
||||
that info explicitly as ValueInfoProtos.
|
||||
|
@ -21,7 +22,7 @@ def _add_value_info_for_constants(model : onnx.ModelProto):
|
|||
if model.ir_version < 4:
|
||||
return
|
||||
|
||||
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
|
||||
def add_const_value_infos_to_graph(graph: onnx.GraphProto):
|
||||
inputs = {i.name for i in graph.input}
|
||||
existing_info = {vi.name: vi for vi in graph.value_info}
|
||||
for init in graph.initializer:
|
||||
|
@ -123,20 +124,22 @@ def _onnx_datatype_tostring(dtype):
|
|||
|
||||
|
||||
def import_onnx(gf: GraphBuilder, net: str):
|
||||
ts, ds, ops, consts = dict(), dict(), dict(), dict() # (key, value) = (name, class)
|
||||
ts, ds, ops, consts = dict(), dict(), dict(), dict() # (key, value) = (name, class)
|
||||
model = onnx.load(net)
|
||||
|
||||
# Tensor_input
|
||||
for input in model.graph.input:
|
||||
if input.name not in ts:
|
||||
dims = [d.dim_value for d in input.type.tensor_type.shape.dim]
|
||||
ts[input.name] = gf.tensor(dims, _onnx_datatype_tostring(input.type.tensor_type.elem_type))
|
||||
ts[input.name] = gf.tensor(dims, _onnx_datatype_tostring(
|
||||
input.type.tensor_type.elem_type))
|
||||
ds[input.name] = dims
|
||||
|
||||
# Tensor_weight
|
||||
for weight in model.graph.initializer:
|
||||
if weight.name not in ts:
|
||||
ts[weight.name] = gf.tensor(weight.dims, _onnx_datatype_tostring(weight.data_type))
|
||||
ts[weight.name] = gf.tensor(
|
||||
weight.dims, _onnx_datatype_tostring(weight.data_type))
|
||||
ds[weight.name] = weight.dims
|
||||
|
||||
# Tensor_inference
|
||||
|
@ -145,14 +148,16 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
for v in infered_model.graph.value_info:
|
||||
if v.name not in ts:
|
||||
dims = [d.dim_value for d in v.type.tensor_type.shape.dim]
|
||||
ts[v.name] = gf.tensor(dims, _onnx_datatype_tostring(v.type.tensor_type.elem_type))
|
||||
ts[v.name] = gf.tensor(dims, _onnx_datatype_tostring(
|
||||
v.type.tensor_type.elem_type))
|
||||
ds[v.name] = dims
|
||||
|
||||
# Tensor_output
|
||||
for output in model.graph.output:
|
||||
if output.name not in ts:
|
||||
dims = [d.dim_value for d in output.type.tensor_type.shape.dim]
|
||||
ts[output.name] = gf.tensor(dims, _onnx_datatype_tostring(output.type.tensor_type.elem_type))
|
||||
ts[output.name] = gf.tensor(dims, _onnx_datatype_tostring(
|
||||
output.type.tensor_type.elem_type))
|
||||
ds[output.name] = dims
|
||||
|
||||
# Op
|
||||
|
@ -164,7 +169,7 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
"dilations": [1, 1],
|
||||
"pads": [0, 0, 0, 0],
|
||||
"strides": [1, 1]})
|
||||
assert len(node.input) == 2 # bias is not implemented yet
|
||||
assert len(node.input) == 2 # bias is not implemented yet
|
||||
assert len(node.output) == 1
|
||||
assert attrs["auto_pad"] == "NOTSET"
|
||||
assert len(attrs["pads"]) == 4
|
||||
|
@ -178,7 +183,7 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
attrs["dilations"][0], attrs["dilations"][1],
|
||||
None if len(node.input) == 2 else ts[node.input[2]])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul
|
||||
# elif node.op_type == 'MatMul':
|
||||
# assert len(node.input) == 2
|
||||
# assert len(node.output) == 1
|
||||
|
@ -240,12 +245,12 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
oph, opw = attrs["output_padding"][0], attrs["output_padding"][1]
|
||||
assert attrs["output_padding"][0] == attrs["output_padding"][1]
|
||||
gf.convTrans(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]],
|
||||
attrs["pads"][0], attrs["pads"][1],
|
||||
attrs["strides"][0], attrs["strides"][1],
|
||||
attrs["dilations"][0], attrs["dilations"][1],
|
||||
oph, opw, group,
|
||||
None if len(node.input) == 2 else ts[node.input[2]])
|
||||
|
||||
attrs["pads"][0], attrs["pads"][1],
|
||||
attrs["strides"][0], attrs["strides"][1],
|
||||
attrs["dilations"][0], attrs["dilations"][1],
|
||||
oph, opw, group,
|
||||
None if len(node.input) == 2 else ts[node.input[2]])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad
|
||||
elif node.op_type == 'Pad':
|
||||
attrs = _parse_attribute(node.attribute, {'mode': b'constant'})
|
||||
|
@ -269,12 +274,13 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
elif node.op_type == 'Concat':
|
||||
attrs = _parse_attribute(node.attribute, {})
|
||||
assert len(node.output) == 1
|
||||
gf.concat([ts[item] for item in node.input], ts[node.output[0]], attrs["axis"])
|
||||
gf.concat([ts[item] for item in node.input],
|
||||
ts[node.output[0]], attrs["axis"])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
|
||||
elif node.op_type == "Split":
|
||||
attrs = _parse_attribute(node.attribute, {'axis': 0})
|
||||
assert len(node.input) == 1
|
||||
assert len(node.input) == 1
|
||||
assert len(node.output) > 1
|
||||
dim = attrs['axis']
|
||||
num = attrs['num_outputs']
|
||||
|
@ -296,10 +302,10 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
assert attrs["pads"][0] == attrs["pads"][2]
|
||||
assert attrs["pads"][1] == attrs["pads"][3]
|
||||
gf.maxpool(ts[node.input[0]], ts[node.output[0]],
|
||||
attrs["kernel_shape"][0], attrs["kernel_shape"][1],
|
||||
attrs["dilations"][0], attrs["dilations"][1],
|
||||
attrs["pads"][0], attrs["pads"][1],
|
||||
attrs["strides"][0], attrs["strides"][1])
|
||||
attrs["kernel_shape"][0], attrs["kernel_shape"][1],
|
||||
attrs["dilations"][0], attrs["dilations"][1],
|
||||
attrs["pads"][0], attrs["pads"][1],
|
||||
attrs["strides"][0], attrs["strides"][1])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
|
||||
# No dilation in ONNX
|
||||
|
@ -311,7 +317,8 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
"strides": [1, 1]})
|
||||
assert len(node.input) == 1
|
||||
assert len(node.output) == 1
|
||||
assert attrs["count_include_pad"] == 0 # To be consistent with operator.cc
|
||||
# To be consistent with operator.cc
|
||||
assert attrs["count_include_pad"] == 0
|
||||
assert len(attrs["kernel_shape"]) == 2
|
||||
assert len(attrs["pads"]) == 4
|
||||
assert len(attrs["strides"]) == 2
|
||||
|
@ -319,54 +326,132 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
assert attrs["pads"][1] == attrs["pads"][3]
|
||||
dh, dw = 1, 1
|
||||
gf.avgpool(ts[node.input[0]], ts[node.output[0]],
|
||||
attrs["kernel_shape"][0], attrs["kernel_shape"][1],
|
||||
dw, dh,
|
||||
attrs["pads"][0], attrs["pads"][1],
|
||||
attrs["strides"][0], attrs["strides"][1])
|
||||
attrs["kernel_shape"][0], attrs["kernel_shape"][1],
|
||||
dw, dh,
|
||||
attrs["pads"][0], attrs["pads"][1],
|
||||
attrs["strides"][0], attrs["strides"][1])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Add
|
||||
elif node.op_type == 'Add':
|
||||
assert len(node.input) == 2
|
||||
assert len(node.output) == 1
|
||||
gf.add(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]])
|
||||
assert ds[node.input[0]] == ds[node.output[0]]
|
||||
if ds[node.input[0]] == ds[node.input[1]]:
|
||||
gf.add(ts[node.input[0]], ts[node.input[1]],
|
||||
ts[node.output[0]])
|
||||
elif len(ds[node.input[1]]) == 0:
|
||||
tmp = gf.tensor(ds[node.input[0]], "FLOAT")
|
||||
gf.add(ts[node.input[0]], tmp, ts[node.output[0]])
|
||||
elif ds[node.input[1]][-1] == 1 and ds[node.input[0]][:-1] == ds[node.input[1]][:-1]:
|
||||
tmp = gf.tensor(ds[node.output[0]], "FLOAT")
|
||||
gf.extend(ts[node.input[1]], tmp,
|
||||
len(ds[node.input[1]]) - 1, ds[node.input[0]][-1] - 1)
|
||||
gf.add(ts[node.input[0]], tmp, ts[node.output[0]])
|
||||
elif len(ds[node.input[1]]) == 1 and ds[node.input[0]][-1] == ds[node.input[1]][0]:
|
||||
tmp = gf.tensor(ds[node.output[0]], "FLOAT")
|
||||
# gf.extend(ts[node.input[1]], tmp,
|
||||
# len(ds[node.input[1]]) - 1, ds[node.input[0]][-1] - 1)
|
||||
gf.add(ts[node.input[0]], tmp, ts[node.output[0]])
|
||||
else:
|
||||
assert False
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sub
|
||||
elif node.op_type == 'Sub':
|
||||
assert len(node.input) == 2
|
||||
assert len(node.output) == 1
|
||||
gf.sub(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]])
|
||||
|
||||
assert ds[node.input[0]] == ds[node.output[0]]
|
||||
if ds[node.input[0]] == ds[node.input[1]]:
|
||||
gf.sub(ts[node.input[0]], ts[node.input[1]],
|
||||
ts[node.output[0]])
|
||||
elif len(ds[node.input[1]]) == 0:
|
||||
tmp = gf.tensor(ds[node.input[0]], "FLOAT")
|
||||
gf.sub(ts[node.input[0]], tmp, ts[node.output[0]])
|
||||
elif ds[node.input[1]][-1] == 1 and ds[node.input[0]][:-1] == ds[node.input[1]][:-1]:
|
||||
tmp = gf.tensor(ds[node.output[0]], "FLOAT")
|
||||
gf.extend(ts[node.input[1]], tmp,
|
||||
len(ds[node.input[1]]) - 1, ds[node.input[0]][-1] - 1)
|
||||
gf.sub(ts[node.input[0]], tmp, ts[node.output[0]])
|
||||
else:
|
||||
assert False
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Mul
|
||||
elif node.op_type == 'Mul':
|
||||
assert len(node.input) == 2
|
||||
assert len(node.output) == 1
|
||||
gf.mul(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]])
|
||||
assert ds[node.input[0]] == ds[node.output[0]]
|
||||
if ds[node.input[0]] == ds[node.input[1]]:
|
||||
gf.mul(ts[node.input[0]], ts[node.input[1]],
|
||||
ts[node.output[0]])
|
||||
elif len(ds[node.input[1]]) == 0:
|
||||
tmp = gf.tensor(ds[node.input[0]], "FLOAT")
|
||||
gf.mul(ts[node.input[0]], tmp, ts[node.output[0]])
|
||||
elif ds[node.input[1]][-1] == 1 and ds[node.input[0]][:-1] == ds[node.input[1]][:-1]:
|
||||
tmp = gf.tensor(ds[node.output[0]], "FLOAT")
|
||||
gf.extend(ts[node.input[1]], tmp,
|
||||
len(ds[node.input[1]]) - 1, ds[node.input[0]][-1] - 1)
|
||||
gf.mul(ts[node.input[0]], tmp, ts[node.output[0]])
|
||||
elif len(ds[node.input[1]]) == 1 and ds[node.input[0]][-1] == ds[node.input[1]][0]:
|
||||
tmp = gf.tensor(ds[node.output[0]], "FLOAT")
|
||||
# gf.extend(ts[node.input[1]], tmp,
|
||||
# len(ds[node.input[1]]) - 1, ds[node.input[0]][-1] - 1)
|
||||
gf.mul(ts[node.input[0]], tmp, ts[node.output[0]])
|
||||
else:
|
||||
print(ds[node.input[0]], ds[node.input[1]])
|
||||
assert False
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Div
|
||||
elif node.op_type == 'Div':
|
||||
assert len(node.input) == 2
|
||||
assert len(node.output) == 1
|
||||
gf.div(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]])
|
||||
assert ds[node.input[0]] == ds[node.output[0]]
|
||||
if ds[node.input[0]] == ds[node.input[1]]:
|
||||
gf.div(ts[node.input[0]], ts[node.input[1]],
|
||||
ts[node.output[0]])
|
||||
elif len(ds[node.input[1]]) == 0:
|
||||
tmp = gf.tensor(ds[node.input[0]], "FLOAT")
|
||||
gf.div(ts[node.input[0]], tmp, ts[node.output[0]])
|
||||
elif ds[node.input[1]][-1] == 1 and ds[node.input[0]][:-1] == ds[node.input[1]][:-1]:
|
||||
tmp = gf.tensor(ds[node.output[0]], "FLOAT")
|
||||
gf.extend(ts[node.input[1]], tmp,
|
||||
len(ds[node.input[1]]) - 1, ds[node.input[0]][-1] - 1)
|
||||
gf.div(ts[node.input[0]], tmp, ts[node.output[0]])
|
||||
else:
|
||||
assert False
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pow
|
||||
elif node.op_type == 'Pow':
|
||||
assert len(node.input) == 2
|
||||
assert len(node.output) == 1
|
||||
gf.pow(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]])
|
||||
assert ds[node.input[0]] == ds[node.output[0]]
|
||||
if ds[node.input[0]] == ds[node.input[1]]:
|
||||
gf.pow(ts[node.input[0]], ts[node.input[1]],
|
||||
ts[node.output[0]])
|
||||
elif len(ds[node.input[1]]) == 0:
|
||||
tmp = gf.tensor(ds[node.input[0]], "FLOAT")
|
||||
gf.pow(ts[node.input[0]], tmp, ts[node.output[0]])
|
||||
elif ds[node.input[1]][-1] == 1 and ds[node.input[0]][:-1] == ds[node.input[1]][:-1]:
|
||||
tmp = gf.tensor(ds[node.output[0]], "FLOAT")
|
||||
gf.extend(ts[node.input[1]], tmp,
|
||||
len(ds[node.input[1]]) - 1, ds[node.input[0]][-1] - 1)
|
||||
gf.pow(ts[node.input[0]], tmp, ts[node.output[0]])
|
||||
else:
|
||||
assert False
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather
|
||||
elif node.op_type == 'Gather':
|
||||
attrs = _parse_attribute(node.attribute, {"axis": 0})
|
||||
assert len(node.input) == 2
|
||||
assert len(node.output) == 1
|
||||
gf.gather(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]], attrs["axis"])
|
||||
gf.gather(ts[node.input[0]], ts[node.input[1]],
|
||||
ts[node.output[0]], attrs["axis"])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape
|
||||
elif node.op_type == 'Reshape':
|
||||
attrs = _parse_attribute(node.attribute, {"allowzero": 0})
|
||||
assert len(node.input) == 2
|
||||
assert len(node.output) == 1
|
||||
gf.reshape(ts[node.input[0]], ts[node.output[0]], ts[node.input[1]])
|
||||
gf.reshape(ts[node.input[0]],
|
||||
ts[node.output[0]], ts[node.input[1]])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Flatten
|
||||
# Output is 2D in ONNX
|
||||
|
@ -394,6 +479,12 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
assert len(node.output) == 1
|
||||
gf.relu(ts[node.input[0]], ts[node.output[0]])
|
||||
|
||||
# TODO
|
||||
elif node.op_type == 'Sqrt':
|
||||
assert len(node.input) == 1
|
||||
assert len(node.output) == 1
|
||||
gf.relu(ts[node.input[0]], ts[node.output[0]])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sigmoid
|
||||
elif node.op_type == 'Sigmoid':
|
||||
assert len(node.input) == 1
|
||||
|
@ -460,15 +551,16 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
# g.avgpool(ts[node.input[0]], ts[node.output[0]])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMean
|
||||
# elif node.op_type == 'ReduceMean':
|
||||
# attrs = _parse_attribute(node.attribute, {'keepdims': 1})
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# assert len(attrs["axes"]) == 1
|
||||
# axis = attrs["axes"][0]
|
||||
# if axis < 0:
|
||||
# axis = len(ds[node.input[0]]) - axis
|
||||
# gf.reduceMean(ts[node.input[0]], ts[node.output[0]], axis)
|
||||
elif node.op_type == 'ReduceMean':
|
||||
attrs = _parse_attribute(node.attribute, {'keepdims': 1})
|
||||
assert len(node.input) == 1
|
||||
assert len(node.output) == 1
|
||||
assert len(attrs["axes"]) == 1
|
||||
axis = attrs["axes"][0]
|
||||
print(axis, len(ds[node.input[0]]))
|
||||
if axis < 0:
|
||||
axis = len(ds[node.input[0]]) + axis
|
||||
gf.reduceMean(ts[node.input[0]], ts[node.output[0]], axis)
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
|
||||
# Ignore for now, and no need to output anything (TODO)
|
||||
|
@ -488,8 +580,9 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
elif node.op_type == 'Unsqueeze':
|
||||
assert len(node.input) == 2
|
||||
assert len(node.output) == 1
|
||||
gf.reshape(ts[node.input[0]], ts[node.output[0]], ts[node.input[1]])
|
||||
|
||||
gf.reshape(ts[node.input[0]],
|
||||
ts[node.output[0]], ts[node.input[1]])
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/docs/Operators.md#BatchNormalization
|
||||
# elif node.op_type == "BatchNormalization":
|
||||
# attrs = _parse_attribute(node.attribute, {})
|
||||
|
@ -498,7 +591,7 @@ def import_onnx(gf: GraphBuilder, net: str):
|
|||
# epsilon = attrs['epsilon'] if 'epsilon' in attrs else 1e-5
|
||||
# momentum = attrs['momentum'] if 'momentum' in attrs else 0.9
|
||||
# g.batchnorm(ts[node.input[0]], ts[node.input[1]],
|
||||
# ts[node.input[2]], ts[node.input[3]],
|
||||
# ts[node.input[2]], ts[node.input[3]],
|
||||
# ts[node.input[4]], ts[node.output[0]],
|
||||
# epsilon, momentum)
|
||||
|
||||
|
|
|
@ -6,4 +6,4 @@ class Test_ImportOnnx:
|
|||
def test_Netname(self):
|
||||
runtime = CpuRuntimeObj.getInstance()
|
||||
graphBuilder = GraphBuilderObj(runtime)
|
||||
import_onnx(graphBuilder, '/path/to/net')
|
||||
import_onnx(graphBuilder, '/home/mazx/git/pf-models/bert.bs1.onnx')
|
||||
|
|
|
@ -9,6 +9,9 @@ Tensor GraphBuilderObj::tensor(Shape dim, const std::string &dtype) {
|
|||
if (dtype == "INT32") {
|
||||
return g->addTensor(dim, DataType::UInt32);
|
||||
}
|
||||
if (dtype == "INT64") {
|
||||
return g->addTensor(dim, DataType::UInt32);
|
||||
}
|
||||
IT_TODO_HALT_MSG("Unsupported data type");
|
||||
}
|
||||
|
||||
|
@ -318,6 +321,7 @@ Operator GraphBuilderObj::add(Tensor input0, Tensor input1) {
|
|||
}
|
||||
|
||||
Operator GraphBuilderObj::sub(Tensor input0, Tensor input1, Tensor output) {
|
||||
std::cout << "Sub1" << std::endl;
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
|
@ -326,6 +330,7 @@ Operator GraphBuilderObj::sub(Tensor input0, Tensor input1, Tensor output) {
|
|||
}
|
||||
|
||||
Operator GraphBuilderObj::sub(Tensor input0, Tensor input1) {
|
||||
std::cout << "Sub2" << std::endl;
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
auto op = g->addOp<SubObj>(i0, i1, nullptr);
|
||||
|
@ -343,7 +348,7 @@ Operator GraphBuilderObj::mul(Tensor input0, Tensor input1, Tensor output) {
|
|||
Operator GraphBuilderObj::mul(Tensor input0, Tensor input1) {
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
auto op = g->addOp<SubObj>(i0, i1, nullptr);
|
||||
auto op = g->addOp<MulObj>(i0, i1, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
|
@ -496,6 +501,14 @@ Operator GraphBuilderObj::abs(Tensor input) {
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphBuilderObj::reduceMean(Tensor input, Tensor output, int axis) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
auto op =
|
||||
g->addOpWithOutputs<ReduceMeanObj>(i0, o0, std::vector<int>({axis}));
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphBuilderObj::memBound(const TensorVec &inputs,
|
||||
const TensorVec &outputs,
|
||||
const std::vector<nnet::Tensor> &nnetInputs,
|
||||
|
|
|
@ -71,6 +71,8 @@ void init_graph_builder(py::module &m) {
|
|||
m, "SigmoidObj");
|
||||
py::class_<TanhObj, std::shared_ptr<TanhObj>, OperatorObj>(m, "TanhObj");
|
||||
py::class_<AbsObj, std::shared_ptr<AbsObj>, OperatorObj>(m, "AbsObj");
|
||||
py::class_<ReduceMeanObj, std::shared_ptr<ReduceMeanObj>, OperatorObj>(
|
||||
m, "ReduceMeanObj");
|
||||
py::class_<MemBoundObj, std::shared_ptr<MemBoundObj>, OperatorObj>(
|
||||
m, "MemBoundObj");
|
||||
py::class_<GraphBuilder>(m, "GraphBuilder");
|
||||
|
@ -175,6 +177,10 @@ void init_graph_builder(py::module &m) {
|
|||
policy::reference_internal)
|
||||
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::abs),
|
||||
policy::reference_internal)
|
||||
.def("reduceMean",
|
||||
py::overload_cast<Tensor, Tensor, int>(
|
||||
&GraphBuilderObj::reduceMean),
|
||||
policy::reference_internal)
|
||||
.def("memBound",
|
||||
py::overload_cast<const TensorVec &, const TensorVec &,
|
||||
const std::vector<nnet::Tensor> &, nnet::Expr,
|
||||
|
|
|
@ -4,6 +4,15 @@ namespace infini {
|
|||
ElementWiseObj::ElementWiseObj(OpType type, GraphObj *graph, Tensor input0,
|
||||
Tensor input1, Tensor output)
|
||||
: OperatorObj(type, {input0, input1}, {output}) {
|
||||
std::cout << "Element: " << int(type) << std::endl;
|
||||
for (auto x : input0->getDims()) {
|
||||
std::cout << x << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
for (auto x : input1->getDims()) {
|
||||
std::cout << x << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
@ -11,7 +20,17 @@ optional<vector<Shape>>
|
|||
ElementWiseObj::inferShape(const TensorVec &inputs) const {
|
||||
// For now,we only process the same dims here, broardcast will be considered
|
||||
// in the opt layer.
|
||||
std::cout << std::endl;
|
||||
const auto A = inputs[0], B = inputs[1];
|
||||
std::cout << "InferShape" << std::endl;
|
||||
for (auto x : A->getDims()) {
|
||||
std::cout << x << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
for (auto x : B->getDims()) {
|
||||
std::cout << x << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
if (A->getDims().size() != B->getDims().size() ||
|
||||
A->getDims() != B->getDims())
|
||||
return {};
|
||||
|
|
|
@ -5,6 +5,15 @@ namespace infini {
|
|||
ExtendObj::ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
|
||||
int num)
|
||||
: OperatorObj(OpType::Extend, {input}, {output}), dim(dim), num(num) {
|
||||
std::cout << "Extend" << std::endl;
|
||||
for (auto x : input->getDims()) {
|
||||
std::cout << x << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
for (auto x : output->getDims()) {
|
||||
std::cout << x << " ";
|
||||
}
|
||||
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ ReduceMeanObj::ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
|
|||
int idx = (*_axis)[j];
|
||||
if (idx < 0)
|
||||
IT_TODO_HALT();
|
||||
std::cout << idx << " " << input->getDims().size() << std::endl;
|
||||
IT_ASSERT((size_t)idx < input->getDims().size());
|
||||
axis.emplace(idx);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue