From 7f6aec6c17cc64a1129301519c73abe2d3de2f69 Mon Sep 17 00:00:00 2001 From: PanZezhong1725 <141193946+PanZezhong1725@users.noreply.github.com> Date: Mon, 1 Apr 2024 14:04:28 +0800 Subject: [PATCH] =?UTF-8?q?=E9=92=88=E5=AF=B9bert=E5=92=8Cgpt2=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=88=86=E5=B8=83=E5=BC=8F=E6=8E=A8=E7=90=86=E7=9A=84?= =?UTF-8?q?=E4=BC=98=E5=8C=96=20(#221)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(dist): 改善分布式脚本,只打印绝对误差 * feat(dist): 增加可导出onnx的pytorch运行脚本 * feat(front): 增加对Y值为-inf的where算子的图优化 * feat(kernel): 对b为常数的pow和div算子进行特判优化 * fix(front): 消除前端对global output形状信息的依赖,分布式脚本删除不必要的shape infer * feat(kernel): 针对matmul中bias为行向量时的expand操作的特化优化 * fix(kernel): 删除div pow const中不必要的同步 * Update expand.cu * fix: fix comments --------- Co-authored-by: Haojie Wang Co-authored-by: Derui Yang --- examples/distributed/README.md | 17 +++ examples/distributed/cuda_launch.py | 8 +- examples/distributed/parallel_opt.py | 2 +- examples/distributed/run_pytorch.py | 178 ++++++++++++++++++++++ include/cuda/cuda_element_wise.h | 4 + include/cuda/cuda_expand.h | 2 + pyinfinitensor/src/pyinfinitensor/onnx.py | 30 +++- src/kernels/cuda/element_wise.cc | 15 +- src/kernels/cuda/element_wise.cu | 103 +++++++++++-- src/kernels/cuda/expand.cu | 71 +++++++++ src/kernels/cuda/matmul.cc | 16 +- 11 files changed, 421 insertions(+), 25 deletions(-) create mode 100644 examples/distributed/README.md create mode 100644 examples/distributed/run_pytorch.py diff --git a/examples/distributed/README.md b/examples/distributed/README.md new file mode 100644 index 00000000..62601d93 --- /dev/null +++ b/examples/distributed/README.md @@ -0,0 +1,17 @@ +# 分布式脚本 + +#### 1. 运行pytorch模型并生成输入和标准输出,可选择导出onnx + +使用 `--export_onnx` 设置导出onnx的目录,默认为当前路径 `./`,不使用这个flag则只进行计算和生成输入输出。 + +```bash +python run_pytorch.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./ +``` + +会在当前目录下生成输入输出文件`test_inputs.npy` 和 `test_results.npy`,目前只支持单一输入输出。 + +#### 2. 运行InfiniTensor分布式脚本 + +```bash +python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4 +``` diff --git a/examples/distributed/cuda_launch.py b/examples/distributed/cuda_launch.py index 58f7efb3..0f48598a 100644 --- a/examples/distributed/cuda_launch.py +++ b/examples/distributed/cuda_launch.py @@ -47,7 +47,7 @@ def parse_args(): def run_model(model, runtime, inputs, n=10): stub = OnnxStub(model, runtime) - for tensor, input in zip(stub.inputs.values(), inputs): + for tensor, input in zip(stub.inputs.values(), inputs, strict=False): tensor.copyin_numpy(input) # stub.tune() stub.run() @@ -55,7 +55,7 @@ def run_model(model, runtime, inputs, n=10): outputs = next(stub.outputs.values().__iter__()).copyout_numpy() # bench - for tensor, input in zip(stub.inputs.values(), inputs): + for tensor, input in zip(stub.inputs.values(), inputs, strict=False): tensor.copyin_numpy(input) begin = time.time() for _ in range(n): @@ -72,7 +72,7 @@ def run_and_compare(name, model, runtime): results = np.load(f"{name}_results.npy") outputs = run_model(model, runtime, (input_ids, position_ids)) print("outputs abs mean:", abs(outputs).mean()) - np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3) + print("max abs diff:", abs(outputs - results).max()) def start_worker( @@ -89,7 +89,7 @@ def start_worker( save_as_external_data=True, location=extern_path, ) - infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") + #infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") runtime = backend.CudaRuntime(local_rank) # print("init comm") runtime.init_comm( diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py index bbb0ac65..804e48c6 100644 --- a/examples/distributed/parallel_opt.py +++ b/examples/distributed/parallel_opt.py @@ -244,5 +244,5 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): if tt.HasField("shape"): tt.ClearField("shape") model = helper.make_model(graph) - model = onnx.shape_inference.infer_shapes(model) + #model = onnx.shape_inference.infer_shapes(model) return model diff --git a/examples/distributed/run_pytorch.py b/examples/distributed/run_pytorch.py new file mode 100644 index 00000000..ea2af234 --- /dev/null +++ b/examples/distributed/run_pytorch.py @@ -0,0 +1,178 @@ +import argparse +import torch +from transformers import BertModel, BertConfig +from transformers import GPT2Model, GPT2Config +from transformers import OPTModel, OPTConfig +import time +import numpy as np +import onnx +import os +from onnx.external_data_helper import convert_model_to_external_data +from onnxsim import simplify + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False +def parse_args(): + parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.") + parser.add_argument( + "--model", type=str, choices=["gpt2", "bert", "opt"], required=True, help="model type" + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=1, help="sequence length.") + parser.add_argument( + "--export_onnx", + type=str, + nargs="?", + default=None, + const="./", + help="whether and where to export onnx file", + ) + args = parser.parse_args() + args = parser.parse_args() + print("arg setting: ", args) + return ( + args.model, + args.batch_size, + args.length, + args.export_onnx + ) + + +def get_model(modelname): + match modelname: + case "bert": + model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini + voc_size = BertConfig().vocab_size + case "gpt2": + model = GPT2Model.from_pretrained("gpt2") + voc_size = GPT2Config().vocab_size + case "opt": + model = model = OPTModel.from_pretrained("./opt-125m") + voc_size = OPTConfig().vocab_size + case _: + raise KeyError(modelname) + + model = model.eval() + return model, voc_size + +def run_pytorch(torch_model, voc_size, batchsize, len): + data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32) + np.save("test_inputs", data) + inputs = torch.from_numpy(data).to("cuda") + torch_model = torch_model.to("cuda") + + n_iter = 20 + with torch.no_grad(): + for _ in range(10): + outputs = torch_model(inputs) + torch.cuda.synchronize() + begin = time.time() + with torch.no_grad(): + for _ in range(n_iter): + torch.cuda.synchronize() + outputs = torch_model(inputs) + # + torch.cuda.synchronize() + torch.cuda.synchronize() + end = time.time() + + avg_time = (end - begin) / n_iter + outputs = outputs.last_hidden_state.to("cpu") + print("outputs abs mean:", abs(np.array(outputs)).mean()) + print(f"average time: {avg_time}") + torch.cuda.memory.empty_cache() + np.save("test_results", np.array(outputs)) + print("Save input & output as test_inputs.npy and test_results.npy") + + +def export_onnx(model, data, path, extern=False): + torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True) + onnx_model = onnx.load(path) + onnx_model, check = simplify(onnx_model, skipped_optimizers=['eliminate_duplicate_initializer']) + #onnx_model, check = simplify(onnx_model, skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer']) + assert check + add_value_info_for_constants(onnx_model) + onnx_model = onnx.shape_inference.infer_shapes(onnx_model) + if extern: + extern_path = path.replace('.onnx', '.pb') + if os.path.exists(extern_path): + os.remove(extern_path) + convert_model_to_external_data( + onnx_model, + all_tensors_to_one_file=True, + location=extern_path, + size_threshold=1024, + convert_attribute=False, + ) + onnx.save(onnx_model, path) + +def add_value_info_for_constants(model : onnx.ModelProto): + """ + Currently onnx.shape_inference doesn't use the shape of initializers, so add + that info explicitly as ValueInfoProtos. + Mutates the model. + Args: + model: The ModelProto to update. + """ + # All (top-level) constants will have ValueInfos before IRv4 as they are all inputs + if model.ir_version < 4: + return + + def add_const_value_infos_to_graph(graph : onnx.GraphProto): + inputs = {i.name for i in graph.input} + existing_info = {vi.name: vi for vi in graph.value_info} + for init in graph.initializer: + # Check it really is a constant, not an input + if init.name in inputs: + continue + + # The details we want to add + elem_type = init.data_type + shape = init.dims + + # Get existing or create new value info for this constant + vi = existing_info.get(init.name) + if vi is None: + vi = graph.value_info.add() + vi.name = init.name + + # Even though it would be weird, we will not overwrite info even if it doesn't match + tt = vi.type.tensor_type + if tt.elem_type == onnx.TensorProto.UNDEFINED: + tt.elem_type = elem_type + if not tt.HasField("shape"): + # Ensure we set an empty list if the const is scalar (zero dims) + tt.shape.dim.extend([]) + for dim in shape: + tt.shape.dim.add().dim_value = dim + + # Handle subgraphs + for node in graph.node: + for attr in node.attribute: + # Ref attrs refer to other attrs, so we don't need to do anything + if attr.ref_attr_name != "": + continue + + if attr.type == onnx.AttributeProto.GRAPH: + add_const_value_infos_to_graph(attr.g) + if attr.type == onnx.AttributeProto.GRAPHS: + for g in attr.graphs: + add_const_value_infos_to_graph(g) + + + return add_const_value_infos_to_graph(model.graph) + + +def main(): + modelname, batchsize, seqlen, export_path = parse_args() + model, voc_size = get_model(modelname) + if export_path is not None: + filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen) + path = os.path.join(export_path, filename) + param = torch.zeros((batchsize, seqlen), dtype=torch.int) + export_onnx(model, param, path, True) + + run_pytorch(model, voc_size, batchsize, seqlen) + +if __name__ == "__main__": + main() diff --git a/include/cuda/cuda_element_wise.h b/include/cuda/cuda_element_wise.h index 10bb1bca..b4a5d6ac 100644 --- a/include/cuda/cuda_element_wise.h +++ b/include/cuda/cuda_element_wise.h @@ -13,4 +13,8 @@ void pow_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1, void less_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1, int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3); + +void div_const_kernel(int dType, void *a, void *b, void *c, size_t n); + +void pow_const_kernel(int dType, void *a, void *b, void *c, size_t n); }; // namespace infini diff --git a/include/cuda/cuda_expand.h b/include/cuda/cuda_expand.h index 3723a8e7..4001df41 100644 --- a/include/cuda/cuda_expand.h +++ b/include/cuda/cuda_expand.h @@ -7,4 +7,6 @@ void expandKernel(int dType, void *input, void *output, int nDims, int outputsize, SmallArray inputShape, SmallArray outputShape); +void expandRowKernel(int dType, void *input, void *output, int n_rows, + int row_len); }; // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index f47dcd0a..fc1e0bbc 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -23,12 +23,13 @@ from onnx.checker import ( ValidationError, ) from onnx.shape_inference import infer_shapes -from onnx.numpy_helper import to_array +from onnx.numpy_helper import to_array, from_array from typing import Dict, List, Any, Tuple, Sequence, Union, Optional from functools import reduce from onnxsim import simplify import copy import warnings +import numpy as np class OnnxStub: @@ -111,12 +112,6 @@ class OnnxStub: ) tensors[input.name].set_input() - for output in model.graph.output: - dims = _take_shape_dim(output.type.tensor_type.shape) - tensors[output.name] = self.handler.tensor( - dims, output.type.tensor_type.elem_type - ) - tensors[output.name].set_output() for node_idx in sorted_nodes: node = model.graph.node[node_idx] @@ -947,6 +942,25 @@ class OnnxStub: tensors.get(node.output[0]), ) elif node.op_type == "Where": + ## If Y is single -inf, treat Where as Add + ## TODO: deal with cases where Y is single inf or 0 + if node.input[0] in data and node.input[2] in data: + where_condition = to_array(data[node.input[0]]) + where_alt = to_array(data[node.input[2]]) + if where_alt.size == 1: + if np.isneginf(where_alt) or np.all(where_alt < -3e38): + node.input[0] = node.input[0] + "_alt" + if node.input[0] not in data: + where_value = np.where(where_condition, 0, -np.inf).astype(where_alt.dtype) + data[node.input[0]] = from_array(where_value, node.input[0]) + tensors[node.input[0]] = self.handler.tensor(list(where_value.shape), data[node.input[0]].data_type) + tensors[node.input[0]].set_weight() + tensors[node.output[0]] = self.handler.add( + tensors[node.input[1]], + tensors[node.input[0]], + tensors.get(node.output[0]), + ) + continue tensors[node.output[0]] = self.handler.where( tensors[node.input[1]], tensors[node.input[2]], @@ -980,6 +994,8 @@ class OnnxStub: else: raise Exception('Unsupported operator "{}"'.format(node.op_type)) + for output in model.graph.output: + tensors[output.name].set_output() ################################ # Allocate memory space for data ################################ diff --git a/src/kernels/cuda/element_wise.cc b/src/kernels/cuda/element_wise.cc index 4a16de29..dea552b9 100644 --- a/src/kernels/cuda/element_wise.cc +++ b/src/kernels/cuda/element_wise.cc @@ -115,6 +115,20 @@ class ElementWiseCuda : public CudaKernelWithoutConfig { auto a_dim = op->getInputs(0)->getDims(); auto b_dim = op->getInputs(1)->getDims(); auto c_dim = op->getOutput()->getDims(); + const int dType = _op->getDType().getIndex(); + + // Use optimized kernel if b is constant + if (b_dim.size() == 0) { + if (op->getOpType() == OpType::Div) { + div_const_kernel(dType, aData, bData, cData, + op->getOutput()->size()); + return; + } else if (op->getOpType() == OpType::Pow) { + pow_const_kernel(dType, aData, bData, cData, + op->getOutput()->size()); + return; + } + } if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4) IT_TODO_HALT(); @@ -127,7 +141,6 @@ class ElementWiseCuda : public CudaKernelWithoutConfig { std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size())); std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size())); - const int dType = _op->getDType().getIndex(); if (op->getOpType() == OpType::Div) { div_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0], b[1], b[2], b[3], c[0], c[1], c[2], c[3]); diff --git a/src/kernels/cuda/element_wise.cu b/src/kernels/cuda/element_wise.cu index e1b68699..a729452e 100644 --- a/src/kernels/cuda/element_wise.cu +++ b/src/kernels/cuda/element_wise.cu @@ -132,8 +132,8 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2, #define CASE(OP, T) \ _##OP##_kernel::t> \ - <<>> \ - (a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); + <<>>( \ + a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); #define SWITCH_DTYPE(OP, DTYPE) \ switch (DTYPE) { \ @@ -177,7 +177,92 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2, IT_TODO_HALT(); \ } +template +__global__ void _div_const_kernel(void const *__restrict__ x, + void const *__restrict__ y, + void *__restrict__ z, const size_t n) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) { + ((T *)z)[tid] = ((T *)x)[tid] / *((T *)y); + } +} + +template +__global__ void _pow_const_kernel(void const *__restrict__ x, + void const *__restrict__ y, + void *__restrict__ z, const size_t n) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) { + ((T *)z)[tid] = pow(((T *)x)[tid], *((T *)y)); + } +} +template <> +__global__ void _pow_const_kernel(void const *__restrict__ x, + void const *__restrict__ y, + void *__restrict__ z, const size_t n) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < n) { + ((half *)z)[tid] = pow(((float)((half *)x)[tid]), *((half *)y)); + } +} + +#define CASE_CONST(OP, T) \ + _##OP##_const_kernel::t> \ + <<>>(a, b, c, \ + n); + +#define SWITCH_DTYPE_CONST(OP, DTYPE) \ + switch (DTYPE) { \ + case 1: \ + CASE_CONST(OP, 1) \ + break; \ + case 2: \ + CASE_CONST(OP, 2) \ + break; \ + case 3: \ + CASE_CONST(OP, 3) \ + break; \ + case 4: \ + CASE_CONST(OP, 4) \ + break; \ + case 5: \ + CASE_CONST(OP, 5) \ + break; \ + case 6: \ + CASE_CONST(OP, 6) \ + break; \ + case 7: \ + CASE_CONST(OP, 7) \ + break; \ + case 10: \ + CASE_CONST(OP, 10) \ + break; \ + case 11: \ + CASE_CONST(OP, 11) \ + break; \ + case 12: \ + CASE_CONST(OP, 12) \ + break; \ + case 13: \ + CASE_CONST(OP, 13) \ + break; \ + default: \ + IT_TODO_HALT(); \ + } + namespace infini { +void div_const_kernel(int dType, void *a, void *b, void *c, size_t n) { + size_t blocksize = block_work_size(); + size_t gridsize = (n + block_work_size() - 1) / block_work_size(); + SWITCH_DTYPE_CONST(div, dType); +} + +void pow_const_kernel(int dType, void *a, void *b, void *c, size_t n) { + size_t blocksize = block_work_size(); + size_t gridsize = (n + block_work_size() - 1) / block_work_size(); + SWITCH_DTYPE_CONST(pow, dType); +} + void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2, int c3) { @@ -204,12 +289,12 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, int gridsize = (num + block_work_size() - 1) / block_work_size(); if (dType == 1) { _pow_kernel - <<>> - (a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); + <<>>( + a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); } else if (dType == 3) { _pow_kernel - <<>> - (a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); + <<>>( + a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3); } else if (dType == 10) { int a_size = a0 * a1 * a2 * a3; int b_size = b0 * b1 * b2 * b3; @@ -224,9 +309,9 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2, b_float[i] = __half2float(((half *)b)[i]); } _pow_kernel - <<>> - (a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0, - b1, b2, b3, c0, c1, c2, c3); + <<>>( + a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, + b0, b1, b2, b3, c0, c1, c2, c3); for (int i = 0; i < c_size; ++i) { ((half *)c)[i] = __float2half(c_float[i]); } diff --git a/src/kernels/cuda/expand.cu b/src/kernels/cuda/expand.cu index 5e22be44..3fbf929e 100644 --- a/src/kernels/cuda/expand.cu +++ b/src/kernels/cuda/expand.cu @@ -39,6 +39,14 @@ __global__ void _expandKernel(void *input, void *output, int nDims, } } +template +static __global__ void _expandRowKernel(void *__restrict__ dst, + void const *__restrict__ src) { + auto da = gridDim.x, db = blockDim.y, dx = blockDim.x, n = blockIdx.y, + a = blockIdx.x, b = threadIdx.y, x = threadIdx.x; + auto i = ((n * da + a) * db + b) * dx + x, j = (a * db + b) * dx + x; + reinterpret_cast(dst)[i] = reinterpret_cast(src)[j]; +} namespace infini { #define CASE(T) \ @@ -96,4 +104,67 @@ void expandKernel(int dType, void *input, void *output, int nDims, SWITCH_DTYPE(dType) } +#define CASE_ROW(T) \ + _expandRowKernel \ + <<>>(output, input); + +#define SWITCH_DTYPE_ROW(DTYPE) \ + switch (DTYPE) { \ + case 1: \ + CASE_ROW(1) \ + break; \ + case 2: \ + CASE_ROW(2) \ + break; \ + case 3: \ + CASE_ROW(3) \ + break; \ + case 4: \ + CASE_ROW(4) \ + break; \ + case 5: \ + CASE_ROW(5) \ + break; \ + case 6: \ + CASE_ROW(6) \ + break; \ + case 7: \ + CASE_ROW(7) \ + break; \ + case 10: \ + CASE_ROW(10) \ + break; \ + case 11: \ + CASE_ROW(11) \ + break; \ + case 12: \ + CASE_ROW(12) \ + break; \ + case 13: \ + CASE_ROW(13) \ + break; \ + case 16: \ + CASE_ROW(16) \ + break; \ + default: \ + IT_TODO_HALT(); \ + } + +// Optimization for expanding a row vector. The row length must be a multiple of 32 +void expandRowKernel(int dType, void *input, void *output, int n_rows, + int row_len) { + // Factorize row_len: row_len = a x b x 32 (32 is the warp size), b<=32 + // input: 1 x (a x b x 32 x sizeT) + // output: n_rows x (a x b x 32 x sizeT) + // grid: n_rows x a + // block: b x 32 + auto c = row_len / 32, b = c; + if (b > 32) { + for (b = 32; c % b != 0; --b); + } + auto a = c / b; + dim3 grid(a, n_rows), block(32, b); + SWITCH_DTYPE_ROW(dType) +} + } // namespace infini diff --git a/src/kernels/cuda/matmul.cc b/src/kernels/cuda/matmul.cc index 771cadb6..de2c646e 100644 --- a/src/kernels/cuda/matmul.cc +++ b/src/kernels/cuda/matmul.cc @@ -102,9 +102,19 @@ class matmulCublas : public Kernel { inputShape.data[i] = inC->getDims()[i - offset]; } const int dType = dataType.getIndex(); - expandKernel(dType, inC->getRawDataPtr(), - out->getRawDataPtr(), nDims, outputsize, - inputShape, outputShape); + + // Bias in linear layer is row vector of (1,n), n is the number of + // features. If row vector and n % 32 == 0, use optimized kernel. + if (inC->getRank() == 1 && inC->getDims()[0] % 32 == 0) { + expandRowKernel(dType, inC->getRawDataPtr(), + out->getRawDataPtr(), + out->size() / inC->getDims()[0], + inC->getDims()[0]); + } else { + expandKernel(dType, inC->getRawDataPtr(), + out->getRawDataPtr(), nDims, outputsize, + inputShape, outputShape); + } } // TODO:use compute type cublasStatus_t stat;