fix rm server

This commit is contained in:
hiyouga 2024-01-03 15:30:46 +08:00
parent 3014e3c189
commit 55021097d5
3 changed files with 4 additions and 2 deletions

View File

@ -87,11 +87,11 @@ def load_model_and_tokenizer(
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
)
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
patch_model(model, tokenizer, model_args, is_trainable)
register_autoclass(config, model, tokenizer)
@ -113,6 +113,7 @@ def load_model_and_tokenizer(
if not is_trainable:
model.requires_grad_(False)
model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
model.eval()
else:
model.train()

View File

@ -276,5 +276,6 @@ def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
setattr(model, "_keys_to_ignore_on_save", ignore_modules)
setattr(model, "_no_split_modules", getattr(model.pretrained_model, "_no_split_modules", None))
setattr(model, "tie_weights", MethodType(tie_weights, model))
setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))

View File

@ -27,7 +27,7 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory
if model._no_split_modules is None:
if getattr(model, "_no_split_modules", None) is None:
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}