forked from jiuyuan/CPM-9G-8B
需要修改补充的地方标记了TODO
This commit is contained in:
parent
01e0f316e2
commit
78fdcfc0d9
|
@ -39,10 +39,13 @@
|
|||
|
||||
# 九格大模型使用文档
|
||||
## 目录
|
||||
|
||||
- [模型推理](https://www.osredm.com/jiuyuan/CPM-9G-8B/tree/master/quick_start_clean/readmes/README_ALL.md?tab=readme-ov-file#模型推理)
|
||||
<!-- - [仓库目录结构](#仓库目录结构) -->
|
||||
- [九格大模型使用文档](#九格大模型使用文档)
|
||||
- [目录](#目录)
|
||||
- [环境配置](#环境配置)
|
||||
- [开源模型](#开源模型)
|
||||
- [数据处理流程](#数据处理流程)
|
||||
- [单个数据集处理](#单个数据集处理)
|
||||
- [多个数据集混合](#多个数据集混合)
|
||||
|
@ -50,6 +53,7 @@
|
|||
- [多机训练](#多机训练)
|
||||
- [参数详细介绍](#参数详细介绍)
|
||||
- [查看训练情况](#查看训练情况)
|
||||
- [模型推理 TODO:需要补充](#模型推理-todo需要补充)
|
||||
- [常见问题](#常见问题)
|
||||
|
||||
<!-- ## 仓库目录结构
|
||||
|
@ -101,8 +105,21 @@ pip install sentencepiece
|
|||
pip install protobuf==3.20.0 #protobuf版本过高时无法适配tensorboard
|
||||
pip install tensorboard
|
||||
pip install tensorboardX
|
||||
|
||||
|
||||
8.安装vllm
|
||||
我们提供了两种vllm的安装方式
|
||||
请直接安装/quick_start_clean/tools/vllm-0.5.0.dev0+cu122-cp310-cp310-linux_x86_64.whl
|
||||
如果不成功,请通过源码安装vllm,即通过/quick_start_clean/tools/vllm-0.5.0.dev0.tar中的setup.py安装
|
||||
```
|
||||
|
||||
## 开源模型
|
||||
1. 8B的百亿SFT模型,v2版本是在v1基础上精度和对话能力的优化模型,下载链接:
|
||||
[8b_sft_model_v1](https://qy-obs-6d58.obs.cn-north-4.myhuaweicloud.com/checkpoints-epoch-1.tar.gz), [8b_sft_model_v2](https://qy-obs-6d58.obs.cn-north-4.myhuaweicloud.com/sft_8b_v2.zip)
|
||||
|
||||
2. 端侧2B模型,下载链接:
|
||||
[2b—sft-model] # TODO
|
||||
|
||||
## 数据处理流程
|
||||
### 单个数据集处理
|
||||
预训练语料为无监督形式,不需要区分问题与答案,但需要将数据转为index后进行模型训练。我们拿到的原始数据可能是两种形式:
|
||||
|
@ -374,6 +391,40 @@ tensorboard –-logdir /apps/fm9g_2b/data/tensorboard/2b_0701 #存放.events文
|
|||
TypeError: MessageToJson() got an unexpected keyword argument 'including_default_value_fields'
|
||||
```
|
||||
|
||||
## 模型推理 TODO:需要补充
|
||||
```python
|
||||
import os
|
||||
from libcpm import CPM9G
|
||||
import argparse, json, os
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pt", type=str, help="the path of ckpt")
|
||||
parser.add_argument("--config", type=str, help="the path of config file")
|
||||
parser.add_argument("--vocab", type=str, help="the path of vocab file")
|
||||
args = parser.parse_args()
|
||||
model_config = json.load(open(args.config, 'r'))
|
||||
model_config["new_vocab"] = True
|
||||
model = CPM9G(
|
||||
"",
|
||||
args.vocab,
|
||||
-1,
|
||||
memory_limit = 30 << 30,#memory limit 左边的参数根据gpu的显存设置,如果是A100,可以设置为 72 << 30,这样的话就可以用到更多的显存
|
||||
model_config=model_config,
|
||||
load_model=False,
|
||||
)
|
||||
model.load_model_pt(args.pt)
|
||||
datas = [
|
||||
'''<用户>马化腾是谁?<AI>''',
|
||||
'''<用户>你是谁?<AI>''',
|
||||
'''<用户>我要参加一个高性能会议,请帮我写一个致辞。<AI>''',
|
||||
]
|
||||
for data in datas:
|
||||
res = model.inference(data, max_length=4096)
|
||||
print(res['result'])
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
1. Conda安装pytorch时卡在solving environment:网络问题。
|
||||
解决方法:
|
||||
|
|
Loading…
Reference in New Issue