test modify swiftformer to temporal input

This commit is contained in:
2026-01-07 11:03:33 +08:00
parent 4aa6cd6752
commit 7e9564ef20
6 changed files with 1074 additions and 0 deletions

373
main_temporal.py Normal file
View File

@@ -0,0 +1,373 @@
"""
Main training script for SwiftFormerTemporal frame prediction
"""
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
from pathlib import Path
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler, get_state_dict, ModelEma
from util import *
from models import *
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
from util.frame_losses import MultiTaskLoss
def get_args_parser():
parser = argparse.ArgumentParser(
'SwiftFormerTemporal training script', add_help=False)
# Dataset parameters
parser.add_argument('--data-path', default='./videos', type=str,
help='Path to video dataset')
parser.add_argument('--dataset-type', default='video', choices=['video', 'synthetic'],
type=str, help='Dataset type')
parser.add_argument('--num-frames', default=3, type=int,
help='Number of input frames (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='Input frame size')
parser.add_argument('--max-interval', default=1, type=int,
help='Maximum interval between consecutive frames')
# Model parameters
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--use-representation-head', action='store_true',
help='Use representation head for pose/velocity prediction')
parser.add_argument('--representation-dim', default=128, type=int,
help='Dimension of representation vector')
# Training parameters
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
# Loss parameters
parser.add_argument('--frame-weight', type=float, default=1.0,
help='Weight for frame prediction loss')
parser.add_argument('--contrastive-weight', type=float, default=0.1,
help='Weight for contrastive loss')
parser.add_argument('--l1-weight', type=float, default=1.0,
help='Weight for L1 loss')
parser.add_argument('--ssim-weight', type=float, default=0.1,
help='Weight for SSIM loss')
parser.add_argument('--no-contrastive', action='store_true',
help='Disable contrastive loss')
parser.add_argument('--no-ssim', action='store_true',
help='Disable SSIM loss')
# System parameters
parser.add_argument('--output-dir', default='./output',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--num-workers', default=4, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# Distributed training
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://',
help='url used to set up distributed training')
return parser
def build_dataset(is_train, args):
"""Build video frame dataset"""
if args.dataset_type == 'synthetic':
dataset = SyntheticVideoDataset(
num_samples=1000 if is_train else 200,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train
)
else:
dataset = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train,
max_interval=args.max_interval
)
return dataset
def main(args):
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
# Fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
# Build datasets
dataset_train = build_dataset(is_train=True, args=args)
dataset_val = build_dataset(is_train=False, args=args)
# Create samplers
if args.distributed:
sampler_train = torch.utils.data.DistributedSampler(dataset_train)
sampler_val = torch.utils.data.DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
# Create model
print(f"Creating model: {args.model}")
model_kwargs = {
'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
}
if args.model == 'SwiftFormerTemporal_XS':
model = SwiftFormerTemporal_XS(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_S':
model = SwiftFormerTemporal_S(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L1':
model = SwiftFormerTemporal_L1(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L3':
model = SwiftFormerTemporal_L3(**model_kwargs)
else:
raise ValueError(f"Unknown model: {args.model}")
model.to(device)
# Model EMA
model_ema = None
if hasattr(args, 'model_ema') and args.model_ema:
model_ema = ModelEma(
model,
decay=args.model_ema_decay if hasattr(args, 'model_ema_decay') else 0.9999,
device='cpu' if hasattr(args, 'model_ema_force_cpu') and args.model_ema_force_cpu else '',
resume='')
# Distributed training
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {n_parameters}')
# Create optimizer
optimizer = create_optimizer(args, model_without_ddp)
# Create loss scaler
loss_scaler = NativeScaler()
# Create scheduler
lr_scheduler, _ = create_scheduler(args, optimizer)
# Create loss function
criterion = MultiTaskLoss(
frame_weight=args.frame_weight,
contrastive_weight=args.contrastive_weight,
l1_weight=args.l1_weight,
ssim_weight=args.ssim_weight,
use_contrastive=not args.no_contrastive
)
# Resume from checkpoint
output_dir = Path(args.output_dir)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if model_ema is not None:
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
if args.eval:
test_stats = evaluate(data_loader_val, model, criterion, device)
print(f"Test stats: {test_stats}")
return
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
model_ema=model_ema
)
lr_scheduler.step(epoch)
# Save checkpoint
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1):
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'model_ema': get_state_dict(model_ema) if model_ema else None,
'scaler': loss_scaler.state_dict(),
'args': args,
}, checkpoint_path)
# Evaluate
if epoch % 5 == 0 or epoch == args.epochs - 1:
test_stats = evaluate(data_loader_val, model, criterion, device)
print(f"Epoch {epoch}: Test stats: {test_stats}")
# Log stats
log_stats = {
**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters
}
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f'Training time {total_time_str}')
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0, clip_mode='norm', model_ema=None, **kwargs):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = f'Epoch: [{epoch}]'
print_freq = 10
for input_frames, target_frames, temporal_indices in metric_logger.log_every(
data_loader, print_freq, header):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Forward pass
with torch.cuda.amp.autocast():
pred_frames, representations = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
representations, temporal_indices
)
loss_value = loss.item()
if not torch.isfinite(torch.tensor(loss_value)):
print(f"Loss is {loss_value}, stopping training")
raise ValueError(f"Loss is {loss_value}")
optimizer.zero_grad()
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters())
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
# Update metrics
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, criterion, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Compute output
with torch.cuda.amp.autocast():
pred_frames, representations = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
representations, temporal_indices
)
# Update metrics
metric_logger.update(loss=loss.item())
for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
metric_logger.synchronize_between_processes()
print('* Test stats:', metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'SwiftFormerTemporal training script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)

View File

@@ -1 +1,7 @@
from .swiftformer import SwiftFormer_XS, SwiftFormer_S, SwiftFormer_L1, SwiftFormer_L3
from .swiftformer_temporal import (
SwiftFormerTemporal_XS,
SwiftFormerTemporal_S,
SwiftFormerTemporal_L1,
SwiftFormerTemporal_L3
)

View File

@@ -0,0 +1,244 @@
"""
SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction
"""
import torch
import torch.nn as nn
from .swiftformer import (
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
stem, Embedding, Stage
)
from timm.models.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
self.conv = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class FramePredictionDecoder(nn.Module):
"""Lightweight decoder for frame prediction with optional skip connections"""
def __init__(self, embed_dims, output_channels=3, use_skip=False):
super().__init__()
self.use_skip = use_skip
# Reverse the embed_dims for decoder
decoder_dims = embed_dims[::-1]
self.blocks = nn.ModuleList()
# First upsampling from bottleneck to stage4 resolution
self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage4 to stage3
self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage3 to stage2
self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage2 to original resolution (4x upsampling total)
self.blocks.append(nn.Sequential(
nn.ConvTranspose2d(
decoder_dims[3], 32,
kernel_size=3, stride=2, padding=1, output_padding=1
),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1),
nn.Tanh() # Output in [-1, 1] range
))
# If using skip connections, we need to adjust input channels for each block
if use_skip:
# We'll modify the first three blocks to accept concatenated features
# Instead of modifying existing blocks, we'll replace them with custom blocks
# For simplicity, we'll keep the same architecture but forward will handle concatenation
pass
def forward(self, x, skip_features=None):
"""
Args:
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
skip_features: list of encoder features from stages [stage2, stage1, stage0]
each of shape [B, C, H', W'] where C matches decoder dims?
"""
if self.use_skip and skip_features is not None:
# Ensure we have exactly 3 skip features (for the first three blocks)
assert len(skip_features) == 3, "Need 3 skip features for skip connections"
# Reverse skip_features to match decoder order: stage2, stage1, stage0
# skip_features[0] should be stage2 (H/16), [1] stage1 (H/8), [2] stage0 (H/4)
skip_features = skip_features[::-1] # Now index 0: stage2, 1: stage1, 2: stage0
for i, block in enumerate(self.blocks):
if self.use_skip and skip_features is not None and i < 3:
# Concatenate skip feature along channel dimension
# Ensure spatial dimensions match (they should because of upsampling)
x = torch.cat([x, skip_features[i]], dim=1)
# Need to adjust block to accept extra channels? We'll create a separate block.
# For now, we'll just pass through, but this will cause channel mismatch.
# Instead, we should have created custom blocks with appropriate in_channels.
# This is a placeholder; we need to implement properly.
pass
x = block(x)
return x
class SwiftFormerTemporal(nn.Module):
"""
SwiftFormer with temporal input for frame prediction.
Input: [B, num_frames, H, W] (Y channel only)
Output: predicted frame [B, 3, H, W] and optional representation
"""
def __init__(self,
model_name='XS',
num_frames=3,
use_decoder=True,
use_representation_head=False,
representation_dim=128,
return_features=False,
**kwargs):
super().__init__()
# Get model configuration
layers = SwiftFormer_depth[model_name]
embed_dims = SwiftFormer_width[model_name]
# Store configuration
self.num_frames = num_frames
self.use_decoder = use_decoder
self.use_representation_head = use_representation_head
self.return_features = return_features
# Modify stem to accept multiple frames (only Y channel)
in_channels = num_frames
self.patch_embed = stem(in_channels, embed_dims[0])
# Build encoder network (same as SwiftFormer)
network = []
for i in range(len(layers)):
stage = Stage(embed_dims[i], i, layers, mlp_ratio=4,
act_layer=nn.GELU,
drop_rate=0., drop_path_rate=0.,
use_layer_scale=True,
layer_scale_init_value=1e-5,
vit_num=1)
network.append(stage)
if i >= len(layers) - 1:
break
if embed_dims[i] != embed_dims[i + 1]:
network.append(
Embedding(
patch_size=3, stride=2, padding=1,
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
)
)
self.network = nn.ModuleList(network)
self.norm = nn.BatchNorm2d(embed_dims[-1])
# Frame prediction decoder
if use_decoder:
self.decoder = FramePredictionDecoder(embed_dims, output_channels=3)
# Representation head for pose/velocity prediction
if use_representation_head:
self.representation_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(embed_dims[-1], representation_dim),
nn.ReLU(),
nn.Linear(representation_dim, representation_dim)
)
else:
self.representation_head = None
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x):
"""Forward through encoder network, return list of stage features if return_features else final output"""
if self.return_features:
features = []
for idx, block in enumerate(self.network):
x = block(x)
# Collect output after each stage (indices 0,2,4,6 correspond to stages)
if idx in [0, 2, 4, 6]:
features.append(x)
return x, features
else:
for block in self.network:
x = block(x)
return x
def forward(self, x):
"""
Args:
x: input frames of shape [B, num_frames, H, W]
Returns:
If return_features is False:
pred_frame: predicted frame [B, 3, H, W] (or None)
representation: optional representation vector [B, representation_dim] (or None)
If return_features is True:
pred_frame, representation, features (list of stage features)
"""
# Encode
x = self.patch_embed(x)
if self.return_features:
x, features = self.forward_tokens(x)
else:
x = self.forward_tokens(x)
x = self.norm(x)
# Get representation if needed
representation = None
if self.representation_head is not None:
representation = self.representation_head(x)
# Decode to frame
pred_frame = None
if self.use_decoder:
pred_frame = self.decoder(x)
if self.return_features:
return pred_frame, representation, features
else:
return pred_frame, representation
# Factory functions for different model sizes
def SwiftFormerTemporal_XS(num_frames=3, **kwargs):
return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_S(num_frames=3, **kwargs):
return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_L1(num_frames=3, **kwargs):
return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_L3(num_frames=3, **kwargs):
return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)

