Add CUDA runtime (#6)

* Fix: add warm-up and repetition in timing

* Add: CUDA runtime and float support

* Refactor: Cuda and Cpu runtimes inherit Runtime

* Add: environment script for Lotus

* Add: Lotus build instructions

* Update README.md

Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
zhengly123 2022-08-22 15:01:03 +08:00 committed by GitHub
parent 9303ddda8e
commit 04ea5eed38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 875 additions and 127 deletions

View File

@ -4,6 +4,7 @@ include(CMakeDependentOption)
project(InfiniTensor C CXX)
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
option(USE_CUDA "Support CUDA GPU" ON)
option(BUILD_TEST "Build tests" ON)
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF)
@ -21,7 +22,6 @@ find_package(
Python
COMPONENTS Interpreter Development
REQUIRED)
find_package(CUDA REQUIRED)
# OpenMP
find_package(OpenMP)
if(OpenMP_C_FOUND)
@ -53,14 +53,19 @@ if(BUILD_TEST)
endif()
file(GLOB_RECURSE SRC src/*.cc src/*.cu)
# file(GLOB_RECURSE FFI src/ffi/ffi_pet.cc)
# list(REMOVE_ITEM SRC ${TEST} ${FFI})
add_library(InfiniTensor SHARED ${SRC})
# Target
# cuda_add_library(it SHARED ${SRC})
# cuda_add_cublas_to_target(it) # cublas
# # target_link_libraries(infini_cpp cudnn curand nlohmann_json::nlohmann_json pybind11::embed)
if(USE_CUDA)
# set(CUDA_HOST_COMPILER /home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc)
enable_language(CUDA)
# TODO: how to set option for CUDA_HOST_COMPILER. Now env var CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc takes effect.
# option(CUDA_HOST_COMPILER "" ${CMAKE_C_COMPILER})
# TODO: find_package seems unnecessary for CMake >= 3.8
find_package(CUDA REQUIRED)
# message("CUBLAS_LIBRARIES: ${CUDA_LIBRARIES}")
target_link_libraries(InfiniTensor cudnn curand cublas ${CUDA_LIBRARIES})
endif()
# # Python bindings
# pybind11_add_module(infini MODULE ${FFI})

View File

@ -1 +1,9 @@
# InfiniTensor
## Compilation on Lotus
``` bash
# Enter the root of InfiniTensor
source test/script/env_lotus.sh
mkdir build && cd build
cmake .. && make -j 12
```

24
include/core/blob.h Normal file
View File

@ -0,0 +1,24 @@
#pragma once
#include "core/common.h"
#include "core/ref.h"
namespace infini {
class RuntimeObj;
using Runtime = Ref<RuntimeObj>;
class BlobObj {
// Runtime might be replaced with a raw pointer for optimization
Runtime runtime;
void *ptr;
public:
BlobObj(Runtime runtime, void *ptr) : runtime(runtime), ptr(ptr) {}
BlobObj(BlobObj &other) = delete;
BlobObj &operator=(BlobObj const &) = delete;
~BlobObj();
template <typename T> T getPtr() const { return reinterpret_cast<T>(ptr); }
};
} // namespace infini

View File

@ -69,6 +69,8 @@ template <typename T> std::string vecToString(const std::vector<T> &vec) {
return ret;
}
double timeit(const std::function<void()> &func);
double timeit(const std::function<void()> &func, int warmupRounds = 200,
int timingRounds = 200,
const std::function<void(void)> &sync = {});
} // namespace infini

View File

@ -4,16 +4,16 @@
namespace infini {
// TODO: graph should be attached to a context
class GraphObj : public Object {
protected:
Runtime runtime;
TensorVec tensors;
TensorVec inputs;
TensorVec outputs;
OpVec ops;
public:
// Graph(OpVec oplist);
GraphObj(Runtime runtime) : runtime(runtime){};
string toString() const override;
Tensor addTensor(Shape dim, DataType dtype = DataType::UInt32);

View File

@ -5,8 +5,14 @@
namespace infini {
class RuntimeObj; // Forward declaration for Kernel::compute
struct PerfRecord {
double time; // in milliseconds
PerfRecord(){};
PerfRecord(double time) : time(time){};
virtual ~PerfRecord() {}
double time = 0; // in milliseconds
};
class Kernel {
@ -20,14 +26,16 @@ class Kernel {
* are required, inherit from PerfRecord and add extra parameters.
* Otherwire, use PerfRecord directly.
*/
virtual void compute(const Operator &op,
const PerfRecord &record) const = 0;
virtual void compute(const Operator &op, const PerfRecord &record,
const RuntimeObj *context) const = 0;
/**
* @brief Executes an op with a default parameter.
*/
virtual void compute(const Operator &op) const = 0;
virtual void compute(const Operator &op,
const RuntimeObj *context) const = 0;
// Premise: op is idempotent since it is called multiple times.
virtual PerfRecord tune(const Operator &op) const = 0;
virtual PerfRecord tune(const Operator &op,
const RuntimeObj *context) const = 0;
};
class KernelRegistry {

View File

@ -1,26 +0,0 @@
#pragma once
#include "core/graph.h"
#include "core/kernel.h"
#include "core/perf_engine.h"
namespace infini {
class RunEngine {
private:
Device device;
public:
RunEngine(Device device) : device(device) {}
~RunEngine() {}
void run(const Graph &graph, bool tune = false,
bool profiling = false) const;
double getPerfTime(const Graph &graph, bool profiling = false) const;
private:
void printProfilingData(double totTime,
const std::map<OpType, double> &opTime,
const std::map<OpType, int> &opCnt) const;
};
} // namespace infini

62
include/core/runtime.h Normal file
View File

@ -0,0 +1,62 @@
#pragma once
#include "core/graph.h"
#include "core/kernel.h"
#include "core/perf_engine.h"
namespace infini {
class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
protected:
Device device;
public:
RuntimeObj(Device device) : device(device) {}
RuntimeObj(RuntimeObj &other) = delete;
RuntimeObj &operator=(RuntimeObj const &) = delete;
virtual ~RuntimeObj() {}
/**
* @brief Execute a graph.
*
* @param graph
* @param tune If there is no performance record, whether to tune it. These
* can be independent method.
* @param profiling Whether to print breakdown of time
*/
virtual void run(const Graph &graph, bool tune = false,
bool profiling = false) const = 0;
virtual void *alloc(size_t size) = 0;
virtual void dealloc(void *ptr) = 0;
/**
* @brief Get the execution time of each operator in performance record. No
* execution happens.
*
* @param graph
* @param profiling Whether to print breakdown of time
* @return double Return the sum of perf time for each operator
*/
double getPerfTime(const Graph &graph, bool profiling = false) const;
Blob allocBlob(size_t size);
protected:
void printProfilingData(double totTime,
const std::map<OpType, double> &opTime,
const std::map<OpType, int> &opCnt) const;
};
// TODO: change inheritance relation
class CpuRuntimeObj : public RuntimeObj {
public:
CpuRuntimeObj() : RuntimeObj(Device::CPU) {}
void run(const Graph &graph, bool tune = false,
bool profiling = false) const override;
void dealloc(void *ptr) override { return free(ptr); };
void *alloc(size_t size) override {
return calloc((size + sizeof(uint64_t) - 1) / sizeof(uint64_t),
sizeof(uint64_t));
};
};
} // namespace infini

