forked from jiuyuan/InfiniTensor
parent
28287f3782
commit
17033fad97
|
@ -141,15 +141,33 @@ optimization::DataType cast(DataType ty) {
|
|||
#undef IT
|
||||
}
|
||||
|
||||
void GraphObj::optimize() {
|
||||
DataType cast(optimization::DataType ty) {
|
||||
#define IT(A, B) \
|
||||
if (optimization::DataTypeId::A == ty.id) \
|
||||
return {DataType::B};
|
||||
|
||||
IT(FLOAT, Float32) //
|
||||
else IT(UINT32, UInt32) //
|
||||
else IT(UINT8, UInt8) //
|
||||
else IT(INT8, Int8) //
|
||||
else IT(UINT16, UInt16) //
|
||||
else IT(INT16, Int16) //
|
||||
else IT(INT32, Int32) //
|
||||
else IT(INT64, Int64) //
|
||||
else IT_ASSERT(false, "unsupported data type");
|
||||
|
||||
#undef IT
|
||||
}
|
||||
|
||||
optimization::Unigraph cast(GraphObj &g) {
|
||||
namespace opt = optimization;
|
||||
|
||||
topo_sort();
|
||||
g.topo_sort();
|
||||
|
||||
#define I(PTR) reinterpret_cast<uintptr_t>((PTR).get())
|
||||
|
||||
unordered_map<uintptr_t, opt::Arc<opt::Tensor>> tensors;
|
||||
for (const auto &t : this->getTensors()) {
|
||||
for (const auto &t : g.getTensors()) {
|
||||
const auto dims = t->getDims();
|
||||
opt::Vec<size_t> shape(dims.size());
|
||||
std::transform(dims.begin(), dims.end(), shape.begin(),
|
||||
|
@ -168,7 +186,7 @@ void GraphObj::optimize() {
|
|||
|
||||
opt::Unigraph ans;
|
||||
|
||||
for (const auto &op : this->getOperators()) {
|
||||
for (const auto &op : g.getOperators()) {
|
||||
const auto inputs = op->getInputs(), outputs = op->getOutputs();
|
||||
opt::Vec<opt::Arc<opt::Tensor>> in(inputs.size()), out(outputs.size());
|
||||
std::transform(inputs.begin(), inputs.end(), in.begin(),
|
||||
|
@ -453,6 +471,48 @@ void GraphObj::optimize() {
|
|||
}
|
||||
|
||||
#undef I
|
||||
return ans;
|
||||
}
|
||||
|
||||
Graph cast(optimization::Unigraph const &g, Runtime rt) {
|
||||
namespace opt = optimization;
|
||||
|
||||
unordered_map<uintptr_t, Tensor> tensors;
|
||||
|
||||
#define I(PTR) reinterpret_cast<uintptr_t>((PTR).get())
|
||||
|
||||
auto ans = make_ref<GraphObj>(std::move(rt));
|
||||
|
||||
for (auto const &op : g.operators) {
|
||||
for (auto const &t : op.inputs) {
|
||||
auto const &shape_ = t->shape;
|
||||
opt::Vec<int> shape(shape_.size());
|
||||
std::transform(shape_.begin(), shape_.end(), shape.begin(),
|
||||
[](auto x) { return static_cast<int>(x); });
|
||||
tensors[I(t)] =
|
||||
ans->addTensor(std::move(shape), cast(t->data_type));
|
||||
}
|
||||
for (auto const &t : op.outputs) {
|
||||
auto const &shape_ = t->shape;
|
||||
opt::Vec<int> shape(shape_.size());
|
||||
std::transform(shape_.begin(), shape_.end(), shape.begin(),
|
||||
[](auto x) { return static_cast<int>(x); });
|
||||
tensors[I(t)] =
|
||||
ans->addTensor(std::move(shape), cast(t->data_type));
|
||||
}
|
||||
switch (op.op_type) {
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#undef I
|
||||
return ans;
|
||||
}
|
||||
|
||||
void GraphObj::optimize() {
|
||||
auto graph = cast(*this);
|
||||
auto ans = cast(graph, this->runtime);
|
||||
}
|
||||
|
||||
void GraphObj::dataMalloc() {
|
||||
|
|
Loading…
Reference in New Issue