forked from jiuyuan/InfiniTensor
accelerate cuda fp32 matmul
This commit is contained in:
parent
0740d26f43
commit
4bdd33522b
|
@ -23,12 +23,13 @@ from onnx.checker import (
|
|||
ValidationError,
|
||||
)
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from onnx.numpy_helper import to_array
|
||||
from onnx.numpy_helper import to_array, from_array
|
||||
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
||||
from functools import reduce
|
||||
from onnxsim import simplify
|
||||
import copy
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
|
||||
class OnnxStub:
|
||||
|
@ -207,15 +208,36 @@ class OnnxStub:
|
|||
op[1],
|
||||
)
|
||||
elif node.op_type == "MatMul":
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
if to_array(data[node.input[1]]).dtype == np.float32 \
|
||||
and type(runtime) == backend.CudaRuntime \
|
||||
and tensors[node.input[0]].shape()[0] == 1 \
|
||||
and tensors[node.input[0]].shape()[1] == 1 \
|
||||
and len(tensors[node.input[1]].shape()) == 2 \
|
||||
and node.input[1] in data.keys():
|
||||
data[node.input[1]] = from_array(
|
||||
np.transpose(to_array(data[node.input[1]])))
|
||||
tensors[node.input[1]] = self.handler.tensor(
|
||||
[tensors[node.input[1]].shape()[1], tensors[node.input[1]].shape()[0]],
|
||||
tensors[node.input[1]].dtype())
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
True,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
else:
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
elif node.op_type == "Gemm":
|
||||
attributes = _parse_attribute(
|
||||
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
||||
|
|
|
@ -107,13 +107,12 @@ __global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
|
|||
ptr_K[i] = ptr_Q[i] * ptr_K[i];
|
||||
#pragma unroll
|
||||
for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
|
||||
ptr_K[i] += __shfl_down_sync(0xffffffff, ptr_K[i], offset);
|
||||
ptr_K[i] += __shfl_xor_sync(0xffffffff, ptr_K[i], offset);
|
||||
}
|
||||
ptr_P[idx_SEQ_UNIT] += __half2float(ptr_K[i]);
|
||||
}
|
||||
|
||||
// div sqrt(d)
|
||||
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
|
||||
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
|
||||
|
||||
// softmax
|
||||
|
|
Loading…
Reference in New Issue