chore: update good to use tokenization_9g.py

This commit is contained in:
wql 2024-10-23 16:16:50 +08:00
parent e6e5cd97e1
commit bab7110cf2
1 changed files with 5 additions and 3 deletions

View File

@ -67,6 +67,8 @@ class CPM9GTokenizer(PreTrainedTokenizer):
if vocab_file: if vocab_file:
if 'vocab.txt' not in vocab_file: if 'vocab.txt' not in vocab_file:
all_tokens = self.load_vocab(io.FileIO(vocab_file + VOCAB_FILES_NAMES['vocab_file'], "rb")) all_tokens = self.load_vocab(io.FileIO(vocab_file + VOCAB_FILES_NAMES['vocab_file'], "rb"))
else:
all_tokens = self.load_vocab(io.FileIO(vocab_file, "rb"))
else: else:
all_tokens = self.load_vocab(io.FileIO(VOCAB_FILES_NAMES['vocab_file'], "rb")) all_tokens = self.load_vocab(io.FileIO(VOCAB_FILES_NAMES['vocab_file'], "rb"))
@ -305,7 +307,7 @@ class CPM9GTokenizer(PreTrainedTokenizer):
return text[0] return text[0]
def encode(self, text: str) -> List[int]: def encode(self, text: str, add_special_tokens: bool) -> List[int]:
""" """
将文本编码为 ID 列表 将文本编码为 ID 列表
@ -380,7 +382,7 @@ class CPM9GTokenizer(PreTrainedTokenizer):
st += 1 st += 1
return "".join(ret) return "".join(ret)
def decode(self, tokens: List[int]) -> str: def decode(self, tokens: List[int], skip_special_tokens: bool) -> str:
""" """
ID 列表解码为字符串 ID 列表解码为字符串
@ -440,7 +442,7 @@ class CPM9GTokenizer(PreTrainedTokenizer):
#else: #else:
# ret.append(self.unk_token) # ret.append(self.unk_token)
# st += 1 # st += 1
return ''.join(ret) return ''.join([str(item) for item in ret])
def _encode_unicode(self, token: str) -> List[int]: def _encode_unicode(self, token: str) -> List[int]:
""" """