From 00efa8a07fe5a69bac545675696b2a19b7b811ed Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 25 Jul 2023 17:04:02 +0800 Subject: [PATCH] fix #242 --- README.md | 6 ++++++ README_zh.md | 6 ++++++ src/llmtuner/chat/stream_chat.py | 10 ++++++++++ src/llmtuner/tuner/core/loader.py | 3 --- 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b19b3143..675b0c7d 100644 --- a/README.md +++ b/README.md @@ -342,6 +342,12 @@ python src/export_model.py \ --output_dir path_to_export ``` +## TODO + +- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)). +- [ ] Implementing multi-query attention for faster inference. +- [ ] Supporting full-parameter RLHF training. + ## License This repository is licensed under the [Apache-2.0 License](LICENSE). diff --git a/README_zh.md b/README_zh.md index 73b50e95..bb9236a4 100644 --- a/README_zh.md +++ b/README_zh.md @@ -342,6 +342,12 @@ python src/export_model.py \ --output_dir path_to_export ``` +## TODO + +- [ ] 实现 flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention))。 +- [ ] 在推理阶段使用 Multi-query attention 进行加速。 +- [ ] 支持 RLHF 的全参数微调。 + ## 协议 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。 diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index abd0a446..a4b46dd6 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -19,6 +19,14 @@ class ChatModel: generating_args: GeneratingArguments ) -> None: self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) + + if torch.cuda.device_count() > 1: + from accelerate import dispatch_model, infer_auto_device_map + device_map = infer_auto_device_map(self.model) + self.model = dispatch_model(self.model, device_map) + else: + self.model = self.model.cuda() + self.template = get_template(data_args.prompt_template) self.source_prefix = data_args.source_prefix or "" self.generating_args = generating_args @@ -32,6 +40,7 @@ class ChatModel: inputs = inputs.to(self.model.device) prompt_length = len(inputs["input_ids"][0]) + do_sample = input_kwargs.pop("do_sample", None) temperature = input_kwargs.pop("temperature", None) top_p = input_kwargs.pop("top_p", None) top_k = input_kwargs.pop("top_k", None) @@ -42,6 +51,7 @@ class ChatModel: gen_kwargs = self.generating_args.to_dict() gen_kwargs.update(dict( input_ids=inputs["input_ids"], + do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"], temperature=temperature or gen_kwargs["temperature"], top_p=top_p or gen_kwargs["top_p"], top_k=top_k or gen_kwargs["top_k"], diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/tuner/core/loader.py index a4e2e7ea..31509b72 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/tuner/core/loader.py @@ -93,9 +93,6 @@ def load_model_and_tokenizer( config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) - if not is_trainable: # `device_map=auto` should be used for inference only - config_kwargs["device_map"] = "auto" - if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": model_to_load = model_args.checkpoint_dir[0] else: