initial commit

This commit is contained in:
2026-05-29 18:49:01 +08:00
commit 9f0321eff8
21 changed files with 3143 additions and 0 deletions

View File

@@ -0,0 +1,213 @@
"""
Training loop for VelocityPredictionModel.
Usage:
python -m src.velocity_prediction.train [--device cuda:0]
"""
import argparse
import os
import time
import numpy as np
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from src.velocity_prediction.config import train_cfg, model_cfg
from src.velocity_prediction.model import VelocityPredictionModel, count_parameters
from src.velocity_prediction.dataset import create_train_loader, create_val_loader
def set_seed(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def train_one_epoch(
model: nn.Module,
loader,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
device: torch.device,
epoch: int,
writer: SummaryWriter,
log_interval: int = 50,
) -> float:
"""Train for one epoch. Returns average loss."""
model.train()
total_loss = 0.0
num_batches = 0
start_time = time.time()
for batch_idx, batch in enumerate(loader):
events = batch["events"].to(device) # (B, S, 1, H, W)
tilt = batch["tilt"].to(device) # (B, S, 3)
target = batch["v_body_target"].to(device) # (B, S, 2)
# Predict velocity for the last frame in the sequence
pred = model(events, tilt) # (B, 2)
target_last = target[:, -1, :] # (B, 2)
loss = criterion(pred, target_last)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
num_batches += 1
if batch_idx % log_interval == 0:
elapsed = time.time() - start_time
print(f" Epoch {epoch} | Batch {batch_idx} | Loss: {loss.item():.6f} | {elapsed:.1f}s")
writer.add_scalar("train/loss_batch", loss.item(), batch_idx)
avg_loss = total_loss / max(num_batches, 1)
return avg_loss
@torch.no_grad()
def validate(
model: nn.Module,
loader,
criterion: nn.Module,
device: torch.device,
) -> float:
"""Validate. Returns average loss."""
model.eval()
total_loss = 0.0
num_batches = 0
for batch in loader:
events = batch["events"].to(device)
tilt = batch["tilt"].to(device)
target = batch["v_body_target"].to(device)
pred = model(events, tilt)
target_last = target[:, -1, :]
loss = criterion(pred, target_last)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
return avg_loss
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda",
help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')")
args = parser.parse_args()
set_seed(train_cfg.seed)
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
print(f"Device: {device}")
# Create model
model = VelocityPredictionModel()
model.to(device)
total_params = count_parameters(model)
print(f"Model parameters: {total_params:,} ({total_params/1e6:.3f} M)")
# Data loaders
train_loader = create_train_loader(
seq_len=train_cfg.seq_len,
batch_size=train_cfg.batch_size,
num_workers=train_cfg.num_workers,
event_threshold=train_cfg.event_threshold,
event_use_log=train_cfg.event_use_log,
)
val_loader = create_val_loader(
seq_len=train_cfg.seq_len,
batch_size=train_cfg.batch_size,
num_workers=train_cfg.num_workers,
event_threshold=train_cfg.event_threshold,
event_use_log=train_cfg.event_use_log,
)
# Optimizer & scheduler
optimizer = torch.optim.AdamW(
model.parameters(),
lr=train_cfg.lr,
weight_decay=train_cfg.weight_decay,
)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=train_cfg.lr_scheduler_step,
gamma=train_cfg.lr_scheduler_gamma,
)
# criterion = nn.SmoothL1Loss()
criterion = nn.MSELoss()
# Logging
log_dir = Path(train_cfg.log_dir)
log_dir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_dir=str(log_dir))
ckpt_dir = Path(train_cfg.checkpoint_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True)
best_val_loss = float("inf")
print(f"\nStarting training for {train_cfg.epochs} epochs...")
print(f" seq_len={train_cfg.seq_len}, batch_size={train_cfg.batch_size}")
print(f" lr={train_cfg.lr}, weight_decay={train_cfg.weight_decay}")
print(f" log_dir={log_dir}, checkpoint_dir={ckpt_dir}\n")
for epoch in range(1, train_cfg.epochs + 1):
epoch_start = time.time()
train_loss = train_one_epoch(
model, train_loader, optimizer, criterion, device, epoch, writer,
log_interval=train_cfg.log_interval,
)
val_loss = validate(model, val_loader, criterion, device)
scheduler.step()
epoch_time = time.time() - epoch_start
current_lr = scheduler.get_last_lr()[0]
print(f"Epoch {epoch:3d}/{train_cfg.epochs} | "
f"Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | "
f"LR: {current_lr:.2e} | Time: {epoch_time:.1f}s")
writer.add_scalar("train/loss_epoch", train_loss, epoch)
writer.add_scalar("val/loss", val_loss, epoch)
writer.add_scalar("lr", current_lr, epoch)
# Save checkpoint
if epoch % train_cfg.save_interval == 0:
ckpt_path = ckpt_dir / f"epoch_{epoch:03d}_val_{val_loss:.6f}.pt"
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"train_loss": train_loss,
"val_loss": val_loss,
}, ckpt_path)
print(f" Checkpoint saved: {ckpt_path}")
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
best_path = ckpt_dir / "best.pt"
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"val_loss": val_loss,
}, best_path)
print(f" Best model updated: {best_path} (val_loss={val_loss:.6f})")
writer.close()
print(f"\nTraining complete. Best val loss: {best_val_loss:.6f}")
print(f"Best checkpoint: {ckpt_dir / 'best.pt'}")
if __name__ == "__main__":
main()