forked from jiuyuan/InfiniTensor
针对bert和gpt2模型分布式推理的优化 (#221)
* fix(dist): 改善分布式脚本,只打印绝对误差 * feat(dist): 增加可导出onnx的pytorch运行脚本 * feat(front): 增加对Y值为-inf的where算子的图优化 * feat(kernel): 对b为常数的pow和div算子进行特判优化 * fix(front): 消除前端对global output形状信息的依赖,分布式脚本删除不必要的shape infer * feat(kernel): 针对matmul中bias为行向量时的expand操作的特化优化 * fix(kernel): 删除div pow const中不必要的同步 * Update expand.cu * fix: fix comments --------- Co-authored-by: Haojie Wang <haojie0429@gmail.com> Co-authored-by: Derui Yang <ydrml@hotmail.com>
This commit is contained in:
parent
a98573990b
commit
7f6aec6c17
|
@ -0,0 +1,17 @@
|
||||||
|
# 分布式脚本
|
||||||
|
|
||||||
|
#### 1. 运行pytorch模型并生成输入和标准输出,可选择导出onnx
|
||||||
|
|
||||||
|
使用 `--export_onnx` 设置导出onnx的目录,默认为当前路径 `./`,不使用这个flag则只进行计算和生成输入输出。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run_pytorch.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
|
||||||
|
```
|
||||||
|
|
||||||
|
会在当前目录下生成输入输出文件`test_inputs.npy` 和 `test_results.npy`,目前只支持单一输入输出。
|
||||||
|
|
||||||
|
#### 2. 运行InfiniTensor分布式脚本
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
|
||||||
|
```
|
|
@ -47,7 +47,7 @@ def parse_args():
|
||||||
|
|
||||||
def run_model(model, runtime, inputs, n=10):
|
def run_model(model, runtime, inputs, n=10):
|
||||||
stub = OnnxStub(model, runtime)
|
stub = OnnxStub(model, runtime)
|
||||||
for tensor, input in zip(stub.inputs.values(), inputs):
|
for tensor, input in zip(stub.inputs.values(), inputs, strict=False):
|
||||||
tensor.copyin_numpy(input)
|
tensor.copyin_numpy(input)
|
||||||
# stub.tune()
|
# stub.tune()
|
||||||
stub.run()
|
stub.run()
|
||||||
|
@ -55,7 +55,7 @@ def run_model(model, runtime, inputs, n=10):
|
||||||
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
|
||||||
|
|
||||||
# bench
|
# bench
|
||||||
for tensor, input in zip(stub.inputs.values(), inputs):
|
for tensor, input in zip(stub.inputs.values(), inputs, strict=False):
|
||||||
tensor.copyin_numpy(input)
|
tensor.copyin_numpy(input)
|
||||||
begin = time.time()
|
begin = time.time()
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
|
@ -72,7 +72,7 @@ def run_and_compare(name, model, runtime):
|
||||||
results = np.load(f"{name}_results.npy")
|
results = np.load(f"{name}_results.npy")
|
||||||
outputs = run_model(model, runtime, (input_ids, position_ids))
|
outputs = run_model(model, runtime, (input_ids, position_ids))
|
||||||
print("outputs abs mean:", abs(outputs).mean())
|
print("outputs abs mean:", abs(outputs).mean())
|
||||||
np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3)
|
print("max abs diff:", abs(outputs - results).max())
|
||||||
|
|
||||||
|
|
||||||
def start_worker(
|
def start_worker(
|
||||||
|
@ -89,7 +89,7 @@ def start_worker(
|
||||||
save_as_external_data=True,
|
save_as_external_data=True,
|
||||||
location=extern_path,
|
location=extern_path,
|
||||||
)
|
)
|
||||||
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
#infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
|
||||||
runtime = backend.CudaRuntime(local_rank)
|
runtime = backend.CudaRuntime(local_rank)
|
||||||
# print("init comm")
|
# print("init comm")
|
||||||
runtime.init_comm(
|
runtime.init_comm(
|
||||||
|
|
|
@ -244,5 +244,5 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
|
||||||
if tt.HasField("shape"):
|
if tt.HasField("shape"):
|
||||||
tt.ClearField("shape")
|
tt.ClearField("shape")
|
||||||
model = helper.make_model(graph)
|
model = helper.make_model(graph)
|
||||||
model = onnx.shape_inference.infer_shapes(model)
|
#model = onnx.shape_inference.infer_shapes(model)
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -0,0 +1,178 @@
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from transformers import BertModel, BertConfig
|
||||||
|
from transformers import GPT2Model, GPT2Config
|
||||||
|
from transformers import OPTModel, OPTConfig
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
import onnx
|
||||||
|
import os
|
||||||
|
from onnx.external_data_helper import convert_model_to_external_data
|
||||||
|
from onnxsim import simplify
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
torch.backends.cudnn.allow_tf32 = False
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, choices=["gpt2", "bert", "opt"], required=True, help="model type"
|
||||||
|
)
|
||||||
|
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
|
||||||
|
parser.add_argument("--length", type=int, default=1, help="sequence length.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--export_onnx",
|
||||||
|
type=str,
|
||||||
|
nargs="?",
|
||||||
|
default=None,
|
||||||
|
const="./",
|
||||||
|
help="whether and where to export onnx file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args = parser.parse_args()
|
||||||
|
print("arg setting: ", args)
|
||||||
|
return (
|
||||||
|
args.model,
|
||||||
|
args.batch_size,
|
||||||
|
args.length,
|
||||||
|
args.export_onnx
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(modelname):
|
||||||
|
match modelname:
|
||||||
|
case "bert":
|
||||||
|
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
|
||||||
|
voc_size = BertConfig().vocab_size
|
||||||
|
case "gpt2":
|
||||||
|
model = GPT2Model.from_pretrained("gpt2")
|
||||||
|
voc_size = GPT2Config().vocab_size
|
||||||
|
case "opt":
|
||||||
|
model = model = OPTModel.from_pretrained("./opt-125m")
|
||||||
|
voc_size = OPTConfig().vocab_size
|
||||||
|
case _:
|
||||||
|
raise KeyError(modelname)
|
||||||
|
|
||||||
|
model = model.eval()
|
||||||
|
return model, voc_size
|
||||||
|
|
||||||
|
def run_pytorch(torch_model, voc_size, batchsize, len):
|
||||||
|
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
|
||||||
|
np.save("test_inputs", data)
|
||||||
|
inputs = torch.from_numpy(data).to("cuda")
|
||||||
|
torch_model = torch_model.to("cuda")
|
||||||
|
|
||||||
|
n_iter = 20
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(10):
|
||||||
|
outputs = torch_model(inputs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
begin = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(n_iter):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
outputs = torch_model(inputs)
|
||||||
|
#
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
avg_time = (end - begin) / n_iter
|
||||||
|
outputs = outputs.last_hidden_state.to("cpu")
|
||||||
|
print("outputs abs mean:", abs(np.array(outputs)).mean())
|
||||||
|
print(f"average time: {avg_time}")
|
||||||
|
torch.cuda.memory.empty_cache()
|
||||||
|
np.save("test_results", np.array(outputs))
|
||||||
|
print("Save input & output as test_inputs.npy and test_results.npy")
|
||||||
|
|
||||||
|
|
||||||
|
def export_onnx(model, data, path, extern=False):
|
||||||
|
torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
|
||||||
|
onnx_model = onnx.load(path)
|
||||||
|
onnx_model, check = simplify(onnx_model, skipped_optimizers=['eliminate_duplicate_initializer'])
|
||||||
|
#onnx_model, check = simplify(onnx_model, skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
|
||||||
|
assert check
|
||||||
|
add_value_info_for_constants(onnx_model)
|
||||||
|
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
|
||||||
|
if extern:
|
||||||
|
extern_path = path.replace('.onnx', '.pb')
|
||||||
|
if os.path.exists(extern_path):
|
||||||
|
os.remove(extern_path)
|
||||||
|
convert_model_to_external_data(
|
||||||
|
onnx_model,
|
||||||
|
all_tensors_to_one_file=True,
|
||||||
|
location=extern_path,
|
||||||
|
size_threshold=1024,
|
||||||
|
convert_attribute=False,
|
||||||
|
)
|
||||||
|
onnx.save(onnx_model, path)
|
||||||
|
|
||||||
|
def add_value_info_for_constants(model : onnx.ModelProto):
|
||||||
|
"""
|
||||||
|
Currently onnx.shape_inference doesn't use the shape of initializers, so add
|
||||||
|
that info explicitly as ValueInfoProtos.
|
||||||
|
Mutates the model.
|
||||||
|
Args:
|
||||||
|
model: The ModelProto to update.
|
||||||
|
"""
|
||||||
|
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
|
||||||
|
if model.ir_version < 4:
|
||||||
|
return
|
||||||
|
|
||||||
|
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
|
||||||
|
inputs = {i.name for i in graph.input}
|
||||||
|
existing_info = {vi.name: vi for vi in graph.value_info}
|
||||||
|
for init in graph.initializer:
|
||||||
|
# Check it really is a constant, not an input
|
||||||
|
if init.name in inputs:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The details we want to add
|
||||||
|
elem_type = init.data_type
|
||||||
|
shape = init.dims
|
||||||
|
|
||||||
|
# Get existing or create new value info for this constant
|
||||||
|
vi = existing_info.get(init.name)
|
||||||
|
if vi is None:
|
||||||
|
vi = graph.value_info.add()
|
||||||
|
vi.name = init.name
|
||||||
|
|
||||||
|
# Even though it would be weird, we will not overwrite info even if it doesn't match
|
||||||
|
tt = vi.type.tensor_type
|
||||||
|
if tt.elem_type == onnx.TensorProto.UNDEFINED:
|
||||||
|
tt.elem_type = elem_type
|
||||||
|
if not tt.HasField("shape"):
|
||||||
|
# Ensure we set an empty list if the const is scalar (zero dims)
|
||||||
|
tt.shape.dim.extend([])
|
||||||
|
for dim in shape:
|
||||||
|
tt.shape.dim.add().dim_value = dim
|
||||||
|
|
||||||
|
# Handle subgraphs
|
||||||
|
for node in graph.node:
|
||||||
|
for attr in node.attribute:
|
||||||
|
# Ref attrs refer to other attrs, so we don't need to do anything
|
||||||
|
if attr.ref_attr_name != "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if attr.type == onnx.AttributeProto.GRAPH:
|
||||||
|
add_const_value_infos_to_graph(attr.g)
|
||||||
|
if attr.type == onnx.AttributeProto.GRAPHS:
|
||||||
|
for g in attr.graphs:
|
||||||
|
add_const_value_infos_to_graph(g)
|
||||||
|
|
||||||
|
|
||||||
|
return add_const_value_infos_to_graph(model.graph)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
modelname, batchsize, seqlen, export_path = parse_args()
|
||||||
|
model, voc_size = get_model(modelname)
|
||||||
|
if export_path is not None:
|
||||||
|
filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen)
|
||||||
|
path = os.path.join(export_path, filename)
|
||||||
|
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
|
||||||
|
export_onnx(model, param, path, True)
|
||||||
|
|
||||||
|
run_pytorch(model, voc_size, batchsize, seqlen)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -13,4 +13,8 @@ void pow_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
||||||
void less_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
void less_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
|
||||||
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
|
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
|
||||||
int c2, int c3);
|
int c2, int c3);
|
||||||
|
|
||||||
|
void div_const_kernel(int dType, void *a, void *b, void *c, size_t n);
|
||||||
|
|
||||||
|
void pow_const_kernel(int dType, void *a, void *b, void *c, size_t n);
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -7,4 +7,6 @@ void expandKernel(int dType, void *input, void *output, int nDims,
|
||||||
int outputsize, SmallArray inputShape,
|
int outputsize, SmallArray inputShape,
|
||||||
SmallArray outputShape);
|
SmallArray outputShape);
|
||||||
|
|
||||||
|
void expandRowKernel(int dType, void *input, void *output, int n_rows,
|
||||||
|
int row_len);
|
||||||
}; // namespace infini
|
}; // namespace infini
|
||||||
|
|
|
@ -23,12 +23,13 @@ from onnx.checker import (
|
||||||
ValidationError,
|
ValidationError,
|
||||||
)
|
)
|
||||||
from onnx.shape_inference import infer_shapes
|
from onnx.shape_inference import infer_shapes
|
||||||
from onnx.numpy_helper import to_array
|
from onnx.numpy_helper import to_array, from_array
|
||||||
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from onnxsim import simplify
|
from onnxsim import simplify
|
||||||
import copy
|
import copy
|
||||||
import warnings
|
import warnings
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class OnnxStub:
|
class OnnxStub:
|
||||||
|
@ -111,12 +112,6 @@ class OnnxStub:
|
||||||
)
|
)
|
||||||
tensors[input.name].set_input()
|
tensors[input.name].set_input()
|
||||||
|
|
||||||
for output in model.graph.output:
|
|
||||||
dims = _take_shape_dim(output.type.tensor_type.shape)
|
|
||||||
tensors[output.name] = self.handler.tensor(
|
|
||||||
dims, output.type.tensor_type.elem_type
|
|
||||||
)
|
|
||||||
tensors[output.name].set_output()
|
|
||||||
|
|
||||||
for node_idx in sorted_nodes:
|
for node_idx in sorted_nodes:
|
||||||
node = model.graph.node[node_idx]
|
node = model.graph.node[node_idx]
|
||||||
|
@ -947,6 +942,25 @@ class OnnxStub:
|
||||||
tensors.get(node.output[0]),
|
tensors.get(node.output[0]),
|
||||||
)
|
)
|
||||||
elif node.op_type == "Where":
|
elif node.op_type == "Where":
|
||||||
|
## If Y is single -inf, treat Where as Add
|
||||||
|
## TODO: deal with cases where Y is single inf or 0
|
||||||
|
if node.input[0] in data and node.input[2] in data:
|
||||||
|
where_condition = to_array(data[node.input[0]])
|
||||||
|
where_alt = to_array(data[node.input[2]])
|
||||||
|
if where_alt.size == 1:
|
||||||
|
if np.isneginf(where_alt) or np.all(where_alt < -3e38):
|
||||||
|
node.input[0] = node.input[0] + "_alt"
|
||||||
|
if node.input[0] not in data:
|
||||||
|
where_value = np.where(where_condition, 0, -np.inf).astype(where_alt.dtype)
|
||||||
|
data[node.input[0]] = from_array(where_value, node.input[0])
|
||||||
|
tensors[node.input[0]] = self.handler.tensor(list(where_value.shape), data[node.input[0]].data_type)
|
||||||
|
tensors[node.input[0]].set_weight()
|
||||||
|
tensors[node.output[0]] = self.handler.add(
|
||||||
|
tensors[node.input[1]],
|
||||||
|
tensors[node.input[0]],
|
||||||
|
tensors.get(node.output[0]),
|
||||||
|
)
|
||||||
|
continue
|
||||||
tensors[node.output[0]] = self.handler.where(
|
tensors[node.output[0]] = self.handler.where(
|
||||||
tensors[node.input[1]],
|
tensors[node.input[1]],
|
||||||
tensors[node.input[2]],
|
tensors[node.input[2]],
|
||||||
|
@ -980,6 +994,8 @@ class OnnxStub:
|
||||||
else:
|
else:
|
||||||
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
raise Exception('Unsupported operator "{}"'.format(node.op_type))
|
||||||
|
|
||||||
|
for output in model.graph.output:
|
||||||
|
tensors[output.name].set_output()
|
||||||
################################
|
################################
|
||||||
# Allocate memory space for data
|
# Allocate memory space for data
|
||||||
################################
|
################################
|
||||||
|
|
|
@ -115,6 +115,20 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
|
||||||
auto a_dim = op->getInputs(0)->getDims();
|
auto a_dim = op->getInputs(0)->getDims();
|
||||||
auto b_dim = op->getInputs(1)->getDims();
|
auto b_dim = op->getInputs(1)->getDims();
|
||||||
auto c_dim = op->getOutput()->getDims();
|
auto c_dim = op->getOutput()->getDims();
|
||||||
|
const int dType = _op->getDType().getIndex();
|
||||||
|
|
||||||
|
// Use optimized kernel if b is constant
|
||||||
|
if (b_dim.size() == 0) {
|
||||||
|
if (op->getOpType() == OpType::Div) {
|
||||||
|
div_const_kernel(dType, aData, bData, cData,
|
||||||
|
op->getOutput()->size());
|
||||||
|
return;
|
||||||
|
} else if (op->getOpType() == OpType::Pow) {
|
||||||
|
pow_const_kernel(dType, aData, bData, cData,
|
||||||
|
op->getOutput()->size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4)
|
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4)
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
|
@ -127,7 +141,6 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
|
||||||
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
|
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
|
||||||
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
|
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
|
||||||
|
|
||||||
const int dType = _op->getDType().getIndex();
|
|
||||||
if (op->getOpType() == OpType::Div) {
|
if (op->getOpType() == OpType::Div) {
|
||||||
div_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
|
div_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
|
||||||
b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
|
b[1], b[2], b[3], c[0], c[1], c[2], c[3]);
|
||||||
|
|
|
@ -132,8 +132,8 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||||
|
|
||||||
#define CASE(OP, T) \
|
#define CASE(OP, T) \
|
||||||
_##OP##_kernel<DT_CUDA<T>::t> \
|
_##OP##_kernel<DT_CUDA<T>::t> \
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
|
||||||
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||||
|
|
||||||
#define SWITCH_DTYPE(OP, DTYPE) \
|
#define SWITCH_DTYPE(OP, DTYPE) \
|
||||||
switch (DTYPE) { \
|
switch (DTYPE) { \
|
||||||
|
@ -177,7 +177,92 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
|
||||||
IT_TODO_HALT(); \
|
IT_TODO_HALT(); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
__global__ void _div_const_kernel(void const *__restrict__ x,
|
||||||
|
void const *__restrict__ y,
|
||||||
|
void *__restrict__ z, const size_t n) {
|
||||||
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (tid < n) {
|
||||||
|
((T *)z)[tid] = ((T *)x)[tid] / *((T *)y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
__global__ void _pow_const_kernel(void const *__restrict__ x,
|
||||||
|
void const *__restrict__ y,
|
||||||
|
void *__restrict__ z, const size_t n) {
|
||||||
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (tid < n) {
|
||||||
|
((T *)z)[tid] = pow(((T *)x)[tid], *((T *)y));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template <>
|
||||||
|
__global__ void _pow_const_kernel<half>(void const *__restrict__ x,
|
||||||
|
void const *__restrict__ y,
|
||||||
|
void *__restrict__ z, const size_t n) {
|
||||||
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
if (tid < n) {
|
||||||
|
((half *)z)[tid] = pow(((float)((half *)x)[tid]), *((half *)y));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CASE_CONST(OP, T) \
|
||||||
|
_##OP##_const_kernel<DT_CUDA<T>::t> \
|
||||||
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(a, b, c, \
|
||||||
|
n);
|
||||||
|
|
||||||
|
#define SWITCH_DTYPE_CONST(OP, DTYPE) \
|
||||||
|
switch (DTYPE) { \
|
||||||
|
case 1: \
|
||||||
|
CASE_CONST(OP, 1) \
|
||||||
|
break; \
|
||||||
|
case 2: \
|
||||||
|
CASE_CONST(OP, 2) \
|
||||||
|
break; \
|
||||||
|
case 3: \
|
||||||
|
CASE_CONST(OP, 3) \
|
||||||
|
break; \
|
||||||
|
case 4: \
|
||||||
|
CASE_CONST(OP, 4) \
|
||||||
|
break; \
|
||||||
|
case 5: \
|
||||||
|
CASE_CONST(OP, 5) \
|
||||||
|
break; \
|
||||||
|
case 6: \
|
||||||
|
CASE_CONST(OP, 6) \
|
||||||
|
break; \
|
||||||
|
case 7: \
|
||||||
|
CASE_CONST(OP, 7) \
|
||||||
|
break; \
|
||||||
|
case 10: \
|
||||||
|
CASE_CONST(OP, 10) \
|
||||||
|
break; \
|
||||||
|
case 11: \
|
||||||
|
CASE_CONST(OP, 11) \
|
||||||
|
break; \
|
||||||
|
case 12: \
|
||||||
|
CASE_CONST(OP, 12) \
|
||||||
|
break; \
|
||||||
|
case 13: \
|
||||||
|
CASE_CONST(OP, 13) \
|
||||||
|
break; \
|
||||||
|
default: \
|
||||||
|
IT_TODO_HALT(); \
|
||||||
|
}
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
void div_const_kernel(int dType, void *a, void *b, void *c, size_t n) {
|
||||||
|
size_t blocksize = block_work_size();
|
||||||
|
size_t gridsize = (n + block_work_size() - 1) / block_work_size();
|
||||||
|
SWITCH_DTYPE_CONST(div, dType);
|
||||||
|
}
|
||||||
|
|
||||||
|
void pow_const_kernel(int dType, void *a, void *b, void *c, size_t n) {
|
||||||
|
size_t blocksize = block_work_size();
|
||||||
|
size_t gridsize = (n + block_work_size() - 1) / block_work_size();
|
||||||
|
SWITCH_DTYPE_CONST(pow, dType);
|
||||||
|
}
|
||||||
|
|
||||||
void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
|
||||||
int c3) {
|
int c3) {
|
||||||
|
@ -204,12 +289,12 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
int gridsize = (num + block_work_size() - 1) / block_work_size();
|
||||||
if (dType == 1) {
|
if (dType == 1) {
|
||||||
_pow_kernel<float>
|
_pow_kernel<float>
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||||
} else if (dType == 3) {
|
} else if (dType == 3) {
|
||||||
_pow_kernel<int8_t>
|
_pow_kernel<int8_t>
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
|
||||||
} else if (dType == 10) {
|
} else if (dType == 10) {
|
||||||
int a_size = a0 * a1 * a2 * a3;
|
int a_size = a0 * a1 * a2 * a3;
|
||||||
int b_size = b0 * b1 * b2 * b3;
|
int b_size = b0 * b1 * b2 * b3;
|
||||||
|
@ -224,9 +309,9 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
|
||||||
b_float[i] = __half2float(((half *)b)[i]);
|
b_float[i] = __half2float(((half *)b)[i]);
|
||||||
}
|
}
|
||||||
_pow_kernel<float>
|
_pow_kernel<float>
|
||||||
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
|
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
|
||||||
(a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0,
|
a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3,
|
||||||
b1, b2, b3, c0, c1, c2, c3);
|
b0, b1, b2, b3, c0, c1, c2, c3);
|
||||||
for (int i = 0; i < c_size; ++i) {
|
for (int i = 0; i < c_size; ++i) {
|
||||||
((half *)c)[i] = __float2half(c_float[i]);
|
((half *)c)[i] = __float2half(c_float[i]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -39,6 +39,14 @@ __global__ void _expandKernel(void *input, void *output, int nDims,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T>
|
||||||
|
static __global__ void _expandRowKernel(void *__restrict__ dst,
|
||||||
|
void const *__restrict__ src) {
|
||||||
|
auto da = gridDim.x, db = blockDim.y, dx = blockDim.x, n = blockIdx.y,
|
||||||
|
a = blockIdx.x, b = threadIdx.y, x = threadIdx.x;
|
||||||
|
auto i = ((n * da + a) * db + b) * dx + x, j = (a * db + b) * dx + x;
|
||||||
|
reinterpret_cast<T *>(dst)[i] = reinterpret_cast<T const *>(src)[j];
|
||||||
|
}
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
#define CASE(T) \
|
#define CASE(T) \
|
||||||
|
@ -96,4 +104,67 @@ void expandKernel(int dType, void *input, void *output, int nDims,
|
||||||
SWITCH_DTYPE(dType)
|
SWITCH_DTYPE(dType)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define CASE_ROW(T) \
|
||||||
|
_expandRowKernel<float> \
|
||||||
|
<<<grid, block, 0, CUDAStream::getCurrentStream()>>>(output, input);
|
||||||
|
|
||||||
|
#define SWITCH_DTYPE_ROW(DTYPE) \
|
||||||
|
switch (DTYPE) { \
|
||||||
|
case 1: \
|
||||||
|
CASE_ROW(1) \
|
||||||
|
break; \
|
||||||
|
case 2: \
|
||||||
|
CASE_ROW(2) \
|
||||||
|
break; \
|
||||||
|
case 3: \
|
||||||
|
CASE_ROW(3) \
|
||||||
|
break; \
|
||||||
|
case 4: \
|
||||||
|
CASE_ROW(4) \
|
||||||
|
break; \
|
||||||
|
case 5: \
|
||||||
|
CASE_ROW(5) \
|
||||||
|
break; \
|
||||||
|
case 6: \
|
||||||
|
CASE_ROW(6) \
|
||||||
|
break; \
|
||||||
|
case 7: \
|
||||||
|
CASE_ROW(7) \
|
||||||
|
break; \
|
||||||
|
case 10: \
|
||||||
|
CASE_ROW(10) \
|
||||||
|
break; \
|
||||||
|
case 11: \
|
||||||
|
CASE_ROW(11) \
|
||||||
|
break; \
|
||||||
|
case 12: \
|
||||||
|
CASE_ROW(12) \
|
||||||
|
break; \
|
||||||
|
case 13: \
|
||||||
|
CASE_ROW(13) \
|
||||||
|
break; \
|
||||||
|
case 16: \
|
||||||
|
CASE_ROW(16) \
|
||||||
|
break; \
|
||||||
|
default: \
|
||||||
|
IT_TODO_HALT(); \
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimization for expanding a row vector. The row length must be a multiple of 32
|
||||||
|
void expandRowKernel(int dType, void *input, void *output, int n_rows,
|
||||||
|
int row_len) {
|
||||||
|
// Factorize row_len: row_len = a x b x 32 (32 is the warp size), b<=32
|
||||||
|
// input: 1 x (a x b x 32 x sizeT)
|
||||||
|
// output: n_rows x (a x b x 32 x sizeT)
|
||||||
|
// grid: n_rows x a
|
||||||
|
// block: b x 32
|
||||||
|
auto c = row_len / 32, b = c;
|
||||||
|
if (b > 32) {
|
||||||
|
for (b = 32; c % b != 0; --b);
|
||||||
|
}
|
||||||
|
auto a = c / b;
|
||||||
|
dim3 grid(a, n_rows), block(32, b);
|
||||||
|
SWITCH_DTYPE_ROW(dType)
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace infini
|
} // namespace infini
|
||||||
|
|
|
@ -102,9 +102,19 @@ class matmulCublas : public Kernel {
|
||||||
inputShape.data[i] = inC->getDims()[i - offset];
|
inputShape.data[i] = inC->getDims()[i - offset];
|
||||||
}
|
}
|
||||||
const int dType = dataType.getIndex();
|
const int dType = dataType.getIndex();
|
||||||
expandKernel(dType, inC->getRawDataPtr<void *>(),
|
|
||||||
out->getRawDataPtr<void *>(), nDims, outputsize,
|
// Bias in linear layer is row vector of (1,n), n is the number of
|
||||||
inputShape, outputShape);
|
// features. If row vector and n % 32 == 0, use optimized kernel.
|
||||||
|
if (inC->getRank() == 1 && inC->getDims()[0] % 32 == 0) {
|
||||||
|
expandRowKernel(dType, inC->getRawDataPtr<void *>(),
|
||||||
|
out->getRawDataPtr<void *>(),
|
||||||
|
out->size() / inC->getDims()[0],
|
||||||
|
inC->getDims()[0]);
|
||||||
|
} else {
|
||||||
|
expandKernel(dType, inC->getRawDataPtr<void *>(),
|
||||||
|
out->getRawDataPtr<void *>(), nDims, outputsize,
|
||||||
|
inputShape, outputShape);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// TODO:use compute type
|
// TODO:use compute type
|
||||||
cublasStatus_t stat;
|
cublasStatus_t stat;
|
||||||
|
|
Loading…
Reference in New Issue