forked from jiuyuan/InfiniTensor
Update: Rename GraphFactory -> GraphBuilder && Remove unnecessary outputs
This commit is contained in:
parent
7cf2d8f78f
commit
970c77d0f4
|
@ -21,12 +21,12 @@
|
|||
|
||||
namespace infini {
|
||||
|
||||
class GraphFactoryObj {
|
||||
class GraphBuilderObj {
|
||||
private:
|
||||
Graph g;
|
||||
|
||||
public:
|
||||
GraphFactoryObj(Runtime runtime) : g(make_ref<GraphObj>(runtime)) {}
|
||||
GraphBuilderObj(Runtime runtime) : g(make_ref<GraphObj>(runtime)) {}
|
||||
|
||||
// tensors
|
||||
Tensor tensor(Shape dim, const std::string &dtype);
|
||||
|
@ -163,4 +163,4 @@ class GraphFactoryObj {
|
|||
nnet::Expr expr, double exec_time, std::string hint = {});
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
} // namespace infini
|
|
@ -10,7 +10,7 @@ class TensorBaseObj;
|
|||
class TensorObj;
|
||||
class OperatorObj;
|
||||
class GraphObj;
|
||||
class GraphFactoryObj;
|
||||
class GraphBuilderObj;
|
||||
class RuntimeObj;
|
||||
class BlobObj;
|
||||
|
||||
|
@ -18,7 +18,7 @@ using TensorBase = Ref<TensorBaseObj>;
|
|||
using Tensor = Ref<TensorObj>;
|
||||
using Operator = Ref<OperatorObj>;
|
||||
using Graph = Ref<GraphObj>;
|
||||
using GraphFactory = Ref<GraphFactoryObj>;
|
||||
using GraphBuilder = Ref<GraphBuilderObj>;
|
||||
using Runtime = Ref<RuntimeObj>;
|
||||
using Blob = Ref<BlobObj>;
|
||||
enum class OpType;
|
||||
|
|
|
@ -122,7 +122,7 @@ def _onnx_datatype_tostring(dtype):
|
|||
assert False, 'Unknown onnx datatype'
|
||||
|
||||
|
||||
def import_onnx(gf: GraphFactory, net: str):
|
||||
def import_onnx(gf: GraphBuilder, net: str):
|
||||
ts, ds, ops, consts = dict(), dict(), dict(), dict() # (key, value) = (name, class)
|
||||
model = onnx.load(net)
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@ import sys
|
|||
|
||||
def main(netPath):
|
||||
runtime = CpuRuntimeObj.getInstance()
|
||||
graphFactory = GraphFactoryObj(runtime)
|
||||
import_onnx(graphFactory, netPath)
|
||||
graphBuilder = GraphBuilderObj(runtime)
|
||||
import_onnx(graphBuilder, netPath)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(sys.argv[1])
|
||||
main(sys.argv[1])
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
#include "core/graph_factory.h"
|
||||
#include "core/graph_builder.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
Tensor GraphFactoryObj::tensor(Shape dim, const std::string &dtype) {
|
||||
Tensor GraphBuilderObj::tensor(Shape dim, const std::string &dtype) {
|
||||
if (dtype == "FLOAT") {
|
||||
return g->addTensor(dim, DataType::Float32);
|
||||
}
|
||||
|
@ -12,7 +12,7 @@ Tensor GraphFactoryObj::tensor(Shape dim, const std::string &dtype) {
|
|||
IT_TODO_HALT_MSG("Unsupported data type");
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
|
||||
Operator GraphBuilderObj::conv(Tensor input, Tensor weight, Tensor output,
|
||||
int ph, int pw, int sh, int sw, int dh, int dw,
|
||||
Tensor bias) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
|
@ -23,7 +23,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, int ph, int pw,
|
||||
Operator GraphBuilderObj::conv(Tensor input, Tensor weight, int ph, int pw,
|
||||
int sh, int sw, int dh, int dw, Tensor bias) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
|
||||
|
@ -31,7 +31,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, int ph, int pw,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
|
||||
Operator GraphBuilderObj::conv(Tensor input, Tensor weight, Tensor output,
|
||||
ConvBaseObj::PaddingMode pm, int sh, int sw,
|
||||
int dh, int dw, Tensor bias) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
|
@ -42,7 +42,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::conv(Tensor input, Tensor weight,
|
||||
Operator GraphBuilderObj::conv(Tensor input, Tensor weight,
|
||||
ConvBaseObj::PaddingMode pm, int sh, int sw,
|
||||
int dh, int dw, Tensor bias) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
|
@ -51,7 +51,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::matmul(Tensor A, Tensor B, Tensor C, bool transA,
|
||||
Operator GraphBuilderObj::matmul(Tensor A, Tensor B, Tensor C, bool transA,
|
||||
bool transB, Tensor bias, ActType act) {
|
||||
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
||||
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
||||
|
@ -61,7 +61,7 @@ Operator GraphFactoryObj::matmul(Tensor A, Tensor B, Tensor C, bool transA,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::matmul(Tensor A, Tensor B, bool transA, bool transB,
|
||||
Operator GraphBuilderObj::matmul(Tensor A, Tensor B, bool transA, bool transB,
|
||||
Tensor bias, ActType act) {
|
||||
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
||||
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
||||
|
@ -69,7 +69,7 @@ Operator GraphFactoryObj::matmul(Tensor A, Tensor B, bool transA, bool transB,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
|
||||
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, Tensor output,
|
||||
int ph, int pw, int sh, int sw, int dh,
|
||||
int dw, int oph, int opw, int group,
|
||||
Tensor bias, ActType act) {
|
||||
|
@ -81,7 +81,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, int ph, int pw,
|
||||
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, int ph, int pw,
|
||||
int sh, int sw, int dh, int dw, int oph,
|
||||
int opw, int group, Tensor bias,
|
||||
ActType act) {
|
||||
|
@ -92,7 +92,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, int ph, int pw,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
|
||||
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, Tensor output,
|
||||
ConvBaseObj::PaddingMode pm, int sh, int sw,
|
||||
int dh, int dw, int oph, int opw, int group,
|
||||
Tensor bias, ActType act) {
|
||||
|
@ -104,7 +104,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight,
|
||||
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight,
|
||||
ConvBaseObj::PaddingMode pm, int sh, int sw,
|
||||
int dh, int dw, int oph, int opw, int group,
|
||||
Tensor bias, ActType act) {
|
||||
|
@ -115,7 +115,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width,
|
||||
Operator GraphBuilderObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width,
|
||||
const int dilation, Tensor bias, ActType act) {
|
||||
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
||||
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
||||
|
@ -125,7 +125,7 @@ Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, const int width,
|
||||
Operator GraphBuilderObj::g2bmm(Tensor A, Tensor B, const int width,
|
||||
const int dilation, Tensor bias, ActType act) {
|
||||
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
||||
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
||||
|
@ -133,7 +133,7 @@ Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, const int width,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, Tensor C,
|
||||
Operator GraphBuilderObj::gbmml(Tensor A, Tensor B, Tensor C,
|
||||
const int dilation, Tensor bias, ActType act) {
|
||||
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
||||
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
||||
|
@ -142,7 +142,7 @@ Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, Tensor C,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, const int dilation,
|
||||
Operator GraphBuilderObj::gbmml(Tensor A, Tensor B, const int dilation,
|
||||
Tensor bias, ActType act) {
|
||||
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
||||
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
||||
|
@ -150,7 +150,7 @@ Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, const int dilation,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::pad(Tensor input, Tensor output,
|
||||
Operator GraphBuilderObj::pad(Tensor input, Tensor output,
|
||||
const vector<int> &pads,
|
||||
const optional<const vector<int>> &axis) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
|
@ -159,14 +159,14 @@ Operator GraphFactoryObj::pad(Tensor input, Tensor output,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::pad(Tensor input, const vector<int> &pads,
|
||||
Operator GraphBuilderObj::pad(Tensor input, const vector<int> &pads,
|
||||
const optional<const vector<int>> &axis) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<PadObj>(i0, nullptr, pads, axis);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::slice(Tensor input, Tensor output,
|
||||
Operator GraphBuilderObj::slice(Tensor input, Tensor output,
|
||||
const vector<int> &starts,
|
||||
const vector<int> &ends,
|
||||
const optional<const vector<int>> &axis,
|
||||
|
@ -177,7 +177,7 @@ Operator GraphFactoryObj::slice(Tensor input, Tensor output,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::slice(Tensor input, const vector<int> &starts,
|
||||
Operator GraphBuilderObj::slice(Tensor input, const vector<int> &starts,
|
||||
const vector<int> &ends,
|
||||
const optional<const vector<int>> &axis,
|
||||
const optional<const vector<int>> &steps) {
|
||||
|
@ -186,7 +186,7 @@ Operator GraphFactoryObj::slice(Tensor input, const vector<int> &starts,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::concat(TensorVec inputs, Tensor output, int dim) {
|
||||
Operator GraphBuilderObj::concat(TensorVec inputs, Tensor output, int dim) {
|
||||
TensorVec is;
|
||||
for (auto input : inputs) {
|
||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||
|
@ -197,7 +197,7 @@ Operator GraphFactoryObj::concat(TensorVec inputs, Tensor output, int dim) {
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::concat(TensorVec inputs, int dim) {
|
||||
Operator GraphBuilderObj::concat(TensorVec inputs, int dim) {
|
||||
TensorVec is;
|
||||
for (auto input : inputs) {
|
||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||
|
@ -207,7 +207,7 @@ Operator GraphFactoryObj::concat(TensorVec inputs, int dim) {
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||
Operator GraphBuilderObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||
int dim, int num) {
|
||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||
if (outputs.has_value()) {
|
||||
|
@ -224,13 +224,13 @@ Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
|
|||
}
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::split(Tensor input, int dim, int num) {
|
||||
Operator GraphBuilderObj::split(Tensor input, int dim, int num) {
|
||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, num);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||
Operator GraphBuilderObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||
int dim, const vector<int> &ratio) {
|
||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||
if (outputs.has_value()) {
|
||||
|
@ -247,14 +247,14 @@ Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
|
|||
}
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::split(Tensor input, int dim,
|
||||
Operator GraphBuilderObj::split(Tensor input, int dim,
|
||||
const vector<int> &ratio) {
|
||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, ratio);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::extend(Tensor input, Tensor output, int dim,
|
||||
Operator GraphBuilderObj::extend(Tensor input, Tensor output, int dim,
|
||||
int num) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
|
@ -262,13 +262,13 @@ Operator GraphFactoryObj::extend(Tensor input, Tensor output, int dim,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::extend(Tensor input, int dim, int num) {
|
||||
Operator GraphBuilderObj::extend(Tensor input, int dim, int num) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<ExtendObj>(i0, nullptr, dim, num);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::maxpool(Tensor input, Tensor output, int kh, int kw,
|
||||
Operator GraphBuilderObj::maxpool(Tensor input, Tensor output, int kh, int kw,
|
||||
int dh, int dw, int ph, int pw, int sh,
|
||||
int sw) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
|
@ -278,14 +278,14 @@ Operator GraphFactoryObj::maxpool(Tensor input, Tensor output, int kh, int kw,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::maxpool(Tensor input, int kh, int kw, int dh, int dw,
|
||||
Operator GraphBuilderObj::maxpool(Tensor input, int kh, int kw, int dh, int dw,
|
||||
int ph, int pw, int sh, int sw) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<MaxPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::avgpool(Tensor input, Tensor output, int kh, int kw,
|
||||
Operator GraphBuilderObj::avgpool(Tensor input, Tensor output, int kh, int kw,
|
||||
int dh, int dw, int ph, int pw, int sh,
|
||||
int sw) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
|
@ -295,14 +295,14 @@ Operator GraphFactoryObj::avgpool(Tensor input, Tensor output, int kh, int kw,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::avgpool(Tensor input, int kh, int kw, int dh, int dw,
|
||||
Operator GraphBuilderObj::avgpool(Tensor input, int kh, int kw, int dh, int dw,
|
||||
int ph, int pw, int sh, int sw) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<AvgPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::add(Tensor input0, Tensor input1, Tensor output) {
|
||||
Operator GraphBuilderObj::add(Tensor input0, Tensor input1, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
|
@ -310,14 +310,14 @@ Operator GraphFactoryObj::add(Tensor input0, Tensor input1, Tensor output) {
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::add(Tensor input0, Tensor input1) {
|
||||
Operator GraphBuilderObj::add(Tensor input0, Tensor input1) {
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
auto op = g->addOp<AddObj>(i0, i1, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::sub(Tensor input0, Tensor input1, Tensor output) {
|
||||
Operator GraphBuilderObj::sub(Tensor input0, Tensor input1, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
|
@ -325,14 +325,14 @@ Operator GraphFactoryObj::sub(Tensor input0, Tensor input1, Tensor output) {
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::sub(Tensor input0, Tensor input1) {
|
||||
Operator GraphBuilderObj::sub(Tensor input0, Tensor input1) {
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
auto op = g->addOp<SubObj>(i0, i1, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::mul(Tensor input0, Tensor input1, Tensor output) {
|
||||
Operator GraphBuilderObj::mul(Tensor input0, Tensor input1, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
|
@ -340,14 +340,14 @@ Operator GraphFactoryObj::mul(Tensor input0, Tensor input1, Tensor output) {
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::mul(Tensor input0, Tensor input1) {
|
||||
Operator GraphBuilderObj::mul(Tensor input0, Tensor input1) {
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
auto op = g->addOp<SubObj>(i0, i1, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::div(Tensor input0, Tensor input1, Tensor output) {
|
||||
Operator GraphBuilderObj::div(Tensor input0, Tensor input1, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
|
@ -355,14 +355,14 @@ Operator GraphFactoryObj::div(Tensor input0, Tensor input1, Tensor output) {
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::div(Tensor input0, Tensor input1) {
|
||||
Operator GraphBuilderObj::div(Tensor input0, Tensor input1) {
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
auto op = g->addOp<DivObj>(i0, i1, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::pow(Tensor input0, Tensor input1, Tensor output) {
|
||||
Operator GraphBuilderObj::pow(Tensor input0, Tensor input1, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
|
@ -370,14 +370,14 @@ Operator GraphFactoryObj::pow(Tensor input0, Tensor input1, Tensor output) {
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::pow(Tensor input0, Tensor input1) {
|
||||
Operator GraphBuilderObj::pow(Tensor input0, Tensor input1) {
|
||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||
auto op = g->addOp<PowObj>(i0, i1, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::gather(Tensor input, Tensor index, Tensor output,
|
||||
Operator GraphBuilderObj::gather(Tensor input, Tensor index, Tensor output,
|
||||
int axis) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
|
@ -385,13 +385,13 @@ Operator GraphFactoryObj::gather(Tensor input, Tensor index, Tensor output,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::gather(Tensor input, Tensor index, int axis) {
|
||||
Operator GraphBuilderObj::gather(Tensor input, Tensor index, int axis) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<GatherObj>(i0, index, nullptr, axis);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::reshape(Tensor input, Tensor output,
|
||||
Operator GraphBuilderObj::reshape(Tensor input, Tensor output,
|
||||
const Shape &dims) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
|
@ -399,104 +399,104 @@ Operator GraphFactoryObj::reshape(Tensor input, Tensor output,
|
|||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::reshape(Tensor input, const Shape &dims) {
|
||||
Operator GraphBuilderObj::reshape(Tensor input, const Shape &dims) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<ReshapeObj>(i0, nullptr, dims);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::flatten(Tensor input, Tensor output) {
|
||||
Operator GraphBuilderObj::flatten(Tensor input, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
auto op = g->addOpWithOutputs<FlattenObj>(i0, o0);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::flatten(Tensor input) {
|
||||
Operator GraphBuilderObj::flatten(Tensor input) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<FlattenObj>(i0, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::identity(Tensor input, Tensor output) {
|
||||
Operator GraphBuilderObj::identity(Tensor input, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
auto op = g->addOpWithOutputs<IdentityObj>(i0, o0);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::identity(Tensor input) {
|
||||
Operator GraphBuilderObj::identity(Tensor input) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<IdentityObj>(i0, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::softmax(Tensor input, Tensor output) {
|
||||
Operator GraphBuilderObj::softmax(Tensor input, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
auto op = g->addOpWithOutputs<SoftmaxObj>(i0, o0);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::softmax(Tensor input) {
|
||||
Operator GraphBuilderObj::softmax(Tensor input) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<SoftmaxObj>(i0, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::relu(Tensor input, Tensor output) {
|
||||
Operator GraphBuilderObj::relu(Tensor input, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
auto op = g->addOpWithOutputs<ReluObj>(i0, o0);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::relu(Tensor input) {
|
||||
Operator GraphBuilderObj::relu(Tensor input) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<ReluObj>(i0, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::sigmoid(Tensor input, Tensor output) {
|
||||
Operator GraphBuilderObj::sigmoid(Tensor input, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
auto op = g->addOpWithOutputs<SigmoidObj>(i0, o0);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::sigmoid(Tensor input) {
|
||||
Operator GraphBuilderObj::sigmoid(Tensor input) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<SigmoidObj>(i0, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::tanh(Tensor input, Tensor output) {
|
||||
Operator GraphBuilderObj::tanh(Tensor input, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
auto op = g->addOpWithOutputs<TanhObj>(i0, o0);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::tanh(Tensor input) {
|
||||
Operator GraphBuilderObj::tanh(Tensor input) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<TanhObj>(i0, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::abs(Tensor input, Tensor output) {
|
||||
Operator GraphBuilderObj::abs(Tensor input, Tensor output) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||
auto op = g->addOpWithOutputs<AbsObj>(i0, o0);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::abs(Tensor input) {
|
||||
Operator GraphBuilderObj::abs(Tensor input) {
|
||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||
auto op = g->addOp<AbsObj>(i0, nullptr);
|
||||
return op;
|
||||
}
|
||||
|
||||
Operator GraphFactoryObj::memBound(const TensorVec &inputs,
|
||||
Operator GraphBuilderObj::memBound(const TensorVec &inputs,
|
||||
const TensorVec &outputs,
|
||||
const std::vector<nnet::Tensor> &nnetInputs,
|
||||
nnet::Expr expr, double exec_time,
|
||||
|
@ -516,4 +516,4 @@ Operator GraphFactoryObj::memBound(const TensorVec &inputs,
|
|||
return op;
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
} // namespace infini
|
|
@ -71,14 +71,6 @@ bool OperatorObj::checkValid(GraphObj *graph) {
|
|||
}
|
||||
} else { // if outputs have been created, check their shapes
|
||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||
printf("|* i = %ld *|\n", i);
|
||||
printf("shapes:\n");
|
||||
for (auto shape : shapes[i])
|
||||
printf("%d ", shape);
|
||||
printf("\n");
|
||||
for (auto dim : outputs[i]->getDims())
|
||||
printf("%d ", dim);
|
||||
printf("\n");
|
||||
if (shapes[i] != outputs[i]->getDims())
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
#ifdef USE_CUDA
|
||||
#include "cuda/operator_timer.h"
|
||||
#endif
|
||||
#include "core/graph_factory.h"
|
||||
#include "core/graph_builder.h"
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace infini {
|
||||
|
@ -21,7 +21,7 @@ void register_operator_timer(py::module &m) {
|
|||
#endif
|
||||
}
|
||||
|
||||
void init_graph_factory(py::module &m) {
|
||||
void init_graph_builder(py::module &m) {
|
||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
m, "CpuRuntimeObj")
|
||||
|
@ -75,112 +75,112 @@ void init_graph_factory(py::module &m) {
|
|||
py::class_<AbsObj, std::shared_ptr<AbsObj>, OperatorObj>(m, "AbsObj");
|
||||
py::class_<MemBoundObj, std::shared_ptr<MemBoundObj>, OperatorObj>(
|
||||
m, "MemBoundObj");
|
||||
py::class_<GraphFactory>(m, "GraphFactory");
|
||||
py::class_<GraphFactoryObj>(m, "GraphFactoryObj")
|
||||
py::class_<GraphBuilder>(m, "GraphBuilder");
|
||||
py::class_<GraphBuilderObj>(m, "GraphBuilderObj")
|
||||
.def(py::init<Runtime>())
|
||||
.def("tensor",
|
||||
py::overload_cast<Shape, const std::string &>(
|
||||
&GraphFactoryObj::tensor),
|
||||
&GraphBuilderObj::tensor),
|
||||
policy::reference_internal)
|
||||
.def("conv",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int,
|
||||
int, Tensor>(&GraphFactoryObj::conv),
|
||||
int, Tensor>(&GraphBuilderObj::conv),
|
||||
policy::reference_internal)
|
||||
.def("matmul",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||
ActType>(&GraphFactoryObj::matmul),
|
||||
ActType>(&GraphBuilderObj::matmul),
|
||||
policy::reference_internal)
|
||||
.def("convTrans",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int,
|
||||
int, int, int, int, Tensor, ActType>(
|
||||
&GraphFactoryObj::convTrans),
|
||||
&GraphBuilderObj::convTrans),
|
||||
policy::reference_internal)
|
||||
.def("g2bmm",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, const int, const int,
|
||||
Tensor, ActType>(&GraphFactoryObj::g2bmm),
|
||||
Tensor, ActType>(&GraphBuilderObj::g2bmm),
|
||||
policy::reference_internal)
|
||||
.def("gbmml",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, const int, Tensor,
|
||||
ActType>(&GraphFactoryObj::gbmml),
|
||||
ActType>(&GraphBuilderObj::gbmml),
|
||||
policy::reference_internal)
|
||||
|
||||
.def("pad",
|
||||
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
||||
const optional<const vector<int>> &>(
|
||||
&GraphFactoryObj::pad),
|
||||
&GraphBuilderObj::pad),
|
||||
policy::reference_internal)
|
||||
.def("slice",
|
||||
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
||||
const vector<int> &,
|
||||
const optional<const vector<int>> &,
|
||||
const optional<const vector<int>> &>(
|
||||
&GraphFactoryObj::slice),
|
||||
&GraphBuilderObj::slice),
|
||||
policy::reference_internal)
|
||||
.def(
|
||||
"concat",
|
||||
py::overload_cast<TensorVec, Tensor, int>(&GraphFactoryObj::concat),
|
||||
py::overload_cast<TensorVec, Tensor, int>(&GraphBuilderObj::concat),
|
||||
policy::reference_internal)
|
||||
.def("split",
|
||||
py::overload_cast<Tensor, std::optional<TensorVec>, int, int>(
|
||||
&GraphFactoryObj::split),
|
||||
&GraphBuilderObj::split),
|
||||
policy::reference_internal)
|
||||
.def("extend",
|
||||
py::overload_cast<Tensor, Tensor, int, int>(
|
||||
&GraphFactoryObj::extend),
|
||||
&GraphBuilderObj::extend),
|
||||
policy::reference_internal)
|
||||
.def("maxpool",
|
||||
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
||||
int, int>(&GraphFactoryObj::maxpool),
|
||||
int, int>(&GraphBuilderObj::maxpool),
|
||||
policy::reference_internal)
|
||||
.def("avgpool",
|
||||
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
||||
int, int>(&GraphFactoryObj::avgpool),
|
||||
int, int>(&GraphBuilderObj::avgpool),
|
||||
policy::reference_internal)
|
||||
.def("add",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::add),
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::add),
|
||||
policy::reference_internal)
|
||||
.def("sub",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::sub),
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::sub),
|
||||
policy::reference_internal)
|
||||
.def("mul",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::mul),
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::mul),
|
||||
policy::reference_internal)
|
||||
.def("div",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::div),
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::div),
|
||||
policy::reference_internal)
|
||||
.def("pow",
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::pow),
|
||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::pow),
|
||||
policy::reference_internal)
|
||||
.def("gather",
|
||||
py::overload_cast<Tensor, Tensor, Tensor, int>(
|
||||
&GraphFactoryObj::gather),
|
||||
&GraphBuilderObj::gather),
|
||||
policy::reference_internal)
|
||||
.def("reshape",
|
||||
py::overload_cast<Tensor, Tensor, const Shape &>(
|
||||
&GraphFactoryObj::reshape),
|
||||
&GraphBuilderObj::reshape),
|
||||
policy::reference_internal)
|
||||
.def("flatten",
|
||||
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::flatten),
|
||||
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::flatten),
|
||||
policy::reference_internal)
|
||||
.def("identity",
|
||||
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::identity),
|
||||
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::identity),
|
||||
policy::reference_internal)
|
||||
.def("softmax",
|
||||
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::softmax),
|
||||
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::softmax),
|
||||
policy::reference_internal)
|
||||
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::relu),
|
||||
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::relu),
|
||||
policy::reference_internal)
|
||||
.def("sigmoid",
|
||||
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::sigmoid),
|
||||
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::sigmoid),
|
||||
policy::reference_internal)
|
||||
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::tanh),
|
||||
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::tanh),
|
||||
policy::reference_internal)
|
||||
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::abs),
|
||||
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::abs),
|
||||
policy::reference_internal)
|
||||
.def("memBound",
|
||||
py::overload_cast<const TensorVec &, const TensorVec &,
|
||||
const std::vector<nnet::Tensor> &, nnet::Expr,
|
||||
double, std::string>(&GraphFactoryObj::memBound),
|
||||
double, std::string>(&GraphBuilderObj::memBound),
|
||||
policy::reference_internal);
|
||||
}
|
||||
|
||||
|
@ -188,5 +188,5 @@ void init_graph_factory(py::module &m) {
|
|||
|
||||
PYBIND11_MODULE(pyinfinitensor, m) {
|
||||
infini::register_operator_timer(m);
|
||||
infini::init_graph_factory(m);
|
||||
}
|
||||
infini::init_graph_builder(m);
|
||||
}
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
#include "core/graph_factory.h"
|
||||
#include "core/graph_builder.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(GraphFactory, ops) {
|
||||
TEST(GraphBuilder, ops) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
{ // conv without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
|
||||
auto weight =
|
||||
|
@ -15,7 +15,7 @@ TEST(GraphFactory, ops) {
|
|||
EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 4, 4}));
|
||||
}
|
||||
{ // conv with output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
|
||||
auto weight =
|
||||
|
@ -25,21 +25,21 @@ TEST(GraphFactory, ops) {
|
|||
auto conv = gf->conv(input, weight, output, 1, 1);
|
||||
}
|
||||
{ // matmul without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime);
|
||||
auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime);
|
||||
auto matmul = gf->matmul(A, B);
|
||||
EXPECT_EQ(matmul->getOutput()->getDims(), (Shape{1, 3, 2}));
|
||||
}
|
||||
{ // matmul with output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime);
|
||||
auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime);
|
||||
auto C = make_ref<TensorObj>(Shape{1, 3, 2}, DataType::UInt32, runtime);
|
||||
auto matmul = gf->matmul(A, B, C);
|
||||
}
|
||||
{ // convtrans without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime);
|
||||
auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2},
|
||||
|
@ -48,7 +48,7 @@ TEST(GraphFactory, ops) {
|
|||
EXPECT_EQ(convtrans->getOutput()->getDims(), (Shape{1, 448, 2, 2}));
|
||||
}
|
||||
{ // convtrans with output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime);
|
||||
auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2},
|
||||
|
@ -58,7 +58,7 @@ TEST(GraphFactory, ops) {
|
|||
auto convtrans = gf->convTrans(input, weight, 0, 0);
|
||||
}
|
||||
{ // pad without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
|
||||
DataType::UInt32, runtime);
|
||||
vector<int> pads = {2, 10, 1, 5, 0, 10, 1, 5};
|
||||
|
@ -66,7 +66,7 @@ TEST(GraphFactory, ops) {
|
|||
EXPECT_EQ(pad->getOutput()->getDims(), (Shape{3, 84, 164, 172}));
|
||||
}
|
||||
{ // pad with output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
|
||||
DataType::UInt32, runtime);
|
||||
auto output = make_ref<TensorObj>(Shape{3, 84, 164, 172},
|
||||
|
@ -75,7 +75,7 @@ TEST(GraphFactory, ops) {
|
|||
auto pad = gf->pad(input, output, pads, std::nullopt);
|
||||
}
|
||||
{ // slice without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162},
|
||||
DataType::UInt32, runtime);
|
||||
vector<int> starts = {2, 10, 1, 5};
|
||||
|
@ -84,7 +84,7 @@ TEST(GraphFactory, ops) {
|
|||
EXPECT_EQ(slice->getOutput()->getDims(), (Shape{2, 1, 100, 96}));
|
||||
}
|
||||
{ // slice with output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162},
|
||||
DataType::UInt32, runtime);
|
||||
auto output = make_ref<TensorObj>(Shape{2, 1, 100, 96},
|
||||
|
@ -95,7 +95,7 @@ TEST(GraphFactory, ops) {
|
|||
gf->slice(input, output, starts, ends, std::nullopt, std::nullopt);
|
||||
}
|
||||
{ // concat without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto t1 =
|
||||
make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime);
|
||||
auto t2 =
|
||||
|
@ -104,7 +104,7 @@ TEST(GraphFactory, ops) {
|
|||
EXPECT_EQ(concat->getOutput()->getDims(), (Shape{1, 3, 2, 9}));
|
||||
}
|
||||
{ // concat with output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto t1 =
|
||||
make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime);
|
||||
auto t2 =
|
||||
|
@ -114,7 +114,7 @@ TEST(GraphFactory, ops) {
|
|||
auto concat = gf->concat(TensorVec{t1, t2}, o0, 3);
|
||||
}
|
||||
{ // split without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime);
|
||||
auto split = gf->split(input, 3, 4);
|
||||
|
@ -126,7 +126,7 @@ TEST(GraphFactory, ops) {
|
|||
EXPECT_EQ(split->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
|
||||
}
|
||||
{ // split with output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime);
|
||||
auto output0 =
|
||||
|
@ -141,14 +141,14 @@ TEST(GraphFactory, ops) {
|
|||
input, TensorVec{output0, output1, output2, output3}, 3, 4);
|
||||
}
|
||||
{ // extend without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
|
||||
auto extend = gf->extend(input, 2, 1);
|
||||
EXPECT_EQ(extend->getOutput()->getDims(), (Shape{2, 3, 6, 4}));
|
||||
}
|
||||
{ // extend with output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
|
||||
auto output =
|
||||
|
@ -156,7 +156,7 @@ TEST(GraphFactory, ops) {
|
|||
auto extend = gf->extend(input, output, 2, 1);
|
||||
}
|
||||
{ // maxpool without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
|
||||
DataType::UInt32, runtime);
|
||||
const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2,
|
||||
|
@ -165,7 +165,7 @@ TEST(GraphFactory, ops) {
|
|||
EXPECT_EQ(maxpool->getOutput()->getDims(), (Shape{1, 64, 80, 80}));
|
||||
}
|
||||
{ // maxpool with output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
|
||||
DataType::UInt32, runtime);
|
||||
auto output = make_ref<TensorObj>(Shape{1, 64, 80, 80},
|
||||
|
@ -176,7 +176,7 @@ TEST(GraphFactory, ops) {
|
|||
gf->maxpool(input, output, kh, kw, dh, dw, ph, pw, sh, sw);
|
||||
}
|
||||
{ // add without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input0 =
|
||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
|
||||
auto input1 =
|
||||
|
@ -185,7 +185,7 @@ TEST(GraphFactory, ops) {
|
|||
EXPECT_EQ(add->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
|
||||
}
|
||||
{ // add with output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input0 =
|
||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
|
||||
auto input1 =
|
||||
|
@ -195,7 +195,7 @@ TEST(GraphFactory, ops) {
|
|||
auto add = gf->add(input0, input1, output);
|
||||
}
|
||||
{ // gather without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
|
||||
auto index =
|
||||
|
@ -204,7 +204,7 @@ TEST(GraphFactory, ops) {
|
|||
EXPECT_EQ(gather->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
|
||||
}
|
||||
{ // gather with output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
|
||||
auto index =
|
||||
|
@ -214,7 +214,7 @@ TEST(GraphFactory, ops) {
|
|||
auto gather = gf->gather(input, index, output, 1);
|
||||
}
|
||||
{ // reshape without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
||||
vector<int> dims = {3, 2, 4, 3};
|
||||
|
@ -222,7 +222,7 @@ TEST(GraphFactory, ops) {
|
|||
EXPECT_EQ(reshape->getOutput()->getDims(), (Shape{3, 2, 4, 3}));
|
||||
}
|
||||
{ // reshape with output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
||||
vector<int> dims = {3, 2, 4, 3};
|
||||
|
@ -231,14 +231,14 @@ TEST(GraphFactory, ops) {
|
|||
auto reshape = gf->reshape(input, output, dims);
|
||||
}
|
||||
{ // flatten without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
||||
auto flatten = gf->flatten(input);
|
||||
EXPECT_EQ(flatten->getOutput()->getDims(), (Shape{72}));
|
||||
}
|
||||
{ // flatten without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
||||
auto output =
|
||||
|
@ -246,14 +246,14 @@ TEST(GraphFactory, ops) {
|
|||
auto flatten = gf->flatten(input, output);
|
||||
}
|
||||
{ // identity without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
||||
auto identity = gf->identity(input);
|
||||
EXPECT_EQ(identity->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
|
||||
}
|
||||
{ // identity without output
|
||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
||||
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||
auto input =
|
||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
||||
auto output =
|
||||
|
@ -262,4 +262,4 @@ TEST(GraphFactory, ops) {
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
} // namespace infini
|
Loading…
Reference in New Issue