This commit is contained in:
hiyouga 2024-06-16 01:06:41 +08:00
parent 80a9e6bf94
commit 38b6b0f52e
22 changed files with 27 additions and 25 deletions

View File

@ -1,7 +1,7 @@
# coding=utf-8 # coding=utf-8
# Copyright 2024 Microsoft Corporation and the LlamaFactory team. # Copyright 2024 Microsoft Corporation and the LlamaFactory team.
# #
# This code is inspired by Microsoft's DeepSpeed library. # This code is inspired by the Microsoft's DeepSpeed library.
# https://www.deepspeed.ai/tutorials/flops-profiler/ # https://www.deepspeed.ai/tutorials/flops-profiler/
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,7 +1,7 @@
# coding=utf-8 # coding=utf-8
# Copyright 2024 imoneoi and the LlamaFactory team. # Copyright 2024 imoneoi and the LlamaFactory team.
# #
# This code is inspired by imoneoi's OpenChat library. # This code is inspired by the imoneoi's OpenChat library.
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py # https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,7 +1,7 @@
# coding=utf-8 # coding=utf-8
# Copyright 2024 Tencent Inc. and the LlamaFactory team. # Copyright 2024 Tencent Inc. and the LlamaFactory team.
# #
# This code is inspired by Tencent's LLaMA-Pro library. # This code is inspired by the Tencent's LLaMA-Pro library.
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py # https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# This code is inspired by Dan's test library. # This code is inspired by the Dan's test library.
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py # https://github.com/hendrycks/test/blob/master/evaluate_flan.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's Transformers and PEFT library. # This code is inspired by the HuggingFace's Transformers and PEFT library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py # https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py
# #

View File

@ -1,7 +1,9 @@
# Copyright 2024 EleutherAI, HuggingFace Inc., and the LlamaFactory team. # Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team.
# #
# This code is based on the EleutherAI's GPT-NeoX and HuggingFace's Transformers libraries. # This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py
# This code is also inspired by the original LongLoRA implementation.
# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's Optimum library. # This code is inspired by the HuggingFace's Optimum library.
# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py # https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's TRL library. # This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/modeling_llava.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/dpo.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py # https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/kto_trainer.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -114,8 +114,8 @@ class CustomKTOTrainer(KTOTrainer):
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict) super()._save(output_dir, state_dict)
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
if self.processor is not None:
getattr(self.processor, "image_processor").save_pretrained(output_dir) getattr(self.processor, "image_processor").save_pretrained(output_dir)
def forward( def forward(

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/kto.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py # https://github.com/huggingface/trl/blob/v0.8.0/trl/trainer/ppo_trainer.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's TRL library. # This code is inspired by the HuggingFace's TRL library.
# https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py # https://github.com/huggingface/trl/blob/v0.8.0/examples/scripts/ppo.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# This code is inspired by CarperAI's trlx library. # This code is inspired by the CarperAI's trlx library.
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py # https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/reward_model.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -89,8 +89,8 @@ class PairwiseTrainer(Trainer):
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None: def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
super()._save(output_dir, state_dict) super()._save(output_dir, state_dict)
if self.processor is not None:
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
if self.processor is not None:
getattr(self.processor, "image_processor").save_pretrained(output_dir) getattr(self.processor, "image_processor").save_pretrained(output_dir)
def compute_loss( def compute_loss(

View File

@ -1,6 +1,6 @@
# Copyright 2024 the LlamaFactory team. # Copyright 2024 the LlamaFactory team.
# #
# This code is inspired by CarperAI's trlx library. # This code is inspired by the CarperAI's trlx library.
# https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py # https://github.com/CarperAI/trlx/blob/v0.7.0/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team. # Copyright 2024 HuggingFace Inc., THUDM, and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's transformers library and THUDM's ChatGLM implementation. # This code is inspired by the HuggingFace's transformers library and the THUDM's ChatGLM implementation.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
# https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py # https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
# #

View File

@ -1,6 +1,6 @@
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
# #
# This code is inspired by HuggingFace's transformers library. # This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");

View File

@ -41,7 +41,7 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module"):
state_dict_b = model_b.state_dict() state_dict_b = model_b.state_dict()
assert set(state_dict_a.keys()) == set(state_dict_b.keys()) assert set(state_dict_a.keys()) == set(state_dict_b.keys())
for name in state_dict_a.keys(): for name in state_dict_a.keys():
assert torch.allclose(state_dict_a[name], state_dict_b[name]) is True assert torch.allclose(state_dict_a[name], state_dict_b[name])
@pytest.fixture @pytest.fixture