initial commit
This commit is contained in:
309
benchmark/benchmark.py
Normal file
309
benchmark/benchmark.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
benchmark.py — Unified evaluation entry point.
|
||||
|
||||
Two modes:
|
||||
1. Single-model eval: python -m benchmark.benchmark --checkpoint <path>
|
||||
2. Compare mode: python -m benchmark.benchmark --compare <checkpoint_dir>
|
||||
|
||||
Results are saved to benchmark/results/<exp_name>/.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from benchmark.config import eval_cfg, TEST_SCENE_GROUPS
|
||||
from benchmark.evaluate import run_full_evaluation, save_results
|
||||
|
||||
# Project root (two levels up from benchmark/benchmark.py)
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
RESULTS_DIR = PROJECT_ROOT / "benchmark" / "results"
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint_path: Path,
|
||||
device: torch.device,
|
||||
) -> torch.nn.Module:
|
||||
"""Load a VelocityPredictionModel from a checkpoint file."""
|
||||
from src.velocity_prediction.model import VelocityPredictionModel
|
||||
|
||||
model = VelocityPredictionModel()
|
||||
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
||||
state_dict = ckpt.get("model_state_dict", ckpt)
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def run_single_eval(
|
||||
checkpoint_path: Path,
|
||||
output_dir: Optional[Path] = None,
|
||||
device: torch.device = None,
|
||||
seq_len: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
save_plots: bool = True,
|
||||
) -> Path:
|
||||
"""Evaluate a single checkpoint and save results.
|
||||
|
||||
Returns the output directory path.
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
seq_len = seq_len or eval_cfg.seq_len
|
||||
batch_size = batch_size or eval_cfg.batch_size
|
||||
num_workers = num_workers or eval_cfg.num_workers
|
||||
|
||||
# Derive experiment name from checkpoint filename (strip extension)
|
||||
exp_name = checkpoint_path.stem # e.g. "best" or "epoch_050_val_1.827390"
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = RESULTS_DIR / exp_name
|
||||
|
||||
print(f"{'=' * 60}")
|
||||
print(f"Benchmark — Single Model Evaluation")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Checkpoint: {checkpoint_path}")
|
||||
print(f" Device: {device}")
|
||||
print(f" Seq len: {seq_len}")
|
||||
print(f" Batch size: {batch_size}")
|
||||
print(f" Output: {output_dir}")
|
||||
print()
|
||||
|
||||
model = load_checkpoint(checkpoint_path, device)
|
||||
|
||||
results = run_full_evaluation(
|
||||
model=model,
|
||||
device=device,
|
||||
seq_len=seq_len,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
event_threshold=eval_cfg.event_threshold,
|
||||
event_use_log=eval_cfg.event_use_log,
|
||||
scene_groups=TEST_SCENE_GROUPS,
|
||||
)
|
||||
|
||||
save_results(results, save_dir=output_dir, checkpoint_name=exp_name)
|
||||
|
||||
return output_dir
|
||||
|
||||
|
||||
def run_compare(
|
||||
checkpoint_dir: Path,
|
||||
output_dir: Optional[Path] = None,
|
||||
device: torch.device = None,
|
||||
seq_len: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
pattern: str = "*.pt",
|
||||
) -> Path:
|
||||
"""Evaluate all checkpoints in a directory and produce a comparison table.
|
||||
|
||||
Returns the output directory path.
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
seq_len = seq_len or eval_cfg.seq_len
|
||||
batch_size = batch_size or eval_cfg.batch_size
|
||||
num_workers = num_workers or eval_cfg.num_workers
|
||||
|
||||
checkpoint_paths = sorted(Path(checkpoint_dir).glob(pattern))
|
||||
if not checkpoint_paths:
|
||||
print(f"No checkpoints found matching '{pattern}' in {checkpoint_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = RESULTS_DIR / "compare"
|
||||
|
||||
print(f"{'=' * 60}")
|
||||
print(f"Benchmark — Compare Mode ({len(checkpoint_paths)} checkpoints)")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Checkpoint dir: {checkpoint_dir}")
|
||||
print(f" Device: {device}")
|
||||
print(f" Seq len: {seq_len}")
|
||||
print(f" Batch size: {batch_size}")
|
||||
print(f" Output: {output_dir}")
|
||||
print()
|
||||
|
||||
all_global_metrics = []
|
||||
|
||||
for ckpt_path in checkpoint_paths:
|
||||
exp_name = ckpt_path.stem
|
||||
print(f"\n── Evaluating {exp_name} ──")
|
||||
|
||||
model = load_checkpoint(ckpt_path, device)
|
||||
|
||||
results = run_full_evaluation(
|
||||
model=model,
|
||||
device=device,
|
||||
seq_len=seq_len,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
event_threshold=eval_cfg.event_threshold,
|
||||
event_use_log=eval_cfg.event_use_log,
|
||||
scene_groups=TEST_SCENE_GROUPS,
|
||||
)
|
||||
|
||||
# Save individual results
|
||||
ckpt_output_dir = output_dir / exp_name
|
||||
save_results(results, save_dir=ckpt_output_dir, checkpoint_name=exp_name)
|
||||
|
||||
all_global_metrics.append((exp_name, results["global"]))
|
||||
|
||||
# ── Comparison table ──
|
||||
print(f"\n\n{'=' * 60}")
|
||||
print("Comparison Summary")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
header = f"{'Checkpoint':<30} {'RMSE vx':>10} {'RMSE vy':>10} {'RMSE xy':>10} {'MAE vx':>10} {'MAE vy':>10} {'R² vx':>8} {'R² vy':>8}"
|
||||
sep = "-" * len(header)
|
||||
print(header)
|
||||
print(sep)
|
||||
|
||||
rows = []
|
||||
for name, metrics in all_global_metrics:
|
||||
row = (
|
||||
f"{name:<30} "
|
||||
f"{metrics.get('rmse_vx', 0):>10.4f} "
|
||||
f"{metrics.get('rmse_vy', 0):>10.4f} "
|
||||
f"{metrics.get('rmse_xy', 0):>10.4f} "
|
||||
f"{metrics.get('mae_vx', 0):>10.4f} "
|
||||
f"{metrics.get('mae_vy', 0):>10.4f} "
|
||||
f"{metrics.get('r2_vx', 0):>8.4f} "
|
||||
f"{metrics.get('r2_vy', 0):>8.4f}"
|
||||
)
|
||||
print(row)
|
||||
rows.append(row)
|
||||
|
||||
# Save comparison CSV
|
||||
import csv
|
||||
csv_path = output_dir / "comparison.csv"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
fieldnames = ["checkpoint", "rmse_vx", "rmse_vy", "rmse_xy", "mae_vx", "mae_vy",
|
||||
"mae_xy", "r2_vx", "r2_vy", "count"]
|
||||
with open(csv_path, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for name, metrics in all_global_metrics:
|
||||
row = {"checkpoint": name, **metrics}
|
||||
writer.writerow(row)
|
||||
print(f"\nComparison CSV: {csv_path}")
|
||||
|
||||
# Save comparison text
|
||||
txt_path = output_dir / "comparison.txt"
|
||||
with open(txt_path, "w") as f:
|
||||
f.write("Benchmark Comparison\n")
|
||||
f.write(f"{'=' * 60}\n\n")
|
||||
f.write(header + "\n")
|
||||
f.write(sep + "\n")
|
||||
for row in rows:
|
||||
f.write(row + "\n")
|
||||
print(f"Comparison TXT: {txt_path}")
|
||||
|
||||
return output_dir
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Unified benchmark for velocity prediction models.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=(
|
||||
"Examples:\n"
|
||||
" # Single model evaluation\n"
|
||||
" python -m benchmark.benchmark --checkpoint checkpoints/best.pt\n\n"
|
||||
" # Compare all checkpoints in a directory\n"
|
||||
" python -m benchmark.benchmark --compare checkpoints/\n\n"
|
||||
" # Custom output directory\n"
|
||||
" python -m benchmark.benchmark --checkpoint checkpoints/best.pt --output my_results/\n"
|
||||
),
|
||||
)
|
||||
|
||||
# Mutually exclusive mode selection
|
||||
mode = parser.add_mutually_exclusive_group(required=True)
|
||||
mode.add_argument(
|
||||
"--checkpoint", type=str, default=None,
|
||||
help="Path to a single checkpoint .pt file for single-model evaluation.",
|
||||
)
|
||||
mode.add_argument(
|
||||
"--compare", type=str, default=None,
|
||||
help="Directory containing multiple .pt checkpoints for comparison.",
|
||||
)
|
||||
|
||||
# Optional overrides
|
||||
parser.add_argument(
|
||||
"--output", type=str, default=None,
|
||||
help="Output directory for results (default: benchmark/results/<exp_name>/).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default=None,
|
||||
help="Device override, e.g. 'cuda:0' or 'cpu' (default: auto-detect).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq-len", type=int, default=None,
|
||||
help=f"Sequence length override (default: {eval_cfg.seq_len}).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=None,
|
||||
help=f"Batch size override (default: {eval_cfg.batch_size}).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers", type=int, default=None,
|
||||
help=f"DataLoader workers override (default: {eval_cfg.num_workers}).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pattern", type=str, default="*.pt",
|
||||
help="Glob pattern for --compare mode (default: '*.pt').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-plots", action="store_true",
|
||||
help="Skip generating per-scene plots.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Resolve device
|
||||
device = None
|
||||
if args.device is not None:
|
||||
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
|
||||
|
||||
# Resolve output directory
|
||||
output_dir = Path(args.output) if args.output else None
|
||||
|
||||
if args.checkpoint:
|
||||
ckpt_path = Path(args.checkpoint)
|
||||
if not ckpt_path.exists():
|
||||
print(f"Error: checkpoint not found: {ckpt_path}")
|
||||
sys.exit(1)
|
||||
run_single_eval(
|
||||
checkpoint_path=ckpt_path,
|
||||
output_dir=output_dir,
|
||||
device=device,
|
||||
seq_len=args.seq_len,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
save_plots=not args.no_plots,
|
||||
)
|
||||
elif args.compare:
|
||||
ckpt_dir = Path(args.compare)
|
||||
if not ckpt_dir.is_dir():
|
||||
print(f"Error: checkpoint directory not found: {ckpt_dir}")
|
||||
sys.exit(1)
|
||||
run_compare(
|
||||
checkpoint_dir=ckpt_dir,
|
||||
output_dir=output_dir,
|
||||
device=device,
|
||||
seq_len=args.seq_len,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pattern=args.pattern,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user