LLaMA-Factory-310P3/prepare_yaml_file.py

54 lines
1.6 KiB
Python
Raw Normal View History

2024-09-05 11:28:19 +08:00
import sys
import yaml
def main():
run_type = sys.argv[1]
model = sys.argv[2]
max_steps = sys.argv[3]
run_name = sys.argv[4]
output_dir = sys.argv[5]
if run_type == "lora_sft":
2024-09-05 13:03:22 +08:00
yaml_file = './results/lora_sft_template.yaml'
2024-09-05 11:28:19 +08:00
elif run_type == "inference":
2024-09-05 13:03:22 +08:00
yaml_file = './results/predict_template.yaml'
2024-09-05 13:04:54 +08:00
model_name_or_path = ""
template = ""
2024-09-05 11:28:19 +08:00
if model == "9g-8B":
2024-09-19 13:52:25 +08:00
model_name_or_path = "../../../models/sft_8b_v2"
2024-09-11 09:39:52 +08:00
template = "default"
2024-09-05 11:28:19 +08:00
elif model == "Baichuan2-7B":
2024-09-19 13:52:25 +08:00
model_name_or_path = "../../../models/Baichuan2-7B-Base"
2024-09-05 13:21:38 +08:00
template = "baichuan2"
2024-09-05 11:28:19 +08:00
elif model == "ChatGLM2-6B":
2024-09-19 13:52:25 +08:00
model_name_or_path = "../../../models/chatglm2-6b"
2024-09-05 11:28:19 +08:00
template = "chatglm2"
elif model == "Llama2-7B":
2024-09-19 13:52:25 +08:00
model_name_or_path = "../../../models/llama-2-7b-ms"
2024-09-05 11:28:19 +08:00
template = "llama2"
elif model == "Qwen-7B":
2024-09-18 15:26:05 +08:00
model_name_or_path = "../../../models/qwen"
2024-09-05 11:28:19 +08:00
template = "qwen"
2024-09-05 13:21:38 +08:00
else:
2024-09-25 10:21:25 +08:00
print("ERROR: model not supported or model name wrong")
2024-09-05 13:21:38 +08:00
sys.exit()
2024-09-05 11:28:19 +08:00
config = None
with open(yaml_file, 'r', encoding='utf-8') as f:
config = yaml.load(f.read(), Loader=yaml.FullLoader)
2024-09-05 13:21:38 +08:00
2024-09-05 11:28:19 +08:00
config['model_name_or_path'] = model_name_or_path
config['template'] = template
config['output_dir'] = output_dir
if run_type == "lora_sft":
2024-09-05 15:54:33 +08:00
config['max_steps'] = int(max_steps)
2024-09-05 11:28:19 +08:00
2024-09-05 13:09:43 +08:00
with open(f'{output_dir}/{run_name}.yaml', 'w', encoding='utf-8') as f:
2024-09-05 11:28:19 +08:00
yaml.dump(data=config, stream=f, allow_unicode=True)
2024-09-05 13:37:17 +08:00
print(f"yaml file saved to {output_dir}/{run_name}.yaml")
2024-09-05 11:28:19 +08:00
if __name__ == "__main__":
main()