forked from jiuyuan/CPM-9G-8B
add pretrain gradient acculation function
This commit is contained in:
parent
86d5191e13
commit
d8571281cc
|
@ -288,7 +288,7 @@ def pretrain(
|
|||
input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32)
|
||||
|
||||
# ===========
|
||||
optim_manager.zero_grad()
|
||||
# optim_manager.zero_grad()
|
||||
# torch.cuda.empty_cache()
|
||||
mem_usage = {}
|
||||
tim_usage = {}
|
||||
|
@ -322,8 +322,16 @@ def pretrain(
|
|||
|
||||
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
||||
# ===========
|
||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||
optim_manager.step()
|
||||
# ===========
|
||||
if iteration % args.gradient_accumulation_steps == 0:
|
||||
grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2)
|
||||
optim_manager.step()
|
||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||
optim_manager.zero_grad()
|
||||
else:
|
||||
grad_norm = None
|
||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||
|
||||
mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage)
|
||||
# bmt.print_rank(torch.cuda.max_memory_allocated())
|
||||
|
||||
|
@ -410,7 +418,7 @@ def pretrain(
|
|||
"token_max": local_total_rate,
|
||||
"token_pass": global_token_pass,
|
||||
"throughout": args.max_length * args.batch_size * local_total_rate / avg_time,
|
||||
"grad_norm": grad_norm.item(),
|
||||
"grad_norm": 0 if grad_norm == None else grad_norm.item(),
|
||||
"mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(),
|
||||
"num_gpus": global_world_size,
|
||||
"task_loss": task_loss_map,
|
||||
|
@ -433,7 +441,7 @@ def pretrain(
|
|||
(targets >= 0).sum(-1).float().mean()
|
||||
/ args.max_length
|
||||
/ (args.batch_size if args.flash == "cuda" else 1),
|
||||
grad_norm,
|
||||
0 if grad_norm == None else grad_norm,
|
||||
max(mem_usage["forward"][1], mem_usage["backward"][1]),
|
||||
)
|
||||
)
|
||||
|
@ -455,7 +463,7 @@ def pretrain(
|
|||
writer.add_scalar("Loss/train", global_loss, iteration)
|
||||
writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration)
|
||||
writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration)
|
||||
writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration)
|
||||
writer.add_scalar("Optimizer/grad_norm", 0 if grad_norm == None else grad_norm.item(), iteration)
|
||||
for task_name, loss in task_loss_map.items():
|
||||
writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)
|
||||
|
||||
|
|
Loading…
Reference in New Issue