fix ChatGLM lm_head #494

This commit is contained in:
hiyouga 2023-08-14 14:14:48 +08:00
parent 20a29297b1
commit d019956808
3 changed files with 12 additions and 8 deletions

View File

@ -153,6 +153,10 @@ def load_model_and_tokenizer(
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)
if not hasattr(model, "lm_head"):
setattr(model, "lm_head", model.transformer.output_layer)
# Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class()

View File

@ -32,11 +32,11 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs)
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
if ref_model is not None:
if hasattr(self, "accelerator"):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
else:
raise AttributeError("Please update `transformers`.")
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward(
self,

View File

@ -45,7 +45,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
with gr.Box():
output_box = gr.Markdown()
input_list = [
input_components = [
top_elems["lang"],
top_elems["model_name"],
top_elems["checkpoints"],
@ -62,13 +62,13 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
predict
]
output_list = [
output_components = [
output_box,
process_bar
]
cmd_preview_btn.click(runner.preview_eval, input_list, output_list)
start_btn.click(runner.run_eval, input_list, output_list)
cmd_preview_btn.click(runner.preview_eval, input_components, output_components)
start_btn.click(runner.run_eval, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False)
return dict(