tiny fix
This commit is contained in:
parent
80a9e6bf94
commit
38b6b0f52e
|
@ -1,7 +1,7 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by Microsoft's DeepSpeed library.
|
# This code is inspired by the Microsoft's DeepSpeed library.
|
||||||
# https://www.deepspeed.ai/tutorials/flops-profiler/
|
# https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 imoneoi and the LlamaFactory team.
|
# Copyright 2024 imoneoi and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by imoneoi's OpenChat library.
|
# This code is inspired by the imoneoi's OpenChat library.
|
||||||
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
|
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by Tencent's LLaMA-Pro library.
|
# This code is inspired by the Tencent's LLaMA-Pro library.
|
||||||
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's transformers library.
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by Dan's test library.
|
# This code is inspired by the Dan's test library.
|
||||||
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's transformers library.
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's transformers library.
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's Transformers and PEFT library.
|
# This code is inspired by the HuggingFace's Transformers and PEFT library.
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
|
||||||
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
|
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
|
||||||
#
|
#
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
# Copyright 2024 EleutherAI, HuggingFace Inc., and the LlamaFactory team.
|
# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is based on the EleutherAI's GPT-NeoX and HuggingFace's Transformers libraries.
|
# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# This code is also inspired by the original LongLoRA implementation.
|
||||||
|
# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's Optimum library.
|
# This code is inspired by the HuggingFace's Optimum library.
|
||||||
# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py
|
# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's TRL library.
|
# This code is inspired by the HuggingFace's Transformers library.
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's TRL library.
|
# This code is inspired by the HuggingFace's TRL library.
|
||||||
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
|
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's TRL library.
|
# This code is inspired by the HuggingFace's TRL library.
|
||||||
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
|
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -114,8 +114,8 @@ class CustomKTOTrainer(KTOTrainer):
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||||
super()._save(output_dir, state_dict)
|
super()._save(output_dir, state_dict)
|
||||||
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
if self.processor is not None:
|
if self.processor is not None:
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's TRL library.
|
# This code is inspired by the HuggingFace's TRL library.
|
||||||
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
|
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's TRL library.
|
# This code is inspired by the HuggingFace's TRL library.
|
||||||
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
|
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's TRL library.
|
# This code is inspired by the HuggingFace's TRL library.
|
||||||
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
|
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's transformers library.
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by CarperAI's trlx library.
|
# This code is inspired by the CarperAI's trlx library.
|
||||||
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
|
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
@ -89,8 +89,8 @@ class PairwiseTrainer(Trainer):
|
||||||
|
|
||||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||||
super()._save(output_dir, state_dict)
|
super()._save(output_dir, state_dict)
|
||||||
|
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||||
if self.processor is not None:
|
if self.processor is not None:
|
||||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by CarperAI's trlx library.
|
# This code is inspired by the CarperAI's trlx library.
|
||||||
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's transformers library and THUDM's ChatGLM implementation.
|
# This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
|
||||||
# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
|
# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
|
||||||
#
|
#
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||||
#
|
#
|
||||||
# This code is inspired by HuggingFace's transformers library.
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
|
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
|
@ -41,7 +41,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
|
||||||
state_dict_b = model_b.state_dict()
|
state_dict_b = model_b.state_dict()
|
||||||
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
assert set(state_dict_a.keys()) == set(state_dict_b.keys())
|
||||||
for name in state_dict_a.keys():
|
for name in state_dict_a.keys():
|
||||||
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True
|
assert torch.allclose(state_dict_a[name], state_dict_b[name])
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
Loading…
Reference in New Issue