update trainers

This commit is contained in:
hiyouga 2024-06-06 18:45:49 +08:00
parent 67aa78cde0
commit fad2591e31
4 changed files with 12 additions and 21 deletions

View File

@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from ..data import Role as DataRole from ..data import Role as DataRole
from ..extras.logging import get_logger from ..extras.logging import get_logger
from ..extras.packages import is_fastapi_available, is_pillow_available from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
from .common import dictify, jsonify from .common import dictify, jsonify
from .protocol import ( from .protocol import (
ChatCompletionMessage, ChatCompletionMessage,
@ -29,10 +29,13 @@ if is_fastapi_available():
if is_pillow_available(): if is_pillow_available():
import requests
from PIL import Image from PIL import Image
if is_requests_available():
import requests
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray from numpy.typing import NDArray

View File

@ -187,13 +187,7 @@ class CustomDPOTrainer(DPOTrainer):
ref_context = nullcontext() ref_context = nullcontext()
with torch.no_grad(), ref_context: with torch.no_grad(), ref_context:
( reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)
reference_chosen_logps,
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(ref_model, batch)
return reference_chosen_logps, reference_rejected_logps return reference_chosen_logps, reference_rejected_logps

View File

@ -146,15 +146,8 @@ class CustomKTOTrainer(KTOTrainer):
if len(target_logps) != len(batch["kto_tags"]): if len(target_logps) != len(batch["kto_tags"]):
raise ValueError("Mismatched shape of inputs and labels.") raise ValueError("Mismatched shape of inputs and labels.")
chosen_idx = [i for i in range(len(target_logps)) if batch["kto_tags"][i]] chosen_logps, rejected_logps = target_logps[batch["kto_tags"]], target_logps[~batch["kto_tags"]]
rejected_idx = [i for i in range(len(target_logps)) if not batch["kto_tags"][i]] chosen_logits, rejected_logits = target_logits[batch["kto_tags"]], target_logits[~batch["kto_tags"]]
chosen_logps = target_logps[chosen_idx, ...]
rejected_logps = target_logps[rejected_idx, ...]
chosen_logits = target_logits[chosen_idx, ...]
rejected_logits = target_logits[rejected_idx, ...]
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
def compute_reference_log_probs( def compute_reference_log_probs(

View File

@ -8,13 +8,14 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from ...extras.packages import is_requests_available from ...extras.packages import is_requests_available
if is_requests_available():
import requests
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedModel from transformers import PreTrainedModel
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
if is_requests_available():
import requests
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]: def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}