Compare commits

...

3 Commits

Author SHA1 Message Date
Liyan Zheng 1b1fc2585b Add: save optime result 2022-11-02 17:38:08 +08:00
Liyan Zheng eb993f7829 Add: evaluate onnx script 2022-11-02 16:51:33 +08:00
Liyan Zheng 63e5df4227 Add: fused conv 2022-11-02 16:39:12 +08:00
10 changed files with 833 additions and 47 deletions

View File

@ -1,10 +1,16 @@
#pragma once
#include "core/operator.h"
namespace infini {
namespace opTimer {
double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s,
int padh, int padw, int strideh, int stridew,
int dilationh, int dilationw, int group,
const char *name);
int dilationh, int dilationw, int group);
double getPerfConvBiasActCudnn(int n, int c, int h, int w, int f, int r, int s,
int padh, int padw, int strideh, int stridew,
int dilationh, int dilationw, int group,
bool bias, string act);
double getPerfConvTransposed2dCudnn(int n, int c, int h, int w, int f, int r,
int s, int padh, int padw, int strideh,

View File

@ -25,19 +25,22 @@ class ConvBaseObj : public OperatorObj {
int h, w; // input shape (same for conv2d and convTranposed2d)
int f; // output/input channel for conv2d/convTransposed2d
int r, s; // weight shape
ActType act;
public:
// Constructors for explicitly setting padding size
ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output, int ph, int pw,
int sh, int sw, int dh, int dw, const Tensor &inputInConvFWD,
const Tensor &weightInConvFWD);
const Tensor &weightInConvFWD, const ActType act);
ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
PaddingMode mode, int sh, int sw, int dh, int dw,
const Tensor &inputInConvFWD, const Tensor &weightInConvFWD);
const Tensor &inputInConvFWD, const Tensor &weightInConvFWD,
const ActType act);
std::string toString() const override;
int numInputs() const override { return 2; }
int numInputs() const override { return inputs.size(); }
int numOutputs() const override { return 1; }
bool hasBias() const { return inputs.size() == 3; }
Tensor getBias() const { return inputs[2]; }
PaddingMode getPaddingMode() const { return padding; }
@ -53,6 +56,7 @@ class ConvBaseObj : public OperatorObj {
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
int getChannelPerGroup() const { return inputs[1]->getDims()[1]; }
virtual int getNumGroups() const = 0;
ActType getAct() const { return act; }
private:
vector<int> getWorkloadVector() const override;
@ -65,8 +69,6 @@ class ConvBaseObj : public OperatorObj {
};
class ConvObj : public ConvBaseObj {
private:
ActType act;
public:
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph,
@ -79,7 +81,6 @@ class ConvObj : public ConvBaseObj {
ActType act = ActType::None);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
ActType getAct() const { return act; }
int getNumGroups() const override { return c / getChannelPerGroup(); }
private:
@ -90,7 +91,6 @@ class ConvTransposed2dObj : public ConvBaseObj {
private:
int oph, opw;
int group;
ActType act;
public:
ConvTransposed2dObj(GraphObj *graph, Tensor input, Tensor weight,
@ -106,7 +106,6 @@ class ConvTransposed2dObj : public ConvBaseObj {
Tensor bias = nullptr, ActType act = ActType::None);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
ActType getAct() const { return act; }
int getNumGroups() const override { return group; }
private:

View File

@ -0,0 +1,478 @@
import functools
import numpy as np
import onnx
import onnx.checker
import onnx.numpy_helper
import onnx.shape_inference
from rules import conv_transposed2d_rules, conv_rules, print_result
def _add_value_info_for_constants(model : onnx.ModelProto):
"""
Currently onnx.shape_inference doesn't use the shape of initializers, so add
that info explicitly as ValueInfoProtos.
Mutates the model.
Args:
model: The ModelProto to update.
"""
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
if model.ir_version < 4:
return
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
inputs = {i.name for i in graph.input}
existing_info = {vi.name: vi for vi in graph.value_info}
for init in graph.initializer:
# Check it really is a constant, not an input
if init.name in inputs:
continue
# The details we want to add
elem_type = init.data_type
shape = init.dims
# Get existing or create new value info for this constant
vi = existing_info.get(init.name)
if vi is None:
vi = graph.value_info.add()
vi.name = init.name
# Even though it would be weird, we will not overwrite info even if it doesn't match
tt = vi.type.tensor_type
if tt.elem_type == onnx.TensorProto.UNDEFINED:
tt.elem_type = elem_type
if not tt.HasField("shape"):
# Ensure we set an empty list if the const is scalar (zero dims)
tt.shape.dim.extend([])
for dim in shape:
tt.shape.dim.add().dim_value = dim
# Handle subgraphs
for node in graph.node:
for attr in node.attribute:
# Ref attrs refer to other attrs, so we don't need to do anything
if attr.ref_attr_name != "":
continue
if attr.type == onnx.AttributeProto.GRAPH:
add_const_value_infos_to_graph(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_const_value_infos_to_graph(g)
return add_const_value_infos_to_graph(model.graph)
def _parse_attribute(attributes, defaults=dict()):
atts = defaults
for att in attributes:
if att.type == onnx.AttributeProto.INT:
atts[att.name] = att.i
elif att.type == onnx.AttributeProto.INTS:
atts[att.name] = att.ints
elif att.type == onnx.AttributeProto.FLOAT:
atts[att.name] = att.f
elif att.type == onnx.AttributeProto.STRING:
atts[att.name] = att.s
elif att.type == onnx.AttributeProto.TENSOR:
atts[att.name] = att.t
else:
assert False, "Unsupported Attribute Type: {}".format(att.type)
return atts
def _onnx_datatype_tostring(dtype):
if dtype == 0:
return 'UNDEFINED'
elif dtype == 1:
return 'FLOAT'
elif dtype == 2:
return 'UINT8'
elif dtype == 3:
return 'INT8'
elif dtype == 4:
return 'UINT16'
elif dtype == 5:
return 'INT16'
elif dtype == 6:
return 'INT32'
elif dtype == 7:
return 'INT64'
elif dtype == 8:
return 'STRING'
elif dtype == 9:
return 'BOOL'
elif dtype == 10:
return 'FLOAT16'
elif dtype == 11:
return 'DOUBLE'
elif dtype == 12:
return 'UINT32'
elif dtype == 13:
return 'UINT64'
elif dtype == 14:
return 'COMPLEX64'
elif dtype == 15:
return 'COMPLEX128'
elif dtype == 16:
return 'BFLOAT16'
else:
assert False, 'Unknown onnx datatype'
def import_onnx(model_path: str, bs :int):
ts, ds, ops, consts = dict(), dict(), dict(), dict() # (key, value) = (name, class)
model = onnx.load(model_path)
# Tensor_input
for input in model.graph.input:
if input.name not in ts:
dims = [d.dim_value for d in input.type.tensor_type.shape.dim]
# ts[input.name] = g.tensor(dims, _onnx_datatype_tostring(input.type.tensor_type.elem_type))
ds[input.name] = dims
# Tensor_weight
for weight in model.graph.initializer:
if weight.name not in ts:
# ts[weight.name] = g.tensor(weight.dims, _onnx_datatype_tostring(weight.data_type))
ds[weight.name] = weight.dims
# Tensor_inference
_add_value_info_for_constants(model)
infered_model = onnx.shape_inference.infer_shapes(model)
for v in infered_model.graph.value_info:
if v.name not in ts:
dims = [d.dim_value for d in v.type.tensor_type.shape.dim]
# ts[v.name] = g.tensor(dims, _onnx_datatype_tostring(v.type.tensor_type.elem_type))
ds[v.name] = dims
# Tensor_output
for output in model.graph.output:
if output.name not in ts:
dims = [d.dim_value for d in output.type.tensor_type.shape.dim]
# ts[output.name] = g.tensor(dims, _onnx_datatype_tostring(output.type.tensor_type.elem_type))
ds[output.name] = dims
# Op
for node in model.graph.node:
# if node.op_type == 'Add':
# assert len(node.output) == 1
# g.add([ts[item] for item in node.input], ts[node.output[0]])
# elif node.op_type == 'Cast':
# assert len(node.input) == 1
# assert len(node.output) == 1
# # Ignore for now (TODO)
# g.identity(ts[node.input[0]], ts[node.output[0]])
if node.op_type == 'Conv':
attrs = _parse_attribute(node.attribute, {
"auto_pad": "NOTSET",
"dilations": [1, 1],
"pads": [0, 0, 0, 0],
"strides": [1, 1]})
assert len(node.input) == 2 or len(node.input) == 3
assert len(node.output) == 1
assert attrs["auto_pad"] == "NOTSET"
assert len(attrs["pads"]) == 4
assert len(attrs["strides"]) == 2
assert len(attrs["dilations"]) == 2
assert attrs["pads"][0] == attrs["pads"][2]
assert attrs["pads"][1] == attrs["pads"][3]
assert ds[node.input[0]][1] % ds[node.input[1]][1] == 0
n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw = ds[node.input[0]][0], ds[node.input[0]][1], ds[node.input[0]][2], ds[node.input[0]][3],ds[node.input[1]][0], ds[node.input[1]][2], ds[node.input[1]][3], attrs["pads"][0], attrs["pads"][1], attrs["strides"][0], attrs["strides"][1], attrs["dilations"][0], attrs["dilations"][1]
group = ds[node.input[0]][1] // ds[node.input[1]][1]
# t = getPerfConv(n, c, h, w, f, r, s, ph, pw,
# sh, sw, dh, dw, group, "")
# print(node.name, n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, group, f'{t:.3f}')
n=n*bs
for rule in conv_rules:
rule(node.name, n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, group)
elif node.op_type == 'ConvTranspose':
attrs = _parse_attribute(node.attribute, {
"auto_pad": "NOTSET",
"dilations": [1, 1],
"pads": [0, 0, 0, 0],
"strides": [1, 1],
"group": 1})
assert len(node.input) == 2 or len(node.input) == 3
assert len(node.output) == 1
assert attrs["auto_pad"] == "NOTSET"
assert len(attrs["pads"]) == 4
assert len(attrs["strides"]) == 2
assert len(attrs["dilations"]) == 2
assert attrs["pads"][0] == attrs["pads"][2]
assert attrs["pads"][1] == attrs["pads"][3]
n, f, h, w = ds[node.input[0]]
_, c, r, s = ds[node.input[1]]
ph, pw, sh, sw, dh, dw = attrs["pads"][0], attrs["pads"][1], attrs["strides"][0], attrs["strides"][1], attrs["dilations"][0], attrs["dilations"][1]
oph, opw = 0, 0
if "output_padding" in attrs:
oph, opw = attrs["output_padding"][0], attrs["output_padding"][1]
assert attrs["output_padding"][0] == attrs["output_padding"][1]
group=attrs["group"]
n=n*bs
for rule in conv_transposed2d_rules:
rule(node.name, n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, oph, opw, group)
elif node.op_type == 'MatMul':
print(f'{node.name} skipped')
continue
assert len(node.input) == 2
assert len(node.output) == 1
dimA = list(ds[node.input[0]])
dimB = list(ds[node.input[1]])
dimO = list(ds[node.output[0]])
# if len(dimA) == 2 and len(dimB) == 2:
# tmpI0 = g.tensor([1] + list(ds[node.input[0]]), "FLOAT")
# tmpI1 = g.tensor([1] + list(ds[node.input[1]]), "FLOAT")
# tmpO = g.tensor([1] + list(ds[node.output[0]]), "FLOAT")
# g.transpose(ts[node.input[0]], tmpI0, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
# g.transpose(ts[node.input[1]], tmpI1, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
# g.matmul(tmpI0, tmpI1, tmpO, False, False, None)
# g.transpose(tmpO, ts[node.output[0]], -1, Perm([PermItem([0, 1]), PermItem(2)]), 0)
# else:
# assert len(dimO) >= 3
# batch = functools.reduce(lambda x, y: x * y, dimO[:-2])
# if len(dimA) == 3:
# tmpI0 = ts[node.input[0]]
# else:
# tmpI0 = g.tensor([batch, dimA[-2], dimA[-1]], "FLOAT")
# g.reshape(ts[node.input[0]], tmpI0)
# if len(dimB) == 3:
# tmpI1 = ts[node.input[1]]
# else:
# tmpI1 = g.tensor([batch, dimB[-2], dimB[-1]], "FLOAT")
# g.reshape(ts[node.input[1]], tmpI1)
# if len(dimO) == 3:
# tmpO = ts[node.output[0]]
# g.matmul(tmpI0, tmpI1, tmpO, False, False, None)
# else:
# tmpO = g.tensor([batch, dimO[-2], dimO[-1]], "FLOAT")
# g.matmul(tmpI0, tmpI1, tmpO, False, False, None)
# g.reshape(tmpO, ts[node.output[0]])
# elif node.op_type == 'Concat':
# assert len(node.output) == 1
# attrs = _parse_attribute(node.attribute, {})
# g.concat([ts[item] for item in node.input], ts[node.output[0]], attrs["axis"])
# elif node.op_type == 'Constant':
# attrs = _parse_attribute(node.attribute, {})
# assert len(node.output) == 1
# c = onnx.numpy_helper.to_array(attrs["value"])
# if c.ndim == 0:
# c = c[()]
# consts[node.output[0]] = c
# elif node.op_type == 'Flatten':
# attrs = _parse_attribute(node.attribute, {"axis": 1})
# assert len(node.input) == 1
# assert len(node.output) == 1
# g.flatten(ts[node.input[0]], ts[node.output[0]], attrs["axis"])
# elif node.op_type == 'Gather':
# attrs = _parse_attribute(node.attribute, {"axis": 0})
# assert len(node.input) == 2
# assert len(node.output) == 1
# g.gather(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]], attrs["axis"])
# elif node.op_type == 'Gemm':
# attrs = _parse_attribute(node.attribute, {
# "alpha": 1.0,
# "beta": 1.0,
# "transA": 0,
# "transB": 0})
# assert len(node.input) == 2 or len(node.input) == 3
# assert len(node.output) == 1
# assert attrs["alpha"] == 1.0
# assert attrs["beta"] == 1.0 or len(node.input) == 2
# tmpI0 = g.tensor([1] + list(ds[node.input[0]]), "FLOAT")
# tmpI1 = g.tensor([1] + list(ds[node.input[1]]), "FLOAT")
# tmpO = g.tensor([1] + list(ds[node.output[0]]), "FLOAT")
# g.transpose(ts[node.input[0]], tmpI0, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
# g.transpose(ts[node.input[1]], tmpI1, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
# g.matmul(tmpI0, tmpI1, tmpO,
# attrs["transA"], attrs["transB"],
# None if len(node.input) == 2 else ts[node.input[2]])
# g.transpose(tmpO, ts[node.output[0]], -1, Perm([PermItem([0, 1]), PermItem(2)]), 0)
# elif node.op_type == 'Mul':
# assert len(node.output) == 1
# g.mul([ts[x] for x in node.input], ts[node.output[0]])
# elif node.op_type == 'GlobalAveragePool':
# assert len(node.input) == 1
# assert len(node.output) == 1
# dims = ds[node.input[0]]
# if len(dims) > 0:
# g.avgpool(ts[node.input[0]], ts[node.output[0]], dims[2], dims[3], 0, 0, 1, 1)
# else:
# g.avgpool(ts[node.input[0]], ts[node.output[0]])
# elif node.op_type == 'MaxPool':
# attrs = _parse_attribute(node.attribute, {
# "auto_pad": "NOTSET",
# "dilations": [1, 1],
# "pads": [0, 0, 0, 0],
# "strides": [1, 1]})
# assert len(node.input) == 1
# assert len(node.output) == 1
# assert len(attrs["kernel_shape"]) == 2
# assert len(attrs["pads"]) == 4
# assert len(attrs["strides"]) == 2
# assert len(attrs["dilations"]) == 2
# assert attrs["pads"][0] == attrs["pads"][2]
# assert attrs["pads"][1] == attrs["pads"][3]
# g.maxpool(ts[node.input[0]], ts[node.output[0]],
# attrs["kernel_shape"][0], attrs["kernel_shape"][1],
# attrs["dilations"][0], attrs["dilations"][1],
# attrs["pads"][0], attrs["pads"][1],
# attrs["strides"][0], attrs["strides"][1])
# elif node.op_type == 'AveragePool':
# attrs = _parse_attribute(node.attribute, {
# "auto_pad": "NOTSET",
# "count_include_pad": 0,
# "pads": [0, 0, 0, 0],
# "strides": [1, 1]})
# # No dilation in ONNX
# assert len(node.input) == 1
# assert len(node.output) == 1
# assert attrs["count_include_pad"] == 0 # To be consistent with operator.cc
# assert len(attrs["kernel_shape"]) == 2
# assert len(attrs["pads"]) == 4
# assert len(attrs["strides"]) == 2
# assert attrs["pads"][0] == attrs["pads"][2]
# assert attrs["pads"][1] == attrs["pads"][3]
# g.avgpool(ts[node.input[0]], ts[node.output[0]],
# attrs["kernel_shape"][0], attrs["kernel_shape"][1],
# attrs["pads"][0], attrs["pads"][1],
# attrs["strides"][0], attrs["strides"][1])
# elif node.op_type == 'Pad':
# attrs = _parse_attribute(node.attribute, {'mode': b'constant'})
# assert attrs["mode"].decode("ascii") == "constant"
# assert len(attrs["pads"]) % 2 == 0
# assert attrs["value"] == 0
# nDim = len(attrs["pads"]) // 2
# begin = attrs["pads"][:nDim]
# end = attrs["pads"][nDim:]
# g.pad(ts[node.input[0]], ts[node.output[0]], begin, end)
# elif node.op_type == 'ReduceMean':
# attrs = _parse_attribute(node.attribute, {'keepdims': 1})
# assert len(node.input) == 1
# assert len(node.output) == 1
# assert len(attrs["axes"]) == 1
# axis = attrs["axes"][0]
# if axis < 0:
# axis = len(ds[node.input[0]]) - axis
# g.reduceMean(ts[node.input[0]], ts[node.output[0]], axis)
# elif node.op_type == 'Softmax':
# attrs = _parse_attribute(node.attribute)
# assert len(node.input) == 1
# assert len(node.output) == 1
# axis = attrs["axis"]
# if axis < 0:
# axis = len(ds[node.input[0]]) - axis
# g.softmax(ts[node.input[0]], ts[node.output[0]], axis)
# elif node.op_type == 'Reshape':
# assert len(node.input) == 2
# assert len(node.output) == 1
# g.reshape(ts[node.input[0]], ts[node.output[0]])
# elif node.op_type == 'Relu':
# assert len(node.input) == 1
# assert len(node.output) == 1
# g.relu(ts[node.input[0]], ts[node.output[0]])
# elif node.op_type == 'Tanh':
# assert len(node.input) == 1
# assert len(node.output) == 1
# g.tanh(ts[node.input[0]], ts[node.output[0]])
# elif node.op_type == 'Sigmoid':
# assert len(node.input) == 1
# assert len(node.output) == 1
# g.sigmoid(ts[node.input[0]], ts[node.output[0]])
# elif node.op_type == 'Shape':
# # Ignore for now, and no need to output anything (TODO)
# pass
# elif node.op_type == 'Sub':
# assert len(node.input) == 2
# assert len(node.output) == 1
# g.sub(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]])
# elif node.op_type == 'Transpose':
# attrs = _parse_attribute(node.attribute, {})
# assert len(node.input) == 1
# assert len(node.output) == 1
# assert "perm" in attrs
# g.transpose(ts[node.input[0]], ts[node.output[0]], -1,
# Perm([PermItem(x) for x in attrs["perm"]]), 0)
# elif node.op_type == 'Unsqueeze':
# assert len(node.input) == 2
# assert len(node.output) == 1
# g.reshape(ts[node.input[0]], ts[node.output[0]])
# elif node.op_type == "BatchNormalization":
# attrs = _parse_attribute(node.attribute, {})
# assert len(node.input) == 5
# assert len(node.output) == 1
# epsilon = attrs['epsilon'] if 'epsilon' in attrs else 1e-5
# momentum = attrs['momentum'] if 'momentum' in attrs else 0.9
# g.batchnorm(ts[node.input[0]], ts[node.input[1]],
# ts[node.input[2]], ts[node.input[3]],
# ts[node.input[4]], ts[node.output[0]],
# epsilon, momentum)
# elif node.op_type == "Split":
# attrs = _parse_attribute(node.attribute, {})
# assert len(node.input) == 1
# assert len(node.output) > 1
# axis = attrs['axis']
# split = attrs['split']
# g.split(ts[node.input[0]], [ts[t] for t in node.output], axis, split)
# elif node.op_type == "Slice":
# attrs = _parse_attribute(node.attribute, {})
# assert len(node.input) == 4
# assert len(node.output) == 1
# g.slice(ts[node.input[0]], ts[node.output[0]],
# ts[node.input[1]], ts[node.input[2]])
# elif node.op_type == "Resize":
# attrs = _parse_attribute(node.attribute, {})
# assert len(node.input) == 4
# assert len(node.output) == 1
# roi = ts[node.input[1]] if node.input[1] != '' else g.tensor(
# [], 'FLOAT')
# g.resize(ts[node.input[0]], roi, ts[node.output[0]])
# else:
# assert False, "Unsupported op: " + node.op_type
if __name__ == "__main__":
import sys
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("model", help="ONNX model file")
parser.add_argument("bs", help="batch size", type=int, default=1)
# parser.add_argument("--output", help="Output file")
args = parser.parse_args()
import_onnx(args.model, args.bs)
print_result(args.model)

View File

@ -1,10 +1,11 @@
import argparse
from tokenize import Double
import pyinfinitensor # import getPerfConv, getPerfMatmul
def getPerfConv(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name=""):
def getPerfConv(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group):
return pyinfinitensor.getPerfConvCudnn(n, c, h, w, f, r, s, padh, padw,
strideh, stridew, dilationh, dilationw, group, name)
strideh, stridew, dilationh, dilationw, group)
def getPerfConvTransposed2dCudnn(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group):
@ -13,3 +14,43 @@ def getPerfConvTransposed2dCudnn(n, c, h, w, f, r, s, padh, padw, strideh, strid
def getPerfMatmul(b, m, n, k, name=""):
return pyinfinitensor.getPerfMatmulCublas(b, m, n, k, name)
def getPerfConvBiasActCudnn(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, bias: bool, act="None"):
return pyinfinitensor.getPerfConvBiasActCudnn(n, c, h, w, f, r, s, padh, padw,
strideh, stridew, dilationh, dilationw, group, bias, act)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('op', metavar='operator', type=str)
parser.add_argument('shape', nargs='+')
parser.add_argument('--pad', type=int, default=0)
parser.add_argument('--stride', type=int, default=1)
parser.add_argument('--dilation', type=int, default=1)
parser.add_argument('--group', type=int, default=1)
parser.add_argument('--bias', type=bool, default=False)
parser.add_argument('--act', type=str, default="None")
args = parser.parse_args()
print(args)
if args.op == 'gemm':
t = getPerfMatmul(int(args.shape[0]), int(
args.shape[1]), int(args.shape[2]), int(args.shape[3]))
print(
f'time {t:.3f} ms, {2*int(args.shape[0])*int(args.shape[1])*int(args.shape[2])*int(args.shape[3])/t/1e9:.3f} TFLOPS')
elif args.op == 'conv':
assert len(args.shape) == 7
n, c, h, w, f, r, s = [int(v) for v in args.shape]
padh = padw = int(args.pad)
strideh = stridew = int(args.stride)
dilationh = dilationw = int(args.dilation)
group = int(args.group)
bias = int(args.bias)
act = args.act
assert group==1, "Unsupported"
t = pyinfinitensor.getPerfConvBiasActCudnn(
n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, bias, act)
print(
f'time {t:.3f} ms, {n*c*h*w*f*r*s/strideh/stridew*2/10**9:.3f} TFlops')
else:
assert False, "Not supported"

View File

@ -0,0 +1,84 @@
import pandas as pd
import numpy as np
from operator_timer import *
from datetime import datetime
pd.options.display.float_format = '{:,.3f}'.format
df= pd.DataFrame(columns=['n', 'c', 'h', 'w', 'f', 'r', 's', 'ph', 'pw', 'sh', 'sw', 'dh', 'dw', 'oph', 'opw', 'group'])
def conv_original(name, n, c, h, w, f, r, s, ph, pw,
sh, sw, dh, dw, group):
df.loc[name, ['n', 'c', 'h', 'w', 'f', 'r', 's', 'ph', 'pw', 'sh', 'sw', 'dh', 'dw', 'group']] = n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, group
df.loc[name, 't_original'] = getPerfConv(n, c, h, w, f, r, s, ph, pw,
sh, sw, dh, dw, group)
df.loc[name, 't_bias'] = getPerfConvBiasActCudnn(n, c, h, w, f, r, s, ph, pw,
sh, sw, dh, dw, group, bias=True)
df.loc[name, 't_bias_relu'] = getPerfConvBiasActCudnn(n, c, h, w, f, r, s, ph, pw,
sh, sw, dh, dw, group, bias=True, act="Relu")
def conv_rule_5x5_to_3x3(name, n, c, h, w, f, r, s, ph, pw,
sh, sw, dh, dw, group):
col = 't_5x5_to_3x3'
if r == 5 and s == 5:
df.loc[name, col] = getPerfConv(n, c, h, w, f*4, 3, 3, ph, pw,
sh, sw, dh, dw, group)
else:
df.loc[name, col] = np.inf
def conv_rule_9x9_to_3x3(name, n, c, h, w, f, r, s, ph, pw,
sh, sw, dh, dw, group):
col = 't_9x9_to_3x3'
if r == 9 and s == 9:
df.loc[name, col] = getPerfConv(n, c, h, w, f*9, r//3, s//3, ph, pw,
sh, sw, dh, dw, group)
else:
df.loc[name, col] = np.inf
bandwidth=200*10**6 # (200GB/ms)
def conv_rule_conv2gemm(name, n, c, h, w, f, r, s, ph, pw,
sh, sw, dh, dw, group):
col = 't_conv2gemm'
if [sh, sw, dh, dw, group] == [1] * 5:
# b = group
# m = batch_size * input_height * input_width
# n = output_channel * kernel_height * kernel_width
# k = input_channel // group
t_reduce= group*n*h*w*f*r*s*4/bandwidth if r>1 or s>1 else 0
df.loc[name, '_'+col+'_mem'] = t_reduce
df.loc[name, col] = getPerfMatmul(group, n*h*w, f*r*s, c//group) + t_reduce
else:
df.loc[name, col] = np.inf
# conv_rules=[conv_original, conv_rule_9x9_to_3x3, conv_rule_5x5_to_3x3, conv_rule_conv2gemm]
conv_rules=[conv_original]
def conv_tranpsposed2d_original(name, n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, oph, opw, group):
df.loc[name, ['n', 'c', 'h', 'w', 'f', 'r', 's', 'ph', 'pw', 'sh', 'sw', 'dh', 'dw', 'oph', 'opw', 'group']] = n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, oph, opw, group
df.loc[name, 't_original'] = getPerfConvTransposed2dCudnn(n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, oph, opw, group)
def conv_tranpsposed2d_togemm(name, n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, oph, opw, group):
col = 't_conv2gemm'
if [dh, dw, group] == [1] * 3:
# ConvTransose2gemm
# b = 1
# m = batch_size * input_height*input_width
# n = output_channel*kernel_height*kernel_width
# k = input_channel
t_reduce= n*h*w*c*r*s*4/bandwidth if r>1 or s>1 else 0
df.loc[name, '_'+col+'_mem'] = t_reduce
print('t_conv2gemm', group, n*h*w, c*r*s, f)
df.loc[name, col] = getPerfMatmul(group, n*h*w, c*r*s, f) + t_reduce
else:
df.loc[name, col] = np.inf
conv_transposed2d_rules=[conv_tranpsposed2d_original, conv_tranpsposed2d_togemm]
def print_result(model_fn):
pd.set_option('display.max_rows', 500)
df['t_min'] = df.filter(regex=("^t_.*")).min(axis=1)
print(df)
print(f'Origin: {df["t_original"].sum():.3f} ms')
print(f'Min: {df["t_min"].sum():.3f} ms')
print(f'Speedup: {df["t_original"].sum()/df["t_min"].sum():.3f} x')
df.to_pickle(f'optime_{model_fn.split("/")[-1]}_{datetime.now().strftime("%m_%d_%H_%M_%S")}.pkl')

View File

@ -1,3 +1,4 @@
#include "cuda/operator_timer.h"
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
@ -12,8 +13,23 @@ namespace opTimer {
double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s,
int padh, int padw, int strideh, int stridew,
int dilationh, int dilationw, int group,
const char *name) {
int dilationh, int dilationw, int group) {
return getPerfConvBiasActCudnn(n, c, h, w, f, r, s, padh, padw, strideh,
stridew, dilationh, dilationw, group, false,
"None");
}
double getPerfConvBiasActCudnn(int n, int c, int h, int w, int f, int r, int s,
int padh, int padw, int strideh, int stridew,
int dilationh, int dilationw, int group,
bool bias, string actName) {
ActType act = ActType::None;
if (actName == "None")
act = ActType::None;
else if (actName == "Relu")
act = ActType::Relu;
else
IT_ASSERT(false, "Unsupported activation");
// const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew,
// dilationh, dilationw, group] =
// tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1};
@ -25,17 +41,27 @@ double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s,
IT_ASSERT(c % group == 0);
Tensor i0Cpu = gCpu->addTensor({n, c, h, w}, DataType::Float32);
Tensor w0Cpu = gCpu->addTensor({f, c / group, r, s}, DataType::Float32);
Tensor b0Cpu = gCpu->addTensor({f}, DataType::Float32);
// Malloc data for all tensors in a graph. Do we need implicit allocation?
gCpu->dataMalloc();
i0Cpu->setData(IncrementalGenerator());
w0Cpu->setData(IncrementalGenerator());
b0Cpu->setData(IncrementalGenerator());
// Copy input tensors from CPU to CUDA
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
Tensor b0Cuda = gCuda->cloneTensor(b0Cpu);
// Build CUDA graph
auto conv = gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, padh, padw,
strideh, stridew, dilationh, dilationw);
if (!bias) {
auto conv =
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, padh, padw, strideh,
stridew, dilationh, dilationw);
} else {
auto conv =
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, padh, padw, strideh,
stridew, dilationh, dilationw, b0Cuda, act);
}
// allocate CUDA memory
gCuda->dataMalloc();
// Execute on CUDA

View File

@ -13,8 +13,10 @@ void register_operator_timer(py::module &m) {
#ifdef USE_CUDA
using namespace opTimer;
m.def("getPerfConvCudnn", &getPerfConvCudnn);
m.def("getPerfConvBiasActCudnn", &getPerfConvBiasActCudnn);
m.def("getPerfConvTransposed2dCudnn", &getPerfConvTransposed2dCudnn);
m.def("getPerfMatmulCublas", &getPerfMatmulCublas);
m.def("getPerfMatmulCublas", &getPerfMatmulCublas);
#endif
}

View File

@ -8,6 +8,8 @@
namespace infini {
struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
int kernel =
0; // 0 cudnnConvolutionForward, 1 cudnnConvolutionBiasActivationForward
int algo = 0; // cudnnConvolutionFwdAlgo_t
int mode = 1;
size_t workspaceSize = 100000;
@ -56,8 +58,6 @@ class convCudnn : public Kernel {
const ConvCuDnnPerfRecord &record) const {
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
if (op->getInputs().size() > 2) // Bias is not supported yet
IT_TODO_HALT();
// void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
@ -209,6 +209,36 @@ class convCudnn : public Kernel {
return true;
}
bool cuDNNfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
const CudaRuntimeObj *context) const {
cudnnStatus_t stat;
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc] =
createCuDNNDescriptor(op, record);
size_t wsSize = record->workspaceSize;
CudaPtr wsData = context->getWorkspace(wsSize);
float alpha = 1.f, beta = 0.f;
// w/ bias & act
stat = cudnnConvolutionBiasActivationForward(
context->cudnnHandle(), &alpha, inDesc, inData, knDesc, knData,
convDesc, ALGOS[record->algo], wsData, wsSize, &beta, outDesc,
outData, biasDesc, nullptr, actDesc, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS)
return false;
// Destories in CUDA does not require sync. But cuDNN does not state
// whether sync is required before destories.
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
return true;
}
void compute(const Operator &op, const RuntimeObj *context) const override {
auto record = make_ref<ConvCuDnnPerfRecordObj>(); // with paramters in
// default ctor
@ -217,10 +247,88 @@ class convCudnn : public Kernel {
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
ConvCuDnnPerfRecordObj ret;
ret.time = std::numeric_limits<double>::max();
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
auto op = as<ConvObj>(_op);
printf("%s\n", op->toString().c_str());
if (op->hasBias() || op->getAct() != ActType::None)
return tuneFused(op, context);
else
return tuneUnfused(op, context);
}
PerfRecord tuneFused(const Ref<ConvObj> &op,
const CudaRuntimeObj *context) const {
ConvCuDnnPerfRecordObj ret;
ret.time = std::numeric_limits<double>::max();
// Both modes have the same performance. Only run cross-correlation.
for (int mode = 1; mode < 2; mode++) {
// Try every possible algorithm of convolution
for (int algo = 0; algo < N_ALGO; algo++) {
auto recordRef = make_ref<ConvCuDnnPerfRecordObj>();
auto &record = *recordRef;
record.mode = mode;
record.algo = algo;
cudnnStatus_t stat;
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
convDesc, actDesc, outDesc] =
createCuDNNDescriptor(op, recordRef);
void *biasData = op->getBias()->getRawDataPtr<void *>();
// get workspace
stat = cudnnGetConvolutionForwardWorkspaceSize(
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
ALGOS[record.algo], &record.workspaceSize);
if (stat != CUDNN_STATUS_SUCCESS)
continue;
if (record.workspaceSize > context->getWorkspaceSize())
continue;
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
float alpha = 1.f, beta = 0.f;
stat = cudnnConvolutionBiasActivationForward(
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
knData, convDesc, ALGOS[record.algo], wsData,
record.workspaceSize, &beta, outDesc, outData, biasDesc,
biasData, actDesc, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS)
continue;
record.time = timeit(
[&]() {
stat = cudnnConvolutionBiasActivationForward(
context->cudnnHandle(), &alpha, inDesc, inData,
knDesc, knData, convDesc, ALGOS[record.algo],
wsData, record.workspaceSize, &beta, outDesc,
outData, biasDesc, biasData, actDesc, outDesc,
outData);
},
[&]() { context->sync(); });
printf("mode %d, algo %d, time %.3lf ms\n", mode, algo,
record.time);
// Update the tune result
if (ret.time > record.time)
ret = record;
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
}
}
printf("tuneFused: the best algo is %d, the best conv mode is %d\n",
ret.algo, ret.mode);
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
"algorithm "
"found");
return make_ref<ConvCuDnnPerfRecordObj>(ret);
}
PerfRecord tuneUnfused(const Ref<ConvObj> &op,
const CudaRuntimeObj *context) const {
ConvCuDnnPerfRecordObj ret;
ret.time = std::numeric_limits<double>::max();
// Both modes have the same performance. Only run cross-correlation.
for (int mode = 1; mode < 2; mode++) {
// Try every possible algorithm of convolution
@ -260,7 +368,8 @@ class convCudnn : public Kernel {
&beta, outDesc, outData);
},
[&]() { context->sync(); });
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
printf("mode %d, algo %d, time %.3lf ms\n", mode, algo,
record.time);
// Update the tune result
if (ret.time > record.time)
@ -273,8 +382,8 @@ class convCudnn : public Kernel {
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
}
}
// printf("the best algo is %d, the best conv mode is %d\n", ret.algo,
// ret.mode);
printf("tuneUnfused: the best algo is %d, the best conv mode is %d\n",
ret.algo, ret.mode);
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
"algorithm "
"found");

View File

@ -5,15 +5,15 @@ namespace infini {
ConvBaseObj::ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
int ph, int pw, int sh, int sw, int dh, int dw,
const Tensor &inputInConvFWD,
const Tensor &weightInConvFWD)
const Tensor &weightInConvFWD, const ActType act)
: OperatorObj(opType, inputs, {output}), ph(ph), pw(pw), sh(sh), sw(sw),
dh(dh), dw(dw), padding(PaddingMode::Other) {}
dh(dh), dw(dw), padding(PaddingMode::Other), act(act) {}
ConvBaseObj::ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
PaddingMode mode, int sh, int sw, int dh, int dw,
const Tensor &inputInConvFWD,
const Tensor &weightInConvFWD)
const Tensor &weightInConvFWD, const ActType act)
: OperatorObj(opType, inputs, {output}), ph(-1), pw(-1), sh(sh), sw(sw),
dh(dh), dw(dw), padding(mode) {
dh(dh), dw(dw), padding(mode), act(act) {
IT_ASSERT(mode != PaddingMode::Other);
}
@ -21,28 +21,60 @@ string ConvBaseObj::toString() const {
std::ostringstream os;
os << OpRegistry::getOpName(getOpType()) << "[" << getGuid() << "]";
os << "(";
if (inputs.size() == 2) {
os << vecToString(inputs[0]->getDims()) << ",";
os << vecToString(inputs[1]->getDims()) << ",";
os << vecToString(inputs[0]->getDims()) << ",";
os << vecToString(inputs[1]->getDims()) << ",";
if (inputs.size() > 2) {
os << vecToString(inputs[2]->getDims()) << ",";
}
os << "p=[" << ph << "," << pw << "],";
os << "s=[" << sh << "," << sw << "],";
os << "d=[" << dh << "," << dw << "],";
os << "act=" << enum_to_underlying(getAct()) << ",";
// os << "act=" << enum_to_underlying(act) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "weight=" << inputs[1]->getGuid() << ",";
os << "bias="
<< ((inputs.size() == 2) ? "nullptr"
: std::to_string(inputs[2]->getGuid()))
<< ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> ConvBaseObj::getWorkloadVector() const {
return {
enum_to_underlying(type), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw};
return {enum_to_underlying(type),
n,
c,
h,
w,
f,
r,
s,
ph,
pw,
sh,
sw,
dh,
dw,
hasBias(),
enum_to_underlying(getAct())};
}
vector<int> ConvBaseObj::getOpAttrVector() const {
IT_TODO_HALT(); // should padding mode / ph+pw be in attrs?
return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw};
return {enum_to_underlying(type),
c,
f,
r,
s,
ph,
pw,
sh,
sw,
dh,
dw,
hasBias(),
enum_to_underlying(getAct())};
}
void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
@ -64,11 +96,10 @@ void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
ActType act)
: ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw,
input, weight),
act(act) {
if (bias)
IT_TODO_HALT();
: ConvBaseObj(OpType::Conv,
((bias) ? (TensorVec{input, weight, bias})
: (TensorVec{input, weight})),
output, ph, pw, sh, sw, dh, dw, input, weight, act) {
setAuxilaryAttributes(PaddingMode::Other);
IT_ASSERT(checkValid(graph));
}
@ -76,11 +107,10 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
ActType act)
: ConvBaseObj(OpType::Conv, {input, weight}, output, mode, sh, sw, dh, dw,
input, weight),
act(act) {
if (bias)
IT_TODO_HALT();
: ConvBaseObj(OpType::Conv,
((bias) ? (TensorVec{input, weight, bias})
: (TensorVec{input, weight})),
output, mode, sh, sw, dh, dw, input, weight, act) {
setAuxilaryAttributes(mode);
IT_ASSERT(checkValid(graph));
}
@ -98,6 +128,11 @@ optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
// For NCHW+FCRS layout, C of input is divisable by C of weight
if (input->getDims()[1] % weight->getDims()[1] != 0)
return {};
// check bias shape
if (inputs.size() == 3) {
if (inputs[2]->size() != (size_t)f)
return {};
}
// Set padding size
if (padding == PaddingMode::Other) {
oh = (h - (r - sh) * dh + ph * 2) / sh;
@ -122,8 +157,8 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
int oph, int opw, int group,
Tensor bias, ActType act)
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, ph, pw, sh, sw,
dh, dw, output, weight),
oph(oph), opw(opw), group(group), act(act) {
dh, dw, output, weight, act),
oph(oph), opw(opw), group(group) {
if (bias)
IT_TODO_HALT();
setAuxilaryAttributes(PaddingMode::Other);
@ -136,8 +171,8 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
int dh, int dw, int oph, int opw,
int group, Tensor bias, ActType act)
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, mode, sh, sw, dh,
dw, output, weight),
oph(oph), opw(opw), group(group), act(act) {
dw, output, weight, act),
oph(oph), opw(opw), group(group) {
if (bias)
IT_TODO_HALT();
setAuxilaryAttributes(mode);
@ -156,6 +191,8 @@ ConvTransposed2dObj::inferShape(const TensorVec &inputs) const {
auto s = weight->getDims()[3];
if (f != weight->getDims()[0])
return {};
if (inputs.size() != 2)
IT_TODO_HALT();
int on = n, oc = c * group;
int oh = 0, ow = 0;

View File

@ -0,0 +1,4 @@
. /home/spack/spack/share/spack/setup-env.sh
spack load /yb2wg5g # cuda@10.2.89
spack load /3bfwma4 # cudnn@7.6.5.32-10.2
export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-haswell/gcc-7.5.0/gcc-7.5.0-sti65cu3zunc4p4kfylgweim6mqan3mk/bin/gcc