""" benchmark.py — Unified evaluation entry point. Two modes: 1. Single-model eval: python -m benchmark.benchmark --checkpoint 2. Compare mode: python -m benchmark.benchmark --compare Results are saved to benchmark/results//. """ import argparse import sys from pathlib import Path from typing import List, Optional import torch from benchmark.config import eval_cfg, TEST_SCENE_GROUPS from benchmark.evaluate import run_full_evaluation, save_results # Project root (two levels up from benchmark/benchmark.py) PROJECT_ROOT = Path(__file__).resolve().parents[1] RESULTS_DIR = PROJECT_ROOT / "benchmark" / "results" def load_checkpoint( checkpoint_path: Path, device: torch.device, ) -> torch.nn.Module: """Load a VelocityPredictionModel from a checkpoint file.""" from src.velocity_prediction.model import VelocityPredictionModel model = VelocityPredictionModel() ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) state_dict = ckpt.get("model_state_dict", ckpt) model.load_state_dict(state_dict) model.to(device) model.eval() return model def run_single_eval( checkpoint_path: Path, output_dir: Optional[Path] = None, device: torch.device = None, seq_len: Optional[int] = None, batch_size: Optional[int] = None, num_workers: Optional[int] = None, save_plots: bool = True, ) -> Path: """Evaluate a single checkpoint and save results. Returns the output directory path. """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") seq_len = seq_len or eval_cfg.seq_len batch_size = batch_size or eval_cfg.batch_size num_workers = num_workers or eval_cfg.num_workers # Derive experiment name from checkpoint filename (strip extension) exp_name = checkpoint_path.stem # e.g. "best" or "epoch_050_val_1.827390" if output_dir is None: output_dir = RESULTS_DIR / exp_name print(f"{'=' * 60}") print(f"Benchmark — Single Model Evaluation") print(f"{'=' * 60}") print(f" Checkpoint: {checkpoint_path}") print(f" Device: {device}") print(f" Seq len: {seq_len}") print(f" Batch size: {batch_size}") print(f" Output: {output_dir}") print() model = load_checkpoint(checkpoint_path, device) results = run_full_evaluation( model=model, device=device, seq_len=seq_len, batch_size=batch_size, num_workers=num_workers, event_threshold=eval_cfg.event_threshold, event_use_log=eval_cfg.event_use_log, scene_groups=TEST_SCENE_GROUPS, ) save_results(results, save_dir=output_dir, checkpoint_name=exp_name) return output_dir def run_compare( checkpoint_dir: Path, output_dir: Optional[Path] = None, device: torch.device = None, seq_len: Optional[int] = None, batch_size: Optional[int] = None, num_workers: Optional[int] = None, pattern: str = "*.pt", ) -> Path: """Evaluate all checkpoints in a directory and produce a comparison table. Returns the output directory path. """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") seq_len = seq_len or eval_cfg.seq_len batch_size = batch_size or eval_cfg.batch_size num_workers = num_workers or eval_cfg.num_workers checkpoint_paths = sorted(Path(checkpoint_dir).glob(pattern)) if not checkpoint_paths: print(f"No checkpoints found matching '{pattern}' in {checkpoint_dir}") sys.exit(1) if output_dir is None: output_dir = RESULTS_DIR / "compare" print(f"{'=' * 60}") print(f"Benchmark — Compare Mode ({len(checkpoint_paths)} checkpoints)") print(f"{'=' * 60}") print(f" Checkpoint dir: {checkpoint_dir}") print(f" Device: {device}") print(f" Seq len: {seq_len}") print(f" Batch size: {batch_size}") print(f" Output: {output_dir}") print() all_global_metrics = [] for ckpt_path in checkpoint_paths: exp_name = ckpt_path.stem print(f"\n── Evaluating {exp_name} ──") model = load_checkpoint(ckpt_path, device) results = run_full_evaluation( model=model, device=device, seq_len=seq_len, batch_size=batch_size, num_workers=num_workers, event_threshold=eval_cfg.event_threshold, event_use_log=eval_cfg.event_use_log, scene_groups=TEST_SCENE_GROUPS, ) # Save individual results ckpt_output_dir = output_dir / exp_name save_results(results, save_dir=ckpt_output_dir, checkpoint_name=exp_name) all_global_metrics.append((exp_name, results["global"])) # ── Comparison table ── print(f"\n\n{'=' * 60}") print("Comparison Summary") print(f"{'=' * 60}") header = f"{'Checkpoint':<30} {'RMSE vx':>10} {'RMSE vy':>10} {'RMSE xy':>10} {'MAE vx':>10} {'MAE vy':>10} {'R² vx':>8} {'R² vy':>8}" sep = "-" * len(header) print(header) print(sep) rows = [] for name, metrics in all_global_metrics: row = ( f"{name:<30} " f"{metrics.get('rmse_vx', 0):>10.4f} " f"{metrics.get('rmse_vy', 0):>10.4f} " f"{metrics.get('rmse_xy', 0):>10.4f} " f"{metrics.get('mae_vx', 0):>10.4f} " f"{metrics.get('mae_vy', 0):>10.4f} " f"{metrics.get('r2_vx', 0):>8.4f} " f"{metrics.get('r2_vy', 0):>8.4f}" ) print(row) rows.append(row) # Save comparison CSV import csv csv_path = output_dir / "comparison.csv" output_dir.mkdir(parents=True, exist_ok=True) fieldnames = ["checkpoint", "rmse_vx", "rmse_vy", "rmse_xy", "mae_vx", "mae_vy", "mae_xy", "r2_vx", "r2_vy", "count"] with open(csv_path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() for name, metrics in all_global_metrics: row = {"checkpoint": name, **metrics} writer.writerow(row) print(f"\nComparison CSV: {csv_path}") # Save comparison text txt_path = output_dir / "comparison.txt" with open(txt_path, "w") as f: f.write("Benchmark Comparison\n") f.write(f"{'=' * 60}\n\n") f.write(header + "\n") f.write(sep + "\n") for row in rows: f.write(row + "\n") print(f"Comparison TXT: {txt_path}") return output_dir def main(): parser = argparse.ArgumentParser( description="Unified benchmark for velocity prediction models.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=( "Examples:\n" " # Single model evaluation\n" " python -m benchmark.benchmark --checkpoint checkpoints/best.pt\n\n" " # Compare all checkpoints in a directory\n" " python -m benchmark.benchmark --compare checkpoints/\n\n" " # Custom output directory\n" " python -m benchmark.benchmark --checkpoint checkpoints/best.pt --output my_results/\n" ), ) # Mutually exclusive mode selection mode = parser.add_mutually_exclusive_group(required=True) mode.add_argument( "--checkpoint", type=str, default=None, help="Path to a single checkpoint .pt file for single-model evaluation.", ) mode.add_argument( "--compare", type=str, default=None, help="Directory containing multiple .pt checkpoints for comparison.", ) # Optional overrides parser.add_argument( "--output", type=str, default=None, help="Output directory for results (default: benchmark/results//).", ) parser.add_argument( "--device", type=str, default=None, help="Device override, e.g. 'cuda:0' or 'cpu' (default: auto-detect).", ) parser.add_argument( "--seq-len", type=int, default=None, help=f"Sequence length override (default: {eval_cfg.seq_len}).", ) parser.add_argument( "--batch-size", type=int, default=None, help=f"Batch size override (default: {eval_cfg.batch_size}).", ) parser.add_argument( "--num-workers", type=int, default=None, help=f"DataLoader workers override (default: {eval_cfg.num_workers}).", ) parser.add_argument( "--pattern", type=str, default="*.pt", help="Glob pattern for --compare mode (default: '*.pt').", ) parser.add_argument( "--no-plots", action="store_true", help="Skip generating per-scene plots.", ) args = parser.parse_args() # Resolve device device = None if args.device is not None: device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu") # Resolve output directory output_dir = Path(args.output) if args.output else None if args.checkpoint: ckpt_path = Path(args.checkpoint) if not ckpt_path.exists(): print(f"Error: checkpoint not found: {ckpt_path}") sys.exit(1) run_single_eval( checkpoint_path=ckpt_path, output_dir=output_dir, device=device, seq_len=args.seq_len, batch_size=args.batch_size, num_workers=args.num_workers, save_plots=not args.no_plots, ) elif args.compare: ckpt_dir = Path(args.compare) if not ckpt_dir.is_dir(): print(f"Error: checkpoint directory not found: {ckpt_dir}") sys.exit(1) run_compare( checkpoint_dir=ckpt_dir, output_dir=output_dir, device=device, seq_len=args.seq_len, batch_size=args.batch_size, num_workers=args.num_workers, pattern=args.pattern, ) if __name__ == "__main__": main()