diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 90fe1b81..1ef99d9f 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -59,6 +59,7 @@ class HuggingfaceEngine(BaseEngine): messages[0]["content"] = "" + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] + system = system or generating_args["default_system"] prompt_ids, _ = template.encode_oneturn( tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools ) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index ba0cc1b3..2e8ecd0c 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -96,6 +96,7 @@ class VllmEngine(BaseEngine): messages[0]["content"] = "" * self.image_feature_size + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] + system = system or self.generating_args["default_system"] prompt_ids, _ = self.template.encode_oneturn( tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools ) diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py index e792c003..0ee17d1a 100644 --- a/src/llamafactory/hparams/generating_args.py +++ b/src/llamafactory/hparams/generating_args.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass, field -from typing import Any, Dict +from typing import Any, Dict, Optional @dataclass @@ -46,6 +46,10 @@ class GeneratingArguments: default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, ) + default_system: Optional[str] = field( + default=None, + metadata={"help": "Default system message to use in chat completion."}, + ) def to_dict(self) -> Dict[str, Any]: args = asdict(self)