diff --git a/batch_run.sh b/batch_run.sh index 1f3329e1..fd9d5f59 100644 --- a/batch_run.sh +++ b/batch_run.sh @@ -1 +1 @@ -bash run_once.sh lora_sft Baichuan-7B 4 50 +bash run_once.sh lora_sft Qwen-7B 4 50 diff --git a/prepare_yaml_file.py b/prepare_yaml_file.py index 9ea58734..69cb300e 100644 --- a/prepare_yaml_file.py +++ b/prepare_yaml_file.py @@ -19,8 +19,8 @@ def main(): model_name_or_path = "../../models/sft_8b_v2" template = "" elif model == "Baichuan2-7B": - model_name_or_path = "../../models/Baichuan-7B" - template = "baichuan" + model_name_or_path = "../../models/Baichuan2-7B" + template = "baichuan2" elif model == "ChatGLM2-6B": model_name_or_path = "../../models/chatglm2-6b" template = "chatglm2" @@ -30,11 +30,15 @@ def main(): elif model == "Qwen-7B": model_name_or_path = "../../models/Qwen-7B" template = "qwen" + else: + print("ERROR: model not supported.") + sys.exit() config = None with open(yaml_file, 'r', encoding='utf-8') as f: config = yaml.load(f.read(), Loader=yaml.FullLoader) - + + config['model_name_or_path'] = model_name_or_path config['template'] = template config['output_dir'] = output_dir