更新归一化方式,当前直接映射,不利用均值标准差进行标准化
This commit is contained in:
@@ -11,7 +11,7 @@ shift 2
|
||||
|
||||
# Default parameters
|
||||
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
|
||||
BATCH_SIZE=${BATCH_SIZE:-32}
|
||||
BATCH_SIZE=${BATCH_SIZE:-256}
|
||||
EPOCHS=${EPOCHS:-100}
|
||||
LR=${LR:-1e-3}
|
||||
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
|
||||
|
||||
@@ -19,7 +19,7 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma
|
||||
from util import *
|
||||
from models import *
|
||||
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
||||
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
|
||||
from util.video_dataset import VideoFrameDataset
|
||||
from util.frame_losses import MultiTaskLoss
|
||||
|
||||
# Try to import TensorBoard
|
||||
@@ -47,7 +47,7 @@ def get_args_parser():
|
||||
help='Number of input frames (T)')
|
||||
parser.add_argument('--frame-size', default=224, type=int,
|
||||
help='Input frame size')
|
||||
parser.add_argument('--max-interval', default=1, type=int,
|
||||
parser.add_argument('--max-interval', default=4, type=int,
|
||||
help='Maximum interval between consecutive frames')
|
||||
|
||||
# Model parameters
|
||||
@@ -109,10 +109,10 @@ def get_args_parser():
|
||||
help='Weight for frame prediction loss')
|
||||
parser.add_argument('--contrastive-weight', type=float, default=0.1,
|
||||
help='Weight for contrastive loss')
|
||||
parser.add_argument('--l1-weight', type=float, default=1.0,
|
||||
help='Weight for L1 loss')
|
||||
parser.add_argument('--ssim-weight', type=float, default=0.1,
|
||||
help='Weight for SSIM loss')
|
||||
# parser.add_argument('--l1-weight', type=float, default=1.0,
|
||||
# help='Weight for L1 loss')
|
||||
# parser.add_argument('--ssim-weight', type=float, default=0.1,
|
||||
# help='Weight for SSIM loss')
|
||||
parser.add_argument('--no-contrastive', action='store_true',
|
||||
help='Disable contrastive loss')
|
||||
parser.add_argument('--no-ssim', action='store_true',
|
||||
@@ -326,7 +326,7 @@ def main(args):
|
||||
lr_scheduler.step(epoch)
|
||||
|
||||
# Save checkpoint
|
||||
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1):
|
||||
if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1):
|
||||
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
|
||||
utils.save_on_master({
|
||||
'model': model_without_ddp.state_dict(),
|
||||
|
||||
@@ -48,27 +48,39 @@ class VideoFrameDataset(Dataset):
|
||||
self.is_train = is_train
|
||||
self.max_interval = max_interval
|
||||
|
||||
# Collect all video folders
|
||||
# if num_frames < 1:
|
||||
# raise ValueError("num_frames must be >= 1")
|
||||
# if frame_size < 1:
|
||||
# raise ValueError("frame_size must be >= 1")
|
||||
# if max_interval < 1:
|
||||
# raise ValueError("max_interval must be >= 1")
|
||||
|
||||
# Collect all video folders and their frame files
|
||||
self.video_folders = []
|
||||
self.video_frame_files = [] # list of list of Path objects
|
||||
for item in self.root_dir.iterdir():
|
||||
if item.is_dir():
|
||||
self.video_folders.append(item)
|
||||
# Get all frame files
|
||||
frame_files = sorted([f for f in item.iterdir()
|
||||
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
|
||||
self.video_frame_files.append(frame_files)
|
||||
|
||||
if len(self.video_folders) == 0:
|
||||
raise ValueError(f"No video folders found in {root_dir}")
|
||||
|
||||
# Build frame index: list of (video_idx, start_frame_idx)
|
||||
self.frame_indices = []
|
||||
for video_idx, video_folder in enumerate(self.video_folders):
|
||||
# Get all frame files
|
||||
frame_files = sorted([f for f in video_folder.iterdir()
|
||||
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
|
||||
|
||||
if len(frame_files) < num_frames + 1:
|
||||
for video_idx, frame_files in enumerate(self.video_frame_files):
|
||||
# Minimum frames needed considering max interval
|
||||
min_frames_needed = num_frames * max_interval + 1
|
||||
if len(frame_files) < min_frames_needed:
|
||||
continue # Skip videos with insufficient frames
|
||||
|
||||
# Add all possible starting positions
|
||||
for start_idx in range(len(frame_files) - num_frames):
|
||||
# Ensure that for any interval up to max_interval, all frames are within bounds
|
||||
max_start = len(frame_files) - num_frames * max_interval
|
||||
for start_idx in range(max_start):
|
||||
self.frame_indices.append((video_idx, start_idx))
|
||||
|
||||
if len(self.frame_indices) == 0:
|
||||
@@ -80,14 +92,12 @@ class VideoFrameDataset(Dataset):
|
||||
else:
|
||||
self.transform = transform
|
||||
|
||||
# Normalization for Y channel (single channel)
|
||||
# Compute average of ImageNet RGB means and stds
|
||||
y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
||||
y_std = (0.229 + 0.224 + 0.225) / 3.0
|
||||
self.normalize = transforms.Normalize(
|
||||
mean=[y_mean],
|
||||
std=[y_std]
|
||||
)
|
||||
# Simple normalization to [-1, 1] range (不使用ImageNet标准化)
|
||||
# Convert pixel values [0, 255] to [-1, 1]
|
||||
# This matches the model's tanh output range
|
||||
self.normalize = None # We'll handle normalization manually
|
||||
|
||||
# print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]")
|
||||
|
||||
def _default_transform(self):
|
||||
"""Default transform with augmentation for training"""
|
||||
@@ -105,9 +115,12 @@ class VideoFrameDataset(Dataset):
|
||||
|
||||
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
|
||||
"""Load a single frame as PIL Image"""
|
||||
video_folder = self.video_folders[video_idx]
|
||||
frame_files = sorted([f for f in video_folder.iterdir()
|
||||
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
|
||||
frame_files = self.video_frame_files[video_idx]
|
||||
if frame_idx < 0 or frame_idx >= len(frame_files):
|
||||
raise IndexError(
|
||||
f"Frame index {frame_idx} out of range for video {video_idx} "
|
||||
f"(0-{len(frame_files)-1})"
|
||||
)
|
||||
frame_path = frame_files[frame_idx]
|
||||
return Image.open(frame_path).convert('RGB')
|
||||
|
||||
@@ -144,19 +157,21 @@ class VideoFrameDataset(Dataset):
|
||||
if self.transform:
|
||||
target_frame = self.transform(target_frame)
|
||||
|
||||
# Convert to tensors, normalize, and convert to grayscale (Y channel)
|
||||
# Convert to tensors and convert to grayscale (Y channel)
|
||||
input_tensors = []
|
||||
for frame in input_frames:
|
||||
tensor = transforms.ToTensor()(frame) # [3, H, W]
|
||||
tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1]
|
||||
# Convert RGB to grayscale using weighted sum
|
||||
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
|
||||
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W]
|
||||
gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast)
|
||||
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1]
|
||||
# Normalize from [0, 1] to [-1, 1]
|
||||
gray = gray * 2 - 1 # [0,1] -> [-1,1]
|
||||
input_tensors.append(gray)
|
||||
|
||||
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W]
|
||||
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1]
|
||||
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
|
||||
target_gray = self.normalize(target_gray)
|
||||
# Normalize from [0, 1] to [-1, 1]
|
||||
target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1]
|
||||
|
||||
# Concatenate input frames along channel dimension
|
||||
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
|
||||
@@ -167,52 +182,52 @@ class VideoFrameDataset(Dataset):
|
||||
return input_concatenated, target_gray, temporal_idx
|
||||
|
||||
|
||||
class SyntheticVideoDataset(Dataset):
|
||||
"""
|
||||
Synthetic dataset for testing - generates random frames
|
||||
"""
|
||||
def __init__(self,
|
||||
num_samples: int = 1000,
|
||||
num_frames: int = 3,
|
||||
frame_size: int = 224,
|
||||
is_train: bool = True):
|
||||
self.num_samples = num_samples
|
||||
self.num_frames = num_frames
|
||||
self.frame_size = frame_size
|
||||
self.is_train = is_train
|
||||
# class SyntheticVideoDataset(Dataset):
|
||||
# """
|
||||
# Synthetic dataset for testing - generates random frames
|
||||
# """
|
||||
# def __init__(self,
|
||||
# num_samples: int = 1000,
|
||||
# num_frames: int = 3,
|
||||
# frame_size: int = 224,
|
||||
# is_train: bool = True):
|
||||
# self.num_samples = num_samples
|
||||
# self.num_frames = num_frames
|
||||
# self.frame_size = frame_size
|
||||
# self.is_train = is_train
|
||||
|
||||
# Normalization for Y channel (single channel)
|
||||
y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
||||
y_std = (0.229 + 0.224 + 0.225) / 3.0
|
||||
self.normalize = transforms.Normalize(
|
||||
mean=[y_mean],
|
||||
std=[y_std]
|
||||
)
|
||||
# # Normalization for Y channel (single channel)
|
||||
# y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
||||
# y_std = (0.229 + 0.224 + 0.225) / 3.0
|
||||
# self.normalize = transforms.Normalize(
|
||||
# mean=[y_mean],
|
||||
# std=[y_std]
|
||||
# )
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
# def __len__(self):
|
||||
# return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Generate random "frames" (noise with temporal correlation)
|
||||
input_frames = []
|
||||
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
|
||||
# def __getitem__(self, idx):
|
||||
# # Generate random "frames" (noise with temporal correlation)
|
||||
# input_frames = []
|
||||
# prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
|
||||
|
||||
for i in range(self.num_frames):
|
||||
# Add some temporal correlation
|
||||
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||
frame = torch.clamp(frame, -1, 1)
|
||||
input_frames.append(self.normalize(frame))
|
||||
prev_frame = frame
|
||||
# for i in range(self.num_frames):
|
||||
# # Add some temporal correlation
|
||||
# frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||
# frame = torch.clamp(frame, -1, 1)
|
||||
# input_frames.append(self.normalize(frame))
|
||||
# prev_frame = frame
|
||||
|
||||
# Target frame (next in sequence)
|
||||
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||
target_frame = torch.clamp(target_frame, -1, 1)
|
||||
target_tensor = self.normalize(target_frame)
|
||||
# # Target frame (next in sequence)
|
||||
# target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||
# target_frame = torch.clamp(target_frame, -1, 1)
|
||||
# target_tensor = self.normalize(target_frame)
|
||||
|
||||
# Concatenate inputs
|
||||
input_concatenated = torch.cat(input_frames, dim=0)
|
||||
# # Concatenate inputs
|
||||
# input_concatenated = torch.cat(input_frames, dim=0)
|
||||
|
||||
# Temporal index
|
||||
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
||||
# # Temporal index
|
||||
# temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
||||
|
||||
return input_concatenated, target_tensor, temporal_idx
|
||||
# return input_concatenated, target_tensor, temporal_idx
|
||||
Reference in New Issue
Block a user