fix webui
This commit is contained in:
parent
cb42676694
commit
b240b6792f
|
@ -65,6 +65,8 @@ def get_dataset(
|
|||
max_samples_temp = min(len(dataset), max_samples)
|
||||
dataset = dataset.select(range(max_samples_temp))
|
||||
|
||||
# TODO: adapt to the sharegpt format
|
||||
|
||||
for column_name in ["prompt", "query", "response", "history"]: # align datasets
|
||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
|
|
|
@ -138,13 +138,14 @@ class Runner:
|
|||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||
resume_lora_training=(
|
||||
False if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] else resume_lora_training
|
||||
),
|
||||
resume_lora_training=resume_lora_training,
|
||||
output_dir=output_dir
|
||||
)
|
||||
args[compute_type] = True
|
||||
|
||||
if TRAINING_STAGES[training_stage] in ["rm", "ppo", "dpo"] and args["quantization_bit"] is None:
|
||||
args["resume_lora_training"] = False
|
||||
|
||||
if args["quantization_bit"] is not None:
|
||||
args["upcast_layernorm"] = True
|
||||
|
||||
|
|
Loading…
Reference in New Issue