Operators g2bmm&gbmm transplantation (#24)

* Function tune and corresponding testcase.

*Add: Tune function in /src/kernel/cuda/conv.cc and corresponding testcase in test_conv.

*Fix: A little bug of perfRecord using in /src/core/runtime.cc.

* Tune part debug

*Add: recover the code, fixed the commit error.

*Add: some anotations in tune function

* clang formmat test

* Fix: mem leak in CUDA Runtime and Conv

* Fix: sync in conv and default sync in timeit

* Change the way to tune operator conv.

Timeit function cudNNUnfused -> Timeit function cudnnConvolutionForward.

* Change: merge the common part of cudnnunfused&tune into cudnndescriptoraccess

* clang test

* clang-format

* clang-format bash.

* Added operator G2BMM and corresponding testcase.

*Added files related to operator G2BMM creating&calling.

*Added custom_ops.cuh&custom_op.h.

* Add operator GBMML

* new version

* Fix: G2BMM and GBMM kernel bugs

* Added testcase of operator GBMML

* clang format

* Added cmake option REQUIRE_GCC9

* Delete redundent file

* Renamed class GBMML into GBMM

* clang format

* Reviewed.

* Added cudahostcompier option.

* Add: explicit CMAKE_CUDA_HOST_COMPILER

* Rename gbmm kernel

* Fix: nvcc warning in GBMM and G2BMM

Co-authored-by: wcz112 <wcz19@mails.tsinghua.edu.cn>
Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
Anmuliar 2022-09-08 21:31:35 +08:00 committed by GitHub
parent e1d43202d7
commit 0409eafb5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 6246 additions and 12 deletions

View File

@ -1,5 +1,4 @@
# TODO: check the minimum cmake version
cmake_minimum_required(VERSION 3.9) # Required by find_package(OpenMP)
cmake_minimum_required(VERSION 3.10) # Required by CMAKE_CUDA_HOST_COMPILER
include(CMakeDependentOption)
project(InfiniTensor C CXX)
@ -16,6 +15,7 @@ set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
@ -67,11 +67,12 @@ if(USE_BACKTRACE)
endif()
if(USE_CUDA)
# set(CUDA_HOST_COMPILER /home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc)
# Since enable_language only executes once, rerun cmake is required if CMAKE_CUDA_HOST_COMPILER is wrong
set(CMAKE_CUDA_HOST_COMPILER
${CMAKE_CXX_COMPILER}
CACHE STRING "Set cuda host compiler path")
enable_language(CUDA)
# TODO: how to set option for CUDA_HOST_COMPILER. Now env var CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc takes effect.
# option(CUDA_HOST_COMPILER "" ${CMAKE_C_COMPILER})
# TODO: find_package seems unnecessary for CMake >= 3.8
# TODO: find_package seems unnecessary for CMake >= 3.8
find_package(CUDA REQUIRED)
# message("CUBLAS_LIBRARIES: ${CUDA_LIBRARIES}")
target_link_libraries(InfiniTensor cudnn curand cublas ${CUDA_LIBRARIES})

View File

@ -10,7 +10,7 @@ enum class OpType {
Matmul,
ConvTrans,
G2BMM,
GBMML,
GBMM,
Pad,
Slice,
Concat,
@ -53,7 +53,7 @@ class OpRegistry {
FOP(Matmul);
FOP(ConvTrans);
FOP(G2BMM);
FOP(GBMML);
FOP(GBMM);
FOP(Pad);
FOP(Slice);
FOP(Concat);

5802
include/custom_ops.cuh Normal file

File diff suppressed because it is too large Load Diff

14
include/custom_ops.h Normal file
View File

@ -0,0 +1,14 @@
#ifndef CUSTOM_OPS_H
#define CUSTOM_OPS_H
namespace infini {
void _sg2bmm(float *__restrict__ q, float *__restrict__ k,
float *__restrict__ y, int bs, int n, int m, int w, int d);
void _sgbmml(float *__restrict__ q, float *__restrict__ k,
float *__restrict__ y, int bs, int n, int m, int w, int d);
} // namespace infini
#endif // CUSTOM_OPS_H

51
include/operators/G2BMM.h Normal file
View File

@ -0,0 +1,51 @@
#pragma once
#include "core/operator.h"
#include <assert.h>
namespace infini {
class G2BMMObj : public OperatorObj {
private:
// to be implemented
int width, dilation;
ActType act;
int b, m, k;
public:
/**
* @brief This comments show how operators is defined in InfiniTensor. The
* constructor can create output tensors for the operator or not, which
* depends on `graph`.
*
* @param graph If graph is not empty, create outputs in the constructor.
* Otherwise, check the provided shape with the results of `inferShape` in
* `checkValid`.
* @param C C is the output of G2BMM. If outputs are going to be created in
* the constructor, C should be an empty Ref.
*/
G2BMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, const int width,
const int dilation, Tensor bias = nullptr,
ActType act = ActType::None);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
int getWidth() const { return width; }
int getDilation() const { return dilation; }
Tensor getBias() const { return inputs[2]; }
ActType getAct() const { return act; }
int getB() const { return b; }
int getM() const { return m; }
int getK() const { return k; }
auto getBMKWD() const { return tuple{b, m, k, width, dilation}; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

49
include/operators/GBMM.h Normal file
View File

@ -0,0 +1,49 @@
#pragma once
#include "core/operator.h"
#include <assert.h>
namespace infini {
class GBMMObj : public OperatorObj {
private:
int dilation;
ActType act;
int b, m, w, n;
public:
/**
* @brief This comments show how operators is defined in InfiniTensor. The
* constructor can create output tensors for the operator or not, which
* depends on `graph`.
*
* @param graph If graph is not empty, create outputs in the constructor.
* Otherwise, check the provided shape with the results of `inferShape` in
* `checkValid`.
* @param C C is the output of GBMM. If outputs are going to be created in
* the constructor, C should be an empty Ref.
*/
GBMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, const int dilation,
Tensor bias = nullptr, ActType act = ActType::None);
std::string toString() const override;
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
int numInputs() const override { return 2; }
int numOutputs() const override { return 1; }
int getDilation() const { return dilation; }
Tensor getBias() const { return inputs[2]; }
ActType getAct() const { return act; }
int getB() const { return b; }
int getM() const { return m; }
int getW() const { return w; }
int getN() const { return n; }
auto getBMWND() const { return tuple{b, m, w, n, dilation}; }
private:
vector<int> getWorkloadVector() const override;
vector<int> getOpAttrVector() const override;
};
} // namespace infini

View File

@ -19,7 +19,7 @@ bool OperatorObj::isConcatOp() const { return type == OpType::Concat; }
bool OperatorObj::isComputeOp() const {
return type == OpType::Conv || type == OpType::Matmul ||
type == OpType::ConvTrans || type == OpType::G2BMM ||
type == OpType::GBMML;
type == OpType::GBMM;
}
bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
@ -53,6 +53,7 @@ bool OperatorObj::checkValid(GraphObj *graph) {
auto optShapes = inferShape();
if (!optShapes) // shape inference failed
return false;
const vector<Shape> &shapes = *optShapes;
if (shapes.size() != outputs.size())
return false;

View File

@ -48,8 +48,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
void CudaRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
if (profiling)
IT_TODO_HALT();
runWithoutSync(graph, tune);
runWithoutSync(graph, tune, profiling);
sync();
}

17
src/custom_ops.cu Normal file
View File

@ -0,0 +1,17 @@
#include "custom_ops.cuh"
#include "custom_ops.h"
namespace infini {
void _sg2bmm(float *__restrict__ q, float *__restrict__ k,
float *__restrict__ y, int bs, int n, int m, int w, int d) {
sg2bmm(q, k, y, bs, n, m, w, d);
}
void _sgbmml(float *__restrict__ q, float *__restrict__ k,
float *__restrict__ y, int bs, int n, int m, int w, int d) {
sgbmml(q, k, y, bs, n, m, w, d);
}
} // namespace infini

64
src/kernels/cuda/G2BMM.cc Normal file
View File

@ -0,0 +1,64 @@
#include "operators/G2BMM.h"
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include "custom_ops.h"
#include <chrono>
#include <functional>
#include <tuple>
namespace infini {
class G2BMMCudnn : public Kernel {
bool g2bmmKernel(const Ref<G2BMMObj> &op,
const CudaRuntimeObj *context) const {
float *const inAData = (op->getInputs(0)->getRawDataPtr<float *>());
float *const inBData = (op->getInputs(1)->getRawDataPtr<float *>());
if (op->getInputs().size() > 2)
IT_TODO_HALT();
float *const outData = (op->getOutput()->getRawDataPtr<float *>());
const auto [b, n, m, width, dilation] = op->getBMKWD();
_sg2bmm(inAData, inBData, outData, b, n, m, width, dilation);
// checkCudaError(cudaDeviceSynchronize());
return true;
}
void compute(const Operator &op, const RuntimeObj *context) const override {
PerfRecord record;
compute(op, record, context);
}
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord record;
auto op = as<G2BMMObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
record.time = std::numeric_limits<double>::max();
const auto [warmupRounds, timingRounds] =
op->getB() > 100 ? tuple{1, 3} : tuple{5, 15};
double tmp =
timeit([&]() { g2bmmKernel(op, context); },
[&]() { context->sync(); }, warmupRounds, timingRounds);
if (tmp < record.time)
record.time = tmp;
IT_ASSERT(record.time < std::numeric_limits<double>::max(),
"Error occured "
"during runtime");
return record;
}
void compute(const Operator &_op, const PerfRecord &_record,
const RuntimeObj *_context) const override {
auto op = as<G2BMMObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
bool success = g2bmmKernel(op, context);
IT_ASSERT(success);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::G2BMM, DataType::Float32, G2BMMCudnn,
"G2BMM_cuDNN_CUDA_Float32");
} // namespace infini

65
src/kernels/cuda/GBMM.cc Normal file
View File

@ -0,0 +1,65 @@
#include "operators/GBMM.h"
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include "custom_ops.h"
#include <chrono>
#include <functional>
#include <tuple>
namespace infini {
class GBMMCudnn : public Kernel {
bool gbmmKernel(const Ref<GBMMObj> &op,
const CudaRuntimeObj *context) const {
float *const inAData = (op->getInputs(0)->getRawDataPtr<float *>());
float *const inBData = (op->getInputs(1)->getRawDataPtr<float *>());
if (op->getInputs().size() > 2)
IT_TODO_HALT();
float *const outData = (op->getOutput()->getRawDataPtr<float *>());
const auto [b, m, w, n, dilation] = op->getBMWND();
// printf("%d %d %d %d %d\n", b, m, n, w, dilation);
_sgbmml(inAData, inBData, outData, b, m, n, w, dilation);
// checkCudaError(cudaDeviceSynchronize());
return true;
}
void compute(const Operator &op, const RuntimeObj *context) const override {
PerfRecord record;
compute(op, record, context);
}
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
PerfRecord record;
auto op = as<GBMMObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
record.time = std::numeric_limits<double>::max();
const auto [warmupRounds, timingRounds] =
op->getB() > 100 ? tuple{1, 3} : tuple{5, 15};
double tmp =
timeit([&]() { gbmmKernel(op, context); },
[&]() { context->sync(); }, warmupRounds, timingRounds);
if (tmp < record.time)
record.time = tmp;
IT_ASSERT(record.time < std::numeric_limits<double>::max(),
"Error occured "
"during runtime");
return record;
}
void compute(const Operator &_op, const PerfRecord &_record,
const RuntimeObj *_context) const override {
auto op = as<GBMMObj>(_op);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
bool success = gbmmKernel(op, context);
IT_ASSERT(success);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::GBMM, DataType::Float32, GBMMCudnn,
"GBMM_cuDNN_CUDA_Float32");
} // namespace infini

View File

@ -244,7 +244,6 @@ class convCudnn : public Kernel {
// Update the tune result
if (ret.time > record.time)
ret = record;
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));

50
src/operators/G2BMM.cc Normal file
View File

@ -0,0 +1,50 @@
#include "operators/G2BMM.h"
#include "custom_ops.h"
namespace infini {
G2BMMObj::G2BMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, int width,
int dilation, [[maybe_unused]] Tensor bias, ActType act)
: OperatorObj(OpType::G2BMM, {A, B}, {C}), width(width), dilation(dilation),
act(act), b(A->getDims()[0]), m(A->getDims()[1]), k(A->getDims()[2]) {
IT_ASSERT(checkValid(graph));
}
string G2BMMObj::toString() const {
std::ostringstream os;
os << "G2BMM(["
<< "width=" << width << ",act=" << enum_to_underlying(act)
<< "],A=" << inputs[0]->getGuid() << ",B=" << inputs[1]->getGuid()
<< ",C=" << outputs[0]->getGuid() << ", TTbmnkd: " << this->getB()
<< ", " << this->getM() << ", " << this->getWidth() << ", "
<< inputs[1]->getDims()[2] << ", " << this->getDilation() << ")";
return os.str();
}
optional<vector<Shape>> G2BMMObj::inferShape(const TensorVec &inputs) const {
auto A = inputs[0], B = inputs[1];
if (!(A->getDims().size() == 3 && B->getDims().size() == 3))
return {};
if (!(A->getDims()[0] == B->getDims()[0]))
return {};
if (!(A->getDims()[1] == B->getDims()[1]))
return {};
if (!(A->getDims()[2] == B->getDims()[2]))
return {};
if (width < 0)
return {};
int b(A->getDims()[0]), m(A->getDims()[1]), n(2 * width + 1);
return {{{b, m, n}}};
}
vector<int> G2BMMObj::getWorkloadVector() const {
return {enum_to_underlying(type), b, m, k, width, dilation,
enum_to_underlying(act)};
}
vector<int> G2BMMObj::getOpAttrVector() const {
return {enum_to_underlying(type), width, dilation, enum_to_underlying(act)};
}
} // namespace infini

48
src/operators/GBMM.cc Normal file
View File

@ -0,0 +1,48 @@
#include "operators/GBMM.h"
#include "custom_ops.h"
namespace infini {
GBMMObj::GBMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, int dilation,
[[maybe_unused]] Tensor bias, ActType act)
: OperatorObj(OpType::GBMM, {A, B}, {C}), dilation(dilation), act(act),
b(A->getDims()[0]), m(A->getDims()[1]), w((A->getDims()[2] - 1) / 2),
n(B->getDims()[2]) {
IT_ASSERT(checkValid(graph));
}
string GBMMObj::toString() const {
std::ostringstream os;
os << "GBMM(["
<< ",act=" << (int)act << "],A=" << inputs[0]->getGuid()
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
<< ", TTbmwnd: " << this->getB() << ", " << this->getM() << ", "
<< this->getW() << ", " << this->getN() << ", " << this->getDilation()
<< ")";
return os.str();
}
optional<vector<Shape>> GBMMObj::inferShape(const TensorVec &inputs) const {
auto A = inputs[0], B = inputs[1];
if (!(A->getDims().size() == 3 && B->getDims().size() == 3))
return {};
if (!(A->getDims()[0] == B->getDims()[0]))
return {};
if (!(A->getDims()[1] == B->getDims()[1]))
return {};
if (A->getDims()[2] % 2 == 0)
return {};
int b(A->getDims()[0]), m(A->getDims()[1]), k(B->getDims()[2]);
return {{{b, m, k}}};
}
vector<int> GBMMObj::getWorkloadVector() const {
return {enum_to_underlying(type), b, m, w, n, dilation,
enum_to_underlying(act)};
}
vector<int> GBMMObj::getOpAttrVector() const {
return {enum_to_underlying(type), dilation, enum_to_underlying(act)};
}
} // namespace infini

View File

@ -0,0 +1,37 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/G2BMM.h"
#include "test.h"
namespace infini {
using ExpectOutput = vector<float>;
TEST(G2BMM, ShapeInference) {
const int bs = 1, seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4;
const int hidden = featlen, hiddenPerHead = hidden / heads;
auto cpuRuntime = CpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
auto ACpu = gCpu->addTensor(Shape{bs * heads, seqlen, hiddenPerHead},
DataType::Float32);
auto BCpu = gCpu->addTensor(Shape{bs * heads, seqlen, hiddenPerHead},
DataType::Float32);
gCpu->dataMalloc();
ACpu->setData(IncrementalGenerator());
BCpu->setData(IncrementalGenerator());
auto cudaRuntime = make_ref<CudaRuntimeObj>();
auto gCuda = make_ref<GraphObj>(cudaRuntime);
auto ACuda = gCuda->cloneTensor(ACpu);
auto BCuda = gCuda->cloneTensor(BCpu);
auto G2BMM = gCuda->addOp<G2BMMObj>(ACuda, BCuda, nullptr, w, d);
EXPECT_EQ(G2BMM->getOutput()->getDims(),
(Shape{bs * heads, seqlen, 2 * w + 1}));
gCuda->dataMalloc();
cudaRuntime->run(gCuda);
}
}; // namespace infini

View File

@ -0,0 +1,37 @@
#include "core/graph.h"
#include "core/kernel.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/GBMM.h"
#include "test.h"
namespace infini {
using ExpectOutput = vector<float>;
TEST(GBMM, ShapeInference) {
const int bs = 1, seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4;
const int hidden = featlen, hiddenPerHead = hidden / heads;
auto cpuRuntime = CpuRuntimeObj::getInstance();
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
auto ACpu = gCpu->addTensor(Shape{bs * heads, seqlen, w * 2 + 1},
DataType::Float32);
auto BCpu = gCpu->addTensor(Shape{bs * heads, seqlen, hiddenPerHead},
DataType::Float32);
gCpu->dataMalloc();
ACpu->setData(IncrementalGenerator());
BCpu->setData(IncrementalGenerator());
auto cudaRuntime = make_ref<CudaRuntimeObj>();
auto gCuda = make_ref<GraphObj>(cudaRuntime);
auto ACuda = gCuda->cloneTensor(ACpu);
auto BCuda = gCuda->cloneTensor(BCpu);
auto GBMM = gCuda->addOp<GBMMObj>(ACuda, BCuda, nullptr, d);
EXPECT_EQ(GBMM->getOutput()->getDims(),
(Shape{bs * heads, seqlen, hiddenPerHead}));
gCuda->dataMalloc();
cudaRuntime->run(gCuda);
}
} // namespace infini