feat: add checkpoint resume and fix train_loss tracking

- Add --resume CLI arg to resume training from a checkpoint
- Restore model, optimizer, scheduler state; continue from saved epoch+1
- Preserve global_step and best_val_loss across resume
- Save run_id in checkpoints for TensorBoard log continuity
- Use logs/run_<timestamp>/ subdirectories to isolate experiment logs
- Fix: replace train_loss in checkpoint dict with global_step to avoid
  KeyError when loading; track global_step through train_one_epoch
- Fix: use global_step (not batch_idx) as TensorBoard x-axis for batch loss
- Fix: print average loss at end of each epoch

Generated by Mistral Vibe (ds-v4-flash).
Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit is contained in:
2026-06-04 22:55:31 +08:00
parent 0a504d648e
commit ec143868d0
2 changed files with 61 additions and 21 deletions

View File

@@ -10,9 +10,10 @@ from pathlib import Path
DATASET_ROOT = Path(__file__).resolve().parents[2] / "dataset" DATASET_ROOT = Path(__file__).resolve().parents[2] / "dataset"
# Velocity normalization stats (computed from training set) # Velocity normalization stats (computed from forward scenes only, 28363 frames)
VELOCITY_MEAN = [0.859184, -0.783945] # [vx, vy] # Yaw-compensated horizontal velocity: [v_right, v_forward]
VELOCITY_STD = [2.244513, 1.088335] # [vx, vy] VELOCITY_MEAN = [-2.902497, 3.837231] # [v_right, v_forward]
VELOCITY_STD = [3.453774, 3.722085] # [v_right, v_forward]
# TRAIN_SCENES = [ # TRAIN_SCENES = [
# "indoor_forward_3", "indoor_forward_5", "indoor_forward_6", # "indoor_forward_3", "indoor_forward_5", "indoor_forward_6",
@@ -38,10 +39,10 @@ VAL_SCENES = [
# "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy # "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
] ]
TEST_SCENES = [ TEST_SCENES = [
"indoor_forward_7", # Hard 室内 # "indoor_forward_7", # Hard 室内
"outdoor_forward_1", # Easy 室外 # "outdoor_forward_1", # Easy 室外
"outdoor_forward_5" # Hard 室外 # "outdoor_forward_5" # Hard 室外
# "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
] ]
@@ -75,7 +76,7 @@ class GRUConfig:
class HeadConfig: class HeadConfig:
input_dim: int = 128 # GRU hidden_size input_dim: int = 128 # GRU hidden_size
hidden_dim: int = 64 hidden_dim: int = 64
output_dim: int = 2 # [vx_body, vy_body] output_dim: int = 2 # [v_right, v_forward]
@dataclass @dataclass

View File