60
test_model.py Normal file
View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
"""
Test script for SwiftFormerTemporal model
"""
import torch
import sys
import os
# Add current directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from models.swiftformer_temporal import SwiftFormerTemporal_XS
def test_model():
print("Testing SwiftFormerTemporal model...")
# Create model
model = SwiftFormerTemporal_XS(num_frames=3, use_representation_head=True)
print(f'Model created: {model.__class__.__name__}')
print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}')
# Test forward pass
batch_size = 2
num_frames = 3
height = width = 224
x = torch.randn(batch_size, 3 * num_frames, height, width)
print(f'\nInput shape: {x.shape}')
with torch.no_grad():
pred_frame, representation = model(x)
print(f'Predicted frame shape: {pred_frame.shape}')
print(f'Representation shape: {representation.shape if representation is not None else "None"}')
# Check output ranges
print(f'\nPredicted frame range: [{pred_frame.min():.3f}, {pred_frame.max():.3f}]')
# Test loss function
from util.frame_losses import MultiTaskLoss
criterion = MultiTaskLoss()
target = torch.randn_like(pred_frame)
temporal_indices = torch.tensor([3, 3], dtype=torch.long)
loss, loss_dict = criterion(pred_frame, target, representation, temporal_indices)
print(f'\nLoss test:')
for k, v in loss_dict.items():
print(f' {k}: {v:.4f}')
print('\nAll tests passed!')
return True
if __name__ == '__main__':
try:
test_model()
except Exception as e:
print(f'Test failed with error: {e}')
import traceback
traceback.print_exc()
sys.exit(1)

