Compare commits

...

3 Commits

Author SHA1 Message Date
f7601e9170 初步可跑通,但loss计算有问题,不收敛 2026-01-08 09:43:23 +08:00
efd76bccd2 update .gitignore 2026-01-07 15:54:52 +08:00
4888619f9d iniit .gitignore 2026-01-07 15:54:20 +08:00
11 changed files with 657 additions and 62 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
.vscode/
__pycache__/
venv/
runs/

57
dist_temporal_train.sh Executable file
View File

@@ -0,0 +1,57 @@
#!/usr/bin/env bash
# Distributed training script for SwiftFormerTemporal
# Usage: ./dist_temporal_train.sh <DATA_PATH> <NUM_GPUS> [OPTIONS]
DATA_PATH=$1
NUM_GPUS=$2
# Shift arguments to pass remaining options to python script
shift 2
# Default parameters
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
BATCH_SIZE=${BATCH_SIZE:-32}
EPOCHS=${EPOCHS:-100}
LR=${LR:-1e-3}
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
echo "Starting distributed training with $NUM_GPUS GPUs"
echo "Data path: $DATA_PATH"
echo "Model: $MODEL"
echo "Batch size: $BATCH_SIZE"
echo "Epochs: $EPOCHS"
echo "Output dir: $OUTPUT_DIR"
# Check if torch.distributed.launch or torchrun should be used
# For newer PyTorch versions (>=1.9), torchrun is recommended
PYTHON_VERSION=$(python -c "import torch; print(torch.__version__)")
echo "PyTorch version: $PYTHON_VERSION"
# Use torchrun for newer PyTorch versions
if [[ "$PYTHON_VERSION" =~ ^2\. ]] || [[ "$PYTHON_VERSION" =~ ^1\.1[0-9]\. ]]; then
echo "Using torchrun (PyTorch >=1.10)"
torchrun --nproc_per_node=$NUM_GPUS --master_port=12345 main_temporal.py \
--data-path "$DATA_PATH" \
--model "$MODEL" \
--batch-size $BATCH_SIZE \
--epochs $EPOCHS \
--lr $LR \
--output-dir "$OUTPUT_DIR" \
"$@"
else
echo "Using torch.distributed.launch"
python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port=12345 --use_env main_temporal.py \
--data-path "$DATA_PATH" \
--model "$MODEL" \
--batch-size $BATCH_SIZE \
--epochs $EPOCHS \
--lr $LR \
--output-dir "$OUTPUT_DIR" \
"$@"
fi
# For single-node multi-GPU training with specific options:
# --world-size 1 --rank 0 --dist-url 'tcp://localhost:12345'
echo "Training completed. Check logs in $OUTPUT_DIR"

View File

