Compare commits

..

9 Commits

10 changed files with 1271 additions and 475 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
.vscode/
__pycache__/
venv/
runs/

58
dist_temporal_train.sh Executable file
View File

@@ -0,0 +1,58 @@
#!/usr/bin/env bash
# Distributed training script for SwiftFormerTemporal
# Usage: ./dist_temporal_train.sh <DATA_PATH> <NUM_GPUS> [OPTIONS]
DATA_PATH=$1
NUM_GPUS=$2
# Shift arguments to pass remaining options to python script
shift 2
# Default parameters
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
BATCH_SIZE=${BATCH_SIZE:-128}
EPOCHS=${EPOCHS:-100}
# LR=${LR:-1e-3}
LR=${LR:-0.01}
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
echo "Starting distributed training with $NUM_GPUS GPUs"
echo "Data path: $DATA_PATH"
echo "Model: $MODEL"
echo "Batch size: $BATCH_SIZE"
echo "Epochs: $EPOCHS"
echo "Output dir: $OUTPUT_DIR"
# Check if torch.distributed.launch or torchrun should be used
# For newer PyTorch versions (>=1.9), torchrun is recommended
PYTHON_VERSION=$(python -c "import torch; print(torch.__version__)")
echo "PyTorch version: $PYTHON_VERSION"
# Use torchrun for newer PyTorch versions
if [[ "$PYTHON_VERSION" =~ ^2\. ]] || [[ "$PYTHON_VERSION" =~ ^1\.1[0-9]\. ]]; then
echo "Using torchrun (PyTorch >=1.10)"
torchrun --nproc_per_node=$NUM_GPUS --master_port=12345 main_temporal.py \
--data-path "$DATA_PATH" \
--model "$MODEL" \
--batch-size $BATCH_SIZE \
--epochs $EPOCHS \
--lr $LR \
--output-dir "$OUTPUT_DIR" \
"$@"
else
echo "Using torch.distributed.launch"
python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port=12345 --use_env main_temporal.py \
--data-path "$DATA_PATH" \
--model "$MODEL" \
--batch-size $BATCH_SIZE \
--epochs $EPOCHS \
--lr $LR \
--output-dir "$OUTPUT_DIR" \
"$@"
fi
# For single-node multi-GPU training with specific options:
# --world-size 1 --rank 0 --dist-url 'tcp://localhost:12345'
echo "Training completed. Check logs in $OUTPUT_DIR"

484
evaluate_temporal.py Normal file
View File

@@ -0,0 +1,484 @@
"""
评估脚本 for SwiftFormerTemporal frame prediction
输出预测图注意反归一化以及对应指标mse&ssim&psnr
"""
import argparse
import os
import torch
import torch.nn as nn
import pickle
import numpy as np
import random
from pathlib import Path
import json
import matplotlib.pyplot as plt
from PIL import Image
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from util.video_dataset import VideoFrameDataset
from models.swiftformer_temporal import (
SwiftFormerTemporal_XS, SwiftFormerTemporal_S,
SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
)
# 导入SSIM和PSNR计算
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import warnings
warnings.filterwarnings('ignore')
def denormalize(tensor):
"""
将[-1, 1]范围的张量反归一化到[0, 255]范围
Args:
tensor: 形状为[B, C, H, W]或[C, H, W],值在[-1, 1]
Returns:
反归一化后的张量,值在[0, 255]
"""
# clip 到 [-1, 1] 范围
tensor = tensor.clamp(-1, 1)
# [-1, 1] -> [0, 1]
tensor = (tensor + 1) / 2
# [0, 1] -> [0, 255]
tensor = tensor * 255
return tensor.clamp(0, 255)
def minmax_denormalize(tensor):
tensor_min = tensor.min()
tensor_max = tensor.max()
tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
# tensor = tensor*2-1
tensor = tensor*255
return tensor.clamp(0, 255)
def calculate_metrics(pred, target, debug=False):
"""
计算MSE, SSIM, PSNR指标
Args:
pred: 预测图像,形状[H, W],值在[0, 255]
target: 目标图像,形状[H, W],值在[0, 255]
debug: 是否输出调试信息
Returns:
mse, ssim_value, psnr_value
"""
# 转换为numpy数组
pred_np = pred.cpu().numpy() if torch.is_tensor(pred) else pred
target_np = target.cpu().numpy() if torch.is_tensor(target) else target
# 确保是2D数组
if pred_np.ndim == 3:
pred_np = pred_np.squeeze(0)
if target_np.ndim == 3:
target_np = target_np.squeeze(0)
# if debug:
# print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
# print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
# print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
mse = np.mean((pred_np - target_np) ** 2)
data_range = 255.0
ssim_value = ssim(pred_np, target_np, data_range=data_range)
psnr_value = psnr(target_np, pred_np, data_range=data_range)
return mse, ssim_value, psnr_value
def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
input_frame_indices=None, target_frame_index=None):
"""
保存对比图:输入帧、目标帧、预测帧
Args:
input_frames: 输入帧列表,每个形状为[H, W],值在[0, 255]
target_frame: 目标帧,形状[H, W],值在[0, 255]
pred_frame: 预测帧,形状[H, W],值在[0, 255]
save_path: 保存路径
input_frame_indices: 输入帧的索引列表(可选)
target_frame_index: 目标帧索引(可选)
"""
num_input = len(input_frames)
fig, axes = plt.subplots(1, num_input + 2, figsize=(4*(num_input+2), 4))
# 绘制输入帧
for i in range(num_input):
ax = axes[i]
ax.imshow(input_frames[i], cmap='gray')
if input_frame_indices is not None:
ax.set_title(f'Input Frame {input_frame_indices[i]}')
else:
ax.set_title(f'Input {i+1}')
ax.axis('off')
# 绘制目标帧
ax = axes[num_input]
ax.imshow(target_frame, cmap='gray')
if target_frame_index is not None:
ax.set_title(f'Target Frame {target_frame_index}')
else:
ax.set_title('Target')
ax.axis('off')
# 绘制预测帧
ax = axes[num_input + 1]
ax.imshow(pred_frame, cmap='gray')
ax.set_title('Predicted')
ax.axis('off')
#debug print
print(target_frame)
print(pred_frame)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.close()
def evaluate_model(model, data_loader, device, args):
"""
评估模型并计算指标
Args:
model: 训练好的模型
data_loader: 数据加载器
device: 设备
args: 命令行参数
Returns:
metrics_dict: 包含所有指标的字典
sample_results: 示例结果用于可视化
"""
model.eval()
# model.train() # 临时使用训练模式
# 初始化指标累加器
total_mse = 0.0
total_ssim = 0.0
total_psnr = 0.0
total_samples = 0
# 存储示例结果用于可视化(使用蓄水池抽样随机选择)
sample_results = []
max_samples_to_save = args.num_samples_to_save
max_samples = args.max_samples
# 用于蓄水池抽样的计数器已处理的样本数不包括因max_samples限制而跳过的样本
sample_count = 0
with torch.no_grad():
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(data_loader):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
# 前向传播
pred_frames = model(input_frames)
# 反归一化用于指标计算
# pred_denorm = minmax_denormalize(pred_frames) # [B, 1, H, W]
pred_denorm = denormalize(pred_frames)
target_denorm = denormalize(target_frames) # [B, 1, H, W]
batch_size = input_frames.size(0)
# 计算每个样本的指标
for i in range(batch_size):
# 检查是否达到最大样本数限制
if max_samples is not None and total_samples >= max_samples:
break
pred_i = pred_denorm[i] # [1, H, W]
target_i = target_denorm[i] # [1, H, W]
# 对第一个样本启用调试
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
# if debug_mode:
# print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
# print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
# print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
# print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=False)
total_mse += mse
total_ssim += ssim_value
total_psnr += psnr_value
total_samples += 1
sample_count += 1
# 构建样本数据字典
input_denorm = denormalize(input_frames[i]) # [num_frames, H, W]
# 分离输入帧
input_frames_list = []
for j in range(args.num_frames):
input_frame_j = input_denorm[j].squeeze(0) # [H, W]
input_frames_list.append(input_frame_j.cpu().numpy())
sample_data = {
'input_frames': input_frames_list,
'target_frame': target_i.squeeze(0).cpu().numpy(),
'pred_frame': pred_i.squeeze(0).cpu().numpy(),
'metrics': {
'mse': mse,
'ssim': ssim_value,
'psnr': psnr_value
},
'batch_idx': batch_idx,
'sample_idx': i
}
# 蓄水池抽样 (Reservoir Sampling)
if sample_count <= max_samples_to_save:
# 蓄水池未满,直接加入
sample_results.append(sample_data)
else:
# 以 max_samples_to_save / sample_count 的概率替换蓄水池中的一个随机位置
r = random.randint(0, sample_count - 1)
if r < max_samples_to_save:
sample_results[r] = sample_data
# 检查是否达到最大样本数限制
if max_samples is not None and total_samples >= max_samples:
print(f"达到最大样本数限制: {max_samples}")
break
# 进度打印
if (batch_idx + 1) % 10 == 0:
print(f'Processed {batch_idx + 1} batches, {total_samples} samples')
# 计算平均指标
if total_samples > 0:
avg_mse = float(total_mse / total_samples)
avg_ssim = float(total_ssim / total_samples)
avg_psnr = float(total_psnr / total_samples)
else:
avg_mse = avg_ssim = avg_psnr = 0.0
metrics_dict = {
'mse': avg_mse,
'ssim': avg_ssim,
'psnr': avg_psnr,
'num_samples': total_samples
}
return metrics_dict, sample_results
def main(args):
print("评估参数:", args)
device = torch.device(args.device)
# 设置随机种子
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
cudnn.benchmark = True
# 构建数据集
print("构建数据集...")
dataset_val = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=False,
max_interval=args.max_interval
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
shuffle=False,
drop_last=False
)
# 创建模型
print(f"创建模型: {args.model}")
model_kwargs = {
'num_frames': args.num_frames,
}
if args.model == 'SwiftFormerTemporal_XS':
model = SwiftFormerTemporal_XS(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_S':
model = SwiftFormerTemporal_S(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L1':
model = SwiftFormerTemporal_L1(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L3':
model = SwiftFormerTemporal_L3(**model_kwargs)
else:
raise ValueError(f"未知模型: {args.model}")
model.to(device)
# 加载检查点
if args.resume:
print(f"加载检查点: {args.resume}")
try:
# 尝试使用weights_only=False加载PyTorch 2.6+需要)
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
except (pickle.UnpicklingError, TypeError) as e:
print(f"使用weights_only=False加载失败: {e}")
print("尝试使用torch.serialization.add_safe_globals...")
# 处理状态字典(可能包含'module.'前缀)
if 'model' in checkpoint:
state_dict = checkpoint['model']
elif 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
# 移除'module.'前缀(如果存在)
if hasattr(model, 'module'):
model.module.load_state_dict(state_dict)
else:
# 如果状态字典有'module.'前缀但模型没有,需要移除前缀
if any(key.startswith('module.') for key in state_dict.keys()):
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
state_dict = new_state_dict
model.load_state_dict(state_dict)
print(f"检查点加载成功epoch: {checkpoint.get('epoch', 'unknown')}")
else:
print("警告: 未提供检查点路径,使用随机初始化的模型")
# 创建输出目录
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# 评估模型
print("开始评估...")
metrics, sample_results = evaluate_model(model, data_loader_val, device, args)
# 打印指标
print("\n" + "="*50)
print("评估结果:")
print(f"MSE: {metrics['mse']:.6f}")
print(f"SSIM: {metrics['ssim']:.6f}")
print(f"PSNR: {metrics['psnr']:.6f} dB")
print(f"样本数量: {metrics['num_samples']}")
print("="*50)
# 保存指标到JSON文件
metrics_file = output_dir / 'evaluation_metrics.json'
with open(metrics_file, 'w') as f:
json.dump(metrics, f, indent=4)
print(f"指标已保存到: {metrics_file}")
# 保存示例可视化
if sample_results:
print(f"\n保存 {len(sample_results)} 个示例可视化...")
samples_dir = output_dir / 'sample_predictions'
samples_dir.mkdir(exist_ok=True)
for i, sample in enumerate(sample_results):
save_path = samples_dir / f'sample_{i:03d}.png'
# 生成输入帧索引(假设连续)
input_frame_indices = list(range(1, args.num_frames + 1))
target_frame_index = args.num_frames + 1
save_comparison_figure(
sample['input_frames'],
sample['target_frame'],
sample['pred_frame'],
save_path,
input_frame_indices=input_frame_indices,
target_frame_index=target_frame_index
)
# 保存该样本的指标
sample_metrics_file = samples_dir / f'sample_{i:03d}_metrics.txt'
with open(sample_metrics_file, 'w') as f:
f.write(f"Sample {i} (batch {sample['batch_idx']}, idx {sample['sample_idx']})\n")
f.write(f"MSE: {sample['metrics']['mse']:.6f}\n")
f.write(f"SSIM: {sample['metrics']['ssim']:.6f}\n")
f.write(f"PSNR: {sample['metrics']['psnr']:.6f} dB\n")
print(f"示例可视化已保存到: {samples_dir}")
# 生成汇总报告
report_file = output_dir / 'evaluation_report.txt'
with open(report_file, 'w') as f:
f.write("SwiftFormerTemporal 帧预测评估报告\n")
f.write("="*50 + "\n")
f.write(f"模型: {args.model}\n")
f.write(f"检查点: {args.resume}\n")
f.write(f"数据集: {args.data_path}\n")
f.write(f"输入帧数: {args.num_frames}\n")
f.write(f"帧大小: {args.frame_size}\n")
f.write(f"批次大小: {args.batch_size}\n")
f.write(f"样本总数: {metrics['num_samples']}\n\n")
f.write("评估指标:\n")
f.write(f" MSE: {metrics['mse']:.6f}\n")
f.write(f" SSIM: {metrics['ssim']:.6f}\n")
f.write(f" PSNR: {metrics['psnr']:.6f} dB\n")
print(f"评估报告已保存到: {report_file}")
print("\n评估完成!")
def get_args_parser():
parser = argparse.ArgumentParser(
'SwiftFormerTemporal 评估脚本', add_help=False)
# 数据集参数
parser.add_argument('--data-path', default='./videos', type=str,
help='视频数据集路径')
parser.add_argument('--num-frames', default=3, type=int,
help='输入帧数 (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='输入帧大小')
parser.add_argument('--max-interval', default=4, type=int,
help='连续帧之间的最大间隔')
# 模型参数
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='要评估的模型名称')
# 评估参数
parser.add_argument('--batch-size', default=16, type=int,
help='评估批次大小')
parser.add_argument('--num-samples-to-save', default=10, type=int,
help='保存可视化的样本数量')
parser.add_argument('--max-samples', default=None, type=int,
help='最大评估样本数None表示全部')
# 系统参数
parser.add_argument('--output-dir', default='./evaluation_results',
help='保存结果的路径')
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu',
help='使用的设备')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='检查点路径')
parser.add_argument('--num-workers', default=4, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='在DataLoader中固定CPU内存')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
return parser
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'SwiftFormerTemporal 评估', parents=[get_args_parser()])
args = parser.parse_args()
# 确保输出目录存在
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)

View File

@@ -6,8 +6,10 @@ import datetime
import numpy as np
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import json
import os
from pathlib import Path
from timm.scheduler import create_scheduler
@@ -17,8 +19,15 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma
from util import *
from models import *
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
from util.frame_losses import MultiTaskLoss
from util.video_dataset import VideoFrameDataset
# from util.frame_losses import MultiTaskLoss
# Try to import TensorBoard
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
def get_args_parser():
@@ -34,34 +43,68 @@ def get_args_parser():
help='Number of input frames (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='Input frame size')
parser.add_argument('--max-interval', default=1, type=int,
parser.add_argument('--max-interval', default=10, type=int,
help='Maximum interval between consecutive frames')
# Model parameters
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--use-representation-head', action='store_true',
help='Use representation head for pose/velocity prediction')
parser.add_argument('--representation-dim', default=128, type=int,
help='Dimension of representation vector')
# Training parameters
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
# Optimizer parameters (required by timm's create_optimizer)
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='agc',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
help='learning rate (default: 1e-3)')
# Learning rate schedule parameters (required by timm's create_scheduler)
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-3, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Loss parameters
parser.add_argument('--frame-weight', type=float, default=1.0,
help='Weight for frame prediction loss')
parser.add_argument('--contrastive-weight', type=float, default=0.1,
help='Weight for contrastive loss')
parser.add_argument('--l1-weight', type=float, default=1.0,
help='Weight for L1 loss')
parser.add_argument('--ssim-weight', type=float, default=0.1,
help='Weight for SSIM loss')
# parser.add_argument('--l1-weight', type=float, default=1.0,
# help='Weight for L1 loss')
# parser.add_argument('--ssim-weight', type=float, default=0.1,
# help='Weight for SSIM loss')
parser.add_argument('--no-contrastive', action='store_true',
help='Disable contrastive loss')
parser.add_argument('--no-ssim', action='store_true',
@@ -78,7 +121,7 @@ def get_args_parser():
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--num-workers', default=4, type=int)
parser.add_argument('--num-workers', default=16, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
@@ -90,26 +133,26 @@ def get_args_parser():
parser.add_argument('--dist-url', default='env://',
help='url used to set up distributed training')
# TensorBoard logging
parser.add_argument('--tensorboard-logdir', default='./runs',
type=str, help='TensorBoard log directory')
parser.add_argument('--log-images', action='store_true',
help='Log sample images to TensorBoard')
parser.add_argument('--image-log-freq', default=100, type=int,
help='Frequency of logging images (in iterations)')
return parser
def build_dataset(is_train, args):
"""Build video frame dataset"""
if args.dataset_type == 'synthetic':
dataset = SyntheticVideoDataset(
num_samples=1000 if is_train else 200,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train
)
else:
dataset = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train,
max_interval=args.max_interval
)
dataset = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train,
max_interval=args.max_interval
)
return dataset
@@ -159,8 +202,6 @@ def main(args):
print(f"Creating model: {args.model}")
model_kwargs = {
'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
}
if args.model == 'SwiftFormerTemporal_XS':
@@ -203,14 +244,18 @@ def main(args):
# Create scheduler
lr_scheduler, _ = create_scheduler(args, optimizer)
# Create loss function
criterion = MultiTaskLoss(
frame_weight=args.frame_weight,
contrastive_weight=args.contrastive_weight,
l1_weight=args.l1_weight,
ssim_weight=args.ssim_weight,
use_contrastive=not args.no_contrastive
)
# Create loss function - simple MSE for Y channel prediction
class MSELossWrapper(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, pred_frame, target_frame, temporal_indices=None):
loss = self.mse(pred_frame, target_frame)
loss_dict = {'mse': loss}
return loss, loss_dict
criterion = MSELossWrapper()
# Resume from checkpoint
output_dir = Path(args.output_dir)
@@ -219,7 +264,7 @@ def main(args):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
model_without_ddp.load_state_dict(checkpoint['model'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
@@ -231,6 +276,21 @@ def main(args):
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
# Initialize TensorBoard writer
writer = None
if TENSORBOARD_AVAILABLE and utils.is_main_process():
from datetime import datetime
# Create log directory with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_dir = os.path.join(args.tensorboard_logdir, f"exp_{timestamp}")
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")
print(f"To view logs, run: tensorboard --logdir={log_dir}")
elif not TENSORBOARD_AVAILABLE and utils.is_main_process():
print("Warning: TensorBoard not available. Install tensorboard or tensorboardX.")
print("Training will continue without TensorBoard logging.")
if args.eval:
test_stats = evaluate(data_loader_val, model, criterion, device)
print(f"Test stats: {test_stats}")
@@ -239,20 +299,24 @@ def main(args):
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
# Global step counter for TensorBoard
global_step = 0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
train_stats, global_step = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
model_ema=model_ema
optimizer, device, epoch, loss_scaler, args.clip_grad, args.clip_mode,
model_ema=model_ema, writer=writer,
global_step=global_step, args=args
)
lr_scheduler.step(epoch)
# Save checkpoint
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1):
if args.output_dir and (epoch % 1 == 0 or epoch == args.epochs - 1):
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
utils.save_on_master({
'model': model_without_ddp.state_dict(),
@@ -266,10 +330,10 @@ def main(args):
# Evaluate
if epoch % 5 == 0 or epoch == args.epochs - 1:
test_stats = evaluate(data_loader_val, model, criterion, device)
test_stats = evaluate(data_loader_val, model, criterion, device, writer=writer, epoch=epoch)
print(f"Epoch {epoch}: Test stats: {test_stats}")
# Log stats
# Log stats to text file
log_stats = {
**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
@@ -285,28 +349,39 @@ def main(args):
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f'Training time {total_time_str}')
# Close TensorBoard writer
if writer is not None:
writer.close()
print(f"TensorBoard logs saved to: {writer.log_dir}")
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0, clip_mode='norm', model_ema=None, **kwargs):
clip_grad=0.01, clip_mode='norm', model_ema=None, writer=None,
global_step=0, args=None, **kwargs):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = f'Epoch: [{epoch}]'
print_freq = 10
for input_frames, target_frames, temporal_indices in metric_logger.log_every(
data_loader, print_freq, header):
# 添加诊断指标
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Forward pass
with torch.cuda.amp.autocast():
pred_frames, representations = model(input_frames)
with torch.amp.autocast(device_type='cuda'):
pred_frames = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
representations, temporal_indices
temporal_indices
)
loss_value = loss.item()
@@ -315,6 +390,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
raise ValueError(f"Loss is {loss_value}")
optimizer.zero_grad()
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters())
@@ -322,36 +398,131 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
if model_ema is not None:
model_ema.update(model)
# 计算诊断指标
pred_mean = pred_frames.mean().item()
pred_std = pred_frames.std().item()
# 计算梯度范数
total_grad_norm = 0.0
for param in model.parameters():
if param.grad is not None:
total_grad_norm += param.grad.norm().item()
# 记录诊断指标
metric_logger.update(pred_mean=pred_mean)
metric_logger.update(pred_std=pred_std)
metric_logger.update(grad_norm=total_grad_norm)
# # 每50个批次打印一次BatchNorm统计
if batch_idx % 50 == 0:
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
# # 检查一个BatchNorm层的运行统计
# for name, module in model.named_modules():
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
# print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
# break
# Log to TensorBoard
if writer is not None:
# Log scalar metrics every iteration
writer.add_scalar('train/loss', loss_value, global_step)
writer.add_scalar('train/lr', optimizer.param_groups[0]["lr"], global_step)
# Log individual loss components
for k, v in loss_dict.items():
if torch.is_tensor(v):
writer.add_scalar(f'train/{k}', v.item(), global_step)
else:
writer.add_scalar(f'train/{k}', v, global_step)
# Log diagnostic metrics
writer.add_scalar('train/pred_mean', pred_mean, global_step)
writer.add_scalar('train/pred_std', pred_std, global_step)
writer.add_scalar('train/grad_norm', total_grad_norm, global_step)
# Log images periodically
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
with torch.no_grad():
# Take first sample from batch for visualization
pred_vis = model(input_frames[:1])
# Convert to appropriate format for TensorBoard
# Assuming frames are in [B, C, H, W] format
writer.add_images('train/input', input_frames[:1], global_step)
writer.add_images('train/target', target_frames[:1], global_step)
writer.add_images('train/predicted', pred_vis[:1], global_step)
# Update metrics
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
global_step += 1
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
# Log epoch-level metrics
if writer is not None:
for k, meter in metric_logger.meters.items():
writer.add_scalar(f'train_epoch/{k}', meter.global_avg, epoch)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, global_step
@torch.no_grad()
def evaluate(data_loader, model, criterion, device):
def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header):
# 添加诊断指标
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('target_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('target_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(metric_logger.log_every(data_loader, 10, header)):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Compute output
with torch.cuda.amp.autocast():
pred_frames, representations = model(input_frames)
with torch.amp.autocast(device_type='cuda'):
pred_frames = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
representations, temporal_indices
temporal_indices
)
# 计算诊断指标
pred_mean = pred_frames.mean().item()
pred_std = pred_frames.std().item()
target_mean = target_frames.mean().item()
target_std = target_frames.std().item()
# 更新诊断指标
metric_logger.update(pred_mean=pred_mean)
metric_logger.update(pred_std=pred_std)
metric_logger.update(target_mean=target_mean)
metric_logger.update(target_std=target_std)
# # 第一个批次打印详细诊断信息
# if batch_idx == 0:
# print(f"[评估诊断] 批次 0:")
# print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
# print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
# print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
# print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
# # 检查BatchNorm运行统计
# for name, module in model.named_modules():
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
# print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
# if module.running_var[0].item() < 1e-6:
# print(f" 警告: BatchNorm运行方差接近零!")
# break
# Update metrics
metric_logger.update(loss=loss.item())
for k, v in loss_dict.items():
@@ -359,6 +530,12 @@ def evaluate(data_loader, model, criterion, device):
metric_logger.synchronize_between_processes()
print('* Test stats:', metric_logger)
# Log validation metrics to TensorBoard
if writer is not None:
for k, meter in metric_logger.meters.items():
writer.add_scalar(f'val/{k}', meter.global_avg, epoch)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

View File

@@ -6,9 +6,9 @@ import copy
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from timm.models.layers.helpers import to_2tuple
from timm.layers import DropPath, trunc_normal_
from timm.models import register_model
from timm.layers import to_2tuple
import einops
SwiftFormer_width = {

View File

@@ -7,96 +7,117 @@ from .swiftformer import (
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
stem, Embedding, Stage
)
from timm.models.layers import DropPath, trunc_normal_
from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder"""
"""Upsampling block for frame prediction decoder without residual connections"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
self.conv = nn.ConvTranspose2d(
# 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=False
bias=False # 禁用bias因为使用BN
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
# 使用ReLU激活函数
self.activation = nn.ReLU(inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu')
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
# 初始化BN层使用默认初始化
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
# 主路径
x = self.conv_transpose(x)
x = self.bn1(x)
x = self.activation(x)
x = self.conv1(x)
x = self.bn2(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn3(x)
x = self.activation(x)
return x
class FramePredictionDecoder(nn.Module):
"""Lightweight decoder for frame prediction with optional skip connections"""
def __init__(self, embed_dims, output_channels=3, use_skip=False):
"""Improved decoder for frame prediction"""
def __init__(self, embed_dims, output_channels=1):
super().__init__()
self.use_skip = use_skip
# Reverse the embed_dims for decoder
decoder_dims = embed_dims[::-1]
# Define decoder dimensions independently (no skip connections)
start_dim = embed_dims[-1]
decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
self.blocks = nn.ModuleList()
# First upsampling from bottleneck to stage4 resolution
# 第一个blockstride=2 (decoder_dims[0] -> decoder_dims[1])
self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage4 to stage3
# 第二个blockstride=2 (decoder_dims[1] -> decoder_dims[2])
self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage3 to stage2
# 第三个blockstride=2 (decoder_dims[2] -> decoder_dims[3])
self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage2 to original resolution (4x upsampling total)
self.blocks.append(nn.Sequential(
nn.ConvTranspose2d(
decoder_dims[3], 32,
kernel_size=3, stride=2, padding=1, output_padding=1
),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1),
nn.Tanh() # Output in [-1, 1] range
# 第四个blockstride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
self.blocks.append(DecoderBlock(
decoder_dims[3], 64,
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
))
# If using skip connections, we need to adjust input channels for each block
if use_skip:
# We'll modify the first three blocks to accept concatenated features
# Instead of modifying existing blocks, we'll replace them with custom blocks
# For simplicity, we'll keep the same architecture but forward will handle concatenation
pass
self.final_block = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
nn.Tanh()
)
def forward(self, x, skip_features=None):
def forward(self, x):
"""
Args:
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
skip_features: list of encoder features from stages [stage2, stage1, stage0]
each of shape [B, C, H', W'] where C matches decoder dims?
"""
if self.use_skip and skip_features is not None:
# Ensure we have exactly 3 skip features (for the first three blocks)
assert len(skip_features) == 3, "Need 3 skip features for skip connections"
# Reverse skip_features to match decoder order: stage2, stage1, stage0
# skip_features[0] should be stage2 (H/16), [1] stage1 (H/8), [2] stage0 (H/4)
skip_features = skip_features[::-1] # Now index 0: stage2, 1: stage1, 2: stage0
# 不使用skip connections
for i in range(4):
x = self.blocks[i](x)
for i, block in enumerate(self.blocks):
if self.use_skip and skip_features is not None and i < 3:
# Concatenate skip feature along channel dimension
# Ensure spatial dimensions match (they should because of upsampling)
x = torch.cat([x, skip_features[i]], dim=1)
# Need to adjust block to accept extra channels? We'll create a separate block.
# For now, we'll just pass through, but this will cause channel mismatch.
# Instead, we should have created custom blocks with appropriate in_channels.
# This is a placeholder; we need to implement properly.
pass
x = block(x)
# 最终输出层:只进行特征精炼,不上采样
x = self.final_block(x)
return x
@@ -104,15 +125,12 @@ class SwiftFormerTemporal(nn.Module):
"""
SwiftFormer with temporal input for frame prediction.
Input: [B, num_frames, H, W] (Y channel only)
Output: predicted frame [B, 3, H, W] and optional representation
Output: predicted frame [B, 1, H, W] and optional representation
"""
def __init__(self,
model_name='XS',
num_frames=3,
use_decoder=True,
use_representation_head=False,
representation_dim=128,
return_features=False,
**kwargs):
super().__init__()
@@ -123,8 +141,6 @@ class SwiftFormerTemporal(nn.Module):
# Store configuration
self.num_frames = num_frames
self.use_decoder = use_decoder
self.use_representation_head = use_representation_head
self.return_features = return_features
# Modify stem to accept multiple frames (only Y channel)
in_channels = num_frames
@@ -155,79 +171,51 @@ class SwiftFormerTemporal(nn.Module):
# Frame prediction decoder
if use_decoder:
self.decoder = FramePredictionDecoder(embed_dims, output_channels=3)
# Representation head for pose/velocity prediction
if use_representation_head:
self.representation_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(embed_dims[-1], representation_dim),
nn.ReLU(),
nn.Linear(representation_dim, representation_dim)
self.decoder = FramePredictionDecoder(
embed_dims,
output_channels=1
)
else:
self.representation_head = None
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
# 使用Kaiming初始化适合ReLU
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
elif isinstance(m, nn.ConvTranspose2d):
# 反卷积层使用特定的初始化
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x):
"""Forward through encoder network, return list of stage features if return_features else final output"""
if self.return_features:
features = []
for idx, block in enumerate(self.network):
x = block(x)
# Collect output after each stage (indices 0,2,4,6 correspond to stages)
if idx in [0, 2, 4, 6]:
features.append(x)
return x, features
else:
for block in self.network:
x = block(x)
return x
for block in self.network:
x = block(x)
return x
def forward(self, x):
"""
Args:
x: input frames of shape [B, num_frames, H, W]
Returns:
If return_features is False:
pred_frame: predicted frame [B, 3, H, W] (or None)
representation: optional representation vector [B, representation_dim] (or None)
If return_features is True:
pred_frame, representation, features (list of stage features)
pred_frame: predicted frame [B, 1, H, W] (or None)
"""
# Encode
x = self.patch_embed(x)
if self.return_features:
x, features = self.forward_tokens(x)
else:
x = self.forward_tokens(x)
x = self.forward_tokens(x)
x = self.norm(x)
# Get representation if needed
representation = None
if self.representation_head is not None:
representation = self.representation_head(x)
# Decode to frame
pred_frame = None
if self.use_decoder:
pred_frame = self.decoder(x)
if self.return_features:
return pred_frame, representation, features
else:
return pred_frame, representation
return pred_frame
# Factory functions for different model sizes

View File

@@ -1,60 +0,0 @@
#!/usr/bin/env python3
"""
Test script for SwiftFormerTemporal model
"""
import torch
import sys
import os
# Add current directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from models.swiftformer_temporal import SwiftFormerTemporal_XS
def test_model():
print("Testing SwiftFormerTemporal model...")
# Create model
model = SwiftFormerTemporal_XS(num_frames=3, use_representation_head=True)
print(f'Model created: {model.__class__.__name__}')
print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}')
# Test forward pass
batch_size = 2
num_frames = 3
height = width = 224
x = torch.randn(batch_size, 3 * num_frames, height, width)
print(f'\nInput shape: {x.shape}')
with torch.no_grad():
pred_frame, representation = model(x)
print(f'Predicted frame shape: {pred_frame.shape}')
print(f'Representation shape: {representation.shape if representation is not None else "None"}')
# Check output ranges
print(f'\nPredicted frame range: [{pred_frame.min():.3f}, {pred_frame.max():.3f}]')
# Test loss function
from util.frame_losses import MultiTaskLoss
criterion = MultiTaskLoss()
target = torch.randn_like(pred_frame)
temporal_indices = torch.tensor([3, 3], dtype=torch.long)
loss, loss_dict = criterion(pred_frame, target, representation, temporal_indices)
print(f'\nLoss test:')
for k, v in loss_dict.items():
print(f' {k}: {v:.4f}')
print('\nAll tests passed!')
return True
if __name__ == '__main__':
try:
test_model()
except Exception as e:
print(f'Test failed with error: {e}')
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -1,182 +0,0 @@
"""
Loss functions for frame prediction and representation learning
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SSIMLoss(nn.Module):
"""
Structural Similarity Index Measure Loss
Based on: https://github.com/Po-Hsun-Su/pytorch-ssim
"""
def __init__(self, window_size=11, size_average=True):
super().__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 3
self.window = self.create_window(window_size, self.channel)
def create_window(self, window_size, channel):
def gaussian(window_size, sigma):
gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
def forward(self, img1, img2):
# Ensure window is on correct device
if self.window.device != img1.device:
self.window = self.window.to(img1.device)
mu1 = F.conv2d(img1, self.window, padding=self.window_size//2, groups=self.channel)
mu2 = F.conv2d(img2, self.window, padding=self.window_size//2, groups=self.channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1*img1, self.window, padding=self.window_size//2, groups=self.channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if self.size_average:
return 1 - ssim_map.mean()
else:
return 1 - ssim_map.mean(1).mean(1).mean(1)
class FramePredictionLoss(nn.Module):
"""
Combined loss for frame prediction
"""
def __init__(self, l1_weight=1.0, ssim_weight=0.1, use_ssim=True):
super().__init__()
self.l1_weight = l1_weight
self.ssim_weight = ssim_weight
self.use_ssim = use_ssim
self.l1_loss = nn.L1Loss()
if use_ssim:
self.ssim_loss = SSIMLoss()
def forward(self, pred, target):
"""
Args:
pred: predicted frame [B, 3, H, W] in range [-1, 1]
target: target frame [B, 3, H, W] in range [-1, 1]
Returns:
total_loss, loss_dict
"""
loss_dict = {}
# L1 loss
l1_loss = self.l1_loss(pred, target)
loss_dict['l1'] = l1_loss
total_loss = self.l1_weight * l1_loss
# SSIM loss
if self.use_ssim:
ssim_loss = self.ssim_loss(pred, target)
loss_dict['ssim'] = ssim_loss
total_loss += self.ssim_weight * ssim_loss
loss_dict['total'] = total_loss
return total_loss, loss_dict
class ContrastiveLoss(nn.Module):
"""
Contrastive loss for representation learning
Positive pairs: representations from adjacent frames
Negative pairs: representations from distant frames
"""
def __init__(self, temperature=0.1, margin=1.0):
super().__init__()
self.temperature = temperature
self.margin = margin
self.cosine_similarity = nn.CosineSimilarity(dim=-1)
def forward(self, representations, temporal_indices):
"""
Args:
representations: [B, D] representation vectors
temporal_indices: [B] temporal indices of each sample
Returns:
contrastive_loss
"""
batch_size = representations.size(0)
# Compute similarity matrix
sim_matrix = torch.matmul(representations, representations.T) / self.temperature
# Create positive mask (adjacent frames)
indices_expanded = temporal_indices.unsqueeze(0)
diff = torch.abs(indices_expanded - indices_expanded.T)
positive_mask = (diff == 1).float()
# Create negative mask (distant frames)
negative_mask = (diff > 2).float()
# Positive loss
pos_sim = sim_matrix * positive_mask
pos_loss = -torch.log(torch.exp(pos_sim) / torch.exp(sim_matrix).sum(dim=-1, keepdim=True) + 1e-8)
pos_loss = (pos_loss * positive_mask).sum() / (positive_mask.sum() + 1e-8)
# Negative loss (push apart)
neg_sim = sim_matrix * negative_mask
neg_loss = torch.relu(neg_sim - self.margin).mean()
return pos_loss + 0.1 * neg_loss
class MultiTaskLoss(nn.Module):
"""
Multi-task loss combining frame prediction and representation learning
"""
def __init__(self, frame_weight=1.0, contrastive_weight=0.1,
l1_weight=1.0, ssim_weight=0.1, use_contrastive=True):
super().__init__()
self.frame_weight = frame_weight
self.contrastive_weight = contrastive_weight
self.use_contrastive = use_contrastive
self.frame_loss = FramePredictionLoss(l1_weight=l1_weight, ssim_weight=ssim_weight)
if use_contrastive:
self.contrastive_loss = ContrastiveLoss()
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
"""
Args:
pred_frame: predicted frame [B, 3, H, W]
target_frame: target frame [B, 3, H, W]
representations: [B, D] representation vectors (optional)
temporal_indices: [B] temporal indices (optional)
Returns:
total_loss, loss_dict
"""
loss_dict = {}
# Frame prediction loss
frame_loss, frame_loss_dict = self.frame_loss(pred_frame, target_frame)
loss_dict.update({f'frame_{k}': v for k, v in frame_loss_dict.items()})
total_loss = self.frame_weight * frame_loss
# Contrastive loss (if representations provided)
if self.use_contrastive and representations is not None and temporal_indices is not None:
contrastive_loss = self.contrastive_loss(representations, temporal_indices)
loss_dict['contrastive'] = contrastive_loss
total_loss += self.contrastive_weight * contrastive_loss
loss_dict['total'] = total_loss
return total_loss, loss_dict

View File

@@ -48,27 +48,39 @@ class VideoFrameDataset(Dataset):
self.is_train = is_train
self.max_interval = max_interval
# Collect all video folders
# if num_frames < 1:
# raise ValueError("num_frames must be >= 1")
# if frame_size < 1:
# raise ValueError("frame_size must be >= 1")
# if max_interval < 1:
# raise ValueError("max_interval must be >= 1")
# Collect all video folders and their frame files
self.video_folders = []
self.video_frame_files = [] # list of list of Path objects
for item in self.root_dir.iterdir():
if item.is_dir():
self.video_folders.append(item)
# Get all frame files
frame_files = sorted([f for f in item.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
self.video_frame_files.append(frame_files)
if len(self.video_folders) == 0:
raise ValueError(f"No video folders found in {root_dir}")
# Build frame index: list of (video_idx, start_frame_idx)
self.frame_indices = []
for video_idx, video_folder in enumerate(self.video_folders):
# Get all frame files
frame_files = sorted([f for f in video_folder.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
if len(frame_files) < num_frames + 1:
for video_idx, frame_files in enumerate(self.video_frame_files):
# Minimum frames needed considering max interval
min_frames_needed = num_frames * max_interval + 1
if len(frame_files) < min_frames_needed:
continue # Skip videos with insufficient frames
# Add all possible starting positions
for start_idx in range(len(frame_files) - num_frames):
# Ensure that for any interval up to max_interval, all frames are within bounds
max_start = len(frame_files) - num_frames * max_interval
for start_idx in range(max_start):
self.frame_indices.append((video_idx, start_idx))
if len(self.frame_indices) == 0:
@@ -80,11 +92,12 @@ class VideoFrameDataset(Dataset):
else:
self.transform = transform
# Normalization (ImageNet stats)
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
# Simple normalization to [-1, 1] range (不使用ImageNet标准化)
# Convert pixel values [0, 255] to [-1, 1]
# This matches the model's tanh output range
self.normalize = None # We'll handle normalization manually
# print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]")
def _default_transform(self):
"""Default transform with augmentation for training"""
@@ -102,9 +115,12 @@ class VideoFrameDataset(Dataset):
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
"""Load a single frame as PIL Image"""
video_folder = self.video_folders[video_idx]
frame_files = sorted([f for f in video_folder.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
frame_files = self.video_frame_files[video_idx]
if frame_idx < 0 or frame_idx >= len(frame_files):
raise IndexError(
f"Frame index {frame_idx} out of range for video {video_idx} "
f"(0-{len(frame_files)-1})"
)
frame_path = frame_files[frame_idx]
return Image.open(frame_path).convert('RGB')
@@ -114,8 +130,8 @@ class VideoFrameDataset(Dataset):
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
input_frames: [3 * num_frames, H, W] concatenated input frames
target_frame: [3, H, W] target frame to predict
input_frames: [num_frames, H, W] concatenated input frames (Y channel only)
target_frame: [1, H, W] target frame to predict (Y channel only)
temporal_idx: temporal index of target frame (for contrastive loss)
"""
video_idx, start_idx = self.frame_indices[idx]
@@ -141,69 +157,77 @@ class VideoFrameDataset(Dataset):
if self.transform:
target_frame = self.transform(target_frame)
# Convert to tensors and normalize
# Convert to tensors and convert to grayscale (Y channel)
input_tensors = []
for frame in input_frames:
tensor = transforms.ToTensor()(frame)
tensor = self.normalize(tensor)
input_tensors.append(tensor)
tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1]
# Convert RGB to grayscale using weighted sum
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1]
# Normalize from [0, 1] to [-1, 1]
gray = gray * 2 - 1 # [0,1] -> [-1,1]
input_tensors.append(gray)
target_tensor = transforms.ToTensor()(target_frame)
target_tensor = self.normalize(target_tensor)
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1]
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
# Normalize from [0, 1] to [-1, 1]
target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1]
# Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0)
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
# Temporal index (for contrastive loss)
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx
return input_concatenated, target_gray, temporal_idx
class SyntheticVideoDataset(Dataset):
"""
Synthetic dataset for testing - generates random frames
"""
def __init__(self,
num_samples: int = 1000,
num_frames: int = 3,
frame_size: int = 224,
is_train: bool = True):
self.num_samples = num_samples
self.num_frames = num_frames
self.frame_size = frame_size
self.is_train = is_train
# class SyntheticVideoDataset(Dataset):
# """
# Synthetic dataset for testing - generates random frames
# """
# def __init__(self,
# num_samples: int = 1000,
# num_frames: int = 3,
# frame_size: int = 224,
# is_train: bool = True):
# self.num_samples = num_samples
# self.num_frames = num_frames
# self.frame_size = frame_size
# self.is_train = is_train
# Normalization
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
# # Normalization for Y channel (single channel)
# y_mean = (0.485 + 0.456 + 0.406) / 3.0
# y_std = (0.229 + 0.224 + 0.225) / 3.0
# self.normalize = transforms.Normalize(
# mean=[y_mean],
# std=[y_std]
# )
def __len__(self):
return self.num_samples
# def __len__(self):
# return self.num_samples
def __getitem__(self, idx):
# Generate random "frames" (noise with temporal correlation)
input_frames = []
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
# def __getitem__(self, idx):
# # Generate random "frames" (noise with temporal correlation)
# input_frames = []
# prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
for i in range(self.num_frames):
# Add some temporal correlation
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
frame = torch.clamp(frame, -1, 1)
input_frames.append(self.normalize(frame))
prev_frame = frame
# for i in range(self.num_frames):
# # Add some temporal correlation
# frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
# frame = torch.clamp(frame, -1, 1)
# input_frames.append(self.normalize(frame))
# prev_frame = frame
# Target frame (next in sequence)
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
target_frame = torch.clamp(target_frame, -1, 1)
target_tensor = self.normalize(target_frame)
# # Target frame (next in sequence)
# target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
# target_frame = torch.clamp(target_frame, -1, 1)
# target_tensor = self.normalize(target_frame)
# Concatenate inputs
input_concatenated = torch.cat(input_frames, dim=0)
# # Concatenate inputs
# input_concatenated = torch.cat(input_frames, dim=0)
# Temporal index
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
# # Temporal index
# temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx
# return input_concatenated, target_tensor, temporal_idx

