Update: Rename GraphFactory -> GraphBuilder && Remove unnecessary outputs

This commit is contained in:
Pairshoe 2022-10-26 21:06:21 +08:00 committed by mazx
parent 7cf2d8f78f
commit 970c77d0f4
8 changed files with 136 additions and 144 deletions

View File

@ -21,12 +21,12 @@
namespace infini {
class GraphFactoryObj {
class GraphBuilderObj {
private:
Graph g;
public:
GraphFactoryObj(Runtime runtime) : g(make_ref<GraphObj>(runtime)) {}
GraphBuilderObj(Runtime runtime) : g(make_ref<GraphObj>(runtime)) {}
// tensors
Tensor tensor(Shape dim, const std::string &dtype);
@ -163,4 +163,4 @@ class GraphFactoryObj {
nnet::Expr expr, double exec_time, std::string hint = {});
};
} // namespace infini
} // namespace infini

View File

@ -10,7 +10,7 @@ class TensorBaseObj;
class TensorObj;
class OperatorObj;
class GraphObj;
class GraphFactoryObj;
class GraphBuilderObj;
class RuntimeObj;
class BlobObj;
@ -18,7 +18,7 @@ using TensorBase = Ref<TensorBaseObj>;
using Tensor = Ref<TensorObj>;
using Operator = Ref<OperatorObj>;
using Graph = Ref<GraphObj>;
using GraphFactory = Ref<GraphFactoryObj>;
using GraphBuilder = Ref<GraphBuilderObj>;
using Runtime = Ref<RuntimeObj>;
using Blob = Ref<BlobObj>;
enum class OpType;

View File

@ -122,7 +122,7 @@ def _onnx_datatype_tostring(dtype):
assert False, 'Unknown onnx datatype'
def import_onnx(gf: GraphFactory, net: str):
def import_onnx(gf: GraphBuilder, net: str):
ts, ds, ops, consts = dict(), dict(), dict(), dict() # (key, value) = (name, class)
model = onnx.load(net)

View File

@ -5,8 +5,8 @@ import sys
def main(netPath):
runtime = CpuRuntimeObj.getInstance()
graphFactory = GraphFactoryObj(runtime)
import_onnx(graphFactory, netPath)
graphBuilder = GraphBuilderObj(runtime)
import_onnx(graphBuilder, netPath)
if __name__ == "__main__":
main(sys.argv[1])
main(sys.argv[1])

View File

