diff --git a/src/velocity_prediction/config.py b/src/velocity_prediction/config.py index f994ddf..8e1c128 100644 --- a/src/velocity_prediction/config.py +++ b/src/velocity_prediction/config.py @@ -10,9 +10,10 @@ from pathlib import Path DATASET_ROOT = Path(__file__).resolve().parents[2] / "dataset" -# Velocity normalization stats (computed from training set) -VELOCITY_MEAN = [0.859184, -0.783945] # [vx, vy] -VELOCITY_STD = [2.244513, 1.088335] # [vx, vy] +# Velocity normalization stats (computed from forward scenes only, 28363 frames) +# Yaw-compensated horizontal velocity: [v_right, v_forward] +VELOCITY_MEAN = [-2.902497, 3.837231] # [v_right, v_forward] +VELOCITY_STD = [3.453774, 3.722085] # [v_right, v_forward] # TRAIN_SCENES = [ # "indoor_forward_3", "indoor_forward_5", "indoor_forward_6", @@ -38,10 +39,10 @@ VAL_SCENES = [ # "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy ] TEST_SCENES = [ - "indoor_forward_7", # Hard 室内 - "outdoor_forward_1", # Easy 室外 - "outdoor_forward_5" # Hard 室外 - # "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy + # "indoor_forward_7", # Hard 室内 + # "outdoor_forward_1", # Easy 室外 + # "outdoor_forward_5" # Hard 室外 + "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy ] @@ -75,7 +76,7 @@ class GRUConfig: class HeadConfig: input_dim: int = 128 # GRU hidden_size hidden_dim: int = 64 - output_dim: int = 2 # [vx_body, vy_body] + output_dim: int = 2 # [v_right, v_forward] @dataclass diff --git a/src/velocity_prediction/train.py b/src/velocity_prediction/train.py index 6c45aad..34b88cb 100644 --- a/src/velocity_prediction/train.py +++ b/src/velocity_prediction/train.py @@ -3,6 +3,7 @@ Training loop for VelocityPredictionModel. Usage: python -m src.velocity_prediction.train [--device cuda:0] + python -m src.velocity_prediction.train --resume checkpoints/epoch_020_val_0.123456.pt """ import argparse @@ -35,8 +36,9 @@ def train_one_epoch( epoch: int, writer: SummaryWriter, log_interval: int = 50, -) -> float: - """Train for one epoch. Returns average loss.""" + global_step: int = 0, +) -> tuple[float, int]: + """Train for one epoch. Returns (avg_loss, updated_global_step).""" model.train() total_loss = 0.0 num_batches = 0 @@ -59,14 +61,16 @@ def train_one_epoch( total_loss += loss.item() num_batches += 1 + global_step += 1 if batch_idx % log_interval == 0: elapsed = time.time() - start_time print(f" Epoch {epoch} | Batch {batch_idx} | Loss: {loss.item():.6f} | {elapsed:.1f}s") - writer.add_scalar("train/loss_batch", loss.item(), batch_idx) + writer.add_scalar("train/loss_batch", loss.item(), global_step) avg_loss = total_loss / max(num_batches, 1) - return avg_loss + print(f" Epoch {epoch} | Avg Loss: {avg_loss:.6f}") + return avg_loss, global_step @torch.no_grad() @@ -101,6 +105,8 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cuda", help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')") + parser.add_argument("--resume", type=str, default=None, + help="Path to checkpoint .pt file to resume training from") args = parser.parse_args() set_seed(train_cfg.seed) @@ -143,27 +149,54 @@ def main(): # criterion = nn.SmoothL1Loss() criterion = nn.MSELoss() - # Logging - log_dir = Path(train_cfg.log_dir) - log_dir.mkdir(parents=True, exist_ok=True) - writer = SummaryWriter(log_dir=str(log_dir)) - ckpt_dir = Path(train_cfg.checkpoint_dir) ckpt_dir.mkdir(parents=True, exist_ok=True) + # ── Resume from checkpoint ──────────────────────────────────── + start_epoch = 1 + global_step = 0 best_val_loss = float("inf") + run_id = None + + if args.resume is not None: + ckpt_path = Path(args.resume) + if not ckpt_path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}") + ckpt = torch.load(ckpt_path, map_location="cpu") + + model.load_state_dict(ckpt["model_state_dict"]) + optimizer.load_state_dict(ckpt["optimizer_state_dict"]) + if "scheduler_state_dict" in ckpt: + scheduler.load_state_dict(ckpt["scheduler_state_dict"]) + start_epoch = ckpt.get("epoch", 0) + 1 + global_step = ckpt.get("global_step", 0) + best_val_loss = ckpt.get("best_val_loss", float("inf")) + run_id = ckpt.get("run_id", None) + + print(f"Resumed from checkpoint: {ckpt_path}") + print(f" Resumed epoch={ckpt.get('epoch', '?')}, global_step={global_step}, " + f"best_val_loss={best_val_loss:.6f}") + else: + print(f"\nStarting training for {train_cfg.epochs} epochs...") + + # Logging — run-specific subdirectory for isolation + resume continuity + if run_id is None: + run_id = time.strftime("run_%Y%m%d_%H%M%S") + log_dir = Path(train_cfg.log_dir) / run_id + log_dir.mkdir(parents=True, exist_ok=True) + writer = SummaryWriter(log_dir=str(log_dir)) - print(f"\nStarting training for {train_cfg.epochs} epochs...") print(f" seq_len={train_cfg.seq_len}, batch_size={train_cfg.batch_size}") print(f" lr={train_cfg.lr}, weight_decay={train_cfg.weight_decay}") print(f" log_dir={log_dir}, checkpoint_dir={ckpt_dir}\n") - for epoch in range(1, train_cfg.epochs + 1): + for epoch in range(start_epoch, train_cfg.epochs + 1): epoch_start = time.time() - train_loss = train_one_epoch( + train_loss, global_step = train_one_epoch( model, train_loader, optimizer, criterion, device, epoch, writer, log_interval=train_cfg.log_interval, + global_step=global_step, ) val_loss = validate(model, val_loader, criterion, device) scheduler.step() @@ -187,7 +220,9 @@ def main(): "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), - "train_loss": train_loss, + "global_step": global_step, + "best_val_loss": best_val_loss, + "run_id": run_id, "val_loss": val_loss, }, ckpt_path) print(f" Checkpoint saved: {ckpt_path}") @@ -200,6 +235,10 @@ def main(): "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "global_step": global_step, + "best_val_loss": best_val_loss, + "run_id": run_id, "val_loss": val_loss, }, best_path) print(f" Best model updated: {best_path} (val_loss={val_loss:.6f})")