Add search engine (#64)

* Add: tensor fuid

* [Intermediate state] Add: Graph ctor for OpVec

* Add: clone for operators

* tmp: search_engine

* search: init search Engine.

* Add: dummy mutator for the test of search engine

* search: add print graph.

* search: add partition.

* search: update comments.

* Fix: remain FUID in Tensor::clone

* Chore: rename GUidBaseType to UidBaseType

* Fix: connect NMutator to SearchEngine

* Chore: output

* Fix test_memboundOp: nmutator uses input runtime

* Chore: clang-format

* Chore: clang-format

* Fix: comments in the review

---------

Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
Co-authored-by: mazx <dyxdy@live.com>
This commit is contained in:
zhengly123 2023-02-12 18:27:52 +08:00 committed by GitHub
parent 14c9c82dab
commit c7ec9ee6e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 1440 additions and 521 deletions

View File

@ -0,0 +1,15 @@
#pragma once
#include "core/mutator.h"
namespace infini {
class DummyMutator : public Mutator {
public:
DummyMutator(int candidatesLimit) : Mutator(candidatesLimit){};
virtual vector<Graph> run(const Graph &inGraph) override;
virtual vector<Graph> mergeMultiBranch(const Graph &inGraph) override;
virtual bool isMultiBranchMergable(const Graph &inGraph) override;
};
} // namespace infini

View File

@ -8,19 +8,22 @@ class GraphObj : public Object {
protected:
Runtime runtime;
TensorVec tensors;
TensorVec inputs;
TensorVec outputs;
// TODO: whether to record input and output tensors
// TensorVec inputs;
// TensorVec outputs;
OpVec ops;
public:
GraphObj(Runtime runtime) : runtime(runtime){};
GraphObj(Runtime runtime, OpVec ops_in);
string toString() const override;
Runtime getRuntime() const { return runtime; }
Tensor addTensor(Shape dim, DataType dtype = DataType::Float32);
Tensor addTensor(const Tensor &tensor);
TensorVec addTensor(const TensorVec &tensors);
Tensor cloneTensor(const Tensor &tensor) {
auto ret = addTensor(tensor->getDims(), tensor->getDType());
ret->dataMalloc();
ret->copyData(tensor);
auto ret = addTensor(tensor->clone(runtime));
return ret;
}
@ -45,12 +48,22 @@ class GraphObj : public Object {
}
const TensorVec &getTensors() const { return tensors; }
const TensorVec &getInputs() const { return inputs; }
const TensorVec &getOutputs() const { return outputs; }
const TensorVec getInputs() const {
TensorVec ret;
for (auto t : tensors)
if (!t->getOutputOf())
ret.emplace_back(t);
return ret;
}
const TensorVec getOutputs() const {
TensorVec ret;
for (auto t : tensors)
if (t->getInputOf().empty())
ret.emplace_back(t);
return ret;
}
const OpVec &getOperators() const { return ops; }
OpVec getComputeOps() const;
// TensorVec &getInputs();
// TensorVec &getOutputs();
void dataMalloc();

View File

@ -106,7 +106,7 @@ class KernelRegistry {
"Kernel not found for key {" +
to_string(enum_to_underlying(std::get<0>(kernelAttrs))) +
", " + OpRegistry::getOpName(std::get<1>(kernelAttrs)) +
", " + std::get<2>(kernelAttrs).toString());
", " + std::get<2>(kernelAttrs).toString() + "}");
return std::get<0>(it->second);
}
const KernelRecord &getKernelItem(const KernelAttrs &kernelAttrs) const {

View File

@ -8,12 +8,27 @@ class Mutator {
int candidatesLimit;
// // Statistical data
// int numTotalCandidates;
protected:
Runtime runtime;
public:
Mutator(int candidatesLimit) : candidatesLimit(candidatesLimit){};
Mutator(int candidatesLimit, Runtime runtime = CpuRuntimeObj::getInstance())
: candidatesLimit(candidatesLimit), runtime(runtime){};
virtual ~Mutator(){};
virtual vector<Graph> run(const Graph &in_graph) = 0;
/**
* @brief Merge a multi-branch graph into single branch graphs
*
* @param in_graph
* @return vector<Graph> Transformed graphs except the orignal one.
*/
virtual vector<Graph> mergeMultiBranch(const Graph &in_graph) {
IT_TODO_HALT();
}
virtual bool isMultiBranchMergable(const Graph &in_graph) {
IT_TODO_HALT();
}
};
} // namespace infini

View File

@ -4,27 +4,44 @@
namespace infini {
using GuidBaseType = int;
using UidBaseType = int;
class Guid {
class Uid {
private:
GuidBaseType guid;
UidBaseType uid;
public:
Uid(UidBaseType uid) : uid(uid) {}
Uid &operator=(const Uid &rhs) = delete;
operator UidBaseType() const { return uid; }
};
class Guid : public Uid {
private:
GuidBaseType generateGuid() {
static GuidBaseType guidCnt = 0;
UidBaseType generateGuid() {
static UidBaseType guidCnt = 0;
return ++guidCnt;
}
public:
Guid() { guid = generateGuid(); }
Guid(const Guid &rhs) { guid = generateGuid(); }
Guid &operator=(const Guid &rhs) {
guid = generateGuid();
return *this;
Guid() : Uid(generateGuid()) {}
Guid(const Guid &rhs) : Uid(generateGuid()) {}
};
/**
* @brief Family unique ID. Cloned tensors shared the same FUID.
*/
class Fuid : public Uid {
private:
UidBaseType generateFuid() {
static UidBaseType fuidCnt = 0;
return ++fuidCnt;
}
operator GuidBaseType() const { return guid; }
public:
Fuid() : Uid(generateFuid()) {}
Fuid(const Fuid &fuid) : Uid(fuid) {}
};
class Object {
@ -35,7 +52,7 @@ class Object {
virtual ~Object(){};
virtual string toString() const = 0;
void print() { std::cout << toString() << std::endl; }
GuidBaseType getGuid() const { return guid; }
UidBaseType getGuid() const { return guid; }
};
inline std::ostream &operator<<(std::ostream &os, const Object &obj) {

View File

@ -197,6 +197,16 @@ class OperatorObj : public Object {
virtual int numInputs() const = 0;
virtual int numOutputs() const = 0;
/**
* @brief Clone this operator and replace its inputs and outputs.
*
* @param newInputs
* @param newOutputs
* @return Operator
*/
virtual Operator clone(const TensorVec &newInputs,
const TensorVec &newOutputs) const = 0;
protected:
optional<vector<Shape>> inferShape() const;
vector<DataType> inferDataType() const;
@ -215,6 +225,18 @@ class OperatorObj : public Object {
virtual vector<int> getWorkloadVector() const { IT_TODO_HALT(); }
};
#define OP_CLONE(OpObj) \
virtual Operator clone(const TensorVec &newInputs, \
const TensorVec &newOutputs) const override { \
auto op = infini::make_ref<OpObj>(*this); \
op->inputs = newInputs; \
op->outputs = newOutputs; \
op->predecessors.clear(); \
op->successors.clear(); \
IT_ASSERT(op->checkValid(nullptr)); \
return op; \
}
} // namespace infini
namespace std {

View File

@ -0,0 +1,80 @@
#pragma once
#include "common.h"
#include "graph.h"
#include "mutator.h"
#include <unordered_map>
namespace infini {
class SearchEngine {
private:
Runtime runtimeExec;
Ref<Mutator> mutator;
public:
SearchEngine(Runtime _runtime, Ref<Mutator> _mutator) {
runtimeExec = _runtime;
mutator = _mutator;
}
~SearchEngine() {}
private: // Configurations
size_t partitionThreshold =
3; // cut nodes whose #in + #out >= partitionThreshold
size_t GRAPH_SIZE = 16; // num of best graphs.
private: // Composed objects
std::shared_ptr<Mutator> mutationEngine;
public:
std::shared_ptr<Mutator> getMutationEngine() { return mutationEngine; };
struct GroupEdge {
int v, next;
GroupEdge() = delete;
};
struct Candidate { // a graph with perf
std::shared_ptr<Graph> graph;
double perf = INFINITY;
};
class MetaGraph { // a graph of subgraphs, for searching.
public:
MetaGraph() {}
~MetaGraph() {}
struct Node {
Graph graph;
std::vector<int> suc;
std::vector<int> pre;
int type, cnt;
};
std::vector<Node> nodes;
};
Graph run(const Graph graph); // entrance of search engine.
std::vector<Graph> search(const Graph &graph); // search for a partition.
private:
std::vector<Graph> partitionGraph(const Graph graph);
std::shared_ptr<MetaGraph> buildMetaGraphWithGraph(const Graph graph);
std::shared_ptr<MetaGraph>
buildMetaGraphWithPlan(const std::shared_ptr<MetaGraph> metaGraph,
const std::vector<int> &plan);
// search horizontal merges
std::vector<std::shared_ptr<MetaGraph>>
searchMerge(std::shared_ptr<MetaGraph> &metaGraph);
void searchMergeDfs(std::shared_ptr<MetaGraph> &metaGraph,
std::vector<int> &plan, std::vector<int> &frontier,
std::vector<std::vector<int>> &plans,
std::unordered_set<uint64_t> &planSet);
std::vector<Graph>
searchMutation(const std::shared_ptr<MetaGraph> &metaGraph);
void printMetaGraph(Ref<SearchEngine::MetaGraph> metaGraph);
/**
* @brief Check whether a multi-brach graph can be merged into a single
* branch.
*/
bool isMultiBranchMergable(const Graph graph);
};
} // namespace infini

View File

@ -10,6 +10,8 @@ using Shape = vector<ShapeElem>;
class TensorObj : public TensorBaseObj {
private:
Shape shape;
Fuid fuid; // Cloned tensors share the same id. Tensors constructed from
// scratch have a new id.
public:
TensorObj(const Shape &shape, DataType dtype, Runtime runtime);
@ -25,6 +27,7 @@ class TensorObj : public TensorBaseObj {
using TensorBaseObj::getData;
VType getData(const Shape &pos) const;
void dataMalloc();
UidBaseType getFuid() const { return fuid; }
void load(std::string file_path);
void save(std::string file_path);
@ -51,10 +54,24 @@ class TensorObj : public TensorBaseObj {
}
generator(data->getPtr<void *>(), size(), dtype);
}
Tensor clone(Runtime runtime) {
auto obj = make_ref<TensorObj>(shape, dtype, runtime);
obj->dataMalloc();
obj->copyData(this);
Tensor clone() const {
auto obj = make_ref<TensorObj>(*this);
obj->freeData();
obj->inputOf.clear();
obj->outputOf.reset();
return obj;
}
// TODO: clarify whether clone copies data
Tensor clone(Runtime runtime) const {
auto obj = make_ref<TensorObj>(*this);
obj->runtime = runtime;
obj->freeData();
obj->inputOf.clear();
obj->outputOf.reset();
if (hasData()) {
obj->dataMalloc();
obj->copyData(this);
}
return obj;
}

View File

@ -33,6 +33,8 @@ class TensorBaseObj : public Object {
data = blob;
}
Blob getDataBlob() const { return data; }
bool hasData() const { return data != nullptr; }
void freeData() { data = nullptr; }
template <typename T> T getRawDataPtr() const {
static_assert(std::is_pointer_v<T>,
"Raw data pointer has a type of pointer");

View File

@ -47,6 +47,7 @@ class CudaRuntimeObj : public RuntimeObj {
CudaPtr alloc(size_t size) override {
void *ptr;
checkCudaError(cudaMalloc(&ptr, size));
// printf("cuda malloc: %p %lu bytes\n", ptr, size);
return ptr;
}
void dealloc(void *ptr) override { checkCudaError(cudaFree(ptr)); }

View File

@ -26,6 +26,7 @@ class G2BMMObj : public OperatorObj {
G2BMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, const int width,
const int dilation, Tensor bias = nullptr,
ActType act = ActType::None);
OP_CLONE(G2BMMObj);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -24,6 +24,7 @@ class GBMMObj : public OperatorObj {
*/
GBMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, const int dilation,
Tensor bias = nullptr, ActType act = ActType::None);
OP_CLONE(GBMMObj);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -10,6 +10,7 @@ class BatchNormObj : public OperatorObj {
BatchNormObj(GraphObj *graph, Tensor input, Tensor output, Tensor mean,
Tensor var, Tensor scale, Tensor bias, float momentum = 0.9,
float eps = 1e-5, bool training = false);
OP_CLONE(BatchNormObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;

View File

@ -7,6 +7,7 @@ class ConcatObj : public OperatorObj {
public:
ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim);
OP_CLONE(ConcatObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -77,6 +77,7 @@ class ConvObj : public ConvBaseObj {
PaddingMode mode = PaddingMode::Same, int sh = 1, int sw = 1,
int dh = 1, int dw = 1, Tensor bias = nullptr,
ActType act = ActType::None);
OP_CLONE(ConvObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
ActType getAct() const { return act; }
@ -104,6 +105,7 @@ class ConvTransposed2dObj : public ConvBaseObj {
int sh = 1, int sw = 1, int dh = 1, int dw = 1,
int oph = 0, int opw = 0, int group = 1,
Tensor bias = nullptr, ActType act = ActType::None);
OP_CLONE(ConvTransposed2dObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
ActType getAct() const { return act; }

View File

@ -23,6 +23,7 @@ class ElementWiseObj : public OperatorObj {
prefix##Obj(GraphObj *graph, Tensor input0, Tensor input1, \
Tensor output) \
: ElementWiseObj(type, graph, input0, input1, output) {} \
OP_CLONE(prefix##Obj); \
};
DEFINE_ELEMENT_WISE_OBJ(Add, OpType::Add)

View File

@ -8,6 +8,7 @@ class ExtendObj : public OperatorObj {
public:
ExtendObj(GraphObj *graph, Tensor input, Tensor output, int dim,
int num = 1);
OP_CLONE(ExtendObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;

View File

@ -9,6 +9,7 @@ class GatherObj : public OperatorObj {
public:
GatherObj(GraphObj *graph, Tensor input, Tensor index, Tensor output,
int axis);
OP_CLONE(GatherObj);
std::string toString() const override;
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }

View File

@ -29,6 +29,7 @@ class MatmulObj : public OperatorObj {
MatmulObj(GraphObj *graph, Tensor A, Tensor B, Tensor C,
bool transA = false, bool transB = false, Tensor bias = nullptr,
ActType act = ActType::None);
OP_CLONE(MatmulObj);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -17,6 +17,7 @@ class MemBoundObj : public OperatorObj {
const TensorVec &output,
const std::vector<nnet::Tensor> &nnetInputs, nnet::Expr expr,
double exec_time, std::string hint = {});
OP_CLONE(MemBoundObj);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -10,6 +10,7 @@ class PadObj : public OperatorObj {
// pad for appointed axises,if axis is empty,then pad for all axises.
PadObj(GraphObj *graph, Tensor input, Tensor output,
const vector<int> &pads, const optional<const vector<int>> &axis);
OP_CLONE(PadObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;

View File

@ -14,6 +14,7 @@ class PoolingObj : public OperatorObj {
public:
PoolingObj(GraphObj *graph, OpType optype, Tensor input, Tensor output,
int kh, int kw, int dh, int dw, int ph, int pw, int sh, int sw);
OP_CLONE(PoolingObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;

View File

@ -10,6 +10,7 @@ class ReduceMeanObj : public OperatorObj {
ReduceMeanObj(GraphObj *graph, Tensor input, Tensor output,
const optional<const vector<int>> &axis,
bool keepDims = true);
OP_CLONE(ReduceMeanObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;

View File

@ -8,6 +8,7 @@ class ReshapeObj : public OperatorObj {
public:
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, const Shape &dims);
OP_CLONE(ReshapeObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
@ -24,6 +25,7 @@ class FlattenObj : public OperatorObj {
public:
FlattenObj(GraphObj *graph, Tensor input, Tensor output);
OP_CLONE(FlattenObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
@ -40,6 +42,7 @@ class IdentityObj : public OperatorObj {
public:
IdentityObj(GraphObj *graph, Tensor input, Tensor output);
OP_CLONE(IdentityObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -51,6 +51,7 @@ class ResizeObj : public OperatorObj {
Tensor roi, ECoeffMode mode,
EKeepAspectRatioPolicy ratioPolicy = EKeepAspectRatioPolicy::none,
ECoordinateTransMode coordTransMode = ECoordinateTransMode::halfPixel);
OP_CLONE(ResizeObj);
// Operator clone(TensorVec inputs, TensorVec outputs) override;
vector<DataType> inferDataType(const TensorVec &inputs) const override;

View File

@ -10,6 +10,7 @@ class SliceObj : public OperatorObj {
const vector<int> &starts, const vector<int> &ends,
const optional<vector<int>> &axis,
const optional<vector<int>> &steps);
OP_CLONE(SliceObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
std::string toString() const override;

View File

@ -10,6 +10,7 @@ class SplitObj : public OperatorObj {
int dim, int num);
SplitObj(GraphObj *graph, Tensor input, std::optional<TensorVec> outputs,
int dim, const vector<int> &ratio);
OP_CLONE(SplitObj);
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;

View File

@ -21,6 +21,7 @@ class UnaryObj : public OperatorObj {
public: \
prefix##Obj(GraphObj *graph, Tensor input, Tensor output) \
: UnaryObj(type, graph, input, output) {} \
OP_CLONE(prefix##Obj); \
};
DEFINE_UNARY_OBJ(Relu, OpType::Relu)

View File

@ -68,10 +68,9 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu
func_code = func.imported_modules[0].get_source()
invoke_code = "%s_kernel0<<<dim3(%s), dim3(%s)>>>(%s, %s);" % (
func_name, ", ".join(map(str, block_dim)), ", ".join(
map(str, thread_dim)),
output_name, ", ".join(input_names))
map(str, thread_dim)), ", ".join(input_names), output_name)
invoke_params = block_dim + thread_dim
ctx = tvm.cuda(0)
input_a = []
for i, (shape, dtype) in enumerate(zip(input_tensors, input_dtypes)):
@ -91,4 +90,4 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu
print("Time")
print(conv_time)
return func_code, invoke_code, conv_time, invoke_params # ms
return func_code, invoke_code, conv_time, invoke_params # ms

65
src/core/dummy_mutator.cc Normal file
View File

@ -0,0 +1,65 @@
#include "core/dummy_mutator.h"
#include "operators/concat.h"
#include "operators/conv.h"
#include "operators/matmul.h"
#include "operators/split.h"
#include "operators/unary.h"
namespace infini {
vector<Graph> DummyMutator::run(const Graph &inGraph) {
if (inGraph->getOperators().size() > 1)
return {inGraph};
// Conv -> Conv + Relu
auto op0 = as<ConvObj>(inGraph->getOperators()[0]);
auto g = make_ref<GraphObj>(runtime);
auto a0 = g->cloneTensor(op0->getInputs()[0]),
w0 = g->cloneTensor(op0->getInputs()[1]),
o0 = g->cloneTensor(op0->getOutput());
auto [ph, pw, sh, sw, dh, dw] = op0->getPadStrideDilation();
auto t =
g->addOp<ConvObj>(a0, w0, nullptr, ph, pw, sh, sw, dh, dw)->getOutput();
g->addOpWithOutputs<ReluObj>(t, o0);
return {inGraph, g};
}
vector<Graph> DummyMutator::mergeMultiBranch(const Graph &inGraph) {
// Two Mamtul of the same shapes -> One Batched Matmul
if (!isMultiBranchMergable(inGraph))
return {};
auto op0 = as<MatmulObj>(inGraph->getOperators()[0]);
auto op1 = as<MatmulObj>(inGraph->getOperators()[1]);
auto [b, m, n, k, transA, transB] = op0->getBMNKTransAB();
auto g = make_ref<GraphObj>(runtime);
auto a0 = g->cloneTensor(op0->getInputs()[0]),
w0 = g->cloneTensor(op0->getInputs()[1]),
o0 = g->cloneTensor(op0->getOutput());
auto a1 = g->cloneTensor(op1->getInputs()[0]),
w1 = g->cloneTensor(op1->getInputs()[1]),
o1 = g->cloneTensor(op1->getOutput());
auto a = g->addOp<ConcatObj>(TensorVec{a0, a1}, nullptr, 0)->getOutput();
auto w = g->addOp<ConcatObj>(TensorVec{w0, w1}, nullptr, 0)->getOutput();
auto t = g->addOp<MatmulObj>(a, w, nullptr, transA, transB);
g->addOpWithOutputs<SplitObj>(t->getOutput(), TensorVec{o0, o1}, 0, 2);
return {g};
}
bool DummyMutator::isMultiBranchMergable(const Graph &inGraph) {
if (inGraph->getOperators().size() != 2)
return false;
for (auto op : inGraph->getOperators()) {
if (op->getOpType() != OpType::Matmul)
return false;
if (op->getPredecessors().size() > 0)
return false;
if (op->getSuccessors().size() > 0)
return false;
}
auto op0 = as<MatmulObj>(inGraph->getOperators()[0]);
auto op1 = as<MatmulObj>(inGraph->getOperators()[1]);
auto args0 = op0->getBMNKTransAB();
auto args1 = op1->getBMNKTransAB();
return args0 == args1;
}
} // namespace infini

View File

@ -1,7 +1,32 @@
#include "core/graph.h"
#include <queue>
namespace infini {
GraphObj::GraphObj(Runtime runtime, OpVec ops_in) : runtime(runtime) {
map<UidBaseType, Tensor> tensorPool;
// Clone tensors
for (const auto &op : ops_in) {
for (const auto &t : op->getInputs())
if (tensorPool.find(t->getFuid()) == tensorPool.end())
tensorPool[t->getFuid()] = t->clone();
for (const auto &t : op->getOutputs())
if (tensorPool.find(t->getFuid()) == tensorPool.end())
tensorPool[t->getFuid()] = t->clone();
}
for (const auto &[_, t] : tensorPool)
addTensor(t);
// Clone operators and add connections
for (const auto &op : ops_in) {
TensorVec inputs, outputs;
for (const auto &t : op->getInputs())
inputs.emplace_back(tensorPool.at(t->getFuid()));
for (const auto &t : op->getOutputs())
outputs.emplace_back(tensorPool.at(t->getFuid()));
addOperatorAndConnect(op->clone(inputs, outputs));
}
}
void GraphObj::addOperatorAndConnect(const Operator &op) {
ops.push_back(op);
for (auto &input : op->getInputs()) {
@ -28,7 +53,7 @@ string GraphObj::toString() const {
oss << "Graph operators:\n";
for (const auto &op : ops) {
vector<GuidBaseType> preds, succs;
vector<UidBaseType> preds, succs;
for (auto &o : op->getPredecessors())
preds.emplace_back(o->getGuid());
for (auto &o : op->getSuccessors())
@ -53,6 +78,18 @@ Tensor GraphObj::addTensor(Shape dim, DataType dtype) {
return tensor;
}
Tensor GraphObj::addTensor(const Tensor &tensor) {
IT_ASSERT(tensor->getRuntime() == runtime, "Tensor runtime mismatch");
tensors.emplace_back(tensor);
return tensor;
}
TensorVec GraphObj::addTensor(const TensorVec &tensors) {
for (auto &t : tensors)
addTensor(t);
return tensors;
}
OpVec GraphObj::getComputeOps() const {
OpVec opList;
for (auto op : ops)

View File

@ -2,6 +2,7 @@
#include "core/blob.h"
#include "core/kernel.h"
#include "core/perf_engine.h"
#include "utils/data_generator.h"
#include <chrono>
#include <cstring>
namespace infini {
@ -73,8 +74,27 @@ double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
PerfRecord record;
// Tune the kernel if there is no record
if (!perfData) {
// TODO: should tenosrs automatically allocate when access data?
// allocate memory for empty tensors and release it after profiling
TensorVec allocatedTensors;
for (auto t : op->getInputs())
if (!t->hasData())
allocatedTensors.emplace_back(t);
for (auto t : op->getOutputs())
if (!t->hasData())
allocatedTensors.emplace_back(t);
for (auto t : allocatedTensors) {
t->dataMalloc();
t->setData(IncrementalGenerator());
}
// Profile operators and record the results
record = kernel->tune(op, this);
perfEngine.setPerfData(perfKey, record);
// Free allocated memory
for (auto t : allocatedTensors)
t->freeData();
} else
record = perfData;

441
src/core/search_engine.cc Normal file
View File

@ -0,0 +1,441 @@
#include "core/search_engine.h"
#include "core/hash.h"
#include "core/runtime.h"
#include <algorithm>
#include <iostream>
#include <unordered_set>
namespace infini {
void SearchEngine::printMetaGraph(Ref<SearchEngine::MetaGraph> metaGraph) {
for (size_t i = 0; i < metaGraph->nodes.size(); i++) {
auto &node = metaGraph->nodes[i];
std::cout << "id: " << i << std::endl;
node.graph->print();
std::cout << "type: " << node.type << std::endl;
std::cout << "pre: ";
for (auto &x : node.pre) {
std::cout << x << " ";
}
std::cout << std::endl;
std::cout << "suc: ";
for (auto &x : node.suc) {
std::cout << x << " ";
}
std::cout << std::endl;
}
std::cout << std::endl;
}
Graph SearchEngine::run(const Graph graph) {
IT_ASSERT(runtimeExec == graph->getRuntime());
std::cout << "[INFO] original graph: " << std::endl;
std::cout << graph->toString();
std::cout << "[INFO] perf: " << runtimeExec->getPerfTime(graph)
<< std::endl;
std::vector<Graph> partitions = partitionGraph(graph);
std::cout << "[INFO] Partition num: " << partitions.size() << std::endl;
std::vector<Graph> bestGraphs = {nullptr};
for (size_t pid = 0; pid < partitions.size(); pid++) {
auto &subGraph = partitions[pid];
std::cout << "[INFO] Partition: " << pid << std::endl;
std::vector<Graph> candidates = search(subGraph);
std::cout << "[INFO] size: " << candidates.size() << std::endl;
IT_ASSERT(candidates.size() > 0);
std::cout << subGraph->toString() << std::endl;
std::vector<Graph> nextGraphs;
for (auto lastGraph : bestGraphs) {
for (auto thisGraph : candidates) {
std::vector<Operator> ops;
if (lastGraph != nullptr) {
for (auto op : lastGraph->getOperators()) {
ops.emplace_back(op);
}
}
if (thisGraph != nullptr) {
for (auto op : thisGraph->getOperators()) {
ops.emplace_back(op);
}
}
auto tmp = make_ref<GraphObj>(runtimeExec, ops);
tmp->dataMalloc();
nextGraphs.emplace_back(tmp);
}
}
std::sort(nextGraphs.begin(), nextGraphs.end(), [&](Graph x, Graph y) {
return runtimeExec->getPerfTime(x) < runtimeExec->getPerfTime(y);
});
if (nextGraphs.size() > GRAPH_SIZE) {
nextGraphs.resize(GRAPH_SIZE);
}
bestGraphs.clear();
for (size_t i = 0; i < nextGraphs.size(); i++) {
bestGraphs.emplace_back(nextGraphs[i]);
}
}
std::cout << "[INFO] unfused graph: " << std::endl;
for (size_t i = 0; i < bestGraphs.size(); i++) {
std::cout << "bestGraph " << i << ":" << std::endl;
std::cout << bestGraphs[i]->toString();
std::cout << "[INFO] perf: " << runtimeExec->getPerfTime(bestGraphs[i])
<< std::endl;
}
return bestGraphs[0];
}
std::vector<Graph> SearchEngine::search(const Graph &graph) {
auto metaGraph = buildMetaGraphWithGraph(graph);
auto mergedGraphs = searchMerge(metaGraph);
std::cout << "[INFO] merged graphs: " << mergedGraphs.size() << std::endl;
std::vector<Graph> results;
for (auto mergedGraph : mergedGraphs) {
auto mutatedGraphs = searchMutation(mergedGraph);
for (size_t i = 0; i < std::min(mutatedGraphs.size(), GRAPH_SIZE);
i++) {
results.emplace_back(mutatedGraphs[i]);
}
}
sort(results.begin(), results.end(), [&](Graph x, Graph y) {
return runtimeExec->getPerfTime(x) < runtimeExec->getPerfTime(y);
}); // compare with perf time
if (results.size() > GRAPH_SIZE) {
results.resize(GRAPH_SIZE);
}
return results;
}
// Build metagraph with a graph, each operator is a node.
std::shared_ptr<SearchEngine::MetaGraph>
SearchEngine::buildMetaGraphWithGraph(const Graph graph) {
auto metaGraph = std::make_shared<MetaGraph>();
int numOps = graph->getOperators().size();
std::vector<int> cnt(numOps, 0);
std::unordered_map<int, int> opMap;
metaGraph->nodes.clear();
std::vector<int> q(0);
for (size_t i = 0; i < graph->getOperators().size(); i++) {
auto &op = graph->getOperators()[i];
MetaGraph::Node node;
std::vector<Operator> ops;
ops.emplace_back(op);
node.graph = make_ref<GraphObj>(runtimeExec, ops);
node.type = op->isComputeOp();
node.cnt = op->getPredecessors().size();
opMap.emplace(op->getGuid(), i);
metaGraph->nodes.emplace_back(node);
}
for (size_t i = 0; i < graph->getOperators().size(); i++) {
auto &op = graph->getOperators()[i];
std::unordered_set<int> set;
set.clear();
set.emplace(i);
for (auto preOp : op->getPredecessors()) {
int id = opMap[preOp->getGuid()];
if (set.find(id) == set.end()) {
metaGraph->nodes[i].pre.emplace_back(id);
set.emplace(id);
}
}
for (auto sucOp : op->getSuccessors()) {
int id = opMap[sucOp->getGuid()];
if (set.find(id) == set.end()) {
metaGraph->nodes[i].suc.emplace_back(id);
set.emplace(id);
}
}
}
return metaGraph;
}
// Build a metagraph with graph and a plan, a plan is which ops should be a
// node.
std::shared_ptr<SearchEngine::MetaGraph> SearchEngine::buildMetaGraphWithPlan(
const std::shared_ptr<SearchEngine::MetaGraph> metaGraph,
const std::vector<int> &plan) {
int numGroups = 0;
for (auto i : plan) {
if (i > numGroups) {
numGroups = i;
}
}
std::vector<std::vector<int>> groups(numGroups + 1, std::vector<int>(0));
for (size_t i = 0; i < plan.size(); i++) {
groups[plan[i]].emplace_back(i);
}
auto resultMetaGraph = make_ref<MetaGraph>();
for (auto &group : groups) {
std::vector<Operator> ops;
std::unordered_set<int> preSet, sucSet;
for (auto id : group) {
MetaGraph::Node node;
for (auto op : metaGraph->nodes[id].graph->getOperators()) {
ops.emplace_back(op);
}
for (auto suc : metaGraph->nodes[id].suc) {
if (sucSet.find(plan[suc]) == sucSet.end()) {
node.suc.emplace_back(plan[suc]);
sucSet.emplace(plan[suc]);
}
}
for (auto pre : metaGraph->nodes[id].pre) {
IT_ASSERT(sucSet.find(plan[pre]) == sucSet.end());
if (preSet.find(plan[pre]) == preSet.end()) {
node.pre.emplace_back(plan[pre]);
preSet.emplace(plan[pre]);
}
}
node.graph = make_ref<GraphObj>(runtimeExec, ops);
node.cnt = node.pre.size();
node.type = ops[0]->isComputeOp();
resultMetaGraph->nodes.emplace_back(node);
}
}
return resultMetaGraph;
}
// Search how to merge multiple ops.
std::vector<std::shared_ptr<SearchEngine::MetaGraph>>
SearchEngine::searchMerge(std::shared_ptr<SearchEngine::MetaGraph> &metaGraph) {
IT_ASSERT(metaGraph != nullptr);
std::vector<int> plan(metaGraph->nodes.size());
for (size_t i = 0; i < plan.size(); i++) {
plan[i] = i;
}
std::vector<int> frontier;
for (size_t i = 0; i < plan.size(); i++) {
if (metaGraph->nodes[i].cnt == 0) {
frontier.emplace_back(i);
}
}
std::vector<std::vector<int>> plans;
std::unordered_set<HashType> planSet;
searchMergeDfs(metaGraph, plan, frontier, plans, planSet);
std::vector<std::shared_ptr<SearchEngine::MetaGraph>> metaGraphs;
for (auto &curPlan : plans) {
metaGraphs.emplace_back(buildMetaGraphWithPlan(metaGraph, curPlan));
}
return metaGraphs;
}
// DFS impl for search merge.
void SearchEngine::searchMergeDfs(std::shared_ptr<MetaGraph> &metaGraph,
std::vector<int> &plan,
std::vector<int> &frontier,
std::vector<std::vector<int>> &plans,
std::unordered_set<uint64_t> &planSet) {
if (frontier.size() == 0) {
// remark id
std::unordered_map<int, int> id_map;
int cnt = 0;
for (size_t i = 0; i < plan.size(); i++) {
if (id_map.find(plan[i]) == id_map.end()) {
id_map.emplace(plan[i], cnt++);
}
plan[i] = id_map[plan[i]];
}
auto hash = hashVector(plan);
if (planSet.find(hash) != planSet.end()) {
return;
}
planSet.emplace(hash);
plans.emplace_back(plan);
return;
}
int numNonCompute = 0;
for (auto x : frontier) {
if (metaGraph->nodes[x].type == 0) {
numNonCompute++;
}
}
auto planBackup = plan;
auto metaGraphBackup = metaGraph;
// DFS non compute ops.
if (numNonCompute > 0) {
std::vector<int> nextFrontier;
for (auto x : frontier) {
if (metaGraph->nodes[x].type == 0) {
for (auto y : metaGraph->nodes[x].suc) {
metaGraph->nodes[y].cnt--;
if (metaGraph->nodes[y].cnt == 0) {
nextFrontier.emplace_back(y);
}
}
} else {
nextFrontier.emplace_back(x);
}
}
searchMergeDfs(metaGraph, plan, nextFrontier, plans, planSet);
metaGraph = metaGraphBackup;
return;
}
// DFS compute ops.
for (int mask = (1 << frontier.size()) - 1; mask > 0; mask--) {
int mergedId = -1;
std::vector<int> nextFrontier;
std::vector<Operator> ops;
for (size_t i = 0; i < frontier.size(); i++) {
if ((1 << i) & mask) {
if (mergedId == -1) {
mergedId = plan[frontier[i]];
} else {
plan[frontier[i]] = mergedId;
}
for (auto y : metaGraph->nodes[frontier[i]].suc) {
metaGraph->nodes[y].cnt--;
if (metaGraph->nodes[y].cnt == 0) {
nextFrontier.emplace_back(y);
}
}
for (auto op :
metaGraph->nodes[frontier[i]].graph->getOperators()) {
ops.emplace_back(op);
}
} else {
nextFrontier.emplace_back(frontier[i]);
}
}
auto graph = make_ref<GraphObj>(runtimeExec, ops);
if (ops.size() == 1 || isMultiBranchMergable(graph)) {
searchMergeDfs(metaGraph, plan, nextFrontier, plans, planSet);
}
plan = planBackup;
metaGraph = metaGraphBackup;
}
return;
}
// Search mutation for each compute op.
std::vector<Graph> SearchEngine::searchMutation(
const std::shared_ptr<SearchEngine::MetaGraph> &metaGraph) {
std::vector<Graph> graphs = {nullptr};
// Append a node to all existing candidates
for (auto &node : metaGraph->nodes) {
std::vector<Graph> nextGraphs;
if (node.type == 1) { // If it has computing OPs
auto mutatedGraphs = mutator->run(node.graph);
for (auto graph : graphs) {
for (auto mutatedGraph : mutatedGraphs) {
std::vector<Operator> ops;
if (graph != nullptr) {
for (auto op : graph->getOperators()) {
ops.emplace_back(op);
}
}
for (auto op : mutatedGraph->getOperators()) {
ops.emplace_back(op);
}
nextGraphs.emplace_back(
make_ref<GraphObj>(runtimeExec, ops));
}
}
} else {
for (auto graph : graphs) {
std::vector<Operator> ops;
if (graph != nullptr) {
for (auto op : graph->getOperators()) {
ops.emplace_back(op);
}
}
for (auto op : node.graph->getOperators()) {
ops.emplace_back(op);
}
nextGraphs.emplace_back(make_ref<GraphObj>(runtimeExec, ops));
}
}
for (auto g : nextGraphs) {
g->dataMalloc();
}
std::sort(nextGraphs.begin(), nextGraphs.end(), [&](Graph x, Graph y) {
return runtimeExec->getPerfTime(x) < runtimeExec->getPerfTime(y);
});
if (nextGraphs.size() > GRAPH_SIZE) {
nextGraphs.resize(GRAPH_SIZE);
}
graphs = nextGraphs;
}
return graphs;
}
bool SearchEngine::isMultiBranchMergable(const Graph graph) {
return mutationEngine->isMultiBranchMergable(graph);
}
// Split a graph into multiple independt graphs. Search engine will search for
// each one.
std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
std::vector<Graph> partitions;
// Reversed DFS post-order is topo-order.
std::unordered_map<size_t, size_t> preOrder, postOrder;
std::vector<Operator> ops;
int preCnt = 0, postCnt = 0;
std::function<void(Operator)> dfs = [&](Operator op) {
if (preOrder.count(op->getGuid())) {
return;
}
preOrder[op->getGuid()] = preCnt++;
for (auto &&next : op->getSuccessors()) {
dfs(next);
}
postOrder[op->getGuid()] = postCnt++;
ops.emplace_back(op);
};
for (auto &&op : graph->getOperators()) {
dfs(op);
}
std::vector<Operator> headOps;
for (size_t i = 0; i < ops.size(); i++) {
auto &op = ops[i];
headOps.emplace_back(op);
if (op->getPredecessors().size() + op->getSuccessors().size() >=
(size_t)partitionThreshold &&
!op->isComputeOp()) {
auto preOrderI = preOrder[op->getGuid()];
auto postOrderI = postOrder[op->getGuid()];
for (size_t j = 0; j < i; j++) {
// True predecessor
if (preOrder[ops[j]->getGuid()] < preOrderI) {
for (auto nextOp : ops[j]->getSuccessors()) {
if (postOrder[nextOp->getGuid()] < postOrderI) {
// FIXME: DO NOT USE goto
goto fail;
}
}
}
}
std::cout << "partition!!!: " << i << std::endl;
for (auto op : headOps) {
std::cout << op->toString() << std::endl;
}
auto tmp = make_ref<GraphObj>(runtimeExec, headOps);
tmp->dataMalloc();
partitions.emplace_back(tmp);
headOps.clear();
}
fail:;
}
if (!headOps.empty()) {
auto tmp = make_ref<GraphObj>(runtimeExec, headOps);
tmp->dataMalloc();
partitions.emplace_back(tmp);
}
std::reverse(partitions.begin(), partitions.end());
return partitions;
}
} // namespace infini

View File

@ -14,9 +14,10 @@ VType TensorObj::getData(const Shape &pos) const {
}
string TensorObj::toString() const {
string ret = "Tensor " + std::to_string(guid) + ", shape " +
vecToString(shape) + ", dtype " + dtype.toString();
vector<GuidBaseType> inputOfGuid;
string ret = "Tensor " + std::to_string(guid) + ", Fuid " +
std::to_string(fuid) + ", shape " + vecToString(shape) +
", dtype " + dtype.toString();
vector<UidBaseType> inputOfGuid;
for (const auto &op : inputOf)
inputOfGuid.emplace_back(op.lock()->getGuid());
if (auto o = outputOf.lock())

View File

@ -69,7 +69,7 @@ void CudaRuntimeObj::run(const Graph &graph, bool runTune,
sync();
}
void CudaRuntimeObj::sync() const { cudaDeviceSynchronize(); }
void CudaRuntimeObj::sync() const { checkCudaError(cudaDeviceSynchronize()); }
string CudaRuntimeObj::toString() const { return "CUDA Runtime"; }

View File

@ -127,19 +127,16 @@ class MemboundTVM : public Kernel {
cuModuleLoadDataEx(&module, ret->ptx.data(), 0, nullptr, nullptr));
checkCUresult(cuModuleGetFunction(&kernel, module, kernelName.c_str()));
std::vector<void *> args;
for (auto &&in : op->getInputs()) {
for (auto &&in : op->getInputs())
args.push_back(in->getRawDataPtr<void *>());
}
args.push_back(op->getOutput()->getRawDataPtr<void *>());
std::vector<void *> argsPtr;
for (auto &arg : args) {
for (auto &arg : args)
argsPtr.push_back(&arg);
}
// Evaluate the kernel
ret->time = timeit(
[&]() {
// TODO: run the kernel
cuLaunchKernel(kernel, invokeParams[0], invokeParams[1],
invokeParams[2], invokeParams[3],
invokeParams[4], invokeParams[5], 0, NULL,

View File

@ -22,18 +22,12 @@ NMutator::~NMutator() {}
void NMutator::setToNaiveMembound() { mode = Mode::ToNaiveMembound; }
vector<Graph> NMutator::run(const Graph &in_graph) {
vector<Graph> out_graphs;
vector<Graph> out_graphs{in_graph};
// Test helper: naively transform one Op to Membound
if (mode == Mode::ToNaiveMembound) {
runSingleOpToNaiveMembound(in_graph, out_graphs);
dbg(out_graphs.size());
return out_graphs;
}
// // Hack for HetConv fusion
// if (statGraph(in_graph) == NMutator::SGType::HetConv) {
// dbg("Start fuse HetConv");
// out_graphs.emplace_back(fuseHetConv(nullptr, in_graph));
// }
// Clear input names maps with tensor
inputsNameNToTensorT.clear();
OpVec computeOps = in_graph->getComputeOps();
@ -51,15 +45,25 @@ void NMutator::runSingleOpToNaiveMembound(Graph in_graph,
OpVec computeOps = in_graph->getComputeOps();
assert(computeOps.size() == 1);
const auto &computeOp = computeOps[0];
auto g = infini::make_ref<GraphObj>(CpuRuntimeObj::getInstance());
auto g = infini::make_ref<GraphObj>(in_graph->getRuntime());
auto expr = opToExpression(computeOp);
auto inputsN = nnet::GetTensorsVisitor().get(expr);
dbg(inputsN);
dbg(expr);
// FIXME: tensors should be copied?
g->addOpWithOutputs<MemBoundObj>(
computeOp->getInputs(), computeOp->getOutputs(),
vector<nnet::Tensor>{inputsN.at("A"), inputsN.at("B")}, expr, 0.);
dbg(inputsN, expr);
IT_ASSERT(inputsN.count("B") + inputsN.count("K") == 1,
"Which one is the second input tensor?");
vector<nnet::Tensor> inputsVectorN = {inputsN.at("A")};
if (inputsN.count("B"))
inputsVectorN.emplace_back(inputsN["B"]);
else
inputsVectorN.emplace_back(inputsN["K"]);
// clone IF inputs and outputs into the new graph
TensorVec inputsT, outputsT;
for (auto t : computeOp->getInputs())
inputsT.emplace_back(g->cloneTensor(t));
for (auto t : computeOp->getOutputs())
outputsT.emplace_back(g->cloneTensor(t));
g->addOpWithOutputs<MemBoundObj>(inputsT, outputsT, inputsVectorN, expr,
0.);
g->print();
out_graphs.emplace_back(g);
}
@ -226,62 +230,62 @@ void NMutator::runMultipleOps(Graph in_graph, std::vector<Graph> &out_graphs) {
nnet::Expr NMutator::opToExpression(Operator op) {
// IT_TODO_HALT();
// if (auto convOp = dynamic_cast<ConvOp *>(op)) {
// const auto &inputs = convOp->getInputs();
// const auto &AT = inputs[0];
// const auto &KT = inputs[1];
// const auto &[n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, g, bi, ac]
// =
// convOp->getArgs(0);
// dbg(n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw);
// if (!(sh == 1 && sw == 1 && dh == 1 && dw == 1))
// return nullptr;
// assert(sh == 1 && sw == 1 && dh == 1 && dw == 1);
// inputsNameNToTensorT["A"] = AT;
// inputsNameNToTensorT["K"] = KT;
// const auto A = nnet::makeTensor("A", AT->getDims(),
// std::vector<int>{0, 0, ph, pw});
// const auto K = nnet::makeTensor("K", KT->getDims());
// return nnet::ConvPattern::getExpr(A, K, n, c, h, w, f, r, s);
// } else if (auto convOp = dynamic_cast<ConvTransOp *>(op)) {
// const auto &AT = convOp->getInputs()[0];
// const auto &KT = convOp->getInputs()[1];
// inputsNameNToTensorT["A"] = AT;
// inputsNameNToTensorT["K"] = KT;
// const auto &[n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, g, bi, ac]
// =
// convOp->getArgs(0);
// if (r != 4) {
// dbg("ConvTranspose R!=4. Skipped.", r);
// return nullptr;
// }
// int padding = 1 * (r - 1) - 1;
// const auto A = nnet::makeTensor(
// "A", AT->getDims(), std::vector<int>{0, padding, padding, 0});
// const auto K = nnet::makeTensor("K", KT->getDims());
// return nnet::ConvTransPattern::getExpr(A, K, n, c, h, w, f, r, s);
// } else if (auto g2bmmOp = dynamic_cast<G2BMMOp *>(op)) {
// const auto &AT = g2bmmOp->getInputs()[0];
// const auto &BT = g2bmmOp->getInputs()[1];
// const auto [b, m, k, width, dilation] = g2bmmOp->getArgs();
if (auto convOp = as<ConvObj>(op)) {
const auto &inputs = convOp->getInputs();
const auto &AT = inputs[0];
const auto &KT = inputs[1];
const auto &[n, c, h, w, f, r, s] = convOp->getNCHWFRS();
const auto &[ph, pw, sh, sw, dh, dw] = convOp->getPadStrideDilation();
if (!(sh == 1 && sw == 1 && dh == 1 && dw == 1))
return nullptr;
assert(sh == 1 && sw == 1 && dh == 1 && dw == 1);
inputsNameNToTensorT["A"] = AT;
inputsNameNToTensorT["K"] = KT;
const auto A = nnet::makeTensor("A", AT->getDims(),
std::vector<int>{0, 0, ph, pw});
const auto K = nnet::makeTensor("K", KT->getDims());
return nnet::ConvPattern::getExpr(A, K, n, c, h, w, f, r, s);
// } else if (auto convOp = dynamic_cast<ConvTransOp *>(op)) {
// const auto &AT = convOp->getInputs()[0];
// const auto &KT = convOp->getInputs()[1];
// inputsNameNToTensorT["A"] = AT;
// inputsNameNToTensorT["K"] = KT;
// const auto &[n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, g, bi,
// ac]
// =
// convOp->getArgs(0);
// if (r != 4) {
// dbg("ConvTranspose R!=4. Skipped.", r);
// return nullptr;
// }
// int padding = 1 * (r - 1) - 1;
// const auto A = nnet::makeTensor(
// "A", AT->getDims(), std::vector<int>{0, padding, padding,
// 0});
// const auto K = nnet::makeTensor("K", KT->getDims());
// return nnet::ConvTransPattern::getExpr(A, K, n, c, h, w, f, r,
// s);
// } else if (auto g2bmmOp = dynamic_cast<G2BMMOp *>(op)) {
// const auto &AT = g2bmmOp->getInputs()[0];
// const auto &BT = g2bmmOp->getInputs()[1];
// const auto [b, m, k, width, dilation] = g2bmmOp->getArgs();
// const auto &[expr, inputsN] =
// nnet::Sg2bmmPattern::getExpr(b, m, k, width, dilation);
// inputsNameNToTensorT[inputsN.first->getName()] = AT;
// inputsNameNToTensorT[inputsN.second->getName()] = BT;
// return expr;
// } else if (auto gbmmlOp = dynamic_cast<GBMMLOp *>(op)) {
// const auto &AT = gbmmlOp->getInputs()[0];
// const auto &BT = gbmmlOp->getInputs()[1];
// const auto [b, m, w, k, dilation] = gbmmlOp->getArgs();
// const auto &[expr, inputsN] =
// nnet::LongformerGBMMPattern::getExpr(b, m, w, k, dilation);
// inputsNameNToTensorT[inputsN.first->getName()] = AT;
// inputsNameNToTensorT[inputsN.second->getName()] = BT;
// dbg(b, m, w, k, dilation, expr);
// return expr;
// } else
if (auto matmulOp = as<MatmulObj>(op)) {
// const auto &[expr, inputsN] =
// nnet::Sg2bmmPattern::getExpr(b, m, k, width, dilation);
// inputsNameNToTensorT[inputsN.first->getName()] = AT;
// inputsNameNToTensorT[inputsN.second->getName()] = BT;
// return expr;
// } else if (auto gbmmlOp = dynamic_cast<GBMMLOp *>(op)) {
// const auto &AT = gbmmlOp->getInputs()[0];
// const auto &BT = gbmmlOp->getInputs()[1];
// const auto [b, m, w, k, dilation] = gbmmlOp->getArgs();
// const auto &[expr, inputsN] =
// nnet::LongformerGBMMPattern::getExpr(b, m, w, k, dilation);
// inputsNameNToTensorT[inputsN.first->getName()] = AT;
// inputsNameNToTensorT[inputsN.second->getName()] = BT;
// dbg(b, m, w, k, dilation, expr);
// return expr;
} else if (auto matmulOp = as<MatmulObj>(op)) {
const auto &AT = matmulOp->getInputs()[0];
const auto &BT = matmulOp->getInputs()[1];
const auto [b, m, n, k, transA, transB] = matmulOp->getBMNKTransAB();

View File

@ -2,6 +2,7 @@
#include "core/graph.h"
#include "core/runtime.h"
#include "operators/matmul.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
@ -57,4 +58,55 @@ TEST(Graph, perf_engine) {
EXPECT_TRUE(matmul->getOutput()->equalData(ans));
}
TEST(Graph, test_tensor_id) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
g->dataMalloc();
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto i1 = g->addTensor(i0->clone());
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
g->print();
EXPECT_NE(i0->getGuid(), i1->getGuid());
EXPECT_EQ(i0->getFuid(), i1->getFuid());
EXPECT_NE(i0->getDataBlob(), nullptr);
EXPECT_EQ(i1->getDataBlob(), nullptr);
}
TEST(Graph, test_OpVec_ctor) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
g->dataMalloc();
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
auto o1 = g->addTensor(o0->clone());
auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
g->addOp<ReluObj>(o1, nullptr);
g->print();
puts("=========");
OpVec ops = g->getOperators();
Graph g2 = make_ref<GraphObj>(runtime, ops);
g2->print();
// Check if the two tensors with the same FUID (o0,o1) remain only one in g2
EXPECT_EQ(g2->getTensors().size(), 4u);
EXPECT_EQ(g2->getOperators().size(), 2u);
map<pair<int, int>, int> inputOutput2Cnt = {
{{1, 0}, 2}, {{1, 1}, 1}, {{0, 1}, 1}};
for (auto t : g2->getTensors()) {
pair<int, int> key = {t->getInputOf().size(),
t->getOutputOf() != nullptr};
EXPECT_GE(inputOutput2Cnt[key], 0);
inputOutput2Cnt[key]--;
}
for (auto [u, v] : inputOutput2Cnt) {
EXPECT_EQ(v, 0);
}
}
} // namespace infini

86
test/core/test_search.cc Normal file
View File

@ -0,0 +1,86 @@
#include "core/blob.h"
#include "core/dummy_mutator.h"
#include "core/graph.h"
#include "core/runtime.h"
#include "core/search_engine.h"
#include "nnet/nmutator.h"
#include "operators/conv.h"
#include "operators/element_wise.h"
#include "operators/matmul.h"
#include "operators/unary.h"
#include "test.h"
namespace infini {
// TEST(Graph, search) {
// Runtime runtime = CpuRuntimeObj::getInstance();
// Graph g = make_ref<GraphObj>(runtime);
// Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
// Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
// Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
// g->dataMalloc();
// i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
// w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
// auto matmul = g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
// g->print();
// // check inputOf and outputsOf for tensor
// SearchEngine searchEngine(runtime, make_ref<NMutator>());
// searchEngine.run(g);
// // check execution results
// }
TEST(Graph, search_withdm) {
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
Tensor t0 = g->addTensor({1, 3, 224, 224});
Tensor w0 = g->addTensor({3, 3, 3, 3});
Tensor t1 = g->addTensor({1, 3, 224, 224});
Tensor t2 = g->addTensor({1, 3, 224, 224});
Tensor t3 = g->addTensor({1, 3, 224, 224});
Tensor w3 = g->addTensor({3, 3, 3, 3});
Tensor t4 = g->addTensor({1, 3, 224, 224});
Tensor t5 = g->addTensor({1, 3, 224, 224});
Tensor t6 = g->addTensor({1, 3, 224, 224});
auto conv0 = g->addOpWithOutputs<ConvObj>(t0, w0, t1, 1, 1);
auto add0 = g->addOpWithOutputs<AddObj>(t1, t2, t3);
auto conv1 = g->addOpWithOutputs<ConvObj>(t3, w3, t4, 1, 1);
auto add1 = g->addOpWithOutputs<AddObj>(t4, t5, t6);
g->dataMalloc();
// check inputOf and outputsOf for tensor
SearchEngine searchEngine(runtime, make_ref<DummyMutator>(10));
searchEngine.run(g);
// check execution results
}
// TEST(DummyMutator, run) {
// Runtime runtime = CpuRuntimeObj::getInstance();
// Graph g = make_ref<GraphObj>(runtime);
// Tensor i0 = g->addTensor({1, 3, 224, 224});
// Tensor w0 = g->addTensor({2, 3, 3, 3});
// auto matmul = g->addOp<ConvObj>(i0, w0, nullptr, 1, 1);
// DummyMutator m(10);
// auto mutations = m.run(g);
// g->print();
// for (auto gg : mutations) {
// gg->print();
// }
// }
// TEST(DummyMutator, fuse) {
// Runtime runtime = CpuRuntimeObj::getInstance();
// Graph g = make_ref<GraphObj>(runtime);
// Tensor i0 = g->addTensor({1, 2, 3});
// Tensor w0 = g->addTensor({1, 3, 4});
// Tensor i1 = g->addTensor({1, 2, 3});
// Tensor w1 = g->addTensor({1, 3, 4});
// auto matmul0 = g->addOp<MatmulObj>(i0, w0, nullptr);
// auto matmul1 = g->addOp<MatmulObj>(i1, w1, nullptr);
// DummyMutator m(10);
// auto mutations = m.mergeMultiBranch(g);
// g->print();
// for (auto gg : mutations) {
// gg->print();
// }
// }
} // namespace infini

View File

@ -1,407 +0,0 @@
#include "code_engine.h"
#include "nnet/nmutator.h"
#include "operator.h"
#include "search_engine.h"
#include "tensor.h"
#include "gtest/gtest.h"
using namespace std;
using namespace infini;
// TEST(Mutator, Conv9x9) {
// auto g = new tpm::Graph();
// auto i0 = g->tensor({1, 1, 224, 224});
// auto w1 = g->tensor({64, 1, 9, 9});
// g->conv(i0, w1, 4, 4);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(bestGraph, "res.cu");
// }
// TEST(Mutator, TConv_1) {
// auto g = new tpm::Graph();
// auto i0 = g->tensor({1, 1, 1, 228});
// auto w1 = g->tensor({228, 2, 2, 448});
// // g->conv(i0, w1, 4, 4);
// g->convTrans(i0, w1, 0, 0, 1, 1);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(bestGraph, "res.cu");
// }
// TEST(Mutator, TConv_3) {
// auto g = new tpm::Graph();
// auto i0 = g->tensor({1, 2, 2, 448});
// auto w1 = g->tensor({448, 4, 4, 256});
// // g->conv(i0, w1, 4, 4);
// g->convTrans(i0, w1, 1, 1, 2, 2, 1, 1);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(bestGraph, "res.cu");
// }
// FIXME: failed since implicit transpose for DLT
TEST(Mutator, DISABLED_InfoGAN_TConv_3_correctness) {
// verifyNaiveMembound True: subgraph after transformation
// verifyNaiveMembound False: subgraph of one single membound (eOP)
const bool verifyNaiveMembound = false;
auto g = new tpm::Graph();
// {n, h, w, f} * {r, s, f, c}
// {n, f, h, w} * {f, c, r, s}
auto i0 = g->tensor({1, 448, 2, 2});
auto w1 = g->tensor({448, 256, 4, 4});
g->convTrans(i0, w1, 1, 1, 2, 2, 1, 1);
g->updateConnection();
printf("--- Init Finished ---\n");
std::shared_ptr<tpm::SubGraph> graph, bestGraph;
vector<tpm::SubGraph *> outGraphs;
graph = std::make_shared<tpm::SubGraph>(g->getOperators());
graph->print();
printf("--- Graph Finished ---\n");
auto mutationEngine = make_shared<tpm::NMutator>();
if (verifyNaiveMembound)
mutationEngine->setToNaiveMembound();
tpm::SearchEngine searchEngine(mutationEngine);
printf("--- SearchEngine Finished ---\n");
tpm::NMutator mutator;
if (verifyNaiveMembound)
mutator.setToNaiveMembound();
mutator.run(graph.get(), outGraphs);
printf("--- Mutator Finished ---\n");
bestGraph = shared_ptr<tpm::SubGraph>(outGraphs.back());
bestGraph->print();
printf("--- BestGraph Finished ---\n");
EXPECT_TRUE(graph->verification(bestGraph.get(), true));
// // Codegen (independent from the above)
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(bestGraph, "res.cu");
}
// TEST(Mutator, G2BMM) {
// auto g = new tpm::Graph();
// int nHeads = 8, seq_len = 10000, feat_len = 64, w = 1000, d = 4;
// auto i0 = g->tensor({nHeads, seq_len, feat_len});
// auto i1 = g->tensor({nHeads, seq_len, feat_len});
// g->g2bmm(i0, i1, w, d);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(graph, "res.cu");
// // codeEngine.genCode(bestGraph, "res.cu");
// }
// TEST(Mutator, GBMML) {
// auto g = new tpm::Graph();
// int nHeads = 8, seq_len = 10000, feat_len = 64, w = 1000, d = 4;
// auto i0 = g->tensor({nHeads, seq_len, 2 * w + 1});
// auto i1 = g->tensor({nHeads, seq_len, feat_len});
// g->gbmml(i0, i1, d);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(graph, "res.cu");
// // codeEngine.genCode(bestGraph, "res.cu");
// }
// TEST(Mutator, Conv5x5) {
// //
// conv7x7->relu->conv3x3->relu->conv3x3->relu->conv3x3->relu->conv3x3->relu
// auto g = new tpm::Graph();
// auto i0 = g->tensor({1, 32, 224, 224});
// auto w1 = g->tensor({1, 32, 5, 5});
// g->conv(i0, w1, tpm::ConvOp::PaddingMode::Same);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(bestGraph, "res.cu");
// }
// TEST(Mutator, BMM) {
// const int m = 16, n = 1024, k = 1024;
// auto g = new tpm::Graph();
// auto i0 = g->tensor({1, m, k});
// auto w0 = g->tensor({1, k, n});
// auto w1 = g->tensor({1, k, n});
// auto w2 = g->tensor({1, k, n});
// g->matmul(i0, w0);
// g->matmul(i0, w1);
// g->matmul(i0, w2);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(bestGraph, "res.cu");
// }
TEST(Mutator, Conv2gemm1x1_bs1_mutator) {
const int N = 1, H = 7, W = 7, C = 512, F = 512, R = 1, S = 1;
auto g = new tpm::Graph();
auto i0 = g->tensor({N, C, H, W});
auto w1 = g->tensor({F, C, R, S});
g->conv(i0, w1, R / 2, S / 2);
g->updateConnection();
std::shared_ptr<tpm::SubGraph> graph, bestGraph;
vector<tpm::SubGraph *> out_graphs;
graph = std::make_shared<tpm::SubGraph>(g->getOperators());
const vector<int> rules = {3, 2, 2, 8, 8, 6, 6};
auto mutator = make_shared<tpm::NMutator>(rules);
mutator->run(graph.get(), out_graphs);
tpm::SearchEngine searchEngine(mutator);
int maxNReshapes = 0;
for (const auto &graph : out_graphs) {
searchEngine.getPerf(make_shared<tpm::SubGraph>(*graph), true);
int nReshapes = 0, nTrans = 0;
for (auto op : graph->getOperators()) {
nReshapes += op->isReshapeOp();
if (auto matmul = dynamic_cast<MatmulOp *>(op))
nTrans = matmul->getTransA() + matmul->getTransB();
}
maxNReshapes = max(maxNReshapes, nReshapes);
// Number of Reshapes for KxA and AxK
EXPECT_TRUE((nReshapes == 3 - nTrans) || (nReshapes == nTrans));
}
// Matmul K^N A^N -> no Membound
EXPECT_EQ(maxNReshapes, 3);
}
TEST(Mutator, Conv2gemm1x1_searchEngine_ruleBased) {
const int N = 1, H = 7, W = 7, C = 512, F = 512, R = 1, S = 1;
auto g = new tpm::Graph();
auto i0 = g->tensor({N, C, H, W});
auto w1 = g->tensor({F, C, R, S});
g->conv(i0, w1, R / 2, S / 2);
g->updateConnection();
std::shared_ptr<tpm::SubGraph> graph, bestGraph;
graph = std::make_shared<tpm::SubGraph>(g->getOperators());
const vector<int> rules = {3, 2, 2, 8, 8, 6, 6};
tpm::SearchEngine searchEngine(make_shared<tpm::NMutator>(rules));
searchEngine.run(graph, bestGraph);
// clang-format off
// ========== PET graph getPerf ============
// Reshape(in=0,out=126)
// op_time 0.000000
// Reshape(in=1,out=125)
// op_time 0.000000
// Matmul([A,B,act=0],A=125,B=126,C=124, TTbmnk: 0, 0, 1, 512, 49, 512)
// op_time 0.013799
// Reshape(in=124,out=3)
// op_time 0.000000
// Op Cnt T_tot Percent T_mean
// Matmul 1 0.014 100.0 0.014
// Reshape 3 0.000 0.0 0.000
// Origin Perf: 0.0553319
// Best Perf without correction: 0.0137989
// Best Perf with correction: 0.0137989
// clang-format on
EXPECT_EQ(bestGraph->getOperators().size(), 4u);
auto cntOps = bestGraph->countOps();
EXPECT_EQ(cntOps["Matmul"], 1);
EXPECT_EQ(cntOps["Reshape"], 3);
bestGraph->print();
}
TEST(Mutator, Conv2gemm1x1_searchEngine_search) {
const int N = 1, H = 7, W = 7, C = 512, F = 512, R = 1, S = 1;
auto g = new tpm::Graph();
auto i0 = g->tensor({N, C, H, W});
auto w1 = g->tensor({F, C, R, S});
g->conv(i0, w1, R / 2, S / 2);
g->updateConnection();
std::shared_ptr<tpm::SubGraph> graph, bestGraph;
graph = std::make_shared<tpm::SubGraph>(g->getOperators());
tpm::SearchEngine searchEngine(make_shared<tpm::NMutator>());
searchEngine.run(graph, bestGraph);
EXPECT_EQ(bestGraph->getOperators().size(), 4u);
auto cntOps = bestGraph->countOps();
EXPECT_EQ(cntOps["Matmul"], 1);
EXPECT_EQ(cntOps["Reshape"], 3);
bestGraph->print();
}
TEST(Mutator, Conv2gemm1x7_searchEngine_ruleBased) {
const int N = 1, C = 2048, H = 7, W = 7, F = 128, R = 1,
S = 7; // gcn_Conv_137
auto g = new tpm::Graph();
auto i0 = g->tensor({N, C, H, W});
auto w1 = g->tensor({F, C, R, S});
g->conv(i0, w1, R / 2, S / 2);
g->updateConnection();
std::shared_ptr<tpm::SubGraph> graph, bestGraph;
graph = std::make_shared<tpm::SubGraph>(g->getOperators());
const vector<int> rules = {3, 2, 2, 5, 8, 8, 6, 90};
tpm::SearchEngine searchEngine(make_shared<tpm::NMutator>(rules));
searchEngine.run(graph, bestGraph);
// clang-format off
// ========== PET graph getPerf ============
// Reshape(in=0,out=309)
// op_time 0.000000
// MemBound[124644277](i0=1, o0=308, exec_time=0.0683594, NNet Inputs=[K,])
// L<c:0:2048><i52:0:896>Sum ... [i52,c]
// {L<i52:0:896><c:0:2048>Sum ... [(i52 / 7),c,((i52 / 7) % 1),(i52 % 7)]
// {K}}
// op_time 0.000000
// Matmul([A^T,B,act=0],A=308,B=309,C=307, TTbmnk: 1, 0, 1, 896, 49, 2048)
// op_time 0.024471
// MemBound[124644277](i0=307, o0=3, exec_time=0.001, NNet Inputs=[T49,])
// L<n:0:1><f:0:128><h:0:7><w:0:7>Sum<r:0:1><s:0:7> ... [(h + r),r,(w + s),s,n,f]
// {L<i45:0:7><i46:0:1><i26:3:10><i27:0:7><n:0:1><f:0:128><pad=0,0,3,0,0,0,>Sum ... [(((7 * f) + (7 * i46)) + i27),(((49 * n) + (7 * i45)) + (i26 + -3))]
// {T49}}
// op_time 0.001000
// Op Cnt T_tot Percent T_mean
// Matmul 1 0.024 96.1 0.024
// Reshape 1 0.000 0.0 0.000
// MemBound 2 0.001 3.9 0.001
// Origin Perf: 0.405595
// Best Perf without correction: 0.0254715
// Best Perf with correction: 0.0254715
// Transpose perf: 0
// clang-format on
EXPECT_EQ(bestGraph->getOperators().size(), 4u);
auto cntOps = bestGraph->countOps();
EXPECT_EQ(cntOps["Matmul"], 1);
EXPECT_EQ(cntOps["Reshape"], 1);
EXPECT_EQ(cntOps["MemBound"], 2);
bestGraph->print();
EXPECT_TRUE(graph->verification(bestGraph.get(), true));
}
TEST(Mutator, Conv2gemm7x1_searchEngine_ruleBased) {
const int N = 1, C = 2048, H = 7, W = 7, F = 128, R = 7,
S = 1; // gcn_Conv_137
auto g = new tpm::Graph();
auto i0 = g->tensor({N, C, H, W});
auto w1 = g->tensor({F, C, R, S});
g->conv(i0, w1, R / 2, S / 2);
g->updateConnection();
std::shared_ptr<tpm::SubGraph> graph, bestGraph;
graph = std::make_shared<tpm::SubGraph>(g->getOperators());
const vector<int> rules = {3, 2, 2, 5, 8, 8, 6, 90};
tpm::SearchEngine searchEngine(make_shared<tpm::NMutator>(rules));
searchEngine.run(graph, bestGraph);
EXPECT_EQ(bestGraph->getOperators().size(), 4u);
auto cntOps = bestGraph->countOps();
EXPECT_EQ(cntOps["Matmul"], 1);
EXPECT_EQ(cntOps["Reshape"], 1);
EXPECT_EQ(cntOps["MemBound"], 2);
bestGraph->print();
EXPECT_TRUE(graph->verification(bestGraph.get(), true));
}
TEST(Mutator, Conv2gemm7x1_searchEngine_search) {
const int N = 1, C = 2048, H = 7, W = 7, F = 128, R = 7,
S = 1; // gcn_Conv_137
auto g = new tpm::Graph();
auto i0 = g->tensor({N, C, H, W});
auto w1 = g->tensor({F, C, R, S});
g->conv(i0, w1, R / 2, S / 2);
g->updateConnection();
std::shared_ptr<tpm::SubGraph> graph, bestGraph;
graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// const vector<int> rules = {3, 2, 2, 5, 8, 8, 6, 90};
tpm::SearchEngine searchEngine(make_shared<tpm::NMutator>());
searchEngine.run(graph, bestGraph);
EXPECT_EQ(bestGraph->getOperators().size(), 4u);
auto cntOps = bestGraph->countOps();
EXPECT_EQ(cntOps["Matmul"], 1);
EXPECT_EQ(cntOps["Reshape"], 1);
EXPECT_EQ(cntOps["MemBound"], 2);
bestGraph->print();
EXPECT_TRUE(graph->verification(bestGraph.get(), true));
}

View File

@ -23,8 +23,8 @@ TEST(nnet, MemboundOpInterpretation) {
g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
NMutator nmutator(NMutator::Mode::ToNaiveMembound);
auto mutations = nmutator.run(g);
ASSERT_EQ(mutations.size(), 1u);
Graph gNew = mutations[0];
ASSERT_EQ(mutations.size(), 2u);
Graph gNew = mutations[1];
gNew->print();
gNew->dataMalloc();
@ -54,8 +54,8 @@ TEST(nnet, MemboundOp_Ansor_Codegen) {
g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
NMutator nmutator(NMutator::Mode::ToNaiveMembound);
auto mutations = nmutator.run(g);
ASSERT_EQ(mutations.size(), 1u);
Graph gNew = mutations[0];
ASSERT_EQ(mutations.size(), 2u);
Graph gNew = mutations[1];
gNew->print();
gNew->dataMalloc();
runtime->run(gNew, true); // tune kernels

421
test/nnet/test_mutator.cc Normal file
View File

@ -0,0 +1,421 @@
#include "core/blob.h"
#include "core/dummy_mutator.h"
#include "core/graph.h"
#include "core/runtime.h"
#include "core/search_engine.h"
#include "nnet/nmutator.h"
#include "operators/conv.h"
#include "test.h"
namespace infini {
// TEST(Mutator, Conv9x9) {
// auto g = new tpm::Graph();
// auto i0 = g->tensor({1, 1, 224, 224});
// auto w1 = g->tensor({64, 1, 9, 9});
// g->conv(i0, w1, 4, 4);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(bestGraph, "res.cu");
// }
// TEST(Mutator, TConv_1) {
// auto g = new tpm::Graph();
// auto i0 = g->tensor({1, 1, 1, 228});
// auto w1 = g->tensor({228, 2, 2, 448});
// // g->conv(i0, w1, 4, 4);
// g->convTrans(i0, w1, 0, 0, 1, 1);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(bestGraph, "res.cu");
// }
// TEST(Mutator, TConv_3) {
// auto g = new tpm::Graph();
// auto i0 = g->tensor({1, 2, 2, 448});
// auto w1 = g->tensor({448, 4, 4, 256});
// // g->conv(i0, w1, 4, 4);
// g->convTrans(i0, w1, 1, 1, 2, 2, 1, 1);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(bestGraph, "res.cu");
// }
// // FIXME: failed since implicit transpose for DLT
// TEST(Mutator, InfoGAN_TConv_3_correctness) {
// // verifyNaiveMembound True: subgraph after transformation
// // verifyNaiveMembound False: subgraph of one single membound (eOP)
// const bool verifyNaiveMembound = false;
// auto g = new tpm::Graph();
// // {n, h, w, f} * {r, s, f, c}
// // {n, f, h, w} * {f, c, r, s}
// auto i0 = g->tensor({1, 448, 2, 2});
// auto w1 = g->tensor({448, 256, 4, 4});
// g->convTrans(i0, w1, 1, 1, 2, 2, 1, 1);
// }
TEST(Mutator, NaiveConvWithInterpreter) {
// verifyNaiveMembound True: subgraph after transformation
// verifyNaiveMembound False: subgraph of one single membound (eOP)
Runtime runtime = CpuRuntimeObj::getInstance();
Graph g = make_ref<GraphObj>(runtime);
// const bool verifyNaiveMembound = false;
auto i0 = g->addTensor({1, 3, 32, 32}, DataType::UInt32);
auto w1 = g->addTensor({2, 3, 3, 3}, DataType::UInt32);
g->addOp<ConvObj>(i0, w1, nullptr, 1, 1);
printf("--- Init Finished ---\n");
auto mutator = make_ref<NMutator>();
mutator->setToNaiveMembound();
SearchEngine searchEngine(runtime, mutator);
// g->dataMalloc();
auto bestGraph = searchEngine.run(g);
bestGraph->print();
printf("--- SearchEngine Finished ---\n");
auto mutatedGraphs = mutator->run(g);
IT_ASSERT(mutatedGraphs.size() == 2);
printf("--- Mutator Finished ---\n");
auto gg = mutatedGraphs[1];
g->dataMalloc();
gg->dataMalloc();
for (auto t : g->getTensors()) {
if (t->getFuid() <= 2)
t->setData(IncrementalGenerator());
}
for (auto t : gg->getTensors()) {
if (t->getFuid() <= 2)
t->setData(IncrementalGenerator());
}
runtime->run(g);
runtime->run(gg);
gg->print();
EXPECT_TRUE(g->getOutputs()[0]->equalData(gg->getOutputs()[0]));
EXPECT_TRUE(g->getOutputs()[0]->getRawDataPtr<void *>() !=
gg->getOutputs()[0]->getRawDataPtr<void *>());
}
// TEST(Mutator, G2BMM) {
// auto g = new tpm::Graph();
// int nHeads = 8, seq_len = 10000, feat_len = 64, w = 1000, d = 4;
// auto i0 = g->tensor({nHeads, seq_len, feat_len});
// auto i1 = g->tensor({nHeads, seq_len, feat_len});
// g->g2bmm(i0, i1, w, d);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(graph, "res.cu");
// // codeEngine.genCode(bestGraph, "res.cu");
// }
// TEST(Mutator, GBMML) {
// auto g = new tpm::Graph();
// int nHeads = 8, seq_len = 10000, feat_len = 64, w = 1000, d = 4;
// auto i0 = g->tensor({nHeads, seq_len, 2 * w + 1});
// auto i1 = g->tensor({nHeads, seq_len, feat_len});
// g->gbmml(i0, i1, d);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(graph, "res.cu");
// // codeEngine.genCode(bestGraph, "res.cu");
// }
// TEST(Mutator, Conv5x5) {
// //
// conv7x7->relu->conv3x3->relu->conv3x3->relu->conv3x3->relu->conv3x3->relu
// auto g = new tpm::Graph();
// auto i0 = g->tensor({1, 32, 224, 224});
// auto w1 = g->tensor({1, 32, 5, 5});
// g->conv(i0, w1, tpm::ConvOp::PaddingMode::Same);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(bestGraph, "res.cu");
// }
// TEST(Mutator, BMM) {
// const int m = 16, n = 1024, k = 1024;
// auto g = new tpm::Graph();
// auto i0 = g->tensor({1, m, k});
// auto w0 = g->tensor({1, k, n});
// auto w1 = g->tensor({1, k, n});
// auto w2 = g->tensor({1, k, n});
// g->matmul(i0, w0);
// g->matmul(i0, w1);
// g->matmul(i0, w2);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine<tpm::NMutator> searchEngine;
// searchEngine.run(graph, bestGraph);
// tpm::CodeEngine codeEngine;
// auto perfEngine = searchEngine.exportPerfEngine();
// codeEngine.importPerfEngine(perfEngine);
// codeEngine.genCode(bestGraph, "res.cu");
// }
// TEST(Mutator, Conv2gemm1x1_bs1_mutator) {
// const int N = 1, H = 7, W = 7, C = 512, F = 512, R = 1, S = 1;
// auto g = new tpm::Graph();
// auto i0 = g->tensor({N, C, H, W});
// auto w1 = g->tensor({F, C, R, S});
// g->conv(i0, w1, R / 2, S / 2);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// vector<tpm::SubGraph *> out_graphs;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// const vector<int> rules = {3, 2, 2, 8, 8, 6, 6};
// auto mutator = make_shared<tpm::NMutator>(rules);
// mutator->run(graph.get(), out_graphs);
// tpm::SearchEngine searchEngine(mutator);
// int maxNReshapes = 0;
// for (const auto &graph : out_graphs) {
// searchEngine.getPerf(make_shared<tpm::SubGraph>(*graph), true);
// int nReshapes = 0, nTrans = 0;
// for (auto op : graph->getOperators()) {
// nReshapes += op->isReshapeOp();
// if (auto matmul = dynamic_cast<MatmulOp *>(op))
// nTrans = matmul->getTransA() + matmul->getTransB();
// }
// maxNReshapes = max(maxNReshapes, nReshapes);
// // Number of Reshapes for KxA and AxK
// EXPECT_TRUE((nReshapes == 3 - nTrans) || (nReshapes == nTrans));
// }
// // Matmul K^N A^N -> no Membound
// EXPECT_EQ(maxNReshapes, 3);
// }
// TEST(Mutator, Conv2gemm1x1_searchEngine_ruleBased) {
// const int N = 1, H = 7, W = 7, C = 512, F = 512, R = 1, S = 1;
// auto g = new tpm::Graph();
// auto i0 = g->tensor({N, C, H, W});
// auto w1 = g->tensor({F, C, R, S});
// g->conv(i0, w1, R / 2, S / 2);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// const vector<int> rules = {3, 2, 2, 8, 8, 6, 6};
// tpm::SearchEngine searchEngine(make_shared<tpm::NMutator>(rules));
// searchEngine.run(graph, bestGraph);
// // clang-format off
// // ========== PET graph getPerf ============
// // Reshape(in=0,out=126)
// // op_time 0.000000
// // Reshape(in=1,out=125)
// // op_time 0.000000
// // Matmul([A,B,act=0],A=125,B=126,C=124, TTbmnk: 0, 0, 1, 512, 49, 512)
// // op_time 0.013799
// // Reshape(in=124,out=3)
// // op_time 0.000000
// // Op Cnt T_tot Percent T_mean
// // Matmul 1 0.014 100.0 0.014
// // Reshape 3 0.000 0.0 0.000
// // Origin Perf: 0.0553319
// // Best Perf without correction: 0.0137989
// // Best Perf with correction: 0.0137989
// // clang-format on
// EXPECT_EQ(bestGraph->getOperators().size(), 4u);
// auto cntOps = bestGraph->countOps();
// EXPECT_EQ(cntOps["Matmul"], 1);
// EXPECT_EQ(cntOps["Reshape"], 3);
// bestGraph->print();
// }
// TEST(Mutator, Conv2gemm1x1_searchEngine_search) {
// const int N = 1, H = 7, W = 7, C = 512, F = 512, R = 1, S = 1;
// auto g = new tpm::Graph();
// auto i0 = g->tensor({N, C, H, W});
// auto w1 = g->tensor({F, C, R, S});
// g->conv(i0, w1, R / 2, S / 2);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// tpm::SearchEngine searchEngine(make_shared<tpm::NMutator>());
// searchEngine.run(graph, bestGraph);
// EXPECT_EQ(bestGraph->getOperators().size(), 4u);
// auto cntOps = bestGraph->countOps();
// EXPECT_EQ(cntOps["Matmul"], 1);
// EXPECT_EQ(cntOps["Reshape"], 3);
// bestGraph->print();
// }
// TEST(Mutator, Conv2gemm1x7_searchEngine_ruleBased) {
// const int N = 1, C = 2048, H = 7, W = 7, F = 128, R = 1,
// S = 7; // gcn_Conv_137
// auto g = new tpm::Graph();
// auto i0 = g->tensor({N, C, H, W});
// auto w1 = g->tensor({F, C, R, S});
// g->conv(i0, w1, R / 2, S / 2);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// const vector<int> rules = {3, 2, 2, 5, 8, 8, 6, 90};
// tpm::SearchEngine searchEngine(make_shared<tpm::NMutator>(rules));
// searchEngine.run(graph, bestGraph);
// // clang-format off
// // ========== PET graph getPerf ============
// // Reshape(in=0,out=309)
// // op_time 0.000000
// // MemBound[124644277](i0=1, o0=308, exec_time=0.0683594, NNet
// Inputs=[K,])
// // L<c:0:2048><i52:0:896>Sum ... [i52,c]
// // {L<i52:0:896><c:0:2048>Sum ... [(i52 / 7),c,((i52 / 7) % 1),(i52
// % 7)]
// // {K}}
// // op_time 0.000000
// // Matmul([A^T,B,act=0],A=308,B=309,C=307, TTbmnk: 1, 0, 1, 896, 49,
// 2048)
// // op_time 0.024471
// // MemBound[124644277](i0=307, o0=3, exec_time=0.001, NNet Inputs=[T49,])
// // L<n:0:1><f:0:128><h:0:7><w:0:7>Sum<r:0:1><s:0:7> ... [(h + r),r,(w +
// s),s,n,f]
// //
// {L<i45:0:7><i46:0:1><i26:3:10><i27:0:7><n:0:1><f:0:128><pad=0,0,3,0,0,0,>Sum
// ... [(((7 * f) + (7 * i46)) + i27),(((49 * n) + (7 * i45)) + (i26 +
// -3))]
// // {T49}}
// // op_time 0.001000
// // Op Cnt T_tot Percent T_mean
// // Matmul 1 0.024 96.1 0.024
// // Reshape 1 0.000 0.0 0.000
// // MemBound 2 0.001 3.9 0.001
// // Origin Perf: 0.405595
// // Best Perf without correction: 0.0254715
// // Best Perf with correction: 0.0254715
// // Transpose perf: 0
// // clang-format on
// EXPECT_EQ(bestGraph->getOperators().size(), 4u);
// auto cntOps = bestGraph->countOps();
// EXPECT_EQ(cntOps["Matmul"], 1);
// EXPECT_EQ(cntOps["Reshape"], 1);
// EXPECT_EQ(cntOps["MemBound"], 2);
// bestGraph->print();
// EXPECT_TRUE(graph->verification(bestGraph.get(), true));
// }
// TEST(Mutator, Conv2gemm7x1_searchEngine_ruleBased) {
// const int N = 1, C = 2048, H = 7, W = 7, F = 128, R = 7,
// S = 1; // gcn_Conv_137
// auto g = new tpm::Graph();
// auto i0 = g->tensor({N, C, H, W});
// auto w1 = g->tensor({F, C, R, S});
// g->conv(i0, w1, R / 2, S / 2);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// const vector<int> rules = {3, 2, 2, 5, 8, 8, 6, 90};
// tpm::SearchEngine searchEngine(make_shared<tpm::NMutator>(rules));
// searchEngine.run(graph, bestGraph);
// EXPECT_EQ(bestGraph->getOperators().size(), 4u);
// auto cntOps = bestGraph->countOps();
// EXPECT_EQ(cntOps["Matmul"], 1);
// EXPECT_EQ(cntOps["Reshape"], 1);
// EXPECT_EQ(cntOps["MemBound"], 2);
// bestGraph->print();
// EXPECT_TRUE(graph->verification(bestGraph.get(), true));
// }
// TEST(Mutator, Conv2gemm7x1_searchEngine_search) {
// const int N = 1, C = 2048, H = 7, W = 7, F = 128, R = 7,
// S = 1; // gcn_Conv_137
// auto g = new tpm::Graph();
// auto i0 = g->tensor({N, C, H, W});
// auto w1 = g->tensor({F, C, R, S});
// g->conv(i0, w1, R / 2, S / 2);
// g->updateConnection();
// std::shared_ptr<tpm::SubGraph> graph, bestGraph;
// graph = std::make_shared<tpm::SubGraph>(g->getOperators());
// // const vector<int> rules = {3, 2, 2, 5, 8, 8, 6, 90};
// tpm::SearchEngine searchEngine(make_shared<tpm::NMutator>());
// searchEngine.run(graph, bestGraph);
// EXPECT_EQ(bestGraph->getOperators().size(), 4u);
// auto cntOps = bestGraph->countOps();
// EXPECT_EQ(cntOps["Matmul"], 1);
// EXPECT_EQ(cntOps["Reshape"], 1);
// EXPECT_EQ(cntOps["MemBound"], 2);
// bestGraph->print();
// EXPECT_TRUE(graph->verification(bestGraph.get(), true));
// }
} // namespace infini