需要修改补充的地方标记了TODO

This commit is contained in:
chaoyu@qiyuanlab.com 2024-07-16 18:25:17 +08:00
parent 01e0f316e2
commit 78fdcfc0d9
1 changed files with 51 additions and 0 deletions

View File

@ -39,10 +39,13 @@
# 九格大模型使用文档
## 目录
- [模型推理](https://www.osredm.com/jiuyuan/CPM-9G-8B/tree/master/quick_start_clean/readmes/README_ALL.md?tab=readme-ov-file#模型推理)
<!-- - [仓库目录结构](#仓库目录结构) -->
- [九格大模型使用文档](#九格大模型使用文档)
- [目录](#目录)
- [环境配置](#环境配置)
- [开源模型](#开源模型)
- [数据处理流程](#数据处理流程)
- [单个数据集处理](#单个数据集处理)
- [多个数据集混合](#多个数据集混合)
@ -50,6 +53,7 @@
- [多机训练](#多机训练)
- [参数详细介绍](#参数详细介绍)
- [查看训练情况](#查看训练情况)
- [模型推理 TODO:需要补充](#模型推理-todo需要补充)
- [常见问题](#常见问题)
<!-- ## 仓库目录结构
@ -101,8 +105,21 @@ pip install sentencepiece
pip install protobuf==3.20.0 #protobuf版本过高时无法适配tensorboard
pip install tensorboard
pip install tensorboardX
8.安装vllm
我们提供了两种vllm的安装方式
请直接安装/quick_start_clean/tools/vllm-0.5.0.dev0+cu122-cp310-cp310-linux_x86_64.whl
如果不成功请通过源码安装vllm即通过/quick_start_clean/tools/vllm-0.5.0.dev0.tar中的setup.py安装
```
## 开源模型
1. 8B的百亿SFT模型v2版本是在v1基础上精度和对话能力的优化模型下载链接
[8b_sft_model_v1](https://qy-obs-6d58.obs.cn-north-4.myhuaweicloud.com/checkpoints-epoch-1.tar.gz), [8b_sft_model_v2](https://qy-obs-6d58.obs.cn-north-4.myhuaweicloud.com/sft_8b_v2.zip)
2. 端侧2B模型下载链接
[2b—sft-model] # TODO
## 数据处理流程
### 单个数据集处理
预训练语料为无监督形式不需要区分问题与答案但需要将数据转为index后进行模型训练。我们拿到的原始数据可能是两种形式
@ -374,6 +391,40 @@ tensorboard -logdir /apps/fm9g_2b/data/tensorboard/2b_0701 #存放.events文
TypeError: MessageToJson() got an unexpected keyword argument 'including_default_value_fields'
```
## 模型推理 TODO:需要补充
```python
import os
from libcpm import CPM9G
import argparse, json, os
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--pt", type=str, help="the path of ckpt")
parser.add_argument("--config", type=str, help="the path of config file")
parser.add_argument("--vocab", type=str, help="the path of vocab file")
args = parser.parse_args()
model_config = json.load(open(args.config, 'r'))
model_config["new_vocab"] = True
model = CPM9G(
"",
args.vocab,
-1,
memory_limit = 30 << 30,#memory limit 左边的参数根据gpu的显存设置如果是A100可以设置为 72 << 30这样的话就可以用到更多的显存
model_config=model_config,
load_model=False,
)
model.load_model_pt(args.pt)
datas = [
'''<用户>马化腾是谁?<AI>''',
'''<用户>你是谁?<AI>''',
'''<用户>我要参加一个高性能会议,请帮我写一个致辞。<AI>''',
]
for data in datas:
res = model.inference(data, max_length=4096)
print(res['result'])
if __name__ == "__main__":
main()
```
## 常见问题
1. Conda安装pytorch时卡在solving environment网络问题。
解决方法: