add: optimization pass for metaGraph.

This commit is contained in:
mazx 2022-10-28 20:42:10 +08:00
parent ec58c85505
commit 5ed540be6e
23 changed files with 200 additions and 201 deletions

View File

@ -1,8 +1,8 @@
rm ./eval_sar_drn_0 ./eval_sar_drn_1 ./eval_sar_drn_2
# rm ./eval_sar_drn_0 ./eval_sar_drn_1 ./eval_sar_drn_2
make -j && ./test_sar_drn
nvcc ../eval_pfusion/eval_sar_drn_0.cu ../generated_code/sar_drn_0.cu -I ../eval_pfusion -o eval_sar_drn_0
nvcc ../eval_pfusion/eval_sar_drn_1.cu ../generated_code/sar_drn_0.cu -I ../eval_pfusion -o eval_sar_drn_1
nvcc ../eval_pfusion/eval_sar_drn_2.cu ../generated_code/sar_drn_1.cu -I ../eval_pfusion -o eval_sar_drn_2
nvcc ../eval_pfusion/eval_sar_drn_1.cu ../generated_code/sar_drn_1.cu -I ../eval_pfusion -o eval_sar_drn_1
# nvcc ../eval_pfusion/eval_sar_drn_2.cu ../generated_code/sar_drn_1.cu -I ../eval_pfusion -o eval_sar_drn_2
./eval_sar_drn_0
./eval_sar_drn_1
./eval_sar_drn_2
# ./eval_sar_drn_2

View File

