初步可跑通,但loss计算有问题,不收敛
This commit is contained in:
190
main_temporal.py
190
main_temporal.py
@@ -6,8 +6,10 @@ import datetime
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.backends.cudnn as cudnn
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from timm.scheduler import create_scheduler
|
||||
@@ -20,6 +22,17 @@ from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTempo
|
||||
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
|
||||
from util.frame_losses import MultiTaskLoss
|
||||
|
||||
# Try to import TensorBoard
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
TENSORBOARD_AVAILABLE = True
|
||||
except ImportError:
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
TENSORBOARD_AVAILABLE = True
|
||||
except ImportError:
|
||||
TENSORBOARD_AVAILABLE = False
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -48,10 +61,48 @@ def get_args_parser():
|
||||
# Training parameters
|
||||
parser.add_argument('--batch-size', default=32, type=int)
|
||||
parser.add_argument('--epochs', default=100, type=int)
|
||||
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
|
||||
help='learning rate (default: 1e-3)')
|
||||
|
||||
# Optimizer parameters (required by timm's create_optimizer)
|
||||
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "adamw"')
|
||||
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
||||
help='Optimizer Epsilon (default: 1e-8)')
|
||||
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
||||
help='Optimizer Betas (default: None, use opt default)')
|
||||
parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM',
|
||||
help='Clip gradient norm (default: None, no clipping)')
|
||||
parser.add_argument('--clip-mode', type=str, default='agc',
|
||||
help='Gradient clipping mode. One of ("norm", "value", "agc")')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
||||
help='SGD momentum (default: 0.9)')
|
||||
parser.add_argument('--weight-decay', type=float, default=0.05,
|
||||
help='weight decay (default: 0.05)')
|
||||
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
|
||||
help='learning rate (default: 1e-3)')
|
||||
|
||||
# Learning rate schedule parameters (required by timm's create_scheduler)
|
||||
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
||||
help='LR scheduler (default: "cosine"')
|
||||
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
||||
help='learning rate noise on/off epoch percentages')
|
||||
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
||||
help='learning rate noise limit percent (default: 0.67)')
|
||||
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
||||
help='learning rate noise std-dev (default: 1.0)')
|
||||
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
|
||||
help='warmup learning rate (default: 1e-6)')
|
||||
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
||||
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
||||
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
||||
help='epoch interval to decay LR')
|
||||
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
|
||||
help='epochs to warmup LR, if scheduler supports')
|
||||
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
||||
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
||||
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
||||
help='patience epochs for Plateau LR scheduler (default: 10')
|
||||
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
||||
help='LR decay rate (default: 0.1)')
|
||||
|
||||
# Loss parameters
|
||||
parser.add_argument('--frame-weight', type=float, default=1.0,
|
||||
@@ -90,26 +141,26 @@ def get_args_parser():
|
||||
parser.add_argument('--dist-url', default='env://',
|
||||
help='url used to set up distributed training')
|
||||
|
||||
# TensorBoard logging
|
||||
parser.add_argument('--tensorboard-logdir', default='./runs',
|
||||
type=str, help='TensorBoard log directory')
|
||||
parser.add_argument('--log-images', action='store_true',
|
||||
help='Log sample images to TensorBoard')
|
||||
parser.add_argument('--image-log-freq', default=100, type=int,
|
||||
help='Frequency of logging images (in iterations)')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def build_dataset(is_train, args):
|
||||
"""Build video frame dataset"""
|
||||
if args.dataset_type == 'synthetic':
|
||||
dataset = SyntheticVideoDataset(
|
||||
num_samples=1000 if is_train else 200,
|
||||
num_frames=args.num_frames,
|
||||
frame_size=args.frame_size,
|
||||
is_train=is_train
|
||||
)
|
||||
else:
|
||||
dataset = VideoFrameDataset(
|
||||
root_dir=args.data_path,
|
||||
num_frames=args.num_frames,
|
||||
frame_size=args.frame_size,
|
||||
is_train=is_train,
|
||||
max_interval=args.max_interval
|
||||
)
|
||||
dataset = VideoFrameDataset(
|
||||
root_dir=args.data_path,
|
||||
num_frames=args.num_frames,
|
||||
frame_size=args.frame_size,
|
||||
is_train=is_train,
|
||||
max_interval=args.max_interval
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -203,14 +254,18 @@ def main(args):
|
||||
# Create scheduler
|
||||
lr_scheduler, _ = create_scheduler(args, optimizer)
|
||||
|
||||
# Create loss function
|
||||
criterion = MultiTaskLoss(
|
||||
frame_weight=args.frame_weight,
|
||||
contrastive_weight=args.contrastive_weight,
|
||||
l1_weight=args.l1_weight,
|
||||
ssim_weight=args.ssim_weight,
|
||||
use_contrastive=not args.no_contrastive
|
||||
)
|
||||
# Create loss function - simple MSE for Y channel prediction
|
||||
class MSELossWrapper(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
|
||||
loss = self.mse(pred_frame, target_frame)
|
||||
loss_dict = {'mse': loss}
|
||||
return loss, loss_dict
|
||||
|
||||
criterion = MSELossWrapper()
|
||||
|
||||
# Resume from checkpoint
|
||||
output_dir = Path(args.output_dir)
|
||||
@@ -231,6 +286,21 @@ def main(args):
|
||||
if 'scaler' in checkpoint:
|
||||
loss_scaler.load_state_dict(checkpoint['scaler'])
|
||||
|
||||
# Initialize TensorBoard writer
|
||||
writer = None
|
||||
if TENSORBOARD_AVAILABLE and utils.is_main_process():
|
||||
from datetime import datetime
|
||||
# Create log directory with timestamp
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
log_dir = os.path.join(args.tensorboard_logdir, f"exp_{timestamp}")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
writer = SummaryWriter(log_dir=log_dir)
|
||||
print(f"TensorBoard logs will be saved to: {log_dir}")
|
||||
print(f"To view logs, run: tensorboard --logdir={log_dir}")
|
||||
elif not TENSORBOARD_AVAILABLE and utils.is_main_process():
|
||||
print("Warning: TensorBoard not available. Install tensorboard or tensorboardX.")
|
||||
print("Training will continue without TensorBoard logging.")
|
||||
|
||||
if args.eval:
|
||||
test_stats = evaluate(data_loader_val, model, criterion, device)
|
||||
print(f"Test stats: {test_stats}")
|
||||
@@ -239,14 +309,18 @@ def main(args):
|
||||
print(f"Start training for {args.epochs} epochs")
|
||||
start_time = time.time()
|
||||
|
||||
# Global step counter for TensorBoard
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
data_loader_train.sampler.set_epoch(epoch)
|
||||
|
||||
train_stats = train_one_epoch(
|
||||
train_stats, global_step = train_one_epoch(
|
||||
model, criterion, data_loader_train,
|
||||
optimizer, device, epoch, loss_scaler,
|
||||
model_ema=model_ema
|
||||
model_ema=model_ema, writer=writer,
|
||||
global_step=global_step, args=args
|
||||
)
|
||||
|
||||
lr_scheduler.step(epoch)
|
||||
@@ -266,10 +340,10 @@ def main(args):
|
||||
|
||||
# Evaluate
|
||||
if epoch % 5 == 0 or epoch == args.epochs - 1:
|
||||
test_stats = evaluate(data_loader_val, model, criterion, device)
|
||||
test_stats = evaluate(data_loader_val, model, criterion, device, writer=writer, epoch=epoch)
|
||||
print(f"Epoch {epoch}: Test stats: {test_stats}")
|
||||
|
||||
# Log stats
|
||||
# Log stats to text file
|
||||
log_stats = {
|
||||
**{f'train_{k}': v for k, v in train_stats.items()},
|
||||
**{f'test_{k}': v for k, v in test_stats.items()},
|
||||
@@ -284,18 +358,24 @@ def main(args):
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print(f'Training time {total_time_str}')
|
||||
|
||||
# Close TensorBoard writer
|
||||
if writer is not None:
|
||||
writer.close()
|
||||
print(f"TensorBoard logs saved to: {writer.log_dir}")
|
||||
|
||||
|
||||
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
||||
clip_grad=0, clip_mode='norm', model_ema=None, **kwargs):
|
||||
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
||||
clip_grad=0, clip_mode='norm', model_ema=None, writer=None,
|
||||
global_step=0, args=None, **kwargs):
|
||||
model.train()
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
header = f'Epoch: [{epoch}]'
|
||||
print_freq = 10
|
||||
|
||||
for input_frames, target_frames, temporal_indices in metric_logger.log_every(
|
||||
data_loader, print_freq, header):
|
||||
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
|
||||
metric_logger.log_every(data_loader, print_freq, header)):
|
||||
|
||||
input_frames = input_frames.to(device, non_blocking=True)
|
||||
target_frames = target_frames.to(device, non_blocking=True)
|
||||
@@ -305,7 +385,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
with torch.cuda.amp.autocast():
|
||||
pred_frames, representations = model(input_frames)
|
||||
loss, loss_dict = criterion(
|
||||
pred_frames, target_frames,
|
||||
pred_frames, target_frames,
|
||||
representations, temporal_indices
|
||||
)
|
||||
|
||||
@@ -322,19 +402,51 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
|
||||
# Log to TensorBoard
|
||||
if writer is not None:
|
||||
# Log scalar metrics every iteration
|
||||
writer.add_scalar('train/loss', loss_value, global_step)
|
||||
writer.add_scalar('train/lr', optimizer.param_groups[0]["lr"], global_step)
|
||||
|
||||
# Log individual loss components
|
||||
for k, v in loss_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
writer.add_scalar(f'train/{k}', v.item(), global_step)
|
||||
else:
|
||||
writer.add_scalar(f'train/{k}', v, global_step)
|
||||
|
||||
# Log images periodically
|
||||
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
|
||||
with torch.no_grad():
|
||||
# Take first sample from batch for visualization
|
||||
pred_vis, _ = model(input_frames[:1])
|
||||
# Convert to appropriate format for TensorBoard
|
||||
# Assuming frames are in [B, C, H, W] format
|
||||
writer.add_images('train/input', input_frames[:1], global_step)
|
||||
writer.add_images('train/target', target_frames[:1], global_step)
|
||||
writer.add_images('train/predicted', pred_vis[:1], global_step)
|
||||
|
||||
# Update metrics
|
||||
metric_logger.update(loss=loss_value)
|
||||
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||
for k, v in loss_dict.items():
|
||||
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
|
||||
|
||||
global_step += 1
|
||||
|
||||
metric_logger.synchronize_between_processes()
|
||||
print("Averaged stats:", metric_logger)
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
# Log epoch-level metrics
|
||||
if writer is not None:
|
||||
for k, meter in metric_logger.meters.items():
|
||||
writer.add_scalar(f'train_epoch/{k}', meter.global_avg, epoch)
|
||||
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(data_loader, model, criterion, device):
|
||||
def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
|
||||
model.eval()
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
header = 'Test:'
|
||||
@@ -359,6 +471,12 @@ def evaluate(data_loader, model, criterion, device):
|
||||
|
||||
metric_logger.synchronize_between_processes()
|
||||
print('* Test stats:', metric_logger)
|
||||
|
||||
# Log validation metrics to TensorBoard
|
||||
if writer is not None:
|
||||
for k, meter in metric_logger.meters.items():
|
||||
writer.add_scalar(f'val/{k}', meter.global_avg, epoch)
|
||||
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user