fix webui

This commit is contained in:
hiyouga 2023-10-13 16:27:59 +08:00
parent cb42676694
commit b240b6792f
2 changed files with 6 additions and 3 deletions

View File

@ -65,6 +65,8 @@ def get_dataset(
max_samples_temp = min(len(dataset), max_samples)
dataset = dataset.select(range(max_samples_temp))
# TODO: adapt to the sharegpt format
for column_name in ["prompt", "query", "response", "history"]: # align datasets
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)

View File

@ -138,13 +138,14 @@ class Runner:
lora_rank=lora_rank,
lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
resume_lora_training=(
False if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] else resume_lora_training
),
resume_lora_training=resume_lora_training,
output_dir=output_dir
)
args[compute_type] = True
if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] and args["quantization_bit"] is None:
args["resume_lora_training"] = False
if args["quantization_bit"] is not None:
args["upcast_layernorm"] = True