initial commit
This commit is contained in:
0
benchmark/__init__.py
Normal file
0
benchmark/__init__.py
Normal file
309
benchmark/benchmark.py
Normal file
309
benchmark/benchmark.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
benchmark.py — Unified evaluation entry point.
|
||||
|
||||
Two modes:
|
||||
1. Single-model eval: python -m benchmark.benchmark --checkpoint <path>
|
||||
2. Compare mode: python -m benchmark.benchmark --compare <checkpoint_dir>
|
||||
|
||||
Results are saved to benchmark/results/<exp_name>/.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from benchmark.config import eval_cfg, TEST_SCENE_GROUPS
|
||||
from benchmark.evaluate import run_full_evaluation, save_results
|
||||
|
||||
# Project root (two levels up from benchmark/benchmark.py)
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||
RESULTS_DIR = PROJECT_ROOT / "benchmark" / "results"
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint_path: Path,
|
||||
device: torch.device,
|
||||
) -> torch.nn.Module:
|
||||
"""Load a VelocityPredictionModel from a checkpoint file."""
|
||||
from src.velocity_prediction.model import VelocityPredictionModel
|
||||
|
||||
model = VelocityPredictionModel()
|
||||
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
||||
state_dict = ckpt.get("model_state_dict", ckpt)
|
||||
model.load_state_dict(state_dict)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def run_single_eval(
|
||||
checkpoint_path: Path,
|
||||
output_dir: Optional[Path] = None,
|
||||
device: torch.device = None,
|
||||
seq_len: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
save_plots: bool = True,
|
||||
) -> Path:
|
||||
"""Evaluate a single checkpoint and save results.
|
||||
|
||||
Returns the output directory path.
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
seq_len = seq_len or eval_cfg.seq_len
|
||||
batch_size = batch_size or eval_cfg.batch_size
|
||||
num_workers = num_workers or eval_cfg.num_workers
|
||||
|
||||
# Derive experiment name from checkpoint filename (strip extension)
|
||||
exp_name = checkpoint_path.stem # e.g. "best" or "epoch_050_val_1.827390"
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = RESULTS_DIR / exp_name
|
||||
|
||||
print(f"{'=' * 60}")
|
||||
print(f"Benchmark — Single Model Evaluation")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Checkpoint: {checkpoint_path}")
|
||||
print(f" Device: {device}")
|
||||
print(f" Seq len: {seq_len}")
|
||||
print(f" Batch size: {batch_size}")
|
||||
print(f" Output: {output_dir}")
|
||||
print()
|
||||
|
||||
model = load_checkpoint(checkpoint_path, device)
|
||||
|
||||
results = run_full_evaluation(
|
||||
model=model,
|
||||
device=device,
|
||||
seq_len=seq_len,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
event_threshold=eval_cfg.event_threshold,
|
||||
event_use_log=eval_cfg.event_use_log,
|
||||
scene_groups=TEST_SCENE_GROUPS,
|
||||
)
|
||||
|
||||
save_results(results, save_dir=output_dir, checkpoint_name=exp_name)
|
||||
|
||||
return output_dir
|
||||
|
||||
|
||||
def run_compare(
|
||||
checkpoint_dir: Path,
|
||||
output_dir: Optional[Path] = None,
|
||||
device: torch.device = None,
|
||||
seq_len: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
pattern: str = "*.pt",
|
||||
) -> Path:
|
||||
"""Evaluate all checkpoints in a directory and produce a comparison table.
|
||||
|
||||
Returns the output directory path.
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
seq_len = seq_len or eval_cfg.seq_len
|
||||
batch_size = batch_size or eval_cfg.batch_size
|
||||
num_workers = num_workers or eval_cfg.num_workers
|
||||
|
||||
checkpoint_paths = sorted(Path(checkpoint_dir).glob(pattern))
|
||||
if not checkpoint_paths:
|
||||
print(f"No checkpoints found matching '{pattern}' in {checkpoint_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = RESULTS_DIR / "compare"
|
||||
|
||||
print(f"{'=' * 60}")
|
||||
print(f"Benchmark — Compare Mode ({len(checkpoint_paths)} checkpoints)")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Checkpoint dir: {checkpoint_dir}")
|
||||
print(f" Device: {device}")
|
||||
print(f" Seq len: {seq_len}")
|
||||
print(f" Batch size: {batch_size}")
|
||||
print(f" Output: {output_dir}")
|
||||
print()
|
||||
|
||||
all_global_metrics = []
|
||||
|
||||
for ckpt_path in checkpoint_paths:
|
||||
exp_name = ckpt_path.stem
|
||||
print(f"\n── Evaluating {exp_name} ──")
|
||||
|
||||
model = load_checkpoint(ckpt_path, device)
|
||||
|
||||
results = run_full_evaluation(
|
||||
model=model,
|
||||
device=device,
|
||||
seq_len=seq_len,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
event_threshold=eval_cfg.event_threshold,
|
||||
event_use_log=eval_cfg.event_use_log,
|
||||
scene_groups=TEST_SCENE_GROUPS,
|
||||
)
|
||||
|
||||
# Save individual results
|
||||
ckpt_output_dir = output_dir / exp_name
|
||||
save_results(results, save_dir=ckpt_output_dir, checkpoint_name=exp_name)
|
||||
|
||||
all_global_metrics.append((exp_name, results["global"]))
|
||||
|
||||
# ── Comparison table ──
|
||||
print(f"\n\n{'=' * 60}")
|
||||
print("Comparison Summary")
|
||||
print(f"{'=' * 60}")
|
||||
|
||||
header = f"{'Checkpoint':<30} {'RMSE vx':>10} {'RMSE vy':>10} {'RMSE xy':>10} {'MAE vx':>10} {'MAE vy':>10} {'R² vx':>8} {'R² vy':>8}"
|
||||
sep = "-" * len(header)
|
||||
print(header)
|
||||
print(sep)
|
||||
|
||||
rows = []
|
||||
for name, metrics in all_global_metrics:
|
||||
row = (
|
||||
f"{name:<30} "
|
||||
f"{metrics.get('rmse_vx', 0):>10.4f} "
|
||||
f"{metrics.get('rmse_vy', 0):>10.4f} "
|
||||
f"{metrics.get('rmse_xy', 0):>10.4f} "
|
||||
f"{metrics.get('mae_vx', 0):>10.4f} "
|
||||
f"{metrics.get('mae_vy', 0):>10.4f} "
|
||||
f"{metrics.get('r2_vx', 0):>8.4f} "
|
||||
f"{metrics.get('r2_vy', 0):>8.4f}"
|
||||
)
|
||||
print(row)
|
||||
rows.append(row)
|
||||
|
||||
# Save comparison CSV
|
||||
import csv
|
||||
csv_path = output_dir / "comparison.csv"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
fieldnames = ["checkpoint", "rmse_vx", "rmse_vy", "rmse_xy", "mae_vx", "mae_vy",
|
||||
"mae_xy", "r2_vx", "r2_vy", "count"]
|
||||
with open(csv_path, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for name, metrics in all_global_metrics:
|
||||
row = {"checkpoint": name, **metrics}
|
||||
writer.writerow(row)
|
||||
print(f"\nComparison CSV: {csv_path}")
|
||||
|
||||
# Save comparison text
|
||||
txt_path = output_dir / "comparison.txt"
|
||||
with open(txt_path, "w") as f:
|
||||
f.write("Benchmark Comparison\n")
|
||||
f.write(f"{'=' * 60}\n\n")
|
||||
f.write(header + "\n")
|
||||
f.write(sep + "\n")
|
||||
for row in rows:
|
||||
f.write(row + "\n")
|
||||
print(f"Comparison TXT: {txt_path}")
|
||||
|
||||
return output_dir
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Unified benchmark for velocity prediction models.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=(
|
||||
"Examples:\n"
|
||||
" # Single model evaluation\n"
|
||||
" python -m benchmark.benchmark --checkpoint checkpoints/best.pt\n\n"
|
||||
" # Compare all checkpoints in a directory\n"
|
||||
" python -m benchmark.benchmark --compare checkpoints/\n\n"
|
||||
" # Custom output directory\n"
|
||||
" python -m benchmark.benchmark --checkpoint checkpoints/best.pt --output my_results/\n"
|
||||
),
|
||||
)
|
||||
|
||||
# Mutually exclusive mode selection
|
||||
mode = parser.add_mutually_exclusive_group(required=True)
|
||||
mode.add_argument(
|
||||
"--checkpoint", type=str, default=None,
|
||||
help="Path to a single checkpoint .pt file for single-model evaluation.",
|
||||
)
|
||||
mode.add_argument(
|
||||
"--compare", type=str, default=None,
|
||||
help="Directory containing multiple .pt checkpoints for comparison.",
|
||||
)
|
||||
|
||||
# Optional overrides
|
||||
parser.add_argument(
|
||||
"--output", type=str, default=None,
|
||||
help="Output directory for results (default: benchmark/results/<exp_name>/).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", type=str, default=None,
|
||||
help="Device override, e.g. 'cuda:0' or 'cpu' (default: auto-detect).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seq-len", type=int, default=None,
|
||||
help=f"Sequence length override (default: {eval_cfg.seq_len}).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=None,
|
||||
help=f"Batch size override (default: {eval_cfg.batch_size}).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers", type=int, default=None,
|
||||
help=f"DataLoader workers override (default: {eval_cfg.num_workers}).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pattern", type=str, default="*.pt",
|
||||
help="Glob pattern for --compare mode (default: '*.pt').",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-plots", action="store_true",
|
||||
help="Skip generating per-scene plots.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Resolve device
|
||||
device = None
|
||||
if args.device is not None:
|
||||
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
|
||||
|
||||
# Resolve output directory
|
||||
output_dir = Path(args.output) if args.output else None
|
||||
|
||||
if args.checkpoint:
|
||||
ckpt_path = Path(args.checkpoint)
|
||||
if not ckpt_path.exists():
|
||||
print(f"Error: checkpoint not found: {ckpt_path}")
|
||||
sys.exit(1)
|
||||
run_single_eval(
|
||||
checkpoint_path=ckpt_path,
|
||||
output_dir=output_dir,
|
||||
device=device,
|
||||
seq_len=args.seq_len,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
save_plots=not args.no_plots,
|
||||
)
|
||||
elif args.compare:
|
||||
ckpt_dir = Path(args.compare)
|
||||
if not ckpt_dir.is_dir():
|
||||
print(f"Error: checkpoint directory not found: {ckpt_dir}")
|
||||
sys.exit(1)
|
||||
run_compare(
|
||||
checkpoint_dir=ckpt_dir,
|
||||
output_dir=output_dir,
|
||||
device=device,
|
||||
seq_len=args.seq_len,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pattern=args.pattern,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
98
benchmark/config.py
Normal file
98
benchmark/config.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
Benchmark configuration — evaluation-only scene splits and metric definitions.
|
||||
|
||||
This config is independent from src.velocity_prediction.config so that
|
||||
evaluation scenarios can be changed without touching training config.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
# ──────────────────────────── Dataset root ────────────────────────────
|
||||
|
||||
DATASET_ROOT = Path(__file__).resolve().parents[1] / "dataset"
|
||||
|
||||
# ──────────────────────────── Scene splits ────────────────────────────
|
||||
|
||||
# Each scene group has a name, a list of scene dirs, and a difficulty label.
|
||||
# The test scenes are the primary evaluation set; val scenes are for
|
||||
# checkpoint selection reference.
|
||||
|
||||
|
||||
@dataclass
|
||||
class SceneGroup:
|
||||
name: str
|
||||
scenes: List[str]
|
||||
difficulty: str = "medium" # easy / medium / hard
|
||||
|
||||
|
||||
# ── Validation scenes (for checkpoint selection reference) ──
|
||||
VAL_SCENE_GROUPS: List[SceneGroup] = [
|
||||
SceneGroup("indoor_forward_7", ["indoor_forward_7"], "hard"),
|
||||
SceneGroup("outdoor_forward_1", ["outdoor_forward_1"], "easy"),
|
||||
# SceneGroup("indoor_forward_6", ["indoor_forward_6"], "medium"),
|
||||
# SceneGroup("indoor_forward_9", ["indoor_forward_9"], "easy"),
|
||||
# SceneGroup("indoor_forward_10", ["indoor_forward_10"], "easy"),
|
||||
# SceneGroup("indoor_forward_5", ["indoor_forward_5"], "medium"),
|
||||
]
|
||||
|
||||
# ── Test scenes (primary evaluation) ──
|
||||
TEST_SCENE_GROUPS: List[SceneGroup] = [
|
||||
SceneGroup("indoor_forward_7", ["indoor_forward_7"], "hard"),
|
||||
SceneGroup("outdoor_forward_1", ["outdoor_forward_1"], "easy"),
|
||||
SceneGroup("outdoor_forward_5", ["outdoor_forward_5"], "hard"),
|
||||
SceneGroup("indoor_forward_6", ["indoor_forward_6"], "medium"),
|
||||
SceneGroup("indoor_forward_9", ["indoor_forward_9"], "easy"),
|
||||
SceneGroup("indoor_forward_10", ["indoor_forward_10"], "easy"),
|
||||
SceneGroup("indoor_forward_5", ["indoor_forward_5"], "medium"),
|
||||
]
|
||||
|
||||
# Flat lists for convenience
|
||||
VAL_SCENES: List[str] = [s for g in VAL_SCENE_GROUPS for s in g.scenes]
|
||||
TEST_SCENES: List[str] = [s for g in TEST_SCENE_GROUPS for s in g.scenes]
|
||||
|
||||
# Difficulty grouping
|
||||
DIFFICULTY_GROUPS: Dict[str, List[str]] = {}
|
||||
for g in TEST_SCENE_GROUPS:
|
||||
DIFFICULTY_GROUPS.setdefault(g.difficulty, []).extend(g.scenes)
|
||||
|
||||
|
||||
# ──────────────────────────── Evaluation parameters ────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalConfig:
|
||||
"""Parameters used when running evaluation."""
|
||||
|
||||
# Sequence length (must match what the model was trained with)
|
||||
seq_len: int = 8
|
||||
|
||||
# Batch size for evaluation (can be larger than training)
|
||||
batch_size: int = 64
|
||||
|
||||
# Data loading
|
||||
num_workers: int = 2
|
||||
|
||||
# Event simulation (must match training config)
|
||||
event_threshold: float = 0.1
|
||||
event_use_log: bool = True
|
||||
|
||||
# Output directory (relative to benchmark/results/)
|
||||
output_dir: str = "results"
|
||||
|
||||
# Whether to generate per-scene plots
|
||||
save_plots: bool = True
|
||||
|
||||
# Device override (None = auto-detect)
|
||||
device: str = "cuda"
|
||||
|
||||
|
||||
# ──────────────────────────── Metrics definition ────────────────────────────
|
||||
|
||||
# Metrics computed per-axis and overall
|
||||
METRICS = ["rmse", "mae", "r2"]
|
||||
|
||||
# Singleton
|
||||
eval_cfg = EvalConfig()
|
||||
458
benchmark/evaluate.py
Normal file
458
benchmark/evaluate.py
Normal file
@@ -0,0 +1,458 @@
|
||||
"""
|
||||
Core evaluation logic: run model on one or more scenes, compute metrics,
|
||||
generate visualizations, and save structured results.
|
||||
|
||||
This module is called by benchmark.py (the user-facing entry point).
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from src.velocity_prediction.model import VelocityPredictionModel
|
||||
from src.velocity_prediction.dataset import create_val_loader
|
||||
from src.velocity_prediction.config import VELOCITY_MEAN, VELOCITY_STD
|
||||
|
||||
from benchmark.config import (
|
||||
eval_cfg,
|
||||
TEST_SCENE_GROUPS,
|
||||
VAL_SCENE_GROUPS,
|
||||
DIFFICULTY_GROUPS,
|
||||
DATASET_ROOT,
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────── Metrics ────────────────────────────
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
pred: np.ndarray,
|
||||
target: np.ndarray,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Compute RMSE, MAE, R² for each axis and overall.
|
||||
|
||||
Args:
|
||||
pred: (N, 2) denormalized predictions
|
||||
target: (N, 2) denormalized ground truth
|
||||
|
||||
Returns:
|
||||
dict with keys like rmse_vx, rmse_vy, rmse_xy, mae_vx, ...
|
||||
"""
|
||||
# Per-axis
|
||||
rmse_x = float(np.sqrt(np.mean((pred[:, 0] - target[:, 0]) ** 2)))
|
||||
rmse_y = float(np.sqrt(np.mean((pred[:, 1] - target[:, 1]) ** 2)))
|
||||
rmse_xy = float(np.sqrt(np.mean(np.sum((pred - target) ** 2, axis=1))))
|
||||
|
||||
mae_x = float(np.mean(np.abs(pred[:, 0] - target[:, 0])))
|
||||
mae_y = float(np.mean(np.abs(pred[:, 1] - target[:, 1])))
|
||||
mae_xy = float(np.mean(np.sqrt(np.sum((pred - target) ** 2, axis=1))))
|
||||
|
||||
# R² per axis
|
||||
def r2(p, t):
|
||||
ss_res = np.sum((t - p) ** 2)
|
||||
ss_tot = np.sum((t - np.mean(t)) ** 2)
|
||||
return float(1 - ss_res / ss_tot) if ss_tot > 1e-12 else 0.0
|
||||
|
||||
r2_x = r2(pred[:, 0], target[:, 0])
|
||||
r2_y = r2(pred[:, 1], target[:, 1])
|
||||
|
||||
return {
|
||||
"rmse_vx": rmse_x,
|
||||
"rmse_vy": rmse_y,
|
||||
"rmse_xy": rmse_xy,
|
||||
"mae_vx": mae_x,
|
||||
"mae_vy": mae_y,
|
||||
"mae_xy": mae_xy,
|
||||
"r2_vx": r2_x,
|
||||
"r2_vy": r2_y,
|
||||
"count": len(pred),
|
||||
}
|
||||
|
||||
|
||||
# ──────────────────────────── Per-scene evaluation ────────────────────────────
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_scene(
|
||||
model: nn.Module,
|
||||
scene_names: List[str],
|
||||
device: torch.device,
|
||||
seq_len: int = 8,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 2,
|
||||
event_threshold: float = 0.1,
|
||||
event_use_log: bool = True,
|
||||
) -> Dict:
|
||||
"""
|
||||
Evaluate model on one or more scenes.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
preds: (N, 2) denormalized predictions
|
||||
targets: (N, 2) denormalized ground truth
|
||||
metrics: dict of scalar metrics
|
||||
"""
|
||||
loader = create_val_loader(
|
||||
scene_names=scene_names,
|
||||
seq_len=seq_len,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
event_threshold=event_threshold,
|
||||
event_use_log=event_use_log,
|
||||
)
|
||||
|
||||
model.eval()
|
||||
all_preds = []
|
||||
all_targets = []
|
||||
|
||||
for batch in loader:
|
||||
events = batch["events"].to(device)
|
||||
tilt = batch["tilt"].to(device)
|
||||
target = batch["v_body_target"].to(device) # (B, S, 2) normalized
|
||||
|
||||
pred = model(events, tilt) # (B, 2) normalized
|
||||
target_last = target[:, -1, :] # (B, 2) normalized
|
||||
|
||||
all_preds.append(pred.cpu().numpy())
|
||||
all_targets.append(target_last.cpu().numpy())
|
||||
|
||||
if not all_preds:
|
||||
return {"preds": np.zeros((0, 2)), "targets": np.zeros((0, 2)), "metrics": {}}
|
||||
|
||||
preds = np.concatenate(all_preds, axis=0)
|
||||
targets = np.concatenate(all_targets, axis=0)
|
||||
|
||||
# Denormalize
|
||||
mean = np.array(VELOCITY_MEAN, dtype=np.float32)
|
||||
std = np.array(VELOCITY_STD, dtype=np.float32)
|
||||
preds_denorm = preds * std + mean
|
||||
targets_denorm = targets * std + mean
|
||||
|
||||
metrics = compute_metrics(preds_denorm, targets_denorm)
|
||||
|
||||
return {
|
||||
"preds": preds_denorm,
|
||||
"targets": targets_denorm,
|
||||
"metrics": metrics,
|
||||
}
|
||||
|
||||
|
||||
# ──────────────────────────── Full evaluation suite ────────────────────────────
|
||||
|
||||
|
||||
def run_full_evaluation(
|
||||
model: nn.Module,
|
||||
device: torch.device,
|
||||
seq_len: int = 8,
|
||||
batch_size: int = 64,
|
||||
num_workers: int = 2,
|
||||
event_threshold: float = 0.1,
|
||||
event_use_log: bool = True,
|
||||
scene_groups=None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Run evaluation on all scene groups.
|
||||
|
||||
Returns nested dict:
|
||||
{
|
||||
"global": { metrics... },
|
||||
"per_scene": {
|
||||
"indoor_forward_7": { metrics..., "preds": ..., "targets": ... },
|
||||
...
|
||||
},
|
||||
"by_difficulty": {
|
||||
"easy": { metrics... },
|
||||
"hard": { metrics... },
|
||||
}
|
||||
}
|
||||
"""
|
||||
if scene_groups is None:
|
||||
from benchmark.config import TEST_SCENE_GROUPS
|
||||
scene_groups = TEST_SCENE_GROUPS
|
||||
|
||||
per_scene = {}
|
||||
all_preds = []
|
||||
all_targets = []
|
||||
|
||||
for group in scene_groups:
|
||||
for scene_name in group.scenes:
|
||||
result = evaluate_scene(
|
||||
model, [scene_name], device,
|
||||
seq_len=seq_len, batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
event_threshold=event_threshold,
|
||||
event_use_log=event_use_log,
|
||||
)
|
||||
per_scene[scene_name] = result
|
||||
if result["preds"].shape[0] > 0:
|
||||
all_preds.append(result["preds"])
|
||||
all_targets.append(result["targets"])
|
||||
|
||||
# Global metrics (all scenes combined)
|
||||
if all_preds:
|
||||
global_preds = np.concatenate(all_preds, axis=0)
|
||||
global_targets = np.concatenate(all_targets, axis=0)
|
||||
global_metrics = compute_metrics(global_preds, global_targets)
|
||||
else:
|
||||
global_preds = np.zeros((0, 2))
|
||||
global_targets = np.zeros((0, 2))
|
||||
global_metrics = {}
|
||||
|
||||
# By difficulty
|
||||
by_difficulty = {}
|
||||
for diff, scenes in DIFFICULTY_GROUPS.items():
|
||||
diff_preds = []
|
||||
diff_targets = []
|
||||
for s in scenes:
|
||||
if s in per_scene and per_scene[s]["preds"].shape[0] > 0:
|
||||
diff_preds.append(per_scene[s]["preds"])
|
||||
diff_targets.append(per_scene[s]["targets"])
|
||||
if diff_preds:
|
||||
by_difficulty[diff] = compute_metrics(
|
||||
np.concatenate(diff_preds, axis=0),
|
||||
np.concatenate(diff_targets, axis=0),
|
||||
)
|
||||
else:
|
||||
by_difficulty[diff] = {}
|
||||
|
||||
return {
|
||||
"global": global_metrics,
|
||||
"per_scene": per_scene,
|
||||
"by_difficulty": by_difficulty,
|
||||
}
|
||||
|
||||
|
||||
# ──────────────────────────── Visualization ────────────────────────────
|
||||
|
||||
|
||||
def plot_scene_comparison(
|
||||
preds: np.ndarray,
|
||||
targets: np.ndarray,
|
||||
scene_name: str,
|
||||
save_dir: Path,
|
||||
metrics: Optional[Dict] = None,
|
||||
):
|
||||
"""Generate time-series and scatter plots for a single scene."""
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
time = np.arange(len(preds))
|
||||
|
||||
# ── Time-series plot ──
|
||||
fig, axes = plt.subplots(2, 1, figsize=(14, 5), sharex=True)
|
||||
|
||||
axes[0].plot(time, targets[:, 0], label="GT vx", color="C0", alpha=0.7, linewidth=0.8)
|
||||
axes[0].plot(time, preds[:, 0], label="Pred vx", color="C1", alpha=0.7, linewidth=0.8)
|
||||
axes[0].set_ylabel("vx (m/s)")
|
||||
axes[0].legend(fontsize=9)
|
||||
axes[0].grid(True, alpha=0.3)
|
||||
|
||||
axes[1].plot(time, targets[:, 1], label="GT vy", color="C0", alpha=0.7, linewidth=0.8)
|
||||
axes[1].plot(time, preds[:, 1], label="Pred vy", color="C1", alpha=0.7, linewidth=0.8)
|
||||
axes[1].set_ylabel("vy (m/s)")
|
||||
axes[1].set_xlabel("Frame index")
|
||||
axes[1].legend(fontsize=9)
|
||||
axes[1].grid(True, alpha=0.3)
|
||||
|
||||
title = f"Body-frame Velocity — {scene_name}"
|
||||
if metrics:
|
||||
title += f" | RMSE vx={metrics['rmse_vx']:.3f} vy={metrics['rmse_vy']:.3f}"
|
||||
fig.suptitle(title)
|
||||
try:
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_dir / f"{scene_name}_timeseries.png", dpi=150, bbox_inches="tight")
|
||||
except Exception as e:
|
||||
print(f" [WARN] Failed to save {scene_name}_timeseries.png: {e}")
|
||||
plt.close()
|
||||
|
||||
# ── Scatter plot ──
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 4.5))
|
||||
|
||||
for ax, pred, target, label in zip(
|
||||
axes, [preds[:, 0], preds[:, 1]], [targets[:, 0], targets[:, 1]], ["vx", "vy"]
|
||||
):
|
||||
ax.scatter(target, pred, s=3, alpha=0.4, c="C1", edgecolors="none")
|
||||
lim_min = min(target.min(), pred.min())
|
||||
lim_max = max(target.max(), pred.max())
|
||||
margin = (lim_max - lim_min) * 0.05
|
||||
ax.plot([lim_min - margin, lim_max + margin],
|
||||
[lim_min - margin, lim_max + margin], "r--", alpha=0.5, linewidth=1)
|
||||
ax.set_xlabel(f"GT {label} (m/s)")
|
||||
ax.set_ylabel(f"Pred {label} (m/s)")
|
||||
ax.set_aspect("equal")
|
||||
ax.grid(True, alpha=0.3)
|
||||
if metrics:
|
||||
ax.set_title(f"{label} — RMSE: {metrics[f'rmse_{label}']:.4f}")
|
||||
|
||||
fig.suptitle(f"Scatter — {scene_name}")
|
||||
try:
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_dir / f"{scene_name}_scatter.png", dpi=150, bbox_inches="tight")
|
||||
except Exception as e:
|
||||
print(f" [WARN] Failed to save {scene_name}_scatter.png: {e}")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_global_comparison(
|
||||
results: Dict,
|
||||
save_dir: Path,
|
||||
):
|
||||
"""Generate a summary figure comparing all scenes."""
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
save_dir = Path(save_dir)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
scenes = list(results["per_scene"].keys())
|
||||
if not scenes:
|
||||
return
|
||||
|
||||
# Bar chart: RMSE vx and vy per scene
|
||||
rmse_vx = [results["per_scene"][s]["metrics"].get("rmse_vx", 0) for s in scenes]
|
||||
rmse_vy = [results["per_scene"][s]["metrics"].get("rmse_vy", 0) for s in scenes]
|
||||
rmse_xy = [results["per_scene"][s]["metrics"].get("rmse_xy", 0) for s in scenes]
|
||||
|
||||
x = np.arange(len(scenes))
|
||||
width = 0.25
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 4.5))
|
||||
bars1 = ax.bar(x - width, rmse_vx, width, label="RMSE vx", alpha=0.8)
|
||||
bars2 = ax.bar(x, rmse_vy, width, label="RMSE vy", alpha=0.8)
|
||||
bars3 = ax.bar(x + width, rmse_xy, width, label="RMSE xy", alpha=0.8)
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(scenes, rotation=15, ha="right")
|
||||
ax.set_ylabel("RMSE (m/s)")
|
||||
ax.set_title("Per-Scene RMSE Comparison")
|
||||
ax.legend(fontsize=9)
|
||||
ax.grid(True, alpha=0.3, axis="y")
|
||||
|
||||
# Annotate values
|
||||
for bars in [bars1, bars2, bars3]:
|
||||
for bar in bars:
|
||||
height = bar.get_height()
|
||||
ax.annotate(f"{height:.3f}",
|
||||
xy=(bar.get_x() + bar.get_width() / 2, height),
|
||||
xytext=(0, 2), textcoords="offset points",
|
||||
ha="center", va="bottom", fontsize=7)
|
||||
|
||||
try:
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_dir / "per_scene_rmse.png", dpi=150, bbox_inches="tight")
|
||||
except Exception as e:
|
||||
print(f" [WARN] Failed to save per_scene_rmse.png: {e}")
|
||||
plt.close()
|
||||
|
||||
|
||||
# ──────────────────────────── Results serialization ────────────────────────────
|
||||
|
||||
|
||||
def save_results(
|
||||
results: Dict,
|
||||
save_dir: Path,
|
||||
checkpoint_name: str = "model",
|
||||
):
|
||||
"""Save all evaluation results to disk."""
|
||||
save_dir = Path(save_dir)
|
||||
plots_dir = save_dir / "plots"
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
plots_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ── 1. Global metrics ──
|
||||
global_m = results["global"]
|
||||
lines = [
|
||||
f"Benchmark Results: {checkpoint_name}",
|
||||
f"{'=' * 50}",
|
||||
f"Total samples: {global_m.get('count', '?')}",
|
||||
"",
|
||||
"── Global Metrics ──",
|
||||
f" RMSE vx: {global_m.get('rmse_vx', 'N/A'):.4f} m/s",
|
||||
f" RMSE vy: {global_m.get('rmse_vy', 'N/A'):.4f} m/s",
|
||||
f" RMSE xy: {global_m.get('rmse_xy', 'N/A'):.4f} m/s",
|
||||
f" MAE vx: {global_m.get('mae_vx', 'N/A'):.4f} m/s",
|
||||
f" MAE vy: {global_m.get('mae_vy', 'N/A'):.4f} m/s",
|
||||
f" MAE xy: {global_m.get('mae_xy', 'N/A'):.4f} m/s",
|
||||
f" R² vx: {global_m.get('r2_vx', 'N/A'):.4f}",
|
||||
f" R² vy: {global_m.get('r2_vy', 'N/A'):.4f}",
|
||||
"",
|
||||
]
|
||||
|
||||
# ── 2. Per-scene metrics ──
|
||||
lines.append("── Per-Scene Metrics ──")
|
||||
lines.append(f" {'Scene':<22} {'RMSE vx':>10} {'RMSE vy':>10} {'RMSE xy':>10} "
|
||||
f"{'MAE vx':>10} {'MAE vy':>10} {'R² vx':>8} {'R² vy':>8} {'Samples':>8}")
|
||||
lines.append(" " + "-" * 96)
|
||||
|
||||
for scene_name, scene_result in results["per_scene"].items():
|
||||
m = scene_result["metrics"]
|
||||
lines.append(
|
||||
f" {scene_name:<22} {m.get('rmse_vx', 0):>10.4f} {m.get('rmse_vy', 0):>10.4f} "
|
||||
f"{m.get('rmse_xy', 0):>10.4f} {m.get('mae_vx', 0):>10.4f} "
|
||||
f"{m.get('mae_vy', 0):>10.4f} {m.get('r2_vx', 0):>8.4f} "
|
||||
f"{m.get('r2_vy', 0):>8.4f} {m.get('count', 0):>8}"
|
||||
)
|
||||
|
||||
lines.append("")
|
||||
|
||||
# ── 3. By difficulty ──
|
||||
lines.append("── By Difficulty ──")
|
||||
for diff, metrics in results["by_difficulty"].items():
|
||||
lines.append(f" {diff:<10} RMSE vx={metrics.get('rmse_vx', 0):.4f} "
|
||||
f"RMSE vy={metrics.get('rmse_vy', 0):.4f} "
|
||||
f"RMSE xy={metrics.get('rmse_xy', 0):.4f} "
|
||||
f"(samples={metrics.get('count', 0)})")
|
||||
|
||||
summary_text = "\n".join(lines)
|
||||
with open(save_dir / "summary.txt", "w") as f:
|
||||
f.write(summary_text)
|
||||
print(summary_text)
|
||||
|
||||
# ── 4. CSV: per-scene metrics ──
|
||||
import csv
|
||||
csv_path = save_dir / "per_scene_metrics.csv"
|
||||
fieldnames = ["scene", "rmse_vx", "rmse_vy", "rmse_xy", "mae_vx", "mae_vy",
|
||||
"mae_xy", "r2_vx", "r2_vy", "count"]
|
||||
with open(csv_path, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for scene_name, scene_result in results["per_scene"].items():
|
||||
row = {"scene": scene_name, **scene_result["metrics"]}
|
||||
writer.writerow(row)
|
||||
print(f"Per-scene CSV: {csv_path}")
|
||||
|
||||
# ── 5. Global metrics CSV (single row) ──
|
||||
csv_path_global = save_dir / "metrics.csv"
|
||||
with open(csv_path_global, "w", newline="") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["checkpoint"] + list(global_m.keys()))
|
||||
writer.writeheader()
|
||||
row = {"checkpoint": checkpoint_name, **global_m}
|
||||
writer.writerow(row)
|
||||
print(f"Global metrics CSV: {csv_path_global}")
|
||||
|
||||
# ── 6. Plots ──
|
||||
for scene_name, scene_result in results["per_scene"].items():
|
||||
if scene_result["preds"].shape[0] > 0:
|
||||
plot_scene_comparison(
|
||||
scene_result["preds"],
|
||||
scene_result["targets"],
|
||||
scene_name,
|
||||
plots_dir,
|
||||
metrics=scene_result["metrics"],
|
||||
)
|
||||
|
||||
plot_global_comparison(results, plots_dir)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
plt.close("all")
|
||||
|
||||
print(f"\nAll results saved to: {save_dir.resolve()}")
|
||||
Reference in New Issue
Block a user