From c5502cc87c30c3de436fe221bef58f628ca1ba23 Mon Sep 17 00:00:00 2001 From: CaoWangrenbo Date: Sun, 11 Jan 2026 10:50:11 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=A2=AF=E5=BA=A6=E8=A3=81?= =?UTF-8?q?=E5=89=AA=E7=9A=84=E6=81=B6=E6=80=A7bug=EF=BC=8C=E5=BD=93?= =?UTF-8?q?=E5=89=8D=E8=83=BD=E8=BF=9B=E8=A1=8C=E8=AE=AD=E7=BB=83=EF=BC=8C?= =?UTF-8?q?=E4=BD=86=E6=98=AF=E6=97=A0=E8=AE=BA=E6=98=AF=E5=90=A6=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E8=B7=B3=E8=BF=9E=E6=8E=A5=EF=BC=8C=E9=A2=84=E6=B5=8B?= =?UTF-8?q?=E5=B8=A7=E6=80=BB=E6=98=AF=E8=BE=93=E5=87=BA=E5=AF=B9=E7=A7=B0?= =?UTF-8?q?=E7=9A=84=E7=9A=84=E6=95=88=E6=9E=9C=EF=BC=8Cmse=E6=94=B6?= =?UTF-8?q?=E6=95=9B=E5=88=B00.10?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main_temporal.py | 55 ++++++------- temporal_train.sh | 0 util/frame_losses.py | 182 ------------------------------------------- 3 files changed, 25 insertions(+), 212 deletions(-) delete mode 100644 temporal_train.sh delete mode 100644 util/frame_losses.py diff --git a/main_temporal.py b/main_temporal.py index 57cfb3c..5a137e1 100644 --- a/main_temporal.py +++ b/main_temporal.py @@ -20,18 +20,14 @@ from util import * from models import * from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3 from util.video_dataset import VideoFrameDataset -from util.frame_losses import MultiTaskLoss +# from util.frame_losses import MultiTaskLoss # Try to import TensorBoard try: from torch.utils.tensorboard import SummaryWriter TENSORBOARD_AVAILABLE = True except ImportError: - try: - from tensorboardX import SummaryWriter - TENSORBOARD_AVAILABLE = True - except ImportError: - TENSORBOARD_AVAILABLE = False + TENSORBOARD_AVAILABLE = False def get_args_parser(): @@ -57,7 +53,7 @@ def get_args_parser(): help='Use representation head for pose/velocity prediction') parser.add_argument('--representation-dim', default=128, type=int, help='Dimension of representation vector') - parser.add_argument('--use-skip', default=True, type=bool, help='using skip connections') + parser.add_argument('--use-skip', default=False, type=bool, help='using skip connections') # Training parameters parser.add_argument('--batch-size', default=32, type=int) @@ -328,7 +324,7 @@ def main(args): lr_scheduler.step(epoch) # Save checkpoint - if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1): + if args.output_dir and (epoch % 1 == 0 or epoch == args.epochs - 1): checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth' utils.save_on_master({ 'model': model_without_ddp.state_dict(), @@ -368,7 +364,7 @@ def main(args): def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, - clip_grad=0, clip_mode='norm', model_ema=None, writer=None, + clip_grad=None, clip_mode='norm', model_ema=None, writer=None, global_step=0, args=None, **kwargs): model.train() metric_logger = utils.MetricLogger(delimiter=" ") @@ -403,7 +399,6 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los optimizer.zero_grad() - # 在反向传播前保存梯度用于诊断 loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, parameters=model.parameters()) @@ -426,14 +421,14 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los metric_logger.update(pred_std=pred_std) metric_logger.update(grad_norm=total_grad_norm) - # 每50个批次打印一次BatchNorm统计 + # # 每50个批次打印一次BatchNorm统计 if batch_idx % 50 == 0: print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}") - # 检查一个BatchNorm层的运行统计 - for name, module in model.named_modules(): - if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name: - print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}") - break + # # 检查一个BatchNorm层的运行统计 + # for name, module in model.named_modules(): + # if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name: + # print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}") + # break # Log to TensorBoard if writer is not None: @@ -520,21 +515,21 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0): metric_logger.update(target_mean=target_mean) metric_logger.update(target_std=target_std) - # 第一个批次打印详细诊断信息 - if batch_idx == 0: - print(f"[评估诊断] 批次 0:") - print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]") - print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}") - print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]") - print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}") + # # 第一个批次打印详细诊断信息 + # if batch_idx == 0: + # print(f"[评估诊断] 批次 0:") + # print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]") + # print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}") + # print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]") + # print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}") - # 检查BatchNorm运行统计 - for name, module in model.named_modules(): - if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name: - print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}") - if module.running_var[0].item() < 1e-6: - print(f" 警告: BatchNorm运行方差接近零!") - break + # # 检查BatchNorm运行统计 + # for name, module in model.named_modules(): + # if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name: + # print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}") + # if module.running_var[0].item() < 1e-6: + # print(f" 警告: BatchNorm运行方差接近零!") + # break # Update metrics metric_logger.update(loss=loss.item()) diff --git a/temporal_train.sh b/temporal_train.sh deleted file mode 100644 index e69de29..0000000 diff --git a/util/frame_losses.py b/util/frame_losses.py deleted file mode 100644 index e27cc05..0000000 --- a/util/frame_losses.py +++ /dev/null @@ -1,182 +0,0 @@ -""" -Loss functions for frame prediction and representation learning -""" -import torch -import torch.nn as nn -import torch.nn.functional as F -import math - - -class SSIMLoss(nn.Module): - """ - Structural Similarity Index Measure Loss - Based on: https://github.com/Po-Hsun-Su/pytorch-ssim - """ - def __init__(self, window_size=11, size_average=True): - super().__init__() - self.window_size = window_size - self.size_average = size_average - self.channel = 3 - self.window = self.create_window(window_size, self.channel) - - def create_window(self, window_size, channel): - def gaussian(window_size, sigma): - gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) - return gauss/gauss.sum() - - _1D_window = gaussian(window_size, 1.5).unsqueeze(1) - _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) - window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() - return window - - def forward(self, img1, img2): - # Ensure window is on correct device - if self.window.device != img1.device: - self.window = self.window.to(img1.device) - - mu1 = F.conv2d(img1, self.window, padding=self.window_size//2, groups=self.channel) - mu2 = F.conv2d(img2, self.window, padding=self.window_size//2, groups=self.channel) - - mu1_sq = mu1.pow(2) - mu2_sq = mu2.pow(2) - mu1_mu2 = mu1 * mu2 - - sigma1_sq = F.conv2d(img1*img1, self.window, padding=self.window_size//2, groups=self.channel) - mu1_sq - sigma2_sq = F.conv2d(img2*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu2_sq - sigma12 = F.conv2d(img1*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu1_mu2 - - C1 = 0.01**2 - C2 = 0.03**2 - - ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) - - if self.size_average: - return 1 - ssim_map.mean() - else: - return 1 - ssim_map.mean(1).mean(1).mean(1) - - -class FramePredictionLoss(nn.Module): - """ - Combined loss for frame prediction - """ - def __init__(self, l1_weight=1.0, ssim_weight=0.1, use_ssim=True): - super().__init__() - self.l1_weight = l1_weight - self.ssim_weight = ssim_weight - self.use_ssim = use_ssim - - self.l1_loss = nn.L1Loss() - if use_ssim: - self.ssim_loss = SSIMLoss() - - def forward(self, pred, target): - """ - Args: - pred: predicted frame [B, 3, H, W] in range [-1, 1] - target: target frame [B, 3, H, W] in range [-1, 1] - Returns: - total_loss, loss_dict - """ - loss_dict = {} - - # L1 loss - l1_loss = self.l1_loss(pred, target) - loss_dict['l1'] = l1_loss - total_loss = self.l1_weight * l1_loss - - # SSIM loss - if self.use_ssim: - ssim_loss = self.ssim_loss(pred, target) - loss_dict['ssim'] = ssim_loss - total_loss += self.ssim_weight * ssim_loss - - loss_dict['total'] = total_loss - return total_loss, loss_dict - - -class ContrastiveLoss(nn.Module): - """ - Contrastive loss for representation learning - Positive pairs: representations from adjacent frames - Negative pairs: representations from distant frames - """ - def __init__(self, temperature=0.1, margin=1.0): - super().__init__() - self.temperature = temperature - self.margin = margin - self.cosine_similarity = nn.CosineSimilarity(dim=-1) - - def forward(self, representations, temporal_indices): - """ - Args: - representations: [B, D] representation vectors - temporal_indices: [B] temporal indices of each sample - Returns: - contrastive_loss - """ - batch_size = representations.size(0) - - # Compute similarity matrix - sim_matrix = torch.matmul(representations, representations.T) / self.temperature - - # Create positive mask (adjacent frames) - indices_expanded = temporal_indices.unsqueeze(0) - diff = torch.abs(indices_expanded - indices_expanded.T) - positive_mask = (diff == 1).float() - - # Create negative mask (distant frames) - negative_mask = (diff > 2).float() - - # Positive loss - pos_sim = sim_matrix * positive_mask - pos_loss = -torch.log(torch.exp(pos_sim) / torch.exp(sim_matrix).sum(dim=-1, keepdim=True) + 1e-8) - pos_loss = (pos_loss * positive_mask).sum() / (positive_mask.sum() + 1e-8) - - # Negative loss (push apart) - neg_sim = sim_matrix * negative_mask - neg_loss = torch.relu(neg_sim - self.margin).mean() - - return pos_loss + 0.1 * neg_loss - - -class MultiTaskLoss(nn.Module): - """ - Multi-task loss combining frame prediction and representation learning - """ - def __init__(self, frame_weight=1.0, contrastive_weight=0.1, - l1_weight=1.0, ssim_weight=0.1, use_contrastive=True): - super().__init__() - self.frame_weight = frame_weight - self.contrastive_weight = contrastive_weight - self.use_contrastive = use_contrastive - - self.frame_loss = FramePredictionLoss(l1_weight=l1_weight, ssim_weight=ssim_weight) - if use_contrastive: - self.contrastive_loss = ContrastiveLoss() - - def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None): - """ - Args: - pred_frame: predicted frame [B, 3, H, W] - target_frame: target frame [B, 3, H, W] - representations: [B, D] representation vectors (optional) - temporal_indices: [B] temporal indices (optional) - Returns: - total_loss, loss_dict - """ - loss_dict = {} - - # Frame prediction loss - frame_loss, frame_loss_dict = self.frame_loss(pred_frame, target_frame) - loss_dict.update({f'frame_{k}': v for k, v in frame_loss_dict.items()}) - total_loss = self.frame_weight * frame_loss - - # Contrastive loss (if representations provided) - if self.use_contrastive and representations is not None and temporal_indices is not None: - contrastive_loss = self.contrastive_loss(representations, temporal_indices) - loss_dict['contrastive'] = contrastive_loss - total_loss += self.contrastive_weight * contrastive_loss - - loss_dict['total'] = total_loss - return total_loss, loss_dict \ No newline at end of file