fix: TensorObj::dataMalloc

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-02-24 09:22:03 +08:00
parent dc79b72655
commit 6dce129cb3
3 changed files with 9 additions and 16 deletions

View File

@ -66,5 +66,5 @@ jobs:
- name: Download test model - name: Download test model
run: wget ${{ env.resnet-download }} run: wget ${{ env.resnet-download }}
# - name: Test onnx frontend - name: Test onnx frontend
# run: make test-onnx run: make test-onnx

View File

@ -8,13 +8,13 @@ from onnx.helper import (
make_tensor_value_info, make_tensor_value_info,
) )
from onnx.checker import check_model from onnx.checker import check_model
from pyinfinitensor.onnx import from_onnx, backend, runtime, to_onnx from pyinfinitensor.onnx import from_onnx, backend, to_onnx, cpu_runtime
def make_and_import_model(graph: onnx.GraphProto): def make_and_import_model(graph: onnx.GraphProto):
model = make_model(graph) model = make_model(graph)
check_model(model) check_model(model)
from_onnx(model) from_onnx(model, cpu_runtime)
class TestStringMethods(unittest.TestCase): class TestStringMethods(unittest.TestCase):
@ -28,7 +28,7 @@ class TestStringMethods(unittest.TestCase):
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024 file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
) )
) )
from_onnx(onnx.load(model_file)) from_onnx(onnx.load(model_file), cpu_runtime)
def test_tensor(self): def test_tensor(self):
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3]) x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
@ -289,10 +289,10 @@ class TestStringMethods(unittest.TestCase):
graph = make_graph([matmul, add], "lr", [x, a, b], [y]) graph = make_graph([matmul, add], "lr", [x, a, b], [y])
model = make_model(graph) model = make_model(graph)
check_model(model) check_model(model)
from_onnx(model) from_onnx(model, cpu_runtime)
def test_frontend(self): def test_frontend(self):
handler = backend.GraphHandler(runtime) handler = backend.GraphHandler(cpu_runtime)
a = handler.tensor([1, 2, 3], 12) a = handler.tensor([1, 2, 3], 12)
b = handler.tensor([1, 2, 3], 12) b = handler.tensor([1, 2, 3], 12)
c = handler.tensor([1, 2, 3], 12) c = handler.tensor([1, 2, 3], 12)

View File

@ -148,15 +148,8 @@ bool TensorObj::equalData(const Tensor &rhs) const {
} }
void TensorObj::dataMalloc() { void TensorObj::dataMalloc() {
if (data != nullptr) if (data == nullptr)
return; data = runtime->allocBlob(getBytes());
// IT_ASSERT(data == nullptr);
size_t bytesPerElement;
if (getDType() == DataType::Float32)
bytesPerElement = sizeof(float);
else if (getDType() == DataType::UInt32)
bytesPerElement = sizeof(uint32_t);
data = runtime->allocBlob(size() * bytesPerElement);
} }
void TensorObj::copyData(const TensorObj *src) { void TensorObj::copyData(const TensorObj *src) {