-
Notifications
You must be signed in to change notification settings - Fork 13
Open
Description
In the paper you said that you reset the codebook every 20 iterations to prevent codebook collapse. However, in the training loop
Lines 156 to 193 in 68d8500
| for nb_iter in tqdm(range(1, args.total_iter + 1)): | |
| gt_motion = next(train_loader_iter) | |
| gt_motion = gt_motion.cuda().float() # bs, nb_joints, joints_dim, seq_len | |
| if args.sep_uplow: | |
| pred_motion, loss_commit, perplexity = net(gt_motion, idx_noise=0) | |
| else: | |
| pred_motion, loss_commit, perplexity = net(gt_motion) | |
| loss_motion = Loss(pred_motion, gt_motion) | |
| loss_vel = Loss.forward_joint(pred_motion, gt_motion) | |
| loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| avg_recons += loss_motion.item() | |
| avg_perplexity += perplexity.item() | |
| avg_commit += loss_commit.item() | |
| if nb_iter % args.print_iter == 0 : | |
| avg_recons /= args.print_iter | |
| avg_perplexity /= args.print_iter | |
| avg_commit /= args.print_iter | |
| writer.add_scalar('./Train/L1', avg_recons, nb_iter) | |
| writer.add_scalar('./Train/PPL', avg_perplexity, nb_iter) | |
| writer.add_scalar('./Train/Commit', avg_commit, nb_iter) | |
| logger.info(f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}") | |
| avg_recons, avg_perplexity, avg_commit = 0., 0., 0., | |
| if nb_iter % args.eval_iter==0 : | |
| best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper=eval_wrapper) |
Metadata
Metadata
Assignees
Labels
No labels