Support kvcache (#134)

* add cmake bits about NCCL

* move example to examples/NNmodel

* impl NCCL communicator

* add comm related function to Runtime

* export runtime interface

* add launch.py

* use unique name to distingush the the NCCL ID file

* add timeout to communicator init

* expose communicator obj from runtime obj, add unit test for nccl communicator

* reformat files

* Add allReduce operator and cuda nccl allReduce kernel

* impl model parallel for resnet

* add allGather nccl kernel and operator

* Add allreduce allgather operator tests, change allgather kernel to output list of tensor, fix shape infer, handle nullptr output

* fix format of onnx.py

* use concat following AllGather

* get tensor parallel for resnet

* fix format of graph_handler.cc

* change BUILD_DIST default to OFF

* polish code of communicator

* update .gitignore

* export min/max to python

* fix MatMul

* modify launch.py to run opt

* hack to treat ReduceSum as AllReduceSum

* throw exception in cuda error

* fix parallel_opt.py

* improve the error prompt and cuda error check

* fix GatherObj::GatherObj member init

* fix size calculation for scalar (rank = 0) tensor

* MatMul supports bias

* fix add bias for row parallel gemm

* add --gen_std to launch.py

* fix AllReduceNCCL

* update launch.py

* less log

* update parallel_opt

* update launch.py

* add __eq__ for Placement sub-classes

* less benchmark run

* fix placement infer for matmul

* fix vacabuary size

* fix Exception

* Add shard tensor with group to support gpt2

* Add find successor function to find split op at different depth

* recover CommunicatorObj

* improve error mesasge

* optimize parallel_opt.py

* optimize launch.py

* recover docs for all_reduce and all_gather

* - support concat for kvcache

* - modify allocator

* - add tensorType
- modify allocator to support memory allocation based on tensorType

* - fix allocator init

* - support kvcache by running 2 stub distributively

* - fix name

* - remove unused flag

* - fix wrong pb name

* - fix as constroy suggessed

* - fix launch.py format

---------

Co-authored-by: constroy <constroy.li@gmail.com>
Co-authored-by: panzezhong <panzezhong@qiyuanlab.com>
This commit is contained in:
kilinchange 2023-09-18 14:17:02 +08:00 committed by GitHub
parent c6b82cfda0
commit 48ec730579
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 447 additions and 38 deletions

View File

@ -0,0 +1,245 @@
import argparse
import os
import time
import multiprocessing as mp
from pyinfinitensor.onnx import OnnxStub, backend
import onnx
from onnx.external_data_helper import convert_model_to_external_data
import numpy as np
from parallel_opt import parallel_model
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
def parse_args():
parser = argparse.ArgumentParser(description="launch distributed infinitensor")
parser.add_argument("--num_nodes", type=int, default=1, help="number of nodes")
parser.add_argument(
"--nproc_per_node", type=int, default=1, help="number of processes per node"
)
parser.add_argument(
"--name", type=str, default="test", help="name of this instance."
)
parser.add_argument(
"--model1", type=str, required=True, help="path to the ONNX model file."
)
parser.add_argument(
"--model2", type=str, required=True, help="path to the ONNX model file."
)
parser.add_argument("--batch_size", type=int, default=1, help="batch size.")
parser.add_argument("--length", type=int, default=1, help="sequence length.")
parser.add_argument(
"--gen_std",
action="store_true",
help="whether to generate the standard results.",
)
args = parser.parse_args()
print("arg setting: ", args)
return (
args.num_nodes,
args.nproc_per_node,
args.name,
args.model1,
args.model2,
args.batch_size,
args.length,
args.gen_std,
)
def run_model(model1, model2, runtime1, runtime2, inputs1: np.array, inputs2: np.array, n=20):
####################################
# run the first graph without kvcache
####################################
stub1 = OnnxStub(model1, runtime1)
stub1.inputs['onnx::Reshape_0'].copyin_int32(inputs1.reshape(-1).tolist())
stub1.tune()
stub1.run()
kvcache_it1 = []
count = 0
for output in stub1.outputs.items().__iter__():
if count == 0:
logits_it1 = np.array(output[1].copyout_float(), dtype=np.float32)
else:
kvcache_it1.append(np.array(output[1].copyout_float(), dtype=np.float32))
count = count + 1
# bench for stub1
next(stub1.inputs.items().__iter__())[1].copyin_int32(inputs1.reshape(-1).tolist())
begin = time.time()
for _ in range(n):
stub1.run()
end = time.time()
avg_time = (end - begin) / n
print(f"stub1 average time: {avg_time}")
####################################
# run the second graph with kvcache
####################################
i = 0
batchsize = 1
stub2 = OnnxStub(model2, runtime2)
past_kvcache_length = (i+2)*np.ones((batchsize, 1), dtype=np.int32)
# copyin input
stub2.inputs['onnx::Reshape_0'].copyin_int32(inputs2.reshape(-1).tolist())
stub2.inputs['input.3'].copyin_int32(past_kvcache_length.reshape(-1).tolist())
count = -1
for input in stub2.inputs.items().__iter__():
if count in range(24):
# print(count, input[0])
# print(np.dtype(kvcache_it1[count][0]), kvcache_it1[count].shape)
input[1].copyin_float(kvcache_it1[count].reshape(-1).tolist())
count = count + 1
stub2.tune()
stub2.run()
# copyout output
count = 0
kvcache_it2 = []
for output in stub2.outputs.items().__iter__():
if count == 0:
logits_it2 = np.array(output[1].copyout_float(), dtype=np.float32)
else:
kvcache_it2.append(np.array(output[1].copyout_float(), dtype=np.float32))
count = count + 1
# bench for stub2
# copyin input
stub2.inputs['onnx::Reshape_0'].copyin_int32(inputs2.reshape(-1).tolist())
stub2.inputs['input.3'].copyin_int32(past_kvcache_length.reshape(-1).tolist())
count = -1
for input in stub2.inputs.items().__iter__():
if count in range(24):
input[1].copyin_float(kvcache_it1[count].reshape(-1).tolist())
count = count + 1
begin = time.time()
for _ in range(n):
stub2.run()
end = time.time()
avg_time = (end - begin) / n
print(f"stub2 average time: {avg_time}")
return logits_it2
def run_and_compare(name, model1, model2, runtime1, runtime2):
data1 = np.load(f"{name}_inputs1.npy")
data2 = np.load(f"{name}_inputs2.npy")
results = np.load(f"{name}_results.npy")
outputs = run_model(model1, model2, runtime1, runtime2, data1, data2)
print("outputs sum:", outputs.sum())
print("max abs diff:", abs(outputs - results).max())
print("max rel diff:", abs((outputs - results) / results).max())
# assert np.allclose(outputs, results, rtol=1e-3, atol=1e-6)
def start_worker(
name: str, world_size: int, rank: int, local_rank: int, model1: onnx.ModelProto, model2: onnx.ModelProto
):
dist_name = name + "_dist"
####################################
# shard the first graph
####################################
model1 = parallel_model(model1, world_size, rank)
extern_path = f"./{dist_name}_stub1_rank{rank}.pb"
if os.path.exists(extern_path):
os.remove(extern_path)
convert_model_to_external_data(
model1,
all_tensors_to_one_file=True,
location=extern_path,
size_threshold=1024,
convert_attribute=False,
)
onnx.save(model1, f"./{dist_name}_stub1_rank{rank}.onnx")
runtime1 = backend.CudaRuntime(local_rank)
runtime1.init_comm(
dist_name,
world_size,
rank,
)
####################################
# shard the second graph
####################################
model2 = parallel_model(model2, world_size, rank)
extern_path = f"./{dist_name}_stub2_rank{rank}.pb"
if os.path.exists(extern_path):
os.remove(extern_path)
convert_model_to_external_data(
model2,
all_tensors_to_one_file=True,
location=extern_path,
size_threshold=1024,
convert_attribute=False,
)
onnx.save(model2, f"./{dist_name}_stub2_rank{rank}.onnx")
runtime2 = backend.CudaRuntime(local_rank)
# print("init comm")
runtime2.init_comm(
dist_name,
world_size,
rank,
)
# run the two graphs
run_and_compare(name, model1, model2, runtime1, runtime2)
def start_single(name, model1, model2):
runtime1 = backend.CudaRuntime(0)
runtime2 = backend.CudaRuntime(0)
run_and_compare(name, model1, model2, runtime1, runtime2)
def gen_standard(name, model1, model2, voc_size, bs, len):
# generate standard results
data1 = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
data2 = np.random.randint(0, voc_size, (bs, len), dtype=np.int32)
np.save(f"{name}_inputs1", data1)
np.save(f"{name}_inputs2", data2)
runtime1 = backend.CudaRuntime(0)
runtime2 = backend.CudaRuntime(0)
outputs = run_model(model1, model2, runtime1, runtime2, data1, data2, 1)
np.save(f"{name}_results", outputs)
def main():
nnodes, nproc_per_node, name, model1_path, model2_path, bs, length, gen_std = parse_args()
model1 = onnx.load(model1_path)
model2 = onnx.load(model2_path)
# generate standart output
if gen_std:
print(f"generate standard data for {name}.")
# a small vocabulary size to fit all LLM.
voc_size = 1000
gen_standard(name, model1, model2, voc_size, bs, length)
return
# run single process.
# use standalone process to isolate cuda.
p = mp.Process(target=start_single, args=(name, model1, model2))
p.start()
p.join()
# run distributed parallel.
world_size = nnodes * nproc_per_node
workers = [
mp.Process(
target=start_worker,
args=(name, world_size, rank, rank % nproc_per_node, model1, model2),
)
for rank in range(world_size)
]
for w in workers:
w.start()
for w in workers:
w.join()
if __name__ == "__main__":
main()

View File

@ -56,6 +56,16 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
ndim = len(vinfo[output].type.tensor_type.shape.dim) ndim = len(vinfo[output].type.tensor_type.shape.dim)
out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial() out_plc = Shard(ndim - 1) if in_plc.is_replicate() else _Partial()
place[node.output[0]] = out_plc place[node.output[0]] = out_plc
def shard_concat(node: NodeProto):
# hack for kvcache
in_plc = place[node.input[1]]
if in_plc.is_sharded():
seq_len_dim = vinfo[node.input[0]].type.tensor_type.shape.dim.pop(1)
seq_len_dim.dim_value //= tp_world_size
vinfo[node.input[0]].type.tensor_type.shape.dim.insert(1, seq_len_dim)
place[node.input[0]] = in_plc
place[node.output[0]] = in_plc
def shard_binary(node: NodeProto, groups: int = 1): def shard_binary(node: NodeProto, groups: int = 1):
# print("binary", node.name, node.input[0], place[node.input[0]]) # print("binary", node.name, node.input[0], place[node.input[0]])
@ -143,6 +153,8 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
place[node.input[0]] == place[node.input[1]] place[node.input[0]] == place[node.input[1]]
), f"{place[node.input[0]]} != {place[node.input[1]]}" ), f"{place[node.input[0]]} != {place[node.input[1]]}"
place[node.output[0]] = place[node.input[0]] place[node.output[0]] = place[node.input[0]]
elif node.op_type == "Concat":
shard_concat(node)
def find_successor(op_type: str, idx: int, search_limit: int = 1): def find_successor(op_type: str, idx: int, search_limit: int = 1):
for node in model.graph.node[idx + 1 : idx + 1 + search_limit]: for node in model.graph.node[idx + 1 : idx + 1 + search_limit]:
@ -203,10 +215,14 @@ def parallel_model(model: ModelProto, tp_world_size: int = 1, tp_rank: int = 0):
continue continue
shard_node(node) shard_node(node)
new_input = []
for info in model.graph.input:
new_input.append(vinfo[info.name])
graph = helper.make_graph( graph = helper.make_graph(
nodes, nodes,
model.graph.name + f"_{tp_rank}", model.graph.name + f"_{tp_rank}",
model.graph.input, new_input,
model.graph.output, model.graph.output,
data.values(), data.values(),
doc_string=model.graph.doc_string, doc_string=model.graph.doc_string,

View File

@ -120,6 +120,11 @@ class GraphObj : public Object {
* @brief If the nodes is sorted in topological order. * @brief If the nodes is sorted in topological order.
*/ */
bool sorted; bool sorted;
/**
* @brief If the weight tensors are allocated.
*/
bool weightAllocated = false;
}; };
} // namespace infini } // namespace infini

View File

@ -20,14 +20,23 @@ class LazyAllocator {
Runtime runtime; Runtime runtime;
size_t used; size_t used = 0;
size_t peak; size_t peak = 0;
size_t weightPeak = 0;
size_t alignment; size_t alignment;
// pointer to the memory actually allocated // pointer to the memory actually allocated
void *ptr; void *ptr = nullptr;
// pointer to the weight memory space
void *weightPtr = nullptr;
// // a cache designed for a batch size that has already occurred
// std::unordered_map<size_t, std::unordered_map<TensorObj *, size_t>>
// batchsizeToTensorOffset;
struct freeBlockInfo { struct freeBlockInfo {
size_t addr; size_t addr;
@ -57,12 +66,16 @@ class LazyAllocator {
virtual ~LazyAllocator(); virtual ~LazyAllocator();
void init();
// function: simulate memory allocation // function: simulate memory allocation
// arguments // arguments
// size: size of memory block to be allocated // size: size of memory block to be allocated
// return: head address offset of the allocated memory block // return: head address offset of the allocated memory block
size_t alloc(size_t size); size_t alloc(size_t size);
size_t allocWeight(size_t size);
// function: simulate memory free // function: simulate memory free
// arguments: // arguments:
// addr: head address offset of memory block to be free // addr: head address offset of memory block to be free
@ -73,6 +86,12 @@ class LazyAllocator {
// return: pointer to the head address of the allocated memory // return: pointer to the head address of the allocated memory
void *getPtr(); void *getPtr();
// void addCache(size_t batchsize, std::unordered_map<TensorObj *, size_t>);
// std::unordered_map<TensorObj *, size_t> getCache(size_t batchsize);
void *getWeightPtr();
void info(); void info();
private: private:

View File

@ -1,5 +1,6 @@
#pragma once #pragma once
#include "core/tensor_base.h" #include "core/tensor_base.h"
#include "core/tensor_type.h"
#include "utils/data_convert.h" #include "utils/data_convert.h"
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
@ -19,6 +20,8 @@ class TensorObj : public TensorBaseObj {
size_t _size; // Cache of Π(shape). size_t _size; // Cache of Π(shape).
Fuid fuid; // Cloned tensors share the same id. Tensors constructed from Fuid fuid; // Cloned tensors share the same id. Tensors constructed from
// scratch have a new id. // scratch have a new id.
TensorType tensorType = TensorType::others;
public: public:
TensorObj(Shape shape, DataType dtype, Runtime runtime); TensorObj(Shape shape, DataType dtype, Runtime runtime);
virtual ~TensorObj() {} virtual ~TensorObj() {}
@ -33,6 +36,33 @@ class TensorObj : public TensorBaseObj {
size_t getOffset(const vector<int> &ds) const; size_t getOffset(const vector<int> &ds) const;
void dataMalloc(); void dataMalloc();
UidBaseType getFuid() const { return fuid; } UidBaseType getFuid() const { return fuid; }
bool isWeight() const { return tensorType == TensorType::weight; }
bool isInput() const { return tensorType == TensorType::input; }
bool isOutput() const { return tensorType == TensorType::output; }
bool isOthers() const { return tensorType == TensorType::others; }
void setWeight() { tensorType = TensorType::weight; }
void setInput() { tensorType = TensorType::input; }
void setOutput() { tensorType = TensorType::output; }
string tensorTypeToString() const {
switch (tensorType) {
case TensorType::weight:
return "weight";
break;
case TensorType::input:
return "input";
break;
case TensorType::output:
return "output";
break;
case TensorType::others:
return "others";
break;
default:
return "unknown tensor type";
break;
}
}
void load(std::string file_path); void load(std::string file_path);
void save(std::string file_path); void save(std::string file_path);

View File

@ -0,0 +1,7 @@
#pragma once
namespace infini {
enum class TensorType { weight, input, output, others };
} // namespace infini

View File

@ -32,24 +32,27 @@ class OnnxStub:
The Onnx model imported into infinitensor. The Onnx model imported into infinitensor.
It can be generated from an Onnx model object. It can be generated from an Onnx model object.
""" """
inputs: Dict[str, backend.Tensor] = {}
outputs: Dict[str, backend.Tensor] = {}
initializer: Dict[int, TensorProto] = {}
handler: backend.GraphHandler
def __init__(self, model: ModelProto, runtime): def __init__(self, model: ModelProto, runtime):
self.inputs: Dict[str, backend.Tensor] = {}
self.outputs: Dict[str, backend.Tensor] = {}
self.initializer: Dict[int, TensorProto] = {}
model = infer_shapes(model) model = infer_shapes(model)
self.handler = backend.GraphHandler(runtime) self.handler = backend.GraphHandler(runtime)
tensors: Dict[str, backend.Tensor] = dict() tensors: Dict[str, backend.Tensor] = dict()
data: Dict[str, TensorProto] = dict() data: Dict[str, TensorProto] = dict()
for initializer in model.graph.initializer:
dims = [d for d in initializer.dims]
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
data[initializer.name] = initializer
for input in model.graph.input: for input in model.graph.input:
dims = _take_shape_dim(input.type.tensor_type.shape) dims = _take_shape_dim(input.type.tensor_type.shape)
tensors[input.name] = self.handler.tensor( if input.name not in tensors.keys():
dims, input.type.tensor_type.elem_type tensors[input.name] = self.handler.tensor(
) dims, input.type.tensor_type.elem_type
)
for output in model.graph.output: for output in model.graph.output:
dims = _take_shape_dim(output.type.tensor_type.shape) dims = _take_shape_dim(output.type.tensor_type.shape)
@ -57,10 +60,6 @@ class OnnxStub:
dims, output.type.tensor_type.elem_type dims, output.type.tensor_type.elem_type
) )
for initializer in model.graph.initializer:
dims = [d for d in initializer.dims]
tensors[initializer.name] = self.handler.tensor(dims, initializer.data_type)
data[initializer.name] = initializer
node_name = [] node_name = []
new_node_name = [] new_node_name = []
@ -667,6 +666,19 @@ class OnnxStub:
# update the node_list # update the node_list
node_list = list(set(node_name) - set(new_node_name)) node_list = list(set(node_name) - set(new_node_name))
################################
# Set tensor type
################################
for initializer in model.graph.initializer:
tensors[initializer.name].set_weight()
for input in model.graph.input:
tensors[input.name].set_input()
for output in model.graph.output:
tensors[output.name].set_output()
################################ ################################
# Allocate memory space for data # Allocate memory space for data
################################ ################################

View File

@ -131,30 +131,63 @@ void GraphObj::dataMalloc() {
// record the memory address offsets of all tensors to be allocated // record the memory address offsets of all tensors to be allocated
std::unordered_map<TensorObj *, size_t> tensorToOffset; std::unordered_map<TensorObj *, size_t> tensorToOffset;
// record all constant tensors, including weight tensors and input tensors // reinit allocator
std::unordered_set<TensorObj *> constTensor; allocator.init();
// record all weight tensors, including weight tensors and kvcache
// tensors
std::unordered_set<TensorObj *> weightTensors;
for (auto &tensor : tensors) { for (auto &tensor : tensors) {
if (tensor.get()->getSource() == nullptr) { if (tensor->isWeight()) {
// allocate memory for all constant tensors first, and this memory // allocate memory for all weight tensors first, and this memory
// will not be freed until the graph is destroyed
weightTensors.insert(tensor.get());
if (!this->weightAllocated) {
tensorToOffset[tensor.get()] =
allocator.allocWeight(tensor->getBytes());
}
} else if (tensor->isInput() || tensor->isOutput()) {
// allocate memory for all input and output tensors, and this memory
// will not be reused later // will not be reused later
constTensor.insert(tensor.get());
tensorToOffset[tensor.get()] = allocator.alloc(tensor->getBytes()); tensorToOffset[tensor.get()] = allocator.alloc(tensor->getBytes());
} else { } else {
tensorToRefCount[tensor.get()] = tensor->getTargets().size(); tensorToRefCount[tensor.get()] = tensor->getTargets().size();
// allocate memory for all user-created tensors
if (tensor.get()->getSource() == nullptr) {
tensorToOffset[tensor.get()] =
allocator.alloc(tensor->getBytes());
}
}
}
// if memory has not yet been allocated for weight tensors,
// allocate memory now and do not allocate again in the future.
if (!this->weightAllocated) {
this->weightAllocated = true;
// only allocate once for weight tensors
for (auto &tensor : weightTensors) {
IT_ASSERT(tensorToOffset.find(tensor) != tensorToOffset.end());
tensor->setDataBlob(make_ref<BlobObj>(
tensor->runtime,
static_cast<uint8_t *>(allocator.getWeightPtr()) +
tensorToOffset[tensor]));
} }
} }
// traverse in topological order and simulate memory allocation // traverse in topological order and simulate memory allocation
for (auto &op : ops) { for (auto &op : ops) {
// memory should be allocated for the output first // memory should be allocated for the op's output first
auto outputs = op->getOutputs(); auto outputs = op->getOutputs();
for (auto &tensor : outputs) { for (auto &tensor : outputs) {
tensorToOffset[tensor.get()] = allocator.alloc(tensor->getBytes()); if (tensor->isOthers()) {
tensorToOffset[tensor.get()] =
allocator.alloc(tensor->getBytes());
}
} }
auto inputs = op->getInputs(); auto inputs = op->getInputs();
for (auto &tensor : inputs) { for (auto &tensor : inputs) {
if (constTensor.find(tensor.get()) == constTensor.end()) { if (tensor->isOthers()) {
auto tensorIter = tensorToRefCount.find(tensor.get()); auto tensorIter = tensorToRefCount.find(tensor.get());
IT_ASSERT(tensorIter != tensorToRefCount.end()); IT_ASSERT(tensorIter != tensorToRefCount.end());
IT_ASSERT(tensorToRefCount[tensor.get()] > 0);
tensorToRefCount[tensor.get()] -= 1; tensorToRefCount[tensor.get()] -= 1;
if (tensorToRefCount[tensor.get()] == 0) { if (tensorToRefCount[tensor.get()] == 0) {
// indicate that this tensor will no longer be used and // indicate that this tensor will no longer be used and
@ -167,15 +200,20 @@ void GraphObj::dataMalloc() {
} }
} }
// perform actual memory allocation // perform actual memory allocation for non-weight tensors
for (auto &tensor : tensors) { for (auto &tensor : tensors) {
IT_ASSERT(tensorToOffset.find(tensor.get()) != tensorToOffset.end()); if (!tensor->isWeight()) {
tensor->setDataBlob(make_ref<BlobObj>( IT_ASSERT(tensorToOffset.find(tensor.get()) !=
tensor->runtime, static_cast<uint8_t *>(allocator.getPtr()) + tensorToOffset.end());
tensorToOffset[tensor.get()])); tensor->setDataBlob(make_ref<BlobObj>(
tensor->runtime, static_cast<uint8_t *>(allocator.getPtr()) +
tensorToOffset[tensor.get()]));
}
} }
#ifdef DEBUG_MODE
allocator.info(); allocator.info();
#endif
} }
Tensor GraphObj::addTensor(Shape dim, DataType dtype) { Tensor GraphObj::addTensor(Shape dim, DataType dtype) {

View File

@ -11,9 +11,6 @@ namespace infini {
constexpr size_t alignmentInBytesForCUDA = 256; constexpr size_t alignmentInBytesForCUDA = 256;
LazyAllocator::LazyAllocator(Runtime runtime) : runtime(runtime) { LazyAllocator::LazyAllocator(Runtime runtime) : runtime(runtime) {
used = 0;
peak = 0;
ptr = nullptr;
if (runtime->isCuda()) { if (runtime->isCuda()) {
// TODO: the alignment on cuda might need further discussion // TODO: the alignment on cuda might need further discussion
alignment = alignmentInBytesForCUDA; alignment = alignmentInBytesForCUDA;
@ -30,10 +27,24 @@ LazyAllocator::~LazyAllocator() {
if (this->ptr != nullptr) { if (this->ptr != nullptr) {
runtime->dealloc(this->ptr); runtime->dealloc(this->ptr);
} }
if (this->weightPtr != nullptr) {
runtime->dealloc(this->weightPtr);
}
}
void LazyAllocator::init() {
used = 0;
peak = 0;
freeBlocks.clear();
headAddrToBlockSize.clear();
tailAddrToBlockSize.clear();
if (this->ptr != nullptr) {
runtime->dealloc(this->ptr);
}
this->ptr = nullptr;
} }
size_t LazyAllocator::alloc(size_t size) { size_t LazyAllocator::alloc(size_t size) {
IT_ASSERT(this->ptr == nullptr);
// pad the size to the multiple of alignment // pad the size to the multiple of alignment
size = this->getAlignedSize(size); size = this->getAlignedSize(size);
auto it = this->freeBlocks.lower_bound(freeBlockInfo{(size_t)0, size}); auto it = this->freeBlocks.lower_bound(freeBlockInfo{(size_t)0, size});
@ -83,6 +94,14 @@ size_t LazyAllocator::alloc(size_t size) {
return retAddr; return retAddr;
} }
size_t LazyAllocator::allocWeight(size_t size) {
IT_ASSERT(this->weightPtr == nullptr);
size = this->getAlignedSize(size);
size_t retAddr = this->weightPeak;
this->weightPeak += size;
return retAddr;
}
void LazyAllocator::free(size_t addr, size_t size) { void LazyAllocator::free(size_t addr, size_t size) {
IT_ASSERT(this->ptr == nullptr); IT_ASSERT(this->ptr == nullptr);
size = getAlignedSize(size); size = getAlignedSize(size);
@ -126,18 +145,33 @@ void LazyAllocator::free(size_t addr, size_t size) {
void *LazyAllocator::getPtr() { void *LazyAllocator::getPtr() {
if (this->ptr == nullptr) { if (this->ptr == nullptr) {
this->ptr = runtime->alloc(this->peak); this->ptr = runtime->alloc(this->peak);
printf("LazyAllocator really alloc: %p %lu bytes\n", this->ptr, peak); #ifdef DEBUG_MODE
printf("LazyAllocator really alloc non-weight: %p %lu bytes\n",
this->ptr, peak);
#endif
} }
return this->ptr; return this->ptr;
} }
void *LazyAllocator::getWeightPtr() {
if (this->weightPtr == nullptr) {
this->weightPtr = runtime->alloc(this->weightPeak);
#ifdef DEBUG_MODE
printf("LazyAllocator really alloc weight: %p %lu bytes\n",
this->weightPtr, weightPeak);
#endif
}
return this->weightPtr;
}
size_t LazyAllocator::getAlignedSize(size_t size) { size_t LazyAllocator::getAlignedSize(size_t size) {
return ((size - 1) / this->alignment + 1) * this->alignment; return ((size - 1) / this->alignment + 1) * this->alignment;
} }
void LazyAllocator::info() { void LazyAllocator::info() {
std::cout << "Used memory: " << this->used std::cout << "Used memory: " << this->used + this->weightPeak
<< ", peak memory: " << this->peak << std::endl; << ", peak memory: " << this->peak + this->weightPeak
<< std::endl;
} }
} // namespace infini } // namespace infini

View File

@ -23,7 +23,7 @@ string TensorObj::toString() const {
string ret = "Tensor " + std::to_string(guid) + ", Fuid " + string ret = "Tensor " + std::to_string(guid) + ", Fuid " +
std::to_string(fuid) + ", shape " + vecToString(shape) + std::to_string(fuid) + ", shape " + vecToString(shape) +
", dtype " + dtype.toString() + ", " + runtime->toString() + ", dtype " + dtype.toString() + ", " + runtime->toString() +
", " + ss.str() + "\n"; ", " + ss.str() + ", " + tensorTypeToString() + "\n";
vector<UidBaseType> targetGuids; vector<UidBaseType> targetGuids;
for (const auto &op : targets) for (const auto &op : targets)
targetGuids.emplace_back(op.lock()->getGuid()); targetGuids.emplace_back(op.lock()->getGuid());

View File

@ -364,6 +364,9 @@ void init_graph_builder(py::module &m) {
py::buffer_protocol()) py::buffer_protocol())
.def("fuid", &TensorObj::getFuid, policy::automatic) .def("fuid", &TensorObj::getFuid, policy::automatic)
.def("shape", &TensorObj::getDims, policy::move) .def("shape", &TensorObj::getDims, policy::move)
.def("set_weight", &TensorObj::setWeight, policy::move)
.def("set_input", &TensorObj::setInput, policy::move)
.def("set_output", &TensorObj::setOutput, policy::move)
.def("dtype", &TensorObj::getDTypeIndex, policy::automatic) .def("dtype", &TensorObj::getDTypeIndex, policy::automatic)
.def("copyin_float", &TensorObj::copyin<float>, policy::move) .def("copyin_float", &TensorObj::copyin<float>, policy::move)
.def("copyin_int32", &TensorObj::copyin<int32_t>, policy::move) .def("copyin_int32", &TensorObj::copyin<int32_t>, policy::move)