gemv2N to gemv2T

This commit is contained in:
xiaonans 2023-11-27 16:19:01 +08:00
parent 86877509c1
commit 2fb1c8cf32
2 changed files with 28 additions and 11 deletions

View File

@ -23,12 +23,13 @@ from onnx.checker import (
ValidationError,
)
from onnx.shape_inference import infer_shapes
from onnx.numpy_helper import to_array
from onnx.numpy_helper import to_array, from_array
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
from functools import reduce
from onnxsim import simplify
import copy
import warnings
import numpy as np
class OnnxStub:
@ -186,15 +187,32 @@ class OnnxStub:
op[1],
)
elif node.op_type == "MatMul":
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
False,
None,
backend.ActType.Linear,
)
if tensors[node.input[0]].shape()[0] == 1 and tensors[node.input[0]].shape()[1] == 1 \
and len(tensors[node.input[1]].shape()) == 2 and node.input[1] in data.keys():
data[node.input[1]] = from_array(
np.transpose(to_array(data[node.input[1]])))
tensors[node.input[1]] = self.handler.tensor(
[tensors[node.input[1]].shape()[1], tensors[node.input[1]].shape()[0]],
tensors[node.input[1]].dtype())
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
True,
None,
backend.ActType.Linear,
)
else:
tensors[node.output[0]] = self.handler.matmul(
tensors[node.input[0]],
tensors[node.input[1]],
tensors.get(node.output[0]),
False,
False,
None,
backend.ActType.Linear,
)
elif node.op_type == "Gemm":
attributes = _parse_attribute(
node, {"alpha": 1.0, "beta": 1.0, "transA": 0, "transB": 0}

View File

@ -24,7 +24,6 @@ class AttentionKVCacheCompute {
Tensor output_matmul, Tensor output_temp_O, Tensor output_temp_sum) const {
AttentionKVCacheMetadata metadata;
initAttentionKVCacheMetadata(metadata, input_v_cache);
std::cout << "do compute" << std::endl;
attention_kvcache_kernel(input_k_cache->getRawDataPtr<float *>(),
input_v_cache->getRawDataPtr<float *>(),