This commit is contained in:
hiyouga 2024-01-08 21:42:25 +08:00
parent 0ed526cedf
commit 3ae735ffe8
2 changed files with 5 additions and 5 deletions

View File

@ -6,6 +6,7 @@ peft>=0.7.0
trl>=0.7.6
gradio>=3.38.0,<4.0.0
scipy
einops
sentencepiece
protobuf
tiktoken
@ -17,4 +18,3 @@ pydantic
fastapi
sse-starlette
matplotlib
einops

View File

@ -118,13 +118,13 @@ class Runner:
logging_steps=get("train.logging_steps"),
save_steps=get("train.save_steps"),
warmup_steps=get("train.warmup_steps"),
neftune_noise_alpha=get("train.neftune_alpha"),
neftune_noise_alpha=get("train.neftune_alpha") or None,
train_on_prompt=get("train.train_on_prompt"),
upcast_layernorm=get("train.upcast_layernorm"),
lora_rank=get("train.lora_rank"),
lora_dropout=get("train.lora_dropout"),
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
additional_target=get("train.additional_target") if get("train.additional_target") else None,
additional_target=get("train.additional_target") or None,
create_new_adapter=get("train.create_new_adapter"),
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
)
@ -164,7 +164,6 @@ class Runner:
args = dict(
stage="sft",
do_eval=True,
model_name_or_path=get("top.model_path"),
adapter_name_or_path=adapter_name_or_path,
cache_dir=user_config.get("cache_dir", None),
@ -187,8 +186,9 @@ class Runner:
)
if get("eval.predict"):
args.pop("do_eval", None)
args["do_predict"] = True
else:
args["do_eval"] = True
return args