diff --git a/9g_config/tokenization_9g.py b/9g_config/tokenization_9g.py index e105acf3..385bd7bd 100644 --- a/9g_config/tokenization_9g.py +++ b/9g_config/tokenization_9g.py @@ -67,6 +67,8 @@ class CPM9GTokenizer(PreTrainedTokenizer): if vocab_file: if 'vocab.txt' not in vocab_file: all_tokens = self.load_vocab(io.FileIO(vocab_file + VOCAB_FILES_NAMES['vocab_file'], "rb")) + else: + all_tokens = self.load_vocab(io.FileIO(vocab_file, "rb")) else: all_tokens = self.load_vocab(io.FileIO(VOCAB_FILES_NAMES['vocab_file'], "rb")) @@ -305,7 +307,7 @@ class CPM9GTokenizer(PreTrainedTokenizer): return text[0] - def encode(self, text: str) -> List[int]: + def encode(self, text: str, add_special_tokens: bool) -> List[int]: """ 将文本编码为 ID 列表。 @@ -380,7 +382,7 @@ class CPM9GTokenizer(PreTrainedTokenizer): st += 1 return "".join(ret) - def decode(self, tokens: List[int]) -> str: + def decode(self, tokens: List[int], skip_special_tokens: bool) -> str: """ 将 ID 列表解码为字符串。 @@ -440,7 +442,7 @@ class CPM9GTokenizer(PreTrainedTokenizer): #else: # ret.append(self.unk_token) # st += 1 - return ''.join(ret) + return ''.join([str(item) for item in ret]) def _encode_unicode(self, token: str) -> List[int]: """