update trainers
This commit is contained in:
parent
67aa78cde0
commit
fad2591e31
|
@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from ..data import Role as DataRole
|
from ..data import Role as DataRole
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from ..extras.packages import is_fastapi_available, is_pillow_available
|
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||||
from .common import dictify, jsonify
|
from .common import dictify, jsonify
|
||||||
from .protocol import (
|
from .protocol import (
|
||||||
ChatCompletionMessage,
|
ChatCompletionMessage,
|
||||||
|
@ -29,10 +29,13 @@ if is_fastapi_available():
|
||||||
|
|
||||||
|
|
||||||
if is_pillow_available():
|
if is_pillow_available():
|
||||||
import requests
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
if is_requests_available():
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
|
|
@ -187,13 +187,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||||
ref_context = nullcontext()
|
ref_context = nullcontext()
|
||||||
|
|
||||||
with torch.no_grad(), ref_context:
|
with torch.no_grad(), ref_context:
|
||||||
(
|
reference_chosen_logps, reference_rejected_logps, *_ = self.concatenated_forward(ref_model, batch)
|
||||||
reference_chosen_logps,
|
|
||||||
reference_rejected_logps,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
) = self.concatenated_forward(ref_model, batch)
|
|
||||||
|
|
||||||
return reference_chosen_logps, reference_rejected_logps
|
return reference_chosen_logps, reference_rejected_logps
|
||||||
|
|
||||||
|
|
|
@ -146,15 +146,8 @@ class CustomKTOTrainer(KTOTrainer):
|
||||||
if len(target_logps) != len(batch["kto_tags"]):
|
if len(target_logps) != len(batch["kto_tags"]):
|
||||||
raise ValueError("Mismatched shape of inputs and labels.")
|
raise ValueError("Mismatched shape of inputs and labels.")
|
||||||
|
|
||||||
chosen_idx = [i for i in range(len(target_logps)) if batch["kto_tags"][i]]
|
chosen_logps, rejected_logps = target_logps[batch["kto_tags"]], target_logps[~batch["kto_tags"]]
|
||||||
rejected_idx = [i for i in range(len(target_logps)) if not batch["kto_tags"][i]]
|
chosen_logits, rejected_logits = target_logits[batch["kto_tags"]], target_logits[~batch["kto_tags"]]
|
||||||
|
|
||||||
chosen_logps = target_logps[chosen_idx, ...]
|
|
||||||
rejected_logps = target_logps[rejected_idx, ...]
|
|
||||||
|
|
||||||
chosen_logits = target_logits[chosen_idx, ...]
|
|
||||||
rejected_logits = target_logits[rejected_idx, ...]
|
|
||||||
|
|
||||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, kl_logps
|
||||||
|
|
||||||
def compute_reference_log_probs(
|
def compute_reference_log_probs(
|
||||||
|
|
|
@ -8,13 +8,14 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from ...extras.packages import is_requests_available
|
from ...extras.packages import is_requests_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_requests_available():
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
if is_requests_available():
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
|
def get_rewards_from_server(server_url: str, messages: List[str]) -> List[torch.Tensor]:
|
||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
Loading…
Reference in New Issue