forked from jiuyuan/InfiniTensor
Compare commits
3 Commits
Author | SHA1 | Date |
---|---|---|
![]() |
1b1fc2585b | |
![]() |
eb993f7829 | |
![]() |
63e5df4227 |
|
@ -1,10 +1,16 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
namespace infini {
|
||||
namespace opTimer {
|
||||
|
||||
double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s,
|
||||
int padh, int padw, int strideh, int stridew,
|
||||
int dilationh, int dilationw, int group,
|
||||
const char *name);
|
||||
int dilationh, int dilationw, int group);
|
||||
|
||||
double getPerfConvBiasActCudnn(int n, int c, int h, int w, int f, int r, int s,
|
||||
int padh, int padw, int strideh, int stridew,
|
||||
int dilationh, int dilationw, int group,
|
||||
bool bias, string act);
|
||||
|
||||
double getPerfConvTransposed2dCudnn(int n, int c, int h, int w, int f, int r,
|
||||
int s, int padh, int padw, int strideh,
|
||||
|
|
|
@ -25,19 +25,22 @@ class ConvBaseObj : public OperatorObj {
|
|||
int h, w; // input shape (same for conv2d and convTranposed2d)
|
||||
int f; // output/input channel for conv2d/convTransposed2d
|
||||
int r, s; // weight shape
|
||||
ActType act;
|
||||
|
||||
public:
|
||||
// Constructors for explicitly setting padding size
|
||||
ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output, int ph, int pw,
|
||||
int sh, int sw, int dh, int dw, const Tensor &inputInConvFWD,
|
||||
const Tensor &weightInConvFWD);
|
||||
const Tensor &weightInConvFWD, const ActType act);
|
||||
ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw,
|
||||
const Tensor &inputInConvFWD, const Tensor &weightInConvFWD);
|
||||
const Tensor &inputInConvFWD, const Tensor &weightInConvFWD,
|
||||
const ActType act);
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 2; }
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return 1; }
|
||||
bool hasBias() const { return inputs.size() == 3; }
|
||||
|
||||
Tensor getBias() const { return inputs[2]; }
|
||||
PaddingMode getPaddingMode() const { return padding; }
|
||||
|
@ -53,6 +56,7 @@ class ConvBaseObj : public OperatorObj {
|
|||
auto getPadStrideDilation() const { return tuple(ph, pw, sh, sw, dh, dw); }
|
||||
int getChannelPerGroup() const { return inputs[1]->getDims()[1]; }
|
||||
virtual int getNumGroups() const = 0;
|
||||
ActType getAct() const { return act; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
|
@ -65,8 +69,6 @@ class ConvBaseObj : public OperatorObj {
|
|||
};
|
||||
|
||||
class ConvObj : public ConvBaseObj {
|
||||
private:
|
||||
ActType act;
|
||||
|
||||
public:
|
||||
ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output, int ph,
|
||||
|
@ -79,7 +81,6 @@ class ConvObj : public ConvBaseObj {
|
|||
ActType act = ActType::None);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
ActType getAct() const { return act; }
|
||||
int getNumGroups() const override { return c / getChannelPerGroup(); }
|
||||
|
||||
private:
|
||||
|
@ -90,7 +91,6 @@ class ConvTransposed2dObj : public ConvBaseObj {
|
|||
private:
|
||||
int oph, opw;
|
||||
int group;
|
||||
ActType act;
|
||||
|
||||
public:
|
||||
ConvTransposed2dObj(GraphObj *graph, Tensor input, Tensor weight,
|
||||
|
@ -106,7 +106,6 @@ class ConvTransposed2dObj : public ConvBaseObj {
|
|||
Tensor bias = nullptr, ActType act = ActType::None);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
ActType getAct() const { return act; }
|
||||
int getNumGroups() const override { return group; }
|
||||
|
||||
private:
|
||||
|
|
|
@ -0,0 +1,478 @@
|
|||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
import onnx
|
||||
import onnx.checker
|
||||
import onnx.numpy_helper
|
||||
import onnx.shape_inference
|
||||
from rules import conv_transposed2d_rules, conv_rules, print_result
|
||||
|
||||
def _add_value_info_for_constants(model : onnx.ModelProto):
|
||||
"""
|
||||
Currently onnx.shape_inference doesn't use the shape of initializers, so add
|
||||
that info explicitly as ValueInfoProtos.
|
||||
Mutates the model.
|
||||
Args:
|
||||
model: The ModelProto to update.
|
||||
"""
|
||||
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
|
||||
if model.ir_version < 4:
|
||||
return
|
||||
|
||||
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
|
||||
inputs = {i.name for i in graph.input}
|
||||
existing_info = {vi.name: vi for vi in graph.value_info}
|
||||
for init in graph.initializer:
|
||||
# Check it really is a constant, not an input
|
||||
if init.name in inputs:
|
||||
continue
|
||||
|
||||
# The details we want to add
|
||||
elem_type = init.data_type
|
||||
shape = init.dims
|
||||
|
||||
# Get existing or create new value info for this constant
|
||||
vi = existing_info.get(init.name)
|
||||
if vi is None:
|
||||
vi = graph.value_info.add()
|
||||
vi.name = init.name
|
||||
|
||||
# Even though it would be weird, we will not overwrite info even if it doesn't match
|
||||
tt = vi.type.tensor_type
|
||||
if tt.elem_type == onnx.TensorProto.UNDEFINED:
|
||||
tt.elem_type = elem_type
|
||||
if not tt.HasField("shape"):
|
||||
# Ensure we set an empty list if the const is scalar (zero dims)
|
||||
tt.shape.dim.extend([])
|
||||
for dim in shape:
|
||||
tt.shape.dim.add().dim_value = dim
|
||||
|
||||
# Handle subgraphs
|
||||
for node in graph.node:
|
||||
for attr in node.attribute:
|
||||
# Ref attrs refer to other attrs, so we don't need to do anything
|
||||
if attr.ref_attr_name != "":
|
||||
continue
|
||||
|
||||
if attr.type == onnx.AttributeProto.GRAPH:
|
||||
add_const_value_infos_to_graph(attr.g)
|
||||
if attr.type == onnx.AttributeProto.GRAPHS:
|
||||
for g in attr.graphs:
|
||||
add_const_value_infos_to_graph(g)
|
||||
|
||||
return add_const_value_infos_to_graph(model.graph)
|
||||
|
||||
|
||||
def _parse_attribute(attributes, defaults=dict()):
|
||||
atts = defaults
|
||||
for att in attributes:
|
||||
if att.type == onnx.AttributeProto.INT:
|
||||
atts[att.name] = att.i
|
||||
elif att.type == onnx.AttributeProto.INTS:
|
||||
atts[att.name] = att.ints
|
||||
elif att.type == onnx.AttributeProto.FLOAT:
|
||||
atts[att.name] = att.f
|
||||
elif att.type == onnx.AttributeProto.STRING:
|
||||
atts[att.name] = att.s
|
||||
elif att.type == onnx.AttributeProto.TENSOR:
|
||||
atts[att.name] = att.t
|
||||
else:
|
||||
assert False, "Unsupported Attribute Type: {}".format(att.type)
|
||||
return atts
|
||||
|
||||
|
||||
def _onnx_datatype_tostring(dtype):
|
||||
if dtype == 0:
|
||||
return 'UNDEFINED'
|
||||
elif dtype == 1:
|
||||
return 'FLOAT'
|
||||
elif dtype == 2:
|
||||
return 'UINT8'
|
||||
elif dtype == 3:
|
||||
return 'INT8'
|
||||
elif dtype == 4:
|
||||
return 'UINT16'
|
||||
elif dtype == 5:
|
||||
return 'INT16'
|
||||
elif dtype == 6:
|
||||
return 'INT32'
|
||||
elif dtype == 7:
|
||||
return 'INT64'
|
||||
elif dtype == 8:
|
||||
return 'STRING'
|
||||
elif dtype == 9:
|
||||
return 'BOOL'
|
||||
elif dtype == 10:
|
||||
return 'FLOAT16'
|
||||
elif dtype == 11:
|
||||
return 'DOUBLE'
|
||||
elif dtype == 12:
|
||||
return 'UINT32'
|
||||
elif dtype == 13:
|
||||
return 'UINT64'
|
||||
elif dtype == 14:
|
||||
return 'COMPLEX64'
|
||||
elif dtype == 15:
|
||||
return 'COMPLEX128'
|
||||
elif dtype == 16:
|
||||
return 'BFLOAT16'
|
||||
else:
|
||||
assert False, 'Unknown onnx datatype'
|
||||
|
||||
|
||||
|
||||
def import_onnx(model_path: str, bs :int):
|
||||
ts, ds, ops, consts = dict(), dict(), dict(), dict() # (key, value) = (name, class)
|
||||
model = onnx.load(model_path)
|
||||
|
||||
# Tensor_input
|
||||
for input in model.graph.input:
|
||||
if input.name not in ts:
|
||||
dims = [d.dim_value for d in input.type.tensor_type.shape.dim]
|
||||
# ts[input.name] = g.tensor(dims, _onnx_datatype_tostring(input.type.tensor_type.elem_type))
|
||||
ds[input.name] = dims
|
||||
|
||||
# Tensor_weight
|
||||
for weight in model.graph.initializer:
|
||||
if weight.name not in ts:
|
||||
# ts[weight.name] = g.tensor(weight.dims, _onnx_datatype_tostring(weight.data_type))
|
||||
ds[weight.name] = weight.dims
|
||||
|
||||
# Tensor_inference
|
||||
_add_value_info_for_constants(model)
|
||||
infered_model = onnx.shape_inference.infer_shapes(model)
|
||||
for v in infered_model.graph.value_info:
|
||||
if v.name not in ts:
|
||||
dims = [d.dim_value for d in v.type.tensor_type.shape.dim]
|
||||
# ts[v.name] = g.tensor(dims, _onnx_datatype_tostring(v.type.tensor_type.elem_type))
|
||||
ds[v.name] = dims
|
||||
|
||||
# Tensor_output
|
||||
for output in model.graph.output:
|
||||
if output.name not in ts:
|
||||
dims = [d.dim_value for d in output.type.tensor_type.shape.dim]
|
||||
# ts[output.name] = g.tensor(dims, _onnx_datatype_tostring(output.type.tensor_type.elem_type))
|
||||
ds[output.name] = dims
|
||||
|
||||
# Op
|
||||
for node in model.graph.node:
|
||||
# if node.op_type == 'Add':
|
||||
# assert len(node.output) == 1
|
||||
# g.add([ts[item] for item in node.input], ts[node.output[0]])
|
||||
|
||||
# elif node.op_type == 'Cast':
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# # Ignore for now (TODO)
|
||||
# g.identity(ts[node.input[0]], ts[node.output[0]])
|
||||
|
||||
if node.op_type == 'Conv':
|
||||
attrs = _parse_attribute(node.attribute, {
|
||||
"auto_pad": "NOTSET",
|
||||
"dilations": [1, 1],
|
||||
"pads": [0, 0, 0, 0],
|
||||
"strides": [1, 1]})
|
||||
assert len(node.input) == 2 or len(node.input) == 3
|
||||
assert len(node.output) == 1
|
||||
assert attrs["auto_pad"] == "NOTSET"
|
||||
assert len(attrs["pads"]) == 4
|
||||
assert len(attrs["strides"]) == 2
|
||||
assert len(attrs["dilations"]) == 2
|
||||
assert attrs["pads"][0] == attrs["pads"][2]
|
||||
assert attrs["pads"][1] == attrs["pads"][3]
|
||||
assert ds[node.input[0]][1] % ds[node.input[1]][1] == 0
|
||||
n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw = ds[node.input[0]][0], ds[node.input[0]][1], ds[node.input[0]][2], ds[node.input[0]][3],ds[node.input[1]][0], ds[node.input[1]][2], ds[node.input[1]][3], attrs["pads"][0], attrs["pads"][1], attrs["strides"][0], attrs["strides"][1], attrs["dilations"][0], attrs["dilations"][1]
|
||||
group = ds[node.input[0]][1] // ds[node.input[1]][1]
|
||||
# t = getPerfConv(n, c, h, w, f, r, s, ph, pw,
|
||||
# sh, sw, dh, dw, group, "")
|
||||
# print(node.name, n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, group, f'{t:.3f}')
|
||||
n=n*bs
|
||||
for rule in conv_rules:
|
||||
rule(node.name, n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, group)
|
||||
|
||||
elif node.op_type == 'ConvTranspose':
|
||||
attrs = _parse_attribute(node.attribute, {
|
||||
"auto_pad": "NOTSET",
|
||||
"dilations": [1, 1],
|
||||
"pads": [0, 0, 0, 0],
|
||||
"strides": [1, 1],
|
||||
"group": 1})
|
||||
assert len(node.input) == 2 or len(node.input) == 3
|
||||
assert len(node.output) == 1
|
||||
assert attrs["auto_pad"] == "NOTSET"
|
||||
assert len(attrs["pads"]) == 4
|
||||
assert len(attrs["strides"]) == 2
|
||||
assert len(attrs["dilations"]) == 2
|
||||
assert attrs["pads"][0] == attrs["pads"][2]
|
||||
assert attrs["pads"][1] == attrs["pads"][3]
|
||||
n, f, h, w = ds[node.input[0]]
|
||||
_, c, r, s = ds[node.input[1]]
|
||||
ph, pw, sh, sw, dh, dw = attrs["pads"][0], attrs["pads"][1], attrs["strides"][0], attrs["strides"][1], attrs["dilations"][0], attrs["dilations"][1]
|
||||
oph, opw = 0, 0
|
||||
if "output_padding" in attrs:
|
||||
oph, opw = attrs["output_padding"][0], attrs["output_padding"][1]
|
||||
assert attrs["output_padding"][0] == attrs["output_padding"][1]
|
||||
group=attrs["group"]
|
||||
n=n*bs
|
||||
for rule in conv_transposed2d_rules:
|
||||
rule(node.name, n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, oph, opw, group)
|
||||
|
||||
elif node.op_type == 'MatMul':
|
||||
print(f'{node.name} skipped')
|
||||
continue
|
||||
assert len(node.input) == 2
|
||||
assert len(node.output) == 1
|
||||
dimA = list(ds[node.input[0]])
|
||||
dimB = list(ds[node.input[1]])
|
||||
dimO = list(ds[node.output[0]])
|
||||
|
||||
# if len(dimA) == 2 and len(dimB) == 2:
|
||||
# tmpI0 = g.tensor([1] + list(ds[node.input[0]]), "FLOAT")
|
||||
# tmpI1 = g.tensor([1] + list(ds[node.input[1]]), "FLOAT")
|
||||
# tmpO = g.tensor([1] + list(ds[node.output[0]]), "FLOAT")
|
||||
# g.transpose(ts[node.input[0]], tmpI0, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
|
||||
# g.transpose(ts[node.input[1]], tmpI1, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
|
||||
# g.matmul(tmpI0, tmpI1, tmpO, False, False, None)
|
||||
# g.transpose(tmpO, ts[node.output[0]], -1, Perm([PermItem([0, 1]), PermItem(2)]), 0)
|
||||
|
||||
# else:
|
||||
# assert len(dimO) >= 3
|
||||
# batch = functools.reduce(lambda x, y: x * y, dimO[:-2])
|
||||
|
||||
# if len(dimA) == 3:
|
||||
# tmpI0 = ts[node.input[0]]
|
||||
# else:
|
||||
# tmpI0 = g.tensor([batch, dimA[-2], dimA[-1]], "FLOAT")
|
||||
# g.reshape(ts[node.input[0]], tmpI0)
|
||||
|
||||
# if len(dimB) == 3:
|
||||
# tmpI1 = ts[node.input[1]]
|
||||
# else:
|
||||
# tmpI1 = g.tensor([batch, dimB[-2], dimB[-1]], "FLOAT")
|
||||
# g.reshape(ts[node.input[1]], tmpI1)
|
||||
|
||||
# if len(dimO) == 3:
|
||||
# tmpO = ts[node.output[0]]
|
||||
# g.matmul(tmpI0, tmpI1, tmpO, False, False, None)
|
||||
# else:
|
||||
# tmpO = g.tensor([batch, dimO[-2], dimO[-1]], "FLOAT")
|
||||
# g.matmul(tmpI0, tmpI1, tmpO, False, False, None)
|
||||
# g.reshape(tmpO, ts[node.output[0]])
|
||||
|
||||
# elif node.op_type == 'Concat':
|
||||
# assert len(node.output) == 1
|
||||
# attrs = _parse_attribute(node.attribute, {})
|
||||
# g.concat([ts[item] for item in node.input], ts[node.output[0]], attrs["axis"])
|
||||
|
||||
# elif node.op_type == 'Constant':
|
||||
# attrs = _parse_attribute(node.attribute, {})
|
||||
# assert len(node.output) == 1
|
||||
# c = onnx.numpy_helper.to_array(attrs["value"])
|
||||
# if c.ndim == 0:
|
||||
# c = c[()]
|
||||
# consts[node.output[0]] = c
|
||||
|
||||
# elif node.op_type == 'Flatten':
|
||||
# attrs = _parse_attribute(node.attribute, {"axis": 1})
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# g.flatten(ts[node.input[0]], ts[node.output[0]], attrs["axis"])
|
||||
|
||||
# elif node.op_type == 'Gather':
|
||||
# attrs = _parse_attribute(node.attribute, {"axis": 0})
|
||||
# assert len(node.input) == 2
|
||||
# assert len(node.output) == 1
|
||||
# g.gather(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]], attrs["axis"])
|
||||
|
||||
# elif node.op_type == 'Gemm':
|
||||
# attrs = _parse_attribute(node.attribute, {
|
||||
# "alpha": 1.0,
|
||||
# "beta": 1.0,
|
||||
# "transA": 0,
|
||||
# "transB": 0})
|
||||
# assert len(node.input) == 2 or len(node.input) == 3
|
||||
# assert len(node.output) == 1
|
||||
# assert attrs["alpha"] == 1.0
|
||||
# assert attrs["beta"] == 1.0 or len(node.input) == 2
|
||||
# tmpI0 = g.tensor([1] + list(ds[node.input[0]]), "FLOAT")
|
||||
# tmpI1 = g.tensor([1] + list(ds[node.input[1]]), "FLOAT")
|
||||
# tmpO = g.tensor([1] + list(ds[node.output[0]]), "FLOAT")
|
||||
# g.transpose(ts[node.input[0]], tmpI0, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
|
||||
# g.transpose(ts[node.input[1]], tmpI1, 0, Perm([PermItem(-1), PermItem(0), PermItem(1)]), 1)
|
||||
# g.matmul(tmpI0, tmpI1, tmpO,
|
||||
# attrs["transA"], attrs["transB"],
|
||||
# None if len(node.input) == 2 else ts[node.input[2]])
|
||||
# g.transpose(tmpO, ts[node.output[0]], -1, Perm([PermItem([0, 1]), PermItem(2)]), 0)
|
||||
|
||||
# elif node.op_type == 'Mul':
|
||||
# assert len(node.output) == 1
|
||||
# g.mul([ts[x] for x in node.input], ts[node.output[0]])
|
||||
|
||||
# elif node.op_type == 'GlobalAveragePool':
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# dims = ds[node.input[0]]
|
||||
# if len(dims) > 0:
|
||||
# g.avgpool(ts[node.input[0]], ts[node.output[0]], dims[2], dims[3], 0, 0, 1, 1)
|
||||
# else:
|
||||
# g.avgpool(ts[node.input[0]], ts[node.output[0]])
|
||||
|
||||
# elif node.op_type == 'MaxPool':
|
||||
# attrs = _parse_attribute(node.attribute, {
|
||||
# "auto_pad": "NOTSET",
|
||||
# "dilations": [1, 1],
|
||||
# "pads": [0, 0, 0, 0],
|
||||
# "strides": [1, 1]})
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# assert len(attrs["kernel_shape"]) == 2
|
||||
# assert len(attrs["pads"]) == 4
|
||||
# assert len(attrs["strides"]) == 2
|
||||
# assert len(attrs["dilations"]) == 2
|
||||
# assert attrs["pads"][0] == attrs["pads"][2]
|
||||
# assert attrs["pads"][1] == attrs["pads"][3]
|
||||
# g.maxpool(ts[node.input[0]], ts[node.output[0]],
|
||||
# attrs["kernel_shape"][0], attrs["kernel_shape"][1],
|
||||
# attrs["dilations"][0], attrs["dilations"][1],
|
||||
# attrs["pads"][0], attrs["pads"][1],
|
||||
# attrs["strides"][0], attrs["strides"][1])
|
||||
|
||||
# elif node.op_type == 'AveragePool':
|
||||
# attrs = _parse_attribute(node.attribute, {
|
||||
# "auto_pad": "NOTSET",
|
||||
# "count_include_pad": 0,
|
||||
# "pads": [0, 0, 0, 0],
|
||||
# "strides": [1, 1]})
|
||||
# # No dilation in ONNX
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# assert attrs["count_include_pad"] == 0 # To be consistent with operator.cc
|
||||
# assert len(attrs["kernel_shape"]) == 2
|
||||
# assert len(attrs["pads"]) == 4
|
||||
# assert len(attrs["strides"]) == 2
|
||||
# assert attrs["pads"][0] == attrs["pads"][2]
|
||||
# assert attrs["pads"][1] == attrs["pads"][3]
|
||||
# g.avgpool(ts[node.input[0]], ts[node.output[0]],
|
||||
# attrs["kernel_shape"][0], attrs["kernel_shape"][1],
|
||||
# attrs["pads"][0], attrs["pads"][1],
|
||||
# attrs["strides"][0], attrs["strides"][1])
|
||||
|
||||
# elif node.op_type == 'Pad':
|
||||
# attrs = _parse_attribute(node.attribute, {'mode': b'constant'})
|
||||
# assert attrs["mode"].decode("ascii") == "constant"
|
||||
# assert len(attrs["pads"]) % 2 == 0
|
||||
# assert attrs["value"] == 0
|
||||
# nDim = len(attrs["pads"]) // 2
|
||||
# begin = attrs["pads"][:nDim]
|
||||
# end = attrs["pads"][nDim:]
|
||||
# g.pad(ts[node.input[0]], ts[node.output[0]], begin, end)
|
||||
|
||||
# elif node.op_type == 'ReduceMean':
|
||||
# attrs = _parse_attribute(node.attribute, {'keepdims': 1})
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# assert len(attrs["axes"]) == 1
|
||||
# axis = attrs["axes"][0]
|
||||
# if axis < 0:
|
||||
# axis = len(ds[node.input[0]]) - axis
|
||||
# g.reduceMean(ts[node.input[0]], ts[node.output[0]], axis)
|
||||
|
||||
# elif node.op_type == 'Softmax':
|
||||
# attrs = _parse_attribute(node.attribute)
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# axis = attrs["axis"]
|
||||
# if axis < 0:
|
||||
# axis = len(ds[node.input[0]]) - axis
|
||||
# g.softmax(ts[node.input[0]], ts[node.output[0]], axis)
|
||||
|
||||
# elif node.op_type == 'Reshape':
|
||||
# assert len(node.input) == 2
|
||||
# assert len(node.output) == 1
|
||||
# g.reshape(ts[node.input[0]], ts[node.output[0]])
|
||||
|
||||
# elif node.op_type == 'Relu':
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# g.relu(ts[node.input[0]], ts[node.output[0]])
|
||||
|
||||
# elif node.op_type == 'Tanh':
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# g.tanh(ts[node.input[0]], ts[node.output[0]])
|
||||
|
||||
# elif node.op_type == 'Sigmoid':
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# g.sigmoid(ts[node.input[0]], ts[node.output[0]])
|
||||
|
||||
# elif node.op_type == 'Shape':
|
||||
# # Ignore for now, and no need to output anything (TODO)
|
||||
# pass
|
||||
|
||||
# elif node.op_type == 'Sub':
|
||||
# assert len(node.input) == 2
|
||||
# assert len(node.output) == 1
|
||||
# g.sub(ts[node.input[0]], ts[node.input[1]], ts[node.output[0]])
|
||||
|
||||
# elif node.op_type == 'Transpose':
|
||||
# attrs = _parse_attribute(node.attribute, {})
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) == 1
|
||||
# assert "perm" in attrs
|
||||
# g.transpose(ts[node.input[0]], ts[node.output[0]], -1,
|
||||
# Perm([PermItem(x) for x in attrs["perm"]]), 0)
|
||||
|
||||
# elif node.op_type == 'Unsqueeze':
|
||||
# assert len(node.input) == 2
|
||||
# assert len(node.output) == 1
|
||||
# g.reshape(ts[node.input[0]], ts[node.output[0]])
|
||||
|
||||
# elif node.op_type == "BatchNormalization":
|
||||
# attrs = _parse_attribute(node.attribute, {})
|
||||
# assert len(node.input) == 5
|
||||
# assert len(node.output) == 1
|
||||
# epsilon = attrs['epsilon'] if 'epsilon' in attrs else 1e-5
|
||||
# momentum = attrs['momentum'] if 'momentum' in attrs else 0.9
|
||||
# g.batchnorm(ts[node.input[0]], ts[node.input[1]],
|
||||
# ts[node.input[2]], ts[node.input[3]],
|
||||
# ts[node.input[4]], ts[node.output[0]],
|
||||
# epsilon, momentum)
|
||||
|
||||
# elif node.op_type == "Split":
|
||||
# attrs = _parse_attribute(node.attribute, {})
|
||||
# assert len(node.input) == 1
|
||||
# assert len(node.output) > 1
|
||||
# axis = attrs['axis']
|
||||
# split = attrs['split']
|
||||
# g.split(ts[node.input[0]], [ts[t] for t in node.output], axis, split)
|
||||
|
||||
# elif node.op_type == "Slice":
|
||||
# attrs = _parse_attribute(node.attribute, {})
|
||||
# assert len(node.input) == 4
|
||||
# assert len(node.output) == 1
|
||||
# g.slice(ts[node.input[0]], ts[node.output[0]],
|
||||
# ts[node.input[1]], ts[node.input[2]])
|
||||
|
||||
# elif node.op_type == "Resize":
|
||||
# attrs = _parse_attribute(node.attribute, {})
|
||||
# assert len(node.input) == 4
|
||||
# assert len(node.output) == 1
|
||||
# roi = ts[node.input[1]] if node.input[1] != '' else g.tensor(
|
||||
# [], 'FLOAT')
|
||||
# g.resize(ts[node.input[0]], roi, ts[node.output[0]])
|
||||
|
||||
# else:
|
||||
# assert False, "Unsupported op: " + node.op_type
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("model", help="ONNX model file")
|
||||
parser.add_argument("bs", help="batch size", type=int, default=1)
|
||||
# parser.add_argument("--output", help="Output file")
|
||||
args = parser.parse_args()
|
||||
import_onnx(args.model, args.bs)
|
||||
print_result(args.model)
|
|
@ -1,10 +1,11 @@
|
|||
import argparse
|
||||
from tokenize import Double
|
||||
import pyinfinitensor # import getPerfConv, getPerfMatmul
|
||||
|
||||
|
||||
def getPerfConv(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name=""):
|
||||
def getPerfConv(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group):
|
||||
return pyinfinitensor.getPerfConvCudnn(n, c, h, w, f, r, s, padh, padw,
|
||||
strideh, stridew, dilationh, dilationw, group, name)
|
||||
strideh, stridew, dilationh, dilationw, group)
|
||||
|
||||
|
||||
def getPerfConvTransposed2dCudnn(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group):
|
||||
|
@ -13,3 +14,43 @@ def getPerfConvTransposed2dCudnn(n, c, h, w, f, r, s, padh, padw, strideh, strid
|
|||
|
||||
def getPerfMatmul(b, m, n, k, name=""):
|
||||
return pyinfinitensor.getPerfMatmulCublas(b, m, n, k, name)
|
||||
|
||||
|
||||
def getPerfConvBiasActCudnn(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, bias: bool, act="None"):
|
||||
return pyinfinitensor.getPerfConvBiasActCudnn(n, c, h, w, f, r, s, padh, padw,
|
||||
strideh, stridew, dilationh, dilationw, group, bias, act)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||
parser.add_argument('op', metavar='operator', type=str)
|
||||
parser.add_argument('shape', nargs='+')
|
||||
parser.add_argument('--pad', type=int, default=0)
|
||||
parser.add_argument('--stride', type=int, default=1)
|
||||
parser.add_argument('--dilation', type=int, default=1)
|
||||
parser.add_argument('--group', type=int, default=1)
|
||||
parser.add_argument('--bias', type=bool, default=False)
|
||||
parser.add_argument('--act', type=str, default="None")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
if args.op == 'gemm':
|
||||
t = getPerfMatmul(int(args.shape[0]), int(
|
||||
args.shape[1]), int(args.shape[2]), int(args.shape[3]))
|
||||
print(
|
||||
f'time {t:.3f} ms, {2*int(args.shape[0])*int(args.shape[1])*int(args.shape[2])*int(args.shape[3])/t/1e9:.3f} TFLOPS')
|
||||
elif args.op == 'conv':
|
||||
assert len(args.shape) == 7
|
||||
n, c, h, w, f, r, s = [int(v) for v in args.shape]
|
||||
padh = padw = int(args.pad)
|
||||
strideh = stridew = int(args.stride)
|
||||
dilationh = dilationw = int(args.dilation)
|
||||
group = int(args.group)
|
||||
bias = int(args.bias)
|
||||
act = args.act
|
||||
assert group==1, "Unsupported"
|
||||
t = pyinfinitensor.getPerfConvBiasActCudnn(
|
||||
n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, bias, act)
|
||||
print(
|
||||
f'time {t:.3f} ms, {n*c*h*w*f*r*s/strideh/stridew*2/10**9:.3f} TFlops')
|
||||
else:
|
||||
assert False, "Not supported"
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import pandas as pd
|
||||
import numpy as np
|
||||
from operator_timer import *
|
||||
from datetime import datetime
|
||||
|
||||
pd.options.display.float_format = '{:,.3f}'.format
|
||||
|
||||
df= pd.DataFrame(columns=['n', 'c', 'h', 'w', 'f', 'r', 's', 'ph', 'pw', 'sh', 'sw', 'dh', 'dw', 'oph', 'opw', 'group'])
|
||||
def conv_original(name, n, c, h, w, f, r, s, ph, pw,
|
||||
sh, sw, dh, dw, group):
|
||||
df.loc[name, ['n', 'c', 'h', 'w', 'f', 'r', 's', 'ph', 'pw', 'sh', 'sw', 'dh', 'dw', 'group']] = n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, group
|
||||
df.loc[name, 't_original'] = getPerfConv(n, c, h, w, f, r, s, ph, pw,
|
||||
sh, sw, dh, dw, group)
|
||||
df.loc[name, 't_bias'] = getPerfConvBiasActCudnn(n, c, h, w, f, r, s, ph, pw,
|
||||
sh, sw, dh, dw, group, bias=True)
|
||||
df.loc[name, 't_bias_relu'] = getPerfConvBiasActCudnn(n, c, h, w, f, r, s, ph, pw,
|
||||
sh, sw, dh, dw, group, bias=True, act="Relu")
|
||||
|
||||
def conv_rule_5x5_to_3x3(name, n, c, h, w, f, r, s, ph, pw,
|
||||
sh, sw, dh, dw, group):
|
||||
col = 't_5x5_to_3x3'
|
||||
if r == 5 and s == 5:
|
||||
df.loc[name, col] = getPerfConv(n, c, h, w, f*4, 3, 3, ph, pw,
|
||||
sh, sw, dh, dw, group)
|
||||
else:
|
||||
df.loc[name, col] = np.inf
|
||||
|
||||
def conv_rule_9x9_to_3x3(name, n, c, h, w, f, r, s, ph, pw,
|
||||
sh, sw, dh, dw, group):
|
||||
col = 't_9x9_to_3x3'
|
||||
if r == 9 and s == 9:
|
||||
df.loc[name, col] = getPerfConv(n, c, h, w, f*9, r//3, s//3, ph, pw,
|
||||
sh, sw, dh, dw, group)
|
||||
else:
|
||||
df.loc[name, col] = np.inf
|
||||
|
||||
bandwidth=200*10**6 # (200GB/ms)
|
||||
|
||||
def conv_rule_conv2gemm(name, n, c, h, w, f, r, s, ph, pw,
|
||||
sh, sw, dh, dw, group):
|
||||
col = 't_conv2gemm'
|
||||
if [sh, sw, dh, dw, group] == [1] * 5:
|
||||
# b = group
|
||||
# m = batch_size * input_height * input_width
|
||||
# n = output_channel * kernel_height * kernel_width
|
||||
# k = input_channel // group
|
||||
t_reduce= group*n*h*w*f*r*s*4/bandwidth if r>1 or s>1 else 0
|
||||
df.loc[name, '_'+col+'_mem'] = t_reduce
|
||||
df.loc[name, col] = getPerfMatmul(group, n*h*w, f*r*s, c//group) + t_reduce
|
||||
else:
|
||||
df.loc[name, col] = np.inf
|
||||
|
||||
# conv_rules=[conv_original, conv_rule_9x9_to_3x3, conv_rule_5x5_to_3x3, conv_rule_conv2gemm]
|
||||
conv_rules=[conv_original]
|
||||
|
||||
def conv_tranpsposed2d_original(name, n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, oph, opw, group):
|
||||
df.loc[name, ['n', 'c', 'h', 'w', 'f', 'r', 's', 'ph', 'pw', 'sh', 'sw', 'dh', 'dw', 'oph', 'opw', 'group']] = n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, oph, opw, group
|
||||
df.loc[name, 't_original'] = getPerfConvTransposed2dCudnn(n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, oph, opw, group)
|
||||
|
||||
def conv_tranpsposed2d_togemm(name, n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, oph, opw, group):
|
||||
col = 't_conv2gemm'
|
||||
if [dh, dw, group] == [1] * 3:
|
||||
# ConvTransose2gemm
|
||||
# b = 1
|
||||
# m = batch_size * input_height*input_width
|
||||
# n = output_channel*kernel_height*kernel_width
|
||||
# k = input_channel
|
||||
t_reduce= n*h*w*c*r*s*4/bandwidth if r>1 or s>1 else 0
|
||||
df.loc[name, '_'+col+'_mem'] = t_reduce
|
||||
print('t_conv2gemm', group, n*h*w, c*r*s, f)
|
||||
df.loc[name, col] = getPerfMatmul(group, n*h*w, c*r*s, f) + t_reduce
|
||||
else:
|
||||
df.loc[name, col] = np.inf
|
||||
|
||||
conv_transposed2d_rules=[conv_tranpsposed2d_original, conv_tranpsposed2d_togemm]
|
||||
|
||||
def print_result(model_fn):
|
||||
pd.set_option('display.max_rows', 500)
|
||||
df['t_min'] = df.filter(regex=("^t_.*")).min(axis=1)
|
||||
print(df)
|
||||
print(f'Origin: {df["t_original"].sum():.3f} ms')
|
||||
print(f'Min: {df["t_min"].sum():.3f} ms')
|
||||
print(f'Speedup: {df["t_original"].sum()/df["t_min"].sum():.3f} x')
|
||||
df.to_pickle(f'optime_{model_fn.split("/")[-1]}_{datetime.now().strftime("%m_%d_%H_%M_%S")}.pkl')
|
|
@ -1,3 +1,4 @@
|
|||
#include "cuda/operator_timer.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
|
@ -12,8 +13,23 @@ namespace opTimer {
|
|||
|
||||
double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s,
|
||||
int padh, int padw, int strideh, int stridew,
|
||||
int dilationh, int dilationw, int group,
|
||||
const char *name) {
|
||||
int dilationh, int dilationw, int group) {
|
||||
return getPerfConvBiasActCudnn(n, c, h, w, f, r, s, padh, padw, strideh,
|
||||
stridew, dilationh, dilationw, group, false,
|
||||
"None");
|
||||
}
|
||||
|
||||
double getPerfConvBiasActCudnn(int n, int c, int h, int w, int f, int r, int s,
|
||||
int padh, int padw, int strideh, int stridew,
|
||||
int dilationh, int dilationw, int group,
|
||||
bool bias, string actName) {
|
||||
ActType act = ActType::None;
|
||||
if (actName == "None")
|
||||
act = ActType::None;
|
||||
else if (actName == "Relu")
|
||||
act = ActType::Relu;
|
||||
else
|
||||
IT_ASSERT(false, "Unsupported activation");
|
||||
// const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew,
|
||||
// dilationh, dilationw, group] =
|
||||
// tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1};
|
||||
|
@ -25,17 +41,27 @@ double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s,
|
|||
IT_ASSERT(c % group == 0);
|
||||
Tensor i0Cpu = gCpu->addTensor({n, c, h, w}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({f, c / group, r, s}, DataType::Float32);
|
||||
Tensor b0Cpu = gCpu->addTensor({f}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
b0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
Tensor b0Cuda = gCuda->cloneTensor(b0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv = gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, padh, padw,
|
||||
strideh, stridew, dilationh, dilationw);
|
||||
if (!bias) {
|
||||
auto conv =
|
||||
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, padh, padw, strideh,
|
||||
stridew, dilationh, dilationw);
|
||||
} else {
|
||||
auto conv =
|
||||
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, padh, padw, strideh,
|
||||
stridew, dilationh, dilationw, b0Cuda, act);
|
||||
}
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
|
|
|
@ -13,8 +13,10 @@ void register_operator_timer(py::module &m) {
|
|||
#ifdef USE_CUDA
|
||||
using namespace opTimer;
|
||||
m.def("getPerfConvCudnn", &getPerfConvCudnn);
|
||||
m.def("getPerfConvBiasActCudnn", &getPerfConvBiasActCudnn);
|
||||
m.def("getPerfConvTransposed2dCudnn", &getPerfConvTransposed2dCudnn);
|
||||
m.def("getPerfMatmulCublas", &getPerfMatmulCublas);
|
||||
m.def("getPerfMatmulCublas", &getPerfMatmulCublas);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
|
|
@ -8,6 +8,8 @@
|
|||
namespace infini {
|
||||
|
||||
struct ConvCuDnnPerfRecordObj : public PerfRecordObj {
|
||||
int kernel =
|
||||
0; // 0 cudnnConvolutionForward, 1 cudnnConvolutionBiasActivationForward
|
||||
int algo = 0; // cudnnConvolutionFwdAlgo_t
|
||||
int mode = 1;
|
||||
size_t workspaceSize = 100000;
|
||||
|
@ -56,8 +58,6 @@ class convCudnn : public Kernel {
|
|||
const ConvCuDnnPerfRecord &record) const {
|
||||
void *const inData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const knData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
if (op->getInputs().size() > 2) // Bias is not supported yet
|
||||
IT_TODO_HALT();
|
||||
// void *const biasData = (op->getInputs(2)->getRawDataPtr<void *>());
|
||||
void *const outData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
|
@ -209,6 +209,36 @@ class convCudnn : public Kernel {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool cuDNNfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
|
||||
const CudaRuntimeObj *context) const {
|
||||
cudnnStatus_t stat;
|
||||
|
||||
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc] =
|
||||
createCuDNNDescriptor(op, record);
|
||||
size_t wsSize = record->workspaceSize;
|
||||
CudaPtr wsData = context->getWorkspace(wsSize);
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
|
||||
// w/ bias & act
|
||||
stat = cudnnConvolutionBiasActivationForward(
|
||||
context->cudnnHandle(), &alpha, inDesc, inData, knDesc, knData,
|
||||
convDesc, ALGOS[record->algo], wsData, wsSize, &beta, outDesc,
|
||||
outData, biasDesc, nullptr, actDesc, outDesc, outData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
return false;
|
||||
|
||||
// Destories in CUDA does not require sync. But cuDNN does not state
|
||||
// whether sync is required before destories.
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
|
||||
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||
return true;
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
auto record = make_ref<ConvCuDnnPerfRecordObj>(); // with paramters in
|
||||
// default ctor
|
||||
|
@ -217,10 +247,88 @@ class convCudnn : public Kernel {
|
|||
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
ConvCuDnnPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
auto op = as<ConvObj>(_op);
|
||||
printf("%s\n", op->toString().c_str());
|
||||
if (op->hasBias() || op->getAct() != ActType::None)
|
||||
return tuneFused(op, context);
|
||||
else
|
||||
return tuneUnfused(op, context);
|
||||
}
|
||||
|
||||
PerfRecord tuneFused(const Ref<ConvObj> &op,
|
||||
const CudaRuntimeObj *context) const {
|
||||
ConvCuDnnPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
|
||||
// Both modes have the same performance. Only run cross-correlation.
|
||||
for (int mode = 1; mode < 2; mode++) {
|
||||
// Try every possible algorithm of convolution
|
||||
for (int algo = 0; algo < N_ALGO; algo++) {
|
||||
auto recordRef = make_ref<ConvCuDnnPerfRecordObj>();
|
||||
auto &record = *recordRef;
|
||||
record.mode = mode;
|
||||
record.algo = algo;
|
||||
cudnnStatus_t stat;
|
||||
const auto &[inData, knData, outData, inDesc, knDesc, biasDesc,
|
||||
convDesc, actDesc, outDesc] =
|
||||
createCuDNNDescriptor(op, recordRef);
|
||||
void *biasData = op->getBias()->getRawDataPtr<void *>();
|
||||
|
||||
// get workspace
|
||||
stat = cudnnGetConvolutionForwardWorkspaceSize(
|
||||
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
|
||||
ALGOS[record.algo], &record.workspaceSize);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
continue;
|
||||
if (record.workspaceSize > context->getWorkspaceSize())
|
||||
continue;
|
||||
CudaPtr wsData = context->getWorkspace(record.workspaceSize);
|
||||
float alpha = 1.f, beta = 0.f;
|
||||
|
||||
stat = cudnnConvolutionBiasActivationForward(
|
||||
context->cudnnHandle(), &alpha, inDesc, inData, knDesc,
|
||||
knData, convDesc, ALGOS[record.algo], wsData,
|
||||
record.workspaceSize, &beta, outDesc, outData, biasDesc,
|
||||
biasData, actDesc, outDesc, outData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
continue;
|
||||
record.time = timeit(
|
||||
[&]() {
|
||||
stat = cudnnConvolutionBiasActivationForward(
|
||||
context->cudnnHandle(), &alpha, inDesc, inData,
|
||||
knDesc, knData, convDesc, ALGOS[record.algo],
|
||||
wsData, record.workspaceSize, &beta, outDesc,
|
||||
outData, biasDesc, biasData, actDesc, outDesc,
|
||||
outData);
|
||||
},
|
||||
[&]() { context->sync(); });
|
||||
printf("mode %d, algo %d, time %.3lf ms\n", mode, algo,
|
||||
record.time);
|
||||
|
||||
// Update the tune result
|
||||
if (ret.time > record.time)
|
||||
ret = record;
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(biasDesc));
|
||||
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||
}
|
||||
}
|
||||
printf("tuneFused: the best algo is %d, the best conv mode is %d\n",
|
||||
ret.algo, ret.mode);
|
||||
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
|
||||
"algorithm "
|
||||
"found");
|
||||
return make_ref<ConvCuDnnPerfRecordObj>(ret);
|
||||
}
|
||||
|
||||
PerfRecord tuneUnfused(const Ref<ConvObj> &op,
|
||||
const CudaRuntimeObj *context) const {
|
||||
ConvCuDnnPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
// Both modes have the same performance. Only run cross-correlation.
|
||||
for (int mode = 1; mode < 2; mode++) {
|
||||
// Try every possible algorithm of convolution
|
||||
|
@ -260,7 +368,8 @@ class convCudnn : public Kernel {
|
|||
&beta, outDesc, outData);
|
||||
},
|
||||
[&]() { context->sync(); });
|
||||
// printf("mode:%d algo:%d :%.8lf\n", mode, algo, record.time);
|
||||
printf("mode %d, algo %d, time %.3lf ms\n", mode, algo,
|
||||
record.time);
|
||||
|
||||
// Update the tune result
|
||||
if (ret.time > record.time)
|
||||
|
@ -273,8 +382,8 @@ class convCudnn : public Kernel {
|
|||
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
|
||||
}
|
||||
}
|
||||
// printf("the best algo is %d, the best conv mode is %d\n", ret.algo,
|
||||
// ret.mode);
|
||||
printf("tuneUnfused: the best algo is %d, the best conv mode is %d\n",
|
||||
ret.algo, ret.mode);
|
||||
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
|
||||
"algorithm "
|
||||
"found");
|
||||
|
|
|
@ -5,15 +5,15 @@ namespace infini {
|
|||
ConvBaseObj::ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
|
||||
int ph, int pw, int sh, int sw, int dh, int dw,
|
||||
const Tensor &inputInConvFWD,
|
||||
const Tensor &weightInConvFWD)
|
||||
const Tensor &weightInConvFWD, const ActType act)
|
||||
: OperatorObj(opType, inputs, {output}), ph(ph), pw(pw), sh(sh), sw(sw),
|
||||
dh(dh), dw(dw), padding(PaddingMode::Other) {}
|
||||
dh(dh), dw(dw), padding(PaddingMode::Other), act(act) {}
|
||||
ConvBaseObj::ConvBaseObj(OpType opType, TensorVec inputs, Tensor &output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw,
|
||||
const Tensor &inputInConvFWD,
|
||||
const Tensor &weightInConvFWD)
|
||||
const Tensor &weightInConvFWD, const ActType act)
|
||||
: OperatorObj(opType, inputs, {output}), ph(-1), pw(-1), sh(sh), sw(sw),
|
||||
dh(dh), dw(dw), padding(mode) {
|
||||
dh(dh), dw(dw), padding(mode), act(act) {
|
||||
IT_ASSERT(mode != PaddingMode::Other);
|
||||
}
|
||||
|
||||
|
@ -21,28 +21,60 @@ string ConvBaseObj::toString() const {
|
|||
std::ostringstream os;
|
||||
os << OpRegistry::getOpName(getOpType()) << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
if (inputs.size() == 2) {
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << vecToString(inputs[1]->getDims()) << ",";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << vecToString(inputs[1]->getDims()) << ",";
|
||||
if (inputs.size() > 2) {
|
||||
os << vecToString(inputs[2]->getDims()) << ",";
|
||||
}
|
||||
os << "p=[" << ph << "," << pw << "],";
|
||||
os << "s=[" << sh << "," << sw << "],";
|
||||
os << "d=[" << dh << "," << dw << "],";
|
||||
os << "act=" << enum_to_underlying(getAct()) << ",";
|
||||
// os << "act=" << enum_to_underlying(act) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "weight=" << inputs[1]->getGuid() << ",";
|
||||
os << "bias="
|
||||
<< ((inputs.size() == 2) ? "nullptr"
|
||||
: std::to_string(inputs[2]->getGuid()))
|
||||
<< ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> ConvBaseObj::getWorkloadVector() const {
|
||||
return {
|
||||
enum_to_underlying(type), n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw};
|
||||
return {enum_to_underlying(type),
|
||||
n,
|
||||
c,
|
||||
h,
|
||||
w,
|
||||
f,
|
||||
r,
|
||||
s,
|
||||
ph,
|
||||
pw,
|
||||
sh,
|
||||
sw,
|
||||
dh,
|
||||
dw,
|
||||
hasBias(),
|
||||
enum_to_underlying(getAct())};
|
||||
}
|
||||
|
||||
vector<int> ConvBaseObj::getOpAttrVector() const {
|
||||
IT_TODO_HALT(); // should padding mode / ph+pw be in attrs?
|
||||
return {enum_to_underlying(type), c, f, r, s, ph, pw, sh, sw, dh, dw};
|
||||
return {enum_to_underlying(type),
|
||||
c,
|
||||
f,
|
||||
r,
|
||||
s,
|
||||
ph,
|
||||
pw,
|
||||
sh,
|
||||
sw,
|
||||
dh,
|
||||
dw,
|
||||
hasBias(),
|
||||
enum_to_underlying(getAct())};
|
||||
}
|
||||
|
||||
void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
|
||||
|
@ -64,11 +96,10 @@ void ConvObj::setAuxilaryAttributes(PaddingMode mode) {
|
|||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, ph, pw, sh, sw, dh, dw,
|
||||
input, weight),
|
||||
act(act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
: ConvBaseObj(OpType::Conv,
|
||||
((bias) ? (TensorVec{input, weight, bias})
|
||||
: (TensorVec{input, weight})),
|
||||
output, ph, pw, sh, sw, dh, dw, input, weight, act) {
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
@ -76,11 +107,10 @@ ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
|||
ConvObj::ConvObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output,
|
||||
PaddingMode mode, int sh, int sw, int dh, int dw, Tensor bias,
|
||||
ActType act)
|
||||
: ConvBaseObj(OpType::Conv, {input, weight}, output, mode, sh, sw, dh, dw,
|
||||
input, weight),
|
||||
act(act) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
: ConvBaseObj(OpType::Conv,
|
||||
((bias) ? (TensorVec{input, weight, bias})
|
||||
: (TensorVec{input, weight})),
|
||||
output, mode, sh, sw, dh, dw, input, weight, act) {
|
||||
setAuxilaryAttributes(mode);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
@ -98,6 +128,11 @@ optional<vector<Shape>> ConvObj::inferShape(const TensorVec &inputs) const {
|
|||
// For NCHW+FCRS layout, C of input is divisable by C of weight
|
||||
if (input->getDims()[1] % weight->getDims()[1] != 0)
|
||||
return {};
|
||||
// check bias shape
|
||||
if (inputs.size() == 3) {
|
||||
if (inputs[2]->size() != (size_t)f)
|
||||
return {};
|
||||
}
|
||||
// Set padding size
|
||||
if (padding == PaddingMode::Other) {
|
||||
oh = (h - (r - sh) * dh + ph * 2) / sh;
|
||||
|
@ -122,8 +157,8 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
|
|||
int oph, int opw, int group,
|
||||
Tensor bias, ActType act)
|
||||
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, ph, pw, sh, sw,
|
||||
dh, dw, output, weight),
|
||||
oph(oph), opw(opw), group(group), act(act) {
|
||||
dh, dw, output, weight, act),
|
||||
oph(oph), opw(opw), group(group) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(PaddingMode::Other);
|
||||
|
@ -136,8 +171,8 @@ ConvTransposed2dObj::ConvTransposed2dObj(GraphObj *graph, Tensor input,
|
|||
int dh, int dw, int oph, int opw,
|
||||
int group, Tensor bias, ActType act)
|
||||
: ConvBaseObj(OpType::ConvTrans, {input, weight}, output, mode, sh, sw, dh,
|
||||
dw, output, weight),
|
||||
oph(oph), opw(opw), group(group), act(act) {
|
||||
dw, output, weight, act),
|
||||
oph(oph), opw(opw), group(group) {
|
||||
if (bias)
|
||||
IT_TODO_HALT();
|
||||
setAuxilaryAttributes(mode);
|
||||
|
@ -156,6 +191,8 @@ ConvTransposed2dObj::inferShape(const TensorVec &inputs) const {
|
|||
auto s = weight->getDims()[3];
|
||||
if (f != weight->getDims()[0])
|
||||
return {};
|
||||
if (inputs.size() != 2)
|
||||
IT_TODO_HALT();
|
||||
|
||||
int on = n, oc = c * group;
|
||||
int oh = 0, ow = 0;
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
. /home/spack/spack/share/spack/setup-env.sh
|
||||
spack load /yb2wg5g # cuda@10.2.89
|
||||
spack load /3bfwma4 # cudnn@7.6.5.32-10.2
|
||||
export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-haswell/gcc-7.5.0/gcc-7.5.0-sti65cu3zunc4p4kfylgweim6mqan3mk/bin/gcc
|
Loading…
Reference in New Issue