This commit is contained in:
hiyouga 2024-05-14 20:37:21 +08:00
parent e8b97d2f79
commit cfaee8b4cf
1 changed files with 9 additions and 0 deletions

View File

@ -21,6 +21,9 @@ def smooth(scalars: List[float]) -> List[float]:
r"""
EMA implementation according to TensorBoard.
"""
if len(scalars) == 0:
return []
last = scalars[0]
smoothed = []
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
@ -32,6 +35,9 @@ def smooth(scalars: List[float]) -> List[float]:
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
r"""
Plots loss curves in LlamaBoard.
"""
plt.close("all")
plt.switch_backend("agg")
fig = plt.figure()
@ -51,6 +57,9 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
r"""
Plots loss curves and saves the image.
"""
plt.switch_backend("agg")
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
data = json.load(f)