54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
import sys
|
|
import yaml
|
|
|
|
def main():
|
|
run_type = sys.argv[1]
|
|
model = sys.argv[2]
|
|
max_steps = sys.argv[3]
|
|
run_name = sys.argv[4]
|
|
output_dir = sys.argv[5]
|
|
|
|
if run_type == "lora_sft":
|
|
yaml_file = './results/lora_sft_template.yaml'
|
|
elif run_type == "inference":
|
|
yaml_file = './results/predict_template.yaml'
|
|
|
|
model_name_or_path = ""
|
|
template = ""
|
|
if model == "9g-8B":
|
|
model_name_or_path = "/home/ma-user/models/8b_sft_model"
|
|
template = "cpm"
|
|
elif model == "Baichuan2-7B":
|
|
model_name_or_path = "/home/ma-user/models/Baichuan2-7B-Base"
|
|
template = "baichuan2"
|
|
elif model == "ChatGLM2-6B":
|
|
model_name_or_path = "/home/ma-user/models/chatglm2-6b"
|
|
template = "chatglm2"
|
|
elif model == "Llama2-7B":
|
|
model_name_or_path = "/home/ma-user/models/llama-2-7b-ms"
|
|
template = "llama2"
|
|
elif model == "Qwen-7B":
|
|
model_name_or_path = "/home/ma-user/models/qwen"
|
|
template = "qwen"
|
|
else:
|
|
print("ERROR: model not supported or model name wrong")
|
|
sys.exit()
|
|
|
|
config = None
|
|
with open(yaml_file, 'r', encoding='utf-8') as f:
|
|
config = yaml.load(f.read(), Loader=yaml.FullLoader)
|
|
|
|
|
|
config['model_name_or_path'] = model_name_or_path
|
|
config['template'] = template
|
|
config['output_dir'] = output_dir
|
|
if run_type == "lora_sft":
|
|
config['max_steps'] = int(max_steps)
|
|
|
|
with open(f'{output_dir}/{run_name}.yaml', 'w', encoding='utf-8') as f:
|
|
yaml.dump(data=config, stream=f, allow_unicode=True)
|
|
|
|
print(f"yaml file saved to {output_dir}/{run_name}.yaml")
|
|
|
|
if __name__ == "__main__":
|
|
main() |