CPM-9G-8B/9g_inference.py

32 lines
1.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from libcpm import CPM9G
import argparse, json, os
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--pt", type=str, help="the path of ckpt")
parser.add_argument("--config", type=str, help="the path of config file")
parser.add_argument("--vocab", type=str, help="the path of vocab file")
args = parser.parse_args()
model_config = json.load(open(args.config, 'r'))
model_config["new_vocab"] = True
model = CPM9G(
"",
args.vocab,
-1,
memory_limit = 30 << 30,#memory limit 左边的参数根据gpu的显存设置如果是A100可以设置为 72 << 30这样的话就可以用到更多的显存
model_config=model_config,
load_model=False,
)
model.load_model_pt(args.pt)
datas = [
'''<用户>马化腾是谁?<AI>''',
'''<用户>你是谁?<AI>''',
'''<用户>我要参加一个高性能会议,请帮我写一个致辞。<AI>''',
]
for data in datas:
res = model.inference(data, max_length=4096)
print(res['result'])
if __name__ == "__main__":
main()