Files
uzh-fpv-sv-test/visualize/visualize_dataset.py
CaoWangrenbo 02d429282e feat: add --show-events overlay with raw log intensity
Visualize raw temporal brightness change (threshold=0, log domain)
as green(+)/red(-) gradient overlay proportional to |change|.
Supports video output and live display modes.
Enables EventProcessor threshold=0 for raw mode without clipping.

Generated by Mistral Vibe.
Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
2026-06-08 11:47:19 +08:00

478 lines
16 KiB
Python

"""
Dataset visualization: overlay body-frame pose on images and produce a video.
Usage:
uv run python -m visualize.visualize_dataset \\
--scene indoor_forward_3 \\
--output videos/indoor_forward_3.mp4
# Visualize all scenes
uv run python -m visualize.visualize_dataset --all --output videos/
# Show on screen instead of saving video
uv run python -m visualize.visualize_dataset --scene indoor_forward_3 --show
"""
import argparse
import io
import tarfile
from pathlib import Path
import cv2
import numpy as np
import torch
# Reuse the same coordinate transforms as the training pipeline
from src.velocity_prediction.utils import (
body_up_vector_np,
quat_to_euler_np,
world_vel_to_body_np,
quat_normalize,
quat_mul,
quat_from_matrix,
R_ODOM_TO_BODY_NP,
R_ODOM_TO_BODY,
)
from src.velocity_prediction.config import DATASET_ROOT, VELOCITY_MEAN, VELOCITY_STD
from src.event_utils import EventProcessor
# ──────────────────────────── Data loading ────────────────────────────
def load_scene_frames(scene_dir: Path):
"""
Load all frames from a scene's shard tar files.
Yields:
dict with keys: img (H,W uint8), ts (float), pose (7,), vel (6,)
"""
shard_files = sorted(scene_dir.glob("shard_*.tar"))
if not shard_files:
raise FileNotFoundError(f"No shard_*.tar files found in {scene_dir}")
for shard_path in shard_files:
with tarfile.open(shard_path, "r") as tar:
# Group entries by sample index
members = tar.getmembers()
samples: dict[str, dict[str, bytes]] = {}
for m in members:
idx, ext = m.name.rsplit(".", 1)
samples.setdefault(idx, {})[ext] = tar.extractfile(m).read()
# Sort by frame index to maintain temporal order
for idx in sorted(samples.keys(), key=lambda k: int(k.split("_")[-1])):
data = samples[idx]
img = cv2.imdecode(
np.frombuffer(data["jpg"], np.uint8), cv2.IMREAD_GRAYSCALE
)
ts = np.frombuffer(data["ts"], dtype=np.float64).item()
pose = np.frombuffer(data["pose"], dtype=np.float32).copy()
vel = np.frombuffer(data["vel"], dtype=np.float32).copy()
yield {"img": img, "ts": ts, "pose": pose, "vel": vel}
# ──────────────────────────── Pose computation ────────────────────────────
def compute_body_state(q_raw: np.ndarray, v_world: np.ndarray):
"""
Compute yaw-compensated horizontal velocity from raw GT pose quaternion.
The GT quaternion is world→odom (not world→body). The static
calibration R_odom_to_body is applied, then only yaw is compensated.
Args:
q_raw: (4,) numpy array — raw quaternion [qx, qy, qz, qw] from dataset (world→odom).
v_world: (3,) numpy array — world-frame linear velocity.
Returns:
v_horiz_xy: (2,) [v_right, v_forward] in yaw-aligned horizontal frame.
"""
v_horiz = world_vel_to_body_np(v_world, q_raw) # (3,) yaw-compensated
return np.array([v_horiz[0], v_horiz[1]], dtype=np.float32)
# ──────────────────────────── Attitude correction ────────────────────────────
#
# The GT quaternion is world→odom, not world→body. We apply the static
# calibration R_odom_to_body to obtain the true body orientation.
#
# q_world_to_body = q_world_to_odom * R_odom_to_body
_Q_R: torch.Tensor | None = None
def reset_attitude_offset():
"""Reset cached R quaternion (call before processing a new scene)."""
global _Q_R
_Q_R = None
def correct_attitude(q: np.ndarray) -> torch.Tensor:
"""
Apply static calibration R_odom_to_body to obtain true body orientation.
q_corrected = q_world_to_odom * R_odom_to_body
Args:
q: (4,) raw quaternion [qx, qy, qz, qw] from dataset (world→odom).
Returns:
q_corrected: (4,) torch tensor, world→body quaternion.
"""
global _Q_R
q_t = torch.from_numpy(q)
if _Q_R is None:
_Q_R = quat_from_matrix(R_ODOM_TO_BODY)
q_corrected = quat_mul(q_t, _Q_R)
return quat_normalize(q_corrected)
# ──────────────────────────── Drawing ────────────────────────────
def draw_pose_overlay(
canvas: np.ndarray,
pose: np.ndarray,
vel: np.ndarray,
tilt: np.ndarray,
v_body: np.ndarray,
euler: np.ndarray,
frame_idx: int,
ts: float,
events: np.ndarray | None = None,
show_events: bool = False,
):
"""
Draw body-frame pose and velocity information onto the image.
Args:
canvas: (H, W) grayscale uint8 — will be converted to BGR for drawing
pose: (7,) world-frame pose
vel: (6,) world-frame velocity
tilt: (3,) body attitude rotation vector (from body_attitude_np)
v_body: (2,) body-frame [v_right, v_forward]
euler: (3,) [roll, pitch, yaw] in degrees from body quaternion
frame_idx: current frame number
ts: timestamp
"""
# Convert to BGR for color overlay
display = cv2.cvtColor(canvas, cv2.COLOR_GRAY2BGR)
h, w = display.shape[:2]
# ── Helper ──
def put_text(
lines,
origin=(10, 20),
line_height=14,
font_scale=0.28,
color=(0, 255, 0),
thickness=1,
):
x, y = origin
for text in lines:
cv2.putText(
display,
text,
(x, y),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
color,
thickness,
cv2.LINE_AA,
)
y += line_height
# ── Info lines ──
info = [
f"Frame: {frame_idx}",
f"Time: {ts:.3f}s",
f"Pos: ({pose[0]:.2f}, {pose[1]:.2f}, {pose[2]:.2f}) m",
]
put_text(info, origin=(10, 20), color=(0, 255, 0))
# ── Euler angles (from body quaternion) ──
roll_deg, pitch_deg, yaw_deg = euler
euler_lines = [
f"Roll: {roll_deg:+.1f} deg",
f"Pitch: {pitch_deg:+.1f} deg",
f"Yaw: {yaw_deg:+.1f} deg",
]
put_text(euler_lines, origin=(10, 62), color=(0, 200, 255))
# ── Body up vector (pitch & roll only) ──
up_lines = [
f"Body up: ({tilt[0]:+.3f}, {tilt[1]:+.3f}, {tilt[2]:+.3f})",
]
put_text(up_lines, origin=(10, 104), color=(0, 200, 255))
# ── Body-frame velocity ──
v_right, v_forward = v_body # [v_right, v_forward]
vel_lines = [
f"v_body: forward={v_forward:+.3f} right={v_right:+.3f} m/s",
f" speed={np.sqrt(v_right**2 + v_forward**2):.3f} m/s",
]
put_text(vel_lines, origin=(10, 132), color=(255, 100, 100))
# ── World-frame velocity ──
wvel_lines = [
f"v_world: ({vel[0]:+.3f}, {vel[1]:+.3f}, {vel[2]:+.3f}) m/s",
]
put_text(wvel_lines, origin=(10, 160), color=(180, 180, 180))
# ── Velocity arrow (body frame) ──
center = (w // 2, h // 2)
vel_scale = 8.0 # pixels per m/s
v_right, v_forward = v_body
arrow_dx = int(v_right * vel_scale)
arrow_dy = int(-v_forward * vel_scale)
arrow_end = (center[0] + arrow_dx, center[1] + arrow_dy)
cv2.arrowedLine(display, center, arrow_end, (255, 0, 255), 2, tipLength=0.3)
cv2.putText(
display,
"v_body",
(arrow_end[0] + 8, arrow_end[1]),
cv2.FONT_HERSHEY_SIMPLEX,
0.4,
(255, 0, 255),
1,
cv2.LINE_AA,
)
# ── Attitude indicator (pitch & roll) ──
ah = 40 # half-length of the attitude line in pixels
# Center of attitude indicator
ax, ay = w // 2, h // 2
# Draw fixed reference line (white, horizontal)
cv2.line(display, (ax - ah, ay), (ax + ah, ay), (200, 200, 200), 1, cv2.LINE_AA)
# Draw moving attitude line (green)
# Roll: rotate line around center (positive roll = clockwise = -angle in image)
# Pitch: offset line vertically (positive pitch = nose up = line moves down)
pitch_offset = int(pitch_deg * 1.0) # pixels per degree
angle_rad = np.deg2rad(-roll_deg) # negate: right bank -> clockwise in image
cos_a = np.cos(angle_rad)
sin_a = np.sin(angle_rad)
x1 = int(ax + (-ah) * cos_a - 0 * sin_a)
y1 = int(ay + pitch_offset + (-ah) * sin_a + 0 * cos_a)
x2 = int(ax + (+ah) * cos_a - 0 * sin_a)
y2 = int(ay + pitch_offset + (+ah) * sin_a + 0 * cos_a)
cv2.line(display, (x1, y1), (x2, y2), (0, 255, 0), 2, cv2.LINE_AA)
# Small center dot
cv2.circle(display, (ax, ay), 2, (0, 255, 0), -1)
# Labels
cv2.putText(
display,
f"P{pitch_deg:+.0f}",
(ax + ah + 6, ay + pitch_offset + 4),
cv2.FONT_HERSHEY_SIMPLEX,
0.3,
(0, 255, 0),
1,
cv2.LINE_AA,
)
cv2.putText(
display,
f"R{roll_deg:+.0f}",
(ax + ah + 6, ay + 14),
cv2.FONT_HERSHEY_SIMPLEX,
0.3,
(0, 255, 0),
1,
cv2.LINE_AA,
)
# ── Event overlay (gradient temporal intensity) ──
if show_events and events is not None:
limit = max(np.abs(events).max(), 1e-6)
norm = np.clip(events / limit, -1.0, 1.0)
pos = norm > 0
neg = norm < 0
intensity = np.abs(norm) # [0, 1] magnitude
overlay = np.zeros_like(display, dtype=np.uint8)
# bg = np.ones_like(display, dtype=np.uint8) * 255
# Color intensity proportional to |norm|: dark → bright
overlay[pos, 1] = (255 * intensity[pos]).astype(np.uint8) # green channel
overlay[neg, 2] = (255 * intensity[neg]).astype(np.uint8) # red channel
# display = cv2.addWeighted(bg, 0.5, overlay, 1.0, 0)
display = cv2.addWeighted(display, 0.5, overlay, 1.0, 0)
return display
# ──────────────────────────── Video generation ────────────────────────────
def create_video(
scene_name: str,
output_path: str | Path,
fps: float = 30.0,
max_frames: int | None = None,
show: bool = False,
show_events: bool = False,
):
"""
Read scene data, overlay pose info, and write to video file (or show).
"""
scene_dir = DATASET_ROOT / scene_name
if not scene_dir.exists():
raise FileNotFoundError(f"Scene directory not found: {scene_dir}")
print(f"Loading scene: {scene_name}")
frames = list(load_scene_frames(scene_dir))
print(f" Total frames: {len(frames)}")
if max_frames:
frames = frames[:max_frames]
print(f" Using first {max_frames} frames")
# Reset attitude offset for this scene
reset_attitude_offset()
# Event processor (threshold=0 → raw temporal intensity)
event_processor = EventProcessor(threshold=0.3, use_log=True) if show_events else None
# Get dimensions from first frame
h, w = frames[0]["img"].shape
# Video writer
if not show:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h))
print(f" Output: {output_path} ({w}x{h} @ {fps}fps)")
else:
writer = None
print(f" Showing on screen (press ESC or 'q' to quit)")
# Process each frame
for i, frame_data in enumerate(frames):
q_raw = frame_data["pose"][3:7] # [qx, qy, qz, qw] world→odom
# Compute events if enabled
if event_processor is not None:
events_binary, _, _ = event_processor(frame_data["img"])
else:
events_binary = None
# Body up vector (pitch & roll only, no yaw) — matches DiffPhysDrone
body_up = body_up_vector_np(q_raw) # (3,) unit vector
# Euler angles from body quaternion for display
q_body = correct_attitude(q_raw)
euler_rad = quat_to_euler_np(q_body.numpy()) # [roll, pitch, yaw] rad
euler_deg = np.rad2deg(euler_rad) # [roll, pitch, yaw] deg
# Compute body-frame velocity from raw quaternion
v_body = compute_body_state(q_raw, frame_data["vel"][:3])
display = draw_pose_overlay(
canvas=frame_data["img"],
pose=frame_data["pose"],
vel=frame_data["vel"],
tilt=body_up,
v_body=v_body,
euler=euler_deg,
frame_idx=i,
ts=frame_data["ts"],
events=events_binary,
show_events=show_events,
)
if show:
cv2.imshow(f"UZH-FPV: {scene_name}", display)
key = cv2.waitKey(int(1000 / fps)) & 0xFF
if key in (27, ord("q")): # ESC or q
print(" Interrupted by user")
break
else:
writer.write(display)
if (i + 1) % 500 == 0:
print(f" Processed {i + 1}/{len(frames)} frames")
if writer:
writer.release()
print(f" Video saved: {output_path}")
if show:
cv2.destroyAllWindows()
print(f" Done. Processed {i + 1} frames.")
# ──────────────────────────── Main ────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="Visualize UZH-FPV dataset with body-frame pose overlay"
)
parser.add_argument(
"--scene", type=str, default=None, help="Scene name (e.g. indoor_forward_3)"
)
parser.add_argument("--all", action="store_true", help="Process all scenes")
parser.add_argument(
"--output",
type=str,
default="videos",
help="Output video path or directory (default: videos/)",
)
parser.add_argument(
"--fps", type=float, default=30.0, help="Output video framerate (default: 30)"
)
parser.add_argument(
"--max-frames", type=int, default=None, help="Limit number of frames to process"
)
parser.add_argument(
"--show", action="store_true", help="Display on screen instead of saving video"
)
parser.add_argument(
"--show-events", action="store_true", help="Overlay event frames (green=+1, red=-1)"
)
args = parser.parse_args()
# Collect scenes to process
if args.all:
scenes = sorted(
d.name
for d in DATASET_ROOT.iterdir()
if d.is_dir() and any(d.glob("shard_*.tar"))
)
if not scenes:
print("No scenes with shard files found.")
return
print(f"Processing all {len(scenes)} scenes: {scenes}")
elif args.scene:
scenes = [args.scene]
else:
parser.print_help()
print("\nError: specify --scene <name> or --all")
return
for scene in scenes:
if args.all and not args.show:
out_path = Path(args.output) / f"{scene}.mp4"
else:
out_path = args.output
create_video(
scene_name=scene,
output_path=out_path,
fps=args.fps,
max_frames=args.max_frames,
show=args.show,
show_events=args.show_events,
)
if __name__ == "__main__":
main()