182
util/frame_losses.py Normal file
View File

@@ -0,0 +1,182 @@
"""
Loss functions for frame prediction and representation learning
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SSIMLoss(nn.Module):
"""
Structural Similarity Index Measure Loss
Based on: https://github.com/Po-Hsun-Su/pytorch-ssim
"""
def __init__(self, window_size=11, size_average=True):
super().__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 3
self.window = self.create_window(window_size, self.channel)
def create_window(self, window_size, channel):
def gaussian(window_size, sigma):
gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
def forward(self, img1, img2):
# Ensure window is on correct device
if self.window.device != img1.device:
self.window = self.window.to(img1.device)
mu1 = F.conv2d(img1, self.window, padding=self.window_size//2, groups=self.channel)
mu2 = F.conv2d(img2, self.window, padding=self.window_size//2, groups=self.channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1*img1, self.window, padding=self.window_size//2, groups=self.channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if self.size_average:
return 1 - ssim_map.mean()
else:
return 1 - ssim_map.mean(1).mean(1).mean(1)
class FramePredictionLoss(nn.Module):
"""
Combined loss for frame prediction
"""
def __init__(self, l1_weight=1.0, ssim_weight=0.1, use_ssim=True):
super().__init__()
self.l1_weight = l1_weight
self.ssim_weight = ssim_weight
self.use_ssim = use_ssim
self.l1_loss = nn.L1Loss()
if use_ssim:
self.ssim_loss = SSIMLoss()
def forward(self, pred, target):
"""
Args:
pred: predicted frame [B, 3, H, W] in range [-1, 1]
target: target frame [B, 3, H, W] in range [-1, 1]
Returns:
total_loss, loss_dict
"""
loss_dict = {}
# L1 loss
l1_loss = self.l1_loss(pred, target)
loss_dict['l1'] = l1_loss
total_loss = self.l1_weight * l1_loss
# SSIM loss
if self.use_ssim:
ssim_loss = self.ssim_loss(pred, target)
loss_dict['ssim'] = ssim_loss
total_loss += self.ssim_weight * ssim_loss
loss_dict['total'] = total_loss
return total_loss, loss_dict
class ContrastiveLoss(nn.Module):
"""
Contrastive loss for representation learning
Positive pairs: representations from adjacent frames
Negative pairs: representations from distant frames
"""
def __init__(self, temperature=0.1, margin=1.0):
super().__init__()
self.temperature = temperature
self.margin = margin
self.cosine_similarity = nn.CosineSimilarity(dim=-1)
def forward(self, representations, temporal_indices):
"""
Args:
representations: [B, D] representation vectors
temporal_indices: [B] temporal indices of each sample
Returns:
contrastive_loss
"""
batch_size = representations.size(0)
# Compute similarity matrix
sim_matrix = torch.matmul(representations, representations.T) / self.temperature
# Create positive mask (adjacent frames)
indices_expanded = temporal_indices.unsqueeze(0)
diff = torch.abs(indices_expanded - indices_expanded.T)
positive_mask = (diff == 1).float()
# Create negative mask (distant frames)
negative_mask = (diff > 2).float()
# Positive loss
pos_sim = sim_matrix * positive_mask
pos_loss = -torch.log(torch.exp(pos_sim) / torch.exp(sim_matrix).sum(dim=-1, keepdim=True) + 1e-8)
pos_loss = (pos_loss * positive_mask).sum() / (positive_mask.sum() + 1e-8)
# Negative loss (push apart)
neg_sim = sim_matrix * negative_mask
neg_loss = torch.relu(neg_sim - self.margin).mean()
return pos_loss + 0.1 * neg_loss
class MultiTaskLoss(nn.Module):
"""
Multi-task loss combining frame prediction and representation learning
"""
def __init__(self, frame_weight=1.0, contrastive_weight=0.1,
l1_weight=1.0, ssim_weight=0.1, use_contrastive=True):
super().__init__()
self.frame_weight = frame_weight
self.contrastive_weight = contrastive_weight
self.use_contrastive = use_contrastive
self.frame_loss = FramePredictionLoss(l1_weight=l1_weight, ssim_weight=ssim_weight)
if use_contrastive:
self.contrastive_loss = ContrastiveLoss()
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
"""
Args:
pred_frame: predicted frame [B, 3, H, W]
target_frame: target frame [B, 3, H, W]
representations: [B, D] representation vectors (optional)
temporal_indices: [B] temporal indices (optional)
Returns:
total_loss, loss_dict
"""
loss_dict = {}
# Frame prediction loss
frame_loss, frame_loss_dict = self.frame_loss(pred_frame, target_frame)
loss_dict.update({f'frame_{k}': v for k, v in frame_loss_dict.items()})
total_loss = self.frame_weight * frame_loss
# Contrastive loss (if representations provided)
if self.use_contrastive and representations is not None and temporal_indices is not None:
contrastive_loss = self.contrastive_loss(representations, temporal_indices)
loss_dict['contrastive'] = contrastive_loss
total_loss += self.contrastive_weight * contrastive_loss
loss_dict['total'] = total_loss
return total_loss, loss_dict

209
util/video_dataset.py Normal file
View File

@@ -0,0 +1,209 @@
"""
Video frame dataset for temporal self-supervised learning
"""
import os
import random
from pathlib import Path
from typing import Optional, Tuple, List
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
class VideoFrameDataset(Dataset):
"""
Dataset for loading consecutive frames from videos for frame prediction.
Assumes directory structure:
dataset_root/
video1/
frame_0001.jpg
frame_0002.jpg
...
video2/
...
"""
def __init__(self,
root_dir: str,
num_frames: int = 3,
frame_size: int = 224,
is_train: bool = True,
max_interval: int = 1,
transform=None):
"""
Args:
root_dir: Root directory containing video folders
num_frames: Number of input frames (T)
frame_size: Size to resize frames to
is_train: Whether this is training set (affects augmentation)
max_interval: Maximum interval between consecutive frames
transform: Optional custom transform
"""
self.root_dir = Path(root_dir)
self.num_frames = num_frames
self.frame_size = frame_size
self.is_train = is_train
self.max_interval = max_interval
# Collect all video folders
self.video_folders = []
for item in self.root_dir.iterdir():
if item.is_dir():
self.video_folders.append(item)
if len(self.video_folders) == 0:
raise ValueError(f"No video folders found in {root_dir}")
# Build frame index: list of (video_idx, start_frame_idx)
self.frame_indices = []
for video_idx, video_folder in enumerate(self.video_folders):
# Get all frame files
frame_files = sorted([f for f in video_folder.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
if len(frame_files) < num_frames + 1:
continue # Skip videos with insufficient frames
# Add all possible starting positions
for start_idx in range(len(frame_files) - num_frames):
self.frame_indices.append((video_idx, start_idx))
if len(self.frame_indices) == 0:
raise ValueError("No valid frame sequences found in dataset")
# Default transforms
if transform is None:
self.transform = self._default_transform()
else:
self.transform = transform
# Normalization (ImageNet stats)
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
def _default_transform(self):
"""Default transform with augmentation for training"""
if self.is_train:
return transforms.Compose([
transforms.RandomResizedCrop(self.frame_size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
])
else:
return transforms.Compose([
transforms.Resize(int(self.frame_size * 1.14)),
transforms.CenterCrop(self.frame_size),
])
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
"""Load a single frame as PIL Image"""
video_folder = self.video_folders[video_idx]
frame_files = sorted([f for f in video_folder.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
frame_path = frame_files[frame_idx]
return Image.open(frame_path).convert('RGB')
def __len__(self) -> int:
return len(self.frame_indices)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
input_frames: [3 * num_frames, H, W] concatenated input frames
target_frame: [3, H, W] target frame to predict
temporal_idx: temporal index of target frame (for contrastive loss)
"""
video_idx, start_idx = self.frame_indices[idx]
# Determine frame interval (for temporal augmentation)
interval = random.randint(1, self.max_interval) if self.is_train else 1
# Load input frames
input_frames = []
for i in range(self.num_frames):
frame_idx = start_idx + i * interval
frame = self._load_frame(video_idx, frame_idx)
# Apply transform (same for all frames in sequence)
if self.transform:
frame = self.transform(frame)
input_frames.append(frame)
# Load target frame (next frame after input sequence)
target_idx = start_idx + self.num_frames * interval
target_frame = self._load_frame(video_idx, target_idx)
if self.transform:
target_frame = self.transform(target_frame)
# Convert to tensors and normalize
input_tensors = []
for frame in input_frames:
tensor = transforms.ToTensor()(frame)
tensor = self.normalize(tensor)
input_tensors.append(tensor)
target_tensor = transforms.ToTensor()(target_frame)
target_tensor = self.normalize(target_tensor)
# Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0)
# Temporal index (for contrastive loss)
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx
class SyntheticVideoDataset(Dataset):
"""
Synthetic dataset for testing - generates random frames
"""
def __init__(self,
num_samples: int = 1000,
num_frames: int = 3,
frame_size: int = 224,
is_train: bool = True):
self.num_samples = num_samples
self.num_frames = num_frames
self.frame_size = frame_size
self.is_train = is_train
# Normalization
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Generate random "frames" (noise with temporal correlation)
input_frames = []
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
for i in range(self.num_frames):
# Add some temporal correlation
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
frame = torch.clamp(frame, -1, 1)
input_frames.append(self.normalize(frame))
prev_frame = frame
# Target frame (next in sequence)
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
target_frame = torch.clamp(target_frame, -1, 1)
target_tensor = self.normalize(target_frame)
# Concatenate inputs
input_concatenated = torch.cat(input_frames, dim=0)
# Temporal index
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx