accelerate cuda fp32 matmul

This commit is contained in:
xiaonans 2024-03-26 11:37:54 +08:00
parent 0740d26f43
commit 4bdd33522b
2 changed files with 33 additions and 12 deletions

View File

@ -23,12 +23,13 @@ from onnx.checker import (
ValidationError,
)
from onnx.shape_inference import infer_shapes
from onnx.numpy_helper import to_array
from onnx.numpy_helper import to_array, from_array
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
from functools import reduce
from onnxsim import simplify
import copy
import warnings
import numpy as np
class OnnxStub:
@ -207,15 +208,36 @@ class OnnxStub:
op[1],
)
elif node.op_type == "MatMul":
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
False,
None,
backend.ActType.Linear,
)
if to_array(data[node.input[1]]).dtype == np.float32 \
and type(runtime) == backend.CudaRuntime \
and tensors[node.input[0]].shape()[0] == 1 \
and tensors[node.input[0]].shape()[1] == 1 \
and len(tensors[node.input[1]].shape()) == 2 \
and node.input[1] in data.keys():
data[node.input[1]] = from_array(
np.transpose(to_array(data[node.input[1]])))
tensors[node.input[1]] = self.handler.tensor(
[tensors[node.input[1]].shape()[1], tensors[node.input[1]].shape()[0]],
tensors[node.input[1]].dtype())
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
True,
None,
backend.ActType.Linear,
)
else:
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
False,
None,
backend.ActType.Linear,
)
elif node.op_type == "Gemm":
attributes = _parse_attribute(
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}

View File

@ -107,13 +107,12 @@ __global__ void _attention_kvcache_kernel_128_1(T* input_k_cache,
ptr_K[i] = ptr_Q[i] * ptr_K[i];
#pragma unroll
for (int offset = WARP_SIZE/2; offset > 0; offset /= 2) {
ptr_K[i] += __shfl_down_sync(0xffffffff, ptr_K[i], offset);
ptr_K[i] += __shfl_xor_sync(0xffffffff, ptr_K[i], offset);
}
ptr_P[idx_SEQ_UNIT] += __half2float(ptr_K[i]);
}
// div sqrt(d)
ptr_P[idx_SEQ_UNIT] = __shfl_sync(0xffffffff, ptr_P[idx_SEQ_UNIT], 0);
ptr_P[idx_SEQ_UNIT] /= sqrt(128.0);
// softmax