更新归一化方式,当前直接映射,不利用均值标准差进行标准化
This commit is contained in:
@@ -11,7 +11,7 @@ shift 2
|
|||||||
|
|
||||||
# Default parameters
|
# Default parameters
|
||||||
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
|
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
|
||||||
BATCH_SIZE=${BATCH_SIZE:-32}
|
BATCH_SIZE=${BATCH_SIZE:-256}
|
||||||
EPOCHS=${EPOCHS:-100}
|
EPOCHS=${EPOCHS:-100}
|
||||||
LR=${LR:-1e-3}
|
LR=${LR:-1e-3}
|
||||||
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
|
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma
|
|||||||
from util import *
|
from util import *
|
||||||
from models import *
|
from models import *
|
||||||
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
||||||
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
|
from util.video_dataset import VideoFrameDataset
|
||||||
from util.frame_losses import MultiTaskLoss
|
from util.frame_losses import MultiTaskLoss
|
||||||
|
|
||||||
# Try to import TensorBoard
|
# Try to import TensorBoard
|
||||||
@@ -47,7 +47,7 @@ def get_args_parser():
|
|||||||
help='Number of input frames (T)')
|
help='Number of input frames (T)')
|
||||||
parser.add_argument('--frame-size', default=224, type=int,
|
parser.add_argument('--frame-size', default=224, type=int,
|
||||||
help='Input frame size')
|
help='Input frame size')
|
||||||
parser.add_argument('--max-interval', default=1, type=int,
|
parser.add_argument('--max-interval', default=4, type=int,
|
||||||
help='Maximum interval between consecutive frames')
|
help='Maximum interval between consecutive frames')
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
@@ -109,10 +109,10 @@ def get_args_parser():
|
|||||||
help='Weight for frame prediction loss')
|
help='Weight for frame prediction loss')
|
||||||
parser.add_argument('--contrastive-weight', type=float, default=0.1,
|
parser.add_argument('--contrastive-weight', type=float, default=0.1,
|
||||||
help='Weight for contrastive loss')
|
help='Weight for contrastive loss')
|
||||||
parser.add_argument('--l1-weight', type=float, default=1.0,
|
# parser.add_argument('--l1-weight', type=float, default=1.0,
|
||||||
help='Weight for L1 loss')
|
# help='Weight for L1 loss')
|
||||||
parser.add_argument('--ssim-weight', type=float, default=0.1,
|
# parser.add_argument('--ssim-weight', type=float, default=0.1,
|
||||||
help='Weight for SSIM loss')
|
# help='Weight for SSIM loss')
|
||||||
parser.add_argument('--no-contrastive', action='store_true',
|
parser.add_argument('--no-contrastive', action='store_true',
|
||||||
help='Disable contrastive loss')
|
help='Disable contrastive loss')
|
||||||
parser.add_argument('--no-ssim', action='store_true',
|
parser.add_argument('--no-ssim', action='store_true',
|
||||||
@@ -326,7 +326,7 @@ def main(args):
|
|||||||
lr_scheduler.step(epoch)
|
lr_scheduler.step(epoch)
|
||||||
|
|
||||||
# Save checkpoint
|
# Save checkpoint
|
||||||
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1):
|
if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1):
|
||||||
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
|
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
|
||||||
utils.save_on_master({
|
utils.save_on_master({
|
||||||
'model': model_without_ddp.state_dict(),
|
'model': model_without_ddp.state_dict(),
|
||||||
|
|||||||
@@ -47,28 +47,40 @@ class VideoFrameDataset(Dataset):
|
|||||||
self.frame_size = frame_size
|
self.frame_size = frame_size
|
||||||
self.is_train = is_train
|
self.is_train = is_train
|
||||||
self.max_interval = max_interval
|
self.max_interval = max_interval
|
||||||
|
|
||||||
|
# if num_frames < 1:
|
||||||
|
# raise ValueError("num_frames must be >= 1")
|
||||||
|
# if frame_size < 1:
|
||||||
|
# raise ValueError("frame_size must be >= 1")
|
||||||
|
# if max_interval < 1:
|
||||||
|
# raise ValueError("max_interval must be >= 1")
|
||||||
|
|
||||||
# Collect all video folders
|
# Collect all video folders and their frame files
|
||||||
self.video_folders = []
|
self.video_folders = []
|
||||||
|
self.video_frame_files = [] # list of list of Path objects
|
||||||
for item in self.root_dir.iterdir():
|
for item in self.root_dir.iterdir():
|
||||||
if item.is_dir():
|
if item.is_dir():
|
||||||
self.video_folders.append(item)
|
self.video_folders.append(item)
|
||||||
|
# Get all frame files
|
||||||
|
frame_files = sorted([f for f in item.iterdir()
|
||||||
|
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
|
||||||
|
self.video_frame_files.append(frame_files)
|
||||||
|
|
||||||
if len(self.video_folders) == 0:
|
if len(self.video_folders) == 0:
|
||||||
raise ValueError(f"No video folders found in {root_dir}")
|
raise ValueError(f"No video folders found in {root_dir}")
|
||||||
|
|
||||||
# Build frame index: list of (video_idx, start_frame_idx)
|
# Build frame index: list of (video_idx, start_frame_idx)
|
||||||
self.frame_indices = []
|
self.frame_indices = []
|
||||||
for video_idx, video_folder in enumerate(self.video_folders):
|
for video_idx, frame_files in enumerate(self.video_frame_files):
|
||||||
# Get all frame files
|
# Minimum frames needed considering max interval
|
||||||
frame_files = sorted([f for f in video_folder.iterdir()
|
min_frames_needed = num_frames * max_interval + 1
|
||||||
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
|
if len(frame_files) < min_frames_needed:
|
||||||
|
|
||||||
if len(frame_files) < num_frames + 1:
|
|
||||||
continue # Skip videos with insufficient frames
|
continue # Skip videos with insufficient frames
|
||||||
|
|
||||||
# Add all possible starting positions
|
# Add all possible starting positions
|
||||||
for start_idx in range(len(frame_files) - num_frames):
|
# Ensure that for any interval up to max_interval, all frames are within bounds
|
||||||
|
max_start = len(frame_files) - num_frames * max_interval
|
||||||
|
for start_idx in range(max_start):
|
||||||
self.frame_indices.append((video_idx, start_idx))
|
self.frame_indices.append((video_idx, start_idx))
|
||||||
|
|
||||||
if len(self.frame_indices) == 0:
|
if len(self.frame_indices) == 0:
|
||||||
@@ -80,14 +92,12 @@ class VideoFrameDataset(Dataset):
|
|||||||
else:
|
else:
|
||||||
self.transform = transform
|
self.transform = transform
|
||||||
|
|
||||||
# Normalization for Y channel (single channel)
|
# Simple normalization to [-1, 1] range (不使用ImageNet标准化)
|
||||||
# Compute average of ImageNet RGB means and stds
|
# Convert pixel values [0, 255] to [-1, 1]
|
||||||
y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
# This matches the model's tanh output range
|
||||||
y_std = (0.229 + 0.224 + 0.225) / 3.0
|
self.normalize = None # We'll handle normalization manually
|
||||||
self.normalize = transforms.Normalize(
|
|
||||||
mean=[y_mean],
|
# print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]")
|
||||||
std=[y_std]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _default_transform(self):
|
def _default_transform(self):
|
||||||
"""Default transform with augmentation for training"""
|
"""Default transform with augmentation for training"""
|
||||||
@@ -105,9 +115,12 @@ class VideoFrameDataset(Dataset):
|
|||||||
|
|
||||||
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
|
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
|
||||||
"""Load a single frame as PIL Image"""
|
"""Load a single frame as PIL Image"""
|
||||||
video_folder = self.video_folders[video_idx]
|
frame_files = self.video_frame_files[video_idx]
|
||||||
frame_files = sorted([f for f in video_folder.iterdir()
|
if frame_idx < 0 or frame_idx >= len(frame_files):
|
||||||
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
|
raise IndexError(
|
||||||
|
f"Frame index {frame_idx} out of range for video {video_idx} "
|
||||||
|
f"(0-{len(frame_files)-1})"
|
||||||
|
)
|
||||||
frame_path = frame_files[frame_idx]
|
frame_path = frame_files[frame_idx]
|
||||||
return Image.open(frame_path).convert('RGB')
|
return Image.open(frame_path).convert('RGB')
|
||||||
|
|
||||||
@@ -144,19 +157,21 @@ class VideoFrameDataset(Dataset):
|
|||||||
if self.transform:
|
if self.transform:
|
||||||
target_frame = self.transform(target_frame)
|
target_frame = self.transform(target_frame)
|
||||||
|
|
||||||
# Convert to tensors, normalize, and convert to grayscale (Y channel)
|
# Convert to tensors and convert to grayscale (Y channel)
|
||||||
input_tensors = []
|
input_tensors = []
|
||||||
for frame in input_frames:
|
for frame in input_frames:
|
||||||
tensor = transforms.ToTensor()(frame) # [3, H, W]
|
tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1]
|
||||||
# Convert RGB to grayscale using weighted sum
|
# Convert RGB to grayscale using weighted sum
|
||||||
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
|
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
|
||||||
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W]
|
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1]
|
||||||
gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast)
|
# Normalize from [0, 1] to [-1, 1]
|
||||||
|
gray = gray * 2 - 1 # [0,1] -> [-1,1]
|
||||||
input_tensors.append(gray)
|
input_tensors.append(gray)
|
||||||
|
|
||||||
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W]
|
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1]
|
||||||
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
|
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
|
||||||
target_gray = self.normalize(target_gray)
|
# Normalize from [0, 1] to [-1, 1]
|
||||||
|
target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1]
|
||||||
|
|
||||||
# Concatenate input frames along channel dimension
|
# Concatenate input frames along channel dimension
|
||||||
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
|
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
|
||||||
@@ -167,52 +182,52 @@ class VideoFrameDataset(Dataset):
|
|||||||
return input_concatenated, target_gray, temporal_idx
|
return input_concatenated, target_gray, temporal_idx
|
||||||
|
|
||||||
|
|
||||||
class SyntheticVideoDataset(Dataset):
|
# class SyntheticVideoDataset(Dataset):
|
||||||
"""
|
# """
|
||||||
Synthetic dataset for testing - generates random frames
|
# Synthetic dataset for testing - generates random frames
|
||||||
"""
|
# """
|
||||||
def __init__(self,
|
# def __init__(self,
|
||||||
num_samples: int = 1000,
|
# num_samples: int = 1000,
|
||||||
num_frames: int = 3,
|
# num_frames: int = 3,
|
||||||
frame_size: int = 224,
|
# frame_size: int = 224,
|
||||||
is_train: bool = True):
|
# is_train: bool = True):
|
||||||
self.num_samples = num_samples
|
# self.num_samples = num_samples
|
||||||
self.num_frames = num_frames
|
# self.num_frames = num_frames
|
||||||
self.frame_size = frame_size
|
# self.frame_size = frame_size
|
||||||
self.is_train = is_train
|
# self.is_train = is_train
|
||||||
|
|
||||||
# Normalization for Y channel (single channel)
|
# # Normalization for Y channel (single channel)
|
||||||
y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
# y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
||||||
y_std = (0.229 + 0.224 + 0.225) / 3.0
|
# y_std = (0.229 + 0.224 + 0.225) / 3.0
|
||||||
self.normalize = transforms.Normalize(
|
# self.normalize = transforms.Normalize(
|
||||||
mean=[y_mean],
|
# mean=[y_mean],
|
||||||
std=[y_std]
|
# std=[y_std]
|
||||||
)
|
# )
|
||||||
|
|
||||||
def __len__(self):
|
# def __len__(self):
|
||||||
return self.num_samples
|
# return self.num_samples
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
# def __getitem__(self, idx):
|
||||||
# Generate random "frames" (noise with temporal correlation)
|
# # Generate random "frames" (noise with temporal correlation)
|
||||||
input_frames = []
|
# input_frames = []
|
||||||
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
|
# prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
|
||||||
|
|
||||||
for i in range(self.num_frames):
|
# for i in range(self.num_frames):
|
||||||
# Add some temporal correlation
|
# # Add some temporal correlation
|
||||||
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
# frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||||
frame = torch.clamp(frame, -1, 1)
|
# frame = torch.clamp(frame, -1, 1)
|
||||||
input_frames.append(self.normalize(frame))
|
# input_frames.append(self.normalize(frame))
|
||||||
prev_frame = frame
|
# prev_frame = frame
|
||||||
|
|
||||||
# Target frame (next in sequence)
|
# # Target frame (next in sequence)
|
||||||
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
# target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||||
target_frame = torch.clamp(target_frame, -1, 1)
|
# target_frame = torch.clamp(target_frame, -1, 1)
|
||||||
target_tensor = self.normalize(target_frame)
|
# target_tensor = self.normalize(target_frame)
|
||||||
|
|
||||||
# Concatenate inputs
|
# # Concatenate inputs
|
||||||
input_concatenated = torch.cat(input_frames, dim=0)
|
# input_concatenated = torch.cat(input_frames, dim=0)
|
||||||
|
|
||||||
# Temporal index
|
# # Temporal index
|
||||||
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
# temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
||||||
|
|
||||||
return input_concatenated, target_tensor, temporal_idx
|
# return input_concatenated, target_tensor, temporal_idx
|
||||||
Reference in New Issue
Block a user