@@ -6,8 +6,10 @@ import datetime
import numpy as np import numpy as np
import time import time
import torch import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import json import json
import os
from pathlib import Path from pathlib import Path
from timm.scheduler import create_scheduler from timm.scheduler import create_scheduler
@@ -20,6 +22,17 @@ from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTempo
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
from util.frame_losses import MultiTaskLoss from util.frame_losses import MultiTaskLoss
# Try to import TensorBoard
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
try:
from tensorboardX import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
def get_args_parser(): def get_args_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@@ -48,10 +61,48 @@ def get_args_parser():
# Training parameters # Training parameters
parser.add_argument('--batch-size', default=32, type=int) parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=100, type=int) parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)') # Optimizer parameters (required by timm's create_optimizer)
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='agc',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.05, parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)') help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
# Learning rate schedule parameters (required by timm's create_scheduler)
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Loss parameters # Loss parameters
parser.add_argument('--frame-weight', type=float, default=1.0, parser.add_argument('--frame-weight', type=float, default=1.0,
@@ -90,26 +141,26 @@ def get_args_parser():
parser.add_argument('--dist-url', default='env://', parser.add_argument('--dist-url', default='env://',
help='url used to set up distributed training') help='url used to set up distributed training')
# TensorBoard logging
parser.add_argument('--tensorboard-logdir', default='./runs',
type=str, help='TensorBoard log directory')
parser.add_argument('--log-images', action='store_true',
help='Log sample images to TensorBoard')
parser.add_argument('--image-log-freq', default=100, type=int,
help='Frequency of logging images (in iterations)')
return parser return parser
def build_dataset(is_train, args): def build_dataset(is_train, args):
"""Build video frame dataset""" """Build video frame dataset"""
if args.dataset_type == 'synthetic': dataset = VideoFrameDataset(
dataset = SyntheticVideoDataset( root_dir=args.data_path,
num_samples=1000 if is_train else 200, num_frames=args.num_frames,
num_frames=args.num_frames, frame_size=args.frame_size,
frame_size=args.frame_size, is_train=is_train,
is_train=is_train max_interval=args.max_interval
) )
else:
dataset = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train,
max_interval=args.max_interval
)
return dataset return dataset
@@ -203,14 +254,18 @@ def main(args):
# Create scheduler # Create scheduler
lr_scheduler, _ = create_scheduler(args, optimizer) lr_scheduler, _ = create_scheduler(args, optimizer)
# Create loss function # Create loss function - simple MSE for Y channel prediction
criterion = MultiTaskLoss( class MSELossWrapper(nn.Module):
frame_weight=args.frame_weight, def __init__(self):
contrastive_weight=args.contrastive_weight, super().__init__()
l1_weight=args.l1_weight, self.mse = nn.MSELoss()
ssim_weight=args.ssim_weight,
use_contrastive=not args.no_contrastive def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
) loss = self.mse(pred_frame, target_frame)
loss_dict = {'mse': loss}
return loss, loss_dict
criterion = MSELossWrapper()
# Resume from checkpoint # Resume from checkpoint
output_dir = Path(args.output_dir) output_dir = Path(args.output_dir)
@@ -231,6 +286,21 @@ def main(args):
if 'scaler' in checkpoint: if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler']) loss_scaler.load_state_dict(checkpoint['scaler'])
# Initialize TensorBoard writer
writer = None
if TENSORBOARD_AVAILABLE and utils.is_main_process():
from datetime import datetime
# Create log directory with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_dir = os.path.join(args.tensorboard_logdir, f"exp_{timestamp}")
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")
print(f"To view logs, run: tensorboard --logdir={log_dir}")
elif not TENSORBOARD_AVAILABLE and utils.is_main_process():
print("Warning: TensorBoard not available. Install tensorboard or tensorboardX.")
print("Training will continue without TensorBoard logging.")
if args.eval: if args.eval:
test_stats = evaluate(data_loader_val, model, criterion, device) test_stats = evaluate(data_loader_val, model, criterion, device)
print(f"Test stats: {test_stats}") print(f"Test stats: {test_stats}")
@@ -239,14 +309,18 @@ def main(args):
print(f"Start training for {args.epochs} epochs") print(f"Start training for {args.epochs} epochs")
start_time = time.time() start_time = time.time()
# Global step counter for TensorBoard
global_step = 0
for epoch in range(args.start_epoch, args.epochs): for epoch in range(args.start_epoch, args.epochs):
if args.distributed: if args.distributed:
data_loader_train.sampler.set_epoch(epoch) data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch( train_stats, global_step = train_one_epoch(
model, criterion, data_loader_train, model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler, optimizer, device, epoch, loss_scaler,
model_ema=model_ema model_ema=model_ema, writer=writer,
global_step=global_step, args=args
) )
lr_scheduler.step(epoch) lr_scheduler.step(epoch)
@@ -266,10 +340,10 @@ def main(args):
# Evaluate # Evaluate
if epoch % 5 == 0 or epoch == args.epochs - 1: if epoch % 5 == 0 or epoch == args.epochs - 1:
test_stats = evaluate(data_loader_val, model, criterion, device) test_stats = evaluate(data_loader_val, model, criterion, device, writer=writer, epoch=epoch)
print(f"Epoch {epoch}: Test stats: {test_stats}") print(f"Epoch {epoch}: Test stats: {test_stats}")
# Log stats # Log stats to text file
log_stats = { log_stats = {
**{f'train_{k}': v for k, v in train_stats.items()}, **{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()},
@@ -285,17 +359,23 @@ def main(args):
total_time_str = str(datetime.timedelta(seconds=int(total_time))) total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f'Training time {total_time_str}') print(f'Training time {total_time_str}')
# Close TensorBoard writer
if writer is not None:
writer.close()
print(f"TensorBoard logs saved to: {writer.log_dir}")
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0, clip_mode='norm', model_ema=None, **kwargs): clip_grad=0, clip_mode='norm', model_ema=None, writer=None,
global_step=0, args=None, **kwargs):
model.train() model.train()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = f'Epoch: [{epoch}]' header = f'Epoch: [{epoch}]'
print_freq = 10 print_freq = 10
for input_frames, target_frames, temporal_indices in metric_logger.log_every( for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
data_loader, print_freq, header): metric_logger.log_every(data_loader, print_freq, header)):
input_frames = input_frames.to(device, non_blocking=True) input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True)
@@ -322,19 +402,51 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
if model_ema is not None: if model_ema is not None:
model_ema.update(model) model_ema.update(model)
# Log to TensorBoard
if writer is not None:
# Log scalar metrics every iteration
writer.add_scalar('train/loss', loss_value, global_step)
writer.add_scalar('train/lr', optimizer.param_groups[0]["lr"], global_step)
# Log individual loss components
for k, v in loss_dict.items():
if torch.is_tensor(v):
writer.add_scalar(f'train/{k}', v.item(), global_step)
else:
writer.add_scalar(f'train/{k}', v, global_step)
# Log images periodically
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
with torch.no_grad():
# Take first sample from batch for visualization
pred_vis, _ = model(input_frames[:1])
# Convert to appropriate format for TensorBoard
# Assuming frames are in [B, C, H, W] format
writer.add_images('train/input', input_frames[:1], global_step)
writer.add_images('train/target', target_frames[:1], global_step)
writer.add_images('train/predicted', pred_vis[:1], global_step)
# Update metrics # Update metrics
metric_logger.update(loss=loss_value) metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"]) metric_logger.update(lr=optimizer.param_groups[0]["lr"])
for k, v in loss_dict.items(): for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v}) metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
global_step += 1
metric_logger.synchronize_between_processes() metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger) print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
# Log epoch-level metrics
if writer is not None:
for k, meter in metric_logger.meters.items():
writer.add_scalar(f'train_epoch/{k}', meter.global_avg, epoch)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, global_step
@torch.no_grad() @torch.no_grad()
def evaluate(data_loader, model, criterion, device): def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
model.eval() model.eval()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:' header = 'Test:'
@@ -359,6 +471,12 @@ def evaluate(data_loader, model, criterion, device):
metric_logger.synchronize_between_processes() metric_logger.synchronize_between_processes()
print('* Test stats:', metric_logger) print('* Test stats:', metric_logger)
# Log validation metrics to TensorBoard
if writer is not None:
for k, meter in metric_logger.meters.items():
writer.add_scalar(f'val/{k}', meter.global_avg, epoch)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()} return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

