""" Video frame dataset for temporal self-supervised learning """ import os import random from pathlib import Path from typing import Optional, Tuple, List import torch from torch.utils.data import Dataset from torchvision import transforms from PIL import Image import numpy as np class VideoFrameDataset(Dataset): """ Dataset for loading consecutive frames from videos for frame prediction. Assumes directory structure: dataset_root/ video1/ frame_0001.jpg frame_0002.jpg ... video2/ ... """ def __init__(self, root_dir: str, num_frames: int = 3, frame_size: int = 224, is_train: bool = True, max_interval: int = 1, transform=None): """ Args: root_dir: Root directory containing video folders num_frames: Number of input frames (T) frame_size: Size to resize frames to is_train: Whether this is training set (affects augmentation) max_interval: Maximum interval between consecutive frames transform: Optional custom transform """ self.root_dir = Path(root_dir) self.num_frames = num_frames self.frame_size = frame_size self.is_train = is_train self.max_interval = max_interval # if num_frames < 1: # raise ValueError("num_frames must be >= 1") # if frame_size < 1: # raise ValueError("frame_size must be >= 1") # if max_interval < 1: # raise ValueError("max_interval must be >= 1") # Collect all video folders and their frame files self.video_folders = [] self.video_frame_files = [] # list of list of Path objects for item in self.root_dir.iterdir(): if item.is_dir(): self.video_folders.append(item) # Get all frame files frame_files = sorted([f for f in item.iterdir() if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) self.video_frame_files.append(frame_files) if len(self.video_folders) == 0: raise ValueError(f"No video folders found in {root_dir}") # Build frame index: list of (video_idx, start_frame_idx) self.frame_indices = [] for video_idx, frame_files in enumerate(self.video_frame_files): # Minimum frames needed considering max interval min_frames_needed = num_frames * max_interval + 1 if len(frame_files) < min_frames_needed: continue # Skip videos with insufficient frames # Add all possible starting positions # Ensure that for any interval up to max_interval, all frames are within bounds max_start = len(frame_files) - num_frames * max_interval for start_idx in range(max_start): self.frame_indices.append((video_idx, start_idx)) if len(self.frame_indices) == 0: raise ValueError("No valid frame sequences found in dataset") # Default transforms if transform is None: self.transform = self._default_transform() else: self.transform = transform # Simple normalization to [-1, 1] range (不使用ImageNet标准化) # Convert pixel values [0, 255] to [-1, 1] # This matches the model's tanh output range self.normalize = None # We'll handle normalization manually # print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]") def _default_transform(self): """Default transform with augmentation for training""" if self.is_train: return transforms.Compose([ transforms.RandomResizedCrop(self.frame_size, scale=(0.8, 1.0)), transforms.RandomHorizontalFlip(), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), ]) else: return transforms.Compose([ transforms.Resize(int(self.frame_size * 1.14)), transforms.CenterCrop(self.frame_size), ]) def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image: """Load a single frame as PIL Image""" frame_files = self.video_frame_files[video_idx] if frame_idx < 0 or frame_idx >= len(frame_files): raise IndexError( f"Frame index {frame_idx} out of range for video {video_idx} " f"(0-{len(frame_files)-1})" ) frame_path = frame_files[frame_idx] return Image.open(frame_path).convert('RGB') def __len__(self) -> int: return len(self.frame_indices) def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns: input_frames: [num_frames, H, W] concatenated input frames (Y channel only) target_frame: [1, H, W] target frame to predict (Y channel only) temporal_idx: temporal index of target frame (for contrastive loss) """ video_idx, start_idx = self.frame_indices[idx] # Determine frame interval (for temporal augmentation) interval = random.randint(1, self.max_interval) if self.is_train else 1 # Load input frames input_frames = [] for i in range(self.num_frames): frame_idx = start_idx + i * interval frame = self._load_frame(video_idx, frame_idx) # Apply transform (same for all frames in sequence) if self.transform: frame = self.transform(frame) input_frames.append(frame) # Load target frame (next frame after input sequence) target_idx = start_idx + self.num_frames * interval target_frame = self._load_frame(video_idx, target_idx) if self.transform: target_frame = self.transform(target_frame) # Convert to tensors and convert to grayscale (Y channel) input_tensors = [] for frame in input_frames: tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1] # Convert RGB to grayscale using weighted sum # Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL) gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1] # Normalize from [0, 1] to [-1, 1] gray = gray * 2 - 1 # [0,1] -> [-1,1] input_tensors.append(gray) target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1] target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0) # Normalize from [0, 1] to [-1, 1] target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1] # Concatenate input frames along channel dimension input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W] # Temporal index (for contrastive loss) temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) return input_concatenated, target_gray, temporal_idx # class SyntheticVideoDataset(Dataset): # """ # Synthetic dataset for testing - generates random frames # """ # def __init__(self, # num_samples: int = 1000, # num_frames: int = 3, # frame_size: int = 224, # is_train: bool = True): # self.num_samples = num_samples # self.num_frames = num_frames # self.frame_size = frame_size # self.is_train = is_train # # Normalization for Y channel (single channel) # y_mean = (0.485 + 0.456 + 0.406) / 3.0 # y_std = (0.229 + 0.224 + 0.225) / 3.0 # self.normalize = transforms.Normalize( # mean=[y_mean], # std=[y_std] # ) # def __len__(self): # return self.num_samples # def __getitem__(self, idx): # # Generate random "frames" (noise with temporal correlation) # input_frames = [] # prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1 # for i in range(self.num_frames): # # Add some temporal correlation # frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 # frame = torch.clamp(frame, -1, 1) # input_frames.append(self.normalize(frame)) # prev_frame = frame # # Target frame (next in sequence) # target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 # target_frame = torch.clamp(target_frame, -1, 1) # target_tensor = self.normalize(target_frame) # # Concatenate inputs # input_concatenated = torch.cat(input_frames, dim=0) # # Temporal index # temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) # return input_concatenated, target_tensor, temporal_idx