From feccd4f318d58d019508c60026de6f1fbee91bda Mon Sep 17 00:00:00 2001 From: constroy Li Date: Mon, 30 Oct 2023 15:04:16 +0800 Subject: [PATCH 01/28] fix tensor parallel for llama (#159) * fix Slice * change default rounds of timeit to 10 to reduce time * fix slice with large ends * Reshape support Int64 * support position_ids as input * skip last MatMul in Llama * skip infer_shapes to parse large model * update launch.py * fix split_concat_kernel * print more message in launch.py * Reshape supports both Int32 and Int64 * try infer_shapes and warn about failure * fix format --------- Co-authored-by: whjthu --- examples/distributed/launch.py | 43 +++++++++++++---------- examples/distributed/parallel_opt.py | 10 ++++-- include/core/common.h | 4 +-- pyinfinitensor/src/pyinfinitensor/onnx.py | 19 +++++++--- src/kernels/cuda/matmul.cc | 1 + src/kernels/cuda/reshape.cc | 4 +++ src/kernels/cuda/split_concat.cu | 8 +++-- src/operators/concat.cc | 6 ++-- src/operators/slice.cc | 15 +++++--- 9 files changed, 70 insertions(+), 40 deletions(-) diff --git a/examples/distributed/launch.py b/examples/distributed/launch.py index 64930e6e..58f7efb3 100644 --- a/examples/distributed/launch.py +++ b/examples/distributed/launch.py @@ -5,6 +5,7 @@ import multiprocessing as mp from pyinfinitensor.onnx import OnnxStub, backend import onnx from onnx.external_data_helper import convert_model_to_external_data +from onnx.shape_inference import infer_shapes_path import numpy as np from parallel_opt import parallel_model @@ -44,16 +45,18 @@ def parse_args(): ) -def run_model(model, runtime, inputs: np.array, n=20): +def run_model(model, runtime, inputs, n=10): stub = OnnxStub(model, runtime) - next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs) - stub.tune() + for tensor, input in zip(stub.inputs.values(), inputs): + tensor.copyin_numpy(input) + # stub.tune() stub.run() # get outputs - outputs = np.array(next(stub.outputs.items().__iter__())[1].copyout_float()) + outputs = next(stub.outputs.values().__iter__()).copyout_numpy() # bench - next(stub.inputs.items().__iter__())[1].copyin_numpy(inputs) + for tensor, input in zip(stub.inputs.values(), inputs): + tensor.copyin_numpy(input) begin = time.time() for _ in range(n): stub.run() @@ -64,13 +67,12 @@ def run_model(model, runtime, inputs: np.array, n=20): def run_and_compare(name, model, runtime): - data = np.load(f"{name}_inputs.npy") + input_ids = np.load(f"{name}_inputs.npy") + position_ids = np.arange(input_ids.shape[-1]) results = np.load(f"{name}_results.npy") - outputs = run_model(model, runtime, data) - print("outputs sum:", outputs.sum()) - print("max abs diff:", abs(outputs - results).max()) - print("max rel diff:", abs((outputs - results) / results).max()) - # assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6) + outputs = run_model(model, runtime, (input_ids, position_ids)) + print("outputs abs mean:", abs(outputs).mean()) + np.testing.assert_allclose(outputs, results, rtol=1e-6, atol=1e-3) def start_worker( @@ -81,14 +83,13 @@ def start_worker( extern_path = f"./{dist_name}_rank{rank}.pb" if os.path.exists(extern_path): os.remove(extern_path) - convert_model_to_external_data( + onnx.save_model( model, - all_tensors_to_one_file=True, + f"./{dist_name}_rank{rank}.onnx", + save_as_external_data=True, location=extern_path, - size_threshold=1024, - convert_attribute=False, ) - onnx.save(model, f"./{dist_name}_rank{rank}.onnx") + infer_shapes_path(f"./{dist_name}_rank{rank}.onnx") runtime = backend.CudaRuntime(local_rank) # print("init comm") runtime.init_comm( @@ -106,10 +107,12 @@ def start_single(name, model): def gen_standard(name, model, voc_size, bs, len): # generate standard results - data = np.random.randint(0, voc_size, (bs, len), dtype=np.int32) - np.save(f"{name}_inputs", data) + input_ids = np.random.randint(0, voc_size, (bs, len)) + position_ids = np.arange(len) + np.save(f"{name}_inputs", input_ids) runtime = backend.CudaRuntime(0) - outputs = run_model(model, runtime, data, 1) + outputs = run_model(model, runtime, (input_ids, position_ids), 1) + print("outputs abs mean:", abs(outputs).mean()) np.save(f"{name}_results", outputs) @@ -128,12 +131,14 @@ def main(): # run single process. # use standalone process to isolate cuda. + print("run model by single GPU.") p = mp.Process(target=start_single, args=(name, model)) p.start() p.join() # run distributed parallel. world_size = nnodes * nproc_per_node + print(f"run model by {world_size} GPU in parallel.") workers = [ mp.Process( target=start_worker, diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py index 42465a69..3ddf2ead 100644 --- a/examples/distributed/parallel_opt.py +++ b/examples/distributed/parallel_opt.py @@ -11,6 +11,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): vinfo = {info.name: info for info in model.graph.value_info} vinfo.update({info.name: info for info in model.graph.input}) vinfo.update({info.name: info for info in model.graph.output}) + output = {info.name: info for info in model.graph.output} place: Dict[str, Placement] = {} nodes: List[NodeProto] = [] @@ -56,7 +57,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): ndim = len(vinfo[output].type.tensor_type.shape.dim) out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial() place[node.output[0]] = out_plc - + def shard_concat(node: NodeProto): # hack for kvcache in_plc = place[node.input[1]] @@ -154,7 +155,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): ), f"{place[node.input[0]]} != {place[node.input[1]]}" place[node.output[0]] = place[node.input[0]] elif node.op_type == "Concat": - shard_concat(node) + shard_concat(node) def find_successor(op_type: str, idx: int, search_limit: int = 1): for node in model.graph.node[idx + 1 : idx + 1 + search_limit]: @@ -175,6 +176,9 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): if (node.op_type == "MatMul" or node.op_type == "Gemm") and any( input in data for input in node.input ): + # FIXME(constroy): the last MatMul should not be sharded as TP. + if node.output[0] in output: + continue groups = 1 # If the Gemm or Matmul is followed by a split, then the inputs are concatinated by groups split_node = find_successor("Split", index, search_limit=2) @@ -218,7 +222,7 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): new_input = [] for info in model.graph.input: new_input.append(vinfo[info.name]) - + graph = helper.make_graph( nodes, model.graph.name + f"_{tp_rank}", diff --git a/include/core/common.h b/include/core/common.h index 749caff2..81e704f8 100644 --- a/include/core/common.h +++ b/include/core/common.h @@ -75,7 +75,7 @@ template std::string vecToString(const std::vector &vec) { double timeit( const std::function &func, - const std::function &sync = []() {}, int warmupRounds = 200, - int timingRounds = 200); + const std::function &sync = []() {}, int warmupRounds = 10, + int timingRounds = 10); } // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index d11fbb90..6d0da9f8 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -28,6 +28,7 @@ from typing import Dict, List, Any, Tuple, Sequence, Union, Optional from functools import reduce from onnxsim import simplify import copy +import warnings class OnnxStub: @@ -48,7 +49,10 @@ class OnnxStub: self.inputs: Dict[str, backend.Tensor] = {} self.outputs: Dict[str, backend.Tensor] = {} self.initializer: Dict[int, TensorProto] = {} - model = infer_shapes(model) + try: + model = infer_shapes(model) + except: + warnings.warn("infer_shapes failed.") self.handler = backend.GraphHandler(runtime) tensors: Dict[str, backend.Tensor] = dict() @@ -603,15 +607,20 @@ class OnnxStub: != 0, ) elif node.op_type == "Slice": + + def clamp(nums): + MAX_INT = 0x7FFFFFFF + return [min(x, MAX_INT) for x in nums] + tensors[node.output[0]] = self.handler.slice( tensors[node.input[0]], tensors.get(node.output[0]), - _parse_data(data[node.input[1]]), - _parse_data(data[node.input[2]]), - _parse_data(data[node.input[3]]) + clamp(_parse_data(data[node.input[1]])), + clamp(_parse_data(data[node.input[2]])), + clamp(_parse_data(data[node.input[3]])) if len(node.input) > 3 else None, - _parse_data(data[node.input[4]]) + clamp(_parse_data(data[node.input[4]])) if len(node.input) > 4 else None, ) diff --git a/src/kernels/cuda/matmul.cc b/src/kernels/cuda/matmul.cc index 9cd4b0b3..2d457cbc 100644 --- a/src/kernels/cuda/matmul.cc +++ b/src/kernels/cuda/matmul.cc @@ -58,6 +58,7 @@ class matmulCublas : public Kernel { SmallArray inputShape, outputShape; int nDims = out->getRank(); IT_ASSERT(nDims <= SMALL_ARRAY_SIZE); + // FIXME(constroy): use size_t for outputsize. int outputsize = 1; // the length of the output vector after flatten int offset = nDims - inC->getRank(); for (int i = 0; i < offset; ++i) diff --git a/src/kernels/cuda/reshape.cc b/src/kernels/cuda/reshape.cc index 77070c23..7be6aca8 100644 --- a/src/kernels/cuda/reshape.cc +++ b/src/kernels/cuda/reshape.cc @@ -13,6 +13,10 @@ class CopyCuda : public CudaKernelWithoutConfig { // reshape/flatten/identity all act as copying from input to output. REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Float32, CopyCuda, "Reshape_CUDA_Float32"); +REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int64, CopyCuda, + "Reshape_CUDA_Int64"); +REGISTER_KERNEL(Device::CUDA, OpType::Reshape, DataType::Int32, CopyCuda, + "Reshape_CUDA_Int32"); REGISTER_KERNEL(Device::CUDA, OpType::Flatten, DataType::Float32, CopyCuda, "Flatten_CUDA_Float32"); REGISTER_KERNEL(Device::CUDA, OpType::Identity, DataType::Float32, CopyCuda, diff --git a/src/kernels/cuda/split_concat.cu b/src/kernels/cuda/split_concat.cu index 73f29482..193501e0 100644 --- a/src/kernels/cuda/split_concat.cu +++ b/src/kernels/cuda/split_concat.cu @@ -51,13 +51,15 @@ __global__ void _split_concat_kernel(ElementTensorMetadata elemMeta, namespace infini { -// TODO: when dim=0, the operation can be executed in-place +// TODO: when dim=0, the operation can be executed in-place void split_concat_kernel(const ElementTensorMetadata &eleMeta, const ComposedTensorMetadata &compMeta, int dim, int batchSize, int nDims, bool isSplit) { dim3 blockSize = dim3(32 * 16); - // gridsize =n_elements / blockSize - int gridDimX = (eleMeta.nElements[0] - 1) / (32 * 16) + 1; + // gridsize = max_n_elements / blockSize + int max_n_elements = + *std::max_element(eleMeta.nElements, eleMeta.nElements + batchSize); + int gridDimX = (max_n_elements - 1) / (32 * 16) + 1; // each y is a split among the batch dim3 gridSize(gridDimX, batchSize); diff --git a/src/operators/concat.cc b/src/operators/concat.cc index de836d58..95535233 100644 --- a/src/operators/concat.cc +++ b/src/operators/concat.cc @@ -2,10 +2,10 @@ #include "utils/operator_utils.h" namespace infini { -ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int dim) - : OperatorObj(OpType::Concat, inputs, {output}), dim(dim) { +ConcatObj::ConcatObj(GraphObj *graph, TensorVec inputs, Tensor output, int _dim) + : OperatorObj(OpType::Concat, inputs, {output}) { int rank = inputs[0]->getRank(); - dim = get_real_axis(dim, rank); + dim = get_real_axis(_dim, rank); IT_ASSERT(checkValid(graph)); } diff --git a/src/operators/slice.cc b/src/operators/slice.cc index 1ded2745..0db3b1a2 100644 --- a/src/operators/slice.cc +++ b/src/operators/slice.cc @@ -43,17 +43,22 @@ SliceObj::SliceObj(GraphObj *graph, Tensor input, Tensor output, auto size = shape.size(); this->axes.reserve(size); - for (size_t i = 0; i < size; ++i) + for (size_t i = 0; i < size; ++i) { + auto len = shape[i]; if (auto _i = axes.find(i); _i != axes.end()) { auto __i = _i->second; auto start = starts[__i]; auto end = ends[__i]; - this->axes.push_back({start >= 0 ? start : start + shape[__i], - end >= 0 ? end : end + shape[__i], - steps[__i]}); + if (start > len) + start = len; + if (end > len) + end = len; + this->axes.push_back({start >= 0 ? start : start + len, + end >= 0 ? end : end + len, steps[__i]}); } else { - this->axes.push_back({0, shape[i], 1}); + this->axes.push_back({0, len, 1}); } + } IT_ASSERT(checkValid(graph)); } From 23b825efc48879974819eb5d6d2edd5618b55845 Mon Sep 17 00:00:00 2001 From: Bolun Zhang <48948016+Chamberlain0w0@users.noreply.github.com> Date: Mon, 30 Oct 2023 16:01:05 +0800 Subject: [PATCH 02/28] Xpu task4 support: add softmax (#172) * add softmax on kunlun * format --------- Co-authored-by: Bolun Co-authored-by: Haojie Wang --- src/kernels/kunlun/softmax.cc | 26 ++++ test/kernels/kunlun/test_kunlun_softmax.cc | 136 +++++++++++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 src/kernels/kunlun/softmax.cc create mode 100644 test/kernels/kunlun/test_kunlun_softmax.cc diff --git a/src/kernels/kunlun/softmax.cc b/src/kernels/kunlun/softmax.cc new file mode 100644 index 00000000..56374766 --- /dev/null +++ b/src/kernels/kunlun/softmax.cc @@ -0,0 +1,26 @@ +#include "operators/softmax.h" +#include "kunlun/kunlun_kernel_without_config.h" +#include "kunlun/kunlun_runtime.h" + +namespace infini { +class SoftmaxXdnn : public KUNLUNKernelWithoutConfig { + void compute(const Operator &_op, + const RuntimeObj *_context) const override { + auto op = as(_op); + auto context = dynamic_cast(_context); + auto dim = op->getInputs(0)->getDims(); + auto axis = op->getAxis(); + + void *const aData = (op->getInputs(0)->getRawDataPtr()); + void *const cData = (op->getOutput()->getRawDataPtr()); + + auto ret = baidu::xpu::api::softmax( + context->KUNLUNHandle(), (float *)aData, (float *)cData, dim, axis); + assert(ret == 0); + return; + } +}; + +REGISTER_KERNEL(Device::KUNLUN, OpType::Softmax, DataType::Float32, SoftmaxXdnn, + "Softmax_xdnn_KUNLUN_Float32"); +}; // namespace infini diff --git a/test/kernels/kunlun/test_kunlun_softmax.cc b/test/kernels/kunlun/test_kunlun_softmax.cc new file mode 100644 index 00000000..77d6dbd8 --- /dev/null +++ b/test/kernels/kunlun/test_kunlun_softmax.cc @@ -0,0 +1,136 @@ +#include "core/graph.h" +#include "core/kernel.h" +#include "core/runtime.h" +#include "kunlun/kunlun_runtime.h" +#include "operators/softmax.h" +#include "test.h" +#include +namespace infini { + +TEST(XDNN_Softmax, run_axis1) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto kunlunRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 4}, DataType::Float32, cpuRuntime); + + // KUNLUN XPU + Graph kunlunGraph = make_ref(kunlunRuntime); + auto inputKunlun = kunlunGraph->cloneTensor(inputCpu); + auto kunlunOp = kunlunGraph->addOp(inputKunlun, nullptr, 1); + kunlunGraph->dataMalloc(); + inputKunlun->copyin(vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); + kunlunRuntime->run(kunlunGraph); + auto outputKunlun = kunlunOp->getOutput(); + auto outputKunlun2Cpu = outputKunlun->clone(cpuRuntime); + + // Check + EXPECT_TRUE(outputKunlun2Cpu->equalData( + vector{0.032058604, 0.08714432, 0.23688284, 0.6439143, + 0.032058604, 0.08714432, 0.23688284, 0.6439143})); +} + +TEST(XDNN_Softmax, run_axis0) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto kunlunRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 4}, DataType::Float32, cpuRuntime); + + // KUNLUN XPU + Graph kunlunGraph = make_ref(kunlunRuntime); + auto inputKunlun = kunlunGraph->cloneTensor(inputCpu); + auto kunlunOp = kunlunGraph->addOp(inputKunlun, nullptr, 0); + kunlunGraph->dataMalloc(); + inputKunlun->copyin(vector{0, 1, 2, 3, 10000, 10001, 10002, 10003}); + kunlunRuntime->run(kunlunGraph); + auto outputKunlun = kunlunOp->getOutput(); + auto outputKunlun2Cpu = outputKunlun->clone(cpuRuntime); + + // Check + EXPECT_TRUE( + outputKunlun2Cpu->equalData(vector{0., 0., 0., 0., 1, 1, 1, 1})); +} + +TEST(XDNN_Softmax2, run_axis1) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto kunlunRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime); + + // KUNLUN XPU + Graph kunlunGraph = make_ref(kunlunRuntime); + auto inputKunlun = kunlunGraph->cloneTensor(inputCpu); + auto kunlunOp = kunlunGraph->addOp(inputKunlun, nullptr, 1); + kunlunGraph->dataMalloc(); + inputKunlun->setData(IncrementalGenerator()); + kunlunRuntime->run(kunlunGraph); + auto outputKunlun = kunlunOp->getOutput(); + auto outputKunlun2Cpu = outputKunlun->clone(cpuRuntime); + + // Check + EXPECT_TRUE(outputKunlun2Cpu->equalData(vector{ + 0.0179862, 0.0179862, 0.0179862, 0.0179862, 0.9820138, 0.9820138, + 0.9820138, 0.9820138, 0.0179862, 0.0179862, 0.0179862, 0.0179862, + 0.9820138, 0.9820138, 0.9820138, 0.9820138})); +} + +TEST(XDNN_Softmax2, run_axis2) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto kunlunRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime); + + // KUNLUN XPU + Graph kunlunGraph = make_ref(kunlunRuntime); + auto inputKunlun = kunlunGraph->cloneTensor(inputCpu); + auto kunlunOp = kunlunGraph->addOp(inputKunlun, nullptr, 2); + kunlunGraph->dataMalloc(); + inputKunlun->setData(IncrementalGenerator()); + kunlunRuntime->run(kunlunGraph); + auto outputKunlun = kunlunOp->getOutput(); + auto outputKunlun2Cpu = outputKunlun->clone(cpuRuntime); + + // Check + EXPECT_TRUE(outputKunlun2Cpu->equalData(vector{ + 0.1192029, 0.1192029, 0.8807971, 0.8807971, 0.1192029, 0.1192029, + 0.8807971, 0.8807971, 0.1192029, 0.1192029, 0.8807971, 0.8807971, + 0.1192029, 0.1192029, 0.8807971, 0.8807971})); +} + +TEST(XDNN_Softmax2, run_axis3) { + // Runtime + Runtime cpuRuntime = NativeCpuRuntimeObj::getInstance(); + auto kunlunRuntime = make_ref(); + + // Build input data on CPU + Tensor inputCpu = + make_ref(Shape{2, 2, 2, 2}, DataType::Float32, cpuRuntime); + + // KUNLUN XPU + Graph kunlunGraph = make_ref(kunlunRuntime); + auto inputKunlun = kunlunGraph->cloneTensor(inputCpu); + auto kunlunOp = kunlunGraph->addOp(inputKunlun, nullptr, 3); + kunlunGraph->dataMalloc(); + inputKunlun->setData(IncrementalGenerator()); + kunlunRuntime->run(kunlunGraph); + auto outputKunlun = kunlunOp->getOutput(); + auto outputKunlun2Cpu = outputKunlun->clone(cpuRuntime); + + // Check + EXPECT_TRUE(outputKunlun2Cpu->equalData(vector{ + 0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586, + 0.2689414, 0.7310586, 0.2689414, 0.7310586, 0.2689414, 0.7310586, + 0.2689414, 0.7310586, 0.2689414, 0.7310586})); +} +} // namespace infini From ec3adf6fa73cc6390f09a9bbd23910640d9ed000 Mon Sep 17 00:00:00 2001 From: xgqdut2016 <140036308+xgqdut2016@users.noreply.github.com> Date: Tue, 31 Oct 2023 10:47:36 +0800 Subject: [PATCH 03/28] support 8D tensor, add test example (#170) Co-authored-by: Haojie Wang --- include/cuda/cuda_split_concat.h | 2 +- test/kernels/cuda/test_cuda_concat.cc | 36 +++++++++++++++++++++++++++ test/kernels/cuda/test_cuda_split.cc | 34 +++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/include/cuda/cuda_split_concat.h b/include/cuda/cuda_split_concat.h index 699f15bc..58bdf330 100644 --- a/include/cuda/cuda_split_concat.h +++ b/include/cuda/cuda_split_concat.h @@ -3,7 +3,7 @@ #include const int BATCH_SIZE = 32; // parallel tensor number. -const int DIM_MAX_SIZE = 4; +const int DIM_MAX_SIZE = 8; // Concat operator acts like element tensors composing to one big tensor,and // split operator acts like one big tensor being composed by element diff --git a/test/kernels/cuda/test_cuda_concat.cc b/test/kernels/cuda/test_cuda_concat.cc index 013d25b5..2c76f405 100644 --- a/test/kernels/cuda/test_cuda_concat.cc +++ b/test/kernels/cuda/test_cuda_concat.cc @@ -122,4 +122,40 @@ TEST(Concat, Cuda_dim0) { EXPECT_TRUE(oCpu->equalData(vector{0, 1, 2, 1, 1, 1, 0, 1, 2})); } +TEST(Concat, CudaHigh) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto t1 = gCpu->addTensor({2, 2, 3, 1, 2}, DataType::Float32); + auto t2 = gCpu->addTensor({2, 2, 1, 1, 2}, DataType::Float32); + auto t3 = gCpu->addTensor({2, 2, 2, 1, 2}, DataType::Float32); + gCpu->dataMalloc(); + t1->setData(IncrementalGenerator()); + t2->setData(OneGenerator()); + t3->setData(OneGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto t1Gpu = gCuda->cloneTensor(t1); + auto t2Gpu = gCuda->cloneTensor(t2); + auto t3Gpu = gCuda->cloneTensor(t3); + + auto op = + gCuda->addOp(TensorVec{t1Gpu, t2Gpu, t3Gpu}, nullptr, 2); + gCuda->dataMalloc(); + t1Gpu->setData(IncrementalGenerator()); + t2Gpu->setData(OneGenerator()); + t3Gpu->setData(OneGenerator()); + cudaRuntime->run(gCuda); + + // cudaPrintTensor(op->getOutput()); + // copy output from CUDA to CPU + auto oCpu = gCpu->cloneTensor(op->getOutput()); + EXPECT_TRUE(oCpu->equalData( + vector{0., 1., 2., 3., 4., 5., 1., 1., 1., 1., 1., 1., + 6., 7., 8., 9., 10., 11., 1., 1., 1., 1., 1., 1., + 12., 13., 14., 15., 16., 17., 1., 1., 1., 1., 1., 1., + 18., 19., 20., 21., 22., 23., 1., 1., 1., 1., 1., 1.})); +} } // namespace infini diff --git a/test/kernels/cuda/test_cuda_split.cc b/test/kernels/cuda/test_cuda_split.cc index 5a32f27f..2cab944e 100644 --- a/test/kernels/cuda/test_cuda_split.cc +++ b/test/kernels/cuda/test_cuda_split.cc @@ -39,6 +39,40 @@ TEST(Split, Cuda) { 12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39})); } +TEST(Split, CudaHigh) { + Runtime runtime = NativeCpuRuntimeObj::getInstance(); + Graph gCpu = make_ref(runtime); + + auto input = gCpu->addTensor({2, 6, 2, 1, 2}, DataType::Float32); + gCpu->dataMalloc(); + input->setData(IncrementalGenerator()); + + auto cudaRuntime = make_ref(); + Graph gCuda = make_ref(cudaRuntime); + + auto inputGpu = gCuda->cloneTensor(input); + auto op = gCuda->addOp(inputGpu, std::nullopt, 1, 3); + gCuda->dataMalloc(); + inputGpu->setData(IncrementalGenerator()); + + cudaRuntime->run(gCuda); + + // copy output from CUDA to CPU + EXPECT_EQ(op->getOutputs().size(), (size_t)3); + auto o0Cpu = gCpu->cloneTensor(op->getOutput(0)); + auto o1Cpu = gCpu->cloneTensor(op->getOutput(1)); + auto o2Cpu = gCpu->cloneTensor(op->getOutput(2)); + EXPECT_TRUE( + o0Cpu->equalData(vector{0., 1., 2., 3., 4., 5., 6., 7., 24., 25., + 26., 27., 28., 29., 30., 31.})); + EXPECT_TRUE(o1Cpu->equalData(vector{8., 9., 10., 11., 12., 13., 14., + 15., 32., 33., 34., 35., 36., + 37., 38., 39.})); + EXPECT_TRUE(o2Cpu->equalData(vector{16., 17., 18., 19., 20., 21., + 22., 23., 40., 41., 42., 43., + 44., 45., 46., 47.})); +} + TEST(Split, Cuda_dim0) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(runtime); From 1a6fccccbe884b5217524f235e7bcc09ebabae60 Mon Sep 17 00:00:00 2001 From: Derui Yang Date: Fri, 3 Nov 2023 13:21:49 +0800 Subject: [PATCH 04/28] =?UTF-8?q?test:=20=E6=94=AF=E6=8C=81=E7=BC=96?= =?UTF-8?q?=E8=AF=91=20einnet=20=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=EF=BC=8C=E4=BD=86=E4=B8=8D=E6=98=AF=E6=89=80=E6=9C=89=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E9=83=BD=E8=83=BD=E9=80=9A=E8=BF=87=20(#174)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test: 支持编译 einnet 单元测试,但不是所有测试都能通过 Signed-off-by: YdrMaster * Fix: locating resource files and skip codegen - Change the path parameters in `matchExprResult` and `checkExprLogSame` to paths relative to the project home - Skip NNetMemboundOp tests as they require codegen --------- Signed-off-by: YdrMaster Co-authored-by: Haojie Wang Co-authored-by: Liyan Zheng --- CMakeLists.txt | 20 +++++++++++--------- Makefile | 9 ++++----- include/nnet/test.h | 4 ++-- src/nnet/test.cc | 33 +++++++++++++++++++++++++++------ test/nnet/test_OpSearch.cc | 11 +++++------ test/nnet/test_TConv2gemm.cc | 6 +++--- test/nnet/test_conv2conv.cc | 12 ++++-------- test/nnet/test_conv2gemm.cc | 8 ++++---- test/nnet/test_g2bmm.cc | 6 ++---- test/nnet/test_memboundOp.cc | 25 ++++++++++++++++++++++--- test/nnet/test_mutator.cc | 4 ++++ 11 files changed, 88 insertions(+), 50 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d942fcd4..b08d9f85 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,7 +21,6 @@ project(InfiniTensor C CXX) cmake_dependent_option(BUILD_TEST_CORE "Build tests for core components" ON BUILD_TEST OFF) cmake_dependent_option(BUILD_TEST_PET "Build tests for PET" OFF BUILD_TEST OFF) -cmake_dependent_option(BUILD_TEST_EINNET "Build tests for EINNET" OFF BUILD_TEST OFF) set(DEFAULT_BUILD_TYPE "RelWithDebInfo") # Build Type @@ -95,16 +94,17 @@ add_subdirectory(3rd-party/nlohmann_json_cmake_fetchcontent) include_directories(3rd-party/nlohmann_json_cmake_fetchcontent/single_include) # TVM backend -if(BUILD_TEST_EINNET) - if (NOT TVM_INCLUDE_DIR OR NOT DMLC_INCLUDE_DIR OR NOT DLPACK_INCLUDE_DIR OR NOT DLPACK_INCLUDE_DIR) - message(FATAL_ERROR "TVM_INCLUDE_DIR, DMLC_INCLUDE_DIR, and DLPACK_INCLUDE_DIR must be set when BUILD_TEST_EINNET is ON") - endif() +if(BUILD_NNET AND BUILD_TEST) # TVM and DMLC for invoking TVM packed functions include_directories(${TVM_INCLUDE_DIR}) include_directories(${DMLC_INCLUDE_DIR}) include_directories(${DLPACK_INCLUDE_DIR}) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_LOGGING_LIBRARY=\\\<${TVM_INCLUDE_DIR}/tvm/runtime/logging.h\\\> ") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DINFINI_USE_TVM=1") # Enable TVM codegen kernels + if (TVM_INCLUDE_DIR AND DMLC_INCLUDE_DIR AND DLPACK_INCLUDE_DIR AND DLPACK_INCLUDE_DIR) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDMLC_USE_LOGGING_LIBRARY=\\\<${TVM_INCLUDE_DIR}/tvm/runtime/logging.h\\\> ") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DINFINI_USE_TVM=1") # Enable TVM codegen kernels + else() + # message(FATAL_ERROR "TVM_INCLUDE_DIR, DMLC_INCLUDE_DIR, and DLPACK_INCLUDE_DIR must be set when BUILD_NNET AND BUILD_TEST is ON") + endif() endif() if(BUILD_TEST) @@ -130,6 +130,8 @@ if(BUILD_NNET) add_compile_definitions(BUILD_NNET=1) file(GLOB_RECURSE SRC_NNET src/nnet/*.cc) list (APPEND SRC ${SRC_NNET}) + # For locating resource files + set_source_files_properties(src/nnet/test.cc PROPERTIES COMPILE_OPTIONS "-DINFINI_PROJECT_HOME=${CMAKE_CURRENT_SOURCE_DIR}") endif() if(USE_CUDA) @@ -161,7 +163,7 @@ endif() target_link_libraries(InfiniTensor pybind11::embed) # TVM backend -if(BUILD_TEST_EINNET) +if(BUILD_NNET AND BUILD_TEST AND TVM_LIB_DIR) target_link_libraries(InfiniTensor ${TVM_LIB_DIR}/libtvm.so) endif() @@ -333,7 +335,7 @@ if(BUILD_TEST) if(BUILD_TEST_PET) build_test(test/pet/*.cc) endif() - if(BUILD_TEST_EINNET) + if(BUILD_NNET AND BUILD_TEST) build_test(test/nnet/test_*.cc) # Build expression reader diff --git a/Makefile b/Makefile index 19f1b353..302f47b8 100644 --- a/Makefile +++ b/Makefile @@ -7,12 +7,13 @@ KUNLUN ?= OFF INTELCPU ?= off BACKTRACE ?= ON TEST ?= ON +NNET ?= OFF FORMAT_ORIGIN ?= # Docker build options DOCKER_NAME ?= infinitensor DOCKER_IMAGE_NAME ?= infinitensor DOCKER_FILE ?= infinitensor_ubuntu_22.04.dockerfile -DOCKER_RUN_OPTION ?= +DOCKER_RUN_OPTION ?= # CUDA option. ifeq ($(CUDA), ON) @@ -22,13 +23,13 @@ ifeq ($(CUDA), ON) DOCKER_RUN_OPTION += --gpus all -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v `pwd`:`pwd` -w `pwd` endif - CMAKE_OPT = -DCMAKE_BUILD_TYPE=$(TYPE) CMAKE_OPT += -DUSE_CUDA=$(CUDA) CMAKE_OPT += -DUSE_BANG=$(BANG) CMAKE_OPT += -DUSE_KUNLUN=$(KUNLUN) CMAKE_OPT += -DUSE_BACKTRACE=$(BACKTRACE) CMAKE_OPT += -DBUILD_TEST=$(TEST) +CMAKE_OPT += -DBUILD_NNET=$(NNET) ifeq ($(INTELCPU), ON) CMAKE_OPT += -DUSE_INTELCPU=ON -DCMAKE_CXX_COMPILER=dpcpp @@ -60,7 +61,7 @@ test-api: @echo python3 pyinfinitensor/tests/test_api.py -docker-build: +docker-build: docker build -f scripts/dockerfile/$(DOCKER_FILE) -t $(DOCKER_NAME) . docker-run: @@ -71,5 +72,3 @@ docker-start: docker-exec: docker exec -it $(DOCKER_IMAGE_NAME) bash - - diff --git a/include/nnet/test.h b/include/nnet/test.h index 6cf75873..1d24bbd4 100644 --- a/include/nnet/test.h +++ b/include/nnet/test.h @@ -24,7 +24,7 @@ // clang-format on namespace nnet { -int matchExprResult(Derivator &derivator, string fn); -bool checkExprLogSame(string fnPrefix, int start, int end); +int matchExprResult(Derivator &derivator, string pathRelativeToProjectHome); +bool checkExprLogSame(string pathRelativeToProjectHome, int start, int end); bool checkExprsEquvivalence(VecExpr exprs); } // namespace nnet diff --git a/src/nnet/test.cc b/src/nnet/test.cc index 6c9738f2..fbf41e20 100644 --- a/src/nnet/test.cc +++ b/src/nnet/test.cc @@ -3,10 +3,30 @@ #include "nnet/Visitor/HashVisitor.h" #include "nnet/Visitor/Interpreter.h" #include "nnet/Visitor/Serializer.h" +#include namespace nnet { -int matchExprResult(Derivator &derivator, string fn) { - auto ans = Serializer().deserialize(fn); +std::filesystem::path getProjectHome() { +#ifndef INFINI_PROJECT_HOME +#error INFINI_PROJECT_HOME is not defined +#endif + +#define Q(x) #x +#define QUOTE(x) Q(x) +#define PROJECT_HOME QUOTE(INFINI_PROJECT_HOME) + return std::filesystem::path(PROJECT_HOME); +#undef PROJECT_HOME +#undef QUOTE +#undef Q +} + +string getResourceFilePath(string path) { + return (getProjectHome() / path).string(); +} + +int matchExprResult(Derivator &derivator, string pathRelativeToProjectHome) { + auto fullPath = getResourceFilePath(pathRelativeToProjectHome); + auto ans = Serializer().deserialize(fullPath); auto hashAns = HashVisitor()(ans); int match = 0; for (const auto &candidate : derivator.getCandidates()) { @@ -16,16 +36,17 @@ int matchExprResult(Derivator &derivator, string fn) { return match; } -bool checkExprLogSame(string fnPrefix, int start, int end) { +bool checkExprLogSame(string pathRelativeToProjectHome, int start, int end) { + auto fullPath = getResourceFilePath(pathRelativeToProjectHome); Serializer serializer; - string fn0 = fnPrefix + to_string(start) + ".expr"; + string fn0 = fullPath + to_string(start) + ".expr"; Expr expr0 = serializer.deserialize(fn0); RangeOp range0 = as(expr0); Interpreter interpreter(range0); auto ans0 = interpreter.interpretUniformSample(range0); dbg(expr0, ans0); for (int i = start + 1; i < end; ++i) { - string fn1 = fnPrefix + to_string(i) + ".expr"; + string fn1 = fullPath + to_string(i) + ".expr"; Expr expr1 = serializer.deserialize(fn1); RangeOp range1 = as(expr1); dbg(fn1, expr1); @@ -67,4 +88,4 @@ bool checkExprsEquvivalence(VecExpr exprs) { return true; } -} // namespace nnet \ No newline at end of file +} // namespace nnet diff --git a/test/nnet/test_OpSearch.cc b/test/nnet/test_OpSearch.cc index f969ccf7..14a6c737 100644 --- a/test/nnet/test_OpSearch.cc +++ b/test/nnet/test_OpSearch.cc @@ -85,7 +85,7 @@ TEST_F(OpSearch, Conv2gemm_NCHW_FCRS_search) { EXPECT_GE(derivator.getNumCandidates(), 1); int nMatches = matchExprResult( - derivator, "../test/nnet/log/conv2gemm/Conv2gemm_NCHW_FCRS_11.expr"); + derivator, "test/nnet/log/conv2gemm/Conv2gemm_NCHW_FCRS_11.expr"); EXPECT_GE(nMatches, 1); // derivator.print(); derivator.printStatistics(); @@ -160,7 +160,7 @@ TEST_F(OpSearch, TConv2gemm_TConv4x4_NHWF_RSFC_search) { EXPECT_GE(derivator.getNumCandidates(), 1); int nMatches = matchExprResult( derivator, - "../test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr"); + "test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr"); EXPECT_GE(nMatches, 1); derivator.printStatistics(); } @@ -197,8 +197,7 @@ TEST_F(OpSearch, Conv2conv_5x5_RuleBased_NCHW_FCRS) { derivator.search(conv_9x9, 0); int nMatches = matchExprResult( - derivator, - "../test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr"); + derivator, "test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr"); // derivator.print(); derivator.printStatistics(); EXPECT_GE(nMatches, 1); @@ -236,8 +235,8 @@ TEST_F(OpSearch, G2BMM_RuleBased) { EXPECT_GE(derivator.getNumCandidates(), 1); int nMatches = - matchExprResult(derivator, "../test/nnet/log/g2bmm/G2BMM_9.expr"); + matchExprResult(derivator, "test/nnet/log/g2bmm/G2BMM_9.expr"); EXPECT_GE(nMatches, 1); // derivator.print(); derivator.printStatistics(); -} \ No newline at end of file +} diff --git a/test/nnet/test_TConv2gemm.cc b/test/nnet/test_TConv2gemm.cc index 73f32088..3b84b843 100644 --- a/test/nnet/test_TConv2gemm.cc +++ b/test/nnet/test_TConv2gemm.cc @@ -365,14 +365,14 @@ TEST(TConv2gemm, TConv4x4_NHWF_RSFC_search) { // } int nMatches = matchExprResult( derivator, - "../test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr"); + "test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_18.expr"); EXPECT_GE(nMatches, 1); derivator.printStatistics(); } TEST(TConv2gemm, TConv4x4_NHWF_FRSC_CheckDerivationCorrectness_log) { const string fnPrefix = - "../test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_"; + "test/nnet/log/TConv4x4_NHWF_RSFC/TConv4x4_NHWF_RSFC_"; EXPECT_TRUE(checkExprLogSame(fnPrefix, 0, 11)); } @@ -388,4 +388,4 @@ TEST(Conv2conv, InfoGAN_ConvTranspose_3_OOB_Test) { dbg(expr); Derivator derivator; derivator.checkOOB(as(expr)); -} \ No newline at end of file +} diff --git a/test/nnet/test_conv2conv.cc b/test/nnet/test_conv2conv.cc index 8e961e95..a77ab39a 100644 --- a/test/nnet/test_conv2conv.cc +++ b/test/nnet/test_conv2conv.cc @@ -37,10 +37,8 @@ TEST(Conv2conv, 9x9_NCHW_FCRS) { derivator.ruleBasedDFS(conv_9x9, 0, rules, {}, true); } else derivator.search(conv_9x9, 0); - int nMatches = matchExprResult( - derivator, - "../test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_14.expr"); + derivator, "test/nnet/log/conv2conv/Conv2conv_9x9_NCHW_FCRS_14.expr"); derivator.print(); derivator.printStatistics(); EXPECT_GE(nMatches, 1); @@ -81,8 +79,7 @@ TEST(Conv2conv, 6x6_RuleBased_NCHW_FCRS) { ASSERT_GE(derivator.getNumCandidates(), 1); int nMatches = matchExprResult( - derivator, - "../test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_14.expr"); + derivator, "test/nnet/log/conv2conv/Conv2conv_6x6_NCHW_FCRS_14.expr"); derivator.print(); derivator.printStatistics(); EXPECT_GE(nMatches, 1); @@ -121,9 +118,8 @@ TEST(Conv2conv, 5x5_RuleBased_NCHW_FCRS) { derivator.search(conv_9x9, 0); int nMatches = matchExprResult( - derivator, - "../test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr"); + derivator, "test/nnet/log/conv2conv/Conv2conv_5x5_NCHW_FCRS_15.expr"); derivator.print(); derivator.printStatistics(); EXPECT_GE(nMatches, 1); -} \ No newline at end of file +} diff --git a/test/nnet/test_conv2gemm.cc b/test/nnet/test_conv2gemm.cc index 9827497b..3f15d328 100644 --- a/test/nnet/test_conv2gemm.cc +++ b/test/nnet/test_conv2gemm.cc @@ -147,7 +147,7 @@ TEST(Conv2gemm, timing_NHWC_RSFC_search) { // Conv2gemm requires thorough update, this is disabled temporarily TEST(Conv2gemm, CheckCorrectness) { - const string fnPrefix = "../test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_"; + const string fnPrefix = "test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_"; // conv2gemm_7 has T3 EXPECT_TRUE(checkExprLogSame(fnPrefix, 0, 7)); } @@ -182,7 +182,7 @@ TEST(Conv2gemm, NCHW_RSFC_search) { ASSERT_GE(derivator.getNumCandidates(), 1); int nMatches = matchExprResult( - derivator, "../test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_11.expr"); + derivator, "test/nnet/log/conv2gemm/Conv2gemm_NCHW_RSFC_11.expr"); EXPECT_GE(nMatches, 1); // derivator.print(); derivator.printStatistics(); @@ -278,6 +278,6 @@ TEST(Conv2gemm1x7, NCHW_FCRS_search) { ASSERT_GE(derivator.getNumCandidates(), 1); int nMatches = matchExprResult( derivator, - "../test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_11.expr"); + "test/nnet/log/conv2gemm_1x7/Conv2gemm_1x7_NCHW_FCRS_11.expr"); EXPECT_GE(nMatches, 1); -} \ No newline at end of file +} diff --git a/test/nnet/test_g2bmm.cc b/test/nnet/test_g2bmm.cc index e285fb98..6ae1150e 100644 --- a/test/nnet/test_g2bmm.cc +++ b/test/nnet/test_g2bmm.cc @@ -36,10 +36,8 @@ TEST(GBMM, RuleBased) { } else { derivator.search(dialted_g2bmm, 0); } - ASSERT_GE(derivator.getNumCandidates(), 1); - int nMatches = - matchExprResult(derivator, "../test/nnet/log/gbmm/GBMM_9.expr"); + int nMatches = matchExprResult(derivator, "test/nnet/log/gbmm/GBMM_9.expr"); EXPECT_GE(nMatches, 1); derivator.print(); derivator.printStatistics(); @@ -78,7 +76,7 @@ TEST(G2BMM, RuleBased) { ASSERT_GE(derivator.getNumCandidates(), 1); int nMatches = - matchExprResult(derivator, "../test/nnet/log/g2bmm/G2BMM_9.expr"); + matchExprResult(derivator, "test/nnet/log/g2bmm/G2BMM_9.expr"); EXPECT_GE(nMatches, 1); derivator.print(); derivator.printStatistics(); diff --git a/test/nnet/test_memboundOp.cc b/test/nnet/test_memboundOp.cc index 910344f2..49716161 100644 --- a/test/nnet/test_memboundOp.cc +++ b/test/nnet/test_memboundOp.cc @@ -1,3 +1,5 @@ +#ifdef USE_CUDA + #include "core/graph.h" #include "core/runtime.h" #include "cuda/cuda_runtime.h" @@ -12,7 +14,22 @@ using namespace infini; using namespace std; -TEST(nnet, MemboundOpInterpretation) { +class NNetMemboundOp : public ::testing::Test { + protected: + void SetUp() override { + if (!hasTVMBackend()) + GTEST_SKIP() << "Skipping test since no TVM backend.\n"; + } + + private: + static bool hasTVMBackend() { + // TODO: as the dispatch mechanism of backend is going to change, this + // function is to be implemented. + return false; + } +}; + +TEST_F(NNetMemboundOp, MemboundOpInterpretation) { Runtime runtime = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(runtime); Tensor i0 = g->addTensor({1, 2, 3}, DataType::UInt32); @@ -41,7 +58,7 @@ TEST(nnet, MemboundOpInterpretation) { EXPECT_TRUE(membound->getOutput()->equalData(ans)); } -TEST(nnet, MemboundOp_Ansor_Codegen) { +TEST_F(NNetMemboundOp, MemboundOp_Ansor_Codegen) { auto runtime = make_ref(); Runtime cpu = NativeCpuRuntimeObj::getInstance(); Graph gCpu = make_ref(cpu); @@ -91,7 +108,7 @@ pair, nnet::Expr> getPReluExpr(int size) { return {{A, B}, ret}; } -TEST(nnet, PRelu_Ansor_Codegen) { +TEST_F(NNetMemboundOp, PRelu_Ansor_Codegen) { auto cuda = make_ref(); Runtime cpu = NativeCpuRuntimeObj::getInstance(); Graph g = make_ref(cuda); @@ -116,3 +133,5 @@ TEST(nnet, PRelu_Ansor_Codegen) { auto oCpu = gCpu->cloneTensor(o0); EXPECT_TRUE(oCpu->equalData(ans)); } + +#endif diff --git a/test/nnet/test_mutator.cc b/test/nnet/test_mutator.cc index cf4d8ab2..abc3f604 100644 --- a/test/nnet/test_mutator.cc +++ b/test/nnet/test_mutator.cc @@ -1,3 +1,5 @@ +#ifdef USE_CUDA + #include "core/blob.h" #include "core/dummy_mutator.h" #include "core/graph.h" @@ -477,3 +479,5 @@ TEST(Mutator, InfoGAN_TConv_3_correctness) { // EXPECT_TRUE(graph->verification(bestGraph.get(), true)); // } } // namespace infini + +#endif From d3e75432917addb80f97b851b1d681a0db52af89 Mon Sep 17 00:00:00 2001 From: xgqdut2016 <140036308+xgqdut2016@users.noreply.github.com> Date: Mon, 6 Nov 2023 08:56:23 +0800 Subject: [PATCH 05/28] Cuda softmax (#129) * "add softmax.cu,.cc,.h" * Modify cuda softmax * "modified the introduction of softmax.cu" * "add format of cuda_softmax.h" * "modified where.cc(.cu,.h) and softmax.cu" * "modified format" * Fix cpu softmax kernel * "modified the // introduction of softmax.cu" * "modified softmax.cu and use 1D block" * "modified softmax.cu,format, and use 1D block" * "introduce share mem to speed softmax" * "reduce the input of function" * modified the format * remodify 2D block softmax * remodify 1D block softmax * modified the share memory * add warp reduce * conflict solve two * remove extra space line * solve comment --------- Co-authored-by: Haojie Wang Co-authored-by: panzezhong --- include/cuda/cuda_softmax.h | 6 + include/cuda/softmax.h | 6 - src/kernels/cpu/unary.cc | 3 +- src/kernels/cuda/softmax.cc | 26 +-- src/kernels/cuda/softmax.cu | 236 +++++++++++++++++++-------- test/kernels/cuda/test_cuda_where.cc | 34 ++-- 6 files changed, 209 insertions(+), 102 deletions(-) create mode 100644 include/cuda/cuda_softmax.h delete mode 100644 include/cuda/softmax.h diff --git a/include/cuda/cuda_softmax.h b/include/cuda/cuda_softmax.h new file mode 100644 index 00000000..671f46f8 --- /dev/null +++ b/include/cuda/cuda_softmax.h @@ -0,0 +1,6 @@ +#pragma once +#include "utils/small_array.h" +namespace infini { +void softmax_kernel(int num_blocks, float *input, float *output, int size, + int dimsize, int stride); +} diff --git a/include/cuda/softmax.h b/include/cuda/softmax.h deleted file mode 100644 index 5c0eccf9..00000000 --- a/include/cuda/softmax.h +++ /dev/null @@ -1,6 +0,0 @@ -#pragma once - -namespace infini { -void softmax_kernel(int max_threadblock_size, int batch_size, float *x, - float *y, int dim, int stride); -} diff --git a/src/kernels/cpu/unary.cc b/src/kernels/cpu/unary.cc index 8975d7cd..3ea61b41 100644 --- a/src/kernels/cpu/unary.cc +++ b/src/kernels/cpu/unary.cc @@ -1,6 +1,7 @@ #include "operators/unary.h" #include "core/constants.h" #include "core/kernel.h" +#include "operators/softmax.h" namespace infini { template class NativeUnary : public CpuKernelWithoutConfig { @@ -22,7 +23,7 @@ template class NativeUnary : public CpuKernelWithoutConfig { template class NaiveSoftmax : public CpuKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *context) const override { - auto op = as(_op); + auto op = as(_op); T *inptr = op->getInputs(0)->getRawDataPtr(); T *outptr = op->getOutput()->getRawDataPtr(); diff --git a/src/kernels/cuda/softmax.cc b/src/kernels/cuda/softmax.cc index 437ed849..024288c2 100644 --- a/src/kernels/cuda/softmax.cc +++ b/src/kernels/cuda/softmax.cc @@ -1,30 +1,30 @@ #include "operators/softmax.h" #include "cuda/cuda_kernel_wihtout_config.h" #include "cuda/cuda_runtime.h" -#include "cuda/softmax.h" +#include "cuda/cuda_softmax.h" namespace infini { -class SoftmaxCudnn : public CudaKernelWithoutConfig { +class SoftmaxCuda : public CudaKernelWithoutConfig { void compute(const Operator &_op, const RuntimeObj *_context) const override { auto op = as(_op); - auto x = op->getInputs(0)->getRawDataPtr(); - auto y = op->getOutput(0)->getRawDataPtr(); + auto input = op->getInputs(0)->getRawDataPtr(); + auto output = op->getOutput(0)->getRawDataPtr(); + const auto &inShape = op->getInputs(0)->getDims(); // input shape auto dims = op->getInputs(0)->getDims(); - int batch_size = 1; - for (size_t i = 0; i < dims.size(); ++i) - batch_size *= dims[i]; - int dim = dims[op->getAxis()]; + int size; // size = i(JKS) + j(KS) + k(S) + s + size = op->getOutput(0)->size(); + int dimsize = dims[op->getAxis()]; + int stride = op->getInputs(0)->getStride().at(op->getAxis()); - int block_num = batch_size / dim; - int max_threadblock_size = batch_size / block_num; - softmax_kernel(max_threadblock_size, block_num, x, y, dim, - op->getInputs(0)->getStride().at(op->getAxis())); + int num_blocks = size / dimsize; + softmax_kernel(num_blocks, (float *)input, (float *)output, size, + dimsize, stride); } }; -REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCudnn, +REGISTER_KERNEL(Device::CUDA, OpType::Softmax, DataType::Float32, SoftmaxCuda, "Softmax_CUDA_Float32"); } // namespace infini diff --git a/src/kernels/cuda/softmax.cu b/src/kernels/cuda/softmax.cu index 1f7f39e6..7e85ec43 100644 --- a/src/kernels/cuda/softmax.cu +++ b/src/kernels/cuda/softmax.cu @@ -1,77 +1,183 @@ #include "cuda/cuda_common.h" -#include "cuda/softmax.h" #include -struct __align__(8) MD { - float data; - float d; +struct __align__(8) DataMaxSum { // update the global max and sum, store the + // output at max_tmp and sum_tmp + float max_tmp; // store max + float sum_tmp; // store sum +}; +__device__ __forceinline__ DataMaxSum reduce_dms_op(DataMaxSum a, + DataMaxSum b) { + bool a_bigger = (a.max_tmp > b.max_tmp); + DataMaxSum bigger = a_bigger ? a : b; + DataMaxSum smaller = a_bigger ? b : a; + bigger.sum_tmp = bigger.sum_tmp + + smaller.sum_tmp * __expf(smaller.max_tmp - bigger.max_tmp); + + return bigger; +} +template +__launch_bounds__(BLOCK_DIM) __global__ void _blockSoftmaxKernel( + float *__restrict input, float *__restrict output, int size, int dimsize, + int stride) { // if set axis = 1, inputShape=[I,J,K,S] + // tid = i(JKS) + j(KS) + k(S) + s + + // blockDim.x = size/dimsize = IKS + // blockIdx.x = i(KS) + k(S) + s,blockIdx.x%stride = k(S) + s + + int tid = + blockIdx.x % stride + (blockIdx.x - blockIdx.x % stride) * + dimsize; // now, tid = i(JKS) + k(S) + s; + + DataMaxSum dms_partial; + dms_partial.max_tmp = -__FLT_MAX__; + dms_partial.sum_tmp = 0.0f; + DataMaxSum dms_input; + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + + dms_input.max_tmp = + input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride]; + + dms_input.sum_tmp = 1.0f; + dms_partial = reduce_dms_op(dms_partial, + dms_input); // reduce the data to one block + } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ DataMaxSum dms_total; + DataMaxSum dms_block = + BlockReduce(temp_storage).Reduce(dms_partial, reduce_dms_op); + if (threadIdx.x == + 0) { // must set threadIdx.x = 0 write the output to memory + dms_total = dms_block; + } + __syncthreads(); + //----------------- + + for (int ph = 0; threadIdx.x + ph * BLOCK_DIM < dimsize; ph++) { + output[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] = + __expf(input[tid + (threadIdx.x + ph * BLOCK_DIM) * stride] - + dms_total.max_tmp) * + __fdividef(1.0F, dms_total.sum_tmp); + } +} + +template struct SumOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return a + b; + } }; -__device__ __forceinline__ MD reduce_md_op(MD a, MD b) { - bool a_bigger = (a.data > b.data); - MD bigger_m = a_bigger ? a : b; - MD smaller_m = a_bigger ? b : a; - MD res; - res.d = bigger_m.d + smaller_m.d * __expf(smaller_m.data - bigger_m.data); - res.data = bigger_m.data; - return res; -} - -template -__launch_bounds__(THREADBLOCK_SIZE) __global__ - void online_softmax(const float *__restrict in, float *__restrict out, - int dimSize, int stride) { - - // reposition in and out to data for the current vector - int blockOffset = blockIdx.x; - if (blockIdx.x >= stride) { - int tmp = blockIdx.x % stride; - blockOffset = tmp + (blockIdx.x - tmp) * dimSize; +template struct MaxOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return max(a, b); } - in += blockOffset; - out += blockOffset; - - MD md_partial; - md_partial.data = -FLT_MAX; - md_partial.d = 0.0F; - - for (int elem_id = threadIdx.x; elem_id < dimSize; - elem_id += THREADBLOCK_SIZE) { - MD new_elem; - new_elem.data = in[elem_id * stride]; - new_elem.d = 1.0F; - md_partial = reduce_md_op(md_partial, new_elem); +}; +template