fix webui

This commit is contained in:
hiyouga 2023-08-03 12:43:12 +08:00
parent 08f180e788
commit e23a3a366c
3 changed files with 3 additions and 2 deletions

View File

@ -128,7 +128,7 @@ def load_model_and_tokenizer(
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
if stage == "rm" or stage == "ppo": # add value head if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging() reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model

View File

@ -10,7 +10,7 @@ from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, args: Optional[Dict[str, Any]]) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
self.generating_args = GeneratingArguments() self.generating_args = GeneratingArguments()

View File

@ -1,5 +1,6 @@
from llmtuner import run_exp from llmtuner import run_exp
def main(): def main():
run_exp() run_exp()