forked from p04798526/LLaMA-Factory-Mirror
fix ChatGLM lm_head #494
This commit is contained in:
parent
20a29297b1
commit
d019956808
|
@ -153,6 +153,10 @@ def load_model_and_tokenizer(
|
|||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Fix LM head (for ChatGLM2)
|
||||
if not hasattr(model, "lm_head"):
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
|
|
|
@ -32,11 +32,11 @@ class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
|
|||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, **kwargs)
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
if ref_model is not None:
|
||||
if hasattr(self, "accelerator"):
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
else:
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
|
|
|
@ -45,7 +45,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||
with gr.Box():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
input_list = [
|
||||
input_components = [
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
|
@ -62,13 +62,13 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
|
|||
predict
|
||||
]
|
||||
|
||||
output_list = [
|
||||
output_components = [
|
||||
output_box,
|
||||
process_bar
|
||||
]
|
||||
|
||||
cmd_preview_btn.click(runner.preview_eval, input_list, output_list)
|
||||
start_btn.click(runner.run_eval, input_list, output_list)
|
||||
cmd_preview_btn.click(runner.preview_eval, input_components, output_components)
|
||||
start_btn.click(runner.run_eval, input_components, output_components)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
|
||||
return dict(
|
||||
|
|
Loading…
Reference in New Issue