303
video_preprocessor.py Normal file
View File

@@ -0,0 +1,303 @@
#!/usr/bin/env python3
"""
视频预处理脚本 - 将MP4视频转换为224x224帧图像
支持多线程并发处理、进度条显示和中断恢复功能
"""
import os
import sys
import json
import argparse
import subprocess
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import time
from typing import List, Dict, Optional
class VideoPreprocessor:
"""视频预处理器,支持多线程和中断恢复"""
def __init__(self,
input_dir: str,
output_dir: str,
frame_size: int = 224,
fps: int = 30,
num_workers: int = 4,
quality: int = 2,
resume: bool = True):
"""
初始化预处理器
Args:
input_dir: 输入视频目录
output_dir: 输出帧目录
frame_size: 帧大小(正方形)
fps: 提取帧率
num_workers: 并发工作线程数
quality: JPEG质量 (1-31, 数值越小质量越高)
resume: 是否启用中断恢复
"""
self.input_dir = Path(input_dir)
self.output_dir = Path(output_dir)
self.frame_size = frame_size
self.fps = fps
self.num_workers = num_workers
self.quality = quality
self.resume = resume
# 状态文件路径
self.state_file = self.output_dir / ".preprocessing_state.json"
# 创建输出目录
self.output_dir.mkdir(parents=True, exist_ok=True)
# 初始化状态
self.state = self._load_state()
# 收集所有视频文件
self.video_files = self._collect_video_files()
def _load_state(self) -> Dict:
"""加载处理状态"""
if self.resume and self.state_file.exists():
try:
with open(self.state_file, 'r') as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
print(f"警告: 无法读取状态文件,将重新开始处理")
return {
"completed": [],
"failed": [],
"total_processed": 0,
"start_time": None,
"last_update": None
}
def _save_state(self):
"""保存处理状态"""
self.state["last_update"] = time.time()
try:
with open(self.state_file, 'w') as f:
json.dump(self.state, f, indent=2)
except IOError as e:
print(f"警告: 无法保存状态文件: {e}")
def _collect_video_files(self) -> List[Path]:
"""收集所有需要处理的视频文件"""
video_files = []
for file_path in self.input_dir.glob("*.mp4"):
if file_path.name not in self.state["completed"]:
video_files.append(file_path)
return sorted(video_files)
def _parse_video_name(self, video_path: Path) -> Dict[str, str]:
"""解析视频文件名使用完整文件名作为ID"""
name_without_ext = video_path.stem
# 直接使用完整文件名作为ID确保每个mp4文件有独立的输出目录
return {
"video_id": name_without_ext,
"start_frame": "unknown",
"end_frame": "unknown",
"full_name": name_without_ext
}
def _extract_frames(self, video_path: Path) -> bool:
"""提取单个视频的帧"""
try:
# 解析视频名称
video_info = self._parse_video_name(video_path)
output_subdir = self.output_dir / video_info["video_id"]
output_subdir.mkdir(exist_ok=True)
# 构建FFmpeg命令
output_pattern = output_subdir / "frame_%04d.jpg"
cmd = [
"ffmpeg",
"-i", str(video_path),
"-vf", f"fps={self.fps},scale={self.frame_size}:{self.frame_size}",
"-q:v", str(self.quality),
"-y", # 覆盖输出文件
str(output_pattern)
]
# 执行FFmpeg命令
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300 # 5分钟超时
)
if result.returncode != 0:
print(f"FFmpeg错误处理 {video_path.name}: {result.stderr}")
return False
# 验证输出帧数量
output_frames = list(output_subdir.glob("frame_*.jpg"))
if len(output_frames) == 0:
print(f"警告: {video_path.name} 没有生成任何帧")
return False
return True
except subprocess.TimeoutExpired:
print(f"超时处理 {video_path.name}")
return False
except Exception as e:
print(f"处理 {video_path.name} 时发生错误: {e}")
return False
def _process_video(self, video_path: Path) -> tuple[bool, str]:
"""处理单个视频文件"""
video_name = video_path.name
try:
success = self._extract_frames(video_path)
if success:
self.state["completed"].append(video_name)
if video_name in self.state["failed"]:
self.state["failed"].remove(video_name)
return True, video_name
else:
self.state["failed"].append(video_name)
return False, video_name
except Exception as e:
print(f"处理 {video_name} 时发生异常: {e}")
self.state["failed"].append(video_name)
return False, video_name
def process_all_videos(self):
"""处理所有视频文件"""
if not self.video_files:
print("没有找到需要处理的视频文件")
return
print(f"找到 {len(self.video_files)} 个待处理视频文件")
print(f"输出目录: {self.output_dir}")
print(f"帧大小: {self.frame_size}x{self.frame_size}")
print(f"帧率: {self.fps} fps")
print(f"并发线程数: {self.num_workers}")
if self.state["completed"]:
print(f"跳过 {len(self.state['completed'])} 个已处理的视频")
# 记录开始时间
if self.state["start_time"] is None:
self.state["start_time"] = time.time()
# 创建进度条
with tqdm(total=len(self.video_files), desc="处理视频", unit="") as pbar:
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
# 提交所有任务
future_to_video = {
executor.submit(self._process_video, video_path): video_path
for video_path in self.video_files
}
# 处理完成的任务
for future in as_completed(future_to_video):
video_path = future_to_video[future]
try:
success, video_name = future.result()
if success:
pbar.set_postfix({"状态": "成功", "文件": video_name[:20]})
else:
pbar.set_postfix({"状态": "失败", "文件": video_name[:20]})
except Exception as e:
print(f"处理 {video_path.name} 时发生异常: {e}")
pbar.set_postfix({"状态": "异常", "文件": video_path.name[:20]})
pbar.update(1)
self.state["total_processed"] += 1
# 定期保存状态
if self.state["total_processed"] % 5 == 0:
self._save_state()
# 最终保存状态
self._save_state()
# 打印处理结果
self._print_summary()
def _print_summary(self):
"""打印处理摘要"""
print("\n" + "="*50)
print("处理完成摘要:")
print(f"总处理视频数: {len(self.state['completed'])}")
print(f"失败视频数: {len(self.state['failed'])}")
if self.state["failed"]:
print("\n失败的视频:")
for video_name in self.state["failed"]:
print(f" - {video_name}")
if self.state["start_time"]:
elapsed_time = time.time() - self.state["start_time"]
print(f"\n总耗时: {elapsed_time:.2f}")
if self.state["total_processed"] > 0:
avg_time = elapsed_time / self.state["total_processed"]
print(f"平均每个视频: {avg_time:.2f}")
print("="*50)
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="视频预处理脚本")
parser.add_argument("--input_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/sekai-real-drone", help="输入视频目录")
parser.add_argument("--output_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/processed", help="输出帧目录")
parser.add_argument("--size", type=int, default=224, help="帧大小 (默认: 224)")
parser.add_argument("--fps", type=int, default=10, help="提取帧率 (默认: 30)")
parser.add_argument("--workers", type=int, default=32, help="并发线程数 (默认: 4)")
parser.add_argument("--quality", type=int, default=2, help="JPEG质量 1-31 (默认: 2)")
parser.add_argument("--no-resume", action="store_true", help="不启用中断恢复")
args = parser.parse_args()
# 检查输入目录
if not Path(args.input_dir).exists():
print(f"错误: 输入目录不存在: {args.input_dir}")
sys.exit(1)
# 检查FFmpeg是否可用
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
except (subprocess.CalledProcessError, FileNotFoundError):
print("错误: FFmpeg未安装或不在PATH中")
sys.exit(1)
# 创建预处理器并开始处理
preprocessor = VideoPreprocessor(
input_dir=args.input_dir,
output_dir=args.output_dir,
frame_size=args.size,
fps=args.fps,
num_workers=args.workers,
quality=args.quality,
resume=not args.no_resume
)
try:
preprocessor.process_all_videos()
except KeyboardInterrupt:
print("\n\n用户中断处理,状态已保存")
preprocessor._save_state()
print("可以使用相同命令恢复处理")
except Exception as e:
print(f"\n处理过程中发生错误: {e}")
preprocessor._save_state()
sys.exit(1)
if __name__ == "__main__":
main()