View File

@@ -6,9 +6,9 @@ import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import DropPath, trunc_normal_ from timm.layers import DropPath, trunc_normal_
from timm.models.registry import register_model from timm.models import register_model
from timm.models.layers.helpers import to_2tuple from timm.layers import to_2tuple
import einops import einops
SwiftFormer_width = { SwiftFormer_width = {

View File

@@ -7,7 +7,7 @@ from .swiftformer import (
SwiftFormer, SwiftFormer_depth, SwiftFormer_width, SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
stem, Embedding, Stage stem, Embedding, Stage
) )
from timm.models.layers import DropPath, trunc_normal_ from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module): class DecoderBlock(nn.Module):
@@ -31,7 +31,7 @@ class DecoderBlock(nn.Module):
class FramePredictionDecoder(nn.Module): class FramePredictionDecoder(nn.Module):
"""Lightweight decoder for frame prediction with optional skip connections""" """Lightweight decoder for frame prediction with optional skip connections"""
def __init__(self, embed_dims, output_channels=3, use_skip=False): def __init__(self, embed_dims, output_channels=1, use_skip=False):
super().__init__() super().__init__()
self.use_skip = use_skip self.use_skip = use_skip
# Reverse the embed_dims for decoder # Reverse the embed_dims for decoder
@@ -53,11 +53,11 @@ class FramePredictionDecoder(nn.Module):
decoder_dims[2], decoder_dims[3], decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1 kernel_size=3, stride=2, padding=1, output_padding=1
)) ))
# stage2 to original resolution (4x upsampling total) # stage2 to original resolution (now 8x upsampling total with stride 4)
self.blocks.append(nn.Sequential( self.blocks.append(nn.Sequential(
nn.ConvTranspose2d( nn.ConvTranspose2d(
decoder_dims[3], 32, decoder_dims[3], 32,
kernel_size=3, stride=2, padding=1, output_padding=1 kernel_size=3, stride=4, padding=1, output_padding=3
), ),
nn.BatchNorm2d(32), nn.BatchNorm2d(32),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
@@ -104,7 +104,7 @@ class SwiftFormerTemporal(nn.Module):
""" """
SwiftFormer with temporal input for frame prediction. SwiftFormer with temporal input for frame prediction.
Input: [B, num_frames, H, W] (Y channel only) Input: [B, num_frames, H, W] (Y channel only)
Output: predicted frame [B, 3, H, W] and optional representation Output: predicted frame [B, 1, H, W] and optional representation
""" """
def __init__(self, def __init__(self,
model_name='XS', model_name='XS',
@@ -155,7 +155,7 @@ class SwiftFormerTemporal(nn.Module):
# Frame prediction decoder # Frame prediction decoder
if use_decoder: if use_decoder:
self.decoder = FramePredictionDecoder(embed_dims, output_channels=3) self.decoder = FramePredictionDecoder(embed_dims, output_channels=1)
# Representation head for pose/velocity prediction # Representation head for pose/velocity prediction
if use_representation_head: if use_representation_head:
@@ -201,7 +201,7 @@ class SwiftFormerTemporal(nn.Module):
x: input frames of shape [B, num_frames, H, W] x: input frames of shape [B, num_frames, H, W]
Returns: Returns:
If return_features is False: If return_features is False:
pred_frame: predicted frame [B, 3, H, W] (or None) pred_frame: predicted frame [B, 1, H, W] (or None)
representation: optional representation vector [B, representation_dim] (or None) representation: optional representation vector [B, representation_dim] (or None)
If return_features is True: If return_features is True:
pred_frame, representation, features (list of stage features) pred_frame, representation, features (list of stage features)

26
multi_gpu_temporal_train.sh Executable file
View File

@@ -0,0 +1,26 @@
#!/usr/bin/env bash
# Simple multi-GPU training script for SwiftFormerTemporal
# Usage: ./multi_gpu_temporal_train.sh <NUM_GPUS> [OPTIONS]
NUM_GPUS=${1:-2}
shift
echo "Starting multi-GPU training with $NUM_GPUS GPUs"
# Set environment variables for distributed training
export MASTER_PORT=12345
export MASTER_ADDR=localhost
export WORLD_SIZE=$NUM_GPUS
# Launch training
torchrun --nproc_per_node=$NUM_GPUS --master_port=$MASTER_PORT main_temporal.py \
--data-path "./videos" \
--model SwiftFormerTemporal_XS \
--batch-size 32 \
--epochs 100 \
--lr 1e-3 \
--output-dir "./temporal_output_multi" \
--num-workers 8 \
--pin-mem \
"$@"

0
temporal_train.sh Normal file
View File

45
test_cuda.py Normal file
View File

@@ -0,0 +1,45 @@
import torch
def test_cuda_availability():
"""全面测试CUDA可用性"""
print("="*50)
print("PyTorch CUDA 测试")
print("="*50)
# 基本信息
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
if not torch.cuda.is_available():
print("CUDA不可用可能原因")
print("1. 未安装CUDA驱动")
print("2. 安装的是CPU版本的PyTorch")
print("3. CUDA版本与PyTorch不匹配")
return False
# 设备信息
device_count = torch.cuda.device_count()
print(f"发现 {device_count} 个CUDA设备")
for i in range(device_count):
print(f"\n设备 {i}:")
print(f" 名称: {torch.cuda.get_device_name(i)}")
print(f" 内存总量: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
print(f" 计算能力: {torch.cuda.get_device_properties(i).major}.{torch.cuda.get_device_properties(i).minor}")
# 简单张量测试
print("\n运行CUDA测试...")
try:
x = torch.randn(3, 3).cuda()
y = torch.randn(3, 3).cuda()
z = x + y
print("CUDA计算测试: 成功!")
print(f"设备上的张量形状: {z.shape}")
return True
except Exception as e:
print(f"CUDA计算测试失败: {e}")
return False
if __name__ == "__main__":
test_cuda_availability()

33
test_import.py Normal file
View File

@@ -0,0 +1,33 @@
#!/usr/bin/env python3
"""
测试 timm 导入是否正常工作
"""
import sys
print("Python version:", sys.version)
try:
from timm.layers import to_2tuple, DropPath, trunc_normal_
from timm.models import register_model
print("✓ 成功导入 timm.layers.to_2tuple")
print("✓ 成功导入 timm.layers.DropPath")
print("✓ 成功导入 timm.layers.trunc_normal_")
print("✓ 成功导入 timm.models.register_model")
except ImportError as e:
print(f"✗ 导入失败: {e}")
sys.exit(1)
try:
from models.swiftformer import SwiftFormer_XS
print("✓ 成功导入 SwiftFormer_XS")
except ImportError as e:
print(f"✗ 导入 SwiftFormer_XS 失败: {e}")
sys.exit(1)
try:
from models.swiftformer_temporal import SwiftFormerTemporal_XS
print("✓ 成功导入 SwiftFormerTemporal_XS")
except ImportError as e:
print(f"✗ 导入 SwiftFormerTemporal_XS 失败: {e}")
sys.exit(1)
print("\n✅ 所有导入测试通过!")

View File

@@ -80,10 +80,13 @@ class VideoFrameDataset(Dataset):
else: else:
self.transform = transform self.transform = transform
# Normalization (ImageNet stats) # Normalization for Y channel (single channel)
# Compute average of ImageNet RGB means and stds
y_mean = (0.485 + 0.456 + 0.406) / 3.0
y_std = (0.229 + 0.224 + 0.225) / 3.0
self.normalize = transforms.Normalize( self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], mean=[y_mean],
std=[0.229, 0.224, 0.225] std=[y_std]
) )
def _default_transform(self): def _default_transform(self):
@@ -114,8 +117,8 @@ class VideoFrameDataset(Dataset):
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Returns: Returns:
input_frames: [3 * num_frames, H, W] concatenated input frames input_frames: [num_frames, H, W] concatenated input frames (Y channel only)
target_frame: [3, H, W] target frame to predict target_frame: [1, H, W] target frame to predict (Y channel only)
temporal_idx: temporal index of target frame (for contrastive loss) temporal_idx: temporal index of target frame (for contrastive loss)
""" """
video_idx, start_idx = self.frame_indices[idx] video_idx, start_idx = self.frame_indices[idx]
@@ -141,23 +144,27 @@ class VideoFrameDataset(Dataset):
if self.transform: if self.transform:
target_frame = self.transform(target_frame) target_frame = self.transform(target_frame)
# Convert to tensors and normalize # Convert to tensors, normalize, and convert to grayscale (Y channel)
input_tensors = [] input_tensors = []
for frame in input_frames: for frame in input_frames:
tensor = transforms.ToTensor()(frame) tensor = transforms.ToTensor()(frame) # [3, H, W]
tensor = self.normalize(tensor) # Convert RGB to grayscale using weighted sum
input_tensors.append(tensor) # Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W]
gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast)
input_tensors.append(gray)
target_tensor = transforms.ToTensor()(target_frame) target_tensor = transforms.ToTensor()(target_frame) # [3, H, W]
target_tensor = self.normalize(target_tensor) target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
target_gray = self.normalize(target_gray)
# Concatenate input frames along channel dimension # Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0) input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
# Temporal index (for contrastive loss) # Temporal index (for contrastive loss)
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx return input_concatenated, target_gray, temporal_idx
class SyntheticVideoDataset(Dataset): class SyntheticVideoDataset(Dataset):
@@ -174,10 +181,12 @@ class SyntheticVideoDataset(Dataset):
self.frame_size = frame_size self.frame_size = frame_size
self.is_train = is_train self.is_train = is_train
# Normalization # Normalization for Y channel (single channel)
y_mean = (0.485 + 0.456 + 0.406) / 3.0
y_std = (0.229 + 0.224 + 0.225) / 3.0
self.normalize = transforms.Normalize( self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], mean=[y_mean],
std=[0.229, 0.224, 0.225] std=[y_std]
) )
def __len__(self): def __len__(self):

303
video_preprocessor.py Normal file
View File

@@ -0,0 +1,303 @@
#!/usr/bin/env python3
"""
视频预处理脚本 - 将MP4视频转换为224x224帧图像
支持多线程并发处理、进度条显示和中断恢复功能
"""
import os
import sys
import json
import argparse
import subprocess
import threading
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import time
from typing import List, Dict, Optional
class VideoPreprocessor:
"""视频预处理器,支持多线程和中断恢复"""
def __init__(self,
input_dir: str,
output_dir: str,
frame_size: int = 224,
fps: int = 30,
num_workers: int = 4,
quality: int = 2,
resume: bool = True):
"""
初始化预处理器
Args:
input_dir: 输入视频目录
output_dir: 输出帧目录
frame_size: 帧大小(正方形)
fps: 提取帧率
num_workers: 并发工作线程数
quality: JPEG质量 (1-31, 数值越小质量越高)
resume: 是否启用中断恢复
"""
self.input_dir = Path(input_dir)
self.output_dir = Path(output_dir)
self.frame_size = frame_size
self.fps = fps
self.num_workers = num_workers
self.quality = quality
self.resume = resume
# 状态文件路径
self.state_file = self.output_dir / ".preprocessing_state.json"
# 创建输出目录
self.output_dir.mkdir(parents=True, exist_ok=True)
# 初始化状态
self.state = self._load_state()
# 收集所有视频文件
self.video_files = self._collect_video_files()
def _load_state(self) -> Dict:
"""加载处理状态"""
if self.resume and self.state_file.exists():
try:
with open(self.state_file, 'r') as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
print(f"警告: 无法读取状态文件,将重新开始处理")
return {
"completed": [],
"failed": [],
"total_processed": 0,
"start_time": None,
"last_update": None
}
def _save_state(self):
"""保存处理状态"""
self.state["last_update"] = time.time()
try:
with open(self.state_file, 'w') as f:
json.dump(self.state, f, indent=2)
except IOError as e:
print(f"警告: 无法保存状态文件: {e}")
def _collect_video_files(self) -> List[Path]:
"""收集所有需要处理的视频文件"""
video_files = []
for file_path in self.input_dir.glob("*.mp4"):
if file_path.name not in self.state["completed"]:
video_files.append(file_path)
return sorted(video_files)
def _parse_video_name(self, video_path: Path) -> Dict[str, str]:
"""解析视频文件名使用完整文件名作为ID"""
name_without_ext = video_path.stem
# 直接使用完整文件名作为ID确保每个mp4文件有独立的输出目录
return {
"video_id": name_without_ext,
"start_frame": "unknown",
"end_frame": "unknown",
"full_name": name_without_ext
}
def _extract_frames(self, video_path: Path) -> bool:
"""提取单个视频的帧"""
try:
# 解析视频名称
video_info = self._parse_video_name(video_path)
output_subdir = self.output_dir / video_info["video_id"]
output_subdir.mkdir(exist_ok=True)
# 构建FFmpeg命令
output_pattern = output_subdir / "frame_%04d.jpg"
cmd = [
"ffmpeg",
"-i", str(video_path),
"-vf", f"fps={self.fps},scale={self.frame_size}:{self.frame_size}",
"-q:v", str(self.quality),
"-y", # 覆盖输出文件
str(output_pattern)
]
# 执行FFmpeg命令
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300 # 5分钟超时
)
if result.returncode != 0:
print(f"FFmpeg错误处理 {video_path.name}: {result.stderr}")
return False
# 验证输出帧数量
output_frames = list(output_subdir.glob("frame_*.jpg"))
if len(output_frames) == 0:
print(f"警告: {video_path.name} 没有生成任何帧")
return False
return True
except subprocess.TimeoutExpired:
print(f"超时处理 {video_path.name}")
return False
except Exception as e:
print(f"处理 {video_path.name} 时发生错误: {e}")
return False
def _process_video(self, video_path: Path) -> tuple[bool, str]:
"""处理单个视频文件"""
video_name = video_path.name
try:
success = self._extract_frames(video_path)
if success:
self.state["completed"].append(video_name)
if video_name in self.state["failed"]:
self.state["failed"].remove(video_name)
return True, video_name
else:
self.state["failed"].append(video_name)
return False, video_name
except Exception as e:
print(f"处理 {video_name} 时发生异常: {e}")
self.state["failed"].append(video_name)
return False, video_name
def process_all_videos(self):
"""处理所有视频文件"""
if not self.video_files:
print("没有找到需要处理的视频文件")
return
print(f"找到 {len(self.video_files)} 个待处理视频文件")
print(f"输出目录: {self.output_dir}")
print(f"帧大小: {self.frame_size}x{self.frame_size}")
print(f"帧率: {self.fps} fps")
print(f"并发线程数: {self.num_workers}")
if self.state["completed"]:
print(f"跳过 {len(self.state['completed'])} 个已处理的视频")
# 记录开始时间
if self.state["start_time"] is None:
self.state["start_time"] = time.time()
# 创建进度条
with tqdm(total=len(self.video_files), desc="处理视频", unit="") as pbar:
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
# 提交所有任务
future_to_video = {
executor.submit(self._process_video, video_path): video_path
for video_path in self.video_files
}
# 处理完成的任务
for future in as_completed(future_to_video):
video_path = future_to_video[future]
try:
success, video_name = future.result()
if success:
pbar.set_postfix({"状态": "成功", "文件": video_name[:20]})
else:
pbar.set_postfix({"状态": "失败", "文件": video_name[:20]})
except Exception as e:
print(f"处理 {video_path.name} 时发生异常: {e}")
pbar.set_postfix({"状态": "异常", "文件": video_path.name[:20]})
pbar.update(1)
self.state["total_processed"] += 1
# 定期保存状态
if self.state["total_processed"] % 5 == 0:
self._save_state()
# 最终保存状态
self._save_state()
# 打印处理结果
self._print_summary()
def _print_summary(self):
"""打印处理摘要"""
print("\n" + "="*50)
print("处理完成摘要:")
print(f"总处理视频数: {len(self.state['completed'])}")
print(f"失败视频数: {len(self.state['failed'])}")
if self.state["failed"]:
print("\n失败的视频:")
for video_name in self.state["failed"]:
print(f" - {video_name}")
if self.state["start_time"]:
elapsed_time = time.time() - self.state["start_time"]
print(f"\n总耗时: {elapsed_time:.2f}")
if self.state["total_processed"] > 0:
avg_time = elapsed_time / self.state["total_processed"]
print(f"平均每个视频: {avg_time:.2f}")
print("="*50)
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="视频预处理脚本")
parser.add_argument("--input_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/sekai-real-drone", help="输入视频目录")
parser.add_argument("--output_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/processed", help="输出帧目录")
parser.add_argument("--size", type=int, default=224, help="帧大小 (默认: 224)")
parser.add_argument("--fps", type=int, default=10, help="提取帧率 (默认: 30)")
parser.add_argument("--workers", type=int, default=32, help="并发线程数 (默认: 4)")
parser.add_argument("--quality", type=int, default=2, help="JPEG质量 1-31 (默认: 2)")
parser.add_argument("--no-resume", action="store_true", help="不启用中断恢复")
args = parser.parse_args()
# 检查输入目录
if not Path(args.input_dir).exists():
print(f"错误: 输入目录不存在: {args.input_dir}")
sys.exit(1)
# 检查FFmpeg是否可用
try:
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
except (subprocess.CalledProcessError, FileNotFoundError):
print("错误: FFmpeg未安装或不在PATH中")
sys.exit(1)
# 创建预处理器并开始处理
preprocessor = VideoPreprocessor(
input_dir=args.input_dir,
output_dir=args.output_dir,
frame_size=args.size,
fps=args.fps,
num_workers=args.workers,
quality=args.quality,
resume=not args.no_resume
)
try:
preprocessor.process_all_videos()
except KeyboardInterrupt:
print("\n\n用户中断处理,状态已保存")
preprocessor._save_state()
print("可以使用相同命令恢复处理")
except Exception as e:
print(f"\n处理过程中发生错误: {e}")
preprocessor._save_state()
sys.exit(1)
if __name__ == "__main__":
main()