forked from jiuyuan/InfiniTensor
Cpu backend2 (#77)
fix review change Device::MKL to Device::INTELCPU fix mkl linkage fix errors according to merge from master now can call mkl backend fix softmax/flatten with axis from onnx. modify README.md fix memory refree add env_lotus_intelcpu.sh fix compile merge from branch cpu_backend fix something add gather fix something FIX: directory rename from "mkl" to "intelcpu" ADD: use oneMKL dpcpp interface to implement matmul kernel. ADD: add dpcpp as compiler for mkl, and fix warnings for clang compiling. add dpcpp kernel for pow. ADD: mkl kernel for pad. ADD: slice mkl kernel. ADD: reshape/flatten/identity mkl kernel. ADD: split mkl kernel. fix compile error FIX: fix flattenObj with axis. ADD reduce_mean mkl kernel. Add concat mkl kernel. bathNorm for mkl kernel. sigmoid mkl kernel. ADD:add mkl kernel for pooling add more tests for softmax Now softmax cuda kernel supports any axises. mkl kernel for softmax softmax add axis to softmax operator add mkl kernel for abs tanh ADD: relu kernel for mkl fix binary mkl primitives. add mkl kernel for binary operators fix compiler error move stream to runtime clang format add MemoryFormat for tensorObj. use post_ops for fused conv/deconv Distinguish mkl op_timer from cuda op timer. add act optype to conv and deconv add operator timer add mkl kernel for convTransposed minor fix for group conv do not use cblas_sgemm_batch CpuRuntimeObj->NativeCpuRuntimeObj add matmul op for mkl
This commit is contained in:
parent
fe1afe38fa
commit
c8b2c8ed32
|
@ -5,7 +5,7 @@ project(InfiniTensor C CXX)
|
|||
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
|
||||
option(USE_CUDA "Support CUDA GPU" OFF)
|
||||
option(USE_BANG "Support BANG MLU" OFF)
|
||||
option(USE_MKL "Support MKL" OFF)
|
||||
option(USE_INTELCPU "Support INTELCPU" OFF)
|
||||
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON)
|
||||
option(USE_PROTOBUF "Serialize and deserialize tensors" ON)
|
||||
option(BUILD_TEST "Build tests" ON)
|
||||
|
@ -19,10 +19,6 @@ set(DEFAULT_BUILD_TYPE "RelWithDebInfo")
|
|||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF) # -std=gnu++11 when on, -std=c++11 when off
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
||||
|
||||
find_package(
|
||||
Python
|
||||
COMPONENTS Interpreter Development
|
||||
|
@ -35,6 +31,20 @@ endif()
|
|||
if(OpenMP_CXX_FOUND)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
|
||||
endif()
|
||||
|
||||
|
||||
if(BUILD_TEST)
|
||||
set(BUILD_GMOCK
|
||||
OFF
|
||||
CACHE BOOL "Do not build gmock" FORCE)
|
||||
set(INSTALL_GTEST
|
||||
OFF
|
||||
CACHE BOOL "Do not install gtest" FORCE)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall ")
|
||||
add_subdirectory(3rd-party/googletest)
|
||||
include_directories(SYSTEM 3rd-party/googletest/googletest/include)
|
||||
endif()
|
||||
|
||||
#Protobuf
|
||||
if(USE_PROTOBUF)
|
||||
add_definitions(-D TENSOR_PROTOBUF)
|
||||
|
@ -47,14 +57,12 @@ if(USE_PROTOBUF)
|
|||
set(PROTO_PATH "${CMAKE_CURRENT_SOURCE_DIR}/proto")
|
||||
file(GLOB PROTO_FILES "${PROTO_PATH}/data.proto")
|
||||
protobuf_generate_cpp(PROTO_SRCS PROTO_HDRS ${PROTO_FILES})
|
||||
message(${PROTO_SRCS} "-----------" ${PROTO_FILES})
|
||||
message(${PROTO_HDRS} "-----------" ${PROTO_FILES})
|
||||
set_source_files_properties (${PROTO_SRCS} PROPERTIES COMPILE_FLAGS -Wno-unused-variable)
|
||||
add_library(tensor_proto SHARED ${PROTO_SRCS} ${PROTO_HDRS})
|
||||
target_link_libraries(tensor_proto PUBLIC ${PROTOBUF_LIBRARIES})
|
||||
endif()
|
||||
|
||||
include_directories(include)
|
||||
|
||||
# Pybind11
|
||||
add_subdirectory(3rd-party/pybind11)
|
||||
include_directories(3rd-party/pybind11/include)
|
||||
|
@ -63,16 +71,9 @@ include_directories(3rd-party/pybind11/include)
|
|||
add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent)
|
||||
include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include)
|
||||
|
||||
if(BUILD_TEST)
|
||||
set(BUILD_GMOCK
|
||||
OFF
|
||||
CACHE BOOL "Do not build gmock" FORCE)
|
||||
set(INSTALL_GTEST
|
||||
OFF
|
||||
CACHE BOOL "Do not install gtest" FORCE)
|
||||
add_subdirectory(3rd-party/googletest)
|
||||
include_directories(3rd-party/googletest/googletest/include)
|
||||
endif()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g -Wall -Werror -Wno-error=deprecated-declarations")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -UNDEBUG") # Enable assertion
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -UNDEBUG") # Enable assertion
|
||||
|
||||
# Source files
|
||||
file(GLOB_RECURSE SRC src/ffi/*.cc src/core/*.cc src/kernels/cpu/*.cc src/nnet/*.cc src/operators/*.cc src/utils/*.cc)
|
||||
|
@ -87,9 +88,9 @@ if(USE_BANG)
|
|||
list (APPEND SRC ${SRC_BANG})
|
||||
endif()
|
||||
|
||||
if(USE_MKL)
|
||||
file(GLOB_RECURSE SRC_MKL src/mkl/*.cc src/kernels/mkl/*.cc )
|
||||
list (APPEND SRC ${SRC_MKL})
|
||||
if(USE_INTELCPU)
|
||||
file(GLOB_RECURSE SRC_INTELCPU src/intelcpu/*.cc src/kernels/intelcpu/*.cc )
|
||||
list (APPEND SRC ${SRC_INTELCPU})
|
||||
endif()
|
||||
|
||||
# Libraries
|
||||
|
@ -113,19 +114,28 @@ if(USE_BACKTRACE)
|
|||
target_link_libraries(InfiniTensor dw)
|
||||
endif()
|
||||
|
||||
if(USE_MKL)
|
||||
if(USE_INTELCPU)
|
||||
add_compile_definitions(USE_INTELCPU=1)
|
||||
find_package(MKL CONFIG REQUIRED)
|
||||
target_link_libraries(InfiniTensor $<LINK_ONLY:MKL::MKL>)
|
||||
|
||||
# Refer to https://www.intel.com/content/www/us/en/developer/tools/oneapi/onemkl-link-line-advisor.html
|
||||
target_link_libraries(InfiniTensor sycl OpenCL)
|
||||
|
||||
set(DNNL_CONFIGURATION "cpu_gomp")
|
||||
find_package(dnnl CONFIG REQUIRED)
|
||||
if(dnnl_FOUND)
|
||||
add_compile_definitions(USE_MKL=1)
|
||||
include_directories(BEFORE ${dnnl_DIR}/../../../cpu_gomp/include/)
|
||||
link_directories(${dnnl_DIR}/../../../cpu_gomp/lib)
|
||||
target_link_libraries(InfiniTensor dnnl)
|
||||
else()
|
||||
message(FATAL_ERROR ”dnnl library not found”)
|
||||
message(FATAL_ERROR "dnnl library not found")
|
||||
endif()
|
||||
set(WNO_ERRORS "-Wno-error=unused-parameter -Wno-error=unused-function -Wno-error=unused-private-field -Wno-error=ignored-attributes -Wno-error=unused-const-variable -Wno-error=inconsistent-missing-override -Wno-error=unused-variable -Wno-error=tautological-constant-compare")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMKL_ILP64 -qmkl=parallel -Werror ${WNO_ERRORS}")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DMKL_ILP64 -qmkl=parallel ${WNO_ERRORS}") # Enable assertion
|
||||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -DMKL_ILP64 -qmkl=parallel ${WNO_ERRORS}") # Enable assertion
|
||||
|
||||
find_package(IntelDPCPP REQUIRED)
|
||||
endif()
|
||||
|
||||
if(USE_CUDA)
|
||||
|
@ -210,8 +220,8 @@ if(BUILD_TEST)
|
|||
if (USE_BANG)
|
||||
build_test(test/kernels/bang/*.cc)
|
||||
endif()
|
||||
if (USE_MKL)
|
||||
build_test(test/kernels/mkl/*.cc)
|
||||
if (USE_INTELCPU)
|
||||
build_test(test/kernels/intelcpu/*.cc)
|
||||
endif()
|
||||
endif()
|
||||
if(BUILD_TEST_PET)
|
||||
|
|
7
Makefile
7
Makefile
|
@ -2,6 +2,7 @@
|
|||
|
||||
TYPE ?= release
|
||||
CUDA ?= off
|
||||
INTELCPU ?= off
|
||||
|
||||
CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE)
|
||||
|
||||
|
@ -9,9 +10,13 @@ ifeq ($(CUDA), ON)
|
|||
CMAKE_OPT += -DUSE_CUDA=ON
|
||||
endif
|
||||
|
||||
ifeq ($(INTELCPU), ON)
|
||||
CMAKE_OPT += -DUSE_INTELCPU=ON -DCMAKE_CXX_COMPILER=dpcpp
|
||||
endif
|
||||
|
||||
build:
|
||||
mkdir -p build/$(TYPE)
|
||||
cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j8
|
||||
cd build/$(TYPE) && cmake $(CMAKE_OPT) ../.. && make -j22
|
||||
|
||||
clean:
|
||||
rm -rf build
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
# InfiniTensor
|
||||
|
||||
## Compilation on Lotus
|
||||
|
||||
# Compilation for cuda
|
||||
``` bash
|
||||
# Enter the root of InfiniTensor
|
||||
source test/script/env_lotus.sh
|
||||
make CUDA=ON
|
||||
```
|
||||
## Compilation for intelcpu
|
||||
``` bash
|
||||
# Enter the root of InfiniTensor
|
||||
source test/script/env_lotus.sh intelcpu
|
||||
mkdir build && cd build
|
||||
cmake -DUSE_INTELCPU=ON -DCMAKE_CXX_COMPILER=dpcpp .. && make -j 12
|
||||
```
|
||||
|
||||
### Make Commands
|
||||
|
||||
|
|
|
@ -66,10 +66,10 @@ class GraphHandlerObj {
|
|||
Tensor relu(Tensor x, Tensor y);
|
||||
Tensor sigmoid(Tensor x, Tensor y);
|
||||
Tensor tanh(Tensor x, Tensor y);
|
||||
Tensor softmax(Tensor x, Tensor y);
|
||||
Tensor softmax(Tensor x, Tensor y, int axis);
|
||||
Tensor abs(Tensor x, Tensor y);
|
||||
Tensor identity(Tensor x, Tensor y);
|
||||
Tensor flatten(Tensor s, Tensor y);
|
||||
Tensor flatten(Tensor s, Tensor y, int axis);
|
||||
Tensor reshape(Tensor data, Tensor reshaped, Shape shape);
|
||||
Tensor concat(TensorVec inputs, Tensor output, int dim);
|
||||
Tensor gather(Tensor data, Tensor indices, Tensor output, int axis);
|
||||
|
|
|
@ -28,7 +28,7 @@ using OpVec = vector<Operator>;
|
|||
|
||||
using VType = uint32_t;
|
||||
|
||||
enum class Device { CPU = 1, CUDA, BANG, MKL };
|
||||
enum class Device { CPU = 1, CUDA, BANG, INTELCPU };
|
||||
/***************** Forward declaration end *****************/
|
||||
|
||||
class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
||||
|
@ -53,7 +53,6 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
|||
bool profiling = false) const = 0;
|
||||
virtual void *alloc(size_t size) = 0;
|
||||
virtual void dealloc(void *ptr) = 0;
|
||||
void prepareAndRun(Graph &graph, bool tune = false, bool profiling = false);
|
||||
/**
|
||||
* @brief Get the execution time of each operator in performance record. No
|
||||
* execution happens.
|
||||
|
@ -65,7 +64,7 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
|||
double getPerfTime(const Graph &graph, bool profiling = false) const;
|
||||
Blob allocBlob(size_t size);
|
||||
bool isCpu() const {
|
||||
return device == Device::CPU || device == Device::MKL;
|
||||
return device == Device::CPU || device == Device::INTELCPU;
|
||||
}
|
||||
bool isCuda() const { return device == Device::CUDA; }
|
||||
bool isBang() const { return device == Device::BANG; }
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
namespace infini {
|
||||
void softmax_kernel(int max_threadblock_size, int batch_size, float *x,
|
||||
float *y, int dim, int stride);
|
||||
}
|
|
@ -0,0 +1,40 @@
|
|||
#pragma once
|
||||
#include "core/kernel.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class MklKernelWithoutConfig : public Kernel {
|
||||
public:
|
||||
virtual void compute(const Operator &op, const PerfRecord &record,
|
||||
const RuntimeObj *_context) const override {
|
||||
compute(op, _context);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
context->sync();
|
||||
}
|
||||
virtual void compute(const Operator &op,
|
||||
const RuntimeObj *context) const = 0;
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
virtual PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
return make_ref<PerfRecordObj>(timeit([&]() { compute(op, _context); },
|
||||
[&]() { context->sync(); }));
|
||||
}
|
||||
|
||||
protected:
|
||||
dnnl::memory::format_tag getUserFormatTag(int nDim) const {
|
||||
if (nDim == 2)
|
||||
return dnnl::memory::format_tag::nc;
|
||||
else if (nDim == 3)
|
||||
return dnnl::memory::format_tag::ncw;
|
||||
else if (nDim == 4)
|
||||
return dnnl::memory::format_tag::nchw;
|
||||
else if (nDim == 5)
|
||||
return dnnl::memory::format_tag::ncdhw;
|
||||
else
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -7,9 +7,9 @@
|
|||
#include <dnnl_debug.h>
|
||||
#include <mkl.h>
|
||||
namespace infini {
|
||||
// TODO move utility function to alone file
|
||||
class MklRuntimeObj : public CpuRuntimeObj {
|
||||
dnnl_engine_t engine;
|
||||
dnnl_stream_t stream;
|
||||
|
||||
public:
|
||||
MklRuntimeObj();
|
||||
|
@ -26,8 +26,10 @@ class MklRuntimeObj : public CpuRuntimeObj {
|
|||
sizeof(uint64_t), 64);
|
||||
};
|
||||
|
||||
string toString() const override { return "CPU MKL Runtime"; };
|
||||
string toString() const override { return "INTELCPU Runtime"; };
|
||||
dnnl::engine getEngine() const { return dnnl::engine(engine, true); }
|
||||
dnnl::stream getStream() const { return dnnl::stream(stream, true); }
|
||||
void sync() const;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -22,6 +22,7 @@ class RoutineNode {
|
|||
|
||||
public:
|
||||
RoutineNode(Expr _expr, const vector<Tensor> &_inputs);
|
||||
virtual ~RoutineNode() {}
|
||||
virtual string toReadable() const = 0;
|
||||
const Expr &getExpr() const { return expr; }
|
||||
const vector<Tensor> &getInputs() const { return inputs; }
|
||||
|
|
|
@ -42,6 +42,7 @@ class ReshapeObj : public OperatorObj {
|
|||
*
|
||||
*/
|
||||
class FlattenObj : public OperatorObj {
|
||||
int axis;
|
||||
|
||||
public:
|
||||
/**
|
||||
|
@ -51,7 +52,7 @@ class FlattenObj : public OperatorObj {
|
|||
* @param input The input tensor.
|
||||
* @param output The output one-dimensional tensor.
|
||||
*/
|
||||
FlattenObj(GraphObj *graph, Tensor input, Tensor output);
|
||||
FlattenObj(GraphObj *graph, Tensor input, Tensor output, int axis);
|
||||
OP_CLONE(FlattenObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override;
|
||||
|
|
|
@ -75,6 +75,9 @@ class ResizeObj : public OperatorObj {
|
|||
IT_ASSERT((size_t)i < scales.size());
|
||||
return scales.at(i);
|
||||
}
|
||||
|
||||
vector<float> getScales() const { return scales; }
|
||||
|
||||
float getRoi(int i) const {
|
||||
if (coMode == ECoordinateTransMode::tfCropAndResize) {
|
||||
IT_ASSERT(size_t(i) < roi.size());
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
#pragma once
|
||||
#include "core/operator.h"
|
||||
|
||||
namespace infini {
|
||||
class SoftmaxObj : public OperatorObj {
|
||||
int axis;
|
||||
|
||||
public:
|
||||
SoftmaxObj(GraphObj *graph, Tensor input, Tensor output, int axis);
|
||||
|
||||
OP_CLONE(SoftmaxObj);
|
||||
|
||||
optional<vector<Shape>> inferShape(const TensorVec &inputs) const override {
|
||||
return {{inputs[0]->getDims()}};
|
||||
};
|
||||
|
||||
std::string toString() const override;
|
||||
int numInputs() const override { return 1; }
|
||||
int numOutputs() const override { return 1; }
|
||||
|
||||
int getAxis() const { return axis; }
|
||||
|
||||
private:
|
||||
vector<int> getWorkloadVector() const override;
|
||||
vector<int> getOpAttrVector() const override;
|
||||
};
|
||||
} // namespace infini
|
|
@ -39,6 +39,7 @@ class UnaryObj : public OperatorObj {
|
|||
DEFINE_UNARY_OBJ(Relu, OpType::Relu)
|
||||
DEFINE_UNARY_OBJ(Sigmoid, OpType::Sigmoid)
|
||||
DEFINE_UNARY_OBJ(Tanh, OpType::Tanh)
|
||||
DEFINE_UNARY_OBJ(Softmax, OpType::Softmax)
|
||||
// DEFINE_UNARY_OBJ(Softmax, OpType::Softmax)
|
||||
DEFINE_UNARY_OBJ(Abs, OpType::Abs)
|
||||
|
||||
}; // namespace infini
|
||||
|
|
|
@ -25,12 +25,7 @@ from onnx.shape_inference import infer_shapes
|
|||
from typing import Dict, List, Any, Tuple, Sequence, Union, Optional
|
||||
from functools import reduce
|
||||
|
||||
cpu_runtime = backend.cpu_runtime()
|
||||
|
||||
|
||||
def cuda_runtime():
|
||||
return backend.cuda_runtime()
|
||||
|
||||
runtime = backend.runtime()
|
||||
|
||||
class OnnxStub:
|
||||
inputs: Dict[str, backend.Tensor] = {}
|
||||
|
@ -253,6 +248,7 @@ class OnnxStub:
|
|||
tensors[node.output[0]] = self.handler.softmax(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "Abs":
|
||||
tensors[node.output[0]] = self.handler.abs(
|
||||
|
@ -265,14 +261,11 @@ class OnnxStub:
|
|||
tensors.get(node.output[0]),
|
||||
)
|
||||
elif node.op_type == "Flatten":
|
||||
# FIXME axis must be 1
|
||||
axis = next(
|
||||
(attr.i for attr in node.attribute if attr.name == "axis"), None
|
||||
)
|
||||
assert axis == None or axis == 1
|
||||
|
||||
tensors[node.output[0]] = self.handler.flatten(
|
||||
tensors[node.input[0]],
|
||||
tensors.get(node.output[0]),
|
||||
next((attr.i for attr in node.attribute if attr.name == "axis")),
|
||||
)
|
||||
elif node.op_type == "Reshape":
|
||||
input_shape = next(
|
||||
|
@ -583,6 +576,9 @@ def from_onnx(model: ModelProto, runtime):
|
|||
stub = OnnxStub(model, runtime)
|
||||
return stub.inputs, stub.outputs, stub.handler
|
||||
|
||||
def run_onnx(model: ModelProto, runtime):
|
||||
stub = OnnxStub(model, runtime)
|
||||
stub.run()
|
||||
|
||||
def _parse_attribute(node: NodeProto, attrs: Dict[str, Any] = dict()) -> Dict[str, Any]:
|
||||
for attr in node.attribute:
|
||||
|
|
|
@ -8,16 +8,28 @@ from onnx.helper import (
|
|||
make_tensor_value_info,
|
||||
)
|
||||
from onnx.checker import check_model
|
||||
from pyinfinitensor.onnx import from_onnx, backend, cpu_runtime
|
||||
from pyinfinitensor.onnx import from_onnx, backend, runtime, run_onnx
|
||||
|
||||
|
||||
def make_and_import_model(graph: onnx.GraphProto):
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
from_onnx(model, cpu_runtime)
|
||||
from_onnx(model, runtime)
|
||||
|
||||
|
||||
class TestStringMethods(unittest.TestCase):
|
||||
#def test_run(self):
|
||||
# model_file = next(
|
||||
# (name for name in os.listdir() if name.endswith(".onnx")), None
|
||||
# )
|
||||
# if model_file != None:
|
||||
# print(
|
||||
# "model: {file}({size:.2f} MiB)".format(
|
||||
# file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
|
||||
# )
|
||||
# )
|
||||
# run_onnx(onnx.load(model_file), runtime)
|
||||
|
||||
def test_load(self):
|
||||
model_file = next(
|
||||
(name for name in os.listdir() if name.endswith(".onnx")), None
|
||||
|
@ -28,7 +40,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
file=model_file, size=os.path.getsize(model_file) / 1024 / 1024
|
||||
)
|
||||
)
|
||||
from_onnx(onnx.load(model_file), cpu_runtime)
|
||||
from_onnx(onnx.load(model_file), runtime)
|
||||
|
||||
def test_tensor(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 2, 3])
|
||||
|
@ -177,7 +189,7 @@ class TestStringMethods(unittest.TestCase):
|
|||
def test_softmax(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
softmax = make_node("Softmax", ["x"], ["y"], name="softmax")
|
||||
softmax = make_node("Softmax", ["x"], ["y"], axis=2, name="softmax")
|
||||
make_and_import_model(make_graph([softmax], "softmax", [x], [y]))
|
||||
|
||||
def test_abs(self):
|
||||
|
@ -194,9 +206,8 @@ class TestStringMethods(unittest.TestCase):
|
|||
|
||||
def test_flatten(self):
|
||||
x = make_tensor_value_info("x", TensorProto.FLOAT, [1, 3, 5, 7])
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1, 1 * 3 * 5 * 7])
|
||||
flatten = make_node("Flatten", ["x"], ["y"], name="flatten")
|
||||
# FIXME 后端要求产生 Π(dims) 长的一维张量,onnx 产生 1×Π(dims) 的二维张量
|
||||
y = make_tensor_value_info("y", TensorProto.FLOAT, [1*3, 5 * 7])
|
||||
flatten = make_node("Flatten", ["x"], ["y"], axis=2, name="flatten")
|
||||
# make_and_import_model(
|
||||
make_graph([flatten], "flatten", [x], [y])
|
||||
# )
|
||||
|
@ -289,10 +300,10 @@ class TestStringMethods(unittest.TestCase):
|
|||
graph = make_graph([matmul, add], "lr", [x, a, b], [y])
|
||||
model = make_model(graph)
|
||||
check_model(model)
|
||||
from_onnx(model, cpu_runtime)
|
||||
from_onnx(model, runtime)
|
||||
|
||||
def test_frontend(self):
|
||||
handler = backend.GraphHandler(cpu_runtime)
|
||||
handler = backend.GraphHandler(runtime)
|
||||
a = handler.tensor([1, 2, 3], 12)
|
||||
b = handler.tensor([1, 2, 3], 12)
|
||||
c = handler.tensor([1, 2, 3], 12)
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "operators/reduce_mean.h"
|
||||
#include "operators/reshape.h"
|
||||
#include "operators/slice.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -126,11 +127,29 @@ DEFINE_ELEMENT_WISE_METHOD(pow, Pow)
|
|||
DEFINE_UNARY_METHOD(relu, Relu)
|
||||
DEFINE_UNARY_METHOD(sigmoid, Sigmoid)
|
||||
DEFINE_UNARY_METHOD(tanh, Tanh)
|
||||
DEFINE_UNARY_METHOD(softmax, Softmax)
|
||||
DEFINE_UNARY_METHOD(abs, Abs)
|
||||
// see operators/reshape.h
|
||||
DEFINE_UNARY_METHOD(identity, Identity)
|
||||
DEFINE_UNARY_METHOD(flatten, Flatten)
|
||||
|
||||
Tensor GraphHandlerObj::softmax(Tensor input, Tensor output, int axis) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<SoftmaxObj>(std::move(input), output, axis);
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<SoftmaxObj>(std::move(input), output, axis)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::flatten(Tensor input, Tensor output, int axis) {
|
||||
if (output) {
|
||||
g->addOpWithOutputs<FlattenObj>(std::move(input), output, axis);
|
||||
return output;
|
||||
} else {
|
||||
return g->addOp<FlattenObj>(std::move(input), output, axis)
|
||||
->getOutput();
|
||||
}
|
||||
}
|
||||
|
||||
Tensor GraphHandlerObj::reshape(Tensor data, Tensor reshaped, Shape shape) {
|
||||
if (reshaped) {
|
||||
|
|
|
@ -6,10 +6,6 @@
|
|||
#include <chrono>
|
||||
#include <cstring>
|
||||
namespace infini {
|
||||
void RuntimeObj::prepareAndRun(Graph &graph, bool tune, bool profiling) {
|
||||
run(graph, tune, profiling);
|
||||
}
|
||||
|
||||
void CpuRuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
||||
if (!tune && profiling)
|
||||
IT_TODO_HALT();
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
namespace infini {
|
||||
|
||||
TensorObj::TensorObj(Shape shape_, DataType dtype, Runtime runtime)
|
||||
: TensorBaseObj(shape.size(), dtype, runtime), shape(std::move(shape_)),
|
||||
: TensorBaseObj(shape_.size(), dtype, runtime), shape(std::move(shape_)),
|
||||
_size(shape.empty()
|
||||
? 0
|
||||
: std::accumulate(shape.begin(), shape.end(), 1,
|
||||
|
|
|
@ -5,7 +5,7 @@ __global__ void cudaPrintFloatImpl(float *x, int len) {
|
|||
int start = threadIdx.x + blockDim.x * blockIdx.x;
|
||||
if (start == 0) {
|
||||
for (int i = 0; i < len; ++i) {
|
||||
printf("%.3f ", x[i]);
|
||||
printf("%.7f ", x[i]);
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
|
|
@ -12,8 +12,9 @@
|
|||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/operator_timer.h"
|
||||
#endif
|
||||
#ifdef USE_MKL
|
||||
#include "mkl/operator_timer.h"
|
||||
#ifdef USE_INTELCPU
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "intelcpu/operator_timer.h"
|
||||
#endif
|
||||
namespace py = pybind11;
|
||||
|
||||
|
@ -30,7 +31,7 @@ void register_operator_timer(py::module &m) {
|
|||
m.def("getPerfMatmulCublas", &getPerfMatmulCublas);
|
||||
#endif
|
||||
|
||||
#ifdef USE_MKL
|
||||
#ifdef USE_INTELCPU
|
||||
using namespace opTimer;
|
||||
m.def("getPerfConvMkl", &getPerfConvMkl);
|
||||
m.def("getPerfConvTransposed2dMkl", &getPerfConvTransposed2dMkl);
|
||||
|
@ -111,6 +112,10 @@ static int tensor_dtype(Tensor t) {
|
|||
static Ref<CudaRuntimeObj> cuda_runtime() { return make_ref<CudaRuntimeObj>(); }
|
||||
#endif
|
||||
|
||||
#ifdef USE_INTELCPU
|
||||
static Ref<RuntimeObj> intelcpu_runtime() { return make_ref<MklRuntimeObj>(); }
|
||||
#endif
|
||||
|
||||
static std::tuple<int, int, int, int, int, int> conv_attrs_of(Operator op) {
|
||||
IT_ASSERT(op->getOpType() == OpType::Conv);
|
||||
auto conv = dynamic_cast<const ConvObj *>(op.get());
|
||||
|
@ -158,10 +163,14 @@ static Shape reshape_shape_of(Operator op) {
|
|||
|
||||
void export_functions(py::module &m) {
|
||||
#define FUNCTION(NAME) def(#NAME, &NAME)
|
||||
m.def("cpu_runtime", &NativeCpuRuntimeObj::getInstance)
|
||||
#ifdef USE_CUDA
|
||||
.FUNCTION(cuda_runtime)
|
||||
m.def("runtime", cuda_runtime)
|
||||
#elif USE_INTELCPU
|
||||
m.def("runtime", intelcpu_runtime)
|
||||
#else
|
||||
m.def("runtime", &NativeCpuRuntimeObj::getInstance)
|
||||
#endif
|
||||
|
||||
.FUNCTION(conv_attrs_of)
|
||||
.FUNCTION(batch_norm_attrs_of)
|
||||
.FUNCTION(pool_attrs_of)
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
namespace infini {
|
||||
MklRuntimeObj::MklRuntimeObj() : CpuRuntimeObj(Device::INTELCPU) {
|
||||
dnnl_engine_create(&engine, dnnl_engine_kind_t::dnnl_cpu, 0);
|
||||
dnnl_stream_create(
|
||||
&stream, engine,
|
||||
static_cast<dnnl_stream_flags_t>(dnnl_stream_default_flags));
|
||||
}
|
||||
|
||||
MklRuntimeObj::~MklRuntimeObj() {
|
||||
mkl_free_buffers();
|
||||
dnnl_stream_destroy(stream);
|
||||
dnnl_engine_destroy(engine);
|
||||
}
|
||||
|
||||
void MklRuntimeObj::sync() const { getStream().wait(); }
|
||||
} // namespace infini
|
|
@ -1,7 +1,7 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "utils/data_generator.h"
|
|
@ -10,8 +10,13 @@ template <typename T> class NaiveConv : public CpuKernelWithoutConfig {
|
|||
T *iptr = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
T *wptr = op->getInputs(1)->getRawDataPtr<T *>();
|
||||
T *optr = op->getOutput()->getRawDataPtr<T *>();
|
||||
auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
// Clang will give an error of " reference to local binding 'sh'
|
||||
// declared in enclosing function" if we write like this:
|
||||
// auto [n, c, h, w, f, r, s] = op->getNCHWFRS();
|
||||
int n, c, h, w, f, r, s;
|
||||
std::tie(n, c, h, w, f, r, s) = op->getNCHWFRS();
|
||||
int ph, pw, sh, sw, dh, dw;
|
||||
std::tie(ph, pw, sh, sw, dh, dw) = op->getPadStrideDilation();
|
||||
int cpg = op->getChannelPerGroup();
|
||||
int g = op->getNumGroups();
|
||||
IT_ASSERT(f % g == 0, "Illegal number of channel");
|
||||
|
@ -23,7 +28,7 @@ template <typename T> class NaiveConv : public CpuKernelWithoutConfig {
|
|||
for (int hh = 0; hh < oh; hh++)
|
||||
for (int ww = 0; ww < ow; ww++) {
|
||||
int gidx = ff / (f / g);
|
||||
VType val = 0;
|
||||
T val = 0;
|
||||
for (int cc = 0; cc < cpg; cc++)
|
||||
for (int rr = 0; rr < r; rr++)
|
||||
for (int ss = 0; ss < s; ss++) {
|
||||
|
|
|
@ -30,8 +30,8 @@ class MemboundInterpreter : public Kernel {
|
|||
// }
|
||||
|
||||
nnet::RangeOp range = nnet::as<nnet::RangeOpNode>(op->getNnetExpr());
|
||||
const auto &rangeShape = range->getOutputShape();
|
||||
const auto &outputShape = output->getDims();
|
||||
// const auto &rangeShape = range->getOutputShape();
|
||||
// const auto &outputShape = output->getDims();
|
||||
// rangeShape and outputShape may extra dims of length 1.
|
||||
// But their sizes should be the same.
|
||||
IT_ASSERT((ssize_t)range->getOutputSize() == (ssize_t)output->size());
|
||||
|
|
|
@ -213,7 +213,7 @@ void resize_kernel_nearest(float *in, float *out, const MetaData &metaData,
|
|||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
IT_ASSERT(nearestMode <
|
||||
sizeof(p_nearest_mode_fun) / sizeof(p_nearest_mode_fun[0]));
|
||||
_resize_kernel_nearest<<<blocksize, gridsize>>>(
|
||||
_resize_kernel_nearest<<<gridsize, blocksize>>>(
|
||||
in, out, metaData, num, coordinateMode, nearestMode);
|
||||
}
|
||||
|
||||
|
@ -223,7 +223,7 @@ void resize_kernel_linear(float *in, float *out, const MetaData &metaData,
|
|||
auto gridsize = (num + blocksize - 1) / blocksize;
|
||||
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
|
||||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
_resize_kernel_linear_coeff<<<blocksize, gridsize>>>(in, out, metaData, num,
|
||||
_resize_kernel_linear_coeff<<<gridsize, blocksize>>>(in, out, metaData, num,
|
||||
coordinateMode);
|
||||
}
|
||||
|
||||
|
@ -233,7 +233,7 @@ void resize_kernel_cubic(float *in, float *out, const MetaData &metaData,
|
|||
auto gridsize = (num + blocksize - 1) / blocksize;
|
||||
IT_ASSERT(coordinateMode < sizeof(p_cooridnate_trans_mode_func) /
|
||||
sizeof(p_cooridnate_trans_mode_func[0]));
|
||||
_resize_kernel_cubic_coeff<<<blocksize, gridsize>>>(in, out, metaData, num,
|
||||
_resize_kernel_cubic_coeff<<<gridsize, blocksize>>>(in, out, metaData, num,
|
||||
coordinateMode);
|
||||
}
|
||||
} // namespace infini
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
#include "operators/softmax.h"
|
||||
#include "cuda/cuda_kernel_wihtout_config.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/softmax.h"
|
||||
|
||||
namespace infini {
|
||||
class SoftmaxCudnn : public CudaKernelWithoutConfig {
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SoftmaxObj>(_op);
|
||||
auto x = op->getInputs(0)->getRawDataPtr<float *>();
|
||||
auto y = op->getOutput(0)->getRawDataPtr<float *>();
|
||||
auto dims = op->getInputs(0)->getDims();
|
||||
|
||||
int batch_size = 1;
|
||||
for (size_t i = 0; i < dims.size(); ++i)
|
||||
batch_size *= dims[i];
|
||||
int dim = dims[op->getAxis()];
|
||||
|
||||
int block_num = batch_size / dim;
|
||||
int max_threadblock_size = batch_size / block_num;
|
||||
softmax_kernel(max_threadblock_size, block_num, x, y, dim,
|
||||
op->getInputs(0)->getStride().at(op->getAxis()));
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCudnn,
|
||||
"Softmax_CUDA_Float32");
|
||||
} // namespace infini
|
|
@ -0,0 +1,77 @@
|
|||
#include "cuda/cuda_common.h"
|
||||
#include "cuda/softmax.h"
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
struct __align__(8) MD {
|
||||
float data;
|
||||
float d;
|
||||
};
|
||||
|
||||
__device__ __forceinline__ MD reduce_md_op(MD a, MD b) {
|
||||
bool a_bigger = (a.data > b.data);
|
||||
MD bigger_m = a_bigger ? a : b;
|
||||
MD smaller_m = a_bigger ? b : a;
|
||||
MD res;
|
||||
res.d = bigger_m.d + smaller_m.d * __expf(smaller_m.data - bigger_m.data);
|
||||
res.data = bigger_m.data;
|
||||
return res;
|
||||
}
|
||||
|
||||
template <int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
void online_softmax(const float *__restrict in, float *__restrict out,
|
||||
int dimSize, int stride) {
|
||||
|
||||
// reposition in and out to data for the current vector
|
||||
int blockOffset = blockIdx.x;
|
||||
if (blockIdx.x >= stride) {
|
||||
int tmp = blockIdx.x % stride;
|
||||
blockOffset = tmp + (blockIdx.x - tmp) * dimSize;
|
||||
}
|
||||
in += blockOffset;
|
||||
out += blockOffset;
|
||||
|
||||
MD md_partial;
|
||||
md_partial.data = -FLT_MAX;
|
||||
md_partial.d = 0.0F;
|
||||
|
||||
for (int elem_id = threadIdx.x; elem_id < dimSize;
|
||||
elem_id += THREADBLOCK_SIZE) {
|
||||
MD new_elem;
|
||||
new_elem.data = in[elem_id * stride];
|
||||
new_elem.d = 1.0F;
|
||||
md_partial = reduce_md_op(md_partial, new_elem);
|
||||
}
|
||||
|
||||
// blockreduce for THREADBLOCK_SIZE threads.
|
||||
// The actrual threads num used in the block is "dimsSize"
|
||||
typedef cub::BlockReduce<MD, THREADBLOCK_SIZE> BlockReduce;
|
||||
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
__shared__ MD md_total;
|
||||
|
||||
MD md = BlockReduce(temp_storage).Reduce(md_partial, reduce_md_op);
|
||||
if (threadIdx.x == 0)
|
||||
md_total = md;
|
||||
__syncthreads();
|
||||
|
||||
float d_total_inverse = __fdividef(1.0F, md_total.d);
|
||||
for (int elem_id = threadIdx.x; elem_id < dimSize;
|
||||
elem_id += THREADBLOCK_SIZE)
|
||||
out[elem_id * stride] =
|
||||
__expf(in[elem_id * stride] - md_total.data) * d_total_inverse;
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
void softmax_kernel(int max_threadblock_size, int blockNum, float *in,
|
||||
float *out, int dimSize, int stride) {
|
||||
if (max_threadblock_size >= 255)
|
||||
online_softmax<256><<<blockNum, 256>>>(in, out, dimSize, stride);
|
||||
else if (max_threadblock_size >= 128)
|
||||
online_softmax<128><<<blockNum, 128>>>(in, out, dimSize, stride);
|
||||
else if (max_threadblock_size >= 64)
|
||||
online_softmax<64><<<blockNum, 64>>>(in, out, dimSize, stride);
|
||||
else
|
||||
online_softmax<32><<<blockNum, 32>>>(in, out, dimSize, stride);
|
||||
}
|
||||
} // namespace infini
|
|
@ -60,48 +60,6 @@ class ActivationCudnn : public CudaKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
class SoftmaxCudnn : public CudaKernelWithoutConfig {
|
||||
virtual cudnnSoftmaxAlgorithm_t getAlgorithmType() const = 0;
|
||||
virtual cudnnSoftmaxMode_t getModeType() const = 0;
|
||||
virtual tuple<float, float> getAlphBeta() const { return {1.f, 0.f}; }
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
|
||||
|
||||
void *const inputData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const outputData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
cudnnTensorDescriptor_t inputDesc, outputDesc;
|
||||
auto dim = op->getInputs(0)->getDims();
|
||||
if (dim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
int n = dim[0], c = dim[1], h = dim[2], w = dim[3];
|
||||
|
||||
// get inputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&inputDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
inputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
|
||||
|
||||
// get outputs
|
||||
checkCudnnError(cudnnCreateTensorDescriptor(&outputDesc));
|
||||
checkCudnnError(cudnnSetTensor4dDescriptor(
|
||||
outputDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, n, c, h, w));
|
||||
|
||||
auto [alpha, beta] = getAlphBeta();
|
||||
cudnnStatus_t stat = cudnnSoftmaxForward(
|
||||
context->cudnnHandle(), getAlgorithmType(), getModeType(), &alpha,
|
||||
inputDesc, inputData, &beta, outputDesc, outputData);
|
||||
if (stat != CUDNN_STATUS_SUCCESS)
|
||||
return;
|
||||
|
||||
// Destories in CUDA does not require sync. But cuDNN does not state
|
||||
// whether sync is required before destories.
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(inputDesc));
|
||||
checkCudnnError(cudnnDestroyTensorDescriptor(outputDesc));
|
||||
}
|
||||
};
|
||||
|
||||
class ReluCudnn : public ActivationCudnn {
|
||||
cudnnActivationMode_t getOpType() const override {
|
||||
return CUDNN_ACTIVATION_RELU;
|
||||
|
@ -120,17 +78,6 @@ class TanhCudnn : public ActivationCudnn {
|
|||
}
|
||||
};
|
||||
|
||||
class NormalSoftmaxCudnn : public SoftmaxCudnn {
|
||||
cudnnSoftmaxAlgorithm_t getAlgorithmType() const override {
|
||||
return CUDNN_SOFTMAX_ACCURATE;
|
||||
}
|
||||
cudnnSoftmaxMode_t getModeType() const override {
|
||||
return CUDNN_SOFTMAX_MODE_INSTANCE;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32,
|
||||
NormalSoftmaxCudnn, "Softmax_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Relu, DataType::Float32, ReluCudnn,
|
||||
"Relu_CUDA_Float32");
|
||||
REGISTER_KERNEL(Device::CUDA, OpType::Sigmoid, DataType::Float32, SigmoidCudnn,
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
#include "operators/batch_norm.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class MklBatchNorm : public MklKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<BatchNormObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
|
||||
float *const srcData = op->getInputs(0)->getRawDataPtr<float *>();
|
||||
float *const dstData = op->getOutput()->getRawDataPtr<float *>();
|
||||
|
||||
// create user memory that describes data layout in the buffers
|
||||
std::vector<dnnl_dim_t> dims;
|
||||
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i)
|
||||
dims.push_back(op->getInputs(0)->getDims()[i]);
|
||||
|
||||
auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
auto srcMemory = dnnl::memory(srcMd, context->getEngine(), srcData);
|
||||
|
||||
auto dstMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
auto output = dnnl::memory(dstMd, context->getEngine(), dstData);
|
||||
|
||||
std::vector<dnnl_dim_t> meanDims(op->getInputs(0)->getDims().size(), 1);
|
||||
meanDims[1] = op->getInputs(0)->getDims()[1];
|
||||
auto meanMd = dnnl::memory::desc(meanDims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(meanDims.size()));
|
||||
|
||||
auto meanMemory =
|
||||
dnnl::memory(meanMd, context->getEngine(),
|
||||
op->getInputs(1)->getRawDataPtr<float *>());
|
||||
auto varMemory =
|
||||
dnnl::memory(meanMd, context->getEngine(),
|
||||
op->getInputs(2)->getRawDataPtr<float *>());
|
||||
auto scaleMemory =
|
||||
dnnl::memory(meanMd, context->getEngine(),
|
||||
op->getInputs(3)->getRawDataPtr<float *>());
|
||||
auto baisMemory =
|
||||
dnnl::memory(meanMd, context->getEngine(),
|
||||
op->getInputs(4)->getRawDataPtr<float *>());
|
||||
using op_desc_t = dnnl::batch_normalization_forward::desc;
|
||||
using pd_t = dnnl::batch_normalization_forward::primitive_desc;
|
||||
|
||||
// use_global_stats stands for use mean and var as inputs
|
||||
auto opDesc =
|
||||
op_desc_t(dnnl::prop_kind::forward_inference, srcMd, op->getEps(),
|
||||
dnnl::normalization_flags::use_global_stats |
|
||||
dnnl::normalization_flags::use_shift |
|
||||
dnnl::normalization_flags::use_scale);
|
||||
auto primDesc = pd_t(opDesc, context->getEngine());
|
||||
|
||||
// create and execute primitive
|
||||
dnnl::batch_normalization_forward(primDesc).execute(
|
||||
context->getStream(), {{DNNL_ARG_SRC, srcMemory},
|
||||
{DNNL_ARG_DST, output},
|
||||
{DNNL_ARG_MEAN, meanMemory},
|
||||
{DNNL_ARG_VARIANCE, varMemory},
|
||||
{DNNL_ARG_SCALE, scaleMemory},
|
||||
{DNNL_ARG_SHIFT, baisMemory}});
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::BatchNorm, DataType::Float32,
|
||||
MklBatchNorm, "BatchNorm_Mkl_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,58 @@
|
|||
#include "operators/concat.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class MklConcat : public MklKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConcatObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
|
||||
// create user memory that describes data layout in the buffers
|
||||
std::vector<dnnl::memory::desc> srcsMd;
|
||||
std::vector<dnnl::memory> srcs;
|
||||
|
||||
for (size_t i = 0; i < op->getInputs().size(); i++) {
|
||||
std::vector<dnnl_dim_t> dims;
|
||||
auto inDims = op->getInputs(i)->getDims();
|
||||
int ndim = inDims.size();
|
||||
for (int j = 0; j < ndim; ++j)
|
||||
dims.push_back(inDims.at(j));
|
||||
|
||||
auto md = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
srcsMd.push_back(md);
|
||||
|
||||
auto srcMemory =
|
||||
dnnl::memory(md, context->getEngine(),
|
||||
op->getInputs(i)->getRawDataPtr<float *>());
|
||||
srcs.push_back(srcMemory);
|
||||
}
|
||||
|
||||
std::vector<dnnl_dim_t> dims;
|
||||
auto oDims = op->getOutput(0)->getDims();
|
||||
int ndim = oDims.size();
|
||||
for (int i = 0; i < ndim; ++i)
|
||||
dims.push_back(oDims.at(i));
|
||||
|
||||
auto dstMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
auto primDesc =
|
||||
dnnl::concat::primitive_desc(dstMd, static_cast<int>(op->getDim()),
|
||||
srcsMd, context->getEngine());
|
||||
|
||||
float *const dstData = op->getOutput()->getRawDataPtr<float *>();
|
||||
auto output = dnnl::memory(dstMd, context->getEngine(), dstData);
|
||||
|
||||
// create and execute primitive
|
||||
std::unordered_map<int, dnnl::memory> args = {{DNNL_ARG_DST, output}};
|
||||
for (int i = 0; i < (int)srcs.size(); i++) {
|
||||
args.insert({DNNL_ARG_MULTIPLE_SRC + i, srcs.at(i)});
|
||||
}
|
||||
dnnl::concat(primDesc).execute(context->getStream(), args);
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Concat, DataType::Float32, MklConcat,
|
||||
"Concat_Mkl_Float32");
|
||||
}; // namespace infini
|
|
@ -1,6 +1,6 @@
|
|||
#include "operators/conv.h"
|
||||
#include "core/kernel.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
struct ConvMklPerfRecordObj : public PerfRecordObj {
|
||||
|
@ -167,20 +167,19 @@ class MklConv : public Kernel {
|
|||
}
|
||||
|
||||
void compute(const Operator &_op, const PerfRecord &_record,
|
||||
const RuntimeObj *_context) const {
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ConvObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
auto record = as<ConvMklPerfRecordObj>(_record);
|
||||
|
||||
dnnl::stream stream(context->getEngine());
|
||||
std::vector<dnnl::primitive> prims;
|
||||
std::vector<std::unordered_map<int, dnnl::memory>> primArgs;
|
||||
IT_ASSERT(createPrimitives(op, record, context, true, prims, primArgs));
|
||||
|
||||
IT_ASSERT(prims.size() == primArgs.size());
|
||||
for (size_t i = 0; i < prims.size(); ++i)
|
||||
prims.at(i).execute(stream, primArgs.at(i));
|
||||
stream.wait();
|
||||
prims.at(i).execute(context->getStream(), primArgs.at(i));
|
||||
context->getStream().wait();
|
||||
}
|
||||
|
||||
void compute(const Operator &op, const RuntimeObj *context) const override {
|
||||
|
@ -209,17 +208,19 @@ class MklConv : public Kernel {
|
|||
continue;
|
||||
|
||||
IT_ASSERT(prims.size() == primArgs.size());
|
||||
dnnl::stream stream(context->getEngine());
|
||||
// does context->getStream() need to be attached to runtime, and
|
||||
// delete after each use?
|
||||
for (size_t i = 0; i < prims.size(); ++i)
|
||||
prims.at(i).execute(stream, primArgs.at(i));
|
||||
stream.wait();
|
||||
prims.at(i).execute(context->getStream(), primArgs.at(i));
|
||||
context->getStream().wait();
|
||||
|
||||
record.time = timeit(
|
||||
[&]() {
|
||||
for (size_t i = 0; i < prims.size(); ++i)
|
||||
prims.at(i).execute(stream, primArgs.at(i));
|
||||
prims.at(i).execute(context->getStream(),
|
||||
primArgs.at(i));
|
||||
},
|
||||
[&]() { stream.wait(); });
|
||||
[&]() { context->getStream().wait(); });
|
||||
|
||||
// Update the tune result
|
||||
if (ret.time > record.time)
|
||||
|
@ -232,6 +233,6 @@ class MklConv : public Kernel {
|
|||
return make_ref<ConvMklPerfRecordObj>(ret);
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::MKL, OpType::Conv, DataType::Float32, MklConv,
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Conv, DataType::Float32, MklConv,
|
||||
"MklConv_CPU_float32");
|
||||
} // namespace infini
|
|
@ -1,5 +1,5 @@
|
|||
#include "core/kernel.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
namespace infini {
|
||||
|
@ -244,7 +244,7 @@ class MklConvTranspose : public Kernel {
|
|||
return make_ref<ConvTransposeMklPerfRecordObj>(ret);
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::MKL, OpType::ConvTrans, DataType::Float32,
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::ConvTrans, DataType::Float32,
|
||||
MklConvTranspose, "MklConvTrans_CPU_float32");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,133 @@
|
|||
#include "operators/element_wise.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/unary.h"
|
||||
|
||||
namespace infini {
|
||||
class MklBinary : public MklKernelWithoutConfig {
|
||||
dnnl::algorithm getAlgorithem(const Ref<ElementWiseObj> &op) const {
|
||||
switch (op->getOpType()) {
|
||||
case OpType::Add:
|
||||
return dnnl::algorithm::binary_add;
|
||||
case OpType::Sub:
|
||||
return dnnl::algorithm::binary_sub;
|
||||
case OpType::Mul:
|
||||
return dnnl::algorithm::binary_mul;
|
||||
case OpType::Div:
|
||||
return dnnl::algorithm::binary_div;
|
||||
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
return dnnl::algorithm::undef;
|
||||
}
|
||||
|
||||
// Binary primitives support elementwise broadcast
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
// create user memory that describes data layout in the buffers
|
||||
std::vector<dnnl_dim_t> dims;
|
||||
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i)
|
||||
dims.push_back(op->getInputs(0)->getDims()[i]);
|
||||
|
||||
auto srcMd1 = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
auto srcMemory1 = dnnl::memory(srcMd1, context->getEngine(), aData);
|
||||
|
||||
auto srcMd2 = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
auto srcMemory2 = dnnl::memory(srcMd2, context->getEngine(), bData);
|
||||
|
||||
auto dstMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
auto output = dnnl::memory(dstMd, context->getEngine(), cData);
|
||||
|
||||
auto binaryDesc =
|
||||
dnnl::binary::desc(getAlgorithem(op), srcMd1, srcMd2, dstMd);
|
||||
auto primDesc =
|
||||
dnnl::binary::primitive_desc(binaryDesc, context->getEngine());
|
||||
|
||||
// create and execute binary primitive
|
||||
dnnl::binary(primDesc).execute(context->getStream(),
|
||||
{{DNNL_ARG_SRC_0, srcMemory1},
|
||||
{DNNL_ARG_SRC_1, srcMemory2},
|
||||
{DNNL_ARG_DST, output}});
|
||||
}
|
||||
};
|
||||
|
||||
class MklUnary : public MklKernelWithoutConfig {
|
||||
dnnl::algorithm getAlgorithem(const Ref<UnaryObj> &op) const {
|
||||
switch (op->getOpType()) {
|
||||
case OpType::Relu:
|
||||
return dnnl::algorithm::eltwise_relu;
|
||||
case OpType::Tanh:
|
||||
return dnnl::algorithm::eltwise_tanh;
|
||||
case OpType::Abs:
|
||||
return dnnl::algorithm::eltwise_abs;
|
||||
case OpType::Sigmoid:
|
||||
return dnnl::algorithm::eltwise_logistic;
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
return dnnl::algorithm::undef;
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<UnaryObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
|
||||
void *const srcData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const dstData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
// create user memory that describes data layout in the buffers
|
||||
std::vector<dnnl_dim_t> dims;
|
||||
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i)
|
||||
dims.push_back(op->getInputs(0)->getDims()[i]);
|
||||
|
||||
auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()), false);
|
||||
auto srcMemory = dnnl::memory(srcMd, context->getEngine(), srcData);
|
||||
|
||||
auto output = dnnl::memory(srcMd, context->getEngine(), dstData);
|
||||
|
||||
const float negative1_slope = 0.0f;
|
||||
|
||||
auto unaryDesc = dnnl::eltwise_forward::desc(
|
||||
dnnl::prop_kind::forward_inference, getAlgorithem(op), srcMd,
|
||||
negative1_slope);
|
||||
auto primDesc = dnnl::eltwise_forward::primitive_desc(
|
||||
unaryDesc, context->getEngine());
|
||||
|
||||
// create and execute binary primitive
|
||||
dnnl::eltwise_forward(primDesc).execute(
|
||||
context->getStream(),
|
||||
{{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}});
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Add, DataType::Float32, MklBinary,
|
||||
"Add_Mkl_Float32");
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Sub, DataType::Float32, MklBinary,
|
||||
"Sub_Mkl_Float32");
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Mul, DataType::Float32, MklBinary,
|
||||
"Mul_Mkl_Float32");
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Div, DataType::Float32, MklBinary,
|
||||
"Div_Mkl_Float32");
|
||||
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Relu, DataType::Float32, MklUnary,
|
||||
"Relu_Mkl_Float32");
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Sigmoid, DataType::Float32, MklUnary,
|
||||
"Sigmoid_Mkl_Float32");
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Tanh, DataType::Float32, MklUnary,
|
||||
"Tanh_Mkl_Float32");
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Abs, DataType::Float32, MklUnary,
|
||||
"Abs_Mkl_Float32");
|
||||
} // namespace infini
|
|
@ -0,0 +1,45 @@
|
|||
#include "operators/extend.h"
|
||||
#include "core/kernel.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include <CL/sycl.hpp>
|
||||
#include <math.h>
|
||||
|
||||
namespace infini {
|
||||
class MklExtend : public MklKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ExtendObj>(_op);
|
||||
auto inData = op->getInputs(0)->getRawDataPtr<float *>();
|
||||
auto outData = op->getOutput(0)->getRawDataPtr<float *>();
|
||||
int iSize = op->getInputs(0)->size();
|
||||
int oSize = op->getOutput(0)->size();
|
||||
|
||||
sycl::queue q(sycl::cpu_selector{});
|
||||
auto inDevice = sycl::malloc_device<float>(iSize, q);
|
||||
auto outDevice = sycl::malloc_device<float>(oSize, q);
|
||||
|
||||
q.memcpy(inDevice, inData, iSize * sizeof(float));
|
||||
q.wait();
|
||||
|
||||
int blockSize = 1;
|
||||
auto iDim = op->getInputs(0)->getDims();
|
||||
for (size_t i = iDim.size() - 1;
|
||||
i >= (size_t)op->getDim() && i != (size_t)-1; --i)
|
||||
blockSize *= iDim[i];
|
||||
auto blockSizeOuter = (op->getNum() + 1) * blockSize;
|
||||
|
||||
q.parallel_for(sycl::range<1>(oSize), [=](sycl::id<1> index) {
|
||||
auto iIdx = index % blockSize + index / blockSizeOuter * blockSize;
|
||||
outDevice[index] = inDevice[iIdx];
|
||||
}).wait();
|
||||
|
||||
q.memcpy(outData, outDevice, oSize * sizeof(float));
|
||||
q.wait();
|
||||
sycl::free(inDevice, q);
|
||||
sycl::free(outDevice, q);
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Extend, DataType::Float32, MklExtend,
|
||||
"Extend_Mkl_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,86 @@
|
|||
#include "operators/gather.h"
|
||||
#include "core/kernel.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include <CL/sycl.hpp>
|
||||
#include <math.h>
|
||||
|
||||
namespace infini {
|
||||
class MklGather : public MklKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<GatherObj>(_op);
|
||||
auto in = op->getInputs(0);
|
||||
auto index = op->getInputs(1);
|
||||
auto out = op->getOutput();
|
||||
int iSize = in->size();
|
||||
int oSize = out->size();
|
||||
int idxSize = index->size();
|
||||
|
||||
int inNDim = in->getDims().size();
|
||||
int oNDim = out->getDims().size();
|
||||
int idxNDim = index->getDims().size();
|
||||
int axis = op->getAxis();
|
||||
|
||||
int outDim[4] = {0};
|
||||
int idxDim[4] = {0};
|
||||
int idxStride[4] = {0};
|
||||
int inStride[4] = {0};
|
||||
for (int i = 0; i < oNDim; ++i)
|
||||
outDim[i] = out->getDims()[i];
|
||||
for (int i = 0; i < idxNDim; ++i) {
|
||||
idxDim[i] = index->getDims()[i];
|
||||
idxStride[i] = index->getStride()[i];
|
||||
}
|
||||
for (int i = 0; i < inNDim; ++i) {
|
||||
inStride[i] = in->getStride()[i];
|
||||
}
|
||||
|
||||
sycl::queue q(sycl::cpu_selector{});
|
||||
auto inDevice = sycl::malloc_device<float>(iSize, q);
|
||||
auto indexDevice = sycl::malloc_device<uint32_t>(idxSize, q);
|
||||
auto outDevice = sycl::malloc_device<float>(oSize, q);
|
||||
|
||||
q.memcpy(inDevice, in->getRawDataPtr<float *>(), iSize * sizeof(float));
|
||||
q.memcpy(indexDevice, index->getRawDataPtr<uint32_t *>(),
|
||||
idxSize * sizeof(uint32_t));
|
||||
q.wait();
|
||||
|
||||
q.parallel_for(sycl::range<1>(oSize), [=](sycl::id<1> index) {
|
||||
int offset = 0;
|
||||
int gOffset = index;
|
||||
for (int i = inNDim - 1, k = oNDim - 1; i >= 0; --i) {
|
||||
int idx = 0;
|
||||
if (i == axis) {
|
||||
int idxOffset = 0;
|
||||
for (int j = idxNDim - 1; j >= 0; --j) {
|
||||
int p = gOffset % idxDim[j];
|
||||
gOffset = gOffset / idxDim[j];
|
||||
idxOffset += p * idxStride[j];
|
||||
}
|
||||
|
||||
idx = indexDevice[idxOffset];
|
||||
k = k - idxNDim;
|
||||
|
||||
} else {
|
||||
idx = gOffset % outDim[k];
|
||||
gOffset = gOffset / outDim[k];
|
||||
--k;
|
||||
}
|
||||
offset += idx * inStride[i];
|
||||
}
|
||||
|
||||
outDevice[index] = inDevice[offset];
|
||||
}).wait();
|
||||
|
||||
q.memcpy(out->getRawDataPtr<float *>(), outDevice,
|
||||
oSize * sizeof(float));
|
||||
q.wait();
|
||||
sycl::free(inDevice, q);
|
||||
sycl::free(outDevice, q);
|
||||
sycl::free(indexDevice, q);
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Gather, DataType::Float32, MklGather,
|
||||
"Gather_Mkl_Float32");
|
||||
}; // namespace infini
|
|
@ -1,9 +1,8 @@
|
|||
#include "operators/matmul.h"
|
||||
#include "core/kernel.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <typename T> class MklMatmul : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
|
@ -32,7 +31,7 @@ template <typename T> class MklMatmul : public CpuKernelWithoutConfig {
|
|||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::MKL, OpType::Matmul, DataType::Float32,
|
||||
MklMatmul<float>, "MklMatmul_CPU_float32");
|
||||
/*REGISTER_KERNEL(Device::INTELCPU, OpType::Matmul, DataType::Float32,
|
||||
MklMatmul<float>, "MklMatmul_CPU_float32");*/
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,75 @@
|
|||
#include "core/kernel.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "mkl.h"
|
||||
#include "oneapi/mkl/blas.hpp"
|
||||
#include "operators/matmul.h"
|
||||
#include <CL/sycl.hpp>
|
||||
|
||||
namespace infini {
|
||||
template <typename T> class MklDpcppMatmul : public CpuKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *context) const override {
|
||||
auto op = as<MatmulObj>(_op);
|
||||
IT_ASSERT(op->getInputs().size() == 2, "Bias is not supported yet.");
|
||||
const T *A = op->getInputs(0)->getRawDataPtr<T *>();
|
||||
const T *B = op->getInputs(1)->getRawDataPtr<T *>();
|
||||
T *C = op->getOutput()->getRawDataPtr<T *>();
|
||||
IT_ASSERT(op->getAct() == ActType::None);
|
||||
const int m = op->getM(), n = op->getN(), k = op->getK(),
|
||||
b = op->getB();
|
||||
|
||||
auto opA = op->getTransA() ? oneapi::mkl::transpose::trans
|
||||
: oneapi::mkl::transpose::nontrans;
|
||||
auto opB = op->getTransB() ? oneapi::mkl::transpose::trans
|
||||
: oneapi::mkl::transpose::nontrans;
|
||||
// ldA is always a.col, and ldB is always b.col when row major
|
||||
const int ldA =
|
||||
std::max((opA == oneapi::mkl::transpose::nontrans) ? k : m, 1);
|
||||
const int ldB =
|
||||
std::max((opB == oneapi::mkl::transpose::nontrans) ? n : k, 1);
|
||||
const int ldC = std::max(n, 1);
|
||||
|
||||
const float alpha = 1.f, beta = 0.f;
|
||||
// TODO: Intel MKL ERROR will occur when using cblas_sgemm_batch
|
||||
/*for (int i = 0; i < b; ++i) {
|
||||
cblas_sgemm(CblasRowMajor, opA, opB, m, n, k, alpha, A + m * k * i,
|
||||
ldA, B + k * n * i, ldB, beta, C + m * n * i, ldC);
|
||||
}*/
|
||||
|
||||
sycl::queue q(sycl::cpu_selector{});
|
||||
// Catch asynchronous exceptions
|
||||
auto exception_handler = [](cl::sycl::exception_list exceptions) {
|
||||
for (std::exception_ptr const &e : exceptions) {
|
||||
try {
|
||||
std::rethrow_exception(e);
|
||||
} catch (cl::sycl::exception const &e) {
|
||||
std::cout
|
||||
<< "Caught asynchronous SYCL exception during GEMM:\n"
|
||||
<< e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// create execution queue and buffers of matrix data
|
||||
cl::sycl::queue main_queue(sycl::cpu_selector{}, exception_handler);
|
||||
|
||||
cl::sycl::buffer<float, 1> A_buffer(A, op->getInputs(0)->size());
|
||||
cl::sycl::buffer<float, 1> B_buffer(B, op->getInputs(1)->size());
|
||||
cl::sycl::buffer<float, 1> C_buffer(C, op->getOutput(0)->size());
|
||||
|
||||
// add oneapi::mkl::blas::gemm to execution queue
|
||||
try {
|
||||
oneapi::mkl::blas::row_major::gemm_batch(
|
||||
main_queue, opA, opB, m, n, k, alpha, A_buffer, ldA, m * k,
|
||||
B_buffer, ldB, k * n, beta, C_buffer, ldC, m * n, b);
|
||||
} catch (cl::sycl::exception const &e) {
|
||||
std::cout << "\t\tCaught synchronous SYCL exception during GEMM:\n"
|
||||
<< e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Matmul, DataType::Float32,
|
||||
MklDpcppMatmul<float>, "MklDpcppMatmul_CPU_float32");
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,58 @@
|
|||
#include "operators/pad.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class MklPad : public MklKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PadObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
|
||||
std::vector<dnnl_dim_t> dims;
|
||||
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) {
|
||||
dims.push_back(op->getInputs(0)->getDims()[i]);
|
||||
}
|
||||
auto paddedMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
|
||||
// dst md
|
||||
auto oDims = op->getOutput(0)->getDims();
|
||||
int ndim = oDims.size();
|
||||
std::vector<dnnl_dim_t> paddedDims, offsets;
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
paddedDims.push_back(oDims.at(i));
|
||||
paddedMd.data.padded_dims[i] = oDims.at(i);
|
||||
paddedMd.data.padded_offsets[i] = op->getPads().at(i);
|
||||
offsets.push_back(op->getPads().at(i));
|
||||
}
|
||||
// will fill padded area with zero.
|
||||
auto paddedMemory =
|
||||
dnnl::memory(paddedMd, context->getEngine(),
|
||||
op->getOutput(0)->getRawDataPtr<float *>());
|
||||
|
||||
auto dstMd =
|
||||
dnnl::memory::desc(paddedDims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(paddedDims.size()));
|
||||
|
||||
// copy src to the submemory of dst
|
||||
// create submemory
|
||||
auto md = dstMd.submemory_desc(dims, offsets);
|
||||
auto mem = dnnl::memory(md, context->getEngine(),
|
||||
op->getOutput(0)->getRawDataPtr<float *>());
|
||||
|
||||
auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
auto srcMemory =
|
||||
dnnl::memory(srcMd, context->getEngine(),
|
||||
op->getInputs(0)->getRawDataPtr<float *>());
|
||||
|
||||
// copy data to submemory
|
||||
dnnl::reorder(srcMemory, mem)
|
||||
.execute(context->getStream(),
|
||||
{{DNNL_ARG_FROM, srcMemory}, {DNNL_ARG_TO, mem}});
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Pad, DataType::Float32, MklPad,
|
||||
"Pad_Mkl_Float32");
|
||||
} // namespace infini
|
|
@ -0,0 +1,84 @@
|
|||
#include "operators/pooling.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class MklPooling : public MklKernelWithoutConfig {
|
||||
virtual dnnl::algorithm getAlgorithm() const = 0;
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PoolingObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
|
||||
float *const srcData = op->getInputs(0)->getRawDataPtr<float *>();
|
||||
float *const dstData = op->getOutput()->getRawDataPtr<float *>();
|
||||
|
||||
// create user memory that describes data layout in the buffers
|
||||
auto [n, c, h, w, r, s] = op->getNCHWRS();
|
||||
auto [ph, pw, sh, sw, dh, dw] = op->getPadStrideDilation();
|
||||
auto nDim = op->getOutput()->getDims().size();
|
||||
auto oh = op->getOutput()->getDims()[nDim - 2];
|
||||
auto ow = op->getOutput()->getDims()[nDim - 1];
|
||||
|
||||
auto srcMd = dnnl::memory::desc(
|
||||
{n, c, h, w}, dnnl::memory::data_type::f32, getUserFormatTag(nDim));
|
||||
auto srcMemory = dnnl::memory(srcMd, context->getEngine(), srcData);
|
||||
|
||||
auto userDstMd =
|
||||
dnnl::memory::desc({n, c, oh, ow}, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(nDim));
|
||||
|
||||
auto dstMd =
|
||||
dnnl::memory::desc({n, c, oh, ow}, dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::any);
|
||||
|
||||
using op_desc_t = dnnl::pooling_v2_forward::desc;
|
||||
using pd_t = dnnl::pooling_v2_forward::primitive_desc;
|
||||
|
||||
auto opDesc = op_desc_t(dnnl::prop_kind::forward_inference,
|
||||
getAlgorithm(), srcMd, dstMd, {sh, sw}, {r, s},
|
||||
{dh - 1, dw - 1}, {ph, pw}, {ph, pw});
|
||||
auto primDesc = pd_t(opDesc, context->getEngine());
|
||||
|
||||
if (primDesc.dst_desc() == userDstMd) {
|
||||
auto output = dnnl::memory(primDesc.dst_desc(),
|
||||
context->getEngine(), dstData);
|
||||
|
||||
dnnl::pooling_v2_forward(primDesc).execute(
|
||||
context->getStream(),
|
||||
{{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}});
|
||||
} else {
|
||||
auto dstMemory =
|
||||
dnnl::memory(primDesc.dst_desc(), context->getEngine());
|
||||
|
||||
dnnl::pooling_v2_forward(primDesc).execute(
|
||||
context->getStream(),
|
||||
{{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, dstMemory}});
|
||||
|
||||
auto output =
|
||||
dnnl::memory(userDstMd, context->getEngine(), dstData);
|
||||
dnnl::reorder(dstMemory, output)
|
||||
.execute(context->getStream(),
|
||||
{{DNNL_ARG_FROM, dstMemory}, {DNNL_ARG_TO, output}});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class MklAvgPool : public MklPooling {
|
||||
dnnl::algorithm getAlgorithm() const override {
|
||||
return dnnl::algorithm::pooling_avg_include_padding;
|
||||
}
|
||||
};
|
||||
|
||||
class MklMaxPool : public MklPooling {
|
||||
dnnl::algorithm getAlgorithm() const override {
|
||||
return dnnl::algorithm::pooling_max;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::AvgPool, DataType::Float32,
|
||||
MklAvgPool, "AvgPool_Mkl_Float32");
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::MaxPool, DataType::Float32,
|
||||
MklMaxPool, "MaxPool_Mkl_Float32");
|
||||
} // namespace infini
|
|
@ -0,0 +1,43 @@
|
|||
#include "core/kernel.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include <CL/sycl.hpp>
|
||||
#include <math.h>
|
||||
|
||||
namespace infini {
|
||||
class MklPow : public MklKernelWithoutConfig {
|
||||
// TODO: not need to copy memory??
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<PowObj>(_op);
|
||||
auto in0Data = op->getInputs(0)->getRawDataPtr<float *>();
|
||||
auto in1Data = op->getInputs(1)->getRawDataPtr<float *>();
|
||||
auto outData = op->getOutput(0)->getRawDataPtr<float *>();
|
||||
int size = op->getInputs(0)->size();
|
||||
|
||||
// cpu_selector using openCL as backend;and host_selector bypasses the
|
||||
// OnenCL backend and runs directly on CPU hardware
|
||||
sycl::queue q(sycl::cpu_selector{});
|
||||
auto in0Device = sycl::malloc_device<float>(size, q);
|
||||
auto in1Device = sycl::malloc_device<float>(size, q);
|
||||
auto outDevice = sycl::malloc_device<float>(size, q);
|
||||
q.memcpy(in0Device, in0Data, size * sizeof(float));
|
||||
q.wait();
|
||||
q.memcpy(in1Device, in1Data, size * sizeof(float));
|
||||
q.wait();
|
||||
|
||||
q.parallel_for(sycl::range<1>(size), [=](sycl::id<1> i) {
|
||||
outDevice[i] = pow(in0Device[i], in1Device[i]);
|
||||
}).wait();
|
||||
q.memcpy(outData, outDevice, size * sizeof(float));
|
||||
q.wait();
|
||||
sycl::free(in0Device, q);
|
||||
sycl::free(in1Device, q);
|
||||
sycl::free(outDevice, q);
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Pow, DataType::Float32, MklPow,
|
||||
"Pow_Mkl_Float32");
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,69 @@
|
|||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
|
||||
namespace infini {
|
||||
class MklReduce : public MklKernelWithoutConfig {
|
||||
dnnl::algorithm getAlgorithm() const {
|
||||
return dnnl::algorithm::reduction_mean;
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ReduceMeanObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
|
||||
float *const srcData = op->getInputs(0)->getRawDataPtr<float *>();
|
||||
float *const dstData = op->getOutput()->getRawDataPtr<float *>();
|
||||
|
||||
// create user memory that describes data layout in the buffers
|
||||
std::vector<dnnl_dim_t> inDims, inStrides;
|
||||
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) {
|
||||
inDims.push_back(op->getInputs(0)->getDims()[i]);
|
||||
inStrides.push_back(op->getInputs(0)->getStride()[i]);
|
||||
}
|
||||
|
||||
std::vector<dnnl_dim_t> oDims(op->getInputs(0)->getDims().size(), 0),
|
||||
oStrides(op->getInputs(0)->getDims().size(), 1);
|
||||
if (!op->getKeepDims()) {
|
||||
oDims = inDims;
|
||||
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) {
|
||||
if (op->isReduced(i)) {
|
||||
oDims[i] = 1;
|
||||
}
|
||||
}
|
||||
int stride = 1;
|
||||
for (int i = (int)oDims.size() - 1; i >= 0; --i) {
|
||||
oStrides[i] = stride;
|
||||
stride *= oDims[i];
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < op->getOutput(0)->getDims().size(); ++i) {
|
||||
oDims[i] = op->getOutput(0)->getDims()[i];
|
||||
oStrides[i] = op->getOutput(0)->getStride()[i];
|
||||
}
|
||||
}
|
||||
|
||||
auto srcMd =
|
||||
dnnl::memory::desc(inDims, dnnl::memory::data_type::f32, inStrides);
|
||||
auto srcMemory = dnnl::memory(srcMd, context->getEngine(), srcData);
|
||||
|
||||
auto dstMd =
|
||||
dnnl::memory::desc(oDims, dnnl::memory::data_type::f32, oStrides);
|
||||
auto output = dnnl::memory(dstMd, context->getEngine(), dstData);
|
||||
|
||||
using op_desc_t = dnnl::reduction::desc;
|
||||
using pd_t = dnnl::reduction::primitive_desc;
|
||||
|
||||
auto opDesc = op_desc_t(getAlgorithm(), srcMd, dstMd, 0, 0);
|
||||
auto primDesc = pd_t(opDesc, context->getEngine());
|
||||
|
||||
// create and execute primitive
|
||||
dnnl::reduction(primDesc).execute(
|
||||
context->getStream(),
|
||||
{{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}});
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::ReduceMean, DataType::Float32,
|
||||
MklReduce, "ReduceMean_Mkl_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,50 @@
|
|||
#include "operators/reshape.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class MklReshape : public MklKernelWithoutConfig {
|
||||
void compute(const Operator &op,
|
||||
const RuntimeObj *_context) const override {
|
||||
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
|
||||
std::vector<dnnl_dim_t> dims;
|
||||
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i)
|
||||
dims.push_back(op->getInputs(0)->getDims()[i]);
|
||||
|
||||
// create src md and src memory
|
||||
auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
|
||||
// dst md
|
||||
auto oDims = op->getOutput(0)->getDims();
|
||||
int ndim = oDims.size();
|
||||
std::vector<dnnl_dim_t> reshapeDims;
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
reshapeDims.push_back(oDims.at(i));
|
||||
}
|
||||
auto reshapeMd = srcMd.reshape(reshapeDims);
|
||||
auto reshapeMemory =
|
||||
dnnl::memory(reshapeMd, context->getEngine(),
|
||||
op->getInputs(0)->getRawDataPtr<float *>());
|
||||
|
||||
auto dstMd =
|
||||
dnnl::memory::desc(reshapeDims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(reshapeDims.size()));
|
||||
auto output = dnnl::memory(dstMd, context->getEngine(),
|
||||
op->getOutput(0)->getRawDataPtr<float *>());
|
||||
|
||||
// copy data to dst
|
||||
dnnl::reorder(reshapeMemory, output)
|
||||
.execute(context->getStream(),
|
||||
{{DNNL_ARG_FROM, reshapeMemory}, {DNNL_ARG_TO, output}});
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Reshape, DataType::Float32,
|
||||
MklReshape, "Reshape_Mkl_Float32");
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Identity, DataType::Float32,
|
||||
MklReshape, "Identify_Mkl_Float32");
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Flatten, DataType::Float32,
|
||||
MklReshape, "Flatten_Mkl_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,80 @@
|
|||
#include "operators/resize.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class MklResize : public MklKernelWithoutConfig {
|
||||
dnnl::algorithm getAlgorithm(Ref<ResizeObj> op) const {
|
||||
switch (op->getMode()) {
|
||||
case ResizeObj::ECoeffMode::nearest: {
|
||||
if (op->getNearestMode() !=
|
||||
enum_to_underlying(ResizeObj::ENearestMode::ceil))
|
||||
IT_TODO_HALT();
|
||||
return dnnl::algorithm::resampling_nearest;
|
||||
}
|
||||
case ResizeObj::ECoeffMode::linear:
|
||||
return dnnl::algorithm::resampling_linear;
|
||||
|
||||
default:
|
||||
IT_TODO_HALT();
|
||||
}
|
||||
return dnnl::algorithm::resampling_nearest;
|
||||
}
|
||||
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ResizeObj>(_op);
|
||||
|
||||
// only support default coordinate transmode??
|
||||
if (op->getCoordinateTransMode() !=
|
||||
enum_to_underlying(ResizeObj::ECoordinateTransMode::halfPixel))
|
||||
IT_TODO_HALT();
|
||||
|
||||
int nDim = op->getInputs(0)->getDims().size();
|
||||
IT_ASSERT(nDim == 3 || nDim == 4 ||
|
||||
nDim == 5 &&
|
||||
(op->getInputs(0)->getDims()[0] == 1 &&
|
||||
op->getInputs(0)->getDims()[1] == 1) &&
|
||||
(op->getOutput(0)->getDims()[0] == 1 &&
|
||||
op->getOutput(0)->getDims()[1] == 1));
|
||||
|
||||
IT_ASSERT(op->getScales().size() == nDim);
|
||||
std::vector<float>::iterator beg = op->getScales().begin() + 2;
|
||||
std::vector<float> scales(beg, op->getScales().end());
|
||||
|
||||
// create user memory that describes data layout in the buffers
|
||||
std::vector<dnnl_dim_t> idims, odims;
|
||||
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i) {
|
||||
idims.push_back(op->getInputs(0)->getDims()[i]);
|
||||
odims.push_back(op->getOutput(0)->getDims()[i]);
|
||||
}
|
||||
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
|
||||
float *const srcData = op->getInputs(0)->getRawDataPtr<float *>();
|
||||
float *const dstData = op->getOutput()->getRawDataPtr<float *>();
|
||||
|
||||
auto srcMd = dnnl::memory::desc(idims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(idims.size()));
|
||||
auto srcMemory = dnnl::memory(srcMd, context->getEngine(), srcData);
|
||||
|
||||
auto dstMd = dnnl::memory::desc(odims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(odims.size()));
|
||||
auto output = dnnl::memory(dstMd, context->getEngine(), dstData);
|
||||
|
||||
using op_desc_t = dnnl::resampling_forward::desc;
|
||||
using pd_t = dnnl::resampling_forward::primitive_desc;
|
||||
|
||||
auto opDesc = op_desc_t(dnnl::prop_kind::forward_inference,
|
||||
getAlgorithm(op), scales, srcMd, dstMd);
|
||||
auto primDesc = pd_t(opDesc, context->getEngine());
|
||||
|
||||
// create and execute primitive
|
||||
dnnl::resampling_forward(primDesc).execute(
|
||||
context->getStream(),
|
||||
{{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}});
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Resize, DataType::Float32, MklResize,
|
||||
"Resize_Mkl_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,46 @@
|
|||
#include "operators/slice.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class MklSlice : public MklKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SliceObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
|
||||
std::vector<dnnl_dim_t> dims;
|
||||
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i)
|
||||
dims.push_back(op->getInputs(0)->getDims()[i]);
|
||||
|
||||
// create src md
|
||||
auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
|
||||
// dst md
|
||||
auto oDims = op->getOutput(0)->getDims();
|
||||
int ndim = oDims.size();
|
||||
std::vector<dnnl_dim_t> sDims, offsets;
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
sDims.push_back(oDims.at(i));
|
||||
offsets.push_back(op->getStart().at(i));
|
||||
}
|
||||
auto sliceMd = srcMd.submemory_desc(sDims, offsets);
|
||||
auto sliceMemory =
|
||||
dnnl::memory(sliceMd, context->getEngine(),
|
||||
op->getInputs(0)->getRawDataPtr<float *>());
|
||||
|
||||
auto dstMd = dnnl::memory::desc(sDims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(sDims.size()));
|
||||
auto output = dnnl::memory(dstMd, context->getEngine(),
|
||||
op->getOutput(0)->getRawDataPtr<float *>());
|
||||
|
||||
// copy data to dst
|
||||
dnnl::reorder(sliceMemory, output)
|
||||
.execute(context->getStream(),
|
||||
{{DNNL_ARG_FROM, sliceMemory}, {DNNL_ARG_TO, output}});
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Slice, DataType::Float32, MklSlice,
|
||||
"Slice_Mkl_Float32");
|
||||
} // namespace infini
|
|
@ -0,0 +1,43 @@
|
|||
#include "operators/softmax.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class MklSoftmax : public MklKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SoftmaxObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
|
||||
float *const srcData = op->getInputs(0)->getRawDataPtr<float *>();
|
||||
float *const dstData = op->getOutput()->getRawDataPtr<float *>();
|
||||
|
||||
// create user memory that describes data layout in the buffers
|
||||
std::vector<dnnl_dim_t> dims;
|
||||
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i)
|
||||
dims.push_back(op->getInputs(0)->getDims()[i]);
|
||||
|
||||
auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
auto srcMemory = dnnl::memory(srcMd, context->getEngine(), srcData);
|
||||
|
||||
auto dstMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
auto output = dnnl::memory(dstMd, context->getEngine(), dstData);
|
||||
|
||||
using op_desc_t = dnnl::softmax_forward::desc;
|
||||
using pd_t = dnnl::softmax_forward::primitive_desc;
|
||||
|
||||
auto opDesc =
|
||||
op_desc_t(dnnl::prop_kind::forward_inference, srcMd, op->getAxis());
|
||||
auto primDesc = pd_t(opDesc, context->getEngine());
|
||||
|
||||
// create and execute primitive
|
||||
dnnl::softmax_forward(primDesc).execute(
|
||||
context->getStream(),
|
||||
{{DNNL_ARG_SRC, srcMemory}, {DNNL_ARG_DST, output}});
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Softmax, DataType::Float32,
|
||||
MklSoftmax, "Softmax_Mkl_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,54 @@
|
|||
#include "operators/split.h"
|
||||
#include "intelcpu/mkl_kernel_without_config.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class MklSplit : public MklKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<SplitObj>(_op);
|
||||
auto context = dynamic_cast<const MklRuntimeObj *>(_context);
|
||||
|
||||
std::vector<dnnl_dim_t> dims;
|
||||
for (size_t i = 0; i < op->getInputs(0)->getDims().size(); ++i)
|
||||
dims.push_back(op->getInputs(0)->getDims()[i]);
|
||||
|
||||
// create src md
|
||||
auto srcMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
|
||||
// dst md
|
||||
std::vector<dnnl::memory::desc> dstsMd;
|
||||
std::vector<dnnl::memory> dsts;
|
||||
int offset = 0;
|
||||
for (size_t i = 0; i < op->getOutputs().size(); i++) {
|
||||
auto oDims = op->getOutput(i)->getDims();
|
||||
int ndim = oDims.size();
|
||||
std::vector<dnnl_dim_t> dims, offsets(ndim, 0);
|
||||
for (int i = 0; i < ndim; ++i) {
|
||||
dims.push_back(oDims.at(i));
|
||||
}
|
||||
offsets[op->getDim()] = offset;
|
||||
auto splitMd = srcMd.submemory_desc(dims, offsets);
|
||||
auto splitMemory =
|
||||
dnnl::memory(splitMd, context->getEngine(),
|
||||
op->getInputs(0)->getRawDataPtr<float *>());
|
||||
|
||||
auto dstMd = dnnl::memory::desc(dims, dnnl::memory::data_type::f32,
|
||||
getUserFormatTag(dims.size()));
|
||||
auto output =
|
||||
dnnl::memory(dstMd, context->getEngine(),
|
||||
op->getOutput(i)->getRawDataPtr<float *>());
|
||||
|
||||
// copy data to dst
|
||||
dnnl::reorder(splitMemory, output)
|
||||
.execute(context->getStream(),
|
||||
{{DNNL_ARG_FROM, splitMemory}, {DNNL_ARG_TO, output}});
|
||||
|
||||
offset += dims.at(op->getDim());
|
||||
}
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL(Device::INTELCPU, OpType::Split, DataType::Float32, MklSplit,
|
||||
"Split_Mkl_Float32");
|
||||
}; // namespace infini
|
|
@ -1,13 +0,0 @@
|
|||
#include "mkl/mkl_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
namespace infini {
|
||||
MklRuntimeObj::MklRuntimeObj() : CpuRuntimeObj(Device::MKL) {
|
||||
dnnl_engine_create(&engine, dnnl_engine_kind_t::dnnl_cpu, 0);
|
||||
}
|
||||
|
||||
MklRuntimeObj::~MklRuntimeObj() {
|
||||
mkl_free_buffers();
|
||||
dnnl_engine_destroy(engine);
|
||||
}
|
||||
} // namespace infini
|
|
@ -39,18 +39,25 @@ vector<int> ReshapeObj::getOpAttrVector() const {
|
|||
return ret;
|
||||
}
|
||||
|
||||
FlattenObj::FlattenObj(GraphObj *graph, Tensor input, Tensor output)
|
||||
FlattenObj::FlattenObj(GraphObj *graph, Tensor input, Tensor output, int _axis)
|
||||
: OperatorObj(OpType::Flatten, {input}, {output}) {
|
||||
if (_axis >= 0 && (size_t)_axis < input->getDims().size())
|
||||
axis = _axis;
|
||||
else if (_axis <= -1 && (size_t)_axis >= -input->getDims().size())
|
||||
axis = _axis + input->getDims().size();
|
||||
else
|
||||
IT_ASSERT(0);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
optional<vector<Shape>> FlattenObj::inferShape(const TensorVec &inputs) const {
|
||||
int size = 1;
|
||||
int sizeB = 1, sizeE = 1;
|
||||
auto dims = getInputs(0)->getDims();
|
||||
for (size_t i = 0; i < dims.size(); ++i)
|
||||
size *= dims.at(i);
|
||||
int ndim = dims.size();
|
||||
for (int i = 0; i < ndim; ++i)
|
||||
((i < axis) ? sizeB : sizeE) *= dims.at(i);
|
||||
|
||||
return {{{size}}};
|
||||
return {{{sizeB, sizeE}}};
|
||||
}
|
||||
|
||||
std::string FlattenObj::toString() const {
|
||||
|
@ -59,18 +66,20 @@ std::string FlattenObj::toString() const {
|
|||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ")";
|
||||
os << "output=" << outputs[0]->getGuid() << ",";
|
||||
os << "axis=" << axis << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> FlattenObj::getWorkloadVector() const {
|
||||
vector<int> ret = inputs[0]->getDims();
|
||||
ret.emplace(ret.begin(), axis);
|
||||
ret.emplace(ret.begin(), enum_to_underlying(type));
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> FlattenObj::getOpAttrVector() const {
|
||||
return {enum_to_underlying(type)};
|
||||
return {enum_to_underlying(type), axis};
|
||||
}
|
||||
|
||||
IdentityObj::IdentityObj(GraphObj *graph, Tensor input, Tensor output)
|
||||
|
|
|
@ -70,25 +70,6 @@ void ResizeObj::init(const Tensor &input, const Tensor &sizes,
|
|||
}
|
||||
}
|
||||
}
|
||||
/*
|
||||
Operator ResizeObj::clone(TensorVec inputs, TensorVec outputs) {
|
||||
Tensor roi{nullptr}, sizes{nullptr}, scales{nullptr};
|
||||
if (inputs.size() == 3)
|
||||
roi = inputs[2];
|
||||
if (isResizeBySizes())
|
||||
sizes = inputs[1];
|
||||
else
|
||||
scales = inputs[1];
|
||||
|
||||
if (mode == ECoeffMode::nearest)
|
||||
return make_ref<ResizeObj>(nullptr, inputs[0], outputs[0], axes,
|
||||
inputs[1], nullptr, roi, ratioPolicy,
|
||||
nearestMode, coMode);
|
||||
else
|
||||
return make_ref<ResizeObj>(nullptr, inputs[0], outputs[0], axes,
|
||||
inputs[1], nullptr, roi, mode, ratioPolicy,
|
||||
coMode);
|
||||
}*/
|
||||
|
||||
void ResizeObj::InitBySizes(Tensor input, Tensor sizes,
|
||||
const std::optional<vector<int>> &axes) {
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
#include "operators/softmax.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
SoftmaxObj::SoftmaxObj(GraphObj *graph, Tensor input, Tensor output, int _axis)
|
||||
: OperatorObj(OpType::Softmax, {input}, {output}) {
|
||||
if (_axis >= 0 && (size_t)_axis < input->getDims().size())
|
||||
axis = _axis;
|
||||
else if (_axis <= -1 && (size_t)_axis >= -input->getDims().size())
|
||||
axis = _axis + input->getDims().size();
|
||||
else
|
||||
IT_ASSERT(0);
|
||||
IT_ASSERT(checkValid(graph));
|
||||
}
|
||||
|
||||
std::string SoftmaxObj::toString() const {
|
||||
std::ostringstream os;
|
||||
os << OpRegistry::getOpName(type) << "[" << getGuid() << "]";
|
||||
os << "(";
|
||||
os << vecToString(inputs[0]->getDims()) << ",";
|
||||
os << "input=" << inputs[0]->getGuid() << ",";
|
||||
os << "output=" << outputs[0]->getGuid() << ",";
|
||||
os << "axis=" << axis << ")";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
vector<int> SoftmaxObj::getWorkloadVector() const {
|
||||
vector<int> ret{enum_to_underlying(type), axis};
|
||||
const Shape shape = outputs[0]->getDims();
|
||||
ret.insert(ret.end(), shape.begin(), shape.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
vector<int> SoftmaxObj::getOpAttrVector() const {
|
||||
return {enum_to_underlying(type), axis};
|
||||
}
|
||||
} // namespace infini
|
|
@ -51,7 +51,7 @@ TEST(CUDA_Flatten, run) {
|
|||
// Build CUDA graph
|
||||
Graph g = make_ref<GraphObj>(cudaRuntime);
|
||||
auto i = g->cloneTensor(icpu);
|
||||
auto op = g->addOp<FlattenObj>(i, nullptr);
|
||||
auto op = g->addOp<FlattenObj>(i, nullptr, 2);
|
||||
|
||||
// allocate CUDA memory
|
||||
g->dataMalloc();
|
||||
|
|
|
@ -0,0 +1,142 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "cuda/cuda_runtime.h"
|
||||
#include "cuda/cuda_utility.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "test.h"
|
||||
#include <cmath>
|
||||
namespace infini {
|
||||
|
||||
TEST(cuDNN_Softmax, run_axis1) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu =
|
||||
make_ref<TensorObj>(Shape{2, 4}, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->copyin(vector<float>{0, 1, 2, 3, 10000, 10001, 10002, 10003});
|
||||
|
||||
// GPU
|
||||
Graph cudaGraph = make_ref<GraphObj>(cudaRuntime);
|
||||
auto inputGpu = cudaGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = cudaGraph->addOp<SoftmaxObj>(inputGpu, nullptr, 1);
|
||||
cudaGraph->dataMalloc();
|
||||
cudaRuntime->run(cudaGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
cudaPrintTensor(outputGpu);
|
||||
// Check
|
||||
EXPECT_TRUE(outputGpu2Cpu->equalData(
|
||||
vector<float>{0.032058604, 0.08714432, 0.23688284, 0.6439143,
|
||||
0.032058604, 0.08714432, 0.23688284, 0.6439143}));
|
||||
}
|
||||
|
||||
TEST(cuDNN_Softmax, run_axis0) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu =
|
||||
make_ref<TensorObj>(Shape{2, 4}, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->copyin(vector<float>{0, 1, 2, 3, 10000, 10001, 10002, 10003});
|
||||
|
||||
// GPU
|
||||
Graph cudaGraph = make_ref<GraphObj>(cudaRuntime);
|
||||
auto inputGpu = cudaGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = cudaGraph->addOp<SoftmaxObj>(inputGpu, nullptr, 0);
|
||||
cudaGraph->dataMalloc();
|
||||
cudaRuntime->run(cudaGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
cudaPrintTensor(outputGpu);
|
||||
// Check
|
||||
EXPECT_TRUE(
|
||||
outputGpu2Cpu->equalData(vector<float>{0., 0., 0., 0., 1, 1, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(cuDNN_Softmax2, run_axis1) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu =
|
||||
make_ref<TensorObj>(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(IncrementalGenerator());
|
||||
|
||||
// GPU
|
||||
Graph cudaGraph = make_ref<GraphObj>(cudaRuntime);
|
||||
auto inputGpu = cudaGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = cudaGraph->addOp<SoftmaxObj>(inputGpu, nullptr, 1);
|
||||
cudaGraph->dataMalloc();
|
||||
cudaRuntime->run(cudaGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
cudaPrintTensor(outputGpu);
|
||||
// Check
|
||||
EXPECT_TRUE(outputGpu2Cpu->equalData(vector<float>{
|
||||
0.0179862, 0.0179862, 0.0179862, 0.0179862, 0.9820138, 0.9820138,
|
||||
0.9820138, 0.9820138, 0.0179862, 0.0179862, 0.0179862, 0.0179862,
|
||||
0.9820138, 0.9820138, 0.9820138, 0.9820138}));
|
||||
}
|
||||
|
||||
TEST(cuDNN_Softmax2, run_axis2) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu =
|
||||
make_ref<TensorObj>(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(IncrementalGenerator());
|
||||
|
||||
// GPU
|
||||
Graph cudaGraph = make_ref<GraphObj>(cudaRuntime);
|
||||
auto inputGpu = cudaGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = cudaGraph->addOp<SoftmaxObj>(inputGpu, nullptr, 2);
|
||||
cudaGraph->dataMalloc();
|
||||
cudaRuntime->run(cudaGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
cudaPrintTensor(outputGpu);
|
||||
// Check
|
||||
EXPECT_TRUE(outputGpu2Cpu->equalData(vector<float>{
|
||||
0.1192029, 0.1192029, 0.8807971, 0.8807971, 0.1192029, 0.1192029,
|
||||
0.8807971, 0.8807971, 0.1192029, 0.1192029, 0.8807971, 0.8807971,
|
||||
0.1192029, 0.1192029, 0.8807971, 0.8807971}));
|
||||
}
|
||||
|
||||
TEST(cuDNN_Softmax2, run_axis3) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto cudaRuntime = make_ref<CudaRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu =
|
||||
make_ref<TensorObj>(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime);
|
||||
inputCpu->dataMalloc();
|
||||
inputCpu->setData(IncrementalGenerator());
|
||||
|
||||
// GPU
|
||||
Graph cudaGraph = make_ref<GraphObj>(cudaRuntime);
|
||||
auto inputGpu = cudaGraph->cloneTensor(inputCpu);
|
||||
auto gpuOp = cudaGraph->addOp<SoftmaxObj>(inputGpu, nullptr, 3);
|
||||
cudaGraph->dataMalloc();
|
||||
cudaRuntime->run(cudaGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
cudaPrintTensor(outputGpu);
|
||||
// Check
|
||||
EXPECT_TRUE(outputGpu2Cpu->equalData(vector<float>{
|
||||
0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586,
|
||||
0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586,
|
||||
0.2689414, 0.7310586, 0.2689414, 0.7310586}));
|
||||
}
|
||||
} // namespace infini
|
|
@ -41,7 +41,6 @@ void testUnary(const std::function<void(void *, size_t, DataType)> &generator,
|
|||
|
||||
TEST(cuDNN_Unary, run) {
|
||||
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SoftmaxObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/batch_norm.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
TEST(MklBatchNorm, run) {
|
||||
// Runtime
|
||||
auto runtime = make_ref<MklRuntimeObj>();
|
||||
|
||||
// Build graph
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
auto i = g->addTensor(Shape{1, 3, 2, 2}, DataType::Float32);
|
||||
auto mean = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32);
|
||||
auto var = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32);
|
||||
auto scale = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32);
|
||||
auto bias = g->addTensor(Shape{1, 3, 1, 1}, DataType::Float32);
|
||||
auto op =
|
||||
g->addOp<BatchNormObj>(i, nullptr, mean, var, scale, bias, 0.9, 0);
|
||||
g->dataMalloc();
|
||||
i->setData(IncrementalGenerator());
|
||||
mean->copyin(vector<float>{1, 6, 9});
|
||||
var->copyin(vector<float>{4, 1, 9});
|
||||
scale->setData(OneGenerator());
|
||||
bias->setData(ZeroGenerator());
|
||||
|
||||
runtime->run(g);
|
||||
|
||||
auto o = op->getOutput();
|
||||
EXPECT_TRUE(o->equalData(vector<float>{
|
||||
-0.5, 0, 0.5, 1, -2, -1, 0, 1, -0.3333333, 0, 0.3333333, 0.6666667}));
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,29 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/concat.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(Concat, Mkl) {
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto t1 = g->addTensor({2, 2, 3, 1}, DataType::Float32);
|
||||
auto t2 = g->addTensor({2, 2, 1, 1}, DataType::Float32);
|
||||
auto t3 = g->addTensor({2, 2, 2, 1}, DataType::Float32);
|
||||
auto op = g->addOp<ConcatObj>(TensorVec{t1, t2, t3}, nullptr, 2);
|
||||
g->dataMalloc();
|
||||
t1->setData(IncrementalGenerator());
|
||||
t2->setData(OneGenerator());
|
||||
t3->setData(OneGenerator());
|
||||
|
||||
runtime->run(g);
|
||||
EXPECT_TRUE(op->getOutput()->equalData(
|
||||
vector<float>{0, 1, 2, 1, 1, 1, 3, 4, 5, 1, 1, 1,
|
||||
6, 7, 8, 1, 1, 1, 9, 10, 11, 1, 1, 1}));
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -2,7 +2,7 @@
|
|||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#include "core/runtime.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
@ -17,18 +17,15 @@ void testConvDnnl(
|
|||
|
||||
Tensor i0 = gMkl->addTensor({1, 3, 4, 4}, DataType::Float32);
|
||||
Tensor w0 = gMkl->addTensor({2, 3, 3, 3}, DataType::Float32);
|
||||
|
||||
// Build graph
|
||||
auto conv = gMkl->addOp<ConvObj>(i0, w0, nullptr, 1, 1, 2, 1, 1, 2);
|
||||
// Malloc data for all tensors in a graph.
|
||||
gMkl->dataMalloc();
|
||||
i0->setData(generator);
|
||||
w0->setData(generator);
|
||||
|
||||
// Build graph
|
||||
auto conv = gMkl->addOp<ConvObj>(i0, w0, nullptr, 1, 1, 2, 1, 1, 2);
|
||||
// allocate CUDA memory
|
||||
gMkl->dataMalloc();
|
||||
// Execute on CUDA
|
||||
mklRuntime->run(gMkl);
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(conv->getOutput(0)->equalData(ansVec));
|
||||
}
|
||||
|
||||
|
@ -57,7 +54,7 @@ TEST(mkl_Conv, tune) {
|
|||
|
||||
// check record
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{Device::MKL, conv->getOpType(), DataType::Float32};
|
||||
KernelAttrs{Device::INTELCPU, conv->getOpType(), DataType::Float32};
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
|
||||
std::optional<PerfRecord> perfData =
|
||||
PerfEngine::getInstance().getPerfData(perfKey);
|
|
@ -1,7 +1,7 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/conv.h"
|
||||
|
||||
#include "test.h"
|
||||
|
@ -26,7 +26,7 @@ void testConvTransposedMkl(
|
|||
i0->setData(generator);
|
||||
w0->setData(generator);
|
||||
|
||||
runtime->prepareAndRun(gMkl);
|
||||
runtime->run(gMkl);
|
||||
EXPECT_TRUE(conv->getOutput()->equalData(ansVec));
|
||||
}
|
||||
|
||||
|
@ -50,7 +50,7 @@ TEST(mkl_ConvTransposed, run1) {
|
|||
i0->setData(IncrementalGenerator());
|
||||
w0->setData(IncrementalGenerator());
|
||||
|
||||
runtime->prepareAndRun(gMkl);
|
||||
runtime->run(gMkl);
|
||||
EXPECT_TRUE(conv->getOutput()->equalData(vector<float>{
|
||||
162, 351, 569, 413, 224, 405, 876, 1417, 1024, 553,
|
||||
747, 1611, 2598, 1869, 1005, 639, 1368, 2191, 1564, 835,
|
||||
|
@ -71,10 +71,10 @@ TEST(mkl_ConvTransposed, tune) {
|
|||
w0->setData(IncrementalGenerator());
|
||||
|
||||
bool tune = true;
|
||||
runtime->prepareAndRun(gMkl, tune);
|
||||
runtime->run(gMkl, tune);
|
||||
// check record
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{Device::MKL, conv->getOpType(), DataType::Float32};
|
||||
KernelAttrs{Device::INTELCPU, conv->getOpType(), DataType::Float32};
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, conv->getOpPerfKey()};
|
||||
std::optional<PerfRecord> perfData =
|
||||
PerfEngine::getInstance().getPerfData(perfKey);
|
|
@ -0,0 +1,84 @@
|
|||
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/element_wise.h"
|
||||
#include "operators/unary.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
using ExpectOutput = vector<float>;
|
||||
template <class T>
|
||||
void testBinary(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const ExpectOutput &ansVec) {
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
auto a = g->addTensor(shape, DataType::Float32);
|
||||
auto b = g->addTensor(shape, DataType::Float32);
|
||||
auto op = g->addOp<T>(a, b, nullptr);
|
||||
g->dataMalloc();
|
||||
a->setData(generator);
|
||||
b->setData(generator);
|
||||
|
||||
runtime->run(g);
|
||||
|
||||
auto c = op->getOutput();
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(c->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(dnnl_Binary, run) {
|
||||
testBinary<AddObj>(IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22});
|
||||
testBinary<SubObj>(IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||
testBinary<MulObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{0, 1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121});
|
||||
|
||||
testBinary<DivObj>(OneGenerator(), Shape{1, 2, 2, 3},
|
||||
ExpectOutput{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
|
||||
}
|
||||
|
||||
TEST(sycl_Pow, run) {
|
||||
testBinary<PowObj>(IncrementalGenerator(), Shape{1, 2, 2, 1},
|
||||
ExpectOutput{1, 1, 4, 27});
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void testUnary(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime rCpu = NativeCpuRuntimeObj::getInstance();
|
||||
auto rMkl = make_ref<MklRuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
|
||||
Graph gCpu = make_ref<GraphObj>(rCpu);
|
||||
Tensor iCpu = gCpu->addTensor(shape, DataType::Float32);
|
||||
auto opCpu = gCpu->addOp<T>(iCpu, nullptr);
|
||||
gCpu->dataMalloc();
|
||||
iCpu->setData(generator);
|
||||
rCpu->run(gCpu);
|
||||
|
||||
// MKL
|
||||
Graph gMkl = make_ref<GraphObj>(rMkl);
|
||||
auto iMkl = gMkl->addTensor(shape, DataType::Float32);
|
||||
auto opMkl = gMkl->addOp<T>(iMkl, nullptr);
|
||||
gMkl->dataMalloc();
|
||||
iMkl->setData(generator);
|
||||
rMkl->run(gMkl);
|
||||
|
||||
// Check
|
||||
EXPECT_TRUE(opCpu->getOutput()->equalData(opMkl->getOutput()));
|
||||
}
|
||||
|
||||
TEST(dnnl_Unary, run) {
|
||||
testUnary<ReluObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<SigmoidObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<AbsObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
testUnary<TanhObj>(IncrementalGenerator(), Shape{1, 2, 2, 3});
|
||||
}
|
||||
|
||||
}; // namespace infini
|
|
@ -0,0 +1,31 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/extend.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(MKL_Extend, run) {
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor(Shape{2, 3, 2, 2}, DataType::Float32);
|
||||
auto op = g->addOp<ExtendObj>(i, nullptr, 1, 1);
|
||||
g->dataMalloc();
|
||||
i->setData(IncrementalGenerator());
|
||||
|
||||
// Execute
|
||||
runtime->run(g);
|
||||
|
||||
auto o = op->getOutput();
|
||||
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(o->equalData(vector<float>{
|
||||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3,
|
||||
4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
|
||||
20, 21, 22, 23, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}));
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,60 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/gather.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
TEST(Gather, Cuda) {
|
||||
{
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
auto input = g->addTensor({3, 2}, DataType::Float32);
|
||||
auto index = g->addTensor({2, 2}, DataType::UInt32);
|
||||
g->dataMalloc();
|
||||
input->copyin(vector<float>{1, 2, 3, 4, 5, 6});
|
||||
index->copyin(vector<uint32_t>{0, 1, 1, 2});
|
||||
|
||||
auto op = g->addOp<GatherObj>(input, index, nullptr, 0);
|
||||
g->dataMalloc();
|
||||
runtime->run(g);
|
||||
|
||||
EXPECT_TRUE(
|
||||
op->getOutput()->equalData(vector<float>{1, 2, 3, 4, 3, 4, 5, 6}));
|
||||
}
|
||||
{
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
auto input = g->addTensor({3, 3}, DataType::Float32);
|
||||
auto index = g->addTensor({1, 2}, DataType::UInt32);
|
||||
g->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
index->copyin(vector<uint32_t>{0, 2});
|
||||
|
||||
auto op = g->addOp<GatherObj>(input, index, nullptr, 1);
|
||||
g->dataMalloc();
|
||||
runtime->run(g);
|
||||
|
||||
EXPECT_TRUE(
|
||||
op->getOutput()->equalData(vector<float>{0, 2, 3, 5, 6, 8}));
|
||||
}
|
||||
{
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
auto input = g->addTensor({2, 4, 2}, DataType::Float32);
|
||||
auto index = g->addTensor({3, 1}, DataType::UInt32);
|
||||
g->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
index->copyin(vector<uint32_t>{0, 3, 1});
|
||||
|
||||
auto op = g->addOp<GatherObj>(input, index, nullptr, 1);
|
||||
g->dataMalloc();
|
||||
runtime->run(g);
|
||||
|
||||
EXPECT_TRUE(op->getOutput()->equalData(
|
||||
vector<float>{0, 1, 6, 7, 2, 3, 8, 9, 14, 15, 10, 11}));
|
||||
}
|
||||
}
|
||||
} // namespace infini
|
|
@ -2,7 +2,7 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "mkl/mkl_runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/matmul.h"
|
||||
|
||||
#include "test.h"
|
||||
|
@ -27,7 +27,6 @@ void testMatmulMkl(
|
|||
|
||||
gCpu->dataMalloc();
|
||||
cpuRuntime->run(gCpu);
|
||||
matmul->getOutput()->printData();
|
||||
EXPECT_TRUE(matmul->getOutput()->equalData(ansVec));
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/pad.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
TEST(Pad, Mkl) {
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
// Build input data
|
||||
Tensor i = g->addTensor(Shape{1, 2, 3, 2}, DataType::Float32);
|
||||
auto op = g->addOp<PadObj>(i, nullptr, vector<int>{1, 0, 1, 1},
|
||||
vector<int>{0, 3});
|
||||
g->dataMalloc();
|
||||
i->setData(IncrementalGenerator());
|
||||
|
||||
// Execute
|
||||
runtime->run(g);
|
||||
|
||||
auto o = op->getOutput();
|
||||
|
||||
// check results
|
||||
EXPECT_TRUE(o->equalData(
|
||||
vector<float>{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 1, 0, 2, 3, 0, 4, 5, 0, 6, 7, 0, 8, 9, 0, 10, 11, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}));
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,47 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/pooling.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
using KDPS = vector<int>;
|
||||
using ExpectOutput = vector<float>;
|
||||
|
||||
template <class T>
|
||||
void testPoolMkl(const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape, const KDPS &kdps,
|
||||
const ExpectOutput &ansVec) {
|
||||
EXPECT_TRUE(kdps.size() == 8);
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
// Build input data
|
||||
Tensor i0 = g->addTensor(shape, DataType::Float32);
|
||||
auto pool = g->addOp<T>(i0, nullptr, kdps[0], kdps[1], kdps[2], kdps[3],
|
||||
kdps[4], kdps[5], kdps[6], kdps[7]);
|
||||
g->dataMalloc();
|
||||
i0->setData(generator);
|
||||
|
||||
runtime->run(g);
|
||||
// check results on CPU
|
||||
EXPECT_TRUE(pool->getOutput()->equalData(ansVec));
|
||||
}
|
||||
|
||||
TEST(mkl_MaxPool, run) {
|
||||
testPoolMkl<MaxPoolObj>(IncrementalGenerator(), Shape{1, 2, 5, 5},
|
||||
KDPS{3, 3, 1, 1, 1, 1, 2, 2},
|
||||
ExpectOutput{6, 8, 9, 16, 18, 19, 21, 23, 24, 31,
|
||||
33, 34, 41, 43, 44, 46, 48, 49});
|
||||
}
|
||||
|
||||
TEST(mkl_AvgPool, run) {
|
||||
testPoolMkl<AvgPoolObj>(
|
||||
IncrementalGenerator(), Shape{1, 2, 5, 5}, KDPS{3, 3, 1, 1, 1, 1, 2, 2},
|
||||
ExpectOutput{1.333333, 3.0000, 2.666667, 7.0000, 12.0000, 9.0000,
|
||||
8.0000, 13.0000, 9.333333, 12.44444, 19.666667, 13.777778,
|
||||
23.666667, 37.0000, 25.666667, 19.111111, 29.666667,
|
||||
20.444444});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,52 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/reduce_mean.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void test_reducemean(const Shape &shape, const vector<float> &data,
|
||||
const optional<const vector<int>> &axis, bool keepDims,
|
||||
const vector<float> &ExpectData) {
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor(shape, DataType::Float32);
|
||||
auto op = g->addOp<ReduceMeanObj>(i, nullptr, axis, keepDims);
|
||||
|
||||
g->dataMalloc();
|
||||
i->copyin(data);
|
||||
|
||||
// Execute
|
||||
runtime->run(g);
|
||||
|
||||
auto o = op->getOutput();
|
||||
|
||||
// check results
|
||||
EXPECT_TRUE(o->equalData(ExpectData));
|
||||
}
|
||||
|
||||
TEST(MKL_ReduceMean, run) {
|
||||
test_reducemean(Shape{3, 2, 2},
|
||||
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
|
||||
std::nullopt, true, vector<float>{18.25});
|
||||
test_reducemean(Shape{1, 3, 2, 2, 1},
|
||||
vector<float>{5, 1, 20, 2, 30, 1, 40, 2, 55, 1, 60, 2},
|
||||
std::nullopt, false, vector<float>{18.25});
|
||||
|
||||
test_reducemean(Shape{2, 3, 2, 2},
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23},
|
||||
vector<int>{1, 2}, false, vector<float>{5, 6, 17, 18});
|
||||
test_reducemean(Shape{2, 3, 2, 2, 1},
|
||||
vector<float>{0, 1, 2, 3, 4, 5, 6, 7,
|
||||
8, 9, 10, 11, 12, 13, 14, 15,
|
||||
16, 17, 18, 19, 20, 21, 22, 23},
|
||||
vector<int>{1, 2}, true, vector<float>{5, 6, 17, 18});
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,57 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/reshape.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(Reshape, Mkl) {
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<ReshapeObj>(input, nullptr, Shape{3, 2, 4, 3});
|
||||
g->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
|
||||
runtime->run(g);
|
||||
|
||||
auto o = g->cloneTensor(op->getOutput(0));
|
||||
// check results
|
||||
EXPECT_TRUE(o->equalData(input));
|
||||
}
|
||||
|
||||
TEST(Flatten, Mkl) {
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<FlattenObj>(input, nullptr, 2);
|
||||
g->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
|
||||
runtime->run(g);
|
||||
|
||||
auto o = g->cloneTensor(op->getOutput(0));
|
||||
// check results
|
||||
EXPECT_TRUE(o->equalData(input));
|
||||
}
|
||||
|
||||
TEST(Identify, Mkl) {
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<IdentityObj>(input, nullptr);
|
||||
g->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
|
||||
runtime->run(g);
|
||||
|
||||
auto o = g->cloneTensor(op->getOutput(0));
|
||||
// check results
|
||||
EXPECT_TRUE(o->equalData(input));
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,30 @@
|
|||
#include "cmath"
|
||||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/resize.h"
|
||||
#include "test.h"
|
||||
namespace infini {
|
||||
TEST(Resize, Mkl_downsample_sizes_nearest) {
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
Graph gCpu = make_ref<GraphObj>(cpuRuntime);
|
||||
|
||||
auto input = gCpu->addTensor({1, 1, 2, 4}, DataType::Float32);
|
||||
auto sizes = gCpu->addTensor({4}, DataType::UInt32);
|
||||
gCpu->dataMalloc();
|
||||
input->copyin(vector<float>{1, 2, 3, 4, 5, 6, 7, 8});
|
||||
sizes->copyin(vector<uint32_t>{1, 1, 1, 3});
|
||||
|
||||
auto runtime = make_ref<MklRuntimeObj>();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto op = g->addOp<ResizeObj>(g->cloneTensor(input), nullptr, std::nullopt,
|
||||
g->cloneTensor(sizes), nullptr, nullptr,
|
||||
ResizeObj::EKeepAspectRatioPolicy::stretch,
|
||||
ResizeObj::ENearestMode::ceil);
|
||||
g->dataMalloc();
|
||||
runtime->run(g);
|
||||
|
||||
EXPECT_TRUE(op->getOutput(0)->equalData(vector<float>{5, 7, 8}));
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,26 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/slice.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
TEST(MKL_Slice, run) {
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
// Build input data
|
||||
Tensor i = g->addTensor(Shape{3, 2, 1, 5}, DataType::Float32);
|
||||
auto op =
|
||||
g->addOp<SliceObj>(i, nullptr, vector<int>{1, 1}, vector<int>{1, 4},
|
||||
vector<int>{0, 3}, std::nullopt);
|
||||
g->dataMalloc();
|
||||
i->setData(IncrementalGenerator());
|
||||
|
||||
// Execute
|
||||
runtime->run(g);
|
||||
|
||||
auto o = op->getOutput();
|
||||
EXPECT_TRUE(o->equalData(vector<float>{11, 12, 13, 14, 16, 17, 18, 19}));
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,83 @@
|
|||
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/softmax.h"
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
TEST(MklSoftmax, run) {
|
||||
// Runtime
|
||||
auto runtime = make_ref<MklRuntimeObj>();
|
||||
|
||||
// Build input data on intelcpu
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor(Shape{2, 4}, DataType::Float32);
|
||||
auto op = g->addOp<SoftmaxObj>(i, nullptr, 1);
|
||||
g->dataMalloc();
|
||||
i->copyin(vector<float>{0, 1, 2, 3, 10000, 10001, 10002, 10003});
|
||||
runtime->run(g);
|
||||
|
||||
// Check
|
||||
EXPECT_TRUE(op->getOutput(0)->equalData(
|
||||
vector<float>{0.032058604, 0.08714432, 0.23688284, 0.6439143,
|
||||
0.032058604, 0.08714432, 0.23688284, 0.6439143}));
|
||||
}
|
||||
|
||||
TEST(MklSoftmax, run_axis1) {
|
||||
// Runtime
|
||||
auto runtime = make_ref<MklRuntimeObj>();
|
||||
|
||||
// Build input data on intelcpu
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor(Shape{2, 2, 2, 2}, DataType::Float32);
|
||||
auto op = g->addOp<SoftmaxObj>(i, nullptr, 1);
|
||||
g->dataMalloc();
|
||||
i->setData(IncrementalGenerator());
|
||||
runtime->run(g);
|
||||
|
||||
// Check
|
||||
EXPECT_TRUE(op->getOutput(0)->equalData(vector<float>{
|
||||
0.0179862, 0.0179862, 0.0179862, 0.0179862, 0.9820138, 0.9820138,
|
||||
0.9820138, 0.9820138, 0.0179862, 0.0179862, 0.0179862, 0.0179862,
|
||||
0.9820138, 0.9820138, 0.9820138, 0.9820138}));
|
||||
}
|
||||
|
||||
TEST(MklSoftmax, run_axis2) {
|
||||
// Runtime
|
||||
auto runtime = make_ref<MklRuntimeObj>();
|
||||
|
||||
// Build input data on intelcpu
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor(Shape{2, 2, 2, 2}, DataType::Float32);
|
||||
auto op = g->addOp<SoftmaxObj>(i, nullptr, 2);
|
||||
g->dataMalloc();
|
||||
i->setData(IncrementalGenerator());
|
||||
runtime->run(g);
|
||||
|
||||
// Check
|
||||
EXPECT_TRUE(op->getOutput(0)->equalData(vector<float>{
|
||||
0.119203, 0.119203, 0.880797, 0.880797, 0.119203, 0.119203, 0.880797,
|
||||
0.880797, 0.119203, 0.119203, 0.880797, 0.880797, 0.119203, 0.119203,
|
||||
0.880797, 0.880797}));
|
||||
}
|
||||
|
||||
TEST(MklSoftmax, run_axis3) {
|
||||
// Runtime
|
||||
auto runtime = make_ref<MklRuntimeObj>();
|
||||
|
||||
// Build input data on intelcpu
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor(Shape{2, 2, 2, 2}, DataType::Float32);
|
||||
auto op = g->addOp<SoftmaxObj>(i, nullptr, 3);
|
||||
g->dataMalloc();
|
||||
i->setData(IncrementalGenerator());
|
||||
runtime->run(g);
|
||||
|
||||
// Check
|
||||
EXPECT_TRUE(op->getOutput(0)->equalData(vector<float>{
|
||||
0.2689414, 0.7310585, 0.2689414, 0.7310585, 0.2689414, 0.7310585,
|
||||
0.2689414, 0.7310585, 0.2689414, 0.7310585, 0.2689414, 0.7310585,
|
||||
0.2689414, 0.7310585, 0.2689414, 0.7310585}));
|
||||
}
|
||||
} // namespace infini
|
|
@ -0,0 +1,33 @@
|
|||
#include "core/graph.h"
|
||||
#include "core/runtime.h"
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "operators/split.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
TEST(Split, Mkl) {
|
||||
Runtime runtime = MklRuntimeObj::getInstance();
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
|
||||
auto input = g->addTensor({2, 10, 2, 1}, DataType::Float32);
|
||||
auto op = g->addOp<SplitObj>(input, std::nullopt, 1, 3);
|
||||
g->dataMalloc();
|
||||
input->setData(IncrementalGenerator());
|
||||
|
||||
runtime->run(g);
|
||||
|
||||
EXPECT_EQ(op->getOutputs().size(), (size_t)3);
|
||||
auto o0 = g->cloneTensor(op->getOutput(0));
|
||||
auto o1 = g->cloneTensor(op->getOutput(1));
|
||||
auto o2 = g->cloneTensor(op->getOutput(2));
|
||||
EXPECT_TRUE(
|
||||
o0->equalData(vector<float>{0, 1, 2, 3, 4, 5, 20, 21, 22, 23, 24, 25}));
|
||||
EXPECT_TRUE(o1->equalData(
|
||||
vector<float>{6, 7, 8, 9, 10, 11, 26, 27, 28, 29, 30, 31}));
|
||||
EXPECT_TRUE(o2->equalData(vector<float>{12, 13, 14, 15, 16, 17, 18, 19, 32,
|
||||
33, 34, 35, 36, 37, 38, 39}));
|
||||
}
|
||||
|
||||
} // namespace infini
|
|
@ -21,8 +21,26 @@ TEST(Flatten, ShapeInference) {
|
|||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<FlattenObj>(i, nullptr);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{72}));
|
||||
auto op = g->addOp<FlattenObj>(i, nullptr, 1);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{2, 36}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<FlattenObj>(i, nullptr, 0);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{1, 72}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<FlattenObj>(i, nullptr, -1);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{18, 4}));
|
||||
}
|
||||
{
|
||||
Graph g = make_ref<GraphObj>(runtime);
|
||||
Tensor i = g->addTensor({2, 3, 3, 4}, DataType::Float32);
|
||||
auto op = g->addOp<FlattenObj>(i, nullptr, -2);
|
||||
EXPECT_EQ(op->getOutput()->getDims(), (Shape{6, 12}));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,26 @@
|
|||
#!/bin/bash
|
||||
|
||||
. /home/spack/spack/share/spack/setup-env.sh
|
||||
spack load cuda@11.0.2 cudnn@8.0.3.33-11.0 intel-oneapi-dnn@2022.1.0 intel-oneapi-mkl@2022.1.0
|
||||
if [ "$#" == 0 ] || [ "$1" == "cuda" ]
|
||||
then
|
||||
echo "Load CUDA environment."
|
||||
spack load cuda@11.0.2 cudnn@8.0.3.33-11.0
|
||||
export CUDAHOSTCXX=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-9.4.0/gcc-9.4.0-st36klijpsnquihiy463hmedsyhoc3g6/bin/gcc
|
||||
elif [ "$1" == "intelcpu" ]
|
||||
then
|
||||
echo "Load INTELCPU environment."
|
||||
spack load intel-oneapi-dnn@2022.1.0 intel-oneapi-mkl@2022.1.0 intel-oneapi-compilers@2022.1.0
|
||||
# The default dnnl library is cpu_dpcpp_gpu_dpcpp which requires libsycl.so, after "spack load", and need to change to gomp explicitly.
|
||||
export LD_LIBRARY_PATH=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-dnn-2022.1.0-7rs6ht57zozyxhxx6s2qlrqzmqknhgzx/dnnl/2022.1.0/cpu_gomp/lib/:$LD_LIBRARY_PATH
|
||||
|
||||
|
||||
# flopen mkl libs will fail when used by python.
|
||||
# Refering to "https://groups.google.com/g/kaldi-help/c/m3nyQke0HS0/m/4fj8gkSWAgAJ", it is recommended to use mkl_rt instead,
|
||||
# but mkl_rt do not support dpc++ refered to https://www.intel.com/content/www/us/en/docs/onemkl/developer-guide-linux/2023-0/using-the-single-dynamic-library.html
|
||||
# Preloading the missing libs will work, refered to https://community.intel.com/t5/Intel-oneAPI-Math-Kernel-Library/mkl-fails-to-load/m-p/1155538
|
||||
|
||||
export MKLLIB_PATH=/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-12.1.0/intel-oneapi-mkl-2022.1.0-mf6te62fo6wxlo33jwwwgg5kljoagc6g/mkl/2022.1.0/
|
||||
export LD_PRELOAD=$MKLLIB_PATH/lib/intel64/libmkl_def.so.2:$MKLLIB_PATH/lib/intel64/libmkl_avx2.so.2:$MKLLIB_PATH/lib/intel64/libmkl_core.so:$MKLLIB_PATH/lib/intel64/libmkl_intel_lp64.so:$MKLLIB_PATH/lib/intel64/libmkl_intel_thread.so:/home/spack/spack/opt/spack/linux-ubuntu22.04-broadwell/gcc-11.3.0/intel-oneapi-compilers-2022.1.0-qrq4a63scjip455bpxvl5ipgqbllwecj/compiler/2022.1.0/linux/compiler/lib/intel64_lin/libiomp5.so
|
||||
else
|
||||
echo "Bad option. Please enter 'cuda' or 'intelcpu'. CUDA will be loaded by default if nothing specified."
|
||||
fi
|
||||
|
|
Loading…
Reference in New Issue