Merge branch 'hiyouga:main' into main

This commit is contained in:
Johann-Peter Hartmann 2024-02-04 12:51:25 +00:00 committed by GitHub
commit 63da6294dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 83 additions and 43 deletions

View File

@ -1,5 +1,5 @@
torch>=1.13.1
transformers>=4.36.2
transformers>=4.37.2
datasets>=2.14.3
accelerate>=0.21.0
peft>=0.7.0

View File

@ -120,6 +120,9 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
def chat_completion(messages: Sequence[Dict[str, str]], system: str, tools: str, request: ChatCompletionRequest):
if request.stream:
if tools:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.")
generate = stream_chat_completion(messages, system, tools, request)
return EventSourceResponse(generate, media_type="text/event-stream")

View File

@ -155,9 +155,6 @@ def get_dataset(
dataset = dataset.to_iterable_dataset()
return dataset
if data_args.streaming:
raise ValueError("Turn off dataset streaming to save cache files.")
with training_args.main_process_first(desc="load dataset"):
all_datasets = []
for dataset_attr in get_dataset_list(data_args): # TODO: add split

View File

@ -22,12 +22,8 @@ def preprocess_pretrain_dataset(
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments"
) -> Dict[str, List[List[int]]]:
# build grouped texts with format `X1 X2 X3 ...`
text_examples = [examples["prompt"][i][0]["content"] for i in range(len(examples["prompt"]))]
text_examples = [messages[0]["content"] + tokenizer.eos_token for messages in examples["prompt"]]
tokenized_examples = tokenizer(text_examples, add_special_tokens=False)
for i in range(len(tokenized_examples["input_ids"])):
tokenized_examples["input_ids"][i] += [tokenizer.eos_token_id]
tokenized_examples["attention_mask"][i] += [1]
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.cutoff_len
@ -59,7 +55,12 @@ def preprocess_supervised_dataset(
input_ids, labels = [], []
for turn_idx, (source_ids, target_ids) in enumerate(
template.encode_multiturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
):
if data_args.train_on_prompt:
@ -147,7 +148,12 @@ def preprocess_unsupervised_dataset(
messages = examples["prompt"][i] + [{"role": Role.ASSISTANT, "content": ""}]
input_ids, labels = template.encode_oneturn(
tokenizer, messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
tokenizer,
messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:
@ -176,10 +182,20 @@ def preprocess_pairwise_dataset(
rejected_messages = examples["prompt"][i] + [examples["response"][i][1]]
prompt_ids, chosen_ids = template.encode_oneturn(
tokenizer, chosen_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
tokenizer,
chosen_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
_, rejected_ids = template.encode_oneturn(
tokenizer, rejected_messages, examples["system"][i], examples["tools"][i], data_args.cutoff_len
tokenizer,
rejected_messages,
examples["system"][i],
examples["tools"][i],
data_args.cutoff_len,
data_args.reserved_label_len,
)
if template.efficient_eos:

View File

@ -37,7 +37,7 @@ class Template:
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 16,
reserved_label_len: Optional[int] = 1,
) -> Tuple[List[int], List[int]]:
r"""
Returns a single pair of token ids representing prompt and response respectively.
@ -57,7 +57,7 @@ class Template:
system: Optional[str] = None,
tools: Optional[str] = None,
cutoff_len: Optional[int] = 1_000_000,
reserved_label_len: Optional[int] = 16,
reserved_label_len: Optional[int] = 1,
) -> Sequence[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
@ -144,10 +144,10 @@ class Template:
max_len=(cutoff_len - total_length),
reserved_label_len=reserved_label_len,
)
encoded_messages[i] = encoded_messages[i][:max_source_len]
encoded_messages[i + 1] = encoded_messages[i + 1][:max_target_len]
total_length += len(encoded_messages[i]) + len(encoded_messages[i + 1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i + 1]))
source_ids = encoded_messages[i][:max_source_len]
target_ids = encoded_messages[i + 1][:max_target_len]
total_length += len(source_ids) + len(target_ids)
encoded_pairs.append((source_ids, target_ids))
return encoded_pairs
@ -218,7 +218,7 @@ def register_template(
default_user_formatter = StringFormatter(slots=["{{content}}"])
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots)
default_tool_formatter = ToolFormatter(slots="default")
default_tool_formatter = ToolFormatter(tool_format="default")
default_separator_formatter = EmptyFormatter()
templates[name] = template_class(
format_user=format_user or default_user_formatter,
@ -356,6 +356,14 @@ register_template(
)
register_template(
name="cpm",
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)
register_template(
name="deepseek",
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
@ -464,7 +472,7 @@ register_template(
register_template(
name="orion",
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: </s>"]),
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]),
force_system=True,
)

View File

@ -21,10 +21,10 @@ class DataArguments:
default="train", metadata={"help": "Which dataset split to use for training and evaluation."}
)
cutoff_len: Optional[int] = field(
default=1024, metadata={"help": "The maximum length of the model inputs after tokenization."}
default=1024, metadata={"help": "The cutoff length of the model inputs after tokenization."}
)
reserved_label_len: Optional[int] = field(
default=1, metadata={"help": "The maximum length reserved for label after tokenization."}
default=1, metadata={"help": "The minimum cutoff length reserved for label after tokenization."}
)
train_on_prompt: Optional[bool] = field(
default=False, metadata={"help": "Whether to disable the mask on the prompt or not."}
@ -57,7 +57,7 @@ class DataArguments:
ignore_pad_token_for_loss: Optional[bool] = field(
default=True,
metadata={
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
},
)
val_size: Optional[float] = field(

View File

@ -17,6 +17,7 @@ class FreezeArguments:
BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
Qwen choices: ["mlp", "attn"], \
Phi choices: ["mlp", "mixer"], \
InternLM2 choices: ["feed_forward", "attention"], \
Others choices: the same as LLaMA.'
},
)

View File

@ -64,12 +64,11 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
raise ValueError("Cannot create new adapter upon a quantized model.")
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
if finetuning_args.finetuning_type != "lora":
raise ValueError("Multiple adapters are only available for LoRA tuning.")
if model_args.quantization_bit is not None:
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Only LoRA method has adapters.")
def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
parser = HfArgumentParser(_TRAIN_ARGS)

View File

@ -52,8 +52,18 @@ def init_adapter(
else: # fine-tuning the first n layers if num_layer_trainable < 0
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] # noqa: C416
freeze_modules = set()
for name, _ in model.named_modules():
if "0." in name:
freeze_modules.add(name.split("0.")[-1].split(".")[0])
trainable_layers = []
for module_name in finetuning_args.name_module_trainable:
if module_name not in freeze_modules:
raise ValueError(
"Module {} is not found, please choose from {}".format(module_name, ", ".join(freeze_modules))
)
for idx in trainable_layer_ids:
trainable_layers.append("{:d}.{}".format(idx, module_name))

View File

@ -21,7 +21,7 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
require_version("transformers>=4.36.2", "To fix: pip install transformers>=4.36.2")
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")
require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")

View File

@ -41,7 +41,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
device_map_kwargs = {"device_map": device_map}
device_map_kwargs = {"device_map": device_map, "offload_dir": "offload"}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
return dispatch_model(model, **device_map_kwargs)

View File

@ -1,7 +1,9 @@
import json
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Literal, Optional
import torch
from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.packages import is_requests_available
@ -23,18 +25,22 @@ def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
if is_deepspeed_zero3_enabled():
import deepspeed # type: ignore
params = [model.v_head.summary.weight, model.v_head.summary.bias]
context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0)
else:
context_maybe_zero3 = nullcontext()
with context_maybe_zero3:
if target == "reward": # save default head temporarily
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
setattr(model, "default_head_weight", model.v_head.summary.weight.data.detach().clone())
setattr(model, "default_head_bias", model.v_head.summary.bias.data.detach().clone())
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
model.v_head.load_state_dict(
{
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone(),
}
)
model.v_head.summary.weight.data = model.get_buffer("{}_head_weight".format(target)).detach().clone()
model.v_head.summary.bias.data = model.get_buffer("{}_head_bias".format(target)).detach().clone()
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: