初步可跑通,但loss计算有问题,不收敛

This commit is contained in:
2026-01-08 09:43:23 +08:00
parent efd76bccd2
commit f7601e9170
11 changed files with 656 additions and 63 deletions

View File

@@ -6,8 +6,10 @@ import datetime
import numpy as np
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import json
import os
from pathlib import Path
from timm.scheduler import create_scheduler
@@ -20,6 +22,17 @@ from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTempo
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
from util.frame_losses import MultiTaskLoss
# Try to import TensorBoard
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
try:
from tensorboardX import SummaryWriter
TENSORBOARD_AVAILABLE = True
except ImportError:
TENSORBOARD_AVAILABLE = False
def get_args_parser():
parser = argparse.ArgumentParser(
@@ -48,10 +61,48 @@ def get_args_parser():
# Training parameters
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
# Optimizer parameters (required by timm's create_optimizer)
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='agc',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
# Learning rate schedule parameters (required by timm's create_scheduler)
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Loss parameters
parser.add_argument('--frame-weight', type=float, default=1.0,
@@ -90,26 +141,26 @@ def get_args_parser():
parser.add_argument('--dist-url', default='env://',
help='url used to set up distributed training')
# TensorBoard logging
parser.add_argument('--tensorboard-logdir', default='./runs',
type=str, help='TensorBoard log directory')
parser.add_argument('--log-images', action='store_true',
help='Log sample images to TensorBoard')
parser.add_argument('--image-log-freq', default=100, type=int,
help='Frequency of logging images (in iterations)')
return parser
def build_dataset(is_train, args):
"""Build video frame dataset"""
if args.dataset_type == 'synthetic':
dataset = SyntheticVideoDataset(
num_samples=1000 if is_train else 200,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train
)
else:
dataset = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train,
max_interval=args.max_interval
)
dataset = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train,
max_interval=args.max_interval
)
return dataset
@@ -203,14 +254,18 @@ def main(args):
# Create scheduler
lr_scheduler, _ = create_scheduler(args, optimizer)
# Create loss function
criterion = MultiTaskLoss(
frame_weight=args.frame_weight,
contrastive_weight=args.contrastive_weight,
l1_weight=args.l1_weight,
ssim_weight=args.ssim_weight,
use_contrastive=not args.no_contrastive
)
# Create loss function - simple MSE for Y channel prediction
class MSELossWrapper(nn.Module):
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
loss = self.mse(pred_frame, target_frame)
loss_dict = {'mse': loss}
return loss, loss_dict
criterion = MSELossWrapper()
# Resume from checkpoint
output_dir = Path(args.output_dir)
@@ -231,6 +286,21 @@ def main(args):
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
# Initialize TensorBoard writer
writer = None
if TENSORBOARD_AVAILABLE and utils.is_main_process():
from datetime import datetime
# Create log directory with timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_dir = os.path.join(args.tensorboard_logdir, f"exp_{timestamp}")
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
print(f"TensorBoard logs will be saved to: {log_dir}")
print(f"To view logs, run: tensorboard --logdir={log_dir}")
elif not TENSORBOARD_AVAILABLE and utils.is_main_process():
print("Warning: TensorBoard not available. Install tensorboard or tensorboardX.")
print("Training will continue without TensorBoard logging.")
if args.eval:
test_stats = evaluate(data_loader_val, model, criterion, device)
print(f"Test stats: {test_stats}")
@@ -239,14 +309,18 @@ def main(args):
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
# Global step counter for TensorBoard
global_step = 0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
train_stats, global_step = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
model_ema=model_ema
model_ema=model_ema, writer=writer,
global_step=global_step, args=args
)
lr_scheduler.step(epoch)
@@ -266,10 +340,10 @@ def main(args):
# Evaluate
if epoch % 5 == 0 or epoch == args.epochs - 1:
test_stats = evaluate(data_loader_val, model, criterion, device)
test_stats = evaluate(data_loader_val, model, criterion, device, writer=writer, epoch=epoch)
print(f"Epoch {epoch}: Test stats: {test_stats}")
# Log stats
# Log stats to text file
log_stats = {
**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
@@ -284,18 +358,24 @@ def main(args):
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f'Training time {total_time_str}')
# Close TensorBoard writer
if writer is not None:
writer.close()
print(f"TensorBoard logs saved to: {writer.log_dir}")
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0, clip_mode='norm', model_ema=None, **kwargs):
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0, clip_mode='norm', model_ema=None, writer=None,
global_step=0, args=None, **kwargs):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = f'Epoch: [{epoch}]'
print_freq = 10
for input_frames, target_frames, temporal_indices in metric_logger.log_every(
data_loader, print_freq, header):
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
metric_logger.log_every(data_loader, print_freq, header)):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
@@ -305,7 +385,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
with torch.cuda.amp.autocast():
pred_frames, representations = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
pred_frames, target_frames,
representations, temporal_indices
)
@@ -322,19 +402,51 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
if model_ema is not None:
model_ema.update(model)
# Log to TensorBoard
if writer is not None:
# Log scalar metrics every iteration
writer.add_scalar('train/loss', loss_value, global_step)
writer.add_scalar('train/lr', optimizer.param_groups[0]["lr"], global_step)
# Log individual loss components
for k, v in loss_dict.items():
if torch.is_tensor(v):
writer.add_scalar(f'train/{k}', v.item(), global_step)
else:
writer.add_scalar(f'train/{k}', v, global_step)
# Log images periodically
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
with torch.no_grad():
# Take first sample from batch for visualization
pred_vis, _ = model(input_frames[:1])
# Convert to appropriate format for TensorBoard
# Assuming frames are in [B, C, H, W] format
writer.add_images('train/input', input_frames[:1], global_step)
writer.add_images('train/target', target_frames[:1], global_step)
writer.add_images('train/predicted', pred_vis[:1], global_step)
# Update metrics
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
global_step += 1
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
# Log epoch-level metrics
if writer is not None:
for k, meter in metric_logger.meters.items():
writer.add_scalar(f'train_epoch/{k}', meter.global_avg, epoch)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, global_step
@torch.no_grad()
def evaluate(data_loader, model, criterion, device):
def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
@@ -359,6 +471,12 @@ def evaluate(data_loader, model, criterion, device):
metric_logger.synchronize_between_processes()
print('* Test stats:', metric_logger)
# Log validation metrics to TensorBoard
if writer is not None:
for k, meter in metric_logger.meters.items():
writer.add_scalar(f'val/{k}', meter.global_avg, epoch)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}