diff --git a/src/velocity_prediction/transforms.py b/src/velocity_prediction/transforms.py index de44145..42f040c 100644 --- a/src/velocity_prediction/transforms.py +++ b/src/velocity_prediction/transforms.py @@ -15,7 +15,7 @@ import numpy as np import cv2 from src.event_utils import EventProcessor -from src.velocity_prediction.utils import decompose_tilt_np, world_vel_to_body_np +from src.velocity_prediction.utils import body_attitude_np, world_vel_to_body_np from src.velocity_prediction.config import VELOCITY_MEAN, VELOCITY_STD @@ -75,24 +75,36 @@ class SimulateEvents: class ComputeTilt: - """Extract tilt rotation vector from pose quaternion (discard position, discard yaw).""" + """Compute body attitude rotation vector from pose quaternion. + + Applies the static calibration R_odom_to_body to obtain the true + world→body quaternion, then converts to a rotation vector. + Unlike the old approach, yaw is preserved — the model can decide + how to use it. + """ def __call__(self, sample: dict) -> dict: - q = sample["pose"][3:7] # [qx, qy, qz, qw] - tilt = decompose_tilt_np(q) # (3,) rotation vector - sample["tilt"] = tilt.astype(np.float32) + q = sample["pose"][3:7] # [qx, qy, qz, qw] world→odom + att = body_attitude_np(q) # (3,) rotation vector of true body + sample["tilt"] = att.astype(np.float32) return sample class ComputeBodyVelocity: - """Transform world-frame velocity to body-frame (yaw-compensated).""" + """Transform world-frame velocity to yaw-compensated horizontal velocity. + + The GT quaternion is world→odom (not world→body). A static calibration + R_odom_to_body is applied, then only yaw is compensated (no pitch/roll). + + Output: [v_right, v_forward] in the horizontal plane, aligned with heading. + """ def __call__(self, sample: dict) -> dict: v_world = sample["vel"][:3] # [vx, vy, vz] world frame - q = sample["pose"][3:7] # [qx, qy, qz, qw] - v_body = world_vel_to_body_np(v_world, q) # (3,) - # Only predict forward (x) and lateral (y) body velocity - sample["v_body_target"] = v_body[:2].astype(np.float32) # (2,) + q = sample["pose"][3:7] # [qx, qy, qz, qw] world→odom + v_horiz = world_vel_to_body_np(v_world, q) # (3,) yaw-compensated + # [v_right, v_forward] = [vx, vy] in yaw-aligned horizontal frame + sample["v_body_target"] = np.array([v_horiz[0], v_horiz[1]], dtype=np.float32) return sample diff --git a/src/velocity_prediction/utils.py b/src/velocity_prediction/utils.py index a64ce34..7cb1df4 100644 --- a/src/velocity_prediction/utils.py +++ b/src/velocity_prediction/utils.py @@ -108,44 +108,185 @@ def quat_to_rotvec(q: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: return torch.stack([rx, ry, rz], dim=-1) +# ──────────────────────────── Static odom→body calibration ──────────────────────────── +# +# The GT pose from the motion-capture system gives world→odom, NOT world→body. +# There is a static rotation R_odom_to_body that corrects this. +# +# R = R_y(45°) @ R_x(90°): first rotate +90° around odom_x, then +45° around odom_y. +# This maps odom-frame vectors to the true body frame (ROS convention): +# body_x = right, body_y = forward, body_z = up +# +# At t=0 (FPV level on ground): +# body_z+ (up) ≈ world_z+ +# body_y+ (forward) ≈ world_x- (i.e. [-1, 0, 0]) +# body_x+ (right) ≈ world_y+ (i.e. [0, 1, 0]) + +R_ODOM_TO_BODY_NP = np.array([ + [ 0.70710678, 0.70710678, 0. ], + [ 0., 0., -1. ], + [-0.70710678, 0.70710678, 0. ], +], dtype=np.float64) + +R_ODOM_TO_BODY = torch.from_numpy(R_ODOM_TO_BODY_NP) + + # ──────────────────────────── Velocity transformation ──────────────────────────── def world_vel_to_body( v_world: torch.Tensor, - q_world_to_body: torch.Tensor, + q_world_to_odom: torch.Tensor, ) -> torch.Tensor: """ - Transform world-frame velocity to body-frame velocity. + Transform world-frame velocity to yaw-compensated horizontal velocity. + + The GT quaternion is world→odom (not world→body). We apply the static + calibration R_odom_to_body, then extract only the yaw to rotate the + world velocity into a yaw-aligned horizontal frame. + + Only yaw is compensated — pitch/roll (tilt) are NOT included, so the + output is the horizontal-plane velocity in a frame aligned with the + body's heading. Steps: - 1. Extract yaw from q_world_to_body. - 2. Build pure-yaw quaternion q_yaw. - 3. Remove yaw from velocity: v_yaw_compensated = q_yaw^{-1} * v_world - 4. Rotate to body frame: v_body = q_tilt^{-1} * v_yaw_compensated - where q_tilt = q_yaw^{-1} * q_world_to_body - - Args: - v_world: (..., 3) world-frame linear velocity [vx, vy, vz] - q_world_to_body: (..., 4) world→body unit quaternion + 1. Compute world→body quaternion: q_world_to_body = q_world_to_odom * R + 2. Extract yaw from q_world_to_body. + 3. Remove yaw from velocity: v_horiz = q_yaw^{-1} * v_world Returns: - v_body: (..., 3) body-frame linear velocity + v_horiz: (..., 3) yaw-compensated horizontal velocity + [v_right, v_forward, v_up] where v_up ≈ vertical """ + # Step 0: apply static calibration + q_R = quat_from_matrix(R_ODOM_TO_BODY.to(q_world_to_odom.device)) + q_world_to_body = quat_mul(q_world_to_odom, q_R) + q_world_to_body = quat_normalize(q_world_to_body) + + # Step 1: extract yaw only yaw = quat_to_yaw(q_world_to_body) q_yaw = quat_from_yaw(yaw) q_yaw_inv = quat_conjugate(q_yaw) - # Step 1: remove yaw from velocity (rotate to yaw-aligned intermediate frame) - v_yaw_comp = quat_rotate(q_yaw_inv, v_world) + # Step 2: remove yaw from velocity (rotate to yaw-aligned horizontal frame) + v_horiz = quat_rotate(q_yaw_inv, v_world) + return v_horiz - # Step 2: compute tilt quaternion - q_tilt = quat_mul(q_yaw_inv, q_world_to_body) - q_tilt = quat_normalize(q_tilt) - q_tilt_inv = quat_conjugate(q_tilt) - # Step 3: rotate to body frame - v_body = quat_rotate(q_tilt_inv, v_yaw_comp) - return v_body +def quat_from_matrix(R: torch.Tensor) -> torch.Tensor: + """ + Convert a 3x3 rotation matrix to a unit quaternion [x, y, z, w]. + + Args: + R: (3, 3) rotation matrix + + Returns: + q: (4,) unit quaternion + """ + trace = R[0, 0] + R[1, 1] + R[2, 2] + if trace > 0: + s = 0.5 / torch.sqrt(trace + 1.0) + w = 0.25 / s + x = (R[2, 1] - R[1, 2]) * s + y = (R[0, 2] - R[2, 0]) * s + z = (R[1, 0] - R[0, 1]) * s + elif R[0, 0] > R[1, 1] and R[0, 0] > R[2, 2]: + s = 2.0 * torch.sqrt(1.0 + R[0, 0] - R[1, 1] - R[2, 2]) + w = (R[2, 1] - R[1, 2]) / s + x = 0.25 * s + y = (R[0, 1] + R[1, 0]) / s + z = (R[0, 2] + R[2, 0]) / s + elif R[1, 1] > R[2, 2]: + s = 2.0 * torch.sqrt(1.0 + R[1, 1] - R[0, 0] - R[2, 2]) + w = (R[0, 2] - R[2, 0]) / s + x = (R[0, 1] + R[1, 0]) / s + y = 0.25 * s + z = (R[1, 2] + R[2, 1]) / s + else: + s = 2.0 * torch.sqrt(1.0 + R[2, 2] - R[0, 0] - R[1, 1]) + w = (R[1, 0] - R[0, 1]) / s + x = (R[0, 2] + R[2, 0]) / s + y = (R[1, 2] + R[2, 1]) / s + z = 0.25 * s + return torch.stack([x, y, z, w]) + + +def decompose_tilt_from_odom(q_world_to_odom: torch.Tensor) -> torch.Tensor: + """ + Decompose tilt from the GT quaternion, applying the static calibration. + + The returned tilt is the pitch/roll of the true body relative to its + heading direction (yaw removed). + + Args: + q_world_to_odom: (..., 4) world→odom unit quaternion from GT + + Returns: + tilt_angles: (..., 3) rotation vector [rx, ry, rz] + """ + q_R = quat_from_matrix(R_ODOM_TO_BODY.to(q_world_to_odom.device)) + q_world_to_body = quat_mul(q_world_to_odom, q_R) + q_world_to_body = quat_normalize(q_world_to_body) + return decompose_tilt(q_world_to_body) + + +# ──────────────────────────── Body attitude (new approach) ──────────────────────────── +# +# Instead of removing yaw from the body quaternion, we directly use the +# corrected world→body quaternion's rotation vector. This preserves yaw +# information and lets the model decide how to use it — analogous to how +# DiffPhysDrone uses the body-up vector as a tilt feature. + +def body_attitude(q_world_to_odom: torch.Tensor) -> torch.Tensor: + """ + Compute the true body attitude rotation vector from GT odom quaternion. + + Applies the static calibration R_odom_to_body, then converts the + resulting world→body quaternion directly to a rotation vector. + Unlike decompose_tilt, this preserves yaw information. + + Args: + q_world_to_odom: (..., 4) world→odom unit quaternion from GT + + Returns: + attitude: (..., 3) rotation vector [rx, ry, rz] of the true body + """ + q_R = quat_from_matrix(R_ODOM_TO_BODY.to(q_world_to_odom.device)) + q_world_to_body = quat_mul(q_world_to_odom, q_R) + q_world_to_body = quat_normalize(q_world_to_body) + return quat_to_rotvec(q_world_to_body) + + +def quat_to_euler(q: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to ZYX Euler angles (yaw, pitch, roll). + + Follows ROS convention: R = R_z(yaw) @ R_y(pitch) @ R_x(roll) + Gravity axis is +z. + + Args: + q: (..., 4) unit quaternion [x, y, z, w] + + Returns: + euler: (..., 3) [roll, pitch, yaw] in radians + """ + x, y, z, w = q.unbind(-1) + + # roll (x-axis rotation) + sinr_cosp = 2.0 * (w * x + y * z) + cosr_cosp = 1.0 - 2.0 * (x * x + y * y) + roll = torch.atan2(sinr_cosp, cosr_cosp) + + # pitch (y-axis rotation) + sinp = 2.0 * (w * y - z * x) + sinp = sinp.clamp(-1.0, 1.0) + pitch = torch.asin(sinp) + + # yaw (z-axis rotation) + siny_cosp = 2.0 * (w * z + x * y) + cosy_cosp = 1.0 - 2.0 * (y * y + z * z) + yaw = torch.atan2(siny_cosp, cosy_cosp) + + return torch.stack([roll, pitch, yaw], dim=-1) # ──────────────────────────── NumPy wrappers (for transforms.py) ──────────────────────────── @@ -157,9 +298,23 @@ def decompose_tilt_np(q: np.ndarray) -> np.ndarray: return tilt.numpy() +def body_attitude_np(q: np.ndarray) -> np.ndarray: + """NumPy version of body_attitude.""" + q_t = torch.from_numpy(q) + att = body_attitude(q_t) + return att.numpy() + + +def quat_to_euler_np(q: np.ndarray) -> np.ndarray: + """NumPy version of quat_to_euler.""" + q_t = torch.from_numpy(q) + euler = quat_to_euler(q_t) + return euler.numpy() + + def world_vel_to_body_np(v_world: np.ndarray, q: np.ndarray) -> np.ndarray: """NumPy version of world_vel_to_body.""" - v_t = torch.from_numpy(v_world) - q_t = torch.from_numpy(q) + v_t = torch.from_numpy(v_world.copy()) + q_t = torch.from_numpy(q.copy()) vb = world_vel_to_body(v_t, q_t) return vb.numpy() diff --git a/visualize/visualize_dataset.py b/visualize/visualize_dataset.py new file mode 100644 index 0000000..2fdce21 --- /dev/null +++ b/visualize/visualize_dataset.py @@ -0,0 +1,443 @@ +""" +Dataset visualization: overlay body-frame pose on images and produce a video. + +Usage: + uv run python -m visualize.visualize_dataset \\ + --scene indoor_forward_3 \\ + --output videos/indoor_forward_3.mp4 + + # Visualize all scenes + uv run python -m visualize.visualize_dataset --all --output videos/ + + # Show on screen instead of saving video + uv run python -m visualize.visualize_dataset --scene indoor_forward_3 --show +""" + +import argparse +import io +import tarfile +from pathlib import Path + +import cv2 +import numpy as np +import torch + +# Reuse the same coordinate transforms as the training pipeline +from src.velocity_prediction.utils import ( + body_attitude_np, + quat_to_euler_np, + world_vel_to_body_np, + quat_normalize, + quat_mul, + quat_from_matrix, + R_ODOM_TO_BODY_NP, + R_ODOM_TO_BODY, +) +from src.velocity_prediction.config import DATASET_ROOT, VELOCITY_MEAN, VELOCITY_STD + + +# ──────────────────────────── Data loading ──────────────────────────── + + +def load_scene_frames(scene_dir: Path): + """ + Load all frames from a scene's shard tar files. + + Yields: + dict with keys: img (H,W uint8), ts (float), pose (7,), vel (6,) + """ + shard_files = sorted(scene_dir.glob("shard_*.tar")) + if not shard_files: + raise FileNotFoundError(f"No shard_*.tar files found in {scene_dir}") + + for shard_path in shard_files: + with tarfile.open(shard_path, "r") as tar: + # Group entries by sample index + members = tar.getmembers() + samples: dict[str, dict[str, bytes]] = {} + for m in members: + idx, ext = m.name.rsplit(".", 1) + samples.setdefault(idx, {})[ext] = tar.extractfile(m).read() + + # Sort by frame index to maintain temporal order + for idx in sorted(samples.keys(), key=lambda k: int(k.split("_")[-1])): + data = samples[idx] + img = cv2.imdecode( + np.frombuffer(data["jpg"], np.uint8), cv2.IMREAD_GRAYSCALE + ) + ts = np.frombuffer(data["ts"], dtype=np.float64).item() + pose = np.frombuffer(data["pose"], dtype=np.float32).copy() + vel = np.frombuffer(data["vel"], dtype=np.float32).copy() + yield {"img": img, "ts": ts, "pose": pose, "vel": vel} + + +# ──────────────────────────── Pose computation ──────────────────────────── + + +def compute_body_state(q_raw: np.ndarray, v_world: np.ndarray): + """ + Compute yaw-compensated horizontal velocity from raw GT pose quaternion. + + The GT quaternion is world→odom (not world→body). The static + calibration R_odom_to_body is applied, then only yaw is compensated. + + Args: + q_raw: (4,) numpy array — raw quaternion [qx, qy, qz, qw] from dataset (world→odom). + v_world: (3,) numpy array — world-frame linear velocity. + + Returns: + v_horiz_xy: (2,) [v_right, v_forward] in yaw-aligned horizontal frame. + """ + v_horiz = world_vel_to_body_np(v_world, q_raw) # (3,) yaw-compensated + return np.array([v_horiz[0], v_horiz[1]], dtype=np.float32) + + +# ──────────────────────────── Attitude correction ──────────────────────────── +# +# The GT quaternion is world→odom, not world→body. We apply the static +# calibration R_odom_to_body to obtain the true body orientation. +# +# q_world_to_body = q_world_to_odom * R_odom_to_body + +_Q_R: torch.Tensor | None = None + + +def reset_attitude_offset(): + """Reset cached R quaternion (call before processing a new scene).""" + global _Q_R + _Q_R = None + + +def correct_attitude(q: np.ndarray) -> torch.Tensor: + """ + Apply static calibration R_odom_to_body to obtain true body orientation. + + q_corrected = q_world_to_odom * R_odom_to_body + + Args: + q: (4,) raw quaternion [qx, qy, qz, qw] from dataset (world→odom). + + Returns: + q_corrected: (4,) torch tensor, world→body quaternion. + """ + global _Q_R + q_t = torch.from_numpy(q) + if _Q_R is None: + _Q_R = quat_from_matrix(R_ODOM_TO_BODY) + q_corrected = quat_mul(q_t, _Q_R) + return quat_normalize(q_corrected) + + +# ──────────────────────────── Drawing ──────────────────────────── + + +def draw_pose_overlay( + canvas: np.ndarray, + pose: np.ndarray, + vel: np.ndarray, + tilt: np.ndarray, + v_body: np.ndarray, + euler: np.ndarray, + frame_idx: int, + ts: float, +): + """ + Draw body-frame pose and velocity information onto the image. + + Args: + canvas: (H, W) grayscale uint8 — will be converted to BGR for drawing + pose: (7,) world-frame pose + vel: (6,) world-frame velocity + tilt: (3,) body attitude rotation vector (from body_attitude_np) + v_body: (2,) body-frame [v_right, v_forward] + euler: (3,) [roll, pitch, yaw] in degrees from body quaternion + frame_idx: current frame number + ts: timestamp + """ + # Convert to BGR for color overlay + display = cv2.cvtColor(canvas, cv2.COLOR_GRAY2BGR) + + h, w = display.shape[:2] + + # ── Helper ── + def put_text( + lines, + origin=(10, 20), + line_height=14, + font_scale=0.28, + color=(0, 255, 0), + thickness=1, + ): + x, y = origin + for text in lines: + cv2.putText( + display, + text, + (x, y), + cv2.FONT_HERSHEY_SIMPLEX, + font_scale, + color, + thickness, + cv2.LINE_AA, + ) + y += line_height + + # ── Info lines ── + info = [ + f"Frame: {frame_idx}", + f"Time: {ts:.3f}s", + f"Pos: ({pose[0]:.2f}, {pose[1]:.2f}, {pose[2]:.2f}) m", + ] + put_text(info, origin=(10, 20), color=(0, 255, 0)) + + # ── Euler angles (from body quaternion) ── + roll_deg, pitch_deg, yaw_deg = euler + euler_lines = [ + f"Roll: {roll_deg:+.1f} deg", + f"Pitch: {pitch_deg:+.1f} deg", + f"Yaw: {yaw_deg:+.1f} deg", + ] + put_text(euler_lines, origin=(10, 62), color=(0, 200, 255)) + + # ── Body attitude (rotation vector) ── + tilt_lines = [ + f"Att: rx={tilt[0]:+.3f} ry={tilt[1]:+.3f} rz={tilt[2]:+.3f}", + ] + put_text(tilt_lines, origin=(10, 104), color=(0, 200, 255)) + + # ── Body-frame velocity ── + v_right, v_forward = v_body # [v_right, v_forward] + vel_lines = [ + f"v_body: forward={v_forward:+.3f} right={v_right:+.3f} m/s", + f" speed={np.sqrt(v_right**2 + v_forward**2):.3f} m/s", + ] + put_text(vel_lines, origin=(10, 132), color=(255, 100, 100)) + + # ── World-frame velocity ── + wvel_lines = [ + f"v_world: ({vel[0]:+.3f}, {vel[1]:+.3f}, {vel[2]:+.3f}) m/s", + ] + put_text(wvel_lines, origin=(10, 160), color=(180, 180, 180)) + + # ── Velocity arrow (body frame) ── + center = (w // 2, h // 2) + vel_scale = 8.0 # pixels per m/s + v_right, v_forward = v_body + arrow_dx = int(v_right * vel_scale) + arrow_dy = int(-v_forward * vel_scale) + arrow_end = (center[0] + arrow_dx, center[1] + arrow_dy) + cv2.arrowedLine(display, center, arrow_end, (255, 0, 255), 2, tipLength=0.3) + cv2.putText( + display, + "v_body", + (arrow_end[0] + 8, arrow_end[1]), + cv2.FONT_HERSHEY_SIMPLEX, + 0.4, + (255, 0, 255), + 1, + cv2.LINE_AA, + ) + + # ── Attitude indicator (pitch & roll) ── + ah = 40 # half-length of the attitude line in pixels + + # Center of attitude indicator + ax, ay = w // 2, h // 2 + + # Draw fixed reference line (white, horizontal) + cv2.line(display, (ax - ah, ay), (ax + ah, ay), (200, 200, 200), 1, cv2.LINE_AA) + + # Draw moving attitude line (green) + # Roll: rotate line around center (positive roll = clockwise = -angle in image) + # Pitch: offset line vertically (positive pitch = nose up = line moves down) + pitch_offset = int(pitch_deg * 1.0) # pixels per degree + angle_rad = np.deg2rad(-roll_deg) # negate: right bank -> clockwise in image + cos_a = np.cos(angle_rad) + sin_a = np.sin(angle_rad) + x1 = int(ax + (-ah) * cos_a - 0 * sin_a) + y1 = int(ay + pitch_offset + (-ah) * sin_a + 0 * cos_a) + x2 = int(ax + (+ah) * cos_a - 0 * sin_a) + y2 = int(ay + pitch_offset + (+ah) * sin_a + 0 * cos_a) + cv2.line(display, (x1, y1), (x2, y2), (0, 255, 0), 2, cv2.LINE_AA) + + # Small center dot + cv2.circle(display, (ax, ay), 2, (0, 255, 0), -1) + + # Labels + cv2.putText( + display, + f"P{pitch_deg:+.0f}", + (ax + ah + 6, ay + pitch_offset + 4), + cv2.FONT_HERSHEY_SIMPLEX, + 0.3, + (0, 255, 0), + 1, + cv2.LINE_AA, + ) + cv2.putText( + display, + f"R{roll_deg:+.0f}", + (ax + ah + 6, ay + 14), + cv2.FONT_HERSHEY_SIMPLEX, + 0.3, + (0, 255, 0), + 1, + cv2.LINE_AA, + ) + + return display + + +# ──────────────────────────── Video generation ──────────────────────────── + + +def create_video( + scene_name: str, + output_path: str | Path, + fps: float = 30.0, + max_frames: int | None = None, + show: bool = False, +): + """ + Read scene data, overlay pose info, and write to video file (or show). + """ + scene_dir = DATASET_ROOT / scene_name + if not scene_dir.exists(): + raise FileNotFoundError(f"Scene directory not found: {scene_dir}") + + print(f"Loading scene: {scene_name}") + frames = list(load_scene_frames(scene_dir)) + print(f" Total frames: {len(frames)}") + + if max_frames: + frames = frames[:max_frames] + print(f" Using first {max_frames} frames") + + # Reset attitude offset for this scene + reset_attitude_offset() + + # Get dimensions from first frame + h, w = frames[0]["img"].shape + + # Video writer + if not show: + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + writer = cv2.VideoWriter(str(output_path), fourcc, fps, (w, h)) + print(f" Output: {output_path} ({w}x{h} @ {fps}fps)") + else: + writer = None + print(f" Showing on screen (press ESC or 'q' to quit)") + + # Process each frame + for i, frame_data in enumerate(frames): + q_raw = frame_data["pose"][3:7] # [qx, qy, qz, qw] world→odom + + # True body attitude rotation vector (preserves yaw) + tilt = body_attitude_np(q_raw) # (3,) + + # Euler angles from body quaternion for display + q_body = correct_attitude(q_raw) + euler_rad = quat_to_euler_np(q_body.numpy()) # [roll, pitch, yaw] rad + euler_deg = np.rad2deg(euler_rad) # [roll, pitch, yaw] deg + + # Compute body-frame velocity from raw quaternion + v_body = compute_body_state(q_raw, frame_data["vel"][:3]) + + display = draw_pose_overlay( + canvas=frame_data["img"], + pose=frame_data["pose"], + vel=frame_data["vel"], + tilt=tilt, + v_body=v_body, + euler=euler_deg, + frame_idx=i, + ts=frame_data["ts"], + ) + + if show: + cv2.imshow(f"UZH-FPV: {scene_name}", display) + key = cv2.waitKey(int(1000 / fps)) & 0xFF + if key in (27, ord("q")): # ESC or q + print(" Interrupted by user") + break + else: + writer.write(display) + + if (i + 1) % 500 == 0: + print(f" Processed {i + 1}/{len(frames)} frames") + + if writer: + writer.release() + print(f" Video saved: {output_path}") + + if show: + cv2.destroyAllWindows() + + print(f" Done. Processed {i + 1} frames.") + + +# ──────────────────────────── Main ──────────────────────────── + + +def main(): + parser = argparse.ArgumentParser( + description="Visualize UZH-FPV dataset with body-frame pose overlay" + ) + parser.add_argument( + "--scene", type=str, default=None, help="Scene name (e.g. indoor_forward_3)" + ) + parser.add_argument("--all", action="store_true", help="Process all scenes") + parser.add_argument( + "--output", + type=str, + default="videos", + help="Output video path or directory (default: videos/)", + ) + parser.add_argument( + "--fps", type=float, default=30.0, help="Output video framerate (default: 30)" + ) + parser.add_argument( + "--max-frames", type=int, default=None, help="Limit number of frames to process" + ) + parser.add_argument( + "--show", action="store_true", help="Display on screen instead of saving video" + ) + args = parser.parse_args() + + # Collect scenes to process + if args.all: + scenes = sorted( + d.name + for d in DATASET_ROOT.iterdir() + if d.is_dir() and any(d.glob("shard_*.tar")) + ) + if not scenes: + print("No scenes with shard files found.") + return + print(f"Processing all {len(scenes)} scenes: {scenes}") + elif args.scene: + scenes = [args.scene] + else: + parser.print_help() + print("\nError: specify --scene or --all") + return + + for scene in scenes: + if args.all and not args.show: + out_path = Path(args.output) / f"{scene}.mp4" + else: + out_path = args.output + + create_video( + scene_name=scene, + output_path=out_path, + fps=args.fps, + max_frames=args.max_frames, + show=args.show, + ) + + +if __name__ == "__main__": + main()