From f7601e91708c8add403b7efd0a2c0305d02d169f Mon Sep 17 00:00:00 2001 From: CaoWangrenbo Date: Thu, 8 Jan 2026 09:43:23 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E5=8F=AF=E8=B7=91=E9=80=9A?= =?UTF-8?q?=EF=BC=8C=E4=BD=86loss=E8=AE=A1=E7=AE=97=E6=9C=89=E9=97=AE?= =?UTF-8?q?=E9=A2=98=EF=BC=8C=E4=B8=8D=E6=94=B6=E6=95=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 +- dist_temporal_train.sh | 57 +++++++ main_temporal.py | 190 +++++++++++++++++---- models/swiftformer.py | 6 +- models/swiftformer_temporal.py | 14 +- multi_gpu_temporal_train.sh | 26 +++ temporal_train.sh | 0 test_cuda.py | 45 +++++ test_import.py | 33 ++++ util/video_dataset.py | 41 +++-- video_preprocessor.py | 303 +++++++++++++++++++++++++++++++++ 11 files changed, 656 insertions(+), 63 deletions(-) create mode 100755 dist_temporal_train.sh create mode 100755 multi_gpu_temporal_train.sh create mode 100644 temporal_train.sh create mode 100644 test_cuda.py create mode 100644 test_import.py create mode 100644 video_preprocessor.py diff --git a/.gitignore b/.gitignore index 0540009..73f5545 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ +.vscode/ __pycache__/ -venv/ \ No newline at end of file +venv/ +runs/ \ No newline at end of file diff --git a/dist_temporal_train.sh b/dist_temporal_train.sh new file mode 100755 index 0000000..1c35ec6 --- /dev/null +++ b/dist_temporal_train.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash + +# Distributed training script for SwiftFormerTemporal +# Usage: ./dist_temporal_train.sh [OPTIONS] + +DATA_PATH=$1 +NUM_GPUS=$2 + +# Shift arguments to pass remaining options to python script +shift 2 + +# Default parameters +MODEL=${MODEL:-"SwiftFormerTemporal_XS"} +BATCH_SIZE=${BATCH_SIZE:-32} +EPOCHS=${EPOCHS:-100} +LR=${LR:-1e-3} +OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"} + +echo "Starting distributed training with $NUM_GPUS GPUs" +echo "Data path: $DATA_PATH" +echo "Model: $MODEL" +echo "Batch size: $BATCH_SIZE" +echo "Epochs: $EPOCHS" +echo "Output dir: $OUTPUT_DIR" + +# Check if torch.distributed.launch or torchrun should be used +# For newer PyTorch versions (>=1.9), torchrun is recommended +PYTHON_VERSION=$(python -c "import torch; print(torch.__version__)") +echo "PyTorch version: $PYTHON_VERSION" + +# Use torchrun for newer PyTorch versions +if [[ "$PYTHON_VERSION" =~ ^2\. ]] || [[ "$PYTHON_VERSION" =~ ^1\.1[0-9]\. ]]; then + echo "Using torchrun (PyTorch >=1.10)" + torchrun --nproc_per_node=$NUM_GPUS --master_port=12345 main_temporal.py \ + --data-path "$DATA_PATH" \ + --model "$MODEL" \ + --batch-size $BATCH_SIZE \ + --epochs $EPOCHS \ + --lr $LR \ + --output-dir "$OUTPUT_DIR" \ + "$@" +else + echo "Using torch.distributed.launch" + python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port=12345 --use_env main_temporal.py \ + --data-path "$DATA_PATH" \ + --model "$MODEL" \ + --batch-size $BATCH_SIZE \ + --epochs $EPOCHS \ + --lr $LR \ + --output-dir "$OUTPUT_DIR" \ + "$@" +fi + +# For single-node multi-GPU training with specific options: +# --world-size 1 --rank 0 --dist-url 'tcp://localhost:12345' + +echo "Training completed. Check logs in $OUTPUT_DIR" \ No newline at end of file diff --git a/main_temporal.py b/main_temporal.py index 6cd8126..56f0a98 100644 --- a/main_temporal.py +++ b/main_temporal.py @@ -6,8 +6,10 @@ import datetime import numpy as np import time import torch +import torch.nn as nn import torch.backends.cudnn as cudnn import json +import os from pathlib import Path from timm.scheduler import create_scheduler @@ -20,6 +22,17 @@ from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTempo from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset from util.frame_losses import MultiTaskLoss +# Try to import TensorBoard +try: + from torch.utils.tensorboard import SummaryWriter + TENSORBOARD_AVAILABLE = True +except ImportError: + try: + from tensorboardX import SummaryWriter + TENSORBOARD_AVAILABLE = True + except ImportError: + TENSORBOARD_AVAILABLE = False + def get_args_parser(): parser = argparse.ArgumentParser( @@ -48,10 +61,48 @@ def get_args_parser(): # Training parameters parser.add_argument('--batch-size', default=32, type=int) parser.add_argument('--epochs', default=100, type=int) - parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', - help='learning rate (default: 1e-3)') + + # Optimizer parameters (required by timm's create_optimizer) + parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "adamw"') + parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') + parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: None, use opt default)') + parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') + parser.add_argument('--clip-mode', type=str, default='agc', + help='Gradient clipping mode. One of ("norm", "value", "agc")') + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=0.05, help='weight decay (default: 0.05)') + parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', + help='learning rate (default: 1e-3)') + + # Learning rate schedule parameters (required by timm's create_scheduler) + parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "cosine"') + parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', + help='learning rate noise on/off epoch percentages') + parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', + help='learning rate noise limit percent (default: 0.67)') + parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', + help='learning rate noise std-dev (default: 1.0)') + parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', + help='warmup learning rate (default: 1e-6)') + parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') + parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', + help='epoch interval to decay LR') + parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', + help='epochs to warmup LR, if scheduler supports') + parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', + help='epochs to cooldown LR at min_lr, after cyclic schedule ends') + parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', + help='patience epochs for Plateau LR scheduler (default: 10') + parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', + help='LR decay rate (default: 0.1)') # Loss parameters parser.add_argument('--frame-weight', type=float, default=1.0, @@ -90,26 +141,26 @@ def get_args_parser(): parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') + # TensorBoard logging + parser.add_argument('--tensorboard-logdir', default='./runs', + type=str, help='TensorBoard log directory') + parser.add_argument('--log-images', action='store_true', + help='Log sample images to TensorBoard') + parser.add_argument('--image-log-freq', default=100, type=int, + help='Frequency of logging images (in iterations)') + return parser def build_dataset(is_train, args): """Build video frame dataset""" - if args.dataset_type == 'synthetic': - dataset = SyntheticVideoDataset( - num_samples=1000 if is_train else 200, - num_frames=args.num_frames, - frame_size=args.frame_size, - is_train=is_train - ) - else: - dataset = VideoFrameDataset( - root_dir=args.data_path, - num_frames=args.num_frames, - frame_size=args.frame_size, - is_train=is_train, - max_interval=args.max_interval - ) + dataset = VideoFrameDataset( + root_dir=args.data_path, + num_frames=args.num_frames, + frame_size=args.frame_size, + is_train=is_train, + max_interval=args.max_interval + ) return dataset @@ -203,14 +254,18 @@ def main(args): # Create scheduler lr_scheduler, _ = create_scheduler(args, optimizer) - # Create loss function - criterion = MultiTaskLoss( - frame_weight=args.frame_weight, - contrastive_weight=args.contrastive_weight, - l1_weight=args.l1_weight, - ssim_weight=args.ssim_weight, - use_contrastive=not args.no_contrastive - ) + # Create loss function - simple MSE for Y channel prediction + class MSELossWrapper(nn.Module): + def __init__(self): + super().__init__() + self.mse = nn.MSELoss() + + def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None): + loss = self.mse(pred_frame, target_frame) + loss_dict = {'mse': loss} + return loss, loss_dict + + criterion = MSELossWrapper() # Resume from checkpoint output_dir = Path(args.output_dir) @@ -231,6 +286,21 @@ def main(args): if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) + # Initialize TensorBoard writer + writer = None + if TENSORBOARD_AVAILABLE and utils.is_main_process(): + from datetime import datetime + # Create log directory with timestamp + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + log_dir = os.path.join(args.tensorboard_logdir, f"exp_{timestamp}") + os.makedirs(log_dir, exist_ok=True) + writer = SummaryWriter(log_dir=log_dir) + print(f"TensorBoard logs will be saved to: {log_dir}") + print(f"To view logs, run: tensorboard --logdir={log_dir}") + elif not TENSORBOARD_AVAILABLE and utils.is_main_process(): + print("Warning: TensorBoard not available. Install tensorboard or tensorboardX.") + print("Training will continue without TensorBoard logging.") + if args.eval: test_stats = evaluate(data_loader_val, model, criterion, device) print(f"Test stats: {test_stats}") @@ -239,14 +309,18 @@ def main(args): print(f"Start training for {args.epochs} epochs") start_time = time.time() + # Global step counter for TensorBoard + global_step = 0 + for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) - train_stats = train_one_epoch( + train_stats, global_step = train_one_epoch( model, criterion, data_loader_train, optimizer, device, epoch, loss_scaler, - model_ema=model_ema + model_ema=model_ema, writer=writer, + global_step=global_step, args=args ) lr_scheduler.step(epoch) @@ -266,10 +340,10 @@ def main(args): # Evaluate if epoch % 5 == 0 or epoch == args.epochs - 1: - test_stats = evaluate(data_loader_val, model, criterion, device) + test_stats = evaluate(data_loader_val, model, criterion, device, writer=writer, epoch=epoch) print(f"Epoch {epoch}: Test stats: {test_stats}") - # Log stats + # Log stats to text file log_stats = { **{f'train_{k}': v for k, v in train_stats.items()}, **{f'test_{k}': v for k, v in test_stats.items()}, @@ -284,18 +358,24 @@ def main(args): total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print(f'Training time {total_time_str}') + + # Close TensorBoard writer + if writer is not None: + writer.close() + print(f"TensorBoard logs saved to: {writer.log_dir}") -def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, - clip_grad=0, clip_mode='norm', model_ema=None, **kwargs): +def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler, + clip_grad=0, clip_mode='norm', model_ema=None, writer=None, + global_step=0, args=None, **kwargs): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = f'Epoch: [{epoch}]' print_freq = 10 - for input_frames, target_frames, temporal_indices in metric_logger.log_every( - data_loader, print_freq, header): + for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate( + metric_logger.log_every(data_loader, print_freq, header)): input_frames = input_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True) @@ -305,7 +385,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los with torch.cuda.amp.autocast(): pred_frames, representations = model(input_frames) loss, loss_dict = criterion( - pred_frames, target_frames, + pred_frames, target_frames, representations, temporal_indices ) @@ -322,19 +402,51 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los if model_ema is not None: model_ema.update(model) + # Log to TensorBoard + if writer is not None: + # Log scalar metrics every iteration + writer.add_scalar('train/loss', loss_value, global_step) + writer.add_scalar('train/lr', optimizer.param_groups[0]["lr"], global_step) + + # Log individual loss components + for k, v in loss_dict.items(): + if torch.is_tensor(v): + writer.add_scalar(f'train/{k}', v.item(), global_step) + else: + writer.add_scalar(f'train/{k}', v, global_step) + + # Log images periodically + if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0: + with torch.no_grad(): + # Take first sample from batch for visualization + pred_vis, _ = model(input_frames[:1]) + # Convert to appropriate format for TensorBoard + # Assuming frames are in [B, C, H, W] format + writer.add_images('train/input', input_frames[:1], global_step) + writer.add_images('train/target', target_frames[:1], global_step) + writer.add_images('train/predicted', pred_vis[:1], global_step) + # Update metrics metric_logger.update(loss=loss_value) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) for k, v in loss_dict.items(): metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v}) + + global_step += 1 metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) - return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + # Log epoch-level metrics + if writer is not None: + for k, meter in metric_logger.meters.items(): + writer.add_scalar(f'train_epoch/{k}', meter.global_avg, epoch) + + return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, global_step @torch.no_grad() -def evaluate(data_loader, model, criterion, device): +def evaluate(data_loader, model, criterion, device, writer=None, epoch=0): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' @@ -359,6 +471,12 @@ def evaluate(data_loader, model, criterion, device): metric_logger.synchronize_between_processes() print('* Test stats:', metric_logger) + + # Log validation metrics to TensorBoard + if writer is not None: + for k, meter in metric_logger.meters.items(): + writer.add_scalar(f'val/{k}', meter.global_avg, epoch) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} diff --git a/models/swiftformer.py b/models/swiftformer.py index b545557..a93e3dd 100644 --- a/models/swiftformer.py +++ b/models/swiftformer.py @@ -6,9 +6,9 @@ import copy import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.layers import DropPath, trunc_normal_ -from timm.models.registry import register_model -from timm.models.layers.helpers import to_2tuple +from timm.layers import DropPath, trunc_normal_ +from timm.models import register_model +from timm.layers import to_2tuple import einops SwiftFormer_width = { diff --git a/models/swiftformer_temporal.py b/models/swiftformer_temporal.py index 6a105ca..3cac757 100644 --- a/models/swiftformer_temporal.py +++ b/models/swiftformer_temporal.py @@ -7,7 +7,7 @@ from .swiftformer import ( SwiftFormer, SwiftFormer_depth, SwiftFormer_width, stem, Embedding, Stage ) -from timm.models.layers import DropPath, trunc_normal_ +from timm.layers import DropPath, trunc_normal_ class DecoderBlock(nn.Module): @@ -31,7 +31,7 @@ class DecoderBlock(nn.Module): class FramePredictionDecoder(nn.Module): """Lightweight decoder for frame prediction with optional skip connections""" - def __init__(self, embed_dims, output_channels=3, use_skip=False): + def __init__(self, embed_dims, output_channels=1, use_skip=False): super().__init__() self.use_skip = use_skip # Reverse the embed_dims for decoder @@ -53,11 +53,11 @@ class FramePredictionDecoder(nn.Module): decoder_dims[2], decoder_dims[3], kernel_size=3, stride=2, padding=1, output_padding=1 )) - # stage2 to original resolution (4x upsampling total) + # stage2 to original resolution (now 8x upsampling total with stride 4) self.blocks.append(nn.Sequential( nn.ConvTranspose2d( decoder_dims[3], 32, - kernel_size=3, stride=2, padding=1, output_padding=1 + kernel_size=3, stride=4, padding=1, output_padding=3 ), nn.BatchNorm2d(32), nn.ReLU(inplace=True), @@ -104,7 +104,7 @@ class SwiftFormerTemporal(nn.Module): """ SwiftFormer with temporal input for frame prediction. Input: [B, num_frames, H, W] (Y channel only) - Output: predicted frame [B, 3, H, W] and optional representation + Output: predicted frame [B, 1, H, W] and optional representation """ def __init__(self, model_name='XS', @@ -155,7 +155,7 @@ class SwiftFormerTemporal(nn.Module): # Frame prediction decoder if use_decoder: - self.decoder = FramePredictionDecoder(embed_dims, output_channels=3) + self.decoder = FramePredictionDecoder(embed_dims, output_channels=1) # Representation head for pose/velocity prediction if use_representation_head: @@ -201,7 +201,7 @@ class SwiftFormerTemporal(nn.Module): x: input frames of shape [B, num_frames, H, W] Returns: If return_features is False: - pred_frame: predicted frame [B, 3, H, W] (or None) + pred_frame: predicted frame [B, 1, H, W] (or None) representation: optional representation vector [B, representation_dim] (or None) If return_features is True: pred_frame, representation, features (list of stage features) diff --git a/multi_gpu_temporal_train.sh b/multi_gpu_temporal_train.sh new file mode 100755 index 0000000..2ee1403 --- /dev/null +++ b/multi_gpu_temporal_train.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +# Simple multi-GPU training script for SwiftFormerTemporal +# Usage: ./multi_gpu_temporal_train.sh [OPTIONS] + +NUM_GPUS=${1:-2} +shift + +echo "Starting multi-GPU training with $NUM_GPUS GPUs" + +# Set environment variables for distributed training +export MASTER_PORT=12345 +export MASTER_ADDR=localhost +export WORLD_SIZE=$NUM_GPUS + +# Launch training +torchrun --nproc_per_node=$NUM_GPUS --master_port=$MASTER_PORT main_temporal.py \ + --data-path "./videos" \ + --model SwiftFormerTemporal_XS \ + --batch-size 32 \ + --epochs 100 \ + --lr 1e-3 \ + --output-dir "./temporal_output_multi" \ + --num-workers 8 \ + --pin-mem \ + "$@" \ No newline at end of file diff --git a/temporal_train.sh b/temporal_train.sh new file mode 100644 index 0000000..e69de29 diff --git a/test_cuda.py b/test_cuda.py new file mode 100644 index 0000000..7dfa3aa --- /dev/null +++ b/test_cuda.py @@ -0,0 +1,45 @@ +import torch + +def test_cuda_availability(): + """全面测试CUDA可用性""" + + print("="*50) + print("PyTorch CUDA 测试") + print("="*50) + + # 基本信息 + print(f"PyTorch版本: {torch.__version__}") + print(f"CUDA可用: {torch.cuda.is_available()}") + + if not torch.cuda.is_available(): + print("CUDA不可用,可能原因:") + print("1. 未安装CUDA驱动") + print("2. 安装的是CPU版本的PyTorch") + print("3. CUDA版本与PyTorch不匹配") + return False + + # 设备信息 + device_count = torch.cuda.device_count() + print(f"发现 {device_count} 个CUDA设备") + + for i in range(device_count): + print(f"\n设备 {i}:") + print(f" 名称: {torch.cuda.get_device_name(i)}") + print(f" 内存总量: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB") + print(f" 计算能力: {torch.cuda.get_device_properties(i).major}.{torch.cuda.get_device_properties(i).minor}") + + # 简单张量测试 + print("\n运行CUDA测试...") + try: + x = torch.randn(3, 3).cuda() + y = torch.randn(3, 3).cuda() + z = x + y + print("CUDA计算测试: 成功!") + print(f"设备上的张量形状: {z.shape}") + return True + except Exception as e: + print(f"CUDA计算测试失败: {e}") + return False + +if __name__ == "__main__": + test_cuda_availability() \ No newline at end of file diff --git a/test_import.py b/test_import.py new file mode 100644 index 0000000..3d764b6 --- /dev/null +++ b/test_import.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +""" +测试 timm 导入是否正常工作 +""" +import sys +print("Python version:", sys.version) + +try: + from timm.layers import to_2tuple, DropPath, trunc_normal_ + from timm.models import register_model + print("✓ 成功导入 timm.layers.to_2tuple") + print("✓ 成功导入 timm.layers.DropPath") + print("✓ 成功导入 timm.layers.trunc_normal_") + print("✓ 成功导入 timm.models.register_model") +except ImportError as e: + print(f"✗ 导入失败: {e}") + sys.exit(1) + +try: + from models.swiftformer import SwiftFormer_XS + print("✓ 成功导入 SwiftFormer_XS") +except ImportError as e: + print(f"✗ 导入 SwiftFormer_XS 失败: {e}") + sys.exit(1) + +try: + from models.swiftformer_temporal import SwiftFormerTemporal_XS + print("✓ 成功导入 SwiftFormerTemporal_XS") +except ImportError as e: + print(f"✗ 导入 SwiftFormerTemporal_XS 失败: {e}") + sys.exit(1) + +print("\n✅ 所有导入测试通过!") \ No newline at end of file diff --git a/util/video_dataset.py b/util/video_dataset.py index 8b6d57f..50ce612 100644 --- a/util/video_dataset.py +++ b/util/video_dataset.py @@ -80,10 +80,13 @@ class VideoFrameDataset(Dataset): else: self.transform = transform - # Normalization (ImageNet stats) + # Normalization for Y channel (single channel) + # Compute average of ImageNet RGB means and stds + y_mean = (0.485 + 0.456 + 0.406) / 3.0 + y_std = (0.229 + 0.224 + 0.225) / 3.0 self.normalize = transforms.Normalize( - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225] + mean=[y_mean], + std=[y_std] ) def _default_transform(self): @@ -114,8 +117,8 @@ class VideoFrameDataset(Dataset): def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Returns: - input_frames: [3 * num_frames, H, W] concatenated input frames - target_frame: [3, H, W] target frame to predict + input_frames: [num_frames, H, W] concatenated input frames (Y channel only) + target_frame: [1, H, W] target frame to predict (Y channel only) temporal_idx: temporal index of target frame (for contrastive loss) """ video_idx, start_idx = self.frame_indices[idx] @@ -141,23 +144,27 @@ class VideoFrameDataset(Dataset): if self.transform: target_frame = self.transform(target_frame) - # Convert to tensors and normalize + # Convert to tensors, normalize, and convert to grayscale (Y channel) input_tensors = [] for frame in input_frames: - tensor = transforms.ToTensor()(frame) - tensor = self.normalize(tensor) - input_tensors.append(tensor) + tensor = transforms.ToTensor()(frame) # [3, H, W] + # Convert RGB to grayscale using weighted sum + # Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL) + gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W] + gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast) + input_tensors.append(gray) - target_tensor = transforms.ToTensor()(target_frame) - target_tensor = self.normalize(target_tensor) + target_tensor = transforms.ToTensor()(target_frame) # [3, H, W] + target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0) + target_gray = self.normalize(target_gray) # Concatenate input frames along channel dimension - input_concatenated = torch.cat(input_tensors, dim=0) + input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W] # Temporal index (for contrastive loss) temporal_idx = torch.tensor(self.num_frames, dtype=torch.long) - return input_concatenated, target_tensor, temporal_idx + return input_concatenated, target_gray, temporal_idx class SyntheticVideoDataset(Dataset): @@ -174,10 +181,12 @@ class SyntheticVideoDataset(Dataset): self.frame_size = frame_size self.is_train = is_train - # Normalization + # Normalization for Y channel (single channel) + y_mean = (0.485 + 0.456 + 0.406) / 3.0 + y_std = (0.229 + 0.224 + 0.225) / 3.0 self.normalize = transforms.Normalize( - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225] + mean=[y_mean], + std=[y_std] ) def __len__(self): diff --git a/video_preprocessor.py b/video_preprocessor.py new file mode 100644 index 0000000..7ff2aac --- /dev/null +++ b/video_preprocessor.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 +""" +视频预处理脚本 - 将MP4视频转换为224x224帧图像 +支持多线程并发处理、进度条显示和中断恢复功能 +""" + +import os +import sys +import json +import argparse +import subprocess +import threading +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed +from tqdm import tqdm +import time +from typing import List, Dict, Optional + + +class VideoPreprocessor: + """视频预处理器,支持多线程和中断恢复""" + + def __init__(self, + input_dir: str, + output_dir: str, + frame_size: int = 224, + fps: int = 30, + num_workers: int = 4, + quality: int = 2, + resume: bool = True): + """ + 初始化预处理器 + + Args: + input_dir: 输入视频目录 + output_dir: 输出帧目录 + frame_size: 帧大小(正方形) + fps: 提取帧率 + num_workers: 并发工作线程数 + quality: JPEG质量 (1-31, 数值越小质量越高) + resume: 是否启用中断恢复 + """ + self.input_dir = Path(input_dir) + self.output_dir = Path(output_dir) + self.frame_size = frame_size + self.fps = fps + self.num_workers = num_workers + self.quality = quality + self.resume = resume + + # 状态文件路径 + self.state_file = self.output_dir / ".preprocessing_state.json" + + # 创建输出目录 + self.output_dir.mkdir(parents=True, exist_ok=True) + + # 初始化状态 + self.state = self._load_state() + + # 收集所有视频文件 + self.video_files = self._collect_video_files() + + def _load_state(self) -> Dict: + """加载处理状态""" + if self.resume and self.state_file.exists(): + try: + with open(self.state_file, 'r') as f: + return json.load(f) + except (json.JSONDecodeError, IOError): + print(f"警告: 无法读取状态文件,将重新开始处理") + + return { + "completed": [], + "failed": [], + "total_processed": 0, + "start_time": None, + "last_update": None + } + + def _save_state(self): + """保存处理状态""" + self.state["last_update"] = time.time() + try: + with open(self.state_file, 'w') as f: + json.dump(self.state, f, indent=2) + except IOError as e: + print(f"警告: 无法保存状态文件: {e}") + + def _collect_video_files(self) -> List[Path]: + """收集所有需要处理的视频文件""" + video_files = [] + for file_path in self.input_dir.glob("*.mp4"): + if file_path.name not in self.state["completed"]: + video_files.append(file_path) + + return sorted(video_files) + + def _parse_video_name(self, video_path: Path) -> Dict[str, str]: + """解析视频文件名,使用完整文件名作为ID""" + name_without_ext = video_path.stem + + # 直接使用完整文件名作为ID,确保每个mp4文件有独立的输出目录 + return { + "video_id": name_without_ext, + "start_frame": "unknown", + "end_frame": "unknown", + "full_name": name_without_ext + } + + def _extract_frames(self, video_path: Path) -> bool: + """提取单个视频的帧""" + try: + # 解析视频名称 + video_info = self._parse_video_name(video_path) + output_subdir = self.output_dir / video_info["video_id"] + output_subdir.mkdir(exist_ok=True) + + # 构建FFmpeg命令 + output_pattern = output_subdir / "frame_%04d.jpg" + + cmd = [ + "ffmpeg", + "-i", str(video_path), + "-vf", f"fps={self.fps},scale={self.frame_size}:{self.frame_size}", + "-q:v", str(self.quality), + "-y", # 覆盖输出文件 + str(output_pattern) + ] + + # 执行FFmpeg命令 + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=300 # 5分钟超时 + ) + + if result.returncode != 0: + print(f"FFmpeg错误处理 {video_path.name}: {result.stderr}") + return False + + # 验证输出帧数量 + output_frames = list(output_subdir.glob("frame_*.jpg")) + if len(output_frames) == 0: + print(f"警告: {video_path.name} 没有生成任何帧") + return False + + return True + + except subprocess.TimeoutExpired: + print(f"超时处理 {video_path.name}") + return False + except Exception as e: + print(f"处理 {video_path.name} 时发生错误: {e}") + return False + + def _process_video(self, video_path: Path) -> tuple[bool, str]: + """处理单个视频文件""" + video_name = video_path.name + + try: + success = self._extract_frames(video_path) + + if success: + self.state["completed"].append(video_name) + if video_name in self.state["failed"]: + self.state["failed"].remove(video_name) + return True, video_name + else: + self.state["failed"].append(video_name) + return False, video_name + + except Exception as e: + print(f"处理 {video_name} 时发生异常: {e}") + self.state["failed"].append(video_name) + return False, video_name + + def process_all_videos(self): + """处理所有视频文件""" + if not self.video_files: + print("没有找到需要处理的视频文件") + return + + print(f"找到 {len(self.video_files)} 个待处理视频文件") + print(f"输出目录: {self.output_dir}") + print(f"帧大小: {self.frame_size}x{self.frame_size}") + print(f"帧率: {self.fps} fps") + print(f"并发线程数: {self.num_workers}") + + if self.state["completed"]: + print(f"跳过 {len(self.state['completed'])} 个已处理的视频") + + # 记录开始时间 + if self.state["start_time"] is None: + self.state["start_time"] = time.time() + + # 创建进度条 + with tqdm(total=len(self.video_files), desc="处理视频", unit="个") as pbar: + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + # 提交所有任务 + future_to_video = { + executor.submit(self._process_video, video_path): video_path + for video_path in self.video_files + } + + # 处理完成的任务 + for future in as_completed(future_to_video): + video_path = future_to_video[future] + try: + success, video_name = future.result() + if success: + pbar.set_postfix({"状态": "成功", "文件": video_name[:20]}) + else: + pbar.set_postfix({"状态": "失败", "文件": video_name[:20]}) + except Exception as e: + print(f"处理 {video_path.name} 时发生异常: {e}") + pbar.set_postfix({"状态": "异常", "文件": video_path.name[:20]}) + + pbar.update(1) + self.state["total_processed"] += 1 + + # 定期保存状态 + if self.state["total_processed"] % 5 == 0: + self._save_state() + + # 最终保存状态 + self._save_state() + + # 打印处理结果 + self._print_summary() + + def _print_summary(self): + """打印处理摘要""" + print("\n" + "="*50) + print("处理完成摘要:") + print(f"总处理视频数: {len(self.state['completed'])}") + print(f"失败视频数: {len(self.state['failed'])}") + + if self.state["failed"]: + print("\n失败的视频:") + for video_name in self.state["failed"]: + print(f" - {video_name}") + + if self.state["start_time"]: + elapsed_time = time.time() - self.state["start_time"] + print(f"\n总耗时: {elapsed_time:.2f} 秒") + if self.state["total_processed"] > 0: + avg_time = elapsed_time / self.state["total_processed"] + print(f"平均每个视频: {avg_time:.2f} 秒") + + print("="*50) + + +def main(): + """主函数""" + parser = argparse.ArgumentParser(description="视频预处理脚本") + parser.add_argument("--input_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/sekai-real-drone", help="输入视频目录") + parser.add_argument("--output_dir", type=str, default="/home/hexone/Workplace/ws_asmo/vhead/sekai-real-drone/processed", help="输出帧目录") + parser.add_argument("--size", type=int, default=224, help="帧大小 (默认: 224)") + parser.add_argument("--fps", type=int, default=10, help="提取帧率 (默认: 30)") + parser.add_argument("--workers", type=int, default=32, help="并发线程数 (默认: 4)") + parser.add_argument("--quality", type=int, default=2, help="JPEG质量 1-31 (默认: 2)") + parser.add_argument("--no-resume", action="store_true", help="不启用中断恢复") + + args = parser.parse_args() + + # 检查输入目录 + if not Path(args.input_dir).exists(): + print(f"错误: 输入目录不存在: {args.input_dir}") + sys.exit(1) + + # 检查FFmpeg是否可用 + try: + subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) + except (subprocess.CalledProcessError, FileNotFoundError): + print("错误: FFmpeg未安装或不在PATH中") + sys.exit(1) + + # 创建预处理器并开始处理 + preprocessor = VideoPreprocessor( + input_dir=args.input_dir, + output_dir=args.output_dir, + frame_size=args.size, + fps=args.fps, + num_workers=args.workers, + quality=args.quality, + resume=not args.no_resume + ) + + try: + preprocessor.process_all_videos() + except KeyboardInterrupt: + print("\n\n用户中断处理,状态已保存") + preprocessor._save_state() + print("可以使用相同命令恢复处理") + except Exception as e: + print(f"\n处理过程中发生错误: {e}") + preprocessor._save_state() + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file