feat: 添加导出

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-04-23 13:47:02 +08:00
parent 28287f3782
commit 17033fad97
1 changed files with 64 additions and 4 deletions

View File

@ -141,15 +141,33 @@ optimization::DataType cast(DataType ty) {
#undef IT
}
void GraphObj::optimize() {
DataType cast(optimization::DataType ty) {
#define IT(A, B) \
if (optimization::DataTypeId::A == ty.id) \
return {DataType::B};
IT(FLOAT, Float32) //
else IT(UINT32, UInt32) //
else IT(UINT8, UInt8) //
else IT(INT8, Int8) //
else IT(UINT16, UInt16) //
else IT(INT16, Int16) //
else IT(INT32, Int32) //
else IT(INT64, Int64) //
else IT_ASSERT(false, "unsupported data type");
#undef IT
}
optimization::Unigraph cast(GraphObj &g) {
namespace opt = optimization;
topo_sort();
g.topo_sort();
#define I(PTR) reinterpret_cast<uintptr_t>((PTR).get())
unordered_map<uintptr_t, opt::Arc<opt::Tensor>> tensors;
for (const auto &t : this->getTensors()) {
for (const auto &t : g.getTensors()) {
const auto dims = t->getDims();
opt::Vec<size_t> shape(dims.size());
std::transform(dims.begin(), dims.end(), shape.begin(),
@ -168,7 +186,7 @@ void GraphObj::optimize() {
opt::Unigraph ans;
for (const auto &op : this->getOperators()) {
for (const auto &op : g.getOperators()) {
const auto inputs = op->getInputs(), outputs = op->getOutputs();
opt::Vec<opt::Arc<opt::Tensor>> in(inputs.size()), out(outputs.size());
std::transform(inputs.begin(), inputs.end(), in.begin(),
@ -453,6 +471,48 @@ void GraphObj::optimize() {
}
#undef I
return ans;
}
Graph cast(optimization::Unigraph const &g, Runtime rt) {
namespace opt = optimization;
unordered_map<uintptr_t, Tensor> tensors;
#define I(PTR) reinterpret_cast<uintptr_t>((PTR).get())
auto ans = make_ref<GraphObj>(std::move(rt));
for (auto const &op : g.operators) {
for (auto const &t : op.inputs) {
auto const &shape_ = t->shape;
opt::Vec<int> shape(shape_.size());
std::transform(shape_.begin(), shape_.end(), shape.begin(),
[](auto x) { return static_cast<int>(x); });
tensors[I(t)] =
ans->addTensor(std::move(shape), cast(t->data_type));
}
for (auto const &t : op.outputs) {
auto const &shape_ = t->shape;
opt::Vec<int> shape(shape_.size());
std::transform(shape_.begin(), shape_.end(), shape.begin(),
[](auto x) { return static_cast<int>(x); });
tensors[I(t)] =
ans->addTensor(std::move(shape), cast(t->data_type));
}
switch (op.op_type) {
default:
break;
}
}
#undef I
return ans;
}
void GraphObj::optimize() {
auto graph = cast(*this);
auto ans = cast(graph, this->runtime);
}
void GraphObj::dataMalloc() {