""" Video frame dataset for temporal self-supervised learning """ import os import random from pathlib import Path from typing import Optional, Tuple, List import torch from torch.utils.data import Dataset from torchvision import transforms from PIL import Image import numpy as np class VideoFrameDataset(Dataset): """ Dataset for loading consecutive frames from videos for frame prediction. Assumes directory structure: dataset_root/ video1/ frame_0001.jpg frame_0002.jpg ... video2/ ... """ def __init__(self, root_dir: str, num_frames: int = 3, frame_size: int = 224, is_train: bool = True, max_interval: int = 1, transform=None): """ Args: root_dir: Root directory containing video folders num_frames: Number of input frames (T) frame_size: Size to resize frames to is_train: Whether this is training set (affects augmentation) max_interval: Maximum interval between consecutive frames transform: Optional custom transform """ self.root_dir = Path(root_dir) self.num_frames = num_frames self.frame_size = frame_size self.is_train = is_train self.max_interval = max_interval # Collect all video folders self.video_folders = [] for item in self.root_dir.iterdir(): if item.is_dir(): self.video_folders.append(item) if len(self.video_folders) == 0: raise ValueError(f"No video folders found in {root_dir}") # Build frame index: list of (video_idx, start_frame_idx) self.frame_indices = [] for video_idx, video_folder in enumerate(self.video_folders): # Get all frame files frame_files = sorted([f for f in video_folder.iterdir() if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) if len(frame_files) < num_frames + 1: continue # Skip videos with insufficient frames # Add all possible starting positions for start_idx in range(len(frame_files) - num_frames): self.frame_indices.append((video_idx, start_idx)) if len(self.frame_indices) == 0: raise ValueError("No valid frame sequences found in dataset") # Default transforms if transform is None: self.transform = self._default_transform() else: self.transform = transform # Normalization for Y channel (single channel) # Compute average of ImageNet RGB means and stds y_mean = (0.485 + 0.456 + 0.406) / 3.0 y_std = (0.229 + 0.224 + 0.225) / 3.0 self.normalize = transforms.Normalize( mean=[y_mean], std=[y_std] ) def _default_transform(self): """Default transform with augmentation for training""" if self.is_train: return transforms.Compose([ transforms.RandomResizedCrop(self.frame_size, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), ]) else: return transforms.Compose([ transforms.Resize(int(self.frame_size * 1.14)), transforms.CenterCrop(self.frame_size), ]) def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image: """Load a single frame as PIL Image""" video_folder = self.video_folders[video_idx] frame_files = sorted([f for f in video_folder.iterdir() if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) frame_path = frame_files[frame_idx] return Image.open(frame_path).convert('RGB') def __len__(self) -> int: return len(self.frame_indices) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns: input_frames: [num_frames, H, W] concatenated input frames (Y channel only) target_frame: [1, H, W] target frame to predict (Y channel only) temporal_idx: temporal index of target frame (for contrastive loss) """ video_idx, start_idx = self.frame_indices[idx] # Determine frame interval (for temporal augmentation) interval = random.randint(1, self.max_interval) if self.is_train else 1 # Load input frames input_frames = [] for i in range(self.num_frames): frame_idx = start_idx + i * interval frame = self._load_frame(video_idx, frame_idx) # Apply transform (same for all frames in sequence) if self.transform: frame = self.transform(frame) input_frames.append(frame) # Load target frame (next frame after input sequence) target_idx = start_idx + self.num_frames * interval target_frame = self._load_frame(video_idx, target_idx) if self.transform: target_frame = self.transform(target_frame) # Convert to tensors, normalize, and convert to grayscale (Y channel) input_tensors = [] for frame in input_frames: tensor = transforms.ToTensor()(frame) # [3, H, W] # Convert RGB to grayscale using weighted sum # Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL) gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W] gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast) input_tensors.append(gray) target_tensor = transforms.ToTensor()(target_frame) # [3, H, W] target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0) target_gray = self.normalize(target_gray) # Concatenate input frames along channel dimension input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W] # Temporal index (for contrastive loss) temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) return input_concatenated, target_gray, temporal_idx class SyntheticVideoDataset(Dataset): """ Synthetic dataset for testing - generates random frames """ def __init__(self, num_samples: int = 1000, num_frames: int = 3, frame_size: int = 224, is_train: bool = True): self.num_samples = num_samples self.num_frames = num_frames self.frame_size = frame_size self.is_train = is_train # Normalization for Y channel (single channel) y_mean = (0.485 + 0.456 + 0.406) / 3.0 y_std = (0.229 + 0.224 + 0.225) / 3.0 self.normalize = transforms.Normalize( mean=[y_mean], std=[y_std] ) def __len__(self): return self.num_samples def __getitem__(self, idx): # Generate random "frames" (noise with temporal correlation) input_frames = [] prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1 for i in range(self.num_frames): # Add some temporal correlation frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 frame = torch.clamp(frame, -1, 1) input_frames.append(self.normalize(frame)) prev_frame = frame # Target frame (next in sequence) target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 target_frame = torch.clamp(target_frame, -1, 1) target_tensor = self.normalize(target_frame) # Concatenate inputs input_concatenated = torch.cat(input_frames, dim=0) # Temporal index temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) return input_concatenated, target_tensor, temporal_idx