更新归一化方式,当前直接映射,不利用均值标准差进行标准化

This commit is contained in:
2026-01-08 16:10:24 +08:00
parent f7601e9170
commit 500c2eb18f
3 changed files with 89 additions and 74 deletions

View File

@@ -11,7 +11,7 @@ shift 2
# Default parameters
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
BATCH_SIZE=${BATCH_SIZE:-32}
BATCH_SIZE=${BATCH_SIZE:-256}
EPOCHS=${EPOCHS:-100}
LR=${LR:-1e-3}
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}

View File

@@ -19,7 +19,7 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma
from util import *
from models import *
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
from util.video_dataset import VideoFrameDataset
from util.frame_losses import MultiTaskLoss
# Try to import TensorBoard
@@ -47,7 +47,7 @@ def get_args_parser():
help='Number of input frames (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='Input frame size')
parser.add_argument('--max-interval', default=1, type=int,
parser.add_argument('--max-interval', default=4, type=int,
help='Maximum interval between consecutive frames')
# Model parameters
@@ -109,10 +109,10 @@ def get_args_parser():
help='Weight for frame prediction loss')
parser.add_argument('--contrastive-weight', type=float, default=0.1,
help='Weight for contrastive loss')
parser.add_argument('--l1-weight', type=float, default=1.0,
help='Weight for L1 loss')
parser.add_argument('--ssim-weight', type=float, default=0.1,
help='Weight for SSIM loss')
# parser.add_argument('--l1-weight', type=float, default=1.0,
# help='Weight for L1 loss')
# parser.add_argument('--ssim-weight', type=float, default=0.1,
# help='Weight for SSIM loss')
parser.add_argument('--no-contrastive', action='store_true',
help='Disable contrastive loss')
parser.add_argument('--no-ssim', action='store_true',
@@ -326,7 +326,7 @@ def main(args):
lr_scheduler.step(epoch)
# Save checkpoint
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1):
if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1):
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
utils.save_on_master({
'model': model_without_ddp.state_dict(),

View File

@@ -47,28 +47,40 @@ class VideoFrameDataset(Dataset):
self.frame_size = frame_size
self.is_train = is_train
self.max_interval = max_interval
# if num_frames < 1:
# raise ValueError("num_frames must be >= 1")
# if frame_size < 1:
# raise ValueError("frame_size must be >= 1")
# if max_interval < 1:
# raise ValueError("max_interval must be >= 1")
# Collect all video folders
# Collect all video folders and their frame files
self.video_folders = []
self.video_frame_files = [] # list of list of Path objects
for item in self.root_dir.iterdir():
if item.is_dir():
self.video_folders.append(item)
# Get all frame files
frame_files = sorted([f for f in item.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
self.video_frame_files.append(frame_files)
if len(self.video_folders) == 0:
raise ValueError(f"No video folders found in {root_dir}")
# Build frame index: list of (video_idx, start_frame_idx)
self.frame_indices = []
for video_idx, video_folder in enumerate(self.video_folders):
# Get all frame files
frame_files = sorted([f for f in video_folder.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
if len(frame_files) < num_frames + 1:
for video_idx, frame_files in enumerate(self.video_frame_files):
# Minimum frames needed considering max interval
min_frames_needed = num_frames * max_interval + 1
if len(frame_files) < min_frames_needed:
continue # Skip videos with insufficient frames
# Add all possible starting positions
for start_idx in range(len(frame_files) - num_frames):
# Ensure that for any interval up to max_interval, all frames are within bounds
max_start = len(frame_files) - num_frames * max_interval
for start_idx in range(max_start):
self.frame_indices.append((video_idx, start_idx))
if len(self.frame_indices) == 0:
@@ -80,14 +92,12 @@ class VideoFrameDataset(Dataset):
else:
self.transform = transform
# Normalization for Y channel (single channel)
# Compute average of ImageNet RGB means and stds
y_mean = (0.485 + 0.456 + 0.406) / 3.0
y_std = (0.229 + 0.224 + 0.225) / 3.0
self.normalize = transforms.Normalize(
mean=[y_mean],
std=[y_std]
)
# Simple normalization to [-1, 1] range (不使用ImageNet标准化)
# Convert pixel values [0, 255] to [-1, 1]
# This matches the model's tanh output range
self.normalize = None # We'll handle normalization manually
# print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]")
def _default_transform(self):
"""Default transform with augmentation for training"""
@@ -105,9 +115,12 @@ class VideoFrameDataset(Dataset):
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
"""Load a single frame as PIL Image"""
video_folder = self.video_folders[video_idx]
frame_files = sorted([f for f in video_folder.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
frame_files = self.video_frame_files[video_idx]
if frame_idx < 0 or frame_idx >= len(frame_files):
raise IndexError(
f"Frame index {frame_idx} out of range for video {video_idx} "
f"(0-{len(frame_files)-1})"
)
frame_path = frame_files[frame_idx]
return Image.open(frame_path).convert('RGB')
@@ -144,19 +157,21 @@ class VideoFrameDataset(Dataset):
if self.transform:
target_frame = self.transform(target_frame)
# Convert to tensors, normalize, and convert to grayscale (Y channel)
# Convert to tensors and convert to grayscale (Y channel)
input_tensors = []
for frame in input_frames:
tensor = transforms.ToTensor()(frame) # [3, H, W]
tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1]
# Convert RGB to grayscale using weighted sum
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W]
gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1]
# Normalize from [0, 1] to [-1, 1]
gray = gray * 2 - 1 # [0,1] -> [-1,1]
input_tensors.append(gray)
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W]
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1]
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
target_gray = self.normalize(target_gray)
# Normalize from [0, 1] to [-1, 1]
target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1]
# Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
@@ -167,52 +182,52 @@ class VideoFrameDataset(Dataset):
return input_concatenated, target_gray, temporal_idx
class SyntheticVideoDataset(Dataset):
"""
Synthetic dataset for testing - generates random frames
"""
def __init__(self,
num_samples: int = 1000,
num_frames: int = 3,
frame_size: int = 224,
is_train: bool = True):
self.num_samples = num_samples
self.num_frames = num_frames
self.frame_size = frame_size
self.is_train = is_train
# class SyntheticVideoDataset(Dataset):
# """
# Synthetic dataset for testing - generates random frames
# """
# def __init__(self,
# num_samples: int = 1000,
# num_frames: int = 3,
# frame_size: int = 224,
# is_train: bool = True):
# self.num_samples = num_samples
# self.num_frames = num_frames
# self.frame_size = frame_size
# self.is_train = is_train
# Normalization for Y channel (single channel)
y_mean = (0.485 + 0.456 + 0.406) / 3.0
y_std = (0.229 + 0.224 + 0.225) / 3.0
self.normalize = transforms.Normalize(
mean=[y_mean],
std=[y_std]
)
# # Normalization for Y channel (single channel)
# y_mean = (0.485 + 0.456 + 0.406) / 3.0
# y_std = (0.229 + 0.224 + 0.225) / 3.0
# self.normalize = transforms.Normalize(
# mean=[y_mean],
# std=[y_std]
# )
def __len__(self):
return self.num_samples
# def __len__(self):
# return self.num_samples
def __getitem__(self, idx):
# Generate random "frames" (noise with temporal correlation)
input_frames = []
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
# def __getitem__(self, idx):
# # Generate random "frames" (noise with temporal correlation)
# input_frames = []
# prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
for i in range(self.num_frames):
# Add some temporal correlation
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
frame = torch.clamp(frame, -1, 1)
input_frames.append(self.normalize(frame))
prev_frame = frame
# for i in range(self.num_frames):
# # Add some temporal correlation
# frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
# frame = torch.clamp(frame, -1, 1)
# input_frames.append(self.normalize(frame))
# prev_frame = frame
# Target frame (next in sequence)
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
target_frame = torch.clamp(target_frame, -1, 1)
target_tensor = self.normalize(target_frame)
# # Target frame (next in sequence)
# target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
# target_frame = torch.clamp(target_frame, -1, 1)
# target_tensor = self.normalize(target_frame)
# Concatenate inputs
input_concatenated = torch.cat(input_frames, dim=0)
# # Concatenate inputs
# input_concatenated = torch.cat(input_frames, dim=0)
# Temporal index
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
# # Temporal index
# temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx
# return input_concatenated, target_tensor, temporal_idx