diff --git a/include/operators/batch_norm.h b/include/operators/batch_norm.h index 20842615..8e41a043 100644 --- a/include/operators/batch_norm.h +++ b/include/operators/batch_norm.h @@ -39,9 +39,11 @@ class BatchNormObj : public OperatorObj { std::string toString() const override; // output size will be 3 when training - int numInputs() const override { return 5; } - int numOutputs() const override { return outputs.size(); } - float getEps() const { return eps; } + inline int numInputs() const override { return 5; } + inline int numOutputs() const override { return outputs.size(); } + inline float getMomentum() const { return momentum; } + inline float getEps() const { return eps; } + inline bool getTraining() const { return training; } private: vector getWorkloadVector() const override; diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index f8fd662c..28b8f514 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -32,490 +32,516 @@ def cuda_runtime(): return backend.cuda_runtime() -def from_onnx( - model: ModelProto, runtime -) -> Tuple[Dict[str, backend.Tensor], Dict[str, backend.Tensor], backend.GraphHandler]: - model = infer_shapes(model) - handler = backend.GraphHandler(runtime) - - tensors: Dict[str, backend.Tensor] = dict() - data: Dict[str, TensorProto] = dict() - - for input in model.graph.input: - dims = _take_shape_dim(input.type.tensor_type.shape) - tensors[input.name] = handler.tensor(dims, input.type.tensor_type.elem_type) - - for output in model.graph.output: - dims = _take_shape_dim(output.type.tensor_type.shape) - tensors[output.name] = handler.tensor(dims, output.type.tensor_type.elem_type) - - for initializer in model.graph.initializer: - data[initializer.name] = initializer - - for node in model.graph.node: - if node.op_type == "Conv": - attributes = _parse_attribute( - node, - { - "dilations": [1, 1], - "pads": [0, 0], - "strides": [1, 1], - }, - ) - (d, p, s) = (attributes[name] for name in ["dilations", "pads", "strides"]) - tensors[node.output[0]] = handler.conv( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - p[0], - p[1], - s[0], - s[1], - d[0], - d[1], - ) - elif node.op_type == "MatMul": - tensors[node.output[0]] = handler.matmul( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - False, - False, - None, - backend.ActType.Linear, - ) - elif node.op_type == "Gemm": - attributes = _parse_attribute( - node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0} - ) - (alpha, beta, transA, transB) = ( - attributes[name] for name in ["alpha", "beta", "transA", "transB"] - ) - # FIXME 不支持 `alpha` `beta` - assert alpha == 1.0 - assert beta == 1.0 - tensors[node.output[0]] = handler.matmul( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - transA == 1, - transB == 1, - tensors[node.input[2]] if len(node.input) > 2 else None, - backend.ActType.Linear, - ) - elif node.op_type == "BatchNormalization": - (input, mean, var, scale, bias) = ( - tensors[node.input[i]] for i in [0, 3, 4, 1, 2] - ) - output = tensors.get(node.output[0]) - attributes = _parse_attribute( - node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0} - ) - (momentum, eps, training) = ( - attributes[name] for name in ["momentum", "epsilon", "training_mode"] - ) - tensors[node.output[0]] = handler.batchNorm( - input, output, mean, var, scale, bias, momentum, eps, training != 0 - ) - elif node.op_type == "MaxPool": - attributes = _parse_attribute( - node, - { - "kernel_shape": None, - "dilations": [1, 1], - "pads": [0, 0], - "strides": [1, 1], - }, - ) - (k, d, p, s) = ( - attributes[name] - for name in ["kernel_shape", "dilations", "pads", "strides"] - ) - tensors[node.output[0]] = handler.maxPool( - tensors[node.input[0]], - tensors.get(node.output[0]), - k[0], - k[1], - d[0], - d[1], - p[0], - p[1], - s[0], - s[1], - ) - elif node.op_type == "AveragePool": - attributes = _parse_attribute( - node, - { - "kernel_shape": None, - "pads": [0, 0], - "strides": [1, 1], - }, - ) - (k, p, s) = ( - attributes[name] for name in ["kernel_shape", "pads", "strides"] - ) - tensors[node.output[0]] = handler.avgPool( - tensors[node.input[0]], - tensors.get(node.output[0]), - k[0], - k[1], - 1, - 1, - p[0], - p[1], - s[0], - s[1], - ) - elif node.op_type == "GlobalAveragePool": - shape = next( - ( - value.type.tensor_type.shape - for value in model.graph.value_info - if value.name == node.input[0] - ), - None, - ) or next( - input.type.tensor_type.shape - for input in model.graph.input - if input.name == node.input[0] - ) - [_, _, h, w] = _take_shape_dim(shape) - tensors[node.output[0]] = handler.avgPool( - tensors[node.input[0]], - tensors.get(node.output[0]), - h, - w, - 1, - 1, - 0, - 0, - 1, - 1, - ) - elif node.op_type == "Add": - tensors[node.output[0]] = handler.add( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Sub": - tensors[node.output[0]] = handler.sub( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Mul": - tensors[node.output[0]] = handler.mul( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Div": - tensors[node.output[0]] = handler.div( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Pow": - tensors[node.output[0]] = handler.pow( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Relu": - tensors[node.output[0]] = handler.relu( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Sigmoid": - tensors[node.output[0]] = handler.sigmoid( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Tanh": - tensors[node.output[0]] = handler.tanh( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Softmax": - tensors[node.output[0]] = handler.softmax( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Abs": - tensors[node.output[0]] = handler.abs( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Identity": - tensors[node.output[0]] = handler.identity( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Flatten": - # FIXME 后端算子不支持沿任意轴展开 - axis = next( - (attr.i for attr in node.attribute if attr.name == "axis"), None - ) - assert axis == None or axis == 1 - tensors[node.output[0]] = handler.flatten( - tensors[node.input[0]], - tensors.get(node.output[0]), - ) - elif node.op_type == "Reshape": - input_shape = next( - ( - value.type.tensor_type.shape - for value in model.graph.value_info - if value.name == node.input[0] - ), - None, - ) or next( - input.type.tensor_type.shape - for input in model.graph.input - if input.name == node.input[0] - ) - dims = _take_shape_dim(input_shape) - size = reduce(lambda acc, x: acc * x, dims) - output_shape = [int(i) for i in data[node.input[1]].int64_data] - for i, x in enumerate(output_shape): - if x == 0: - output_shape[i] = dims[i] - temp = reduce(lambda acc, x: acc * x, output_shape) - if temp < 0: - output_shape[output_shape.index(-1)] = size // -temp - tensors[node.output[0]] = handler.reshape( - tensors[node.input[0]], - tensors.get(node.output[0]), - output_shape, - ) - elif node.op_type == "Concat": - tensors[node.output[0]] = handler.concat( - [tensors[name] for name in node.input], - tensors.get(node.output[0]), - next((attr.i for attr in node.attribute if attr.name == "axis")), - ) - elif node.op_type == "Gather": - tensors[node.output[0]] = handler.gather( - tensors[node.input[0]], - tensors[node.input[1]], - tensors.get(node.output[0]), - next((attr.i for attr in node.attribute if attr.name == "axis")), - ) - elif node.op_type == "ReduceMean": - tensors[node.output[0]] = handler.reduceMean( - tensors[node.input[0]], - tensors.get(node.output[0]), - tensors[node.input[1]] if len(node.input) > 1 else None, - next((attr.i for attr in node.attribute if attr.name == "keepdims")) - != 0, - ) - elif node.op_type == "Slice": - tensors[node.output[0]] = handler.slice( - tensors[node.input[0]], - tensors.get(node.output[0]), - _parse_data(data[node.input[1]]), - _parse_data(data[node.input[2]]), - _parse_data(data[node.input[3]]) if len(node.input) > 3 else None, - _parse_data(data[node.input[4]]) if len(node.input) > 4 else None, - ) - elif node.op_type == "Pad": - tensors[node.output[0]] = handler.pad( - tensors[node.input[0]], - tensors.get(node.output[0]), - _parse_data(data[node.input[1]]), - _parse_data(data[node.input[3]]) if len(node.input) > 3 else None, - ) - else: - raise Exception('Unsupported operator "{}"'.format(node.op_type)) - - handler.data_malloc() - +class OnnxStub: inputs: Dict[str, backend.Tensor] = {} - for name, obj in tensors.items(): - tensor = data.get(name) - if tensor == None: - if any(input.name == name for input in model.graph.input): - inputs[name] = obj - else: - if tensor.data_type == TensorProto.INT32: - handler.copy_int32(obj, [int(i) for i in tensor.int32_data]) - elif tensor.data_type == TensorProto.INT64: - handler.copy_int64(obj, [int(i) for i in tensor.int64_data]) - elif tensor.data_type == TensorProto.FLOAT: - handler.copy_float(obj, [float(i) for i in tensor.float_data]) - else: - assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) - outputs: Dict[str, backend.Tensor] = {} - for output in model.graph.output: - outputs[output.name] = tensors[output.name] + handler: backend.GraphHandler - return inputs, outputs, handler + def __init__(self, model: ModelProto, runtime): + model = infer_shapes(model) + self.handler = backend.GraphHandler(runtime) + tensors: Dict[str, backend.Tensor] = dict() + data: Dict[str, TensorProto] = dict() -def to_onnx(graph: backend.GraphHandler, name: str) -> ModelProto: - class Context: - # saves object names, including tensors and operators - names: Dict[Any, str] = dict() - # counts the occurrence times of each operator for naming - count_op: Dict[backend.OpType, int] = dict() - # counts input and output tensors for naming - count_in, count_out = 0, 0 - # saves nodes (operators) - nodes: List[NodeProto] = [] - # saves global input tensors - inputs: List[ValueInfoProto] = [] - # saves global output tensors - outputs: List[ValueInfoProto] = [] - # saves global input tensors - initializers: List[TensorProto] = [] - - def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]: - ty = op.op_type() - name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1) - self.names[op] = name - self.count_op[ty] += 1 - return ty, name - - def push_output(self, name: str, tensor: backend.Tensor) -> str: - self.names[tensor] = name - if not tensor.has_target(): - shape = tensor.shape() - dtype = backend.tensor_dtype(tensor) - value_info = make_tensor_value_info(name, dtype, shape) - check_value_info(value_info) - self.outputs.append(value_info) - return name - - def push_input(self, tensor: backend.Tensor) -> str: - name = self.names.get(tensor) - # means that this input is a global input - if name is None: - self.count_in += 1 - name = "input{}".format(self.count_in) - self.names[tensor] = name - shape = tensor.shape() - dtype = backend.tensor_dtype(tensor) - value_info = make_tensor_value_info(name, dtype, shape) - check_value_info(value_info) - self.inputs.append(value_info) - - return name - - def push_data_input( - self, - node_name: str, - attr_name: str, - elem_type: int, - shape: Sequence[int], - vals: Any, - ) -> str: - name = "{}_{}".format(node_name, attr_name) - value_info = make_tensor_value_info(name, elem_type, shape) - tensor = make_tensor(name, elem_type, shape, vals) - check_value_info(value_info) - check_tensor(tensor) - self.inputs.append(value_info) - self.initializers.append(tensor) - return name - - def push_node(self, node: NodeProto) -> None: - check_node(node) - self.nodes.append(node) - - def build(self, name: str) -> ModelProto: - print() - print(ctx.names) - print() - print(ctx.inputs) - print() - print(ctx.outputs) - print() - print(ctx.nodes) - - graph = make_graph( - self.nodes, name, self.inputs, self.outputs, self.initializers + for input in model.graph.input: + dims = _take_shape_dim(input.type.tensor_type.shape) + tensors[input.name] = self.handler.tensor( + dims, input.type.tensor_type.elem_type ) - check_graph(graph) - model = make_model(graph) - check_model(model) + for output in model.graph.output: + dims = _take_shape_dim(output.type.tensor_type.shape) + tensors[output.name] = self.handler.tensor( + dims, output.type.tensor_type.elem_type + ) - return model + for initializer in model.graph.initializer: + data[initializer.name] = initializer - # 拓扑排序 - if not graph.topo_sort(): - raise Exception("Sorting fails") - - ops = graph.operators() # 图中所有算子(节点) - - ctx = Context() - - for op in ops: - ty, name = ctx.name_op(op) - inputs = [ctx.push_input(it) for it in op.inputs()] - outputs = [ - ctx.push_output("{}_{}".format(name, i), it) - for (i, it) in enumerate(op.outputs()) - ] - if ty == backend.OpType.Matmul: - ctx.push_node(make_node("MatMul", inputs, outputs, name)) - elif ty == backend.OpType.BatchNorm: - raise Exception("TODO") - elif ty == backend.OpType.MaxPool: - raise Exception("TODO") - elif ty == backend.OpType.AvgPool: - raise Exception("TODO") - elif ty in [ - backend.OpType.Add, - backend.OpType.Sub, - backend.OpType.Mul, - backend.OpType.Div, - backend.OpType.Pow, - backend.OpType.Relu, - backend.OpType.Sigmoid, - backend.OpType.Tanh, - backend.OpType.Softmax, - backend.OpType.Abs, - backend.OpType.Identity, - ]: - ctx.push_node(make_node(ty.name, inputs, outputs, name)) - elif ty == backend.OpType.Flatten: - raise Exception("TODO") - elif ty == backend.OpType.Reshape: - shape = backend.reshape_shape_of(op) - inputs.append( - ctx.push_data_input( - name, - "shape", - TensorProto.INT32, - [len(shape)], - shape, + for node in model.graph.node: + if node.op_type == "Conv": + attributes = _parse_attribute( + node, + { + "dilations": [1, 1], + "pads": [0, 0], + "strides": [1, 1], + }, ) - ) - ctx.push_node(make_node(ty.name, inputs, outputs, name)) - elif ty == backend.OpType.Concat: - axis = backend.concat_axis_of(op) - ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) - elif ty == backend.OpType.Gather: - axis = backend.gather_axis_of(op) - ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) - elif ty == backend.OpType.ReduceMean: - axes = backend.reduce_mean_axes_of(op) - inputs.append( - ctx.push_data_input(name, "axes", TensorProto.INT32, [len(axes)], axes) - ) - ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1)) - elif ty == backend.OpType.Slice: - raise Exception("TODO") - elif ty == backend.OpType.Pad: - raise Exception("TODO") - else: - raise Exception("Unsupported OpType {}".format(ty.name)) + (d, p, s) = ( + attributes[name] for name in ["dilations", "pads", "strides"] + ) + tensors[node.output[0]] = self.handler.conv( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + p[0], + p[1], + s[0], + s[1], + d[0], + d[1], + ) + elif node.op_type == "MatMul": + tensors[node.output[0]] = self.handler.matmul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + False, + False, + None, + backend.ActType.Linear, + ) + elif node.op_type == "Gemm": + attributes = _parse_attribute( + node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0} + ) + (alpha, beta, transA, transB) = ( + attributes[name] for name in ["alpha", "beta", "transA", "transB"] + ) + # TODO 不支持这些参数 + assert alpha == 1.0 + assert beta == 1.0 + tensors[node.output[0]] = self.handler.matmul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + transA == 1, + transB == 1, + tensors[node.input[2]] if len(node.input) > 2 else None, + backend.ActType.Linear, + ) + elif node.op_type == "BatchNormalization": + (input, mean, var, scale, bias) = ( + tensors[node.input[i]] for i in [0, 3, 4, 1, 2] + ) + output = tensors.get(node.output[0]) + attributes = _parse_attribute( + node, {"momentum": 0.9, "epsilon": 1e-05, "training_mode": 0} + ) + (momentum, eps, training) = ( + attributes[name] + for name in ["momentum", "epsilon", "training_mode"] + ) + tensors[node.output[0]] = self.handler.batchNorm( + input, output, mean, var, scale, bias, momentum, eps, training != 0 + ) + elif node.op_type == "MaxPool": + attributes = _parse_attribute( + node, + { + "kernel_shape": None, + "dilations": [1, 1], + "pads": [0, 0], + "strides": [1, 1], + }, + ) + (k, d, p, s) = ( + attributes[name] + for name in ["kernel_shape", "dilations", "pads", "strides"] + ) + tensors[node.output[0]] = self.handler.maxPool( + tensors[node.input[0]], + tensors.get(node.output[0]), + k[0], + k[1], + d[0], + d[1], + p[0], + p[1], + s[0], + s[1], + ) + elif node.op_type == "AveragePool": + attributes = _parse_attribute( + node, + { + "kernel_shape": None, + "pads": [0, 0], + "strides": [1, 1], + }, + ) + (k, p, s) = ( + attributes[name] for name in ["kernel_shape", "pads", "strides"] + ) + tensors[node.output[0]] = self.handler.avgPool( + tensors[node.input[0]], + tensors.get(node.output[0]), + k[0], + k[1], + 1, + 1, + p[0], + p[1], + s[0], + s[1], + ) + elif node.op_type == "GlobalAveragePool": + shape = next( + ( + value.type.tensor_type.shape + for value in model.graph.value_info + if value.name == node.input[0] + ), + None, + ) or next( + input.type.tensor_type.shape + for input in model.graph.input + if input.name == node.input[0] + ) + [_, _, h, w] = _take_shape_dim(shape) + tensors[node.output[0]] = self.handler.avgPool( + tensors[node.input[0]], + tensors.get(node.output[0]), + h, + w, + 1, + 1, + 0, + 0, + 1, + 1, + ) + elif node.op_type == "Add": + tensors[node.output[0]] = self.handler.add( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Sub": + tensors[node.output[0]] = self.handler.sub( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Mul": + tensors[node.output[0]] = self.handler.mul( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Div": + tensors[node.output[0]] = self.handler.div( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Pow": + tensors[node.output[0]] = self.handler.pow( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Relu": + tensors[node.output[0]] = self.handler.relu( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Sigmoid": + tensors[node.output[0]] = self.handler.sigmoid( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Tanh": + tensors[node.output[0]] = self.handler.tanh( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Softmax": + tensors[node.output[0]] = self.handler.softmax( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Abs": + tensors[node.output[0]] = self.handler.abs( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Identity": + tensors[node.output[0]] = self.handler.identity( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Flatten": + # TODO 后端算子不支持沿任意轴展开 + axis = next( + (attr.i for attr in node.attribute if attr.name == "axis"), None + ) + assert axis == None or axis == 1 + tensors[node.output[0]] = self.handler.flatten( + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + elif node.op_type == "Reshape": + input_shape = next( + ( + value.type.tensor_type.shape + for value in model.graph.value_info + if value.name == node.input[0] + ), + None, + ) or next( + input.type.tensor_type.shape + for input in model.graph.input + if input.name == node.input[0] + ) + dims = _take_shape_dim(input_shape) + size = reduce(lambda acc, x: acc * x, dims) + output_shape = [int(i) for i in data[node.input[1]].int64_data] + for i, x in enumerate(output_shape): + if x == 0: + output_shape[i] = dims[i] + temp = reduce(lambda acc, x: acc * x, output_shape) + if temp < 0: + output_shape[output_shape.index(-1)] = size // -temp + tensors[node.output[0]] = self.handler.reshape( + tensors[node.input[0]], + tensors.get(node.output[0]), + output_shape, + ) + elif node.op_type == "Concat": + tensors[node.output[0]] = self.handler.concat( + [tensors[name] for name in node.input], + tensors.get(node.output[0]), + next((attr.i for attr in node.attribute if attr.name == "axis")), + ) + elif node.op_type == "Gather": + tensors[node.output[0]] = self.handler.gather( + tensors[node.input[0]], + tensors[node.input[1]], + tensors.get(node.output[0]), + next((attr.i for attr in node.attribute if attr.name == "axis")), + ) + elif node.op_type == "ReduceMean": + tensors[node.output[0]] = self.handler.reduceMean( + tensors[node.input[0]], + tensors.get(node.output[0]), + tensors[node.input[1]] if len(node.input) > 1 else None, + next((attr.i for attr in node.attribute if attr.name == "keepdims")) + != 0, + ) + elif node.op_type == "Slice": + tensors[node.output[0]] = self.handler.slice( + tensors[node.input[0]], + tensors.get(node.output[0]), + _parse_data(data[node.input[1]]), + _parse_data(data[node.input[2]]), + _parse_data(data[node.input[3]]) if len(node.input) > 3 else None, + _parse_data(data[node.input[4]]) if len(node.input) > 4 else None, + ) + elif node.op_type == "Pad": + tensors[node.output[0]] = self.handler.pad( + tensors[node.input[0]], + tensors.get(node.output[0]), + _parse_data(data[node.input[1]]), + _parse_data(data[node.input[3]]) if len(node.input) > 3 else None, + ) + else: + raise Exception('Unsupported operator "{}"'.format(node.op_type)) - return ctx.build(name) + self.handler.data_malloc() + + for name, obj in tensors.items(): + tensor = data.get(name) + if tensor == None: + if any(input.name == name for input in model.graph.input): + self.inputs[name] = obj + else: + if tensor.data_type == TensorProto.INT32: + self.handler.copy_int32(obj, [int(i) for i in tensor.int32_data]) + elif tensor.data_type == TensorProto.INT64: + self.handler.copy_int64(obj, [int(i) for i in tensor.int64_data]) + elif tensor.data_type == TensorProto.FLOAT: + self.handler.copy_float(obj, [float(i) for i in tensor.float_data]) + else: + assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) + + for output in model.graph.output: + self.outputs[output.name] = tensors[output.name] + + def to_onnx(self, name: str) -> ModelProto: + class Context: + # saves object names, including tensors and operators + names: Dict[Any, str] = dict() + # counts the occurrence times of each operator for naming + count_op: Dict[backend.OpType, int] = dict() + # counts input and output tensors for naming + count_in, count_out = 0, 0 + # saves nodes (operators) + nodes: List[NodeProto] = [] + # saves global input tensors + inputs: List[ValueInfoProto] = [] + # saves global output tensors + outputs: List[ValueInfoProto] = [] + # saves global input tensors + initializers: List[TensorProto] = [] + + def name_op(self, op: backend.Operator) -> Tuple[backend.OpType, str]: + ty = op.op_type() + name = "{}{}".format(ty.name, self.count_op.setdefault(ty, 0) + 1) + self.names[op] = name + self.count_op[ty] += 1 + return ty, name + + def push_output(self, name: str, tensor: backend.Tensor) -> str: + self.names[tensor] = name + if not tensor.has_target(): + shape = tensor.shape() + dtype = backend.tensor_dtype(tensor) + value_info = make_tensor_value_info(name, dtype, shape) + check_value_info(value_info) + self.outputs.append(value_info) + return name + + def push_input(self, tensor: backend.Tensor) -> str: + name = self.names.get(tensor) + # means that this input is a global input + if name is None: + self.count_in += 1 + name = "input{}".format(self.count_in) + self.names[tensor] = name + shape = tensor.shape() + dtype = backend.tensor_dtype(tensor) + value_info = make_tensor_value_info(name, dtype, shape) + check_value_info(value_info) + self.inputs.append(value_info) + + return name + + def push_data_input( + self, + node_name: str, + attr_name: str, + elem_type: int, + shape: Sequence[int], + vals: Any, + ) -> str: + name = "{}_{}".format(node_name, attr_name) + value_info = make_tensor_value_info(name, elem_type, shape) + tensor = make_tensor(name, elem_type, shape, vals) + check_value_info(value_info) + check_tensor(tensor) + self.inputs.append(value_info) + self.initializers.append(tensor) + return name + + def push_node(self, node: NodeProto) -> None: + check_node(node) + self.nodes.append(node) + + def build(self, name: str) -> ModelProto: + print() + print(ctx.names) + print() + print(ctx.inputs) + print() + print(ctx.outputs) + print() + print(ctx.nodes) + + graph = make_graph( + self.nodes, name, self.inputs, self.outputs, self.initializers + ) + check_graph(graph) + + model = make_model(graph) + check_model(model) + + return model + + # 拓扑排序 + if not self.handler.topo_sort(): + raise Exception("Sorting fails") + + ops = self.handler.operators() # 图中所有算子(节点) + + ctx = Context() + + for op in ops: + ty, name = ctx.name_op(op) + inputs = [ctx.push_input(it) for it in op.inputs()] + outputs = [ + ctx.push_output("{}_{}".format(name, i), it) + for (i, it) in enumerate(op.outputs()) + ] + if ty == backend.OpType.Conv: + raise Exception("TODO") + elif ty == backend.OpType.Matmul: + ctx.push_node(make_node("MatMul", inputs, outputs, name)) + elif ty == backend.OpType.BatchNorm: + inputs = [inputs[i] for i in [0, 3, 4, 1, 2]] + momentum, eps, training = backend.batch_norm_attrs_of(op) + ctx.push_node( + make_node( + "BatchNormalization", + inputs, + outputs, + name, + epsilon=eps, + momentum=momentum, + training_mode=training, + ) + ) + elif ty == backend.OpType.MaxPool: + raise Exception("TODO") + elif ty == backend.OpType.AvgPool: + raise Exception("TODO") + elif ty in [ + backend.OpType.Add, + backend.OpType.Sub, + backend.OpType.Mul, + backend.OpType.Div, + backend.OpType.Pow, + backend.OpType.Relu, + backend.OpType.Sigmoid, + backend.OpType.Tanh, + backend.OpType.Softmax, + backend.OpType.Abs, + backend.OpType.Identity, + ]: + ctx.push_node(make_node(ty.name, inputs, outputs, name)) + elif ty == backend.OpType.Flatten: + raise Exception("TODO") + elif ty == backend.OpType.Reshape: + shape = backend.reshape_shape_of(op) + inputs.append( + ctx.push_data_input( + name, + "shape", + TensorProto.INT32, + [len(shape)], + shape, + ) + ) + ctx.push_node(make_node(ty.name, inputs, outputs, name)) + elif ty == backend.OpType.Concat: + axis = backend.concat_axis_of(op) + ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) + elif ty == backend.OpType.Gather: + axis = backend.gather_axis_of(op) + ctx.push_node(make_node(ty.name, inputs, outputs, name, axis=axis)) + elif ty == backend.OpType.ReduceMean: + axes = backend.reduce_mean_axes_of(op) + inputs.append( + ctx.push_data_input( + name, "axes", TensorProto.INT32, [len(axes)], axes + ) + ) + ctx.push_node(make_node(ty.name, inputs, outputs, name, keepdims=1)) + elif ty == backend.OpType.Slice: + raise Exception("TODO") + elif ty == backend.OpType.Pad: + raise Exception("TODO") + else: + raise Exception("Unsupported OpType {}".format(ty.name)) + + return ctx.build(name) + + +def from_onnx(model: ModelProto, runtime): + stub = OnnxStub(model, runtime) + return stub.inputs, stub.outputs, stub.handler def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]: diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index a529fb67..150a96e7 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -8,7 +8,7 @@ from onnx.helper import ( make_tensor_value_info, ) from onnx.checker import check_model -from pyinfinitensor.onnx import from_onnx, backend, to_onnx, cpu_runtime +from pyinfinitensor.onnx import from_onnx, backend, cpu_runtime def make_and_import_model(graph: onnx.GraphProto): @@ -305,8 +305,6 @@ class TestStringMethods(unittest.TestCase): y = handler.tensor([3, 2, 1], 12) handler.reshape(x, y, [3, 2, 1]) - to_onnx(handler, "test_frontend") - if __name__ == "__main__": unittest.main() diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 53e1376c..87599f28 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -1,4 +1,5 @@ #include "core/graph_handler.h" +#include "operators/batch_norm.h" #include "operators/concat.h" #include "operators/gather.h" #include "operators/reduce_mean.h" @@ -120,6 +121,13 @@ static Shape reshape_shape_of(Operator op) { return dynamic_cast(op.get())->getShape(); } +static std::tuple batch_norm_attrs_of(Operator op) { + IT_ASSERT(op->getOpType() == OpType::BatchNorm); + auto batchnorm = dynamic_cast(op.get()); + return std::make_tuple(batchnorm->getMomentum(), batchnorm->getEps(), + batchnorm->getTraining()); +} + void export_functions(py::module &m) { #define FUNCTION(NAME) def(#NAME, &NAME) m.def("cpu_runtime", &CpuRuntimeObj::getInstance) @@ -130,7 +138,8 @@ void export_functions(py::module &m) { .FUNCTION(reshape_shape_of) .FUNCTION(concat_axis_of) .FUNCTION(gather_axis_of) - .FUNCTION(reduce_mean_axes_of); + .FUNCTION(reduce_mean_axes_of) + .FUNCTION(batch_norm_attrs_of); #undef FUNCTION }