Merge branch 'master' of github.com:InfiniTensor/InfiniTensor into kunlun_dist_op

This commit is contained in:
wanghailu 2024-04-03 01:01:40 +08:00
commit 14a40a1967
23 changed files with 675 additions and 35 deletions

View File

@ -0,0 +1,17 @@
# 分布式脚本
#### 1. 运行pytorch模型并生成输入和标准输出可选择导出onnx
使用 `--export_onnx` 设置导出onnx的目录默认为当前路径 `./`不使用这个flag则只进行计算和生成输入输出。
```bash
python run_pytorch.py --model gpt2 --batch_size 1 --length 1 --export_onnx ./
```
会在当前目录下生成输入输出文件`test_inputs.npy` 和 `test_results.npy`,目前只支持单一输入输出。
#### 2. 运行InfiniTensor分布式脚本
```bash
python cuda_launch.py --model "/XXX/XXX.onnx" --nproc_per_node 4
```

View File

@ -47,7 +47,7 @@ def parse_args():
def run_model(model, runtime, inputs, n=10):
stub = OnnxStub(model, runtime)
for tensor, input in zip(stub.inputs.values(), inputs):
for tensor, input in zip(stub.inputs.values(), inputs, strict=False):
tensor.copyin_numpy(input)
# stub.tune()
stub.run()
@ -55,7 +55,7 @@ def run_model(model, runtime, inputs, n=10):
outputs = next(stub.outputs.values().__iter__()).copyout_numpy()
# bench
for tensor, input in zip(stub.inputs.values(), inputs):
for tensor, input in zip(stub.inputs.values(), inputs, strict=False):
tensor.copyin_numpy(input)
begin = time.time()
for _ in range(n):
@ -72,7 +72,7 @@ def run_and_compare(name, model, runtime):
results = np.load(f"{name}_results.npy")
outputs = run_model(model, runtime, (input_ids, position_ids))
print("outputs abs mean:", abs(outputs).mean())
np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3)
print("max abs diff:", abs(outputs - results).max())
def start_worker(
@ -89,7 +89,7 @@ def start_worker(
save_as_external_data=True,
location=extern_path,
)
infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
#infer_shapes_path(f"./{dist_name}_rank{rank}.onnx")
runtime = backend.CudaRuntime(local_rank)
# print("init comm")
runtime.init_comm(

View File

@ -245,7 +245,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
if tt.HasField("shape"):
tt.ClearField("shape")
model = helper.make_model(graph)
model = onnx.shape_inference.infer_shapes(model)
#model = onnx.shape_inference.infer_shapes(model)
return model
if __name__ == "__main__":

View File

@ -0,0 +1,178 @@
import argparse
import torch
from transformers import BertModel, BertConfig
from transformers import GPT2Model, GPT2Config
from transformers import OPTModel, OPTConfig
import time
import numpy as np
import onnx
import os
from onnx.external_data_helper import convert_model_to_external_data
from onnxsim import simplify
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
def parse_args():
parser = argparse.ArgumentParser(description="Run pytorch gpt2/bert/opt and optionally export onnx.")
parser.add_argument(
"--model", type=str, choices=["gpt2", "bert", "opt"], required=True, help="model type"
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument(
"--export_onnx",
type=str,
nargs="?",
default=None,
const="./",
help="whether and where to export onnx file",
)
args = parser.parse_args()
args = parser.parse_args()
print("arg setting: ", args)
return (
args.model,
args.batch_size,
args.length,
args.export_onnx
)
def get_model(modelname):
match modelname:
case "bert":
model = BertModel.from_pretrained("bert-base-uncased", add_pooling_layer=False, hidden_act="gelu_new") # erf is not impl by infini
voc_size = BertConfig().vocab_size
case "gpt2":
model = GPT2Model.from_pretrained("gpt2")
voc_size = GPT2Config().vocab_size
case "opt":
model = model = OPTModel.from_pretrained("./opt-125m")
voc_size = OPTConfig().vocab_size
case _:
raise KeyError(modelname)
model = model.eval()
return model, voc_size
def run_pytorch(torch_model, voc_size, batchsize, len):
data = np.random.randint(0, voc_size, (batchsize, len), dtype=np.int32)
np.save("test_inputs", data)
inputs = torch.from_numpy(data).to("cuda")
torch_model = torch_model.to("cuda")
n_iter = 20
with torch.no_grad():
for _ in range(10):
outputs = torch_model(inputs)
torch.cuda.synchronize()
begin = time.time()
with torch.no_grad():
for _ in range(n_iter):
torch.cuda.synchronize()
outputs = torch_model(inputs)
#
torch.cuda.synchronize()
torch.cuda.synchronize()
end = time.time()
avg_time = (end - begin) / n_iter
outputs = outputs.last_hidden_state.to("cpu")
print("outputs abs mean:", abs(np.array(outputs)).mean())
print(f"average time: {avg_time}")
torch.cuda.memory.empty_cache()
np.save("test_results", np.array(outputs))
print("Save input & output as test_inputs.npy and test_results.npy")
def export_onnx(model, data, path, extern=False):
torch.onnx.export(model, data, path, verbose=False, do_constant_folding=True)
onnx_model = onnx.load(path)
onnx_model, check = simplify(onnx_model, skipped_optimizers=['eliminate_duplicate_initializer'])
#onnx_model, check = simplify(onnx_model, skipped_optimizers=['fuse_qkv', 'eliminate_duplicate_initializer'])
assert check
add_value_info_for_constants(onnx_model)
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
if extern:
extern_path = path.replace('.onnx', '.pb')
if os.path.exists(extern_path):
os.remove(extern_path)
convert_model_to_external_data(
onnx_model,
all_tensors_to_one_file=True,
location=extern_path,
size_threshold=1024,
convert_attribute=False,
)
onnx.save(onnx_model, path)
def add_value_info_for_constants(model : onnx.ModelProto):
"""
Currently onnx.shape_inference doesn't use the shape of initializers, so add
that info explicitly as ValueInfoProtos.
Mutates the model.
Args:
model: The ModelProto to update.
"""
# All (top-level) constants will have ValueInfos before IRv4 as they are all inputs
if model.ir_version < 4:
return
def add_const_value_infos_to_graph(graph : onnx.GraphProto):
inputs = {i.name for i in graph.input}
existing_info = {vi.name: vi for vi in graph.value_info}
for init in graph.initializer:
# Check it really is a constant, not an input
if init.name in inputs:
continue
# The details we want to add
elem_type = init.data_type
shape = init.dims
# Get existing or create new value info for this constant
vi = existing_info.get(init.name)
if vi is None:
vi = graph.value_info.add()
vi.name = init.name
# Even though it would be weird, we will not overwrite info even if it doesn't match
tt = vi.type.tensor_type
if tt.elem_type == onnx.TensorProto.UNDEFINED:
tt.elem_type = elem_type
if not tt.HasField("shape"):
# Ensure we set an empty list if the const is scalar (zero dims)
tt.shape.dim.extend([])
for dim in shape:
tt.shape.dim.add().dim_value = dim
# Handle subgraphs
for node in graph.node:
for attr in node.attribute:
# Ref attrs refer to other attrs, so we don't need to do anything
if attr.ref_attr_name != "":
continue
if attr.type == onnx.AttributeProto.GRAPH:
add_const_value_infos_to_graph(attr.g)
if attr.type == onnx.AttributeProto.GRAPHS:
for g in attr.graphs:
add_const_value_infos_to_graph(g)
return add_const_value_infos_to_graph(model.graph)
def main():
modelname, batchsize, seqlen, export_path = parse_args()
model, voc_size = get_model(modelname)
if export_path is not None:
filename = "{}_{}_{}.onnx".format(modelname, batchsize, seqlen)
path = os.path.join(export_path, filename)
param = torch.zeros((batchsize, seqlen), dtype=torch.int)
export_onnx(model, param, path, True)
run_pytorch(model, voc_size, batchsize, seqlen)
if __name__ == "__main__":
main()

View File

@ -37,6 +37,7 @@ class GraphHandlerObj {
float momentum, float eps, bool training);
Tensor layerNormalization(Tensor input, Tensor scale, Tensor output,
Tensor bias, float eps, int axis, int stash_type);
Tensor rmsNorm(Tensor input, Tensor weight, Tensor output);
Tensor maxPool(Tensor input, Tensor output, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw, int ceilMode);

View File

@ -156,8 +156,9 @@ struct OpType {
Resize,
ReverseSequence,
RoiAlign,
RoPE, // Fusion
Round, // Unary
RoPE, // Fusion
Round, // Unary
RMSNorm, // Fusion
STFT,
Scan,
Scatter,

View File

@ -13,4 +13,8 @@ void pow_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
void less_kernel(int dtypeIndex, void *a, void *b, void *c, int a0, int a1,
int a2, int a3, int b0, int b1, int b2, int b3, int c0, int c1,
int c2, int c3);
void div_const_kernel(int dType, void *a, void *b, void *c, size_t n);
void pow_const_kernel(int dType, void *a, void *b, void *c, size_t n);
}; // namespace infini

View File

@ -7,4 +7,6 @@ void expandKernel(int dType, void *input, void *output, int nDims,
int outputsize, SmallArray inputShape,
SmallArray outputShape);
void expandRowKernel(int dType, void *input, void *output, int n_rows,
int row_len);
}; // namespace infini

View File

@ -0,0 +1,10 @@
#pragma once
#include "operators/rms_norm.h"
namespace infini {
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
int num_tokens, int hidden_size);
}; // namespace infini

View File

@ -0,0 +1,34 @@
#pragma once
#include "core/operator.h"
namespace infini {
/**
* @brief Fused RMSNorm Operator
*
*/
class RMSNormObj : public OperatorObj {
int dim;
public:
/**
* @brief Construct a new RMSNorm object.
*
* @param graph The computation graph that this operator belongs to.
* @param input The input tensor.
* @param output The output tensor.
*/
RMSNormObj(GraphObj *graph, Tensor input, Tensor weight, Tensor output);
OP_CLONE(RMSNormObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) override;
std::string toString() const override;
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
int getDim() const { return dim; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -23,12 +23,13 @@ from onnx.checker import (
ValidationError,
)
from onnx.shape_inference import infer_shapes
from onnx.numpy_helper import to_array
from onnx.numpy_helper import to_array, from_array
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
from functools import reduce
from onnxsim import simplify
import copy
import warnings
import numpy as np
class OnnxStub:
@ -111,12 +112,6 @@ class OnnxStub:
)
tensors[input.name].set_input()
for output in model.graph.output:
dims = _take_shape_dim(output.type.tensor_type.shape)
tensors[output.name] = self.handler.tensor(
dims, output.type.tensor_type.elem_type
)
tensors[output.name].set_output()
for node_idx in sorted_nodes:
node = model.graph.node[node_idx]
@ -285,6 +280,12 @@ class OnnxStub:
axis,
stash_type,
)
elif node.op_type == "RMSNorm":
tensors[node.output[0]] = self.handler.RMSNorm(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
)
elif node.op_type == "MaxPool":
attributes = _parse_attribute(
node,
@ -941,6 +942,25 @@ class OnnxStub:
tensors.get(node.output[0]),
)
elif node.op_type == "Where":
## If Y is single -inf, treat Where as Add
## TODO: deal with cases where Y is single inf or 0
if node.input[0] in data and node.input[2] in data:
where_condition = to_array(data[node.input[0]])
where_alt = to_array(data[node.input[2]])
if where_alt.size == 1:
if np.isneginf(where_alt) or np.all(where_alt < -3e38):
node.input[0] = node.input[0] + "_alt"
if node.input[0] not in data:
where_value = np.where(where_condition, 0, -np.inf).astype(where_alt.dtype)
data[node.input[0]] = from_array(where_value, node.input[0])
tensors[node.input[0]] = self.handler.tensor(list(where_value.shape), data[node.input[0]].data_type)
tensors[node.input[0]].set_weight()
tensors[node.output[0]] = self.handler.add(
tensors[node.input[1]],
tensors[node.input[0]],
tensors.get(node.output[0]),
)
continue
tensors[node.output[0]] = self.handler.where(
tensors[node.input[1]],
tensors[node.input[2]],
@ -974,6 +994,8 @@ class OnnxStub:
else:
raise Exception('Unsupported operator "{}"'.format(node.op_type))
for output in model.graph.output:
tensors[output.name].set_output()
################################
# Allocate memory space for data
################################

View File

@ -18,6 +18,7 @@
#include "operators/reduce.h"
#include "operators/reshape.h"
#include "operators/resize.h"
#include "operators/rms_norm.h"
#include "operators/rope.h"
#include "operators/send.h"
#include "operators/slice.h"
@ -124,6 +125,17 @@ Tensor GraphHandlerObj::layerNormalization(Tensor input, Tensor scale,
}
}
Tensor GraphHandlerObj::rmsNorm(Tensor input, Tensor weight, Tensor output) {
if (output) {
g->addOpWithOutputs<RMSNormObj>(std::move(input), std::move(weight),
output);
return output;
} else {
return g->addOp<RMSNormObj>(std::move(input), std::move(weight), output)
->getOutput();
}
}
Tensor GraphHandlerObj::maxPool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh, int sw,
int ceilMode) {

View File

@ -506,6 +506,7 @@ void init_graph_builder(py::module &m) {
.def("matmul", &Handler::matmul, policy::move)
.def("batchNormalization", &Handler::batchNormalization, policy::move)
.def("layerNormalization", &Handler::layerNormalization, policy::move)
.def("RMSNorm", &Handler::rmsNorm, policy::move)
.def("maxPool", &Handler::maxPool, policy::move)
.def("avgPool", &Handler::avgPool, policy::move)
.def("add", &Handler::add, policy::move)

View File

@ -115,6 +115,20 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
auto a_dim = op->getInputs(0)->getDims();
auto b_dim = op->getInputs(1)->getDims();
auto c_dim = op->getOutput()->getDims();
const int dType = _op->getDType().getIndex();
// Use optimized kernel if b is constant
if (b_dim.size() == 0) {
if (op->getOpType() == OpType::Div) {
div_const_kernel(dType, aData, bData, cData,
op->getOutput()->size());
return;
} else if (op->getOpType() == OpType::Pow) {
pow_const_kernel(dType, aData, bData, cData,
op->getOutput()->size());
return;
}
}
if (a_dim.size() > 4 || b_dim.size() > 4 || c_dim.size() > 4)
IT_TODO_HALT();
@ -127,7 +141,6 @@ class ElementWiseCuda : public CudaKernelWithoutConfig {
std::copy(b_dim.begin(), b_dim.end(), b + (4 - b_dim.size()));
std::copy(c_dim.begin(), c_dim.end(), c + (4 - c_dim.size()));
const int dType = _op->getDType().getIndex();
if (op->getOpType() == OpType::Div) {
div_kernel(dType, aData, bData, cData, a[0], a[1], a[2], a[3], b[0],
b[1], b[2], b[3], c[0], c[1], c[2], c[3]);

View File

@ -132,8 +132,8 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
#define CASE(OP, T) \
_##OP##_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>( \
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
#define SWITCH_DTYPE(OP, DTYPE) \
switch (DTYPE) { \
@ -177,7 +177,92 @@ __global__ void _less_kernel(void *x, void *y, void *z, int a0, int a1, int a2,
IT_TODO_HALT(); \
}
template <class T>
__global__ void _div_const_kernel(void const *__restrict__ x,
void const *__restrict__ y,
void *__restrict__ z, const size_t n) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) {
((T *)z)[tid] = ((T *)x)[tid] / *((T *)y);
}
}
template <class T>
__global__ void _pow_const_kernel(void const *__restrict__ x,
void const *__restrict__ y,
void *__restrict__ z, const size_t n) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) {
((T *)z)[tid] = pow(((T *)x)[tid], *((T *)y));
}
}
template <>
__global__ void _pow_const_kernel<half>(void const *__restrict__ x,
void const *__restrict__ y,
void *__restrict__ z, const size_t n) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) {
((half *)z)[tid] = pow(((float)((half *)x)[tid]), *((half *)y));
}
}
#define CASE_CONST(OP, T) \
_##OP##_const_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(a, b, c, \
n);
#define SWITCH_DTYPE_CONST(OP, DTYPE) \
switch (DTYPE) { \
case 1: \
CASE_CONST(OP, 1) \
break; \
case 2: \
CASE_CONST(OP, 2) \
break; \
case 3: \
CASE_CONST(OP, 3) \
break; \
case 4: \
CASE_CONST(OP, 4) \
break; \
case 5: \
CASE_CONST(OP, 5) \
break; \
case 6: \
CASE_CONST(OP, 6) \
break; \
case 7: \
CASE_CONST(OP, 7) \
break; \
case 10: \
CASE_CONST(OP, 10) \
break; \
case 11: \
CASE_CONST(OP, 11) \
break; \
case 12: \
CASE_CONST(OP, 12) \
break; \
case 13: \
CASE_CONST(OP, 13) \
break; \
default: \
IT_TODO_HALT(); \
}
namespace infini {
void div_const_kernel(int dType, void *a, void *b, void *c, size_t n) {
size_t blocksize = block_work_size();
size_t gridsize = (n + block_work_size() - 1) / block_work_size();
SWITCH_DTYPE_CONST(div, dType);
}
void pow_const_kernel(int dType, void *a, void *b, void *c, size_t n) {
size_t blocksize = block_work_size();
size_t gridsize = (n + block_work_size() - 1) / block_work_size();
SWITCH_DTYPE_CONST(pow, dType);
}
void div_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int a3, int b0, int b1, int b2, int b3, int c0, int c1, int c2,
int c3) {
@ -204,12 +289,12 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
int gridsize = (num + block_work_size() - 1) / block_work_size();
if (dType == 1) {
_pow_kernel<float>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
} else if (dType == 3) {
_pow_kernel<int8_t>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
a, b, c, a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3);
} else if (dType == 10) {
int a_size = a0 * a1 * a2 * a3;
int b_size = b0 * b1 * b2 * b3;
@ -224,9 +309,9 @@ void pow_kernel(int dType, void *a, void *b, void *c, int a0, int a1, int a2,
b_float[i] = __half2float(((half *)b)[i]);
}
_pow_kernel<float>
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>
(a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3, b0,
b1, b2, b3, c0, c1, c2, c3);
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>>(
a_float.data(), b_float.data(), c_float.data(), a0, a1, a2, a3,
b0, b1, b2, b3, c0, c1, c2, c3);
for (int i = 0; i < c_size; ++i) {
((half *)c)[i] = __float2half(c_float[i]);
}

View File

@ -39,6 +39,14 @@ __global__ void _expandKernel(void *input, void *output, int nDims,
}
}
template <class T>
static __global__ void _expandRowKernel(void *__restrict__ dst,
void const *__restrict__ src) {
auto da = gridDim.x, db = blockDim.y, dx = blockDim.x, n = blockIdx.y,
a = blockIdx.x, b = threadIdx.y, x = threadIdx.x;
auto i = ((n * da + a) * db + b) * dx + x, j = (a * db + b) * dx + x;
reinterpret_cast<T *>(dst)[i] = reinterpret_cast<T const *>(src)[j];
}
namespace infini {
#define CASE(T) \
@ -96,4 +104,67 @@ void expandKernel(int dType, void *input, void *output, int nDims,
SWITCH_DTYPE(dType)
}
#define CASE_ROW(T) \
_expandRowKernel<float> \
<<<grid, block, 0, CUDAStream::getCurrentStream()>>>(output, input);
#define SWITCH_DTYPE_ROW(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE_ROW(1) \
break; \
case 2: \
CASE_ROW(2) \
break; \
case 3: \
CASE_ROW(3) \
break; \
case 4: \
CASE_ROW(4) \
break; \
case 5: \
CASE_ROW(5) \
break; \
case 6: \
CASE_ROW(6) \
break; \
case 7: \
CASE_ROW(7) \
break; \
case 10: \
CASE_ROW(10) \
break; \
case 11: \
CASE_ROW(11) \
break; \
case 12: \
CASE_ROW(12) \
break; \
case 13: \
CASE_ROW(13) \
break; \
case 16: \
CASE_ROW(16) \
break; \
default: \
IT_TODO_HALT(); \
}
// Optimization for expanding a row vector. The row length must be a multiple of 32
void expandRowKernel(int dType, void *input, void *output, int n_rows,
int row_len) {
// Factorize row_len: row_len = a x b x 32 (32 is the warp size), b<=32
// input: 1 x (a x b x 32 x sizeT)
// output: n_rows x (a x b x 32 x sizeT)
// grid: n_rows x a
// block: b x 32
auto c = row_len / 32, b = c;
if (b > 32) {
for (b = 32; c % b != 0; --b);
}
auto a = c / b;
dim3 grid(a, n_rows), block(32, b);
SWITCH_DTYPE_ROW(dType)
}
} // namespace infini

