forked from jiuyuan/InfiniTensor
tested fsrcnn
This commit is contained in:
commit
8409c1f9d4
|
@ -63,6 +63,7 @@ class GraphObj : public Object {
|
|||
void optimize();
|
||||
|
||||
void dataMalloc();
|
||||
void dataFree();
|
||||
|
||||
/**
|
||||
* @brief Add an operator and create its outputs. Output tensor arguments
|
||||
|
|
|
@ -30,7 +30,14 @@ class Mutator {
|
|||
virtual bool isMultiBranchMergable(const Graph &in_graph) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
|
||||
/// @brief Fuse memory bound operators.
|
||||
/// @return The graph after fusion. Return `nullptr` if fails.
|
||||
virtual Graph fuseVertically(const Graph &inputGraph) { IT_TODO_HALT(); }
|
||||
|
||||
/// @brief Eliminate transpose and reshape.
|
||||
/// @return The graph after elimination. Return `nullptr` if fails.
|
||||
virtual Graph eliminateVertically(const Graph &in_graph) { IT_TODO_HALT(); }
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -105,7 +105,8 @@ enum class OpType {
|
|||
MemBound = 300,
|
||||
//
|
||||
Conv2dReduce = 400,
|
||||
Conv2dReduceTranspose
|
||||
Conv2dReduceTranspose,
|
||||
Any
|
||||
};
|
||||
|
||||
using KernelAttrs = std::tuple<Device, OpType, DataType>;
|
||||
|
@ -217,6 +218,7 @@ class OpRegistry {
|
|||
//
|
||||
FOP(Conv2dReduce);
|
||||
FOP(Conv2dReduceTranspose);
|
||||
FOP(Any);
|
||||
default:
|
||||
IT_ASSERT(false, "Unknown OpType " +
|
||||
std::to_string(enum_to_underlying(opType)));
|
||||
|
|
|
@ -78,25 +78,8 @@ class TensorObj : public TensorBaseObj {
|
|||
void setData(
|
||||
std::function<void(void *, size_t, DataType)> const &generator) const;
|
||||
void setData(const Blob &_blob) { data = _blob; }
|
||||
Tensor clone() const {
|
||||
auto obj = make_ref<TensorObj>(*this);
|
||||
obj->freeData();
|
||||
obj->targets.clear();
|
||||
obj->source.reset();
|
||||
return obj;
|
||||
}
|
||||
Tensor clone(Runtime runtime) const {
|
||||
auto obj = make_ref<TensorObj>(*this);
|
||||
obj->runtime = runtime;
|
||||
obj->freeData();
|
||||
obj->targets.clear();
|
||||
obj->source.reset();
|
||||
// if (hasData()) {
|
||||
// obj->dataMalloc();
|
||||
// obj->copyData(this);
|
||||
// }
|
||||
return obj;
|
||||
}
|
||||
Tensor clone() const;
|
||||
Tensor clone(Runtime runtime) const;
|
||||
|
||||
void printData() const;
|
||||
bool equalData(const Tensor &rhs, double relativeError = 1e-6) const;
|
||||
|
@ -127,6 +110,12 @@ class TensorObj : public TensorBaseObj {
|
|||
if (i % dimSzVec[j] == 0)
|
||||
builder << "[";
|
||||
|
||||
if (iEnd > 1000 && i > 20 && i < iEnd - 20) {
|
||||
printf("... , ");
|
||||
i = iEnd - 20;
|
||||
continue;
|
||||
}
|
||||
|
||||
builder << ptr[i];
|
||||
for (size_t j = 0; j < numDims; ++j)
|
||||
if ((int)i % dimSzVec[j] == dimSzVec[j] - 1)
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
#pragma once
|
||||
|
||||
#include "operators/any.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void any_kernel_mapping(vector<float *> input, vector<float *> output,
|
||||
const string &kernel_name, const vector<int> &attr);
|
||||
|
||||
} // namespace infini
|
|
@ -12,6 +12,11 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
cublasHandle_t cublas;
|
||||
CudaPtr workspace;
|
||||
size_t workspaceSize;
|
||||
|
||||
// Memory information
|
||||
size_t allocatedGPUMemorySize = 0;
|
||||
map<void *, size_t> allocationMap;
|
||||
|
||||
bool cudaGraphStatus; // Whether CUDA graph stream capture is enabled
|
||||
|
||||
public:
|
||||
|
@ -28,10 +33,20 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
void *ptr;
|
||||
// dbg(size);
|
||||
checkCudaError(cudaMalloc(&ptr, size));
|
||||
// printf("cuda malloc: %p %lu bytes\n", ptr, size);
|
||||
allocatedGPUMemorySize += size;
|
||||
allocationMap[ptr] = size;
|
||||
// printf("cuda malloc: %p %lu bytes, total %lu bytes (%.2lf GB)\n",
|
||||
// ptr, size, allocatedGPUMemorySize,
|
||||
// double(allocatedGPUMemorySize) / 1024 / 1024 / 1024);
|
||||
return ptr;
|
||||
}
|
||||
void dealloc(void *ptr) override { checkCudaError(cudaFree(ptr)); }
|
||||
void dealloc(void *ptr) override {
|
||||
checkCudaError(cudaFree(ptr));
|
||||
allocatedGPUMemorySize -= allocationMap.at(ptr);
|
||||
allocationMap.erase(ptr);
|
||||
// printf("cuda dealloc: %p %lu bytes, total %lu\n", ptr,
|
||||
// allocationMap.at(ptr), allocatedGPUMemorySize);
|
||||
}
|
||||
cudnnHandle_t cudnnHandle() const { return cudnn; }
|
||||
cublasHandle_t cublasHandle() const { return cublas; }
|
||||
size_t getWorkspaceSize() const { return workspaceSize; }
|
||||
|
@ -60,7 +75,7 @@ class CudaRuntimeObj : public RuntimeObj {
|
|||
bool isInCudaGraph() const { return cudaGraphStatus; }
|
||||
cudaStream_t getStream() const { return stream; }
|
||||
|
||||
double timeWithCudaGraph(Graph graph);
|
||||
double timeWithCudaGraph(Graph graph, int rounds = 1000);
|
||||
|
||||
private:
|
||||
void tune(const Graph &graph, bool profiling) const;
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
#pragma once
|
||||
|
||||
#include "operators/transpose.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void transpose_kernel(float *input, float *output, int nDims, int size,
|
||||
SmallArray strides, SmallArray outputShape);
|
||||
|
||||
} // namespace infini
|
File diff suppressed because it is too large
Load Diff
|
@ -26,6 +26,8 @@ class NMutator : public Mutator {
|
|||
|
||||
vector<Graph> run(const Graph &in_graph) override;
|
||||
Graph fuseVertically(const Graph &in_graph) override;
|
||||
Graph eliminateVertically(const Graph &in_graph) override;
|
||||
bool isMultiBranchMergable(const Graph &in_graph) override;
|
||||
|
||||
void setToNaiveMembound();
|
||||
void setMaxDepth(int _maxDepth) { maxDepth = _maxDepth; }
|
||||
|
@ -57,14 +59,26 @@ class NMutator : public Mutator {
|
|||
double memboundTime(const Shape &dims);
|
||||
|
||||
// TODO: recover these rules
|
||||
// Graph fuseHetConv(nnet::Expr expr, Graph in_graph);
|
||||
Graph transformConvtransposed1x1(Operator _op);
|
||||
// Graph transformConvtransposed(Operator op);
|
||||
Graph transformConv1x1(Operator op);
|
||||
Graph transformG2bmm(Operator op);
|
||||
Graph transformGbmm(Operator op);
|
||||
Graph transformDialtedConv(Operator _op);
|
||||
// Graph transformConv1x1(Operator op);
|
||||
// Graph transformConv1xk(Operator op);
|
||||
Graph transformConvToGEMMReduce(Operator _op);
|
||||
Graph transformConvTranposeToGEMMReduce(Operator _op);
|
||||
|
||||
Tensor splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
|
||||
Tensor output = nullptr);
|
||||
|
||||
/// @brief Construct a new graph with a chain of operators. Use the output
|
||||
/// from the previous operator as the input of the next operator. While
|
||||
/// constructing, the input and output tensors from inputGraph are used as
|
||||
/// new constructed graph.
|
||||
/// @param op The operator chain. It can have wrong input/output shapes.
|
||||
/// @return
|
||||
Graph constructGraphByOperatorChain(vector<Operator> ops, Graph inputGraph);
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -7,6 +7,7 @@ namespace infini {
|
|||
|
||||
Graph getGANGraph(int batch, Runtime runtime, int nLayers, int modelId);
|
||||
Graph getFSRCNNGraph(int batch, Runtime runtime);
|
||||
Graph getLongformer(Runtime runtime, int bs);
|
||||
vector<Tensor> runInfoGAN(int nLayers);
|
||||
Graph getConvtransposedNHWC(Runtime runtime, Shape shape, int layerId);
|
||||
Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,
|
||||
|
|
|
@ -49,7 +49,7 @@ template <typename R, typename... Args> class Functor<R(Args...)> {
|
|||
virtual R visit_(const Tensor &c, Args... args) FUNCTOR_DEFAULT;
|
||||
virtual R visit_(const Func &c, Args... args) FUNCTOR_DEFAULT;
|
||||
virtual R visitDefault(const Expr &c, [[maybe_unused]] Args... args) {
|
||||
dbg(*c);
|
||||
dbg(*c, c->getType());
|
||||
nnet_assert(0, "Reach unimplemented visit function.");
|
||||
return R();
|
||||
};
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class AnyObj : public OperatorObj {
|
||||
private:
|
||||
string kernelName;
|
||||
vector<int> attr;
|
||||
|
||||
public:
|
||||
AnyObj(GraphObj *graph, const TensorVec &inputs, const TensorVec &outputs,
|
||||
string &kernelName, const vector<int> &attr);
|
||||
|
||||
OP_CLONE(AnyObj);
|
||||
|
||||
string toString() const override;
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
||||
int numInputs() const override { return inputs.size(); }
|
||||
int numOutputs() const override { return outputs.size(); }
|
||||
|
||||
const string getKernelName() const;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
vector<int> getWorkloadVector() const override;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -19,7 +19,7 @@ class ReshapeObj : public OperatorObj {
|
|||
* @param output The output tensor.
|
||||
* @param dims The shape of the output tensor.
|
||||
*/
|
||||
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims);
|
||||
ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims = {});
|
||||
OP_CLONE(ReshapeObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
namespace infini {
|
||||
class TransposeObj : public OperatorObj {
|
||||
vector<int> transposePermute;
|
||||
|
||||
public:
|
||||
TransposeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
vector<int> permute);
|
||||
|
@ -15,7 +17,6 @@ class TransposeObj : public OperatorObj {
|
|||
std::vector<int> getPermute() const { return transposePermute; }
|
||||
|
||||
private:
|
||||
vector<int> transposePermute = {1, 1, 1, 1};
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
namespace infini {
|
||||
|
||||
#define SMALL_ARRAY_SIZE 8
|
||||
struct SmallArray {
|
||||
int data[SMALL_ARRAY_SIZE];
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -639,9 +639,9 @@ class OnnxStub:
|
|||
if name is None:
|
||||
self.count_in += 1
|
||||
if tensor.getTensorType() == backend.TensorType.Input:
|
||||
name = "input{}".format(self.count_in)
|
||||
name = f"input{self.count_in}_{tensor.guid()}"
|
||||
else:
|
||||
name = "weight{}".format(self.count_in)
|
||||
name = f"weight{self.count_in}_{tensor.guid()}"
|
||||
self.names[tensor] = name
|
||||
if init != None:
|
||||
init.name = name
|
||||
|
@ -706,7 +706,7 @@ class OnnxStub:
|
|||
for it in op.inputs()
|
||||
]
|
||||
outputs = [
|
||||
ctx.push_output("{}_{}".format(name, i), it)
|
||||
ctx.push_output(f"{name}_{i}_{it.guid()}", it)
|
||||
for (i, it) in enumerate(op.outputs())
|
||||
]
|
||||
if ty == backend.OpType.Conv or ty == backend.OpType.ConvNHWC:
|
||||
|
@ -884,7 +884,8 @@ class OnnxStub:
|
|||
ctx.push_data_input(name, "max", TensorProto.FLOAT, [], [])
|
||||
)
|
||||
ctx.push_node(make_node(ty.name, inputs, outputs, name))
|
||||
elif ty == backend.OpType.ConvTransNHWC:
|
||||
elif ty in [backend.OpType.ConvTransNHWC, backend.OpType.GBMM,
|
||||
backend.OpType.G2BMM]:
|
||||
ctx.push_node(
|
||||
make_node(
|
||||
ty.name,
|
||||
|
@ -1003,3 +1004,9 @@ def _parse_data(tensor: TensorProto) -> List[Any]:
|
|||
|
||||
def _take_shape_dim(shape: TensorShapeProto) -> List[int]:
|
||||
return [(d.dim_value if d.dim_value > 0 else 1) for d in shape.dim]
|
||||
|
||||
|
||||
def save_onnx(opt_g, filename: str):
|
||||
stub = OnnxStub.from_graph(opt_g)
|
||||
with open(filename, "wb") as f:
|
||||
f.write(stub.to_onnx("optimized").SerializeToString())
|
||||
|
|
|
@ -29,6 +29,7 @@ void BangRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
|||
perfEngine.setPerfData(perfKey, record);
|
||||
} else
|
||||
record = perfData;
|
||||
std::cout << 5 << std::endl;
|
||||
|
||||
double t = record->time;
|
||||
totalTime += t;
|
||||
|
|
|
@ -125,7 +125,24 @@ void GraphObj::optimize() {
|
|||
|
||||
void GraphObj::dataMalloc() {
|
||||
for (auto &tensor : tensors) {
|
||||
tensor->dataMalloc();
|
||||
if (tensor->getSource() && tensor->getTargets().size() > 0 &&
|
||||
tensor->getSource()->getOpType() == OpType::Reshape) {
|
||||
continue;
|
||||
} else
|
||||
tensor->dataMalloc();
|
||||
}
|
||||
// Fill reshape output for avoiding nullptr
|
||||
for (auto &tensor : tensors) {
|
||||
if (tensor->getSource() &&
|
||||
tensor->getSource()->getOpType() == OpType::Reshape) {
|
||||
tensor->setData(tensor->getSource()->getInputs(0)->getDataBlob());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphObj::dataFree() {
|
||||
for (auto &tensor : tensors) {
|
||||
tensor->freeData();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -35,9 +35,10 @@ bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
|
|||
bool OperatorObj::isReshapeOp() const { return type == OpType::Reshape; }
|
||||
|
||||
bool OperatorObj::isMemBoundOp() const {
|
||||
return type == OpType::MemBound || type == OpType::Activation ||
|
||||
type == OpType::Transpose || type == OpType::Relu ||
|
||||
type == OpType::Tanh;
|
||||
return type == OpType::MemBound || type == OpType::Reshape ||
|
||||
type == OpType::Activation || type == OpType::Transpose ||
|
||||
type == OpType::Relu || type == OpType::Tanh ||
|
||||
type == OpType::Softmax;
|
||||
}
|
||||
|
||||
void OperatorObj::removePredecessors(const Operator &op) {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "core/search_engine.h"
|
||||
#include "core/hash.h"
|
||||
#include "core/runtime.h"
|
||||
#include "ffi/ffi_callback.h"
|
||||
#include "nnet/dbg.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
@ -70,7 +71,6 @@ Graph SearchEngine::run(const Graph graph) {
|
|||
}
|
||||
}
|
||||
auto tmp = make_ref<GraphObj>(runtimeExec, ops);
|
||||
tmp->dataMalloc();
|
||||
nextGraphs.emplace_back(tmp);
|
||||
}
|
||||
}
|
||||
|
@ -376,9 +376,6 @@ std::vector<Graph> SearchEngine::searchMutation(const MetaGraph &metaGraph) {
|
|||
nextGraphs.emplace_back(make_ref<GraphObj>(runtimeExec, ops));
|
||||
}
|
||||
}
|
||||
for (auto g : nextGraphs) {
|
||||
g->dataMalloc();
|
||||
}
|
||||
dbg("===Num" + std::to_string(nextGraphs.size()));
|
||||
std::sort(nextGraphs.begin(), nextGraphs.end(), graphTimeComparer);
|
||||
if (nextGraphs.size() > GRAPH_SIZE) {
|
||||
|
@ -441,7 +438,6 @@ std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
|
|||
std::cout << op->toString() << std::endl;
|
||||
}
|
||||
auto tmp = make_ref<GraphObj>(runtimeExec, headOps);
|
||||
tmp->dataMalloc();
|
||||
partitions.emplace_back(tmp);
|
||||
headOps.clear();
|
||||
}
|
||||
|
@ -449,7 +445,6 @@ std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
|
|||
}
|
||||
if (!headOps.empty()) {
|
||||
auto tmp = make_ref<GraphObj>(runtimeExec, headOps);
|
||||
tmp->dataMalloc();
|
||||
partitions.emplace_back(tmp);
|
||||
}
|
||||
std::reverse(partitions.begin(), partitions.end());
|
||||
|
@ -457,6 +452,9 @@ std::vector<Graph> SearchEngine::partitionGraph(const Graph graph) {
|
|||
}
|
||||
|
||||
double SearchEngine::getEstimatedGraphPerf(Graph graph) {
|
||||
// dbg(graph);
|
||||
// // hkz
|
||||
// callback::exportONNX(graph, "a.onnx");
|
||||
return runtimeExec->getPerfTime(graph, false, true, true);
|
||||
}
|
||||
|
||||
|
@ -502,9 +500,15 @@ Graph SearchEngine::fuseVertically(const Graph &graph) {
|
|||
}
|
||||
make_ref<GraphObj>(runtimeExec, chainOps)->print();
|
||||
|
||||
Graph optGraph =
|
||||
mutator->fuseVertically(make_ref<GraphObj>(runtimeExec, chainOps));
|
||||
for (auto op : optGraph->getOperators()) {
|
||||
auto bestGraph = make_ref<GraphObj>(runtimeExec, chainOps);
|
||||
// Eliminate transpose and reshape operators
|
||||
// if (auto eliminatedGraph = mutator->eliminateVertically(
|
||||
// make_ref<GraphObj>(runtimeExec, chainOps)))
|
||||
// bestGraph = eliminatedGraph;
|
||||
// Fuse membound operators
|
||||
if (auto optGraph = mutator->fuseVertically(bestGraph))
|
||||
bestGraph = optGraph;
|
||||
for (auto op : bestGraph->getOperators()) {
|
||||
ops.emplace_back(op);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ TensorObj::TensorObj(Shape shape_, DataType dtype, Runtime runtime,
|
|||
: TensorBaseObj(shape_.size(), dtype, runtime), shape(std::move(shape_)),
|
||||
_size(shape.empty()
|
||||
? 0
|
||||
: std::accumulate(shape.begin(), shape.end(), 1,
|
||||
: std::accumulate(shape.begin(), shape.end(), 1lu,
|
||||
[](auto acc, auto x) { return acc * x; })),
|
||||
tensorType(tensorType) {}
|
||||
|
||||
|
@ -124,8 +124,6 @@ bool TensorObj::equalData(const Tensor &rhs, double relativeError) const {
|
|||
|
||||
void TensorObj::dataMalloc() {
|
||||
if (!data) {
|
||||
dbg(toString());
|
||||
dbg(getBytes());
|
||||
data = runtime->allocBlob(getBytes());
|
||||
}
|
||||
}
|
||||
|
@ -186,4 +184,26 @@ size_t TensorObj::getOffsetByBroadcastOffset(size_t bcOffset,
|
|||
|
||||
return getOffsetByPos(pos, shape);
|
||||
}
|
||||
|
||||
Tensor TensorObj::clone() const {
|
||||
auto obj = make_ref<TensorObj>(*this);
|
||||
obj->freeData();
|
||||
obj->targets.clear();
|
||||
obj->source.reset();
|
||||
return obj;
|
||||
}
|
||||
|
||||
Tensor TensorObj::clone(Runtime runtime) const {
|
||||
auto obj = make_ref<TensorObj>(*this);
|
||||
obj->runtime = runtime;
|
||||
obj->freeData();
|
||||
obj->targets.clear();
|
||||
obj->source.reset();
|
||||
if (hasData()) {
|
||||
obj->dataMalloc();
|
||||
obj->copyData(this);
|
||||
}
|
||||
return obj;
|
||||
}
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -17,9 +17,7 @@ CudaRuntimeObj::CudaRuntimeObj()
|
|||
checkCublasError(cublasCreate(&cublas));
|
||||
checkCudnnError(cudnnSetStream(cudnn, stream));
|
||||
checkCublasError(cublasSetStream(cublas, stream));
|
||||
// 10GB for Longformer
|
||||
// size_t longformerNum = 3lu * (1 << 30);
|
||||
workspaceSize = 7ll << 30; // 7 GB
|
||||
workspaceSize = 2ll << 30; // 2 GB
|
||||
workspace = alloc(workspaceSize);
|
||||
}
|
||||
|
||||
|
@ -121,7 +119,7 @@ void CudaRuntimeObj::sync() const { checkCudaError(cudaDeviceSynchronize()); }
|
|||
|
||||
string CudaRuntimeObj::toString() const { return "CUDA Runtime"; }
|
||||
|
||||
double CudaRuntimeObj::timeWithCudaGraph(Graph graph) {
|
||||
double CudaRuntimeObj::timeWithCudaGraph(Graph graph, int rounds) {
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto &perfEngine = PerfEngine::getInstance();
|
||||
// compile-time computable
|
||||
|
@ -141,6 +139,7 @@ double CudaRuntimeObj::timeWithCudaGraph(Graph graph) {
|
|||
kernel->compute(op, perfData, this);
|
||||
else
|
||||
kernel->compute(op, this);
|
||||
// FIXME: transpose
|
||||
if (!ctcMap.at(op->getGuid()) && op->getOpType() != OpType::Reshape)
|
||||
kernels.emplace_back(op, kernel, perfData);
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ static std::function<void(const Graph &, string)> exportONNXImpl;
|
|||
void exportONNX(const Graph &graph, const string &path) {
|
||||
IT_ASSERT(Py_IsInitialized(), "Python interpreter is not running.");
|
||||
static auto exportONNXImpl =
|
||||
py::module_::import("infinitensor.if_onnx").attr("export_onnx");
|
||||
py::module_::import("pyinfinitensor.onnx").attr("save_onnx");
|
||||
exportONNXImpl(graph, path);
|
||||
}
|
||||
|
||||
|
|
|
@ -407,6 +407,7 @@ void export_test_model(py::module &m) {
|
|||
m.def("runInfoGAN", &runInfoGAN)
|
||||
.def("getGANGraph", &getGANGraph)
|
||||
.def("getFSRCNNGraph", &getFSRCNNGraph)
|
||||
.def("getLongformer", &getLongformer)
|
||||
.def("getConvtransposedNHWC", &getConvtransposedNHWC)
|
||||
.def("optimizeGraph", &optimizeGraph, "graph"_a, "runtime"_a,
|
||||
"tuning"_a = false, "mode"_a = NMutator::Mode::Normal,
|
||||
|
|
|
@ -33,7 +33,7 @@ class G2BMMCudnn : public CudaKernelWithoutConfig {
|
|||
auto record =
|
||||
make_ref<PerfRecordObj>(std::numeric_limits<double>::max());
|
||||
const auto [warmupRounds, timingRounds] =
|
||||
op->getB() > 100 ? tuple{1, 3} : tuple{5, 15};
|
||||
op->getB() > 100 ? tuple{1, 1} : tuple{1, 2};
|
||||
double tmp =
|
||||
timeit([&]() { g2bmmKernel(op, context); },
|
||||
[&]() { context->sync(); }, warmupRounds, timingRounds);
|
||||
|
|
|
@ -34,7 +34,7 @@ class GBMMCudnn : public CudaKernelWithoutConfig {
|
|||
auto record =
|
||||
make_ref<PerfRecordObj>(std::numeric_limits<double>::max());
|
||||
const auto [warmupRounds, timingRounds] =
|
||||
op->getB() > 100 ? tuple{1, 3} : tuple{5, 15};
|
||||
op->getB() > 100 ? tuple{1, 1} : tuple{1, 3};
|
||||
double tmp =
|
||||
timeit([&]() { gbmmKernel(op, context); },
|
||||
[&]() { context->sync(); }, warmupRounds, timingRounds);
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
#include "operators/any.h"
|
||||
#include "cuda/cuda_any.h"
|
||||
#include "cuda/cuda_conv2dreduce.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class AnyCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<AnyObj>(_op);
|
||||
|
||||
auto inputs = op->getInputs();
|
||||
auto outputs = op->getOutputs();
|
||||
|
||||
vector<float *> inputsRawPtr;
|
||||
for (auto &input : inputs) {
|
||||
inputsRawPtr.emplace_back(input->getRawDataPtr<float *>());
|
||||
}
|
||||
vector<float *> outputsRawPtr;
|
||||
for (auto &output : outputs) {
|
||||
outputsRawPtr.emplace_back(output->getRawDataPtr<float *>());
|
||||
}
|
||||
|
||||
any_kernel_mapping(inputsRawPtr, outputsRawPtr, op->getKernelName(),
|
||||
op->getOpAttrVector());
|
||||
}
|
||||
};
|
||||
|
||||
void any_kernel_mapping(vector<float *> inputs, vector<float *> outputs,
|
||||
const string &kernelName, const vector<int> &attr) {
|
||||
if (kernelName == "conv2dreduce_kernel") {
|
||||
IT_ASSERT(attr.size() == 15);
|
||||
IT_ASSERT(inputs.size() == 1 || inputs.size() == 2)
|
||||
IT_ASSERT(outputs.size() == 1);
|
||||
conv2dreduce_kernel(inputs[0], inputs.size() > 1 ? inputs[1] : nullptr,
|
||||
outputs[0], attr[0] != 0, attr[1], attr[2], attr[3],
|
||||
attr[4], attr[5], attr[6], attr[7], attr[8],
|
||||
attr[9], attr[10], attr[11], attr[12], attr[13],
|
||||
attr[14]);
|
||||
} else {
|
||||
std::cout << "Unimplemented AnyOp cuda kernel: " << kernelName
|
||||
<< std::endl;
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Any, DataType::Float32, AnyCuda,
|
||||
"Any_CUDA_Float32");
|
||||
|
||||
} // namespace infini
|
|
@ -24,4 +24,4 @@ class ClipCuda : public CudaKernelWithoutConfig {
|
|||
REGISTER_KERNEL(Device::CUDA, OpType::Clip, DataType::Float32, ClipCuda,
|
||||
"Clip_CUDA_Float32");
|
||||
|
||||
}; // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
#include "core/common.h"
|
||||
#include "core/constants.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include <math.h>
|
||||
|
||||
using infini::E_CONSTANT;
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
@ -29,4 +27,4 @@ void clip_kernel(float *input, float *output, int num, float minValue,
|
|||
maxValue);
|
||||
}
|
||||
|
||||
}; // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
#include "operators/transpose.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_transpose.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class TransposeCuda : public CudaKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<TransposeObj>(_op);
|
||||
|
||||
auto input = op->getInputs(0);
|
||||
auto output = op->getOutput();
|
||||
void *const inputData = input->getRawDataPtr<void *>();
|
||||
void *const outputData = output->getRawDataPtr<void *>();
|
||||
const auto &inputShape = input->getDims();
|
||||
const auto &outputShape = output->getDims();
|
||||
|
||||
const auto &perm = op->getPermute();
|
||||
int size = input->size();
|
||||
int nDims = input->getDims().size();
|
||||
|
||||
// Compute strides
|
||||
SmallArray strides, buffer;
|
||||
IT_ASSERT(nDims <= SMALL_ARRAY_SIZE);
|
||||
int curStride = 1;
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
buffer.data[i] = curStride;
|
||||
curStride *= inputShape[i];
|
||||
}
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
strides.data[i] = buffer.data[perm[i]];
|
||||
}
|
||||
|
||||
SmallArray outputDims;
|
||||
for (int i = 0; i < nDims; ++i) {
|
||||
outputDims.data[i] = outputShape[i];
|
||||
}
|
||||
|
||||
transpose_kernel((float *)inputData, (float *)outputData, nDims, size,
|
||||
strides, outputDims);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Transpose, DataType::Float32,
|
||||
TransposeCuda, "Transpose_CUDA_Float32");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,37 @@
|
|||
#include "core/common.h"
|
||||
#include "cuda/cuda_common.h"
|
||||
#include "utils/small_array.h"
|
||||
|
||||
constexpr unsigned int num_threads() { return 32 * 4; }
|
||||
constexpr int thread_work_size() { return 4; }
|
||||
constexpr int block_work_size() { return thread_work_size() * num_threads(); }
|
||||
|
||||
__global__ void _transpose_kernel(float *input, float *output, int nDims,
|
||||
int size, infini::SmallArray strides,
|
||||
infini::SmallArray outputShape) {
|
||||
int outputIdx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (outputIdx < size) {
|
||||
int inputIdx = 0;
|
||||
int v = outputIdx;
|
||||
for (int i = nDims - 1; i >= 0; --i) {
|
||||
inputIdx += v % outputShape.data[i] * strides.data[i];
|
||||
v /= outputShape.data[i];
|
||||
}
|
||||
#if __CUDA_ARCH__ >= 350 || defined(USE_ROCM)
|
||||
output[outputIdx] = __ldg(input + inputIdx);
|
||||
#else
|
||||
output[outputIdx] = input[inputIdx];
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void transpose_kernel(float *input, float *output, int nDims, int size,
|
||||
SmallArray strides, SmallArray outputShape) {
|
||||
int blocksize = block_work_size();
|
||||
int gridsize = (size + block_work_size() - 1) / block_work_size();
|
||||
_transpose_kernel<<<gridsize, blocksize>>>(input, output, nDims, size,
|
||||
strides, outputShape);
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -25,9 +25,16 @@ class ActivationCudnn : public CudaKernelWithoutConfig {
|
|||
|
||||
cudnnTensorDescriptor_t inputDesc, outputDesc;
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
IT_ASSERT_TODO(dim.size() <= 4);
|
||||
int n, c, h, w;
|
||||
if (dim.size() == 4) {
|
||||
n = dim[0], c = dim[1], h = dim[2], w = dim[3];
|
||||
} else if (dim.size() == 3) {
|
||||
n = 1, c = dim[0], h = dim[1], w = dim[2];
|
||||
} else {
|
||||
dbg(vecToString(dim));
|
||||
IT_TODO_HALT();
|
||||
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
|
||||
}
|
||||
|
||||
// get inputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
|
||||
|
|
|
@ -7,7 +7,14 @@
|
|||
#include "cuda/cuda_runtime.h"
|
||||
#include "ffi/ffi_callback.h"
|
||||
#include "nnet/nmutator.h"
|
||||
#include "operators/G2BMM.h"
|
||||
#include "operators/GBMM.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
#include "test.h"
|
||||
#include <pybind11/stl.h>
|
||||
|
@ -130,7 +137,113 @@ Graph getFSRCNNGraph(int batch, Runtime runtime) {
|
|||
pad, stride, stride, 1, 1)
|
||||
->getOutput();
|
||||
}
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph getLongformer(Runtime runtime, int bs) {
|
||||
const int seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4;
|
||||
const int hidden = featlen, hiddenPerHead = hidden / heads;
|
||||
assert(hidden % heads == 0);
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto i0 = g->addTensor({bs, seqlen, featlen}, DataType::Float32,
|
||||
TensorType::Input);
|
||||
auto w0 = g->addTensor({featlen, hidden}, DataType::Float32,
|
||||
TensorType::Initialized);
|
||||
auto w1 =
|
||||
g->addTensor({512, 512}, DataType::Float32, TensorType::Initialized);
|
||||
auto w2 =
|
||||
g->addTensor({512, 512}, DataType::Float32, TensorType::Initialized);
|
||||
// Feed forward
|
||||
auto w3 =
|
||||
g->addTensor({512, 512}, DataType::Float32, TensorType::Initialized);
|
||||
auto bias3 =
|
||||
g->addTensor({512}, DataType::Float32, TensorType::Initialized);
|
||||
auto w4 =
|
||||
g->addTensor({512, 512}, DataType::Float32, TensorType::Initialized);
|
||||
auto bias4 =
|
||||
g->addTensor({512}, DataType::Float32, TensorType::Initialized);
|
||||
|
||||
auto q0 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto k0 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto v0 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
|
||||
auto q1 = g->addTensor({bs, seqlen, heads, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto k1 = g->addTensor({bs, seqlen, heads, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto v1 = g->addTensor({bs, seqlen, heads, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
|
||||
auto q2 = g->addTensor({bs, heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto k2 = g->addTensor({bs, heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto v2 = g->addTensor({bs, heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
|
||||
auto q3 = g->addTensor({bs * heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto k3 = g->addTensor({bs * heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto v3 = g->addTensor({bs * heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
|
||||
auto prob = g->addTensor({bs * heads, seqlen, 2 * w + 1}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto probSoftmax = g->addTensor({bs * heads, seqlen, 2 * w + 1},
|
||||
DataType::Float32, TensorType::Other);
|
||||
auto attn = g->addTensor({bs * heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32, TensorType::Other);
|
||||
|
||||
auto t00 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto t01 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto t02 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
// auto t10 = g->addTensor({bs, seqlen, hidden});
|
||||
auto t11 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto t12 = g->addTensor({bs, seqlen, hidden}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
auto output = g->addTensor({bs, seqlen, featlen}, DataType::Float32,
|
||||
TensorType::Other);
|
||||
|
||||
g->addOpWithOutputs<MatmulObj>(i0, w0, q0, false, true);
|
||||
g->addOpWithOutputs<MatmulObj>(i0, w1, k0, false, true);
|
||||
g->addOpWithOutputs<MatmulObj>(i0, w2, v0, false, true);
|
||||
g->addOpWithOutputs<ReshapeObj>(q0, q1);
|
||||
g->addOpWithOutputs<ReshapeObj>(k0, k1);
|
||||
g->addOpWithOutputs<ReshapeObj>(v0, v1);
|
||||
// For example, when perm=(1, 0, 2), given an input tensor of shape (1,
|
||||
// 2, 3), the output shape will be (2, 1, 3).
|
||||
g->addOpWithOutputs<TransposeObj>(q1, q2, vector{0, 2, 1, 3});
|
||||
g->addOpWithOutputs<TransposeObj>(k1, k2, vector{0, 2, 1, 3});
|
||||
g->addOpWithOutputs<TransposeObj>(v1, v2, vector{0, 2, 1, 3});
|
||||
g->addOpWithOutputs<ReshapeObj>(q2, q3);
|
||||
g->addOpWithOutputs<ReshapeObj>(k2, k3);
|
||||
g->addOpWithOutputs<ReshapeObj>(v2, v3);
|
||||
// Attention
|
||||
g->addOpWithOutputs<G2BMMObj>(q3, k3, prob, w, d);
|
||||
g->addOpWithOutputs<SoftmaxObj>(prob, probSoftmax, 2);
|
||||
g->addOpWithOutputs<GBMMObj>(probSoftmax, v3, attn, d);
|
||||
auto attn2 = g->addOp<ReshapeObj>(attn, nullptr,
|
||||
vector{bs, heads, seqlen, hiddenPerHead})
|
||||
->getOutput();
|
||||
auto t000 =
|
||||
g->addOp<TransposeObj>(attn2, nullptr, vector{0, 2, 1, 3})->getOutput();
|
||||
g->addOpWithOutputs<ReshapeObj>(t000, t00);
|
||||
|
||||
// Feed forward
|
||||
g->addOpWithOutputs<MatmulObj>(t00, w3, t01, false, true, bias3);
|
||||
g->addOpWithOutputs<ReluObj>(t01, t02);
|
||||
g->addOpWithOutputs<MatmulObj>(t02, w4, t11, false, true, bias4);
|
||||
g->addOpWithOutputs<ReluObj>(t11, t12);
|
||||
g->addOpWithOutputs<AddObj>(t12, i0, output);
|
||||
return g;
|
||||
}
|
||||
|
||||
|
@ -173,6 +286,7 @@ void printGraph(Graph g) {
|
|||
}
|
||||
|
||||
void initializeGraphTensors(Graph g, double l, double r, bool useInt) {
|
||||
g->dataMalloc();
|
||||
auto gen = RandomGenerator(-0.1, 0.1, 0, useInt);
|
||||
for (auto t : g->getInputs()) {
|
||||
t->setData(gen);
|
||||
|
@ -201,6 +315,8 @@ Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,
|
|||
IT_TODO_HALT();
|
||||
vector<Graph> bestGraphs;
|
||||
SearchEngine searchEngine(runtime, mutator);
|
||||
return searchEngine.run(g);
|
||||
|
||||
bestGraphs.emplace_back(searchEngine.run(g));
|
||||
g->topo_sort();
|
||||
dbg(g, bestGraphs[0], bestGraphs.size());
|
||||
|
@ -224,6 +340,7 @@ Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,
|
|||
// dbg("Baseline graph");
|
||||
// printGraph(g);
|
||||
// dbg(runtme->getPerfTime(g, true));
|
||||
g->dataFree();
|
||||
|
||||
for (size_t i = 0; i < bestGraphs.size(); i++) {
|
||||
auto bestGraphCpu = bestGraphs[i];
|
||||
|
@ -231,32 +348,33 @@ Graph optimizeGraph(Graph g, Runtime _runtime, bool tuning, NMutator::Mode mode,
|
|||
make_ref<GraphObj>(runtime, bestGraphCpu->getOperators());
|
||||
bestGraph->topo_sort();
|
||||
|
||||
bestGraph->dataMalloc();
|
||||
// Initialize inputs with random data
|
||||
for (auto t : bestGraph->getInputs()) {
|
||||
t->copyData(fuidToInputTensor[t->getFuid()]);
|
||||
}
|
||||
// bestGraph->dataMalloc();
|
||||
// // Initialize inputs with random data
|
||||
// for (auto t : bestGraph->getInputs()) {
|
||||
// t->copyData(fuidToInputTensor[t->getFuid()]);
|
||||
// }
|
||||
|
||||
// Initialize outputs with zeros
|
||||
for (auto t : bestGraph->getOutputs()) {
|
||||
t->setData(ZeroGenerator());
|
||||
}
|
||||
// // Initialize outputs with zeros
|
||||
// for (auto t : bestGraph->getOutputs()) {
|
||||
// t->setData(ZeroGenerator());
|
||||
// }
|
||||
|
||||
dbg(bestGraph);
|
||||
dbg(bestGraph->getOutputs());
|
||||
// dbg(bestGraph);
|
||||
// dbg(bestGraph->getOutputs());
|
||||
|
||||
if (tuning) {
|
||||
runtime->run(bestGraph, true); // Tune kernels
|
||||
runtime->run(bestGraph, false); // Execute transfomraed graph
|
||||
// if (tuning) {
|
||||
// runtime->run(bestGraph, true); // Tune kernels
|
||||
// runtime->run(bestGraph, false); // Execute transfomraed graph
|
||||
|
||||
auto go0 = gCpu->cloneTensor(g->getOutputs()[0]);
|
||||
auto bgo0 = gCpu->cloneTensor(bestGraph->getOutputs()[0]);
|
||||
// EXPECT_TRUE(go0->equalData(bgo0, 1e-3));
|
||||
dbg(go0->equalData(bgo0, 1e-3));
|
||||
dbg(runtime->getPerfTime(bestGraph, true));
|
||||
dbg(runtime->timeNonCtcOperators(bestGraph));
|
||||
// dbg(runtime->timeWithCudaGraph(bestGraph));
|
||||
}
|
||||
// // FIXME: g is freed
|
||||
// auto go0 = gCpu->cloneTensor(g->getOutputs()[0]);
|
||||
// auto bgo0 = gCpu->cloneTensor(bestGraph->getOutputs()[0]);
|
||||
// // EXPECT_TRUE(go0->equalData(bgo0, 1e-3));
|
||||
// dbg(go0->equalData(bgo0, 1e-3));
|
||||
// dbg(runtime->getPerfTime(bestGraph, true));
|
||||
// dbg(runtime->timeNonCtcOperators(bestGraph));
|
||||
// // dbg(runtime->timeWithCudaGraph(bestGraph));
|
||||
// }
|
||||
|
||||
// dbg("Best graph");
|
||||
// printGraph(bestGraph);
|
||||
|
|
|
@ -6,11 +6,15 @@
|
|||
#include "nnet/Visitor/MatchReshapeVisitor.h"
|
||||
#include "nnet/Visitor/MergeMemboundMutator.h"
|
||||
#include "nnet/derivator.h"
|
||||
#include "operators/G2BMM.h"
|
||||
#include "operators/GBMM.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/conv2dreduce.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "operators/membound.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -39,7 +43,6 @@ vector<Graph> NMutator::run(const Graph &in_graph) {
|
|||
inputsNameNToTensorT.clear();
|
||||
OpVec computeOps = in_graph->getComputeOps();
|
||||
// assert(computeOps.size() == 1);
|
||||
printf("hkz: in run\n");
|
||||
if (computeOps.size() == 1)
|
||||
runSingleOp(in_graph, out_graphs);
|
||||
// FIXME: runMultipleOps results in segfault
|
||||
|
@ -48,6 +51,12 @@ vector<Graph> NMutator::run(const Graph &in_graph) {
|
|||
return out_graphs;
|
||||
}
|
||||
|
||||
bool NMutator::isMultiBranchMergable(const Graph &in_graph) {
|
||||
// TODO
|
||||
// dbg("Skip mergable Multi-Branch", in_graph);
|
||||
return false;
|
||||
}
|
||||
|
||||
void NMutator::runSingleOpToNaiveMembound(Graph in_graph,
|
||||
std::vector<Graph> &out_graphs) {
|
||||
OpVec computeOps = in_graph->getComputeOps();
|
||||
|
@ -79,16 +88,24 @@ void NMutator::runSingleOpToNaiveMembound(Graph in_graph,
|
|||
void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
|
||||
OpVec computeOps = in_graph->getComputeOps();
|
||||
IT_ASSERT(computeOps.size() == 1);
|
||||
printf("hkz: try conv transpose\n");
|
||||
if (Graph g = transformConvtransposed1x1(computeOps[0])) {
|
||||
printf("hkz: apply conv transpose\n");
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
printf("hkz: try conv diated\n");
|
||||
if (Graph g = transformConv1x1(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
if (Graph g = transformG2bmm(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
if (Graph g = transformGbmm(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
return;
|
||||
}
|
||||
if (infini::Graph g = transformDialtedConv(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
printf("hkz: apply conv diated\n");
|
||||
return;
|
||||
}
|
||||
if (infini::Graph g = transformConvToGEMMReduce(computeOps[0])) {
|
||||
|
@ -96,10 +113,8 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
|
|||
return;
|
||||
}
|
||||
|
||||
printf("hkz: try conv transpose gemm\n");
|
||||
if (infini::Graph g = transformConvTranposeToGEMMReduce(computeOps[0])) {
|
||||
out_graphs.emplace_back(g);
|
||||
printf("hkz: apply conv transpose gemm\n");
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -114,7 +129,8 @@ void NMutator::runSingleOp(Graph in_graph, std::vector<Graph> &out_graphs) {
|
|||
// // return;
|
||||
// // }
|
||||
|
||||
const set<OpType> opSet{OpType::Conv, OpType::ConvTransNHWC};
|
||||
const set<OpType> opSet{OpType::Conv, OpType::ConvTransNHWC, OpType::G2BMM,
|
||||
OpType::GBMM};
|
||||
if (opSet.count(computeOps[0]->getOpType()) == 0)
|
||||
return;
|
||||
|
||||
|
@ -249,6 +265,8 @@ void NMutator::runMultipleOps(Graph in_graph, std::vector<Graph> &out_graphs) {
|
|||
|
||||
nnet::Expr NMutator::opToExpression(Operator opT) {
|
||||
auto [expr, mapNameNToTensorT] = extractOp(opT);
|
||||
IT_ASSERT(expr,
|
||||
"Cannot convert " + opT->toString() + " to an NNet expression");
|
||||
for (auto &[name, tensorT] : mapNameNToTensorT) {
|
||||
IT_ASSERT(inputsNameNToTensorT.count(name) == 0);
|
||||
inputsNameNToTensorT[name] = tensorT;
|
||||
|
@ -292,26 +310,25 @@ pair<nnet::Expr, NMutator::NameNToTensorT> NMutator::extractOp(Operator opT) {
|
|||
const auto K = nnet::makeTensor("K", KT->getDims());
|
||||
return {nnet::ConvTransPattern::getExpr(A, K, n, c, h, w, f, r, s),
|
||||
{{"A", AT}, {"K", KT}}};
|
||||
// } else if (auto g2bmmOp = dynamic_cast<G2BMMOp *>(opT)) {
|
||||
// const auto &AT = g2bmmOp->getInputs()[0];
|
||||
// const auto &BT = g2bmmOp->getInputs()[1];
|
||||
// const auto [b, m, k, width, dilation] = g2bmmOp->getArgs();
|
||||
} else if (auto g2bmmOp = as<G2BMMObj>(opT)) {
|
||||
const auto &AT = g2bmmOp->getInputs()[0];
|
||||
const auto &BT = g2bmmOp->getInputs()[1];
|
||||
const auto [b, m, k, width, dilation] = g2bmmOp->getBMKWD();
|
||||
|
||||
// const auto &[expr, inputsN] =
|
||||
// nnet::Sg2bmmPattern::getExpr(b, m, k, width, dilation);
|
||||
// inputsNameNToTensorT[inputsN.first->getName()] = AT;
|
||||
// inputsNameNToTensorT[inputsN.second->getName()] = BT;
|
||||
// return expr;
|
||||
// } else if (auto gbmmlOp = dynamic_cast<GBMMLOp *>(opT)) {
|
||||
// const auto &AT = gbmmlOp->getInputs()[0];
|
||||
// const auto &BT = gbmmlOp->getInputs()[1];
|
||||
// const auto [b, m, w, k, dilation] = gbmmlOp->getArgs();
|
||||
// const auto &[expr, inputsN] =
|
||||
// nnet::LongformerGBMMPattern::getExpr(b, m, w, k, dilation);
|
||||
// inputsNameNToTensorT[inputsN.first->getName()] = AT;
|
||||
// inputsNameNToTensorT[inputsN.second->getName()] = BT;
|
||||
// dbg(b, m, w, k, dilation, expr);
|
||||
// return expr;
|
||||
const auto &[expr, inputsN] =
|
||||
nnet::Sg2bmmPattern::getExpr(b, m, k, width, dilation);
|
||||
return {
|
||||
expr,
|
||||
{{inputsN.first->getName(), AT}, {inputsN.second->getName(), BT}}};
|
||||
} else if (auto gbmmlOp = as<GBMMObj>(opT)) {
|
||||
const auto &AT = gbmmlOp->getInputs()[0];
|
||||
const auto &BT = gbmmlOp->getInputs()[1];
|
||||
const auto [b, m, w, k, dilation] = gbmmlOp->getBMWND();
|
||||
const auto &[expr, inputsN] =
|
||||
nnet::LongformerGBMMPattern::getExpr(b, m, w, k, dilation);
|
||||
return {
|
||||
expr,
|
||||
{{inputsN.first->getName(), AT}, {inputsN.second->getName(), BT}}};
|
||||
} else if (auto matmulOp = as<MatmulObj>(opT)) {
|
||||
const auto &AT = matmulOp->getInputs()[0];
|
||||
const auto &BT = matmulOp->getInputs()[1];
|
||||
|
@ -333,8 +350,6 @@ pair<nnet::Expr, NMutator::NameNToTensorT> NMutator::extractOp(Operator opT) {
|
|||
// // else if (auto transposeOp = dynamic_cast<TransposeOp *>(opT)) {
|
||||
// // return transposeOpToExpression(transposeOp);
|
||||
// // }
|
||||
IT_TODO_HALT_MSG("Cannot convert " + opT->toString() +
|
||||
" to an NNet expression");
|
||||
return {};
|
||||
}
|
||||
|
||||
|
@ -390,7 +405,13 @@ infini::Graph NMutator::expressionToGraph(nnet::Expr expr, Graph in_graph) {
|
|||
const auto &[ph, pw, sh, sw, dh, dw] = op->getArgs();
|
||||
g->addOpWithOutputs<ConvObj>(A, K, output, ph, pw, sh, sw, dh, dw);
|
||||
} else if (auto op = nnet::as<nnet::ElementWiseNode>(routineN)) {
|
||||
assert(op->getInputs().size() == 1);
|
||||
// dbg(op, op->getExpr());
|
||||
// TODO: For a single input channel conv, it can be transformed into
|
||||
// vec X vec ---> matrix --reduce--> result
|
||||
// This transformation only introduce membound Ops and can have a
|
||||
// wrong estimated execution time, so we skip it now.
|
||||
if (op->getInputs().size() != 1)
|
||||
return nullptr;
|
||||
nnet::MatchReshapeVisitor matchReshapeVisitor;
|
||||
// If this routine only change the shape, translate it to a Reshape
|
||||
if (matchReshapeVisitor(op->getExpr())) {
|
||||
|
@ -465,34 +486,6 @@ double NMutator::memboundTime(const Shape &dims) {
|
|||
return memboundTime(dims.size());
|
||||
}
|
||||
|
||||
// infini::Graph NMutator::fuseHetConv(nnet::Expr expr, Graph in_graph) {
|
||||
// // Conv3x3+Conv1x1 => Gemm(nhw, f(rs+1), c) + Reduce
|
||||
// auto g = std::make_shared<infini::Graph>();
|
||||
// in_graph->print();
|
||||
// assert(in_graph->getInputs().size() == 3);
|
||||
// auto input = in_graph->getOperators()[0]->getInputs(0);
|
||||
// auto conv = dynamic_cast<ConvOp *>(in_graph->getOperators()[0]);
|
||||
// auto output = conv->getOutput();
|
||||
// // auto input = g->reshape(input);
|
||||
// auto inputTrans = g->transpose(input, 0, {-1, {0, 2, 3}, 1}, -1);
|
||||
// // dbg(inputTrans->getOutput()->getDims());
|
||||
// const auto &[n, c, h, w, f, r, s, ph, pw, sh, sw, dh, dw, G, bi, ac] =
|
||||
// conv->getArgs(0);
|
||||
// auto weight = g->tensor({1, c, f * (3 * 3 + 1)});
|
||||
// dbg(weight->getDims());
|
||||
// auto matmul = g->matmul(inputTrans->getOutput(), weight, false, false);
|
||||
// auto bias = g->tensor({f});
|
||||
// const double size = n * f * h * w * (3 * 3 + 1) * 4;
|
||||
// // FIXME: add NNET tensors for verfication
|
||||
// auto membound =
|
||||
// g->membound({matmul->getOutput(), bias}, {output}, {}, nullptr,
|
||||
// memboundTime(size), "Reduce_conv3x3+1x1");
|
||||
// dbg(n, f, h, w);
|
||||
// dynamic_cast<MemBoundOp *>(membound)->setNFHW(n, f, h, w);
|
||||
|
||||
// return new Graph(g->getOperators());
|
||||
// }
|
||||
|
||||
Graph NMutator::transformDialtedConv(Operator _op) {
|
||||
auto op = as<ConvObj>(_op);
|
||||
if (!op)
|
||||
|
@ -695,61 +688,72 @@ Graph NMutator::transformConvTranposeToGEMMReduce(Operator _op) {
|
|||
// return graph;
|
||||
// }
|
||||
|
||||
// Graph NMutator::transformTConv1x1(Operator op) {
|
||||
// if (auto tconvOp = dynamic_cast<ConvTransOp *>(op)) {
|
||||
// if (tconvOp->getPh() == 0 && tconvOp->getSh() == 1) {
|
||||
// auto g = new infini::Graph();
|
||||
// auto inputDims = tconvOp->getInputs(0)->getDims();
|
||||
// auto weightDims = tconvOp->getInputs(1)->getDims();
|
||||
// auto outputDims = tconvOp->getOutput()->getDims();
|
||||
// auto newA = g->tensor(
|
||||
// {inputDims[0] * inputDims[1] * inputDims[2], inputDims[3]});
|
||||
// auto newW = g->tensor(
|
||||
// {weightDims[0] * weightDims[1] * weightDims[3],
|
||||
// weightDims[2]});
|
||||
// auto newO =
|
||||
// g->tensor({inputDims[0] * inputDims[1] * inputDims[2],
|
||||
// weightDims[0] * weightDims[1] * weightDims[3]});
|
||||
// g->reshape(tconvOp->getInputs(0), newA);
|
||||
// g->reshape(tconvOp->getInputs(1), newW);
|
||||
// g->matmul(newA, newW, newO, 0, 1);
|
||||
// g->reshape(newO, tconvOp->getOutput());
|
||||
// g->updateConnection();
|
||||
// Graph graph = new Graph(g->getOperators());
|
||||
// return graph;
|
||||
// }
|
||||
// }
|
||||
// return nullptr;
|
||||
// }
|
||||
Graph NMutator::transformG2bmm(Operator _op) {
|
||||
auto op = as<G2BMMObj>(_op);
|
||||
if (!op)
|
||||
return nullptr;
|
||||
const auto [b, m, k, width, dilation] = op->getBMKWD();
|
||||
if (dilation == 1 || m % dilation != 0)
|
||||
return nullptr;
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
auto A = g->cloneTensor(op->getInputs(0));
|
||||
auto B = g->cloneTensor(op->getInputs(1));
|
||||
auto O = g->cloneTensor(op->getOutput());
|
||||
auto A3 = splitTransposeMerge(g, A, 1, dilation),
|
||||
B3 = splitTransposeMerge(g, B, 1, dilation);
|
||||
auto O3 = g->addOp<G2BMMObj>(A3, B3, nullptr, width, 1)->getOutput();
|
||||
splitTransposeMerge(g, O3, 1, m / dilation, O);
|
||||
g->checkValid();
|
||||
return g;
|
||||
}
|
||||
|
||||
// Graph NMutator::transformConv1x1(Operator op) {
|
||||
// auto convOp = dynamic_cast<ConvOp *>(op);
|
||||
// if (!convOp)
|
||||
// return nullptr;
|
||||
// if (convOp->getPh() == 0 && convOp->getSh() == 1 &&
|
||||
// convOp->getInputs()[1]->getDims()[2] == 1 &&
|
||||
// convOp->getInputs()[1]->getDims()[3] == 1) {
|
||||
// // Transpose is requrired for BS>1
|
||||
// // if (convOp->getInputs()[0]->getDims()[0] == 1) {
|
||||
// auto g = new infini::Graph();
|
||||
// auto inputDims = convOp->getInputs(0)->getDims();
|
||||
// auto weightDims = convOp->getInputs(1)->getDims();
|
||||
// auto outputDims = convOp->getOutput()->getDims();
|
||||
// auto newA = g->tensor(
|
||||
// {inputDims[1], inputDims[0] * inputDims[2] * inputDims[3]});
|
||||
// auto newW = g->tensor({weightDims[0], weightDims[1]});
|
||||
// auto newO = g->tensor(
|
||||
// {weightDims[0], inputDims[0] * inputDims[2] * inputDims[3]});
|
||||
// g->reshape(convOp->getInputs(0), newA);
|
||||
// g->reshape(convOp->getInputs(1), newW);
|
||||
// g->matmul(newW, newA, newO, 0, 0);
|
||||
// g->reshape(newO, convOp->getOutput());
|
||||
// g->updateConnection();
|
||||
// Graph graph = new Graph(g->getOperators());
|
||||
// return graph;
|
||||
// }
|
||||
// return nullptr;
|
||||
// }
|
||||
Graph NMutator::transformGbmm(Operator _op) {
|
||||
auto op = as<GBMMObj>(_op);
|
||||
if (!op)
|
||||
return nullptr;
|
||||
const auto [b, m, width, k, dilation] = op->getBMWND();
|
||||
if (dilation == 1 || m % dilation != 0)
|
||||
return nullptr;
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
auto A = g->cloneTensor(op->getInputs(0)); // [b,m,2w+1]
|
||||
auto B = g->cloneTensor(op->getInputs(1)); // [b,m,n]
|
||||
auto O = g->cloneTensor(op->getOutput()); // [b,m,n]
|
||||
auto A3 = splitTransposeMerge(g, A, 1, dilation),
|
||||
B3 = splitTransposeMerge(g, B, 1, dilation);
|
||||
auto O3 = g->addOp<GBMMObj>(A3, B3, nullptr, 1)->getOutput();
|
||||
splitTransposeMerge(g, O3, 1, m / dilation, O);
|
||||
g->checkValid();
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph NMutator::transformConv1x1(Operator _op) {
|
||||
auto op = as<ConvObj>(_op);
|
||||
if (!op)
|
||||
return nullptr;
|
||||
Shape shapeA = op->getInputs(0)->getDims();
|
||||
Shape shapeW = op->getInputs(1)->getDims();
|
||||
// TODO: support batch size > 1
|
||||
if (shapeA[0] != 1)
|
||||
return nullptr;
|
||||
if (op->getPh() == 0 && op->getSh() == 1 && shapeW[2] == 1 &&
|
||||
shapeW[3] == 1) {
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
auto A =
|
||||
g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(0)), nullptr,
|
||||
vector{shapeA[1], shapeA[0] * shapeA[2] *
|
||||
shapeA[3]}) // [C, N*H*W]
|
||||
->getOutput();
|
||||
auto B = g->addOp<ReshapeObj>(g->cloneTensor(op->getInputs(1)), nullptr,
|
||||
vector{shapeW[0], shapeW[1]}) // [F, C]
|
||||
->getOutput();
|
||||
auto O =
|
||||
g->addOp<MatmulObj>(B, A, nullptr, 0, 0)->getOutput(); // [F, N*H*W]
|
||||
g->addOpWithOutputs<ReshapeObj>(O, g->cloneTensor(op->getOutput()),
|
||||
op->getOutput()->getDims());
|
||||
return g;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Graph NMutator::transformConv1xk(Operator op) {
|
||||
// auto convOp = dynamic_cast<ConvOp *>(op);
|
||||
|
@ -788,6 +792,142 @@ Graph NMutator::transformConvTranposeToGEMMReduce(Operator _op) {
|
|||
// return graph;
|
||||
// }
|
||||
|
||||
Graph NMutator::constructGraphByOperatorChain(vector<Operator> ops,
|
||||
Graph inputGraph) {
|
||||
// Construct new graph
|
||||
auto g = make_ref<GraphObj>(runtime);
|
||||
IT_ASSERT(inputGraph->getInputs().size() == 1);
|
||||
IT_ASSERT(inputGraph->getOutputs().size() == 1);
|
||||
IT_ASSERT(ops.size() > 0,
|
||||
"TODO: If there is no op left, how to return an empty graph?");
|
||||
auto input = g->cloneTensor(inputGraph->getInputs()[0]);
|
||||
for (size_t i = 0; i < ops.size(); ++i) {
|
||||
auto output = (i + 1 == ops.size())
|
||||
? inputGraph->getOutputs()[0]
|
||||
: g->addTensor(ops[i]->getOutput()->getDims());
|
||||
dbg(input->getDims(), output->getDims());
|
||||
input = g->cloneOperator(ops[i], {input}, {output})->getOutput();
|
||||
}
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph NMutator::eliminateVertically(const Graph &inputGraph) {
|
||||
auto ops = inputGraph->getOperators();
|
||||
|
||||
IT_ASSERT(!ops.empty());
|
||||
for (auto &op : ops) {
|
||||
IT_ASSERT(op->isMemBoundOp());
|
||||
IT_ASSERT_TODO(op->getInputs().size() == 1);
|
||||
IT_ASSERT(op->getOutputs().size() == 1);
|
||||
}
|
||||
if (ops.size() == 1) {
|
||||
return make_ref<GraphObj>(runtime, ops);
|
||||
}
|
||||
|
||||
// Set attributs for operators.
|
||||
// isComputation: is computaiton
|
||||
// isElementwise: do elementwise computations
|
||||
// lastRowSwapable: do last-channel-wise computations, which includes
|
||||
// elementwise as a special case.
|
||||
auto classifyOperator = [](Operator op) {
|
||||
auto type = op->getOpType();
|
||||
bool isComputation =
|
||||
type != OpType::Reshape && type != OpType::Transpose;
|
||||
bool isElementwise =
|
||||
!isComputation || (type == OpType::Relu || type == OpType::Tanh);
|
||||
bool lastRowSwapable = false;
|
||||
if (isComputation)
|
||||
lastRowSwapable = isElementwise || // Softmax along the last dim
|
||||
(type == OpType::Softmax &&
|
||||
as<SoftmaxObj>(op)->getAxis() ==
|
||||
int(op->getOutput()->getDims().size()) - 1);
|
||||
else {
|
||||
if (auto t = as<TransposeObj>(op)) {
|
||||
// Last dim remains unchanged
|
||||
lastRowSwapable =
|
||||
(t->getPermute().back() == int(t->getPermute().size()) - 1);
|
||||
} else if (auto t = as<ReshapeObj>(op)) {
|
||||
// Last dim remains unchanged
|
||||
lastRowSwapable = (t->getInputs(0)->getDims().back() ==
|
||||
t->getOutput()->getDims().back());
|
||||
}
|
||||
}
|
||||
return tuple{isComputation, isElementwise, lastRowSwapable};
|
||||
};
|
||||
|
||||
// Reorder operators: move computatation operators to the tail
|
||||
for (int i = ops.size() - 2; i >= 0; --i) {
|
||||
for (int j = i; j < int(ops.size()) - 1; ++j) {
|
||||
bool swapable = false;
|
||||
const set<OpType> unaryElementwise{OpType::Relu, OpType::PRelu,
|
||||
OpType::Tanh};
|
||||
auto [aIsC, aEw, aLRS] = classifyOperator(ops[j]);
|
||||
auto [bIsC, bEw, bLRS] = classifyOperator(ops[j + 1]);
|
||||
if (aIsC && !bIsC && (aEw || (aLRS && bLRS))) // Swap condition
|
||||
swapable = true;
|
||||
if (swapable) {
|
||||
swap(ops[j], ops[j + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Graph g = constructGraphByOperatorChain(ops, inputGraph);
|
||||
// Eliminate operators
|
||||
bool haveElimination;
|
||||
do {
|
||||
haveElimination = false;
|
||||
ops = g->getOperators();
|
||||
vector<Operator> newOps;
|
||||
for (int i = 0; i < int(ops.size()); ++i) {
|
||||
// Eliminate identity operators
|
||||
if (auto op = as<TransposeObj>(ops[i])) {
|
||||
auto perm = op->getPermute();
|
||||
int j = 0;
|
||||
for (j = 0; j < int(perm.size()); ++j)
|
||||
if (j != perm[j])
|
||||
break;
|
||||
if (j == int(perm.size())) {
|
||||
haveElimination = true;
|
||||
continue;
|
||||
}
|
||||
} else if (auto op = as<ReshapeObj>(ops[i])) {
|
||||
if (op->getShape() == op->getInputs(0)->getDims()) {
|
||||
haveElimination = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Eliminate reciprocal operators
|
||||
if (i + 1 == (int)ops.size() ||
|
||||
(ops[i]->getOpType() != ops[i + 1]->getOpType())) {
|
||||
newOps.push_back(ops[i]);
|
||||
continue;
|
||||
}
|
||||
if (ops[i]->getOpType() == OpType::Reshape) {
|
||||
newOps.push_back(make_ref<ReshapeObj>(
|
||||
nullptr, ops[i]->getInputs(0), ops[i + 1]->getOutput()));
|
||||
++i;
|
||||
haveElimination = true;
|
||||
} else if (ops[i]->getOpType() == OpType::Transpose) {
|
||||
auto permuteA = as<TransposeObj>(ops[i])->getPermute();
|
||||
auto permuteB = as<TransposeObj>(ops[i + 1])->getPermute();
|
||||
vector<int> permute;
|
||||
for (auto p : permuteB)
|
||||
permute.push_back(permuteA[p]);
|
||||
newOps.push_back(
|
||||
make_ref<TransposeObj>(nullptr, ops[i]->getInputs(0),
|
||||
ops[i + 1]->getOutput(), permute));
|
||||
++i;
|
||||
haveElimination = true;
|
||||
} else {
|
||||
newOps.push_back(ops[i]);
|
||||
}
|
||||
}
|
||||
g = constructGraphByOperatorChain(newOps, inputGraph);
|
||||
} while (haveElimination);
|
||||
return g;
|
||||
}
|
||||
|
||||
Graph NMutator::fuseVertically(const Graph &inputGraph) {
|
||||
Graph optGraph = make_ref<GraphObj>(runtime);
|
||||
|
||||
|
@ -804,6 +944,8 @@ Graph NMutator::fuseVertically(const Graph &inputGraph) {
|
|||
std::vector<nnet::Expr> exprs;
|
||||
for (const auto &op : chainOps) {
|
||||
auto [expr, _] = extractOp(op);
|
||||
if (!expr)
|
||||
return nullptr;
|
||||
exprs.emplace_back(expr);
|
||||
// dbg(op, infini::as<nnet::RangeOpNode>(expr)->getFullExpression());
|
||||
}
|
||||
|
@ -870,4 +1012,28 @@ pair<nnet::Expr, vector<nnet::Tensor>> NMutator::generateRevert(Tensor in) {
|
|||
return {range, {tensor}};
|
||||
}
|
||||
|
||||
Tensor NMutator::splitTransposeMerge(Graph g, Tensor A, int dim, int chunkSize,
|
||||
Tensor output) {
|
||||
IT_ASSERT(A->getDims().size() == 3);
|
||||
Shape shapeOrignial = A->getDims();
|
||||
Shape shapeNew;
|
||||
// Construct new shape
|
||||
for (int i = 0; i < dim; ++i)
|
||||
shapeNew.emplace_back(shapeOrignial[i]);
|
||||
shapeNew.emplace_back(shapeOrignial[dim] / chunkSize);
|
||||
shapeNew.emplace_back(chunkSize);
|
||||
for (size_t i = dim + 1; i < shapeOrignial.size(); ++i)
|
||||
shapeNew.emplace_back(shapeOrignial[i]);
|
||||
auto A1 = g->addOp<ReshapeObj>(A, nullptr, shapeNew)->getOutput();
|
||||
auto A2 =
|
||||
g->addOp<TransposeObj>(A1, nullptr, vector{0, 2, 1, 3})->getOutput();
|
||||
Tensor A3;
|
||||
if (output)
|
||||
A3 = g->addOpWithOutputs<ReshapeObj>(A2, output, shapeOrignial)
|
||||
->getOutput();
|
||||
else
|
||||
A3 = g->addOp<ReshapeObj>(A2, nullptr, shapeOrignial)->getOutput();
|
||||
return A3;
|
||||
};
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
#include "operators/any.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
AnyObj::AnyObj(GraphObj *graph, const TensorVec &inputs,
|
||||
const TensorVec &outputs, string &kernelName,
|
||||
const vector<int> &attr)
|
||||
: OperatorObj(OpType::Any, inputs, outputs), kernelName(kernelName),
|
||||
attr(attr) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
// Outputs must assigned when constructing AnyObj
|
||||
IT_ASSERT(!outputs.empty());
|
||||
for (auto &output : outputs)
|
||||
IT_ASSERT(output != nullptr && output->size() > 0);
|
||||
}
|
||||
|
||||
string AnyObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Any[" << getGuid() << "](";
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
os << "i" << i << "=" << inputs[i]->getGuid();
|
||||
if (i != inputs.size() - 1)
|
||||
os << " ";
|
||||
}
|
||||
os << ", ";
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
os << "o" << i << "=" << outputs[i]->getGuid();
|
||||
if (i != outputs.size() - 1)
|
||||
os << " ";
|
||||
}
|
||||
os << ", ";
|
||||
os << "kernel name: " << kernelName << ", ";
|
||||
os << "attr = [";
|
||||
for (size_t i = 0; i < attr.size(); ++i) {
|
||||
os << attr[i];
|
||||
if (i != attr.size() - 1)
|
||||
os << ", ";
|
||||
}
|
||||
os << "])\n";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
optional<vector<Shape>> AnyObj::inferShape(const TensorVec &inputs) const {
|
||||
vector<Shape> ret;
|
||||
for (auto output : outputs) {
|
||||
ret.emplace_back(output->getDims());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
const string AnyObj::getKernelName() const { return kernelName; }
|
||||
|
||||
vector<int> AnyObj::getOpAttrVector() const { return attr; };
|
||||
|
||||
vector<int> AnyObj::getWorkloadVector() const {
|
||||
vector<int> ret = {};
|
||||
for (auto &input : inputs) {
|
||||
auto inputDims = input->getDims();
|
||||
ret.insert(ret.end(), inputDims.begin(), inputDims.end());
|
||||
}
|
||||
for (auto &output : outputs) {
|
||||
auto outputDims = output->getDims();
|
||||
ret.insert(ret.end(), outputDims.begin(), outputDims.end());
|
||||
}
|
||||
for (auto c : kernelName) {
|
||||
ret.emplace_back(c);
|
||||
}
|
||||
for (auto at : attr) {
|
||||
ret.emplace_back(at);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -2,7 +2,8 @@
|
|||
|
||||
namespace infini {
|
||||
ReshapeObj::ReshapeObj(GraphObj *graph, Tensor input, Tensor output, Shape dims)
|
||||
: OperatorObj(OpType::Reshape, {input}, {output}), dims(std::move(dims)) {
|
||||
: OperatorObj(OpType::Reshape, {input}, {output}),
|
||||
dims(dims.size() == 0 ? output->getDims() : dims) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
@ -19,9 +20,9 @@ optional<vector<Shape>> ReshapeObj::inferShape(const TensorVec &inputs) const {
|
|||
std::string ReshapeObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << "Reshape[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << "(input dim=";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "dims=" << vecToString(dims) << ",";
|
||||
os << "output dims=" << vecToString(dims) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
return os.str();
|
||||
|
|
|
@ -4,13 +4,7 @@ namespace infini {
|
|||
TransposeObj::TransposeObj(GraphObj *graph, Tensor input, Tensor output,
|
||||
vector<int> permute)
|
||||
: OperatorObj(OpType::Transpose, {input}, {output}) {
|
||||
if (permute.size() != 4) {
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
transposePermute[0] = permute[0];
|
||||
transposePermute[1] = permute[1];
|
||||
transposePermute[2] = permute[2];
|
||||
transposePermute[3] = permute[3];
|
||||
transposePermute = permute;
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
|
@ -20,7 +14,8 @@ TransposeObj::inferShape(const TensorVec &inputs) const {
|
|||
auto input = A->getDims();
|
||||
auto output = input;
|
||||
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
auto nDims = input.size();
|
||||
for (size_t i = 0; i < nDims; ++i) {
|
||||
output[i] = input[transposePermute[i]];
|
||||
}
|
||||
return {{output}};
|
||||
|
@ -32,7 +27,8 @@ std::string TransposeObj::toString() const {
|
|||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
os << "output=" << outputs[0]->getGuid() << ",";
|
||||
os << "perm=" << vecToString(transposePermute) << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ TEST(SubGraphRewriter, subGraphMatch1) {
|
|||
SubGraphRewriter v(g);
|
||||
vector<MatchGraph> subgs = v.findMatch(subG);
|
||||
|
||||
EXPECT_TRUE(subgs.size() == 2);
|
||||
EXPECT_TRUE(subgs.size() == 2u);
|
||||
}
|
||||
|
||||
TEST(MatchGraph, single_input) {
|
||||
|
@ -116,12 +116,12 @@ TEST(MatchGraph, single_input) {
|
|||
|
||||
auto o4 = v.addSubGraph(subG, TensorVec{add1->getOutput(0)});
|
||||
|
||||
EXPECT_EQ(g->getOperators().size(), 52);
|
||||
EXPECT_EQ(g->getOperators().size(), 52u);
|
||||
vector<MatchGraph> subgs = v.findMatch(subG);
|
||||
EXPECT_TRUE(subgs.size() == 5);
|
||||
EXPECT_TRUE(subgs.size() == 5u);
|
||||
|
||||
vector<MatchGraph> subgs1 = v.findMatch(subG1);
|
||||
EXPECT_TRUE(subgs1.size() == 4);
|
||||
EXPECT_TRUE(subgs1.size() == 4u);
|
||||
|
||||
// test replace
|
||||
Tensor sii0 =
|
||||
|
@ -135,7 +135,7 @@ TEST(MatchGraph, single_input) {
|
|||
}
|
||||
|
||||
v.replaceSubGraph(subG, subG2);
|
||||
EXPECT_EQ(g->getOperators().size(), 37);
|
||||
EXPECT_EQ(g->getOperators().size(), 37u);
|
||||
}
|
||||
|
||||
TEST(MatchGraph, multi_input) {
|
||||
|
@ -186,17 +186,17 @@ TEST(MatchGraph, multi_input) {
|
|||
nullptr);
|
||||
|
||||
auto matches = v.findMatch(subG);
|
||||
EXPECT_EQ(2, matches.size());
|
||||
EXPECT_EQ(2u, matches.size());
|
||||
|
||||
auto div0 = g->addOp<DivObj>(reduce1->getOutput(0), i2, nullptr);
|
||||
auto add1 =
|
||||
g->addOp<AddObj>(sub0->getOutput(), div0->getOutput(), nullptr);
|
||||
matches = v.findMatch(subG);
|
||||
EXPECT_EQ(1, matches.size());
|
||||
EXPECT_EQ(1u, matches.size());
|
||||
|
||||
// two matched subgraphs overlaped,so only replaced one sub graph
|
||||
v.replaceSubGraph(subG, replaceG);
|
||||
EXPECT_EQ(1, v.findMatch(replaceG).size());
|
||||
EXPECT_EQ(1u, v.findMatch(replaceG).size());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -240,7 +240,7 @@ TEST(MatchGraph, multi_output) {
|
|||
{
|
||||
auto input = g->cloneTensor(i);
|
||||
auto outs = v.addSubGraph(subg0, {input});
|
||||
EXPECT_EQ(2, outs.size());
|
||||
EXPECT_EQ(2u, outs.size());
|
||||
Tensor w0 = g->addTensor(Shape{96, 64, 3, 3}, DataType::UInt32);
|
||||
auto conv0 = g->addOp<ConvObj>(outs[0], w0, nullptr, 1, 1);
|
||||
auto relu0 = g->addOp<ReluObj>(conv0->getOutput(0), nullptr);
|
||||
|
@ -263,11 +263,11 @@ TEST(MatchGraph, multi_output) {
|
|||
}
|
||||
|
||||
auto matches = v.findMatch(subg0);
|
||||
EXPECT_EQ(1, matches.size());
|
||||
EXPECT_EQ(1u, matches.size());
|
||||
|
||||
v.replaceSubGraph(subg0, subg1);
|
||||
auto matches2 = v.findMatch(subg1);
|
||||
EXPECT_EQ(1, matches2.size());
|
||||
EXPECT_EQ(1u, matches2.size());
|
||||
}
|
||||
|
||||
// gcn
|
||||
|
@ -354,16 +354,16 @@ TEST(MatchGraph, multi_input_output) {
|
|||
v.addSubGraph(subg0, {relu->getOutput(0), maxPool->getOutput(0)});
|
||||
auto out1 =
|
||||
v.addSubGraph(subg1, {maxPool->getOutput(0), relu->getOutput(0)});
|
||||
EXPECT_EQ(2, out0.size());
|
||||
EXPECT_EQ(2, out1.size());
|
||||
EXPECT_EQ(2u, out0.size());
|
||||
EXPECT_EQ(2u, out1.size());
|
||||
auto div = g->addOp<DivObj>(out0[0], out1[1], nullptr);
|
||||
auto sub = g->addOp<SubObj>(out0[1], out1[0], nullptr);
|
||||
}
|
||||
|
||||
EXPECT_EQ(2, v.findMatch(subg0).size());
|
||||
EXPECT_EQ(2, v.findMatch(subg1).size());
|
||||
EXPECT_EQ(2u, v.findMatch(subg0).size());
|
||||
EXPECT_EQ(2u, v.findMatch(subg1).size());
|
||||
v.replaceSubGraph(subg0, subg2);
|
||||
EXPECT_EQ(v.findMatch(subg2).size(), 2);
|
||||
EXPECT_EQ(v.findMatch(subg2).size(), 2u);
|
||||
}
|
||||
|
||||
/* One Node having two or more successors is not supported yet.
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "operators/any.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
TEST(cuda_Any, anyKernel) {
|
||||
// conv2dreduce
|
||||
{
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu =
|
||||
NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
|
||||
auto generator = IncrementalGenerator();
|
||||
|
||||
int PRelu = 0, n = 1, h = 4, w = 4, f = 2, r = 3, s = 3, oh = 4, ow = 4,
|
||||
ph = 1, pw = 1, sh = 1, sw = 1, dh = 1, dw = 1;
|
||||
string kernelName = "conv2dreduce_kernel";
|
||||
vector<int> attr{PRelu, n, h, w, f, r, s, oh,
|
||||
ow, ph, pw, sh, sw, dh, dw};
|
||||
|
||||
// Build input data on CPu
|
||||
Tensor i0Cpu = gCpu->addTensor({n, 1, h, w}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({f, 1, r, s}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit
|
||||
// allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(generator);
|
||||
w0Cpu->setData(generator);
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
Tensor o0Cuda = gCuda->addTensor({n, f, oh, ow});
|
||||
auto anyOp = gCuda->addOpWithOutputs<AnyObj>(
|
||||
TensorVec{i0Cuda, w0Cuda}, TensorVec{o0Cuda}, kernelName, attr);
|
||||
anyOp->print();
|
||||
// allocate CUDA memory
|
||||
gCuda->dataMalloc();
|
||||
std::cout << "data malloc success..." << std::endl;
|
||||
// Execute on CUDA
|
||||
cuda->run(gCuda);
|
||||
std::cout << "cuda run success..." << std::endl;
|
||||
// copy output from CUDA to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(anyOp->getOutput());
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(1);
|
||||
// print a tensor/operator/graph by print()
|
||||
gCuda->print();
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,45 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "operators/transpose.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testTranspose(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu = make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(generator);
|
||||
|
||||
// GPU
|
||||
Graph cudaGraph = make_ref<GraphObj>(cudaRuntime);
|
||||
auto inputGpu = cudaGraph->cloneTensor(inputCpu);
|
||||
vector<int> permute = {0, 2, 1, 3};
|
||||
auto gpuOp = cudaGraph->addOp<T>(inputGpu, nullptr, permute);
|
||||
cudaGraph->dataMalloc();
|
||||
cudaRuntime->run(cudaGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto oCpu = outputGpu->clone(cpuRuntime);
|
||||
// Check
|
||||
// inputCpu->printData();
|
||||
// oCpu->printData();
|
||||
EXPECT_TRUE(oCpu->equalData(vector<float>{0, 1, 2, 3, 12, 13, 14, 15,
|
||||
4, 5, 6, 7, 16, 17, 18, 19,
|
||||
8, 9, 10, 11, 20, 21, 22, 23}));
|
||||
}
|
||||
|
||||
TEST(cuda_Transpose, run) {
|
||||
testTranspose<TransposeObj>(IncrementalGenerator(), Shape{1, 2, 3, 4});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -26,10 +26,11 @@ def load_onnx(runtime, filename: str) -> ft.Graph:
|
|||
|
||||
|
||||
def run_and_evaluate(runtime, g):
|
||||
ft.initializeGraphTensors(g)
|
||||
runtime.run(g, True)
|
||||
print(f'getPerfTime = {runtime.getPerfTime(g, True, False, False)}')
|
||||
print(f'Non-ctc time = {runtime.timeNonCtcOperators(g, 1000, 1000)}')
|
||||
print(f'Cuda graph time = {runtime.timeWithCudaGraph(g)}')
|
||||
print(f'Cuda graph time = {runtime.timeWithCudaGraph(g, 100)}')
|
||||
|
||||
|
||||
def run_graph_get_output_as_torch_tensor(runtime, g):
|
||||
|
@ -85,10 +86,17 @@ def evluate_GANs():
|
|||
run_and_evaluate(runtime, g)
|
||||
|
||||
|
||||
def construct_convTranspose2d(runtime):
|
||||
# def construct_convTranspose2d(runtime):
|
||||
# handler = ft.GraphHandler(runtime)
|
||||
# input = handler.tensor([1, 56, 32, 32], tensor_type=ft.TensorType.Input)
|
||||
# w = handler.tensor([56, 1, 9, 9], tensor_type=ft.TensorType.Initialized)
|
||||
# handler.convTransposed2d(input, w, None, 3, 3, 4, 4, 1, 1, 1, 1)
|
||||
# return handler.getGraph()
|
||||
|
||||
def construct_convTranspose2d(runtime, n, c, h, w, f, r, s, pad, stride, dilation):
|
||||
handler = ft.GraphHandler(runtime)
|
||||
input = handler.tensor([1, 56, 32, 32], tensor_type=ft.TensorType.Input)
|
||||
w = handler.tensor([56, 1, 9, 9], tensor_type=ft.TensorType.Initialized)
|
||||
input = handler.tensor([n, f, h, w], tensor_type=ft.TensorType.Input)
|
||||
w = handler.tensor([f, c, r, s], tensor_type=ft.TensorType.Initialized)
|
||||
handler.convTransposed2d(input, w, None, 3, 3, 4, 4, 1, 1, 1, 1)
|
||||
return handler.getGraph()
|
||||
|
||||
|
@ -121,14 +129,32 @@ def construct_convtranposed_nhwc(runtime, n, c, h, w, f, r, s, pad, stride, dila
|
|||
return handler.getGraph()
|
||||
|
||||
|
||||
def export_op_level_onnx(runtime):
|
||||
graphs = [
|
||||
(construct_conv(runtime, 1, 512, 7, 7, 512, 3, 3,
|
||||
1, 1, 1), "orig_Conv3x3"), # ResNet18 Conv_37
|
||||
# 16, 256, 2, 2, 448, 4, 4, 1, 2, 1 # CelebA_ConvTranspose_0
|
||||
# TODO
|
||||
(construct_convTranspose2d(), "orig_ConvTranspose"),
|
||||
(construct_conv(runtime, 16, 32, 224, 224, 1, 5,
|
||||
5, 2, 1, 1, 1), "orig_Conv5x5"), # SRCNN_Conv_4
|
||||
(construct_convTranspose2d(), "orig_G2BMM"),
|
||||
]
|
||||
for g, name in graphs:
|
||||
save_onnx(g, f"opt_{name}.onnx")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime = ft.cuda_runtime()
|
||||
graphs = [
|
||||
# (construct_conv(runtime, 16, 56, 32, 32, 12, 1, 1, 0, 1, 1), 'conv1x1'), # FSRCNN Conv_2 1x1
|
||||
# (construct_conv(runtime, 1, 12, 32, 32, 12, 3, 3, 1, 1, 1), 'conv3x3'), # FSRCNN Conv_4 3x3
|
||||
# ft.getGANGraph(batch, runtime, 5, 1)
|
||||
# (ft.getLongformer(runtime, 1), 'longformer.bs1'),
|
||||
# (ft.getLongformer(runtime, 16), 'longformer.bs16'),
|
||||
# construct_convTranspose2d(runtime)
|
||||
# (load_onnx(runtime, '/mnt/auxHome/models/einnet/fsrcnn.bs1.onnx'), 'fsrcnn.bs1'),
|
||||
(ft.getFSRCNNGraph(1, runtime), "fsrcnn.bs1"),
|
||||
(ft.getFSRCNNGraph(16, runtime), "fsrcnn.bs16")
|
||||
# (construct_conv_nhwc(runtime, 1, 56, 32, 32, 12, 1, 1, 0, 1, 1), 'conv1x1')
|
||||
]
|
||||
|
@ -138,9 +164,12 @@ if __name__ == "__main__":
|
|||
if True: # Optimization
|
||||
save_onnx(original_g, f"orig_{name}.onnx")
|
||||
g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.RuleBased,
|
||||
[3, 2, 2, 5, 8, 8, 6, 90])
|
||||
[1, 7, 7, 2, 8, 6, 6]) # G2BMM/GBMM
|
||||
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.RuleBased,
|
||||
# [3, 2, 2, 5, 8, 8, 6, 90]) # Conv2conv
|
||||
# g = ft.optimizeGraph(original_g, runtime, False, ft.NMutatorMode.Normal)
|
||||
|
||||
save_onnx(g, f"optimized_{name}.onnx")
|
||||
verify_graphs(runtime, original_g, g)
|
||||
save_onnx(g, f"opt_{name}.onnx")
|
||||
# verify_graphs(runtime, original_g, g)
|
||||
# run_and_evaluate(runtime, original_g)
|
||||
run_and_evaluate(runtime, g)
|
||||
|
|
|
@ -7,6 +7,9 @@
|
|||
#include "nnet/nmutator.h"
|
||||
#include "nnet/test.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/transpose.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -479,4 +482,73 @@ TEST(NMutator, InfoGAN_TConv_3_correctness) {
|
|||
// bestGraph->print();
|
||||
// EXPECT_TRUE(graph->verification(bestGraph.get(), true));
|
||||
// }
|
||||
|
||||
TEST(NMutator, eliminateVertically_RTSTR) {
|
||||
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
const int a = 8, b = 4, c = 5, d = 30;
|
||||
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
|
||||
auto input = t0;
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
|
||||
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
|
||||
auto mutator = make_ref<NMutator>();
|
||||
auto optG = mutator->eliminateVertically(g);
|
||||
dbg(optG);
|
||||
ASSERT_EQ(optG->getOperators().size(), 1u);
|
||||
auto op = optG->getOperators()[0];
|
||||
EXPECT_EQ(op->getOpType(), OpType::Softmax);
|
||||
EXPECT_EQ(op->getInputs(0)->getFuid(), input->getFuid());
|
||||
EXPECT_EQ(op->getOutput()->getFuid(), t0->getFuid());
|
||||
}
|
||||
|
||||
TEST(NMutator, eliminateVertically_RTST) {
|
||||
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
const int a = 8, b = 4, c = 5, d = 30;
|
||||
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
|
||||
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{0, 2, 1, 3})->getOutput();
|
||||
auto mutator = make_ref<NMutator>();
|
||||
auto optG = mutator->eliminateVertically(g);
|
||||
dbg(optG);
|
||||
ASSERT_EQ(optG->getOperators().size(), 2u);
|
||||
}
|
||||
|
||||
TEST(NMutator, eliminateVertically_RTSTR_3d) {
|
||||
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
const int a = 8, b = 4, c = 5, d = 30;
|
||||
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{1, 2, 0, 3})->getOutput();
|
||||
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 3)->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{2, 0, 1, 3})->getOutput();
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
|
||||
auto mutator = make_ref<NMutator>();
|
||||
auto optG = mutator->eliminateVertically(g);
|
||||
dbg(optG);
|
||||
EXPECT_EQ(optG->getOperators().size(), 1u);
|
||||
}
|
||||
|
||||
TEST(NMutator, eliminateVertically_RTSTR_softmax_non_last_dim) {
|
||||
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
const int a = 8, b = 4, c = 5, d = 30;
|
||||
auto t0 = g->addTensor({a, b * c, d}, DataType::Float32, TensorType::Input);
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b, c, d})->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{1, 2, 0, 3})->getOutput();
|
||||
t0 = g->addOp<SoftmaxObj>(t0, nullptr, 2)->getOutput();
|
||||
t0 = g->addOp<TransposeObj>(t0, nullptr, Shape{2, 0, 1, 3})->getOutput();
|
||||
t0 = g->addOp<ReshapeObj>(t0, nullptr, Shape{a, b * c, d})->getOutput();
|
||||
auto mutator = make_ref<NMutator>();
|
||||
auto optG = mutator->eliminateVertically(g);
|
||||
dbg(optG);
|
||||
EXPECT_EQ(optG->getOperators().size(), 5u);
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "operators/any.h"
|
||||
#include "test.h"
|
||||
using namespace infini;
|
||||
using namespace std;
|
||||
|
||||
TEST(Any, ShapeInference) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
vector<int> attr;
|
||||
string kernelName = "fake_kernel_name";
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::Float32);
|
||||
Tensor i1 = g->addTensor({2, 2, 3}, DataType::Float32);
|
||||
Tensor o0 = g->addTensor({3, 2, 3}, DataType::Float32);
|
||||
auto anyOp = g->addOpWithOutputs<AnyObj>(
|
||||
TensorVec{i0, i1}, TensorVec{o0}, kernelName, attr);
|
||||
EXPECT_TRUE(anyOp->getOutputs().size() == 1);
|
||||
EXPECT_EQ(anyOp->getOutput()->getDims(), (Shape{3, 2, 3}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::Float32);
|
||||
Tensor i1 = g->addTensor({2, 2, 3}, DataType::Float32);
|
||||
Tensor o0 = g->addTensor({2, 2, 3}, DataType::Float32);
|
||||
Tensor o1 = g->addTensor({1, 2, 3}, DataType::Float32);
|
||||
auto anyOp = g->addOpWithOutputs<AnyObj>(
|
||||
TensorVec{i0, i1}, TensorVec{o0, o1}, kernelName, attr);
|
||||
EXPECT_TRUE(anyOp->getOutputs().size() == 2);
|
||||
EXPECT_EQ(anyOp->getOutput(0)->getDims(), (Shape{2, 2, 3}));
|
||||
EXPECT_EQ(anyOp->getOutput(1)->getDims(), (Shape{1, 2, 3}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Any, Attr) {
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
string kernelName = "fake_kernel_name";
|
||||
vector<int> attr = {2, 3, 2, 1, 4, 4};
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::Float32);
|
||||
Tensor i1 = g->addTensor({2, 2, 3}, DataType::Float32);
|
||||
Tensor o0 = g->addTensor({3, 2, 3}, DataType::Float32);
|
||||
auto anyOp = g->addOpWithOutputs<AnyObj>(TensorVec{i0, i1}, TensorVec{o0},
|
||||
kernelName, attr);
|
||||
EXPECT_EQ(anyOp->getOpAttrVector(), attr);
|
||||
}
|
Loading…
Reference in New Issue