forked from jiuyuan/InfiniTensor
fix: TensorObj::dataMalloc
Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
parent
dc79b72655
commit
6dce129cb3
|
@ -66,5 +66,5 @@ jobs:
|
|||
- name: Download test model
|
||||
run: wget ${{ env.resnet-download }}
|
||||
|
||||
# - name: Test onnx frontend
|
||||
# run: make test-onnx
|
||||
- name: Test onnx frontend
|
||||
run: make test-onnx
|
||||
|
|
|
@ -8,13 +8,13 @@ from onnx.helper import (
|
|||
make_tensor_value_info,
|
||||
)
|
||||
from onnx.checker import check_model
|
||||
from pyinfinitensor.onnx import from_onnx, backend, runtime, to_onnx
|
||||
from pyinfinitensor.onnx import from_onnx, backend, to_onnx, cpu_runtime
|
||||
|
||||
|
||||
def make_and_import_model(graph: onnx.GraphProto):
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
from_onnx(model)
|
||||
from_onnx(model, cpu_runtime)
|
||||
|
||||
|
||||
class TestStringMethods(unittest.TestCase):
|
||||
|
@ -28,7 +28,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
|
||||
)
|
||||
)
|
||||
from_onnx(onnx.load(model_file))
|
||||
from_onnx(onnx.load(model_file), cpu_runtime)
|
||||
|
||||
def test_tensor(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||
|
@ -289,10 +289,10 @@ class TestStringMethods(unittest.TestCase):
|
|||
graph = make_graph([matmul, add], "lr", [x, a, b], [y])
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
from_onnx(model)
|
||||
from_onnx(model, cpu_runtime)
|
||||
|
||||
def test_frontend(self):
|
||||
handler = backend.GraphHandler(runtime)
|
||||
handler = backend.GraphHandler(cpu_runtime)
|
||||
a = handler.tensor([1, 2, 3], 12)
|
||||
b = handler.tensor([1, 2, 3], 12)
|
||||
c = handler.tensor([1, 2, 3], 12)
|
||||
|
|
|
@ -148,15 +148,8 @@ bool TensorObj::equalData(const Tensor &rhs) const {
|
|||
}
|
||||
|
||||
void TensorObj::dataMalloc() {
|
||||
if (data != nullptr)
|
||||
return;
|
||||
// IT_ASSERT(data == nullptr);
|
||||
size_t bytesPerElement;
|
||||
if (getDType() == DataType::Float32)
|
||||
bytesPerElement = sizeof(float);
|
||||
else if (getDType() == DataType::UInt32)
|
||||
bytesPerElement = sizeof(uint32_t);
|
||||
data = runtime->allocBlob(size() * bytesPerElement);
|
||||
if (data == nullptr)
|
||||
data = runtime->allocBlob(getBytes());
|
||||
}
|
||||
|
||||
void TensorObj::copyData(const TensorObj *src) {
|
||||
|
|
Loading…
Reference in New Issue