forked from jiuyuan/InfiniTensor
gemv2N to gemv2T
This commit is contained in:
parent
86877509c1
commit
2fb1c8cf32
|
@ -23,12 +23,13 @@ from onnx.checker import (
|
|||
ValidationError,
|
||||
)
|
||||
from onnx.shape_inference import infer_shapes
|
||||
from onnx.numpy_helper import to_array
|
||||
from onnx.numpy_helper import to_array, from_array
|
||||
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
||||
from functools import reduce
|
||||
from onnxsim import simplify
|
||||
import copy
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
|
||||
class OnnxStub:
|
||||
|
@ -186,15 +187,32 @@ class OnnxStub:
|
|||
op[1],
|
||||
)
|
||||
elif node.op_type == "MatMul":
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
if tensors[node.input[0]].shape()[0] == 1 and tensors[node.input[0]].shape()[1] == 1 \
|
||||
and len(tensors[node.input[1]].shape()) == 2 and node.input[1] in data.keys():
|
||||
data[node.input[1]] = from_array(
|
||||
np.transpose(to_array(data[node.input[1]])))
|
||||
tensors[node.input[1]] = self.handler.tensor(
|
||||
[tensors[node.input[1]].shape()[1], tensors[node.input[1]].shape()[0]],
|
||||
tensors[node.input[1]].dtype())
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
True,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
else:
|
||||
tensors[node.output[0]] = self.handler.matmul(
|
||||
tensors[node.input[0]],
|
||||
tensors[node.input[1]],
|
||||
tensors.get(node.output[0]),
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
backend.ActType.Linear,
|
||||
)
|
||||
elif node.op_type == "Gemm":
|
||||
attributes = _parse_attribute(
|
||||
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}
|
||||
|
|
|
@ -24,7 +24,6 @@ class AttentionKVCacheCompute {
|
|||
Tensor output_matmul, Tensor output_temp_O, Tensor output_temp_sum) const {
|
||||
AttentionKVCacheMetadata metadata;
|
||||
initAttentionKVCacheMetadata(metadata, input_v_cache);
|
||||
std::cout << "do compute" << std::endl;
|
||||
|
||||
attention_kvcache_kernel(input_k_cache->getRawDataPtr<float *>(),
|
||||
input_v_cache->getRawDataPtr<float *>(),
|
||||
|
|
Loading…
Reference in New Issue