459 lines
15 KiB
Python
459 lines
15 KiB
Python
"""
|
|
Core evaluation logic: run model on one or more scenes, compute metrics,
|
|
generate visualizations, and save structured results.
|
|
|
|
This module is called by benchmark.py (the user-facing entry point).
|
|
"""
|
|
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import List, Optional, Dict, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from src.velocity_prediction.model import VelocityPredictionModel
|
|
from src.velocity_prediction.dataset import create_val_loader
|
|
from src.velocity_prediction.config import VELOCITY_MEAN, VELOCITY_STD
|
|
|
|
from benchmark.config import (
|
|
eval_cfg,
|
|
TEST_SCENE_GROUPS,
|
|
VAL_SCENE_GROUPS,
|
|
DIFFICULTY_GROUPS,
|
|
DATASET_ROOT,
|
|
)
|
|
|
|
|
|
# ──────────────────────────── Metrics ────────────────────────────
|
|
|
|
|
|
def compute_metrics(
|
|
pred: np.ndarray,
|
|
target: np.ndarray,
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Compute RMSE, MAE, R² for each axis and overall.
|
|
|
|
Args:
|
|
pred: (N, 2) denormalized predictions
|
|
target: (N, 2) denormalized ground truth
|
|
|
|
Returns:
|
|
dict with keys like rmse_vx, rmse_vy, rmse_xy, mae_vx, ...
|
|
"""
|
|
# Per-axis
|
|
rmse_x = float(np.sqrt(np.mean((pred[:, 0] - target[:, 0]) ** 2)))
|
|
rmse_y = float(np.sqrt(np.mean((pred[:, 1] - target[:, 1]) ** 2)))
|
|
rmse_xy = float(np.sqrt(np.mean(np.sum((pred - target) ** 2, axis=1))))
|
|
|
|
mae_x = float(np.mean(np.abs(pred[:, 0] - target[:, 0])))
|
|
mae_y = float(np.mean(np.abs(pred[:, 1] - target[:, 1])))
|
|
mae_xy = float(np.mean(np.sqrt(np.sum((pred - target) ** 2, axis=1))))
|
|
|
|
# R² per axis
|
|
def r2(p, t):
|
|
ss_res = np.sum((t - p) ** 2)
|
|
ss_tot = np.sum((t - np.mean(t)) ** 2)
|
|
return float(1 - ss_res / ss_tot) if ss_tot > 1e-12 else 0.0
|
|
|
|
r2_x = r2(pred[:, 0], target[:, 0])
|
|
r2_y = r2(pred[:, 1], target[:, 1])
|
|
|
|
return {
|
|
"rmse_vx": rmse_x,
|
|
"rmse_vy": rmse_y,
|
|
"rmse_xy": rmse_xy,
|
|
"mae_vx": mae_x,
|
|
"mae_vy": mae_y,
|
|
"mae_xy": mae_xy,
|
|
"r2_vx": r2_x,
|
|
"r2_vy": r2_y,
|
|
"count": len(pred),
|
|
}
|
|
|
|
|
|
# ──────────────────────────── Per-scene evaluation ────────────────────────────
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate_scene(
|
|
model: nn.Module,
|
|
scene_names: List[str],
|
|
device: torch.device,
|
|
seq_len: int = 8,
|
|
batch_size: int = 64,
|
|
num_workers: int = 2,
|
|
event_threshold: float = 0.1,
|
|
event_use_log: bool = True,
|
|
) -> Dict:
|
|
"""
|
|
Evaluate model on one or more scenes.
|
|
|
|
Returns:
|
|
dict with keys:
|
|
preds: (N, 2) denormalized predictions
|
|
targets: (N, 2) denormalized ground truth
|
|
metrics: dict of scalar metrics
|
|
"""
|
|
loader = create_val_loader(
|
|
scene_names=scene_names,
|
|
seq_len=seq_len,
|
|
batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
event_threshold=event_threshold,
|
|
event_use_log=event_use_log,
|
|
)
|
|
|
|
model.eval()
|
|
all_preds = []
|
|
all_targets = []
|
|
|
|
for batch in loader:
|
|
events = batch["events"].to(device)
|
|
tilt = batch["tilt"].to(device)
|
|
target = batch["v_body_target"].to(device) # (B, S, 2) normalized
|
|
|
|
pred = model(events, tilt) # (B, 2) normalized
|
|
target_last = target[:, -1, :] # (B, 2) normalized
|
|
|
|
all_preds.append(pred.cpu().numpy())
|
|
all_targets.append(target_last.cpu().numpy())
|
|
|
|
if not all_preds:
|
|
return {"preds": np.zeros((0, 2)), "targets": np.zeros((0, 2)), "metrics": {}}
|
|
|
|
preds = np.concatenate(all_preds, axis=0)
|
|
targets = np.concatenate(all_targets, axis=0)
|
|
|
|
# Denormalize
|
|
mean = np.array(VELOCITY_MEAN, dtype=np.float32)
|
|
std = np.array(VELOCITY_STD, dtype=np.float32)
|
|
preds_denorm = preds * std + mean
|
|
targets_denorm = targets * std + mean
|
|
|
|
metrics = compute_metrics(preds_denorm, targets_denorm)
|
|
|
|
return {
|
|
"preds": preds_denorm,
|
|
"targets": targets_denorm,
|
|
"metrics": metrics,
|
|
}
|
|
|
|
|
|
# ──────────────────────────── Full evaluation suite ────────────────────────────
|
|
|
|
|
|
def run_full_evaluation(
|
|
model: nn.Module,
|
|
device: torch.device,
|
|
seq_len: int = 8,
|
|
batch_size: int = 64,
|
|
num_workers: int = 2,
|
|
event_threshold: float = 0.1,
|
|
event_use_log: bool = True,
|
|
scene_groups=None,
|
|
) -> Dict:
|
|
"""
|
|
Run evaluation on all scene groups.
|
|
|
|
Returns nested dict:
|
|
{
|
|
"global": { metrics... },
|
|
"per_scene": {
|
|
"indoor_forward_7": { metrics..., "preds": ..., "targets": ... },
|
|
...
|
|
},
|
|
"by_difficulty": {
|
|
"easy": { metrics... },
|
|
"hard": { metrics... },
|
|
}
|
|
}
|
|
"""
|
|
if scene_groups is None:
|
|
from benchmark.config import TEST_SCENE_GROUPS
|
|
scene_groups = TEST_SCENE_GROUPS
|
|
|
|
per_scene = {}
|
|
all_preds = []
|
|
all_targets = []
|
|
|
|
for group in scene_groups:
|
|
for scene_name in group.scenes:
|
|
result = evaluate_scene(
|
|
model, [scene_name], device,
|
|
seq_len=seq_len, batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
event_threshold=event_threshold,
|
|
event_use_log=event_use_log,
|
|
)
|
|
per_scene[scene_name] = result
|
|
if result["preds"].shape[0] > 0:
|
|
all_preds.append(result["preds"])
|
|
all_targets.append(result["targets"])
|
|
|
|
# Global metrics (all scenes combined)
|
|
if all_preds:
|
|
global_preds = np.concatenate(all_preds, axis=0)
|
|
global_targets = np.concatenate(all_targets, axis=0)
|
|
global_metrics = compute_metrics(global_preds, global_targets)
|
|
else:
|
|
global_preds = np.zeros((0, 2))
|
|
global_targets = np.zeros((0, 2))
|
|
global_metrics = {}
|
|
|
|
# By difficulty
|
|
by_difficulty = {}
|
|
for diff, scenes in DIFFICULTY_GROUPS.items():
|
|
diff_preds = []
|
|
diff_targets = []
|
|
for s in scenes:
|
|
if s in per_scene and per_scene[s]["preds"].shape[0] > 0:
|
|
diff_preds.append(per_scene[s]["preds"])
|
|
diff_targets.append(per_scene[s]["targets"])
|
|
if diff_preds:
|
|
by_difficulty[diff] = compute_metrics(
|
|
np.concatenate(diff_preds, axis=0),
|
|
np.concatenate(diff_targets, axis=0),
|
|
)
|
|
else:
|
|
by_difficulty[diff] = {}
|
|
|
|
return {
|
|
"global": global_metrics,
|
|
"per_scene": per_scene,
|
|
"by_difficulty": by_difficulty,
|
|
}
|
|
|
|
|
|
# ──────────────────────────── Visualization ────────────────────────────
|
|
|
|
|
|
def plot_scene_comparison(
|
|
preds: np.ndarray,
|
|
targets: np.ndarray,
|
|
scene_name: str,
|
|
save_dir: Path,
|
|
metrics: Optional[Dict] = None,
|
|
):
|
|
"""Generate time-series and scatter plots for a single scene."""
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
|
|
save_dir = Path(save_dir)
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
time = np.arange(len(preds))
|
|
|
|
# ── Time-series plot ──
|
|
fig, axes = plt.subplots(2, 1, figsize=(14, 5), sharex=True)
|
|
|
|
axes[0].plot(time, targets[:, 0], label="GT vx", color="C0", alpha=0.7, linewidth=0.8)
|
|
axes[0].plot(time, preds[:, 0], label="Pred vx", color="C1", alpha=0.7, linewidth=0.8)
|
|
axes[0].set_ylabel("vx (m/s)")
|
|
axes[0].legend(fontsize=9)
|
|
axes[0].grid(True, alpha=0.3)
|
|
|
|
axes[1].plot(time, targets[:, 1], label="GT vy", color="C0", alpha=0.7, linewidth=0.8)
|
|
axes[1].plot(time, preds[:, 1], label="Pred vy", color="C1", alpha=0.7, linewidth=0.8)
|
|
axes[1].set_ylabel("vy (m/s)")
|
|
axes[1].set_xlabel("Frame index")
|
|
axes[1].legend(fontsize=9)
|
|
axes[1].grid(True, alpha=0.3)
|
|
|
|
title = f"Body-frame Velocity — {scene_name}"
|
|
if metrics:
|
|
title += f" | RMSE vx={metrics['rmse_vx']:.3f} vy={metrics['rmse_vy']:.3f}"
|
|
fig.suptitle(title)
|
|
try:
|
|
plt.tight_layout()
|
|
plt.savefig(save_dir / f"{scene_name}_timeseries.png", dpi=150, bbox_inches="tight")
|
|
except Exception as e:
|
|
print(f" [WARN] Failed to save {scene_name}_timeseries.png: {e}")
|
|
plt.close()
|
|
|
|
# ── Scatter plot ──
|
|
fig, axes = plt.subplots(1, 2, figsize=(10, 4.5))
|
|
|
|
for ax, pred, target, label in zip(
|
|
axes, [preds[:, 0], preds[:, 1]], [targets[:, 0], targets[:, 1]], ["vx", "vy"]
|
|
):
|
|
ax.scatter(target, pred, s=3, alpha=0.4, c="C1", edgecolors="none")
|
|
lim_min = min(target.min(), pred.min())
|
|
lim_max = max(target.max(), pred.max())
|
|
margin = (lim_max - lim_min) * 0.05
|
|
ax.plot([lim_min - margin, lim_max + margin],
|
|
[lim_min - margin, lim_max + margin], "r--", alpha=0.5, linewidth=1)
|
|
ax.set_xlabel(f"GT {label} (m/s)")
|
|
ax.set_ylabel(f"Pred {label} (m/s)")
|
|
ax.set_aspect("equal")
|
|
ax.grid(True, alpha=0.3)
|
|
if metrics:
|
|
ax.set_title(f"{label} — RMSE: {metrics[f'rmse_{label}']:.4f}")
|
|
|
|
fig.suptitle(f"Scatter — {scene_name}")
|
|
try:
|
|
plt.tight_layout()
|
|
plt.savefig(save_dir / f"{scene_name}_scatter.png", dpi=150, bbox_inches="tight")
|
|
except Exception as e:
|
|
print(f" [WARN] Failed to save {scene_name}_scatter.png: {e}")
|
|
plt.close()
|
|
|
|
|
|
def plot_global_comparison(
|
|
results: Dict,
|
|
save_dir: Path,
|
|
):
|
|
"""Generate a summary figure comparing all scenes."""
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
|
|
save_dir = Path(save_dir)
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
scenes = list(results["per_scene"].keys())
|
|
if not scenes:
|
|
return
|
|
|
|
# Bar chart: RMSE vx and vy per scene
|
|
rmse_vx = [results["per_scene"][s]["metrics"].get("rmse_vx", 0) for s in scenes]
|
|
rmse_vy = [results["per_scene"][s]["metrics"].get("rmse_vy", 0) for s in scenes]
|
|
rmse_xy = [results["per_scene"][s]["metrics"].get("rmse_xy", 0) for s in scenes]
|
|
|
|
x = np.arange(len(scenes))
|
|
width = 0.25
|
|
|
|
fig, ax = plt.subplots(figsize=(10, 4.5))
|
|
bars1 = ax.bar(x - width, rmse_vx, width, label="RMSE vx", alpha=0.8)
|
|
bars2 = ax.bar(x, rmse_vy, width, label="RMSE vy", alpha=0.8)
|
|
bars3 = ax.bar(x + width, rmse_xy, width, label="RMSE xy", alpha=0.8)
|
|
|
|
ax.set_xticks(x)
|
|
ax.set_xticklabels(scenes, rotation=15, ha="right")
|
|
ax.set_ylabel("RMSE (m/s)")
|
|
ax.set_title("Per-Scene RMSE Comparison")
|
|
ax.legend(fontsize=9)
|
|
ax.grid(True, alpha=0.3, axis="y")
|
|
|
|
# Annotate values
|
|
for bars in [bars1, bars2, bars3]:
|
|
for bar in bars:
|
|
height = bar.get_height()
|
|
ax.annotate(f"{height:.3f}",
|
|
xy=(bar.get_x() + bar.get_width() / 2, height),
|
|
xytext=(0, 2), textcoords="offset points",
|
|
ha="center", va="bottom", fontsize=7)
|
|
|
|
try:
|
|
plt.tight_layout()
|
|
plt.savefig(save_dir / "per_scene_rmse.png", dpi=150, bbox_inches="tight")
|
|
except Exception as e:
|
|
print(f" [WARN] Failed to save per_scene_rmse.png: {e}")
|
|
plt.close()
|
|
|
|
|
|
# ──────────────────────────── Results serialization ────────────────────────────
|
|
|
|
|
|
def save_results(
|
|
results: Dict,
|
|
save_dir: Path,
|
|
checkpoint_name: str = "model",
|
|
):
|
|
"""Save all evaluation results to disk."""
|
|
save_dir = Path(save_dir)
|
|
plots_dir = save_dir / "plots"
|
|
save_dir.mkdir(parents=True, exist_ok=True)
|
|
plots_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# ── 1. Global metrics ──
|
|
global_m = results["global"]
|
|
lines = [
|
|
f"Benchmark Results: {checkpoint_name}",
|
|
f"{'=' * 50}",
|
|
f"Total samples: {global_m.get('count', '?')}",
|
|
"",
|
|
"── Global Metrics ──",
|
|
f" RMSE vx: {global_m.get('rmse_vx', 'N/A'):.4f} m/s",
|
|
f" RMSE vy: {global_m.get('rmse_vy', 'N/A'):.4f} m/s",
|
|
f" RMSE xy: {global_m.get('rmse_xy', 'N/A'):.4f} m/s",
|
|
f" MAE vx: {global_m.get('mae_vx', 'N/A'):.4f} m/s",
|
|
f" MAE vy: {global_m.get('mae_vy', 'N/A'):.4f} m/s",
|
|
f" MAE xy: {global_m.get('mae_xy', 'N/A'):.4f} m/s",
|
|
f" R² vx: {global_m.get('r2_vx', 'N/A'):.4f}",
|
|
f" R² vy: {global_m.get('r2_vy', 'N/A'):.4f}",
|
|
"",
|
|
]
|
|
|
|
# ── 2. Per-scene metrics ──
|
|
lines.append("── Per-Scene Metrics ──")
|
|
lines.append(f" {'Scene':<22} {'RMSE vx':>10} {'RMSE vy':>10} {'RMSE xy':>10} "
|
|
f"{'MAE vx':>10} {'MAE vy':>10} {'R² vx':>8} {'R² vy':>8} {'Samples':>8}")
|
|
lines.append(" " + "-" * 96)
|
|
|
|
for scene_name, scene_result in results["per_scene"].items():
|
|
m = scene_result["metrics"]
|
|
lines.append(
|
|
f" {scene_name:<22} {m.get('rmse_vx', 0):>10.4f} {m.get('rmse_vy', 0):>10.4f} "
|
|
f"{m.get('rmse_xy', 0):>10.4f} {m.get('mae_vx', 0):>10.4f} "
|
|
f"{m.get('mae_vy', 0):>10.4f} {m.get('r2_vx', 0):>8.4f} "
|
|
f"{m.get('r2_vy', 0):>8.4f} {m.get('count', 0):>8}"
|
|
)
|
|
|
|
lines.append("")
|
|
|
|
# ── 3. By difficulty ──
|
|
lines.append("── By Difficulty ──")
|
|
for diff, metrics in results["by_difficulty"].items():
|
|
lines.append(f" {diff:<10} RMSE vx={metrics.get('rmse_vx', 0):.4f} "
|
|
f"RMSE vy={metrics.get('rmse_vy', 0):.4f} "
|
|
f"RMSE xy={metrics.get('rmse_xy', 0):.4f} "
|
|
f"(samples={metrics.get('count', 0)})")
|
|
|
|
summary_text = "\n".join(lines)
|
|
with open(save_dir / "summary.txt", "w") as f:
|
|
f.write(summary_text)
|
|
print(summary_text)
|
|
|
|
# ── 4. CSV: per-scene metrics ──
|
|
import csv
|
|
csv_path = save_dir / "per_scene_metrics.csv"
|
|
fieldnames = ["scene", "rmse_vx", "rmse_vy", "rmse_xy", "mae_vx", "mae_vy",
|
|
"mae_xy", "r2_vx", "r2_vy", "count"]
|
|
with open(csv_path, "w", newline="") as f:
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
for scene_name, scene_result in results["per_scene"].items():
|
|
row = {"scene": scene_name, **scene_result["metrics"]}
|
|
writer.writerow(row)
|
|
print(f"Per-scene CSV: {csv_path}")
|
|
|
|
# ── 5. Global metrics CSV (single row) ──
|
|
csv_path_global = save_dir / "metrics.csv"
|
|
with open(csv_path_global, "w", newline="") as f:
|
|
writer = csv.DictWriter(f, fieldnames=["checkpoint"] + list(global_m.keys()))
|
|
writer.writeheader()
|
|
row = {"checkpoint": checkpoint_name, **global_m}
|
|
writer.writerow(row)
|
|
print(f"Global metrics CSV: {csv_path_global}")
|
|
|
|
# ── 6. Plots ──
|
|
for scene_name, scene_result in results["per_scene"].items():
|
|
if scene_result["preds"].shape[0] > 0:
|
|
plot_scene_comparison(
|
|
scene_result["preds"],
|
|
scene_result["targets"],
|
|
scene_name,
|
|
plots_dir,
|
|
metrics=scene_result["metrics"],
|
|
)
|
|
|
|
plot_global_comparison(results, plots_dir)
|
|
|
|
import matplotlib.pyplot as plt
|
|
plt.close("all")
|
|
|
|
print(f"\nAll results saved to: {save_dir.resolve()}")
|