test modify swiftformer to temporal input
This commit is contained in:
182
util/frame_losses.py
Normal file
182
util/frame_losses.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Loss functions for frame prediction and representation learning
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
class SSIMLoss(nn.Module):
|
||||
"""
|
||||
Structural Similarity Index Measure Loss
|
||||
Based on: https://github.com/Po-Hsun-Su/pytorch-ssim
|
||||
"""
|
||||
def __init__(self, window_size=11, size_average=True):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.channel = 3
|
||||
self.window = self.create_window(window_size, self.channel)
|
||||
|
||||
def create_window(self, window_size, channel):
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
||||
return gauss/gauss.sum()
|
||||
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
||||
return window
|
||||
|
||||
def forward(self, img1, img2):
|
||||
# Ensure window is on correct device
|
||||
if self.window.device != img1.device:
|
||||
self.window = self.window.to(img1.device)
|
||||
|
||||
mu1 = F.conv2d(img1, self.window, padding=self.window_size//2, groups=self.channel)
|
||||
mu2 = F.conv2d(img2, self.window, padding=self.window_size//2, groups=self.channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(img1*img1, self.window, padding=self.window_size//2, groups=self.channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01**2
|
||||
C2 = 0.03**2
|
||||
|
||||
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
if self.size_average:
|
||||
return 1 - ssim_map.mean()
|
||||
else:
|
||||
return 1 - ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
|
||||
class FramePredictionLoss(nn.Module):
|
||||
"""
|
||||
Combined loss for frame prediction
|
||||
"""
|
||||
def __init__(self, l1_weight=1.0, ssim_weight=0.1, use_ssim=True):
|
||||
super().__init__()
|
||||
self.l1_weight = l1_weight
|
||||
self.ssim_weight = ssim_weight
|
||||
self.use_ssim = use_ssim
|
||||
|
||||
self.l1_loss = nn.L1Loss()
|
||||
if use_ssim:
|
||||
self.ssim_loss = SSIMLoss()
|
||||
|
||||
def forward(self, pred, target):
|
||||
"""
|
||||
Args:
|
||||
pred: predicted frame [B, 3, H, W] in range [-1, 1]
|
||||
target: target frame [B, 3, H, W] in range [-1, 1]
|
||||
Returns:
|
||||
total_loss, loss_dict
|
||||
"""
|
||||
loss_dict = {}
|
||||
|
||||
# L1 loss
|
||||
l1_loss = self.l1_loss(pred, target)
|
||||
loss_dict['l1'] = l1_loss
|
||||
total_loss = self.l1_weight * l1_loss
|
||||
|
||||
# SSIM loss
|
||||
if self.use_ssim:
|
||||
ssim_loss = self.ssim_loss(pred, target)
|
||||
loss_dict['ssim'] = ssim_loss
|
||||
total_loss += self.ssim_weight * ssim_loss
|
||||
|
||||
loss_dict['total'] = total_loss
|
||||
return total_loss, loss_dict
|
||||
|
||||
|
||||
class ContrastiveLoss(nn.Module):
|
||||
"""
|
||||
Contrastive loss for representation learning
|
||||
Positive pairs: representations from adjacent frames
|
||||
Negative pairs: representations from distant frames
|
||||
"""
|
||||
def __init__(self, temperature=0.1, margin=1.0):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.margin = margin
|
||||
self.cosine_similarity = nn.CosineSimilarity(dim=-1)
|
||||
|
||||
def forward(self, representations, temporal_indices):
|
||||
"""
|
||||
Args:
|
||||
representations: [B, D] representation vectors
|
||||
temporal_indices: [B] temporal indices of each sample
|
||||
Returns:
|
||||
contrastive_loss
|
||||
"""
|
||||
batch_size = representations.size(0)
|
||||
|
||||
# Compute similarity matrix
|
||||
sim_matrix = torch.matmul(representations, representations.T) / self.temperature
|
||||
|
||||
# Create positive mask (adjacent frames)
|
||||
indices_expanded = temporal_indices.unsqueeze(0)
|
||||
diff = torch.abs(indices_expanded - indices_expanded.T)
|
||||
positive_mask = (diff == 1).float()
|
||||
|
||||
# Create negative mask (distant frames)
|
||||
negative_mask = (diff > 2).float()
|
||||
|
||||
# Positive loss
|
||||
pos_sim = sim_matrix * positive_mask
|
||||
pos_loss = -torch.log(torch.exp(pos_sim) / torch.exp(sim_matrix).sum(dim=-1, keepdim=True) + 1e-8)
|
||||
pos_loss = (pos_loss * positive_mask).sum() / (positive_mask.sum() + 1e-8)
|
||||
|
||||
# Negative loss (push apart)
|
||||
neg_sim = sim_matrix * negative_mask
|
||||
neg_loss = torch.relu(neg_sim - self.margin).mean()
|
||||
|
||||
return pos_loss + 0.1 * neg_loss
|
||||
|
||||
|
||||
class MultiTaskLoss(nn.Module):
|
||||
"""
|
||||
Multi-task loss combining frame prediction and representation learning
|
||||
"""
|
||||
def __init__(self, frame_weight=1.0, contrastive_weight=0.1,
|
||||
l1_weight=1.0, ssim_weight=0.1, use_contrastive=True):
|
||||
super().__init__()
|
||||
self.frame_weight = frame_weight
|
||||
self.contrastive_weight = contrastive_weight
|
||||
self.use_contrastive = use_contrastive
|
||||
|
||||
self.frame_loss = FramePredictionLoss(l1_weight=l1_weight, ssim_weight=ssim_weight)
|
||||
if use_contrastive:
|
||||
self.contrastive_loss = ContrastiveLoss()
|
||||
|
||||
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
|
||||
"""
|
||||
Args:
|
||||
pred_frame: predicted frame [B, 3, H, W]
|
||||
target_frame: target frame [B, 3, H, W]
|
||||
representations: [B, D] representation vectors (optional)
|
||||
temporal_indices: [B] temporal indices (optional)
|
||||
Returns:
|
||||
total_loss, loss_dict
|
||||
"""
|
||||
loss_dict = {}
|
||||
|
||||
# Frame prediction loss
|
||||
frame_loss, frame_loss_dict = self.frame_loss(pred_frame, target_frame)
|
||||
loss_dict.update({f'frame_{k}': v for k, v in frame_loss_dict.items()})
|
||||
total_loss = self.frame_weight * frame_loss
|
||||
|
||||
# Contrastive loss (if representations provided)
|
||||
if self.use_contrastive and representations is not None and temporal_indices is not None:
|
||||
contrastive_loss = self.contrastive_loss(representations, temporal_indices)
|
||||
loss_dict['contrastive'] = contrastive_loss
|
||||
total_loss += self.contrastive_weight * contrastive_loss
|
||||
|
||||
loss_dict['total'] = total_loss
|
||||
return total_loss, loss_dict
|
||||
209
util/video_dataset.py
Normal file
209
util/video_dataset.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
Video frame dataset for temporal self-supervised learning
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class VideoFrameDataset(Dataset):
|
||||
"""
|
||||
Dataset for loading consecutive frames from videos for frame prediction.
|
||||
|
||||
Assumes directory structure:
|
||||
dataset_root/
|
||||
video1/
|
||||
frame_0001.jpg
|
||||
frame_0002.jpg
|
||||
...
|
||||
video2/
|
||||
...
|
||||
"""
|
||||
def __init__(self,
|
||||
root_dir: str,
|
||||
num_frames: int = 3,
|
||||
frame_size: int = 224,
|
||||
is_train: bool = True,
|
||||
max_interval: int = 1,
|
||||
transform=None):
|
||||
"""
|
||||
Args:
|
||||
root_dir: Root directory containing video folders
|
||||
num_frames: Number of input frames (T)
|
||||
frame_size: Size to resize frames to
|
||||
is_train: Whether this is training set (affects augmentation)
|
||||
max_interval: Maximum interval between consecutive frames
|
||||
transform: Optional custom transform
|
||||
"""
|
||||
self.root_dir = Path(root_dir)
|
||||
self.num_frames = num_frames
|
||||
self.frame_size = frame_size
|
||||
self.is_train = is_train
|
||||
self.max_interval = max_interval
|
||||
|
||||
# Collect all video folders
|
||||
self.video_folders = []
|
||||
for item in self.root_dir.iterdir():
|
||||
if item.is_dir():
|
||||
self.video_folders.append(item)
|
||||
|
||||
if len(self.video_folders) == 0:
|
||||
raise ValueError(f"No video folders found in {root_dir}")
|
||||
|
||||
# Build frame index: list of (video_idx, start_frame_idx)
|
||||
self.frame_indices = []
|
||||
for video_idx, video_folder in enumerate(self.video_folders):
|
||||
# Get all frame files
|
||||
frame_files = sorted([f for f in video_folder.iterdir()
|
||||
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
|
||||
|
||||
if len(frame_files) < num_frames + 1:
|
||||
continue # Skip videos with insufficient frames
|
||||
|
||||
# Add all possible starting positions
|
||||
for start_idx in range(len(frame_files) - num_frames):
|
||||
self.frame_indices.append((video_idx, start_idx))
|
||||
|
||||
if len(self.frame_indices) == 0:
|
||||
raise ValueError("No valid frame sequences found in dataset")
|
||||
|
||||
# Default transforms
|
||||
if transform is None:
|
||||
self.transform = self._default_transform()
|
||||
else:
|
||||
self.transform = transform
|
||||
|
||||
# Normalization (ImageNet stats)
|
||||
self.normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
def _default_transform(self):
|
||||
"""Default transform with augmentation for training"""
|
||||
if self.is_train:
|
||||
return transforms.Compose([
|
||||
transforms.RandomResizedCrop(self.frame_size, scale=(0.8, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
|
||||
])
|
||||
else:
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(self.frame_size * 1.14)),
|
||||
transforms.CenterCrop(self.frame_size),
|
||||
])
|
||||
|
||||
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
|
||||
"""Load a single frame as PIL Image"""
|
||||
video_folder = self.video_folders[video_idx]
|
||||
frame_files = sorted([f for f in video_folder.iterdir()
|
||||
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
|
||||
frame_path = frame_files[frame_idx]
|
||||
return Image.open(frame_path).convert('RGB')
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.frame_indices)
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Returns:
|
||||
input_frames: [3 * num_frames, H, W] concatenated input frames
|
||||
target_frame: [3, H, W] target frame to predict
|
||||
temporal_idx: temporal index of target frame (for contrastive loss)
|
||||
"""
|
||||
video_idx, start_idx = self.frame_indices[idx]
|
||||
|
||||
# Determine frame interval (for temporal augmentation)
|
||||
interval = random.randint(1, self.max_interval) if self.is_train else 1
|
||||
|
||||
# Load input frames
|
||||
input_frames = []
|
||||
for i in range(self.num_frames):
|
||||
frame_idx = start_idx + i * interval
|
||||
frame = self._load_frame(video_idx, frame_idx)
|
||||
|
||||
# Apply transform (same for all frames in sequence)
|
||||
if self.transform:
|
||||
frame = self.transform(frame)
|
||||
|
||||
input_frames.append(frame)
|
||||
|
||||
# Load target frame (next frame after input sequence)
|
||||
target_idx = start_idx + self.num_frames * interval
|
||||
target_frame = self._load_frame(video_idx, target_idx)
|
||||
if self.transform:
|
||||
target_frame = self.transform(target_frame)
|
||||
|
||||
# Convert to tensors and normalize
|
||||
input_tensors = []
|
||||
for frame in input_frames:
|
||||
tensor = transforms.ToTensor()(frame)
|
||||
tensor = self.normalize(tensor)
|
||||
input_tensors.append(tensor)
|
||||
|
||||
target_tensor = transforms.ToTensor()(target_frame)
|
||||
target_tensor = self.normalize(target_tensor)
|
||||
|
||||
# Concatenate input frames along channel dimension
|
||||
input_concatenated = torch.cat(input_tensors, dim=0)
|
||||
|
||||
# Temporal index (for contrastive loss)
|
||||
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
||||
|
||||
return input_concatenated, target_tensor, temporal_idx
|
||||
|
||||
|
||||
class SyntheticVideoDataset(Dataset):
|
||||
"""
|
||||
Synthetic dataset for testing - generates random frames
|
||||
"""
|
||||
def __init__(self,
|
||||
num_samples: int = 1000,
|
||||
num_frames: int = 3,
|
||||
frame_size: int = 224,
|
||||
is_train: bool = True):
|
||||
self.num_samples = num_samples
|
||||
self.num_frames = num_frames
|
||||
self.frame_size = frame_size
|
||||
self.is_train = is_train
|
||||
|
||||
# Normalization
|
||||
self.normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Generate random "frames" (noise with temporal correlation)
|
||||
input_frames = []
|
||||
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
|
||||
|
||||
for i in range(self.num_frames):
|
||||
# Add some temporal correlation
|
||||
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||
frame = torch.clamp(frame, -1, 1)
|
||||
input_frames.append(self.normalize(frame))
|
||||
prev_frame = frame
|
||||
|
||||
# Target frame (next in sequence)
|
||||
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||
target_frame = torch.clamp(target_frame, -1, 1)
|
||||
target_tensor = self.normalize(target_frame)
|
||||
|
||||
# Concatenate inputs
|
||||
input_concatenated = torch.cat(input_frames, dim=0)
|
||||
|
||||
# Temporal index
|
||||
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
||||
|
||||
return input_concatenated, target_tensor, temporal_idx
|
||||
Reference in New Issue
Block a user