Files
uzh-fpv-sv-test/benchmark/evaluate.py
2026-05-29 18:49:01 +08:00

459 lines
15 KiB
Python

"""
Core evaluation logic: run model on one or more scenes, compute metrics,
generate visualizations, and save structured results.
This module is called by benchmark.py (the user-facing entry point).
"""
import numpy as np
from pathlib import Path
from typing import List, Optional, Dict, Tuple
import torch
import torch.nn as nn
from src.velocity_prediction.model import VelocityPredictionModel
from src.velocity_prediction.dataset import create_val_loader
from src.velocity_prediction.config import VELOCITY_MEAN, VELOCITY_STD
from benchmark.config import (
eval_cfg,
TEST_SCENE_GROUPS,
VAL_SCENE_GROUPS,
DIFFICULTY_GROUPS,
DATASET_ROOT,
)
# ──────────────────────────── Metrics ────────────────────────────
def compute_metrics(
pred: np.ndarray,
target: np.ndarray,
) -> Dict[str, float]:
"""
Compute RMSE, MAE, R² for each axis and overall.
Args:
pred: (N, 2) denormalized predictions
target: (N, 2) denormalized ground truth
Returns:
dict with keys like rmse_vx, rmse_vy, rmse_xy, mae_vx, ...
"""
# Per-axis
rmse_x = float(np.sqrt(np.mean((pred[:, 0] - target[:, 0]) ** 2)))
rmse_y = float(np.sqrt(np.mean((pred[:, 1] - target[:, 1]) ** 2)))
rmse_xy = float(np.sqrt(np.mean(np.sum((pred - target) ** 2, axis=1))))
mae_x = float(np.mean(np.abs(pred[:, 0] - target[:, 0])))
mae_y = float(np.mean(np.abs(pred[:, 1] - target[:, 1])))
mae_xy = float(np.mean(np.sqrt(np.sum((pred - target) ** 2, axis=1))))
# R² per axis
def r2(p, t):
ss_res = np.sum((t - p) ** 2)
ss_tot = np.sum((t - np.mean(t)) ** 2)
return float(1 - ss_res / ss_tot) if ss_tot > 1e-12 else 0.0
r2_x = r2(pred[:, 0], target[:, 0])
r2_y = r2(pred[:, 1], target[:, 1])
return {
"rmse_vx": rmse_x,
"rmse_vy": rmse_y,
"rmse_xy": rmse_xy,
"mae_vx": mae_x,
"mae_vy": mae_y,
"mae_xy": mae_xy,
"r2_vx": r2_x,
"r2_vy": r2_y,
"count": len(pred),
}
# ──────────────────────────── Per-scene evaluation ────────────────────────────
@torch.no_grad()
def evaluate_scene(
model: nn.Module,
scene_names: List[str],
device: torch.device,
seq_len: int = 8,
batch_size: int = 64,
num_workers: int = 2,
event_threshold: float = 0.1,
event_use_log: bool = True,
) -> Dict:
"""
Evaluate model on one or more scenes.
Returns:
dict with keys:
preds: (N, 2) denormalized predictions
targets: (N, 2) denormalized ground truth
metrics: dict of scalar metrics
"""
loader = create_val_loader(
scene_names=scene_names,
seq_len=seq_len,
batch_size=batch_size,
num_workers=num_workers,
event_threshold=event_threshold,
event_use_log=event_use_log,
)
model.eval()
all_preds = []
all_targets = []
for batch in loader:
events = batch["events"].to(device)
tilt = batch["tilt"].to(device)
target = batch["v_body_target"].to(device) # (B, S, 2) normalized
pred = model(events, tilt) # (B, 2) normalized
target_last = target[:, -1, :] # (B, 2) normalized
all_preds.append(pred.cpu().numpy())
all_targets.append(target_last.cpu().numpy())
if not all_preds:
return {"preds": np.zeros((0, 2)), "targets": np.zeros((0, 2)), "metrics": {}}
preds = np.concatenate(all_preds, axis=0)
targets = np.concatenate(all_targets, axis=0)
# Denormalize
mean = np.array(VELOCITY_MEAN, dtype=np.float32)
std = np.array(VELOCITY_STD, dtype=np.float32)
preds_denorm = preds * std + mean
targets_denorm = targets * std + mean
metrics = compute_metrics(preds_denorm, targets_denorm)
return {
"preds": preds_denorm,
"targets": targets_denorm,
"metrics": metrics,
}
# ──────────────────────────── Full evaluation suite ────────────────────────────
def run_full_evaluation(
model: nn.Module,
device: torch.device,
seq_len: int = 8,
batch_size: int = 64,
num_workers: int = 2,
event_threshold: float = 0.1,
event_use_log: bool = True,
scene_groups=None,
) -> Dict:
"""
Run evaluation on all scene groups.
Returns nested dict:
{
"global": { metrics... },
"per_scene": {
"indoor_forward_7": { metrics..., "preds": ..., "targets": ... },
...
},
"by_difficulty": {
"easy": { metrics... },
"hard": { metrics... },
}
}
"""
if scene_groups is None:
from benchmark.config import TEST_SCENE_GROUPS
scene_groups = TEST_SCENE_GROUPS
per_scene = {}
all_preds = []
all_targets = []
for group in scene_groups:
for scene_name in group.scenes:
result = evaluate_scene(
model, [scene_name], device,
seq_len=seq_len, batch_size=batch_size,
num_workers=num_workers,
event_threshold=event_threshold,
event_use_log=event_use_log,
)
per_scene[scene_name] = result
if result["preds"].shape[0] > 0:
all_preds.append(result["preds"])
all_targets.append(result["targets"])
# Global metrics (all scenes combined)
if all_preds:
global_preds = np.concatenate(all_preds, axis=0)
global_targets = np.concatenate(all_targets, axis=0)
global_metrics = compute_metrics(global_preds, global_targets)
else:
global_preds = np.zeros((0, 2))
global_targets = np.zeros((0, 2))
global_metrics = {}
# By difficulty
by_difficulty = {}
for diff, scenes in DIFFICULTY_GROUPS.items():
diff_preds = []
diff_targets = []
for s in scenes:
if s in per_scene and per_scene[s]["preds"].shape[0] > 0:
diff_preds.append(per_scene[s]["preds"])
diff_targets.append(per_scene[s]["targets"])
if diff_preds:
by_difficulty[diff] = compute_metrics(
np.concatenate(diff_preds, axis=0),
np.concatenate(diff_targets, axis=0),
)
else:
by_difficulty[diff] = {}
return {
"global": global_metrics,
"per_scene": per_scene,
"by_difficulty": by_difficulty,
}
# ──────────────────────────── Visualization ────────────────────────────
def plot_scene_comparison(
preds: np.ndarray,
targets: np.ndarray,
scene_name: str,
save_dir: Path,
metrics: Optional[Dict] = None,
):
"""Generate time-series and scatter plots for a single scene."""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
time = np.arange(len(preds))
# ── Time-series plot ──
fig, axes = plt.subplots(2, 1, figsize=(14, 5), sharex=True)
axes[0].plot(time, targets[:, 0], label="GT vx", color="C0", alpha=0.7, linewidth=0.8)
axes[0].plot(time, preds[:, 0], label="Pred vx", color="C1", alpha=0.7, linewidth=0.8)
axes[0].set_ylabel("vx (m/s)")
axes[0].legend(fontsize=9)
axes[0].grid(True, alpha=0.3)
axes[1].plot(time, targets[:, 1], label="GT vy", color="C0", alpha=0.7, linewidth=0.8)
axes[1].plot(time, preds[:, 1], label="Pred vy", color="C1", alpha=0.7, linewidth=0.8)
axes[1].set_ylabel("vy (m/s)")
axes[1].set_xlabel("Frame index")
axes[1].legend(fontsize=9)
axes[1].grid(True, alpha=0.3)
title = f"Body-frame Velocity — {scene_name}"
if metrics:
title += f" | RMSE vx={metrics['rmse_vx']:.3f} vy={metrics['rmse_vy']:.3f}"
fig.suptitle(title)
try:
plt.tight_layout()
plt.savefig(save_dir / f"{scene_name}_timeseries.png", dpi=150, bbox_inches="tight")
except Exception as e:
print(f" [WARN] Failed to save {scene_name}_timeseries.png: {e}")
plt.close()
# ── Scatter plot ──
fig, axes = plt.subplots(1, 2, figsize=(10, 4.5))
for ax, pred, target, label in zip(
axes, [preds[:, 0], preds[:, 1]], [targets[:, 0], targets[:, 1]], ["vx", "vy"]
):
ax.scatter(target, pred, s=3, alpha=0.4, c="C1", edgecolors="none")
lim_min = min(target.min(), pred.min())
lim_max = max(target.max(), pred.max())
margin = (lim_max - lim_min) * 0.05
ax.plot([lim_min - margin, lim_max + margin],
[lim_min - margin, lim_max + margin], "r--", alpha=0.5, linewidth=1)
ax.set_xlabel(f"GT {label} (m/s)")
ax.set_ylabel(f"Pred {label} (m/s)")
ax.set_aspect("equal")
ax.grid(True, alpha=0.3)
if metrics:
ax.set_title(f"{label} — RMSE: {metrics[f'rmse_{label}']:.4f}")
fig.suptitle(f"Scatter — {scene_name}")
try:
plt.tight_layout()
plt.savefig(save_dir / f"{scene_name}_scatter.png", dpi=150, bbox_inches="tight")
except Exception as e:
print(f" [WARN] Failed to save {scene_name}_scatter.png: {e}")
plt.close()
def plot_global_comparison(
results: Dict,
save_dir: Path,
):
"""Generate a summary figure comparing all scenes."""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
scenes = list(results["per_scene"].keys())
if not scenes:
return
# Bar chart: RMSE vx and vy per scene
rmse_vx = [results["per_scene"][s]["metrics"].get("rmse_vx", 0) for s in scenes]
rmse_vy = [results["per_scene"][s]["metrics"].get("rmse_vy", 0) for s in scenes]
rmse_xy = [results["per_scene"][s]["metrics"].get("rmse_xy", 0) for s in scenes]
x = np.arange(len(scenes))
width = 0.25
fig, ax = plt.subplots(figsize=(10, 4.5))
bars1 = ax.bar(x - width, rmse_vx, width, label="RMSE vx", alpha=0.8)
bars2 = ax.bar(x, rmse_vy, width, label="RMSE vy", alpha=0.8)
bars3 = ax.bar(x + width, rmse_xy, width, label="RMSE xy", alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels(scenes, rotation=15, ha="right")
ax.set_ylabel("RMSE (m/s)")
ax.set_title("Per-Scene RMSE Comparison")
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3, axis="y")
# Annotate values
for bars in [bars1, bars2, bars3]:
for bar in bars:
height = bar.get_height()
ax.annotate(f"{height:.3f}",
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 2), textcoords="offset points",
ha="center", va="bottom", fontsize=7)
try:
plt.tight_layout()
plt.savefig(save_dir / "per_scene_rmse.png", dpi=150, bbox_inches="tight")
except Exception as e:
print(f" [WARN] Failed to save per_scene_rmse.png: {e}")
plt.close()
# ──────────────────────────── Results serialization ────────────────────────────
def save_results(
results: Dict,
save_dir: Path,
checkpoint_name: str = "model",
):
"""Save all evaluation results to disk."""
save_dir = Path(save_dir)
plots_dir = save_dir / "plots"
save_dir.mkdir(parents=True, exist_ok=True)
plots_dir.mkdir(parents=True, exist_ok=True)
# ── 1. Global metrics ──
global_m = results["global"]
lines = [
f"Benchmark Results: {checkpoint_name}",
f"{'=' * 50}",
f"Total samples: {global_m.get('count', '?')}",
"",
"── Global Metrics ──",
f" RMSE vx: {global_m.get('rmse_vx', 'N/A'):.4f} m/s",
f" RMSE vy: {global_m.get('rmse_vy', 'N/A'):.4f} m/s",
f" RMSE xy: {global_m.get('rmse_xy', 'N/A'):.4f} m/s",
f" MAE vx: {global_m.get('mae_vx', 'N/A'):.4f} m/s",
f" MAE vy: {global_m.get('mae_vy', 'N/A'):.4f} m/s",
f" MAE xy: {global_m.get('mae_xy', 'N/A'):.4f} m/s",
f" R² vx: {global_m.get('r2_vx', 'N/A'):.4f}",
f" R² vy: {global_m.get('r2_vy', 'N/A'):.4f}",
"",
]
# ── 2. Per-scene metrics ──
lines.append("── Per-Scene Metrics ──")
lines.append(f" {'Scene':<22} {'RMSE vx':>10} {'RMSE vy':>10} {'RMSE xy':>10} "
f"{'MAE vx':>10} {'MAE vy':>10} {'R² vx':>8} {'R² vy':>8} {'Samples':>8}")
lines.append(" " + "-" * 96)
for scene_name, scene_result in results["per_scene"].items():
m = scene_result["metrics"]
lines.append(
f" {scene_name:<22} {m.get('rmse_vx', 0):>10.4f} {m.get('rmse_vy', 0):>10.4f} "
f"{m.get('rmse_xy', 0):>10.4f} {m.get('mae_vx', 0):>10.4f} "
f"{m.get('mae_vy', 0):>10.4f} {m.get('r2_vx', 0):>8.4f} "
f"{m.get('r2_vy', 0):>8.4f} {m.get('count', 0):>8}"
)
lines.append("")
# ── 3. By difficulty ──
lines.append("── By Difficulty ──")
for diff, metrics in results["by_difficulty"].items():
lines.append(f" {diff:<10} RMSE vx={metrics.get('rmse_vx', 0):.4f} "
f"RMSE vy={metrics.get('rmse_vy', 0):.4f} "
f"RMSE xy={metrics.get('rmse_xy', 0):.4f} "
f"(samples={metrics.get('count', 0)})")
summary_text = "\n".join(lines)
with open(save_dir / "summary.txt", "w") as f:
f.write(summary_text)
print(summary_text)
# ── 4. CSV: per-scene metrics ──
import csv
csv_path = save_dir / "per_scene_metrics.csv"
fieldnames = ["scene", "rmse_vx", "rmse_vy", "rmse_xy", "mae_vx", "mae_vy",
"mae_xy", "r2_vx", "r2_vy", "count"]
with open(csv_path, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
for scene_name, scene_result in results["per_scene"].items():
row = {"scene": scene_name, **scene_result["metrics"]}
writer.writerow(row)
print(f"Per-scene CSV: {csv_path}")
# ── 5. Global metrics CSV (single row) ──
csv_path_global = save_dir / "metrics.csv"
with open(csv_path_global, "w", newline="") as f:
writer = csv.DictWriter(f, fieldnames=["checkpoint"] + list(global_m.keys()))
writer.writeheader()
row = {"checkpoint": checkpoint_name, **global_m}
writer.writerow(row)
print(f"Global metrics CSV: {csv_path_global}")
# ── 6. Plots ──
for scene_name, scene_result in results["per_scene"].items():
if scene_result["preds"].shape[0] > 0:
plot_scene_comparison(
scene_result["preds"],
scene_result["targets"],
scene_name,
plots_dir,
metrics=scene_result["metrics"],
)
plot_global_comparison(results, plots_dir)
import matplotlib.pyplot as plt
plt.close("all")
print(f"\nAll results saved to: {save_dir.resolve()}")