chore: update good to use tokenization_9g.py
This commit is contained in:
parent
e6e5cd97e1
commit
bab7110cf2
|
@ -67,6 +67,8 @@ class CPM9GTokenizer(PreTrainedTokenizer):
|
|||
if vocab_file:
|
||||
if 'vocab.txt' not in vocab_file:
|
||||
all_tokens = self.load_vocab(io.FileIO(vocab_file + VOCAB_FILES_NAMES['vocab_file'], "rb"))
|
||||
else:
|
||||
all_tokens = self.load_vocab(io.FileIO(vocab_file, "rb"))
|
||||
else:
|
||||
all_tokens = self.load_vocab(io.FileIO(VOCAB_FILES_NAMES['vocab_file'], "rb"))
|
||||
|
||||
|
@ -305,7 +307,7 @@ class CPM9GTokenizer(PreTrainedTokenizer):
|
|||
return text[0]
|
||||
|
||||
|
||||
def encode(self, text: str) -> List[int]:
|
||||
def encode(self, text: str, add_special_tokens: bool) -> List[int]:
|
||||
"""
|
||||
将文本编码为 ID 列表。
|
||||
|
||||
|
@ -380,7 +382,7 @@ class CPM9GTokenizer(PreTrainedTokenizer):
|
|||
st += 1
|
||||
return "".join(ret)
|
||||
|
||||
def decode(self, tokens: List[int]) -> str:
|
||||
def decode(self, tokens: List[int], skip_special_tokens: bool) -> str:
|
||||
"""
|
||||
将 ID 列表解码为字符串。
|
||||
|
||||
|
@ -440,7 +442,7 @@ class CPM9GTokenizer(PreTrainedTokenizer):
|
|||
#else:
|
||||
# ret.append(self.unk_token)
|
||||
# st += 1
|
||||
return ''.join(ret)
|
||||
return ''.join([str(item) for item in ret])
|
||||
|
||||
def _encode_unicode(self, token: str) -> List[int]:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue