import sys import yaml def main(): run_type = sys.argv[1] model = sys.argv[2] max_steps = sys.argv[3] run_name = sys.argv[4] output_dir = sys.argv[5] if run_type == "lora_sft": yaml_file = './results/lora_sft_template.yaml' elif run_type == "inference": yaml_file = './results/predict_template.yaml' model_name_or_path = "" template = "" if model == "9g-8B": model_name_or_path = "/home/ma-user/models/8b_sft_model" template = "cpm" elif model == "Baichuan2-7B": model_name_or_path = "/home/ma-user/models/Baichuan2-7B-Base" template = "baichuan2" elif model == "ChatGLM2-6B": model_name_or_path = "/home/ma-user/models/chatglm2-6b" template = "chatglm2" elif model == "Llama2-7B": model_name_or_path = "/home/ma-user/models/llama-2-7b-ms" template = "llama2" elif model == "Qwen-7B": model_name_or_path = "/home/ma-user/models/qwen" template = "qwen" else: print("ERROR: model not supported or model name wrong") sys.exit() config = None with open(yaml_file, 'r', encoding='utf-8') as f: config = yaml.load(f.read(), Loader=yaml.FullLoader) config['model_name_or_path'] = model_name_or_path config['template'] = template config['output_dir'] = output_dir if run_type == "lora_sft": config['max_steps'] = int(max_steps) with open(f'{output_dir}/{run_name}.yaml', 'w', encoding='utf-8') as f: yaml.dump(data=config, stream=f, allow_unicode=True) print(f"yaml file saved to {output_dir}/{run_name}.yaml") if __name__ == "__main__": main()