tiny fix
This commit is contained in:
parent
dbe886ae5c
commit
f6ae4e75dd
|
@ -107,15 +107,14 @@ def _process_request(
|
|||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
|
||||
else:
|
||||
image_url = input_item.image_url.url
|
||||
if re.match("^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url):
|
||||
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
|
||||
image_path = io.BytesIO(image_data)
|
||||
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
|
||||
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
|
||||
elif os.path.isfile(image_url): # local file
|
||||
image_path = open(image_url, "rb")
|
||||
image_stream = open(image_url, "rb")
|
||||
else: # web uri
|
||||
image_path = requests.get(image_url, stream=True).raw
|
||||
image_stream = requests.get(image_url, stream=True).raw
|
||||
|
||||
image = np.array(Image.open(image_path).convert("RGB"))
|
||||
image = np.array(Image.open(image_stream).convert("RGB"))
|
||||
else:
|
||||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})
|
||||
|
||||
|
|
|
@ -156,6 +156,18 @@ def get_logits_processor() -> "LogitsProcessorList":
|
|||
return logits_processor
|
||||
|
||||
|
||||
def get_peak_memory() -> Tuple[int, int]:
|
||||
r"""
|
||||
Gets the peak memory usage for the current device (in Bytes).
|
||||
"""
|
||||
if is_torch_npu_available():
|
||||
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
|
||||
else:
|
||||
return 0, 0
|
||||
|
||||
|
||||
def has_tokenized_data(path: "os.PathLike") -> bool:
|
||||
r"""
|
||||
Checks if the path has a tokenized dataset.
|
||||
|
|
|
@ -35,6 +35,7 @@ from transformers.utils import (
|
|||
|
||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.logging import LoggerHandler, get_logger
|
||||
from ..extras.misc import get_peak_memory
|
||||
|
||||
|
||||
if is_safetensors_available():
|
||||
|
@ -304,14 +305,21 @@ class LogCallback(TrainerCallback):
|
|||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time,
|
||||
throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
|
||||
total_tokens=state.num_input_tokens_seen,
|
||||
)
|
||||
if state.num_input_tokens_seen:
|
||||
logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2)
|
||||
logs["total_tokens"] = state.num_input_tokens_seen
|
||||
|
||||
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
|
||||
vram_allocated, vram_reserved = get_peak_memory()
|
||||
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2)
|
||||
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2)
|
||||
|
||||
logs = {k: v for k, v in logs.items() if v is not None}
|
||||
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
||||
logger.info(
|
||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
|
||||
logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"]
|
||||
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput")
|
||||
)
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue