forked from jiuyuan/InfiniTensor
support kunlun xpu and add an operator named Add
This commit is contained in:
parent
26f0d13c26
commit
13d6a6144f
|
@ -1,14 +1,15 @@
|
|||
cmake_minimum_required(VERSION 3.17) # FindCUDAToolkit
|
||||
cmake_minimum_required(VERSION 3.12) # FindCUDAToolkit
|
||||
include(CMakeDependentOption)
|
||||
project(InfiniTensor C CXX)
|
||||
|
||||
# Do not change these options in this file. Use cmake.config, cmake -DOPTION=VALUE, or ccmake to specify them.
|
||||
option(USE_CUDA "Support CUDA GPU" OFF)
|
||||
option(USE_BANG "Support BANG MLU" OFF)
|
||||
option(USE_XPU "Support KUNLUN MLU" ON)
|
||||
option(USE_INTELCPU "Support INTELCPU" OFF)
|
||||
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" ON)
|
||||
option(USE_BACKTRACE "Print backtrace on exception and segmentation fault" OFF)
|
||||
option(USE_PROTOBUF "Serialize and deserialize tensors" OFF)
|
||||
option(BUILD_TEST "Build tests" OFF)
|
||||
option(BUILD_TEST "Build tests" ON)
|
||||
|
||||
cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF)
|
||||
cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF)
|
||||
|
@ -110,6 +111,11 @@ if(USE_BANG)
|
|||
list (APPEND SRC ${SRC_BANG})
|
||||
endif()
|
||||
|
||||
if(USE_XPU)
|
||||
file(GLOB_RECURSE SRC_XPU src/xpu/*.cc src/kernels/xpu/*.cc )
|
||||
list (APPEND SRC ${SRC_XPU})
|
||||
endif()
|
||||
|
||||
if(USE_INTELCPU)
|
||||
file(GLOB_RECURSE SRC_INTELCPU src/intelcpu/*.cc src/kernels/intelcpu/*.cc )
|
||||
list (APPEND SRC ${SRC_INTELCPU})
|
||||
|
@ -218,6 +224,37 @@ if(USE_BANG)
|
|||
target_link_libraries(InfiniTensor ${CAMBRICON_CNNL} ${CAMBRICON_CNRT} ${CAMBRICON_CNDRV} stdc++)
|
||||
endif()
|
||||
|
||||
if(USE_XPU)
|
||||
add_compile_definitions(USE_XPU=1)
|
||||
include_directories(src/kernels/xpu/include)
|
||||
if ((NOT DEFINED XPU_HOME) AND (NOT DEFINED ENV{XPU_HOME}))
|
||||
message(FATAL_ERROR "XPU_HOME is not defined from cmake or env")
|
||||
elseif (DEFINED XPU_HOME)
|
||||
set(XPU_HOME ${XPU_HOME} CACHE STRING "XPU_HOME directory for Kunlun development")
|
||||
else()
|
||||
set(XPU_HOME $ENV{XPU_HOME} CACHE STRING "XPU_HOME directory for Kunlun development")
|
||||
endif()
|
||||
message(STATUS "XPU_HOME: ${XPU_HOME}")
|
||||
|
||||
include_directories("${XPU_HOME}/include/")
|
||||
include_directories("${XPU_HOME}/XTDK/include/")
|
||||
find_library(XPU_RT libxpurt.so "${XPU_HOME}/lib64")
|
||||
find_library(XPU_DNN libxpuapi.so "${XPU_HOME}/XTDK/shlib")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lstdc++ -Wall -Werror")
|
||||
|
||||
if ((NOT DEFINED TARGET_CPU_ARCH) AND (NOT DEFINED ENV{TARGET_CPU_ARCH}))
|
||||
execute_process(COMMAND uname -m OUTPUT_VARIABLE _uname_m OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
set(TARGET_CPU_ARCH "${_uname_m}" CACHE STRING "Target CPU ARCH")
|
||||
elseif(DEFINED TARGET_CPU_ARCH)
|
||||
set(TARGET_CPU_ARCH ${TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH")
|
||||
else()
|
||||
set(TARGET_CPU_ARCH $ENV{TARGET_CPU_ARCH} CACHE STRING "Target CPU ARCH")
|
||||
endif()
|
||||
message(STATUS "TARGET_CPU_ARCH: ${TARGET_CPU_ARCH}")
|
||||
|
||||
target_link_libraries(InfiniTensor ${XPU_RT} ${XPU_DNN} stdc++)
|
||||
endif()
|
||||
|
||||
# # Python bindings
|
||||
# pybind11_add_module(infini MODULE ${FFI})
|
||||
# target_link_libraries(infini PRIVATE infini_cpp)
|
||||
|
@ -247,6 +284,9 @@ if(BUILD_TEST)
|
|||
if (USE_BANG)
|
||||
build_test(test/kernels/bang/*.cc)
|
||||
endif()
|
||||
if (USE_XPU)
|
||||
build_test(test/kernels/xpu/*.cc)
|
||||
endif()
|
||||
if (USE_INTELCPU)
|
||||
build_test(test/kernels/intelcpu/*.cc)
|
||||
endif()
|
||||
|
|
|
@ -29,7 +29,7 @@ using OpLists = list<Operator>;
|
|||
|
||||
using VType = uint32_t;
|
||||
|
||||
enum class Device { CPU = 1, CUDA, BANG, INTELCPU };
|
||||
enum class Device { CPU = 1, CUDA, BANG, INTELCPU, XPU };
|
||||
/***************** Forward declaration end *****************/
|
||||
|
||||
class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
||||
|
@ -69,6 +69,7 @@ class RuntimeObj : public std::enable_shared_from_this<RuntimeObj> {
|
|||
}
|
||||
bool isCuda() const { return device == Device::CUDA; }
|
||||
bool isBang() const { return device == Device::BANG; }
|
||||
bool isXPU() const { return device == Device::XPU; }
|
||||
void copyBlob(const TensorObj *dst, const TensorObj *src) const;
|
||||
// TODO: unify these copy APIs
|
||||
virtual void copyBlobFromCPU(void *dst, const void *src,
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
#pragma once
|
||||
namespace infini {
|
||||
namespace opTimer {
|
||||
double getPerfConvXdnn(int n, int c, int h, int w, int f, int r, int s,
|
||||
int padh, int padw, int strideh, int stridew,
|
||||
int dilationh, int dilationw, int group,
|
||||
const char *name);
|
||||
double getPerfMatmulXdnn(int b, int m, int n, int k, const char *name);
|
||||
} // namespace opTimer
|
||||
} // namespace infini
|
|
@ -0,0 +1,20 @@
|
|||
#pragma once
|
||||
#include "xpu/runtime_ex.h"
|
||||
#include "xpu/xdnn.h"
|
||||
#include "core/common.h"
|
||||
|
||||
#define checkXPUError(call) \
|
||||
{ \
|
||||
auto err = call; \
|
||||
if (XPU_SUCCESS != err) { \
|
||||
fprintf(stderr, "XPU error in %s:%i : %s.\n", __FILE__, __LINE__, \
|
||||
xpu_strerror(err)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace infini {
|
||||
|
||||
using XPUPtr = void *;
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,24 @@
|
|||
#pragma once
|
||||
#include "xpu/xpu_runtime.h"
|
||||
#include "core/kernel.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class XPUKernelWithoutConfig : public Kernel {
|
||||
public:
|
||||
virtual void compute(const Operator &op, const PerfRecord &record,
|
||||
const RuntimeObj *context) const {
|
||||
compute(op, context);
|
||||
}
|
||||
virtual void compute(const Operator &op,
|
||||
const RuntimeObj *context) const = 0;
|
||||
// Premise: op is idempotent since it is called multiple times.
|
||||
virtual PerfRecord tune(const Operator &op,
|
||||
const RuntimeObj *_context) const {
|
||||
auto context = dynamic_cast<const XPURuntimeObj *>(_context);
|
||||
return make_ref<PerfRecordObj>(timeit([&]() { compute(op, _context); },
|
||||
[&]() { context->sync(); }));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,63 @@
|
|||
#pragma once
|
||||
#include "xpu/xpu_common.h"
|
||||
#include "core/runtime.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
class XPURuntimeObj : public RuntimeObj {
|
||||
private:
|
||||
baidu::xpu::api::Context* xdnn;
|
||||
XPUPtr workspace;
|
||||
size_t workspaceSize;
|
||||
|
||||
public:
|
||||
XPURuntimeObj() : RuntimeObj(Device::XPU) {
|
||||
xdnn = baidu::xpu::api::create_context();
|
||||
// 10GB for Longformer
|
||||
// size_t longformerNum = 3lu * (1 << 30);
|
||||
workspaceSize = 7ll << 30; // 7 GB
|
||||
workspace = alloc(workspaceSize);
|
||||
}
|
||||
virtual ~XPURuntimeObj() {
|
||||
dealloc(workspace);
|
||||
baidu::xpu::api::destroy_context(xdnn);
|
||||
}
|
||||
string toString() const override;
|
||||
|
||||
void run(const Graph &graph, bool tune = false,
|
||||
bool profiling = false) const;
|
||||
// double runEvaluation(const Graph &graph, int nWarmups,
|
||||
// int nEvaluations) const;
|
||||
void sync() const;
|
||||
XPUPtr alloc(size_t size) override {
|
||||
void *ptr;
|
||||
xpu_malloc(&ptr, size);
|
||||
return ptr;
|
||||
}
|
||||
void dealloc(void *ptr) override { xpu_free(ptr); }
|
||||
baidu::xpu::api::Context* XPUHandle() const { return xdnn; }
|
||||
XPUPtr getWorkspace(size_t size) const {
|
||||
IT_ASSERT(size <= workspaceSize);
|
||||
return workspace;
|
||||
}
|
||||
|
||||
void copyBlobFromCPU(void *dst, const void *src,
|
||||
size_t bytes) const override {
|
||||
xpu_memcpy(dst, const_cast<void *>(src), bytes, XPUMemcpyKind::XPU_HOST_TO_DEVICE);
|
||||
}
|
||||
|
||||
void copyBlobToCPU(void *dst, const void *src,
|
||||
size_t bytes) const override {
|
||||
xpu_memcpy(dst, const_cast<void *>(src), bytes, XPUMemcpyKind::XPU_DEVICE_TO_HOST);
|
||||
}
|
||||
|
||||
void copyBlobInsideRuntime(void *dst, const void *src,
|
||||
size_t bytes) const override {
|
||||
xpu_memcpy(dst, const_cast<void *>(src), bytes, XPUMemcpyKind::XPU_DEVICE_TO_DEVICE);
|
||||
}
|
||||
|
||||
private:
|
||||
void runWithoutSync(const Graph &graph, bool tune, bool profiling) const;
|
||||
};
|
||||
|
||||
} // namespace infini
|
|
@ -21,6 +21,9 @@
|
|||
#ifdef USE_BANG
|
||||
#include "bang/bang_runtime.h"
|
||||
#endif
|
||||
#ifdef USE_XPU
|
||||
#include "xpu/xpu_runtime.h"
|
||||
#endif
|
||||
#ifdef USE_INTELCPU
|
||||
#include "intelcpu/mkl_runtime.h"
|
||||
#include "intelcpu/operator_timer.h"
|
||||
|
@ -127,6 +130,10 @@ static Ref<CudaRuntimeObj> cuda_runtime() { return make_ref<CudaRuntimeObj>(); }
|
|||
static Ref<BangRuntimeObj> bang_runtime() { return make_ref<BangRuntimeObj>(); }
|
||||
#endif
|
||||
|
||||
#ifdef USE_XPU
|
||||
static Ref<XPURuntimeObj> xpu_runtime() { return make_ref<XPURuntimeObj>(); }
|
||||
#endif
|
||||
|
||||
#ifdef USE_INTELCPU
|
||||
static Ref<RuntimeObj> intelcpu_runtime() { return make_ref<MklRuntimeObj>(); }
|
||||
#endif
|
||||
|
@ -239,6 +246,10 @@ void export_functions(py::module &m) {
|
|||
#ifdef USE_BANG
|
||||
.FUNCTION(bang_runtime)
|
||||
#endif
|
||||
|
||||
#ifdef USE_XPU
|
||||
.FUNCTION(xpu_runtime)
|
||||
#endif
|
||||
.FUNCTION(conv_attrs_of)
|
||||
.FUNCTION(conv_trans_attrs_of)
|
||||
.FUNCTION(matmul_attrs_of)
|
||||
|
@ -269,6 +280,10 @@ void init_graph_builder(py::module &m) {
|
|||
#ifdef USE_BANG
|
||||
py::class_<BangRuntimeObj, std::shared_ptr<BangRuntimeObj>, RuntimeObj>(
|
||||
m, "BangRuntime");
|
||||
#endif
|
||||
#ifdef USE_XPU
|
||||
py::class_<XPURuntimeObj, std::shared_ptr<XPURuntimeObj>, RuntimeObj>(
|
||||
m, "XPURuntime");
|
||||
#endif
|
||||
py::class_<TensorObj, std::shared_ptr<TensorObj>>(m, "Tensor")
|
||||
.def("fuid", &TensorObj::getFuid, policy::automatic)
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
#include "operators/element_wise.h"
|
||||
#include "xpu/xpu_kernel_without_config.h"
|
||||
#include "xpu/xpu_runtime.h"
|
||||
|
||||
namespace infini {
|
||||
class AddXdnn : public XPUKernelWithoutConfig {
|
||||
void compute(const Operator &_op,
|
||||
const RuntimeObj *_context) const override {
|
||||
auto op = as<ElementWiseObj>(_op);
|
||||
auto context = dynamic_cast<const XPURuntimeObj *>(_context);
|
||||
|
||||
void *const aData = (op->getInputs(0)->getRawDataPtr<void *>());
|
||||
void *const bData = (op->getInputs(1)->getRawDataPtr<void *>());
|
||||
void *const cData = (op->getOutput()->getRawDataPtr<void *>());
|
||||
|
||||
auto aDim = op->getInputs(0)->getDims();
|
||||
auto bDim = op->getInputs(1)->getDims();
|
||||
if (aDim.size() != 4 || bDim.size() != 4)
|
||||
IT_TODO_HALT();
|
||||
auto ret = baidu::xpu::api::broadcast_add<float>(context->XPUHandle(), (float*)aData, (float*)bData, (float*)cData, aDim, bDim);
|
||||
assert(ret == 0);
|
||||
return;
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL(Device::XPU, OpType::Add, DataType::Float32, AddXdnn,
|
||||
"Add_xdnn_XPU_Float32");
|
||||
}; // namespace infini
|
|
@ -0,0 +1,71 @@
|
|||
#include "xpu/operator_timer.h"
|
||||
#include "xpu/xpu_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/conv.h"
|
||||
#include "operators/matmul.h"
|
||||
#include "utils/data_generator.h"
|
||||
|
||||
namespace infini {
|
||||
namespace opTimer {
|
||||
|
||||
double getPerfConvXPU(int n, int c, int h, int w, int f, int r, int s,
|
||||
int padh, int padw, int strideh, int stridew,
|
||||
int dilationh, int dilationw, int group,
|
||||
const char *name) {
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime xpu = make_ref<XPURuntimeObj>();
|
||||
Graph gXpu = make_ref<GraphObj>(xpu);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
IT_ASSERT(c % group == 0);
|
||||
Tensor i0Cpu = gCpu->addTensor({n, h, w, c}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({f, r, s, c / group}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to Xpu
|
||||
Tensor i0XPU = gXpu->cloneTensor(i0Cpu);
|
||||
Tensor w0XPU = gXpu->cloneTensor(w0Cpu);
|
||||
// Build Xpu graph
|
||||
auto conv = gXpu->addOp<ConvObj>(i0XPU, w0XPU, nullptr, padh, padw,
|
||||
strideh, stridew, dilationh, dilationw);
|
||||
// allocate Xpu memory
|
||||
gXpu->dataMalloc();
|
||||
// Execute on Xpu
|
||||
bool tune = true;
|
||||
xpu->run(gXpu, tune);
|
||||
return xpu->getPerfTime(gXpu);
|
||||
}
|
||||
|
||||
double getPerfMatmulXPU(int b, int m, int n, int k, const char *name) {
|
||||
Runtime cpu = NativeCpuRuntimeObj::getInstance(); // CPUruntime is singleton
|
||||
Graph gCpu = make_ref<GraphObj>(cpu);
|
||||
Runtime xpu = make_ref<XPURuntimeObj>();
|
||||
Graph gXpu = make_ref<GraphObj>(xpu);
|
||||
// Set input data on CPU in a CPU Graph
|
||||
Tensor i0Cpu = gCpu->addTensor({b, m, k}, DataType::Float32);
|
||||
Tensor w0Cpu = gCpu->addTensor({b, k, n}, DataType::Float32);
|
||||
// Malloc data for all tensors in a graph. Do we need implicit allocation?
|
||||
gCpu->dataMalloc();
|
||||
i0Cpu->setData(IncrementalGenerator());
|
||||
w0Cpu->setData(IncrementalGenerator());
|
||||
|
||||
// Copy input tensors from CPU to Xpu
|
||||
Tensor i0XPU = gXpu->cloneTensor(i0Cpu);
|
||||
Tensor w0XPU = gXpu->cloneTensor(w0Cpu);
|
||||
// Build Xpu graph
|
||||
auto conv = gXpu->addOp<MatmulObj>(i0XPU, w0XPU, nullptr);
|
||||
// allocate Xpu memory
|
||||
gXpu->dataMalloc();
|
||||
// Execute on Xpu
|
||||
bool tune = true;
|
||||
xpu->run(gXpu, tune);
|
||||
return xpu->getPerfTime(gXpu);
|
||||
}
|
||||
|
||||
} // namespace opTimer
|
||||
} // namespace infini
|
|
@ -0,0 +1,59 @@
|
|||
#include "xpu/xpu_runtime.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/perf_engine.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
void XPURuntimeObj::runWithoutSync(const Graph &graph, bool tune = false,
|
||||
bool profiling = false) const {
|
||||
const auto &kernelRegistry = KernelRegistry::getInstance();
|
||||
auto &perfEngine = PerfEngine::getInstance();
|
||||
double totalTime = 0;
|
||||
std::map<OpType, double> opTime;
|
||||
std::map<OpType, int> opCnt;
|
||||
for (auto &op : graph->getOperators()) {
|
||||
// HACK: set correct data type
|
||||
auto kernelAttrs =
|
||||
KernelAttrs{device, op->getOpType(), DataType::Float32};
|
||||
Kernel *kernel = kernelRegistry.getKernel(kernelAttrs);
|
||||
auto perfKey = PerfEngine::Key{kernelAttrs, op->getOpPerfKey()};
|
||||
auto perfData = perfEngine.getPerfData(perfKey);
|
||||
if (!perfData && !tune) {
|
||||
kernel->compute(op, this);
|
||||
continue;
|
||||
}
|
||||
|
||||
PerfRecord record;
|
||||
if (!perfData) {
|
||||
record = kernel->tune(op, this);
|
||||
perfEngine.setPerfData(perfKey, record);
|
||||
} else
|
||||
record = perfData;
|
||||
|
||||
double t = record->time;
|
||||
totalTime += t;
|
||||
|
||||
if (profiling) {
|
||||
double t = timeit([&]() { kernel->compute(op, record, this); },
|
||||
[&]() { sync(); }, 1, 1);
|
||||
op->print();
|
||||
printf(" op_time on xpu %lf\n", t);
|
||||
totalTime += t;
|
||||
opTime[op->getOpType()] += t;
|
||||
opCnt[op->getOpType()]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void XPURuntimeObj::run(const Graph &graph, bool tune, bool profiling) const {
|
||||
if (profiling)
|
||||
IT_TODO_HALT();
|
||||
runWithoutSync(graph, tune, profiling);
|
||||
sync();
|
||||
}
|
||||
|
||||
void XPURuntimeObj::sync() const { ; }
|
||||
|
||||
string XPURuntimeObj::toString() const { return "XPU Runtime"; }
|
||||
|
||||
} // namespace infini
|
|
@ -0,0 +1,54 @@
|
|||
#include "xpu/xpu_runtime.h"
|
||||
#include "core/graph.h"
|
||||
#include "core/kernel.h"
|
||||
#include "core/runtime.h"
|
||||
#include "operators/element_wise.h"
|
||||
|
||||
#include "test.h"
|
||||
|
||||
namespace infini {
|
||||
|
||||
template <class T>
|
||||
void testAdd(
|
||||
const std::function<void(void *, size_t, DataType)> &generator,
|
||||
const Shape &shape) {
|
||||
// Runtime
|
||||
Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance();
|
||||
auto xpuRuntime = make_ref<XPURuntimeObj>();
|
||||
|
||||
// Build input data on CPU
|
||||
Tensor inputCpu1 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu1->dataMalloc();
|
||||
inputCpu1->setData(generator);
|
||||
Tensor inputCpu2 =
|
||||
make_ref<TensorObj>(shape, DataType::Float32, cpuRuntime);
|
||||
inputCpu2->dataMalloc();
|
||||
inputCpu2->setData(generator);
|
||||
|
||||
// GPU
|
||||
Graph xpuGraph = make_ref<GraphObj>(xpuRuntime);
|
||||
auto inputGpu1 = xpuGraph->cloneTensor(inputCpu1);
|
||||
auto inputGpu2 = xpuGraph->cloneTensor(inputCpu2);
|
||||
auto gpuOp = xpuGraph->addOp<T>(inputGpu1, inputGpu2, nullptr);
|
||||
xpuGraph->dataMalloc();
|
||||
xpuRuntime->run(xpuGraph);
|
||||
auto outputGpu = gpuOp->getOutput();
|
||||
auto outputGpu2Cpu = outputGpu->clone(cpuRuntime);
|
||||
// CPU
|
||||
Graph cpuGraph = make_ref<GraphObj>(cpuRuntime);
|
||||
auto cpuOp = cpuGraph->addOp<T>(inputCpu1, inputCpu2, nullptr);
|
||||
cpuGraph->dataMalloc();
|
||||
cpuRuntime->run(cpuGraph);
|
||||
auto outputCpu = cpuOp->getOutput();
|
||||
// Check
|
||||
//outputCpu->printData();
|
||||
//outputGpu2Cpu->printData();
|
||||
EXPECT_TRUE(outputCpu->equalData(outputGpu2Cpu));
|
||||
}
|
||||
|
||||
TEST(xpu_add, run) {
|
||||
testAdd<AddObj>(IncrementalGenerator(), Shape{10, 256, 256, 3});
|
||||
}
|
||||
|
||||
} // namespace infini
|
Loading…
Reference in New Issue