@ -1,8 +1,8 @@
#include "core/graph_factory.h"
#include "core/graph_builder.h"
namespace infini {
Tensor GraphFactoryObj::tensor(Shape dim, const std::string &dtype) {
Tensor GraphBuilderObj::tensor(Shape dim, const std::string &dtype) {
if (dtype == "FLOAT") {
return g->addTensor(dim, DataType::Float32);
}
@ -12,7 +12,7 @@ Tensor GraphFactoryObj::tensor(Shape dim, const std::string &dtype) {
IT_TODO_HALT_MSG("Unsupported data type");
}
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
Operator GraphBuilderObj::conv(Tensor input, Tensor weight, Tensor output,
int ph, int pw, int sh, int sw, int dh, int dw,
Tensor bias) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -23,7 +23,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
return op;
}
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, int ph, int pw,
Operator GraphBuilderObj::conv(Tensor input, Tensor weight, int ph, int pw,
int sh, int sw, int dh, int dw, Tensor bias) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
@ -31,7 +31,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, int ph, int pw,
return op;
}
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
Operator GraphBuilderObj::conv(Tensor input, Tensor weight, Tensor output,
ConvBaseObj::PaddingMode pm, int sh, int sw,
int dh, int dw, Tensor bias) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -42,7 +42,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
return op;
}
Operator GraphFactoryObj::conv(Tensor input, Tensor weight,
Operator GraphBuilderObj::conv(Tensor input, Tensor weight,
ConvBaseObj::PaddingMode pm, int sh, int sw,
int dh, int dw, Tensor bias) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -51,7 +51,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight,
return op;
}
Operator GraphFactoryObj::matmul(Tensor A, Tensor B, Tensor C, bool transA,
Operator GraphBuilderObj::matmul(Tensor A, Tensor B, Tensor C, bool transA,
bool transB, Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
@ -61,7 +61,7 @@ Operator GraphFactoryObj::matmul(Tensor A, Tensor B, Tensor C, bool transA,
return op;
}
Operator GraphFactoryObj::matmul(Tensor A, Tensor B, bool transA, bool transB,
Operator GraphBuilderObj::matmul(Tensor A, Tensor B, bool transA, bool transB,
Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
@ -69,7 +69,7 @@ Operator GraphFactoryObj::matmul(Tensor A, Tensor B, bool transA, bool transB,
return op;
}
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, Tensor output,
int ph, int pw, int sh, int sw, int dh,
int dw, int oph, int opw, int group,
Tensor bias, ActType act) {
@ -81,7 +81,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
return op;
}
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, int ph, int pw,
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, int ph, int pw,
int sh, int sw, int dh, int dw, int oph,
int opw, int group, Tensor bias,
ActType act) {
@ -92,7 +92,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, int ph, int pw,
return op;
}
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, Tensor output,
ConvBaseObj::PaddingMode pm, int sh, int sw,
int dh, int dw, int oph, int opw, int group,
Tensor bias, ActType act) {
@ -104,7 +104,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
return op;
}
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight,
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight,
ConvBaseObj::PaddingMode pm, int sh, int sw,
int dh, int dw, int oph, int opw, int group,
Tensor bias, ActType act) {
@ -115,7 +115,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight,
return op;
}
Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width,
Operator GraphBuilderObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width,
const int dilation, Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
@ -125,7 +125,7 @@ Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width,
return op;
}
Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, const int width,
Operator GraphBuilderObj::g2bmm(Tensor A, Tensor B, const int width,
const int dilation, Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
@ -133,7 +133,7 @@ Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, const int width,
return op;
}
Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, Tensor C,
Operator GraphBuilderObj::gbmml(Tensor A, Tensor B, Tensor C,
const int dilation, Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
@ -142,7 +142,7 @@ Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, Tensor C,
return op;
}
Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, const int dilation,
Operator GraphBuilderObj::gbmml(Tensor A, Tensor B, const int dilation,
Tensor bias, ActType act) {
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
@ -150,7 +150,7 @@ Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, const int dilation,
return op;
}
Operator GraphFactoryObj::pad(Tensor input, Tensor output,
Operator GraphBuilderObj::pad(Tensor input, Tensor output,
const vector<int> &pads,
const optional<const vector<int>> &axis) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -159,14 +159,14 @@ Operator GraphFactoryObj::pad(Tensor input, Tensor output,
return op;
}
Operator GraphFactoryObj::pad(Tensor input, const vector<int> &pads,
Operator GraphBuilderObj::pad(Tensor input, const vector<int> &pads,
const optional<const vector<int>> &axis) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<PadObj>(i0, nullptr, pads, axis);
return op;
}
Operator GraphFactoryObj::slice(Tensor input, Tensor output,
Operator GraphBuilderObj::slice(Tensor input, Tensor output,
const vector<int> &starts,
const vector<int> &ends,
const optional<const vector<int>> &axis,
@ -177,7 +177,7 @@ Operator GraphFactoryObj::slice(Tensor input, Tensor output,
return op;
}
Operator GraphFactoryObj::slice(Tensor input, const vector<int> &starts,
Operator GraphBuilderObj::slice(Tensor input, const vector<int> &starts,
const vector<int> &ends,
const optional<const vector<int>> &axis,
const optional<const vector<int>> &steps) {
@ -186,7 +186,7 @@ Operator GraphFactoryObj::slice(Tensor input, const vector<int> &starts,
return op;
}
Operator GraphFactoryObj::concat(TensorVec inputs, Tensor output, int dim) {
Operator GraphBuilderObj::concat(TensorVec inputs, Tensor output, int dim) {
TensorVec is;
for (auto input : inputs) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
@ -197,7 +197,7 @@ Operator GraphFactoryObj::concat(TensorVec inputs, Tensor output, int dim) {
return op;
}
Operator GraphFactoryObj::concat(TensorVec inputs, int dim) {
Operator GraphBuilderObj::concat(TensorVec inputs, int dim) {
TensorVec is;
for (auto input : inputs) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
@ -207,7 +207,7 @@ Operator GraphFactoryObj::concat(TensorVec inputs, int dim) {
return op;
}
Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
Operator GraphBuilderObj::split(Tensor input, std::optional<TensorVec> outputs,
int dim, int num) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
if (outputs.has_value()) {
@ -224,13 +224,13 @@ Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
}
}
Operator GraphFactoryObj::split(Tensor input, int dim, int num) {
Operator GraphBuilderObj::split(Tensor input, int dim, int num) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, num);
return op;
}
Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
Operator GraphBuilderObj::split(Tensor input, std::optional<TensorVec> outputs,
int dim, const vector<int> &ratio) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
if (outputs.has_value()) {
@ -247,14 +247,14 @@ Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
}
}
Operator GraphFactoryObj::split(Tensor input, int dim,
Operator GraphBuilderObj::split(Tensor input, int dim,
const vector<int> &ratio) {
Tensor i = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, ratio);
return op;
}
Operator GraphFactoryObj::extend(Tensor input, Tensor output, int dim,
Operator GraphBuilderObj::extend(Tensor input, Tensor output, int dim,
int num) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -262,13 +262,13 @@ Operator GraphFactoryObj::extend(Tensor input, Tensor output, int dim,
return op;
}
Operator GraphFactoryObj::extend(Tensor input, int dim, int num) {
Operator GraphBuilderObj::extend(Tensor input, int dim, int num) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<ExtendObj>(i0, nullptr, dim, num);
return op;
}
Operator GraphFactoryObj::maxpool(Tensor input, Tensor output, int kh, int kw,
Operator GraphBuilderObj::maxpool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh,
int sw) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -278,14 +278,14 @@ Operator GraphFactoryObj::maxpool(Tensor input, Tensor output, int kh, int kw,
return op;
}
Operator GraphFactoryObj::maxpool(Tensor input, int kh, int kw, int dh, int dw,
Operator GraphBuilderObj::maxpool(Tensor input, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<MaxPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
return op;
}
Operator GraphFactoryObj::avgpool(Tensor input, Tensor output, int kh, int kw,
Operator GraphBuilderObj::avgpool(Tensor input, Tensor output, int kh, int kw,
int dh, int dw, int ph, int pw, int sh,
int sw) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
@ -295,14 +295,14 @@ Operator GraphFactoryObj::avgpool(Tensor input, Tensor output, int kh, int kw,
return op;
}
Operator GraphFactoryObj::avgpool(Tensor input, int kh, int kw, int dh, int dw,
Operator GraphBuilderObj::avgpool(Tensor input, int kh, int kw, int dh, int dw,
int ph, int pw, int sh, int sw) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<AvgPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
return op;
}
Operator GraphFactoryObj::add(Tensor input0, Tensor input1, Tensor output) {
Operator GraphBuilderObj::add(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -310,14 +310,14 @@ Operator GraphFactoryObj::add(Tensor input0, Tensor input1, Tensor output) {
return op;
}
Operator GraphFactoryObj::add(Tensor input0, Tensor input1) {
Operator GraphBuilderObj::add(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<AddObj>(i0, i1, nullptr);
return op;
}
Operator GraphFactoryObj::sub(Tensor input0, Tensor input1, Tensor output) {
Operator GraphBuilderObj::sub(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -325,14 +325,14 @@ Operator GraphFactoryObj::sub(Tensor input0, Tensor input1, Tensor output) {
return op;
}
Operator GraphFactoryObj::sub(Tensor input0, Tensor input1) {
Operator GraphBuilderObj::sub(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<SubObj>(i0, i1, nullptr);
return op;
}
Operator GraphFactoryObj::mul(Tensor input0, Tensor input1, Tensor output) {
Operator GraphBuilderObj::mul(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -340,14 +340,14 @@ Operator GraphFactoryObj::mul(Tensor input0, Tensor input1, Tensor output) {
return op;
}
Operator GraphFactoryObj::mul(Tensor input0, Tensor input1) {
Operator GraphBuilderObj::mul(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<SubObj>(i0, i1, nullptr);
return op;
}
Operator GraphFactoryObj::div(Tensor input0, Tensor input1, Tensor output) {
Operator GraphBuilderObj::div(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -355,14 +355,14 @@ Operator GraphFactoryObj::div(Tensor input0, Tensor input1, Tensor output) {
return op;
}
Operator GraphFactoryObj::div(Tensor input0, Tensor input1) {
Operator GraphBuilderObj::div(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<DivObj>(i0, i1, nullptr);
return op;
}
Operator GraphFactoryObj::pow(Tensor input0, Tensor input1, Tensor output) {
Operator GraphBuilderObj::pow(Tensor input0, Tensor input1, Tensor output) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -370,14 +370,14 @@ Operator GraphFactoryObj::pow(Tensor input0, Tensor input1, Tensor output) {
return op;
}
Operator GraphFactoryObj::pow(Tensor input0, Tensor input1) {
Operator GraphBuilderObj::pow(Tensor input0, Tensor input1) {
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
auto op = g->addOp<PowObj>(i0, i1, nullptr);
return op;
}
Operator GraphFactoryObj::gather(Tensor input, Tensor index, Tensor output,
Operator GraphBuilderObj::gather(Tensor input, Tensor index, Tensor output,
int axis) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -385,13 +385,13 @@ Operator GraphFactoryObj::gather(Tensor input, Tensor index, Tensor output,
return op;
}
Operator GraphFactoryObj::gather(Tensor input, Tensor index, int axis) {
Operator GraphBuilderObj::gather(Tensor input, Tensor index, int axis) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<GatherObj>(i0, index, nullptr, axis);
return op;
}
Operator GraphFactoryObj::reshape(Tensor input, Tensor output,
Operator GraphBuilderObj::reshape(Tensor input, Tensor output,
const Shape &dims) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
@ -399,104 +399,104 @@ Operator GraphFactoryObj::reshape(Tensor input, Tensor output,
return op;
}
Operator GraphFactoryObj::reshape(Tensor input, const Shape &dims) {
Operator GraphBuilderObj::reshape(Tensor input, const Shape &dims) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<ReshapeObj>(i0, nullptr, dims);
return op;
}
Operator GraphFactoryObj::flatten(Tensor input, Tensor output) {
Operator GraphBuilderObj::flatten(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<FlattenObj>(i0, o0);
return op;
}
Operator GraphFactoryObj::flatten(Tensor input) {
Operator GraphBuilderObj::flatten(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<FlattenObj>(i0, nullptr);
return op;
}
Operator GraphFactoryObj::identity(Tensor input, Tensor output) {
Operator GraphBuilderObj::identity(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<IdentityObj>(i0, o0);
return op;
}
Operator GraphFactoryObj::identity(Tensor input) {
Operator GraphBuilderObj::identity(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<IdentityObj>(i0, nullptr);
return op;
}
Operator GraphFactoryObj::softmax(Tensor input, Tensor output) {
Operator GraphBuilderObj::softmax(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<SoftmaxObj>(i0, o0);
return op;
}
Operator GraphFactoryObj::softmax(Tensor input) {
Operator GraphBuilderObj::softmax(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SoftmaxObj>(i0, nullptr);
return op;
}
Operator GraphFactoryObj::relu(Tensor input, Tensor output) {
Operator GraphBuilderObj::relu(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<ReluObj>(i0, o0);
return op;
}
Operator GraphFactoryObj::relu(Tensor input) {
Operator GraphBuilderObj::relu(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<ReluObj>(i0, nullptr);
return op;
}
Operator GraphFactoryObj::sigmoid(Tensor input, Tensor output) {
Operator GraphBuilderObj::sigmoid(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<SigmoidObj>(i0, o0);
return op;
}
Operator GraphFactoryObj::sigmoid(Tensor input) {
Operator GraphBuilderObj::sigmoid(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<SigmoidObj>(i0, nullptr);
return op;
}
Operator GraphFactoryObj::tanh(Tensor input, Tensor output) {
Operator GraphBuilderObj::tanh(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<TanhObj>(i0, o0);
return op;
}
Operator GraphFactoryObj::tanh(Tensor input) {
Operator GraphBuilderObj::tanh(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<TanhObj>(i0, nullptr);
return op;
}
Operator GraphFactoryObj::abs(Tensor input, Tensor output) {
Operator GraphBuilderObj::abs(Tensor input, Tensor output) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
auto op = g->addOpWithOutputs<AbsObj>(i0, o0);
return op;
}
Operator GraphFactoryObj::abs(Tensor input) {
Operator GraphBuilderObj::abs(Tensor input) {
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
auto op = g->addOp<AbsObj>(i0, nullptr);
return op;
}
Operator GraphFactoryObj::memBound(const TensorVec &inputs,
Operator GraphBuilderObj::memBound(const TensorVec &inputs,
const TensorVec &outputs,
const std::vector<nnet::Tensor> &nnetInputs,
nnet::Expr expr, double exec_time,
@ -516,4 +516,4 @@ Operator GraphFactoryObj::memBound(const TensorVec &inputs,
return op;
}
} // namespace infini
} // namespace infini

View File

@ -71,14 +71,6 @@ bool OperatorObj::checkValid(GraphObj *graph) {
}
} else { // if outputs have been created, check their shapes
for (size_t i = 0; i < shapes.size(); ++i) {
printf("|* i = %ld *|\n", i);
printf("shapes:\n");
for (auto shape : shapes[i])
printf("%d ", shape);
printf("\n");
for (auto dim : outputs[i]->getDims())
printf("%d ", dim);
printf("\n");
if (shapes[i] != outputs[i]->getDims())
return false;
}

View File

@ -2,7 +2,7 @@
#ifdef USE_CUDA
#include "cuda/operator_timer.h"
#endif
#include "core/graph_factory.h"
#include "core/graph_builder.h"
namespace py = pybind11;
namespace infini {
@ -21,7 +21,7 @@ void register_operator_timer(py::module &m) {
#endif
}
void init_graph_factory(py::module &m) {
void init_graph_builder(py::module &m) {
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
m, "CpuRuntimeObj")
@ -75,112 +75,112 @@ void init_graph_factory(py::module &m) {
py::class_<AbsObj, std::shared_ptr<AbsObj>, OperatorObj>(m, "AbsObj");
py::class_<MemBoundObj, std::shared_ptr<MemBoundObj>, OperatorObj>(
m, "MemBoundObj");
py::class_<GraphFactory>(m, "GraphFactory");
py::class_<GraphFactoryObj>(m, "GraphFactoryObj")
py::class_<GraphBuilder>(m, "GraphBuilder");
py::class_<GraphBuilderObj>(m, "GraphBuilderObj")
.def(py::init<Runtime>())
.def("tensor",
py::overload_cast<Shape, const std::string &>(
&GraphFactoryObj::tensor),
&GraphBuilderObj::tensor),
policy::reference_internal)
.def("conv",
py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int,
int, Tensor>(&GraphFactoryObj::conv),
int, Tensor>(&GraphBuilderObj::conv),
policy::reference_internal)
.def("matmul",
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
ActType>(&GraphFactoryObj::matmul),
ActType>(&GraphBuilderObj::matmul),
policy::reference_internal)
.def("convTrans",
py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int,
int, int, int, int, Tensor, ActType>(
&GraphFactoryObj::convTrans),
&GraphBuilderObj::convTrans),
policy::reference_internal)
.def("g2bmm",
py::overload_cast<Tensor, Tensor, Tensor, const int, const int,
Tensor, ActType>(&GraphFactoryObj::g2bmm),
Tensor, ActType>(&GraphBuilderObj::g2bmm),
policy::reference_internal)
.def("gbmml",
py::overload_cast<Tensor, Tensor, Tensor, const int, Tensor,
ActType>(&GraphFactoryObj::gbmml),
ActType>(&GraphBuilderObj::gbmml),
policy::reference_internal)
.def("pad",
py::overload_cast<Tensor, Tensor, const vector<int> &,
const optional<const vector<int>> &>(
&GraphFactoryObj::pad),
&GraphBuilderObj::pad),
policy::reference_internal)
.def("slice",
py::overload_cast<Tensor, Tensor, const vector<int> &,
const vector<int> &,
const optional<const vector<int>> &,
const optional<const vector<int>> &>(
&GraphFactoryObj::slice),
&GraphBuilderObj::slice),
policy::reference_internal)
.def(
"concat",
py::overload_cast<TensorVec, Tensor, int>(&GraphFactoryObj::concat),
py::overload_cast<TensorVec, Tensor, int>(&GraphBuilderObj::concat),
policy::reference_internal)
.def("split",
py::overload_cast<Tensor, std::optional<TensorVec>, int, int>(
&GraphFactoryObj::split),
&GraphBuilderObj::split),
policy::reference_internal)
.def("extend",
py::overload_cast<Tensor, Tensor, int, int>(
&GraphFactoryObj::extend),
&GraphBuilderObj::extend),
policy::reference_internal)
.def("maxpool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&GraphFactoryObj::maxpool),
int, int>(&GraphBuilderObj::maxpool),
policy::reference_internal)
.def("avgpool",
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
int, int>(&GraphFactoryObj::avgpool),
int, int>(&GraphBuilderObj::avgpool),
policy::reference_internal)
.def("add",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::add),
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::add),
policy::reference_internal)
.def("sub",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::sub),
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::sub),
policy::reference_internal)
.def("mul",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::mul),
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::mul),
policy::reference_internal)
.def("div",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::div),
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::div),
policy::reference_internal)
.def("pow",
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::pow),
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::pow),
policy::reference_internal)
.def("gather",
py::overload_cast<Tensor, Tensor, Tensor, int>(
&GraphFactoryObj::gather),
&GraphBuilderObj::gather),
policy::reference_internal)
.def("reshape",
py::overload_cast<Tensor, Tensor, const Shape &>(
&GraphFactoryObj::reshape),
&GraphBuilderObj::reshape),
policy::reference_internal)
.def("flatten",
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::flatten),
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::flatten),
policy::reference_internal)
.def("identity",
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::identity),
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::identity),
policy::reference_internal)
.def("softmax",
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::softmax),
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::softmax),
policy::reference_internal)
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::relu),
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::relu),
policy::reference_internal)
.def("sigmoid",
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::sigmoid),
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::sigmoid),
policy::reference_internal)
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::tanh),
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::tanh),
policy::reference_internal)
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::abs),
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::abs),
policy::reference_internal)
.def("memBound",
py::overload_cast<const TensorVec &, const TensorVec &,
const std::vector<nnet::Tensor> &, nnet::Expr,
double, std::string>(&GraphFactoryObj::memBound),
double, std::string>(&GraphBuilderObj::memBound),
policy::reference_internal);
}
@ -188,5 +188,5 @@ void init_graph_factory(py::module &m) {
PYBIND11_MODULE(pyinfinitensor, m) {
infini::register_operator_timer(m);
infini::init_graph_factory(m);
}
infini::init_graph_builder(m);
}

View File

@ -1,12 +1,12 @@
#include "core/graph_factory.h"
#include "core/graph_builder.h"
#include "test.h"
namespace infini {
TEST(GraphFactory, ops) {
TEST(GraphBuilder, ops) {
Runtime runtime = CpuRuntimeObj::getInstance();
{ // conv without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
auto weight =
@ -15,7 +15,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 4, 4}));
}
{ // conv with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
auto weight =
@ -25,21 +25,21 @@ TEST(GraphFactory, ops) {
auto conv = gf->conv(input, weight, output, 1, 1);
}
{ // matmul without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime);
auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime);
auto matmul = gf->matmul(A, B);
EXPECT_EQ(matmul->getOutput()->getDims(), (Shape{1, 3, 2}));
}
{ // matmul with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime);
auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime);
auto C = make_ref<TensorObj>(Shape{1, 3, 2}, DataType::UInt32, runtime);
auto matmul = gf->matmul(A, B, C);
}
{ // convtrans without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime);
auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2},
@ -48,7 +48,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(convtrans->getOutput()->getDims(), (Shape{1, 448, 2, 2}));
}
{ // convtrans with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime);
auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2},
@ -58,7 +58,7 @@ TEST(GraphFactory, ops) {
auto convtrans = gf->convTrans(input, weight, 0, 0);
}
{ // pad without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
DataType::UInt32, runtime);
vector<int> pads = {2, 10, 1, 5, 0, 10, 1, 5};
@ -66,7 +66,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(pad->getOutput()->getDims(), (Shape{3, 84, 164, 172}));
}
{ // pad with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
DataType::UInt32, runtime);
auto output = make_ref<TensorObj>(Shape{3, 84, 164, 172},
@ -75,7 +75,7 @@ TEST(GraphFactory, ops) {
auto pad = gf->pad(input, output, pads, std::nullopt);
}
{ // slice without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162},
DataType::UInt32, runtime);
vector<int> starts = {2, 10, 1, 5};
@ -84,7 +84,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(slice->getOutput()->getDims(), (Shape{2, 1, 100, 96}));
}
{ // slice with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162},
DataType::UInt32, runtime);
auto output = make_ref<TensorObj>(Shape{2, 1, 100, 96},
@ -95,7 +95,7 @@ TEST(GraphFactory, ops) {
gf->slice(input, output, starts, ends, std::nullopt, std::nullopt);
}
{ // concat without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto t1 =
make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime);
auto t2 =
@ -104,7 +104,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(concat->getOutput()->getDims(), (Shape{1, 3, 2, 9}));
}
{ // concat with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto t1 =
make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime);
auto t2 =
@ -114,7 +114,7 @@ TEST(GraphFactory, ops) {
auto concat = gf->concat(TensorVec{t1, t2}, o0, 3);
}
{ // split without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime);
auto split = gf->split(input, 3, 4);
@ -126,7 +126,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(split->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
}
{ // split with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime);
auto output0 =
@ -141,14 +141,14 @@ TEST(GraphFactory, ops) {
input, TensorVec{output0, output1, output2, output3}, 3, 4);
}
{ // extend without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto extend = gf->extend(input, 2, 1);
EXPECT_EQ(extend->getOutput()->getDims(), (Shape{2, 3, 6, 4}));
}
{ // extend with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto output =
@ -156,7 +156,7 @@ TEST(GraphFactory, ops) {
auto extend = gf->extend(input, output, 2, 1);
}
{ // maxpool without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
DataType::UInt32, runtime);
const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2,
@ -165,7 +165,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(maxpool->getOutput()->getDims(), (Shape{1, 64, 80, 80}));
}
{ // maxpool with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
DataType::UInt32, runtime);
auto output = make_ref<TensorObj>(Shape{1, 64, 80, 80},
@ -176,7 +176,7 @@ TEST(GraphFactory, ops) {
gf->maxpool(input, output, kh, kw, dh, dw, ph, pw, sh, sw);
}
{ // add without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input0 =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto input1 =
@ -185,7 +185,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(add->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
}
{ // add with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input0 =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
auto input1 =
@ -195,7 +195,7 @@ TEST(GraphFactory, ops) {
auto add = gf->add(input0, input1, output);
}
{ // gather without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
auto index =
@ -204,7 +204,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(gather->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
}
{ // gather with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
auto index =
@ -214,7 +214,7 @@ TEST(GraphFactory, ops) {
auto gather = gf->gather(input, index, output, 1);
}
{ // reshape without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
vector<int> dims = {3, 2, 4, 3};
@ -222,7 +222,7 @@ TEST(GraphFactory, ops) {
EXPECT_EQ(reshape->getOutput()->getDims(), (Shape{3, 2, 4, 3}));
}
{ // reshape with output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
vector<int> dims = {3, 2, 4, 3};
@ -231,14 +231,14 @@ TEST(GraphFactory, ops) {
auto reshape = gf->reshape(input, output, dims);
}
{ // flatten without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto flatten = gf->flatten(input);
EXPECT_EQ(flatten->getOutput()->getDims(), (Shape{72}));
}
{ // flatten without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto output =
@ -246,14 +246,14 @@ TEST(GraphFactory, ops) {
auto flatten = gf->flatten(input, output);
}
{ // identity without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto identity = gf->identity(input);
EXPECT_EQ(identity->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
}
{ // identity without output
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
auto input =
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
auto output =
@ -262,4 +262,4 @@ TEST(GraphFactory, ops) {
}
}
} // namespace infini
} // namespace infini