完善了跳连接,在上decode块后增加特征精炼层,未测效果

This commit is contained in:
2026-01-09 18:23:45 +08:00
parent 500c2eb18f
commit 12de74f130
8 changed files with 893 additions and 244 deletions

View File

@@ -57,6 +57,7 @@ def get_args_parser():
help='Use representation head for pose/velocity prediction')
parser.add_argument('--representation-dim', default=128, type=int,
help='Dimension of representation vector')
parser.add_argument('--use-skip', default=True, type=bool, help='using skip connections')
# Training parameters
parser.add_argument('--batch-size', default=32, type=int)
@@ -77,7 +78,7 @@ def get_args_parser():
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 1e-3)')
# Learning rate schedule parameters (required by timm's create_scheduler)
@@ -89,7 +90,7 @@ def get_args_parser():
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
parser.add_argument('--warmup-lr', type=float, default=1e-3, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
@@ -212,6 +213,7 @@ def main(args):
'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
'use_skip': args.use_skip,
}
if args.model == 'SwiftFormerTemporal_XS':
@@ -373,6 +375,11 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = f'Epoch: [{epoch}]'
print_freq = 10
# 添加诊断指标
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)):
@@ -382,7 +389,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Forward pass
with torch.cuda.amp.autocast():
with torch.amp.autocast(device_type='cuda'):
pred_frames, representations = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
@@ -395,6 +402,8 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
raise ValueError(f"Loss is {loss_value}")
optimizer.zero_grad()
# 在反向传播前保存梯度用于诊断
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters())
@@ -402,6 +411,30 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
if model_ema is not None:
model_ema.update(model)
# 计算诊断指标
pred_mean = pred_frames.mean().item()
pred_std = pred_frames.std().item()
# 计算梯度范数
total_grad_norm = 0.0
for param in model.parameters():
if param.grad is not None:
total_grad_norm += param.grad.norm().item()
# 记录诊断指标
metric_logger.update(pred_mean=pred_mean)
metric_logger.update(pred_std=pred_std)
metric_logger.update(grad_norm=total_grad_norm)
# 每50个批次打印一次BatchNorm统计
if batch_idx % 50 == 0:
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
# 检查一个BatchNorm层的运行统计
for name, module in model.named_modules():
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
break
# Log to TensorBoard
if writer is not None:
# Log scalar metrics every iteration
@@ -415,6 +448,11 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
else:
writer.add_scalar(f'train/{k}', v, global_step)
# Log diagnostic metrics
writer.add_scalar('train/pred_mean', pred_mean, global_step)
writer.add_scalar('train/pred_std', pred_std, global_step)
writer.add_scalar('train/grad_norm', total_grad_norm, global_step)
# Log images periodically
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
with torch.no_grad():
@@ -450,20 +488,54 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
# 添加诊断指标
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('target_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('target_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header):
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(metric_logger.log_every(data_loader, 10, header)):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Compute output
with torch.cuda.amp.autocast():
with torch.amp.autocast(device_type='cuda'):
pred_frames, representations = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
representations, temporal_indices
)
# 计算诊断指标
pred_mean = pred_frames.mean().item()
pred_std = pred_frames.std().item()
target_mean = target_frames.mean().item()
target_std = target_frames.std().item()
# 更新诊断指标
metric_logger.update(pred_mean=pred_mean)
metric_logger.update(pred_std=pred_std)
metric_logger.update(target_mean=target_mean)
metric_logger.update(target_std=target_std)
# 第一个批次打印详细诊断信息
if batch_idx == 0:
print(f"[评估诊断] 批次 0:")
print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
# 检查BatchNorm运行统计
for name, module in model.named_modules():
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
if module.running_var[0].item() < 1e-6:
print(f" 警告: BatchNorm运行方差接近零!")
break
# Update metrics
metric_logger.update(loss=loss.item())
for k, v in loss_dict.items():