fix llama2 template
This commit is contained in:
parent
f30fc3b030
commit
e4a15f863c
|
@ -58,7 +58,7 @@ class Template:
|
|||
r"""
|
||||
Aligns inputs to a special format.
|
||||
"""
|
||||
prefix = [prefix] if prefix is not None else self.prefix # use prefix if provided
|
||||
prefix = [prefix] if prefix else self.prefix # use prefix if provided
|
||||
prefix = prefix + self.sep if prefix else [] # add separator for non-empty prefix
|
||||
history = history if (history and self.use_history) else []
|
||||
history = history + [(query, resp)]
|
||||
|
@ -124,6 +124,11 @@ class Llama2Template(Template):
|
|||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
"""
|
||||
if tokenizer.bos_token and getattr(tokenizer, "add_bos_token", False): # bos token is optional
|
||||
bos_token_id = [tokenizer.bos_token_id]
|
||||
else:
|
||||
bos_token_id = []
|
||||
eos_token_id = [tokenizer.eos_token_id] # eos token is required
|
||||
encoded_pairs = []
|
||||
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str."
|
||||
for turn_idx, (query, resp) in enumerate(history):
|
||||
|
@ -134,7 +139,7 @@ class Llama2Template(Template):
|
|||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||
encoded_pairs.append((prefix_ids + query_ids, resp_ids))
|
||||
encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id))
|
||||
return encoded_pairs
|
||||
|
||||
|
||||
|
@ -154,8 +159,8 @@ def register_template(
|
|||
prefix=prefix,
|
||||
prompt=prompt,
|
||||
sep=sep,
|
||||
use_history=use_history,
|
||||
stop_words=stop_words
|
||||
stop_words=stop_words,
|
||||
use_history=use_history
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue