Compare commits

..

6 Commits

11 changed files with 752 additions and 551 deletions

View File

@@ -11,9 +11,10 @@ shift 2
# Default parameters # Default parameters
MODEL=${MODEL:-"SwiftFormerTemporal_XS"} MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
BATCH_SIZE=${BATCH_SIZE:-32} BATCH_SIZE=${BATCH_SIZE:-128}
EPOCHS=${EPOCHS:-100} EPOCHS=${EPOCHS:-100}
LR=${LR:-1e-3} # LR=${LR:-1e-3}
LR=${LR:-0.01}
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"} OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
echo "Starting distributed training with $NUM_GPUS GPUs" echo "Starting distributed training with $NUM_GPUS GPUs"

484
evaluate_temporal.py Normal file
View File

@@ -0,0 +1,484 @@
"""
评估脚本 for SwiftFormerTemporal frame prediction
输出预测图注意反归一化以及对应指标mse&ssim&psnr
"""
import argparse
import os
import torch
import torch.nn as nn
import pickle
import numpy as np
import random
from pathlib import Path
import json
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from util.video_dataset import VideoFrameDataset
from models.swiftformer_temporal import (
SwiftFormerTemporal_XS, SwiftFormerTemporal_S,
SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
)
# 导入SSIM和PSNR计算
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import warnings
warnings.filterwarnings('ignore')
def denormalize(tensor):
"""
将[-1, 1]范围的张量反归一化到[0, 255]范围
Args:
tensor: 形状为[B, C, H, W]或[C, H, W],值在[-1, 1]
Returns:
反归一化后的张量,值在[0, 255]
"""
# clip 到 [-1, 1] 范围
tensor = tensor.clamp(-1, 1)
# [-1, 1] -> [0, 1]
tensor = (tensor + 1) / 2
# [0, 1] -> [0, 255]
tensor = tensor * 255
return tensor.clamp(0, 255)
def minmax_denormalize(tensor):
tensor_min = tensor.min()
tensor_max = tensor.max()
tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
# tensor = tensor*2-1
tensor = tensor*255
return tensor.clamp(0, 255)
def calculate_metrics(pred, target, debug=False):
"""
计算MSE, SSIM, PSNR指标
Args:
pred: 预测图像,形状[H, W],值在[0, 255]
target: 目标图像,形状[H, W],值在[0, 255]
debug: 是否输出调试信息
Returns:
mse, ssim_value, psnr_value
"""
# 转换为numpy数组
pred_np = pred.cpu().numpy() if torch.is_tensor(pred) else pred
target_np = target.cpu().numpy() if torch.is_tensor(target) else target
# 确保是2D数组
if pred_np.ndim == 3:
pred_np = pred_np.squeeze(0)
if target_np.ndim == 3:
target_np = target_np.squeeze(0)
# if debug:
# print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
# print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
# print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
mse = np.mean((pred_np - target_np) ** 2)
data_range = 255.0
ssim_value = ssim(pred_np, target_np, data_range=data_range)
psnr_value = psnr(target_np, pred_np, data_range=data_range)
return mse, ssim_value, psnr_value
def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
input_frame_indices=None, target_frame_index=None):
"""
保存对比图:输入帧、目标帧、预测帧
Args:
input_frames: 输入帧列表,每个形状为[H, W],值在[0, 255]
target_frame: 目标帧,形状[H, W],值在[0, 255]
pred_frame: 预测帧,形状[H, W],值在[0, 255]
save_path: 保存路径
input_frame_indices: 输入帧的索引列表(可选)
target_frame_index: 目标帧索引(可选)
"""
num_input = len(input_frames)
fig, axes = plt.subplots(1, num_input + 2, figsize=(4*(num_input+2), 4))
# 绘制输入帧
for i in range(num_input):
ax = axes[i]
ax.imshow(input_frames[i], cmap='gray')
if input_frame_indices is not None:
ax.set_title(f'Input Frame {input_frame_indices[i]}')
else:
ax.set_title(f'Input {i+1}')
ax.axis('off')
# 绘制目标帧
ax = axes[num_input]
ax.imshow(target_frame, cmap='gray')
if target_frame_index is not None:
ax.set_title(f'Target Frame {target_frame_index}')
else:
ax.set_title('Target')
ax.axis('off')
# 绘制预测帧
ax = axes[num_input + 1]
ax.imshow(pred_frame, cmap='gray')
ax.set_title('Predicted')
ax.axis('off')
#debug print
print(target_frame)
print(pred_frame)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
def evaluate_model(model, data_loader, device, args):
"""
评估模型并计算指标
Args:
model: 训练好的模型
data_loader: 数据加载器
device: 设备
args: 命令行参数
Returns:
metrics_dict: 包含所有指标的字典
sample_results: 示例结果用于可视化
"""
model.eval()
# model.train() # 临时使用训练模式
# 初始化指标累加器
total_mse = 0.0
total_ssim = 0.0
total_psnr = 0.0
total_samples = 0
# 存储示例结果用于可视化(使用蓄水池抽样随机选择)
sample_results = []
max_samples_to_save = args.num_samples_to_save
max_samples = args.max_samples
# 用于蓄水池抽样的计数器已处理的样本数不包括因max_samples限制而跳过的样本
sample_count = 0
with torch.no_grad():
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(data_loader):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
# 前向传播
pred_frames = model(input_frames)
# 反归一化用于指标计算
# pred_denorm = minmax_denormalize(pred_frames) # [B, 1, H, W]
pred_denorm = denormalize(pred_frames)
target_denorm = denormalize(target_frames) # [B, 1, H, W]
batch_size = input_frames.size(0)
# 计算每个样本的指标
for i in range(batch_size):
# 检查是否达到最大样本数限制
if max_samples is not None and total_samples >= max_samples:
break
pred_i = pred_denorm[i] # [1, H, W]
target_i = target_denorm[i] # [1, H, W]
# 对第一个样本启用调试
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
# if debug_mode:
# print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
# print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
# print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
# print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=False)
total_mse += mse
total_ssim += ssim_value
total_psnr += psnr_value
total_samples += 1
sample_count += 1
# 构建样本数据字典
input_denorm = denormalize(input_frames[i]) # [num_frames, H, W]
# 分离输入帧
input_frames_list = []
for j in range(args.num_frames):
input_frame_j = input_denorm[j].squeeze(0) # [H, W]
input_frames_list.append(input_frame_j.cpu().numpy())
sample_data = {
'input_frames': input_frames_list,
'target_frame': target_i.squeeze(0).cpu().numpy(),
'pred_frame': pred_i.squeeze(0).cpu().numpy(),
'metrics': {
'mse': mse,
'ssim': ssim_value,
'psnr': psnr_value
},
'batch_idx': batch_idx,
'sample_idx': i
}
# 蓄水池抽样 (Reservoir Sampling)
if sample_count <= max_samples_to_save:
# 蓄水池未满,直接加入
sample_results.append(sample_data)
else:
# 以 max_samples_to_save / sample_count 的概率替换蓄水池中的一个随机位置
r = random.randint(0, sample_count - 1)
if r < max_samples_to_save:
sample_results[r] = sample_data
# 检查是否达到最大样本数限制
if max_samples is not None and total_samples >= max_samples:
print(f"达到最大样本数限制: {max_samples}")
break
# 进度打印
if (batch_idx + 1) % 10 == 0:
print(f'Processed {batch_idx + 1} batches, {total_samples} samples')
# 计算平均指标
if total_samples > 0:
avg_mse = float(total_mse / total_samples)
avg_ssim = float(total_ssim / total_samples)
avg_psnr = float(total_psnr / total_samples)
else:
avg_mse = avg_ssim = avg_psnr = 0.0
metrics_dict = {
'mse': avg_mse,
'ssim': avg_ssim,
'psnr': avg_psnr,
'num_samples': total_samples
}
return metrics_dict, sample_results
def main(args):
print("评估参数:", args)
device = torch.device(args.device)
# 设置随机种子
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
cudnn.benchmark = True
# 构建数据集
print("构建数据集...")
dataset_val = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=False,
max_interval=args.max_interval
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
shuffle=False,
drop_last=False
)
# 创建模型
print(f"创建模型: {args.model}")
model_kwargs = {
'num_frames': args.num_frames,
}
if args.model == 'SwiftFormerTemporal_XS':
model = SwiftFormerTemporal_XS(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_S':
model = SwiftFormerTemporal_S(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L1':
model = SwiftFormerTemporal_L1(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L3':
model = SwiftFormerTemporal_L3(**model_kwargs)
else:
raise ValueError(f"未知模型: {args.model}")
model.to(device)
# 加载检查点
if args.resume:
print(f"加载检查点: {args.resume}")
try:
# 尝试使用weights_only=False加载PyTorch 2.6+需要)
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
except (pickle.UnpicklingError, TypeError) as e:
print(f"使用weights_only=False加载失败: {e}")
print("尝试使用torch.serialization.add_safe_globals...")
# 处理状态字典(可能包含'module.'前缀)
if 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# 移除'module.'前缀(如果存在)
if hasattr(model, 'module'):
model.module.load_state_dict(state_dict)
else:
# 如果状态字典有'module.'前缀但模型没有,需要移除前缀
if any(key.startswith('module.') for key in state_dict.keys()):
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
state_dict = new_state_dict
model.load_state_dict(state_dict)
print(f"检查点加载成功epoch: {checkpoint.get('epoch', 'unknown')}")
else:
print("警告: 未提供检查点路径,使用随机初始化的模型")
# 创建输出目录
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# 评估模型
print("开始评估...")
metrics, sample_results = evaluate_model(model, data_loader_val, device, args)
# 打印指标
print("\n" + "="*50)
print("评估结果:")
print(f"MSE: {metrics['mse']:.6f}")
print(f"SSIM: {metrics['ssim']:.6f}")
print(f"PSNR: {metrics['psnr']:.6f} dB")
print(f"样本数量: {metrics['num_samples']}")
print("="*50)
# 保存指标到JSON文件
metrics_file = output_dir / 'evaluation_metrics.json'
with open(metrics_file, 'w') as f:
json.dump(metrics, f, indent=4)
print(f"指标已保存到: {metrics_file}")
# 保存示例可视化
if sample_results:
print(f"\n保存 {len(sample_results)} 个示例可视化...")
samples_dir = output_dir / 'sample_predictions'
samples_dir.mkdir(exist_ok=True)
for i, sample in enumerate(sample_results):
save_path = samples_dir / f'sample_{i:03d}.png'
# 生成输入帧索引(假设连续)
input_frame_indices = list(range(1, args.num_frames + 1))
target_frame_index = args.num_frames + 1
save_comparison_figure(
sample['input_frames'],
sample['target_frame'],
sample['pred_frame'],
save_path,
input_frame_indices=input_frame_indices,
target_frame_index=target_frame_index
)
# 保存该样本的指标
sample_metrics_file = samples_dir / f'sample_{i:03d}_metrics.txt'
with open(sample_metrics_file, 'w') as f:
f.write(f"Sample {i} (batch {sample['batch_idx']}, idx {sample['sample_idx']})\n")
f.write(f"MSE: {sample['metrics']['mse']:.6f}\n")
f.write(f"SSIM: {sample['metrics']['ssim']:.6f}\n")
f.write(f"PSNR: {sample['metrics']['psnr']:.6f} dB\n")
print(f"示例可视化已保存到: {samples_dir}")
# 生成汇总报告
report_file = output_dir / 'evaluation_report.txt'
with open(report_file, 'w') as f:
f.write("SwiftFormerTemporal 帧预测评估报告\n")
f.write("="*50 + "\n")
f.write(f"模型: {args.model}\n")
f.write(f"检查点: {args.resume}\n")
f.write(f"数据集: {args.data_path}\n")
f.write(f"输入帧数: {args.num_frames}\n")
f.write(f"帧大小: {args.frame_size}\n")
f.write(f"批次大小: {args.batch_size}\n")
f.write(f"样本总数: {metrics['num_samples']}\n\n")
f.write("评估指标:\n")
f.write(f" MSE: {metrics['mse']:.6f}\n")
f.write(f" SSIM: {metrics['ssim']:.6f}\n")
f.write(f" PSNR: {metrics['psnr']:.6f} dB\n")
print(f"评估报告已保存到: {report_file}")
print("\n评估完成!")
def get_args_parser():
parser = argparse.ArgumentParser(
'SwiftFormerTemporal 评估脚本', add_help=False)
# 数据集参数
parser.add_argument('--data-path', default='./videos', type=str,
help='视频数据集路径')
parser.add_argument('--num-frames', default=3, type=int,
help='输入帧数 (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='输入帧大小')
parser.add_argument('--max-interval', default=4, type=int,
help='连续帧之间的最大间隔')
# 模型参数
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='要评估的模型名称')
# 评估参数
parser.add_argument('--batch-size', default=16, type=int,
help='评估批次大小')
parser.add_argument('--num-samples-to-save', default=10, type=int,
help='保存可视化的样本数量')
parser.add_argument('--max-samples', default=None, type=int,
help='最大评估样本数None表示全部')
# 系统参数
parser.add_argument('--output-dir', default='./evaluation_results',
help='保存结果的路径')
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu',
help='使用的设备')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='检查点路径')
parser.add_argument('--num-workers', default=4, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='在DataLoader中固定CPU内存')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
return parser
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'SwiftFormerTemporal 评估', parents=[get_args_parser()])
args = parser.parse_args()
# 确保输出目录存在
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)

View File

@@ -19,19 +19,15 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma
from util import * from util import *
from models import * from models import *
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3 from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset from util.video_dataset import VideoFrameDataset
from util.frame_losses import MultiTaskLoss # from util.frame_losses import MultiTaskLoss
# Try to import TensorBoard # Try to import TensorBoard
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True TENSORBOARD_AVAILABLE = True
except ImportError: except ImportError:
try: TENSORBOARD_AVAILABLE = False
from tensorboardX import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
def get_args_parser(): def get_args_parser():
@@ -47,16 +43,12 @@ def get_args_parser():
help='Number of input frames (T)') help='Number of input frames (T)')
parser.add_argument('--frame-size', default=224, type=int, parser.add_argument('--frame-size', default=224, type=int,
help='Input frame size') help='Input frame size')
parser.add_argument('--max-interval', default=1, type=int, parser.add_argument('--max-interval', default=10, type=int,
help='Maximum interval between consecutive frames') help='Maximum interval between consecutive frames')
# Model parameters # Model parameters
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL', parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='Name of model to train') help='Name of model to train')
parser.add_argument('--use-representation-head', action='store_true',
help='Use representation head for pose/velocity prediction')
parser.add_argument('--representation-dim', default=128, type=int,
help='Dimension of representation vector')
# Training parameters # Training parameters
parser.add_argument('--batch-size', default=32, type=int) parser.add_argument('--batch-size', default=32, type=int)
@@ -77,7 +69,7 @@ def get_args_parser():
help='SGD momentum (default: 0.9)') help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.05, parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)') help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 1e-3)') help='learning rate (default: 1e-3)')
# Learning rate schedule parameters (required by timm's create_scheduler) # Learning rate schedule parameters (required by timm's create_scheduler)
@@ -89,7 +81,7 @@ def get_args_parser():
help='learning rate noise limit percent (default: 0.67)') help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)') help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', parser.add_argument('--warmup-lr', type=float, default=1e-3, metavar='LR',
help='warmup learning rate (default: 1e-6)') help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
@@ -109,10 +101,10 @@ def get_args_parser():
help='Weight for frame prediction loss') help='Weight for frame prediction loss')
parser.add_argument('--contrastive-weight', type=float, default=0.1, parser.add_argument('--contrastive-weight', type=float, default=0.1,
help='Weight for contrastive loss') help='Weight for contrastive loss')
parser.add_argument('--l1-weight', type=float, default=1.0, # parser.add_argument('--l1-weight', type=float, default=1.0,
help='Weight for L1 loss') # help='Weight for L1 loss')
parser.add_argument('--ssim-weight', type=float, default=0.1, # parser.add_argument('--ssim-weight', type=float, default=0.1,
help='Weight for SSIM loss') # help='Weight for SSIM loss')
parser.add_argument('--no-contrastive', action='store_true', parser.add_argument('--no-contrastive', action='store_true',
help='Disable contrastive loss') help='Disable contrastive loss')
parser.add_argument('--no-ssim', action='store_true', parser.add_argument('--no-ssim', action='store_true',
@@ -129,7 +121,7 @@ def get_args_parser():
help='start epoch') help='start epoch')
parser.add_argument('--eval', action='store_true', parser.add_argument('--eval', action='store_true',
help='Perform evaluation only') help='Perform evaluation only')
parser.add_argument('--num-workers', default=4, type=int) parser.add_argument('--num-workers', default=16, type=int)
parser.add_argument('--pin-mem', action='store_true', parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader') help='Pin CPU memory in DataLoader')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem') parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
@@ -210,8 +202,6 @@ def main(args):
print(f"Creating model: {args.model}") print(f"Creating model: {args.model}")
model_kwargs = { model_kwargs = {
'num_frames': args.num_frames, 'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
} }
if args.model == 'SwiftFormerTemporal_XS': if args.model == 'SwiftFormerTemporal_XS':
@@ -260,7 +250,7 @@ def main(args):
super().__init__() super().__init__()
self.mse = nn.MSELoss() self.mse = nn.MSELoss()
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None): def forward(self, pred_frame, target_frame, temporal_indices=None):
loss = self.mse(pred_frame, target_frame) loss = self.mse(pred_frame, target_frame)
loss_dict = {'mse': loss} loss_dict = {'mse': loss}
return loss, loss_dict return loss, loss_dict
@@ -274,7 +264,7 @@ def main(args):
checkpoint = torch.hub.load_state_dict_from_url( checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True) args.resume, map_location='cpu', check_hash=True)
else: else:
checkpoint = torch.load(args.resume, map_location='cpu') checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
model_without_ddp.load_state_dict(checkpoint['model']) model_without_ddp.load_state_dict(checkpoint['model'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
@@ -318,7 +308,7 @@ def main(args):
train_stats, global_step = train_one_epoch( train_stats, global_step = train_one_epoch(
model, criterion, data_loader_train, model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler, optimizer, device, epoch, loss_scaler, args.clip_grad, args.clip_mode,
model_ema=model_ema, writer=writer, model_ema=model_ema, writer=writer,
global_step=global_step, args=args global_step=global_step, args=args
) )
@@ -326,7 +316,7 @@ def main(args):
lr_scheduler.step(epoch) lr_scheduler.step(epoch)
# Save checkpoint # Save checkpoint
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1): if args.output_dir and (epoch % 1 == 0 or epoch == args.epochs - 1):
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth' checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
utils.save_on_master({ utils.save_on_master({
'model': model_without_ddp.state_dict(), 'model': model_without_ddp.state_dict(),
@@ -366,7 +356,7 @@ def main(args):
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0, clip_mode='norm', model_ema=None, writer=None, clip_grad=0.01, clip_mode='norm', model_ema=None, writer=None,
global_step=0, args=None, **kwargs): global_step=0, args=None, **kwargs):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
@@ -374,6 +364,11 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
header = f'Epoch: [{epoch}]' header = f'Epoch: [{epoch}]'
print_freq = 10 print_freq = 10
# 添加诊断指标
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate( for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)): metric_logger.log_every(data_loader, print_freq, header)):
@@ -382,11 +377,11 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
temporal_indices = temporal_indices.to(device, non_blocking=True) temporal_indices = temporal_indices.to(device, non_blocking=True)
# Forward pass # Forward pass
with torch.cuda.amp.autocast(): with torch.amp.autocast(device_type='cuda'):
pred_frames, representations = model(input_frames) pred_frames = model(input_frames)
loss, loss_dict = criterion( loss, loss_dict = criterion(
pred_frames, target_frames, pred_frames, target_frames,
representations, temporal_indices temporal_indices
) )
loss_value = loss.item() loss_value = loss.item()
@@ -395,6 +390,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
raise ValueError(f"Loss is {loss_value}") raise ValueError(f"Loss is {loss_value}")
optimizer.zero_grad() optimizer.zero_grad()
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters()) parameters=model.parameters())
@@ -402,6 +398,30 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
if model_ema is not None: if model_ema is not None:
model_ema.update(model) model_ema.update(model)
# 计算诊断指标
pred_mean = pred_frames.mean().item()
pred_std = pred_frames.std().item()
# 计算梯度范数
total_grad_norm = 0.0
for param in model.parameters():
if param.grad is not None:
total_grad_norm += param.grad.norm().item()
# 记录诊断指标
metric_logger.update(pred_mean=pred_mean)
metric_logger.update(pred_std=pred_std)
metric_logger.update(grad_norm=total_grad_norm)
# # 每50个批次打印一次BatchNorm统计
if batch_idx % 50 == 0:
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
# # 检查一个BatchNorm层的运行统计
# for name, module in model.named_modules():
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
# print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
# break
# Log to TensorBoard # Log to TensorBoard
if writer is not None: if writer is not None:
# Log scalar metrics every iteration # Log scalar metrics every iteration
@@ -415,11 +435,16 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
else: else:
writer.add_scalar(f'train/{k}', v, global_step) writer.add_scalar(f'train/{k}', v, global_step)
# Log diagnostic metrics
writer.add_scalar('train/pred_mean', pred_mean, global_step)
writer.add_scalar('train/pred_std', pred_std, global_step)
writer.add_scalar('train/grad_norm', total_grad_norm, global_step)
# Log images periodically # Log images periodically
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0: if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
with torch.no_grad(): with torch.no_grad():
# Take first sample from batch for visualization # Take first sample from batch for visualization
pred_vis, _ = model(input_frames[:1]) pred_vis = model(input_frames[:1])
# Convert to appropriate format for TensorBoard # Convert to appropriate format for TensorBoard
# Assuming frames are in [B, C, H, W] format # Assuming frames are in [B, C, H, W] format
writer.add_images('train/input', input_frames[:1], global_step) writer.add_images('train/input', input_frames[:1], global_step)
@@ -451,19 +476,53 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:' header = 'Test:'
for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header): # 添加诊断指标
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('target_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('target_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(metric_logger.log_every(data_loader, 10, header)):
input_frames = input_frames.to(device, non_blocking=True) input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True) temporal_indices = temporal_indices.to(device, non_blocking=True)
# Compute output # Compute output
with torch.cuda.amp.autocast(): with torch.amp.autocast(device_type='cuda'):
pred_frames, representations = model(input_frames) pred_frames = model(input_frames)
loss, loss_dict = criterion( loss, loss_dict = criterion(
pred_frames, target_frames, pred_frames, target_frames,
representations, temporal_indices temporal_indices
) )
# 计算诊断指标
pred_mean = pred_frames.mean().item()
pred_std = pred_frames.std().item()
target_mean = target_frames.mean().item()
target_std = target_frames.std().item()
# 更新诊断指标
metric_logger.update(pred_mean=pred_mean)
metric_logger.update(pred_std=pred_std)
metric_logger.update(target_mean=target_mean)
metric_logger.update(target_std=target_std)
# # 第一个批次打印详细诊断信息
# if batch_idx == 0:
# print(f"[评估诊断] 批次 0:")
# print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
# print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
# print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
# print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
# # 检查BatchNorm运行统计
# for name, module in model.named_modules():
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
# print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
# if module.running_var[0].item() < 1e-6:
# print(f" 警告: BatchNorm运行方差接近零!")
# break
# Update metrics # Update metrics
metric_logger.update(loss=loss.item()) metric_logger.update(loss=loss.item())
for k, v in loss_dict.items(): for k, v in loss_dict.items():

