""" Core evaluation logic: run model on one or more scenes, compute metrics, generate visualizations, and save structured results. This module is called by benchmark.py (the user-facing entry point). """ import numpy as np from pathlib import Path from typing import List, Optional, Dict, Tuple import torch import torch.nn as nn from src.velocity_prediction.model import VelocityPredictionModel from src.velocity_prediction.dataset import create_val_loader from src.velocity_prediction.config import VELOCITY_MEAN, VELOCITY_STD from benchmark.config import ( eval_cfg, TEST_SCENE_GROUPS, VAL_SCENE_GROUPS, DIFFICULTY_GROUPS, DATASET_ROOT, ) # ──────────────────────────── Metrics ──────────────────────────── def compute_metrics( pred: np.ndarray, target: np.ndarray, ) -> Dict[str, float]: """ Compute RMSE, MAE, R² for each axis and overall. Args: pred: (N, 2) denormalized predictions target: (N, 2) denormalized ground truth Returns: dict with keys like rmse_vx, rmse_vy, rmse_xy, mae_vx, ... """ # Per-axis rmse_x = float(np.sqrt(np.mean((pred[:, 0] - target[:, 0]) ** 2))) rmse_y = float(np.sqrt(np.mean((pred[:, 1] - target[:, 1]) ** 2))) rmse_xy = float(np.sqrt(np.mean(np.sum((pred - target) ** 2, axis=1)))) mae_x = float(np.mean(np.abs(pred[:, 0] - target[:, 0]))) mae_y = float(np.mean(np.abs(pred[:, 1] - target[:, 1]))) mae_xy = float(np.mean(np.sqrt(np.sum((pred - target) ** 2, axis=1)))) # R² per axis def r2(p, t): ss_res = np.sum((t - p) ** 2) ss_tot = np.sum((t - np.mean(t)) ** 2) return float(1 - ss_res / ss_tot) if ss_tot > 1e-12 else 0.0 r2_x = r2(pred[:, 0], target[:, 0]) r2_y = r2(pred[:, 1], target[:, 1]) return { "rmse_vx": rmse_x, "rmse_vy": rmse_y, "rmse_xy": rmse_xy, "mae_vx": mae_x, "mae_vy": mae_y, "mae_xy": mae_xy, "r2_vx": r2_x, "r2_vy": r2_y, "count": len(pred), } # ──────────────────────────── Per-scene evaluation ──────────────────────────── @torch.no_grad() def evaluate_scene( model: nn.Module, scene_names: List[str], device: torch.device, seq_len: int = 8, batch_size: int = 64, num_workers: int = 2, event_threshold: float = 0.1, event_use_log: bool = True, ) -> Dict: """ Evaluate model on one or more scenes. Returns: dict with keys: preds: (N, 2) denormalized predictions targets: (N, 2) denormalized ground truth metrics: dict of scalar metrics """ loader = create_val_loader( scene_names=scene_names, seq_len=seq_len, batch_size=batch_size, num_workers=num_workers, event_threshold=event_threshold, event_use_log=event_use_log, ) model.eval() all_preds = [] all_targets = [] for batch in loader: events = batch["events"].to(device) tilt = batch["tilt"].to(device) target = batch["v_body_target"].to(device) # (B, S, 2) normalized pred = model(events, tilt) # (B, 2) normalized target_last = target[:, -1, :] # (B, 2) normalized all_preds.append(pred.cpu().numpy()) all_targets.append(target_last.cpu().numpy()) if not all_preds: return {"preds": np.zeros((0, 2)), "targets": np.zeros((0, 2)), "metrics": {}} preds = np.concatenate(all_preds, axis=0) targets = np.concatenate(all_targets, axis=0) # Denormalize mean = np.array(VELOCITY_MEAN, dtype=np.float32) std = np.array(VELOCITY_STD, dtype=np.float32) preds_denorm = preds * std + mean targets_denorm = targets * std + mean metrics = compute_metrics(preds_denorm, targets_denorm) return { "preds": preds_denorm, "targets": targets_denorm, "metrics": metrics, } # ──────────────────────────── Full evaluation suite ──────────────────────────── def run_full_evaluation( model: nn.Module, device: torch.device, seq_len: int = 8, batch_size: int = 64, num_workers: int = 2, event_threshold: float = 0.1, event_use_log: bool = True, scene_groups=None, ) -> Dict: """ Run evaluation on all scene groups. Returns nested dict: { "global": { metrics... }, "per_scene": { "indoor_forward_7": { metrics..., "preds": ..., "targets": ... }, ... }, "by_difficulty": { "easy": { metrics... }, "hard": { metrics... }, } } """ if scene_groups is None: from benchmark.config import TEST_SCENE_GROUPS scene_groups = TEST_SCENE_GROUPS per_scene = {} all_preds = [] all_targets = [] for group in scene_groups: for scene_name in group.scenes: result = evaluate_scene( model, [scene_name], device, seq_len=seq_len, batch_size=batch_size, num_workers=num_workers, event_threshold=event_threshold, event_use_log=event_use_log, ) per_scene[scene_name] = result if result["preds"].shape[0] > 0: all_preds.append(result["preds"]) all_targets.append(result["targets"]) # Global metrics (all scenes combined) if all_preds: global_preds = np.concatenate(all_preds, axis=0) global_targets = np.concatenate(all_targets, axis=0) global_metrics = compute_metrics(global_preds, global_targets) else: global_preds = np.zeros((0, 2)) global_targets = np.zeros((0, 2)) global_metrics = {} # By difficulty by_difficulty = {} for diff, scenes in DIFFICULTY_GROUPS.items(): diff_preds = [] diff_targets = [] for s in scenes: if s in per_scene and per_scene[s]["preds"].shape[0] > 0: diff_preds.append(per_scene[s]["preds"]) diff_targets.append(per_scene[s]["targets"]) if diff_preds: by_difficulty[diff] = compute_metrics( np.concatenate(diff_preds, axis=0), np.concatenate(diff_targets, axis=0), ) else: by_difficulty[diff] = {} return { "global": global_metrics, "per_scene": per_scene, "by_difficulty": by_difficulty, } # ──────────────────────────── Visualization ──────────────────────────── def plot_scene_comparison( preds: np.ndarray, targets: np.ndarray, scene_name: str, save_dir: Path, metrics: Optional[Dict] = None, ): """Generate time-series and scatter plots for a single scene.""" import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) time = np.arange(len(preds)) # ── Time-series plot ── fig, axes = plt.subplots(2, 1, figsize=(14, 5), sharex=True) axes[0].plot(time, targets[:, 0], label="GT vx", color="C0", alpha=0.7, linewidth=0.8) axes[0].plot(time, preds[:, 0], label="Pred vx", color="C1", alpha=0.7, linewidth=0.8) axes[0].set_ylabel("vx (m/s)") axes[0].legend(fontsize=9) axes[0].grid(True, alpha=0.3) axes[1].plot(time, targets[:, 1], label="GT vy", color="C0", alpha=0.7, linewidth=0.8) axes[1].plot(time, preds[:, 1], label="Pred vy", color="C1", alpha=0.7, linewidth=0.8) axes[1].set_ylabel("vy (m/s)") axes[1].set_xlabel("Frame index") axes[1].legend(fontsize=9) axes[1].grid(True, alpha=0.3) title = f"Body-frame Velocity — {scene_name}" if metrics: title += f" | RMSE vx={metrics['rmse_vx']:.3f} vy={metrics['rmse_vy']:.3f}" fig.suptitle(title) try: plt.tight_layout() plt.savefig(save_dir / f"{scene_name}_timeseries.png", dpi=150, bbox_inches="tight") except Exception as e: print(f" [WARN] Failed to save {scene_name}_timeseries.png: {e}") plt.close() # ── Scatter plot ── fig, axes = plt.subplots(1, 2, figsize=(10, 4.5)) for ax, pred, target, label in zip( axes, [preds[:, 0], preds[:, 1]], [targets[:, 0], targets[:, 1]], ["vx", "vy"] ): ax.scatter(target, pred, s=3, alpha=0.4, c="C1", edgecolors="none") lim_min = min(target.min(), pred.min()) lim_max = max(target.max(), pred.max()) margin = (lim_max - lim_min) * 0.05 ax.plot([lim_min - margin, lim_max + margin], [lim_min - margin, lim_max + margin], "r--", alpha=0.5, linewidth=1) ax.set_xlabel(f"GT {label} (m/s)") ax.set_ylabel(f"Pred {label} (m/s)") ax.set_aspect("equal") ax.grid(True, alpha=0.3) if metrics: ax.set_title(f"{label} — RMSE: {metrics[f'rmse_{label}']:.4f}") fig.suptitle(f"Scatter — {scene_name}") try: plt.tight_layout() plt.savefig(save_dir / f"{scene_name}_scatter.png", dpi=150, bbox_inches="tight") except Exception as e: print(f" [WARN] Failed to save {scene_name}_scatter.png: {e}") plt.close() def plot_global_comparison( results: Dict, save_dir: Path, ): """Generate a summary figure comparing all scenes.""" import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) scenes = list(results["per_scene"].keys()) if not scenes: return # Bar chart: RMSE vx and vy per scene rmse_vx = [results["per_scene"][s]["metrics"].get("rmse_vx", 0) for s in scenes] rmse_vy = [results["per_scene"][s]["metrics"].get("rmse_vy", 0) for s in scenes] rmse_xy = [results["per_scene"][s]["metrics"].get("rmse_xy", 0) for s in scenes] x = np.arange(len(scenes)) width = 0.25 fig, ax = plt.subplots(figsize=(10, 4.5)) bars1 = ax.bar(x - width, rmse_vx, width, label="RMSE vx", alpha=0.8) bars2 = ax.bar(x, rmse_vy, width, label="RMSE vy", alpha=0.8) bars3 = ax.bar(x + width, rmse_xy, width, label="RMSE xy", alpha=0.8) ax.set_xticks(x) ax.set_xticklabels(scenes, rotation=15, ha="right") ax.set_ylabel("RMSE (m/s)") ax.set_title("Per-Scene RMSE Comparison") ax.legend(fontsize=9) ax.grid(True, alpha=0.3, axis="y") # Annotate values for bars in [bars1, bars2, bars3]: for bar in bars: height = bar.get_height() ax.annotate(f"{height:.3f}", xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 2), textcoords="offset points", ha="center", va="bottom", fontsize=7) try: plt.tight_layout() plt.savefig(save_dir / "per_scene_rmse.png", dpi=150, bbox_inches="tight") except Exception as e: print(f" [WARN] Failed to save per_scene_rmse.png: {e}") plt.close() # ──────────────────────────── Results serialization ──────────────────────────── def save_results( results: Dict, save_dir: Path, checkpoint_name: str = "model", ): """Save all evaluation results to disk.""" save_dir = Path(save_dir) plots_dir = save_dir / "plots" save_dir.mkdir(parents=True, exist_ok=True) plots_dir.mkdir(parents=True, exist_ok=True) # ── 1. Global metrics ── global_m = results["global"] lines = [ f"Benchmark Results: {checkpoint_name}", f"{'=' * 50}", f"Total samples: {global_m.get('count', '?')}", "", "── Global Metrics ──", f" RMSE vx: {global_m.get('rmse_vx', 'N/A'):.4f} m/s", f" RMSE vy: {global_m.get('rmse_vy', 'N/A'):.4f} m/s", f" RMSE xy: {global_m.get('rmse_xy', 'N/A'):.4f} m/s", f" MAE vx: {global_m.get('mae_vx', 'N/A'):.4f} m/s", f" MAE vy: {global_m.get('mae_vy', 'N/A'):.4f} m/s", f" MAE xy: {global_m.get('mae_xy', 'N/A'):.4f} m/s", f" R² vx: {global_m.get('r2_vx', 'N/A'):.4f}", f" R² vy: {global_m.get('r2_vy', 'N/A'):.4f}", "", ] # ── 2. Per-scene metrics ── lines.append("── Per-Scene Metrics ──") lines.append(f" {'Scene':<22} {'RMSE vx':>10} {'RMSE vy':>10} {'RMSE xy':>10} " f"{'MAE vx':>10} {'MAE vy':>10} {'R² vx':>8} {'R² vy':>8} {'Samples':>8}") lines.append(" " + "-" * 96) for scene_name, scene_result in results["per_scene"].items(): m = scene_result["metrics"] lines.append( f" {scene_name:<22} {m.get('rmse_vx', 0):>10.4f} {m.get('rmse_vy', 0):>10.4f} " f"{m.get('rmse_xy', 0):>10.4f} {m.get('mae_vx', 0):>10.4f} " f"{m.get('mae_vy', 0):>10.4f} {m.get('r2_vx', 0):>8.4f} " f"{m.get('r2_vy', 0):>8.4f} {m.get('count', 0):>8}" ) lines.append("") # ── 3. By difficulty ── lines.append("── By Difficulty ──") for diff, metrics in results["by_difficulty"].items(): lines.append(f" {diff:<10} RMSE vx={metrics.get('rmse_vx', 0):.4f} " f"RMSE vy={metrics.get('rmse_vy', 0):.4f} " f"RMSE xy={metrics.get('rmse_xy', 0):.4f} " f"(samples={metrics.get('count', 0)})") summary_text = "\n".join(lines) with open(save_dir / "summary.txt", "w") as f: f.write(summary_text) print(summary_text) # ── 4. CSV: per-scene metrics ── import csv csv_path = save_dir / "per_scene_metrics.csv" fieldnames = ["scene", "rmse_vx", "rmse_vy", "rmse_xy", "mae_vx", "mae_vy", "mae_xy", "r2_vx", "r2_vy", "count"] with open(csv_path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() for scene_name, scene_result in results["per_scene"].items(): row = {"scene": scene_name, **scene_result["metrics"]} writer.writerow(row) print(f"Per-scene CSV: {csv_path}") # ── 5. Global metrics CSV (single row) ── csv_path_global = save_dir / "metrics.csv" with open(csv_path_global, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=["checkpoint"] + list(global_m.keys())) writer.writeheader() row = {"checkpoint": checkpoint_name, **global_m} writer.writerow(row) print(f"Global metrics CSV: {csv_path_global}") # ── 6. Plots ── for scene_name, scene_result in results["per_scene"].items(): if scene_result["preds"].shape[0] > 0: plot_scene_comparison( scene_result["preds"], scene_result["targets"], scene_name, plots_dir, metrics=scene_result["metrics"], ) plot_global_comparison(results, plots_dir) import matplotlib.pyplot as plt plt.close("all") print(f"\nAll results saved to: {save_dir.resolve()}")