Merge branch 'hiyouga:main' into main
This commit is contained in:
commit
63da6294dd
|
@ -1,5 +1,5 @@
|
||||||
torch>=1.13.1
|
torch>=1.13.1
|
||||||
transformers>=4.36.2
|
transformers>=4.37.2
|
||||||
datasets>=2.14.3
|
datasets>=2.14.3
|
||||||
accelerate>=0.21.0
|
accelerate>=0.21.0
|
||||||
peft>=0.7.0
|
peft>=0.7.0
|
||||||
|
|
|
@ -120,6 +120,9 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
|
|
||||||
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
|
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
|
||||||
if request.stream:
|
if request.stream:
|
||||||
|
if tools:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
|
||||||
|
|
||||||
generate = stream_chat_completion(messages, system, tools, request)
|
generate = stream_chat_completion(messages, system, tools, request)
|
||||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
|
@ -155,9 +155,6 @@ def get_dataset(
|
||||||
dataset = dataset.to_iterable_dataset()
|
dataset = dataset.to_iterable_dataset()
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
if data_args.streaming:
|
|
||||||
raise ValueError("Turn off dataset streaming to save cache files.")
|
|
||||||
|
|
||||||
with training_args.main_process_first(desc="load dataset"):
|
with training_args.main_process_first(desc="load dataset"):
|
||||||
all_datasets = []
|
all_datasets = []
|
||||||
for dataset_attr in get_dataset_list(data_args): # TODO: add split
|
for dataset_attr in get_dataset_list(data_args): # TODO: add split
|
||||||
|
|
|
@ -22,12 +22,8 @@ def preprocess_pretrain_dataset(
|
||||||
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
|
||||||
) -> Dict[str, List[List[int]]]:
|
) -> Dict[str, List[List[int]]]:
|
||||||
# build grouped texts with format `X1 X2 X3 ...`
|
# build grouped texts with format `X1 X2 X3 ...`
|
||||||
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
|
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
|
||||||
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
|
||||||
for i in range(len(tokenized_examples["input_ids"])):
|
|
||||||
tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
|
|
||||||
tokenized_examples["attention_mask"][i] += [1]
|
|
||||||
|
|
||||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||||
block_size = data_args.cutoff_len
|
block_size = data_args.cutoff_len
|
||||||
|
@ -59,7 +55,12 @@ def preprocess_supervised_dataset(
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
for turn_idx, (source_ids, target_ids) in enumerate(
|
for turn_idx, (source_ids, target_ids) in enumerate(
|
||||||
template.encode_multiturn(
|
template.encode_multiturn(
|
||||||
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
tokenizer,
|
||||||
|
messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
if data_args.train_on_prompt:
|
if data_args.train_on_prompt:
|
||||||
|
@ -147,7 +148,12 @@ def preprocess_unsupervised_dataset(
|
||||||
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
|
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
|
||||||
|
|
||||||
input_ids, labels = template.encode_oneturn(
|
input_ids, labels = template.encode_oneturn(
|
||||||
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
tokenizer,
|
||||||
|
messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
|
@ -176,10 +182,20 @@ def preprocess_pairwise_dataset(
|
||||||
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
|
||||||
|
|
||||||
prompt_ids, chosen_ids = template.encode_oneturn(
|
prompt_ids, chosen_ids = template.encode_oneturn(
|
||||||
tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
tokenizer,
|
||||||
|
chosen_messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
)
|
)
|
||||||
_, rejected_ids = template.encode_oneturn(
|
_, rejected_ids = template.encode_oneturn(
|
||||||
tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
|
tokenizer,
|
||||||
|
rejected_messages,
|
||||||
|
examples["system"][i],
|
||||||
|
examples["tools"][i],
|
||||||
|
data_args.cutoff_len,
|
||||||
|
data_args.reserved_label_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
if template.efficient_eos:
|
if template.efficient_eos:
|
||||||
|
|
|
@ -37,7 +37,7 @@ class Template:
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
cutoff_len: Optional[int] = 1_000_000,
|
cutoff_len: Optional[int] = 1_000_000,
|
||||||
reserved_label_len: Optional[int] = 16,
|
reserved_label_len: Optional[int] = 1,
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
r"""
|
r"""
|
||||||
Returns a single pair of token ids representing prompt and response respectively.
|
Returns a single pair of token ids representing prompt and response respectively.
|
||||||
|
@ -57,7 +57,7 @@ class Template:
|
||||||
system: Optional[str] = None,
|
system: Optional[str] = None,
|
||||||
tools: Optional[str] = None,
|
tools: Optional[str] = None,
|
||||||
cutoff_len: Optional[int] = 1_000_000,
|
cutoff_len: Optional[int] = 1_000_000,
|
||||||
reserved_label_len: Optional[int] = 16,
|
reserved_label_len: Optional[int] = 1,
|
||||||
) -> Sequence[Tuple[List[int], List[int]]]:
|
) -> Sequence[Tuple[List[int], List[int]]]:
|
||||||
r"""
|
r"""
|
||||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||||
|
@ -144,10 +144,10 @@ class Template:
|
||||||
max_len=(cutoff_len - total_length),
|
max_len=(cutoff_len - total_length),
|
||||||
reserved_label_len=reserved_label_len,
|
reserved_label_len=reserved_label_len,
|
||||||
)
|
)
|
||||||
encoded_messages[i] = encoded_messages[i][:max_source_len]
|
source_ids = encoded_messages[i][:max_source_len]
|
||||||
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len]
|
target_ids = encoded_messages[i + 1][:max_target_len]
|
||||||
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
|
total_length += len(source_ids) + len(target_ids)
|
||||||
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
|
encoded_pairs.append((source_ids, target_ids))
|
||||||
|
|
||||||
return encoded_pairs
|
return encoded_pairs
|
||||||
|
|
||||||
|
@ -218,7 +218,7 @@ def register_template(
|
||||||
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
||||||
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
||||||
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
|
||||||
default_tool_formatter = ToolFormatter(slots="default")
|
default_tool_formatter = ToolFormatter(tool_format="default")
|
||||||
default_separator_formatter = EmptyFormatter()
|
default_separator_formatter = EmptyFormatter()
|
||||||
templates[name] = template_class(
|
templates[name] = template_class(
|
||||||
format_user=format_user or default_user_formatter,
|
format_user=format_user or default_user_formatter,
|
||||||
|
@ -356,6 +356,14 @@ register_template(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
register_template(
|
||||||
|
name="cpm",
|
||||||
|
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
||||||
|
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||||
|
force_system=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="deepseek",
|
name="deepseek",
|
||||||
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
||||||
|
@ -464,7 +472,7 @@ register_template(
|
||||||
|
|
||||||
register_template(
|
register_template(
|
||||||
name="orion",
|
name="orion",
|
||||||
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: </s>"]),
|
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
||||||
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
|
||||||
force_system=True,
|
force_system=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,10 +21,10 @@ class DataArguments:
|
||||||
default="train", metadata={"help": "Which dataset split to use for training and evaluation."}
|
default="train", metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||||
)
|
)
|
||||||
cutoff_len: Optional[int] = field(
|
cutoff_len: Optional[int] = field(
|
||||||
default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."}
|
default=1024, metadata={"help": "The cutoff length of the model inputs after tokenization."}
|
||||||
)
|
)
|
||||||
reserved_label_len: Optional[int] = field(
|
reserved_label_len: Optional[int] = field(
|
||||||
default=1, metadata={"help": "The maximum length reserved for label after tokenization."}
|
default=1, metadata={"help": "The minimum cutoff length reserved for label after tokenization."}
|
||||||
)
|
)
|
||||||
train_on_prompt: Optional[bool] = field(
|
train_on_prompt: Optional[bool] = field(
|
||||||
default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}
|
default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}
|
||||||
|
@ -57,7 +57,7 @@ class DataArguments:
|
||||||
ignore_pad_token_for_loss: Optional[bool] = field(
|
ignore_pad_token_for_loss: Optional[bool] = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
val_size: Optional[float] = field(
|
val_size: Optional[float] = field(
|
||||||
|
|
|
@ -17,6 +17,7 @@ class FreezeArguments:
|
||||||
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
|
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
|
||||||
Qwen choices: ["mlp", "attn"], \
|
Qwen choices: ["mlp", "attn"], \
|
||||||
Phi choices: ["mlp", "mixer"], \
|
Phi choices: ["mlp", "mixer"], \
|
||||||
|
InternLM2 choices: ["feed_forward", "attention"], \
|
||||||
Others choices: the same as LLaMA.'
|
Others choices: the same as LLaMA.'
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -63,13 +63,12 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||||
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
|
if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter:
|
||||||
raise ValueError("Cannot create new adapter upon a quantized model.")
|
raise ValueError("Cannot create new adapter upon a quantized model.")
|
||||||
|
|
||||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||||
if finetuning_args.finetuning_type != "lora":
|
|
||||||
raise ValueError("Multiple adapters are only available for LoRA tuning.")
|
|
||||||
|
|
||||||
if model_args.quantization_bit is not None:
|
|
||||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
|
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
||||||
|
raise ValueError("Only LoRA method has adapters.")
|
||||||
|
|
||||||
|
|
||||||
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||||
|
|
|
@ -52,8 +52,18 @@ def init_adapter(
|
||||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||||
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] # noqa: C416
|
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] # noqa: C416
|
||||||
|
|
||||||
|
freeze_modules = set()
|
||||||
|
for name, _ in model.named_modules():
|
||||||
|
if "0." in name:
|
||||||
|
freeze_modules.add(name.split("0.")[-1].split(".")[0])
|
||||||
|
|
||||||
trainable_layers = []
|
trainable_layers = []
|
||||||
for module_name in finetuning_args.name_module_trainable:
|
for module_name in finetuning_args.name_module_trainable:
|
||||||
|
if module_name not in freeze_modules:
|
||||||
|
raise ValueError(
|
||||||
|
"Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules))
|
||||||
|
)
|
||||||
|
|
||||||
for idx in trainable_layer_ids:
|
for idx in trainable_layer_ids:
|
||||||
trainable_layers.append("{:d}.{}".format(idx, module_name))
|
trainable_layers.append("{:d}.{}".format(idx, module_name))
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
require_version("transformers>=4.36.2", "To fix: pip install transformers>=4.36.2")
|
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
|
||||||
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
|
||||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||||
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
|
||||||
|
|
|
@ -41,7 +41,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||||
# Make sure tied weights are tied before creating the device map.
|
# Make sure tied weights are tied before creating the device map.
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||||
device_map_kwargs = {"device_map": device_map}
|
device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
|
||||||
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
if "skip_keys" in inspect.signature(dispatch_model).parameters:
|
||||||
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
|
||||||
return dispatch_model(model, **device_map_kwargs)
|
return dispatch_model(model, **device_map_kwargs)
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
import json
|
import json
|
||||||
|
from contextlib import nullcontext
|
||||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ...extras.packages import is_requests_available
|
from ...extras.packages import is_requests_available
|
||||||
|
|
||||||
|
@ -23,18 +25,22 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
|
||||||
|
|
||||||
|
|
||||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||||
if target == "reward": # save default head temporarily
|
if is_deepspeed_zero3_enabled():
|
||||||
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
|
import deepspeed # type: ignore
|
||||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
|
|
||||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
|
|
||||||
|
|
||||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
params = [model.v_head.summary.weight, model.v_head.summary.bias]
|
||||||
model.v_head.load_state_dict(
|
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
|
||||||
{
|
else:
|
||||||
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
context_maybe_zero3 = nullcontext()
|
||||||
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(),
|
|
||||||
}
|
with context_maybe_zero3:
|
||||||
)
|
if target == "reward": # save default head temporarily
|
||||||
|
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
|
||||||
|
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
|
||||||
|
|
||||||
|
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||||
|
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
|
||||||
|
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()
|
||||||
|
|
||||||
|
|
||||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
||||||
|
|
Loading…
Reference in New Issue