View File

@@ -11,92 +11,113 @@ from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module): class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder""" """Upsampling block for frame prediction decoder without residual connections"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1): def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__() super().__init__()
self.conv = nn.ConvTranspose2d( # 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
in_channels, out_channels, in_channels, out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
padding=padding, padding=padding,
output_padding=output_padding, output_padding=output_padding,
bias=False bias=False # 禁用bias因为使用BN
) )
self.bn = nn.BatchNorm2d(out_channels) self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True) self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
# 使用ReLU激活函数
self.activation = nn.ReLU(inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu')
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
# 初始化BN层使用默认初始化
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x): def forward(self, x):
return self.relu(self.bn(self.conv(x))) # 主路径
x = self.conv_transpose(x)
x = self.bn1(x)
x = self.activation(x)
x = self.conv1(x)
x = self.bn2(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn3(x)
x = self.activation(x)
return x
class FramePredictionDecoder(nn.Module): class FramePredictionDecoder(nn.Module):
"""Lightweight decoder for frame prediction with optional skip connections""" """Improved decoder for frame prediction"""
def __init__(self, embed_dims, output_channels=1, use_skip=False): def __init__(self, embed_dims, output_channels=1):
super().__init__() super().__init__()
self.use_skip = use_skip # Define decoder dimensions independently (no skip connections)
# Reverse the embed_dims for decoder start_dim = embed_dims[-1]
decoder_dims = embed_dims[::-1] decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
# First upsampling from bottleneck to stage4 resolution
# 第一个blockstride=2 (decoder_dims[0] -> decoder_dims[1])
self.blocks.append(DecoderBlock( self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1], decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=2, padding=1, output_padding=1 kernel_size=3, stride=2, padding=1, output_padding=1
)) ))
# stage4 to stage3 # 第二个blockstride=2 (decoder_dims[1] -> decoder_dims[2])
self.blocks.append(DecoderBlock( self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2], decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1 kernel_size=3, stride=2, padding=1, output_padding=1
)) ))
# stage3 to stage2 # 第三个blockstride=2 (decoder_dims[2] -> decoder_dims[3])
self.blocks.append(DecoderBlock( self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3], decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1 kernel_size=3, stride=2, padding=1, output_padding=1
)) ))
# stage2 to original resolution (now 8x upsampling total with stride 4) # 第四个blockstride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
self.blocks.append(nn.Sequential( self.blocks.append(DecoderBlock(
nn.ConvTranspose2d( decoder_dims[3], 64,
decoder_dims[3], 32, kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
kernel_size=3, stride=4, padding=1, output_padding=3
),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1),
nn.Tanh() # Output in [-1, 1] range
)) ))
# If using skip connections, we need to adjust input channels for each block self.final_block = nn.Sequential(
if use_skip: nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
# We'll modify the first three blocks to accept concatenated features nn.ReLU(inplace=True),
# Instead of modifying existing blocks, we'll replace them with custom blocks nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
# For simplicity, we'll keep the same architecture but forward will handle concatenation nn.ReLU(inplace=True),
pass nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
nn.Tanh()
)
def forward(self, x, skip_features=None): def forward(self, x):
""" """
Args: Args:
x: input tensor of shape [B, embed_dims[-1], H/32, W/32] x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
skip_features: list of encoder features from stages [stage2, stage1, stage0]
each of shape [B, C, H', W'] where C matches decoder dims?
""" """
if self.use_skip and skip_features is not None: # 不使用skip connections
# Ensure we have exactly 3 skip features (for the first three blocks) for i in range(4):
assert len(skip_features) == 3, "Need 3 skip features for skip connections" x = self.blocks[i](x)
# Reverse skip_features to match decoder order: stage2, stage1, stage0
# skip_features[0] should be stage2 (H/16), [1] stage1 (H/8), [2] stage0 (H/4)
skip_features = skip_features[::-1] # Now index 0: stage2, 1: stage1, 2: stage0
for i, block in enumerate(self.blocks): # 最终输出层:只进行特征精炼,不上采样
if self.use_skip and skip_features is not None and i < 3: x = self.final_block(x)
# Concatenate skip feature along channel dimension
# Ensure spatial dimensions match (they should because of upsampling)
x = torch.cat([x, skip_features[i]], dim=1)
# Need to adjust block to accept extra channels? We'll create a separate block.
# For now, we'll just pass through, but this will cause channel mismatch.
# Instead, we should have created custom blocks with appropriate in_channels.
# This is a placeholder; we need to implement properly.
pass
x = block(x)
return x return x
@@ -110,9 +131,6 @@ class SwiftFormerTemporal(nn.Module):
model_name='XS', model_name='XS',
num_frames=3, num_frames=3,
use_decoder=True, use_decoder=True,
use_representation_head=False,
representation_dim=128,
return_features=False,
**kwargs): **kwargs):
super().__init__() super().__init__()
@@ -123,8 +141,6 @@ class SwiftFormerTemporal(nn.Module):
# Store configuration # Store configuration
self.num_frames = num_frames self.num_frames = num_frames
self.use_decoder = use_decoder self.use_decoder = use_decoder
self.use_representation_head = use_representation_head
self.return_features = return_features
# Modify stem to accept multiple frames (only Y channel) # Modify stem to accept multiple frames (only Y channel)
in_channels = num_frames in_channels = num_frames
@@ -155,79 +171,51 @@ class SwiftFormerTemporal(nn.Module):
# Frame prediction decoder # Frame prediction decoder
if use_decoder: if use_decoder:
self.decoder = FramePredictionDecoder(embed_dims, output_channels=1) self.decoder = FramePredictionDecoder(
embed_dims,
# Representation head for pose/velocity prediction output_channels=1
if use_representation_head:
self.representation_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(embed_dims[-1], representation_dim),
nn.ReLU(),
nn.Linear(representation_dim, representation_dim)
) )
else:
self.representation_head = None
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)): if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02) # 使用Kaiming初始化适合ReLU
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)): elif isinstance(m, nn.ConvTranspose2d):
# 反卷积层使用特定的初始化
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x): def forward_tokens(self, x):
"""Forward through encoder network, return list of stage features if return_features else final output""" for block in self.network:
if self.return_features: x = block(x)
features = [] return x
for idx, block in enumerate(self.network):
x = block(x)
# Collect output after each stage (indices 0,2,4,6 correspond to stages)
if idx in [0, 2, 4, 6]:
features.append(x)
return x, features
else:
for block in self.network:
x = block(x)
return x
def forward(self, x): def forward(self, x):
""" """
Args: Args:
x: input frames of shape [B, num_frames, H, W] x: input frames of shape [B, num_frames, H, W]
Returns: Returns:
If return_features is False: pred_frame: predicted frame [B, 1, H, W] (or None)
pred_frame: predicted frame [B, 1, H, W] (or None)
representation: optional representation vector [B, representation_dim] (or None)
If return_features is True:
pred_frame, representation, features (list of stage features)
""" """
# Encode # Encode
x = self.patch_embed(x) x = self.patch_embed(x)
if self.return_features: x = self.forward_tokens(x)
x, features = self.forward_tokens(x)
else:
x = self.forward_tokens(x)
x = self.norm(x) x = self.norm(x)
# Get representation if needed
representation = None
if self.representation_head is not None:
representation = self.representation_head(x)
# Decode to frame # Decode to frame
pred_frame = None pred_frame = None
if self.use_decoder: if self.use_decoder:
pred_frame = self.decoder(x) pred_frame = self.decoder(x)
if self.return_features: return pred_frame
return pred_frame, representation, features
else:
return pred_frame, representation
# Factory functions for different model sizes # Factory functions for different model sizes

