Issue 107: Add copyin Numpy and covertion to Numpy (#126)

* Add copyin_numpy and to_numpy for pybind TensorObj

* fix copyin size assertion

* fix size calculation for scalar (rank = 0) tensor

* Use pybind buffer instead of returning array

* fix format
This commit is contained in:
PanZezhong1725 2023-09-01 11:20:26 +08:00 committed by GitHub
parent 3e6ef305f1
commit 2412c25e67
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 151 additions and 36 deletions

View File

@ -19,14 +19,6 @@ class TensorObj : public TensorBaseObj {
size_t _size; // Cache of Π(shape). size_t _size; // Cache of Π(shape).
Fuid fuid; // Cloned tensors share the same id. Tensors constructed from Fuid fuid; // Cloned tensors share the same id. Tensors constructed from
// scratch have a new id. // scratch have a new id.
void copyin(const void *ptr, size_t size) {
runtime->copyBlobFromCPU(getRawDataPtr<void *>(), ptr, size);
}
void copyout(void *ptr, size_t size) const {
runtime->copyBlobToCPU(ptr, getRawDataPtr<void *>(), size);
}
public: public:
TensorObj(Shape shape, DataType dtype, Runtime runtime); TensorObj(Shape shape, DataType dtype, Runtime runtime);
virtual ~TensorObj() {} virtual ~TensorObj() {}
@ -45,10 +37,17 @@ class TensorObj : public TensorBaseObj {
void load(std::string file_path); void load(std::string file_path);
void save(std::string file_path); void save(std::string file_path);
void copyin(const void *ptr, size_t size) {
runtime->copyBlobFromCPU(getRawDataPtr<void *>(), ptr, size);
}
void copyout(void *ptr, size_t size) const {
runtime->copyBlobToCPU(ptr, getRawDataPtr<void *>(), size);
}
// Copy elements from `data`. // Copy elements from `data`.
template <typename T> void copyin(const vector<T> &data) { template <typename T> void copyin(const vector<T> &data) {
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt()); IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
IT_ASSERT(data.size() >= _size); IT_ASSERT(data.size() == _size);
copyin(data.data(), getBytes()); copyin(data.data(), getBytes());
} }
// Copy all the elements to a vector. // Copy all the elements to a vector.

View File

@ -616,8 +616,14 @@ class OnnxStub:
# update the node_list # update the node_list
node_list = list(set(node_name) - set(new_node_name)) node_list = list(set(node_name) - set(new_node_name))
################################
# Allocate memory space for data
################################
self.handler.data_malloc() self.handler.data_malloc()
#################################
# Copy in data to tensor objects
#################################
for name, obj in tensors.items(): for name, obj in tensors.items():
tensor = data.get(name) tensor = data.get(name)
if tensor == None: if tensor == None:
@ -625,22 +631,24 @@ class OnnxStub:
self.inputs[name] = obj self.inputs[name] = obj
else: else:
self.initializer[obj.fuid()] = tensor self.initializer[obj.fuid()] = tensor
if tensor.data_type == TensorProto.INT32: # TODO: delete these lines after copyin_numpy is stable
obj.copyin_int32(_parse_data(tensor)) # if tensor.data_type == TensorProto.INT32:
elif tensor.data_type == TensorProto.INT64: # obj.copyin_int32(_parse_data(tensor))
obj.copyin_int64(_parse_data(tensor)) # elif tensor.data_type == TensorProto.INT64:
elif tensor.data_type == TensorProto.FLOAT: # obj.copyin_int64(_parse_data(tensor))
obj.copyin_float(_parse_data(tensor)) # elif tensor.data_type == TensorProto.FLOAT:
elif tensor.data_type == TensorProto.BOOL: # obj.copyin_float(_parse_data(tensor))
obj.copyin_int8(_parse_data(tensor)) # elif tensor.data_type == TensorProto.BOOL:
elif tensor.data_type == TensorProto.FLOAT16: # obj.copyin_int8(_parse_data(tensor))
obj.copyin_float16(_parse_data_fp16(tensor)) # elif tensor.data_type == TensorProto.FLOAT16:
elif tensor.data_type == TensorProto.INT8: # obj.copyin_float16(_parse_data_fp16(tensor))
obj.copyin_uint8(_parse_data(tensor)) # elif tensor.data_type == TensorProto.INT8:
elif tensor.data_type == TensorProto.BFLOAT16: # obj.copyin_uint8(_parse_data(tensor))
obj.copyin_float16(_parse_data_fp16(tensor)) # elif tensor.data_type == TensorProto.BFLOAT16:
else: # obj.copyin_float16(_parse_data_fp16(tensor))
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type) # else:
# assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
obj.copyin_numpy(to_array(tensor))
for output in model.graph.output: for output in model.graph.output:
self.outputs[output.name] = tensors[output.name] self.outputs[output.name] = tensors[output.name]

View File

@ -9,7 +9,8 @@ from onnx.helper import (
) )
from onnx.checker import check_model, check_graph from onnx.checker import check_model, check_graph
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
from pyinfinitensor.onnx import from_onnx, OnnxStub, backend from pyinfinitensor.onnx import from_onnx, OnnxStub, backend, _parse_data_fp16
import numpy as np
def make_and_import_model(graph: onnx.GraphProto): def make_and_import_model(graph: onnx.GraphProto):
@ -382,6 +383,53 @@ class TestStringMethods(unittest.TestCase):
where = make_node("Where", ["x", "y", "con"], ["output"], name="where") where = make_node("Where", ["x", "y", "con"], ["output"], name="where")
make_and_import_model(make_graph([where], "where", [x, y, con], [output])) make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
def test_copyin(self):
dims = [2,3,5,4]
np_array = np.random.random(dims).astype(np.float32)
handler = backend.GraphHandler(backend.cpu_runtime())
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
tensor2 = handler.tensor(dims, TensorProto.FLOAT)
handler.data_malloc()
tensor1.copyin_numpy(np_array)
tensor2.copyin_float(np_array.flatten().tolist())
array1 = tensor1.copyout_float()
array2 = tensor2.copyout_float()
self.assertEqual(array1, array2)
self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array))
np_array = np.random.random(dims).astype(np.int64)
handler = backend.GraphHandler(backend.cpu_runtime())
tensor1 = handler.tensor(dims, TensorProto.INT64)
tensor2 = handler.tensor(dims, TensorProto.INT64)
handler.data_malloc()
tensor1.copyin_numpy(np_array)
tensor2.copyin_int64(np_array.flatten().tolist())
array1 = tensor1.copyout_int64()
array2 = tensor2.copyout_int64()
self.assertEqual(array1, array2)
self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array))
def test_to_numpy(self):
dims = [2,3,5,4]
np_array = np.random.random(dims).astype(np.float32)
handler = backend.GraphHandler(backend.cpu_runtime())
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
tensor2 = handler.tensor(dims, TensorProto.FLOAT)
handler.data_malloc()
tensor1.copyin_float(np_array.flatten().tolist())
tensor2.copyin_float(np_array.flatten().tolist())
array1 = np.array(tensor1.copyout_float()).reshape(dims)
array2 = np.array(tensor2)
self.assertTrue(np.array_equal(array2, np_array))
self.assertTrue(np.array_equal(array1, array2))
np_array = np.random.random(dims).astype(np.float16)
handler = backend.GraphHandler(backend.cpu_runtime())
tensor1 = handler.tensor(dims, TensorProto.FLOAT16)
handler.data_malloc()
tensor1.copyin_numpy(np_array)
array1 = np.array(tensor1, copy=False)
self.assertTrue(np.array_equal(array1, np_array))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -10,10 +10,8 @@ namespace infini {
TensorObj::TensorObj(Shape shape_, DataType dtype, Runtime runtime) TensorObj::TensorObj(Shape shape_, DataType dtype, Runtime runtime)
: TensorBaseObj(shape_.size(), dtype, runtime), shape(std::move(shape_)), : TensorBaseObj(shape_.size(), dtype, runtime), shape(std::move(shape_)),
_size(shape.empty() _size(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{})) {
? 0 }
: std::accumulate(shape.begin(), shape.end(), 1,
[](auto acc, auto x) { return acc * x; })) {}
string TensorObj::toString() const { string TensorObj::toString() const {
// Convert data pointer to string // Convert data pointer to string

View File

@ -1,3 +1,4 @@
#include "core/data_type.h"
#include "core/graph_handler.h" #include "core/graph_handler.h"
#include "operators/batch_norm.h" #include "operators/batch_norm.h"
#include "operators/concat.h" #include "operators/concat.h"
@ -13,8 +14,9 @@
#include "operators/transpose.h" #include "operators/transpose.h"
#include "operators/unary.h" #include "operators/unary.h"
#include <algorithm> #include <algorithm>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#ifdef USE_CUDA #ifdef USE_CUDA
#include "cuda/cuda_runtime.h" #include "cuda/cuda_runtime.h"
#include "cuda/operator_timer.h" #include "cuda/operator_timer.h"
@ -315,7 +317,8 @@ void init_graph_builder(py::module &m) {
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>( py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
m, "BangRuntime"); m, "BangRuntime");
#endif #endif
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor") py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor",
py::buffer_protocol())
.def("fuid", &TensorObj::getFuid, policy::automatic) .def("fuid", &TensorObj::getFuid, policy::automatic)
.def("shape", &TensorObj::getDims, policy::move) .def("shape", &TensorObj::getDims, policy::move)
.def("copyin_float", &TensorObj::copyin<float>, policy::move) .def("copyin_float", &TensorObj::copyin<float>, policy::move)
@ -330,6 +333,65 @@ void init_graph_builder(py::module &m) {
.def("copyout_int8", &TensorObj::copyout<int8_t>, policy::move) .def("copyout_int8", &TensorObj::copyout<int8_t>, policy::move)
.def("copyout_uint8", &TensorObj::copyout<uint8_t>, policy::move) .def("copyout_uint8", &TensorObj::copyout<uint8_t>, policy::move)
.def("copyout_float16", &TensorObj::copyout<uint16_t>, policy::move) .def("copyout_float16", &TensorObj::copyout<uint16_t>, policy::move)
// Copy data from a Numpy array
.def("copyin_numpy",
[](TensorObj &self, py::buffer buf) {
py::buffer_info buf_info = buf.request();
void *data_np = buf_info.ptr;
size_t itemsize = buf_info.itemsize;
size_t size = buf_info.size;
IT_ASSERT(itemsize == self.getDType().getSize());
IT_ASSERT(size == self.size());
for (size_t i = 0; i < self.getRank(); i++) {
IT_ASSERT(self.getDims()[i] == buf_info.shape[i]);
}
self.copyin(data_np, self.getBytes());
})
// A buffer can be used to convert a TensorObj directly to Numpy array
// without copy
.def_buffer([](TensorObj &self) -> py::buffer_info {
vector<size_t> stride_byte;
for (int s : self.getStride()) {
stride_byte.push_back(s * self.getDType().getSize());
}
std::string format;
if (self.getDType() == DataType::Float32) {
format = py::format_descriptor<float>::format();
} else if (self.getDType() == DataType::Double) {
format = py::format_descriptor<double>::format();
} else if (self.getDType() == DataType::Int32) {
format = py::format_descriptor<int>::format();
} else if (self.getDType() == DataType::UInt32) {
format = py::format_descriptor<uint32_t>::format();
} else if (self.getDType() == DataType::Int64) {
format = py::format_descriptor<int64_t>::format();
} else if (self.getDType() == DataType::UInt64) {
format = py::format_descriptor<uint64_t>::format();
} else if (self.getDType() == DataType::Int16) {
format = py::format_descriptor<int16_t>::format();
} else if (self.getDType() == DataType::UInt16) {
format = py::format_descriptor<uint16_t>::format();
} else if (self.getDType() == DataType::Int8) {
format = py::format_descriptor<int8_t>::format();
} else if (self.getDType() == DataType::UInt8) {
format = py::format_descriptor<uint8_t>::format();
} else if (self.getDType() == DataType::Float16 ||
self.getDType() == DataType::BFloat16) {
// Python uses "e" for half precision float type code.
// Check the following link for more information.
// https://docs.python.org/3/library/struct.html#format-characters
format = "e";
} else {
throw std::runtime_error("Error converting TensorObj to "
"Numpy: unsupported datatype.\n");
}
return py::buffer_info(self.getRawDataPtr<void *>(),
self.getDType().getSize(), format,
self.getRank(), self.getDims(), stride_byte,
true); // Read-only = true
})
.def("has_target", &TensorObj::hasTarget, policy::automatic) .def("has_target", &TensorObj::hasTarget, policy::automatic)
.def("src", &TensorObj::getSource, policy::move) .def("src", &TensorObj::getSource, policy::move)
.def("printData", &TensorObj::printData, policy::automatic); .def("printData", &TensorObj::printData, policy::automatic);

View File

@ -15,7 +15,7 @@ TEST(Graph, build_and_run) {
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32); Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
g->dataMalloc(); g->dataMalloc();
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6});
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0); auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
g->print(); g->print();
@ -84,7 +84,7 @@ TEST(Graph, perf_engine) {
auto matmul = g->addOp<MatmulObj>(i0, w0, nullptr); auto matmul = g->addOp<MatmulObj>(i0, w0, nullptr);
g->dataMalloc(); g->dataMalloc();
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6});
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
runtime->run(g, true, true); runtime->run(g, true, true);
double perfTime = runtime->getPerfTime(g); double perfTime = runtime->getPerfTime(g);
@ -105,7 +105,7 @@ TEST(Graph, test_tensor_id) {
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32); Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
g->dataMalloc(); g->dataMalloc();
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6});
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto i1 = g->addTensor(i0->clone()); auto i1 = g->addTensor(i0->clone());
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0); auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
@ -123,7 +123,7 @@ TEST(Graph, test_OpVec_ctor) {
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32); Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32); Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
g->dataMalloc(); g->dataMalloc();
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6});
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto o1 = g->addTensor(o0->clone()); auto o1 = g->addTensor(o0->clone());
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0); auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);