fix baichuan templates
This commit is contained in:
parent
0531886e1f
commit
85b1f6632a
|
@ -55,7 +55,7 @@
|
||||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
|
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
|
||||||
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan |
|
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
||||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
|
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||||
|
|
|
@ -55,7 +55,7 @@
|
||||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
|
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
|
||||||
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan |
|
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
||||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
|
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
|
||||||
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
|
||||||
|
|
|
@ -49,7 +49,7 @@ class ChatModel:
|
||||||
top_p=top_p or gen_kwargs["top_p"],
|
top_p=top_p or gen_kwargs["top_p"],
|
||||||
top_k=top_k or gen_kwargs["top_k"],
|
top_k=top_k or gen_kwargs["top_k"],
|
||||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||||
eos_token_id=list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)),
|
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
logits_processor=get_logits_processor()
|
logits_processor=get_logits_processor()
|
||||||
))
|
))
|
||||||
|
|
|
@ -63,7 +63,9 @@ def preprocess_dataset(
|
||||||
for query, response, history, system in construct_example(examples):
|
for query, response, history, system in construct_example(examples):
|
||||||
input_ids, labels = [], []
|
input_ids, labels = [], []
|
||||||
|
|
||||||
for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
|
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||||
|
tokenizer, query, response, history, system
|
||||||
|
)):
|
||||||
if len(source_ids) > data_args.max_source_length:
|
if len(source_ids) > data_args.max_source_length:
|
||||||
source_ids = source_ids[:data_args.max_source_length]
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
if len(target_ids) > data_args.max_target_length:
|
if len(target_ids) > data_args.max_target_length:
|
||||||
|
@ -72,8 +74,17 @@ def preprocess_dataset(
|
||||||
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
|
if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if turn_idx != 0 and template.efficient_eos: # used in baichuan, qwen and gpt2 models
|
||||||
|
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||||
|
else:
|
||||||
|
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||||
|
|
||||||
input_ids += source_ids + target_ids
|
input_ids += source_ids + target_ids
|
||||||
labels += [IGNORE_INDEX] * len(source_ids) + target_ids
|
labels += source_mask + target_ids
|
||||||
|
|
||||||
|
if template.efficient_eos:
|
||||||
|
input_ids += [tokenizer.eos_token_id]
|
||||||
|
labels += [tokenizer.eos_token_id]
|
||||||
|
|
||||||
model_inputs["input_ids"].append(input_ids)
|
model_inputs["input_ids"].append(input_ids)
|
||||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||||
|
|
|
@ -20,6 +20,7 @@ class Template:
|
||||||
sep: List[Union[str, Dict[str, str]]]
|
sep: List[Union[str, Dict[str, str]]]
|
||||||
stop_words: List[str]
|
stop_words: List[str]
|
||||||
use_history: bool
|
use_history: bool
|
||||||
|
efficient_eos: bool
|
||||||
|
|
||||||
def encode_oneturn(
|
def encode_oneturn(
|
||||||
self,
|
self,
|
||||||
|
@ -74,19 +75,19 @@ class Template:
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer"
|
tokenizer: "PreTrainedTokenizer"
|
||||||
) -> Tuple[List[int], List[int]]:
|
) -> Tuple[List[int], List[int]]:
|
||||||
if (
|
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
|
||||||
tokenizer.bos_token_id is not None
|
|
||||||
and getattr(tokenizer, "add_bos_token", True)
|
|
||||||
): # baichuan-13b has no bos token
|
|
||||||
bos_ids = [tokenizer.bos_token_id]
|
bos_ids = [tokenizer.bos_token_id]
|
||||||
else:
|
else: # baichuan, qwen and gpt2 models has no bos token
|
||||||
bos_ids = [] # bos token is optional
|
bos_ids = []
|
||||||
|
|
||||||
if tokenizer.eos_token_id is not None:
|
if tokenizer.eos_token_id is None:
|
||||||
eos_ids = [tokenizer.eos_token_id]
|
|
||||||
else:
|
|
||||||
raise ValueError("EOS token is required.")
|
raise ValueError("EOS token is required.")
|
||||||
|
|
||||||
|
if self.efficient_eos: # used in baichuan, qwen and gpt2 models
|
||||||
|
eos_ids = []
|
||||||
|
else:
|
||||||
|
eos_ids = [tokenizer.eos_token_id]
|
||||||
|
|
||||||
return bos_ids, eos_ids
|
return bos_ids, eos_ids
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
|
@ -186,7 +187,8 @@ def register_template(
|
||||||
system: str,
|
system: str,
|
||||||
sep: List[Union[str, Dict[str, str]]],
|
sep: List[Union[str, Dict[str, str]]],
|
||||||
stop_words: Optional[List[str]] = [],
|
stop_words: Optional[List[str]] = [],
|
||||||
use_history: Optional[bool] = True
|
use_history: Optional[bool] = True,
|
||||||
|
efficient_eos: Optional[bool] = False
|
||||||
) -> None:
|
) -> None:
|
||||||
template_class = Llama2Template if "llama2" in name else Template
|
template_class = Llama2Template if "llama2" in name else Template
|
||||||
templates[name] = template_class(
|
templates[name] = template_class(
|
||||||
|
@ -195,7 +197,8 @@ def register_template(
|
||||||
system=system,
|
system=system,
|
||||||
sep=sep,
|
sep=sep,
|
||||||
stop_words=stop_words,
|
stop_words=stop_words,
|
||||||
use_history=use_history
|
use_history=use_history,
|
||||||
|
efficient_eos=efficient_eos
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -206,15 +209,6 @@ def get_template_and_fix_tokenizer(
|
||||||
template = templates.get(name, None)
|
template = templates.get(name, None)
|
||||||
assert template is not None, "Template {} does not exist.".format(name)
|
assert template is not None, "Template {} does not exist.".format(name)
|
||||||
|
|
||||||
additional_special_tokens = template.stop_words
|
|
||||||
if len(template.stop_words): # inplace method
|
|
||||||
if tokenizer.eos_token_id is not None:
|
|
||||||
additional_special_tokens.append(tokenizer.eos_token)
|
|
||||||
|
|
||||||
tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token
|
|
||||||
additional_special_tokens.pop(0)
|
|
||||||
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
|
||||||
|
|
||||||
if tokenizer.eos_token_id is None:
|
if tokenizer.eos_token_id is None:
|
||||||
tokenizer.eos_token = "<|endoftext|>"
|
tokenizer.eos_token = "<|endoftext|>"
|
||||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||||
|
@ -227,7 +221,7 @@ def get_template_and_fix_tokenizer(
|
||||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||||
|
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
dict(additional_special_tokens=additional_special_tokens),
|
dict(additional_special_tokens=template.stop_words),
|
||||||
replace_additional_special_tokens=False
|
replace_additional_special_tokens=False
|
||||||
)
|
)
|
||||||
return template
|
return template
|
||||||
|
@ -466,18 +460,18 @@ register_template(
|
||||||
],
|
],
|
||||||
system="",
|
system="",
|
||||||
sep=[
|
sep=[
|
||||||
|
{"token": "<eoa>"},
|
||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
"</s>", # internlm cannot replace eos token
|
|
||||||
"<eoa>"
|
"<eoa>"
|
||||||
]
|
],
|
||||||
|
efficient_eos=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
||||||
Used for training and inference of the fine-tuned models.
|
|
||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="baichuan",
|
name="baichuan",
|
||||||
|
@ -487,39 +481,17 @@ register_template(
|
||||||
prompt=[
|
prompt=[
|
||||||
{"token": "<reserved_102>"}, # user token
|
{"token": "<reserved_102>"}, # user token
|
||||||
"{{query}}",
|
"{{query}}",
|
||||||
{"token": "<reserved_103>"} # assistant token
|
{"token": "<reserved_103>"} # assistant token
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
|
||||||
Used for inference of the original model.
|
|
||||||
"""
|
|
||||||
register_template(
|
|
||||||
name="baichuan_eval",
|
|
||||||
prefix=[
|
|
||||||
"{{system}}",
|
|
||||||
{"token": "<reserved_102>"} # user token
|
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
"{{query}}",
|
|
||||||
{"token": "<reserved_103>"} # assistant token
|
|
||||||
],
|
],
|
||||||
system="",
|
system="",
|
||||||
sep=[],
|
sep=[],
|
||||||
stop_words=[
|
efficient_eos=True
|
||||||
"<reserved_102>" # user token
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
r"""
|
r"""
|
||||||
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
|
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
|
||||||
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
|
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
|
||||||
Used for training and inference of the fine-tuned models.
|
|
||||||
"""
|
"""
|
||||||
register_template(
|
register_template(
|
||||||
name="baichuan2",
|
name="baichuan2",
|
||||||
|
@ -529,33 +501,11 @@ register_template(
|
||||||
prompt=[
|
prompt=[
|
||||||
{"token": "<reserved_106>"}, # user token
|
{"token": "<reserved_106>"}, # user token
|
||||||
"{{query}}",
|
"{{query}}",
|
||||||
{"token": "<reserved_107>"} # assistant token
|
{"token": "<reserved_107>"} # assistant token
|
||||||
],
|
|
||||||
system="",
|
|
||||||
sep=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
r"""
|
|
||||||
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
|
|
||||||
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
|
|
||||||
Used for inference of the original model.
|
|
||||||
"""
|
|
||||||
register_template(
|
|
||||||
name="baichuan2_eval",
|
|
||||||
prefix=[
|
|
||||||
"{{system}}",
|
|
||||||
{"token": "<reserved_106>"} # user token
|
|
||||||
],
|
|
||||||
prompt=[
|
|
||||||
"{{query}}",
|
|
||||||
{"token": "<reserved_107>"} # assistant token
|
|
||||||
],
|
],
|
||||||
system="",
|
system="",
|
||||||
sep=[],
|
sep=[],
|
||||||
stop_words=[
|
efficient_eos=True
|
||||||
"<reserved_106>" # user token
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -568,7 +518,6 @@ register_template(
|
||||||
prefix=[
|
prefix=[
|
||||||
{"token": "<|system|>"},
|
{"token": "<|system|>"},
|
||||||
"\n{{system}}",
|
"\n{{system}}",
|
||||||
{"token": "<|end|>"}
|
|
||||||
],
|
],
|
||||||
prompt=[
|
prompt=[
|
||||||
{"token": "<|user|>"},
|
{"token": "<|user|>"},
|
||||||
|
@ -579,11 +528,13 @@ register_template(
|
||||||
],
|
],
|
||||||
system="",
|
system="",
|
||||||
sep=[
|
sep=[
|
||||||
|
{"token": "<|end|>"},
|
||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
"<|end|>"
|
"<|end|>"
|
||||||
]
|
],
|
||||||
|
efficient_eos=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -594,8 +545,7 @@ register_template(
|
||||||
name="chatml",
|
name="chatml",
|
||||||
prefix=[
|
prefix=[
|
||||||
{"token": "<|im_start|>"},
|
{"token": "<|im_start|>"},
|
||||||
"system\n{{system}}",
|
"system\n{{system}}"
|
||||||
{"token": "<|im_end|>"}
|
|
||||||
],
|
],
|
||||||
prompt=[
|
prompt=[
|
||||||
{"token": "<|im_start|>"},
|
{"token": "<|im_start|>"},
|
||||||
|
@ -607,11 +557,13 @@ register_template(
|
||||||
],
|
],
|
||||||
system="You are a helpful assistant.",
|
system="You are a helpful assistant.",
|
||||||
sep=[
|
sep=[
|
||||||
|
{"token": "<|im_end|>"},
|
||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
"<|im_end|>"
|
"<|im_end|>"
|
||||||
]
|
],
|
||||||
|
efficient_eos=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,9 +15,13 @@ from transformers import (
|
||||||
)
|
)
|
||||||
from transformers.utils import check_min_version
|
from transformers.utils import check_min_version
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
except ImportError:
|
||||||
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from llmtuner.extras.logging import reset_logging, get_logger
|
from llmtuner.extras.logging import reset_logging, get_logger
|
||||||
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
||||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||||
|
@ -91,7 +95,7 @@ def load_model_and_tokenizer(
|
||||||
setattr(config, "use_logn_attn", True)
|
setattr(config, "use_logn_attn", True)
|
||||||
logger.info("Using dynamic NTK scaling.")
|
logger.info("Using dynamic NTK scaling.")
|
||||||
|
|
||||||
elif hasattr(config, "rope_scaling"): # for LLaMA models
|
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
|
||||||
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
|
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
|
||||||
|
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
|
|
|
@ -76,7 +76,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = self.generating_args.to_dict()
|
gen_kwargs = self.generating_args.to_dict()
|
||||||
gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids))
|
gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids
|
||||||
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ from torch.optim import AdamW
|
||||||
from typing import TYPE_CHECKING, Optional, List
|
from typing import TYPE_CHECKING, Optional, List
|
||||||
from transformers import DataCollatorForSeq2Seq
|
from transformers import DataCollatorForSeq2Seq
|
||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
from transformers.utils.versions import require_version
|
|
||||||
|
|
||||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||||
from llmtuner.extras.ploting import plot_loss
|
from llmtuner.extras.ploting import plot_loss
|
||||||
|
|
|
@ -54,7 +54,7 @@ def run_sft(
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
gen_kwargs = generating_args.to_dict()
|
gen_kwargs = generating_args.to_dict()
|
||||||
gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids))
|
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue