forked from jiuyuan/CPM-9G-8B
32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
import os
|
||
from libcpm import CPM9G
|
||
import argparse, json, os
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--pt", type=str, help="the path of ckpt")
|
||
parser.add_argument("--config", type=str, help="the path of config file")
|
||
parser.add_argument("--vocab", type=str, help="the path of vocab file")
|
||
args = parser.parse_args()
|
||
model_config = json.load(open(args.config, 'r'))
|
||
model_config["new_vocab"] = True
|
||
model = CPM9G(
|
||
"",
|
||
args.vocab,
|
||
-1,
|
||
memory_limit = 30 << 30,#memory limit 左边的参数根据gpu的显存设置,如果是A100,可以设置为 72 << 30,这样的话就可以用到更多的显存
|
||
model_config=model_config,
|
||
load_model=False,
|
||
)
|
||
model.load_model_pt(args.pt)
|
||
datas = [
|
||
'''<用户>马化腾是谁?<AI>''',
|
||
'''<用户>你是谁?<AI>''',
|
||
'''<用户>我要参加一个高性能会议,请帮我写一个致辞。<AI>''',
|
||
]
|
||
for data in datas:
|
||
res = model.inference(data, max_length=4096)
|
||
print(res['result'])
|
||
|
||
if __name__ == "__main__":
|
||
main() |