Support kvcache (#134)

* add cmake bits about NCCL

* move example to examples/NNmodel

* impl NCCL communicator

* add comm related function to Runtime

* export runtime interface

* add launch.py

* use unique name to distingush the the NCCL ID file

* add timeout to communicator init

* expose communicator obj from runtime obj, add unit test for nccl communicator

* reformat files

* Add allReduce operator and cuda nccl allReduce kernel

* impl model parallel for resnet

* add allGather nccl kernel and operator

* Add allreduce allgather operator tests, change allgather kernel to output list of tensor, fix shape infer, handle nullptr output

* fix format of onnx.py

* use concat following AllGather

* get tensor parallel for resnet

* fix format of graph_handler.cc

* change BUILD_DIST default to OFF

* polish code of communicator

* update .gitignore

* export min/max to python

* fix MatMul

* modify launch.py to run opt

* hack to treat ReduceSum as AllReduceSum

* throw exception in cuda error

* fix parallel_opt.py

* improve the error prompt and cuda error check

* fix GatherObj::GatherObj member init

* fix size calculation for scalar (rank = 0) tensor

* MatMul supports bias

* fix add bias for row parallel gemm

* add --gen_std to launch.py

* fix AllReduceNCCL

* update launch.py

* less log

* update parallel_opt

* update launch.py

* add __eq__ for Placement sub-classes

* less benchmark run

* fix placement infer for matmul

* fix vacabuary size

* fix Exception

* Add shard tensor with group to support gpt2

* Add find successor function to find split op at different depth

* recover CommunicatorObj

* improve error mesasge

* optimize parallel_opt.py

* optimize launch.py

* recover docs for all_reduce and all_gather

* - support concat for kvcache

* - modify allocator

* - add tensorType
- modify allocator to support memory allocation based on tensorType

* - fix allocator init

* - support kvcache by running 2 stub distributively

* - fix name

* - remove unused flag

* - fix wrong pb name

* - fix as constroy suggessed

* - fix launch.py format

---------

Co-authored-by: constroy <constroy.li@gmail.com>
Co-authored-by: panzezhong <panzezhong@qiyuanlab.com>
This commit is contained in:
kilinchange 2023-09-18 14:17:02 +08:00 committed by GitHub
parent c6b82cfda0
commit 48ec730579
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 447 additions and 38 deletions

View File

@ -0,0 +1,245 @@
import argparse
import os
import time
import multiprocessing as mp
from pyinfinitensor.onnx import OnnxStub, backend
import onnx
from onnx.external_data_helper import convert_model_to_external_data
import numpy as np
from parallel_opt import parallel_model
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
def parse_args():
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
parser.add_argument(
"--nproc_per_node", type=int, default=1, help="number of processes per node"
)
parser.add_argument(
"--name", type=str, default="test", help="name of this instance."
)
parser.add_argument(
"--model1", type=str, required=True, help="path to the ONNX model file."
)
parser.add_argument(
"--model2", type=str, required=True, help="path to the ONNX model file."
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument(
"--gen_std",
action="store_true",
help="whether to generate the standard results.",
)
args = parser.parse_args()
print("arg setting: ", args)
return (
args.num_nodes,
args.nproc_per_node,
args.name,
args.model1,
args.model2,
args.batch_size,
args.length,
args.gen_std,
)
def run_model(model1, model2, runtime1, runtime2, inputs1: np.array, inputs2: np.array, n=20):
####################################
# run the first graph without kvcache
####################################
stub1 = OnnxStub(model1, runtime1)
stub1.inputs['onnx::Reshape_0'].copyin_int32(inputs1.reshape(-1).tolist())
stub1.tune()
stub1.run()
kvcache_it1 = []
count = 0
for output in stub1.outputs.items().__iter__():
if count == 0:
logits_it1 = np.array(output[1].copyout_float(), dtype=np.float32)
else:
kvcache_it1.append(np.array(output[1].copyout_float(), dtype=np.float32))
count = count + 1
# bench for stub1
next(stub1.inputs.items().__iter__())[1].copyin_int32(inputs1.reshape(-1).tolist())
begin = time.time()
for _ in range(n):
stub1.run()
end = time.time()
avg_time = (end - begin) / n
print(f"stub1 average time: {avg_time}")
####################################
# run the second graph with kvcache
####################################
i = 0
batchsize = 1
stub2 = OnnxStub(model2, runtime2)
past_kvcache_length = (i+2)*np.ones((batchsize, 1), dtype=np.int32)
# copyin input
stub2.inputs['onnx::Reshape_0'].copyin_int32(inputs2.reshape(-1).tolist())
stub2.inputs['input.3'].copyin_int32(past_kvcache_length.reshape(-1).tolist())
count = -1
for input in stub2.inputs.items().__iter__():
if count in range(24):
# print(count, input[0])
# print(np.dtype(kvcache_it1[count][0]), kvcache_it1[count].shape)
input[1].copyin_float(kvcache_it1[count].reshape(-1).tolist())
count = count + 1
stub2.tune()
stub2.run()
# copyout output
count = 0
kvcache_it2 = []
for output in stub2.outputs.items().__iter__():
if count == 0:
logits_it2 = np.array(output[1].copyout_float(), dtype=np.float32)
else:
kvcache_it2.append(np.array(output[1].copyout_float(), dtype=np.float32))
count = count + 1
# bench for stub2
# copyin input
stub2.inputs['onnx::Reshape_0'].copyin_int32(inputs2.reshape(-1).tolist())
stub2.inputs['input.3'].copyin_int32(past_kvcache_length.reshape(-1).tolist())
count = -1
for input in stub2.inputs.items().__iter__():
if count in range(24):
input[1].copyin_float(kvcache_it1[count].reshape(-1).tolist())
count = count + 1
begin = time.time()
for _ in range(n):
stub2.run()
end = time.time()
avg_time = (end - begin) / n
print(f"stub2 average time: {avg_time}")
return logits_it2
def run_and_compare(name, model1, model2, runtime1, runtime2):
data1 = np.load(f"{name}_inputs1.npy")
data2 = np.load(f"{name}_inputs2.npy")
results = np.load(f"{name}_results.npy")
outputs = run_model(model1, model2, runtime1, runtime2, data1, data2)
print("outputs sum:", outputs.sum())
print("max abs diff:", abs(outputs - results).max())
print("max rel diff:", abs((outputs - results) / results).max())
# assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6)
def start_worker(
name: str, world_size: int, rank: int, local_rank: int, model1: onnx.ModelProto, model2: onnx.ModelProto
):
dist_name = name + "_dist"
####################################
# shard the first graph
####################################
model1 = parallel_model(model1, world_size, rank)
extern_path = f"./{dist_name}_stub1_rank{rank}.pb"
if os.path.exists(extern_path):
os.remove(extern_path)
convert_model_to_external_data(
model1,
all_tensors_to_one_file=True,
location=extern_path,
size_threshold=1024,
convert_attribute=False,
)
onnx.save(model1, f"./{dist_name}_stub1_rank{rank}.onnx")
runtime1 = backend.CudaRuntime(local_rank)
runtime1.init_comm(
dist_name,
world_size,
rank,
)
####################################
# shard the second graph
####################################
model2 = parallel_model(model2, world_size, rank)
extern_path = f"./{dist_name}_stub2_rank{rank}.pb"
if os.path.exists(extern_path):
os.remove(extern_path)
convert_model_to_external_data(
model2,
all_tensors_to_one_file=True,
location=extern_path,
size_threshold=1024,
convert_attribute=False,
)
onnx.save(model2, f"./{dist_name}_stub2_rank{rank}.onnx")
runtime2 = backend.CudaRuntime(local_rank)
# print("init comm")
runtime2.init_comm(
dist_name,
world_size,
rank,
)
# run the two graphs
run_and_compare(name, model1, model2, runtime1, runtime2)
def start_single(name, model1, model2):
runtime1 = backend.CudaRuntime(0)
runtime2 = backend.CudaRuntime(0)
run_and_compare(name, model1, model2, runtime1, runtime2)
def gen_standard(name, model1, model2, voc_size, bs, len):
# generate standard results
data1 = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
data2 = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
np.save(f"{name}_inputs1", data1)
np.save(f"{name}_inputs2", data2)
runtime1 = backend.CudaRuntime(0)
runtime2 = backend.CudaRuntime(0)
outputs = run_model(model1, model2, runtime1, runtime2, data1, data2, 1)
np.save(f"{name}_results", outputs)
def main():
nnodes, nproc_per_node, name, model1_path, model2_path, bs, length, gen_std = parse_args()
model1 = onnx.load(model1_path)
model2 = onnx.load(model2_path)
# generate standart output
if gen_std:
print(f"generate standard data for {name}.")
# a small vocabulary size to fit all LLM.
voc_size = 1000
gen_standard(name, model1, model2, voc_size, bs, length)
return
# run single process.
# use standalone process to isolate cuda.
p = mp.Process(target=start_single, args=(name, model1, model2))
p.start()
p.join()
# run distributed parallel.
world_size = nnodes * nproc_per_node
workers = [
mp.Process(
target=start_worker,
args=(name, world_size, rank, rank % nproc_per_node, model1, model2),
)
for rank in range(world_size)
]
for w in workers:
w.start()
for w in workers:
w.join()
if __name__ == "__main__":
main()

View File

@ -57,6 +57,16 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial()
place[node.output[0]] = out_plc
def shard_concat(node: NodeProto):
# hack for kvcache
in_plc = place[node.input[1]]
if in_plc.is_sharded():
seq_len_dim = vinfo[node.input[0]].type.tensor_type.shape.dim.pop(1)
seq_len_dim.dim_value //= tp_world_size
vinfo[node.input[0]].type.tensor_type.shape.dim.insert(1, seq_len_dim)
place[node.input[0]] = in_plc
place[node.output[0]] = in_plc
def shard_binary(node: NodeProto, groups: int = 1):
# print("binary", node.name, node.input[0], place[node.input[0]])
a = node.input[0]
@ -143,6 +153,8 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
place[node.input[0]] == place[node.input[1]]
), f"{place[node.input[0]]} != {place[node.input[1]]}"
place[node.output[0]] = place[node.input[0]]
elif node.op_type == "Concat":
shard_concat(node)
def find_successor(op_type: str, idx: int, search_limit: int = 1):
for node in model.graph.node[idx + 1 : idx + 1 + search_limit]:
@ -203,10 +215,14 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
continue
shard_node(node)
new_input = []
for info in model.graph.input:
new_input.append(vinfo[info.name])
graph = helper.make_graph(
nodes,
model.graph.name + f"_{tp_rank}",
model.graph.input,
new_input,
model.graph.output,
data.values(),
doc_string=model.graph.doc_string,

View File

@ -120,6 +120,11 @@ class GraphObj : public Object {
* @brief If the nodes is sorted in topological order.
*/
bool sorted;
/**
* @brief If the weight tensors are allocated.
*/
bool weightAllocated = false;
};
} // namespace infini

View File

@ -20,14 +20,23 @@ class LazyAllocator {
Runtime runtime;
size_t used;
size_t used = 0;
size_t peak;
size_t peak = 0;
size_t weightPeak = 0;
size_t alignment;
// pointer to the memory actually allocated
void *ptr;
void *ptr = nullptr;
// pointer to the weight memory space
void *weightPtr = nullptr;
// // a cache designed for a batch size that has already occurred
// std::unordered_map<size_t, std::unordered_map<TensorObj *, size_t>>
// batchsizeToTensorOffset;
struct freeBlockInfo {
size_t addr;
@ -57,12 +66,16 @@ class LazyAllocator {
virtual ~LazyAllocator();
void init();
// function: simulate memory allocation
// arguments
// size: size of memory block to be allocated
// return: head address offset of the allocated memory block
size_t alloc(size_t size);
size_t allocWeight(size_t size);
// function: simulate memory free
// arguments:
// addr: head address offset of memory block to be free
@ -73,6 +86,12 @@ class LazyAllocator {
// return: pointer to the head address of the allocated memory
void *getPtr();
// void addCache(size_t batchsize, std::unordered_map<TensorObj *, size_t>);
// std::unordered_map<TensorObj *, size_t> getCache(size_t batchsize);
void *getWeightPtr();
void info();
private:

View File

@ -1,5 +1,6 @@
#pragma once
#include "core/tensor_base.h"
#include "core/tensor_type.h"
#include "utils/data_convert.h"
#include <cmath>
#include <cstring>
@ -19,6 +20,8 @@ class TensorObj : public TensorBaseObj {
size_t _size; // Cache of Π(shape).
Fuid fuid; // Cloned tensors share the same id. Tensors constructed from
// scratch have a new id.
TensorType tensorType = TensorType::others;
public:
TensorObj(Shape shape, DataType dtype, Runtime runtime);
virtual ~TensorObj() {}
@ -33,6 +36,33 @@ class TensorObj : public TensorBaseObj {
size_t getOffset(const vector<int> &ds) const;
void dataMalloc();
UidBaseType getFuid() const { return fuid; }
bool isWeight() const { return tensorType == TensorType::weight; }
bool isInput() const { return tensorType == TensorType::input; }
bool isOutput() const { return tensorType == TensorType::output; }
bool isOthers() const { return tensorType == TensorType::others; }
void setWeight() { tensorType = TensorType::weight; }
void setInput() { tensorType = TensorType::input; }
void setOutput() { tensorType = TensorType::output; }
string tensorTypeToString() const {
switch (tensorType) {
case TensorType::weight:
return "weight";
break;
case TensorType::input:
return "input";
break;
case TensorType::output:
return "output";
break;
case TensorType::others:
return "others";
break;
default:
return "unknown tensor type";
break;
}
}
void load(std::string file_path);
void save(std::string file_path);

View File

@ -0,0 +1,7 @@
#pragma once
namespace infini {
enum class TensorType { weight, input, output, others };
} // namespace infini

View File

@ -32,21 +32,24 @@ class OnnxStub:
The Onnx model imported into infinitensor.
It can be generated from an Onnx model object.
"""
inputs: Dict[str, backend.Tensor] = {}
outputs: Dict[str, backend.Tensor] = {}
initializer: Dict[int, TensorProto] = {}
handler: backend.GraphHandler
def __init__(self, model: ModelProto, runtime):
self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {}
self.initializer: Dict[int, TensorProto] = {}
model = infer_shapes(model)
self.handler = backend.GraphHandler(runtime)
tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, TensorProto] = dict()
for initializer in model.graph.initializer:
dims = [d for d in initializer.dims]
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
data[initializer.name] = initializer
for input in model.graph.input:
dims = _take_shape_dim(input.type.tensor_type.shape)
if input.name not in tensors.keys():
tensors[input.name] = self.handler.tensor(
dims, input.type.tensor_type.elem_type
)
@ -57,10 +60,6 @@ class OnnxStub:
dims, output.type.tensor_type.elem_type
)
for initializer in model.graph.initializer:
dims = [d for d in initializer.dims]
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
data[initializer.name] = initializer
node_name = []
new_node_name = []
@ -667,6 +666,19 @@ class OnnxStub:
# update the node_list
node_list = list(set(node_name) - set(new_node_name))
################################
# Set tensor type
################################
for initializer in model.graph.initializer:
tensors[initializer.name].set_weight()
for input in model.graph.input:
tensors[input.name].set_input()
for output in model.graph.output:
tensors[output.name].set_output()
################################
# Allocate memory space for data
################################

View File

@ -131,30 +131,63 @@ void GraphObj::dataMalloc() {
// record the memory address offsets of all tensors to be allocated
std::unordered_map<TensorObj *, size_t> tensorToOffset;
// record all constant tensors, including weight tensors and input tensors
std::unordered_set<TensorObj *> constTensor;
// reinit allocator
allocator.init();
// record all weight tensors, including weight tensors and kvcache
// tensors
std::unordered_set<TensorObj *> weightTensors;
for (auto &tensor : tensors) {
if (tensor.get()->getSource() == nullptr) {
// allocate memory for all constant tensors first, and this memory
if (tensor->isWeight()) {
// allocate memory for all weight tensors first, and this memory
// will not be freed until the graph is destroyed
weightTensors.insert(tensor.get());
if (!this->weightAllocated) {
tensorToOffset[tensor.get()] =
allocator.allocWeight(tensor->getBytes());
}
} else if (tensor->isInput() || tensor->isOutput()) {
// allocate memory for all input and output tensors, and this memory
// will not be reused later
constTensor.insert(tensor.get());
tensorToOffset[tensor.get()] = allocator.alloc(tensor->getBytes());
} else {
tensorToRefCount[tensor.get()] = tensor->getTargets().size();
// allocate memory for all user-created tensors
if (tensor.get()->getSource() == nullptr) {
tensorToOffset[tensor.get()] =
allocator.alloc(tensor->getBytes());
}
}
}
// if memory has not yet been allocated for weight tensors,
// allocate memory now and do not allocate again in the future.
if (!this->weightAllocated) {
this->weightAllocated = true;
// only allocate once for weight tensors
for (auto &tensor : weightTensors) {
IT_ASSERT(tensorToOffset.find(tensor) != tensorToOffset.end());
tensor->setDataBlob(make_ref<BlobObj>(
tensor->runtime,
static_cast<uint8_t *>(allocator.getWeightPtr()) +
tensorToOffset[tensor]));
}
}
// traverse in topological order and simulate memory allocation
for (auto &op : ops) {
// memory should be allocated for the output first
// memory should be allocated for the op's output first
auto outputs = op->getOutputs();
for (auto &tensor : outputs) {
tensorToOffset[tensor.get()] = allocator.alloc(tensor->getBytes());
if (tensor->isOthers()) {
tensorToOffset[tensor.get()] =
allocator.alloc(tensor->getBytes());
}
}
auto inputs = op->getInputs();
for (auto &tensor : inputs) {
if (constTensor.find(tensor.get()) == constTensor.end()) {
if (tensor->isOthers()) {
auto tensorIter = tensorToRefCount.find(tensor.get());
IT_ASSERT(tensorIter != tensorToRefCount.end());
IT_ASSERT(tensorToRefCount[tensor.get()] > 0);
tensorToRefCount[tensor.get()] -= 1;
if (tensorToRefCount[tensor.get()] == 0) {
// indicate that this tensor will no longer be used and
@ -167,15 +200,20 @@ void GraphObj::dataMalloc() {
}
}
// perform actual memory allocation
// perform actual memory allocation for non-weight tensors
for (auto &tensor : tensors) {
IT_ASSERT(tensorToOffset.find(tensor.get()) != tensorToOffset.end());
if (!tensor->isWeight()) {
IT_ASSERT(tensorToOffset.find(tensor.get()) !=
tensorToOffset.end());
tensor->setDataBlob(make_ref<BlobObj>(
tensor->runtime, static_cast<uint8_t *>(allocator.getPtr()) +
tensorToOffset[tensor.get()]));
}
}
#ifdef DEBUG_MODE
allocator.info();
#endif
}
Tensor GraphObj::addTensor(Shape dim, DataType dtype) {

View File

@ -11,9 +11,6 @@ namespace infini {
constexpr size_t alignmentInBytesForCUDA = 256;
LazyAllocator::LazyAllocator(Runtime runtime) : runtime(runtime) {
used = 0;
peak = 0;
ptr = nullptr;
if (runtime->isCuda()) {
// TODO: the alignment on cuda might need further discussion
alignment = alignmentInBytesForCUDA;
@ -30,10 +27,24 @@ LazyAllocator::~LazyAllocator() {
if (this->ptr != nullptr) {
runtime->dealloc(this->ptr);
}
if (this->weightPtr != nullptr) {
runtime->dealloc(this->weightPtr);
}
}
void LazyAllocator::init() {
used = 0;
peak = 0;
freeBlocks.clear();
headAddrToBlockSize.clear();
tailAddrToBlockSize.clear();
if (this->ptr != nullptr) {
runtime->dealloc(this->ptr);
}
this->ptr = nullptr;
}
size_t LazyAllocator::alloc(size_t size) {
IT_ASSERT(this->ptr == nullptr);
// pad the size to the multiple of alignment
size = this->getAlignedSize(size);
auto it = this->freeBlocks.lower_bound(freeBlockInfo{(size_t)0, size});
@ -83,6 +94,14 @@ size_t LazyAllocator::alloc(size_t size) {
return retAddr;
}
size_t LazyAllocator::allocWeight(size_t size) {
IT_ASSERT(this->weightPtr == nullptr);
size = this->getAlignedSize(size);
size_t retAddr = this->weightPeak;
this->weightPeak += size;
return retAddr;
}
void LazyAllocator::free(size_t addr, size_t size) {
IT_ASSERT(this->ptr == nullptr);
size = getAlignedSize(size);
@ -126,18 +145,33 @@ void LazyAllocator::free(size_t addr, size_t size) {
void *LazyAllocator::getPtr() {
if (this->ptr == nullptr) {
this->ptr = runtime->alloc(this->peak);
printf("LazyAllocator really alloc: %p %lu bytes\n", this->ptr, peak);
#ifdef DEBUG_MODE
printf("LazyAllocator really alloc non-weight: %p %lu bytes\n",
this->ptr, peak);
#endif
}
return this->ptr;
}
void *LazyAllocator::getWeightPtr() {
if (this->weightPtr == nullptr) {
this->weightPtr = runtime->alloc(this->weightPeak);
#ifdef DEBUG_MODE
printf("LazyAllocator really alloc weight: %p %lu bytes\n",
this->weightPtr, weightPeak);
#endif
}
return this->weightPtr;
}
size_t LazyAllocator::getAlignedSize(size_t size) {
return ((size - 1) / this->alignment + 1) * this->alignment;
}
void LazyAllocator::info() {
std::cout << "Used memory: " << this->used
<< ", peak memory: " << this->peak << std::endl;
std::cout << "Used memory: " << this->used + this->weightPeak
<< ", peak memory: " << this->peak + this->weightPeak
<< std::endl;
}
} // namespace infini

View File

@ -23,7 +23,7 @@ string TensorObj::toString() const {
string ret = "Tensor " + std::to_string(guid) + ", Fuid " +
std::to_string(fuid) + ", shape " + vecToString(shape) +
", dtype " + dtype.toString() + ", " + runtime->toString() +
", " + ss.str() + "\n";
", " + ss.str() + ", " + tensorTypeToString() + "\n";
vector<UidBaseType> targetGuids;
for (const auto &op : targets)
targetGuids.emplace_back(op.lock()->getGuid());

View File

@ -364,6 +364,9 @@ void init_graph_builder(py::module &m) {
py::buffer_protocol())
.def("fuid", &TensorObj::getFuid, policy::automatic)
.def("shape", &TensorObj::getDims, policy::move)
.def("set_weight", &TensorObj::setWeight, policy::move)
.def("set_input", &TensorObj::setInput, policy::move)
.def("set_output", &TensorObj::setOutput, policy::move)
.def("dtype", &TensorObj::getDTypeIndex, policy::automatic)
.def("copyin_float", &TensorObj::copyin<float>, policy::move)
.def("copyin_int32", &TensorObj::copyin<int32_t>, policy::move)