feat: add checkpoint resume and fix train_loss tracking
- Add --resume CLI arg to resume training from a checkpoint - Restore model, optimizer, scheduler state; continue from saved epoch+1 - Preserve global_step and best_val_loss across resume - Save run_id in checkpoints for TensorBoard log continuity - Use logs/run_<timestamp>/ subdirectories to isolate experiment logs - Fix: replace train_loss in checkpoint dict with global_step to avoid KeyError when loading; track global_step through train_one_epoch - Fix: use global_step (not batch_idx) as TensorBoard x-axis for batch loss - Fix: print average loss at end of each epoch Generated by Mistral Vibe (ds-v4-flash). Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit is contained in:
@@ -10,9 +10,10 @@ from pathlib import Path
|
|||||||
|
|
||||||
DATASET_ROOT = Path(__file__).resolve().parents[2] / "dataset"
|
DATASET_ROOT = Path(__file__).resolve().parents[2] / "dataset"
|
||||||
|
|
||||||
# Velocity normalization stats (computed from training set)
|
# Velocity normalization stats (computed from forward scenes only, 28363 frames)
|
||||||
VELOCITY_MEAN = [0.859184, -0.783945] # [vx, vy]
|
# Yaw-compensated horizontal velocity: [v_right, v_forward]
|
||||||
VELOCITY_STD = [2.244513, 1.088335] # [vx, vy]
|
VELOCITY_MEAN = [-2.902497, 3.837231] # [v_right, v_forward]
|
||||||
|
VELOCITY_STD = [3.453774, 3.722085] # [v_right, v_forward]
|
||||||
|
|
||||||
# TRAIN_SCENES = [
|
# TRAIN_SCENES = [
|
||||||
# "indoor_forward_3", "indoor_forward_5", "indoor_forward_6",
|
# "indoor_forward_3", "indoor_forward_5", "indoor_forward_6",
|
||||||
@@ -38,10 +39,10 @@ VAL_SCENES = [
|
|||||||
# "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
# "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
||||||
]
|
]
|
||||||
TEST_SCENES = [
|
TEST_SCENES = [
|
||||||
"indoor_forward_7", # Hard 室内
|
# "indoor_forward_7", # Hard 室内
|
||||||
"outdoor_forward_1", # Easy 室外
|
# "outdoor_forward_1", # Easy 室外
|
||||||
"outdoor_forward_5" # Hard 室外
|
# "outdoor_forward_5" # Hard 室外
|
||||||
# "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
"indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -75,7 +76,7 @@ class GRUConfig:
|
|||||||
class HeadConfig:
|
class HeadConfig:
|
||||||
input_dim: int = 128 # GRU hidden_size
|
input_dim: int = 128 # GRU hidden_size
|
||||||
hidden_dim: int = 64
|
hidden_dim: int = 64
|
||||||
output_dim: int = 2 # [vx_body, vy_body]
|
output_dim: int = 2 # [v_right, v_forward]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ Training loop for VelocityPredictionModel.
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python -m src.velocity_prediction.train [--device cuda:0]
|
python -m src.velocity_prediction.train [--device cuda:0]
|
||||||
|
python -m src.velocity_prediction.train --resume checkpoints/epoch_020_val_0.123456.pt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@@ -35,8 +36,9 @@ def train_one_epoch(
|
|||||||
epoch: int,
|
epoch: int,
|
||||||
writer: SummaryWriter,
|
writer: SummaryWriter,
|
||||||
log_interval: int = 50,
|
log_interval: int = 50,
|
||||||
) -> float:
|
global_step: int = 0,
|
||||||
"""Train for one epoch. Returns average loss."""
|
) -> tuple[float, int]:
|
||||||
|
"""Train for one epoch. Returns (avg_loss, updated_global_step)."""
|
||||||
model.train()
|
model.train()
|
||||||
total_loss = 0.0
|
total_loss = 0.0
|
||||||
num_batches = 0
|
num_batches = 0
|
||||||
@@ -59,14 +61,16 @@ def train_one_epoch(
|
|||||||
|
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
num_batches += 1
|
num_batches += 1
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
if batch_idx % log_interval == 0:
|
if batch_idx % log_interval == 0:
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
print(f" Epoch {epoch} | Batch {batch_idx} | Loss: {loss.item():.6f} | {elapsed:.1f}s")
|
print(f" Epoch {epoch} | Batch {batch_idx} | Loss: {loss.item():.6f} | {elapsed:.1f}s")
|
||||||
writer.add_scalar("train/loss_batch", loss.item(), batch_idx)
|
writer.add_scalar("train/loss_batch", loss.item(), global_step)
|
||||||
|
|
||||||
avg_loss = total_loss / max(num_batches, 1)
|
avg_loss = total_loss / max(num_batches, 1)
|
||||||
return avg_loss
|
print(f" Epoch {epoch} | Avg Loss: {avg_loss:.6f}")
|
||||||
|
return avg_loss, global_step
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -101,6 +105,8 @@ def main():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--device", type=str, default="cuda",
|
parser.add_argument("--device", type=str, default="cuda",
|
||||||
help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')")
|
help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')")
|
||||||
|
parser.add_argument("--resume", type=str, default=None,
|
||||||
|
help="Path to checkpoint .pt file to resume training from")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
set_seed(train_cfg.seed)
|
set_seed(train_cfg.seed)
|
||||||
@@ -143,27 +149,54 @@ def main():
|
|||||||
# criterion = nn.SmoothL1Loss()
|
# criterion = nn.SmoothL1Loss()
|
||||||
criterion = nn.MSELoss()
|
criterion = nn.MSELoss()
|
||||||
|
|
||||||
# Logging
|
|
||||||
log_dir = Path(train_cfg.log_dir)
|
|
||||||
log_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
writer = SummaryWriter(log_dir=str(log_dir))
|
|
||||||
|
|
||||||
ckpt_dir = Path(train_cfg.checkpoint_dir)
|
ckpt_dir = Path(train_cfg.checkpoint_dir)
|
||||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ── Resume from checkpoint ────────────────────────────────────
|
||||||
|
start_epoch = 1
|
||||||
|
global_step = 0
|
||||||
best_val_loss = float("inf")
|
best_val_loss = float("inf")
|
||||||
|
run_id = None
|
||||||
|
|
||||||
|
if args.resume is not None:
|
||||||
|
ckpt_path = Path(args.resume)
|
||||||
|
if not ckpt_path.exists():
|
||||||
|
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
||||||
|
ckpt = torch.load(ckpt_path, map_location="cpu")
|
||||||
|
|
||||||
|
model.load_state_dict(ckpt["model_state_dict"])
|
||||||
|
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
||||||
|
if "scheduler_state_dict" in ckpt:
|
||||||
|
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
|
||||||
|
start_epoch = ckpt.get("epoch", 0) + 1
|
||||||
|
global_step = ckpt.get("global_step", 0)
|
||||||
|
best_val_loss = ckpt.get("best_val_loss", float("inf"))
|
||||||
|
run_id = ckpt.get("run_id", None)
|
||||||
|
|
||||||
|
print(f"Resumed from checkpoint: {ckpt_path}")
|
||||||
|
print(f" Resumed epoch={ckpt.get('epoch', '?')}, global_step={global_step}, "
|
||||||
|
f"best_val_loss={best_val_loss:.6f}")
|
||||||
|
else:
|
||||||
|
print(f"\nStarting training for {train_cfg.epochs} epochs...")
|
||||||
|
|
||||||
|
# Logging — run-specific subdirectory for isolation + resume continuity
|
||||||
|
if run_id is None:
|
||||||
|
run_id = time.strftime("run_%Y%m%d_%H%M%S")
|
||||||
|
log_dir = Path(train_cfg.log_dir) / run_id
|
||||||
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
writer = SummaryWriter(log_dir=str(log_dir))
|
||||||
|
|
||||||
print(f"\nStarting training for {train_cfg.epochs} epochs...")
|
|
||||||
print(f" seq_len={train_cfg.seq_len}, batch_size={train_cfg.batch_size}")
|
print(f" seq_len={train_cfg.seq_len}, batch_size={train_cfg.batch_size}")
|
||||||
print(f" lr={train_cfg.lr}, weight_decay={train_cfg.weight_decay}")
|
print(f" lr={train_cfg.lr}, weight_decay={train_cfg.weight_decay}")
|
||||||
print(f" log_dir={log_dir}, checkpoint_dir={ckpt_dir}\n")
|
print(f" log_dir={log_dir}, checkpoint_dir={ckpt_dir}\n")
|
||||||
|
|
||||||
for epoch in range(1, train_cfg.epochs + 1):
|
for epoch in range(start_epoch, train_cfg.epochs + 1):
|
||||||
epoch_start = time.time()
|
epoch_start = time.time()
|
||||||
|
|
||||||
train_loss = train_one_epoch(
|
train_loss, global_step = train_one_epoch(
|
||||||
model, train_loader, optimizer, criterion, device, epoch, writer,
|
model, train_loader, optimizer, criterion, device, epoch, writer,
|
||||||
log_interval=train_cfg.log_interval,
|
log_interval=train_cfg.log_interval,
|
||||||
|
global_step=global_step,
|
||||||
)
|
)
|
||||||
val_loss = validate(model, val_loader, criterion, device)
|
val_loss = validate(model, val_loader, criterion, device)
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
@@ -187,7 +220,9 @@ def main():
|
|||||||
"model_state_dict": model.state_dict(),
|
"model_state_dict": model.state_dict(),
|
||||||
"optimizer_state_dict": optimizer.state_dict(),
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
"scheduler_state_dict": scheduler.state_dict(),
|
"scheduler_state_dict": scheduler.state_dict(),
|
||||||
"train_loss": train_loss,
|
"global_step": global_step,
|
||||||
|
"best_val_loss": best_val_loss,
|
||||||
|
"run_id": run_id,
|
||||||
"val_loss": val_loss,
|
"val_loss": val_loss,
|
||||||
}, ckpt_path)
|
}, ckpt_path)
|
||||||
print(f" Checkpoint saved: {ckpt_path}")
|
print(f" Checkpoint saved: {ckpt_path}")
|
||||||
@@ -200,6 +235,10 @@ def main():
|
|||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"model_state_dict": model.state_dict(),
|
"model_state_dict": model.state_dict(),
|
||||||
"optimizer_state_dict": optimizer.state_dict(),
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"scheduler_state_dict": scheduler.state_dict(),
|
||||||
|
"global_step": global_step,
|
||||||
|
"best_val_loss": best_val_loss,
|
||||||
|
"run_id": run_id,
|
||||||
"val_loss": val_loss,
|
"val_loss": val_loss,
|
||||||
}, best_path)
|
}, best_path)
|
||||||
print(f" Best model updated: {best_path} (val_loss={val_loss:.6f})")
|
print(f" Best model updated: {best_path} (val_loss={val_loss:.6f})")
|
||||||
|
|||||||
Reference in New Issue
Block a user