forked from jiuyuan/CPM-9G-8B
add: add new sh
This commit is contained in:
parent
7149c4298d
commit
da66a01213
|
@ -0,0 +1,32 @@
|
|||
import os
|
||||
from libcpm import CPM9G
|
||||
import argparse, json, os
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pt", type=str, help="the path of ckpt")
|
||||
parser.add_argument("--config", type=str, help="the path of config file")
|
||||
parser.add_argument("--vocab", type=str, help="the path of vocab file")
|
||||
args = parser.parse_args()
|
||||
model_config = json.load(open(args.config, 'r'))
|
||||
model_config["new_vocab"] = True
|
||||
model = CPM9G(
|
||||
"",
|
||||
args.vocab,
|
||||
-1,
|
||||
memory_limit = 30 << 30,#memory limit 左边的参数根据gpu的显存设置,如果是A100,可以设置为 72 << 30,这样的话就可以用到更多的显存
|
||||
model_config=model_config,
|
||||
load_model=False,
|
||||
)
|
||||
model.load_model_pt(args.pt)
|
||||
datas = [
|
||||
'''<用户>马化腾是谁?<AI>''',
|
||||
'''<用户>你是谁?<AI>''',
|
||||
'''<用户>我要参加一个高性能会议,请帮我写一个致辞。<AI>''',
|
||||
]
|
||||
for data in datas:
|
||||
res = model.inference(data, max_length=4096)
|
||||
print(res['result'])
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,57 @@
|
|||
#! /bin/bash
|
||||
|
||||
export MASTER_ADDR="localhost"
|
||||
export MASTER_PORT=12370
|
||||
|
||||
CPM_PATH="/workspace/repo/CPM-9G-8B/9G-Train"
|
||||
|
||||
NO=0
|
||||
GPU_NUM=multi
|
||||
MAX_STEP=500
|
||||
EXP_PATH=/workspace/repo/CPM-9G-8B/results/lora/${GPU_NUM}/${MAX_STEP}/${NO}/models/
|
||||
MODEL_NAME="9g-sft"
|
||||
TB_PATH="/workspace/repo/CPM-9G-8B/results/lora/${GPU_NUM}/${MAX_STEP}/${NO}/logs/"
|
||||
|
||||
OPTS=""
|
||||
OPTS+=" --vocab /v2/sft_8b_v2/vocab.txt"
|
||||
OPTS+=" --model-config /v2/sft_8b_v2/config.json"
|
||||
|
||||
OPTS+=" --train-iters 500"
|
||||
OPTS+=" --lr_decay_iters 500"
|
||||
OPTS+=" --inspect-iters 500"
|
||||
OPTS+=" --warmup-iters 20"
|
||||
|
||||
OPTS+=" --lr-decay-style cosine"
|
||||
OPTS+=" --weight-decay 0.01"
|
||||
OPTS+=" --clip-grad 1.0"
|
||||
OPTS+=" --loss-scale 1048576"
|
||||
OPTS+=" --max-loss-scale 33554432"
|
||||
OPTS+=" --min-loss-scale 1"
|
||||
OPTS+=" --loss-scale-steps 32"
|
||||
|
||||
OPTS+=" --offload"
|
||||
OPTS+=" --batch-size 2"
|
||||
OPTS+=" --max-length 4096"
|
||||
OPTS+=" --lr 3e-4"
|
||||
OPTS+=" --start-step 0"
|
||||
OPTS+=" --epoch 2"
|
||||
OPTS+=" --load /v2/sft_8b_v2/cpm_live_8b-1500-float16.pt"
|
||||
OPTS+=" --dataset /workspace/repo/CPM-9G-8B/dataset_bin"
|
||||
# TODO 这些 /data 在启元机器上需要改成 /home 下的路径
|
||||
OPTS+=" --save ${EXP_PATH}/checkpoints"
|
||||
OPTS+=" --save-name ${MODEL_NAME}"
|
||||
OPTS+=" --tensorboard ${TB_PATH}"
|
||||
|
||||
OPTS+=" --delta-tuning"
|
||||
OPTS+=" --delta-type lora"
|
||||
OPTS+=" --lora-r 64" # 常用的lora 参数
|
||||
OPTS+=" --lora-dropout 0.05"
|
||||
OPTS+=" --lora-alpha 64" # 常用的lora alpha 参数
|
||||
OPTS+=" --lora-layer project_q project_v project_k w_0 w_1 w_out"
|
||||
|
||||
OPTS+=" $@"
|
||||
|
||||
CMD="torchrun --nnodes=1 --nproc_per_node=7 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} ${CPM_PATH}/apps/cpm9g/sft_cpm9g_delta.py ${OPTS}"
|
||||
|
||||
echo "${CMD}"
|
||||
$CMD
|
Loading…
Reference in New Issue