From 970c77d0f4077862ba24d07104d1a0d35ba2573b Mon Sep 17 00:00:00 2001 From: Pairshoe Date: Wed, 26 Oct 2022 21:06:21 +0800 Subject: [PATCH] Update: Rename GraphFactory -> GraphBuilder && Remove unnecessary outputs --- .../core/{graph_factory.h => graph_builder.h} | 6 +- include/core/runtime.h | 4 +- python/infinitensor/import_onnx.py | 2 +- python/infinitensor/test_import_onnx.py | 6 +- .../{graph_factory.cc => graph_builder.cc} | 124 +++++++++--------- src/core/operator.cc | 8 -- src/ffi/ffi_infinitensor.cc | 68 +++++----- ...graph_factory.cc => test_graph_builder.cc} | 62 ++++----- 8 files changed, 136 insertions(+), 144 deletions(-) rename include/core/{graph_factory.h => graph_builder.h} (98%) rename src/core/{graph_factory.cc => graph_builder.cc} (82%) rename test/core/{test_graph_factory.cc => test_graph_builder.cc} (85%) diff --git a/include/core/graph_factory.h b/include/core/graph_builder.h similarity index 98% rename from include/core/graph_factory.h rename to include/core/graph_builder.h index c320b39e..431f55e2 100644 --- a/include/core/graph_factory.h +++ b/include/core/graph_builder.h @@ -21,12 +21,12 @@ namespace infini { -class GraphFactoryObj { +class GraphBuilderObj { private: Graph g; public: - GraphFactoryObj(Runtime runtime) : g(make_ref(runtime)) {} + GraphBuilderObj(Runtime runtime) : g(make_ref(runtime)) {} // tensors Tensor tensor(Shape dim, const std::string &dtype); @@ -163,4 +163,4 @@ class GraphFactoryObj { nnet::Expr expr, double exec_time, std::string hint = {}); }; -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/include/core/runtime.h b/include/core/runtime.h index f387efe5..f7e6fad4 100644 --- a/include/core/runtime.h +++ b/include/core/runtime.h @@ -10,7 +10,7 @@ class TensorBaseObj; class TensorObj; class OperatorObj; class GraphObj; -class GraphFactoryObj; +class GraphBuilderObj; class RuntimeObj; class BlobObj; @@ -18,7 +18,7 @@ using TensorBase = Ref; using Tensor = Ref; using Operator = Ref; using Graph = Ref; -using GraphFactory = Ref; +using GraphBuilder = Ref; using Runtime = Ref; using Blob = Ref; enum class OpType; diff --git a/python/infinitensor/import_onnx.py b/python/infinitensor/import_onnx.py index f91df739..0e6e6ccb 100644 --- a/python/infinitensor/import_onnx.py +++ b/python/infinitensor/import_onnx.py @@ -122,7 +122,7 @@ def _onnx_datatype_tostring(dtype): assert False, 'Unknown onnx datatype' -def import_onnx(gf: GraphFactory, net: str): +def import_onnx(gf: GraphBuilder, net: str): ts, ds, ops, consts = dict(), dict(), dict(), dict() # (key, value) = (name, class) model = onnx.load(net) diff --git a/python/infinitensor/test_import_onnx.py b/python/infinitensor/test_import_onnx.py index d6e35dbf..581fb206 100644 --- a/python/infinitensor/test_import_onnx.py +++ b/python/infinitensor/test_import_onnx.py @@ -5,8 +5,8 @@ import sys def main(netPath): runtime = CpuRuntimeObj.getInstance() - graphFactory = GraphFactoryObj(runtime) - import_onnx(graphFactory, netPath) + graphBuilder = GraphBuilderObj(runtime) + import_onnx(graphBuilder, netPath) if __name__ == "__main__": - main(sys.argv[1]) \ No newline at end of file + main(sys.argv[1]) diff --git a/src/core/graph_factory.cc b/src/core/graph_builder.cc similarity index 82% rename from src/core/graph_factory.cc rename to src/core/graph_builder.cc index 8e3d0934..8859012c 100644 --- a/src/core/graph_factory.cc +++ b/src/core/graph_builder.cc @@ -1,8 +1,8 @@ -#include "core/graph_factory.h" +#include "core/graph_builder.h" namespace infini { -Tensor GraphFactoryObj::tensor(Shape dim, const std::string &dtype) { +Tensor GraphBuilderObj::tensor(Shape dim, const std::string &dtype) { if (dtype == "FLOAT") { return g->addTensor(dim, DataType::Float32); } @@ -12,7 +12,7 @@ Tensor GraphFactoryObj::tensor(Shape dim, const std::string &dtype) { IT_TODO_HALT_MSG("Unsupported data type"); } -Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output, +Operator GraphBuilderObj::conv(Tensor input, Tensor weight, Tensor output, int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); @@ -23,7 +23,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output, return op; } -Operator GraphFactoryObj::conv(Tensor input, Tensor weight, int ph, int pw, +Operator GraphBuilderObj::conv(Tensor input, Tensor weight, int ph, int pw, int sh, int sw, int dh, int dw, Tensor bias) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor w0 = g->addTensor(weight->getDims(), weight->getDType()); @@ -31,7 +31,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, int ph, int pw, return op; } -Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output, +Operator GraphBuilderObj::conv(Tensor input, Tensor weight, Tensor output, ConvBaseObj::PaddingMode pm, int sh, int sw, int dh, int dw, Tensor bias) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); @@ -42,7 +42,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output, return op; } -Operator GraphFactoryObj::conv(Tensor input, Tensor weight, +Operator GraphBuilderObj::conv(Tensor input, Tensor weight, ConvBaseObj::PaddingMode pm, int sh, int sw, int dh, int dw, Tensor bias) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); @@ -51,7 +51,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, return op; } -Operator GraphFactoryObj::matmul(Tensor A, Tensor B, Tensor C, bool transA, +Operator GraphBuilderObj::matmul(Tensor A, Tensor B, Tensor C, bool transA, bool transB, Tensor bias, ActType act) { Tensor i0 = g->addTensor(A->getDims(), A->getDType()); Tensor i1 = g->addTensor(B->getDims(), B->getDType()); @@ -61,7 +61,7 @@ Operator GraphFactoryObj::matmul(Tensor A, Tensor B, Tensor C, bool transA, return op; } -Operator GraphFactoryObj::matmul(Tensor A, Tensor B, bool transA, bool transB, +Operator GraphBuilderObj::matmul(Tensor A, Tensor B, bool transA, bool transB, Tensor bias, ActType act) { Tensor i0 = g->addTensor(A->getDims(), A->getDType()); Tensor i1 = g->addTensor(B->getDims(), B->getDType()); @@ -69,7 +69,7 @@ Operator GraphFactoryObj::matmul(Tensor A, Tensor B, bool transA, bool transB, return op; } -Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output, +Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, Tensor output, int ph, int pw, int sh, int sw, int dh, int dw, int oph, int opw, int group, Tensor bias, ActType act) { @@ -81,7 +81,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output, return op; } -Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, int ph, int pw, +Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, int ph, int pw, int sh, int sw, int dh, int dw, int oph, int opw, int group, Tensor bias, ActType act) { @@ -92,7 +92,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, int ph, int pw, return op; } -Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output, +Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, Tensor output, ConvBaseObj::PaddingMode pm, int sh, int sw, int dh, int dw, int oph, int opw, int group, Tensor bias, ActType act) { @@ -104,7 +104,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output, return op; } -Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, +Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, ConvBaseObj::PaddingMode pm, int sh, int sw, int dh, int dw, int oph, int opw, int group, Tensor bias, ActType act) { @@ -115,7 +115,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, return op; } -Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width, +Operator GraphBuilderObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width, const int dilation, Tensor bias, ActType act) { Tensor i0 = g->addTensor(A->getDims(), A->getDType()); Tensor i1 = g->addTensor(B->getDims(), B->getDType()); @@ -125,7 +125,7 @@ Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width, return op; } -Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, const int width, +Operator GraphBuilderObj::g2bmm(Tensor A, Tensor B, const int width, const int dilation, Tensor bias, ActType act) { Tensor i0 = g->addTensor(A->getDims(), A->getDType()); Tensor i1 = g->addTensor(B->getDims(), B->getDType()); @@ -133,7 +133,7 @@ Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, const int width, return op; } -Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, Tensor C, +Operator GraphBuilderObj::gbmml(Tensor A, Tensor B, Tensor C, const int dilation, Tensor bias, ActType act) { Tensor i0 = g->addTensor(A->getDims(), A->getDType()); Tensor i1 = g->addTensor(B->getDims(), B->getDType()); @@ -142,7 +142,7 @@ Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, Tensor C, return op; } -Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, const int dilation, +Operator GraphBuilderObj::gbmml(Tensor A, Tensor B, const int dilation, Tensor bias, ActType act) { Tensor i0 = g->addTensor(A->getDims(), A->getDType()); Tensor i1 = g->addTensor(B->getDims(), B->getDType()); @@ -150,7 +150,7 @@ Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, const int dilation, return op; } -Operator GraphFactoryObj::pad(Tensor input, Tensor output, +Operator GraphBuilderObj::pad(Tensor input, Tensor output, const vector &pads, const optional> &axis) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); @@ -159,14 +159,14 @@ Operator GraphFactoryObj::pad(Tensor input, Tensor output, return op; } -Operator GraphFactoryObj::pad(Tensor input, const vector &pads, +Operator GraphBuilderObj::pad(Tensor input, const vector &pads, const optional> &axis) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, nullptr, pads, axis); return op; } -Operator GraphFactoryObj::slice(Tensor input, Tensor output, +Operator GraphBuilderObj::slice(Tensor input, Tensor output, const vector &starts, const vector &ends, const optional> &axis, @@ -177,7 +177,7 @@ Operator GraphFactoryObj::slice(Tensor input, Tensor output, return op; } -Operator GraphFactoryObj::slice(Tensor input, const vector &starts, +Operator GraphBuilderObj::slice(Tensor input, const vector &starts, const vector &ends, const optional> &axis, const optional> &steps) { @@ -186,7 +186,7 @@ Operator GraphFactoryObj::slice(Tensor input, const vector &starts, return op; } -Operator GraphFactoryObj::concat(TensorVec inputs, Tensor output, int dim) { +Operator GraphBuilderObj::concat(TensorVec inputs, Tensor output, int dim) { TensorVec is; for (auto input : inputs) { Tensor i = g->addTensor(input->getDims(), input->getDType()); @@ -197,7 +197,7 @@ Operator GraphFactoryObj::concat(TensorVec inputs, Tensor output, int dim) { return op; } -Operator GraphFactoryObj::concat(TensorVec inputs, int dim) { +Operator GraphBuilderObj::concat(TensorVec inputs, int dim) { TensorVec is; for (auto input : inputs) { Tensor i = g->addTensor(input->getDims(), input->getDType()); @@ -207,7 +207,7 @@ Operator GraphFactoryObj::concat(TensorVec inputs, int dim) { return op; } -Operator GraphFactoryObj::split(Tensor input, std::optional outputs, +Operator GraphBuilderObj::split(Tensor input, std::optional outputs, int dim, int num) { Tensor i = g->addTensor(input->getDims(), input->getDType()); if (outputs.has_value()) { @@ -224,13 +224,13 @@ Operator GraphFactoryObj::split(Tensor input, std::optional outputs, } } -Operator GraphFactoryObj::split(Tensor input, int dim, int num) { +Operator GraphBuilderObj::split(Tensor input, int dim, int num) { Tensor i = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i, std::nullopt, dim, num); return op; } -Operator GraphFactoryObj::split(Tensor input, std::optional outputs, +Operator GraphBuilderObj::split(Tensor input, std::optional outputs, int dim, const vector &ratio) { Tensor i = g->addTensor(input->getDims(), input->getDType()); if (outputs.has_value()) { @@ -247,14 +247,14 @@ Operator GraphFactoryObj::split(Tensor input, std::optional outputs, } } -Operator GraphFactoryObj::split(Tensor input, int dim, +Operator GraphBuilderObj::split(Tensor input, int dim, const vector &ratio) { Tensor i = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i, std::nullopt, dim, ratio); return op; } -Operator GraphFactoryObj::extend(Tensor input, Tensor output, int dim, +Operator GraphBuilderObj::extend(Tensor input, Tensor output, int dim, int num) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); @@ -262,13 +262,13 @@ Operator GraphFactoryObj::extend(Tensor input, Tensor output, int dim, return op; } -Operator GraphFactoryObj::extend(Tensor input, int dim, int num) { +Operator GraphBuilderObj::extend(Tensor input, int dim, int num) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, nullptr, dim, num); return op; } -Operator GraphFactoryObj::maxpool(Tensor input, Tensor output, int kh, int kw, +Operator GraphBuilderObj::maxpool(Tensor input, Tensor output, int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); @@ -278,14 +278,14 @@ Operator GraphFactoryObj::maxpool(Tensor input, Tensor output, int kh, int kw, return op; } -Operator GraphFactoryObj::maxpool(Tensor input, int kh, int kw, int dh, int dw, +Operator GraphBuilderObj::maxpool(Tensor input, int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw); return op; } -Operator GraphFactoryObj::avgpool(Tensor input, Tensor output, int kh, int kw, +Operator GraphBuilderObj::avgpool(Tensor input, Tensor output, int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); @@ -295,14 +295,14 @@ Operator GraphFactoryObj::avgpool(Tensor input, Tensor output, int kh, int kw, return op; } -Operator GraphFactoryObj::avgpool(Tensor input, int kh, int kw, int dh, int dw, +Operator GraphBuilderObj::avgpool(Tensor input, int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw); return op; } -Operator GraphFactoryObj::add(Tensor input0, Tensor input1, Tensor output) { +Operator GraphBuilderObj::add(Tensor input0, Tensor input1, Tensor output) { Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); @@ -310,14 +310,14 @@ Operator GraphFactoryObj::add(Tensor input0, Tensor input1, Tensor output) { return op; } -Operator GraphFactoryObj::add(Tensor input0, Tensor input1) { +Operator GraphBuilderObj::add(Tensor input0, Tensor input1) { Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); auto op = g->addOp(i0, i1, nullptr); return op; } -Operator GraphFactoryObj::sub(Tensor input0, Tensor input1, Tensor output) { +Operator GraphBuilderObj::sub(Tensor input0, Tensor input1, Tensor output) { Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); @@ -325,14 +325,14 @@ Operator GraphFactoryObj::sub(Tensor input0, Tensor input1, Tensor output) { return op; } -Operator GraphFactoryObj::sub(Tensor input0, Tensor input1) { +Operator GraphBuilderObj::sub(Tensor input0, Tensor input1) { Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); auto op = g->addOp(i0, i1, nullptr); return op; } -Operator GraphFactoryObj::mul(Tensor input0, Tensor input1, Tensor output) { +Operator GraphBuilderObj::mul(Tensor input0, Tensor input1, Tensor output) { Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); @@ -340,14 +340,14 @@ Operator GraphFactoryObj::mul(Tensor input0, Tensor input1, Tensor output) { return op; } -Operator GraphFactoryObj::mul(Tensor input0, Tensor input1) { +Operator GraphBuilderObj::mul(Tensor input0, Tensor input1) { Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); auto op = g->addOp(i0, i1, nullptr); return op; } -Operator GraphFactoryObj::div(Tensor input0, Tensor input1, Tensor output) { +Operator GraphBuilderObj::div(Tensor input0, Tensor input1, Tensor output) { Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); @@ -355,14 +355,14 @@ Operator GraphFactoryObj::div(Tensor input0, Tensor input1, Tensor output) { return op; } -Operator GraphFactoryObj::div(Tensor input0, Tensor input1) { +Operator GraphBuilderObj::div(Tensor input0, Tensor input1) { Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); auto op = g->addOp(i0, i1, nullptr); return op; } -Operator GraphFactoryObj::pow(Tensor input0, Tensor input1, Tensor output) { +Operator GraphBuilderObj::pow(Tensor input0, Tensor input1, Tensor output) { Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); @@ -370,14 +370,14 @@ Operator GraphFactoryObj::pow(Tensor input0, Tensor input1, Tensor output) { return op; } -Operator GraphFactoryObj::pow(Tensor input0, Tensor input1) { +Operator GraphBuilderObj::pow(Tensor input0, Tensor input1) { Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); auto op = g->addOp(i0, i1, nullptr); return op; } -Operator GraphFactoryObj::gather(Tensor input, Tensor index, Tensor output, +Operator GraphBuilderObj::gather(Tensor input, Tensor index, Tensor output, int axis) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); @@ -385,13 +385,13 @@ Operator GraphFactoryObj::gather(Tensor input, Tensor index, Tensor output, return op; } -Operator GraphFactoryObj::gather(Tensor input, Tensor index, int axis) { +Operator GraphBuilderObj::gather(Tensor input, Tensor index, int axis) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, index, nullptr, axis); return op; } -Operator GraphFactoryObj::reshape(Tensor input, Tensor output, +Operator GraphBuilderObj::reshape(Tensor input, Tensor output, const Shape &dims) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); @@ -399,104 +399,104 @@ Operator GraphFactoryObj::reshape(Tensor input, Tensor output, return op; } -Operator GraphFactoryObj::reshape(Tensor input, const Shape &dims) { +Operator GraphBuilderObj::reshape(Tensor input, const Shape &dims) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, nullptr, dims); return op; } -Operator GraphFactoryObj::flatten(Tensor input, Tensor output) { +Operator GraphBuilderObj::flatten(Tensor input, Tensor output) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); auto op = g->addOpWithOutputs(i0, o0); return op; } -Operator GraphFactoryObj::flatten(Tensor input) { +Operator GraphBuilderObj::flatten(Tensor input) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, nullptr); return op; } -Operator GraphFactoryObj::identity(Tensor input, Tensor output) { +Operator GraphBuilderObj::identity(Tensor input, Tensor output) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); auto op = g->addOpWithOutputs(i0, o0); return op; } -Operator GraphFactoryObj::identity(Tensor input) { +Operator GraphBuilderObj::identity(Tensor input) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, nullptr); return op; } -Operator GraphFactoryObj::softmax(Tensor input, Tensor output) { +Operator GraphBuilderObj::softmax(Tensor input, Tensor output) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); auto op = g->addOpWithOutputs(i0, o0); return op; } -Operator GraphFactoryObj::softmax(Tensor input) { +Operator GraphBuilderObj::softmax(Tensor input) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, nullptr); return op; } -Operator GraphFactoryObj::relu(Tensor input, Tensor output) { +Operator GraphBuilderObj::relu(Tensor input, Tensor output) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); auto op = g->addOpWithOutputs(i0, o0); return op; } -Operator GraphFactoryObj::relu(Tensor input) { +Operator GraphBuilderObj::relu(Tensor input) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, nullptr); return op; } -Operator GraphFactoryObj::sigmoid(Tensor input, Tensor output) { +Operator GraphBuilderObj::sigmoid(Tensor input, Tensor output) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); auto op = g->addOpWithOutputs(i0, o0); return op; } -Operator GraphFactoryObj::sigmoid(Tensor input) { +Operator GraphBuilderObj::sigmoid(Tensor input) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, nullptr); return op; } -Operator GraphFactoryObj::tanh(Tensor input, Tensor output) { +Operator GraphBuilderObj::tanh(Tensor input, Tensor output) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); auto op = g->addOpWithOutputs(i0, o0); return op; } -Operator GraphFactoryObj::tanh(Tensor input) { +Operator GraphBuilderObj::tanh(Tensor input) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, nullptr); return op; } -Operator GraphFactoryObj::abs(Tensor input, Tensor output) { +Operator GraphBuilderObj::abs(Tensor input, Tensor output) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType()); auto op = g->addOpWithOutputs(i0, o0); return op; } -Operator GraphFactoryObj::abs(Tensor input) { +Operator GraphBuilderObj::abs(Tensor input) { Tensor i0 = g->addTensor(input->getDims(), input->getDType()); auto op = g->addOp(i0, nullptr); return op; } -Operator GraphFactoryObj::memBound(const TensorVec &inputs, +Operator GraphBuilderObj::memBound(const TensorVec &inputs, const TensorVec &outputs, const std::vector &nnetInputs, nnet::Expr expr, double exec_time, @@ -516,4 +516,4 @@ Operator GraphFactoryObj::memBound(const TensorVec &inputs, return op; } -} // namespace infini \ No newline at end of file +} // namespace infini diff --git a/src/core/operator.cc b/src/core/operator.cc index 4e5e96f4..b8e69af8 100644 --- a/src/core/operator.cc +++ b/src/core/operator.cc @@ -71,14 +71,6 @@ bool OperatorObj::checkValid(GraphObj *graph) { } } else { // if outputs have been created, check their shapes for (size_t i = 0; i < shapes.size(); ++i) { - printf("|* i = %ld *|\n", i); - printf("shapes:\n"); - for (auto shape : shapes[i]) - printf("%d ", shape); - printf("\n"); - for (auto dim : outputs[i]->getDims()) - printf("%d ", dim); - printf("\n"); if (shapes[i] != outputs[i]->getDims()) return false; } diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 48459084..6a7d7b0e 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -2,7 +2,7 @@ #ifdef USE_CUDA #include "cuda/operator_timer.h" #endif -#include "core/graph_factory.h" +#include "core/graph_builder.h" namespace py = pybind11; namespace infini { @@ -21,7 +21,7 @@ void register_operator_timer(py::module &m) { #endif } -void init_graph_factory(py::module &m) { +void init_graph_builder(py::module &m) { py::class_>(m, "RuntimeObj"); py::class_, RuntimeObj>( m, "CpuRuntimeObj") @@ -75,112 +75,112 @@ void init_graph_factory(py::module &m) { py::class_, OperatorObj>(m, "AbsObj"); py::class_, OperatorObj>( m, "MemBoundObj"); - py::class_(m, "GraphFactory"); - py::class_(m, "GraphFactoryObj") + py::class_(m, "GraphBuilder"); + py::class_(m, "GraphBuilderObj") .def(py::init()) .def("tensor", py::overload_cast( - &GraphFactoryObj::tensor), + &GraphBuilderObj::tensor), policy::reference_internal) .def("conv", py::overload_cast(&GraphFactoryObj::conv), + int, Tensor>(&GraphBuilderObj::conv), policy::reference_internal) .def("matmul", py::overload_cast(&GraphFactoryObj::matmul), + ActType>(&GraphBuilderObj::matmul), policy::reference_internal) .def("convTrans", py::overload_cast( - &GraphFactoryObj::convTrans), + &GraphBuilderObj::convTrans), policy::reference_internal) .def("g2bmm", py::overload_cast(&GraphFactoryObj::g2bmm), + Tensor, ActType>(&GraphBuilderObj::g2bmm), policy::reference_internal) .def("gbmml", py::overload_cast(&GraphFactoryObj::gbmml), + ActType>(&GraphBuilderObj::gbmml), policy::reference_internal) .def("pad", py::overload_cast &, const optional> &>( - &GraphFactoryObj::pad), + &GraphBuilderObj::pad), policy::reference_internal) .def("slice", py::overload_cast &, const vector &, const optional> &, const optional> &>( - &GraphFactoryObj::slice), + &GraphBuilderObj::slice), policy::reference_internal) .def( "concat", - py::overload_cast(&GraphFactoryObj::concat), + py::overload_cast(&GraphBuilderObj::concat), policy::reference_internal) .def("split", py::overload_cast, int, int>( - &GraphFactoryObj::split), + &GraphBuilderObj::split), policy::reference_internal) .def("extend", py::overload_cast( - &GraphFactoryObj::extend), + &GraphBuilderObj::extend), policy::reference_internal) .def("maxpool", py::overload_cast(&GraphFactoryObj::maxpool), + int, int>(&GraphBuilderObj::maxpool), policy::reference_internal) .def("avgpool", py::overload_cast(&GraphFactoryObj::avgpool), + int, int>(&GraphBuilderObj::avgpool), policy::reference_internal) .def("add", - py::overload_cast(&GraphFactoryObj::add), + py::overload_cast(&GraphBuilderObj::add), policy::reference_internal) .def("sub", - py::overload_cast(&GraphFactoryObj::sub), + py::overload_cast(&GraphBuilderObj::sub), policy::reference_internal) .def("mul", - py::overload_cast(&GraphFactoryObj::mul), + py::overload_cast(&GraphBuilderObj::mul), policy::reference_internal) .def("div", - py::overload_cast(&GraphFactoryObj::div), + py::overload_cast(&GraphBuilderObj::div), policy::reference_internal) .def("pow", - py::overload_cast(&GraphFactoryObj::pow), + py::overload_cast(&GraphBuilderObj::pow), policy::reference_internal) .def("gather", py::overload_cast( - &GraphFactoryObj::gather), + &GraphBuilderObj::gather), policy::reference_internal) .def("reshape", py::overload_cast( - &GraphFactoryObj::reshape), + &GraphBuilderObj::reshape), policy::reference_internal) .def("flatten", - py::overload_cast(&GraphFactoryObj::flatten), + py::overload_cast(&GraphBuilderObj::flatten), policy::reference_internal) .def("identity", - py::overload_cast(&GraphFactoryObj::identity), + py::overload_cast(&GraphBuilderObj::identity), policy::reference_internal) .def("softmax", - py::overload_cast(&GraphFactoryObj::softmax), + py::overload_cast(&GraphBuilderObj::softmax), policy::reference_internal) - .def("relu", py::overload_cast(&GraphFactoryObj::relu), + .def("relu", py::overload_cast(&GraphBuilderObj::relu), policy::reference_internal) .def("sigmoid", - py::overload_cast(&GraphFactoryObj::sigmoid), + py::overload_cast(&GraphBuilderObj::sigmoid), policy::reference_internal) - .def("tanh", py::overload_cast(&GraphFactoryObj::tanh), + .def("tanh", py::overload_cast(&GraphBuilderObj::tanh), policy::reference_internal) - .def("abs", py::overload_cast(&GraphFactoryObj::abs), + .def("abs", py::overload_cast(&GraphBuilderObj::abs), policy::reference_internal) .def("memBound", py::overload_cast &, nnet::Expr, - double, std::string>(&GraphFactoryObj::memBound), + double, std::string>(&GraphBuilderObj::memBound), policy::reference_internal); } @@ -188,5 +188,5 @@ void init_graph_factory(py::module &m) { PYBIND11_MODULE(pyinfinitensor, m) { infini::register_operator_timer(m); - infini::init_graph_factory(m); -} \ No newline at end of file + infini::init_graph_builder(m); +} diff --git a/test/core/test_graph_factory.cc b/test/core/test_graph_builder.cc similarity index 85% rename from test/core/test_graph_factory.cc rename to test/core/test_graph_builder.cc index 3016f643..18cf843f 100644 --- a/test/core/test_graph_factory.cc +++ b/test/core/test_graph_builder.cc @@ -1,12 +1,12 @@ -#include "core/graph_factory.h" +#include "core/graph_builder.h" #include "test.h" namespace infini { -TEST(GraphFactory, ops) { +TEST(GraphBuilder, ops) { Runtime runtime = CpuRuntimeObj::getInstance(); { // conv without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{1, 3, 4, 4}, DataType::UInt32, runtime); auto weight = @@ -15,7 +15,7 @@ TEST(GraphFactory, ops) { EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 4, 4})); } { // conv with output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{1, 3, 4, 4}, DataType::UInt32, runtime); auto weight = @@ -25,21 +25,21 @@ TEST(GraphFactory, ops) { auto conv = gf->conv(input, weight, output, 1, 1); } { // matmul without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto A = make_ref(Shape{1, 3, 5}, DataType::UInt32, runtime); auto B = make_ref(Shape{1, 5, 2}, DataType::UInt32, runtime); auto matmul = gf->matmul(A, B); EXPECT_EQ(matmul->getOutput()->getDims(), (Shape{1, 3, 2})); } { // matmul with output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto A = make_ref(Shape{1, 3, 5}, DataType::UInt32, runtime); auto B = make_ref(Shape{1, 5, 2}, DataType::UInt32, runtime); auto C = make_ref(Shape{1, 3, 2}, DataType::UInt32, runtime); auto matmul = gf->matmul(A, B, C); } { // convtrans without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{1, 228, 1, 1}, DataType::UInt32, runtime); auto weight = make_ref(Shape{228, 448, 2, 2}, @@ -48,7 +48,7 @@ TEST(GraphFactory, ops) { EXPECT_EQ(convtrans->getOutput()->getDims(), (Shape{1, 448, 2, 2})); } { // convtrans with output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{1, 228, 1, 1}, DataType::UInt32, runtime); auto weight = make_ref(Shape{228, 448, 2, 2}, @@ -58,7 +58,7 @@ TEST(GraphFactory, ops) { auto convtrans = gf->convTrans(input, weight, 0, 0); } { // pad without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{1, 64, 162, 162}, DataType::UInt32, runtime); vector pads = {2, 10, 1, 5, 0, 10, 1, 5}; @@ -66,7 +66,7 @@ TEST(GraphFactory, ops) { EXPECT_EQ(pad->getOutput()->getDims(), (Shape{3, 84, 164, 172})); } { // pad with output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{1, 64, 162, 162}, DataType::UInt32, runtime); auto output = make_ref(Shape{3, 84, 164, 172}, @@ -75,7 +75,7 @@ TEST(GraphFactory, ops) { auto pad = gf->pad(input, output, pads, std::nullopt); } { // slice without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{10, 64, 162, 162}, DataType::UInt32, runtime); vector starts = {2, 10, 1, 5}; @@ -84,7 +84,7 @@ TEST(GraphFactory, ops) { EXPECT_EQ(slice->getOutput()->getDims(), (Shape{2, 1, 100, 96})); } { // slice with output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{10, 64, 162, 162}, DataType::UInt32, runtime); auto output = make_ref(Shape{2, 1, 100, 96}, @@ -95,7 +95,7 @@ TEST(GraphFactory, ops) { gf->slice(input, output, starts, ends, std::nullopt, std::nullopt); } { // concat without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto t1 = make_ref(Shape{1, 3, 2, 4}, DataType::Float32, runtime); auto t2 = @@ -104,7 +104,7 @@ TEST(GraphFactory, ops) { EXPECT_EQ(concat->getOutput()->getDims(), (Shape{1, 3, 2, 9})); } { // concat with output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto t1 = make_ref(Shape{1, 3, 2, 4}, DataType::Float32, runtime); auto t2 = @@ -114,7 +114,7 @@ TEST(GraphFactory, ops) { auto concat = gf->concat(TensorVec{t1, t2}, o0, 3); } { // split without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{1, 3, 2, 15}, DataType::Float32, runtime); auto split = gf->split(input, 3, 4); @@ -126,7 +126,7 @@ TEST(GraphFactory, ops) { EXPECT_EQ(split->getOutput(3)->getDims(), (Shape{1, 3, 2, 6})); } { // split with output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{1, 3, 2, 15}, DataType::Float32, runtime); auto output0 = @@ -141,14 +141,14 @@ TEST(GraphFactory, ops) { input, TensorVec{output0, output1, output2, output3}, 3, 4); } { // extend without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{2, 3, 3, 4}, DataType::UInt32, runtime); auto extend = gf->extend(input, 2, 1); EXPECT_EQ(extend->getOutput()->getDims(), (Shape{2, 3, 6, 4})); } { // extend with output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{2, 3, 3, 4}, DataType::UInt32, runtime); auto output = @@ -156,7 +156,7 @@ TEST(GraphFactory, ops) { auto extend = gf->extend(input, output, 2, 1); } { // maxpool without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{1, 64, 162, 162}, DataType::UInt32, runtime); const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2, @@ -165,7 +165,7 @@ TEST(GraphFactory, ops) { EXPECT_EQ(maxpool->getOutput()->getDims(), (Shape{1, 64, 80, 80})); } { // maxpool with output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{1, 64, 162, 162}, DataType::UInt32, runtime); auto output = make_ref(Shape{1, 64, 80, 80}, @@ -176,7 +176,7 @@ TEST(GraphFactory, ops) { gf->maxpool(input, output, kh, kw, dh, dw, ph, pw, sh, sw); } { // add without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input0 = make_ref(Shape{2, 3, 3, 4}, DataType::UInt32, runtime); auto input1 = @@ -185,7 +185,7 @@ TEST(GraphFactory, ops) { EXPECT_EQ(add->getOutput()->getDims(), (Shape{2, 3, 3, 4})); } { // add with output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input0 = make_ref(Shape{2, 3, 3, 4}, DataType::UInt32, runtime); auto input1 = @@ -195,7 +195,7 @@ TEST(GraphFactory, ops) { auto add = gf->add(input0, input1, output); } { // gather without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{1, 3, 4, 4}, DataType::UInt32, runtime); auto index = @@ -204,7 +204,7 @@ TEST(GraphFactory, ops) { EXPECT_EQ(gather->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4})); } { // gather with output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{1, 3, 4, 4}, DataType::UInt32, runtime); auto index = @@ -214,7 +214,7 @@ TEST(GraphFactory, ops) { auto gather = gf->gather(input, index, output, 1); } { // reshape without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{2, 3, 3, 4}, DataType::Float32, runtime); vector dims = {3, 2, 4, 3}; @@ -222,7 +222,7 @@ TEST(GraphFactory, ops) { EXPECT_EQ(reshape->getOutput()->getDims(), (Shape{3, 2, 4, 3})); } { // reshape with output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{2, 3, 3, 4}, DataType::Float32, runtime); vector dims = {3, 2, 4, 3}; @@ -231,14 +231,14 @@ TEST(GraphFactory, ops) { auto reshape = gf->reshape(input, output, dims); } { // flatten without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{2, 3, 3, 4}, DataType::Float32, runtime); auto flatten = gf->flatten(input); EXPECT_EQ(flatten->getOutput()->getDims(), (Shape{72})); } { // flatten without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{2, 3, 3, 4}, DataType::Float32, runtime); auto output = @@ -246,14 +246,14 @@ TEST(GraphFactory, ops) { auto flatten = gf->flatten(input, output); } { // identity without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{2, 3, 3, 4}, DataType::Float32, runtime); auto identity = gf->identity(input); EXPECT_EQ(identity->getOutput()->getDims(), (Shape{2, 3, 3, 4})); } { // identity without output - GraphFactory gf = make_ref(runtime); + GraphBuilder gf = make_ref(runtime); auto input = make_ref(Shape{2, 3, 3, 4}, DataType::Float32, runtime); auto output = @@ -262,4 +262,4 @@ TEST(GraphFactory, ops) { } } -} // namespace infini \ No newline at end of file +} // namespace infini