@@ -3,6 +3,7 @@ Training loop for VelocityPredictionModel.
Usage: Usage:
python -m src.velocity_prediction.train [--device cuda:0] python -m src.velocity_prediction.train [--device cuda:0]
python -m src.velocity_prediction.train --resume checkpoints/epoch_020_val_0.123456.pt
""" """
import argparse import argparse
@@ -35,8 +36,9 @@ def train_one_epoch(
epoch: int, epoch: int,
writer: SummaryWriter, writer: SummaryWriter,
log_interval: int = 50, log_interval: int = 50,
) -> float: global_step: int = 0,
"""Train for one epoch. Returns average loss.""" ) -> tuple[float, int]:
"""Train for one epoch. Returns (avg_loss, updated_global_step)."""
model.train() model.train()
total_loss = 0.0 total_loss = 0.0
num_batches = 0 num_batches = 0
@@ -59,14 +61,16 @@ def train_one_epoch(
total_loss += loss.item() total_loss += loss.item()
num_batches += 1 num_batches += 1
global_step += 1
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
elapsed = time.time() - start_time elapsed = time.time() - start_time
print(f" Epoch {epoch} | Batch {batch_idx} | Loss: {loss.item():.6f} | {elapsed:.1f}s") print(f" Epoch {epoch} | Batch {batch_idx} | Loss: {loss.item():.6f} | {elapsed:.1f}s")
writer.add_scalar("train/loss_batch", loss.item(), batch_idx) writer.add_scalar("train/loss_batch", loss.item(), global_step)
avg_loss = total_loss / max(num_batches, 1) avg_loss = total_loss / max(num_batches, 1)
return avg_loss print(f" Epoch {epoch} | Avg Loss: {avg_loss:.6f}")
return avg_loss, global_step
@torch.no_grad() @torch.no_grad()
@@ -101,6 +105,8 @@ def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda", parser.add_argument("--device", type=str, default="cuda",
help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')") help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')")
parser.add_argument("--resume", type=str, default=None,
help="Path to checkpoint .pt file to resume training from")
args = parser.parse_args() args = parser.parse_args()
set_seed(train_cfg.seed) set_seed(train_cfg.seed)
@@ -143,27 +149,54 @@ def main():
# criterion = nn.SmoothL1Loss() # criterion = nn.SmoothL1Loss()
criterion = nn.MSELoss() criterion = nn.MSELoss()
# Logging
log_dir = Path(train_cfg.log_dir)
log_dir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_dir=str(log_dir))
ckpt_dir = Path(train_cfg.checkpoint_dir) ckpt_dir = Path(train_cfg.checkpoint_dir)
ckpt_dir.mkdir(parents=True, exist_ok=True) ckpt_dir.mkdir(parents=True, exist_ok=True)
# ── Resume from checkpoint ────────────────────────────────────
start_epoch = 1
global_step = 0
best_val_loss = float("inf") best_val_loss = float("inf")
run_id = None
if args.resume is not None:
ckpt_path = Path(args.resume)
if not ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
if "scheduler_state_dict" in ckpt:
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
start_epoch = ckpt.get("epoch", 0) + 1
global_step = ckpt.get("global_step", 0)
best_val_loss = ckpt.get("best_val_loss", float("inf"))
run_id = ckpt.get("run_id", None)
print(f"Resumed from checkpoint: {ckpt_path}")
print(f" Resumed epoch={ckpt.get('epoch', '?')}, global_step={global_step}, "
f"best_val_loss={best_val_loss:.6f}")
else:
print(f"\nStarting training for {train_cfg.epochs} epochs...") print(f"\nStarting training for {train_cfg.epochs} epochs...")
# Logging — run-specific subdirectory for isolation + resume continuity
if run_id is None:
run_id = time.strftime("run_%Y%m%d_%H%M%S")
log_dir = Path(train_cfg.log_dir) / run_id
log_dir.mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_dir=str(log_dir))
print(f" seq_len={train_cfg.seq_len}, batch_size={train_cfg.batch_size}") print(f" seq_len={train_cfg.seq_len}, batch_size={train_cfg.batch_size}")
print(f" lr={train_cfg.lr}, weight_decay={train_cfg.weight_decay}") print(f" lr={train_cfg.lr}, weight_decay={train_cfg.weight_decay}")
print(f" log_dir={log_dir}, checkpoint_dir={ckpt_dir}\n") print(f" log_dir={log_dir}, checkpoint_dir={ckpt_dir}\n")
for epoch in range(1, train_cfg.epochs + 1): for epoch in range(start_epoch, train_cfg.epochs + 1):
epoch_start = time.time() epoch_start = time.time()
train_loss = train_one_epoch( train_loss, global_step = train_one_epoch(
model, train_loader, optimizer, criterion, device, epoch, writer, model, train_loader, optimizer, criterion, device, epoch, writer,
log_interval=train_cfg.log_interval, log_interval=train_cfg.log_interval,
global_step=global_step,
) )
val_loss = validate(model, val_loader, criterion, device) val_loss = validate(model, val_loader, criterion, device)
scheduler.step() scheduler.step()
@@ -187,7 +220,9 @@ def main():
"model_state_dict": model.state_dict(), "model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(), "optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(), "scheduler_state_dict": scheduler.state_dict(),
"train_loss": train_loss, "global_step": global_step,
"best_val_loss": best_val_loss,
"run_id": run_id,
"val_loss": val_loss, "val_loss": val_loss,
}, ckpt_path) }, ckpt_path)
print(f" Checkpoint saved: {ckpt_path}") print(f" Checkpoint saved: {ckpt_path}")
@@ -200,6 +235,10 @@ def main():
"epoch": epoch, "epoch": epoch,
"model_state_dict": model.state_dict(), "model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(), "optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"global_step": global_step,
"best_val_loss": best_val_loss,
"run_id": run_id,
"val_loss": val_loss, "val_loss": val_loss,
}, best_path) }, best_path)
print(f" Best model updated: {best_path} (val_loss={val_loss:.6f})") print(f" Best model updated: {best_path} (val_loss={val_loss:.6f})")