fix baichuan templates

This commit is contained in:
hiyouga 2023-09-07 18:54:14 +08:00
parent 0531886e1f
commit 85b1f6632a
9 changed files with 53 additions and 87 deletions

View File

@ -55,7 +55,7 @@
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan |
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |

View File

@ -55,7 +55,7 @@
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan |
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |

View File

@ -49,7 +49,7 @@ class ChatModel:
top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
eos_token_id=list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)),
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id,
logits_processor=get_logits_processor()
))

View File

@ -63,7 +63,9 @@ def preprocess_dataset(
for query, response, history, system in construct_example(examples):
input_ids, labels = [], []
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
tokenizer, query, response, history, system
)):
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length:
@ -72,8 +74,17 @@ def preprocess_dataset(
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
break
if turn_idx != 0 and template.efficient_eos: # used in baichuan, qwen and gpt2 models
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
else:
source_mask = [IGNORE_INDEX] * len(source_ids)
input_ids += source_ids + target_ids
labels += [IGNORE_INDEX] * len(source_ids) + target_ids
labels += source_mask + target_ids
if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
labels += [tokenizer.eos_token_id]
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))

View File

@ -20,6 +20,7 @@ class Template:
sep: List[Union[str, Dict[str, str]]]
stop_words: List[str]
use_history: bool
efficient_eos: bool
def encode_oneturn(
self,
@ -74,19 +75,19 @@ class Template:
self,
tokenizer: "PreTrainedTokenizer"
) -> Tuple[List[int], List[int]]:
if (
tokenizer.bos_token_id is not None
and getattr(tokenizer, "add_bos_token", True)
): # baichuan-13b has no bos token
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
bos_ids = [tokenizer.bos_token_id]
else:
bos_ids = [] # bos token is optional
else: # baichuan, qwen and gpt2 models has no bos token
bos_ids = []
if tokenizer.eos_token_id is not None:
eos_ids = [tokenizer.eos_token_id]
else:
if tokenizer.eos_token_id is None:
raise ValueError("EOS token is required.")
if self.efficient_eos: # used in baichuan, qwen and gpt2 models
eos_ids = []
else:
eos_ids = [tokenizer.eos_token_id]
return bos_ids, eos_ids
def _encode(
@ -186,7 +187,8 @@ def register_template(
system: str,
sep: List[Union[str, Dict[str, str]]],
stop_words: Optional[List[str]] = [],
use_history: Optional[bool] = True
use_history: Optional[bool] = True,
efficient_eos: Optional[bool] = False
) -> None:
template_class = Llama2Template if "llama2" in name else Template
templates[name] = template_class(
@ -195,7 +197,8 @@ def register_template(
system=system,
sep=sep,
stop_words=stop_words,
use_history=use_history
use_history=use_history,
efficient_eos=efficient_eos
)
@ -206,15 +209,6 @@ def get_template_and_fix_tokenizer(
template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name)
additional_special_tokens = template.stop_words
if len(template.stop_words): # inplace method
if tokenizer.eos_token_id is not None:
additional_special_tokens.append(tokenizer.eos_token)
tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token
additional_special_tokens.pop(0)
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token))
@ -227,7 +221,7 @@ def get_template_and_fix_tokenizer(
logger.info("Add pad token: {}".format(tokenizer.pad_token))
tokenizer.add_special_tokens(
dict(additional_special_tokens=additional_special_tokens),
dict(additional_special_tokens=template.stop_words),
replace_additional_special_tokens=False
)
return template
@ -466,18 +460,18 @@ register_template(
],
system="",
sep=[
{"token": "<eoa>"},
"\n"
],
stop_words=[
"</s>", # internlm cannot replace eos token
"<eoa>"
]
],
efficient_eos=True
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
Used for training and inference of the fine-tuned models.
"""
register_template(
name="baichuan",
@ -487,39 +481,17 @@ register_template(
prompt=[
{"token": "<reserved_102>"}, # user token
"{{query}}",
{"token": "<reserved_103>"} # assistant token
],
system="",
sep=[]
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
Used for inference of the original model.
"""
register_template(
name="baichuan_eval",
prefix=[
"{{system}}",
{"token": "<reserved_102>"} # user token
],
prompt=[
"{{query}}",
{"token": "<reserved_103>"} # assistant token
{"token": "<reserved_103>"} # assistant token
],
system="",
sep=[],
stop_words=[
"<reserved_102>" # user token
]
efficient_eos=True
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
Used for training and inference of the fine-tuned models.
"""
register_template(
name="baichuan2",
@ -529,33 +501,11 @@ register_template(
prompt=[
{"token": "<reserved_106>"}, # user token
"{{query}}",
{"token": "<reserved_107>"} # assistant token
],
system="",
sep=[]
)
r"""
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
Used for inference of the original model.
"""
register_template(
name="baichuan2_eval",
prefix=[
"{{system}}",
{"token": "<reserved_106>"} # user token
],
prompt=[
"{{query}}",
{"token": "<reserved_107>"} # assistant token
{"token": "<reserved_107>"} # assistant token
],
system="",
sep=[],
stop_words=[
"<reserved_106>" # user token
]
efficient_eos=True
)
@ -568,7 +518,6 @@ register_template(
prefix=[
{"token": "<|system|>"},
"\n{{system}}",
{"token": "<|end|>"}
],
prompt=[
{"token": "<|user|>"},
@ -579,11 +528,13 @@ register_template(
],
system="",
sep=[
{"token": "<|end|>"},
"\n"
],
stop_words=[
"<|end|>"
]
],
efficient_eos=True
)
@ -594,8 +545,7 @@ register_template(
name="chatml",
prefix=[
{"token": "<|im_start|>"},
"system\n{{system}}",
{"token": "<|im_end|>"}
"system\n{{system}}"
],
prompt=[
{"token": "<|im_start|>"},
@ -607,11 +557,13 @@ register_template(
],
system="You are a helpful assistant.",
sep=[
{"token": "<|im_end|>"},
"\n"
],
stop_words=[
"<|im_end|>"
]
],
efficient_eos=True
)

View File

@ -15,9 +15,13 @@ from transformers import (
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.deepspeed import is_deepspeed_zero3_enabled
from trl import AutoModelForCausalLMWithValueHead
try:
from transformers.deepspeed import is_deepspeed_zero3_enabled
except ImportError:
from transformers.integrations import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
from llmtuner.extras.save_and_load import load_valuehead_params
@ -91,7 +95,7 @@ def load_model_and_tokenizer(
setattr(config, "use_logn_attn", True)
logger.info("Using dynamic NTK scaling.")
elif hasattr(config, "rope_scaling"): # for LLaMA models
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
if is_trainable:

View File

@ -76,7 +76,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Keyword arguments for `model.generate`
gen_kwargs = self.generating_args.to_dict()
gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids))
gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()

View File

@ -6,7 +6,6 @@ from torch.optim import AdamW
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler
from transformers.utils.versions import require_version
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.ploting import plot_loss

View File

@ -54,7 +54,7 @@ def run_sft(
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict()
gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids))
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()