fix #3324
This commit is contained in:
parent
3b43a3b7c5
commit
942362d008
|
@ -46,7 +46,7 @@ Choose your path:
|
||||||
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
- **Various models**: LLaMA, Mistral, Mixtral-MoE, Qwen, Yi, Gemma, Baichuan, ChatGLM, Phi, etc.
|
||||||
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
|
- **Integrated methods**: (Continuous) pre-training, supervised fine-tuning, reward modeling, PPO, DPO and ORPO.
|
||||||
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
- **Scalable resources**: 32-bit full-tuning, 16-bit freeze-tuning, 16-bit LoRA and 2/4/8-bit QLoRA via AQLM/AWQ/GPTQ/LLM.int8.
|
||||||
- **Advanced algorithms**: GaLore, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
|
- **Advanced algorithms**: GaLore, BAdam, DoRA, LongLoRA, LLaMA Pro, LoRA+, LoftQ and Agent tuning.
|
||||||
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
- **Practical tricks**: FlashAttention-2, Unsloth, RoPE scaling, NEFTune and rsLoRA.
|
||||||
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
- **Experiment monitors**: LlamaBoard, TensorBoard, Wandb, MLflow, etc.
|
||||||
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
- **Faster inference**: OpenAI-style API, Gradio UI and CLI with vLLM worker.
|
||||||
|
|
|
@ -132,8 +132,9 @@ def gradient_checkpointing_enable(
|
||||||
|
|
||||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||||
|
self.enable_input_require_grads()
|
||||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||||
else:
|
else: # have already enabled input require gradients
|
||||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue