fix #242
This commit is contained in:
parent
2c31e05e63
commit
00efa8a07f
|
@ -342,6 +342,12 @@ python src/export_model.py \
|
|||
--output_dir path_to_export
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
|
||||
- [ ] Implementing multi-query attention for faster inference.
|
||||
- [ ] Supporting full-parameter RLHF training.
|
||||
|
||||
## License
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
|
|
@ -342,6 +342,12 @@ python src/export_model.py \
|
|||
--output_dir path_to_export
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] 实现 flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention))。
|
||||
- [ ] 在推理阶段使用 Multi-query attention 进行加速。
|
||||
- [ ] 支持 RLHF 的全参数微调。
|
||||
|
||||
## 协议
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
|
|
@ -19,6 +19,14 @@ class ChatModel:
|
|||
generating_args: GeneratingArguments
|
||||
) -> None:
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model, infer_auto_device_map
|
||||
device_map = infer_auto_device_map(self.model)
|
||||
self.model = dispatch_model(self.model, device_map)
|
||||
else:
|
||||
self.model = self.model.cuda()
|
||||
|
||||
self.template = get_template(data_args.prompt_template)
|
||||
self.source_prefix = data_args.source_prefix or ""
|
||||
self.generating_args = generating_args
|
||||
|
@ -32,6 +40,7 @@ class ChatModel:
|
|||
inputs = inputs.to(self.model.device)
|
||||
prompt_length = len(inputs["input_ids"][0])
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", None)
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
|
@ -42,6 +51,7 @@ class ChatModel:
|
|||
gen_kwargs = self.generating_args.to_dict()
|
||||
gen_kwargs.update(dict(
|
||||
input_ids=inputs["input_ids"],
|
||||
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
|
||||
temperature=temperature or gen_kwargs["temperature"],
|
||||
top_p=top_p or gen_kwargs["top_p"],
|
||||
top_k=top_k or gen_kwargs["top_k"],
|
||||
|
|
|
@ -93,9 +93,6 @@ def load_model_and_tokenizer(
|
|||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
if not is_trainable: # `device_map=auto` should be used for inference only
|
||||
config_kwargs["device_map"] = "auto"
|
||||
|
||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue