fix #242
This commit is contained in:
parent
2c31e05e63
commit
00efa8a07f
|
@ -342,6 +342,12 @@ python src/export_model.py \
|
||||||
--output_dir path_to_export
|
--output_dir path_to_export
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
|
||||||
|
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
|
||||||
|
- [ ] Implementing multi-query attention for faster inference.
|
||||||
|
- [ ] Supporting full-parameter RLHF training.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||||
|
|
|
@ -342,6 +342,12 @@ python src/export_model.py \
|
||||||
--output_dir path_to_export
|
--output_dir path_to_export
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
|
||||||
|
- [ ] 实现 flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention))。
|
||||||
|
- [ ] 在推理阶段使用 Multi-query attention 进行加速。
|
||||||
|
- [ ] 支持 RLHF 的全参数微调。
|
||||||
|
|
||||||
## 协议
|
## 协议
|
||||||
|
|
||||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||||
|
|
|
@ -19,6 +19,14 @@ class ChatModel:
|
||||||
generating_args: GeneratingArguments
|
generating_args: GeneratingArguments
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||||
|
|
||||||
|
if torch.cuda.device_count() > 1:
|
||||||
|
from accelerate import dispatch_model, infer_auto_device_map
|
||||||
|
device_map = infer_auto_device_map(self.model)
|
||||||
|
self.model = dispatch_model(self.model, device_map)
|
||||||
|
else:
|
||||||
|
self.model = self.model.cuda()
|
||||||
|
|
||||||
self.template = get_template(data_args.prompt_template)
|
self.template = get_template(data_args.prompt_template)
|
||||||
self.source_prefix = data_args.source_prefix or ""
|
self.source_prefix = data_args.source_prefix or ""
|
||||||
self.generating_args = generating_args
|
self.generating_args = generating_args
|
||||||
|
@ -32,6 +40,7 @@ class ChatModel:
|
||||||
inputs = inputs.to(self.model.device)
|
inputs = inputs.to(self.model.device)
|
||||||
prompt_length = len(inputs["input_ids"][0])
|
prompt_length = len(inputs["input_ids"][0])
|
||||||
|
|
||||||
|
do_sample = input_kwargs.pop("do_sample", None)
|
||||||
temperature = input_kwargs.pop("temperature", None)
|
temperature = input_kwargs.pop("temperature", None)
|
||||||
top_p = input_kwargs.pop("top_p", None)
|
top_p = input_kwargs.pop("top_p", None)
|
||||||
top_k = input_kwargs.pop("top_k", None)
|
top_k = input_kwargs.pop("top_k", None)
|
||||||
|
@ -42,6 +51,7 @@ class ChatModel:
|
||||||
gen_kwargs = self.generating_args.to_dict()
|
gen_kwargs = self.generating_args.to_dict()
|
||||||
gen_kwargs.update(dict(
|
gen_kwargs.update(dict(
|
||||||
input_ids=inputs["input_ids"],
|
input_ids=inputs["input_ids"],
|
||||||
|
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
|
||||||
temperature=temperature or gen_kwargs["temperature"],
|
temperature=temperature or gen_kwargs["temperature"],
|
||||||
top_p=top_p or gen_kwargs["top_p"],
|
top_p=top_p or gen_kwargs["top_p"],
|
||||||
top_k=top_k or gen_kwargs["top_k"],
|
top_k=top_k or gen_kwargs["top_k"],
|
||||||
|
|
|
@ -93,9 +93,6 @@ def load_model_and_tokenizer(
|
||||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||||
|
|
||||||
if not is_trainable: # `device_map=auto` should be used for inference only
|
|
||||||
config_kwargs["device_map"] = "auto"
|
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||||
model_to_load = model_args.checkpoint_dir[0]
|
model_to_load = model_args.checkpoint_dir[0]
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue