From da66a012134dab266337988b6dc2301cb162783f Mon Sep 17 00:00:00 2001 From: wql Date: Mon, 26 Aug 2024 16:02:16 +0800 Subject: [PATCH] add: add new sh --- 9g_inference.py | 32 +++++++++++++++++++++++++++ 9g_lora_0.sh | 57 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) create mode 100644 9g_inference.py create mode 100644 9g_lora_0.sh diff --git a/9g_inference.py b/9g_inference.py new file mode 100644 index 0000000..f506a8d --- /dev/null +++ b/9g_inference.py @@ -0,0 +1,32 @@ +import os +from libcpm import CPM9G +import argparse, json, os + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--pt", type=str, help="the path of ckpt") + parser.add_argument("--config", type=str, help="the path of config file") + parser.add_argument("--vocab", type=str, help="the path of vocab file") + args = parser.parse_args() + model_config = json.load(open(args.config, 'r')) + model_config["new_vocab"] = True + model = CPM9G( + "", + args.vocab, + -1, + memory_limit = 30 << 30,#memory limit 左边的参数根据gpu的显存设置,如果是A100,可以设置为 72 << 30,这样的话就可以用到更多的显存 + model_config=model_config, + load_model=False, + ) + model.load_model_pt(args.pt) + datas = [ + '''<用户>马化腾是谁?''', + '''<用户>你是谁?''', + '''<用户>我要参加一个高性能会议,请帮我写一个致辞。''', + ] + for data in datas: + res = model.inference(data, max_length=4096) + print(res['result']) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/9g_lora_0.sh b/9g_lora_0.sh new file mode 100644 index 0000000..a83bd87 --- /dev/null +++ b/9g_lora_0.sh @@ -0,0 +1,57 @@ +#! /bin/bash + +export MASTER_ADDR="localhost" +export MASTER_PORT=12370 + +CPM_PATH="/workspace/repo/CPM-9G-8B/9G-Train" + +NO=0 +GPU_NUM=multi +MAX_STEP=500 +EXP_PATH=/workspace/repo/CPM-9G-8B/results/lora/${GPU_NUM}/${MAX_STEP}/${NO}/models/ +MODEL_NAME="9g-sft" +TB_PATH="/workspace/repo/CPM-9G-8B/results/lora/${GPU_NUM}/${MAX_STEP}/${NO}/logs/" + +OPTS="" +OPTS+=" --vocab /v2/sft_8b_v2/vocab.txt" +OPTS+=" --model-config /v2/sft_8b_v2/config.json" + +OPTS+=" --train-iters 500" +OPTS+=" --lr_decay_iters 500" +OPTS+=" --inspect-iters 500" +OPTS+=" --warmup-iters 20" + +OPTS+=" --lr-decay-style cosine" +OPTS+=" --weight-decay 0.01" +OPTS+=" --clip-grad 1.0" +OPTS+=" --loss-scale 1048576" +OPTS+=" --max-loss-scale 33554432" +OPTS+=" --min-loss-scale 1" +OPTS+=" --loss-scale-steps 32" + +OPTS+=" --offload" +OPTS+=" --batch-size 2" +OPTS+=" --max-length 4096" +OPTS+=" --lr 3e-4" +OPTS+=" --start-step 0" +OPTS+=" --epoch 2" +OPTS+=" --load /v2/sft_8b_v2/cpm_live_8b-1500-float16.pt" +OPTS+=" --dataset /workspace/repo/CPM-9G-8B/dataset_bin" +# TODO 这些 /data 在启元机器上需要改成 /home 下的路径 +OPTS+=" --save ${EXP_PATH}/checkpoints" +OPTS+=" --save-name ${MODEL_NAME}" +OPTS+=" --tensorboard ${TB_PATH}" + +OPTS+=" --delta-tuning" +OPTS+=" --delta-type lora" +OPTS+=" --lora-r 64" # 常用的lora 参数 +OPTS+=" --lora-dropout 0.05" +OPTS+=" --lora-alpha 64" # 常用的lora alpha 参数 +OPTS+=" --lora-layer project_q project_v project_k w_0 w_1 w_out" + +OPTS+=" $@" + +CMD="torchrun --nnodes=1 --nproc_per_node=7 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} ${CPM_PATH}/apps/cpm9g/sft_cpm9g_delta.py ${OPTS}" + +echo "${CMD}" +$CMD \ No newline at end of file