View File

@ -1,5 +1,6 @@
#pragma once
#include "core/tensor_base.h"
#include <cmath>
namespace infini {
@ -16,21 +17,46 @@ class TensorObj : public TensorBaseObj {
string toString() const override;
size_t size() const;
void dataMalloc();
Shape getDims() const { return shape; }
size_t getOffset(const Shape &ds) const;
using TensorBaseObj::getData;
VType getData(const Shape &pos) const;
void copyData(VType *dptr);
void dataMalloc(const Runtime &runtime);
// void copyData(VType *dptr);
template <typename T> void copyData(const T *dptr);
void copyData(vector<VType> dataVector);
void copyData(vector<float> dataVector);
void printData() const;
// TODO: merge these methods
bool equalData(const Tensor &rhs) const;
void
setData(std::function<void(void *, size_t, DataType)> generator) const {
generator((void *)(data.get()), size(), dtype);
template <typename T> bool equalData(const Tensor &rhs) const {
IT_ASSERT(data != nullptr);
IT_ASSERT(rhs->data != nullptr);
// TODO: deal with data type
auto ptr = data->getPtr<T *>();
auto ptrRhs = rhs->data->getPtr<T *>();
if (shape != rhs->getDims())
return false;
size_t sz = size();
for (size_t i = 0; i < sz; ++i)
if (fabs(ptr[i] - ptrRhs[i]) /
std::max(fabs(ptr[i]), fabs(ptrRhs[i])) >
1e-6) {
printf("Error on %lu: %f %f\n", i, ptr[i], ptrRhs[i]);
return false;
}
return true;
}
void setData(
const std::function<void(void *, size_t, DataType)> &generator) const {
generator(data->getPtr<void *>(), size(), dtype);
}
private:
void printDataFloat() const;
void printDataUint32_t() const;
// void setDims(const Dim &dms) { dims = dms; }
// bool dataRand(int seed = 0) {

View File

@ -1,19 +1,23 @@
#pragma once
#include "core/blob.h"
#include "core/object.h"
#include "core/ref.h"
namespace infini {
// class Tensor;
class TensorBaseObj;
class TensorObj;
class OperatorObj;
class GraphObj;
class RuntimeObj;
class BlobObj;
using TensorBase = Ref<TensorBaseObj>;
using Tensor = Ref<TensorObj>;
using Operator = Ref<OperatorObj>;
using Graph = Ref<GraphObj>;
using Runtime = Ref<RuntimeObj>;
using Blob = Ref<BlobObj>;
using TensorVec = vector<Tensor>;
using OpVec = vector<Operator>;
@ -40,8 +44,7 @@ class TensorBaseObj : public Object {
DataType dtype;
vector<WRef<TensorBaseObj>> inputOf;
WRef<TensorBaseObj> outputOf;
// TODO: Ref<void> -> Ref<Blob>
Ref<VType[]> data;
Blob data;
// ComputeState computed;
// static int random_seed[256 * 16];
// static bool random_inited;
@ -50,7 +53,15 @@ class TensorBaseObj : public Object {
TensorBaseObj(int dim, DataType dtype);
virtual ~TensorBaseObj() {}
Ref<VType[]> getDataPtr() const { return data; }
void dataMalloc(const Blob &blob) {
IT_ASSERT(data == nullptr);
data = blob;
}
Blob getDataPtr() const { return data; }
template <typename T> T getDataRawPtr() const {
IT_ASSERT(data != nullptr);
return data->getPtr<T>();
}
VType getData(size_t offset) const;
DataType getDType() const { return dtype; }

111
include/cuda/cuda_common.h Normal file
View File

@ -0,0 +1,111 @@
#pragma once
#include "core/common.h"
#include <cublas_v2.h>
#include <cuda.h>
#include <cudnn.h>
#include <curand.h>
// TODO: replace with Exception (IT_ASSERT)
#define checkCudaError(call) \
{ \
auto err = call; \
if (cudaSuccess != err) { \
fprintf(stderr, "Cuda error in %s:%i : %s.\n", __FILE__, __LINE__, \
cudaGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#define checkCublasError(call) \
{ \
auto err = call; \
if (CUBLAS_STATUS_SUCCESS != err) { \
fprintf(stderr, "cuBLAS error in %s:%i : %s.\n", __FILE__, \
__LINE__, cublasGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#define checkCudnnError(call) \
{ \
auto err = call; \
if (CUDNN_STATUS_SUCCESS != err) { \
fprintf(stderr, "cuDNN error in %s:%i : %s.\n", __FILE__, \
__LINE__, cudnnGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
#define checkCurandError(call) \
{ \
auto err = call; \
if (CURAND_STATUS_SUCCESS != err) { \
fprintf(stderr, "cuRAND error in %s:%i : %s.\n", __FILE__, \
__LINE__, curandGetErrorString(err)); \
exit(EXIT_FAILURE); \
} \
}
namespace infini {
inline const char *cublasGetErrorString(cublasStatus_t error) {
switch (error) {
case CUBLAS_STATUS_SUCCESS:
return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED:
return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED:
return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE:
return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH:
return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR:
return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED:
return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR:
return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED:
return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR:
return "CUBLAS_STATUS_LICENSE_ERROR";
}
return "<unknown>";
}
inline const char *curandGetErrorString(curandStatus_t error) {
switch (error) {
case CURAND_STATUS_SUCCESS:
return "CURAND_STATUS_SUCCESS";
case CURAND_STATUS_VERSION_MISMATCH:
return "CURAND_STATUS_VERSION_MISMATCH";
case CURAND_STATUS_NOT_INITIALIZED:
return "CURAND_STATUS_NOT_INITIALIZED";
case CURAND_STATUS_ALLOCATION_FAILED:
return "CURAND_STATUS_ALLOCATION_FAILED";
case CURAND_STATUS_TYPE_ERROR:
return "CURAND_STATUS_TYPE_ERROR";
case CURAND_STATUS_OUT_OF_RANGE:
return "CURAND_STATUS_OUT_OF_RANGE";
case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
case CURAND_STATUS_LAUNCH_FAILURE:
return "CURAND_STATUS_LAUNCH_FAILURE";
case CURAND_STATUS_PREEXISTING_FAILURE:
return "CURAND_STATUS_PREEXISTING_FAILURE";
case CURAND_STATUS_INITIALIZATION_FAILED:
return "CURAND_STATUS_INITIALIZATION_FAILED";
case CURAND_STATUS_ARCH_MISMATCH:
return "CURAND_STATUS_ARCH_MISMATCH";
case CURAND_STATUS_INTERNAL_ERROR:
return "CURAND_STATUS_INTERNAL_ERROR";
}
return "<unknown>";
}
using CudaPtr = void *;
} // namespace infini

View File

@ -0,0 +1,49 @@
#pragma once
#include "core/runtime.h"
#include "cuda/cuda_common.h"
namespace infini {
class CudaRuntimeObj : public RuntimeObj {
private:
cudnnHandle_t cudnn;
cublasHandle_t cublas;
CudaPtr workspace;
size_t workspaceSize;
public:
CudaRuntimeObj() : RuntimeObj(Device::CUDA) {
checkCudnnError(cudnnCreate(&cudnn));
checkCublasError(cublasCreate(&cublas));
// 10GB for Longformer
// size_t longformerNum = 3lu * (1 << 30);
workspaceSize = 7ll << 30; // 7 GB
workspace = alloc(workspaceSize);
}
virtual ~CudaRuntimeObj() {
checkCudnnError(cudnnDestroy(cudnn));
checkCublasError(cublasDestroy(cublas));
}
void run(const Graph &graph, bool tune = false,
bool profiling = false) const;
// double runEvaluation(const Graph &graph, int nWarmups,
// int nEvaluations) const;
void sync() const;
CudaPtr alloc(size_t size) override {
void *ptr;
checkCudaError(cudaMalloc(&ptr, size));
return ptr;
}
void dealloc(void *ptr) override { checkCudaError(cudaFree(ptr)); }
cudnnHandle_t cudnnHandle() const { return cudnn; }
cublasHandle_t cublasHandle() const { return cublas; }
CudaPtr getWorkspace(size_t size) const {
IT_ASSERT(size <= workspaceSize);
return workspace;
}
private:
void runWithoutSync(const Graph &graph) const;
};
} // namespace infini

View File

@ -0,0 +1,11 @@
#include "core/tensor.h"
namespace infini {
void cudaPrintFloat(float *x, int len);
void cudaPrintTensor(const Tensor &tensor) {
cudaPrintFloat(tensor->getDataRawPtr<float *>(), tensor->size());
}
} // namespace infini

View File

@ -5,12 +5,14 @@
namespace infini {
// TODO: isolate these class
class DataGenerator {
private:
virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); };
virtual void fill(float *data, size_t size) { IT_TODO_HALT(); };
virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); }
virtual void fill(float *data, size_t size) { IT_TODO_HALT(); }
public:
virtual ~DataGenerator() {}
void operator()(void *data, size_t size, DataType dataType) {
switch (dataType) {
case DataType::UInt32:
@ -26,10 +28,36 @@ class DataGenerator {
};
class IncrementalGenerator : public DataGenerator {
void fill(uint32_t *data, size_t size) override {
public:
virtual ~IncrementalGenerator() {}
private:
template <typename T> void fill(T *data, size_t size) {
for (size_t i = 0; i < size; i++) {
data[i] = i;
}
}
void fill(uint32_t *data, size_t size) override {
fill<uint32_t>(data, size);
}
void fill(float *data, size_t size) override { fill<float>(data, size); }
};
class OneGenerator : public DataGenerator {
public:
virtual ~OneGenerator() {}
private:
template <typename T> void fill(T *data, size_t size) {
for (size_t i = 0; i < size; i++) {
data[i] = 1;
}
}
void fill(uint32_t *data, size_t size) override {
fill<uint32_t>(data, size);
}
void fill(float *data, size_t size) override { fill<float>(data, size); }
};
} // namespace infini

10
src/core/blob.cc Normal file
View File

@ -0,0 +1,10 @@
#include "core/runtime.h"
namespace infini {
BlobObj::~BlobObj() {
// Avoid cycled inclusion
runtime->dealloc(ptr);
}
} // namespace infini

View File

@ -4,11 +4,20 @@
namespace infini {
double timeit(const std::function<void()> &func) {
double timeit(const std::function<void()> &func, int warmupRounds,
int timingRounds, const std::function<void(void)> &sync) {
for (int i = 0; i < warmupRounds; ++i)
func();
if (sync)
sync();
auto start = std::chrono::high_resolution_clock::now();
func();
for (int i = 0; i < timingRounds; ++i)
func();
if (sync)
sync();
auto end = std::chrono::high_resolution_clock::now();
return std::chrono::duration<double, std::milli>(end - start).count();
return std::chrono::duration<double, std::milli>(end - start).count() /
timingRounds;
}
} // namespace infini

View File

@ -13,8 +13,9 @@ string GraphObj::toString() const {
}
void GraphObj::dataMalloc() {
for (auto &tensor : tensors)
tensor->dataMalloc();
for (auto &tensor : tensors) {
tensor->dataMalloc(runtime);
}
}
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {

View File

@ -1,9 +1,14 @@
#include "core/run_enigne.h"
#include "core/runtime.h"
#include "core/blob.h"
#include <chrono>
#include <cuda.h>
#include <cuda_profiler_api.h>
#include <cudnn.h>
#include <curand.h>
namespace infini {
void RunEngine::run(const Graph &graph, bool tune, bool profiling) const {
void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
if (!tune && profiling)
IT_TODO_HALT();
const auto &kernelRegistry = KernelRegistry::getInstance();
@ -12,7 +17,6 @@ void RunEngine::run(const Graph &graph, bool tune, bool profiling) const {
double totalTime = 0;
std::map<OpType, double> opTime;
std::map<OpType, int> opCnt;
std::chrono::system_clock::time_point begin, end;
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
@ -24,7 +28,7 @@ void RunEngine::run(const Graph &graph, bool tune, bool profiling) const {
// If no record and disable tuning, run with the default argument
if (!perfData && !tune) {
kernel->compute(op);
kernel->compute(op, this);
continue;
}
@ -32,16 +36,18 @@ void RunEngine::run(const Graph &graph, bool tune, bool profiling) const {
PerfRecord record;
// Tune the kernel if there is no record
if (!perfData) {
record = kernel->tune(op);
// TODO: record is not used
record = kernel->tune(op, this);
perfEngine.setPerfData(perfKey, record);
} else
record = *perfData;
if (!profiling) {
kernel->compute(op, *perfData);
kernel->compute(op, *perfData, this);
continue;
} else {
double t = timeit([&]() { kernel->compute(op, *perfData); });
double t =
timeit([&]() { kernel->compute(op, *perfData, this); }, 1, 1);
op->print();
printf(" op_time %lf\n", t);
totalTime += t;
@ -53,7 +59,7 @@ void RunEngine::run(const Graph &graph, bool tune, bool profiling) const {
printProfilingData(totalTime, opTime, opCnt);
}
double RunEngine::getPerfTime(const Graph &graph, bool profiling) const {
double RuntimeObj::getPerfTime(const Graph &graph, bool profiling) const {
const auto &kernelRegistry = KernelRegistry::getInstance();
auto perfEngine = PerfEngine::getInstance();
// Statistics
@ -72,7 +78,7 @@ double RunEngine::getPerfTime(const Graph &graph, bool profiling) const {
PerfRecord record;
// Tune the kernel if there is no record
if (!perfData) {
record = kernel->tune(op);
record = kernel->tune(op, this);
perfEngine.setPerfData(perfKey, record);
} else
record = *perfData;
@ -91,9 +97,9 @@ double RunEngine::getPerfTime(const Graph &graph, bool profiling) const {
return totalTime;
}
void RunEngine::printProfilingData(double totalTime,
const std::map<OpType, double> &opTime,
const std::map<OpType, int> &opCnt) const {
void RuntimeObj::printProfilingData(double totalTime,
const std::map<OpType, double> &opTime,
const std::map<OpType, int> &opCnt) const {
printf("%11s %3s %7s %7s %7s\n", "Op", "Cnt", "T_tot", "Percent", "T_mean");
for (const auto &[type, t] : opTime) {
printf("%11s %3d %7.3f %7.1f %7.3f\n",
@ -102,4 +108,8 @@ void RunEngine::printProfilingData(double totalTime,
}
}
Blob RuntimeObj::allocBlob(size_t size) {
return make_ref<BlobObj>(shared_from_this(), alloc(size));
}
} // namespace infini

View File

@ -1,15 +1,12 @@
#include <core/tensor.h>
#include "core/tensor.h"
#include "core/blob.h"
#include "core/runtime.h"
namespace infini {
TensorObj::TensorObj(const Shape &shape, DataType dtype)
: TensorBaseObj(shape.size(), dtype), shape(shape) {}
void TensorObj::dataMalloc() {
IT_ASSERT(data == nullptr);
// initialized to zero
data.reset(reinterpret_cast<VType *>(calloc(size(), sizeof(VType))));
}
VType TensorObj::getData(const Shape &pos) const {
return getData(getOffset(pos));
}
@ -37,24 +34,42 @@ size_t TensorObj::size() const {
return ret;
}
void TensorObj::copyData(VType *dptr) {
template <typename T> void TensorObj::copyData(const T *dptr) {
// TODO: cuda
IT_ASSERT(data != nullptr);
auto ptr = data->getPtr<T *>();
size_t sz = size();
#pragma omp parallel for
for (size_t i = 0; i < sz; ++i) {
data[i] = dptr[i];
ptr[i] = dptr[i];
}
}
void TensorObj::copyData(vector<VType> dataVector) {
IT_ASSERT(dataVector.size() >= size());
copyData(dataVector.data());
}
void TensorObj::copyData(vector<float> dataVector) {
IT_ASSERT(dataVector.size() >= size());
copyData(dataVector.data());
}
void TensorObj::printData() const {
IT_ASSERT(data != nullptr);
if (dtype == DataType::Float32)
printDataFloat();
else if (dtype == DataType::UInt32)
printDataUint32_t();
else
IT_TODO_HALT();
}
void TensorObj::printDataFloat() const {
std::cout << "Tensor: " << guid << std::endl;
auto numDims = shape.size();
auto dimSzVec = std::vector<int>(numDims, 1);
auto ptr = data->getPtr<float *>();
dimSzVec[numDims - 1] = shape[numDims - 1];
for (int i = numDims - 1; i != 0; --i)
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
@ -64,7 +79,35 @@ void TensorObj::printData() const {
std::cout << "[";
}
}
std::cout << data[i];
printf("%.1f", ptr[i]);
for (size_t j = 0; j < numDims; ++j) {
if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
std::cout << "]";
}
}
if (i != size() - 1)
std::cout << ", ";
if ((int)i % dimSzVec[numDims - 1] == dimSzVec[numDims - 1] - 1)
std::cout << std::endl;
}
}
void TensorObj::printDataUint32_t() const {
IT_ASSERT(data != nullptr);
std::cout << "Tensor: " << guid << std::endl;
auto numDims = shape.size();
auto dimSzVec = std::vector<int>(numDims, 1);
auto ptr = data->getPtr<VType *>();
dimSzVec[numDims - 1] = shape[numDims - 1];
for (int i = numDims - 1; i != 0; --i)
dimSzVec[i - 1] = dimSzVec[i] * shape[i - 1];
for (size_t i = 0, iEnd = size(); i < iEnd; ++i) {
for (size_t j = 0; j < numDims; ++j) {
if (i % dimSzVec[j] == 0) {
std::cout << "[";
}
}
std::cout << ptr[i];
for (size_t j = 0; j < numDims; ++j) {
if ((int)i % dimSzVec[j] == dimSzVec[j] - 1) {
std::cout << "]";
@ -80,13 +123,26 @@ void TensorObj::printData() const {
bool TensorObj::equalData(const Tensor &rhs) const {
IT_ASSERT(data != nullptr);
IT_ASSERT(rhs->data != nullptr);
// TODO: deal with data type
auto ptr = data->getPtr<VType *>();
auto ptrRhs = rhs->data->getPtr<VType *>();
if (shape != rhs->getDims())
return false;
size_t sz = size();
for (size_t i = 0; i < sz; ++i)
if (data[i] != rhs->data[i])
if (ptr[i] != ptrRhs[i])
return false;
return true;
}
void TensorObj::dataMalloc(const Runtime &runtime) {
IT_ASSERT(data == nullptr);
size_t bytesPerElement;
if (getDType() == DataType::Float32)
bytesPerElement = sizeof(float);
else if (getDType() == DataType::UInt32)
bytesPerElement = sizeof(uint32_t);
data = runtime->allocBlob(size() * bytesPerElement);
}
}; // namespace infini

View File

@ -1,9 +1,14 @@
#include <core/tensor_base.h>
#include "core/tensor_base.h"
#include "core/blob.h"
#include "core/runtime.h"
namespace infini {
TensorBaseObj::TensorBaseObj(int dim, DataType dtype)
: dim(dim), dtype(dtype) {}
VType TensorBaseObj::getData(size_t offset) const { return data[offset]; }
VType TensorBaseObj::getData(size_t offset) const {
// TODO: check cuda array
return (data->getPtr<VType *>())[offset];
}
}; // namespace infini

32
src/cuda/cuda_runtime.cc Normal file
View File

@ -0,0 +1,32 @@
#include "cuda/cuda_runtime.h"
namespace infini {
void CudaRuntimeObj::runWithoutSync(const Graph &graph) const {
const auto &kernelRegistry = KernelRegistry::getInstance();
auto perfEngine = PerfEngine::getInstance();
for (auto &op : graph->getOperators()) {
// HACK: set correct data type
auto kernelAttrs =
KernelAttrs{device, op->getOpType(), DataType::Float32};
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
std::optional<PerfRecord> perfData = perfEngine.getPerfData(perfKey);
if (perfData)
kernel->compute(op, *perfData, this);
else
kernel->compute(op, this);
}
}
void CudaRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
if (tune || profiling)
IT_TODO_HALT();
runWithoutSync(graph);
sync();
}
void CudaRuntimeObj::sync() const { cudaDeviceSynchronize(); }
} // namespace infini

20
src/cuda/cuda_utility.cu Normal file
View File

@ -0,0 +1,20 @@
#include "cuda/cuda_common.h"
#include <cstdio>
__global__ void cudaPrintFloatImpl(float *x, int len) {
int start = threadIdx.x + blockDim.x * blockIdx.x;
if (start == 0) {
for (int i = 0; i < len; ++i) {
printf("%.3f ", x[i]);
}
printf("\n");
}
}
namespace infini {
void cudaPrintFloat(float *x, int len) {
cudaPrintFloatImpl<<<1, 1>>>(x, len);
cudaDeviceSynchronize();
}
} // namespace infini

View File

@ -4,11 +4,12 @@
namespace infini {
template <typename T> class NaiveConv : public Kernel {
void compute(const Operator &_op, const PerfRecord &record) const override {
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *context) const override {
auto op = as<ConvObj>(_op);
T *iptr = reinterpret_cast<T *>(op->getInputs(0)->getDataPtr().get());
T *wptr = reinterpret_cast<T *>(op->getInputs(1)->getDataPtr().get());
T *optr = reinterpret_cast<T *>(op->getOutput()->getDataPtr().get());
T *iptr = op->getInputs(0)->getDataRawPtr<T *>();
T *wptr = op->getInputs(1)->getDataRawPtr<T *>();
T *optr = op->getOutput()->getDataRawPtr<T *>();
auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
int cpg = op->getChannelPerGroup();
@ -45,10 +46,13 @@ template <typename T> class NaiveConv : public Kernel {
}
}
void compute(const Operator &op) const override { compute(op, {}); }
void compute(const Operator &op, const RuntimeObj *context) const override {
compute(op, {}, context);
}
PerfRecord tune(const Operator &op) const override {
return PerfRecord{.time = timeit([this, &op]() { compute(op); })};
PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
return PerfRecord(timeit([&]() { compute(op, context); }));
}
};

View File

@ -4,11 +4,12 @@
namespace infini {
template <typename T> class NaiveMatmul : public Kernel {
void compute(const Operator &_op, const PerfRecord &record) const override {
void compute(const Operator &_op, const PerfRecord &record,
const RuntimeObj *context) const override {
auto op = as<MatmulObj>(_op);
T *A = reinterpret_cast<T *>(op->getInputs(0)->getDataPtr().get());
T *B = reinterpret_cast<T *>(op->getInputs(1)->getDataPtr().get());
T *C = reinterpret_cast<T *>(op->getOutput()->getDataPtr().get());
T *A = op->getInputs(0)->getDataRawPtr<T *>();
T *B = op->getInputs(1)->getDataRawPtr<T *>();
T *C = op->getOutput()->getDataRawPtr<T *>();
IT_ASSERT(op->getTransA() == false && op->getTransB() == false);
IT_ASSERT(op->getAct() == ActType::None);
IT_ASSERT(op->getB() == 1);
@ -23,10 +24,15 @@ template <typename T> class NaiveMatmul : public Kernel {
}
}
void compute(const Operator &op) const override { compute(op, {}); }
void compute(const Operator &op, const RuntimeObj *context) const override {
compute(op, {}, context);
}
PerfRecord tune(const Operator &op) const override {
return PerfRecord{.time = timeit([this, &op]() { compute(op); })};
PerfRecord tune(const Operator &op,
const RuntimeObj *context) const override {
PerfRecord ret;
ret.time = timeit([&]() { compute(op, context); });
return ret;
}
};

208
src/kerels/cuda/conv.cc Normal file
View File

@ -0,0 +1,208 @@
#include "operators/conv.h"
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
namespace infini {
static constexpr int N_ALGO = 8;
static constexpr cudnnConvolutionFwdAlgo_t ALGOS[N_ALGO] = {
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT,
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED};
struct ConvCuDnnPerfRecord : public PerfRecord {
int algo = 0; // cudnnConvolutionFwdAlgo_t
size_t workspaceSize = 100000;
bool fuseAct = false;
};
class convCudnn : public Kernel {
bool cuDNNUnfused(const Ref<ConvObj> &op, const ConvCuDnnPerfRecord &record,
const CudaRuntimeObj *context) const {
cudnnStatus_t stat;
void *const inData = (op->getInputs(0)->getDataRawPtr<void *>());
void *const knData = (op->getInputs(1)->getDataRawPtr<void *>());
if (op->getInputs(2) != nullptr)
IT_TODO_HALT();
// void *const biasData = (op->getInputs(2)->getDataRawPtr<void *>());
void *const outData = (op->getOutput()->getDataRawPtr<void *>());
const auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
const int cpg = op->getChannelPerGroup();
const int g = c / cpg;
const auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
int channelsPerGrp = cpg, channels = c;
// get inputs
cudnnTensorDescriptor_t inDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&inDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
inDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, channels, h, w));
// get kernels
cudnnFilterDescriptor_t knDesc;
checkCudnnError(cudnnCreateFilterDescriptor(&knDesc));
checkCudnnError(cudnnSetFilter4dDescriptor(knDesc, CUDNN_DATA_FLOAT,
CUDNN_TENSOR_NCHW, f,
channelsPerGrp, r, s));
// get bias
cudnnTensorDescriptor_t biasDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&biasDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(
biasDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, f, 1, 1));
// get convlution descriptor
cudnnConvolutionDescriptor_t convDesc;
checkCudnnError(cudnnCreateConvolutionDescriptor(&convDesc));
// TODO: CUDNN_CONVOLUTION is a tunable argument
checkCudnnError(cudnnSetConvolution2dDescriptor(
convDesc, ph, pw, sh, sw, dh, dw, CUDNN_CROSS_CORRELATION,
CUDNN_DATA_FLOAT));
if (g > 1) {
checkCudnnError(cudnnSetConvolutionGroupCount(convDesc, g));
}
// get activation descriptor
cudnnActivationDescriptor_t actDesc;
checkCudnnError(cudnnCreateActivationDescriptor(&actDesc));
// NOT_PROPAGATE_NAN is requierd by
// cudnnConvolotionBiasActivationForward
switch (op->getAct()) {
case ActType::Relu:
checkCudnnError(cudnnSetActivationDescriptor(
actDesc, CUDNN_ACTIVATION_RELU, CUDNN_NOT_PROPAGATE_NAN, 0));
break;
case ActType::Sigmoid:
checkCudnnError(cudnnSetActivationDescriptor(
actDesc, CUDNN_ACTIVATION_SIGMOID, CUDNN_NOT_PROPAGATE_NAN, 0));
break;
case ActType::None:
checkCudnnError(
cudnnSetActivationDescriptor(actDesc, CUDNN_ACTIVATION_IDENTITY,
CUDNN_NOT_PROPAGATE_NAN, 0));
break;
default:
assert(false);
}
// get outputs
int outn, outc, outh, outw;
checkCudnnError(cudnnGetConvolution2dForwardOutputDim(
convDesc, inDesc, knDesc, &outn, &outc, &outh, &outw));
cudnnTensorDescriptor_t outDesc;
checkCudnnError(cudnnCreateTensorDescriptor(&outDesc));
checkCudnnError(cudnnSetTensor4dDescriptor(outDesc, CUDNN_TENSOR_NCHW,
CUDNN_DATA_FLOAT, outn, outc,
outh, outw));
IT_ASSERT((vector{outn, outc, outh, outw}) ==
op->getOutput()->getDims(),
"cuDNN output shape mismatches with OP output shape");
// get workspace
size_t wsSize = record.workspaceSize;
stat = cudnnGetConvolutionForwardWorkspaceSize(
context->cudnnHandle(), inDesc, knDesc, convDesc, outDesc,
ALGOS[record.algo], &wsSize);
if (stat != CUDNN_STATUS_SUCCESS)
return false;
// assert(wsSize < (size_t)3 * 1024 * 1024 * 1024);
// if (wsSize >= (size_t)10 * 1024 * 1024 * 1024)
// continue;
CudaPtr wsData = context->getWorkspace(wsSize);
float alpha = 1.f, beta = 0.f;
stat = cudnnConvolutionForward(context->cudnnHandle(), &alpha, inDesc,
inData, knDesc, knData, convDesc,
ALGOS[record.algo], wsData, wsSize,
&beta, outDesc, outData);
if (stat != CUDNN_STATUS_SUCCESS)
return false;
// TODO:
// // bias
// if (bias != nullptr) {
// auto sz = op.getOutputs()[0]->size();
// // TODO: element wise
// t += sz * 2 / 400;
// }
// // act
// if (act != None) {
// stat = cudnnActivationForward(cudnnHandle(), actDesc,
// &alpha, inDesc, inData,
// &beta, outDesc, outData);
// checkCudaError(cudaDeviceSynchronize());
// end = ch::high_resolution_clock::now();
// if (stat != CUDNN_STATUS_SUCCESS) {
// durtime = INFINITY;
// break;
// }
// t +=
// ch::duration_cast<ch::duration<double>>(end -
// beg).count() * 1000; // ms
// }
// best = ConvResult{durtime, ALGOS[i], wsSize, false};
// // w/ bias & act
// for (int j = 0; j < rounds + warmupRounds; ++j) {
// cudnnStatus_t stat;
// if (j == warmupRounds) {
// checkCudaError(cudaDeviceSynchronize());
// beg = ch::high_resolution_clock::now();
// }
// stat = cudnnConvolutionBiasActivationForward(
// cudnnHandle(), &alpha, inDesc, inData, knDesc, knData,
// convDesc, ALGOS[i], wsData, wsSize, &beta, outDesc,
// outData, biasDesc, biasData, actDesc, outDesc, outData);
// if (stat != CUDNN_STATUS_SUCCESS) {
// // checkCudnnError(stat);
// // Do not checkCudnnError since not all algorithms are
// // supported
// durtime_fuse = INFINITY;
// break;
// }
// }
// Destories in CUDA does not require sync. But cuDNN does not state
// whether sync is required before destories.
checkCudnnError(cudnnDestroyTensorDescriptor(inDesc));
checkCudnnError(cudnnDestroyTensorDescriptor(outDesc));
checkCudnnError(cudnnDestroyFilterDescriptor(knDesc));
checkCudnnError(cudnnDestroyConvolutionDescriptor(convDesc));
return true;
}
void compute(const Operator &op, const RuntimeObj *context) const override {
ConvCuDnnPerfRecord record; // with paramters in default ctor
compute(op, record, context);
}
PerfRecord tune(const Operator &_op,
const RuntimeObj *_context) const override {
// TODO: real tuning
ConvCuDnnPerfRecord ret;
ret.time = timeit([&]() { compute(_op, _context); });
return ret;
}
void compute(const Operator &_op, const PerfRecord &_record,
const RuntimeObj *_context) const override {
auto op = as<ConvObj>(_op);
auto &record = dynamic_cast<const ConvCuDnnPerfRecord &>(_record);
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
bool success = cuDNNUnfused(op, record, context);
IT_ASSERT(success);
}
};
REGISTER_KERNEL(Device::CUDA, OpType::Conv, DataType::Float32, convCudnn,
"Conv_cuDNN_CUDA_Float32");
} // namespace infini

View File

@ -1,45 +1,48 @@
#include "core/blob.h"
#include "core/graph.h"
#include "core/run_enigne.h"
#include "core/runtime.h"
#include "operators/matmul.h"
#include "test.h"
namespace infini {
TEST(Graph, build_and_run) {
Graph g = make_ref<GraphObj>();
Runtime runtime = make_ref<CpuRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
g->dataMalloc();
i0->copyData({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
w0->copyData({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
g->addOpWithOutputs<MatmulObj>(i0, w0, o0);
RunEngine(Device::CPU).run(g);
runtime->run(g);
// check answer
auto ans = make_ref<TensorObj>(Shape{1, 2, 4}, DataType::UInt32);
ans->dataMalloc();
ans->copyData({38, 44, 50, 56, 83, 98, 113, 128});
ans->dataMalloc(runtime);
ans->copyData(vector<uint32_t>{38, 44, 50, 56, 83, 98, 113, 128});
EXPECT_TRUE(o0->equalData(ans));
}
TEST(Graph, perf_engine) {
Graph g = make_ref<GraphObj>();
Runtime runtime = make_ref<CpuRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
auto matmul = g->addOp<MatmulObj>(i0, w0, nullptr);
g->dataMalloc();
i0->copyData({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
w0->copyData({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
RunEngine(Device::CPU).run(g, true, true);
double perfTime = RunEngine(Device::CPU).getPerfTime(g);
i0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
w0->copyData(vector<uint32_t>{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
runtime->run(g, true, true);
double perfTime = runtime->getPerfTime(g);
// The example matmul takes 0.0036ms with one core
EXPECT_GT(perfTime, 0);
EXPECT_LT(perfTime, 0.01);
// check answer
auto ans = make_ref<TensorObj>(Shape{1, 2, 4}, DataType::UInt32);
ans->dataMalloc();
ans->copyData({38, 44, 50, 56, 83, 98, 113, 128});
ans->dataMalloc(runtime);
ans->copyData(vector<uint32_t>{38, 44, 50, 56, 83, 98, 113, 128});
EXPECT_TRUE(matmul->getOutput()->equalData(ans));
}

View File

@ -1,5 +1,5 @@
#include "core/graph.h"
#include "core/run_enigne.h"
#include "core/runtime.h"
#include "operators/matmul.h"
#include "test.h"
@ -8,7 +8,7 @@ namespace infini {
TEST(Hash, OperatorHash) {
OpPerfKey key1(0, OpType::Unknown), key2(0, OpType::Unknown);
{ // build with addOpWithOutputs
Graph g = make_ref<GraphObj>();
Graph g = make_ref<GraphObj>(nullptr);
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
Tensor o0 = g->addTensor({1, 2, 4}, DataType::UInt32);
@ -18,7 +18,7 @@ TEST(Hash, OperatorHash) {
EXPECT_GT(key1.attrs.size(), 5);
}
{ // build with addOp
Graph g = make_ref<GraphObj>();
Graph g = make_ref<GraphObj>(nullptr);
Tensor i0 = g->addTensor({2, 2, 3}, DataType::UInt32);
Tensor w0 = g->addTensor({2, 3, 4}, DataType::UInt32);
auto matmul = g->addOp<MatmulObj>(i0, w0, nullptr);

View File

@ -1,21 +1,24 @@
#include "core/graph.h"
#include "core/run_enigne.h"
#include "core/runtime.h"
#include "cuda/cuda_runtime.h"
#include "cuda/cuda_utility.h"
#include "operators/conv.h"
#include "test.h"
namespace infini {
TEST(Conv, ShapeInference) {
auto runtime = make_ref<CpuRuntimeObj>();
// Padding modes
{
Graph g = make_ref<GraphObj>();
Graph g = make_ref<GraphObj>(runtime);
Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32);
Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32);
auto conv = g->addOp<ConvObj>(i0, w0, nullptr, 1, 1);
EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 4, 4}));
}
{
Graph g = make_ref<GraphObj>();
Graph g = make_ref<GraphObj>(runtime);
Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32);
Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32);
auto conv =
@ -23,7 +26,7 @@ TEST(Conv, ShapeInference) {
EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 4, 4}));
}
{
Graph g = make_ref<GraphObj>();
Graph g = make_ref<GraphObj>(runtime);
Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32);
Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32);
auto conv =
@ -31,7 +34,7 @@ TEST(Conv, ShapeInference) {
EXPECT_EQ(conv->getOutput()->getDims(), (Shape{1, 2, 2, 2}));
}
{ // dilation & stride
Graph g = make_ref<GraphObj>();
Graph g = make_ref<GraphObj>(runtime);
Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32);
Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32);
auto conv = g->addOp<ConvObj>(i0, w0, nullptr, 1, 1, 2, 1, 1, 2);
@ -40,7 +43,8 @@ TEST(Conv, ShapeInference) {
}
TEST(Conv, NaiveCPU) {
Graph g = make_ref<GraphObj>();
auto runtime = make_ref<CpuRuntimeObj>();
Graph g = make_ref<GraphObj>(runtime);
Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32);
Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32);
auto conv = g->addOp<ConvObj>(i0, w0, nullptr, 1, 1, 2, 1, 1, 2);
@ -48,16 +52,75 @@ TEST(Conv, NaiveCPU) {
g->dataMalloc();
i0->setData(IncrementalGenerator());
w0->setData(IncrementalGenerator());
RunEngine(Device::CPU).run(g, true, true);
double perfTime = RunEngine(Device::CPU).getPerfTime(g);
// The example matmul takes 0.0036ms with one core
runtime->run(g, true, true);
double perfTime = runtime->getPerfTime(g);
// The example Conv takes 0.015ms with one core
EXPECT_GT(perfTime, 0);
EXPECT_LT(perfTime, 5);
EXPECT_LT(perfTime, 0.1);
// check answer
auto ans = make_ref<TensorObj>(Shape{1, 2, 2, 2}, DataType::UInt32);
ans->dataMalloc();
ans->copyData({4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656});
ans->dataMalloc(runtime);
ans->copyData(
vector<uint32_t>{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656});
EXPECT_TRUE(conv->getOutput()->equalData(ans));
}
void testConvCudnn(
const std::function<void(void *, size_t, DataType)> &generator,
vector<float> ansVec) {
auto cpuRuntime = make_ref<CpuRuntimeObj>();
auto cudaRuntime = make_ref<CudaRuntimeObj>();
// Build CUDA graph
Graph g = make_ref<GraphObj>(cudaRuntime);
Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::Float32);
Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::Float32);
auto conv = g->addOp<ConvObj>(i0, w0, nullptr, 1, 1, 2, 1, 1, 2);
// allocate CUDA memory
g->dataMalloc();
// Build input and output data on CPU
auto cpui0 = make_ref<TensorObj>(Shape{1, 3, 4, 4}, DataType::Float32);
cpui0->dataMalloc(cpuRuntime);
cpui0->setData(generator);
auto cpuw0 = make_ref<TensorObj>(Shape{2, 3, 3, 3}, DataType::Float32);
cpuw0->dataMalloc(cpuRuntime);
cpuw0->setData(generator);
auto ans = make_ref<TensorObj>(Shape{1, 2, 2, 2}, DataType::Float32);
ans->dataMalloc(cpuRuntime);
ans->copyData(ansVec);
// Copy inputs from CPU to CUDA
cudaMemcpy(i0->getDataRawPtr<void *>(), cpui0->getDataRawPtr<void *>(),
cpui0->size() * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(w0->getDataRawPtr<void *>(), cpuw0->getDataRawPtr<void *>(),
cpuw0->size() * sizeof(float), cudaMemcpyHostToDevice);
// Execute on CUDA
cudaRuntime->run(g);
// double perfTime = cudaRuntime->getPerfTime(g);
// // The example Conv takes 0.015ms with one core
// EXPECT_GT(perfTime, 0);
// EXPECT_LT(perfTime, 0.1);
// copy CUDA output to CPU
auto o0 = conv->getOutput();
auto cpuo0 = make_ref<TensorObj>(Shape{1, 2, 2, 2}, DataType::Float32);
cpuo0->dataMalloc(cpuRuntime);
cudaMemcpy(cpuo0->getDataRawPtr<void *>(),
conv->getOutput()->getDataRawPtr<void *>(),
cpuo0->size() * sizeof(float), cudaMemcpyDeviceToHost);
// check results on CPU
EXPECT_TRUE(cpuo0->equalData<float>(ans));
}
TEST(Conv, cuDNN) {
testConvCudnn(OneGenerator(),
vector<float>{12, 12, 18, 18, 12, 12, 18, 18});
testConvCudnn(
IncrementalGenerator(),
vector<float>{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656});
}
} // namespace infini

2
test/script/env_lotus.sh Normal file
View File

@ -0,0 +1,2 @@
spack load cuda@11.0.2 cudnn@8.0.3.33-11.0
export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc