From 89398a1c57a7b28910de3c27194451c9c926bbc6 Mon Sep 17 00:00:00 2001 From: mazx Date: Tue, 15 Nov 2022 15:13:50 +0800 Subject: [PATCH] update fused kernels. --- eval_pfusion/eval_kernel.py | 21 +++++ eval_pfusion/eval_kernel.tmp | 41 +++++++++ eval_pfusion/eval_transpose.py | 21 +++++ eval_pfusion/eval_transpose.tmp | 84 +++++++++++++++++ generated_code/bias_0.cu | 42 +++++++++ generated_code/bias_1.cu | 42 +++++++++ generated_code/bias_2.cu | 42 +++++++++ generated_code/bias_3.cu | 42 +++++++++ generated_code/tmp.cu | 84 +++++++++++++++++ generated_code/transpose.cu | 10 +-- generated_code/transpose_0.cu | 56 ++++++++++++ generated_code/transpose_1.cu | 66 ++++++++++++++ generated_code/transpose_2.cu | 62 +++++++++++++ include/pfusion/common.h | 1 + include/pfusion/memory_codegen.h | 10 ++- include/pfusion/meta_op.h | 35 +++++++- include/pfusion/micro_kernel/memory.h | 23 ++++- include/pfusion/micro_op.h | 4 +- include/pfusion/pointer.h | 11 +++ src/pfusion/instantiate.cc | 45 +++++----- src/pfusion/memory_codegen.cc | 42 ++++++++- src/pfusion/meta_op.cc | 124 ++++++++++++++++++++++++-- src/pfusion/micro_kernel/memory.cc | 74 ++++++++++++++- test/pfusion/.test_sar_drn.cc.swp | Bin 12288 -> 0 bytes test/pfusion/test_bias.cc | 26 ++++++ test/pfusion/test_transpose.cc | 26 +++--- 26 files changed, 976 insertions(+), 58 deletions(-) create mode 100644 eval_pfusion/eval_kernel.py create mode 100644 eval_pfusion/eval_kernel.tmp create mode 100644 eval_pfusion/eval_transpose.py create mode 100644 eval_pfusion/eval_transpose.tmp create mode 100644 generated_code/bias_0.cu create mode 100644 generated_code/bias_1.cu create mode 100644 generated_code/bias_2.cu create mode 100644 generated_code/bias_3.cu create mode 100644 generated_code/tmp.cu create mode 100644 generated_code/transpose_0.cu create mode 100644 generated_code/transpose_1.cu create mode 100644 generated_code/transpose_2.cu delete mode 100644 test/pfusion/.test_sar_drn.cc.swp create mode 100644 test/pfusion/test_bias.cc diff --git a/eval_pfusion/eval_kernel.py b/eval_pfusion/eval_kernel.py new file mode 100644 index 00000000..94900d52 --- /dev/null +++ b/eval_pfusion/eval_kernel.py @@ -0,0 +1,21 @@ +import os + + +def eval(filename, kernel, shape): + with open("../eval_pfusion/eval_kernel.tmp", "r") as f: + code = f.read() + code = code.replace("%%invoke_func%%", kernel) + code = code.replace("%%shape%%", shape) + with open("../generated_code/tmp.cu", "w") as f: + f.write(code) + # os.system("make -j && ./test_bias") + os.system( + "nvcc ../generated_code/tmp.cu ../generated_code/" + filename + " -I ../eval_pfusion -o ./tmp") + os.system("./tmp") + + +if __name__ == "__main__": + eval("bias_0.cu", "invoke_func_0", "{28 * 28, 24}") + eval("bias_1.cu", "invoke_func_1", "{28 * 28, 58}") + eval("bias_2.cu", "invoke_func_2", "{14 * 14, 116}") + eval("bias_3.cu", "invoke_func_3", "{7 * 7, 232}") diff --git a/eval_pfusion/eval_kernel.tmp b/eval_pfusion/eval_kernel.tmp new file mode 100644 index 00000000..5102f74b --- /dev/null +++ b/eval_pfusion/eval_kernel.tmp @@ -0,0 +1,41 @@ +#include "cuda.h" +#include "cuda_utils.h" + +#include + +void %%invoke_func%%(float *tensor_ptr_0, float *tensor_ptr_1, + float *tensor_ptr_2); + +int main() { + std::vector shape = %%shape%%; + float *t0, *t1, *t2; + size_t size = 1; + for (auto x : shape) { + size *= x; + } + + cudaSafeCall(cudaMalloc((void **)&t0, size * sizeof(float))); + cudaSafeCall(cudaMalloc((void **)&t1, size * sizeof(float))); + cudaSafeCall(cudaMalloc((void **)&t2, size * sizeof(float))); + + float duration = 0; + cudaEvent_t st, ed; + cudaEventCreate(&st); + cudaEventCreate(&ed); + int cnt = 128; + for (int t = 0; t < cnt; t++) { + %%invoke_func%%(t0, t1, t2); + } + cudaEventRecord(st, 0); + for (int t = 0; t < cnt; t++) { + %%invoke_func%%(t0, t1, t2); + } + cudaEventRecord(ed, 0); + cudaEventSynchronize(st); + cudaEventSynchronize(ed); + cudaEventElapsedTime(&duration, st, ed); + std::cout << "[INFO] time: " << duration / cnt << std::endl; + // double perf = double(size) * 8.0f * cnt / (duration * 1e-3) / 1024.0f / 1024.0f / 1024.0f; + // std::cout << "[INFO] Perf: " << perf << "GB/s" << std::endl; + std::cout << "[Exit] successful." << std::endl; +} diff --git a/eval_pfusion/eval_transpose.py b/eval_pfusion/eval_transpose.py new file mode 100644 index 00000000..929032b8 --- /dev/null +++ b/eval_pfusion/eval_transpose.py @@ -0,0 +1,21 @@ +import os + + +def eval(filename, kernel, shape, perm): + with open("../eval_pfusion/eval_transpose.tmp", "r") as f: + code = f.read() + code = code.replace("%%invoke_func%%", kernel) + code = code.replace("%%shape%%", shape) + code = code.replace("%%perm%%", perm) + with open("../generated_code/tmp.cu", "w") as f: + f.write(code) + # os.system("make -j && ./test_bias") + os.system( + "nvcc ../generated_code/tmp.cu ../generated_code/" + filename + " -I ../eval_pfusion -o ./tmp") + os.system("./tmp") + + +if __name__ == "__main__": + eval("transpose_0.cu", "invoke_func_0", "{28 * 28, 58, 2}", "{0, 2, 1}") + eval("transpose_1.cu", "invoke_func_1", "{14 * 14, 116, 2}", "{0, 2, 1}") + eval("transpose_2.cu", "invoke_func_2", "{7 * 7, 232, 2}", "{0, 2, 1}") diff --git a/eval_pfusion/eval_transpose.tmp b/eval_pfusion/eval_transpose.tmp new file mode 100644 index 00000000..183066a8 --- /dev/null +++ b/eval_pfusion/eval_transpose.tmp @@ -0,0 +1,84 @@ +#include "cuda.h" +#include "cuda_utils.h" + +#include + +void %%invoke_func%%(float *src, float *dst); + +int main() { + std::vector shape = %%shape%%; + std::vector perm = %%perm%%; + float *src, *dst; + size_t size = 1; + for (auto x : shape) { + size *= x; + } + std::vector stride_src(shape.size()), stride_dst(shape.size()); + stride_dst[0] = 1; + for (int i = 1; i < shape.size(); i++) { + stride_dst[i] = stride_dst[i-1] * shape[i-1]; + } + size_t this_stride = 1; + for (int i = 0; i < shape.size(); i++) { + for (int j = 0; j < shape.size(); j++) { + if (perm[j] == i) { + stride_src[i] = this_stride; + this_stride *= shape[j]; + } + } + } + + cudaSafeCall(cudaMalloc((void **)&src, size * sizeof(float))); + cudaSafeCall(cudaMalloc((void **)&dst, size * sizeof(float))); + + float *src_host, *dst_host; + src_host = (float *)malloc(size * sizeof(float)); + dst_host = (float *)malloc(size * sizeof(float)); + for (size_t i = 0; i < size; i++) { + src_host[i] = i; + } + cudaSafeCall(cudaMemcpy(src, src_host, size * sizeof(float), cudaMemcpyHostToDevice)); + %%invoke_func%%(src, dst); + cudaSafeCall(cudaMemcpy(dst_host, dst, size * sizeof(float), cudaMemcpyDeviceToHost)); + bool flag = 0; + for (size_t i = 0; i < size; i++) { + size_t base = i; + size_t offset_src = 0; + for (int j = 0; j < shape.size(); j++) { + offset_src += base % shape[j] * stride_src[perm[j]]; + base /= shape[j]; + } + if (dst_host[i] != src_host[offset_src]) { + flag = 1; + std::cout << "[ERROR] at " << i << "," << offset_src << ":" << dst_host[i] << "," << src_host[offset_src] << std::endl; + break; + } + } + + if (!flag) { + std::cout << "[INFO] transpose correct." << std::endl; + } else { + std::cout << "[ERROR] transpose incorrect." << std::endl; + } + + float duration = 0; + cudaEvent_t st, ed; + cudaEventCreate(&st); + cudaEventCreate(&ed); + int cnt = 128; + for (int t = 0; t < cnt; t++) { + %%invoke_func%%(src, dst); + } + cudaEventRecord(st, 0); + for (int t = 0; t < cnt; t++) { + %%invoke_func%%(src, dst); + } + cudaEventRecord(ed, 0); + cudaEventSynchronize(st); + cudaEventSynchronize(ed); + cudaEventElapsedTime(&duration, st, ed); + std::cout << "[INFO] time: " << duration / cnt << std::endl; + double perf = double(size) * 8.0f * cnt / (duration * 1e-3) / 1024.0f / 1024.0f / 1024.0f; + std::cout << "[INFO] Perf: " << perf << "GB/s" << std::endl; + std::cout << "[Exit] successful." << std::endl; +} diff --git a/generated_code/bias_0.cu b/generated_code/bias_0.cu new file mode 100644 index 00000000..72bbf950 --- /dev/null +++ b/generated_code/bias_0.cu @@ -0,0 +1,42 @@ +#include "cuda_utils.h" +// Kernel +__global__ void kernel_func_0(float *input, float *bias, float *output) { + int lane_id = threadIdx.x % 32; + int warp_id = threadIdx.x / 32; + int parallel_idx = blockIdx.x * 4 + warp_id; + float buf[4]; + for (int loop_idx = parallel_idx; loop_idx < 144; loop_idx += 320) { + int offset_input = 0; + int offset_input_buf = loop_idx; + offset_input += offset_input_buf % 7 * 128; + offset_input_buf /= 7; + offset_input += offset_input_buf % 24 * 784; + offset_input_buf /= 24; + int offset_bias = 0; + int offset_bias_buf = loop_idx; + offset_bias += offset_bias_buf % 24 * 24; + offset_bias_buf /= 24; +#pragma unroll + for (int inst_idx = 0; inst_idx < 4; inst_idx++) { + buf[inst_idx] = input[0 + offset_input + inst_idx * 32 + lane_id]; + } +#pragma unroll + for (int inst_idx = 0; inst_idx < 1; inst_idx++) { + buf[4] = bias[0 + offset_bias]; + } +#pragma unroll + for (int inst_idx = 0; inst_idx < 4; inst_idx++) { + buf[inst_idx] = buf[inst_idx] + buf[4]; + } +#pragma unroll + for (int inst_idx = 0; inst_idx < 8; inst_idx++) { + bias[0 + offset_input + inst_idx * 32 + lane_id] = buf[inst_idx]; + } + } +} +void invoke_func_0(float *input, float *bias, float *output) { + dim3 gridDim(80, 1); + dim3 blockDim(128, 1); + kernel_func_0<<>>(input, bias, output); + cudaCheckError(); +} diff --git a/generated_code/bias_1.cu b/generated_code/bias_1.cu new file mode 100644 index 00000000..155afe2c --- /dev/null +++ b/generated_code/bias_1.cu @@ -0,0 +1,42 @@ +#include "cuda_utils.h" +// Kernel +__global__ void kernel_func_1(float *input, float *bias, float *output) { + int lane_id = threadIdx.x % 32; + int warp_id = threadIdx.x / 32; + int parallel_idx = blockIdx.x * 4 + warp_id; + float buf[4]; + for (int loop_idx = parallel_idx; loop_idx < 348; loop_idx += 320) { + int offset_input = 0; + int offset_input_buf = loop_idx; + offset_input += offset_input_buf % 7 * 128; + offset_input_buf /= 7; + offset_input += offset_input_buf % 58 * 784; + offset_input_buf /= 58; + int offset_bias = 0; + int offset_bias_buf = loop_idx; + offset_bias += offset_bias_buf % 58 * 58; + offset_bias_buf /= 58; +#pragma unroll + for (int inst_idx = 0; inst_idx < 4; inst_idx++) { + buf[inst_idx] = input[0 + offset_input + inst_idx * 32 + lane_id]; + } +#pragma unroll + for (int inst_idx = 0; inst_idx < 1; inst_idx++) { + buf[4] = bias[0 + offset_bias]; + } +#pragma unroll + for (int inst_idx = 0; inst_idx < 4; inst_idx++) { + buf[inst_idx] = buf[inst_idx] + buf[4]; + } +#pragma unroll + for (int inst_idx = 0; inst_idx < 8; inst_idx++) { + bias[0 + offset_input + inst_idx * 32 + lane_id] = buf[inst_idx]; + } + } +} +void invoke_func_1(float *input, float *bias, float *output) { + dim3 gridDim(80, 1); + dim3 blockDim(128, 1); + kernel_func_1<<>>(input, bias, output); + cudaCheckError(); +} diff --git a/generated_code/bias_2.cu b/generated_code/bias_2.cu new file mode 100644 index 00000000..90d08c61 --- /dev/null +++ b/generated_code/bias_2.cu @@ -0,0 +1,42 @@ +#include "cuda_utils.h" +// Kernel +__global__ void kernel_func_2(float *input, float *bias, float *output) { + int lane_id = threadIdx.x % 32; + int warp_id = threadIdx.x / 32; + int parallel_idx = blockIdx.x * 4 + warp_id; + float buf[4]; + for (int loop_idx = parallel_idx; loop_idx < 116; loop_idx += 320) { + int offset_input = 0; + int offset_input_buf = loop_idx; + offset_input += offset_input_buf % 2 * 128; + offset_input_buf /= 2; + offset_input += offset_input_buf % 116 * 196; + offset_input_buf /= 116; + int offset_bias = 0; + int offset_bias_buf = loop_idx; + offset_bias += offset_bias_buf % 116 * 116; + offset_bias_buf /= 116; +#pragma unroll + for (int inst_idx = 0; inst_idx < 4; inst_idx++) { + buf[inst_idx] = input[0 + offset_input + inst_idx * 32 + lane_id]; + } +#pragma unroll + for (int inst_idx = 0; inst_idx < 1; inst_idx++) { + buf[4] = bias[0 + offset_bias]; + } +#pragma unroll + for (int inst_idx = 0; inst_idx < 4; inst_idx++) { + buf[inst_idx] = buf[inst_idx] + buf[4]; + } +#pragma unroll + for (int inst_idx = 0; inst_idx < 8; inst_idx++) { + bias[0 + offset_input + inst_idx * 32 + lane_id] = buf[inst_idx]; + } + } +} +void invoke_func_2(float *input, float *bias, float *output) { + dim3 gridDim(80, 1); + dim3 blockDim(128, 1); + kernel_func_2<<>>(input, bias, output); + cudaCheckError(); +} diff --git a/generated_code/bias_3.cu b/generated_code/bias_3.cu new file mode 100644 index 00000000..ad83155f --- /dev/null +++ b/generated_code/bias_3.cu @@ -0,0 +1,42 @@ +#include "cuda_utils.h" +// Kernel +__global__ void kernel_func_3(float *input, float *bias, float *output) { + int lane_id = threadIdx.x % 32; + int warp_id = threadIdx.x / 32; + int parallel_idx = blockIdx.x * 4 + warp_id; + float buf[4]; + for (int loop_idx = parallel_idx; loop_idx < 0; loop_idx += 320) { + int offset_input = 0; + int offset_input_buf = loop_idx; + offset_input += offset_input_buf % 1 * 128; + offset_input_buf /= 1; + offset_input += offset_input_buf % 232 * 49; + offset_input_buf /= 232; + int offset_bias = 0; + int offset_bias_buf = loop_idx; + offset_bias += offset_bias_buf % 232 * 232; + offset_bias_buf /= 232; +#pragma unroll + for (int inst_idx = 0; inst_idx < 4; inst_idx++) { + buf[inst_idx] = input[0 + offset_input + inst_idx * 32 + lane_id]; + } +#pragma unroll + for (int inst_idx = 0; inst_idx < 1; inst_idx++) { + buf[4] = bias[0 + offset_bias]; + } +#pragma unroll + for (int inst_idx = 0; inst_idx < 4; inst_idx++) { + buf[inst_idx] = buf[inst_idx] + buf[4]; + } +#pragma unroll + for (int inst_idx = 0; inst_idx < 8; inst_idx++) { + bias[0 + offset_input + inst_idx * 32 + lane_id] = buf[inst_idx]; + } + } +} +void invoke_func_3(float *input, float *bias, float *output) { + dim3 gridDim(80, 1); + dim3 blockDim(128, 1); + kernel_func_3<<>>(input, bias, output); + cudaCheckError(); +} diff --git a/generated_code/tmp.cu b/generated_code/tmp.cu new file mode 100644 index 00000000..15398d56 --- /dev/null +++ b/generated_code/tmp.cu @@ -0,0 +1,84 @@ +#include "cuda.h" +#include "cuda_utils.h" + +#include + +void invoke_func_2(float *src, float *dst); + +int main() { + std::vector shape = {7 * 7, 232, 2}; + std::vector perm = {0, 2, 1}; + float *src, *dst; + size_t size = 1; + for (auto x : shape) { + size *= x; + } + std::vector stride_src(shape.size()), stride_dst(shape.size()); + stride_dst[0] = 1; + for (int i = 1; i < shape.size(); i++) { + stride_dst[i] = stride_dst[i-1] * shape[i-1]; + } + size_t this_stride = 1; + for (int i = 0; i < shape.size(); i++) { + for (int j = 0; j < shape.size(); j++) { + if (perm[j] == i) { + stride_src[i] = this_stride; + this_stride *= shape[j]; + } + } + } + + cudaSafeCall(cudaMalloc((void **)&src, size * sizeof(float))); + cudaSafeCall(cudaMalloc((void **)&dst, size * sizeof(float))); + + float *src_host, *dst_host; + src_host = (float *)malloc(size * sizeof(float)); + dst_host = (float *)malloc(size * sizeof(float)); + for (size_t i = 0; i < size; i++) { + src_host[i] = i; + } + cudaSafeCall(cudaMemcpy(src, src_host, size * sizeof(float), cudaMemcpyHostToDevice)); + invoke_func_2(src, dst); + cudaSafeCall(cudaMemcpy(dst_host, dst, size * sizeof(float), cudaMemcpyDeviceToHost)); + bool flag = 0; + for (size_t i = 0; i < size; i++) { + size_t base = i; + size_t offset_src = 0; + for (int j = 0; j < shape.size(); j++) { + offset_src += base % shape[j] * stride_src[perm[j]]; + base /= shape[j]; + } + if (dst_host[i] != src_host[offset_src]) { + flag = 1; + std::cout << "[ERROR] at " << i << "," << offset_src << ":" << dst_host[i] << "," << src_host[offset_src] << std::endl; + break; + } + } + + if (!flag) { + std::cout << "[INFO] transpose correct." << std::endl; + } else { + std::cout << "[ERROR] transpose incorrect." << std::endl; + } + + float duration = 0; + cudaEvent_t st, ed; + cudaEventCreate(&st); + cudaEventCreate(&ed); + int cnt = 128; + for (int t = 0; t < cnt; t++) { + invoke_func_2(src, dst); + } + cudaEventRecord(st, 0); + for (int t = 0; t < cnt; t++) { + invoke_func_2(src, dst); + } + cudaEventRecord(ed, 0); + cudaEventSynchronize(st); + cudaEventSynchronize(ed); + cudaEventElapsedTime(&duration, st, ed); + std::cout << "[INFO] time: " << duration / cnt << std::endl; + double perf = double(size) * 8.0f * cnt / (duration * 1e-3) / 1024.0f / 1024.0f / 1024.0f; + std::cout << "[INFO] Perf: " << perf << "GB/s" << std::endl; + std::cout << "[Exit] successful." << std::endl; +} diff --git a/generated_code/transpose.cu b/generated_code/transpose.cu index b97f986e..059962b5 100644 --- a/generated_code/transpose.cu +++ b/generated_code/transpose.cu @@ -11,13 +11,13 @@ __global__ void kernel_func_0(float *tensor_ptr_2, float *tensor_ptr_3) { int offset_src_buf = loop_idx; offset_src += offset_src_buf % 32 * 32736; offset_src_buf /= 32; - offset_src += offset_src_buf % 33 * 32; + offset_src += offset_src_buf % 33 * 33; offset_src_buf /= 33; int offset_dst = 0; int offset_dst_buf = loop_idx; - offset_dst += offset_dst_buf % 32 * 992; + offset_dst += offset_dst_buf % 32 * 1024; offset_dst_buf /= 32; - offset_dst += offset_dst_buf % 33 * 31744; + offset_dst += offset_dst_buf % 33 * 33792; offset_dst_buf /= 33; #pragma unroll for (int inst_idx = 0; inst_idx < 31; inst_idx++) { @@ -26,13 +26,13 @@ __global__ void kernel_func_0(float *tensor_ptr_2, float *tensor_ptr_3) { } #pragma unroll for (int inst_idx = 0; inst_idx < 31; inst_idx++) { - smem[warp_id * 32 * 33 + inst_idx * 33 + lane_id] = buf[inst_idx]; + smem[group_id * 32 * 33 + inst_idx * 33 + lane_id] = buf[inst_idx]; } if (lane_id < 31) { #pragma unroll for (int inst_idx = 0; inst_idx < 32; inst_idx++) { buf[inst_idx] = - smem[warp_id * 32 * 33 + lane_id * 33 + inst_idx]; + smem[group_id * 32 * 33 + lane_id * 33 + inst_idx]; } } if (lane_id < 31) { diff --git a/generated_code/transpose_0.cu b/generated_code/transpose_0.cu new file mode 100644 index 00000000..e637469d --- /dev/null +++ b/generated_code/transpose_0.cu @@ -0,0 +1,56 @@ +#include "cuda_utils.h" +// Kernel +__global__ void kernel_func_0(float *input, float *output) { + int lane_id = threadIdx.x % 32; + int warp_id = threadIdx.x / 32; + int parallel_idx = blockIdx.x * 4 + warp_id; + float buf[4]; + for (int loop_idx = parallel_idx; loop_idx < 812; loop_idx += 320) { + int offset_input = 0; + int offset_input_buf = loop_idx; + offset_input += offset_input_buf % 7 * 128; + offset_input_buf /= 7; + offset_input += offset_input_buf % 58 * 1568; + offset_input_buf /= 58; + offset_input += offset_input_buf % 2 * 784; + offset_input_buf /= 2; + int offset_output = 0; + int offset_output_buf = loop_idx; + offset_output += offset_output_buf % 7 * 128; + offset_output_buf /= 7; + offset_output += offset_output_buf % 58 * 784; + offset_output_buf /= 58; + offset_output += offset_output_buf % 2 * 45472; + offset_output_buf /= 2; + if (loop_idx % 7 == 6) { + if (lane_id < 16) { + buf[0] = input[0 + offset_input + 0 * 32 + lane_id]; + } + } else { +#pragma unroll + for (int inst_idx = 0; inst_idx < 4; inst_idx++) { + buf[inst_idx] = + input[0 + offset_input + inst_idx * 32 + lane_id]; + } + } + // test + if (loop_idx % 7 == 6) { + if (lane_id < 16) { + output[0 + offset_output + 0 * 32 + lane_id] = buf[0]; + } + } else { +#pragma unroll + for (int inst_idx = 0; inst_idx < 4; inst_idx++) { + output[0 + offset_output + inst_idx * 32 + lane_id] = + buf[inst_idx]; + } + } + // test + } +} +void invoke_func_0(float *input, float *output) { + dim3 gridDim(80, 1); + dim3 blockDim(128, 1); + kernel_func_0<<>>(input, output); + cudaCheckError(); +} diff --git a/generated_code/transpose_1.cu b/generated_code/transpose_1.cu new file mode 100644 index 00000000..c2918d01 --- /dev/null +++ b/generated_code/transpose_1.cu @@ -0,0 +1,66 @@ +#include "cuda_utils.h" +// Kernel +__global__ void kernel_func_1(float *input, float *output) { + int lane_id = threadIdx.x % 32; + int warp_id = threadIdx.x / 32; + int parallel_idx = blockIdx.x * 4 + warp_id; + float buf[4]; + for (int loop_idx = parallel_idx; loop_idx < 464; loop_idx += 320) { + int offset_input = 0; + int offset_input_buf = loop_idx; + offset_input += offset_input_buf % 2 * 128; + offset_input_buf /= 2; + offset_input += offset_input_buf % 116 * 392; + offset_input_buf /= 116; + offset_input += offset_input_buf % 2 * 196; + offset_input_buf /= 2; + int offset_output = 0; + int offset_output_buf = loop_idx; + offset_output += offset_output_buf % 2 * 128; + offset_output_buf /= 2; + offset_output += offset_output_buf % 116 * 196; + offset_output_buf /= 116; + offset_output += offset_output_buf % 2 * 22736; + offset_output_buf /= 2; + if (loop_idx % 2 == 1) { +#pragma unroll + for (int inst_idx = 0; inst_idx < 2; inst_idx++) { + buf[inst_idx] = + input[0 + offset_input + inst_idx * 32 + lane_id]; + } + if (lane_id < 4) { + buf[2] = input[0 + offset_input + 2 * 32 + lane_id]; + } + } else { +#pragma unroll + for (int inst_idx = 0; inst_idx < 4; inst_idx++) { + buf[inst_idx] = + input[0 + offset_input + inst_idx * 32 + lane_id]; + } + } + // test + if (loop_idx % 2 == 1) { +#pragma unroll + for (int inst_idx = 0; inst_idx < 2; inst_idx++) { + output[0 + offset_output + inst_idx * 32 + lane_id] = + buf[inst_idx]; + } + if (lane_id < 4) { + output[0 + offset_output + 2 * 32 + lane_id] = buf[2]; + } + } else { +#pragma unroll + for (int inst_idx = 0; inst_idx < 4; inst_idx++) { + output[0 + offset_output + inst_idx * 32 + lane_id] = + buf[inst_idx]; + } + } + // test + } +} +void invoke_func_1(float *input, float *output) { + dim3 gridDim(80, 1); + dim3 blockDim(128, 1); + kernel_func_1<<>>(input, output); + cudaCheckError(); +} diff --git a/generated_code/transpose_2.cu b/generated_code/transpose_2.cu new file mode 100644 index 00000000..440f65c6 --- /dev/null +++ b/generated_code/transpose_2.cu @@ -0,0 +1,62 @@ +#include "cuda_utils.h" +// Kernel +__global__ void kernel_func_2(float *input, float *output) { + int lane_id = threadIdx.x % 32; + int warp_id = threadIdx.x / 32; + int parallel_idx = blockIdx.x * 4 + warp_id; + float buf[4]; + for (int loop_idx = parallel_idx; loop_idx < 464; loop_idx += 320) { + int offset_input = 0; + int offset_input_buf = loop_idx; + offset_input += offset_input_buf % 232 * 98; + offset_input_buf /= 232; + offset_input += offset_input_buf % 2 * 49; + offset_input_buf /= 2; + int offset_output = 0; + int offset_output_buf = loop_idx; + offset_output += offset_output_buf % 232 * 49; + offset_output_buf /= 232; + offset_output += offset_output_buf % 2 * 11368; + offset_output_buf /= 2; + if (loop_idx % 1 == 0) { +#pragma unroll + for (int inst_idx = 0; inst_idx < 1; inst_idx++) { + buf[inst_idx] = + input[0 + offset_input + inst_idx * 32 + lane_id]; + } + if (lane_id < 17) { + buf[1] = input[0 + offset_input + 1 * 32 + lane_id]; + } + } else { +#pragma unroll + for (int inst_idx = 0; inst_idx < 2; inst_idx++) { + buf[inst_idx] = + input[0 + offset_input + inst_idx * 32 + lane_id]; + } + } + // test + if (loop_idx % 1 == 0) { +#pragma unroll + for (int inst_idx = 0; inst_idx < 1; inst_idx++) { + output[0 + offset_output + inst_idx * 32 + lane_id] = + buf[inst_idx]; + } + if (lane_id < 17) { + output[0 + offset_output + 1 * 32 + lane_id] = buf[1]; + } + } else { +#pragma unroll + for (int inst_idx = 0; inst_idx < 2; inst_idx++) { + output[0 + offset_output + inst_idx * 32 + lane_id] = + buf[inst_idx]; + } + } + // test + } +} +void invoke_func_2(float *input, float *output) { + dim3 gridDim(80, 1); + dim3 blockDim(128, 1); + kernel_func_2<<>>(input, output); + cudaCheckError(); +} diff --git a/include/pfusion/common.h b/include/pfusion/common.h index 0cc8f7a7..6be7b6c3 100644 --- a/include/pfusion/common.h +++ b/include/pfusion/common.h @@ -10,6 +10,7 @@ namespace memb { enum OpType { + NONE = 0, EMPTY = 1, READ, WRITE, diff --git a/include/pfusion/memory_codegen.h b/include/pfusion/memory_codegen.h index 43f7fe9a..7eeb5c1b 100644 --- a/include/pfusion/memory_codegen.h +++ b/include/pfusion/memory_codegen.h @@ -5,7 +5,10 @@ namespace infini { class MemoryCodegen { private: - std::string generate(Graph graph); + std::string generateGraph(Graph graph); + std::string generateBias(const std::vector &shape); + std::string generateTranspose(const std::vector &shape, + const std::vector &perm); public: MemoryCodegen() {} @@ -17,5 +20,10 @@ class MemoryCodegen { void exportViT_LN(const std::string &filename); void exportViT_SM(const std::string &filename); void exportViT_GELU(const std::string &filename); + void exportBias(const std::string &filename, + const std::vector &shape); + void exportTranspose(const std::string &filename, + const std::vector &shape, + const std::vector &perm); }; } // namespace infini diff --git a/include/pfusion/meta_op.h b/include/pfusion/meta_op.h index 3d255d76..c8243ae2 100644 --- a/include/pfusion/meta_op.h +++ b/include/pfusion/meta_op.h @@ -7,23 +7,46 @@ namespace memb { class TensorMapping { private: - std::vector shape, map; + std::vector shape, stride, map; std::string name; public: - TensorMapping(const std::string _name, const std::vector &_shape, + TensorMapping(const std::string &_name, const std::vector &_shape, + const std::vector &_stride, const std::vector &_map) { name = "offset_" + _name; IT_ASSERT(_shape.size() > 0 && _shape.size() < 10); for (auto x : _shape) { shape.emplace_back(x); } + IT_ASSERT(_stride.size() > 0 && _stride.size() < 10); + for (auto x : _stride) { + stride.emplace_back(x); + } IT_ASSERT(_map.size() > 0 && _map.size() < 10); for (auto x : _map) { map.emplace_back(x); } } ~TensorMapping() {} + + static inline std::shared_ptr + buildWithMap(const std::string &name, const std::vector &shape, + const std::vector &map) { + std::vector stride(shape.size()); + stride[0] = 1; + for (size_t i = 1; i < stride.size(); i++) { + stride[i] = shape[i] * stride[i - 1]; + } + return std::make_shared(name, shape, stride, map); + } + + static inline std::shared_ptr + build(const std::string &name, const std::vector &shape, + const std::vector &stride, const std::vector &map) { + return std::make_shared(name, shape, stride, map); + } + inline std::string offset() { return name; } inline size_t getHash() { std::hash hasher; @@ -45,7 +68,8 @@ class TensorMapping { class MetaOp { public: int id; - int main_loop_st, main_loop_ed, numBlocks, numGroups, numReg, numSmem, numLanes; + int main_loop_st, main_loop_ed, numBlocks, numGroups, numReg, numSmem, + numLanes; std::vector> microOps; std::vector> ptrs; std::vector> mappings; @@ -77,6 +101,11 @@ class MetaOp { // TODO: check valid return true; }; + static std::shared_ptr + buildBiasOp(const std::vector &shape); + static std::shared_ptr + buildTransposeOp(const std::vector &shape, + const std::vector &perm); }; } // namespace memb \ No newline at end of file diff --git a/include/pfusion/micro_kernel/memory.h b/include/pfusion/micro_kernel/memory.h index 2756c825..7b524dc8 100644 --- a/include/pfusion/micro_kernel/memory.h +++ b/include/pfusion/micro_kernel/memory.h @@ -8,17 +8,36 @@ class MemoryOp : public MicroOp { size_t num, width; public: - MemoryOp(OpType _opType, std::shared_ptr _src, - std::shared_ptr _dst, size_t _num, size_t _width) + MemoryOp(const OpType _opType, const std::shared_ptr _src, + const std::shared_ptr _dst, const size_t _num, + const size_t _width, const std::vector &_cond) : num(_num), width(_width) { opType = _opType; ptrs = {_src, _dst}; + cond = _cond; } // bool checkValid() override; ~MemoryOp() {} + + static inline std::shared_ptr + build(const OpType opType, const std::shared_ptr src, + const std::shared_ptr dst, const size_t num, + const size_t width) { + return std::make_shared(opType, src, dst, num, width, + std::vector({})); + } + + static inline std::shared_ptr + build(const OpType opType, const std::shared_ptr src, + const std::shared_ptr dst, const size_t num, + const size_t width, const std::vector &cond) { + return std::make_shared(opType, src, dst, num, width, cond); + } + std::shared_ptr getSrc() { return ptrs[0]; } std::shared_ptr getDst() { return ptrs[1]; } std::string generate() override; + std::string generateWithCond(); inline void print() override { if (opType == READ) { std::cout << id << " " << getName(opType) << " " diff --git a/include/pfusion/micro_op.h b/include/pfusion/micro_op.h index 93ca2e19..0ac182ad 100644 --- a/include/pfusion/micro_op.h +++ b/include/pfusion/micro_op.h @@ -10,12 +10,12 @@ class MicroOp { size_t id; OpType opType; std::vector> ptrs; + std::vector cond; public: - MicroOp() { + MicroOp() : opType(NONE), cond(0) { static int microOpId = 0; id = microOpId++; - opType = OpType(0); } virtual ~MicroOp() {} diff --git a/include/pfusion/pointer.h b/include/pfusion/pointer.h index cc088cc9..491a4fbf 100644 --- a/include/pfusion/pointer.h +++ b/include/pfusion/pointer.h @@ -40,6 +40,17 @@ class Pointer { inline const std::string getName() { return name; } inline const std::string getOffset() { return offset; } inline const std::string generate() { return name + "[" + offset + "]"; } + inline const std::string generateWithInstIdx(std::string idx) { + std::string code = generate(); + size_t pos = 0, lengthA = 8, lengthB = idx.size(); + while ((pos = code.find("inst_idx", pos)) != std::string::npos) { + code.replace(pos, lengthA, idx); + pos += lengthB; + } + std::cout << "[INFO] " << idx << " " << lengthB << " " << code + << std::endl; + return code; + } inline bool equal(std::shared_ptr ptr) { if (name == ptr->getName() && offset == ptr->getOffset()) { IT_ASSERT(memType == ptr->getType()); diff --git a/src/pfusion/instantiate.cc b/src/pfusion/instantiate.cc index cdd3b22b..e2746477 100644 --- a/src/pfusion/instantiate.cc +++ b/src/pfusion/instantiate.cc @@ -33,21 +33,21 @@ instantiateUnary(const OpType opType, metaOp->numReg = 8; metaOp->numSmem = 0; - metaOp->mappings.emplace_back(std::make_shared( + metaOp->mappings.emplace_back(TensorMapping::buildWithMap( std::string("src"), std::vector({32 * 8, size / 32 / 8}), std::vector({1}))); metaOp->ptrs = ptrs; auto buf = Pointer::buildPtr(REG, "buf", "inst_idx"); - metaOp->microOps.emplace_back(std::make_shared( + metaOp->microOps.emplace_back(MemoryOp::build( READ, Pointer::buildPtr(ptrs[0], "offset_src + inst_idx * 32 + lane_id"), buf, 8, 32)); metaOp->microOps.emplace_back( std::make_shared(opType, buf, buf, 8, 32)); - metaOp->microOps.emplace_back(std::make_shared( + metaOp->microOps.emplace_back(MemoryOp::build( WRITE, buf, Pointer::buildPtr(ptrs[1], "offset_src + inst_idx * 32 + lane_id"), 8, 32)); @@ -72,7 +72,7 @@ instantiateBinary(const OpType opType, metaOp->numReg = 24; metaOp->numSmem = 0; - metaOp->mappings.emplace_back(std::make_shared( + metaOp->mappings.emplace_back(TensorMapping::buildWithMap( std::string("src"), std::vector({32 * 8, size / 32 / 8}), std::vector({1}))); @@ -81,17 +81,17 @@ instantiateBinary(const OpType opType, auto buf1 = Pointer::buildPtr(REG, "buf", "inst_idx + 8"); auto buf2 = Pointer::buildPtr(REG, "buf", "inst_idx + 16"); - metaOp->microOps.emplace_back(std::make_shared( + metaOp->microOps.emplace_back(MemoryOp::build( READ, Pointer::buildPtr(ptrs[0], "offset_src + inst_idx * 32 + lane_id"), buf0, 8, 32)); - metaOp->microOps.emplace_back(std::make_shared( + metaOp->microOps.emplace_back(MemoryOp::build( READ, Pointer::buildPtr(ptrs[1], "offset_src + inst_idx * 32 + lane_id"), buf1, 8, 32)); metaOp->microOps.emplace_back( std::make_shared(opType, buf0, buf1, buf2, 8, 32)); - metaOp->microOps.emplace_back(std::make_shared( + metaOp->microOps.emplace_back(MemoryOp::build( WRITE, buf2, Pointer::buildPtr(ptrs[2], "offset_src + inst_idx * 32 + lane_id"), 8, 32)); @@ -123,7 +123,7 @@ std::vector> instantiateTranspose( } } metaOp->mappings.emplace_back( - std::make_shared("src", srcShape, srcMap)); + TensorMapping::buildWithMap("src", srcShape, srcMap)); std::vector dstMap; for (size_t i = 0; i < shape.size(); i++) { @@ -132,7 +132,7 @@ std::vector> instantiateTranspose( } } metaOp->mappings.emplace_back( - std::make_shared("dst", shape, dstMap)); + TensorMapping::buildWithMap("dst", shape, dstMap)); metaOp->main_loop_st = 0; metaOp->main_loop_ed = parallelSize; @@ -160,7 +160,10 @@ std::vector> instantiateTranspose( // TODO: tiling is a metaOp or microOps? metaOp->ptrs = ptrs; - auto smem = Pointer::buildPtr(SRAM, "smem", "group_id * " + std::to_string(metaOp->numLanes) + " * " + std::to_string(metaOp->numLanes + 1)); + auto smem = + Pointer::buildPtr(SRAM, "smem", + "group_id * " + std::to_string(metaOp->numLanes) + + " * " + std::to_string(metaOp->numLanes + 1)); auto buf = Pointer::buildPtr(REG, "buf", "inst_idx"); for (int i = 0; i < numTileA; i++) { @@ -170,13 +173,13 @@ std::vector> instantiateTranspose( std::to_string(j * 32 * stride_src + i * 32) + "+" + "inst_idx * " + std::to_string(stride_src) + " + lane_id"); - metaOp->microOps.emplace_back(std::make_shared( - READ, src_ptr, buf, min(32u, shape[0]), - min(32, shape[perm[0]]))); - metaOp->microOps.emplace_back(std::make_shared( + metaOp->microOps.emplace_back( + MemoryOp::build(READ, src_ptr, buf, min(32u, shape[0]), + min(32, shape[perm[0]]))); + metaOp->microOps.emplace_back(MemoryOp::build( WRITE, buf, Pointer::buildPtr(smem, "inst_idx * 33 + lane_id"), min(32, shape[0]), min(32, shape[perm[0]]))); - metaOp->microOps.emplace_back(std::make_shared( + metaOp->microOps.emplace_back(MemoryOp::build( READ, Pointer::buildPtr(smem, "lane_id * 33 + inst_idx"), buf, min(32, shape[perm[0]]), min(32, shape[0]))); auto dst_ptr = Pointer::buildPtr( @@ -184,9 +187,9 @@ std::vector> instantiateTranspose( std::to_string(i * 32 * stride_dst + j * 32) + "+" + "inst_idx * " + std::to_string(stride_dst) + " + lane_id"); - metaOp->microOps.emplace_back(std::make_shared( - WRITE, buf, dst_ptr, min(32, shape[perm[0]]), - min(32, shape[0]))); + metaOp->microOps.emplace_back( + MemoryOp::build(WRITE, buf, dst_ptr, min(32, shape[perm[0]]), + min(32, shape[0]))); } } metaOps.emplace_back(metaOp); @@ -216,7 +219,7 @@ instantiateGather(const OpType opType, metaOp->numReg = 24; metaOp->numSmem = 0; - metaOp->mappings.emplace_back(std::make_shared( + metaOp->mappings.emplace_back(TensorMapping::buildWithMap( std::string("src"), std::vector({seq_size, par_size}), std::vector({1}))); @@ -252,7 +255,7 @@ instantiateReduce(const OpType opType, metaOp->numReg = inputShape[0] / 128; metaOp->numSmem = 0; - metaOp->mappings.emplace_back(std::make_shared( + metaOp->mappings.emplace_back(TensorMapping::buildWithMap( std::string("src"), std::vector({seq_size, par_size}), std::vector({1}))); @@ -283,7 +286,7 @@ instantiateBroadcast(const OpType opType, metaOp->numReg = 24; metaOp->numSmem = 0; - metaOp->mappings.emplace_back(std::make_shared( + metaOp->mappings.emplace_back(TensorMapping::buildWithMap( std::string("src"), std::vector({seq_size, par_size}), std::vector({1}))); diff --git a/src/pfusion/memory_codegen.cc b/src/pfusion/memory_codegen.cc index 16e775bf..19e40073 100644 --- a/src/pfusion/memory_codegen.cc +++ b/src/pfusion/memory_codegen.cc @@ -30,7 +30,7 @@ void exportCode(const std::string &filename, const std::string &code) { } void infini::MemoryCodegen::exportGraph(Graph graph, std::string filename) { - std::string code = generate(graph); + std::string code = generateGraph(graph); exportCode(filename, code); } @@ -64,6 +64,19 @@ void infini::MemoryCodegen::exportViT_GELU(const std::string &filename) { exportCode(filename, code); } +void infini::MemoryCodegen::exportBias(const std::string &filename, + const std::vector &shape) { + std::string code = generateBias(shape); + exportCode(filename, code); +} + +void infini::MemoryCodegen::exportTranspose(const std::string &filename, + const std::vector &shape, + const std::vector &perm) { + std::string code = generateTranspose(shape, perm); + exportCode(filename, code); +} + std::vector convertShape(const std::vector &_shape) { std::vector shape; for (int i = int(_shape.size()); i > 0; i--) { @@ -203,7 +216,7 @@ std::shared_ptr instantiateGraph(infini::Graph graph) { return searchGraph; } -std::string infini::MemoryCodegen::generate(Graph graph) { +std::string infini::MemoryCodegen::generateGraph(Graph graph) { auto searchGraph = instantiateGraph(graph); auto metaGraph = searchGraph->exportFirstMetaGraph(); std::string code = ""; @@ -217,3 +230,28 @@ std::string infini::MemoryCodegen::generate(Graph graph) { code += metaGraph->genInvokeFuncs(); return code; } + +std::string +infini::MemoryCodegen::generateBias(const std::vector &shape) { + auto metaGraph = std::make_shared(); + metaGraph->addOp(memb::MetaOp::buildBiasOp(shape)); + metaGraph->print(); + std::string code = ""; + code += metaGraph->genHeader(); + code += metaGraph->genKernelFuncs(); + code += metaGraph->genInvokeFuncs(); + return code; +} + +std::string +infini::MemoryCodegen::generateTranspose(const std::vector &shape, + const std::vector &perm) { + auto metaGraph = std::make_shared(); + metaGraph->addOp(memb::MetaOp::buildTransposeOp(shape, perm)); + metaGraph->print(); + std::string code = ""; + code += metaGraph->genHeader(); + code += metaGraph->genKernelFuncs(); + code += metaGraph->genInvokeFuncs(); + return code; +} diff --git a/src/pfusion/meta_op.cc b/src/pfusion/meta_op.cc index c7b09cac..5f7a858a 100644 --- a/src/pfusion/meta_op.cc +++ b/src/pfusion/meta_op.cc @@ -1,14 +1,11 @@ #include "pfusion/meta_op.h" +#include "pfusion/micro_kernel/binary.h" +#include "pfusion/micro_kernel/memory.h" namespace memb { std::string TensorMapping::genOffset() { std::string code = "int " + offset() + " = 0;\n"; - std::vector stride; - stride.emplace_back(1); - for (size_t i = 1; i < shape.size(); i++) { - stride.emplace_back(stride[i - 1] * shape[i - 1]); - } std::string bufName = name + "_buf"; code += "int " + bufName + " = loop_idx;\n"; @@ -131,7 +128,7 @@ std::shared_ptr MetaOp::merge(std::shared_ptr metaOp0, if (metaOp0->main_loop_st != metaOp1->main_loop_st || metaOp0->main_loop_ed != metaOp1->main_loop_ed || metaOp0->numBlocks != metaOp1->numBlocks || - metaOp0->numGroups != metaOp1->numGroups || + metaOp0->numGroups != metaOp1->numGroups || metaOp0->numLanes != metaOp1->numLanes) { return nullptr; } @@ -188,4 +185,119 @@ std::shared_ptr MetaOp::merge(std::shared_ptr metaOp0, return metaOp; } +std::shared_ptr MetaOp::buildBiasOp(const std::vector &shape) { + IT_ASSERT(shape.size() == 2); + auto metaOp = std::make_shared(); + metaOp->main_loop_st = 0; + metaOp->main_loop_ed = shape[1] * (shape[0] / 32 / 4); + metaOp->numBlocks = 80; + metaOp->numGroups = 4; + metaOp->numLanes = 32; + metaOp->numReg = 4; + metaOp->numSmem = 0; + + metaOp->mappings.emplace_back(TensorMapping::build( + std::string("input"), + std::vector({32 * 4, (shape[0] - 1) / (32 * 4) + 1, shape[1]}), + std::vector({1, 32 * 4, shape[0]}), + std::vector({1, 2}))); + metaOp->mappings.emplace_back(TensorMapping::buildWithMap( + std::string("bias"), std::vector({shape[0], shape[1]}), + std::vector({1}))); + + metaOp->ptrs = std::vector>(); + auto &ptrs = metaOp->ptrs; + ptrs.emplace_back(Pointer::buildPtr(DRAM, "input")); + ptrs.emplace_back(Pointer::buildPtr(DRAM, "bias")); + ptrs.emplace_back(Pointer::buildPtr(DRAM, "output")); + + auto buf_input = Pointer::buildPtr(REG, "buf", "inst_idx"); + auto buf_bias = Pointer::buildPtr(REG, "buf", "4"); + auto buf_output = Pointer::buildPtr(REG, "buf", "inst_idx"); + + // @cond group_id * 4 * 32 + inst_idx * 32 + lane_id < shape[0] + metaOp->microOps.emplace_back(MemoryOp::build( + READ, + Pointer::buildPtr(ptrs[0], "offset_input + inst_idx * 32 + lane_id"), + buf_input, 4, 32)); + metaOp->microOps.emplace_back(MemoryOp::build( + READ, Pointer::buildPtr(ptrs[1], "offset_bias"), buf_bias, 1, 32)); + metaOp->microOps.emplace_back(std::make_shared( + ADD, buf_input, buf_bias, buf_output, 4, 32)); + // @cond group_id * 4 * 32 + inst_idx * 32 + lane_id < shape[0] + metaOp->microOps.emplace_back(MemoryOp::build( + WRITE, buf_output, + Pointer::buildPtr(ptrs[1], "offset_input + inst_idx * 32 + lane_id"), 8, + 32)); + return metaOp; +} + +std::shared_ptr +MetaOp::buildTransposeOp(const std::vector &shape, + const std::vector &perm) { + IT_ASSERT(perm[0] == 0 && shape[0] >= 32); + IT_ASSERT(shape.size() == 3); + auto metaOp = std::make_shared(); + size_t numInst, extraDim; + std::vector map_shape, map_stride; + if (shape[0] <= 4 * 32) { + numInst = (shape[0] - 1) / 32 + 1; + extraDim = 1; + metaOp->mappings.emplace_back(TensorMapping::build( + std::string("input"), + std::vector({shape[0], shape[perm[1]], shape[perm[2]]}), + std::vector({1, shape[0], shape[0] * shape[perm[1]]}), + std::vector({perm[1], perm[2]}))); + metaOp->mappings.emplace_back(TensorMapping::build( + std::string("output"), + std::vector({shape[0], shape[1], shape[2]}), + std::vector({1, shape[0], shape[0] * shape[1]}), + std::vector({1, 2}))); + // cond: local_id < shape[0]; + } else { + numInst = 4; + extraDim = (shape[0] - 1) / 128 + 1; + metaOp->mappings.emplace_back(TensorMapping::build( + std::string("input"), + std::vector( + {128, extraDim, shape[perm[1]], shape[perm[2]]}), + std::vector({1, 128, shape[0], shape[0] * shape[perm[1]]}), + std::vector({1, perm[1] + 1, perm[2] + 1}))); + metaOp->mappings.emplace_back(TensorMapping::build( + std::string("output"), + std::vector({128, extraDim, shape[1], shape[2]}), + std::vector({1, 128, shape[0], shape[0] * shape[1]}), + std::vector({1, 2, 3}))); + // cond loop_idx % extraDim * 128 + local_id < shape[0]; + } + metaOp->main_loop_st = 0; + metaOp->main_loop_ed = shape[1] * shape[2] * extraDim; + metaOp->numBlocks = 80; + metaOp->numGroups = 4; + metaOp->numLanes = 32; + metaOp->numReg = 4; + metaOp->numSmem = 0; + + metaOp->ptrs = std::vector>(); + auto &ptrs = metaOp->ptrs; + ptrs.emplace_back(Pointer::buildPtr(DRAM, "input")); + ptrs.emplace_back(Pointer::buildPtr(DRAM, "output")); + + auto buf_input = Pointer::buildPtr(REG, "buf", "inst_idx"); + auto buf_output = Pointer::buildPtr(REG, "buf", "inst_idx"); + + // @cond group_id * 4 * 32 + inst_idx * 32 + lane_id < shape[0] + std::vector cond = {shape[0], extraDim, 128}; + auto inPtr = + Pointer::buildPtr(ptrs[0], "offset_input + inst_idx * 32 + lane_id"); + auto opRead = MemoryOp::build(READ, inPtr, buf_input, numInst, 32, cond); + auto outPtr = + Pointer::buildPtr(ptrs[1], "offset_output + inst_idx * 32 + lane_id"); + auto opWrite = + MemoryOp::build(WRITE, buf_output, outPtr, numInst, 32, cond); + metaOp->microOps = std::vector>({opRead, opWrite}); + + return metaOp; +} + } // namespace memb \ No newline at end of file diff --git a/src/pfusion/micro_kernel/memory.cc b/src/pfusion/micro_kernel/memory.cc index 122bd5ec..d16deec2 100644 --- a/src/pfusion/micro_kernel/memory.cc +++ b/src/pfusion/micro_kernel/memory.cc @@ -1,9 +1,79 @@ #include "pfusion/micro_kernel/memory.h" namespace memb { -std::string MemoryOp::generate() { - std::string code; +std::string MemoryOp::generateWithCond() { + IT_ASSERT(cond.size() == 3); + std::string code = ""; + int edge_length = cond[0] % cond[2]; + int edge_num = edge_length / width; + int edge_width = edge_length % width; + + if (edge_num > 0 || edge_width > 0) { + code += "if (loop_idx % " + std::to_string(cond[1]) + + " == " + std::to_string(cond[1] - 1) + ") {\n"; + } + + if (edge_num > 0) { + code += "#pragma unroll\n"; + code += "for (int inst_idx = 0; inst_idx < " + + std::to_string(edge_num) + "; inst_idx++) {\n"; + if ((opType == OpType::READ && getSrc()->getType() != MemType::REG && + getDst()->getType() == MemType::REG) || + (opType == OpType::WRITE && getSrc()->getType() == MemType::REG && + getDst()->getType() != MemType::REG)) { + code += getDst()->generate() + " = " + getSrc()->generate() + ";\n"; + } else { + IT_ASSERT(false); + } + code += "}\n"; + } + + if (edge_width > 0) { + code += "if (lane_id < " + std::to_string(edge_width) + ") {"; + if ((opType == OpType::READ && getSrc()->getType() != MemType::REG && + getDst()->getType() == MemType::REG) || + (opType == OpType::WRITE && getSrc()->getType() == MemType::REG && + getDst()->getType() != MemType::REG)) { + code += getDst()->generateWithInstIdx(std::to_string(edge_num)) + + " = " + + getSrc()->generateWithInstIdx(std::to_string(edge_num)) + + ";\n"; + } else { + IT_ASSERT(false); + } + code += "}\n"; + } + + if (edge_num > 0 || edge_width > 0) { + code += "} else {\n"; + } + + code += "#pragma unroll\n"; + code += "for (int inst_idx = 0; inst_idx < " + std::to_string(num) + + "; inst_idx++) {\n"; + if ((opType == OpType::READ && getSrc()->getType() != MemType::REG && + getDst()->getType() == MemType::REG) || + (opType == OpType::WRITE && getSrc()->getType() == MemType::REG && + getDst()->getType() != MemType::REG)) { + code += getDst()->generate() + " = " + getSrc()->generate() + ";\n"; + } else { + IT_ASSERT(false); + } + + code += "}\n"; + if (edge_num > 0 || edge_width > 0) { + code += "}\n"; + } + code += "// test\n"; + return code; +} + +std::string MemoryOp::generate() { + if (cond.size() != 0) { + return generateWithCond(); + } + std::string code; if (width < 32) { code += "if (lane_id < " + std::to_string(width) + ") {\n"; } diff --git a/test/pfusion/.test_sar_drn.cc.swp b/test/pfusion/.test_sar_drn.cc.swp deleted file mode 100644 index 9d6274f4dc1c103f14d745c784ca22b8f1db70ae..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeI2y>HV%6u@6$;-i457?F-tNGS?&oRoq>sZ_OqI+RFh_*lTY*jLwJpDp`NZK(=7 zD`JI#g^7=ejsF2LBK`o>GO;l7>^P#WQZ;R*3edCii*?WMUM@d9TZ(e+%AJ{6cBNbe zJWc^TTf1jFUk`(Gb`%2scqt9~C~;d;uLk@KGPT0$maY$%Z^BbwKV z7s_m4B|NsASe9*j=c1HB2FO641_q&4H7jZIaN!&qI(@s(d6Y&5$N(8217v^EZkjuyrGJpW`;{?Du@&$Q?D5Qd{4+6YF9HfDq zM}8d#_=bEy-XU+17YIkjke|l@-Xj5W0Xc)L;krH{uaQMWBMM2^dmo{L43GgbKnBPF z86X2>;BPXp$_#@^9*893wqRbi?Gs=X7#?=$2v^9mSc<|}S?RFomRKR*{b`wHTNg&I zX1TK>2;=3c&W!HN&L#Ka1b3W9bk|c04Xq-rlJPm+ZnPgx6ji0fl!;XHI=zyc!q*$J z|71E`@mcQsp_ogJ%8)z~%GDbuFXjgwo;k*L>i_#s!d?h3jk zs{1ct?!QE3?b;|{&rV|F9IQzwdifFV280V@*8XO?Cgr0-laX_Z9i({28 * 28, 24})); +} + +TEST(Graph, bias_1) { + MemoryCodegen codegen; + codegen.exportBias("bias_1.cu", std::vector({28 * 28, 58})); +} + +TEST(Graph, bias_2) { + MemoryCodegen codegen; + codegen.exportBias("bias_2.cu", std::vector({14 * 14, 116})); +} + +TEST(Graph, bias_3) { + MemoryCodegen codegen; + codegen.exportBias("bias_3.cu", std::vector({7 * 7, 232})); +} + +} // namespace infini diff --git a/test/pfusion/test_transpose.cc b/test/pfusion/test_transpose.cc index 62f55295..116f3124 100644 --- a/test/pfusion/test_transpose.cc +++ b/test/pfusion/test_transpose.cc @@ -1,23 +1,21 @@ -#include "core/blob.h" -#include "core/graph.h" -#include "core/runtime.h" -#include "operators/matmul.h" -#include "operators/transpose.h" -#include "operators/unary.h" #include "pfusion/memory_codegen.h" #include "test.h" namespace infini { -TEST(Graph, transpose) { - Runtime runtime = CpuRuntimeObj::getInstance(); - Graph g = make_ref(runtime); - Tensor t0 = g->addTensor({32, 31, 33, 32}, DataType::Float32); - Tensor t1 = g->addTensor({33, 32, 32, 31}, DataType::Float32); - g->dataMalloc(); - g->addOpWithOutputs(t0, t1, Shape{2, 0, 3, 1}); +TEST(Graph, transpose_0) { MemoryCodegen codegen; - codegen.exportGraph(g, "transpose.cu"); + codegen.exportTranspose("transpose_0.cu", {28 * 28, 58, 2}, {0, 2, 1}); +} + +TEST(Graph, transpose_1) { + MemoryCodegen codegen; + codegen.exportTranspose("transpose_1.cu", {14 * 14, 116, 2}, {0, 2, 1}); +} + +TEST(Graph, transpose_2) { + MemoryCodegen codegen; + codegen.exportTranspose("transpose_2.cu", {7 * 7, 232, 2}, {0, 2, 1}); } } // namespace infini