diff --git a/README.md b/README.md index 23c351df..ea5d79bb 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/) -[![Citation](https://img.shields.io/badge/citation-22-green)](#projects-using-llama-factory) +[![Citation](https://img.shields.io/badge/citation-26-green)](#projects-using-llama-factory) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -261,18 +261,18 @@ huggingface-cli login | ------------ | ------- | --------- | | python | 3.8 | 3.10 | | torch | 1.13.1 | 2.2.0 | -| transformers | 4.37.2 | 4.38.2 | +| transformers | 4.37.2 | 4.39.1 | | datasets | 2.14.3 | 2.17.1 | -| accelerate | 0.27.2 | 0.27.2 | -| peft | 0.9.0 | 0.9.0 | -| trl | 0.7.11 | 0.7.11 | +| accelerate | 0.27.2 | 0.28.0 | +| peft | 0.9.0 | 0.10.0 | +| trl | 0.8.1 | 0.8.1 | | Optional | Minimum | Recommend | | ------------ | ------- | --------- | | CUDA | 11.6 | 12.2 | -| deepspeed | 0.10.0 | 0.13.1 | -| bitsandbytes | 0.39.0 | 0.41.3 | -| flash-attn | 2.3.0 | 2.5.5 | +| deepspeed | 0.10.0 | 0.14.0 | +| bitsandbytes | 0.39.0 | 0.43.0 | +| flash-attn | 2.3.0 | 2.5.6 | ### Hardware Requirement @@ -690,6 +690,7 @@ docker compose -f ./docker-compose.yml up -d 1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223) 1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092) +1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526) 1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816) 1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710) 1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319) @@ -705,6 +706,9 @@ docker compose -f ./docker-compose.yml up -d 1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204) 1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714) 1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043) +1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333) +1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419) +1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228) 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B. 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge. 1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B. diff --git a/README_zh.md b/README_zh.md index 9a7d1eae..e359d960 100644 --- a/README_zh.md +++ b/README_zh.md @@ -5,7 +5,7 @@ [![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Factory)](https://github.com/hiyouga/LLaMA-Factory/commits/main) [![PyPI](https://img.shields.io/pypi/v/llmtuner)](https://pypi.org/project/llmtuner/) [![Downloads](https://static.pepy.tech/badge/llmtuner)](https://pypi.org/project/llmtuner/) -[![Citation](https://img.shields.io/badge/citation-22-green)](#使用了-llama-factory-的项目) +[![Citation](https://img.shields.io/badge/citation-26-green)](#使用了-llama-factory-的项目) [![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)](https://github.com/hiyouga/LLaMA-Factory/pulls) [![Discord](https://dcbadge.vercel.app/api/server/rKfvV9r9FK?compact=true&style=flat)](https://discord.gg/rKfvV9r9FK) [![Twitter](https://img.shields.io/twitter/follow/llamafactory_ai)](https://twitter.com/llamafactory_ai) @@ -261,18 +261,18 @@ huggingface-cli login | ------------ | ------- | --------- | | python | 3.8 | 3.10 | | torch | 1.13.1 | 2.2.0 | -| transformers | 4.37.2 | 4.38.2 | +| transformers | 4.37.2 | 4.39.1 | | datasets | 2.14.3 | 2.17.1 | -| accelerate | 0.27.2 | 0.27.2 | -| peft | 0.9.0 | 0.9.0 | -| trl | 0.7.11 | 0.7.11 | +| accelerate | 0.27.2 | 0.28.0 | +| peft | 0.9.0 | 0.10.0 | +| trl | 0.8.1 | 0.8.1 | | 可选项 | 至少 | 推荐 | | ------------ | ------- | --------- | | CUDA | 11.6 | 12.2 | -| deepspeed | 0.10.0 | 0.13.1 | -| bitsandbytes | 0.39.0 | 0.41.3 | -| flash-attn | 2.3.0 | 2.5.5 | +| deepspeed | 0.10.0 | 0.14.0 | +| bitsandbytes | 0.39.0 | 0.43.0 | +| flash-attn | 2.3.0 | 2.5.6 | ### 硬件依赖 @@ -663,6 +663,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ 1. Wang et al. ESRL: Efficient Sampling-based Reinforcement Learning for Sequence Generation. 2023. [[arxiv]](https://arxiv.org/abs/2308.02223) 1. Yu et al. Open, Closed, or Small Language Models for Text Classification? 2023. [[arxiv]](https://arxiv.org/abs/2308.10092) +1. Wang et al. UbiPhysio: Support Daily Functioning, Fitness, and Rehabilitation with Action Understanding and Feedback in Natural Language. 2023. [[arxiv]](https://arxiv.org/abs/2308.10526) 1. Luceri et al. Leveraging Large Language Models to Detect Influence Campaigns in Social Media. 2023. [[arxiv]](https://arxiv.org/abs/2311.07816) 1. Zhang et al. Alleviating Hallucinations of Large Language Models through Induced Hallucinations. 2023. [[arxiv]](https://arxiv.org/abs/2312.15710) 1. Wang et al. Know Your Needs Better: Towards Structured Understanding of Marketer Demands with Analogical Reasoning Augmented LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2401.04319) @@ -678,6 +679,9 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ 1. Zhang et al. Enhancing Multilingual Capabilities of Large Language Models through Self-Distillation from Resource-Rich Languages. 2024. [[arxiv]](https://arxiv.org/abs/2402.12204) 1. Kim et al. Efficient and Effective Vocabulary Expansion Towards Multilingual Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.14714) 1. Yu et al. KIEval: A Knowledge-grounded Interactive Evaluation Framework for Large Language Models. 2024. [[arxiv]](https://arxiv.org/abs/2402.15043) +1. Huang et al. Key-Point-Driven Data Synthesis with its Enhancement on Mathematical Reasoning. 2024. [[arxiv]](https://arxiv.org/abs/2403.02333) +1. Duan et al. Negating Negatives: Alignment without Human Positive Samples via Distributional Dispreference Optimization. 2024. [[arxiv]](https://arxiv.org/abs/2403.03419) +1. Xie and Schwertfeger. Empowering Robotics with Large Language Models: osmAG Map Comprehension with LLMs. 2024. [[arxiv]](https://arxiv.org/abs/2403.08228) 1. **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。 1. **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。 1. **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。 diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py index fa43f769..7fb9f9d6 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -11,12 +11,14 @@ from transformers.models.llama.modeling_llama import ( repeat_kv, ) from transformers.utils import logging +from transformers.utils.versions import require_version logger = logging.get_logger(__name__) -# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +# Modified from: +# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py def llama_torch_attn_forward( self: "LlamaAttention", hidden_states: torch.Tensor, @@ -24,6 +26,7 @@ def llama_torch_attn_forward( position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional["Cache"] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -36,15 +39,12 @@ def llama_torch_attn_forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -96,14 +96,16 @@ def llama_torch_attn_forward( return attn_output, attn_weights, past_key_value -# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py +# Modified from: +# https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/llama/modeling_llama.py def llama_flash_attn_forward( self: "LlamaFlashAttention2", hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional["Cache"] = None, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # LlamaFlashAttention2 attention does not support output_attentions @@ -120,15 +122,13 @@ def llama_flash_attn_forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -193,5 +193,6 @@ def llama_flash_attn_forward( def apply_llama_patch() -> None: + require_version("transformers==4.39.1", "To fix: pip install transformers==4.39.1") LlamaAttention.forward = llama_torch_attn_forward LlamaFlashAttention2.forward = llama_flash_attn_forward diff --git a/src/llmtuner/model/patcher.py b/src/llmtuner/model/patcher.py index 712599c0..7fd5f573 100644 --- a/src/llmtuner/model/patcher.py +++ b/src/llmtuner/model/patcher.py @@ -183,9 +183,7 @@ def _configure_quantization( quantization_config["use_exllama"] = False # disable exllama if quant_method == QuantizationMethod.AQLM: - require_version( - "transformers>=4.39.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git" - ) + require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") quantization_config["bits"] = 2 @@ -210,6 +208,11 @@ def _configure_quantization( logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) elif model_args.quantization_bit is not None: # bnb + if is_deepspeed_zero3_enabled(): + require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") + require_version("accelerate>=0.28.0", "To fix: pip install accelerate>=0.28.0") + require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") + if model_args.quantization_bit == 8: require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)