fix llama2 template

This commit is contained in:
hiyouga 2023-08-05 00:07:54 +08:00
parent f30fc3b030
commit e4a15f863c
1 changed files with 9 additions and 4 deletions

View File

@ -58,7 +58,7 @@ class Template:
r"""
Aligns inputs to a special format.
"""
prefix = [prefix] if prefix is not None else self.prefix # use prefix if provided
prefix = [prefix] if prefix else self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else [] # add separator for non-empty prefix
history = history if (history and self.use_history) else []
history = history + [(query, resp)]
@ -124,6 +124,11 @@ class Llama2Template(Template):
r"""
Encodes formatted inputs to pairs of token ids.
"""
if tokenizer.bos_token and getattr(tokenizer, "add_bos_token", False): # bos token is optional
bos_token_id = [tokenizer.bos_token_id]
else:
bos_token_id = []
eos_token_id = [tokenizer.eos_token_id] # eos token is required
encoded_pairs = []
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str."
for turn_idx, (query, resp) in enumerate(history):
@ -134,7 +139,7 @@ class Llama2Template(Template):
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((prefix_ids + query_ids, resp_ids))
encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id))
return encoded_pairs
@ -154,8 +159,8 @@ def register_template(
prefix=prefix,
prompt=prompt,
sep=sep,
use_history=use_history,
stop_words=stop_words
stop_words=stop_words,
use_history=use_history
)