add pretrain gradient acculation function

This commit is contained in:
anrongqiao 2024-10-10 16:26:31 +08:00
parent 86d5191e13
commit d8571281cc
1 changed files with 14 additions and 6 deletions

View File

@ -288,7 +288,7 @@ def pretrain(
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
# ===========
optim_manager.zero_grad()
# optim_manager.zero_grad()
# torch.cuda.empty_cache()
mem_usage = {}
tim_usage = {}
@ -322,8 +322,16 @@ def pretrain(
# bmt.print_rank(torch.cuda.max_memory_allocated())
# ===========
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
optim_manager.step()
# ===========
if iteration % args.gradient_accumulation_steps == 0:
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
optim_manager.step()
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
optim_manager.zero_grad()
else:
grad_norm = None
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
# bmt.print_rank(torch.cuda.max_memory_allocated())
@ -410,7 +418,7 @@ def pretrain(
"token_max": local_total_rate,
"token_pass": global_token_pass,
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
"grad_norm": grad_norm.item(),
"grad_norm": 0 if grad_norm == None else grad_norm.item(),
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
"num_gpus": global_world_size,
"task_loss": task_loss_map,
@ -433,7 +441,7 @@ def pretrain(
(targets >= 0).sum(-1).float().mean()
/ args.max_length
/ (args.batch_size if args.flash == "cuda" else 1),
grad_norm,
0 if grad_norm == None else grad_norm,
max(mem_usage["forward"][1], mem_usage["backward"][1]),
)
)
@ -455,7 +463,7 @@ def pretrain(
writer.add_scalar("Loss/train", global_loss, iteration)
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
writer.add_scalar("Optimizer/grad_norm", 0 if grad_norm == None else grad_norm.item(), iteration)
for task_name, loss in task_loss_map.items():
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)