From 2412c25e67bab976d8ac6379bf3286319ecffb63 Mon Sep 17 00:00:00 2001 From: PanZezhong1725 <141193946+PanZezhong1725@users.noreply.github.com> Date: Fri, 1 Sep 2023 11:20:26 +0800 Subject: [PATCH] Issue 107: Add copyin Numpy and covertion to Numpy (#126) * Add copyin_numpy and to_numpy for pybind TensorObj * fix copyin size assertion * fix size calculation for scalar (rank = 0) tensor * Use pybind buffer instead of returning array * fix format --- include/core/tensor.h | 17 +++--- pyinfinitensor/src/pyinfinitensor/onnx.py | 40 ++++++++------ pyinfinitensor/tests/test_onnx.py | 50 ++++++++++++++++- src/core/tensor.cc | 6 +-- src/ffi/ffi_infinitensor.cc | 66 ++++++++++++++++++++++- test/core/test_graph.cc | 8 +-- 6 files changed, 151 insertions(+), 36 deletions(-) diff --git a/include/core/tensor.h b/include/core/tensor.h index fe0b536e..03e1b20c 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -19,14 +19,6 @@ class TensorObj : public TensorBaseObj { size_t _size; // Cache of Π(shape). Fuid fuid; // Cloned tensors share the same id. Tensors constructed from // scratch have a new id. - - void copyin(const void *ptr, size_t size) { - runtime->copyBlobFromCPU(getRawDataPtr(), ptr, size); - } - void copyout(void *ptr, size_t size) const { - runtime->copyBlobToCPU(ptr, getRawDataPtr(), size); - } - public: TensorObj(Shape shape, DataType dtype, Runtime runtime); virtual ~TensorObj() {} @@ -45,10 +37,17 @@ class TensorObj : public TensorBaseObj { void load(std::string file_path); void save(std::string file_path); + void copyin(const void *ptr, size_t size) { + runtime->copyBlobFromCPU(getRawDataPtr(), ptr, size); + } + void copyout(void *ptr, size_t size) const { + runtime->copyBlobToCPU(ptr, getRawDataPtr(), size); + } + // Copy elements from `data`. template void copyin(const vector &data) { IT_ASSERT(DataType::get() == dtype.cpuTypeInt()); - IT_ASSERT(data.size() >= _size); + IT_ASSERT(data.size() == _size); copyin(data.data(), getBytes()); } // Copy all the elements to a vector. diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index 17cdb8fe..9fba35c4 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -616,8 +616,14 @@ class OnnxStub: # update the node_list node_list = list(set(node_name) - set(new_node_name)) + ################################ + # Allocate memory space for data + ################################ self.handler.data_malloc() + ################################# + # Copy in data to tensor objects + ################################# for name, obj in tensors.items(): tensor = data.get(name) if tensor == None: @@ -625,22 +631,24 @@ class OnnxStub: self.inputs[name] = obj else: self.initializer[obj.fuid()] = tensor - if tensor.data_type == TensorProto.INT32: - obj.copyin_int32(_parse_data(tensor)) - elif tensor.data_type == TensorProto.INT64: - obj.copyin_int64(_parse_data(tensor)) - elif tensor.data_type == TensorProto.FLOAT: - obj.copyin_float(_parse_data(tensor)) - elif tensor.data_type == TensorProto.BOOL: - obj.copyin_int8(_parse_data(tensor)) - elif tensor.data_type == TensorProto.FLOAT16: - obj.copyin_float16(_parse_data_fp16(tensor)) - elif tensor.data_type == TensorProto.INT8: - obj.copyin_uint8(_parse_data(tensor)) - elif tensor.data_type == TensorProto.BFLOAT16: - obj.copyin_float16(_parse_data_fp16(tensor)) - else: - assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) + # TODO: delete these lines after copyin_numpy is stable + # if tensor.data_type == TensorProto.INT32: + # obj.copyin_int32(_parse_data(tensor)) + # elif tensor.data_type == TensorProto.INT64: + # obj.copyin_int64(_parse_data(tensor)) + # elif tensor.data_type == TensorProto.FLOAT: + # obj.copyin_float(_parse_data(tensor)) + # elif tensor.data_type == TensorProto.BOOL: + # obj.copyin_int8(_parse_data(tensor)) + # elif tensor.data_type == TensorProto.FLOAT16: + # obj.copyin_float16(_parse_data_fp16(tensor)) + # elif tensor.data_type == TensorProto.INT8: + # obj.copyin_uint8(_parse_data(tensor)) + # elif tensor.data_type == TensorProto.BFLOAT16: + # obj.copyin_float16(_parse_data_fp16(tensor)) + # else: + # assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) + obj.copyin_numpy(to_array(tensor)) for output in model.graph.output: self.outputs[output.name] = tensors[output.name] diff --git a/pyinfinitensor/tests/test_onnx.py b/pyinfinitensor/tests/test_onnx.py index 28134dc4..3fdb5f06 100644 --- a/pyinfinitensor/tests/test_onnx.py +++ b/pyinfinitensor/tests/test_onnx.py @@ -9,7 +9,8 @@ from onnx.helper import ( ) from onnx.checker import check_model, check_graph from onnx.shape_inference import infer_shapes -from pyinfinitensor.onnx import from_onnx, OnnxStub, backend +from pyinfinitensor.onnx import from_onnx, OnnxStub, backend, _parse_data_fp16 +import numpy as np def make_and_import_model(graph: onnx.GraphProto): @@ -382,6 +383,53 @@ class TestStringMethods(unittest.TestCase): where = make_node("Where", ["x", "y", "con"], ["output"], name="where") make_and_import_model(make_graph([where], "where", [x, y, con], [output])) + def test_copyin(self): + dims = [2,3,5,4] + np_array = np.random.random(dims).astype(np.float32) + handler = backend.GraphHandler(backend.cpu_runtime()) + tensor1 = handler.tensor(dims, TensorProto.FLOAT) + tensor2 = handler.tensor(dims, TensorProto.FLOAT) + handler.data_malloc() + tensor1.copyin_numpy(np_array) + tensor2.copyin_float(np_array.flatten().tolist()) + array1 = tensor1.copyout_float() + array2 = tensor2.copyout_float() + self.assertEqual(array1, array2) + self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array)) + + np_array = np.random.random(dims).astype(np.int64) + handler = backend.GraphHandler(backend.cpu_runtime()) + tensor1 = handler.tensor(dims, TensorProto.INT64) + tensor2 = handler.tensor(dims, TensorProto.INT64) + handler.data_malloc() + tensor1.copyin_numpy(np_array) + tensor2.copyin_int64(np_array.flatten().tolist()) + array1 = tensor1.copyout_int64() + array2 = tensor2.copyout_int64() + self.assertEqual(array1, array2) + self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array)) + + def test_to_numpy(self): + dims = [2,3,5,4] + np_array = np.random.random(dims).astype(np.float32) + handler = backend.GraphHandler(backend.cpu_runtime()) + tensor1 = handler.tensor(dims, TensorProto.FLOAT) + tensor2 = handler.tensor(dims, TensorProto.FLOAT) + handler.data_malloc() + tensor1.copyin_float(np_array.flatten().tolist()) + tensor2.copyin_float(np_array.flatten().tolist()) + array1 = np.array(tensor1.copyout_float()).reshape(dims) + array2 = np.array(tensor2) + self.assertTrue(np.array_equal(array2, np_array)) + self.assertTrue(np.array_equal(array1, array2)) + + np_array = np.random.random(dims).astype(np.float16) + handler = backend.GraphHandler(backend.cpu_runtime()) + tensor1 = handler.tensor(dims, TensorProto.FLOAT16) + handler.data_malloc() + tensor1.copyin_numpy(np_array) + array1 = np.array(tensor1, copy=False) + self.assertTrue(np.array_equal(array1, np_array)) if __name__ == "__main__": unittest.main() diff --git a/src/core/tensor.cc b/src/core/tensor.cc index f52127e6..2d786f7b 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -10,10 +10,8 @@ namespace infini { TensorObj::TensorObj(Shape shape_, DataType dtype, Runtime runtime) : TensorBaseObj(shape_.size(), dtype, runtime), shape(std::move(shape_)), - _size(shape.empty() - ? 0 - : std::accumulate(shape.begin(), shape.end(), 1, - [](auto acc, auto x) { return acc * x; })) {} + _size(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{})) { +} string TensorObj::toString() const { // Convert data pointer to string diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index d62e57f6..efe047da 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -1,3 +1,4 @@ +#include "core/data_type.h" #include "core/graph_handler.h" #include "operators/batch_norm.h" #include "operators/concat.h" @@ -13,8 +14,9 @@ #include "operators/transpose.h" #include "operators/unary.h" #include +#include +#include #include - #ifdef USE_CUDA #include "cuda/cuda_runtime.h" #include "cuda/operator_timer.h" @@ -315,7 +317,8 @@ void init_graph_builder(py::module &m) { py::class_, RuntimeObj>( m, "BangRuntime"); #endif - py::class_>(m, "Tensor") + py::class_>(m, "Tensor", + py::buffer_protocol()) .def("fuid", &TensorObj::getFuid, policy::automatic) .def("shape", &TensorObj::getDims, policy::move) .def("copyin_float", &TensorObj::copyin, policy::move) @@ -330,6 +333,65 @@ void init_graph_builder(py::module &m) { .def("copyout_int8", &TensorObj::copyout, policy::move) .def("copyout_uint8", &TensorObj::copyout, policy::move) .def("copyout_float16", &TensorObj::copyout, policy::move) + // Copy data from a Numpy array + .def("copyin_numpy", + [](TensorObj &self, py::buffer buf) { + py::buffer_info buf_info = buf.request(); + void *data_np = buf_info.ptr; + size_t itemsize = buf_info.itemsize; + size_t size = buf_info.size; + IT_ASSERT(itemsize == self.getDType().getSize()); + IT_ASSERT(size == self.size()); + for (size_t i = 0; i < self.getRank(); i++) { + IT_ASSERT(self.getDims()[i] == buf_info.shape[i]); + } + self.copyin(data_np, self.getBytes()); + }) + // A buffer can be used to convert a TensorObj directly to Numpy array + // without copy + .def_buffer([](TensorObj &self) -> py::buffer_info { + vector stride_byte; + for (int s : self.getStride()) { + stride_byte.push_back(s * self.getDType().getSize()); + } + + std::string format; + if (self.getDType() == DataType::Float32) { + format = py::format_descriptor::format(); + } else if (self.getDType() == DataType::Double) { + format = py::format_descriptor::format(); + } else if (self.getDType() == DataType::Int32) { + format = py::format_descriptor::format(); + } else if (self.getDType() == DataType::UInt32) { + format = py::format_descriptor::format(); + } else if (self.getDType() == DataType::Int64) { + format = py::format_descriptor::format(); + } else if (self.getDType() == DataType::UInt64) { + format = py::format_descriptor::format(); + } else if (self.getDType() == DataType::Int16) { + format = py::format_descriptor::format(); + } else if (self.getDType() == DataType::UInt16) { + format = py::format_descriptor::format(); + } else if (self.getDType() == DataType::Int8) { + format = py::format_descriptor::format(); + } else if (self.getDType() == DataType::UInt8) { + format = py::format_descriptor::format(); + } else if (self.getDType() == DataType::Float16 || + self.getDType() == DataType::BFloat16) { + // Python uses "e" for half precision float type code. + // Check the following link for more information. + // https://docs.python.org/3/library/struct.html#format-characters + format = "e"; + } else { + throw std::runtime_error("Error converting TensorObj to " + "Numpy: unsupported datatype.\n"); + } + + return py::buffer_info(self.getRawDataPtr(), + self.getDType().getSize(), format, + self.getRank(), self.getDims(), stride_byte, + true); // Read-only = true + }) .def("has_target", &TensorObj::hasTarget, policy::automatic) .def("src", &TensorObj::getSource, policy::move) .def("printData", &TensorObj::printData, policy::automatic); diff --git a/test/core/test_graph.cc b/test/core/test_graph.cc index c2b1ff4c..28c8f917 100644 --- a/test/core/test_graph.cc +++ b/test/core/test_graph.cc @@ -15,7 +15,7 @@ TEST(Graph, build_and_run) { Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32); g->dataMalloc(); - i0->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + i0->copyin(vector{1, 2, 3, 4, 5, 6}); w0->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); auto matmul = g->addOpWithOutputs(i0, w0, o0); g->print(); @@ -84,7 +84,7 @@ TEST(Graph, perf_engine) { auto matmul = g->addOp(i0, w0, nullptr); g->dataMalloc(); - i0->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + i0->copyin(vector{1, 2, 3, 4, 5, 6}); w0->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); runtime->run(g, true, true); double perfTime = runtime->getPerfTime(g); @@ -105,7 +105,7 @@ TEST(Graph, test_tensor_id) { Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32); g->dataMalloc(); - i0->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + i0->copyin(vector{1, 2, 3, 4, 5, 6}); w0->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); auto i1 = g->addTensor(i0->clone()); auto matmul = g->addOpWithOutputs(i0, w0, o0); @@ -123,7 +123,7 @@ TEST(Graph, test_OpVec_ctor) { Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32); g->dataMalloc(); - i0->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + i0->copyin(vector{1, 2, 3, 4, 5, 6}); w0->copyin(vector{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); auto o1 = g->addTensor(o0->clone()); auto matmul = g->addOpWithOutputs(i0, w0, o0);