forked from jiuyuan/InfiniTensor
Issue 107: Add copyin Numpy and covertion to Numpy (#126)
* Add copyin_numpy and to_numpy for pybind TensorObj * fix copyin size assertion * fix size calculation for scalar (rank = 0) tensor * Use pybind buffer instead of returning array * fix format
This commit is contained in:
parent
3e6ef305f1
commit
2412c25e67
|
@ -19,14 +19,6 @@ class TensorObj : public TensorBaseObj {
|
||||||
size_t _size; // Cache of Π(shape).
|
size_t _size; // Cache of Π(shape).
|
||||||
Fuid fuid; // Cloned tensors share the same id. Tensors constructed from
|
Fuid fuid; // Cloned tensors share the same id. Tensors constructed from
|
||||||
// scratch have a new id.
|
// scratch have a new id.
|
||||||
|
|
||||||
void copyin(const void *ptr, size_t size) {
|
|
||||||
runtime->copyBlobFromCPU(getRawDataPtr<void *>(), ptr, size);
|
|
||||||
}
|
|
||||||
void copyout(void *ptr, size_t size) const {
|
|
||||||
runtime->copyBlobToCPU(ptr, getRawDataPtr<void *>(), size);
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
public:
|
||||||
TensorObj(Shape shape, DataType dtype, Runtime runtime);
|
TensorObj(Shape shape, DataType dtype, Runtime runtime);
|
||||||
virtual ~TensorObj() {}
|
virtual ~TensorObj() {}
|
||||||
|
@ -45,10 +37,17 @@ class TensorObj : public TensorBaseObj {
|
||||||
void load(std::string file_path);
|
void load(std::string file_path);
|
||||||
void save(std::string file_path);
|
void save(std::string file_path);
|
||||||
|
|
||||||
|
void copyin(const void *ptr, size_t size) {
|
||||||
|
runtime->copyBlobFromCPU(getRawDataPtr<void *>(), ptr, size);
|
||||||
|
}
|
||||||
|
void copyout(void *ptr, size_t size) const {
|
||||||
|
runtime->copyBlobToCPU(ptr, getRawDataPtr<void *>(), size);
|
||||||
|
}
|
||||||
|
|
||||||
// Copy elements from `data`.
|
// Copy elements from `data`.
|
||||||
template <typename T> void copyin(const vector<T> &data) {
|
template <typename T> void copyin(const vector<T> &data) {
|
||||||
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
IT_ASSERT(DataType::get<T>() == dtype.cpuTypeInt());
|
||||||
IT_ASSERT(data.size() >= _size);
|
IT_ASSERT(data.size() == _size);
|
||||||
copyin(data.data(), getBytes());
|
copyin(data.data(), getBytes());
|
||||||
}
|
}
|
||||||
// Copy all the elements to a vector.
|
// Copy all the elements to a vector.
|
||||||
|
|
|
@ -616,8 +616,14 @@ class OnnxStub:
|
||||||
# update the node_list
|
# update the node_list
|
||||||
node_list = list(set(node_name) - set(new_node_name))
|
node_list = list(set(node_name) - set(new_node_name))
|
||||||
|
|
||||||
|
################################
|
||||||
|
# Allocate memory space for data
|
||||||
|
################################
|
||||||
self.handler.data_malloc()
|
self.handler.data_malloc()
|
||||||
|
|
||||||
|
#################################
|
||||||
|
# Copy in data to tensor objects
|
||||||
|
#################################
|
||||||
for name, obj in tensors.items():
|
for name, obj in tensors.items():
|
||||||
tensor = data.get(name)
|
tensor = data.get(name)
|
||||||
if tensor == None:
|
if tensor == None:
|
||||||
|
@ -625,22 +631,24 @@ class OnnxStub:
|
||||||
self.inputs[name] = obj
|
self.inputs[name] = obj
|
||||||
else:
|
else:
|
||||||
self.initializer[obj.fuid()] = tensor
|
self.initializer[obj.fuid()] = tensor
|
||||||
if tensor.data_type == TensorProto.INT32:
|
# TODO: delete these lines after copyin_numpy is stable
|
||||||
obj.copyin_int32(_parse_data(tensor))
|
# if tensor.data_type == TensorProto.INT32:
|
||||||
elif tensor.data_type == TensorProto.INT64:
|
# obj.copyin_int32(_parse_data(tensor))
|
||||||
obj.copyin_int64(_parse_data(tensor))
|
# elif tensor.data_type == TensorProto.INT64:
|
||||||
elif tensor.data_type == TensorProto.FLOAT:
|
# obj.copyin_int64(_parse_data(tensor))
|
||||||
obj.copyin_float(_parse_data(tensor))
|
# elif tensor.data_type == TensorProto.FLOAT:
|
||||||
elif tensor.data_type == TensorProto.BOOL:
|
# obj.copyin_float(_parse_data(tensor))
|
||||||
obj.copyin_int8(_parse_data(tensor))
|
# elif tensor.data_type == TensorProto.BOOL:
|
||||||
elif tensor.data_type == TensorProto.FLOAT16:
|
# obj.copyin_int8(_parse_data(tensor))
|
||||||
obj.copyin_float16(_parse_data_fp16(tensor))
|
# elif tensor.data_type == TensorProto.FLOAT16:
|
||||||
elif tensor.data_type == TensorProto.INT8:
|
# obj.copyin_float16(_parse_data_fp16(tensor))
|
||||||
obj.copyin_uint8(_parse_data(tensor))
|
# elif tensor.data_type == TensorProto.INT8:
|
||||||
elif tensor.data_type == TensorProto.BFLOAT16:
|
# obj.copyin_uint8(_parse_data(tensor))
|
||||||
obj.copyin_float16(_parse_data_fp16(tensor))
|
# elif tensor.data_type == TensorProto.BFLOAT16:
|
||||||
else:
|
# obj.copyin_float16(_parse_data_fp16(tensor))
|
||||||
assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
# else:
|
||||||
|
# assert False, "Unsupported Tensor Type: {}".format(tensor.data_type)
|
||||||
|
obj.copyin_numpy(to_array(tensor))
|
||||||
|
|
||||||
for output in model.graph.output:
|
for output in model.graph.output:
|
||||||
self.outputs[output.name] = tensors[output.name]
|
self.outputs[output.name] = tensors[output.name]
|
||||||
|
|
|
@ -9,7 +9,8 @@ from onnx.helper import (
|
||||||
)
|
)
|
||||||
from onnx.checker import check_model, check_graph
|
from onnx.checker import check_model, check_graph
|
||||||
from onnx.shape_inference import infer_shapes
|
from onnx.shape_inference import infer_shapes
|
||||||
from pyinfinitensor.onnx import from_onnx, OnnxStub, backend
|
from pyinfinitensor.onnx import from_onnx, OnnxStub, backend, _parse_data_fp16
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def make_and_import_model(graph: onnx.GraphProto):
|
def make_and_import_model(graph: onnx.GraphProto):
|
||||||
|
@ -382,6 +383,53 @@ class TestStringMethods(unittest.TestCase):
|
||||||
where = make_node("Where", ["x", "y", "con"], ["output"], name="where")
|
where = make_node("Where", ["x", "y", "con"], ["output"], name="where")
|
||||||
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
|
make_and_import_model(make_graph([where], "where", [x, y, con], [output]))
|
||||||
|
|
||||||
|
def test_copyin(self):
|
||||||
|
dims = [2,3,5,4]
|
||||||
|
np_array = np.random.random(dims).astype(np.float32)
|
||||||
|
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||||
|
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
|
||||||
|
tensor2 = handler.tensor(dims, TensorProto.FLOAT)
|
||||||
|
handler.data_malloc()
|
||||||
|
tensor1.copyin_numpy(np_array)
|
||||||
|
tensor2.copyin_float(np_array.flatten().tolist())
|
||||||
|
array1 = tensor1.copyout_float()
|
||||||
|
array2 = tensor2.copyout_float()
|
||||||
|
self.assertEqual(array1, array2)
|
||||||
|
self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array))
|
||||||
|
|
||||||
|
np_array = np.random.random(dims).astype(np.int64)
|
||||||
|
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||||
|
tensor1 = handler.tensor(dims, TensorProto.INT64)
|
||||||
|
tensor2 = handler.tensor(dims, TensorProto.INT64)
|
||||||
|
handler.data_malloc()
|
||||||
|
tensor1.copyin_numpy(np_array)
|
||||||
|
tensor2.copyin_int64(np_array.flatten().tolist())
|
||||||
|
array1 = tensor1.copyout_int64()
|
||||||
|
array2 = tensor2.copyout_int64()
|
||||||
|
self.assertEqual(array1, array2)
|
||||||
|
self.assertTrue(np.array_equal(np.array(array1).reshape(dims), np_array))
|
||||||
|
|
||||||
|
def test_to_numpy(self):
|
||||||
|
dims = [2,3,5,4]
|
||||||
|
np_array = np.random.random(dims).astype(np.float32)
|
||||||
|
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||||
|
tensor1 = handler.tensor(dims, TensorProto.FLOAT)
|
||||||
|
tensor2 = handler.tensor(dims, TensorProto.FLOAT)
|
||||||
|
handler.data_malloc()
|
||||||
|
tensor1.copyin_float(np_array.flatten().tolist())
|
||||||
|
tensor2.copyin_float(np_array.flatten().tolist())
|
||||||
|
array1 = np.array(tensor1.copyout_float()).reshape(dims)
|
||||||
|
array2 = np.array(tensor2)
|
||||||
|
self.assertTrue(np.array_equal(array2, np_array))
|
||||||
|
self.assertTrue(np.array_equal(array1, array2))
|
||||||
|
|
||||||
|
np_array = np.random.random(dims).astype(np.float16)
|
||||||
|
handler = backend.GraphHandler(backend.cpu_runtime())
|
||||||
|
tensor1 = handler.tensor(dims, TensorProto.FLOAT16)
|
||||||
|
handler.data_malloc()
|
||||||
|
tensor1.copyin_numpy(np_array)
|
||||||
|
array1 = np.array(tensor1, copy=False)
|
||||||
|
self.assertTrue(np.array_equal(array1, np_array))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -10,10 +10,8 @@ namespace infini {
|
||||||
|
|
||||||
TensorObj::TensorObj(Shape shape_, DataType dtype, Runtime runtime)
|
TensorObj::TensorObj(Shape shape_, DataType dtype, Runtime runtime)
|
||||||
: TensorBaseObj(shape_.size(), dtype, runtime), shape(std::move(shape_)),
|
: TensorBaseObj(shape_.size(), dtype, runtime), shape(std::move(shape_)),
|
||||||
_size(shape.empty()
|
_size(std::accumulate(shape.begin(), shape.end(), 1, std::multiplies{})) {
|
||||||
? 0
|
}
|
||||||
: std::accumulate(shape.begin(), shape.end(), 1,
|
|
||||||
[](auto acc, auto x) { return acc * x; })) {}
|
|
||||||
|
|
||||||
string TensorObj::toString() const {
|
string TensorObj::toString() const {
|
||||||
// Convert data pointer to string
|
// Convert data pointer to string
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
#include "core/data_type.h"
|
||||||
#include "core/graph_handler.h"
|
#include "core/graph_handler.h"
|
||||||
#include "operators/batch_norm.h"
|
#include "operators/batch_norm.h"
|
||||||
#include "operators/concat.h"
|
#include "operators/concat.h"
|
||||||
|
@ -13,8 +14,9 @@
|
||||||
#include "operators/transpose.h"
|
#include "operators/transpose.h"
|
||||||
#include "operators/unary.h"
|
#include "operators/unary.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <pybind11/numpy.h>
|
||||||
|
#include <pybind11/pybind11.h>
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
#include "cuda/cuda_runtime.h"
|
#include "cuda/cuda_runtime.h"
|
||||||
#include "cuda/operator_timer.h"
|
#include "cuda/operator_timer.h"
|
||||||
|
@ -315,7 +317,8 @@ void init_graph_builder(py::module &m) {
|
||||||
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
||||||
m, "BangRuntime");
|
m, "BangRuntime");
|
||||||
#endif
|
#endif
|
||||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor",
|
||||||
|
py::buffer_protocol())
|
||||||
.def("fuid", &TensorObj::getFuid, policy::automatic)
|
.def("fuid", &TensorObj::getFuid, policy::automatic)
|
||||||
.def("shape", &TensorObj::getDims, policy::move)
|
.def("shape", &TensorObj::getDims, policy::move)
|
||||||
.def("copyin_float", &TensorObj::copyin<float>, policy::move)
|
.def("copyin_float", &TensorObj::copyin<float>, policy::move)
|
||||||
|
@ -330,6 +333,65 @@ void init_graph_builder(py::module &m) {
|
||||||
.def("copyout_int8", &TensorObj::copyout<int8_t>, policy::move)
|
.def("copyout_int8", &TensorObj::copyout<int8_t>, policy::move)
|
||||||
.def("copyout_uint8", &TensorObj::copyout<uint8_t>, policy::move)
|
.def("copyout_uint8", &TensorObj::copyout<uint8_t>, policy::move)
|
||||||
.def("copyout_float16", &TensorObj::copyout<uint16_t>, policy::move)
|
.def("copyout_float16", &TensorObj::copyout<uint16_t>, policy::move)
|
||||||
|
// Copy data from a Numpy array
|
||||||
|
.def("copyin_numpy",
|
||||||
|
[](TensorObj &self, py::buffer buf) {
|
||||||
|
py::buffer_info buf_info = buf.request();
|
||||||
|
void *data_np = buf_info.ptr;
|
||||||
|
size_t itemsize = buf_info.itemsize;
|
||||||
|
size_t size = buf_info.size;
|
||||||
|
IT_ASSERT(itemsize == self.getDType().getSize());
|
||||||
|
IT_ASSERT(size == self.size());
|
||||||
|
for (size_t i = 0; i < self.getRank(); i++) {
|
||||||
|
IT_ASSERT(self.getDims()[i] == buf_info.shape[i]);
|
||||||
|
}
|
||||||
|
self.copyin(data_np, self.getBytes());
|
||||||
|
})
|
||||||
|
// A buffer can be used to convert a TensorObj directly to Numpy array
|
||||||
|
// without copy
|
||||||
|
.def_buffer([](TensorObj &self) -> py::buffer_info {
|
||||||
|
vector<size_t> stride_byte;
|
||||||
|
for (int s : self.getStride()) {
|
||||||
|
stride_byte.push_back(s * self.getDType().getSize());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string format;
|
||||||
|
if (self.getDType() == DataType::Float32) {
|
||||||
|
format = py::format_descriptor<float>::format();
|
||||||
|
} else if (self.getDType() == DataType::Double) {
|
||||||
|
format = py::format_descriptor<double>::format();
|
||||||
|
} else if (self.getDType() == DataType::Int32) {
|
||||||
|
format = py::format_descriptor<int>::format();
|
||||||
|
} else if (self.getDType() == DataType::UInt32) {
|
||||||
|
format = py::format_descriptor<uint32_t>::format();
|
||||||
|
} else if (self.getDType() == DataType::Int64) {
|
||||||
|
format = py::format_descriptor<int64_t>::format();
|
||||||
|
} else if (self.getDType() == DataType::UInt64) {
|
||||||
|
format = py::format_descriptor<uint64_t>::format();
|
||||||
|
} else if (self.getDType() == DataType::Int16) {
|
||||||
|
format = py::format_descriptor<int16_t>::format();
|
||||||
|
} else if (self.getDType() == DataType::UInt16) {
|
||||||
|
format = py::format_descriptor<uint16_t>::format();
|
||||||
|
} else if (self.getDType() == DataType::Int8) {
|
||||||
|
format = py::format_descriptor<int8_t>::format();
|
||||||
|
} else if (self.getDType() == DataType::UInt8) {
|
||||||
|
format = py::format_descriptor<uint8_t>::format();
|
||||||
|
} else if (self.getDType() == DataType::Float16 ||
|
||||||
|
self.getDType() == DataType::BFloat16) {
|
||||||
|
// Python uses "e" for half precision float type code.
|
||||||
|
// Check the following link for more information.
|
||||||
|
// https://docs.python.org/3/library/struct.html#format-characters
|
||||||
|
format = "e";
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Error converting TensorObj to "
|
||||||
|
"Numpy: unsupported datatype.\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
return py::buffer_info(self.getRawDataPtr<void *>(),
|
||||||
|
self.getDType().getSize(), format,
|
||||||
|
self.getRank(), self.getDims(), stride_byte,
|
||||||
|
true); // Read-only = true
|
||||||
|
})
|
||||||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||||
.def("src", &TensorObj::getSource, policy::move)
|
.def("src", &TensorObj::getSource, policy::move)
|
||||||
.def("printData", &TensorObj::printData, policy::automatic);
|
.def("printData", &TensorObj::printData, policy::automatic);
|
||||||
|
|
|
@ -15,7 +15,7 @@ TEST(Graph, build_and_run) {
|
||||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||||
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
||||||
g->dataMalloc();
|
g->dataMalloc();
|
||||||
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6});
|
||||||
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||||
g->print();
|
g->print();
|
||||||
|
@ -84,7 +84,7 @@ TEST(Graph, perf_engine) {
|
||||||
auto matmul = g->addOp<MatmulObj>(i0, w0, nullptr);
|
auto matmul = g->addOp<MatmulObj>(i0, w0, nullptr);
|
||||||
|
|
||||||
g->dataMalloc();
|
g->dataMalloc();
|
||||||
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6});
|
||||||
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
runtime->run(g, true, true);
|
runtime->run(g, true, true);
|
||||||
double perfTime = runtime->getPerfTime(g);
|
double perfTime = runtime->getPerfTime(g);
|
||||||
|
@ -105,7 +105,7 @@ TEST(Graph, test_tensor_id) {
|
||||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||||
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
||||||
g->dataMalloc();
|
g->dataMalloc();
|
||||||
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6});
|
||||||
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
auto i1 = g->addTensor(i0->clone());
|
auto i1 = g->addTensor(i0->clone());
|
||||||
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||||
|
@ -123,7 +123,7 @@ TEST(Graph, test_OpVec_ctor) {
|
||||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||||
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
|
||||||
g->dataMalloc();
|
g->dataMalloc();
|
||||||
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
i0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6});
|
||||||
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
w0->copyin(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||||
auto o1 = g->addTensor(o0->clone());
|
auto o1 = g->addTensor(o0->clone());
|
||||||
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
|
||||||
|
|
Loading…
Reference in New Issue