修改梯度裁剪的恶性bug,当前能进行训练,但是无论是否使用跳连接,预测帧总是输出对称的的效果,mse收敛到0.10
This commit is contained in:
@@ -20,18 +20,14 @@ from util import *
|
||||
from models import *
|
||||
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
||||
from util.video_dataset import VideoFrameDataset
|
||||
from util.frame_losses import MultiTaskLoss
|
||||
# from util.frame_losses import MultiTaskLoss
|
||||
|
||||
# Try to import TensorBoard
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
TENSORBOARD_AVAILABLE = True
|
||||
except ImportError:
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
TENSORBOARD_AVAILABLE = True
|
||||
except ImportError:
|
||||
TENSORBOARD_AVAILABLE = False
|
||||
TENSORBOARD_AVAILABLE = False
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
@@ -57,7 +53,7 @@ def get_args_parser():
|
||||
help='Use representation head for pose/velocity prediction')
|
||||
parser.add_argument('--representation-dim', default=128, type=int,
|
||||
help='Dimension of representation vector')
|
||||
parser.add_argument('--use-skip', default=True, type=bool, help='using skip connections')
|
||||
parser.add_argument('--use-skip', default=False, type=bool, help='using skip connections')
|
||||
|
||||
# Training parameters
|
||||
parser.add_argument('--batch-size', default=32, type=int)
|
||||
@@ -328,7 +324,7 @@ def main(args):
|
||||
lr_scheduler.step(epoch)
|
||||
|
||||
# Save checkpoint
|
||||
if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1):
|
||||
if args.output_dir and (epoch % 1 == 0 or epoch == args.epochs - 1):
|
||||
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
|
||||
utils.save_on_master({
|
||||
'model': model_without_ddp.state_dict(),
|
||||
@@ -368,7 +364,7 @@ def main(args):
|
||||
|
||||
|
||||
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
||||
clip_grad=0, clip_mode='norm', model_ema=None, writer=None,
|
||||
clip_grad=None, clip_mode='norm', model_ema=None, writer=None,
|
||||
global_step=0, args=None, **kwargs):
|
||||
model.train()
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
@@ -403,7 +399,6 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 在反向传播前保存梯度用于诊断
|
||||
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
|
||||
parameters=model.parameters())
|
||||
|
||||
@@ -426,14 +421,14 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
metric_logger.update(pred_std=pred_std)
|
||||
metric_logger.update(grad_norm=total_grad_norm)
|
||||
|
||||
# 每50个批次打印一次BatchNorm统计
|
||||
# # 每50个批次打印一次BatchNorm统计
|
||||
if batch_idx % 50 == 0:
|
||||
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
|
||||
# 检查一个BatchNorm层的运行统计
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||
print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||
break
|
||||
# # 检查一个BatchNorm层的运行统计
|
||||
# for name, module in model.named_modules():
|
||||
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||
# print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||
# break
|
||||
|
||||
# Log to TensorBoard
|
||||
if writer is not None:
|
||||
@@ -520,21 +515,21 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
|
||||
metric_logger.update(target_mean=target_mean)
|
||||
metric_logger.update(target_std=target_std)
|
||||
|
||||
# 第一个批次打印详细诊断信息
|
||||
if batch_idx == 0:
|
||||
print(f"[评估诊断] 批次 0:")
|
||||
print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
|
||||
print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
|
||||
print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
|
||||
print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
|
||||
# # 第一个批次打印详细诊断信息
|
||||
# if batch_idx == 0:
|
||||
# print(f"[评估诊断] 批次 0:")
|
||||
# print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
|
||||
# print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
|
||||
# print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
|
||||
# print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
|
||||
|
||||
# 检查BatchNorm运行统计
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||
print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||
if module.running_var[0].item() < 1e-6:
|
||||
print(f" 警告: BatchNorm运行方差接近零!")
|
||||
break
|
||||
# # 检查BatchNorm运行统计
|
||||
# for name, module in model.named_modules():
|
||||
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||
# print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||
# if module.running_var[0].item() < 1e-6:
|
||||
# print(f" 警告: BatchNorm运行方差接近零!")
|
||||
# break
|
||||
|
||||
# Update metrics
|
||||
metric_logger.update(loss=loss.item())
|
||||
|
||||
Reference in New Issue
Block a user