forked from jiuyuan/InfiniTensor
Add: expr simplication for TVM
This commit is contained in:
parent
0d8685b1e0
commit
e8b4e3f03f
|
@ -19,6 +19,9 @@ class GraphObj : public Object {
|
|||
Tensor addTensor(Shape dim, DataType dtype = DataType::Float32);
|
||||
Tensor addTensor(const Tensor &tensor);
|
||||
TensorVec addTensor(const TensorVec &tensors);
|
||||
/**
|
||||
* @brief Clone a tensor and add it to the graph.
|
||||
*/
|
||||
Tensor cloneTensor(const Tensor &tensor) {
|
||||
return addTensor(tensor->clone(runtime));
|
||||
}
|
||||
|
@ -79,6 +82,8 @@ class GraphObj : public Object {
|
|||
return ret;
|
||||
}
|
||||
|
||||
bool selfCheck(bool assert = false) const;
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Add reverse connections and Op relationship in ctor.
|
||||
|
|
|
@ -20,7 +20,11 @@ class MergeMemboundMutator : public Mutator {
|
|||
*/
|
||||
MergeMemboundMutator(const VecExpr &kernels)
|
||||
: Mutator(), kernels(kernels), curDepth(kernels.size() - 1) {}
|
||||
Expr merge(bool allowEmptyMembound = false);
|
||||
/// @brief Merged multiple expressions into one with one or several stages.
|
||||
/// @param allowEmptyMembound
|
||||
/// @param allowFailure If true, return nullptr when merging fails. If false, assert will fail.
|
||||
/// @return
|
||||
Expr merge(bool allowEmptyMembound = false, bool allowFailure = false);
|
||||
};
|
||||
|
||||
} // namespace nnet
|
|
@ -10,6 +10,7 @@ class MemBoundObj : public OperatorObj {
|
|||
nnet::Expr expr;
|
||||
double exec_time;
|
||||
std::string hint;
|
||||
HashType hash;
|
||||
int n, f, h, w;
|
||||
|
||||
public:
|
||||
|
|
|
@ -5,19 +5,26 @@ import tvm
|
|||
from tvm import te, tir, auto_scheduler, topi
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
USE_CACHE = True
|
||||
logger = logging.getLogger('InfiniTensor')
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, func_name, input_names, output_name, hash_code=None):
|
||||
def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f,
|
||||
func_name, input_names, output_name, nnet_expression: str,
|
||||
nnet_simplified_expression: str, hash_code=None):
|
||||
assert len(input_tensors) == len(input_dtypes)
|
||||
assert len(input_tensors) == len(input_names)
|
||||
|
||||
|
||||
logging.debug(f'Work on hash {hash_code}')
|
||||
|
||||
dir_name = os.path.join(".cache", "generated_kernels", str(hash_code))
|
||||
func_code_fn = os.path.join(dir_name, "kernel.cu")
|
||||
invoke_code_fn = os.path.join(dir_name, "invoke.cpp")
|
||||
config_fn = os.path.join(dir_name, "config.json")
|
||||
|
||||
|
||||
if USE_CACHE and hash_code is not None:
|
||||
if os.path.exists(dir_name):
|
||||
print(f"Use cache in {dir_name}")
|
||||
|
@ -29,7 +36,8 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu
|
|||
config = json.loads(config_fin.read().strip())
|
||||
conv_time = config["conv_time"]
|
||||
invoke_params = config["invoke_params"]
|
||||
|
||||
|
||||
logger.debug(f'Find tuning log for {hash_code}')
|
||||
return func_code, invoke_code, conv_time, invoke_params
|
||||
|
||||
print("Generating Ansor op: ")
|
||||
|
@ -65,7 +73,7 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu
|
|||
|
||||
# Kill the measurement process
|
||||
del measure_ctx
|
||||
|
||||
|
||||
def test_mutator():
|
||||
# test part
|
||||
tgt_temp = tvm.target.Target(target="llvm", host="llvm")
|
||||
|
@ -74,13 +82,16 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu
|
|||
args = all_tensors
|
||||
C0, K0, A0 = args
|
||||
func_temp = tvm.build(sch, args, tgt_temp, name="temp")
|
||||
|
||||
|
||||
# print result
|
||||
n, c, h, w, f, r, s = 1, 1, 2, 2, 1, 4, 4
|
||||
dev_temp = tvm.device(tgt_temp.kind.name, 0)
|
||||
A_temp = tvm.nd.array(np.arange(n*h*w*f).reshape(n,h,w,f).astype(A0.dtype), dev_temp)
|
||||
K_temp = tvm.nd.array(np.arange(f*r*s*c).reshape(f,r,s,c).astype(K0.dtype), dev_temp)
|
||||
C_temp = tvm.nd.array(np.zeros((1,4,4,1)).astype(C0.dtype), dev_temp)
|
||||
A_temp = tvm.nd.array(
|
||||
np.arange(n*h*w*f).reshape(n, h, w, f).astype(A0.dtype), dev_temp)
|
||||
K_temp = tvm.nd.array(
|
||||
np.arange(f*r*s*c).reshape(f, r, s, c).astype(K0.dtype), dev_temp)
|
||||
C_temp = tvm.nd.array(
|
||||
np.zeros((1, 4, 4, 1)).astype(C0.dtype), dev_temp)
|
||||
func_temp(C_temp, K_temp, A_temp)
|
||||
print("================= Test Result =====================")
|
||||
print(C_temp)
|
||||
|
@ -125,13 +136,16 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu
|
|||
|
||||
print("Func Code")
|
||||
# Attach TVM code behind func_code
|
||||
func_code += "\n/* " + f + "*/"
|
||||
func_code += "\n/* NNET tensor expression \n" + nnet_expression + "\n*/\n"
|
||||
func_code += "\n/* NNET simplified tensor expression \n" + \
|
||||
nnet_simplified_expression + "\n*/\n"
|
||||
func_code += "\n/* TVM compute\n" + f + "\n*/\n"
|
||||
print(func_code)
|
||||
print("Invoke Code")
|
||||
print(invoke_code)
|
||||
print("Time")
|
||||
print(conv_time)
|
||||
|
||||
|
||||
if hash_code is not None:
|
||||
if not os.path.exists(dir_name):
|
||||
os.makedirs(dir_name)
|
||||
|
|
|
@ -147,4 +147,19 @@ OpVec GraphObj::getComputeOps() const {
|
|||
return opList;
|
||||
};
|
||||
|
||||
bool GraphObj::selfCheck(bool assert) const {
|
||||
std::set<UidBaseType> s;
|
||||
// check whether two tensors with the same FUID exist
|
||||
for (auto tensor : tensors) {
|
||||
int cnt = s.count(tensor->getFuid());
|
||||
|
||||
if (assert)
|
||||
IT_ASSERT(cnt == 0, std::to_string(tensor->getFuid()));
|
||||
else if (cnt > 0)
|
||||
return false;
|
||||
s.insert(tensor->getFuid());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -82,6 +82,11 @@ void TensorObj::printDataFloat(float *ptr) const {
|
|||
for (int i = numDims - 1; i != 0; --i)
|
||||
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
|
||||
for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
|
||||
if (iEnd > 1000 && i > 20 && i < iEnd - 20) {
|
||||
printf("... , ");
|
||||
i = iEnd - 20;
|
||||
continue;
|
||||
}
|
||||
for (size_t j = 0; j < numDims; ++j) {
|
||||
if (i % dimSzVec[j] == 0) {
|
||||
std::cout << "[";
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
#include "nnet/Visitor/AsTVMVisitor.h"
|
||||
#include "nnet/Visitor/CheckOOBVisitor.h"
|
||||
#include "nnet/Visitor/HashVisitor.h"
|
||||
#include "nnet/Visitor/MergeMemboundMutator.h"
|
||||
#include "nnet/dbg.h"
|
||||
#include "nvrtc.h"
|
||||
#include "operators/membound.h"
|
||||
#include "operators/pooling.h"
|
||||
|
@ -19,6 +21,7 @@ class TVMRecordObj : public PerfRecordObj {
|
|||
std::string log, ptx;
|
||||
std::vector<int> invokeParams;
|
||||
std::string kernelName;
|
||||
HashType simplifiedExprHash;
|
||||
};
|
||||
|
||||
using TVMRecord = Ref<TVMRecordObj>;
|
||||
|
@ -80,14 +83,16 @@ class MemboundTVM : public Kernel {
|
|||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
||||
// invoke Ansor to tune a membound kernel
|
||||
std::string func = "mem_bound_" + std::to_string(op->getGuid());
|
||||
std::string kernelName = func + "_kernel0";
|
||||
nnet::AsTVMVisitor visitor;
|
||||
IT_ASSERT(!checkOOB(op->getNnetExpr()));
|
||||
|
||||
// fuse stages in nnet expr to reduce kernels generated by TVM
|
||||
auto expr = op->getNnetExpr();
|
||||
if (auto mergedExpr =
|
||||
nnet::MergeMemboundMutator({expr}).merge(false, true))
|
||||
expr = mergedExpr;
|
||||
|
||||
nnet::HashVisitor hashVisitor;
|
||||
auto hashCode = hashVisitor.dispatch(expr);
|
||||
HashType hashCode = hashVisitor.getHash(expr);
|
||||
|
||||
visitor.dispatch(expr);
|
||||
auto &&stmts = visitor.getStmts();
|
||||
|
@ -98,10 +103,14 @@ class MemboundTVM : public Kernel {
|
|||
for (auto &&in : op->getInputs()) {
|
||||
inputs.emplace_back(getVarName(in));
|
||||
}
|
||||
std::string output = getVarName(op->getOutput());
|
||||
const std::string output = getVarName(op->getOutput());
|
||||
|
||||
const std::string func = "membound_" + std::to_string(hashCode);
|
||||
const std::string kernelName = func + "_kernel0";
|
||||
auto res = getAnsorCode(
|
||||
inShapes, std::vector<std::string>(inShapes.size(), "float32"),
|
||||
outShape, "float32", stmts, func, inputs, output, hashCode);
|
||||
outShape, "float32", stmts, func, inputs, output, op->toString(),
|
||||
expr->toReadable(), hashCode);
|
||||
|
||||
// compile the kernel
|
||||
auto funcCode = res.first;
|
||||
|
@ -132,6 +141,7 @@ class MemboundTVM : public Kernel {
|
|||
nvrtcGetPTX(prog, ret->ptx.data());
|
||||
ret->invokeParams = invokeParams;
|
||||
ret->kernelName = kernelName;
|
||||
ret->simplifiedExprHash = hashCode;
|
||||
|
||||
// prepare for evaluation
|
||||
CUmodule module;
|
||||
|
@ -157,6 +167,13 @@ class MemboundTVM : public Kernel {
|
|||
},
|
||||
[&]() { context->sync(); });
|
||||
|
||||
// dbg(kernelName, "Print input");
|
||||
// op->getInputs()[0]->print();
|
||||
// op->getInputs()[0]->printData();
|
||||
// dbg("Print output");
|
||||
// op->getOutput()->print();
|
||||
// op->getOutput()->printData();
|
||||
|
||||
// free module
|
||||
checkCUresult(cuModuleUnload(module));
|
||||
nvrtcDestroyProgram(&prog);
|
||||
|
@ -164,13 +181,31 @@ class MemboundTVM : public Kernel {
|
|||
return std::dynamic_pointer_cast<PerfRecordObj>(ret);
|
||||
}
|
||||
|
||||
/// @brief
|
||||
/// @param inDims
|
||||
/// @param inDTypes
|
||||
/// @param outDims
|
||||
/// @param outDType
|
||||
/// @param lambda
|
||||
/// @param funcName Generated function name
|
||||
/// @param inputNames Input array names in the generated invocation code.
|
||||
/// @param outputName Output array names in the generated invocation code.
|
||||
/// @param nnetExpressionString Save expr in string for logging.
|
||||
/// @param nnetSimplifiedExprString Save simplified expr in string for
|
||||
/// logging.
|
||||
/// @param hashCode (optional) Hash code of the input expression for kernel
|
||||
/// cache.
|
||||
/// @return
|
||||
std::pair<std::string, std::vector<int>>
|
||||
getAnsorCode(const std::vector<std::vector<int>> &inDims,
|
||||
const std::vector<std::string> &inDTypes,
|
||||
const std::vector<int> &outDims, const std::string &outDType,
|
||||
const std::string &lambda, const std::string &funcName,
|
||||
const std::vector<std::string> &inputNames,
|
||||
const std::string &outputName, const HashType hashCode) const {
|
||||
const std::string &outputName,
|
||||
const std::string &nnetExprString,
|
||||
const std::string &nnetSimplifiedExprString,
|
||||
const HashType hashCode) const {
|
||||
std::string funcCode;
|
||||
std::vector<int> invokeParams;
|
||||
try {
|
||||
|
@ -179,8 +214,10 @@ class MemboundTVM : public Kernel {
|
|||
// in cuBLAS failure, whose root cause is not identified yet.
|
||||
static auto func =
|
||||
py::module::import("cpp_plugin").attr("gen_ansor_op");
|
||||
py::tuple code = func(inDims, inDTypes, outDims, outDType, lambda,
|
||||
funcName, inputNames, outputName, hashCode);
|
||||
py::tuple code =
|
||||
func(inDims, inDTypes, outDims, outDType, lambda, funcName,
|
||||
inputNames, outputName, nnetExprString,
|
||||
nnetSimplifiedExprString, std::to_string(hashCode));
|
||||
funcCode = py::str(code[0]);
|
||||
auto temp = py::list(code[3]);
|
||||
for (int i = 0; i < 6; ++i) {
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
namespace nnet {
|
||||
|
||||
Expr MergeMemboundMutator::merge(bool allowEmptyMembound) {
|
||||
Expr MergeMemboundMutator::merge(bool allowEmptyMembound, bool allowFailure) {
|
||||
// FIXME: fix empty expression in membound
|
||||
assert(kernels.size() >= 1);
|
||||
if (checkEmpty()) {
|
||||
|
@ -27,19 +27,25 @@ Expr MergeMemboundMutator::merge(bool allowEmptyMembound) {
|
|||
assert(CheckOOBVisitor().checkRangeOp(curRangeOp) == false);
|
||||
auto summand = curRangeOp->getSummand();
|
||||
if (auto subscriptOp = as<SubscriptNode>(summand)) {
|
||||
// Try merging the current and next stages
|
||||
if (auto mergedExpr = rule4StageMerging(*curExpr, true)) {
|
||||
// dbg(*curExpr, mergedExpr);
|
||||
*curExpr = mergedExpr;
|
||||
merged = true;
|
||||
break;
|
||||
}
|
||||
// If merging fails, try the next stage
|
||||
curExpr = subscriptOp->getObjectPtr();
|
||||
nnet_assert(*curExpr != nullptr, __LINE__);
|
||||
} else if (auto funcOp = as<FuncNode>(summand)) {
|
||||
// Relu({...}[i,j])
|
||||
curExpr = funcOp->getObject()->getObjectPtr();
|
||||
} else
|
||||
nnet_unimplemented_halt();
|
||||
} else {
|
||||
if (allowFailure)
|
||||
return nullptr;
|
||||
else
|
||||
nnet_unimplemented_halt();
|
||||
}
|
||||
}
|
||||
} while (merged);
|
||||
return expr;
|
||||
|
|
|
@ -10,6 +10,7 @@ MemBoundObj::MemBoundObj(GraphObj *graph, const TensorVec &input,
|
|||
: OperatorObj(OpType::MemBound, input, output), nnetInputs(nnetInputs),
|
||||
expr(expr), exec_time(exec_time), hint(hint) {
|
||||
IT_ASSERT(checkValid(graph));
|
||||
hash = getHash();
|
||||
}
|
||||
|
||||
string MemBoundObj::toString() const {
|
||||
|
@ -31,7 +32,8 @@ string MemBoundObj::toString() const {
|
|||
os << "NNet Inputs=[";
|
||||
for (const auto &tensor : nnetInputs)
|
||||
os << tensor->toReadable() << ",";
|
||||
os << "])";
|
||||
os << "]";
|
||||
os << ", ExprHash=" << hash << ")";
|
||||
os << "\n" << (expr ? expr->toReadable() : "Empty expression") << "\n";
|
||||
return os.str();
|
||||
}
|
||||
|
@ -47,7 +49,7 @@ optional<vector<Shape>> MemBoundObj::inferShape(const TensorVec &inputs) const {
|
|||
}
|
||||
|
||||
vector<int> MemBoundObj::getWorkloadVector() const {
|
||||
return {enum_to_underlying(type), (int)getHash()};
|
||||
return {enum_to_underlying(type), (int)hash};
|
||||
}
|
||||
|
||||
vector<int> MemBoundObj::getOpAttrVector() const { return getWorkloadVector(); }
|
||||
|
|
|
@ -4,19 +4,16 @@
|
|||
#include "nnet/Visitor/HashVisitor.h"
|
||||
#include "nnet/Visitor/MergeMemboundMutator.h"
|
||||
#include "nnet/expr.h"
|
||||
#include "nnet/test.h"
|
||||
#include "gtest/gtest.h"
|
||||
using namespace nnet;
|
||||
using namespace std;
|
||||
#define DEFINE_VAR(name) auto name = make_ref<VarNode>(#name);
|
||||
|
||||
TEST(FuseMembound, Relu) {
|
||||
const int n_heads = 8, seq_len = 10000, feat_len = 512;
|
||||
// dilation_heads = 2;
|
||||
const int Batch = n_heads, M = seq_len, K = feat_len, W = 32;
|
||||
DEFINE_VAR(b);
|
||||
DEFINE_VAR(m);
|
||||
DEFINE_VAR(w);
|
||||
DEFINE_VAR(k);
|
||||
DEFINE_VAR(b, m, w, k);
|
||||
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
|
||||
vector<int>{0, 0, 0});
|
||||
|
||||
|
@ -35,10 +32,7 @@ TEST(FuseMembound, MemMemFusion) {
|
|||
const int n_heads = 8, seq_len = 100, feat_len = 100;
|
||||
// dilation_heads = 2;
|
||||
const int Batch = n_heads, M = seq_len, K = feat_len;
|
||||
DEFINE_VAR(b);
|
||||
DEFINE_VAR(m);
|
||||
DEFINE_VAR(w);
|
||||
DEFINE_VAR(k);
|
||||
DEFINE_VAR(b, m, w, k);
|
||||
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
|
||||
vector<int>{0, 0, 0});
|
||||
auto B = make_ref<TensorNode>("B", vector<int>({Batch, K, M}),
|
||||
|
@ -54,4 +48,27 @@ TEST(FuseMembound, MemMemFusion) {
|
|||
RangeOp ans = makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}},
|
||||
{{k, {0, K}}}, makeSubscript(A, {b, m, k}));
|
||||
EXPECT_EQ(HashVisitor().getHash(merged), HashVisitor().getHash(ans));
|
||||
}
|
||||
|
||||
TEST(FuseMembound, mergeNestedStagesInRangeOp) {
|
||||
// Case in ConvTranspose to Matmul
|
||||
// L<f:0:448><i39:0:4096>Sum ... [i39,f]
|
||||
// {L<i39:0:4096><f:0:448>Sum ... [f,(i39 / 1024),((i39 / 256) % 4),(i39
|
||||
// % 256)] {K}}
|
||||
DEFINE_VAR(f, i);
|
||||
const int I = 4096, F = 448;
|
||||
auto K = make_ref<TensorNode>("K", vector<int>({448, 4, 4, 256}));
|
||||
|
||||
auto subA = makeSubscript(K, {f, i / 1024, (i / 256) % 4, i % 256});
|
||||
auto range = makeRangeOperator({{i, {0, I}}, {f, {0, F}}}, {}, subA);
|
||||
auto outerRange = makeRangeOperator({{f, {0, F}}, {i, {0, I}}}, {},
|
||||
makeSubscript(range, {i, f}));
|
||||
auto merged = MergeMemboundMutator({outerRange}).merge();
|
||||
|
||||
// Compare the result with answer
|
||||
RangeOp ans = makeRangeOperator(
|
||||
{{f, {0, F}}, {i, {0, I}}}, {},
|
||||
makeSubscript(K, {f, i / 1024, (i / 256) % 4, i % 256}));
|
||||
dbg(merged, ans);
|
||||
EXPECT_EQ(HashVisitor().getHash(merged), HashVisitor().getHash(ans));
|
||||
}
|
Loading…
Reference in New Issue