From 500c2eb18f8503af67bed1f3380bdd6ba489079e Mon Sep 17 00:00:00 2001 From: CaoWangrenbo Date: Thu, 8 Jan 2026 16:10:24 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E5=BD=92=E4=B8=80=E5=8C=96?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=EF=BC=8C=E5=BD=93=E5=89=8D=E7=9B=B4=E6=8E=A5?= =?UTF-8?q?=E6=98=A0=E5=B0=84=EF=BC=8C=E4=B8=8D=E5=88=A9=E7=94=A8=E5=9D=87?= =?UTF-8?q?=E5=80=BC=E6=A0=87=E5=87=86=E5=B7=AE=E8=BF=9B=E8=A1=8C=E6=A0=87?= =?UTF-8?q?=E5=87=86=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dist_temporal_train.sh | 2 +- main_temporal.py | 14 ++-- util/video_dataset.py | 147 +++++++++++++++++++++++------------------ 3 files changed, 89 insertions(+), 74 deletions(-) diff --git a/dist_temporal_train.sh b/dist_temporal_train.sh index 1c35ec6..ce10ceb 100755 --- a/dist_temporal_train.sh +++ b/dist_temporal_train.sh @@ -11,7 +11,7 @@ shift 2 # Default parameters MODEL=${MODEL:-"SwiftFormerTemporal_XS"} -BATCH_SIZE=${BATCH_SIZE:-32} +BATCH_SIZE=${BATCH_SIZE:-256} EPOCHS=${EPOCHS:-100} LR=${LR:-1e-3} OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"} diff --git a/main_temporal.py b/main_temporal.py index 56f0a98..d1fab34 100644 --- a/main_temporal.py +++ b/main_temporal.py @@ -19,7 +19,7 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma from util import * from models import * from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3 -from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset +from util.video_dataset import VideoFrameDataset from util.frame_losses import MultiTaskLoss # Try to import TensorBoard @@ -47,7 +47,7 @@ def get_args_parser(): help='Number of input frames (T)') parser.add_argument('--frame-size', default=224, type=int, help='Input frame size') - parser.add_argument('--max-interval', default=1, type=int, + parser.add_argument('--max-interval', default=4, type=int, help='Maximum interval between consecutive frames') # Model parameters @@ -109,10 +109,10 @@ def get_args_parser(): help='Weight for frame prediction loss') parser.add_argument('--contrastive-weight', type=float, default=0.1, help='Weight for contrastive loss') - parser.add_argument('--l1-weight', type=float, default=1.0, - help='Weight for L1 loss') - parser.add_argument('--ssim-weight', type=float, default=0.1, - help='Weight for SSIM loss') + # parser.add_argument('--l1-weight', type=float, default=1.0, + # help='Weight for L1 loss') + # parser.add_argument('--ssim-weight', type=float, default=0.1, + # help='Weight for SSIM loss') parser.add_argument('--no-contrastive', action='store_true', help='Disable contrastive loss') parser.add_argument('--no-ssim', action='store_true', @@ -326,7 +326,7 @@ def main(args): lr_scheduler.step(epoch) # Save checkpoint - if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1): + if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1): checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth' utils.save_on_master({ 'model': model_without_ddp.state_dict(), diff --git a/util/video_dataset.py b/util/video_dataset.py index 50ce612..1823f9d 100644 --- a/util/video_dataset.py +++ b/util/video_dataset.py @@ -47,28 +47,40 @@ class VideoFrameDataset(Dataset): self.frame_size = frame_size self.is_train = is_train self.max_interval = max_interval + + # if num_frames < 1: + # raise ValueError("num_frames must be >= 1") + # if frame_size < 1: + # raise ValueError("frame_size must be >= 1") + # if max_interval < 1: + # raise ValueError("max_interval must be >= 1") - # Collect all video folders + # Collect all video folders and their frame files self.video_folders = [] + self.video_frame_files = [] # list of list of Path objects for item in self.root_dir.iterdir(): if item.is_dir(): self.video_folders.append(item) + # Get all frame files + frame_files = sorted([f for f in item.iterdir() + if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) + self.video_frame_files.append(frame_files) if len(self.video_folders) == 0: raise ValueError(f"No video folders found in {root_dir}") # Build frame index: list of (video_idx, start_frame_idx) self.frame_indices = [] - for video_idx, video_folder in enumerate(self.video_folders): - # Get all frame files - frame_files = sorted([f for f in video_folder.iterdir() - if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) - - if len(frame_files) < num_frames + 1: + for video_idx, frame_files in enumerate(self.video_frame_files): + # Minimum frames needed considering max interval + min_frames_needed = num_frames * max_interval + 1 + if len(frame_files) < min_frames_needed: continue # Skip videos with insufficient frames # Add all possible starting positions - for start_idx in range(len(frame_files) - num_frames): + # Ensure that for any interval up to max_interval, all frames are within bounds + max_start = len(frame_files) - num_frames * max_interval + for start_idx in range(max_start): self.frame_indices.append((video_idx, start_idx)) if len(self.frame_indices) == 0: @@ -80,14 +92,12 @@ class VideoFrameDataset(Dataset): else: self.transform = transform - # Normalization for Y channel (single channel) - # Compute average of ImageNet RGB means and stds - y_mean = (0.485 + 0.456 + 0.406) / 3.0 - y_std = (0.229 + 0.224 + 0.225) / 3.0 - self.normalize = transforms.Normalize( - mean=[y_mean], - std=[y_std] - ) + # Simple normalization to [-1, 1] range (不使用ImageNet标准化) + # Convert pixel values [0, 255] to [-1, 1] + # This matches the model's tanh output range + self.normalize = None # We'll handle normalization manually + + # print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]") def _default_transform(self): """Default transform with augmentation for training""" @@ -105,9 +115,12 @@ class VideoFrameDataset(Dataset): def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image: """Load a single frame as PIL Image""" - video_folder = self.video_folders[video_idx] - frame_files = sorted([f for f in video_folder.iterdir() - if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']]) + frame_files = self.video_frame_files[video_idx] + if frame_idx < 0 or frame_idx >= len(frame_files): + raise IndexError( + f"Frame index {frame_idx} out of range for video {video_idx} " + f"(0-{len(frame_files)-1})" + ) frame_path = frame_files[frame_idx] return Image.open(frame_path).convert('RGB') @@ -144,19 +157,21 @@ class VideoFrameDataset(Dataset): if self.transform: target_frame = self.transform(target_frame) - # Convert to tensors, normalize, and convert to grayscale (Y channel) + # Convert to tensors and convert to grayscale (Y channel) input_tensors = [] for frame in input_frames: - tensor = transforms.ToTensor()(frame) # [3, H, W] + tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1] # Convert RGB to grayscale using weighted sum # Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL) - gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W] - gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast) + gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1] + # Normalize from [0, 1] to [-1, 1] + gray = gray * 2 - 1 # [0,1] -> [-1,1] input_tensors.append(gray) - target_tensor = transforms.ToTensor()(target_frame) # [3, H, W] + target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1] target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0) - target_gray = self.normalize(target_gray) + # Normalize from [0, 1] to [-1, 1] + target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1] # Concatenate input frames along channel dimension input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W] @@ -167,52 +182,52 @@ class VideoFrameDataset(Dataset): return input_concatenated, target_gray, temporal_idx -class SyntheticVideoDataset(Dataset): - """ - Synthetic dataset for testing - generates random frames - """ - def __init__(self, - num_samples: int = 1000, - num_frames: int = 3, - frame_size: int = 224, - is_train: bool = True): - self.num_samples = num_samples - self.num_frames = num_frames - self.frame_size = frame_size - self.is_train = is_train +# class SyntheticVideoDataset(Dataset): +# """ +# Synthetic dataset for testing - generates random frames +# """ +# def __init__(self, +# num_samples: int = 1000, +# num_frames: int = 3, +# frame_size: int = 224, +# is_train: bool = True): +# self.num_samples = num_samples +# self.num_frames = num_frames +# self.frame_size = frame_size +# self.is_train = is_train - # Normalization for Y channel (single channel) - y_mean = (0.485 + 0.456 + 0.406) / 3.0 - y_std = (0.229 + 0.224 + 0.225) / 3.0 - self.normalize = transforms.Normalize( - mean=[y_mean], - std=[y_std] - ) +# # Normalization for Y channel (single channel) +# y_mean = (0.485 + 0.456 + 0.406) / 3.0 +# y_std = (0.229 + 0.224 + 0.225) / 3.0 +# self.normalize = transforms.Normalize( +# mean=[y_mean], +# std=[y_std] +# ) - def __len__(self): - return self.num_samples +# def __len__(self): +# return self.num_samples - def __getitem__(self, idx): - # Generate random "frames" (noise with temporal correlation) - input_frames = [] - prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1 +# def __getitem__(self, idx): +# # Generate random "frames" (noise with temporal correlation) +# input_frames = [] +# prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1 - for i in range(self.num_frames): - # Add some temporal correlation - frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 - frame = torch.clamp(frame, -1, 1) - input_frames.append(self.normalize(frame)) - prev_frame = frame +# for i in range(self.num_frames): +# # Add some temporal correlation +# frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 +# frame = torch.clamp(frame, -1, 1) +# input_frames.append(self.normalize(frame)) +# prev_frame = frame - # Target frame (next in sequence) - target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 - target_frame = torch.clamp(target_frame, -1, 1) - target_tensor = self.normalize(target_frame) +# # Target frame (next in sequence) +# target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05 +# target_frame = torch.clamp(target_frame, -1, 1) +# target_tensor = self.normalize(target_frame) - # Concatenate inputs - input_concatenated = torch.cat(input_frames, dim=0) +# # Concatenate inputs +# input_concatenated = torch.cat(input_frames, dim=0) - # Temporal index - temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) +# # Temporal index +# temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) - return input_concatenated, target_tensor, temporal_idx \ No newline at end of file +# return input_concatenated, target_tensor, temporal_idx \ No newline at end of file