diff --git a/prepare_yaml_file.py b/prepare_yaml_file.py index 10931229..4c864328 100644 --- a/prepare_yaml_file.py +++ b/prepare_yaml_file.py @@ -12,7 +12,9 @@ def main(): yaml_file = './results/lora_sft_template.yaml' elif run_type == "inference": yaml_file = './results/predict_template.yaml' - + + model_name_or_path = "" + template = "" if model == "9g-8B": model_name_or_path = "../../models/sft_8b_v2" template = ""