View File

@@ -1,26 +0,0 @@
#!/usr/bin/env bash
# Simple multi-GPU training script for SwiftFormerTemporal
# Usage: ./multi_gpu_temporal_train.sh <NUM_GPUS> [OPTIONS]
NUM_GPUS=${1:-2}
shift
echo "Starting multi-GPU training with $NUM_GPUS GPUs"
# Set environment variables for distributed training
export MASTER_PORT=12345
export MASTER_ADDR=localhost
export WORLD_SIZE=$NUM_GPUS
# Launch training
torchrun --nproc_per_node=$NUM_GPUS --master_port=$MASTER_PORT main_temporal.py \
--data-path "./videos" \
--model SwiftFormerTemporal_XS \
--batch-size 32 \
--epochs 100 \
--lr 1e-3 \
--output-dir "./temporal_output_multi" \
--num-workers 8 \
--pin-mem \
"$@"

View File

View File

@@ -1,45 +0,0 @@
import torch
def test_cuda_availability():
"""全面测试CUDA可用性"""
print("="*50)
print("PyTorch CUDA 测试")
print("="*50)
# 基本信息
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if not torch.cuda.is_available():
print("CUDA不可用可能原因")
print("1. 未安装CUDA驱动")
print("2. 安装的是CPU版本的PyTorch")
print("3. CUDA版本与PyTorch不匹配")
return False
# 设备信息
device_count = torch.cuda.device_count()
print(f"发现 {device_count} 个CUDA设备")
for i in range(device_count):
print(f"\n设备 {i}:")
print(f" 名称: {torch.cuda.get_device_name(i)}")
print(f" 内存总量: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
print(f" 计算能力: {torch.cuda.get_device_properties(i).major}.{torch.cuda.get_device_properties(i).minor}")
# 简单张量测试
print("\n运行CUDA测试...")
try:
x = torch.randn(3, 3).cuda()
y = torch.randn(3, 3).cuda()
z = x + y
print("CUDA计算测试: 成功!")
print(f"设备上的张量形状: {z.shape}")
return True
except Exception as e:
print(f"CUDA计算测试失败: {e}")
return False
if __name__ == "__main__":
test_cuda_availability()

View File

@@ -1,33 +0,0 @@
#!/usr/bin/env python3
"""
测试 timm 导入是否正常工作
"""
import sys
print("Python version:", sys.version)
try:
from timm.layers import to_2tuple, DropPath, trunc_normal_
from timm.models import register_model
print("✓ 成功导入 timm.layers.to_2tuple")
print("✓ 成功导入 timm.layers.DropPath")
print("✓ 成功导入 timm.layers.trunc_normal_")
print("✓ 成功导入 timm.models.register_model")
except ImportError as e:
print(f"✗ 导入失败: {e}")
sys.exit(1)
try:
from models.swiftformer import SwiftFormer_XS
print("✓ 成功导入 SwiftFormer_XS")
except ImportError as e:
print(f"✗ 导入 SwiftFormer_XS 失败: {e}")
sys.exit(1)
try:
from models.swiftformer_temporal import SwiftFormerTemporal_XS
print("✓ 成功导入 SwiftFormerTemporal_XS")
except ImportError as e:
print(f"✗ 导入 SwiftFormerTemporal_XS 失败: {e}")
sys.exit(1)
print("\n✅ 所有导入测试通过!")

View File

@@ -1,60 +0,0 @@
#!/usr/bin/env python3
"""
Test script for SwiftFormerTemporal model
"""
import torch
import sys
import os
# Add current directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from models.swiftformer_temporal import SwiftFormerTemporal_XS
def test_model():
print("Testing SwiftFormerTemporal model...")
# Create model
model = SwiftFormerTemporal_XS(num_frames=3, use_representation_head=True)
print(f'Model created: {model.__class__.__name__}')
print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}')
# Test forward pass
batch_size = 2
num_frames = 3
height = width = 224
x = torch.randn(batch_size, 3 * num_frames, height, width)
print(f'\nInput shape: {x.shape}')
with torch.no_grad():
pred_frame, representation = model(x)
print(f'Predicted frame shape: {pred_frame.shape}')
print(f'Representation shape: {representation.shape if representation is not None else "None"}')
# Check output ranges
print(f'\nPredicted frame range: [{pred_frame.min():.3f}, {pred_frame.max():.3f}]')
# Test loss function
from util.frame_losses import MultiTaskLoss
criterion = MultiTaskLoss()
target = torch.randn_like(pred_frame)
temporal_indices = torch.tensor([3, 3], dtype=torch.long)
loss, loss_dict = criterion(pred_frame, target, representation, temporal_indices)
print(f'\nLoss test:')
for k, v in loss_dict.items():
print(f' {k}: {v:.4f}')
print('\nAll tests passed!')
return True
if __name__ == '__main__':
try:
test_model()
except Exception as e:
print(f'Test failed with error: {e}')
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -1,182 +0,0 @@
"""
Loss functions for frame prediction and representation learning
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SSIMLoss(nn.Module):
"""
Structural Similarity Index Measure Loss
Based on: https://github.com/Po-Hsun-Su/pytorch-ssim
"""
def __init__(self, window_size=11, size_average=True):
super().__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 3
self.window = self.create_window(window_size, self.channel)
def create_window(self, window_size, channel):
def gaussian(window_size, sigma):
gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
def forward(self, img1, img2):
# Ensure window is on correct device
if self.window.device != img1.device:
self.window = self.window.to(img1.device)
mu1 = F.conv2d(img1, self.window, padding=self.window_size//2, groups=self.channel)
mu2 = F.conv2d(img2, self.window, padding=self.window_size//2, groups=self.channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1*img1, self.window, padding=self.window_size//2, groups=self.channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if self.size_average:
return 1 - ssim_map.mean()
else:
return 1 - ssim_map.mean(1).mean(1).mean(1)
class FramePredictionLoss(nn.Module):
"""
Combined loss for frame prediction
"""
def __init__(self, l1_weight=1.0, ssim_weight=0.1, use_ssim=True):
super().__init__()
self.l1_weight = l1_weight
self.ssim_weight = ssim_weight
self.use_ssim = use_ssim
self.l1_loss = nn.L1Loss()
if use_ssim:
self.ssim_loss = SSIMLoss()
def forward(self, pred, target):
"""
Args:
pred: predicted frame [B, 3, H, W] in range [-1, 1]
target: target frame [B, 3, H, W] in range [-1, 1]
Returns:
total_loss, loss_dict
"""
loss_dict = {}
# L1 loss
l1_loss = self.l1_loss(pred, target)
loss_dict['l1'] = l1_loss
total_loss = self.l1_weight * l1_loss
# SSIM loss
if self.use_ssim:
ssim_loss = self.ssim_loss(pred, target)
loss_dict['ssim'] = ssim_loss
total_loss += self.ssim_weight * ssim_loss
loss_dict['total'] = total_loss
return total_loss, loss_dict
class ContrastiveLoss(nn.Module):
"""
Contrastive loss for representation learning
Positive pairs: representations from adjacent frames
Negative pairs: representations from distant frames
"""
def __init__(self, temperature=0.1, margin=1.0):
super().__init__()
self.temperature = temperature
self.margin = margin
self.cosine_similarity = nn.CosineSimilarity(dim=-1)
def forward(self, representations, temporal_indices):
"""
Args:
representations: [B, D] representation vectors
temporal_indices: [B] temporal indices of each sample
Returns:
contrastive_loss
"""
batch_size = representations.size(0)
# Compute similarity matrix
sim_matrix = torch.matmul(representations, representations.T) / self.temperature
# Create positive mask (adjacent frames)
indices_expanded = temporal_indices.unsqueeze(0)
diff = torch.abs(indices_expanded - indices_expanded.T)
positive_mask = (diff == 1).float()
# Create negative mask (distant frames)
negative_mask = (diff > 2).float()
# Positive loss
pos_sim = sim_matrix * positive_mask
pos_loss = -torch.log(torch.exp(pos_sim) / torch.exp(sim_matrix).sum(dim=-1, keepdim=True) + 1e-8)
pos_loss = (pos_loss * positive_mask).sum() / (positive_mask.sum() + 1e-8)
# Negative loss (push apart)
neg_sim = sim_matrix * negative_mask
neg_loss = torch.relu(neg_sim - self.margin).mean()
return pos_loss + 0.1 * neg_loss
class MultiTaskLoss(nn.Module):
"""
Multi-task loss combining frame prediction and representation learning
"""
def __init__(self, frame_weight=1.0, contrastive_weight=0.1,
l1_weight=1.0, ssim_weight=0.1, use_contrastive=True):
super().__init__()
self.frame_weight = frame_weight
self.contrastive_weight = contrastive_weight
self.use_contrastive = use_contrastive
self.frame_loss = FramePredictionLoss(l1_weight=l1_weight, ssim_weight=ssim_weight)
if use_contrastive:
self.contrastive_loss = ContrastiveLoss()
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
"""
Args:
pred_frame: predicted frame [B, 3, H, W]
target_frame: target frame [B, 3, H, W]
representations: [B, D] representation vectors (optional)
temporal_indices: [B] temporal indices (optional)
Returns:
total_loss, loss_dict
"""
loss_dict = {}
# Frame prediction loss
frame_loss, frame_loss_dict = self.frame_loss(pred_frame, target_frame)
loss_dict.update({f'frame_{k}': v for k, v in frame_loss_dict.items()})
total_loss = self.frame_weight * frame_loss
# Contrastive loss (if representations provided)
if self.use_contrastive and representations is not None and temporal_indices is not None:
contrastive_loss = self.contrastive_loss(representations, temporal_indices)
loss_dict['contrastive'] = contrastive_loss
total_loss += self.contrastive_weight * contrastive_loss
loss_dict['total'] = total_loss
return total_loss, loss_dict

View File

@@ -48,27 +48,39 @@ class VideoFrameDataset(Dataset):
self.is_train = is_train self.is_train = is_train
self.max_interval = max_interval self.max_interval = max_interval
# Collect all video folders # if num_frames < 1:
# raise ValueError("num_frames must be >= 1")
# if frame_size < 1:
# raise ValueError("frame_size must be >= 1")
# if max_interval < 1:
# raise ValueError("max_interval must be >= 1")
# Collect all video folders and their frame files
self.video_folders = [] self.video_folders = []
self.video_frame_files = [] # list of list of Path objects
for item in self.root_dir.iterdir(): for item in self.root_dir.iterdir():
if item.is_dir(): if item.is_dir():
self.video_folders.append(item) self.video_folders.append(item)
# Get all frame files
frame_files = sorted([f for f in item.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
self.video_frame_files.append(frame_files)
if len(self.video_folders) == 0: if len(self.video_folders) == 0:
raise ValueError(f"No video folders found in {root_dir}") raise ValueError(f"No video folders found in {root_dir}")
# Build frame index: list of (video_idx, start_frame_idx) # Build frame index: list of (video_idx, start_frame_idx)
self.frame_indices = [] self.frame_indices = []
for video_idx, video_folder in enumerate(self.video_folders): for video_idx, frame_files in enumerate(self.video_frame_files):
# Get all frame files # Minimum frames needed considering max interval
frame_files = sorted([f for f in video_folder.iterdir() min_frames_needed = num_frames * max_interval + 1
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) if len(frame_files) < min_frames_needed:
if len(frame_files) < num_frames + 1:
continue # Skip videos with insufficient frames continue # Skip videos with insufficient frames
# Add all possible starting positions # Add all possible starting positions
for start_idx in range(len(frame_files) - num_frames): # Ensure that for any interval up to max_interval, all frames are within bounds
max_start = len(frame_files) - num_frames * max_interval
for start_idx in range(max_start):
self.frame_indices.append((video_idx, start_idx)) self.frame_indices.append((video_idx, start_idx))
if len(self.frame_indices) == 0: if len(self.frame_indices) == 0:
@@ -80,14 +92,12 @@ class VideoFrameDataset(Dataset):
else: else:
self.transform = transform self.transform = transform
# Normalization for Y channel (single channel) # Simple normalization to [-1, 1] range (不使用ImageNet标准化)
# Compute average of ImageNet RGB means and stds # Convert pixel values [0, 255] to [-1, 1]
y_mean = (0.485 + 0.456 + 0.406) / 3.0 # This matches the model's tanh output range
y_std = (0.229 + 0.224 + 0.225) / 3.0 self.normalize = None # We'll handle normalization manually
self.normalize = transforms.Normalize(
mean=[y_mean], # print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]")
std=[y_std]
)
def _default_transform(self): def _default_transform(self):
"""Default transform with augmentation for training""" """Default transform with augmentation for training"""
@@ -105,9 +115,12 @@ class VideoFrameDataset(Dataset):
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image: def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
"""Load a single frame as PIL Image""" """Load a single frame as PIL Image"""
video_folder = self.video_folders[video_idx] frame_files = self.video_frame_files[video_idx]
frame_files = sorted([f for f in video_folder.iterdir() if frame_idx < 0 or frame_idx >= len(frame_files):
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) raise IndexError(
f"Frame index {frame_idx} out of range for video {video_idx} "
f"(0-{len(frame_files)-1})"
)
frame_path = frame_files[frame_idx] frame_path = frame_files[frame_idx]
return Image.open(frame_path).convert('RGB') return Image.open(frame_path).convert('RGB')
@@ -144,19 +157,21 @@ class VideoFrameDataset(Dataset):
if self.transform: if self.transform:
target_frame = self.transform(target_frame) target_frame = self.transform(target_frame)
# Convert to tensors, normalize, and convert to grayscale (Y channel) # Convert to tensors and convert to grayscale (Y channel)
input_tensors = [] input_tensors = []
for frame in input_frames: for frame in input_frames:
tensor = transforms.ToTensor()(frame) # [3, H, W] tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1]
# Convert RGB to grayscale using weighted sum # Convert RGB to grayscale using weighted sum
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL) # Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W] gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1]
gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast) # Normalize from [0, 1] to [-1, 1]
gray = gray * 2 - 1 # [0,1] -> [-1,1]
input_tensors.append(gray) input_tensors.append(gray)
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W] target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1]
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0) target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
target_gray = self.normalize(target_gray) # Normalize from [0, 1] to [-1, 1]
target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1]
# Concatenate input frames along channel dimension # Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W] input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
@@ -167,52 +182,52 @@ class VideoFrameDataset(Dataset):
return input_concatenated, target_gray, temporal_idx return input_concatenated, target_gray, temporal_idx
class SyntheticVideoDataset(Dataset): # class SyntheticVideoDataset(Dataset):
""" # """
Synthetic dataset for testing - generates random frames # Synthetic dataset for testing - generates random frames
""" # """
def __init__(self, # def __init__(self,
num_samples: int = 1000, # num_samples: int = 1000,
num_frames: int = 3, # num_frames: int = 3,
frame_size: int = 224, # frame_size: int = 224,
is_train: bool = True): # is_train: bool = True):
self.num_samples = num_samples # self.num_samples = num_samples
self.num_frames = num_frames # self.num_frames = num_frames
self.frame_size = frame_size # self.frame_size = frame_size
self.is_train = is_train # self.is_train = is_train
# Normalization for Y channel (single channel) # # Normalization for Y channel (single channel)
y_mean = (0.485 + 0.456 + 0.406) / 3.0 # y_mean = (0.485 + 0.456 + 0.406) / 3.0
y_std = (0.229 + 0.224 + 0.225) / 3.0 # y_std = (0.229 + 0.224 + 0.225) / 3.0
self.normalize = transforms.Normalize( # self.normalize = transforms.Normalize(
mean=[y_mean], # mean=[y_mean],
std=[y_std] # std=[y_std]
) # )
def __len__(self): # def __len__(self):
return self.num_samples # return self.num_samples
def __getitem__(self, idx): # def __getitem__(self, idx):
# Generate random "frames" (noise with temporal correlation) # # Generate random "frames" (noise with temporal correlation)
input_frames = [] # input_frames = []
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1 # prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
for i in range(self.num_frames): # for i in range(self.num_frames):
# Add some temporal correlation # # Add some temporal correlation
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 # frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
frame = torch.clamp(frame, -1, 1) # frame = torch.clamp(frame, -1, 1)
input_frames.append(self.normalize(frame)) # input_frames.append(self.normalize(frame))
prev_frame = frame # prev_frame = frame
# Target frame (next in sequence) # # Target frame (next in sequence)
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 # target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
target_frame = torch.clamp(target_frame, -1, 1) # target_frame = torch.clamp(target_frame, -1, 1)
target_tensor = self.normalize(target_frame) # target_tensor = self.normalize(target_frame)
# Concatenate inputs # # Concatenate inputs
input_concatenated = torch.cat(input_frames, dim=0) # input_concatenated = torch.cat(input_frames, dim=0)
# Temporal index # # Temporal index
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) # temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx # return input_concatenated, target_tensor, temporal_idx