forked from jiuyuan/InfiniTensor
Operators g2bmm&gbmm transplantation (#24)
* Function tune and corresponding testcase. *Add: Tune function in /src/kernel/cuda/conv.cc and corresponding testcase in test_conv. *Fix: A little bug of perfRecord using in /src/core/runtime.cc. * Tune part debug *Add: recover the code, fixed the commit error. *Add: some anotations in tune function * clang formmat test * Fix: mem leak in CUDA Runtime and Conv * Fix: sync in conv and default sync in timeit * Change the way to tune operator conv. Timeit function cudNNUnfused -> Timeit function cudnnConvolutionForward. * Change: merge the common part of cudnnunfused&tune into cudnndescriptoraccess * clang test * clang-format * clang-format bash. * Added operator G2BMM and corresponding testcase. *Added files related to operator G2BMM creating&calling. *Added custom_ops.cuh&custom_op.h. * Add operator GBMML * new version * Fix: G2BMM and GBMM kernel bugs * Added testcase of operator GBMML * clang format * Added cmake option REQUIRE_GCC9 * Delete redundent file * Renamed class GBMML into GBMM * clang format * Reviewed. * Added cudahostcompier option. * Add: explicit CMAKE_CUDA_HOST_COMPILER * Rename gbmm kernel * Fix: nvcc warning in GBMM and G2BMM Co-authored-by: wcz112 <wcz19@mails.tsinghua.edu.cn> Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
parent
e1d43202d7
commit
0409eafb5f
|
@ -1,5 +1,4 @@
|
||||||
# TODO: check the minimum cmake version
|
cmake_minimum_required(VERSION 3.10) # Required by CMAKE_CUDA_HOST_COMPILER
|
||||||
cmake_minimum_required(VERSION 3.9) # Required by find_package(OpenMP)
|
|
||||||
include(CMakeDependentOption)
|
include(CMakeDependentOption)
|
||||||
project(InfiniTensor C CXX)
|
project(InfiniTensor C CXX)
|
||||||
|
|
||||||
|
@ -16,6 +15,7 @@ set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
|
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
|
||||||
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations")
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
||||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
||||||
|
@ -67,10 +67,11 @@ if(USE_BACKTRACE)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_CUDA)
|
if(USE_CUDA)
|
||||||
# set(CUDA_HOST_COMPILER /home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc)
|
# Since enable_language only executes once, rerun cmake is required if CMAKE_CUDA_HOST_COMPILER is wrong
|
||||||
|
set(CMAKE_CUDA_HOST_COMPILER
|
||||||
|
${CMAKE_CXX_COMPILER}
|
||||||
|
CACHE STRING "Set cuda host compiler path")
|
||||||
enable_language(CUDA)
|
enable_language(CUDA)
|
||||||
# TODO: how to set option for CUDA_HOST_COMPILER. Now env var CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc takes effect.
|
|
||||||
# option(CUDA_HOST_COMPILER "" ${CMAKE_C_COMPILER})
|
|
||||||
# TODO: find_package seems unnecessary for CMake >= 3.8
|
# TODO: find_package seems unnecessary for CMake >= 3.8
|
||||||
find_package(CUDA REQUIRED)
|
find_package(CUDA REQUIRED)
|
||||||
# message("CUBLAS_LIBRARIES: ${CUDA_LIBRARIES}")
|
# message("CUBLAS_LIBRARIES: ${CUDA_LIBRARIES}")
|
||||||
|
|
|
@ -10,7 +10,7 @@ enum class OpType {
|
||||||
Matmul,
|
Matmul,
|
||||||
ConvTrans,
|
ConvTrans,
|
||||||
G2BMM,
|
G2BMM,
|
||||||
GBMML,
|
GBMM,
|
||||||
Pad,
|
Pad,
|
||||||
Slice,
|
Slice,
|
||||||
Concat,
|
Concat,
|
||||||
|
@ -53,7 +53,7 @@ class OpRegistry {
|
||||||
FOP(Matmul);
|
FOP(Matmul);
|
||||||
FOP(ConvTrans);
|
FOP(ConvTrans);
|
||||||
FOP(G2BMM);
|
FOP(G2BMM);
|
||||||
FOP(GBMML);
|
FOP(GBMM);
|
||||||
FOP(Pad);
|
FOP(Pad);
|
||||||
FOP(Slice);
|
FOP(Slice);
|
||||||
FOP(Concat);
|
FOP(Concat);
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,14 @@
|
||||||
|
#ifndef CUSTOM_OPS_H
|
||||||
|
#define CUSTOM_OPS_H
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void _sg2bmm(float *__restrict__ q, float *__restrict__ k,
|
||||||
|
float *__restrict__ y, int bs, int n, int m, int w, int d);
|
||||||
|
|
||||||
|
void _sgbmml(float *__restrict__ q, float *__restrict__ k,
|
||||||
|
float *__restrict__ y, int bs, int n, int m, int w, int d);
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
|
#endif // CUSTOM_OPS_H
|
|
@ -0,0 +1,51 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
#include <assert.h>
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class G2BMMObj : public OperatorObj {
|
||||||
|
private:
|
||||||
|
// to be implemented
|
||||||
|
int width, dilation;
|
||||||
|
ActType act;
|
||||||
|
|
||||||
|
int b, m, k;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief This comments show how operators is defined in InfiniTensor. The
|
||||||
|
* constructor can create output tensors for the operator or not, which
|
||||||
|
* depends on `graph`.
|
||||||
|
*
|
||||||
|
* @param graph If graph is not empty, create outputs in the constructor.
|
||||||
|
* Otherwise, check the provided shape with the results of `inferShape` in
|
||||||
|
* `checkValid`.
|
||||||
|
* @param C C is the output of G2BMM. If outputs are going to be created in
|
||||||
|
* the constructor, C should be an empty Ref.
|
||||||
|
*/
|
||||||
|
G2BMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, const int width,
|
||||||
|
const int dilation, Tensor bias = nullptr,
|
||||||
|
ActType act = ActType::None);
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
int numInputs() const override { return 2; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
int getWidth() const { return width; }
|
||||||
|
int getDilation() const { return dilation; }
|
||||||
|
Tensor getBias() const { return inputs[2]; }
|
||||||
|
ActType getAct() const { return act; }
|
||||||
|
|
||||||
|
int getB() const { return b; }
|
||||||
|
int getM() const { return m; }
|
||||||
|
int getK() const { return k; }
|
||||||
|
auto getBMKWD() const { return tuple{b, m, k, width, dilation}; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,49 @@
|
||||||
|
#pragma once
|
||||||
|
#include "core/operator.h"
|
||||||
|
#include <assert.h>
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class GBMMObj : public OperatorObj {
|
||||||
|
private:
|
||||||
|
int dilation;
|
||||||
|
ActType act;
|
||||||
|
|
||||||
|
int b, m, w, n;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/**
|
||||||
|
* @brief This comments show how operators is defined in InfiniTensor. The
|
||||||
|
* constructor can create output tensors for the operator or not, which
|
||||||
|
* depends on `graph`.
|
||||||
|
*
|
||||||
|
* @param graph If graph is not empty, create outputs in the constructor.
|
||||||
|
* Otherwise, check the provided shape with the results of `inferShape` in
|
||||||
|
* `checkValid`.
|
||||||
|
* @param C C is the output of GBMM. If outputs are going to be created in
|
||||||
|
* the constructor, C should be an empty Ref.
|
||||||
|
*/
|
||||||
|
GBMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, const int dilation,
|
||||||
|
Tensor bias = nullptr, ActType act = ActType::None);
|
||||||
|
|
||||||
|
std::string toString() const override;
|
||||||
|
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||||
|
|
||||||
|
int numInputs() const override { return 2; }
|
||||||
|
int numOutputs() const override { return 1; }
|
||||||
|
|
||||||
|
int getDilation() const { return dilation; }
|
||||||
|
Tensor getBias() const { return inputs[2]; }
|
||||||
|
ActType getAct() const { return act; }
|
||||||
|
|
||||||
|
int getB() const { return b; }
|
||||||
|
int getM() const { return m; }
|
||||||
|
int getW() const { return w; }
|
||||||
|
int getN() const { return n; }
|
||||||
|
auto getBMWND() const { return tuple{b, m, w, n, dilation}; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
vector<int> getWorkloadVector() const override;
|
||||||
|
vector<int> getOpAttrVector() const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -19,7 +19,7 @@ bool OperatorObj::isConcatOp() const { return type == OpType::Concat; }
|
||||||
bool OperatorObj::isComputeOp() const {
|
bool OperatorObj::isComputeOp() const {
|
||||||
return type == OpType::Conv || type == OpType::Matmul ||
|
return type == OpType::Conv || type == OpType::Matmul ||
|
||||||
type == OpType::ConvTrans || type == OpType::G2BMM ||
|
type == OpType::ConvTrans || type == OpType::G2BMM ||
|
||||||
type == OpType::GBMML;
|
type == OpType::GBMM;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
|
bool OperatorObj::isTransposeOp() const { return type == OpType::Transpose; }
|
||||||
|
@ -53,6 +53,7 @@ bool OperatorObj::checkValid(GraphObj *graph) {
|
||||||
auto optShapes = inferShape();
|
auto optShapes = inferShape();
|
||||||
if (!optShapes) // shape inference failed
|
if (!optShapes) // shape inference failed
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
const vector<Shape> &shapes = *optShapes;
|
const vector<Shape> &shapes = *optShapes;
|
||||||
if (shapes.size() != outputs.size())
|
if (shapes.size() != outputs.size())
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -48,8 +48,7 @@ void CudaRuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
||||||
void CudaRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
void CudaRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
||||||
if (profiling)
|
if (profiling)
|
||||||
IT_TODO_HALT();
|
IT_TODO_HALT();
|
||||||
|
runWithoutSync(graph, tune, profiling);
|
||||||
runWithoutSync(graph, tune);
|
|
||||||
sync();
|
sync();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
#include "custom_ops.cuh"
|
||||||
|
#include "custom_ops.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
void _sg2bmm(float *__restrict__ q, float *__restrict__ k,
|
||||||
|
float *__restrict__ y, int bs, int n, int m, int w, int d) {
|
||||||
|
sg2bmm(q, k, y, bs, n, m, w, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
void _sgbmml(float *__restrict__ q, float *__restrict__ k,
|
||||||
|
float *__restrict__ y, int bs, int n, int m, int w, int d) {
|
||||||
|
sgbmml(q, k, y, bs, n, m, w, d);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
#include "operators/G2BMM.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "custom_ops.h"
|
||||||
|
#include <chrono>
|
||||||
|
#include <functional>
|
||||||
|
#include <tuple>
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class G2BMMCudnn : public Kernel {
|
||||||
|
|
||||||
|
bool g2bmmKernel(const Ref<G2BMMObj> &op,
|
||||||
|
const CudaRuntimeObj *context) const {
|
||||||
|
float *const inAData = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||||
|
float *const inBData = (op->getInputs(1)->getRawDataPtr<float *>());
|
||||||
|
if (op->getInputs().size() > 2)
|
||||||
|
IT_TODO_HALT();
|
||||||
|
|
||||||
|
float *const outData = (op->getOutput()->getRawDataPtr<float *>());
|
||||||
|
|
||||||
|
const auto [b, n, m, width, dilation] = op->getBMKWD();
|
||||||
|
|
||||||
|
_sg2bmm(inAData, inBData, outData, b, n, m, width, dilation);
|
||||||
|
// checkCudaError(cudaDeviceSynchronize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||||
|
PerfRecord record;
|
||||||
|
compute(op, record, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
PerfRecord tune(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
PerfRecord record;
|
||||||
|
auto op = as<G2BMMObj>(_op);
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
record.time = std::numeric_limits<double>::max();
|
||||||
|
const auto [warmupRounds, timingRounds] =
|
||||||
|
op->getB() > 100 ? tuple{1, 3} : tuple{5, 15};
|
||||||
|
double tmp =
|
||||||
|
timeit([&]() { g2bmmKernel(op, context); },
|
||||||
|
[&]() { context->sync(); }, warmupRounds, timingRounds);
|
||||||
|
if (tmp < record.time)
|
||||||
|
record.time = tmp;
|
||||||
|
IT_ASSERT(record.time < std::numeric_limits<double>::max(),
|
||||||
|
"Error occured "
|
||||||
|
"during runtime");
|
||||||
|
return record;
|
||||||
|
}
|
||||||
|
void compute(const Operator &_op, const PerfRecord &_record,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<G2BMMObj>(_op);
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
bool success = g2bmmKernel(op, context);
|
||||||
|
IT_ASSERT(success);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::G2BMM, DataType::Float32, G2BMMCudnn,
|
||||||
|
"G2BMM_cuDNN_CUDA_Float32");
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,65 @@
|
||||||
|
#include "operators/GBMM.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "custom_ops.h"
|
||||||
|
#include <chrono>
|
||||||
|
#include <functional>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
class GBMMCudnn : public Kernel {
|
||||||
|
|
||||||
|
bool gbmmKernel(const Ref<GBMMObj> &op,
|
||||||
|
const CudaRuntimeObj *context) const {
|
||||||
|
float *const inAData = (op->getInputs(0)->getRawDataPtr<float *>());
|
||||||
|
float *const inBData = (op->getInputs(1)->getRawDataPtr<float *>());
|
||||||
|
if (op->getInputs().size() > 2)
|
||||||
|
IT_TODO_HALT();
|
||||||
|
|
||||||
|
float *const outData = (op->getOutput()->getRawDataPtr<float *>());
|
||||||
|
|
||||||
|
const auto [b, m, w, n, dilation] = op->getBMWND();
|
||||||
|
// printf("%d %d %d %d %d\n", b, m, n, w, dilation);
|
||||||
|
_sgbmml(inAData, inBData, outData, b, m, n, w, dilation);
|
||||||
|
// checkCudaError(cudaDeviceSynchronize());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||||
|
PerfRecord record;
|
||||||
|
compute(op, record, context);
|
||||||
|
}
|
||||||
|
|
||||||
|
PerfRecord tune(const Operator &_op,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
PerfRecord record;
|
||||||
|
auto op = as<GBMMObj>(_op);
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
|
||||||
|
record.time = std::numeric_limits<double>::max();
|
||||||
|
const auto [warmupRounds, timingRounds] =
|
||||||
|
op->getB() > 100 ? tuple{1, 3} : tuple{5, 15};
|
||||||
|
double tmp =
|
||||||
|
timeit([&]() { gbmmKernel(op, context); },
|
||||||
|
[&]() { context->sync(); }, warmupRounds, timingRounds);
|
||||||
|
if (tmp < record.time)
|
||||||
|
record.time = tmp;
|
||||||
|
IT_ASSERT(record.time < std::numeric_limits<double>::max(),
|
||||||
|
"Error occured "
|
||||||
|
"during runtime");
|
||||||
|
return record;
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute(const Operator &_op, const PerfRecord &_record,
|
||||||
|
const RuntimeObj *_context) const override {
|
||||||
|
auto op = as<GBMMObj>(_op);
|
||||||
|
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||||
|
bool success = gbmmKernel(op, context);
|
||||||
|
IT_ASSERT(success);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL(Device::CUDA, OpType::GBMM, DataType::Float32, GBMMCudnn,
|
||||||
|
"GBMM_cuDNN_CUDA_Float32");
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -244,7 +244,6 @@ class convCudnn : public Kernel {
|
||||||
// Update the tune result
|
// Update the tune result
|
||||||
if (ret.time > record.time)
|
if (ret.time > record.time)
|
||||||
ret = record;
|
ret = record;
|
||||||
|
|
||||||
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
|
||||||
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
checkCudnnError(cudnnDestroyActivationDescriptor(actDesc));
|
||||||
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
#include "operators/G2BMM.h"
|
||||||
|
#include "custom_ops.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
G2BMMObj::G2BMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, int width,
|
||||||
|
int dilation, [[maybe_unused]] Tensor bias, ActType act)
|
||||||
|
: OperatorObj(OpType::G2BMM, {A, B}, {C}), width(width), dilation(dilation),
|
||||||
|
act(act), b(A->getDims()[0]), m(A->getDims()[1]), k(A->getDims()[2]) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
string G2BMMObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "G2BMM(["
|
||||||
|
<< "width=" << width << ",act=" << enum_to_underlying(act)
|
||||||
|
<< "],A=" << inputs[0]->getGuid() << ",B=" << inputs[1]->getGuid()
|
||||||
|
<< ",C=" << outputs[0]->getGuid() << ", TTbmnkd: " << this->getB()
|
||||||
|
<< ", " << this->getM() << ", " << this->getWidth() << ", "
|
||||||
|
<< inputs[1]->getDims()[2] << ", " << this->getDilation() << ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> G2BMMObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
auto A = inputs[0], B = inputs[1];
|
||||||
|
|
||||||
|
if (!(A->getDims().size() == 3 && B->getDims().size() == 3))
|
||||||
|
return {};
|
||||||
|
if (!(A->getDims()[0] == B->getDims()[0]))
|
||||||
|
return {};
|
||||||
|
if (!(A->getDims()[1] == B->getDims()[1]))
|
||||||
|
return {};
|
||||||
|
if (!(A->getDims()[2] == B->getDims()[2]))
|
||||||
|
return {};
|
||||||
|
if (width < 0)
|
||||||
|
return {};
|
||||||
|
int b(A->getDims()[0]), m(A->getDims()[1]), n(2 * width + 1);
|
||||||
|
return {{{b, m, n}}};
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> G2BMMObj::getWorkloadVector() const {
|
||||||
|
return {enum_to_underlying(type), b, m, k, width, dilation,
|
||||||
|
enum_to_underlying(act)};
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> G2BMMObj::getOpAttrVector() const {
|
||||||
|
return {enum_to_underlying(type), width, dilation, enum_to_underlying(act)};
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,48 @@
|
||||||
|
#include "operators/GBMM.h"
|
||||||
|
#include "custom_ops.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
GBMMObj::GBMMObj(GraphObj *graph, Tensor A, Tensor B, Tensor C, int dilation,
|
||||||
|
[[maybe_unused]] Tensor bias, ActType act)
|
||||||
|
: OperatorObj(OpType::GBMM, {A, B}, {C}), dilation(dilation), act(act),
|
||||||
|
b(A->getDims()[0]), m(A->getDims()[1]), w((A->getDims()[2] - 1) / 2),
|
||||||
|
n(B->getDims()[2]) {
|
||||||
|
IT_ASSERT(checkValid(graph));
|
||||||
|
}
|
||||||
|
|
||||||
|
string GBMMObj::toString() const {
|
||||||
|
std::ostringstream os;
|
||||||
|
os << "GBMM(["
|
||||||
|
<< ",act=" << (int)act << "],A=" << inputs[0]->getGuid()
|
||||||
|
<< ",B=" << inputs[1]->getGuid() << ",C=" << outputs[0]->getGuid()
|
||||||
|
<< ", TTbmwnd: " << this->getB() << ", " << this->getM() << ", "
|
||||||
|
<< this->getW() << ", " << this->getN() << ", " << this->getDilation()
|
||||||
|
<< ")";
|
||||||
|
return os.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
optional<vector<Shape>> GBMMObj::inferShape(const TensorVec &inputs) const {
|
||||||
|
auto A = inputs[0], B = inputs[1];
|
||||||
|
|
||||||
|
if (!(A->getDims().size() == 3 && B->getDims().size() == 3))
|
||||||
|
return {};
|
||||||
|
if (!(A->getDims()[0] == B->getDims()[0]))
|
||||||
|
return {};
|
||||||
|
if (!(A->getDims()[1] == B->getDims()[1]))
|
||||||
|
return {};
|
||||||
|
if (A->getDims()[2] % 2 == 0)
|
||||||
|
return {};
|
||||||
|
int b(A->getDims()[0]), m(A->getDims()[1]), k(B->getDims()[2]);
|
||||||
|
return {{{b, m, k}}};
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> GBMMObj::getWorkloadVector() const {
|
||||||
|
return {enum_to_underlying(type), b, m, w, n, dilation,
|
||||||
|
enum_to_underlying(act)};
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<int> GBMMObj::getOpAttrVector() const {
|
||||||
|
return {enum_to_underlying(type), dilation, enum_to_underlying(act)};
|
||||||
|
}
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,37 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/G2BMM.h"
|
||||||
|
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
using ExpectOutput = vector<float>;
|
||||||
|
|
||||||
|
TEST(G2BMM, ShapeInference) {
|
||||||
|
const int bs = 1, seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4;
|
||||||
|
const int hidden = featlen, hiddenPerHead = hidden / heads;
|
||||||
|
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
auto ACpu = gCpu->addTensor(Shape{bs * heads, seqlen, hiddenPerHead},
|
||||||
|
DataType::Float32);
|
||||||
|
auto BCpu = gCpu->addTensor(Shape{bs * heads, seqlen, hiddenPerHead},
|
||||||
|
DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
ACpu->setData(IncrementalGenerator());
|
||||||
|
BCpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
auto gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto ACuda = gCuda->cloneTensor(ACpu);
|
||||||
|
auto BCuda = gCuda->cloneTensor(BCpu);
|
||||||
|
auto G2BMM = gCuda->addOp<G2BMMObj>(ACuda, BCuda, nullptr, w, d);
|
||||||
|
EXPECT_EQ(G2BMM->getOutput()->getDims(),
|
||||||
|
(Shape{bs * heads, seqlen, 2 * w + 1}));
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
}
|
||||||
|
|
||||||
|
}; // namespace infini
|
|
@ -0,0 +1,37 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/GBMM.h"
|
||||||
|
#include "test.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
using ExpectOutput = vector<float>;
|
||||||
|
|
||||||
|
TEST(GBMM, ShapeInference) {
|
||||||
|
const int bs = 1, seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4;
|
||||||
|
const int hidden = featlen, hiddenPerHead = hidden / heads;
|
||||||
|
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||||
|
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||||
|
auto ACpu = gCpu->addTensor(Shape{bs * heads, seqlen, w * 2 + 1},
|
||||||
|
DataType::Float32);
|
||||||
|
auto BCpu = gCpu->addTensor(Shape{bs * heads, seqlen, hiddenPerHead},
|
||||||
|
DataType::Float32);
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
ACpu->setData(IncrementalGenerator());
|
||||||
|
BCpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||||
|
auto gCuda = make_ref<GraphObj>(cudaRuntime);
|
||||||
|
auto ACuda = gCuda->cloneTensor(ACpu);
|
||||||
|
auto BCuda = gCuda->cloneTensor(BCpu);
|
||||||
|
auto GBMM = gCuda->addOp<GBMMObj>(ACuda, BCuda, nullptr, d);
|
||||||
|
EXPECT_EQ(GBMM->getOutput()->getDims(),
|
||||||
|
(Shape{bs * heads, seqlen, hiddenPerHead}));
|
||||||
|
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
cudaRuntime->run(gCuda);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
Loading…
Reference in New Issue