This commit is contained in:
hiyouga 2024-08-27 12:49:32 +08:00
parent dbe886ae5c
commit f6ae4e75dd
3 changed files with 28 additions and 9 deletions

View File

@ -107,15 +107,14 @@ def _process_request(
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text})
else:
image_url = input_item.image_url.url
if re.match("^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url):
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1])
image_path = io.BytesIO(image_data)
if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image
image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1]))
elif os.path.isfile(image_url): # local file
image_path = open(image_url, "rb")
image_stream = open(image_url, "rb")
else: # web uri
image_path = requests.get(image_url, stream=True).raw
image_stream = requests.get(image_url, stream=True).raw
image = np.array(Image.open(image_path).convert("RGB"))
image = np.array(Image.open(image_stream).convert("RGB"))
else:
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content})

View File

@ -156,6 +156,18 @@ def get_logits_processor() -> "LogitsProcessorList":
return logits_processor
def get_peak_memory() -> Tuple[int, int]:
r"""
Gets the peak memory usage for the current device (in Bytes).
"""
if is_torch_npu_available():
return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved()
elif is_torch_cuda_available():
return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved()
else:
return 0, 0
def has_tokenized_data(path: "os.PathLike") -> bool:
r"""
Checks if the path has a tokenized dataset.

View File

@ -35,6 +35,7 @@ from transformers.utils import (
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import LoggerHandler, get_logger
from ..extras.misc import get_peak_memory
if is_safetensors_available():
@ -304,14 +305,21 @@ class LogCallback(TrainerCallback):
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
elapsed_time=self.elapsed_time,
remaining_time=self.remaining_time,
throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)),
total_tokens=state.num_input_tokens_seen,
)
if state.num_input_tokens_seen:
logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2)
logs["total_tokens"] = state.num_input_tokens_seen
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]:
vram_allocated, vram_reserved = get_peak_memory()
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2)
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2)
logs = {k: v for k, v in logs.items() if v is not None}
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
logger.info(
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"]
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput")
)
)