support train from scratch #4033 #4075

This commit is contained in:
hiyouga 2024-06-06 02:43:19 +08:00
parent 946f601136
commit a12a506c3d
2 changed files with 6 additions and 0 deletions

View File

@ -101,6 +101,10 @@ class ModelArguments:
default=False,
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
)
train_from_scratch: bool = field(
default=False,
metadata={"help": "Whether or not to randomly initialize the model weights."},
)
infer_backend: Literal["huggingface", "vllm"] = field(
default="huggingface",
metadata={"help": "Backend engine used at inference."},

View File

@ -131,6 +131,8 @@ def load_model(
model = load_mod_pretrained_model(**init_kwargs)
elif model_args.visual_inputs:
model = AutoModelForVision2Seq.from_pretrained(**init_kwargs)
elif model_args.train_from_scratch:
model = AutoModelForCausalLM.from_config(config)
else:
model = AutoModelForCausalLM.from_pretrained(**init_kwargs)