@ -3,8 +3,7 @@
#include <vector>
void invoke_func_0(float *tensor_ptr_2, float *tensor_ptr_3);
void invoke_func_1(float *tensor_ptr_2, float *tensor_ptr_3,
void invoke_func_2(float *tensor_ptr_2, float *tensor_ptr_3,
float *tensor_ptr_4);
int main() {
@ -25,11 +24,11 @@ int main() {
cudaEventCreate(&ed);
int cnt = 128;
for (int t = 0; t < cnt; t++) {
invoke_func_0(t0, t1);
invoke_func_2(t0, t1, t2);
}
cudaEventRecord(st, 0);
for (int t = 0; t < cnt; t++) {
invoke_func_0(t0, t1);
invoke_func_2(t0, t1, t2);
}
cudaEventRecord(ed, 0);
cudaEventSynchronize(st);

View File

@ -3,12 +3,11 @@
#include <vector>
void invoke_func_0(float *tensor_ptr_2, float *tensor_ptr_3);
void invoke_func_1(float *tensor_ptr_2, float *tensor_ptr_3,
void invoke_func_5(float *tensor_ptr_2, float *tensor_ptr_3,
float *tensor_ptr_4);
int main() {
std::vector<int> shape = {1, 64, 512, 512};
std::vector<int> shape = {1, 1, 512, 512};
float *t0, *t1, *t2;
size_t size = 1;
for (auto x : shape) {
@ -25,13 +24,11 @@ int main() {
cudaEventCreate(&ed);
int cnt = 128;
for (int t = 0; t < cnt; t++) {
invoke_func_0(t0, t1);
invoke_func_1(t0, t1, t2);
invoke_func_5(t0, t1, t2);
}
cudaEventRecord(st, 0);
for (int t = 0; t < cnt; t++) {
invoke_func_0(t0, t1);
invoke_func_1(t0, t1, t2);
invoke_func_5(t0, t1, t2);
}
cudaEventRecord(ed, 0);
cudaEventSynchronize(st);

View File

@ -1,44 +0,0 @@
#include "cuda.h"
#include "cuda_utils.h"
#include <vector>
void invoke_func_2(float *tensor_ptr_2, float *tensor_ptr_3);
void invoke_func_3(float *tensor_ptr_2, float *tensor_ptr_3,
float *tensor_ptr_4);
int main() {
std::vector<int> shape = {1, 1, 512, 512};
float *t0, *t1, *t2;
size_t size = 1;
for (auto x : shape) {
size *= x;
}
cudaSafeCall(cudaMalloc((void **)&t0, size * sizeof(float)));
cudaSafeCall(cudaMalloc((void **)&t1, size * sizeof(float)));
cudaSafeCall(cudaMalloc((void **)&t2, size * sizeof(float)));
float duration = 0;
cudaEvent_t st, ed;
cudaEventCreate(&st);
cudaEventCreate(&ed);
int cnt = 128;
for (int t = 0; t < cnt; t++) {
invoke_func_2(t0, t1);
invoke_func_3(t0, t1, t2);
}
cudaEventRecord(st, 0);
for (int t = 0; t < cnt; t++) {
invoke_func_2(t0, t1);
invoke_func_3(t0, t1, t2);
}
cudaEventRecord(ed, 0);
cudaEventSynchronize(st);
cudaEventSynchronize(ed);
cudaEventElapsedTime(&duration, st, ed);
std::cout << "[INFO] time: " << duration / cnt << std::endl;
double perf = double(size) * 8.0f * cnt / (duration * 1e-3) / 1024.0f / 1024.0f / 1024.0f;
std::cout << "[INFO] Perf: " << perf << "GB/s" << std::endl;
std::cout << "[Exit] successful." << std::endl;
}

View File

@ -1,10 +1,11 @@
#include "cuda_utils.h"
// Kernel
__global__ void kernel_func_0(float *tensor_ptr_2, float *tensor_ptr_3) {
__global__ void kernel_func_2(float *tensor_ptr_2, float *tensor_ptr_4,
float *tensor_ptr_5) {
int lane_id = threadIdx.x % 32;
int warp_id = threadIdx.x / 32;
int parallel_idx = blockIdx.x * 8 + warp_id;
float buf[8];
float buf[32];
for (int loop_idx = parallel_idx; loop_idx < 65536; loop_idx += 864) {
int offset_src = 0;
int offset_src_buf = loop_idx;
@ -19,34 +20,10 @@ __global__ void kernel_func_0(float *tensor_ptr_2, float *tensor_ptr_3) {
for (int inst_idx = 0; inst_idx < 8; inst_idx++) {
buf[inst_idx] = (buf[inst_idx] > 0) ? buf[inst_idx] : 0;
}
#pragma unroll
for (int inst_idx = 0; inst_idx < 8; inst_idx++) {
tensor_ptr_3[0 + offset_src + inst_idx * 32 + lane_id] =
buf[inst_idx];
}
}
}
// Kernel
__global__ void kernel_func_1(float *tensor_ptr_2, float *tensor_ptr_3,
float *tensor_ptr_4) {
int lane_id = threadIdx.x % 32;
int warp_id = threadIdx.x / 32;
int parallel_idx = blockIdx.x * 8 + warp_id;
float buf[24];
for (int loop_idx = parallel_idx; loop_idx < 65536; loop_idx += 864) {
int offset_src = 0;
int offset_src_buf = loop_idx;
offset_src += offset_src_buf % 65536 * 256;
offset_src_buf /= 65536;
#pragma unroll
for (int inst_idx = 0; inst_idx < 8; inst_idx++) {
buf[inst_idx] =
tensor_ptr_2[0 + offset_src + inst_idx * 32 + lane_id];
}
#pragma unroll
for (int inst_idx = 0; inst_idx < 8; inst_idx++) {
buf[inst_idx + 8] =
tensor_ptr_3[0 + offset_src + inst_idx * 32 + lane_id];
tensor_ptr_4[0 + offset_src + inst_idx * 32 + lane_id];
}
#pragma unroll
for (int inst_idx = 0; inst_idx < 8; inst_idx++) {
@ -54,22 +31,16 @@ __global__ void kernel_func_1(float *tensor_ptr_2, float *tensor_ptr_3,
}
#pragma unroll
for (int inst_idx = 0; inst_idx < 8; inst_idx++) {
tensor_ptr_4[0 + offset_src + inst_idx * 32 + lane_id] =
tensor_ptr_5[0 + offset_src + inst_idx * 32 + lane_id] =
buf[inst_idx + 16];
}
}
}
void invoke_func_0(float *tensor_ptr_2, float *tensor_ptr_3) {
void invoke_func_2(float *tensor_ptr_2, float *tensor_ptr_4,
float *tensor_ptr_5) {
dim3 gridDim(108, 1);
dim3 blockDim(256, 1);
kernel_func_0<<<gridDim, blockDim>>>(tensor_ptr_2, tensor_ptr_3);
cudaCheckError();
}
void invoke_func_1(float *tensor_ptr_2, float *tensor_ptr_3,
float *tensor_ptr_4) {
dim3 gridDim(108, 1);
dim3 blockDim(256, 1);
kernel_func_1<<<gridDim, blockDim>>>(tensor_ptr_2, tensor_ptr_3,
tensor_ptr_4);
kernel_func_2<<<gridDim, blockDim>>>(tensor_ptr_2, tensor_ptr_4,
tensor_ptr_5);
cudaCheckError();
}

View File

@ -1,10 +1,11 @@
#include "cuda_utils.h"
// Kernel
__global__ void kernel_func_2(float *tensor_ptr_9, float *tensor_ptr_10) {
__global__ void kernel_func_5(float *tensor_ptr_9, float *tensor_ptr_11,
float *tensor_ptr_12) {
int lane_id = threadIdx.x % 32;
int warp_id = threadIdx.x / 32;
int parallel_idx = blockIdx.x * 8 + warp_id;
float buf[8];
float buf[32];
for (int loop_idx = parallel_idx; loop_idx < 1024; loop_idx += 864) {
int offset_src = 0;
int offset_src_buf = loop_idx;
@ -19,34 +20,10 @@ __global__ void kernel_func_2(float *tensor_ptr_9, float *tensor_ptr_10) {
for (int inst_idx = 0; inst_idx < 8; inst_idx++) {
buf[inst_idx] = (buf[inst_idx] > 0) ? buf[inst_idx] : 0;
}
#pragma unroll
for (int inst_idx = 0; inst_idx < 8; inst_idx++) {
tensor_ptr_10[0 + offset_src + inst_idx * 32 + lane_id] =
buf[inst_idx];
}
}
}
// Kernel
__global__ void kernel_func_3(float *tensor_ptr_9, float *tensor_ptr_10,
float *tensor_ptr_11) {
int lane_id = threadIdx.x % 32;
int warp_id = threadIdx.x / 32;
int parallel_idx = blockIdx.x * 8 + warp_id;
float buf[24];
for (int loop_idx = parallel_idx; loop_idx < 1024; loop_idx += 864) {
int offset_src = 0;
int offset_src_buf = loop_idx;
offset_src += offset_src_buf % 1024 * 256;
offset_src_buf /= 1024;
#pragma unroll
for (int inst_idx = 0; inst_idx < 8; inst_idx++) {
buf[inst_idx] =
tensor_ptr_9[0 + offset_src + inst_idx * 32 + lane_id];
}
#pragma unroll
for (int inst_idx = 0; inst_idx < 8; inst_idx++) {
buf[inst_idx + 8] =
tensor_ptr_10[0 + offset_src + inst_idx * 32 + lane_id];
tensor_ptr_11[0 + offset_src + inst_idx * 32 + lane_id];
}
#pragma unroll
for (int inst_idx = 0; inst_idx < 8; inst_idx++) {
@ -54,22 +31,16 @@ __global__ void kernel_func_3(float *tensor_ptr_9, float *tensor_ptr_10,
}
#pragma unroll
for (int inst_idx = 0; inst_idx < 8; inst_idx++) {
tensor_ptr_11[0 + offset_src + inst_idx * 32 + lane_id] =
tensor_ptr_12[0 + offset_src + inst_idx * 32 + lane_id] =
buf[inst_idx + 16];
}
}
}
void invoke_func_2(float *tensor_ptr_9, float *tensor_ptr_10) {
void invoke_func_5(float *tensor_ptr_9, float *tensor_ptr_11,
float *tensor_ptr_12) {
dim3 gridDim(108, 1);
dim3 blockDim(256, 1);
kernel_func_2<<<gridDim, blockDim>>>(tensor_ptr_9, tensor_ptr_10);
cudaCheckError();
}
void invoke_func_3(float *tensor_ptr_9, float *tensor_ptr_10,
float *tensor_ptr_11) {
dim3 gridDim(108, 1);
dim3 blockDim(256, 1);
kernel_func_3<<<gridDim, blockDim>>>(tensor_ptr_9, tensor_ptr_10,
tensor_ptr_11);
kernel_func_5<<<gridDim, blockDim>>>(tensor_ptr_9, tensor_ptr_11,
tensor_ptr_12);
cudaCheckError();
}

View File

@ -10,7 +10,8 @@
namespace memb {
enum OpType {
READ = 1,
EMPTY = 1,
READ,
WRITE,
RELU,
ADD,
@ -26,6 +27,8 @@ enum MemType {
inline std::string getName(OpType opType) {
switch (opType) {
case (OpType::EMPTY):
return "EMPTY";
case (OpType::READ):
return "READ";
case (OpType::WRITE):

View File

@ -23,6 +23,8 @@ class MetaGraph {
IT_ASSERT(metaOpMap.find(op2->id) != metaOpMap.end());
edges.emplace_back(metaOpMap[op1->id], metaOpMap[op2->id]);
}
void print();
void optimize();
std::string genHeader();
std::string genKernelFuncs();
std::string genInvokeFuncs();

View File

@ -60,12 +60,12 @@ class MetaOp {
inline int getLoopSt() { return main_loop_st; }
inline int getLoopEd() { return main_loop_ed; }
void optimize();
std::string genKernelFunc();
std::string genInvokeFunc();
static std::shared_ptr<MetaOp>
buildByMerge(std::shared_ptr<MetaOp> metaOp0,
std::shared_ptr<MetaOp> metaOp1);
static std::shared_ptr<MetaOp> merge(std::shared_ptr<MetaOp> metaOp0,
std::shared_ptr<MetaOp> metaOp1);
inline void print() {
std::cout << "MetaOp: " << id << std::endl;

View File

@ -5,17 +5,20 @@
namespace memb {
class BinaryOp : public MicroOp {
private:
OpType opType;
std::shared_ptr<Pointer> pSrc0, pSrc1, pDst;
size_t num, width;
public:
BinaryOp(OpType _opType, std::shared_ptr<Pointer> _pSrc0,
std::shared_ptr<Pointer> _pSrc1, std::shared_ptr<Pointer> _pDst,
size_t _num, size_t _width)
: opType(_opType), pSrc0(_pSrc0), pSrc1(_pSrc1), pDst(_pDst), num(_num),
width(_width) {}
: num(_num), width(_width) {
opType = _opType;
ptrs = {_pSrc0, _pSrc1, _pDst};
}
~BinaryOp() {}
std::shared_ptr<Pointer> getSrc0() { return ptrs[0]; }
std::shared_ptr<Pointer> getSrc1() { return ptrs[1]; }
std::shared_ptr<Pointer> getDst() { return ptrs[2]; }
// bool checkValid() override;
std::string generate() override;
inline void print() override {

View File

@ -0,0 +1,17 @@
#pragma once
#include "pfusion/micro_op.h"
namespace memb {
class EmptyOp : public MicroOp {
public:
EmptyOp() { opType = EMPTY; }
~EmptyOp() {}
// bool checkValid() override;
std::string generate() override { return ""; };
inline void print() override {
std::cout << id << " " << getName(opType) << std::endl;
}
};
} // namespace memb

View File

@ -5,23 +5,29 @@
namespace memb {
class MemoryOp : public MicroOp {
private:
OpType opType;
std::shared_ptr<Pointer> src, dst;
size_t num, width;
public:
MemoryOp(OpType _opType, std::shared_ptr<Pointer> _src,
std::shared_ptr<Pointer> _dst, size_t _num, size_t _width)
: opType(_opType), src(_src), dst(_dst), num(_num), width(_width) {}
: num(_num), width(_width) {
opType = _opType;
ptrs = {_src, _dst};
}
// bool checkValid() override;
~MemoryOp() {}
std::shared_ptr<Pointer> getSrc() { return ptrs[0]; }
std::shared_ptr<Pointer> getDst() { return ptrs[1]; }
std::string generate() override;
inline void print() override {
if (opType == READ) {
std::cout << id << " " << getName(opType) << " "
<< getName(src->getType()) << std::endl;
<< getName(getSrc()->getType()) << " "
<< getSrc()->getHash() << std::endl;
} else if (opType == WRITE) {
std::cout << id << " " << getName(opType) << " "
<< getName(dst->getType()) << std::endl;
<< getName(getDst()->getType()) << " "
<< getDst()->getHash() << std::endl;
} else {
IT_ASSERT(false);
}

View File

@ -1,20 +1,24 @@
#pragma once
#include "pfusion/micro_op.h"
#include <string>
namespace memb {
class UnaryOp : public MicroOp {
private:
const OpType opType;
const std::shared_ptr<Pointer> src, dst;
const int num, width;
public:
UnaryOp(OpType _opType, std::shared_ptr<Pointer> _src,
std::shared_ptr<Pointer> _dst, int _num, int _width)
: opType(_opType), src(_src), dst(_dst), num(_num), width(_width) {}
: num(_num), width(_width) {
opType = _opType;
ptrs = {_src, _dst};
}
~UnaryOp() {}
std::shared_ptr<Pointer> getSrc() { return ptrs[0]; }
std::shared_ptr<Pointer> getDst() { return ptrs[1]; }
// bool checkValid() override;
std::string generate() override;
inline void print() override {

View File

@ -6,28 +6,28 @@
namespace memb {
class MicroOp {
public:
enum MicroOpType {
memory = 1,
kernel,
};
protected:
MicroOpType type;
int id;
size_t id;
OpType opType;
std::vector<std::shared_ptr<Pointer>> ptrs;
public:
MicroOp() {
static int microOpId = 0;
id = microOpId++;
opType = OpType(0);
}
virtual ~MicroOp() {}
inline MicroOpType getType() { return type; }
inline OpType getType() { return opType; }
inline bool isMemoryOp() { return opType == READ || opType == WRITE; }
inline std::vector<std::shared_ptr<Pointer>> getPtrs() { return ptrs; }
// virtual bool checkValid() = 0;
virtual std::string generate() = 0;
virtual void print() = 0;
static std::shared_ptr<MicroOp> merge(std::shared_ptr<MicroOp> op0,
std::shared_ptr<MicroOp> op1);
};
class MicroGraph {
@ -37,7 +37,7 @@ class MicroGraph {
private:
std::vector<std::shared_ptr<MicroOp>> microOps;
std::vector<std::pair<int, int>> deps;
std::vector<std::pair<int, int>> edges;
};
} // namespace memb

View File

@ -34,28 +34,6 @@ class SearchGraph {
nodes[j].pred.emplace_back(i);
}
std::shared_ptr<MetaGraph> exportFirstMetaGraph();
inline void print() {
for (auto node : nodes) {
std::cout << node.id << "[(";
if (node.pred.size() > 0) {
std::cout << node.pred[0];
}
for (size_t i = 1; i < node.pred.size(); i++) {
std::cout << ", " << node.pred[i];
}
std::cout << ")(";
if (node.succ.size() > 0) {
std::cout << node.succ[0];
}
for (size_t i = 1; i < node.succ.size(); i++) {
std::cout << ", " << node.succ[i];
}
std::cout << ")]" << std::endl;
for (auto metaOp : node.metaOps) {
metaOp->print();
}
}
}
};
} // namespace memb

View File

@ -112,6 +112,11 @@ std::string infini::MemoryCodegen::generate(Graph graph) {
auto searchGraph = instantiateGraph(graph);
auto metaGraph = searchGraph->exportFirstMetaGraph();
std::string code = "";
std::cout << "[INFO] before opt." << std::endl;
metaGraph->print();
metaGraph->optimize();
std::cout << "[INFO] after opt." << std::endl;
metaGraph->print();
code += metaGraph->genHeader();
code += metaGraph->genKernelFuncs();
code += metaGraph->genInvokeFuncs();

View File

@ -2,6 +2,34 @@
namespace memb {
void MetaGraph::print() {
for (auto op : metaOps) {
op->print();
}
}
void MetaGraph::optimize() {
std::vector<std::shared_ptr<MetaOp>> ops;
int numOp = metaOps.size();
int cur = 0;
for (int i = 1; i < numOp; i++) {
auto next = MetaOp::merge(metaOps[cur], metaOps[i]);
if (next == nullptr) {
ops.emplace_back(metaOps[cur]);
cur = i;
} else {
cur = metaOps.size();
metaOps.emplace_back(next);
}
}
ops.emplace_back(metaOps[cur]);
metaOps.clear();
for (auto op : ops) {
op->optimize();
metaOps.emplace_back(op);
}
}
std::string MetaGraph::genHeader() {
std::string code = "#include \"cuda_utils.h\"\n";
return code;

View File

@ -22,6 +22,43 @@ std::string TensorMapping::genOffset() {
return code;
}
void MetaOp::optimize() {
std::vector<std::shared_ptr<MicroOp>> ops;
int numOp = microOps.size();
int cur = 0;
for (int i = 1; i < numOp; i++) {
auto next = MicroOp::merge(microOps[cur], microOps[i]);
if (next == nullptr) {
ops.emplace_back(microOps[cur]);
cur = i;
} else {
cur = microOps.size();
microOps.emplace_back(next);
}
}
ops.emplace_back(microOps[cur]);
microOps.clear();
std::unordered_set<std::string> ptrSet;
for (auto op : ops) {
for (auto ptr : op->getPtrs()) {
ptrSet.emplace(ptr->getName());
}
if (op->getType() != EMPTY) {
microOps.emplace_back(op);
}
}
std::vector<std::shared_ptr<Pointer>> newPtrs;
for (auto ptr : ptrs) {
if (ptrSet.find(ptr->getName()) != ptrSet.end()) {
newPtrs.emplace_back(ptr);
}
}
ptrs.clear();
for (auto ptr : newPtrs) {
ptrs.emplace_back(ptr);
}
}
std::string MetaOp::genKernelFunc() {
std::string code = "";
code += "// Kernel\n";
@ -84,24 +121,24 @@ std::string MetaOp::genInvokeFunc() {
return code;
}
std::shared_ptr<MetaOp> MetaOp::buildByMerge(std::shared_ptr<MetaOp> metaOp0,
std::shared_ptr<MetaOp> metaOp1) {
std::shared_ptr<MetaOp> MetaOp::merge(std::shared_ptr<MetaOp> metaOp0,
std::shared_ptr<MetaOp> metaOp1) {
IT_ASSERT(metaOp0->checkValid());
IT_ASSERT(metaOp1->checkValid());
// Check unmergeable
if (metaOp0->main_loop_st != metaOp1->main_loop_st ||
metaOp0->main_loop_ed != metaOp1->main_loop_ed ||
metaOp0->numBlocks != metaOp1->numBlocks ||
metaOp0->numReg != metaOp1->numReg ||
metaOp0->numSmem != metaOp1->numSmem) {
metaOp0->numWarps != metaOp1->numWarps) {
return nullptr;
}
auto metaOp = std::make_shared<MetaOp>();
metaOp->main_loop_st = metaOp0->main_loop_st;
metaOp->main_loop_ed = metaOp0->main_loop_ed;
metaOp->numBlocks = metaOp0->numBlocks;
metaOp->numReg = metaOp0->numReg;
metaOp->numSmem = metaOp0->numSmem;
metaOp->numWarps = metaOp0->numWarps;
metaOp->numReg = metaOp0->numReg + metaOp1->numReg;
metaOp->numSmem = metaOp0->numSmem + metaOp1->numSmem;
// Merge ptr
std::unordered_set<size_t> ptrSet;

View File

@ -11,7 +11,7 @@ std::string BinaryOp::generate() {
code += "#pragma unroll\n";
code += "for (int inst_idx = 0; inst_idx < " + std::to_string(num) +
"; inst_idx++) {\n";
std::string opFunc = pDst->generate() + " = " + pSrc0->generate();
std::string opFunc = getDst()->generate() + " = " + getSrc0()->generate();
if (opType == ADD) {
opFunc += " + ";
} else if (opType == SUB) {
@ -19,7 +19,7 @@ std::string BinaryOp::generate() {
} else {
IT_ASSERT(false);
}
opFunc += pSrc1->generate() + ";\n";
opFunc += getSrc1()->generate() + ";\n";
code += opFunc;
code += "}\n";

View File

@ -11,11 +11,11 @@ std::string MemoryOp::generate() {
code += "#pragma unroll\n";
code += "for (int inst_idx = 0; inst_idx < " + std::to_string(num) +
"; inst_idx++) {\n";
if ((opType == OpType::READ && src->getType() != MemType::REG &&
dst->getType() == MemType::REG) ||
(opType == OpType::WRITE && src->getType() == MemType::REG &&
dst->getType() != MemType::REG)) {
code += dst->generate() + " = " + src->generate() + ";\n";
if ((opType == OpType::READ && getSrc()->getType() != MemType::REG &&
getDst()->getType() == MemType::REG) ||
(opType == OpType::WRITE && getSrc()->getType() == MemType::REG &&
getDst()->getType() != MemType::REG)) {
code += getDst()->generate() + " = " + getSrc()->generate() + ";\n";
} else {
IT_ASSERT(false);
}

View File

@ -12,8 +12,8 @@ std::string UnaryOp::generate() {
code += "for (int inst_idx = 0; inst_idx < " + std::to_string(num) +
"; inst_idx++) {\n";
if (opType == RELU) {
code += dst->generate() + " = (" + src->generate() + " > 0) ? " +
src->generate() + " : 0;\n";
code += getDst()->generate() + " = (" + getSrc()->generate() +
" > 0) ? " + getSrc()->generate() + " : 0;\n";
} else {
IT_ASSERT(false);
}

22
src/pfusion/micro_op.cc Normal file
View File

@ -0,0 +1,22 @@
#include "pfusion/micro_op.h"
#include "pfusion/micro_kernel/empty.h"
#include "pfusion/micro_kernel/memory.h"
namespace memb {
std::shared_ptr<MicroOp> MicroOp::merge(std::shared_ptr<MicroOp> op0,
std::shared_ptr<MicroOp> op1) {
if (op0->getType() == WRITE && op1->getType() == READ) {
auto memOp0 = std::dynamic_pointer_cast<MemoryOp>(op0);
auto memOp1 = std::dynamic_pointer_cast<MemoryOp>(op1);
if (memOp0->getDst()->getHash() == memOp1->getSrc()->getHash()) {
if (memOp0->getSrc()->getHash() == memOp1->getDst()->getHash()) {
return std::make_shared<EmptyOp>();
} else {
// TODO: gen reg to reg.
IT_ASSERT(false);
}
}
}
return nullptr;
}
} // namespace memb

View File

@ -19,7 +19,7 @@ TEST(Graph, SAR_DRN_0) {
Tensor t3 = g->addTensor({1, 64, 512, 512}, DataType::Float32);
g->dataMalloc();
g->addOpWithOutputs<ReluObj>(t0, t1);
g->addOpWithOutputs<AddObj>(t0, t1, t2);
g->addOpWithOutputs<AddObj>(t1, t2, t3);
MemoryCodegen codegen;
codegen.exportCode(g, "sar_drn_0.cu");
}
@ -33,7 +33,7 @@ TEST(Graph, SAR_DRN_1) {
Tensor t3 = g->addTensor({1, 1, 512, 512}, DataType::Float32);
g->dataMalloc();
g->addOpWithOutputs<ReluObj>(t0, t1);
g->addOpWithOutputs<SubObj>(t0, t1, t2);
g->addOpWithOutputs<SubObj>(t1, t2, t3);
MemoryCodegen codegen;
codegen.exportCode(g, "sar_drn_1.cu");
}