针对bert和gpt2模型分布式推理的优化 (#221)

* fix(dist): 改善分布式脚本,只打印绝对误差

* feat(dist): 增加可导出onnx的pytorch运行脚本

* feat(front): 增加对Y值为-inf的where算子的图优化

* feat(kernel): 对b为常数的pow和div算子进行特判优化

* fix(front): 消除前端对global output形状信息的依赖,分布式脚本删除不必要的shape infer

* feat(kernel): 针对matmul中bias为行向量时的expand操作的特化优化

* fix(kernel): 删除div pow const中不必要的同步

* Update expand.cu

* fix: fix comments

---------

Co-authored-by: Haojie Wang <haojie0429@gmail.com>
Co-authored-by: Derui Yang <ydrml@hotmail.com>
This commit is contained in:
PanZezhong1725 2024-04-01 14:04:28 +08:00 committed by GitHub
parent a98573990b
commit 7f6aec6c17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 421 additions and 25 deletions

View File

@ -0,0 +1,17 @@
# 分布式脚本
#### 1. 运行pytorch模型并生成输入和标准输出可选择导出onnx
使用 `--export_onnx` 设置导出onnx的目录默认为当前路径 `./`不使用这个flag则只进行计算和生成输入输出。
```bash
python run_pytorch.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
```
会在当前目录下生成输入输出文件`test_inputs.npy` 和 `test_results.npy`,目前只支持单一输入输出。
#### 2. 运行InfiniTensor分布式脚本
```bash
python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
```

View File

@ -47,7 +47,7 @@ def parse_args():
def run_model(model, runtime, inputs, n=10): def run_model(model, runtime, inputs, n=10):
stub = OnnxStub(model, runtime) stub = OnnxStub(model, runtime)
for tensor, input in zip(stub.inputs.values(), inputs): for tensor, input in zip(stub.inputs.values(), inputs, strict=False):
tensor.copyin_numpy(input) tensor.copyin_numpy(input)
# stub.tune() # stub.tune()
stub.run() stub.run()
@ -55,7 +55,7 @@ def run_model(model, runtime, inputs, n=10):
outputs = next(stub.outputs.values().__iter__()).copyout_numpy() outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
# bench # bench
for tensor, input in zip(stub.inputs.values(), inputs): for tensor, input in zip(stub.inputs.values(), inputs, strict=False):
tensor.copyin_numpy(input) tensor.copyin_numpy(input)
begin = time.time() begin = time.time()
for _ in range(n): for _ in range(n):
@ -72,7 +72,7 @@ def run_and_compare(name, model, runtime):
results = np.load(f"{name}_results.npy") results = np.load(f"{name}_results.npy")
outputs = run_model(model, runtime, (input_ids, position_ids)) outputs = run_model(model, runtime, (input_ids, position_ids))
print("outputs abs mean:", abs(outputs).mean()) print("outputs abs mean:", abs(outputs).mean())
np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3) print("max abs diff:", abs(outputs - results).max())
def start_worker( def start_worker(
@ -89,7 +89,7 @@ def start_worker(
save_as_external_data=True, save_as_external_data=True,
location=extern_path, location=extern_path,
) )
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") #infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
runtime = backend.CudaRuntime(local_rank) runtime = backend.CudaRuntime(local_rank)
# print("init comm") # print("init comm")
runtime.init_comm( runtime.init_comm(

View File

@ -244,5 +244,5 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
if tt.HasField("shape"): if tt.HasField("shape"):
tt.ClearField("shape") tt.ClearField("shape")
model = helper.make_model(graph) model = helper.make_model(graph)
model = onnx.shape_inference.infer_shapes(model) #model = onnx.shape_inference.infer_shapes(model)
return model return model

View File

@ -0,0 +1,178 @@
import argparse
import torch
from transformers import BertModel, BertConfig
from transformers import GPT2Model, GPT2Config
from transformers import OPTModel, OPTConfig
import time
import numpy as np
import onnx
import os
from onnx.external_data_helper import convert_model_to_external_data
from onnxsim import simplify
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def parse_args():
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
parser.add_argument(
"--model", type=str, choices=["gpt2", "bert", "opt"], required=True, help="model type"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument(
"--export_onnx",
type=str,
nargs="?",
default=None,
const="./",
help="whether and where to export onnx file",
)
args = parser.parse_args()
args = parser.parse_args()
print("arg setting: ", args)
return (
args.model,
args.batch_size,
args.length,
args.export_onnx
)
def get_model(modelname):
match modelname:
case "bert":
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
voc_size = BertConfig().vocab_size
case "gpt2":
model = GPT2Model.from_pretrained("gpt2")
voc_size = GPT2Config().vocab_size
case "opt":
model = model = OPTModel.from_pretrained("./opt-125m")
voc_size = OPTConfig().vocab_size
case _:
raise KeyError(modelname)
model = model.eval()
return model, voc_size
def run_pytorch(torch_model, voc_size, batchsize, len):
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
np.save("test_inputs", data)
inputs = torch.from_numpy(data).to("cuda")
torch_model = torch_model.to("cuda")
n_iter = 20
with torch.no_grad():
for _ in range(10):
outputs = torch_model(inputs)
torch.cuda.synchronize()
begin = time.time()
with torch.no_grad():
for _ in range(n_iter):
torch.cuda.synchronize()
outputs = torch_model(inputs)
#
torch.cuda.synchronize()
torch.cuda.synchronize()
end = time.time()
avg_time = (end - begin) / n_iter
outputs = outputs.last_hidden_state.to("cpu")
print("outputs abs mean:", abs(np.array(outputs)).mean())
print(f"average time: {avg_time}")
torch.cuda.memory.empty_cache()
np.save("test_results", np.array(outputs))
print("Save input & output as test_inputs.npy and test_results.npy")
def export_onnx(model, data, path, extern=False):
torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
onnx_model = onnx.load(path)
onnx_model, check = simplify(onnx_model, skipped_optimizers=['eliminate_duplicate_initializer'])
#onnx_model, check = simplify(onnx_model, skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
assert check
add_value_info_for_constants(onnx_model)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
if extern:
extern_path = path.replace('.onnx', '.pb')
if os.path.exists(extern_path):
os.remove(extern_path)
convert_model_to_external_data(
onnx_model,
all_tensors_to_one_file=True,
location=extern_path,
size_threshold=1024,
convert_attribute=False,
)
onnx.save(onnx_model, path)
def add_value_info_for_constants(model : onnx.ModelProto):
"""
Currently onnx.shape_inference doesn't use the shape of initializers, so add
that info explicitly as ValueInfoProtos.
Mutates the model.
Args:
model: The ModelProto to update.
"""
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
if model.ir_version < 4:
return
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
inputs = {i.name for i in graph.input}
existing_info = {vi.name: vi for vi in graph.value_info}
for init in graph.initializer:
# Check it really is a constant, not an input
if init.name in inputs:
continue
# The details we want to add
elem_type = init.data_type
shape = init.dims
# Get existing or create new value info for this constant
vi = existing_info.get(init.name)
if vi is None:
vi = graph.value_info.add()
vi.name = init.name
# Even though it would be weird, we will not overwrite info even if it doesn't match
tt = vi.type.tensor_type
if tt.elem_type == onnx.TensorProto.UNDEFINED:
tt.elem_type = elem_type
if not tt.HasField("shape"):
# Ensure we set an empty list if the const is scalar (zero dims)
tt.shape.dim.extend([])
for dim in shape:
tt.shape.dim.add().dim_value = dim
# Handle subgraphs
for node in graph.node:
for attr in node.attribute:
# Ref attrs refer to other attrs, so we don't need to do anything
if attr.ref_attr_name != "":
continue
if attr.type == onnx.AttributeProto.GRAPH:
add_const_value_infos_to_graph(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_const_value_infos_to_graph(g)
return add_const_value_infos_to_graph(model.graph)
def main():
modelname, batchsize, seqlen, export_path = parse_args()
model, voc_size = get_model(modelname)
if export_path is not None:
filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen)
path = os.path.join(export_path, filename)
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
export_onnx(model, param, path, True)
run_pytorch(model, voc_size, batchsize, seqlen)
if __name__ == "__main__":
main()

View File

@ -13,4 +13,8 @@ void pow_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
void less_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1, void less_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
int c2, int c3); int c2, int c3);
void div_const_kernel(int dType, void *a, void *b, void *c, size_t n);
void pow_const_kernel(int dType, void *a, void *b, void *c, size_t n);
}; // namespace infini }; // namespace infini

View File

@ -7,4 +7,6 @@ void expandKernel(int dType, void *input, void *output, int nDims,
int outputsize, SmallArray inputShape, int outputsize, SmallArray inputShape,
SmallArray outputShape); SmallArray outputShape);
void expandRowKernel(int dType, void *input, void *output, int n_rows,
int row_len);
}; // namespace infini }; // namespace infini

View File

@ -23,12 +23,13 @@ from onnx.checker import (
ValidationError, ValidationError,
) )
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
from onnx.numpy_helper import to_array from onnx.numpy_helper import to_array, from_array
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
from functools import reduce from functools import reduce
from onnxsim import simplify from onnxsim import simplify
import copy import copy
import warnings import warnings
import numpy as np
class OnnxStub: class OnnxStub:
@ -111,12 +112,6 @@ class OnnxStub:
) )
tensors[input.name].set_input() tensors[input.name].set_input()
for output in model.graph.output:
dims = _take_shape_dim(output.type.tensor_type.shape)
tensors[output.name] = self.handler.tensor(
dims, output.type.tensor_type.elem_type
)
tensors[output.name].set_output()
for node_idx in sorted_nodes: for node_idx in sorted_nodes:
node = model.graph.node[node_idx] node = model.graph.node[node_idx]
@ -947,6 +942,25 @@ class OnnxStub:
tensors.get(node.output[0]), tensors.get(node.output[0]),
) )
elif node.op_type == "Where": elif node.op_type == "Where":
## If Y is single -inf, treat Where as Add
## TODO: deal with cases where Y is single inf or 0
if node.input[0] in data and node.input[2] in data:
where_condition = to_array(data[node.input[0]])
where_alt = to_array(data[node.input[2]])
if where_alt.size == 1:
if np.isneginf(where_alt) or np.all(where_alt < -3e38):
node.input[0] = node.input[0] + "_alt"
if node.input[0] not in data:
where_value = np.where(where_condition, 0, -np.inf).astype(where_alt.dtype)
data[node.input[0]] = from_array(where_value, node.input[0])
tensors[node.input[0]] = self.handler.tensor(list(where_value.shape), data[node.input[0]].data_type)
tensors[node.input[0]].set_weight()
tensors[node.output[0]] = self.handler.add(
tensors[node.input[1]],
tensors[node.input[0]],
tensors.get(node.output[0]),
)
continue
tensors[node.output[0]] = self.handler.where( tensors[node.output[0]] = self.handler.where(
tensors[node.input[1]], tensors[node.input[1]],
tensors[node.input[2]], tensors[node.input[2]],
@ -980,6 +994,8 @@ class OnnxStub:
else: else:
raise Exception('Unsupported operator "{}"'.format(node.op_type)) raise Exception('Unsupported operator "{}"'.format(node.op_type))
for output in model.graph.output:
tensors[output.name].set_output()
################################ ################################
# Allocate memory space for data # Allocate memory space for data
################################ ################################

View File

@ -115,6 +115,20 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
auto a_dim = op->getInputs(0)->getDims(); auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims(); auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims(); auto c_dim = op->getOutput()->getDims();
const int dType = _op->getDType().getIndex();
// Use optimized kernel if b is constant
if (b_dim.size() == 0) {
if (op->getOpType() == OpType::Div) {
div_const_kernel(dType, aData, bData, cData,
op->getOutput()->size());
return;
} else if (op->getOpType() == OpType::Pow) {
pow_const_kernel(dType, aData, bData, cData,
op->getOutput()->size());
return;
}
}
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4) if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4)
IT_TODO_HALT(); IT_TODO_HALT();
@ -127,7 +141,6 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size())); std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size())); std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
const int dType = _op->getDType().getIndex();
if (op->getOpType() == OpType::Div) { if (op->getOpType() == OpType::Div) {
div_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0], div_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
b[1], b[2], b[3], c[0], c[1], c[2], c[3]); b[1], b[2], b[3], c[0], c[1], c[2], c[3]);

