forked from jiuyuan/InfiniTensor
Update: Rename GraphFactory -> GraphBuilder && Remove unnecessary outputs
This commit is contained in:
parent
7cf2d8f78f
commit
970c77d0f4
|
@ -21,12 +21,12 @@
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
class GraphFactoryObj {
|
class GraphBuilderObj {
|
||||||
private:
|
private:
|
||||||
Graph g;
|
Graph g;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
GraphFactoryObj(Runtime runtime) : g(make_ref<GraphObj>(runtime)) {}
|
GraphBuilderObj(Runtime runtime) : g(make_ref<GraphObj>(runtime)) {}
|
||||||
|
|
||||||
// tensors
|
// tensors
|
||||||
Tensor tensor(Shape dim, const std::string &dtype);
|
Tensor tensor(Shape dim, const std::string &dtype);
|
|
@ -10,7 +10,7 @@ class TensorBaseObj;
|
||||||
class TensorObj;
|
class TensorObj;
|
||||||
class OperatorObj;
|
class OperatorObj;
|
||||||
class GraphObj;
|
class GraphObj;
|
||||||
class GraphFactoryObj;
|
class GraphBuilderObj;
|
||||||
class RuntimeObj;
|
class RuntimeObj;
|
||||||
class BlobObj;
|
class BlobObj;
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ using TensorBase = Ref<TensorBaseObj>;
|
||||||
using Tensor = Ref<TensorObj>;
|
using Tensor = Ref<TensorObj>;
|
||||||
using Operator = Ref<OperatorObj>;
|
using Operator = Ref<OperatorObj>;
|
||||||
using Graph = Ref<GraphObj>;
|
using Graph = Ref<GraphObj>;
|
||||||
using GraphFactory = Ref<GraphFactoryObj>;
|
using GraphBuilder = Ref<GraphBuilderObj>;
|
||||||
using Runtime = Ref<RuntimeObj>;
|
using Runtime = Ref<RuntimeObj>;
|
||||||
using Blob = Ref<BlobObj>;
|
using Blob = Ref<BlobObj>;
|
||||||
enum class OpType;
|
enum class OpType;
|
||||||
|
|
|
@ -122,7 +122,7 @@ def _onnx_datatype_tostring(dtype):
|
||||||
assert False, 'Unknown onnx datatype'
|
assert False, 'Unknown onnx datatype'
|
||||||
|
|
||||||
|
|
||||||
def import_onnx(gf: GraphFactory, net: str):
|
def import_onnx(gf: GraphBuilder, net: str):
|
||||||
ts, ds, ops, consts = dict(), dict(), dict(), dict() # (key, value) = (name, class)
|
ts, ds, ops, consts = dict(), dict(), dict(), dict() # (key, value) = (name, class)
|
||||||
model = onnx.load(net)
|
model = onnx.load(net)
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,8 @@ import sys
|
||||||
|
|
||||||
def main(netPath):
|
def main(netPath):
|
||||||
runtime = CpuRuntimeObj.getInstance()
|
runtime = CpuRuntimeObj.getInstance()
|
||||||
graphFactory = GraphFactoryObj(runtime)
|
graphBuilder = GraphBuilderObj(runtime)
|
||||||
import_onnx(graphFactory, netPath)
|
import_onnx(graphBuilder, netPath)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main(sys.argv[1])
|
main(sys.argv[1])
|
|
@ -1,8 +1,8 @@
|
||||||
#include "core/graph_factory.h"
|
#include "core/graph_builder.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
Tensor GraphFactoryObj::tensor(Shape dim, const std::string &dtype) {
|
Tensor GraphBuilderObj::tensor(Shape dim, const std::string &dtype) {
|
||||||
if (dtype == "FLOAT") {
|
if (dtype == "FLOAT") {
|
||||||
return g->addTensor(dim, DataType::Float32);
|
return g->addTensor(dim, DataType::Float32);
|
||||||
}
|
}
|
||||||
|
@ -12,7 +12,7 @@ Tensor GraphFactoryObj::tensor(Shape dim, const std::string &dtype) {
|
||||||
IT_TODO_HALT_MSG("Unsupported data type");
|
IT_TODO_HALT_MSG("Unsupported data type");
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
|
Operator GraphBuilderObj::conv(Tensor input, Tensor weight, Tensor output,
|
||||||
int ph, int pw, int sh, int sw, int dh, int dw,
|
int ph, int pw, int sh, int sw, int dh, int dw,
|
||||||
Tensor bias) {
|
Tensor bias) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
|
@ -23,7 +23,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, int ph, int pw,
|
Operator GraphBuilderObj::conv(Tensor input, Tensor weight, int ph, int pw,
|
||||||
int sh, int sw, int dh, int dw, Tensor bias) {
|
int sh, int sw, int dh, int dw, Tensor bias) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
|
Tensor w0 = g->addTensor(weight->getDims(), weight->getDType());
|
||||||
|
@ -31,7 +31,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, int ph, int pw,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
|
Operator GraphBuilderObj::conv(Tensor input, Tensor weight, Tensor output,
|
||||||
ConvBaseObj::PaddingMode pm, int sh, int sw,
|
ConvBaseObj::PaddingMode pm, int sh, int sw,
|
||||||
int dh, int dw, Tensor bias) {
|
int dh, int dw, Tensor bias) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
|
@ -42,7 +42,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight, Tensor output,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::conv(Tensor input, Tensor weight,
|
Operator GraphBuilderObj::conv(Tensor input, Tensor weight,
|
||||||
ConvBaseObj::PaddingMode pm, int sh, int sw,
|
ConvBaseObj::PaddingMode pm, int sh, int sw,
|
||||||
int dh, int dw, Tensor bias) {
|
int dh, int dw, Tensor bias) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
|
@ -51,7 +51,7 @@ Operator GraphFactoryObj::conv(Tensor input, Tensor weight,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::matmul(Tensor A, Tensor B, Tensor C, bool transA,
|
Operator GraphBuilderObj::matmul(Tensor A, Tensor B, Tensor C, bool transA,
|
||||||
bool transB, Tensor bias, ActType act) {
|
bool transB, Tensor bias, ActType act) {
|
||||||
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
||||||
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
||||||
|
@ -61,7 +61,7 @@ Operator GraphFactoryObj::matmul(Tensor A, Tensor B, Tensor C, bool transA,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::matmul(Tensor A, Tensor B, bool transA, bool transB,
|
Operator GraphBuilderObj::matmul(Tensor A, Tensor B, bool transA, bool transB,
|
||||||
Tensor bias, ActType act) {
|
Tensor bias, ActType act) {
|
||||||
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
||||||
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
||||||
|
@ -69,7 +69,7 @@ Operator GraphFactoryObj::matmul(Tensor A, Tensor B, bool transA, bool transB,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
|
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, Tensor output,
|
||||||
int ph, int pw, int sh, int sw, int dh,
|
int ph, int pw, int sh, int sw, int dh,
|
||||||
int dw, int oph, int opw, int group,
|
int dw, int oph, int opw, int group,
|
||||||
Tensor bias, ActType act) {
|
Tensor bias, ActType act) {
|
||||||
|
@ -81,7 +81,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, int ph, int pw,
|
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, int ph, int pw,
|
||||||
int sh, int sw, int dh, int dw, int oph,
|
int sh, int sw, int dh, int dw, int oph,
|
||||||
int opw, int group, Tensor bias,
|
int opw, int group, Tensor bias,
|
||||||
ActType act) {
|
ActType act) {
|
||||||
|
@ -92,7 +92,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, int ph, int pw,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
|
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight, Tensor output,
|
||||||
ConvBaseObj::PaddingMode pm, int sh, int sw,
|
ConvBaseObj::PaddingMode pm, int sh, int sw,
|
||||||
int dh, int dw, int oph, int opw, int group,
|
int dh, int dw, int oph, int opw, int group,
|
||||||
Tensor bias, ActType act) {
|
Tensor bias, ActType act) {
|
||||||
|
@ -104,7 +104,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight, Tensor output,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight,
|
Operator GraphBuilderObj::convTrans(Tensor input, Tensor weight,
|
||||||
ConvBaseObj::PaddingMode pm, int sh, int sw,
|
ConvBaseObj::PaddingMode pm, int sh, int sw,
|
||||||
int dh, int dw, int oph, int opw, int group,
|
int dh, int dw, int oph, int opw, int group,
|
||||||
Tensor bias, ActType act) {
|
Tensor bias, ActType act) {
|
||||||
|
@ -115,7 +115,7 @@ Operator GraphFactoryObj::convTrans(Tensor input, Tensor weight,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width,
|
Operator GraphBuilderObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width,
|
||||||
const int dilation, Tensor bias, ActType act) {
|
const int dilation, Tensor bias, ActType act) {
|
||||||
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
||||||
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
||||||
|
@ -125,7 +125,7 @@ Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, Tensor C, const int width,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, const int width,
|
Operator GraphBuilderObj::g2bmm(Tensor A, Tensor B, const int width,
|
||||||
const int dilation, Tensor bias, ActType act) {
|
const int dilation, Tensor bias, ActType act) {
|
||||||
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
||||||
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
||||||
|
@ -133,7 +133,7 @@ Operator GraphFactoryObj::g2bmm(Tensor A, Tensor B, const int width,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, Tensor C,
|
Operator GraphBuilderObj::gbmml(Tensor A, Tensor B, Tensor C,
|
||||||
const int dilation, Tensor bias, ActType act) {
|
const int dilation, Tensor bias, ActType act) {
|
||||||
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
||||||
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
||||||
|
@ -142,7 +142,7 @@ Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, Tensor C,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, const int dilation,
|
Operator GraphBuilderObj::gbmml(Tensor A, Tensor B, const int dilation,
|
||||||
Tensor bias, ActType act) {
|
Tensor bias, ActType act) {
|
||||||
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
Tensor i0 = g->addTensor(A->getDims(), A->getDType());
|
||||||
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
Tensor i1 = g->addTensor(B->getDims(), B->getDType());
|
||||||
|
@ -150,7 +150,7 @@ Operator GraphFactoryObj::gbmml(Tensor A, Tensor B, const int dilation,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::pad(Tensor input, Tensor output,
|
Operator GraphBuilderObj::pad(Tensor input, Tensor output,
|
||||||
const vector<int> &pads,
|
const vector<int> &pads,
|
||||||
const optional<const vector<int>> &axis) {
|
const optional<const vector<int>> &axis) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
|
@ -159,14 +159,14 @@ Operator GraphFactoryObj::pad(Tensor input, Tensor output,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::pad(Tensor input, const vector<int> &pads,
|
Operator GraphBuilderObj::pad(Tensor input, const vector<int> &pads,
|
||||||
const optional<const vector<int>> &axis) {
|
const optional<const vector<int>> &axis) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<PadObj>(i0, nullptr, pads, axis);
|
auto op = g->addOp<PadObj>(i0, nullptr, pads, axis);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::slice(Tensor input, Tensor output,
|
Operator GraphBuilderObj::slice(Tensor input, Tensor output,
|
||||||
const vector<int> &starts,
|
const vector<int> &starts,
|
||||||
const vector<int> &ends,
|
const vector<int> &ends,
|
||||||
const optional<const vector<int>> &axis,
|
const optional<const vector<int>> &axis,
|
||||||
|
@ -177,7 +177,7 @@ Operator GraphFactoryObj::slice(Tensor input, Tensor output,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::slice(Tensor input, const vector<int> &starts,
|
Operator GraphBuilderObj::slice(Tensor input, const vector<int> &starts,
|
||||||
const vector<int> &ends,
|
const vector<int> &ends,
|
||||||
const optional<const vector<int>> &axis,
|
const optional<const vector<int>> &axis,
|
||||||
const optional<const vector<int>> &steps) {
|
const optional<const vector<int>> &steps) {
|
||||||
|
@ -186,7 +186,7 @@ Operator GraphFactoryObj::slice(Tensor input, const vector<int> &starts,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::concat(TensorVec inputs, Tensor output, int dim) {
|
Operator GraphBuilderObj::concat(TensorVec inputs, Tensor output, int dim) {
|
||||||
TensorVec is;
|
TensorVec is;
|
||||||
for (auto input : inputs) {
|
for (auto input : inputs) {
|
||||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||||
|
@ -197,7 +197,7 @@ Operator GraphFactoryObj::concat(TensorVec inputs, Tensor output, int dim) {
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::concat(TensorVec inputs, int dim) {
|
Operator GraphBuilderObj::concat(TensorVec inputs, int dim) {
|
||||||
TensorVec is;
|
TensorVec is;
|
||||||
for (auto input : inputs) {
|
for (auto input : inputs) {
|
||||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||||
|
@ -207,7 +207,7 @@ Operator GraphFactoryObj::concat(TensorVec inputs, int dim) {
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
|
Operator GraphBuilderObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||||
int dim, int num) {
|
int dim, int num) {
|
||||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||||
if (outputs.has_value()) {
|
if (outputs.has_value()) {
|
||||||
|
@ -224,13 +224,13 @@ Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::split(Tensor input, int dim, int num) {
|
Operator GraphBuilderObj::split(Tensor input, int dim, int num) {
|
||||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, num);
|
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, num);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
|
Operator GraphBuilderObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||||
int dim, const vector<int> &ratio) {
|
int dim, const vector<int> &ratio) {
|
||||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||||
if (outputs.has_value()) {
|
if (outputs.has_value()) {
|
||||||
|
@ -247,14 +247,14 @@ Operator GraphFactoryObj::split(Tensor input, std::optional<TensorVec> outputs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::split(Tensor input, int dim,
|
Operator GraphBuilderObj::split(Tensor input, int dim,
|
||||||
const vector<int> &ratio) {
|
const vector<int> &ratio) {
|
||||||
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
Tensor i = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, ratio);
|
auto op = g->addOp<SplitObj>(i, std::nullopt, dim, ratio);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::extend(Tensor input, Tensor output, int dim,
|
Operator GraphBuilderObj::extend(Tensor input, Tensor output, int dim,
|
||||||
int num) {
|
int num) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
|
@ -262,13 +262,13 @@ Operator GraphFactoryObj::extend(Tensor input, Tensor output, int dim,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::extend(Tensor input, int dim, int num) {
|
Operator GraphBuilderObj::extend(Tensor input, int dim, int num) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<ExtendObj>(i0, nullptr, dim, num);
|
auto op = g->addOp<ExtendObj>(i0, nullptr, dim, num);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::maxpool(Tensor input, Tensor output, int kh, int kw,
|
Operator GraphBuilderObj::maxpool(Tensor input, Tensor output, int kh, int kw,
|
||||||
int dh, int dw, int ph, int pw, int sh,
|
int dh, int dw, int ph, int pw, int sh,
|
||||||
int sw) {
|
int sw) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
|
@ -278,14 +278,14 @@ Operator GraphFactoryObj::maxpool(Tensor input, Tensor output, int kh, int kw,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::maxpool(Tensor input, int kh, int kw, int dh, int dw,
|
Operator GraphBuilderObj::maxpool(Tensor input, int kh, int kw, int dh, int dw,
|
||||||
int ph, int pw, int sh, int sw) {
|
int ph, int pw, int sh, int sw) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<MaxPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
|
auto op = g->addOp<MaxPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::avgpool(Tensor input, Tensor output, int kh, int kw,
|
Operator GraphBuilderObj::avgpool(Tensor input, Tensor output, int kh, int kw,
|
||||||
int dh, int dw, int ph, int pw, int sh,
|
int dh, int dw, int ph, int pw, int sh,
|
||||||
int sw) {
|
int sw) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
|
@ -295,14 +295,14 @@ Operator GraphFactoryObj::avgpool(Tensor input, Tensor output, int kh, int kw,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::avgpool(Tensor input, int kh, int kw, int dh, int dw,
|
Operator GraphBuilderObj::avgpool(Tensor input, int kh, int kw, int dh, int dw,
|
||||||
int ph, int pw, int sh, int sw) {
|
int ph, int pw, int sh, int sw) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<AvgPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
|
auto op = g->addOp<AvgPoolObj>(i0, nullptr, kh, kw, dh, dw, ph, pw, sh, sw);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::add(Tensor input0, Tensor input1, Tensor output) {
|
Operator GraphBuilderObj::add(Tensor input0, Tensor input1, Tensor output) {
|
||||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
|
@ -310,14 +310,14 @@ Operator GraphFactoryObj::add(Tensor input0, Tensor input1, Tensor output) {
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::add(Tensor input0, Tensor input1) {
|
Operator GraphBuilderObj::add(Tensor input0, Tensor input1) {
|
||||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||||
auto op = g->addOp<AddObj>(i0, i1, nullptr);
|
auto op = g->addOp<AddObj>(i0, i1, nullptr);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::sub(Tensor input0, Tensor input1, Tensor output) {
|
Operator GraphBuilderObj::sub(Tensor input0, Tensor input1, Tensor output) {
|
||||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
|
@ -325,14 +325,14 @@ Operator GraphFactoryObj::sub(Tensor input0, Tensor input1, Tensor output) {
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::sub(Tensor input0, Tensor input1) {
|
Operator GraphBuilderObj::sub(Tensor input0, Tensor input1) {
|
||||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||||
auto op = g->addOp<SubObj>(i0, i1, nullptr);
|
auto op = g->addOp<SubObj>(i0, i1, nullptr);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::mul(Tensor input0, Tensor input1, Tensor output) {
|
Operator GraphBuilderObj::mul(Tensor input0, Tensor input1, Tensor output) {
|
||||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
|
@ -340,14 +340,14 @@ Operator GraphFactoryObj::mul(Tensor input0, Tensor input1, Tensor output) {
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::mul(Tensor input0, Tensor input1) {
|
Operator GraphBuilderObj::mul(Tensor input0, Tensor input1) {
|
||||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||||
auto op = g->addOp<SubObj>(i0, i1, nullptr);
|
auto op = g->addOp<SubObj>(i0, i1, nullptr);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::div(Tensor input0, Tensor input1, Tensor output) {
|
Operator GraphBuilderObj::div(Tensor input0, Tensor input1, Tensor output) {
|
||||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
|
@ -355,14 +355,14 @@ Operator GraphFactoryObj::div(Tensor input0, Tensor input1, Tensor output) {
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::div(Tensor input0, Tensor input1) {
|
Operator GraphBuilderObj::div(Tensor input0, Tensor input1) {
|
||||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||||
auto op = g->addOp<DivObj>(i0, i1, nullptr);
|
auto op = g->addOp<DivObj>(i0, i1, nullptr);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::pow(Tensor input0, Tensor input1, Tensor output) {
|
Operator GraphBuilderObj::pow(Tensor input0, Tensor input1, Tensor output) {
|
||||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
|
@ -370,14 +370,14 @@ Operator GraphFactoryObj::pow(Tensor input0, Tensor input1, Tensor output) {
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::pow(Tensor input0, Tensor input1) {
|
Operator GraphBuilderObj::pow(Tensor input0, Tensor input1) {
|
||||||
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
Tensor i0 = g->addTensor(input0->getDims(), input0->getDType());
|
||||||
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
Tensor i1 = g->addTensor(input1->getDims(), input1->getDType());
|
||||||
auto op = g->addOp<PowObj>(i0, i1, nullptr);
|
auto op = g->addOp<PowObj>(i0, i1, nullptr);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::gather(Tensor input, Tensor index, Tensor output,
|
Operator GraphBuilderObj::gather(Tensor input, Tensor index, Tensor output,
|
||||||
int axis) {
|
int axis) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
|
@ -385,13 +385,13 @@ Operator GraphFactoryObj::gather(Tensor input, Tensor index, Tensor output,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::gather(Tensor input, Tensor index, int axis) {
|
Operator GraphBuilderObj::gather(Tensor input, Tensor index, int axis) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<GatherObj>(i0, index, nullptr, axis);
|
auto op = g->addOp<GatherObj>(i0, index, nullptr, axis);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::reshape(Tensor input, Tensor output,
|
Operator GraphBuilderObj::reshape(Tensor input, Tensor output,
|
||||||
const Shape &dims) {
|
const Shape &dims) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
|
@ -399,104 +399,104 @@ Operator GraphFactoryObj::reshape(Tensor input, Tensor output,
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::reshape(Tensor input, const Shape &dims) {
|
Operator GraphBuilderObj::reshape(Tensor input, const Shape &dims) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<ReshapeObj>(i0, nullptr, dims);
|
auto op = g->addOp<ReshapeObj>(i0, nullptr, dims);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::flatten(Tensor input, Tensor output) {
|
Operator GraphBuilderObj::flatten(Tensor input, Tensor output) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
auto op = g->addOpWithOutputs<FlattenObj>(i0, o0);
|
auto op = g->addOpWithOutputs<FlattenObj>(i0, o0);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::flatten(Tensor input) {
|
Operator GraphBuilderObj::flatten(Tensor input) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<FlattenObj>(i0, nullptr);
|
auto op = g->addOp<FlattenObj>(i0, nullptr);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::identity(Tensor input, Tensor output) {
|
Operator GraphBuilderObj::identity(Tensor input, Tensor output) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
auto op = g->addOpWithOutputs<IdentityObj>(i0, o0);
|
auto op = g->addOpWithOutputs<IdentityObj>(i0, o0);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::identity(Tensor input) {
|
Operator GraphBuilderObj::identity(Tensor input) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<IdentityObj>(i0, nullptr);
|
auto op = g->addOp<IdentityObj>(i0, nullptr);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::softmax(Tensor input, Tensor output) {
|
Operator GraphBuilderObj::softmax(Tensor input, Tensor output) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
auto op = g->addOpWithOutputs<SoftmaxObj>(i0, o0);
|
auto op = g->addOpWithOutputs<SoftmaxObj>(i0, o0);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::softmax(Tensor input) {
|
Operator GraphBuilderObj::softmax(Tensor input) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<SoftmaxObj>(i0, nullptr);
|
auto op = g->addOp<SoftmaxObj>(i0, nullptr);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::relu(Tensor input, Tensor output) {
|
Operator GraphBuilderObj::relu(Tensor input, Tensor output) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
auto op = g->addOpWithOutputs<ReluObj>(i0, o0);
|
auto op = g->addOpWithOutputs<ReluObj>(i0, o0);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::relu(Tensor input) {
|
Operator GraphBuilderObj::relu(Tensor input) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<ReluObj>(i0, nullptr);
|
auto op = g->addOp<ReluObj>(i0, nullptr);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::sigmoid(Tensor input, Tensor output) {
|
Operator GraphBuilderObj::sigmoid(Tensor input, Tensor output) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
auto op = g->addOpWithOutputs<SigmoidObj>(i0, o0);
|
auto op = g->addOpWithOutputs<SigmoidObj>(i0, o0);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::sigmoid(Tensor input) {
|
Operator GraphBuilderObj::sigmoid(Tensor input) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<SigmoidObj>(i0, nullptr);
|
auto op = g->addOp<SigmoidObj>(i0, nullptr);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::tanh(Tensor input, Tensor output) {
|
Operator GraphBuilderObj::tanh(Tensor input, Tensor output) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
auto op = g->addOpWithOutputs<TanhObj>(i0, o0);
|
auto op = g->addOpWithOutputs<TanhObj>(i0, o0);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::tanh(Tensor input) {
|
Operator GraphBuilderObj::tanh(Tensor input) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<TanhObj>(i0, nullptr);
|
auto op = g->addOp<TanhObj>(i0, nullptr);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::abs(Tensor input, Tensor output) {
|
Operator GraphBuilderObj::abs(Tensor input, Tensor output) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
Tensor o0 = g->addTensor(output->getDims(), output->getDType());
|
||||||
auto op = g->addOpWithOutputs<AbsObj>(i0, o0);
|
auto op = g->addOpWithOutputs<AbsObj>(i0, o0);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::abs(Tensor input) {
|
Operator GraphBuilderObj::abs(Tensor input) {
|
||||||
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
Tensor i0 = g->addTensor(input->getDims(), input->getDType());
|
||||||
auto op = g->addOp<AbsObj>(i0, nullptr);
|
auto op = g->addOp<AbsObj>(i0, nullptr);
|
||||||
return op;
|
return op;
|
||||||
}
|
}
|
||||||
|
|
||||||
Operator GraphFactoryObj::memBound(const TensorVec &inputs,
|
Operator GraphBuilderObj::memBound(const TensorVec &inputs,
|
||||||
const TensorVec &outputs,
|
const TensorVec &outputs,
|
||||||
const std::vector<nnet::Tensor> &nnetInputs,
|
const std::vector<nnet::Tensor> &nnetInputs,
|
||||||
nnet::Expr expr, double exec_time,
|
nnet::Expr expr, double exec_time,
|
|
@ -71,14 +71,6 @@ bool OperatorObj::checkValid(GraphObj *graph) {
|
||||||
}
|
}
|
||||||
} else { // if outputs have been created, check their shapes
|
} else { // if outputs have been created, check their shapes
|
||||||
for (size_t i = 0; i < shapes.size(); ++i) {
|
for (size_t i = 0; i < shapes.size(); ++i) {
|
||||||
printf("|* i = %ld *|\n", i);
|
|
||||||
printf("shapes:\n");
|
|
||||||
for (auto shape : shapes[i])
|
|
||||||
printf("%d ", shape);
|
|
||||||
printf("\n");
|
|
||||||
for (auto dim : outputs[i]->getDims())
|
|
||||||
printf("%d ", dim);
|
|
||||||
printf("\n");
|
|
||||||
if (shapes[i] != outputs[i]->getDims())
|
if (shapes[i] != outputs[i]->getDims())
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
#include "cuda/operator_timer.h"
|
#include "cuda/operator_timer.h"
|
||||||
#endif
|
#endif
|
||||||
#include "core/graph_factory.h"
|
#include "core/graph_builder.h"
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
@ -21,7 +21,7 @@ void register_operator_timer(py::module &m) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void init_graph_factory(py::module &m) {
|
void init_graph_builder(py::module &m) {
|
||||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
|
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "RuntimeObj");
|
||||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||||
m, "CpuRuntimeObj")
|
m, "CpuRuntimeObj")
|
||||||
|
@ -75,112 +75,112 @@ void init_graph_factory(py::module &m) {
|
||||||
py::class_<AbsObj, std::shared_ptr<AbsObj>, OperatorObj>(m, "AbsObj");
|
py::class_<AbsObj, std::shared_ptr<AbsObj>, OperatorObj>(m, "AbsObj");
|
||||||
py::class_<MemBoundObj, std::shared_ptr<MemBoundObj>, OperatorObj>(
|
py::class_<MemBoundObj, std::shared_ptr<MemBoundObj>, OperatorObj>(
|
||||||
m, "MemBoundObj");
|
m, "MemBoundObj");
|
||||||
py::class_<GraphFactory>(m, "GraphFactory");
|
py::class_<GraphBuilder>(m, "GraphBuilder");
|
||||||
py::class_<GraphFactoryObj>(m, "GraphFactoryObj")
|
py::class_<GraphBuilderObj>(m, "GraphBuilderObj")
|
||||||
.def(py::init<Runtime>())
|
.def(py::init<Runtime>())
|
||||||
.def("tensor",
|
.def("tensor",
|
||||||
py::overload_cast<Shape, const std::string &>(
|
py::overload_cast<Shape, const std::string &>(
|
||||||
&GraphFactoryObj::tensor),
|
&GraphBuilderObj::tensor),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("conv",
|
.def("conv",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int,
|
py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int,
|
||||||
int, Tensor>(&GraphFactoryObj::conv),
|
int, Tensor>(&GraphBuilderObj::conv),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("matmul",
|
.def("matmul",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
py::overload_cast<Tensor, Tensor, Tensor, bool, bool, Tensor,
|
||||||
ActType>(&GraphFactoryObj::matmul),
|
ActType>(&GraphBuilderObj::matmul),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("convTrans",
|
.def("convTrans",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int,
|
py::overload_cast<Tensor, Tensor, Tensor, int, int, int, int, int,
|
||||||
int, int, int, int, Tensor, ActType>(
|
int, int, int, int, Tensor, ActType>(
|
||||||
&GraphFactoryObj::convTrans),
|
&GraphBuilderObj::convTrans),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("g2bmm",
|
.def("g2bmm",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, const int, const int,
|
py::overload_cast<Tensor, Tensor, Tensor, const int, const int,
|
||||||
Tensor, ActType>(&GraphFactoryObj::g2bmm),
|
Tensor, ActType>(&GraphBuilderObj::g2bmm),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("gbmml",
|
.def("gbmml",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, const int, Tensor,
|
py::overload_cast<Tensor, Tensor, Tensor, const int, Tensor,
|
||||||
ActType>(&GraphFactoryObj::gbmml),
|
ActType>(&GraphBuilderObj::gbmml),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
|
|
||||||
.def("pad",
|
.def("pad",
|
||||||
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
||||||
const optional<const vector<int>> &>(
|
const optional<const vector<int>> &>(
|
||||||
&GraphFactoryObj::pad),
|
&GraphBuilderObj::pad),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("slice",
|
.def("slice",
|
||||||
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
py::overload_cast<Tensor, Tensor, const vector<int> &,
|
||||||
const vector<int> &,
|
const vector<int> &,
|
||||||
const optional<const vector<int>> &,
|
const optional<const vector<int>> &,
|
||||||
const optional<const vector<int>> &>(
|
const optional<const vector<int>> &>(
|
||||||
&GraphFactoryObj::slice),
|
&GraphBuilderObj::slice),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def(
|
.def(
|
||||||
"concat",
|
"concat",
|
||||||
py::overload_cast<TensorVec, Tensor, int>(&GraphFactoryObj::concat),
|
py::overload_cast<TensorVec, Tensor, int>(&GraphBuilderObj::concat),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("split",
|
.def("split",
|
||||||
py::overload_cast<Tensor, std::optional<TensorVec>, int, int>(
|
py::overload_cast<Tensor, std::optional<TensorVec>, int, int>(
|
||||||
&GraphFactoryObj::split),
|
&GraphBuilderObj::split),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("extend",
|
.def("extend",
|
||||||
py::overload_cast<Tensor, Tensor, int, int>(
|
py::overload_cast<Tensor, Tensor, int, int>(
|
||||||
&GraphFactoryObj::extend),
|
&GraphBuilderObj::extend),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("maxpool",
|
.def("maxpool",
|
||||||
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
||||||
int, int>(&GraphFactoryObj::maxpool),
|
int, int>(&GraphBuilderObj::maxpool),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("avgpool",
|
.def("avgpool",
|
||||||
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
py::overload_cast<Tensor, Tensor, int, int, int, int, int, int,
|
||||||
int, int>(&GraphFactoryObj::avgpool),
|
int, int>(&GraphBuilderObj::avgpool),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("add",
|
.def("add",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::add),
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::add),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("sub",
|
.def("sub",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::sub),
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::sub),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("mul",
|
.def("mul",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::mul),
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::mul),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("div",
|
.def("div",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::div),
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::div),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("pow",
|
.def("pow",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor>(&GraphFactoryObj::pow),
|
py::overload_cast<Tensor, Tensor, Tensor>(&GraphBuilderObj::pow),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("gather",
|
.def("gather",
|
||||||
py::overload_cast<Tensor, Tensor, Tensor, int>(
|
py::overload_cast<Tensor, Tensor, Tensor, int>(
|
||||||
&GraphFactoryObj::gather),
|
&GraphBuilderObj::gather),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("reshape",
|
.def("reshape",
|
||||||
py::overload_cast<Tensor, Tensor, const Shape &>(
|
py::overload_cast<Tensor, Tensor, const Shape &>(
|
||||||
&GraphFactoryObj::reshape),
|
&GraphBuilderObj::reshape),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("flatten",
|
.def("flatten",
|
||||||
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::flatten),
|
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::flatten),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("identity",
|
.def("identity",
|
||||||
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::identity),
|
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::identity),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("softmax",
|
.def("softmax",
|
||||||
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::softmax),
|
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::softmax),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::relu),
|
.def("relu", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::relu),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("sigmoid",
|
.def("sigmoid",
|
||||||
py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::sigmoid),
|
py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::sigmoid),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::tanh),
|
.def("tanh", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::tanh),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphFactoryObj::abs),
|
.def("abs", py::overload_cast<Tensor, Tensor>(&GraphBuilderObj::abs),
|
||||||
policy::reference_internal)
|
policy::reference_internal)
|
||||||
.def("memBound",
|
.def("memBound",
|
||||||
py::overload_cast<const TensorVec &, const TensorVec &,
|
py::overload_cast<const TensorVec &, const TensorVec &,
|
||||||
const std::vector<nnet::Tensor> &, nnet::Expr,
|
const std::vector<nnet::Tensor> &, nnet::Expr,
|
||||||
double, std::string>(&GraphFactoryObj::memBound),
|
double, std::string>(&GraphBuilderObj::memBound),
|
||||||
policy::reference_internal);
|
policy::reference_internal);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -188,5 +188,5 @@ void init_graph_factory(py::module &m) {
|
||||||
|
|
||||||
PYBIND11_MODULE(pyinfinitensor, m) {
|
PYBIND11_MODULE(pyinfinitensor, m) {
|
||||||
infini::register_operator_timer(m);
|
infini::register_operator_timer(m);
|
||||||
infini::init_graph_factory(m);
|
infini::init_graph_builder(m);
|
||||||
}
|
}
|
|
@ -1,12 +1,12 @@
|
||||||
#include "core/graph_factory.h"
|
#include "core/graph_builder.h"
|
||||||
#include "test.h"
|
#include "test.h"
|
||||||
|
|
||||||
namespace infini {
|
namespace infini {
|
||||||
|
|
||||||
TEST(GraphFactory, ops) {
|
TEST(GraphBuilder, ops) {
|
||||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||||
{ // conv without output
|
{ // conv without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
|
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
|
||||||
auto weight =
|
auto weight =
|
||||||
|
@ -15,7 +15,7 @@ TEST(GraphFactory, ops) {
|
||||||
EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 4, 4}));
|
EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 4, 4}));
|
||||||
}
|
}
|
||||||
{ // conv with output
|
{ // conv with output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
|
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
|
||||||
auto weight =
|
auto weight =
|
||||||
|
@ -25,21 +25,21 @@ TEST(GraphFactory, ops) {
|
||||||
auto conv = gf->conv(input, weight, output, 1, 1);
|
auto conv = gf->conv(input, weight, output, 1, 1);
|
||||||
}
|
}
|
||||||
{ // matmul without output
|
{ // matmul without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime);
|
auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime);
|
||||||
auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime);
|
auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime);
|
||||||
auto matmul = gf->matmul(A, B);
|
auto matmul = gf->matmul(A, B);
|
||||||
EXPECT_EQ(matmul->getOutput()->getDims(), (Shape{1, 3, 2}));
|
EXPECT_EQ(matmul->getOutput()->getDims(), (Shape{1, 3, 2}));
|
||||||
}
|
}
|
||||||
{ // matmul with output
|
{ // matmul with output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime);
|
auto A = make_ref<TensorObj>(Shape{1, 3, 5}, DataType::UInt32, runtime);
|
||||||
auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime);
|
auto B = make_ref<TensorObj>(Shape{1, 5, 2}, DataType::UInt32, runtime);
|
||||||
auto C = make_ref<TensorObj>(Shape{1, 3, 2}, DataType::UInt32, runtime);
|
auto C = make_ref<TensorObj>(Shape{1, 3, 2}, DataType::UInt32, runtime);
|
||||||
auto matmul = gf->matmul(A, B, C);
|
auto matmul = gf->matmul(A, B, C);
|
||||||
}
|
}
|
||||||
{ // convtrans without output
|
{ // convtrans without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime);
|
make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime);
|
||||||
auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2},
|
auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2},
|
||||||
|
@ -48,7 +48,7 @@ TEST(GraphFactory, ops) {
|
||||||
EXPECT_EQ(convtrans->getOutput()->getDims(), (Shape{1, 448, 2, 2}));
|
EXPECT_EQ(convtrans->getOutput()->getDims(), (Shape{1, 448, 2, 2}));
|
||||||
}
|
}
|
||||||
{ // convtrans with output
|
{ // convtrans with output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime);
|
make_ref<TensorObj>(Shape{1, 228, 1, 1}, DataType::UInt32, runtime);
|
||||||
auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2},
|
auto weight = make_ref<TensorObj>(Shape{228, 448, 2, 2},
|
||||||
|
@ -58,7 +58,7 @@ TEST(GraphFactory, ops) {
|
||||||
auto convtrans = gf->convTrans(input, weight, 0, 0);
|
auto convtrans = gf->convTrans(input, weight, 0, 0);
|
||||||
}
|
}
|
||||||
{ // pad without output
|
{ // pad without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
|
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
|
||||||
DataType::UInt32, runtime);
|
DataType::UInt32, runtime);
|
||||||
vector<int> pads = {2, 10, 1, 5, 0, 10, 1, 5};
|
vector<int> pads = {2, 10, 1, 5, 0, 10, 1, 5};
|
||||||
|
@ -66,7 +66,7 @@ TEST(GraphFactory, ops) {
|
||||||
EXPECT_EQ(pad->getOutput()->getDims(), (Shape{3, 84, 164, 172}));
|
EXPECT_EQ(pad->getOutput()->getDims(), (Shape{3, 84, 164, 172}));
|
||||||
}
|
}
|
||||||
{ // pad with output
|
{ // pad with output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
|
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
|
||||||
DataType::UInt32, runtime);
|
DataType::UInt32, runtime);
|
||||||
auto output = make_ref<TensorObj>(Shape{3, 84, 164, 172},
|
auto output = make_ref<TensorObj>(Shape{3, 84, 164, 172},
|
||||||
|
@ -75,7 +75,7 @@ TEST(GraphFactory, ops) {
|
||||||
auto pad = gf->pad(input, output, pads, std::nullopt);
|
auto pad = gf->pad(input, output, pads, std::nullopt);
|
||||||
}
|
}
|
||||||
{ // slice without output
|
{ // slice without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162},
|
auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162},
|
||||||
DataType::UInt32, runtime);
|
DataType::UInt32, runtime);
|
||||||
vector<int> starts = {2, 10, 1, 5};
|
vector<int> starts = {2, 10, 1, 5};
|
||||||
|
@ -84,7 +84,7 @@ TEST(GraphFactory, ops) {
|
||||||
EXPECT_EQ(slice->getOutput()->getDims(), (Shape{2, 1, 100, 96}));
|
EXPECT_EQ(slice->getOutput()->getDims(), (Shape{2, 1, 100, 96}));
|
||||||
}
|
}
|
||||||
{ // slice with output
|
{ // slice with output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162},
|
auto input = make_ref<TensorObj>(Shape{10, 64, 162, 162},
|
||||||
DataType::UInt32, runtime);
|
DataType::UInt32, runtime);
|
||||||
auto output = make_ref<TensorObj>(Shape{2, 1, 100, 96},
|
auto output = make_ref<TensorObj>(Shape{2, 1, 100, 96},
|
||||||
|
@ -95,7 +95,7 @@ TEST(GraphFactory, ops) {
|
||||||
gf->slice(input, output, starts, ends, std::nullopt, std::nullopt);
|
gf->slice(input, output, starts, ends, std::nullopt, std::nullopt);
|
||||||
}
|
}
|
||||||
{ // concat without output
|
{ // concat without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto t1 =
|
auto t1 =
|
||||||
make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime);
|
make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime);
|
||||||
auto t2 =
|
auto t2 =
|
||||||
|
@ -104,7 +104,7 @@ TEST(GraphFactory, ops) {
|
||||||
EXPECT_EQ(concat->getOutput()->getDims(), (Shape{1, 3, 2, 9}));
|
EXPECT_EQ(concat->getOutput()->getDims(), (Shape{1, 3, 2, 9}));
|
||||||
}
|
}
|
||||||
{ // concat with output
|
{ // concat with output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto t1 =
|
auto t1 =
|
||||||
make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime);
|
make_ref<TensorObj>(Shape{1, 3, 2, 4}, DataType::Float32, runtime);
|
||||||
auto t2 =
|
auto t2 =
|
||||||
|
@ -114,7 +114,7 @@ TEST(GraphFactory, ops) {
|
||||||
auto concat = gf->concat(TensorVec{t1, t2}, o0, 3);
|
auto concat = gf->concat(TensorVec{t1, t2}, o0, 3);
|
||||||
}
|
}
|
||||||
{ // split without output
|
{ // split without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime);
|
make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime);
|
||||||
auto split = gf->split(input, 3, 4);
|
auto split = gf->split(input, 3, 4);
|
||||||
|
@ -126,7 +126,7 @@ TEST(GraphFactory, ops) {
|
||||||
EXPECT_EQ(split->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
|
EXPECT_EQ(split->getOutput(3)->getDims(), (Shape{1, 3, 2, 6}));
|
||||||
}
|
}
|
||||||
{ // split with output
|
{ // split with output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime);
|
make_ref<TensorObj>(Shape{1, 3, 2, 15}, DataType::Float32, runtime);
|
||||||
auto output0 =
|
auto output0 =
|
||||||
|
@ -141,14 +141,14 @@ TEST(GraphFactory, ops) {
|
||||||
input, TensorVec{output0, output1, output2, output3}, 3, 4);
|
input, TensorVec{output0, output1, output2, output3}, 3, 4);
|
||||||
}
|
}
|
||||||
{ // extend without output
|
{ // extend without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
|
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
|
||||||
auto extend = gf->extend(input, 2, 1);
|
auto extend = gf->extend(input, 2, 1);
|
||||||
EXPECT_EQ(extend->getOutput()->getDims(), (Shape{2, 3, 6, 4}));
|
EXPECT_EQ(extend->getOutput()->getDims(), (Shape{2, 3, 6, 4}));
|
||||||
}
|
}
|
||||||
{ // extend with output
|
{ // extend with output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
|
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
|
||||||
auto output =
|
auto output =
|
||||||
|
@ -156,7 +156,7 @@ TEST(GraphFactory, ops) {
|
||||||
auto extend = gf->extend(input, output, 2, 1);
|
auto extend = gf->extend(input, output, 2, 1);
|
||||||
}
|
}
|
||||||
{ // maxpool without output
|
{ // maxpool without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
|
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
|
||||||
DataType::UInt32, runtime);
|
DataType::UInt32, runtime);
|
||||||
const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2,
|
const int kh = 3, kw = 3, dh = 1, dw = 1, ph = 0, pw = 0, sh = 2,
|
||||||
|
@ -165,7 +165,7 @@ TEST(GraphFactory, ops) {
|
||||||
EXPECT_EQ(maxpool->getOutput()->getDims(), (Shape{1, 64, 80, 80}));
|
EXPECT_EQ(maxpool->getOutput()->getDims(), (Shape{1, 64, 80, 80}));
|
||||||
}
|
}
|
||||||
{ // maxpool with output
|
{ // maxpool with output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
|
auto input = make_ref<TensorObj>(Shape{1, 64, 162, 162},
|
||||||
DataType::UInt32, runtime);
|
DataType::UInt32, runtime);
|
||||||
auto output = make_ref<TensorObj>(Shape{1, 64, 80, 80},
|
auto output = make_ref<TensorObj>(Shape{1, 64, 80, 80},
|
||||||
|
@ -176,7 +176,7 @@ TEST(GraphFactory, ops) {
|
||||||
gf->maxpool(input, output, kh, kw, dh, dw, ph, pw, sh, sw);
|
gf->maxpool(input, output, kh, kw, dh, dw, ph, pw, sh, sw);
|
||||||
}
|
}
|
||||||
{ // add without output
|
{ // add without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input0 =
|
auto input0 =
|
||||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
|
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
|
||||||
auto input1 =
|
auto input1 =
|
||||||
|
@ -185,7 +185,7 @@ TEST(GraphFactory, ops) {
|
||||||
EXPECT_EQ(add->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
|
EXPECT_EQ(add->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
|
||||||
}
|
}
|
||||||
{ // add with output
|
{ // add with output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input0 =
|
auto input0 =
|
||||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
|
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::UInt32, runtime);
|
||||||
auto input1 =
|
auto input1 =
|
||||||
|
@ -195,7 +195,7 @@ TEST(GraphFactory, ops) {
|
||||||
auto add = gf->add(input0, input1, output);
|
auto add = gf->add(input0, input1, output);
|
||||||
}
|
}
|
||||||
{ // gather without output
|
{ // gather without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
|
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
|
||||||
auto index =
|
auto index =
|
||||||
|
@ -204,7 +204,7 @@ TEST(GraphFactory, ops) {
|
||||||
EXPECT_EQ(gather->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
|
EXPECT_EQ(gather->getOutput()->getDims(), (Shape{1, 2, 1, 2, 4, 4}));
|
||||||
}
|
}
|
||||||
{ // gather with output
|
{ // gather with output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
|
make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::UInt32, runtime);
|
||||||
auto index =
|
auto index =
|
||||||
|
@ -214,7 +214,7 @@ TEST(GraphFactory, ops) {
|
||||||
auto gather = gf->gather(input, index, output, 1);
|
auto gather = gf->gather(input, index, output, 1);
|
||||||
}
|
}
|
||||||
{ // reshape without output
|
{ // reshape without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
||||||
vector<int> dims = {3, 2, 4, 3};
|
vector<int> dims = {3, 2, 4, 3};
|
||||||
|
@ -222,7 +222,7 @@ TEST(GraphFactory, ops) {
|
||||||
EXPECT_EQ(reshape->getOutput()->getDims(), (Shape{3, 2, 4, 3}));
|
EXPECT_EQ(reshape->getOutput()->getDims(), (Shape{3, 2, 4, 3}));
|
||||||
}
|
}
|
||||||
{ // reshape with output
|
{ // reshape with output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
||||||
vector<int> dims = {3, 2, 4, 3};
|
vector<int> dims = {3, 2, 4, 3};
|
||||||
|
@ -231,14 +231,14 @@ TEST(GraphFactory, ops) {
|
||||||
auto reshape = gf->reshape(input, output, dims);
|
auto reshape = gf->reshape(input, output, dims);
|
||||||
}
|
}
|
||||||
{ // flatten without output
|
{ // flatten without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
||||||
auto flatten = gf->flatten(input);
|
auto flatten = gf->flatten(input);
|
||||||
EXPECT_EQ(flatten->getOutput()->getDims(), (Shape{72}));
|
EXPECT_EQ(flatten->getOutput()->getDims(), (Shape{72}));
|
||||||
}
|
}
|
||||||
{ // flatten without output
|
{ // flatten without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
||||||
auto output =
|
auto output =
|
||||||
|
@ -246,14 +246,14 @@ TEST(GraphFactory, ops) {
|
||||||
auto flatten = gf->flatten(input, output);
|
auto flatten = gf->flatten(input, output);
|
||||||
}
|
}
|
||||||
{ // identity without output
|
{ // identity without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
||||||
auto identity = gf->identity(input);
|
auto identity = gf->identity(input);
|
||||||
EXPECT_EQ(identity->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
|
EXPECT_EQ(identity->getOutput()->getDims(), (Shape{2, 3, 3, 4}));
|
||||||
}
|
}
|
||||||
{ // identity without output
|
{ // identity without output
|
||||||
GraphFactory gf = make_ref<GraphFactoryObj>(runtime);
|
GraphBuilder gf = make_ref<GraphBuilderObj>(runtime);
|
||||||
auto input =
|
auto input =
|
||||||
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
make_ref<TensorObj>(Shape{2, 3, 3, 4}, DataType::Float32, runtime);
|
||||||
auto output =
|
auto output =
|
Loading…
Reference in New Issue