From 6dce129cb3fdc20578506528085fe10e0eee7e6b Mon Sep 17 00:00:00 2001 From: YdrMaster Date: Fri, 24 Feb 2023 09:22:03 +0800 Subject: [PATCH] fix: TensorObj::dataMalloc Signed-off-by: YdrMaster --- .github/workflows/build.yml | 4 ++-- pyinfinitensor/tests/test_onnx.py | 10 +++++----- src/core/tensor.cc | 11 ++--------- 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 879ae847..7dae8509 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -66,5 +66,5 @@ jobs: - name: Download test model run: wget ${{ env.resnet-download }} - # - name: Test onnx frontend - # run: make test-onnx + - name: Test onnx frontend + run: make test-onnx diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 8be02f11..a529fb67 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -8,13 +8,13 @@ from onnx.helper import ( make_tensor_value_info, ) from onnx.checker import check_model -from pyinfinitensor.onnx import from_onnx, backend, runtime, to_onnx +from pyinfinitensor.onnx import from_onnx, backend, to_onnx, cpu_runtime def make_and_import_model(graph: onnx.GraphProto): model = make_model(graph) check_model(model) - from_onnx(model) + from_onnx(model, cpu_runtime) class TestStringMethods(unittest.TestCase): @@ -28,7 +28,7 @@ class TestStringMethods(unittest.TestCase): file=model_file, size=os.path.getsize(model_file) / 1024 / 1024 ) ) - from_onnx(onnx.load(model_file)) + from_onnx(onnx.load(model_file), cpu_runtime) def test_tensor(self): x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) @@ -289,10 +289,10 @@ class TestStringMethods(unittest.TestCase): graph = make_graph([matmul, add], "lr", [x, a, b], [y]) model = make_model(graph) check_model(model) - from_onnx(model) + from_onnx(model, cpu_runtime) def test_frontend(self): - handler = backend.GraphHandler(runtime) + handler = backend.GraphHandler(cpu_runtime) a = handler.tensor([1, 2, 3], 12) b = handler.tensor([1, 2, 3], 12) c = handler.tensor([1, 2, 3], 12) diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 5050e53e..89294a8e 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -148,15 +148,8 @@ bool TensorObj::equalData(const Tensor &rhs) const { } void TensorObj::dataMalloc() { - if (data != nullptr) - return; - // IT_ASSERT(data == nullptr); - size_t bytesPerElement; - if (getDType() == DataType::Float32) - bytesPerElement = sizeof(float); - else if (getDType() == DataType::UInt32) - bytesPerElement = sizeof(uint32_t); - data = runtime->allocBlob(size() * bytesPerElement); + if (data == nullptr) + data = runtime->allocBlob(getBytes()); } void TensorObj::copyData(const TensorObj *src) {