forked from jiuyuan/InfiniTensor
ADD: add mkl runtime for intel cpu , and add mkl kernel for matmul/conv/convtransposed. (#61)
* move memory format transformation to TensorObj clang format add MemoryFormat for tensorObj. use post_ops for fused conv/deconv Distinguish mkl op_timer from cuda op timer. add act optype to conv and deconv add operator timer add mkl kernel for convTransposed minor fix for group conv do not use cblas_sgemm_batch CpuRuntimeObj->NativeCpuRuntimeObj add matmul op for mkl * fix: fix bugs when rebasing from master fix: fix bugs when rebasing from master * fix: update api after rebasing * fix: fix format; fix onnx import * fix: fix clang-format * [fix] fix conv_transpose test * [fix] use stronger test case for transposed conv * [fix] remove tensor memory format; fix mkl transpose conv * [fix] add FIXME tag for op_timer python api --------- Co-authored-by: whjthu <haojie0429@gmail.com>
This commit is contained in:
parent
65a3abf5dc
commit
86ec4036ce
|
@ -5,6 +5,7 @@ project(InfiniTensor C CXX)
|
|||
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
|
||||
option(USE_CUDA "Support CUDA GPU" OFF)
|
||||
option(USE_BANG "Support BANG MLU" OFF)
|
||||
option(USE_MKL "Support MKL" OFF)
|
||||
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON)
|
||||
option(USE_PROTOBUF "Serialize and deserialize tensors" ON)
|
||||
option(BUILD_TEST "Build tests" ON)
|
||||
|
@ -86,6 +87,11 @@ if(USE_BANG)
|
|||
list (APPEND SRC ${SRC_BANG})
|
||||
endif()
|
||||
|
||||
if(USE_MKL)
|
||||
file(GLOB_RECURSE SRC_MKL src/mkl/*.cc src/kernels/mkl/*.cc )
|
||||
list (APPEND SRC ${SRC_MKL})
|
||||
endif()
|
||||
|
||||
# Libraries
|
||||
add_library(InfiniTensor SHARED ${SRC})
|
||||
if(USE_PROTOBUF)
|
||||
|
@ -107,6 +113,21 @@ if(USE_BACKTRACE)
|
|||
target_link_libraries(InfiniTensor dw)
|
||||
endif()
|
||||
|
||||
if(USE_MKL)
|
||||
find_package(MKL CONFIG REQUIRED)
|
||||
target_link_libraries(InfiniTensor $<LINK_ONLY:MKL::MKL>)
|
||||
set(DNNL_CONFIGURATION "cpu_gomp")
|
||||
find_package(dnnl CONFIG REQUIRED)
|
||||
if(dnnl_FOUND)
|
||||
add_compile_definitions(USE_MKL=1)
|
||||
include_directories(BEFORE ${dnnl_DIR}/../../../cpu_gomp/include/)
|
||||
link_directories(${dnnl_DIR}/../../../cpu_gomp/lib)
|
||||
target_link_libraries(InfiniTensor dnnl)
|
||||
else()
|
||||
message(FATAL_ERROR ”dnnl library not found”)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(USE_CUDA)
|
||||
add_compile_definitions(USE_CUDA=1)
|
||||
# Since enable_language only executes once, rerun cmake is required if CMAKE_CUDA_HOST_COMPILER is wrong
|
||||
|
@ -189,6 +210,9 @@ if(BUILD_TEST)
|
|||
if (USE_BANG)
|
||||
build_test(test/kernels/bang/*.cc)
|
||||
endif()
|
||||
if (USE_MKL)
|
||||
build_test(test/kernels/mkl/*.cc)
|
||||
endif()
|
||||
endif()
|
||||
if(BUILD_TEST_PET)
|
||||
build_test(test/pet/*.cc)
|
||||
|
|
|
@ -12,7 +12,8 @@ class Mutator {
|
|||
Runtime runtime;
|
||||
|
||||
public:
|
||||
Mutator(int candidatesLimit, Runtime runtime = CpuRuntimeObj::getInstance())
|
||||
Mutator(int candidatesLimit,
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance())
|
||||
: candidatesLimit(candidatesLimit), runtime(runtime){};
|
||||
virtual ~Mutator(){};
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ using OpVec = vector<Operator>;
|
|||
|
||||
using VType = uint32_t;
|
||||
|
||||
enum class Device { CPU = 1, CUDA, BANG };
|
||||
enum class Device { CPU = 1, CUDA, BANG, MKL };
|
||||
/***************** Forward declaration end *****************/
|
||||
|
||||
class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
||||
|
@ -53,7 +53,7 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
|||
bool profiling = false) const = 0;
|
||||
virtual void *alloc(size_t size) = 0;
|
||||
virtual void dealloc(void *ptr) = 0;
|
||||
|
||||
void prepareAndRun(Graph &graph, bool tune = false, bool profiling = false);
|
||||
/**
|
||||
* @brief Get the execution time of each operator in performance record. No
|
||||
* execution happens.
|
||||
|
@ -64,7 +64,9 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
|||
*/
|
||||
double getPerfTime(const Graph &graph, bool profiling = false) const;
|
||||
Blob allocBlob(size_t size);
|
||||
bool isCpu() const { return device == Device::CPU; }
|
||||
bool isCpu() const {
|
||||
return device == Device::CPU || device == Device::MKL;
|
||||
}
|
||||
bool isCuda() const { return device == Device::CUDA; }
|
||||
bool isBang() const { return device == Device::BANG; }
|
||||
void copyBlob(const TensorObj *dst, const TensorObj *src) const;
|
||||
|
@ -85,26 +87,33 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
|||
|
||||
class CpuRuntimeObj : public RuntimeObj {
|
||||
public:
|
||||
CpuRuntimeObj() : RuntimeObj(Device::CPU) {}
|
||||
static Ref<CpuRuntimeObj> &getInstance() {
|
||||
static Ref<CpuRuntimeObj> instance = make_ref<CpuRuntimeObj>();
|
||||
return instance;
|
||||
}
|
||||
CpuRuntimeObj(Device dev) : RuntimeObj(dev) {}
|
||||
|
||||
void run(const Graph &graph, bool tune = false,
|
||||
bool profiling = false) const override;
|
||||
void dealloc(void *ptr) override { return free(ptr); };
|
||||
|
||||
void *alloc(size_t size) override {
|
||||
return calloc((size + sizeof(uint64_t) - 1) / sizeof(uint64_t),
|
||||
sizeof(uint64_t));
|
||||
};
|
||||
|
||||
void copyBlobFromCPU(void *dst, const void *src,
|
||||
size_t bytes) const override;
|
||||
void copyBlobToCPU(void *dst, const void *src, size_t bytes) const override;
|
||||
void copyBlobInsideRuntime(void *dst, const void *src,
|
||||
size_t bytes) const override;
|
||||
};
|
||||
|
||||
class NativeCpuRuntimeObj : public CpuRuntimeObj {
|
||||
public:
|
||||
NativeCpuRuntimeObj() : CpuRuntimeObj(Device::CPU) {}
|
||||
|
||||
static Ref<NativeCpuRuntimeObj> &getInstance() {
|
||||
static Ref<NativeCpuRuntimeObj> instance =
|
||||
make_ref<NativeCpuRuntimeObj>();
|
||||
return instance;
|
||||
}
|
||||
void dealloc(void *ptr) override { return free(ptr); };
|
||||
|
||||
void *alloc(size_t size) override {
|
||||
return calloc((size + sizeof(uint64_t) - 1) / sizeof(uint64_t),
|
||||
sizeof(uint64_t));
|
||||
};
|
||||
string toString() const override;
|
||||
};
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
#include "core/runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class TensorBaseObj : public Object {
|
||||
public:
|
||||
// enum TensorType {
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
#pragma once
|
||||
#include "core/runtime.h"
|
||||
#include "dnnl.h"
|
||||
#include "oneapi/dnnl/dnnl.h"
|
||||
#include "oneapi/dnnl/dnnl.hpp"
|
||||
#include "oneapi/dnnl/dnnl_types.h"
|
||||
#include <dnnl_debug.h>
|
||||
#include <mkl.h>
|
||||
namespace infini {
|
||||
// TODO move utility function to alone file
|
||||
class MklRuntimeObj : public CpuRuntimeObj {
|
||||
dnnl_engine_t engine;
|
||||
|
||||
public:
|
||||
MklRuntimeObj();
|
||||
static Ref<MklRuntimeObj> &getInstance() {
|
||||
static Ref<MklRuntimeObj> instance = make_ref<MklRuntimeObj>();
|
||||
return instance;
|
||||
}
|
||||
|
||||
virtual ~MklRuntimeObj();
|
||||
void dealloc(void *ptr) override { return mkl_free(ptr); };
|
||||
|
||||
void *alloc(size_t size) override {
|
||||
return mkl_calloc((size + sizeof(uint64_t) - 1) / sizeof(uint64_t),
|
||||
sizeof(uint64_t), 64);
|
||||
};
|
||||
|
||||
string toString() const override { return "CPU MKL Runtime"; };
|
||||
dnnl::engine getEngine() const { return dnnl::engine(engine, true); }
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
namespace infini {
|
||||
namespace opTimer {
|
||||
double getPerfConvMkl(int n, int c, int h, int w, int f, int r, int s, int padh,
|
||||
int padw, int strideh, int stridew, int dilationh,
|
||||
int dilationw, int group);
|
||||
|
||||
double getPerfConvTransposed2dMkl(int n, int c, int h, int w, int f, int r,
|
||||
int s, int padh, int padw, int strideh,
|
||||
int stridew, int dilationh, int dilationw,
|
||||
int oph, int opw, int group);
|
||||
|
||||
double getPerfMatmulMkl(int b, int m, int n, int k);
|
||||
} // namespace opTimer
|
||||
} // namespace infini
|
|
@ -2,14 +2,28 @@ from tokenize import Double
|
|||
import pyinfinitensor # import getPerfConv, getPerfMatmul
|
||||
|
||||
|
||||
def getPerfConv(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name=""):
|
||||
return pyinfinitensor.getPerfConvCudnn(n, c, h, w, f, r, s, padh, padw,
|
||||
# FIXME: change API from getPerfOpDevice(...) to getPerfOp(device='dev', ...)
|
||||
def getPerfConvCuda(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name=""):
|
||||
return pyinfinitensor.getPerfConvCuda(n, c, h, w, f, r, s, padh, padw,
|
||||
strideh, stridew, dilationh, dilationw, group, name)
|
||||
|
||||
|
||||
def getPerfConvTransposed2dCudnn(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group):
|
||||
return pyinfinitensor.getPerfConvTransposed2dCudnn(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group)
|
||||
def getPerfConvTransposed2dCuda(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group):
|
||||
return pyinfinitensor.getPerfConvTransposed2dCuda(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group)
|
||||
|
||||
|
||||
def getPerfMatmul(b, m, n, k, name=""):
|
||||
return pyinfinitensor.getPerfMatmulCublas(b, m, n, k, name)
|
||||
def getPerfMatmulCuda(b, m, n, k, name=""):
|
||||
return pyinfinitensor.getPerfMatmulCuda(b, m, n, k, name)
|
||||
|
||||
|
||||
def getPerfConvMkl(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, group, name=""):
|
||||
return pyinfinitensor.getPerfConvMkl(n, c, h, w, f, r, s, padh, padw,
|
||||
strideh, stridew, dilationh, dilationw, group)
|
||||
|
||||
|
||||
def getPerfConvTransposed2dMkl(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group):
|
||||
return pyinfinitensor.getPerfConvTransposed2dMkl(n, c, h, w, f, r, s, padh, padw, strideh, stridew, dilationh, dilationw, oph, opw, group)
|
||||
|
||||
|
||||
def getPerfMatmulMkl(b, m, n, k, name=""):
|
||||
return pyinfinitensor.getPerfMatmulMkl(b, m, n, k)
|
||||
|
|
|
@ -6,6 +6,9 @@
|
|||
#include <chrono>
|
||||
#include <cstring>
|
||||
namespace infini {
|
||||
void RuntimeObj::prepareAndRun(Graph &graph, bool tune, bool profiling) {
|
||||
run(graph, tune, profiling);
|
||||
}
|
||||
|
||||
void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
||||
if (!tune && profiling)
|
||||
|
@ -159,6 +162,6 @@ void CpuRuntimeObj::copyBlobInsideRuntime(void *dst, const void *src,
|
|||
memcpy(dst, src, bytes);
|
||||
}
|
||||
|
||||
string CpuRuntimeObj::toString() const { return "CPU Runtime"; }
|
||||
string NativeCpuRuntimeObj::toString() const { return "CPU Runtime"; }
|
||||
|
||||
} // namespace infini
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
#include "core/operator.h"
|
||||
#include "core/runtime.h"
|
||||
#include "utils/dataloader.h"
|
||||
#include <cstring>
|
||||
#include <numeric>
|
||||
|
||||
namespace infini {
|
||||
|
@ -157,7 +158,7 @@ void TensorObj::setData(
|
|||
generator(getRawDataPtr<void *>(), size(), dtype);
|
||||
} else {
|
||||
// Create a CPU buffer for the generetor and copy results to the device
|
||||
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
size_t nBytes = size() * dtype.getSize();
|
||||
Blob buffer = cpuRuntime->allocBlob(nBytes);
|
||||
generator(buffer->getPtr<void *>(), size(), dtype);
|
||||
|
@ -200,5 +201,4 @@ size_t TensorObj::getOffsetByBroadcastOffset(size_t bcOffset,
|
|||
}
|
||||
return getOffsetByPos(pos, shape);
|
||||
}
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -17,7 +17,7 @@ double getPerfConvCudnn(int n, int c, int h, int w, int f, int r, int s,
|
|||
// const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew,
|
||||
// dilationh, dilationw, group] =
|
||||
// tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1};
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
|
@ -51,7 +51,7 @@ double getPerfConvTransposed2dCudnn(int n, int c, int h, int w, int f, int r,
|
|||
// const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew,
|
||||
// dilationh, dilationw, group] =
|
||||
// tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1};
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
|
@ -83,7 +83,7 @@ double getPerfMatmulCublas(int b, int m, int n, int k, const char *name) {
|
|||
// const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew,
|
||||
// dilationh, dilationw, group] =
|
||||
// tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1};
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
|
@ -109,4 +109,4 @@ double getPerfMatmulCublas(int b, int m, int n, int k, const char *name) {
|
|||
}
|
||||
|
||||
} // namespace opTimer
|
||||
} // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -12,7 +12,9 @@
|
|||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/operator_timer.h"
|
||||
#endif
|
||||
|
||||
#ifdef USE_MKL
|
||||
#include "mkl/operator_timer.h"
|
||||
#endif
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace infini {
|
||||
|
@ -27,6 +29,13 @@ void register_operator_timer(py::module &m) {
|
|||
m.def("getPerfConvTransposed2dCudnn", &getPerfConvTransposed2dCudnn);
|
||||
m.def("getPerfMatmulCublas", &getPerfMatmulCublas);
|
||||
#endif
|
||||
|
||||
#ifdef USE_MKL
|
||||
using namespace opTimer;
|
||||
m.def("getPerfConvMkl", &getPerfConvMkl);
|
||||
m.def("getPerfConvTransposed2dMkl", &getPerfConvTransposed2dMkl);
|
||||
m.def("getPerfMatmulMkl", &getPerfMatmulMkl);
|
||||
#endif
|
||||
}
|
||||
|
||||
void export_values(py::module &m) {
|
||||
|
@ -149,7 +158,7 @@ static Shape reshape_shape_of(Operator op) {
|
|||
|
||||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &CpuRuntimeObj::getInstance)
|
||||
m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance)
|
||||
#ifdef USE_CUDA
|
||||
.FUNCTION(cuda_runtime)
|
||||
#endif
|
||||
|
@ -168,8 +177,8 @@ void init_graph_builder(py::module &m) {
|
|||
using Handler = GraphHandlerObj;
|
||||
|
||||
py::class_<RuntimeObj, std::shared_ptr<RuntimeObj>>(m, "Runtime");
|
||||
py::class_<CpuRuntimeObj, std::shared_ptr<CpuRuntimeObj>, RuntimeObj>(
|
||||
m, "CpuRuntime");
|
||||
py::class_<NativeCpuRuntimeObj, std::shared_ptr<NativeCpuRuntimeObj>,
|
||||
RuntimeObj>(m, "CpuRuntime");
|
||||
#ifdef USE_CUDA
|
||||
py::class_<CudaRuntimeObj, std::shared_ptr<CudaRuntimeObj>, RuntimeObj>(
|
||||
m, "CudaRuntime");
|
||||
|
@ -184,7 +193,8 @@ void init_graph_builder(py::module &m) {
|
|||
.def("copyout_int32", &TensorObj::copyout<int32_t>, policy::move)
|
||||
.def("copyout_int64", &TensorObj::copyout<int64_t>, policy::move)
|
||||
.def("has_target", &TensorObj::hasTarget, policy::automatic)
|
||||
.def("src", &TensorObj::getSource, policy::move);
|
||||
.def("src", &TensorObj::getSource, policy::move)
|
||||
.def("printData", &TensorObj::printData, policy::automatic);
|
||||
py::class_<OperatorObj, std::shared_ptr<OperatorObj>>(m, "Operator")
|
||||
.def("op_type", &OperatorObj::getOpType, policy::automatic)
|
||||
.def("inputs", py::overload_cast<>(&OperatorObj::getInputs, py::const_),
|
||||
|
|
|
@ -18,8 +18,6 @@ class BatchNormCudnn : public CudaKernelWithoutConfig {
|
|||
void *const biasData = (op->getInputs(4)->getRawDataPtr<void *>());
|
||||
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
if (dims.size() == 2)
|
||||
IT_TODO_HALT();
|
||||
// Only 4D and 5D tensors are supported by
|
||||
// cudnnBatchNormalizationForwardInference
|
||||
IT_ASSERT(dims.size() == 4 || dims.size() == 5);
|
||||
|
|
|
@ -0,0 +1,237 @@
|
|||
#include "operators/conv.h"
|
||||
#include "core/kernel.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
struct ConvMklPerfRecordObj : public PerfRecordObj {
|
||||
dnnl::algorithm algo = dnnl::algorithm::convolution_auto;
|
||||
void to_json(json &j) override {
|
||||
j["type"] = 1;
|
||||
j["data"] = std::make_tuple(enum_to_underlying(algo), time);
|
||||
}
|
||||
static PerfRecord from_json(const json &j) {
|
||||
ConvMklPerfRecordObj tmp;
|
||||
auto [Algo, Time] = j["data"].get<tuple<int, double>>();
|
||||
tmp.algo = (dnnl::algorithm)Algo;
|
||||
tmp.time = Time;
|
||||
return make_ref<ConvMklPerfRecordObj>(tmp);
|
||||
}
|
||||
};
|
||||
|
||||
using ConvMklPerfRecord = Ref<ConvMklPerfRecordObj>;
|
||||
class MklConv : public Kernel {
|
||||
bool createPrimitives(
|
||||
const Ref<ConvObj> &op, const ConvMklPerfRecord &record,
|
||||
const MklRuntimeObj *context, bool allowEmpty,
|
||||
std::vector<dnnl::primitive> &prims,
|
||||
std::vector<std::unordered_map<int, dnnl::memory>> &primArgs) const {
|
||||
auto srcData = op->getInputs(0)->getRawDataPtr<float *>();
|
||||
auto wData = op->getInputs(1)->getRawDataPtr<float *>();
|
||||
auto dstData = op->getOutput(0)->getRawDataPtr<float *>();
|
||||
|
||||
auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
const int cpg = op->getChannelPerGroup();
|
||||
|
||||
auto oDims = op->getOutput(0)->getDims();
|
||||
int oH = oDims[oDims.size() - 2];
|
||||
int oW = oDims[oDims.size() - 1];
|
||||
|
||||
// create user memory that describes data layout in the buffers
|
||||
auto userSrcMd =
|
||||
dnnl::memory::desc({n, c, h, w}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::nchw);
|
||||
auto userSrcMemory =
|
||||
dnnl::memory(userSrcMd, context->getEngine(), srcData);
|
||||
|
||||
auto userWMd =
|
||||
dnnl::memory::desc({f, cpg, r, s}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::oihw);
|
||||
auto userWMemory = dnnl::memory(userWMd, context->getEngine(), wData);
|
||||
auto userDstMd =
|
||||
dnnl::memory::desc({n, f, oH, oW}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::nchw);
|
||||
|
||||
// create memory descriptors with layout tag::any, to let convolution
|
||||
// choose memory format
|
||||
// Convolution and inner product primitives choose the memory format
|
||||
// when you create them with the placeholder memory format
|
||||
// dnnl::memory::format_tag::any for input or output. The memory format
|
||||
// chosen is based on different circumstances such as hardware and
|
||||
// convolutional parameters. Using the placeholder memory format is the
|
||||
// recommended practice for convolutions, since they are the most
|
||||
// compute-intensive operations in most topologies where they are
|
||||
// present.
|
||||
auto srcMd =
|
||||
dnnl::memory::desc({n, c, h, w}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto wMd =
|
||||
dnnl::memory::desc({f, cpg, r, s}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto dstMd =
|
||||
dnnl::memory::desc({n, f, oH, oW}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::any);
|
||||
|
||||
// create convolution descriptor
|
||||
dnnl::memory::dims strides = {sh, sw};
|
||||
dnnl::memory::dims pads = {ph, pw};
|
||||
dnnl::memory::dims dilations = {dh - 1, dw - 1};
|
||||
auto convDesc = dnnl::convolution_forward::desc(
|
||||
dnnl::prop_kind::forward_inference, record->algo, srcMd, wMd, dstMd,
|
||||
strides, dilations, pads, pads);
|
||||
|
||||
dnnl::convolution_forward::primitive_desc primDesc;
|
||||
|
||||
// fused convolution
|
||||
// The non-intensive operation is added as a post-op attribute to the
|
||||
// compute intensive primitive descriptor
|
||||
if (ActType::None != op->getAct()) {
|
||||
dnnl::algorithm algo;
|
||||
switch (op->getAct()) {
|
||||
case ActType::Relu:
|
||||
algo = dnnl::algorithm::eltwise_relu;
|
||||
break;
|
||||
case ActType::Sigmoid:
|
||||
algo = dnnl::algorithm::eltwise_logsigmoid;
|
||||
break;
|
||||
case ActType::Tanh:
|
||||
algo = dnnl::algorithm::eltwise_tanh;
|
||||
break;
|
||||
|
||||
default:
|
||||
IT_ASSERT(0);
|
||||
}
|
||||
dnnl::primitive_attr attr;
|
||||
dnnl::post_ops po;
|
||||
po.append_eltwise(1.f, algo, 0.f, 0.f);
|
||||
attr.set_post_ops(po);
|
||||
|
||||
primDesc = dnnl::convolution_forward::primitive_desc(
|
||||
convDesc, attr, context->getEngine(), allowEmpty);
|
||||
|
||||
} else {
|
||||
primDesc = dnnl::convolution_forward::primitive_desc(
|
||||
convDesc, context->getEngine(), allowEmpty);
|
||||
}
|
||||
|
||||
if (primDesc.get(allowEmpty) == nullptr)
|
||||
return false;
|
||||
|
||||
// reorder data and weight
|
||||
auto srcMemory = userSrcMemory;
|
||||
if (primDesc.src_desc() != userSrcMemory.get_desc()) {
|
||||
srcMemory = dnnl::memory(primDesc.src_desc(), context->getEngine());
|
||||
|
||||
prims.push_back(dnnl::reorder(userSrcMemory, srcMemory));
|
||||
primArgs.push_back(
|
||||
{{DNNL_ARG_FROM, userSrcMemory}, {DNNL_ARG_TO, srcMemory}});
|
||||
}
|
||||
|
||||
auto wMemory = userWMemory;
|
||||
if (primDesc.weights_desc() != userWMemory.get_desc()) {
|
||||
wMemory =
|
||||
dnnl::memory(primDesc.weights_desc(), context->getEngine());
|
||||
|
||||
prims.push_back(dnnl::reorder(userWMemory, wMemory));
|
||||
primArgs.push_back(
|
||||
{{DNNL_ARG_FROM, userWMemory}, {DNNL_ARG_TO, wMemory}});
|
||||
}
|
||||
|
||||
// Create memory for output
|
||||
if (primDesc.dst_desc() == userDstMd) {
|
||||
auto output = dnnl::memory(primDesc.dst_desc(),
|
||||
context->getEngine(), dstData);
|
||||
|
||||
// create convolution primitivee
|
||||
prims.push_back(dnnl::convolution_forward(primDesc));
|
||||
primArgs.push_back({{DNNL_ARG_SRC, srcMemory},
|
||||
{DNNL_ARG_WEIGHTS, wMemory},
|
||||
{DNNL_ARG_DST, output}});
|
||||
} else {
|
||||
auto dstMemory =
|
||||
dnnl::memory(primDesc.dst_desc(), context->getEngine());
|
||||
|
||||
// create convolution primitivee
|
||||
prims.push_back(dnnl::convolution_forward(primDesc));
|
||||
primArgs.push_back({{DNNL_ARG_SRC, srcMemory},
|
||||
{DNNL_ARG_WEIGHTS, wMemory},
|
||||
{DNNL_ARG_DST, dstMemory}});
|
||||
|
||||
auto output =
|
||||
dnnl::memory(userDstMd, context->getEngine(), dstData);
|
||||
prims.push_back(dnnl::reorder(dstMemory, output));
|
||||
primArgs.push_back(
|
||||
{{DNNL_ARG_FROM, dstMemory}, {DNNL_ARG_TO, output}});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const {
|
||||
auto op = as<ConvObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
auto record = as<ConvMklPerfRecordObj>(_record);
|
||||
|
||||
dnnl::stream stream(context->getEngine());
|
||||
std::vector<dnnl::primitive> prims;
|
||||
std::vector<std::unordered_map<int, dnnl::memory>> primArgs;
|
||||
IT_ASSERT(createPrimitives(op, record, context, true, prims, primArgs));
|
||||
|
||||
IT_ASSERT(prims.size() == primArgs.size());
|
||||
for (size_t i = 0; i < prims.size(); ++i)
|
||||
prims.at(i).execute(stream, primArgs.at(i));
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
auto record = make_ref<ConvMklPerfRecordObj>();
|
||||
compute(op, record, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
ConvMklPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
auto op = as<ConvObj>(_op);
|
||||
|
||||
// Try every possible algorithm of convolution
|
||||
for (auto algo : {dnnl::algorithm::convolution_auto,
|
||||
dnnl::algorithm::convolution_direct,
|
||||
dnnl::algorithm::convolution_winograd}) {
|
||||
ConvMklPerfRecordObj record;
|
||||
record.algo = algo;
|
||||
|
||||
std::vector<dnnl::primitive> prims;
|
||||
std::vector<std::unordered_map<int, dnnl::memory>> primArgs;
|
||||
if (!createPrimitives(op, make_ref<ConvMklPerfRecordObj>(record),
|
||||
context, true, prims, primArgs))
|
||||
continue;
|
||||
|
||||
IT_ASSERT(prims.size() == primArgs.size());
|
||||
dnnl::stream stream(context->getEngine());
|
||||
for (size_t i = 0; i < prims.size(); ++i)
|
||||
prims.at(i).execute(stream, primArgs.at(i));
|
||||
stream.wait();
|
||||
|
||||
record.time = timeit(
|
||||
[&]() {
|
||||
for (size_t i = 0; i < prims.size(); ++i)
|
||||
prims.at(i).execute(stream, primArgs.at(i));
|
||||
},
|
||||
[&]() { stream.wait(); });
|
||||
|
||||
// Update the tune result
|
||||
if (ret.time > record.time)
|
||||
ret = record;
|
||||
}
|
||||
|
||||
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
|
||||
"algorithm "
|
||||
"found");
|
||||
return make_ref<ConvMklPerfRecordObj>(ret);
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::MKL, OpType::Conv, DataType::Float32, MklConv,
|
||||
"MklConv_CPU_float32");
|
||||
} // namespace infini
|
|
@ -0,0 +1,250 @@
|
|||
#include "core/kernel.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
namespace infini {
|
||||
struct ConvTransposeMklPerfRecordObj : public PerfRecordObj {
|
||||
dnnl::algorithm algo = dnnl::algorithm::deconvolution_direct;
|
||||
void to_json(json &j) override {
|
||||
j["type"] = 1;
|
||||
j["data"] = std::make_tuple(enum_to_underlying(algo), time);
|
||||
}
|
||||
static PerfRecord from_json(const json &j) {
|
||||
ConvTransposeMklPerfRecordObj tmp;
|
||||
auto [Algo, Time] = j["data"].get<tuple<int, double>>();
|
||||
tmp.algo = (dnnl::algorithm)Algo;
|
||||
tmp.time = Time;
|
||||
return make_ref<ConvTransposeMklPerfRecordObj>(tmp);
|
||||
}
|
||||
};
|
||||
|
||||
using ConvTransposeMklPerfRecord = Ref<ConvTransposeMklPerfRecordObj>;
|
||||
class MklConvTranspose : public Kernel {
|
||||
private:
|
||||
bool createPrimitives(
|
||||
const Ref<ConvTransposed2dObj> &op,
|
||||
const ConvTransposeMklPerfRecord &record, const MklRuntimeObj *context,
|
||||
bool allowEmpty, std::vector<dnnl::primitive> &prims,
|
||||
std::vector<std::unordered_map<int, dnnl::memory>> &primArgs) const {
|
||||
auto srcData = op->getInputs(0)->getRawDataPtr<float *>();
|
||||
auto wData = op->getInputs(1)->getRawDataPtr<float *>();
|
||||
// FIXME: iohw2iohwData
|
||||
auto dstData = op->getOutput(0)->getRawDataPtr<float *>();
|
||||
|
||||
auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
const int cpg = op->getChannelPerGroup();
|
||||
if (cpg != c)
|
||||
IT_TODO_HALT();
|
||||
|
||||
auto oDims = op->getOutput(0)->getDims();
|
||||
int oH = oDims[oDims.size() - 2];
|
||||
int oW = oDims[oDims.size() - 1];
|
||||
|
||||
// create user memory that describes data layout in the buffers
|
||||
auto userSrcMd =
|
||||
dnnl::memory::desc({n, f, h, w}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::nchw);
|
||||
auto userSrcMemory =
|
||||
dnnl::memory(userSrcMd, context->getEngine(), srcData);
|
||||
|
||||
// DNNL deconvolution expects the logical order of weights (parameters)
|
||||
// dimensions to be in order {o, i, h, w}. So need to reorder wData.
|
||||
// TODO: to make reorder happen only once when inference (because
|
||||
// weights are fixed).
|
||||
// TODO: Fix by whj, change memory format tag from oihw to iohw to
|
||||
// remove extra transpose. Correctness to be confirmed.
|
||||
auto userWMd =
|
||||
dnnl::memory::desc({cpg, f, r, s}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::iohw);
|
||||
|
||||
auto userWMemory = dnnl::memory(userWMd, context->getEngine(), wData);
|
||||
|
||||
auto userDstMd =
|
||||
dnnl::memory::desc({n, c, oH, oW}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::nchw);
|
||||
|
||||
// create memory descriptors with layout tag::any, to let convolution
|
||||
// choose memory format
|
||||
// Convolution and inner product primitives choose the memory format
|
||||
// when you create them with the placeholder memory format
|
||||
// dnnl::memory::format_tag::any for input or output. The memory format
|
||||
// chosen is based on different circumstances such as hardware and
|
||||
// convolutional parameters. Using the placeholder memory format is the
|
||||
// recommended practice for convolutions, since they are the most
|
||||
// compute-intensive operations in most topologies where they are
|
||||
// present.
|
||||
auto srcMd =
|
||||
dnnl::memory::desc({n, f, h, w}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto wMd =
|
||||
dnnl::memory::desc({cpg, f, r, s}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto dstMd =
|
||||
dnnl::memory::desc({n, c, oH, oW}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::any);
|
||||
|
||||
// create convolution descriptor
|
||||
dnnl::memory::dims strides = {sh, sw};
|
||||
dnnl::memory::dims pads = {ph, pw};
|
||||
dnnl::memory::dims dilations = {dh - 1, dw - 1};
|
||||
auto deconvDesc = dnnl::deconvolution_forward::desc(
|
||||
dnnl::prop_kind::forward_inference, record->algo, srcMd, wMd, dstMd,
|
||||
strides, dilations, pads, pads);
|
||||
|
||||
dnnl::deconvolution_forward::primitive_desc primDesc;
|
||||
// fused convolution
|
||||
// The non-intensive operation is added as a post-op attribute to the
|
||||
// compute intensive primitive descriptor
|
||||
if (ActType::None != op->getAct()) {
|
||||
dnnl::algorithm algo;
|
||||
switch (op->getAct()) {
|
||||
case ActType::Relu:
|
||||
algo = dnnl::algorithm::eltwise_relu;
|
||||
break;
|
||||
case ActType::Sigmoid:
|
||||
algo = dnnl::algorithm::eltwise_logsigmoid;
|
||||
break;
|
||||
case ActType::Tanh:
|
||||
algo = dnnl::algorithm::eltwise_tanh;
|
||||
break;
|
||||
|
||||
default:
|
||||
IT_ASSERT(0);
|
||||
}
|
||||
dnnl::primitive_attr attr;
|
||||
dnnl::post_ops po;
|
||||
po.append_eltwise(1.f, algo, 0.f, 0.f);
|
||||
attr.set_post_ops(po);
|
||||
|
||||
primDesc = dnnl::deconvolution_forward::primitive_desc(
|
||||
deconvDesc, attr, context->getEngine(), allowEmpty);
|
||||
|
||||
} else {
|
||||
primDesc = dnnl::deconvolution_forward::primitive_desc(
|
||||
deconvDesc, context->getEngine(), allowEmpty);
|
||||
}
|
||||
|
||||
if (primDesc.get(allowEmpty) == nullptr)
|
||||
return false;
|
||||
|
||||
// reorder data and weight
|
||||
auto srcMemory = userSrcMemory;
|
||||
if (primDesc.src_desc() != userSrcMemory.get_desc()) {
|
||||
srcMemory = dnnl::memory(primDesc.src_desc(), context->getEngine());
|
||||
|
||||
prims.push_back(dnnl::reorder(userSrcMemory, srcMemory));
|
||||
primArgs.push_back(
|
||||
{{DNNL_ARG_FROM, userSrcMemory}, {DNNL_ARG_TO, srcMemory}});
|
||||
}
|
||||
|
||||
auto wMemory = userWMemory;
|
||||
if (primDesc.weights_desc() != userWMemory.get_desc()) {
|
||||
wMemory =
|
||||
dnnl::memory(primDesc.weights_desc(), context->getEngine());
|
||||
|
||||
prims.push_back(dnnl::reorder(userWMemory, wMemory));
|
||||
primArgs.push_back(
|
||||
{{DNNL_ARG_FROM, userWMemory}, {DNNL_ARG_TO, wMemory}});
|
||||
}
|
||||
|
||||
if (primDesc.dst_desc() == userDstMd) {
|
||||
// Create memory for output
|
||||
auto dstMemory = dnnl::memory(primDesc.dst_desc(),
|
||||
context->getEngine(), dstData);
|
||||
|
||||
// create convolution primitivee
|
||||
prims.push_back(dnnl::deconvolution_forward(primDesc));
|
||||
primArgs.push_back({{DNNL_ARG_SRC, srcMemory},
|
||||
{DNNL_ARG_WEIGHTS, wMemory},
|
||||
{DNNL_ARG_DST, dstMemory}});
|
||||
} else {
|
||||
auto dstMemory =
|
||||
dnnl::memory(primDesc.dst_desc(), context->getEngine());
|
||||
|
||||
// create convolution primitivee
|
||||
prims.push_back(dnnl::deconvolution_forward(primDesc));
|
||||
primArgs.push_back({{DNNL_ARG_SRC, srcMemory},
|
||||
{DNNL_ARG_WEIGHTS, wMemory},
|
||||
{DNNL_ARG_DST, dstMemory}});
|
||||
|
||||
auto output =
|
||||
dnnl::memory(userDstMd, context->getEngine(), dstData);
|
||||
|
||||
prims.push_back(dnnl::reorder(dstMemory, output));
|
||||
primArgs.push_back(
|
||||
{{DNNL_ARG_FROM, dstMemory}, {DNNL_ARG_TO, output}});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const {
|
||||
auto op = as<ConvTransposed2dObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
auto record = as<ConvTransposeMklPerfRecordObj>(_record);
|
||||
|
||||
dnnl::stream stream(context->getEngine());
|
||||
std::vector<dnnl::primitive> prims;
|
||||
std::vector<std::unordered_map<int, dnnl::memory>> primArgs;
|
||||
IT_ASSERT(createPrimitives(op, record, context, true, prims, primArgs));
|
||||
|
||||
IT_ASSERT(prims.size() == primArgs.size());
|
||||
for (size_t i = 0; i < prims.size(); ++i)
|
||||
prims.at(i).execute(stream, primArgs.at(i));
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
auto record = make_ref<ConvTransposeMklPerfRecordObj>();
|
||||
compute(op, record, context);
|
||||
}
|
||||
|
||||
PerfRecord tune(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
ConvTransposeMklPerfRecordObj ret;
|
||||
ret.time = std::numeric_limits<double>::max();
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
auto op = as<ConvTransposed2dObj>(_op);
|
||||
|
||||
// Try every possible algorithm of convolution
|
||||
for (auto algo : {dnnl::algorithm::deconvolution_direct,
|
||||
dnnl::algorithm::deconvolution_winograd}) {
|
||||
ConvTransposeMklPerfRecordObj record;
|
||||
record.algo = algo;
|
||||
|
||||
std::vector<dnnl::primitive> prims;
|
||||
std::vector<std::unordered_map<int, dnnl::memory>> primArgs;
|
||||
if (!createPrimitives(
|
||||
op, make_ref<ConvTransposeMklPerfRecordObj>(record),
|
||||
context, true, prims, primArgs))
|
||||
continue;
|
||||
|
||||
IT_ASSERT(prims.size() == primArgs.size());
|
||||
dnnl::stream stream(context->getEngine());
|
||||
for (size_t i = 0; i < prims.size(); ++i)
|
||||
prims.at(i).execute(stream, primArgs.at(i));
|
||||
stream.wait();
|
||||
|
||||
record.time = timeit(
|
||||
[&]() {
|
||||
for (size_t i = 0; i < prims.size(); ++i)
|
||||
prims.at(i).execute(stream, primArgs.at(i));
|
||||
},
|
||||
[&]() { stream.wait(); });
|
||||
|
||||
// Update the tune result
|
||||
if (ret.time > record.time)
|
||||
ret = record;
|
||||
}
|
||||
|
||||
IT_ASSERT(ret.time < std::numeric_limits<double>::max(), "No valid "
|
||||
"algorithm "
|
||||
"found");
|
||||
return make_ref<ConvTransposeMklPerfRecordObj>(ret);
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::MKL, OpType::ConvTrans, DataType::Float32,
|
||||
MklConvTranspose, "MklConvTrans_CPU_float32");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,38 @@
|
|||
#include "operators/matmul.h"
|
||||
#include "core/kernel.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <typename T> class MklMatmul : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<MatmulObj>(_op);
|
||||
IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet.");
|
||||
const T *A = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
const T *B = op->getInputs(1)->getRawDataPtr<T *>();
|
||||
T *C = op->getOutput()->getRawDataPtr<T *>();
|
||||
IT_ASSERT(op->getAct() == ActType::None);
|
||||
const int m = op->getM(), n = op->getN(), k = op->getK(),
|
||||
b = op->getB();
|
||||
|
||||
auto opA = op->getTransA() ? CblasTrans : CblasNoTrans;
|
||||
auto opB = op->getTransB() ? CblasTrans : CblasNoTrans;
|
||||
// lda is always a.col, and ldb is always b.col when row major
|
||||
const int lda = std::max((opA == CblasNoTrans) ? k : m, 1);
|
||||
const int ldb = std::max((opB == CblasNoTrans) ? n : k, 1);
|
||||
const int ldc = std::max(n, 1);
|
||||
|
||||
const float alpha = 1.f, beta = 0.f;
|
||||
// TODO: Intel MKL ERROR will occur when using cblas_sgemm_batch
|
||||
for (int i = 0; i < b; ++i) {
|
||||
cblas_sgemm(CblasRowMajor, opA, opB, m, n, k, alpha, A + m * k * i,
|
||||
lda, B + k * n * i, ldb, beta, C + m * n * i, ldc);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::MKL, OpType::Matmul, DataType::Float32,
|
||||
MklMatmul<float>, "MklMatmul_CPU_float32");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,13 @@
|
|||
#include "mkl/mkl_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
namespace infini {
|
||||
MklRuntimeObj::MklRuntimeObj() : CpuRuntimeObj(Device::MKL) {
|
||||
dnnl_engine_create(&engine, dnnl_engine_kind_t::dnnl_cpu, 0);
|
||||
}
|
||||
|
||||
MklRuntimeObj::~MklRuntimeObj() {
|
||||
mkl_free_buffers();
|
||||
dnnl_engine_destroy(engine);
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,82 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "utils/data_generator.h"
|
||||
|
||||
namespace infini {
|
||||
namespace opTimer {
|
||||
|
||||
double getPerfConvMkl(int n, int c, int h, int w, int f, int r, int s, int padh,
|
||||
int padw, int strideh, int stridew, int dilationh,
|
||||
int dilationw, int group) {
|
||||
// const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew,
|
||||
// dilationh, dilationw, group] =
|
||||
// tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1};
|
||||
Runtime runtime = MklRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
IT_ASSERT(c % group == 0);
|
||||
Tensor i0 = g->addTensor({n, c, h, w}, DataType::Float32);
|
||||
Tensor w0 = g->addTensor({f, c / group, r, s}, DataType::Float32);
|
||||
auto conv = g->addOp<ConvObj>(i0, w0, nullptr, padh, padw, strideh, stridew,
|
||||
dilationh, dilationw);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
g->dataMalloc();
|
||||
i0->setData(IncrementalGenerator());
|
||||
w0->setData(IncrementalGenerator());
|
||||
|
||||
bool tune = true;
|
||||
runtime->run(g, tune);
|
||||
return runtime->getPerfTime(g);
|
||||
}
|
||||
|
||||
double getPerfConvTransposed2dMkl(int n, int c, int h, int w, int f, int r,
|
||||
int s, int padh, int padw, int strideh,
|
||||
int stridew, int dilationh, int dilationw,
|
||||
int oph, int opw, int group) {
|
||||
// const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew,
|
||||
// dilationh, dilationw, group] =
|
||||
// tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1};
|
||||
Runtime runtime = MklRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
IT_ASSERT(c % group == 0);
|
||||
Tensor i0 = g->addTensor({n, f, h, w}, DataType::Float32);
|
||||
Tensor w0 = g->addTensor({f, c / group, r, s}, DataType::Float32);
|
||||
auto conv = g->addOp<ConvTransposed2dObj>(i0, w0, nullptr, padh, padw,
|
||||
strideh, stridew, dilationh,
|
||||
dilationw, oph, opw, group);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
g->dataMalloc();
|
||||
i0->setData(IncrementalGenerator());
|
||||
w0->setData(IncrementalGenerator());
|
||||
|
||||
bool tune = true;
|
||||
runtime->run(g, tune);
|
||||
return runtime->getPerfTime(g);
|
||||
}
|
||||
|
||||
double getPerfMatmulMkl(int b, int m, int n, int k) {
|
||||
// const auto &[n, c, h, w, f, r, s, padh, padw, strideh, stridew,
|
||||
// dilationh, dilationw, group] =
|
||||
// tuple{1, 512, 14, 14, 512, 3, 3, 2, 2, 1, 1, 2, 2, 1};
|
||||
Runtime runtime = MklRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0 = g->addTensor({b, m, k}, DataType::Float32);
|
||||
Tensor w0 = g->addTensor({b, k, n}, DataType::Float32);
|
||||
auto conv = g->addOp<MatmulObj>(i0, w0, nullptr);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
g->dataMalloc();
|
||||
i0->setData(IncrementalGenerator());
|
||||
w0->setData(IncrementalGenerator());
|
||||
|
||||
bool tune = true;
|
||||
runtime->run(g, tune);
|
||||
return runtime->getPerfTime(g);
|
||||
}
|
||||
|
||||
} // namespace opTimer
|
||||
} // namespace infini
|
|
@ -607,4 +607,4 @@ double NMutator::memboundTime(const Shape &dims) {
|
|||
// return graph;
|
||||
// }
|
||||
|
||||
} // namespace infini
|
||||
} // namespace infini
|
||||
|
|
|
@ -38,7 +38,7 @@ bool GatherObj::CheckIndexValid() const {
|
|||
if (index->getDataBlob() == nullptr)
|
||||
return true;
|
||||
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
int *data = (int *)runtime->alloc(index->getBytes());
|
||||
index->getRuntime()->copyBlobToCPU(
|
||||
(void *)data, index->getRawDataPtr<void *>(), index->getBytes());
|
||||
|
|
|
@ -57,7 +57,7 @@ void ResizeObj::init(const Tensor &input, const Tensor &sizes,
|
|||
this->roi.emplace_back(1);
|
||||
}
|
||||
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
std::shared_ptr<float> dataObj((float *)runtime->alloc(roi->getBytes()),
|
||||
[&](float *p) { runtime->dealloc(p); });
|
||||
auto data = dataObj.get();
|
||||
|
@ -117,7 +117,7 @@ void ResizeObj::InitBySizes(Tensor input, Tensor sizes,
|
|||
|
||||
// copy sizes data to host.
|
||||
IT_ASSERT(sizes->getDataBlob() != nullptr);
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
std::shared_ptr<int> dataObj((int *)runtime->alloc(sizes->getBytes()),
|
||||
[&](int *p) { runtime->dealloc(p); });
|
||||
auto data = dataObj.get();
|
||||
|
@ -166,7 +166,7 @@ void ResizeObj::InitByScales(Tensor input, Tensor scales,
|
|||
|
||||
// copy scales data to host.
|
||||
IT_ASSERT(scales->getDataBlob() != nullptr);
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
std::shared_ptr<float> dataObj((float *)runtime->alloc(scales->getBytes()),
|
||||
[&](float *p) { runtime->dealloc(p); });
|
||||
auto data = dataObj.get();
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(Graph, build_and_run) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
|
@ -38,7 +38,7 @@ TEST(Graph, build_and_run) {
|
|||
}
|
||||
|
||||
TEST(Graph, topological) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor a = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor b = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
|
@ -77,7 +77,7 @@ TEST(Graph, topological) {
|
|||
} // namespace infini
|
||||
|
||||
TEST(Graph, perf_engine) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
|
@ -99,7 +99,7 @@ TEST(Graph, perf_engine) {
|
|||
}
|
||||
|
||||
TEST(Graph, test_tensor_id) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
|
@ -117,7 +117,7 @@ TEST(Graph, test_tensor_id) {
|
|||
}
|
||||
|
||||
TEST(Graph, test_OpVec_ctor) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(Handler, matmul) {
|
||||
auto runtime = CpuRuntimeObj::getInstance();
|
||||
auto runtime = NativeCpuRuntimeObj::getInstance();
|
||||
auto handler = make_ref<GraphHandlerObj>(runtime);
|
||||
auto i = handler->tensor({1, 2, 3}, OnnxDType::UINT32);
|
||||
auto w = handler->tensor({1, 3, 4}, OnnxDType::UINT32);
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
namespace infini {
|
||||
|
||||
// TEST(Graph, search) {
|
||||
// Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
// Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
// Graph g = make_ref<GraphObj>(runtime);
|
||||
// Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
// Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
|
@ -30,7 +30,7 @@ namespace infini {
|
|||
// }
|
||||
|
||||
TEST(Graph, search_withdm) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor t0 = g->addTensor({1, 3, 224, 224});
|
||||
Tensor w0 = g->addTensor({3, 3, 3, 3});
|
||||
|
@ -53,7 +53,7 @@ TEST(Graph, search_withdm) {
|
|||
}
|
||||
|
||||
// TEST(DummyMutator, run) {
|
||||
// Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
// Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
// Graph g = make_ref<GraphObj>(runtime);
|
||||
// Tensor i0 = g->addTensor({1, 3, 224, 224});
|
||||
// Tensor w0 = g->addTensor({2, 3, 3, 3});
|
||||
|
@ -67,7 +67,7 @@ TEST(Graph, search_withdm) {
|
|||
// }
|
||||
|
||||
// TEST(DummyMutator, fuse) {
|
||||
// Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
// Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
// Graph g = make_ref<GraphObj>(runtime);
|
||||
// Tensor i0 = g->addTensor({1, 2, 3});
|
||||
// Tensor w0 = g->addTensor({1, 3, 4});
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(Prtotbuf, save_and_load) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 3, 4}, DataType::Float32);
|
||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::Float32);
|
||||
|
|
|
@ -15,7 +15,7 @@ void testBangcKernel(
|
|||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
|
|
@ -13,7 +13,7 @@ template <class T>
|
|||
void testElementWiseCnnl(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const ExpectOutput &ansVec) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
|
|
@ -13,7 +13,7 @@ void testOptensor(
|
|||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto bangRuntime = make_ref<BangRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
|
|
@ -13,7 +13,7 @@ using ExpectOutput = vector<float>;
|
|||
TEST(CUDA_G2BMM, ShapeInference) {
|
||||
const int bs = 1, seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4;
|
||||
const int hidden = featlen, hiddenPerHead = hidden / heads;
|
||||
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
auto ACpu = gCpu->addTensor(Shape{bs * heads, seqlen, hiddenPerHead},
|
||||
DataType::Float32);
|
||||
|
|
|
@ -12,7 +12,7 @@ using ExpectOutput = vector<float>;
|
|||
TEST(CUDA_GBMM, ShapeInference) {
|
||||
const int bs = 1, seqlen = 10000, w = 1000, featlen = 512, heads = 8, d = 4;
|
||||
const int hidden = featlen, hiddenPerHead = hidden / heads;
|
||||
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
auto ACpu = gCpu->addTensor(Shape{bs * heads, seqlen, w * 2 + 1},
|
||||
DataType::Float32);
|
||||
|
|
|
@ -8,16 +8,16 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(CUDA_BatchNorm, run) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build cpu graph
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
auto iCpu = gCpu->addTensor(Shape{1, 3, 2, 2}, DataType::Float32);
|
||||
auto meanCpu = gCpu->addTensor(Shape{1, 3, 1, 1}, DataType::Float32);
|
||||
auto varCpu = gCpu->addTensor(Shape{1, 3, 1, 1}, DataType::Float32);
|
||||
auto scaleCpu = gCpu->addTensor(Shape{1, 3, 1, 1}, DataType::Float32);
|
||||
auto biasCpu = gCpu->addTensor(Shape{1, 3, 1, 1}, DataType::Float32);
|
||||
auto meanCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||
auto varCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||
auto scaleCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||
auto biasCpu = gCpu->addTensor(Shape{3}, DataType::Float32);
|
||||
|
||||
// Build input data on CPU
|
||||
gCpu->dataMalloc();
|
||||
|
|
|
@ -44,7 +44,7 @@ TEST(Concat, OffsetTrans) {
|
|||
}
|
||||
*/
|
||||
TEST(Concat, Cuda) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto t1 = gCpu->addTensor({2, 2, 3, 1}, DataType::Float32);
|
||||
|
|
|
@ -13,7 +13,7 @@ void testConvCudnn(
|
|||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
|
@ -52,7 +52,7 @@ TEST(cuDNN_Conv, run) {
|
|||
}
|
||||
|
||||
TEST(cuDNN_Conv, tune) {
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
|
|
|
@ -16,7 +16,7 @@ void testConvTransposedCudnn(
|
|||
const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4};
|
||||
const int stride = 1, padding = 0, dilation = 1;
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
|
@ -50,7 +50,7 @@ void testConvTransposedNHWCCudnn(
|
|||
const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 2, 4, 4};
|
||||
const int stride = 1, padding = 0, dilation = 1;
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
|
@ -94,8 +94,42 @@ TEST(cuDNN_ConvTransposedNHWC, run) {
|
|||
465, 487, 509, 307});
|
||||
}
|
||||
|
||||
TEST(cuDNN_ConvTransposed, run1) {
|
||||
// Construct Runtime and graph for CPU and CUDA
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({1, 2, 3, 3}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({2, 2, 3, 3}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to CUDA
|
||||
Tensor i0Cuda = gCuda->cloneTensor(i0Cpu);
|
||||
Tensor w0Cuda = gCuda->cloneTensor(w0Cpu);
|
||||
// Build CUDA graph
|
||||
auto conv =
|
||||
gCuda->addOp<ConvTransposed2dObj>(i0Cuda, w0Cuda, nullptr, 0, 0);
|
||||
gCuda->dataMalloc();
|
||||
// Execute on CUDA
|
||||
cuda->run(gCuda);
|
||||
// copy output from CUDA to CPU
|
||||
auto o0Cpu = gCpu->cloneTensor(conv->getOutput());
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o0Cpu->equalData(vector<float>{
|
||||
162, 351, 569, 413, 224, 405, 876, 1417, 1024, 553,
|
||||
747, 1611, 2598, 1869, 1005, 639, 1368, 2191, 1564, 835,
|
||||
396, 843, 1343, 953, 506, 243, 531, 866, 629, 341,
|
||||
621, 1344, 2173, 1564, 841, 1152, 2475, 3975, 2841, 1518,
|
||||
963, 2052, 3271, 2320, 1231, 585, 1239, 1964, 1385, 731}));
|
||||
}
|
||||
|
||||
TEST(cuDNN_ConvTransposed, tune) {
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
Graph gCuda = make_ref<GraphObj>(cuda);
|
||||
|
@ -117,8 +151,6 @@ TEST(cuDNN_ConvTransposed, tune) {
|
|||
// Execute on CUDA
|
||||
bool tune = true;
|
||||
cuda->run(gCuda, tune);
|
||||
// print a tensor/operator/graph by print()
|
||||
gCuda->print();
|
||||
// check record
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{Device::CUDA, conv->getOpType(), DataType::Float32};
|
||||
|
|
|
@ -14,7 +14,7 @@ template <class T>
|
|||
void testElementWiseCudnn(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const ExpectOutput &ansVec) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(CUDA_Extend, run) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
|
|
@ -176,7 +176,7 @@ TEST(Gather, offsetTrans) {
|
|||
|
||||
TEST(Gather, Cuda) {
|
||||
{
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
auto input = gCpu->addTensor({3, 2}, DataType::Float32);
|
||||
auto index = gCpu->addTensor({2, 2}, DataType::UInt32);
|
||||
|
@ -197,7 +197,7 @@ TEST(Gather, Cuda) {
|
|||
EXPECT_TRUE(oCpu->equalData(vector<float>{1, 2, 3, 4, 3, 4, 5, 6}));
|
||||
}
|
||||
{
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
auto input = gCpu->addTensor({3, 3}, DataType::Float32);
|
||||
auto index = gCpu->addTensor({1, 2}, DataType::UInt32);
|
||||
|
@ -218,7 +218,7 @@ TEST(Gather, Cuda) {
|
|||
EXPECT_TRUE(oCpu->equalData(vector<float>{0, 2, 3, 5, 6, 8}));
|
||||
}
|
||||
{
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
auto input = gCpu->addTensor({2, 4, 2}, DataType::Float32);
|
||||
auto index = gCpu->addTensor({3, 1}, DataType::UInt32);
|
||||
|
|
|
@ -16,7 +16,7 @@ void testMatmulCuda(
|
|||
const std::function<void(void *, size_t, DataType)> &generatorB,
|
||||
bool transA, bool transB, const Shape &shapeA, const Shape &shapeB,
|
||||
const ExpectOutput &ansVec) {
|
||||
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
auto ACpu = gCpu->addTensor(shapeA, DataType::Float32);
|
||||
auto BCpu = gCpu->addTensor(shapeB, DataType::Float32);
|
||||
|
@ -54,7 +54,7 @@ TEST(cuBLAS_Matmul, run) {
|
|||
}
|
||||
|
||||
TEST(cuBLAS_Matmul, tune) {
|
||||
auto cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
auto cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
auto ACpu = gCpu->addTensor(Shape{1, 3, 5}, DataType::Float32);
|
||||
auto BCpu = gCpu->addTensor(Shape{1, 5, 2}, DataType::Float32);
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
namespace infini {
|
||||
TEST(Pad, Cuda) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
|
|
@ -14,7 +14,7 @@ void testPoolCudnn(
|
|||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const KDPS &kdps, const ExpectOutput &ansVec) {
|
||||
EXPECT_TRUE(kdps.size() == 8);
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
|
|
@ -12,7 +12,7 @@ namespace infini {
|
|||
void test_reducemean(const Shape &shape, const vector<float> &data,
|
||||
const optional<const vector<int>> &axis, bool keepDims,
|
||||
const vector<float> &ExpectData) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
|
|
@ -10,7 +10,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(CUDA_Reshape, run) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
@ -39,7 +39,7 @@ TEST(CUDA_Reshape, run) {
|
|||
}
|
||||
|
||||
TEST(CUDA_Flatten, run) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
@ -68,7 +68,7 @@ TEST(CUDA_Flatten, run) {
|
|||
}
|
||||
|
||||
TEST(CUDA_Identity, run) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
#include "test.h"
|
||||
namespace infini {
|
||||
TEST(Resize, Cuda_downsample_sizes_nearest) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
|
@ -32,7 +32,7 @@ TEST(Resize, Cuda_downsample_sizes_nearest) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_sizes_nearest_notlarger) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
|
@ -62,7 +62,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_notlarger) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_sizes_nearest_notsmaller) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
|
@ -92,7 +92,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_notsmaller) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_sizes_nearest_ceil_half_pixel) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
@ -125,7 +125,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_ceil_half_pixel) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_sizes_nearest_floor_align_corners) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
@ -158,7 +158,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_floor_align_corners) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_sizes_nearest_round_prefer_ceil_asymmetri) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
@ -191,7 +191,7 @@ TEST(Resize, Cuda_upsample_sizes_nearest_round_prefer_ceil_asymmetri) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_downsample_scales_nearest) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
|
@ -215,7 +215,7 @@ TEST(Resize, Cuda_downsample_scales_nearest) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_nearest) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
|
@ -241,7 +241,7 @@ TEST(Resize, Cuda_upsample_scales_nearest) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_nearest_axes_3_2) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
|
@ -267,7 +267,7 @@ TEST(Resize, Cuda_upsample_scales_nearest_axes_3_2) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_downsample_scales_linear) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
|
@ -291,7 +291,7 @@ TEST(Resize, Cuda_downsample_scales_linear) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_downsample_scales_linear_aligncorners) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
|
@ -317,7 +317,7 @@ TEST(Resize, Cuda_downsample_scales_linear_aligncorners) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_linear) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
|
@ -343,7 +343,7 @@ TEST(Resize, Cuda_upsample_scales_linear) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_linear_align_corners) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 2}, DataType::Float32);
|
||||
|
@ -371,7 +371,7 @@ TEST(Resize, Cuda_upsample_scales_linear_align_corners) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_downsample_sizes_linear_pytorchhalfpixel) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
@ -399,7 +399,7 @@ TEST(Resize, Cuda_downsample_sizes_linear_pytorchhalfpixel) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_tf_crop_and_resize) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
@ -430,7 +430,7 @@ TEST(Resize, Cuda_tf_crop_and_resize) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_tf_crop_and_resize_axes_3_2) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
@ -461,7 +461,7 @@ TEST(Resize, Cuda_tf_crop_and_resize_axes_3_2) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_downsample_scales_cubic) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
@ -488,7 +488,7 @@ TEST(Resize, Cuda_downsample_scales_cubic) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_downsample_scales_cubic_align_corners) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
@ -516,7 +516,7 @@ TEST(Resize, Cuda_downsample_scales_cubic_align_corners) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_cubic) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
@ -553,7 +553,7 @@ TEST(Resize, Cuda_upsample_scales_cubic) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_cubic_align_corners) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
@ -592,7 +592,7 @@ TEST(Resize, Cuda_upsample_scales_cubic_align_corners) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_scales_cubic_asymmetric) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
@ -627,7 +627,7 @@ TEST(Resize, Cuda_upsample_scales_cubic_asymmetric) {
|
|||
|
||||
//
|
||||
TEST(Resize, Cuda_downsample_sizes_cubic) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
@ -661,7 +661,7 @@ TEST(Resize, Cuda_downsample_sizes_cubic) {
|
|||
}
|
||||
|
||||
TEST(Resize, Cuda_upsample_sizes_cubic) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 4, 4}, DataType::Float32);
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
namespace infini {
|
||||
TEST(CUDA_Slice, run) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(Split, Cuda) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = gCpu->addTensor({2, 10, 2, 1}, DataType::Float32);
|
||||
|
|
|
@ -13,7 +13,7 @@ template <class T>
|
|||
void testUnary(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(PerfEngine, save_and_load) {
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime cuda = make_ref<CudaRuntimeObj>();
|
||||
{ // Conv
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#include "core/runtime.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void testConvDnnl(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
auto mklRuntime = MklRuntimeObj::getInstance();
|
||||
Graph gMkl = make_ref<GraphObj>(mklRuntime);
|
||||
|
||||
Tensor i0 = gMkl->addTensor({1, 3, 4, 4}, DataType::Float32);
|
||||
Tensor w0 = gMkl->addTensor({2, 3, 3, 3}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph.
|
||||
gMkl->dataMalloc();
|
||||
i0->setData(generator);
|
||||
w0->setData(generator);
|
||||
|
||||
// Build graph
|
||||
auto conv = gMkl->addOp<ConvObj>(i0, w0, nullptr, 1, 1, 2, 1, 1, 2);
|
||||
// allocate CUDA memory
|
||||
gMkl->dataMalloc();
|
||||
// Execute on CUDA
|
||||
mklRuntime->run(gMkl);
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(conv->getOutput(0)->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(dnnl_Conv, run) {
|
||||
testConvDnnl(OneGenerator(), vector<float>{12, 12, 18, 18, 12, 12, 18, 18});
|
||||
testConvDnnl(
|
||||
IncrementalGenerator(),
|
||||
vector<float>{4794, 4386, 8199, 7506, 11274, 10542, 20835, 19656});
|
||||
}
|
||||
|
||||
TEST(mkl_Conv, tune) {
|
||||
auto mklRuntime = MklRuntimeObj::getInstance();
|
||||
Graph gMkl = make_ref<GraphObj>(mklRuntime);
|
||||
|
||||
Tensor i0 = gMkl->addTensor({1, 3, 224, 224}, DataType::Float32);
|
||||
Tensor w0 = gMkl->addTensor({2, 3, 3, 3}, DataType::Float32);
|
||||
auto conv = gMkl->addOp<ConvObj>(i0, w0, nullptr, 1, 1, 1, 1, 1, 1);
|
||||
gMkl->dataMalloc();
|
||||
|
||||
i0->setData(IncrementalGenerator());
|
||||
w0->setData(IncrementalGenerator());
|
||||
|
||||
// Execute on CUDA
|
||||
bool tune = true;
|
||||
mklRuntime->run(gMkl, tune);
|
||||
|
||||
// check record
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{Device::MKL, conv->getOpType(), DataType::Float32};
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
|
||||
std::optional<PerfRecord> perfData =
|
||||
PerfEngine::getInstance().getPerfData(perfKey);
|
||||
ASSERT_TRUE(perfData.has_value());
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,84 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void testConvTransposedMkl(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
vector<float> ansVec) {
|
||||
const auto &[N, C, H, W, F, R, S] = tuple{1, 1, 2, 2, 1, 4, 4};
|
||||
const int stride = 1, padding = 0, dilation = 1;
|
||||
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph gMkl = make_ref<GraphObj>(runtime);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0 = gMkl->addTensor({N, F, H, H}, DataType::Float32);
|
||||
Tensor w0 = gMkl->addTensor({F, C, R, S}, DataType::Float32);
|
||||
auto conv = gMkl->addOp<ConvTransposed2dObj>(
|
||||
i0, w0, nullptr, padding, padding, stride, stride, dilation, dilation);
|
||||
|
||||
gMkl->dataMalloc();
|
||||
i0->setData(generator);
|
||||
w0->setData(generator);
|
||||
|
||||
runtime->prepareAndRun(gMkl);
|
||||
EXPECT_TRUE(conv->getOutput()->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(mkl_ConvTransposed, run) {
|
||||
testConvTransposedMkl(IncrementalGenerator(),
|
||||
vector<float>{0., 0., 1., 2., 3., 0., 6.,
|
||||
12., 18., 16., 8., 30., 36., 42.,
|
||||
32., 16., 54., 60., 66., 48., 24.,
|
||||
62., 67., 72., 45.});
|
||||
}
|
||||
|
||||
TEST(mkl_ConvTransposed, run1) {
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph gMkl = make_ref<GraphObj>(runtime);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0 = gMkl->addTensor({1, 2, 3, 3}, DataType::Float32);
|
||||
Tensor w0 = gMkl->addTensor({2, 2, 3, 3}, DataType::Float32);
|
||||
auto conv = gMkl->addOp<ConvTransposed2dObj>(i0, w0, nullptr, 0, 0);
|
||||
|
||||
gMkl->dataMalloc();
|
||||
i0->setData(IncrementalGenerator());
|
||||
w0->setData(IncrementalGenerator());
|
||||
|
||||
runtime->prepareAndRun(gMkl);
|
||||
EXPECT_TRUE(conv->getOutput()->equalData(vector<float>{
|
||||
162, 351, 569, 413, 224, 405, 876, 1417, 1024, 553,
|
||||
747, 1611, 2598, 1869, 1005, 639, 1368, 2191, 1564, 835,
|
||||
396, 843, 1343, 953, 506, 243, 531, 866, 629, 341,
|
||||
621, 1344, 2173, 1564, 841, 1152, 2475, 3975, 2841, 1518,
|
||||
963, 2052, 3271, 2320, 1231, 585, 1239, 1964, 1385, 731}));
|
||||
}
|
||||
|
||||
TEST(mkl_ConvTransposed, tune) {
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph gMkl = make_ref<GraphObj>(runtime);
|
||||
|
||||
Tensor i0 = gMkl->addTensor({1, 448, 2, 2}, DataType::Float32);
|
||||
Tensor w0 = gMkl->addTensor({448, 256, 4, 4}, DataType::Float32);
|
||||
auto conv = gMkl->addOp<ConvTransposed2dObj>(i0, w0, nullptr);
|
||||
gMkl->dataMalloc();
|
||||
i0->setData(IncrementalGenerator());
|
||||
w0->setData(IncrementalGenerator());
|
||||
|
||||
bool tune = true;
|
||||
runtime->prepareAndRun(gMkl, tune);
|
||||
// check record
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{Device::MKL, conv->getOpType(), DataType::Float32};
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
|
||||
std::optional<PerfRecord> perfData =
|
||||
PerfEngine::getInstance().getPerfData(perfKey);
|
||||
ASSERT_TRUE(perfData.has_value());
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,44 @@
|
|||
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
#include "operators/matmul.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
using ExpectOutput = vector<float>;
|
||||
|
||||
void testMatmulMkl(
|
||||
const std::function<void(void *, size_t, DataType)> &generatorA,
|
||||
const std::function<void(void *, size_t, DataType)> &generatorB,
|
||||
bool transA, bool transB, const Shape &shapeA, const Shape &shapeB,
|
||||
const ExpectOutput &ansVec) {
|
||||
auto cpuRuntime = MklRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
auto ACpu = gCpu->addTensor(shapeA, DataType::Float32);
|
||||
auto BCpu = gCpu->addTensor(shapeB, DataType::Float32);
|
||||
gCpu->dataMalloc();
|
||||
ACpu->setData(generatorA);
|
||||
BCpu->setData(generatorB);
|
||||
|
||||
auto matmul = gCpu->addOp<MatmulObj>(ACpu, BCpu, nullptr, transA, transB);
|
||||
|
||||
gCpu->dataMalloc();
|
||||
cpuRuntime->run(gCpu);
|
||||
matmul->getOutput()->printData();
|
||||
EXPECT_TRUE(matmul->getOutput()->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(mkl_Matmul, run) {
|
||||
testMatmulMkl(IncrementalGenerator(), OneGenerator(), false, false,
|
||||
Shape{1, 3, 5}, Shape{1, 5, 2},
|
||||
ExpectOutput{10, 10, 35, 35, 60, 60});
|
||||
testMatmulMkl(IncrementalGenerator(), IncrementalGenerator(), true, false,
|
||||
Shape{2, 3, 4}, Shape{2, 3, 2},
|
||||
ExpectOutput{40, 52, 46, 61, 52, 70, 58, 79, 400, 448, 424,
|
||||
475, 448, 502, 472, 529});
|
||||
}
|
||||
|
||||
}; // namespace infini
|
|
@ -12,7 +12,7 @@ using namespace infini;
|
|||
using namespace std;
|
||||
|
||||
TEST(nnet, MemboundOpInterpretation) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32);
|
||||
Tensor w0 = g->addTensor({1, 3, 4}, DataType::UInt32);
|
||||
|
@ -42,7 +42,7 @@ TEST(nnet, MemboundOpInterpretation) {
|
|||
|
||||
TEST(nnet, MemboundOp_Ansor_Codegen) {
|
||||
auto runtime = make_ref<CudaRuntimeObj>();
|
||||
Runtime cpu = CpuRuntimeObj::getInstance();
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 2, 3}, DataType::Float32);
|
||||
|
@ -76,4 +76,4 @@ TEST(nnet, MemboundOp_Ansor_Codegen) {
|
|||
// Timing
|
||||
// double time = timeit([&]() { runtime->run(gNew, false); }); // tune
|
||||
// kernels std::cout << "Time (ms):" << time << std::endl;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ namespace infini {
|
|||
TEST(Mutator, NaiveConvWithInterpreter) {
|
||||
// verifyNaiveMembound True: subgraph after transformation
|
||||
// verifyNaiveMembound False: subgraph of one single membound (eOP)
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
// const bool verifyNaiveMembound = false;
|
||||
|
||||
|
@ -61,7 +61,7 @@ TEST(Mutator, InfoGAN_TConv_3_correctness) {
|
|||
// const bool verifyNaiveMembound = false;
|
||||
Runtime runtime = make_ref<CudaRuntimeObj>();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Runtime cpu = CpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
|
||||
// {n, h, w, f} * {f, r, s, c}
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
namespace infini {
|
||||
TEST(BatchNorm, ShapeInference) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 3, 2, 2}, DataType::UInt32);
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
namespace infini {
|
||||
TEST(Concat, ShapeInfer) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
auto t1 = g->addTensor({1, 3, 2, 4}, DataType::Float32);
|
||||
auto t2 = g->addTensor({1, 3, 2, 5}, DataType::Float32);
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(Conv, ShapeInference) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
// Padding modes
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
@ -43,7 +43,7 @@ TEST(Conv, ShapeInference) {
|
|||
}
|
||||
|
||||
TEST(Conv, NaiveCPU) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 3, 4, 4}, DataType::UInt32);
|
||||
Tensor w0 = g->addTensor({2, 3, 3, 3}, DataType::UInt32);
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(ConvTransposed, ShapeInference) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{ // No pad: InfoGAN ConvTranspose_0
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({1, 228, 1, 1});
|
||||
|
|
|
@ -9,7 +9,7 @@ namespace infini {
|
|||
|
||||
using ExpectOutput = vector<float>;
|
||||
TEST(ElementWise, ShapeInference) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i0 = g->addTensor({2, 3, 3, 4}, DataType::UInt32);
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(Extend, ShapeInference) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(Gather, ShapeInference) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({1, 3, 4, 4}, DataType::UInt32);
|
||||
|
|
|
@ -10,7 +10,7 @@ namespace infini {
|
|||
using ExpectOutput = vector<float>;
|
||||
|
||||
TEST(Matmul, ShapeInference) {
|
||||
auto runtime = CpuRuntimeObj::getInstance();
|
||||
auto runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
auto A = g->addTensor(Shape{1, 3, 5});
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
namespace infini {
|
||||
TEST(Pad, ShapeInference) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32);
|
||||
|
|
|
@ -7,7 +7,7 @@ namespace infini {
|
|||
using KDPS = vector<int>;
|
||||
using ExpectOutput = vector<float>;
|
||||
TEST(MaxPool, ShapeInference) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 64, 162, 162}, DataType::UInt32);
|
||||
|
@ -27,7 +27,7 @@ TEST(MaxPool, ShapeInference) {
|
|||
}
|
||||
|
||||
TEST(MaxPool, NaiveCPU) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 2, 5, 5}, DataType::UInt32);
|
||||
auto op = g->addOp<MaxPoolObj>(i, nullptr, 3, 3, 1, 1, 1, 1, 2, 2);
|
||||
|
@ -46,7 +46,7 @@ TEST(MaxPool, NaiveCPU) {
|
|||
}
|
||||
|
||||
TEST(AvgPool, NaiveCPU) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({1, 2, 5, 5}, DataType::Float32);
|
||||
auto op = g->addOp<AvgPoolObj>(i, nullptr, 3, 3, 1, 1, 1, 1, 2, 2);
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(ReduceMean, ShapeInference) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
namespace infini {
|
||||
|
||||
TEST(Reshape, ShapeInference) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
|
@ -17,7 +17,7 @@ TEST(Reshape, ShapeInference) {
|
|||
}
|
||||
}
|
||||
TEST(Flatten, ShapeInference) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
|
@ -27,7 +27,7 @@ TEST(Flatten, ShapeInference) {
|
|||
}
|
||||
|
||||
TEST(Identity, ShapeInference) {
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
namespace infini {
|
||||
TEST(Resize, ShapeInference) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
// downsample_sizes_nearest no axes
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
|
||||
namespace infini {
|
||||
TEST(Slice, ShapeInference) {
|
||||
Runtime cpuRuntime = CpuRuntimeObj::getInstance();
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(cpuRuntime);
|
||||
Tensor i = g->addTensor({10, 64, 162, 162}, DataType::UInt32);
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
namespace infini {
|
||||
TEST(Split, ShapeInfer) {
|
||||
{
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32);
|
||||
|
||||
|
@ -21,7 +21,7 @@ TEST(Split, ShapeInfer) {
|
|||
}
|
||||
|
||||
{
|
||||
Runtime runtime = CpuRuntimeObj::getInstance();
|
||||
Runtime runtime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
auto input = g->addTensor({1, 3, 2, 15}, DataType::Float32);
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
. /home/spack/spack/share/spack/setup-env.sh
|
||||
spack load cuda@11.0.2 cudnn@8.0.3.33-11.0
|
||||
spack load cuda@11.0.2 cudnn@8.0.3.33-11.0 intel-oneapi-dnn@2022.1.0 intel-oneapi-mkl@2022.1.0
|
||||
export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc
|
||||
# The default dnnl library is cpu_dpcpp_gpu_dpcpp which requires libsycl.so, after "spack load", and need to change to gomp explicitly.
|
||||
export LD_LIBRARY_PATH=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-dnn-2022.1.0-7rs6ht57zozyxhxx6s2qlrqzmqknhgzx/dnnl/2022.1.0/cpu_gomp/lib/:$LD_LIBRARY_PATH
|
||||
|
|
Loading…
Reference in New Issue