forked from jiuyuan/InfiniTensor
Add python interface for CUDA operator evaluation (#42)
* Refactor: seperate data generator * Add: python bindings for opTimer * Fix: test_perfengine Co-authored-by: Liyan Zheng <liyan-zheng@outlook.com>
This commit is contained in:
parent
11d5aa1ccc
commit
1aefc1b27e
|
@ -94,6 +94,11 @@ endif()
|
||||||
|
|
||||||
target_link_libraries(InfiniTensor pybind11::embed)
|
target_link_libraries(InfiniTensor pybind11::embed)
|
||||||
|
|
||||||
|
# Python bindings
|
||||||
|
file(GLOB_RECURSE FFIS src/ffi/ffi_infinitensor.cc)
|
||||||
|
pybind11_add_module(pyinfinitensor MODULE ${FFIS})
|
||||||
|
target_link_libraries(pyinfinitensor PRIVATE InfiniTensor)
|
||||||
|
|
||||||
if(USE_BACKTRACE)
|
if(USE_BACKTRACE)
|
||||||
add_definitions(-D BACKWARD_TRACE)
|
add_definitions(-D BACKWARD_TRACE)
|
||||||
add_subdirectory(3rd-party/backward-cpp)
|
add_subdirectory(3rd-party/backward-cpp)
|
||||||
|
@ -103,6 +108,7 @@ if(USE_BACKTRACE)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_CUDA)
|
if(USE_CUDA)
|
||||||
|
add_compile_definitions(USE_CUDA=1)
|
||||||
# Since enable_language only executes once, rerun cmake is required if CMAKE_CUDA_HOST_COMPILER is wrong
|
# Since enable_language only executes once, rerun cmake is required if CMAKE_CUDA_HOST_COMPILER is wrong
|
||||||
set(CMAKE_CUDA_HOST_COMPILER
|
set(CMAKE_CUDA_HOST_COMPILER
|
||||||
${CMAKE_CXX_COMPILER}
|
${CMAKE_CXX_COMPILER}
|
||||||
|
@ -111,14 +117,7 @@ if(USE_CUDA)
|
||||||
set_target_properties(InfiniTensor PROPERTIES CUDA_ARCHITECTURES "70;80")
|
set_target_properties(InfiniTensor PROPERTIES CUDA_ARCHITECTURES "70;80")
|
||||||
enable_language(CUDA)
|
enable_language(CUDA)
|
||||||
find_package(CUDAToolkit) # For nvrtc and cuda driver
|
find_package(CUDAToolkit) # For nvrtc and cuda driver
|
||||||
target_link_libraries(
|
target_link_libraries(InfiniTensor cudnn CUDA::curand CUDA::cublas CUDA::nvrtc CUDA::cudart CUDA::cuda_driver)
|
||||||
InfiniTensor
|
|
||||||
cudnn
|
|
||||||
CUDA::curand
|
|
||||||
CUDA::cublas
|
|
||||||
CUDA::nvrtc
|
|
||||||
CUDA::cudart
|
|
||||||
CUDA::cuda_driver)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_BANG)
|
if(USE_BANG)
|
||||||
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
#pragma once
|
||||||
|
namespace infini {
|
||||||
|
namespace opTimer {
|
||||||
|
double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s,
|
||||||
|
int padh, int padw, int strideh, int stridew,
|
||||||
|
int dilationh, int dilationw, int group,
|
||||||
|
const char *name);
|
||||||
|
|
||||||
|
double getPerfMatmulCublas(int b, int m, int n, int k, const char *name);
|
||||||
|
} // namespace opTimer
|
||||||
|
} // namespace infini
|
|
@ -1,59 +1,5 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include "core/common.h"
|
#include "core/common.h"
|
||||||
#include "core/tensor_base.h"
|
#include "core/tensor_base.h"
|
||||||
|
#include "utils/data_generator.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
|
|
||||||
namespace infini {
|
|
||||||
|
|
||||||
// TODO: isolate these class
|
|
||||||
class DataGenerator {
|
|
||||||
private:
|
|
||||||
virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); }
|
|
||||||
virtual void fill(float *data, size_t size) { IT_TODO_HALT(); }
|
|
||||||
|
|
||||||
public:
|
|
||||||
virtual ~DataGenerator() {}
|
|
||||||
void operator()(void *data, size_t size, DataType dataType) {
|
|
||||||
if (dataType == DataType::UInt32)
|
|
||||||
fill(reinterpret_cast<uint32_t *>(data), size);
|
|
||||||
else if (dataType == DataType::Float32)
|
|
||||||
fill(reinterpret_cast<float *>(data), size);
|
|
||||||
else
|
|
||||||
IT_TODO_HALT();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class IncrementalGenerator : public DataGenerator {
|
|
||||||
public:
|
|
||||||
virtual ~IncrementalGenerator() {}
|
|
||||||
|
|
||||||
private:
|
|
||||||
template <typename T> void fill(T *data, size_t size) {
|
|
||||||
for (size_t i = 0; i < size; i++) {
|
|
||||||
data[i] = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void fill(uint32_t *data, size_t size) override {
|
|
||||||
fill<uint32_t>(data, size);
|
|
||||||
}
|
|
||||||
void fill(float *data, size_t size) override { fill<float>(data, size); }
|
|
||||||
};
|
|
||||||
|
|
||||||
class OneGenerator : public DataGenerator {
|
|
||||||
public:
|
|
||||||
virtual ~OneGenerator() {}
|
|
||||||
|
|
||||||
private:
|
|
||||||
template <typename T> void fill(T *data, size_t size) {
|
|
||||||
for (size_t i = 0; i < size; i++) {
|
|
||||||
data[i] = 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void fill(uint32_t *data, size_t size) override {
|
|
||||||
fill<uint32_t>(data, size);
|
|
||||||
}
|
|
||||||
void fill(float *data, size_t size) override { fill<float>(data, size); }
|
|
||||||
};
|
|
||||||
} // namespace infini
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
#include "core/common.h"
|
||||||
|
#include "core/tensor_base.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
// TODO: isolate these class
|
||||||
|
class DataGenerator {
|
||||||
|
private:
|
||||||
|
virtual void fill(uint32_t *data, size_t size) { IT_TODO_HALT(); }
|
||||||
|
virtual void fill(float *data, size_t size) { IT_TODO_HALT(); }
|
||||||
|
|
||||||
|
public:
|
||||||
|
virtual ~DataGenerator() {}
|
||||||
|
void operator()(void *data, size_t size, DataType dataType) {
|
||||||
|
if (dataType == DataType::UInt32)
|
||||||
|
fill(reinterpret_cast<uint32_t *>(data), size);
|
||||||
|
else if (dataType == DataType::Float32)
|
||||||
|
fill(reinterpret_cast<float *>(data), size);
|
||||||
|
else
|
||||||
|
IT_TODO_HALT();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class IncrementalGenerator : public DataGenerator {
|
||||||
|
public:
|
||||||
|
virtual ~IncrementalGenerator() {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T> void fill(T *data, size_t size) {
|
||||||
|
for (size_t i = 0; i < size; i++) {
|
||||||
|
data[i] = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill(uint32_t *data, size_t size) override {
|
||||||
|
fill<uint32_t>(data, size);
|
||||||
|
}
|
||||||
|
void fill(float *data, size_t size) override { fill<float>(data, size); }
|
||||||
|
};
|
||||||
|
|
||||||
|
class OneGenerator : public DataGenerator {
|
||||||
|
public:
|
||||||
|
virtual ~OneGenerator() {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T> void fill(T *data, size_t size) {
|
||||||
|
for (size_t i = 0; i < size; i++) {
|
||||||
|
data[i] = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fill(uint32_t *data, size_t size) override {
|
||||||
|
fill<uint32_t>(data, size);
|
||||||
|
}
|
||||||
|
void fill(float *data, size_t size) override { fill<float>(data, size); }
|
||||||
|
};
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,13 @@
|
||||||
|
from tokenize import Double
|
||||||
|
import pyinfinitensor # import getPerfConv, getPerfMatmul
|
||||||
|
|
||||||
|
|
||||||
|
def getPerfConv(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name):
|
||||||
|
return pyinfinitensor.getPerfConvCudnn(n, c, h, w, f, r, s, padh, padw,
|
||||||
|
strideh, stridew, dilationh, dilationw, group, name)
|
||||||
|
|
||||||
|
|
||||||
|
def getPerfMatmul(b, m, n, k, name):
|
||||||
|
return pyinfinitensor.getPerfMatmulCublas(b, m, n, k, name)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
#include "core/graph.h"
|
||||||
|
#include "core/kernel.h"
|
||||||
|
#include "core/runtime.h"
|
||||||
|
#include "cuda/cuda_runtime.h"
|
||||||
|
#include "cuda/cuda_utility.h"
|
||||||
|
#include "operators/conv.h"
|
||||||
|
#include "operators/matmul.h"
|
||||||
|
#include "utils/data_generator.h"
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
namespace opTimer {
|
||||||
|
|
||||||
|
double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s,
|
||||||
|
int padh, int padw, int strideh, int stridew,
|
||||||
|
int dilationh, int dilationw, int group,
|
||||||
|
const char *name) {
|
||||||
|
// const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew,
|
||||||
|
// dilationh, dilationw, group] =
|
||||||
|
// tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1};
|
||||||
|
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||||
|
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||||
|
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||||
|
// Set input data on CPU in a CPU Graph
|
||||||
|
Tensor i0Cpu = gCpu->addTensor({n, c, h, w}, DataType::Float32);
|
||||||
|
Tensor w0Cpu = gCpu->addTensor({f, c, r, s}, DataType::Float32);
|
||||||
|
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
i0Cpu->setData(IncrementalGenerator());
|
||||||
|
w0Cpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
// Copy input tensors from CPU to CUDA
|
||||||
|
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||||
|
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||||
|
// Build CUDA graph
|
||||||
|
auto conv = gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, padh, padw,
|
||||||
|
strideh, stridew, dilationh, dilationw);
|
||||||
|
// allocate CUDA memory
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
// Execute on CUDA
|
||||||
|
bool tune = true;
|
||||||
|
cuda->run(gCuda, tune);
|
||||||
|
return cuda->getPerfTime(gCuda);
|
||||||
|
}
|
||||||
|
|
||||||
|
double getPerfMatmulCublas(int b, int m, int n, int k, const char *name) {
|
||||||
|
// const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew,
|
||||||
|
// dilationh, dilationw, group] =
|
||||||
|
// tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1};
|
||||||
|
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||||
|
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||||
|
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||||
|
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||||
|
// Set input data on CPU in a CPU Graph
|
||||||
|
Tensor i0Cpu = gCpu->addTensor({b, m, k}, DataType::Float32);
|
||||||
|
Tensor w0Cpu = gCpu->addTensor({b, k, n}, DataType::Float32);
|
||||||
|
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||||
|
gCpu->dataMalloc();
|
||||||
|
i0Cpu->setData(IncrementalGenerator());
|
||||||
|
w0Cpu->setData(IncrementalGenerator());
|
||||||
|
|
||||||
|
// Copy input tensors from CPU to CUDA
|
||||||
|
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||||
|
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||||
|
// Build CUDA graph
|
||||||
|
auto conv = gCuda->addOp<MatmulObj>(i0Cuda, w0Cuda, nullptr);
|
||||||
|
// allocate CUDA memory
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
// Execute on CUDA
|
||||||
|
bool tune = true;
|
||||||
|
cuda->run(gCuda, tune);
|
||||||
|
return cuda->getPerfTime(gCuda);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace opTimer
|
||||||
|
} // namespace infini
|
|
@ -0,0 +1,22 @@
|
||||||
|
#include <pybind11/stl.h>
|
||||||
|
#ifdef USE_CUDA
|
||||||
|
#include "cuda/operator_timer.h"
|
||||||
|
#endif
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
namespace infini {
|
||||||
|
|
||||||
|
using namespace py::literals;
|
||||||
|
using policy = py::return_value_policy;
|
||||||
|
|
||||||
|
void register_operator_timer(py::module &m) {
|
||||||
|
#ifdef USE_CUDA
|
||||||
|
using namespace opTimer;
|
||||||
|
m.def("getPerfConvCudnn", &getPerfConvCudnn);
|
||||||
|
m.def("getPerfMatmulCublas", &getPerfMatmulCublas);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace infini
|
||||||
|
|
||||||
|
PYBIND11_MODULE(pyinfinitensor, m) { infini::register_operator_timer(m); }
|
|
@ -14,36 +14,27 @@ TEST(PerfEngine, save_and_load) {
|
||||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
{ // Conv
|
||||||
// Set input data on CPU in a CPU Graph
|
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||||
Tensor i0Cpu = gCpu->addTensor({1, 3, 224, 224}, DataType::Float32);
|
|
||||||
Tensor w0Cpu = gCpu->addTensor({2, 3, 3, 3}, DataType::Float32);
|
|
||||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
|
||||||
gCpu->dataMalloc();
|
|
||||||
i0Cpu->setData(IncrementalGenerator());
|
|
||||||
w0Cpu->setData(IncrementalGenerator());
|
|
||||||
|
|
||||||
// Copy input tensors from CPU to CUDA
|
// Copy input tensors from CPU to CUDA
|
||||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
Tensor i0Cuda = gCuda->addTensor({1, 3, 224, 224}, DataType::Float32);
|
||||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
Tensor w0Cuda = gCuda->addTensor({2, 3, 3, 3}, DataType::Float32);
|
||||||
// Build CUDA graph
|
// Build CUDA graph
|
||||||
auto conv =
|
auto conv =
|
||||||
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1);
|
gCuda->addOp<ConvObj>(i0Cuda, w0Cuda, nullptr, 1, 1, 1, 1, 1, 1);
|
||||||
|
gCuda->dataMalloc();
|
||||||
|
cuda->run(gCuda, true);
|
||||||
|
}
|
||||||
|
|
||||||
auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32);
|
{ // Matmul
|
||||||
auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32);
|
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||||
gCpu->dataMalloc();
|
auto ACuda = gCuda->addTensor(Shape{1, 3, 5}, DataType::Float32);
|
||||||
ACpu->setData(IncrementalGenerator());
|
auto BCuda = gCuda->addTensor(Shape{1, 5, 2}, DataType::Float32);
|
||||||
BCpu->setData(IncrementalGenerator());
|
auto matmul = gCuda->addOp<MatmulObj>(ACuda, BCuda, nullptr);
|
||||||
|
gCuda->dataMalloc();
|
||||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
cuda->run(gCuda, true);
|
||||||
|
}
|
||||||
auto ACuda = gCuda->cloneTensor(ACpu);
|
|
||||||
auto BCuda = gCuda->cloneTensor(BCpu);
|
|
||||||
auto matmul = gCuda->addOp<MatmulObj>(ACuda, BCuda, nullptr);
|
|
||||||
|
|
||||||
gCuda->dataMalloc();
|
|
||||||
cudaRuntime->run(gCuda, true);
|
|
||||||
auto &perfEngine = PerfEngine::getInstance();
|
auto &perfEngine = PerfEngine::getInstance();
|
||||||
|
|
||||||
json j0 = perfEngine;
|
json j0 = perfEngine;
|
||||||
|
|
Loading…
Reference in New Issue