Add: expr simplication for TVM

This commit is contained in:
Liyan Zheng 2023-04-07 13:48:29 +08:00
parent 0d8685b1e0
commit e8b4e3f03f
10 changed files with 141 additions and 35 deletions

View File

@ -19,6 +19,9 @@ class GraphObj : public Object {
Tensor addTensor(Shape dim, DataType dtype = DataType::Float32);
Tensor addTensor(const Tensor &tensor);
TensorVec addTensor(const TensorVec &tensors);
/**
* @brief Clone a tensor and add it to the graph.
*/
Tensor cloneTensor(const Tensor &tensor) {
return addTensor(tensor->clone(runtime));
}
@ -79,6 +82,8 @@ class GraphObj : public Object {
return ret;
}
bool selfCheck(bool assert = false) const;
private:
/**
* @brief Add reverse connections and Op relationship in ctor.

View File

@ -20,7 +20,11 @@ class MergeMemboundMutator : public Mutator {
*/
MergeMemboundMutator(const VecExpr &kernels)
: Mutator(), kernels(kernels), curDepth(kernels.size() - 1) {}
Expr merge(bool allowEmptyMembound = false);
/// @brief Merged multiple expressions into one with one or several stages.
/// @param allowEmptyMembound
/// @param allowFailure If true, return nullptr when merging fails. If false, assert will fail.
/// @return
Expr merge(bool allowEmptyMembound = false, bool allowFailure = false);
};
} // namespace nnet

View File

@ -10,6 +10,7 @@ class MemBoundObj : public OperatorObj {
nnet::Expr expr;
double exec_time;
std::string hint;
HashType hash;
int n, f, h, w;
public:

View File

@ -5,19 +5,26 @@ import tvm
from tvm import te, tir, auto_scheduler, topi
import os
import json
import logging
USE_CACHE = True
logger = logging.getLogger('InfiniTensor')
logger.setLevel(logging.DEBUG)
def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, func_name, input_names, output_name, hash_code=None):
def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f,
func_name, input_names, output_name, nnet_expression: str,
nnet_simplified_expression: str, hash_code=None):
assert len(input_tensors) == len(input_dtypes)
assert len(input_tensors) == len(input_names)
logging.debug(f'Work on hash {hash_code}')
dir_name = os.path.join(".cache", "generated_kernels", str(hash_code))
func_code_fn = os.path.join(dir_name, "kernel.cu")
invoke_code_fn = os.path.join(dir_name, "invoke.cpp")
config_fn = os.path.join(dir_name, "config.json")
if USE_CACHE and hash_code is not None:
if os.path.exists(dir_name):
print(f"Use cache in {dir_name}")
@ -29,7 +36,8 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu
config = json.loads(config_fin.read().strip())
conv_time = config["conv_time"]
invoke_params = config["invoke_params"]
logger.debug(f'Find tuning log for {hash_code}')
return func_code, invoke_code, conv_time, invoke_params
print("Generating Ansor op: ")
@ -65,7 +73,7 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu
# Kill the measurement process
del measure_ctx
def test_mutator():
# test part
tgt_temp = tvm.target.Target(target="llvm", host="llvm")
@ -74,13 +82,16 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu
args = all_tensors
C0, K0, A0 = args
func_temp = tvm.build(sch, args, tgt_temp, name="temp")
# print result
n, c, h, w, f, r, s = 1, 1, 2, 2, 1, 4, 4
dev_temp = tvm.device(tgt_temp.kind.name, 0)
A_temp = tvm.nd.array(np.arange(n*h*w*f).reshape(n,h,w,f).astype(A0.dtype), dev_temp)
K_temp = tvm.nd.array(np.arange(f*r*s*c).reshape(f,r,s,c).astype(K0.dtype), dev_temp)
C_temp = tvm.nd.array(np.zeros((1,4,4,1)).astype(C0.dtype), dev_temp)
A_temp = tvm.nd.array(
np.arange(n*h*w*f).reshape(n, h, w, f).astype(A0.dtype), dev_temp)
K_temp = tvm.nd.array(
np.arange(f*r*s*c).reshape(f, r, s, c).astype(K0.dtype), dev_temp)
C_temp = tvm.nd.array(
np.zeros((1, 4, 4, 1)).astype(C0.dtype), dev_temp)
func_temp(C_temp, K_temp, A_temp)
print("================= Test Result =====================")
print(C_temp)
@ -125,13 +136,16 @@ def gen_ansor_op(input_tensors, input_dtypes, output_tensor, output_dtype, f, fu
print("Func Code")
# Attach TVM code behind func_code
func_code += "\n/* " + f + "*/"
func_code += "\n/* NNET tensor expression \n" + nnet_expression + "\n*/\n"
func_code += "\n/* NNET simplified tensor expression \n" + \
nnet_simplified_expression + "\n*/\n"
func_code += "\n/* TVM compute\n" + f + "\n*/\n"
print(func_code)
print("Invoke Code")
print(invoke_code)
print("Time")
print(conv_time)
if hash_code is not None:
if not os.path.exists(dir_name):
os.makedirs(dir_name)

View File

@ -147,4 +147,19 @@ OpVec GraphObj::getComputeOps() const {
return opList;
};
bool GraphObj::selfCheck(bool assert) const {
std::set<UidBaseType> s;
// check whether two tensors with the same FUID exist
for (auto tensor : tensors) {
int cnt = s.count(tensor->getFuid());
if (assert)
IT_ASSERT(cnt == 0, std::to_string(tensor->getFuid()));
else if (cnt > 0)
return false;
s.insert(tensor->getFuid());
}
return true;
}
} // namespace infini

View File

@ -82,6 +82,11 @@ void TensorObj::printDataFloat(float *ptr) const {
for (int i = numDims - 1; i != 0; --i)
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
if (iEnd > 1000 && i > 20 && i < iEnd - 20) {
printf("... , ");
i = iEnd - 20;
continue;
}
for (size_t j = 0; j < numDims; ++j) {
if (i % dimSzVec[j] == 0) {
std::cout << "[";

View File

@ -4,6 +4,8 @@
#include "nnet/Visitor/AsTVMVisitor.h"
#include "nnet/Visitor/CheckOOBVisitor.h"
#include "nnet/Visitor/HashVisitor.h"
#include "nnet/Visitor/MergeMemboundMutator.h"
#include "nnet/dbg.h"
#include "nvrtc.h"
#include "operators/membound.h"
#include "operators/pooling.h"
@ -19,6 +21,7 @@ class TVMRecordObj : public PerfRecordObj {
std::string log, ptx;
std::vector<int> invokeParams;
std::string kernelName;
HashType simplifiedExprHash;
};
using TVMRecord = Ref<TVMRecordObj>;
@ -80,14 +83,16 @@ class MemboundTVM : public Kernel {
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
// invoke Ansor to tune a membound kernel
std::string func = "mem_bound_" + std::to_string(op->getGuid());
std::string kernelName = func + "_kernel0";
nnet::AsTVMVisitor visitor;
IT_ASSERT(!checkOOB(op->getNnetExpr()));
// fuse stages in nnet expr to reduce kernels generated by TVM
auto expr = op->getNnetExpr();
if (auto mergedExpr =
nnet::MergeMemboundMutator({expr}).merge(false, true))
expr = mergedExpr;
nnet::HashVisitor hashVisitor;
auto hashCode = hashVisitor.dispatch(expr);
HashType hashCode = hashVisitor.getHash(expr);
visitor.dispatch(expr);
auto &&stmts = visitor.getStmts();
@ -98,10 +103,14 @@ class MemboundTVM : public Kernel {
for (auto &&in : op->getInputs()) {
inputs.emplace_back(getVarName(in));
}
std::string output = getVarName(op->getOutput());
const std::string output = getVarName(op->getOutput());
const std::string func = "membound_" + std::to_string(hashCode);
const std::string kernelName = func + "_kernel0";
auto res = getAnsorCode(
inShapes, std::vector<std::string>(inShapes.size(), "float32"),
outShape, "float32", stmts, func, inputs, output, hashCode);
outShape, "float32", stmts, func, inputs, output, op->toString(),
expr->toReadable(), hashCode);
// compile the kernel
auto funcCode = res.first;
@ -132,6 +141,7 @@ class MemboundTVM : public Kernel {
nvrtcGetPTX(prog, ret->ptx.data());
ret->invokeParams = invokeParams;
ret->kernelName = kernelName;
ret->simplifiedExprHash = hashCode;
// prepare for evaluation
CUmodule module;
@ -157,6 +167,13 @@ class MemboundTVM : public Kernel {
},
[&]() { context->sync(); });
// dbg(kernelName, "Print input");
// op->getInputs()[0]->print();
// op->getInputs()[0]->printData();
// dbg("Print output");
// op->getOutput()->print();
// op->getOutput()->printData();
// free module
checkCUresult(cuModuleUnload(module));
nvrtcDestroyProgram(&prog);
@ -164,13 +181,31 @@ class MemboundTVM : public Kernel {
return std::dynamic_pointer_cast<PerfRecordObj>(ret);
}
/// @brief
/// @param inDims
/// @param inDTypes
/// @param outDims
/// @param outDType
/// @param lambda
/// @param funcName Generated function name
/// @param inputNames Input array names in the generated invocation code.
/// @param outputName Output array names in the generated invocation code.
/// @param nnetExpressionString Save expr in string for logging.
/// @param nnetSimplifiedExprString Save simplified expr in string for
/// logging.
/// @param hashCode (optional) Hash code of the input expression for kernel
/// cache.
/// @return
std::pair<std::string, std::vector<int>>
getAnsorCode(const std::vector<std::vector<int>> &inDims,
const std::vector<std::string> &inDTypes,
const std::vector<int> &outDims, const std::string &outDType,
const std::string &lambda, const std::string &funcName,
const std::vector<std::string> &inputNames,
const std::string &outputName, const HashType hashCode) const {
const std::string &outputName,
const std::string &nnetExprString,
const std::string &nnetSimplifiedExprString,
const HashType hashCode) const {
std::string funcCode;
std::vector<int> invokeParams;
try {
@ -179,8 +214,10 @@ class MemboundTVM : public Kernel {
// in cuBLAS failure, whose root cause is not identified yet.
static auto func =
py::module::import("cpp_plugin").attr("gen_ansor_op");
py::tuple code = func(inDims, inDTypes, outDims, outDType, lambda,
funcName, inputNames, outputName, hashCode);
py::tuple code =
func(inDims, inDTypes, outDims, outDType, lambda, funcName,
inputNames, outputName, nnetExprString,
nnetSimplifiedExprString, std::to_string(hashCode));
funcCode = py::str(code[0]);
auto temp = py::list(code[3]);
for (int i = 0; i < 6; ++i) {

View File

@ -5,7 +5,7 @@
namespace nnet {
Expr MergeMemboundMutator::merge(bool allowEmptyMembound) {
Expr MergeMemboundMutator::merge(bool allowEmptyMembound, bool allowFailure) {
// FIXME: fix empty expression in membound
assert(kernels.size() >= 1);
if (checkEmpty()) {
@ -27,19 +27,25 @@ Expr MergeMemboundMutator::merge(bool allowEmptyMembound) {
assert(CheckOOBVisitor().checkRangeOp(curRangeOp) == false);
auto summand = curRangeOp->getSummand();
if (auto subscriptOp = as<SubscriptNode>(summand)) {
// Try merging the current and next stages
if (auto mergedExpr = rule4StageMerging(*curExpr, true)) {
// dbg(*curExpr, mergedExpr);
*curExpr = mergedExpr;
merged = true;
break;
}
// If merging fails, try the next stage
curExpr = subscriptOp->getObjectPtr();
nnet_assert(*curExpr != nullptr, __LINE__);
} else if (auto funcOp = as<FuncNode>(summand)) {
// Relu({...}[i,j])
curExpr = funcOp->getObject()->getObjectPtr();
} else
nnet_unimplemented_halt();
} else {
if (allowFailure)
return nullptr;
else
nnet_unimplemented_halt();
}
}
} while (merged);
return expr;

View File

@ -10,6 +10,7 @@ MemBoundObj::MemBoundObj(GraphObj *graph, const TensorVec &input,
: OperatorObj(OpType::MemBound, input, output), nnetInputs(nnetInputs),
expr(expr), exec_time(exec_time), hint(hint) {
IT_ASSERT(checkValid(graph));
hash = getHash();
}
string MemBoundObj::toString() const {
@ -31,7 +32,8 @@ string MemBoundObj::toString() const {
os << "NNet Inputs=[";
for (const auto &tensor : nnetInputs)
os << tensor->toReadable() << ",";
os << "])";
os << "]";
os << ", ExprHash=" << hash << ")";
os << "\n" << (expr ? expr->toReadable() : "Empty expression") << "\n";
return os.str();
}
@ -47,7 +49,7 @@ optional<vector<Shape>> MemBoundObj::inferShape(const TensorVec &inputs) const {
}
vector<int> MemBoundObj::getWorkloadVector() const {
return {enum_to_underlying(type), (int)getHash()};
return {enum_to_underlying(type), (int)hash};
}
vector<int> MemBoundObj::getOpAttrVector() const { return getWorkloadVector(); }

View File

@ -4,19 +4,16 @@
#include "nnet/Visitor/HashVisitor.h"
#include "nnet/Visitor/MergeMemboundMutator.h"
#include "nnet/expr.h"
#include "nnet/test.h"
#include "gtest/gtest.h"
using namespace nnet;
using namespace std;
#define DEFINE_VAR(name) auto name = make_ref<VarNode>(#name);
TEST(FuseMembound, Relu) {
const int n_heads = 8, seq_len = 10000, feat_len = 512;
// dilation_heads = 2;
const int Batch = n_heads, M = seq_len, K = feat_len, W = 32;
DEFINE_VAR(b);
DEFINE_VAR(m);
DEFINE_VAR(w);
DEFINE_VAR(k);
DEFINE_VAR(b, m, w, k);
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
vector<int>{0, 0, 0});
@ -35,10 +32,7 @@ TEST(FuseMembound, MemMemFusion) {
const int n_heads = 8, seq_len = 100, feat_len = 100;
// dilation_heads = 2;
const int Batch = n_heads, M = seq_len, K = feat_len;
DEFINE_VAR(b);
DEFINE_VAR(m);
DEFINE_VAR(w);
DEFINE_VAR(k);
DEFINE_VAR(b, m, w, k);
auto A = make_ref<TensorNode>("A", vector<int>({Batch, M, K}),
vector<int>{0, 0, 0});
auto B = make_ref<TensorNode>("B", vector<int>({Batch, K, M}),
@ -54,4 +48,27 @@ TEST(FuseMembound, MemMemFusion) {
RangeOp ans = makeRangeOperator({{b, {0, Batch}}, {m, {0, M}}},
{{k, {0, K}}}, makeSubscript(A, {b, m, k}));
EXPECT_EQ(HashVisitor().getHash(merged), HashVisitor().getHash(ans));
}
TEST(FuseMembound, mergeNestedStagesInRangeOp) {
// Case in ConvTranspose to Matmul
// L<f:0:448><i39:0:4096>Sum ... [i39,f]
// {L<i39:0:4096><f:0:448>Sum ... [f,(i39 / 1024),((i39 / 256) % 4),(i39
// % 256)] {K}}
DEFINE_VAR(f, i);
const int I = 4096, F = 448;
auto K = make_ref<TensorNode>("K", vector<int>({448, 4, 4, 256}));
auto subA = makeSubscript(K, {f, i / 1024, (i / 256) % 4, i % 256});
auto range = makeRangeOperator({{i, {0, I}}, {f, {0, F}}}, {}, subA);
auto outerRange = makeRangeOperator({{f, {0, F}}, {i, {0, I}}}, {},
makeSubscript(range, {i, f}));
auto merged = MergeMemboundMutator({outerRange}).merge();
// Compare the result with answer
RangeOp ans = makeRangeOperator(
{{f, {0, F}}, {i, {0, I}}}, {},
makeSubscript(K, {f, i / 1024, (i / 256) % 4, i % 256}));
dbg(merged, ans);
EXPECT_EQ(HashVisitor().getHash(merged), HashVisitor().getHash(ans));
}