Update: Rename GraphFactory -> GraphBuilder && Remove unnecessary outputs

This commit is contained in:
Pairshoe 2022-10-26 21:06:21 +08:00 committed by mazx
parent 7cf2d8f78f
commit 970c77d0f4
8 changed files with 136 additions and 144 deletions

View File

@ -21,12 +21,12 @@
namespace infini { namespace infini {
class GraphFactoryObj { class GraphBuilderObj {
private: private:
Graph g; Graph g;
public: public:
GraphFactoryObj(Runtime runtime) : g(make_ref<GraphObj>(runtime)) {} GraphBuilderObj(Runtime runtime) : g(make_ref<GraphObj>(runtime)) {}
// tensors // tensors
Tensor tensor(Shape dim, const std::string &dtype); Tensor tensor(Shape dim, const std::string &dtype);

View File

@ -10,7 +10,7 @@ class TensorBaseObj;
class TensorObj; class TensorObj;
class OperatorObj; class OperatorObj;
class GraphObj; class GraphObj;
class GraphFactoryObj; class GraphBuilderObj;
class RuntimeObj; class RuntimeObj;
class BlobObj; class BlobObj;
@ -18,7 +18,7 @@ using TensorBase = Ref<TensorBaseObj>;
using Tensor = Ref<TensorObj>; using Tensor = Ref<TensorObj>;
using Operator = Ref<OperatorObj>; using Operator = Ref<OperatorObj>;
using Graph = Ref<GraphObj>; using Graph = Ref<GraphObj>;
using GraphFactory = Ref<GraphFactoryObj>; using GraphBuilder = Ref<GraphBuilderObj>;
using Runtime = Ref<RuntimeObj>; using Runtime = Ref<RuntimeObj>;
using Blob = Ref<BlobObj>; using Blob = Ref<BlobObj>;
enum class OpType; enum class OpType;

View File

@ -122,7 +122,7 @@ def _onnx_datatype_tostring(dtype):
assert False, 'Unknown onnx datatype' assert False, 'Unknown onnx datatype'
def import_onnx(gf: GraphFactory, net: str): def import_onnx(gf: GraphBuilder, net: str):
ts, ds, ops, consts = dict(), dict(), dict(), dict() # (key, value) = (name, class) ts, ds, ops, consts = dict(), dict(), dict(), dict() # (key, value) = (name, class)
model = onnx.load(net) model = onnx.load(net)

View File

@ -5,8 +5,8 @@ import sys
def main(netPath): def main(netPath):
runtime = CpuRuntimeObj.getInstance() runtime = CpuRuntimeObj.getInstance()
graphFactory = GraphFactoryObj(runtime) graphBuilder = GraphBuilderObj(runtime)
import_onnx(graphFactory, netPath) import_onnx(graphBuilder, netPath)
if __name__ == "__main__": if __name__ == "__main__":
main(sys.argv[1]) main(sys.argv[1])

View File

@ -1,8 +1,8 @@
#include "core/graph_factory.h" #include "core/graph_builder.h"
namespace infini { namespace infini {
Tensor GraphFactoryObj::tensor(Shape dim, const std::string &dtype) { Tensor GraphBuilderObj::tensor(Shape dim, const std::string &dtype) {
if (dtype == "FLOAT") { if (dtype == "FLOAT") {
return g->addTensor(dim, DataType::Float32); return g->addTensor(dim, DataType::Float32);
} }
@ -12,7 +12,7 @@ Tensor GraphFactoryObj::tensor(Shape dim, const std::string &dtype) {
IT_TODO_HALT_MSG("Unsupported data type"); IT_TODO_HALT_MSG("Unsupported data type");
} }
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output, Operator GraphBuilderObj::conv(Tensor input, Tensor weight, Tensor output,
int ph, int pw, int sh, int sw, int dh, int dw, int ph, int pw, int sh, int sw, int dh, int dw,
Tensor bias) { Tensor bias) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -23,7 +23,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
return op; return op;
} }
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, int ph, int pw, Operator GraphBuilderObj::conv(Tensor input, Tensor weight, int ph, int pw,
int sh, int sw, int dh, int dw, Tensor bias) { int sh, int sw, int dh, int dw, Tensor bias) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType()); Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
@ -31,7 +31,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, int ph, int pw,
return op; return op;
} }
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output, Operator GraphBuilderObj::conv(Tensor input, Tensor weight, Tensor output,
ConvBaseObj::PaddingMode pm, int sh, int sw, ConvBaseObj::PaddingMode pm, int sh, int sw,
int dh, int dw, Tensor bias) { int dh, int dw, Tensor bias) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -42,7 +42,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
return op; return op;
} }
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Operator GraphBuilderObj::conv(Tensor input, Tensor weight,
ConvBaseObj::PaddingMode pm, int sh, int sw, ConvBaseObj::PaddingMode pm, int sh, int sw,
int dh, int dw, Tensor bias) { int dh, int dw, Tensor bias) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -51,7 +51,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight,
return op; return op;
} }
Operator GraphFactoryObj::matmul(Tensor A, Tensor B, Tensor C, bool transA, Operator GraphBuilderObj::matmul(Tensor A, Tensor B, Tensor C, bool transA,
bool transB, Tensor bias, ActType act) { bool transB, Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType()); Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType()); Tensor i1 = g->addTensor(B->getDims(), B->getDType());
@ -61,7 +61,7 @@ Operator GraphFactoryObj::matmul(Tensor A, Tensor B, Tensor C, bool transA,
return op; return op;
} }
Operator GraphFactoryObj::matmul(Tensor A, Tensor B, bool transA, bool transB, Operator GraphBuilderObj::matmul(Tensor A, Tensor B, bool transA, bool transB,
Tensor bias, ActType act) { Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType()); Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType()); Tensor i1 = g->addTensor(B->getDims(), B->getDType());
@ -69,7 +69,7 @@ Operator GraphFactoryObj::matmul(Tensor A, Tensor B, bool transA, bool transB,
return op; return op;
} }
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output, Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, Tensor output,
int ph, int pw, int sh, int sw, int dh, int ph, int pw, int sh, int sw, int dh,
int dw, int oph, int opw, int group, int dw, int oph, int opw, int group,
Tensor bias, ActType act) { Tensor bias, ActType act) {
@ -81,7 +81,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
return op; return op;
} }
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, int ph, int pw, Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, int ph, int pw,
int sh, int sw, int dh, int dw, int oph, int sh, int sw, int dh, int dw, int oph,
int opw, int group, Tensor bias, int opw, int group, Tensor bias,
ActType act) { ActType act) {
@ -92,7 +92,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, int ph, int pw,
return op; return op;
} }
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output, Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, Tensor output,
ConvBaseObj::PaddingMode pm, int sh, int sw, ConvBaseObj::PaddingMode pm, int sh, int sw,
int dh, int dw, int oph, int opw, int group, int dh, int dw, int oph, int opw, int group,
Tensor bias, ActType act) { Tensor bias, ActType act) {
@ -104,7 +104,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
return op; return op;
} }
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight,
ConvBaseObj::PaddingMode pm, int sh, int sw, ConvBaseObj::PaddingMode pm, int sh, int sw,
int dh, int dw, int oph, int opw, int group, int dh, int dw, int oph, int opw, int group,
Tensor bias, ActType act) { Tensor bias, ActType act) {
@ -115,7 +115,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight,
return op; return op;
} }
Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width, Operator GraphBuilderObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width,
const int dilation, Tensor bias, ActType act) { const int dilation, Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType()); Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType()); Tensor i1 = g->addTensor(B->getDims(), B->getDType());
@ -125,7 +125,7 @@ Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width,
return op; return op;
} }
Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, const int width, Operator GraphBuilderObj::g2bmm(Tensor A, Tensor B, const int width,
const int dilation, Tensor bias, ActType act) { const int dilation, Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType()); Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType()); Tensor i1 = g->addTensor(B->getDims(), B->getDType());
@ -133,7 +133,7 @@ Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, const int width,
return op; return op;
} }
Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, Tensor C, Operator GraphBuilderObj::gbmml(Tensor A, Tensor B, Tensor C,
const int dilation, Tensor bias, ActType act) { const int dilation, Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType()); Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType()); Tensor i1 = g->addTensor(B->getDims(), B->getDType());
@ -142,7 +142,7 @@ Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, Tensor C,
return op; return op;
} }
Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, const int dilation, Operator GraphBuilderObj::gbmml(Tensor A, Tensor B, const int dilation,
Tensor bias, ActType act) { Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType()); Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType()); Tensor i1 = g->addTensor(B->getDims(), B->getDType());
@ -150,7 +150,7 @@ Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, const int dilation,
return op; return op;
} }
Operator GraphFactoryObj::pad(Tensor input, Tensor output, Operator GraphBuilderObj::pad(Tensor input, Tensor output,
const vector<int> &pads, const vector<int> &pads,
const optional<const vector<int>> &axis) { const optional<const vector<int>> &axis) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -159,14 +159,14 @@ Operator GraphFactoryObj::pad(Tensor input, Tensor output,
return op; return op;
} }
Operator GraphFactoryObj::pad(Tensor input, const vector<int> &pads, Operator GraphBuilderObj::pad(Tensor input, const vector<int> &pads,
const optional<const vector<int>> &axis) { const optional<const vector<int>> &axis) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<PadObj>(i0, nullptr, pads, axis); auto op = g->addOp<PadObj>(i0, nullptr, pads, axis);
return op; return op;
} }
Operator GraphFactoryObj::slice(Tensor input, Tensor output, Operator GraphBuilderObj::slice(Tensor input, Tensor output,
const vector<int> &starts, const vector<int> &starts,
const vector<int> &ends, const vector<int> &ends,
const optional<const vector<int>> &axis, const optional<const vector<int>> &axis,
@ -177,7 +177,7 @@ Operator GraphFactoryObj::slice(Tensor input, Tensor output,
return op; return op;
} }
Operator GraphFactoryObj::slice(Tensor input, const vector<int> &starts, Operator GraphBuilderObj::slice(Tensor input, const vector<int> &starts,
const vector<int> &ends, const vector<int> &ends,
const optional<const vector<int>> &axis, const optional<const vector<int>> &axis,
const optional<const vector<int>> &steps) { const optional<const vector<int>> &steps) {
@ -186,7 +186,7 @@ Operator GraphFactoryObj::slice(Tensor input, const vector<int> &starts,
return op; return op;
} }
Operator GraphFactoryObj::concat(TensorVec inputs, Tensor output, int dim) { Operator GraphBuilderObj::concat(TensorVec inputs, Tensor output, int dim) {
TensorVec is; TensorVec is;
for (auto input : inputs) { for (auto input : inputs) {
Tensor i = g->addTensor(input->getDims(), input->getDType()); Tensor i = g->addTensor(input->getDims(), input->getDType());
@ -197,7 +197,7 @@ Operator GraphFactoryObj::concat(TensorVec inputs, Tensor output, int dim) {
return op; return op;
} }
Operator GraphFactoryObj::concat(TensorVec inputs, int dim) { Operator GraphBuilderObj::concat(TensorVec inputs, int dim) {
TensorVec is; TensorVec is;
for (auto input : inputs) { for (auto input : inputs) {
Tensor i = g->addTensor(input->getDims(), input->getDType()); Tensor i = g->addTensor(input->getDims(), input->getDType());
@ -207,7 +207,7 @@ Operator GraphFactoryObj::concat(TensorVec inputs, int dim) {
return op; return op;
} }
Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs, Operator GraphBuilderObj::split(Tensor input, std::optional<TensorVec> outputs,
int dim, int num) { int dim, int num) {
Tensor i = g->addTensor(input->getDims(), input->getDType()); Tensor i = g->addTensor(input->getDims(), input->getDType());
if (outputs.has_value()) { if (outputs.has_value()) {
@ -224,13 +224,13 @@ Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
} }
} }
Operator GraphFactoryObj::split(Tensor input, int dim, int num) { Operator GraphBuilderObj::split(Tensor input, int dim, int num) {
Tensor i = g->addTensor(input->getDims(), input->getDType()); Tensor i = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, num); auto op = g->addOp<SplitObj>(i, std::nullopt, dim, num);
return op; return op;
} }
Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs, Operator GraphBuilderObj::split(Tensor input, std::optional<TensorVec> outputs,
int dim, const vector<int> &ratio) { int dim, const vector<int> &ratio) {
Tensor i = g->addTensor(input->getDims(), input->getDType()); Tensor i = g->addTensor(input->getDims(), input->getDType());
if (outputs.has_value()) { if (outputs.has_value()) {
@ -247,14 +247,14 @@ Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
} }
} }
Operator GraphFactoryObj::split(Tensor input, int dim, Operator GraphBuilderObj::split(Tensor input, int dim,
const vector<int> &ratio) { const vector<int> &ratio) {
Tensor i = g->addTensor(input->getDims(), input->getDType()); Tensor i = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, ratio); auto op = g->addOp<SplitObj>(i, std::nullopt, dim, ratio);
return op; return op;
} }
Operator GraphFactoryObj::extend(Tensor input, Tensor output, int dim, Operator GraphBuilderObj::extend(Tensor input, Tensor output, int dim,
int num) { int num) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -262,13 +262,13 @@ Operator GraphFactoryObj::extend(Tensor input, Tensor output, int dim,
return op; return op;
} }
Operator GraphFactoryObj::extend(Tensor input, int dim, int num) { Operator GraphBuilderObj::extend(Tensor input, int dim, int num) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<ExtendObj>(i0, nullptr, dim, num); auto op = g->addOp<ExtendObj>(i0, nullptr, dim, num);
return op; return op;
} }
Operator GraphFactoryObj::maxpool(Tensor input, Tensor output, int kh, int kw, Operator GraphBuilderObj::maxpool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh, int dh, int dw, int ph, int pw, int sh,
int sw) { int sw) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -278,14 +278,14 @@ Operator GraphFactoryObj::maxpool(Tensor input, Tensor output, int kh, int kw,
return op; return op;
} }
Operator GraphFactoryObj::maxpool(Tensor input, int kh, int kw, int dh, int dw, Operator GraphBuilderObj::maxpool(Tensor input, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw) { int ph, int pw, int sh, int sw) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<MaxPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw); auto op = g->addOp<MaxPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
return op; return op;
} }
Operator GraphFactoryObj::avgpool(Tensor input, Tensor output, int kh, int kw, Operator GraphBuilderObj::avgpool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh, int dh, int dw, int ph, int pw, int sh,
int sw) { int sw) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -295,14 +295,14 @@ Operator GraphFactoryObj::avgpool(Tensor input, Tensor output, int kh, int kw,
return op; return op;
} }
Operator GraphFactoryObj::avgpool(Tensor input, int kh, int kw, int dh, int dw, Operator GraphBuilderObj::avgpool(Tensor input, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw) { int ph, int pw, int sh, int sw) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<AvgPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw); auto op = g->addOp<AvgPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
return op; return op;
} }
Operator GraphFactoryObj::add(Tensor input0, Tensor input1, Tensor output) { Operator GraphBuilderObj::add(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -310,14 +310,14 @@ Operator GraphFactoryObj::add(Tensor input0, Tensor input1, Tensor output) {
return op; return op;
} }
Operator GraphFactoryObj::add(Tensor input0, Tensor input1) { Operator GraphBuilderObj::add(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<AddObj>(i0, i1, nullptr); auto op = g->addOp<AddObj>(i0, i1, nullptr);
return op; return op;
} }
Operator GraphFactoryObj::sub(Tensor input0, Tensor input1, Tensor output) { Operator GraphBuilderObj::sub(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -325,14 +325,14 @@ Operator GraphFactoryObj::sub(Tensor input0, Tensor input1, Tensor output) {
return op; return op;
} }
Operator GraphFactoryObj::sub(Tensor input0, Tensor input1) { Operator GraphBuilderObj::sub(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<SubObj>(i0, i1, nullptr); auto op = g->addOp<SubObj>(i0, i1, nullptr);
return op; return op;
} }
Operator GraphFactoryObj::mul(Tensor input0, Tensor input1, Tensor output) { Operator GraphBuilderObj::mul(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -340,14 +340,14 @@ Operator GraphFactoryObj::mul(Tensor input0, Tensor input1, Tensor output) {
return op; return op;
} }
Operator GraphFactoryObj::mul(Tensor input0, Tensor input1) { Operator GraphBuilderObj::mul(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<SubObj>(i0, i1, nullptr); auto op = g->addOp<SubObj>(i0, i1, nullptr);
return op; return op;
} }
Operator GraphFactoryObj::div(Tensor input0, Tensor input1, Tensor output) { Operator GraphBuilderObj::div(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -355,14 +355,14 @@ Operator GraphFactoryObj::div(Tensor input0, Tensor input1, Tensor output) {
return op; return op;
} }
Operator GraphFactoryObj::div(Tensor input0, Tensor input1) { Operator GraphBuilderObj::div(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<DivObj>(i0, i1, nullptr); auto op = g->addOp<DivObj>(i0, i1, nullptr);
return op; return op;
} }
Operator GraphFactoryObj::pow(Tensor input0, Tensor input1, Tensor output) { Operator GraphBuilderObj::pow(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -370,14 +370,14 @@ Operator GraphFactoryObj::pow(Tensor input0, Tensor input1, Tensor output) {
return op; return op;
} }
Operator GraphFactoryObj::pow(Tensor input0, Tensor input1) { Operator GraphBuilderObj::pow(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType()); Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType()); Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<PowObj>(i0, i1, nullptr); auto op = g->addOp<PowObj>(i0, i1, nullptr);
return op; return op;
} }
Operator GraphFactoryObj::gather(Tensor input, Tensor index, Tensor output, Operator GraphBuilderObj::gather(Tensor input, Tensor index, Tensor output,
int axis) { int axis) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -385,13 +385,13 @@ Operator GraphFactoryObj::gather(Tensor input, Tensor index, Tensor output,
return op; return op;
} }
Operator GraphFactoryObj::gather(Tensor input, Tensor index, int axis) { Operator GraphBuilderObj::gather(Tensor input, Tensor index, int axis) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<GatherObj>(i0, index, nullptr, axis); auto op = g->addOp<GatherObj>(i0, index, nullptr, axis);
return op; return op;
} }
Operator GraphFactoryObj::reshape(Tensor input, Tensor output, Operator GraphBuilderObj::reshape(Tensor input, Tensor output,
const Shape &dims) { const Shape &dims) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -399,104 +399,104 @@ Operator GraphFactoryObj::reshape(Tensor input, Tensor output,
return op; return op;
} }
Operator GraphFactoryObj::reshape(Tensor input, const Shape &dims) { Operator GraphBuilderObj::reshape(Tensor input, const Shape &dims) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<ReshapeObj>(i0, nullptr, dims); auto op = g->addOp<ReshapeObj>(i0, nullptr, dims);
return op; return op;
} }
Operator GraphFactoryObj::flatten(Tensor input, Tensor output) { Operator GraphBuilderObj::flatten(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<FlattenObj>(i0, o0); auto op = g->addOpWithOutputs<FlattenObj>(i0, o0);
return op; return op;
} }
Operator GraphFactoryObj::flatten(Tensor input) { Operator GraphBuilderObj::flatten(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<FlattenObj>(i0, nullptr); auto op = g->addOp<FlattenObj>(i0, nullptr);
return op; return op;
} }
Operator GraphFactoryObj::identity(Tensor input, Tensor output) { Operator GraphBuilderObj::identity(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<IdentityObj>(i0, o0); auto op = g->addOpWithOutputs<IdentityObj>(i0, o0);
return op; return op;
} }
Operator GraphFactoryObj::identity(Tensor input) { Operator GraphBuilderObj::identity(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<IdentityObj>(i0, nullptr); auto op = g->addOp<IdentityObj>(i0, nullptr);
return op; return op;
} }
Operator GraphFactoryObj::softmax(Tensor input, Tensor output) { Operator GraphBuilderObj::softmax(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<SoftmaxObj>(i0, o0); auto op = g->addOpWithOutputs<SoftmaxObj>(i0, o0);
return op; return op;
} }
Operator GraphFactoryObj::softmax(Tensor input) { Operator GraphBuilderObj::softmax(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SoftmaxObj>(i0, nullptr); auto op = g->addOp<SoftmaxObj>(i0, nullptr);
return op; return op;
} }
Operator GraphFactoryObj::relu(Tensor input, Tensor output) { Operator GraphBuilderObj::relu(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<ReluObj>(i0, o0); auto op = g->addOpWithOutputs<ReluObj>(i0, o0);
return op; return op;
} }
Operator GraphFactoryObj::relu(Tensor input) { Operator GraphBuilderObj::relu(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<ReluObj>(i0, nullptr); auto op = g->addOp<ReluObj>(i0, nullptr);
return op; return op;
} }
Operator GraphFactoryObj::sigmoid(Tensor input, Tensor output) { Operator GraphBuilderObj::sigmoid(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<SigmoidObj>(i0, o0); auto op = g->addOpWithOutputs<SigmoidObj>(i0, o0);
return op; return op;
} }
Operator GraphFactoryObj::sigmoid(Tensor input) { Operator GraphBuilderObj::sigmoid(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SigmoidObj>(i0, nullptr); auto op = g->addOp<SigmoidObj>(i0, nullptr);
return op; return op;
} }
Operator GraphFactoryObj::tanh(Tensor input, Tensor output) { Operator GraphBuilderObj::tanh(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<TanhObj>(i0, o0); auto op = g->addOpWithOutputs<TanhObj>(i0, o0);
return op; return op;
} }
Operator GraphFactoryObj::tanh(Tensor input) { Operator GraphBuilderObj::tanh(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<TanhObj>(i0, nullptr); auto op = g->addOp<TanhObj>(i0, nullptr);
return op; return op;
} }
Operator GraphFactoryObj::abs(Tensor input, Tensor output) { Operator GraphBuilderObj::abs(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType()); Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<AbsObj>(i0, o0); auto op = g->addOpWithOutputs<AbsObj>(i0, o0);
return op; return op;
} }
Operator GraphFactoryObj::abs(Tensor input) { Operator GraphBuilderObj::abs(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType()); Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<AbsObj>(i0, nullptr); auto op = g->addOp<AbsObj>(i0, nullptr);
return op; return op;
} }
Operator GraphFactoryObj::memBound(const TensorVec &inputs, Operator GraphBuilderObj::memBound(const TensorVec &inputs,
const TensorVec &outputs, const TensorVec &outputs,
const std::vector<nnet::Tensor> &nnetInputs, const std::vector<nnet::Tensor> &nnetInputs,
nnet::Expr expr, double exec_time, nnet::Expr expr, double exec_time,

View File

@ -71,14 +71,6 @@ bool OperatorObj::checkValid(GraphObj *graph) {
} }
} else { // if outputs have been created, check their shapes } else { // if outputs have been created, check their shapes
for (size_t i = 0; i < shapes.size(); ++i) { for (size_t i = 0; i < shapes.size(); ++i) {
printf("|* i = %ld *|\n", i);
printf("shapes:\n");
for (auto shape : shapes[i])
printf("%d ", shape);
printf("\n");
for (auto dim : outputs[i]->getDims())
printf("%d ", dim);
printf("\n");
if (shapes[i] != outputs[i]->getDims()) if (shapes[i] != outputs[i]->getDims())
return false; return false;
} }

View File

@ -2,7 +2,7 @@
#ifdef USE_CUDA #ifdef USE_CUDA
#include "cuda/operator_timer.h" #include "cuda/operator_timer.h"
#endif #endif
#include "core/graph_factory.h" #include "core/graph_builder.h"
namespace py = pybind11; namespace py = pybind11;
namespace infini { namespace infini {
@ -21,7 +21,7 @@ void register_operator_timer(py::module &m) {
#endif #endif
} }
void init_graph_factory(py::module &m) { void init_graph_builder(py::module &m) {
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj"); py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>( py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntimeObj") m, "CpuRuntimeObj")
@ -75,112 +75,112 @@ void init_graph_factory(py::module &m) {
py::class_<AbsObj, std::shared_ptr<AbsObj>, OperatorObj>(m, "AbsObj"); py::class_<AbsObj, std::shared_ptr<AbsObj>, OperatorObj>(m, "AbsObj");
py::class_<MemBoundObj, std::shared_ptr<MemBoundObj>, OperatorObj>( py::class_<MemBoundObj, std::shared_ptr<MemBoundObj>, OperatorObj>(
m, "MemBoundObj"); m, "MemBoundObj");
py::class_<GraphFactory>(m, "GraphFactory"); py::class_<GraphBuilder>(m, "GraphBuilder");
py::class_<GraphFactoryObj>(m, "GraphFactoryObj") py::class_<GraphBuilderObj>(m, "GraphBuilderObj")
.def(py::init<Runtime>()) .def(py::init<Runtime>())
.def("tensor", .def("tensor",
py::overload_cast<Shape, const std::string &>( py::overload_cast<Shape, const std::string &>(
&GraphFactoryObj::tensor), &GraphBuilderObj::tensor),
policy::reference_internal) policy::reference_internal)
.def("conv", .def("conv",
py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int, py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int,
int, Tensor>(&GraphFactoryObj::conv), int, Tensor>(&GraphBuilderObj::conv),
policy::reference_internal) policy::reference_internal)
.def("matmul", .def("matmul",
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor, py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
ActType>(&GraphFactoryObj::matmul), ActType>(&GraphBuilderObj::matmul),
policy::reference_internal) policy::reference_internal)
.def("convTrans", .def("convTrans",
py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int, py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int,
int, int, int, int, Tensor, ActType>( int, int, int, int, Tensor, ActType>(
&GraphFactoryObj::convTrans), &GraphBuilderObj::convTrans),
policy::reference_internal) policy::reference_internal)
.def("g2bmm", .def("g2bmm",
py::overload_cast<Tensor, Tensor, Tensor, const int, const int, py::overload_cast<Tensor, Tensor, Tensor, const int, const int,
Tensor, ActType>(&GraphFactoryObj::g2bmm), Tensor, ActType>(&GraphBuilderObj::g2bmm),
policy::reference_internal) policy::reference_internal)
.def("gbmml", .def("gbmml",
py::overload_cast<Tensor, Tensor, Tensor, const int, Tensor, py::overload_cast<Tensor, Tensor, Tensor, const int, Tensor,
ActType>(&GraphFactoryObj::gbmml), ActType>(&GraphBuilderObj::gbmml),
policy::reference_internal) policy::reference_internal)
.def("pad", .def("pad",
py::overload_cast<Tensor, Tensor, const vector<int> &, py::overload_cast<Tensor, Tensor, const vector<int> &,
const optional<const vector<int>> &>( const optional<const vector<int>> &>(
&GraphFactoryObj::pad), &GraphBuilderObj::pad),
policy::reference_internal) policy::reference_internal)
.def("slice", .def("slice",
py::overload_cast<Tensor, Tensor, const vector<int> &, py::overload_cast<Tensor, Tensor, const vector<int> &,
const vector<int> &, const vector<int> &,
const optional<const vector<int>> &, const optional<const vector<int>> &,
const optional<const vector<int>> &>( const optional<const vector<int>> &>(
&GraphFactoryObj::slice), &GraphBuilderObj::slice),
policy::reference_internal) policy::reference_internal)
.def( .def(
"concat", "concat",
py::overload_cast<TensorVec, Tensor, int>(&GraphFactoryObj::concat), py::overload_cast<TensorVec, Tensor, int>(&GraphBuilderObj::concat),
policy::reference_internal) policy::reference_internal)
.def("split", .def("split",
py::overload_cast<Tensor, std::optional<TensorVec>, int, int>( py::overload_cast<Tensor, std::optional<TensorVec>, int, int>(
&GraphFactoryObj::split), &GraphBuilderObj::split),
policy::reference_internal) policy::reference_internal)
.def("extend", .def("extend",
py::overload_cast<Tensor, Tensor, int, int>( py::overload_cast<Tensor, Tensor, int, int>(
&GraphFactoryObj::extend), &GraphBuilderObj::extend),
policy::reference_internal) policy::reference_internal)
.def("maxpool", .def("maxpool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int, py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&GraphFactoryObj::maxpool), int, int>(&GraphBuilderObj::maxpool),
policy::reference_internal) policy::reference_internal)
.def("avgpool", .def("avgpool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int, py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&GraphFactoryObj::avgpool), int, int>(&GraphBuilderObj::avgpool),
policy::reference_internal) policy::reference_internal)
.def("add", .def("add",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::add), py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::add),
policy::reference_internal) policy::reference_internal)
.def("sub", .def("sub",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::sub), py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::sub),
policy::reference_internal) policy::reference_internal)
.def("mul", .def("mul",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::mul), py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::mul),
policy::reference_internal) policy::reference_internal)
.def("div", .def("div",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::div), py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::div),
policy::reference_internal) policy::reference_internal)
.def("pow", .def("pow",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::pow), py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::pow),
policy::reference_internal) policy::reference_internal)
.def("gather", .def("gather",
py::overload_cast<Tensor, Tensor, Tensor, int>( py::overload_cast<Tensor, Tensor, Tensor, int>(
&GraphFactoryObj::gather), &GraphBuilderObj::gather),
policy::reference_internal) policy::reference_internal)
.def("reshape", .def("reshape",
py::overload_cast<Tensor, Tensor, const Shape &>( py::overload_cast<Tensor, Tensor, const Shape &>(
&GraphFactoryObj::reshape), &GraphBuilderObj::reshape),
policy::reference_internal) policy::reference_internal)
.def("flatten", .def("flatten",
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::flatten), py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::flatten),
policy::reference_internal) policy::reference_internal)
.def("identity", .def("identity",
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::identity), py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::identity),
policy::reference_internal) policy::reference_internal)
.def("softmax", .def("softmax",
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::softmax), py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::softmax),
policy::reference_internal) policy::reference_internal)
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::relu), .def("relu", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::relu),
policy::reference_internal) policy::reference_internal)
.def("sigmoid", .def("sigmoid",
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::sigmoid), py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::sigmoid),
policy::reference_internal) policy::reference_internal)
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::tanh), .def("tanh", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::tanh),
policy::reference_internal) policy::reference_internal)
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::abs), .def("abs", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::abs),
policy::reference_internal) policy::reference_internal)
.def("memBound", .def("memBound",
py::overload_cast<const TensorVec &, const TensorVec &, py::overload_cast<const TensorVec &, const TensorVec &,
const std::vector<nnet::Tensor> &, nnet::Expr, const std::vector<nnet::Tensor> &, nnet::Expr,
double, std::string>(&GraphFactoryObj::memBound), double, std::string>(&GraphBuilderObj::memBound),
policy::reference_internal); policy::reference_internal);
} }
@ -188,5 +188,5 @@ void init_graph_factory(py::module &m) {
PYBIND11_MODULE(pyinfinitensor, m) { PYBIND11_MODULE(pyinfinitensor, m) {
infini::register_operator_timer(m); infini::register_operator_timer(m);
infini::init_graph_factory(m); infini::init_graph_builder(m);
} }

View File

@ -1,12 +1,12 @@
#include "core/graph_factory.h" #include "core/graph_builder.h"
#include "test.h" #include "test.h"
namespace infini { namespace infini {
TEST(GraphFactory, ops) { TEST(GraphBuilder, ops) {
Runtime runtime = CpuRuntimeObj::getInstance(); Runtime runtime = CpuRuntimeObj::getInstance();
{ // conv without output { // conv without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime); make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
auto weight = auto weight =
@ -15,7 +15,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 4, 4})); EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 4, 4}));
} }
{ // conv with output { // conv with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime); make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
auto weight = auto weight =
@ -25,21 +25,21 @@ TEST(GraphFactory, ops) {
auto conv = gf->conv(input, weight, output, 1, 1); auto conv = gf->conv(input, weight, output, 1, 1);
} }
{ // matmul without output { // matmul without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime); auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime);
auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime); auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime);
auto matmul = gf->matmul(A, B); auto matmul = gf->matmul(A, B);
EXPECT_EQ(matmul->getOutput()->getDims(), (Shape{1, 3, 2})); EXPECT_EQ(matmul->getOutput()->getDims(), (Shape{1, 3, 2}));
} }
{ // matmul with output { // matmul with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime); auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime);
auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime); auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime);
auto C = make_ref<TensorObj>(Shape{1, 3, 2}, DataType::UInt32, runtime); auto C = make_ref<TensorObj>(Shape{1, 3, 2}, DataType::UInt32, runtime);
auto matmul = gf->matmul(A, B, C); auto matmul = gf->matmul(A, B, C);
} }
{ // convtrans without output { // convtrans without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime); make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime);
auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2}, auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2},
@ -48,7 +48,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(convtrans->getOutput()->getDims(), (Shape{1, 448, 2, 2})); EXPECT_EQ(convtrans->getOutput()->getDims(), (Shape{1, 448, 2, 2}));
} }
{ // convtrans with output { // convtrans with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime); make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime);
auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2}, auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2},
@ -58,7 +58,7 @@ TEST(GraphFactory, ops) {
auto convtrans = gf->convTrans(input, weight, 0, 0); auto convtrans = gf->convTrans(input, weight, 0, 0);
} }
{ // pad without output { // pad without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162}, auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
DataType::UInt32, runtime); DataType::UInt32, runtime);
vector<int> pads = {2, 10, 1, 5, 0, 10, 1, 5}; vector<int> pads = {2, 10, 1, 5, 0, 10, 1, 5};
@ -66,7 +66,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(pad->getOutput()->getDims(), (Shape{3, 84, 164, 172})); EXPECT_EQ(pad->getOutput()->getDims(), (Shape{3, 84, 164, 172}));
} }
{ // pad with output { // pad with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162}, auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
DataType::UInt32, runtime); DataType::UInt32, runtime);
auto output = make_ref<TensorObj>(Shape{3, 84, 164, 172}, auto output = make_ref<TensorObj>(Shape{3, 84, 164, 172},
@ -75,7 +75,7 @@ TEST(GraphFactory, ops) {
auto pad = gf->pad(input, output, pads, std::nullopt); auto pad = gf->pad(input, output, pads, std::nullopt);
} }
{ // slice without output { // slice without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162}, auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162},
DataType::UInt32, runtime); DataType::UInt32, runtime);
vector<int> starts = {2, 10, 1, 5}; vector<int> starts = {2, 10, 1, 5};
@ -84,7 +84,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(slice->getOutput()->getDims(), (Shape{2, 1, 100, 96})); EXPECT_EQ(slice->getOutput()->getDims(), (Shape{2, 1, 100, 96}));
} }
{ // slice with output { // slice with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162}, auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162},
DataType::UInt32, runtime); DataType::UInt32, runtime);
auto output = make_ref<TensorObj>(Shape{2, 1, 100, 96}, auto output = make_ref<TensorObj>(Shape{2, 1, 100, 96},
@ -95,7 +95,7 @@ TEST(GraphFactory, ops) {
gf->slice(input, output, starts, ends, std::nullopt, std::nullopt); gf->slice(input, output, starts, ends, std::nullopt, std::nullopt);
} }
{ // concat without output { // concat without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto t1 = auto t1 =
make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime); make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime);
auto t2 = auto t2 =
@ -104,7 +104,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(concat->getOutput()->getDims(), (Shape{1, 3, 2, 9})); EXPECT_EQ(concat->getOutput()->getDims(), (Shape{1, 3, 2, 9}));
} }
{ // concat with output { // concat with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto t1 = auto t1 =
make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime); make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime);
auto t2 = auto t2 =
@ -114,7 +114,7 @@ TEST(GraphFactory, ops) {
auto concat = gf->concat(TensorVec{t1, t2}, o0, 3); auto concat = gf->concat(TensorVec{t1, t2}, o0, 3);
} }
{ // split without output { // split without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime); make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime);
auto split = gf->split(input, 3, 4); auto split = gf->split(input, 3, 4);
@ -126,7 +126,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(split->getOutput(3)->getDims(), (Shape{1, 3, 2, 6})); EXPECT_EQ(split->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
} }
{ // split with output { // split with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime); make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime);
auto output0 = auto output0 =
@ -141,14 +141,14 @@ TEST(GraphFactory, ops) {
input, TensorVec{output0, output1, output2, output3}, 3, 4); input, TensorVec{output0, output1, output2, output3}, 3, 4);
} }
{ // extend without output { // extend without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime); make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto extend = gf->extend(input, 2, 1); auto extend = gf->extend(input, 2, 1);
EXPECT_EQ(extend->getOutput()->getDims(), (Shape{2, 3, 6, 4})); EXPECT_EQ(extend->getOutput()->getDims(), (Shape{2, 3, 6, 4}));
} }
{ // extend with output { // extend with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime); make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto output = auto output =
@ -156,7 +156,7 @@ TEST(GraphFactory, ops) {
auto extend = gf->extend(input, output, 2, 1); auto extend = gf->extend(input, output, 2, 1);
} }
{ // maxpool without output { // maxpool without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162}, auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
DataType::UInt32, runtime); DataType::UInt32, runtime);
const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2, const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2,
@ -165,7 +165,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(maxpool->getOutput()->getDims(), (Shape{1, 64, 80, 80})); EXPECT_EQ(maxpool->getOutput()->getDims(), (Shape{1, 64, 80, 80}));
} }
{ // maxpool with output { // maxpool with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162}, auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
DataType::UInt32, runtime); DataType::UInt32, runtime);
auto output = make_ref<TensorObj>(Shape{1, 64, 80, 80}, auto output = make_ref<TensorObj>(Shape{1, 64, 80, 80},
@ -176,7 +176,7 @@ TEST(GraphFactory, ops) {
gf->maxpool(input, output, kh, kw, dh, dw, ph, pw, sh, sw); gf->maxpool(input, output, kh, kw, dh, dw, ph, pw, sh, sw);
} }
{ // add without output { // add without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input0 = auto input0 =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime); make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto input1 = auto input1 =
@ -185,7 +185,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(add->getOutput()->getDims(), (Shape{2, 3, 3, 4})); EXPECT_EQ(add->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
} }
{ // add with output { // add with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input0 = auto input0 =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime); make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto input1 = auto input1 =
@ -195,7 +195,7 @@ TEST(GraphFactory, ops) {
auto add = gf->add(input0, input1, output); auto add = gf->add(input0, input1, output);
} }
{ // gather without output { // gather without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime); make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
auto index = auto index =
@ -204,7 +204,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(gather->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4})); EXPECT_EQ(gather->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
} }
{ // gather with output { // gather with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime); make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
auto index = auto index =
@ -214,7 +214,7 @@ TEST(GraphFactory, ops) {
auto gather = gf->gather(input, index, output, 1); auto gather = gf->gather(input, index, output, 1);
} }
{ // reshape without output { // reshape without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime); make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
vector<int> dims = {3, 2, 4, 3}; vector<int> dims = {3, 2, 4, 3};
@ -222,7 +222,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(reshape->getOutput()->getDims(), (Shape{3, 2, 4, 3})); EXPECT_EQ(reshape->getOutput()->getDims(), (Shape{3, 2, 4, 3}));
} }
{ // reshape with output { // reshape with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime); make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
vector<int> dims = {3, 2, 4, 3}; vector<int> dims = {3, 2, 4, 3};
@ -231,14 +231,14 @@ TEST(GraphFactory, ops) {
auto reshape = gf->reshape(input, output, dims); auto reshape = gf->reshape(input, output, dims);
} }
{ // flatten without output { // flatten without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime); make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto flatten = gf->flatten(input); auto flatten = gf->flatten(input);
EXPECT_EQ(flatten->getOutput()->getDims(), (Shape{72})); EXPECT_EQ(flatten->getOutput()->getDims(), (Shape{72}));
} }
{ // flatten without output { // flatten without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime); make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto output = auto output =
@ -246,14 +246,14 @@ TEST(GraphFactory, ops) {
auto flatten = gf->flatten(input, output); auto flatten = gf->flatten(input, output);
} }
{ // identity without output { // identity without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime); make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto identity = gf->identity(input); auto identity = gf->identity(input);
EXPECT_EQ(identity->getOutput()->getDims(), (Shape{2, 3, 3, 4})); EXPECT_EQ(identity->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
} }
{ // identity without output { // identity without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime); GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime); make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto output = auto output =