- Add --resume CLI arg to resume training from a checkpoint - Restore model, optimizer, scheduler state; continue from saved epoch+1 - Preserve global_step and best_val_loss across resume - Save run_id in checkpoints for TensorBoard log continuity - Use logs/run_<timestamp>/ subdirectories to isolate experiment logs - Fix: replace train_loss in checkpoint dict with global_step to avoid KeyError when loading; track global_step through train_one_epoch - Fix: use global_step (not batch_idx) as TensorBoard x-axis for batch loss - Fix: print average loss at end of each epoch Generated by Mistral Vibe (ds-v4-flash). Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
253 lines
8.5 KiB
Python
253 lines
8.5 KiB
Python
"""
|
|
Training loop for VelocityPredictionModel.
|
|
|
|
Usage:
|
|
python -m src.velocity_prediction.train [--device cuda:0]
|
|
python -m src.velocity_prediction.train --resume checkpoints/epoch_020_val_0.123456.pt
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import time
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from src.velocity_prediction.config import train_cfg, model_cfg
|
|
from src.velocity_prediction.model import VelocityPredictionModel, count_parameters
|
|
from src.velocity_prediction.dataset import create_train_loader, create_val_loader
|
|
|
|
|
|
def set_seed(seed: int):
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def train_one_epoch(
|
|
model: nn.Module,
|
|
loader,
|
|
optimizer: torch.optim.Optimizer,
|
|
criterion: nn.Module,
|
|
device: torch.device,
|
|
epoch: int,
|
|
writer: SummaryWriter,
|
|
log_interval: int = 50,
|
|
global_step: int = 0,
|
|
) -> tuple[float, int]:
|
|
"""Train for one epoch. Returns (avg_loss, updated_global_step)."""
|
|
model.train()
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
start_time = time.time()
|
|
|
|
for batch_idx, batch in enumerate(loader):
|
|
events = batch["events"].to(device) # (B, S, 1, H, W)
|
|
tilt = batch["tilt"].to(device) # (B, S, 3)
|
|
target = batch["v_body_target"].to(device) # (B, S, 2)
|
|
|
|
# Predict velocity for the last frame in the sequence
|
|
pred = model(events, tilt) # (B, 2)
|
|
target_last = target[:, -1, :] # (B, 2)
|
|
|
|
loss = criterion(pred, target_last)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
total_loss += loss.item()
|
|
num_batches += 1
|
|
global_step += 1
|
|
|
|
if batch_idx % log_interval == 0:
|
|
elapsed = time.time() - start_time
|
|
print(f" Epoch {epoch} | Batch {batch_idx} | Loss: {loss.item():.6f} | {elapsed:.1f}s")
|
|
writer.add_scalar("train/loss_batch", loss.item(), global_step)
|
|
|
|
avg_loss = total_loss / max(num_batches, 1)
|
|
print(f" Epoch {epoch} | Avg Loss: {avg_loss:.6f}")
|
|
return avg_loss, global_step
|
|
|
|
|
|
@torch.no_grad()
|
|
def validate(
|
|
model: nn.Module,
|
|
loader,
|
|
criterion: nn.Module,
|
|
device: torch.device,
|
|
) -> float:
|
|
"""Validate. Returns average loss."""
|
|
model.eval()
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
|
|
for batch in loader:
|
|
events = batch["events"].to(device)
|
|
tilt = batch["tilt"].to(device)
|
|
target = batch["v_body_target"].to(device)
|
|
|
|
pred = model(events, tilt)
|
|
target_last = target[:, -1, :]
|
|
|
|
loss = criterion(pred, target_last)
|
|
total_loss += loss.item()
|
|
num_batches += 1
|
|
|
|
avg_loss = total_loss / max(num_batches, 1)
|
|
return avg_loss
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--device", type=str, default="cuda",
|
|
help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')")
|
|
parser.add_argument("--resume", type=str, default=None,
|
|
help="Path to checkpoint .pt file to resume training from")
|
|
args = parser.parse_args()
|
|
|
|
set_seed(train_cfg.seed)
|
|
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
|
|
print(f"Device: {device}")
|
|
|
|
# Create model
|
|
model = VelocityPredictionModel()
|
|
model.to(device)
|
|
total_params = count_parameters(model)
|
|
print(f"Model parameters: {total_params:,} ({total_params/1e6:.3f} M)")
|
|
|
|
# Data loaders
|
|
train_loader = create_train_loader(
|
|
seq_len=train_cfg.seq_len,
|
|
batch_size=train_cfg.batch_size,
|
|
num_workers=train_cfg.num_workers,
|
|
event_threshold=train_cfg.event_threshold,
|
|
event_use_log=train_cfg.event_use_log,
|
|
)
|
|
val_loader = create_val_loader(
|
|
seq_len=train_cfg.seq_len,
|
|
batch_size=train_cfg.batch_size,
|
|
num_workers=train_cfg.num_workers,
|
|
event_threshold=train_cfg.event_threshold,
|
|
event_use_log=train_cfg.event_use_log,
|
|
)
|
|
|
|
# Optimizer & scheduler
|
|
optimizer = torch.optim.AdamW(
|
|
model.parameters(),
|
|
lr=train_cfg.lr,
|
|
weight_decay=train_cfg.weight_decay,
|
|
)
|
|
scheduler = torch.optim.lr_scheduler.StepLR(
|
|
optimizer,
|
|
step_size=train_cfg.lr_scheduler_step,
|
|
gamma=train_cfg.lr_scheduler_gamma,
|
|
)
|
|
# criterion = nn.SmoothL1Loss()
|
|
criterion = nn.MSELoss()
|
|
|
|
ckpt_dir = Path(train_cfg.checkpoint_dir)
|
|
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# ── Resume from checkpoint ────────────────────────────────────
|
|
start_epoch = 1
|
|
global_step = 0
|
|
best_val_loss = float("inf")
|
|
run_id = None
|
|
|
|
if args.resume is not None:
|
|
ckpt_path = Path(args.resume)
|
|
if not ckpt_path.exists():
|
|
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
|
|
|
model.load_state_dict(ckpt["model_state_dict"])
|
|
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
|
if "scheduler_state_dict" in ckpt:
|
|
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
|
start_epoch = ckpt.get("epoch", 0) + 1
|
|
global_step = ckpt.get("global_step", 0)
|
|
best_val_loss = ckpt.get("best_val_loss", float("inf"))
|
|
run_id = ckpt.get("run_id", None)
|
|
|
|
print(f"Resumed from checkpoint: {ckpt_path}")
|
|
print(f" Resumed epoch={ckpt.get('epoch', '?')}, global_step={global_step}, "
|
|
f"best_val_loss={best_val_loss:.6f}")
|
|
else:
|
|
print(f"\nStarting training for {train_cfg.epochs} epochs...")
|
|
|
|
# Logging — run-specific subdirectory for isolation + resume continuity
|
|
if run_id is None:
|
|
run_id = time.strftime("run_%Y%m%d_%H%M%S")
|
|
log_dir = Path(train_cfg.log_dir) / run_id
|
|
log_dir.mkdir(parents=True, exist_ok=True)
|
|
writer = SummaryWriter(log_dir=str(log_dir))
|
|
|
|
print(f" seq_len={train_cfg.seq_len}, batch_size={train_cfg.batch_size}")
|
|
print(f" lr={train_cfg.lr}, weight_decay={train_cfg.weight_decay}")
|
|
print(f" log_dir={log_dir}, checkpoint_dir={ckpt_dir}\n")
|
|
|
|
for epoch in range(start_epoch, train_cfg.epochs + 1):
|
|
epoch_start = time.time()
|
|
|
|
train_loss, global_step = train_one_epoch(
|
|
model, train_loader, optimizer, criterion, device, epoch, writer,
|
|
log_interval=train_cfg.log_interval,
|
|
global_step=global_step,
|
|
)
|
|
val_loss = validate(model, val_loader, criterion, device)
|
|
scheduler.step()
|
|
|
|
epoch_time = time.time() - epoch_start
|
|
current_lr = scheduler.get_last_lr()[0]
|
|
|
|
print(f"Epoch {epoch:3d}/{train_cfg.epochs} | "
|
|
f"Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | "
|
|
f"LR: {current_lr:.2e} | Time: {epoch_time:.1f}s")
|
|
|
|
writer.add_scalar("train/loss_epoch", train_loss, epoch)
|
|
writer.add_scalar("val/loss", val_loss, epoch)
|
|
writer.add_scalar("lr", current_lr, epoch)
|
|
|
|
# Save checkpoint
|
|
if epoch % train_cfg.save_interval == 0:
|
|
ckpt_path = ckpt_dir / f"epoch_{epoch:03d}_val_{val_loss:.6f}.pt"
|
|
torch.save({
|
|
"epoch": epoch,
|
|
"model_state_dict": model.state_dict(),
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
"scheduler_state_dict": scheduler.state_dict(),
|
|
"global_step": global_step,
|
|
"best_val_loss": best_val_loss,
|
|
"run_id": run_id,
|
|
"val_loss": val_loss,
|
|
}, ckpt_path)
|
|
print(f" Checkpoint saved: {ckpt_path}")
|
|
|
|
# Save best model
|
|
if val_loss < best_val_loss:
|
|
best_val_loss = val_loss
|
|
best_path = ckpt_dir / "best.pt"
|
|
torch.save({
|
|
"epoch": epoch,
|
|
"model_state_dict": model.state_dict(),
|
|
"optimizer_state_dict": optimizer.state_dict(),
|
|
"scheduler_state_dict": scheduler.state_dict(),
|
|
"global_step": global_step,
|
|
"best_val_loss": best_val_loss,
|
|
"run_id": run_id,
|
|
"val_loss": val_loss,
|
|
}, best_path)
|
|
print(f" Best model updated: {best_path} (val_loss={val_loss:.6f})")
|
|
|
|
writer.close()
|
|
print(f"\nTraining complete. Best val loss: {best_val_loss:.6f}")
|
|
print(f"Best checkpoint: {ckpt_dir / 'best.pt'}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|