diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index 50892a54..98957bc1 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from ..data import Role as DataRole from ..extras.logging import get_logger -from ..extras.packages import is_fastapi_available, is_pillow_available +from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from .common import dictify, jsonify from .protocol import ( ChatCompletionMessage, @@ -29,10 +29,13 @@ if is_fastapi_available(): if is_pillow_available(): - import requests from PIL import Image +if is_requests_available(): + import requests + + if TYPE_CHECKING: from numpy.typing import NDArray diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 2bbe6a06..6f1da34e 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -187,13 +187,7 @@ class CustomDPOTrainer(DPOTrainer): ref_context = nullcontext() with torch.no_grad(), ref_context: - ( - reference_chosen_logps, - reference_rejected_logps, - _, - _, - _, - ) = self.concatenated_forward(ref_model, batch) + reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch) return reference_chosen_logps, reference_rejected_logps diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index f29945f5..03cad5a7 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -146,15 +146,8 @@ class CustomKTOTrainer(KTOTrainer): if len(target_logps) != len(batch["kto_tags"]): raise ValueError("Mismatched shape of inputs and labels.") - chosen_idx = [i for i in range(len(target_logps)) if batch["kto_tags"][i]] - rejected_idx = [i for i in range(len(target_logps)) if not batch["kto_tags"][i]] - - chosen_logps = target_logps[chosen_idx, ...] - rejected_logps = target_logps[rejected_idx, ...] - - chosen_logits = target_logits[chosen_idx, ...] - rejected_logits = target_logits[rejected_idx, ...] - + chosen_logps, rejected_logps = target_logps[batch["kto_tags"]], target_logps[~batch["kto_tags"]] + chosen_logits, rejected_logits = target_logits[batch["kto_tags"]], target_logits[~batch["kto_tags"]] return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps def compute_reference_log_probs( diff --git a/src/llamafactory/train/ppo/utils.py b/src/llamafactory/train/ppo/utils.py index e6bdb89c..e5025581 100644 --- a/src/llamafactory/train/ppo/utils.py +++ b/src/llamafactory/train/ppo/utils.py @@ -8,13 +8,14 @@ from transformers.integrations import is_deepspeed_zero3_enabled from ...extras.packages import is_requests_available +if is_requests_available(): + import requests + + if TYPE_CHECKING: from transformers import PreTrainedModel from trl import AutoModelForCausalLMWithValueHead -if is_requests_available(): - import requests - def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]: headers = {"Content-Type": "application/json"}