View File

@ -102,9 +102,19 @@ class matmulCublas : public Kernel {
inputShape.data[i] = inC->getDims()[i - offset];
}
const int dType = dataType.getIndex();
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), nDims, outputsize,
inputShape, outputShape);
// Bias in linear layer is row vector of (1,n), n is the number of
// features. If row vector and n % 32 == 0, use optimized kernel.
if (inC->getRank() == 1 && inC->getDims()[0] % 32 == 0) {
expandRowKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(),
out->size() / inC->getDims()[0],
inC->getDims()[0]);
} else {
expandKernel(dType, inC->getRawDataPtr<void *>(),
out->getRawDataPtr<void *>(), nDims, outputsize,
inputShape, outputShape);
}
}
// TODO:use compute type
cublasStatus_t stat;

View File

@ -0,0 +1,34 @@
#include "operators/rms_norm.h"
#include "cuda/cuda_kernel_wihtout_config.h"
#include "cuda/cuda_rmsnorm.h"
#include "cuda/cuda_runtime.h"
namespace infini {
class RMSNormCuda : public CudaKernelWithoutConfig {
void compute(const Operator &_op,
const RuntimeObj *_context) const override {
auto op = as<RMSNormObj>(_op);
auto input = op->getInputs(0);
auto weight = op->getInputs(1);
auto output = op->getOutput();
void *const inputData = input->getRawDataPtr<void *>();
void *const weightData = weight->getRawDataPtr<void *>();
void *const outputData = output->getRawDataPtr<void *>();
const auto &inputShape = input->getDims();
int nDims = input->getDims().size();
int hidden_size = inputShape[nDims - 1];
int num_tokens = input->size() / hidden_size;
IT_ASSERT(hidden_size == (int)weight->size());
const int dType = op->getDType().getIndex();
rmsnorm_kernel(dType, inputData, weightData, outputData, num_tokens,
hidden_size);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::RMSNorm, RMSNormCuda, "RMSNorm_CUDA");
} // namespace infini

View File

@ -0,0 +1,112 @@
#include "core/common.h"
#include "cuda/cuda_common.h"
#include "cuda/cuda_utility.h"
#include "utils/small_array.h"
template<class T>
__inline__ __device__ T warpReduceSum(T val) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(uint32_t(-1), val, mask);
return val;
}
/* Calculate the sum of all elements in a block */
template<class T>
__inline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);
return val;
}
template <class T>
__global__ void _rmsnorm_kernel(void *in, void *weight, void *out, int num_tokens, int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
const float x = ((T*) in)[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
if(threadIdx.x == 0){
s_variance = rsqrtf(variance / hidden_size + 0.00001f);
}
__syncthreads();
for(int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x){
float x = ((T*) in)[blockIdx.x * hidden_size + idx];
((T*)out)[blockIdx.x * hidden_size + idx] = ((T)(x * s_variance)) * ((T*)weight)[idx];
}
}
#define CASE(T) \
_rmsnorm_kernel<DT_CUDA<T>::t> \
<<<gridsize, blocksize, 0, CUDAStream::getCurrentStream()>>> \
(input, weight, output, num_tokens, hidden_size);
#define SWITCH_DTYPE(DTYPE) \
switch (DTYPE) { \
case 1: \
CASE(1) \
break; \
case 2: \
CASE(2) \
break; \
case 3: \
CASE(3) \
break; \
case 4: \
CASE(4) \
break; \
case 5: \
CASE(5) \
break; \
case 6: \
CASE(6) \
break; \
case 7: \
CASE(7) \
break; \
case 10: \
CASE(10) \
break; \
case 11: \
CASE(11) \
break; \
case 12: \
CASE(12) \
break; \
case 13: \
CASE(13) \
break; \
case 16: \
CASE(16) \
break; \
default: \
IT_TODO_HALT(); \
}
namespace infini {
void rmsnorm_kernel(int dType, void *input, void *weight, void *output,
int num_tokens, int hidden_size) {
dim3 blocksize = dim3(std::min(hidden_size, 1024));
dim3 gridsize = dim3(num_tokens);
SWITCH_DTYPE(dType)
}
} // namespace infini

View File

@ -22,7 +22,7 @@ class RoPECuda : public CudaKernelWithoutConfig {
IT_ASSERT(nDims == 3 && pos->getDims().size() == 2);
IT_ASSERT(inputShape[1] == pos->getDims()[1]);
int dim_model = inputShape[2];
int dim_head = dim_model / 32;
int dim_head = 128;
int hidden_stride = dim_model * inputShape[1];
int pos_stride = inputShape[1];

View File

@ -3,11 +3,6 @@
#include "cuda/cuda_utility.h"
#include "utils/small_array.h"
constexpr unsigned int num_threads() { return 32 * 4; }
constexpr int thread_work_size() { return 4; }
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
// gridDim (batch, seq_len, dim_model / 1024), blockDim (1024, 1, 1)
template <class T>
__global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_model,
int dim_head, int hidden_stride, int pos_stride) {
@ -86,8 +81,8 @@ __global__ void _rope_kernel(int* pos, void *in, void *out, int size, int dim_mo
namespace infini {
void rope_kernel(int dType, int * pos, void *input, void *output, int size,
int dim_model, int dim_head, int hidden_stride, int pos_stride) {
dim3 blocksize = dim3(1024,1,1);
dim3 gridsize = dim3(1, 1, 4);
dim3 blocksize = dim3(32,1,1);
dim3 gridsize = dim3(1, 1, dim_model/32);
SWITCH_DTYPE(dType)
}

View File

@ -315,6 +315,8 @@ void unary_kernel(const Operator &_op) {
} else if (op->getOpType() == OpType::Silu) {
if (_op->getDType() == DataType::Float32) {
silu_kernel<float>((float *)inputData, (float *)outputData, num);
} else if (_op->getDType() == DataType::Float16){
silu_kernel<half>((half *)inputData, (half *)outputData, num);
} else {
IT_TODO_HALT();
}

36
src/operators/rms_norm.cc Normal file
View File

@ -0,0 +1,36 @@
#include "operators/rms_norm.h"
namespace infini {
RMSNormObj::RMSNormObj(GraphObj *graph, Tensor input, Tensor weight,
Tensor output)
: OperatorObj(OpType::RMSNorm, {input, weight}, {output}) {
IT_ASSERT(checkValid(graph));
}
optional<vector<Shape>> RMSNormObj::inferShape(const TensorVec &inputs) {
const auto A = inputs[0];
auto input_dim = A->getDims();
auto output_dim = input_dim;
return {{output_dim}};
}
std::string RMSNormObj::toString() const {
std::ostringstream os;
os << type.toString() << "[" << getGuid() << "]";
os << "(";
os << vecToString(inputs[0]->getDims()) << ",";
os << "input=" << inputs[0]->getGuid() << ",";
os << "output=" << outputs[0]->getGuid() << ")";
return os.str();
}
vector<int> RMSNormObj::getWorkloadVector() const {
vector<int> ret{type.underlying()};
const Shape shape = outputs[0]->getDims();
ret.insert(ret.end(), shape.begin(), shape.end());
return ret;
}
vector<int> RMSNormObj::getOpAttrVector() const { return {type.underlying()}; }
}; // namespace infini