From d8571281ccb6d59994687868faf11957396832f2 Mon Sep 17 00:00:00 2001 From: anrongqiao <17710054230@163.com> Date: Thu, 10 Oct 2024 16:26:31 +0800 Subject: [PATCH] add pretrain gradient acculation function --- 9G-Train/apps/cpm9g/pretrain_cpm9g.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/9G-Train/apps/cpm9g/pretrain_cpm9g.py b/9G-Train/apps/cpm9g/pretrain_cpm9g.py index f64679d..e020ca5 100644 --- a/9G-Train/apps/cpm9g/pretrain_cpm9g.py +++ b/9G-Train/apps/cpm9g/pretrain_cpm9g.py @@ -288,7 +288,7 @@ def pretrain( input_span = torch.from_numpy(data["spans"]).cuda().to(torch.int32) # =========== - optim_manager.zero_grad() + # optim_manager.zero_grad() # torch.cuda.empty_cache() mem_usage = {} tim_usage = {} @@ -322,8 +322,16 @@ def pretrain( # bmt.print_rank(torch.cuda.max_memory_allocated()) # =========== - grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2) - optim_manager.step() + # =========== + if iteration % args.gradient_accumulation_steps == 0: + grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, args.clip_grad, norm_type=2) + optim_manager.step() + mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage) + optim_manager.zero_grad() + else: + grad_norm = None + mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage) + mem_usage, tim_usage = add_mem_time("optim", mem_usage, tim_usage) # bmt.print_rank(torch.cuda.max_memory_allocated()) @@ -410,7 +418,7 @@ def pretrain( "token_max": local_total_rate, "token_pass": global_token_pass, "throughout": args.max_length * args.batch_size * local_total_rate / avg_time, - "grad_norm": grad_norm.item(), + "grad_norm": 0 if grad_norm == None else grad_norm.item(), "mask_max": ((targets >= 0).sum(-1).float().mean() / args.max_length).item(), "num_gpus": global_world_size, "task_loss": task_loss_map, @@ -433,7 +441,7 @@ def pretrain( (targets >= 0).sum(-1).float().mean() / args.max_length / (args.batch_size if args.flash == "cuda" else 1), - grad_norm, + 0 if grad_norm == None else grad_norm, max(mem_usage["forward"][1], mem_usage["backward"][1]), ) ) @@ -455,7 +463,7 @@ def pretrain( writer.add_scalar("Loss/train", global_loss, iteration) writer.add_scalar("Optimizer/lr", lr_scheduler.current_lr, iteration) writer.add_scalar("Optimizer/scale", optim_manager.loss_scale, iteration) - writer.add_scalar("Optimizer/grad_norm", grad_norm.item(), iteration) + writer.add_scalar("Optimizer/grad_norm", 0 if grad_norm == None else grad_norm.item(), iteration) for task_name, loss in task_loss_map.items(): writer.add_scalar("Loss/train/{}".format(task_name), loss, iteration)