View File

@ -132,8 +132,8 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
#define CASE(OP, T) \ #define CASE(OP, T) \
_##OP##_kernel<DT_CUDA<T>::t> \ _##OP##_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \ <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
#define SWITCH_DTYPE(OP, DTYPE) \ #define SWITCH_DTYPE(OP, DTYPE) \
switch (DTYPE) { \ switch (DTYPE) { \
@ -177,7 +177,92 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
IT_TODO_HALT(); \ IT_TODO_HALT(); \
} }
template <class T>
__global__ void _div_const_kernel(void const *__restrict__ x,
void const *__restrict__ y,
void *__restrict__ z, const size_t n) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) {
((T *)z)[tid] = ((T *)x)[tid] / *((T *)y);
}
}
template <class T>
__global__ void _pow_const_kernel(void const *__restrict__ x,
void const *__restrict__ y,
void *__restrict__ z, const size_t n) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) {
((T *)z)[tid] = pow(((T *)x)[tid], *((T *)y));
}
}
template <>
__global__ void _pow_const_kernel<half>(void const *__restrict__ x,
void const *__restrict__ y,
void *__restrict__ z, const size_t n) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) {
((half *)z)[tid] = pow(((float)((half *)x)[tid]), *((half *)y));
}
}
#define CASE_CONST(OP, T) \
_##OP##_const_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(a, b, c, \
n);
#define SWITCH_DTYPE_CONST(OP, DTYPE) \
switch (DTYPE) { \
case 1: \
CASE_CONST(OP, 1) \
break; \
case 2: \
CASE_CONST(OP, 2) \
break; \
case 3: \
CASE_CONST(OP, 3) \
break; \
case 4: \
CASE_CONST(OP, 4) \
break; \
case 5: \
CASE_CONST(OP, 5) \
break; \
case 6: \
CASE_CONST(OP, 6) \
break; \
case 7: \
CASE_CONST(OP, 7) \
break; \
case 10: \
CASE_CONST(OP, 10) \
break; \
case 11: \
CASE_CONST(OP, 11) \
break; \
case 12: \
CASE_CONST(OP, 12) \
break; \
case 13: \
CASE_CONST(OP, 13) \
break; \
default: \
IT_TODO_HALT(); \
}
namespace infini { namespace infini {
void div_const_kernel(int dType, void *a, void *b, void *c, size_t n) {
size_t blocksize = block_work_size();
size_t gridsize = (n + block_work_size() - 1) / block_work_size();
SWITCH_DTYPE_CONST(div, dType);
}
void pow_const_kernel(int dType, void *a, void *b, void *c, size_t n) {
size_t blocksize = block_work_size();
size_t gridsize = (n + block_work_size() - 1) / block_work_size();
SWITCH_DTYPE_CONST(pow, dType);
}
void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) { int c3) {
@ -204,12 +289,12 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int gridsize = (num + block_work_size() - 1) / block_work_size(); int gridsize = (num + block_work_size() - 1) / block_work_size();
if (dType == 1) { if (dType == 1) {
_pow_kernel<float> _pow_kernel<float>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
} else if (dType == 3) { } else if (dType == 3) {
_pow_kernel<int8_t> _pow_kernel<int8_t>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
} else if (dType == 10) { } else if (dType == 10) {
int a_size = a0 * a1 * a2 * a3; int a_size = a0 * a1 * a2 * a3;
int b_size = b0 * b1 * b2 * b3; int b_size = b0 * b1 * b2 * b3;
@ -224,9 +309,9 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
b_float[i] = __half2float(((half *)b)[i]); b_float[i] = __half2float(((half *)b)[i]);
} }
_pow_kernel<float> _pow_kernel<float>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> <<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
(a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0, a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3,
b1, b2, b3, c0, c1, c2, c3); b0, b1, b2, b3, c0, c1, c2, c3);
for (int i = 0; i < c_size; ++i) { for (int i = 0; i < c_size; ++i) {
((half *)c)[i] = __float2half(c_float[i]); ((half *)c)[i] = __float2half(c_float[i]);
} }

View File

@ -39,6 +39,14 @@ __global__ void _expandKernel(void *input, void *output, int nDims,
} }
} }
template <class T>
static __global__ void _expandRowKernel(void *__restrict__ dst,
void const *__restrict__ src) {
auto da = gridDim.x, db = blockDim.y, dx = blockDim.x, n = blockIdx.y,
a = blockIdx.x, b = threadIdx.y, x = threadIdx.x;
auto i = ((n * da + a) * db + b) * dx + x, j = (a * db + b) * dx + x;
reinterpret_cast<T *>(dst)[i] = reinterpret_cast<T const *>(src)[j];
}
namespace infini { namespace infini {
#define CASE(T) \ #define CASE(T) \
@ -96,4 +104,67 @@ void expandKernel(int dType, void *input, void *output, int nDims,
SWITCH_DTYPE(dType) SWITCH_DTYPE(dType)
} }
#define CASE_ROW(T) \
_expandRowKernel<float> \
<<<grid, block, 0, CUDAStream::getCurrentStream()>>>(output, input);
#define SWITCH_DTYPE_ROW(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE_ROW(1) \
break; \
case 2: \
CASE_ROW(2) \
break; \
case 3: \
CASE_ROW(3) \
break; \
case 4: \
CASE_ROW(4) \
break; \
case 5: \
CASE_ROW(5) \
break; \
case 6: \
CASE_ROW(6) \
break; \
case 7: \
CASE_ROW(7) \
break; \
case 10: \
CASE_ROW(10) \
break; \
case 11: \
CASE_ROW(11) \
break; \
case 12: \
CASE_ROW(12) \
break; \
case 13: \
CASE_ROW(13) \
break; \
case 16: \
CASE_ROW(16) \
break; \
default: \
IT_TODO_HALT(); \
}
// Optimization for expanding a row vector. The row length must be a multiple of 32
void expandRowKernel(int dType, void *input, void *output, int n_rows,
int row_len) {
// Factorize row_len: row_len = a x b x 32 (32 is the warp size), b<=32
// input: 1 x (a x b x 32 x sizeT)
// output: n_rows x (a x b x 32 x sizeT)
// grid: n_rows x a
// block: b x 32
auto c = row_len / 32, b = c;
if (b > 32) {
for (b = 32; c % b != 0; --b);
}
auto a = c / b;
dim3 grid(a, n_rows), block(32, b);
SWITCH_DTYPE_ROW(dType)
}
} // namespace infini } // namespace infini

View File

@ -102,9 +102,19 @@ class matmulCublas : public Kernel {
inputShape.data[i] = inC->getDims()[i - offset]; inputShape.data[i] = inC->getDims()[i - offset];
} }
const int dType = dataType.getIndex(); const int dType = dataType.getIndex();
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), nDims, outputsize, // Bias in linear layer is row vector of (1,n), n is the number of
inputShape, outputShape); // features. If row vector and n % 32 == 0, use optimized kernel.
if (inC->getRank() == 1 && inC->getDims()[0] % 32 == 0) {
expandRowKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(),
out->size() / inC->getDims()[0],
inC->getDims()[0]);
} else {
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), nDims, outputsize,
inputShape, outputShape);
}
} }
// TODO:use compute type // TODO:use compute type
cublasStatus_t stat; cublasStatus_t stat;