feat: 导出新前端的 3 个函数并修改 python

Signed-off-by: YdrMaster <ydrml@hotmail.com>
This commit is contained in:
YdrMaster 2023-09-08 14:19:58 +08:00
parent a8c117a48c
commit ed66ab6a37
3 changed files with 122 additions and 1111 deletions

@ -1 +1 @@
Subproject commit 1a2f03d8cbee2ae936184995cbb606462f54e02c
Subproject commit 7afc18a5632e59f80466049071e96b54adaf2160

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,4 @@
#include "../RefactorGraph/src/02computation/include/graph/graph.h"
#include "core/data_type.h"
#include "core/graph_handler.h"
#include "operators/batch_norm.h"
@ -470,9 +471,73 @@ void init_graph_builder(py::module &m) {
} // namespace infini
namespace {
using namespace refactor;
graph::Edge edge(int dataType, std::vector<graph::DimExpr> shape,
std::optional<std::vector<uint8_t>> data) {
graph::Shape s(shape.begin(), shape.end());
auto ans = std::make_shared<graph::Tensor>(
static_cast<common::DataType>(dataType), std::move(s));
if (data) {
auto const bytesSize = ans->bytesSize();
std::cout << "bytesSize: " << bytesSize << std::endl;
std::cout << "data size: " << data->size() << std::endl;
IT_ASSERT(bytesSize == data->size());
ans->data = std::make_shared<graph::Blob>(new uint8_t[bytesSize]);
std::memcpy(ans->data->ptr, data->data(), bytesSize);
}
return ans;
}
graph::Node
node(std::string opType,
std::unordered_map<std::string, decltype(graph::Attribute::value)> attrs) {
std::unordered_map<std::string, graph::Attribute> attrs_;
for (auto it = attrs.begin(); it != attrs.end(); attrs.erase(it++)) {
attrs_.insert({std::move(it->first), {std::move(it->second)}});
}
return std::make_shared<graph::NodeInfo>(
graph::Operator{common::OpType::Unknown, std::move(attrs_)});
}
std::shared_ptr<graph::Graph>
graph(std::unordered_map<std::string, std::pair<std::vector<std::string>,
std::vector<std::string>>>
topology,
std::unordered_map<std::string, graph::Node> nodes,
std::unordered_map<std::string, graph::Edge> edges,
std::vector<std::string> inputs, std::vector<std::string> outputs) {
exit(999);
auto builder = graph_topo::Builder<std::string, graph::Node, std::string,
graph::Edge>{};
for (auto &[node, rels] : topology) {
builder.topology.insert(
{std::move(node), {std::move(rels.first), std::move(rels.second)}});
}
builder.nodes = std::move(nodes);
builder.edges = std::move(edges);
builder.globalInputs = std::move(inputs);
builder.globalOutputs = std::move(outputs);
return std::make_shared<graph::Graph>(builder.build());
}
void register_refactor(py::module &m) {
py::class_<graph::DimExpr>(m, "DimExpr")
.def(py::init<int64_t>())
.def(py::init<std::string &&>());
py::class_<graph::NodeInfo, graph::Node>(m, "Node");
py::class_<graph::Tensor, graph::Edge>(m, "Edge");
m.def("refactor_tensor", edge)
.def("refactor_operator", node)
.def("refactor_graph", graph);
}
} // namespace
PYBIND11_MODULE(backend, m) {
infini::register_operator_timer(m);
infini::export_values(m);
infini::export_functions(m);
infini::init_graph_builder(m);
register_refactor(m);
}