""" Training loop for VelocityPredictionModel. Usage: python -m src.velocity_prediction.train [--device cuda:0] python -m src.velocity_prediction.train --resume checkpoints/epoch_020_val_0.123456.pt """ import argparse import os import time import numpy as np from pathlib import Path import torch import torch.nn as nn from torch.utils.tensorboard import SummaryWriter from src.velocity_prediction.config import train_cfg, model_cfg from src.velocity_prediction.model import VelocityPredictionModel, count_parameters from src.velocity_prediction.dataset import create_train_loader, create_val_loader def set_seed(seed: int): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def train_one_epoch( model: nn.Module, loader, optimizer: torch.optim.Optimizer, criterion: nn.Module, device: torch.device, epoch: int, writer: SummaryWriter, log_interval: int = 50, global_step: int = 0, ) -> tuple[float, int]: """Train for one epoch. Returns (avg_loss, updated_global_step).""" model.train() total_loss = 0.0 num_batches = 0 start_time = time.time() for batch_idx, batch in enumerate(loader): events = batch["events"].to(device) # (B, S, 1, H, W) tilt = batch["tilt"].to(device) # (B, S, 3) target = batch["v_body_target"].to(device) # (B, S, 2) # Predict velocity for the last frame in the sequence pred = model(events, tilt) # (B, 2) target_last = target[:, -1, :] # (B, 2) loss = criterion(pred, target_last) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() num_batches += 1 global_step += 1 if batch_idx % log_interval == 0: elapsed = time.time() - start_time print(f" Epoch {epoch} | Batch {batch_idx} | Loss: {loss.item():.6f} | {elapsed:.1f}s") writer.add_scalar("train/loss_batch", loss.item(), global_step) avg_loss = total_loss / max(num_batches, 1) print(f" Epoch {epoch} | Avg Loss: {avg_loss:.6f}") return avg_loss, global_step @torch.no_grad() def validate( model: nn.Module, loader, criterion: nn.Module, device: torch.device, ) -> float: """Validate. Returns average loss.""" model.eval() total_loss = 0.0 num_batches = 0 for batch in loader: events = batch["events"].to(device) tilt = batch["tilt"].to(device) target = batch["v_body_target"].to(device) pred = model(events, tilt) target_last = target[:, -1, :] loss = criterion(pred, target_last) total_loss += loss.item() num_batches += 1 avg_loss = total_loss / max(num_batches, 1) return avg_loss def main(): parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cuda", help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')") parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint .pt file to resume training from") args = parser.parse_args() set_seed(train_cfg.seed) device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu") print(f"Device: {device}") # Create model model = VelocityPredictionModel() model.to(device) total_params = count_parameters(model) print(f"Model parameters: {total_params:,} ({total_params/1e6:.3f} M)") # Data loaders train_loader = create_train_loader( seq_len=train_cfg.seq_len, batch_size=train_cfg.batch_size, num_workers=train_cfg.num_workers, event_threshold=train_cfg.event_threshold, event_use_log=train_cfg.event_use_log, ) val_loader = create_val_loader( seq_len=train_cfg.seq_len, batch_size=train_cfg.batch_size, num_workers=train_cfg.num_workers, event_threshold=train_cfg.event_threshold, event_use_log=train_cfg.event_use_log, ) # Optimizer & scheduler optimizer = torch.optim.AdamW( model.parameters(), lr=train_cfg.lr, weight_decay=train_cfg.weight_decay, ) scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=train_cfg.lr_scheduler_step, gamma=train_cfg.lr_scheduler_gamma, ) # criterion = nn.SmoothL1Loss() criterion = nn.MSELoss() ckpt_dir = Path(train_cfg.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) # ── Resume from checkpoint ──────────────────────────────────── start_epoch = 1 global_step = 0 best_val_loss = float("inf") run_id = None if args.resume is not None: ckpt_path = Path(args.resume) if not ckpt_path.exists(): raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") ckpt = torch.load(ckpt_path, map_location="cpu") model.load_state_dict(ckpt["model_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) if "scheduler_state_dict" in ckpt: scheduler.load_state_dict(ckpt["scheduler_state_dict"]) start_epoch = ckpt.get("epoch", 0) + 1 global_step = ckpt.get("global_step", 0) best_val_loss = ckpt.get("best_val_loss", float("inf")) run_id = ckpt.get("run_id", None) print(f"Resumed from checkpoint: {ckpt_path}") print(f" Resumed epoch={ckpt.get('epoch', '?')}, global_step={global_step}, " f"best_val_loss={best_val_loss:.6f}") else: print(f"\nStarting training for {train_cfg.epochs} epochs...") # Logging — run-specific subdirectory for isolation + resume continuity if run_id is None: run_id = time.strftime("run_%Y%m%d_%H%M%S") log_dir = Path(train_cfg.log_dir) / run_id log_dir.mkdir(parents=True, exist_ok=True) writer = SummaryWriter(log_dir=str(log_dir)) print(f" seq_len={train_cfg.seq_len}, batch_size={train_cfg.batch_size}") print(f" lr={train_cfg.lr}, weight_decay={train_cfg.weight_decay}") print(f" log_dir={log_dir}, checkpoint_dir={ckpt_dir}\n") for epoch in range(start_epoch, train_cfg.epochs + 1): epoch_start = time.time() train_loss, global_step = train_one_epoch( model, train_loader, optimizer, criterion, device, epoch, writer, log_interval=train_cfg.log_interval, global_step=global_step, ) val_loss = validate(model, val_loader, criterion, device) scheduler.step() epoch_time = time.time() - epoch_start current_lr = scheduler.get_last_lr()[0] print(f"Epoch {epoch:3d}/{train_cfg.epochs} | " f"Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | " f"LR: {current_lr:.2e} | Time: {epoch_time:.1f}s") writer.add_scalar("train/loss_epoch", train_loss, epoch) writer.add_scalar("val/loss", val_loss, epoch) writer.add_scalar("lr", current_lr, epoch) # Save checkpoint if epoch % train_cfg.save_interval == 0: ckpt_path = ckpt_dir / f"epoch_{epoch:03d}_val_{val_loss:.6f}.pt" torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "global_step": global_step, "best_val_loss": best_val_loss, "run_id": run_id, "val_loss": val_loss, }, ckpt_path) print(f" Checkpoint saved: {ckpt_path}") # Save best model if val_loss < best_val_loss: best_val_loss = val_loss best_path = ckpt_dir / "best.pt" torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "global_step": global_step, "best_val_loss": best_val_loss, "run_id": run_id, "val_loss": val_loss, }, best_path) print(f" Best model updated: {best_path} (val_loss={val_loss:.6f})") writer.close() print(f"\nTraining complete. Best val loss: {best_val_loss:.6f}") print(f"Best checkpoint: {ckpt_dir / 'best.pt'}") if __name__ == "__main__": main()