diff --git a/examples/distributed/launch_kvcache.py b/examples/distributed/launch_kvcache.py new file mode 100644 index 00000000..13f908af --- /dev/null +++ b/examples/distributed/launch_kvcache.py @@ -0,0 +1,245 @@ +import argparse +import os +import time +import multiprocessing as mp +from pyinfinitensor.onnx import OnnxStub, backend +import onnx +from onnx.external_data_helper import convert_model_to_external_data +import numpy as np +from parallel_opt import parallel_model + + +os.environ["NVIDIA_TF32_OVERRIDE"] = "0" + + +def parse_args(): + parser = argparse.ArgumentParser(description="launch distributed infinitensor") + parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes") + parser.add_argument( + "--nproc_per_node", type=int, default=1, help="number of processes per node" + ) + parser.add_argument( + "--name", type=str, default="test", help="name of this instance." + ) + parser.add_argument( + "--model1", type=str, required=True, help="path to the ONNX model file." + ) + parser.add_argument( + "--model2", type=str, required=True, help="path to the ONNX model file." + ) + parser.add_argument("--batch_size", type=int, default=1, help="batch size.") + parser.add_argument("--length", type=int, default=1, help="sequence length.") + parser.add_argument( + "--gen_std", + action="store_true", + help="whether to generate the standard results.", + ) + args = parser.parse_args() + print("arg setting: ", args) + return ( + args.num_nodes, + args.nproc_per_node, + args.name, + args.model1, + args.model2, + args.batch_size, + args.length, + args.gen_std, + ) + + +def run_model(model1, model2, runtime1, runtime2, inputs1: np.array, inputs2: np.array, n=20): + #################################### + # run the first graph without kvcache + #################################### + stub1 = OnnxStub(model1, runtime1) + stub1.inputs['onnx::Reshape_0'].copyin_int32(inputs1.reshape(-1).tolist()) + stub1.tune() + stub1.run() + kvcache_it1 = [] + count = 0 + for output in stub1.outputs.items().__iter__(): + if count == 0: + logits_it1 = np.array(output[1].copyout_float(), dtype=np.float32) + else: + kvcache_it1.append(np.array(output[1].copyout_float(), dtype=np.float32)) + count = count + 1 + + # bench for stub1 + next(stub1.inputs.items().__iter__())[1].copyin_int32(inputs1.reshape(-1).tolist()) + begin = time.time() + for _ in range(n): + stub1.run() + end = time.time() + avg_time = (end - begin) / n + print(f"stub1 average time: {avg_time}") + + #################################### + # run the second graph with kvcache + #################################### + i = 0 + batchsize = 1 + stub2 = OnnxStub(model2, runtime2) + past_kvcache_length = (i+2)*np.ones((batchsize, 1), dtype=np.int32) + # copyin input + stub2.inputs['onnx::Reshape_0'].copyin_int32(inputs2.reshape(-1).tolist()) + stub2.inputs['input.3'].copyin_int32(past_kvcache_length.reshape(-1).tolist()) + count = -1 + for input in stub2.inputs.items().__iter__(): + if count in range(24): + # print(count, input[0]) + # print(np.dtype(kvcache_it1[count][0]), kvcache_it1[count].shape) + input[1].copyin_float(kvcache_it1[count].reshape(-1).tolist()) + count = count + 1 + stub2.tune() + stub2.run() + + # copyout output + count = 0 + kvcache_it2 = [] + for output in stub2.outputs.items().__iter__(): + if count == 0: + logits_it2 = np.array(output[1].copyout_float(), dtype=np.float32) + else: + kvcache_it2.append(np.array(output[1].copyout_float(), dtype=np.float32)) + count = count + 1 + + # bench for stub2 + # copyin input + stub2.inputs['onnx::Reshape_0'].copyin_int32(inputs2.reshape(-1).tolist()) + stub2.inputs['input.3'].copyin_int32(past_kvcache_length.reshape(-1).tolist()) + count = -1 + for input in stub2.inputs.items().__iter__(): + if count in range(24): + input[1].copyin_float(kvcache_it1[count].reshape(-1).tolist()) + count = count + 1 + begin = time.time() + for _ in range(n): + stub2.run() + end = time.time() + avg_time = (end - begin) / n + print(f"stub2 average time: {avg_time}") + return logits_it2 + + +def run_and_compare(name, model1, model2, runtime1, runtime2): + data1 = np.load(f"{name}_inputs1.npy") + data2 = np.load(f"{name}_inputs2.npy") + results = np.load(f"{name}_results.npy") + outputs = run_model(model1, model2, runtime1, runtime2, data1, data2) + print("outputs sum:", outputs.sum()) + print("max abs diff:", abs(outputs - results).max()) + print("max rel diff:", abs((outputs - results) / results).max()) + # assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6) + + +def start_worker( + name: str, world_size: int, rank: int, local_rank: int, model1: onnx.ModelProto, model2: onnx.ModelProto +): + dist_name = name + "_dist" + #################################### + # shard the first graph + #################################### + model1 = parallel_model(model1, world_size, rank) + extern_path = f"./{dist_name}_stub1_rank{rank}.pb" + if os.path.exists(extern_path): + os.remove(extern_path) + convert_model_to_external_data( + model1, + all_tensors_to_one_file=True, + location=extern_path, + size_threshold=1024, + convert_attribute=False, + ) + onnx.save(model1, f"./{dist_name}_stub1_rank{rank}.onnx") + runtime1 = backend.CudaRuntime(local_rank) + runtime1.init_comm( + dist_name, + world_size, + rank, + ) + + #################################### + # shard the second graph + #################################### + model2 = parallel_model(model2, world_size, rank) + extern_path = f"./{dist_name}_stub2_rank{rank}.pb" + if os.path.exists(extern_path): + os.remove(extern_path) + convert_model_to_external_data( + model2, + all_tensors_to_one_file=True, + location=extern_path, + size_threshold=1024, + convert_attribute=False, + ) + onnx.save(model2, f"./{dist_name}_stub2_rank{rank}.onnx") + runtime2 = backend.CudaRuntime(local_rank) + # print("init comm") + runtime2.init_comm( + dist_name, + world_size, + rank, + ) + + # run the two graphs + run_and_compare(name, model1, model2, runtime1, runtime2) + + +def start_single(name, model1, model2): + runtime1 = backend.CudaRuntime(0) + runtime2 = backend.CudaRuntime(0) + run_and_compare(name, model1, model2, runtime1, runtime2) + + +def gen_standard(name, model1, model2, voc_size, bs, len): + # generate standard results + data1 = np.random.randint(0, voc_size, (bs, len), dtype=np.int32) + data2 = np.random.randint(0, voc_size, (bs, len), dtype=np.int32) + np.save(f"{name}_inputs1", data1) + np.save(f"{name}_inputs2", data2) + runtime1 = backend.CudaRuntime(0) + runtime2 = backend.CudaRuntime(0) + outputs = run_model(model1, model2, runtime1, runtime2, data1, data2, 1) + np.save(f"{name}_results", outputs) + + +def main(): + nnodes, nproc_per_node, name, model1_path, model2_path, bs, length, gen_std = parse_args() + + model1 = onnx.load(model1_path) + model2 = onnx.load(model2_path) + + # generate standart output + if gen_std: + print(f"generate standard data for {name}.") + # a small vocabulary size to fit all LLM. + voc_size = 1000 + gen_standard(name, model1, model2, voc_size, bs, length) + return + + # run single process. + # use standalone process to isolate cuda. + p = mp.Process(target=start_single, args=(name, model1, model2)) + p.start() + p.join() + + # run distributed parallel. + world_size = nnodes * nproc_per_node + workers = [ + mp.Process( + target=start_worker, + args=(name, world_size, rank, rank % nproc_per_node, model1, model2), + ) + for rank in range(world_size) + ] + + for w in workers: + w.start() + + for w in workers: + w.join() + + +if __name__ == "__main__": + main() diff --git a/examples/distributed/parallel_opt.py b/examples/distributed/parallel_opt.py index c152f6be..b16386a7 100644 --- a/examples/distributed/parallel_opt.py +++ b/examples/distributed/parallel_opt.py @@ -56,6 +56,16 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): ndim = len(vinfo[output].type.tensor_type.shape.dim) out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial() place[node.output[0]] = out_plc + + def shard_concat(node: NodeProto): + # hack for kvcache + in_plc = place[node.input[1]] + if in_plc.is_sharded(): + seq_len_dim = vinfo[node.input[0]].type.tensor_type.shape.dim.pop(1) + seq_len_dim.dim_value //= tp_world_size + vinfo[node.input[0]].type.tensor_type.shape.dim.insert(1, seq_len_dim) + place[node.input[0]] = in_plc + place[node.output[0]] = in_plc def shard_binary(node: NodeProto, groups: int = 1): # print("binary", node.name, node.input[0], place[node.input[0]]) @@ -143,6 +153,8 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): place[node.input[0]] == place[node.input[1]] ), f"{place[node.input[0]]} != {place[node.input[1]]}" place[node.output[0]] = place[node.input[0]] + elif node.op_type == "Concat": + shard_concat(node) def find_successor(op_type: str, idx: int, search_limit: int = 1): for node in model.graph.node[idx + 1 : idx + 1 + search_limit]: @@ -203,10 +215,14 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0): continue shard_node(node) + new_input = [] + for info in model.graph.input: + new_input.append(vinfo[info.name]) + graph = helper.make_graph( nodes, model.graph.name + f"_{tp_rank}", - model.graph.input, + new_input, model.graph.output, data.values(), doc_string=model.graph.doc_string, diff --git a/include/core/graph.h b/include/core/graph.h index 3efd893f..8415b15a 100644 --- a/include/core/graph.h +++ b/include/core/graph.h @@ -120,6 +120,11 @@ class GraphObj : public Object { * @brief If the nodes is sorted in topological order. */ bool sorted; + + /** + * @brief If the weight tensors are allocated. + */ + bool weightAllocated = false; }; } // namespace infini diff --git a/include/core/lazy_allocator.h b/include/core/lazy_allocator.h index 228639a3..5f073845 100644 --- a/include/core/lazy_allocator.h +++ b/include/core/lazy_allocator.h @@ -20,14 +20,23 @@ class LazyAllocator { Runtime runtime; - size_t used; + size_t used = 0; - size_t peak; + size_t peak = 0; + + size_t weightPeak = 0; size_t alignment; // pointer to the memory actually allocated - void *ptr; + void *ptr = nullptr; + + // pointer to the weight memory space + void *weightPtr = nullptr; + + // // a cache designed for a batch size that has already occurred + // std::unordered_map> + // batchsizeToTensorOffset; struct freeBlockInfo { size_t addr; @@ -57,12 +66,16 @@ class LazyAllocator { virtual ~LazyAllocator(); + void init(); + // function: simulate memory allocation // arguments: // size: size of memory block to be allocated // return: head address offset of the allocated memory block size_t alloc(size_t size); + size_t allocWeight(size_t size); + // function: simulate memory free // arguments: // addr: head address offset of memory block to be free @@ -73,6 +86,12 @@ class LazyAllocator { // return: pointer to the head address of the allocated memory void *getPtr(); + // void addCache(size_t batchsize, std::unordered_map); + + // std::unordered_map getCache(size_t batchsize); + + void *getWeightPtr(); + void info(); private: diff --git a/include/core/tensor.h b/include/core/tensor.h index 03e1b20c..edaa8655 100644 --- a/include/core/tensor.h +++ b/include/core/tensor.h @@ -1,5 +1,6 @@ #pragma once #include "core/tensor_base.h" +#include "core/tensor_type.h" #include "utils/data_convert.h" #include #include @@ -19,6 +20,8 @@ class TensorObj : public TensorBaseObj { size_t _size; // Cache of Π(shape). Fuid fuid; // Cloned tensors share the same id. Tensors constructed from // scratch have a new id. + TensorType tensorType = TensorType::others; + public: TensorObj(Shape shape, DataType dtype, Runtime runtime); virtual ~TensorObj() {} @@ -33,6 +36,33 @@ class TensorObj : public TensorBaseObj { size_t getOffset(const vector &ds) const; void dataMalloc(); UidBaseType getFuid() const { return fuid; } + bool isWeight() const { return tensorType == TensorType::weight; } + bool isInput() const { return tensorType == TensorType::input; } + bool isOutput() const { return tensorType == TensorType::output; } + bool isOthers() const { return tensorType == TensorType::others; } + void setWeight() { tensorType = TensorType::weight; } + void setInput() { tensorType = TensorType::input; } + void setOutput() { tensorType = TensorType::output; } + string tensorTypeToString() const { + switch (tensorType) { + case TensorType::weight: + return "weight"; + break; + case TensorType::input: + return "input"; + break; + case TensorType::output: + return "output"; + break; + case TensorType::others: + return "others"; + break; + + default: + return "unknown tensor type"; + break; + } + } void load(std::string file_path); void save(std::string file_path); diff --git a/include/core/tensor_type.h b/include/core/tensor_type.h new file mode 100644 index 00000000..46df6073 --- /dev/null +++ b/include/core/tensor_type.h @@ -0,0 +1,7 @@ +#pragma once + +namespace infini { + +enum class TensorType { weight, input, output, others }; + +} // namespace infini diff --git a/pyinfinitensor/src/pyinfinitensor/onnx.py b/pyinfinitensor/src/pyinfinitensor/onnx.py index bade70f8..813a5e8e 100644 --- a/pyinfinitensor/src/pyinfinitensor/onnx.py +++ b/pyinfinitensor/src/pyinfinitensor/onnx.py @@ -32,24 +32,27 @@ class OnnxStub: The Onnx model imported into infinitensor. It can be generated from an Onnx model object. """ - - inputs: Dict[str, backend.Tensor] = {} - outputs: Dict[str, backend.Tensor] = {} - initializer: Dict[int, TensorProto] = {} - handler: backend.GraphHandler - def __init__(self, model: ModelProto, runtime): + self.inputs: Dict[str, backend.Tensor] = {} + self.outputs: Dict[str, backend.Tensor] = {} + self.initializer: Dict[int, TensorProto] = {} model = infer_shapes(model) self.handler = backend.GraphHandler(runtime) tensors: Dict[str, backend.Tensor] = dict() data: Dict[str, TensorProto] = dict() + for initializer in model.graph.initializer: + dims = [d for d in initializer.dims] + tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type) + data[initializer.name] = initializer + for input in model.graph.input: dims = _take_shape_dim(input.type.tensor_type.shape) - tensors[input.name] = self.handler.tensor( - dims, input.type.tensor_type.elem_type - ) + if input.name not in tensors.keys(): + tensors[input.name] = self.handler.tensor( + dims, input.type.tensor_type.elem_type + ) for output in model.graph.output: dims = _take_shape_dim(output.type.tensor_type.shape) @@ -57,10 +60,6 @@ class OnnxStub: dims, output.type.tensor_type.elem_type ) - for initializer in model.graph.initializer: - dims = [d for d in initializer.dims] - tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type) - data[initializer.name] = initializer node_name = [] new_node_name = [] @@ -667,6 +666,19 @@ class OnnxStub: # update the node_list node_list = list(set(node_name) - set(new_node_name)) + ################################ + # Set tensor type + ################################ + for initializer in model.graph.initializer: + tensors[initializer.name].set_weight() + + for input in model.graph.input: + tensors[input.name].set_input() + + for output in model.graph.output: + tensors[output.name].set_output() + + ################################ # Allocate memory space for data ################################ diff --git a/src/core/graph.cc b/src/core/graph.cc index 05f45fae..0f844c34 100644 --- a/src/core/graph.cc +++ b/src/core/graph.cc @@ -131,30 +131,63 @@ void GraphObj::dataMalloc() { // record the memory address offsets of all tensors to be allocated std::unordered_map tensorToOffset; - // record all constant tensors, including weight tensors and input tensors - std::unordered_set constTensor; + // reinit allocator + allocator.init(); + + // record all weight tensors, including weight tensors and kvcache + // tensors + std::unordered_set weightTensors; for (auto &tensor : tensors) { - if (tensor.get()->getSource() == nullptr) { - // allocate memory for all constant tensors first, and this memory + if (tensor->isWeight()) { + // allocate memory for all weight tensors first, and this memory + // will not be freed until the graph is destroyed + weightTensors.insert(tensor.get()); + if (!this->weightAllocated) { + tensorToOffset[tensor.get()] = + allocator.allocWeight(tensor->getBytes()); + } + } else if (tensor->isInput() || tensor->isOutput()) { + // allocate memory for all input and output tensors, and this memory // will not be reused later - constTensor.insert(tensor.get()); tensorToOffset[tensor.get()] = allocator.alloc(tensor->getBytes()); } else { tensorToRefCount[tensor.get()] = tensor->getTargets().size(); + // allocate memory for all user-created tensors + if (tensor.get()->getSource() == nullptr) { + tensorToOffset[tensor.get()] = + allocator.alloc(tensor->getBytes()); + } + } + } + // if memory has not yet been allocated for weight tensors, + // allocate memory now and do not allocate again in the future. + if (!this->weightAllocated) { + this->weightAllocated = true; + // only allocate once for weight tensors + for (auto &tensor : weightTensors) { + IT_ASSERT(tensorToOffset.find(tensor) != tensorToOffset.end()); + tensor->setDataBlob(make_ref( + tensor->runtime, + static_cast(allocator.getWeightPtr()) + + tensorToOffset[tensor])); } } // traverse in topological order and simulate memory allocation for (auto &op : ops) { - // memory should be allocated for the output first + // memory should be allocated for the op's output first auto outputs = op->getOutputs(); for (auto &tensor : outputs) { - tensorToOffset[tensor.get()] = allocator.alloc(tensor->getBytes()); + if (tensor->isOthers()) { + tensorToOffset[tensor.get()] = + allocator.alloc(tensor->getBytes()); + } } auto inputs = op->getInputs(); for (auto &tensor : inputs) { - if (constTensor.find(tensor.get()) == constTensor.end()) { + if (tensor->isOthers()) { auto tensorIter = tensorToRefCount.find(tensor.get()); IT_ASSERT(tensorIter != tensorToRefCount.end()); + IT_ASSERT(tensorToRefCount[tensor.get()] > 0); tensorToRefCount[tensor.get()] -= 1; if (tensorToRefCount[tensor.get()] == 0) { // indicate that this tensor will no longer be used and @@ -167,15 +200,20 @@ void GraphObj::dataMalloc() { } } - // perform actual memory allocation + // perform actual memory allocation for non-weight tensors for (auto &tensor : tensors) { - IT_ASSERT(tensorToOffset.find(tensor.get()) != tensorToOffset.end()); - tensor->setDataBlob(make_ref( - tensor->runtime, static_cast(allocator.getPtr()) + - tensorToOffset[tensor.get()])); + if (!tensor->isWeight()) { + IT_ASSERT(tensorToOffset.find(tensor.get()) != + tensorToOffset.end()); + tensor->setDataBlob(make_ref( + tensor->runtime, static_cast(allocator.getPtr()) + + tensorToOffset[tensor.get()])); + } } +#ifdef DEBUG_MODE allocator.info(); +#endif } Tensor GraphObj::addTensor(Shape dim, DataType dtype) { diff --git a/src/core/lazy_allocator.cc b/src/core/lazy_allocator.cc index bb7f766f..a5014e5c 100644 --- a/src/core/lazy_allocator.cc +++ b/src/core/lazy_allocator.cc @@ -11,9 +11,6 @@ namespace infini { constexpr size_t alignmentInBytesForCUDA = 256; LazyAllocator::LazyAllocator(Runtime runtime) : runtime(runtime) { - used = 0; - peak = 0; - ptr = nullptr; if (runtime->isCuda()) { // TODO: the alignment on cuda might need further discussion alignment = alignmentInBytesForCUDA; @@ -30,10 +27,24 @@ LazyAllocator::~LazyAllocator() { if (this->ptr != nullptr) { runtime->dealloc(this->ptr); } + if (this->weightPtr != nullptr) { + runtime->dealloc(this->weightPtr); + } +} + +void LazyAllocator::init() { + used = 0; + peak = 0; + freeBlocks.clear(); + headAddrToBlockSize.clear(); + tailAddrToBlockSize.clear(); + if (this->ptr != nullptr) { + runtime->dealloc(this->ptr); + } + this->ptr = nullptr; } size_t LazyAllocator::alloc(size_t size) { - IT_ASSERT(this->ptr == nullptr); // pad the size to the multiple of alignment size = this->getAlignedSize(size); auto it = this->freeBlocks.lower_bound(freeBlockInfo{(size_t)0, size}); @@ -83,6 +94,14 @@ size_t LazyAllocator::alloc(size_t size) { return retAddr; } +size_t LazyAllocator::allocWeight(size_t size) { + IT_ASSERT(this->weightPtr == nullptr); + size = this->getAlignedSize(size); + size_t retAddr = this->weightPeak; + this->weightPeak += size; + return retAddr; +} + void LazyAllocator::free(size_t addr, size_t size) { IT_ASSERT(this->ptr == nullptr); size = getAlignedSize(size); @@ -126,18 +145,33 @@ void LazyAllocator::free(size_t addr, size_t size) { void *LazyAllocator::getPtr() { if (this->ptr == nullptr) { this->ptr = runtime->alloc(this->peak); - printf("LazyAllocator really alloc: %p %lu bytes\n", this->ptr, peak); +#ifdef DEBUG_MODE + printf("LazyAllocator really alloc non-weight: %p %lu bytes\n", + this->ptr, peak); +#endif } return this->ptr; } +void *LazyAllocator::getWeightPtr() { + if (this->weightPtr == nullptr) { + this->weightPtr = runtime->alloc(this->weightPeak); +#ifdef DEBUG_MODE + printf("LazyAllocator really alloc weight: %p %lu bytes\n", + this->weightPtr, weightPeak); +#endif + } + return this->weightPtr; +} + size_t LazyAllocator::getAlignedSize(size_t size) { return ((size - 1) / this->alignment + 1) * this->alignment; } void LazyAllocator::info() { - std::cout << "Used memory: " << this->used - << ", peak memory: " << this->peak << std::endl; + std::cout << "Used memory: " << this->used + this->weightPeak + << ", peak memory: " << this->peak + this->weightPeak + << std::endl; } } // namespace infini diff --git a/src/core/tensor.cc b/src/core/tensor.cc index 2d786f7b..d318f014 100644 --- a/src/core/tensor.cc +++ b/src/core/tensor.cc @@ -23,7 +23,7 @@ string TensorObj::toString() const { string ret = "Tensor " + std::to_string(guid) + ", Fuid " + std::to_string(fuid) + ", shape " + vecToString(shape) + ", dtype " + dtype.toString() + ", " + runtime->toString() + - ", " + ss.str() + "\n"; + ", " + ss.str() + ", " + tensorTypeToString() + "\n"; vector targetGuids; for (const auto &op : targets) targetGuids.emplace_back(op.lock()->getGuid()); diff --git a/src/ffi/ffi_infinitensor.cc b/src/ffi/ffi_infinitensor.cc index 28b17bd4..bea3f4bc 100644 --- a/src/ffi/ffi_infinitensor.cc +++ b/src/ffi/ffi_infinitensor.cc @@ -364,6 +364,9 @@ void init_graph_builder(py::module &m) { py::buffer_protocol()) .def("fuid", &TensorObj::getFuid, policy::automatic) .def("shape", &TensorObj::getDims, policy::move) + .def("set_weight", &TensorObj::setWeight, policy::move) + .def("set_input", &TensorObj::setInput, policy::move) + .def("set_output", &TensorObj::setOutput, policy::move) .def("dtype", &TensorObj::getDTypeIndex, policy::automatic) .def("copyin_float", &TensorObj::copyin, policy::move) .def("copyin_int32", &TensorObj::copyin, policy::move)