fix: add not supported model err msg

This commit is contained in:
wql 2024-09-05 13:21:38 +08:00
parent 8162a54aa5
commit 64044380bd
2 changed files with 8 additions and 4 deletions

View File

@ -1 +1 @@
bash run_once.sh lora_sft Baichuan-7B 4 50 bash run_once.sh lora_sft Qwen-7B 4 50

View File

@ -19,8 +19,8 @@ def main():
model_name_or_path = "../../models/sft_8b_v2" model_name_or_path = "../../models/sft_8b_v2"
template = "" template = ""
elif model == "Baichuan2-7B": elif model == "Baichuan2-7B":
model_name_or_path = "../../models/Baichuan-7B" model_name_or_path = "../../models/Baichuan2-7B"
template = "baichuan" template = "baichuan2"
elif model == "ChatGLM2-6B": elif model == "ChatGLM2-6B":
model_name_or_path = "../../models/chatglm2-6b" model_name_or_path = "../../models/chatglm2-6b"
template = "chatglm2" template = "chatglm2"
@ -30,11 +30,15 @@ def main():
elif model == "Qwen-7B": elif model == "Qwen-7B":
model_name_or_path = "../../models/Qwen-7B" model_name_or_path = "../../models/Qwen-7B"
template = "qwen" template = "qwen"
else:
print("ERROR: model not supported.")
sys.exit()
config = None config = None
with open(yaml_file, 'r', encoding='utf-8') as f: with open(yaml_file, 'r', encoding='utf-8') as f:
config = yaml.load(f.read(), Loader=yaml.FullLoader) config = yaml.load(f.read(), Loader=yaml.FullLoader)
config['model_name_or_path'] = model_name_or_path config['model_name_or_path'] = model_name_or_path
config['template'] = template config['template'] = template
config['output_dir'] = output_dir config['output_dir'] = output_dir