CPM-9G-8B/FM_9G/fm9g/tokenizer/fm9g.py

212 lines
7.6 KiB
Python
Raw Normal View History

2024-02-27 14:33:33 +08:00
import io
import json
from typing import Dict
from typing import IO
from typing import List
import pkg_resources
from pytrie import StringTrie
def load_vocab(fp: IO[bytes]) -> Dict[str, int]:
"""Loads a vocabulary file into a dictionary."""
vocab: Dict[str, int] = {}
reader = io.TextIOWrapper(fp, encoding="utf-8")
for token in reader.readlines():
token = token.strip()
if len(token) == 0:
continue
token = json.loads(token)
vocab[token] = len(vocab)
return vocab
2024-07-15 14:27:10 +08:00
class FM9GTokenizer(object):
2024-02-27 14:33:33 +08:00
def __init__(self, path=None):
self.unk_token = "<unk>"
self.bos_token = "<s>"
self.eos_token = "</s>"
self.byte_list = ["<0x0{}>".format(hex(i).upper()[2:]) for i in range(0x10)] + [
"<0x{}>".format(hex(i).upper()[2:]) for i in range(0x10, 0x100)
]
self._special_token_set = set([self.unk_token, self.bos_token, self.eos_token] + self.byte_list)
if path:
all_tokens = load_vocab(io.FileIO(path, "rb"))
else:
2024-07-15 14:27:10 +08:00
all_tokens = load_vocab(pkg_resources.resource_stream("fm9g", "/fm9g/vocabs/fm9g.txt"))
2024-02-27 14:33:33 +08:00
self.encoder: Dict[str, int] = {}
self._special_encoder: Dict[str, int] = {}
for token, token_id in all_tokens.items():
if token in self._special_token_set:
self._special_encoder[token] = token_id
else:
self.encoder[token] = token_id
self.decoder = {v: k for k, v in self.encoder.items()}
self._byte_decoder = {self._special_encoder[token]: i for i, token in enumerate(self.byte_list)}
self._max_word_len = max([len(x) for x in self.encoder.keys()])
self._len_word_first = {}
for x in self.encoder.keys():
if not x[0] in self._len_word_first:
self._len_word_first[x[0]] = 1
if len(x) > self._len_word_first[x[0]]:
self._len_word_first[x[0]] = len(x)
self.tencoder = StringTrie(self.encoder)
def get_piece(self, text: str) -> str:
if text[0] in self._len_word_first:
text = text[: self._len_word_first[text[0]]]
len_text = len(text)
for i in range(len(text)):
sub = text[: len_text - i]
if sub in self.encoder:
return sub
return text[0]
@property
def vocab_size(self):
return len(self)
@property
def eos_id(self):
return self._special_encoder[self.eos_token]
@property
def bos_id(self):
return self._special_encoder[self.bos_token]
@property
def unk_id(self):
return self._special_encoder[self.unk_token]
def __len__(self):
return len(self.encoder) + len(self._special_encoder)
def tokenize(self, text: str) -> List[str]:
output_tokens: List[str] = []
st = 0
while st < len(text):
piece = self.get_piece(text[st:])
output_tokens.append(piece)
st += len(piece)
return output_tokens
@staticmethod
def escape(text: str) -> str:
return text
@staticmethod
def unescape(text: str) -> str:
return text
def encode(self, text: str) -> List[int]:
2024-07-15 14:27:10 +08:00
#if len(text) > 20480:
# return [0 for _ in range(20480)]
2024-02-27 14:33:33 +08:00
ret = []
for x in self.tokenize(text):
if x in self.encoder:
ret.append(self.encoder[x])
else:
ret.extend(self._encode_unicode(x))
return ret
def decode(self, tokens: List[int]):
"""Decode ids into a string."""
ret = []
st = 0
while st < len(tokens):
if tokens[st] in self.decoder:
ret.append(self.decoder[tokens[st]])
st += 1
elif tokens[st] in self._byte_decoder:
if (
st + 3 < len(tokens)
and tokens[st + 1] in self._byte_decoder
and tokens[st + 2] in self._byte_decoder
and tokens[st + 3] in self._byte_decoder
):
first_id = self._byte_decoder[tokens[st]]
plane_id = self._byte_decoder[tokens[st + 1]]
row_id = self._byte_decoder[tokens[st + 2]]
cell_id = self._byte_decoder[tokens[st + 3]]
2024-07-15 14:27:10 +08:00
int_bytes = int.to_bytes(first_id << 24 | plane_id << 16 | row_id << 8 | cell_id, 4, "big")
try:
decoded_str = int_bytes.decode("utf-8", errors="replace")
ret.append(decoded_str)
#print(decoded_str)
except UnicodeDecodeError as e:
print(f"UnicodeDecodeError: {e}")
2024-02-27 14:33:33 +08:00
st += 4
elif (
st + 2 < len(tokens)
and tokens[st + 1] in self._byte_decoder
and tokens[st + 2] in self._byte_decoder
):
plane_id = self._byte_decoder[tokens[st]]
row_id = self._byte_decoder[tokens[st + 1]]
cell_id = self._byte_decoder[tokens[st + 2]]
2024-07-15 14:27:10 +08:00
int_bytes = int.to_bytes(plane_id << 16 | row_id << 8 | cell_id, 3, "big")
try:
decoded_str = int_bytes.decode("utf-8", errors="replace")
ret.append(decoded_str)
except UnicodeDecodeError as e:
print(f"UnicodeDecodeError: {e}")
2024-02-27 14:33:33 +08:00
st += 3
elif st + 1 < len(tokens) and tokens[st + 1] in self._byte_decoder:
row_id = self._byte_decoder[tokens[st]]
cell_id = self._byte_decoder[tokens[st + 1]]
2024-07-15 14:27:10 +08:00
int_bytes = int.to_bytes(row_id << 8 | cell_id, 2, "big")
try:
decoded_str = int_bytes.decode("utf-8", errors="replace")
ret.append(decoded_str)
except UnicodeDecodeError as e:
print(f"UnicodeDecodeError: {e}")
#ret.append(int.to_bytes(row_id << 8 | cell_id, 2, "big").decode("utf-8"))
2024-02-27 14:33:33 +08:00
st += 2
else:
cell_id = self._byte_decoder[tokens[st]]
2024-07-15 14:27:10 +08:00
int_bytes = int.to_bytes(cell_id, 1, "big")
try:
decoded_str = int_bytes.decode("utf-8", errors="replace")
ret.append(decoded_str)
except UnicodeDecodeError as e:
print(f"UnicodeDecodeError: {e}")
#ret.append(int.to_bytes(cell_id, 1, "big").decode("utf-8"))
2024-02-27 14:33:33 +08:00
st += 1
elif tokens[st] == self.eos_id:
ret.append(self.eos_token)
st += 1
elif tokens[st] == self.bos_id:
ret.append(self.bos_token)
st += 1
else:
ret.append(self.unk_token)
st += 1
return "".join(ret)
def _encode_unicode(self, token):
# wrap unicode encoding into a helper function
ids = []
utf8_id = token.encode("utf-8")
for _id in utf8_id:
ids.append(self._special_encoder[self.byte_list[_id]])
return ids
def next_token(self, text):
# fast next token matching
token, token_id = self.tencoder.longest_prefix_item(text, (None, None))
if token is None:
token = text[0]
token_ids = self._encode_unicode(token)
else:
token_ids = [token_id]
return token, token_ids