commit 9f0321eff8a145636de4dafcce6efc0f403d0df1 Author: CaoWangrenbo Date: Fri May 29 18:49:01 2026 +0800 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..dd2fd6a --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +# Python +__pycache__/ +*.py[cod] +*.pyo +*.egg-info/ +dist/ +build/ + +# Virtual environment +.venv/ +venv/ +env/ + +# IDE / Editor +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# Data (large binary files) +bags/ +dataset/ + +# Model checkpoints / weights +checkpoints/ +*.pt + +# Logs (TensorBoard, etc.) +logs/ + +# Benchmark evaluation results +benchmark/results/ + +# Evaluation figures +*.png +*.jpg +*.jpeg +*.pdf +*.svg + +# Shell scripts (optional — uncomment if you want to ignore) +# *.sh + +# ROS bag files +*.bag +*.bag.active diff --git a/DATASET_FORMAT.md b/DATASET_FORMAT.md new file mode 100644 index 0000000..d80a45c --- /dev/null +++ b/DATASET_FORMAT.md @@ -0,0 +1,95 @@ +# UZH FPV Dataset Format + +> 由 `rosbag2wds.py` 从 DAVIS 事件相机 ROS bag 转换生成 + +## 目录结构 + +``` +dataset/ +├── / +│ ├── shard_0000.tar # WebDataset shard (图像 + 对齐的 GT) +│ ├── shard_0001.tar +│ ├── ... +│ ├── imu_sequence.npz # 完整 IMU 序列 (独立存储) +│ └── metadata.json # 数据集元信息 +``` + +## 文件说明 + +### 1. WebDataset Shard (`shard_*.tar`) + +每个 tar 文件包含 `shard_size` 个样本(默认 2000),每个样本的 key 为 `frame_`,包含以下字段: + +| Key | 类型 | 内容 | +|------|-------------|-------------------------------------------------------------------| +| `jpg` | JPEG bytes | 灰度图,尺寸 `320×240`,JPEG quality=85 | +| `ts` | float64 | 图像时间戳(ROS bag 系统时间 `t.to_sec()`) | +| `pose`| float32[7] | 位姿:`[x, y, z, qx, qy, qz, qw]`(位置 + 单位四元数) | +| `vel` | float32[6] | 速度:`[vx, vy, vz, wx, wy, wz]`(线速度 + 角速度) | + +**读取示例 (Python):** + +```python +import webdataset as wds + +dataset = wds.WebDataset("dataset//shard_0000.tar") +for sample in dataset: + img = sample["jpg"] # JPEG bytes + ts = sample["ts"] # bytes -> np.frombuffer(..., dtype=np.float64) + pose = sample["pose"] # bytes -> np.frombuffer(..., dtype=np.float32).reshape(7) + vel = sample["vel"] # bytes -> np.frombuffer(..., dtype=np.float32).reshape(6) +``` + +### 2. IMU 序列 (`imu_sequence.npz`) + +独立存储的完整 IMU 数据(NPZ 压缩格式),包含三个数组: + +| Key | 类型 | 形状 | 内容 | +|--------------------|-------------|-------------|----------------------------| +| `timestamps` | float64 | (N,) | IMU 时间戳 | +| `accelerations` | float32 | (N, 3) | 线性加速度 `(ax, ay, az)` m/s² | +| `angular_velocities`| float32 | (N, 3) | 角速度 `(gx, gy, gz)` rad/s | + +**读取示例:** + +```python +import numpy as np +data = np.load("dataset//imu_sequence.npz") +timestamps = data["timestamps"] +acc = data["accelerations"] +gyro = data["angular_velocities"] +``` + +### 3. 元信息 (`metadata.json`) + +```json +{ + "dataset_name": "indoor_forward_7", + "source_bag": "/mnt/indoor_forward_7_davis_with_gt.bag", + "num_images": 2459, + "num_imu_messages": 66632, + "num_ground_truth": 33350, + "image_size": [320, 240], + "imu_frequency_hz": 999.02, + "camera_frequency_hz": 36.89, + "gt_frequency_hz": 500.01, + "coordinate_system": "horizontal (z aligned with gravity, assumed from GT)", + "velocity_dimensions": 6 +} +``` + +## 数据来源 + +| Topic | 内容 | 频率 (典型值) | +|------------------------|------------------|-------------| +| `/dvs/image_raw` | 灰度图像 (mono8) | ~30–50 Hz | +| `/dvs/imu` | IMU (加速度+角速度)| ~1000 Hz | +| `/groundtruth/odometry`| 位姿真值 | ~500 Hz | + +## 预处理说明 + +- **时间戳**: 统一使用 ROS bag 系统时间 `t.to_sec()`,而非 `msg.header.stamp` +- **时间对齐**: 图像与 GT 通过最近邻时间戳匹配,最大允许偏差 0.1s +- **速度计算**: GT 速度由位姿差分计算(前向/后向有限差分 + 四元数旋转向量),忽略 bag 中原始 twist 数据 +- **时间裁剪**: 所有数据裁剪至 GT 时间范围内,去除首尾无 GT 的片段 +- **图像缩放**: 原始 DAVIS 分辨率 `240×180` → 缩放至 `320×240` (INTER_LINEAR) diff --git a/batch_convert.sh b/batch_convert.sh new file mode 100755 index 0000000..b59b082 --- /dev/null +++ b/batch_convert.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# batch_convert.sh +# 在已运行的 ROS 容器内执行,批量转换尚未转换的数据集 +# +# 用法(容器内): +# cd /mnt && bash batch_convert.sh +# +# 它会自动检测: +# - 哪些 .bag 文件尚未转换(通过检查 dataset//metadata.json) +# - 跳过已转换的数据集 + +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +OUTPUT_DIR="${SCRIPT_DIR}/dataset" +BAG_DIR="${SCRIPT_DIR}/bags" + +# 已转换的数据集列表(通过检查 metadata.json) +echo "==========================================" +echo "批量转换 UZH FPV 数据集" +echo "==========================================" + +# 收集所有 bag 文件 +BAG_FILES=() +while IFS= read -r -d '' f; do + BAG_FILES+=("$f") +done < <(find "$BAG_DIR" -maxdepth 2 -name "*_davis_with_gt.bag" -print0) + +if [ ${#BAG_FILES[@]} -eq 0 ]; then + echo "❌ 未找到任何 .bag 文件" + exit 1 +fi + +echo "找到 ${#BAG_FILES[@]} 个 bag 文件" + +# 逐个检查并转换 +CONVERTED=0 +SKIPPED=0 +FAILED=0 + +for bag_path in "${BAG_FILES[@]}"; do + bag_name="$(basename "$bag_path")" + dataset_name="${bag_name%_davis_with_gt.bag}" + + # 检查是否已转换 + metadata_file="${OUTPUT_DIR}/${dataset_name}/metadata.json" + if [ -f "$metadata_file" ]; then + echo " ⏭️ 跳过 ${dataset_name} (已转换)" + SKIPPED=$((SKIPPED + 1)) + continue + fi + + echo "" + echo " 🔄 转换: ${dataset_name}" + + # 检查依赖 + if ! python3 -c "import webdataset" 2>/dev/null; then + echo " ⚠️ 正在安装 webdataset..." + pip3 install webdataset tqdm scipy 2>/dev/null || { + echo " ❌ pip install 失败,跳过 ${dataset_name}" + FAILED=$((FAILED + 1)) + continue + } + fi + + # 执行转换 + if python3 "${SCRIPT_DIR}/rosbag2wds.py" \ + --bag "$bag_path" \ + --output "$OUTPUT_DIR" \ + --name "$dataset_name" \ + --shard_size 2000 \ + --width 320 --height 240; then + echo " ✅ ${dataset_name} 转换完成" + CONVERTED=$((CONVERTED + 1)) + else + echo " ❌ ${dataset_name} 转换失败" + FAILED=$((FAILED + 1)) + fi +done + +echo "" +echo "==========================================" +echo "批量转换结束" +echo " 已转换: ${CONVERTED}" +echo " 已跳过: ${SKIPPED}" +echo " 失败: ${FAILED}" +echo "==========================================" diff --git a/benchmark/__init__.py b/benchmark/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py new file mode 100644 index 0000000..178742f --- /dev/null +++ b/benchmark/benchmark.py @@ -0,0 +1,309 @@ +""" +benchmark.py — Unified evaluation entry point. + +Two modes: + 1. Single-model eval: python -m benchmark.benchmark --checkpoint + 2. Compare mode: python -m benchmark.benchmark --compare + +Results are saved to benchmark/results//. +""" + +import argparse +import sys +from pathlib import Path +from typing import List, Optional + +import torch + +from benchmark.config import eval_cfg, TEST_SCENE_GROUPS +from benchmark.evaluate import run_full_evaluation, save_results + +# Project root (two levels up from benchmark/benchmark.py) +PROJECT_ROOT = Path(__file__).resolve().parents[1] +RESULTS_DIR = PROJECT_ROOT / "benchmark" / "results" + + +def load_checkpoint( + checkpoint_path: Path, + device: torch.device, +) -> torch.nn.Module: + """Load a VelocityPredictionModel from a checkpoint file.""" + from src.velocity_prediction.model import VelocityPredictionModel + + model = VelocityPredictionModel() + ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) + state_dict = ckpt.get("model_state_dict", ckpt) + model.load_state_dict(state_dict) + model.to(device) + model.eval() + return model + + +def run_single_eval( + checkpoint_path: Path, + output_dir: Optional[Path] = None, + device: torch.device = None, + seq_len: Optional[int] = None, + batch_size: Optional[int] = None, + num_workers: Optional[int] = None, + save_plots: bool = True, +) -> Path: + """Evaluate a single checkpoint and save results. + + Returns the output directory path. + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + seq_len = seq_len or eval_cfg.seq_len + batch_size = batch_size or eval_cfg.batch_size + num_workers = num_workers or eval_cfg.num_workers + + # Derive experiment name from checkpoint filename (strip extension) + exp_name = checkpoint_path.stem # e.g. "best" or "epoch_050_val_1.827390" + + if output_dir is None: + output_dir = RESULTS_DIR / exp_name + + print(f"{'=' * 60}") + print(f"Benchmark — Single Model Evaluation") + print(f"{'=' * 60}") + print(f" Checkpoint: {checkpoint_path}") + print(f" Device: {device}") + print(f" Seq len: {seq_len}") + print(f" Batch size: {batch_size}") + print(f" Output: {output_dir}") + print() + + model = load_checkpoint(checkpoint_path, device) + + results = run_full_evaluation( + model=model, + device=device, + seq_len=seq_len, + batch_size=batch_size, + num_workers=num_workers, + event_threshold=eval_cfg.event_threshold, + event_use_log=eval_cfg.event_use_log, + scene_groups=TEST_SCENE_GROUPS, + ) + + save_results(results, save_dir=output_dir, checkpoint_name=exp_name) + + return output_dir + + +def run_compare( + checkpoint_dir: Path, + output_dir: Optional[Path] = None, + device: torch.device = None, + seq_len: Optional[int] = None, + batch_size: Optional[int] = None, + num_workers: Optional[int] = None, + pattern: str = "*.pt", +) -> Path: + """Evaluate all checkpoints in a directory and produce a comparison table. + + Returns the output directory path. + """ + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + seq_len = seq_len or eval_cfg.seq_len + batch_size = batch_size or eval_cfg.batch_size + num_workers = num_workers or eval_cfg.num_workers + + checkpoint_paths = sorted(Path(checkpoint_dir).glob(pattern)) + if not checkpoint_paths: + print(f"No checkpoints found matching '{pattern}' in {checkpoint_dir}") + sys.exit(1) + + if output_dir is None: + output_dir = RESULTS_DIR / "compare" + + print(f"{'=' * 60}") + print(f"Benchmark — Compare Mode ({len(checkpoint_paths)} checkpoints)") + print(f"{'=' * 60}") + print(f" Checkpoint dir: {checkpoint_dir}") + print(f" Device: {device}") + print(f" Seq len: {seq_len}") + print(f" Batch size: {batch_size}") + print(f" Output: {output_dir}") + print() + + all_global_metrics = [] + + for ckpt_path in checkpoint_paths: + exp_name = ckpt_path.stem + print(f"\n── Evaluating {exp_name} ──") + + model = load_checkpoint(ckpt_path, device) + + results = run_full_evaluation( + model=model, + device=device, + seq_len=seq_len, + batch_size=batch_size, + num_workers=num_workers, + event_threshold=eval_cfg.event_threshold, + event_use_log=eval_cfg.event_use_log, + scene_groups=TEST_SCENE_GROUPS, + ) + + # Save individual results + ckpt_output_dir = output_dir / exp_name + save_results(results, save_dir=ckpt_output_dir, checkpoint_name=exp_name) + + all_global_metrics.append((exp_name, results["global"])) + + # ── Comparison table ── + print(f"\n\n{'=' * 60}") + print("Comparison Summary") + print(f"{'=' * 60}") + + header = f"{'Checkpoint':<30} {'RMSE vx':>10} {'RMSE vy':>10} {'RMSE xy':>10} {'MAE vx':>10} {'MAE vy':>10} {'R² vx':>8} {'R² vy':>8}" + sep = "-" * len(header) + print(header) + print(sep) + + rows = [] + for name, metrics in all_global_metrics: + row = ( + f"{name:<30} " + f"{metrics.get('rmse_vx', 0):>10.4f} " + f"{metrics.get('rmse_vy', 0):>10.4f} " + f"{metrics.get('rmse_xy', 0):>10.4f} " + f"{metrics.get('mae_vx', 0):>10.4f} " + f"{metrics.get('mae_vy', 0):>10.4f} " + f"{metrics.get('r2_vx', 0):>8.4f} " + f"{metrics.get('r2_vy', 0):>8.4f}" + ) + print(row) + rows.append(row) + + # Save comparison CSV + import csv + csv_path = output_dir / "comparison.csv" + output_dir.mkdir(parents=True, exist_ok=True) + fieldnames = ["checkpoint", "rmse_vx", "rmse_vy", "rmse_xy", "mae_vx", "mae_vy", + "mae_xy", "r2_vx", "r2_vy", "count"] + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for name, metrics in all_global_metrics: + row = {"checkpoint": name, **metrics} + writer.writerow(row) + print(f"\nComparison CSV: {csv_path}") + + # Save comparison text + txt_path = output_dir / "comparison.txt" + with open(txt_path, "w") as f: + f.write("Benchmark Comparison\n") + f.write(f"{'=' * 60}\n\n") + f.write(header + "\n") + f.write(sep + "\n") + for row in rows: + f.write(row + "\n") + print(f"Comparison TXT: {txt_path}") + + return output_dir + + +def main(): + parser = argparse.ArgumentParser( + description="Unified benchmark for velocity prediction models.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Examples:\n" + " # Single model evaluation\n" + " python -m benchmark.benchmark --checkpoint checkpoints/best.pt\n\n" + " # Compare all checkpoints in a directory\n" + " python -m benchmark.benchmark --compare checkpoints/\n\n" + " # Custom output directory\n" + " python -m benchmark.benchmark --checkpoint checkpoints/best.pt --output my_results/\n" + ), + ) + + # Mutually exclusive mode selection + mode = parser.add_mutually_exclusive_group(required=True) + mode.add_argument( + "--checkpoint", type=str, default=None, + help="Path to a single checkpoint .pt file for single-model evaluation.", + ) + mode.add_argument( + "--compare", type=str, default=None, + help="Directory containing multiple .pt checkpoints for comparison.", + ) + + # Optional overrides + parser.add_argument( + "--output", type=str, default=None, + help="Output directory for results (default: benchmark/results//).", + ) + parser.add_argument( + "--device", type=str, default=None, + help="Device override, e.g. 'cuda:0' or 'cpu' (default: auto-detect).", + ) + parser.add_argument( + "--seq-len", type=int, default=None, + help=f"Sequence length override (default: {eval_cfg.seq_len}).", + ) + parser.add_argument( + "--batch-size", type=int, default=None, + help=f"Batch size override (default: {eval_cfg.batch_size}).", + ) + parser.add_argument( + "--num-workers", type=int, default=None, + help=f"DataLoader workers override (default: {eval_cfg.num_workers}).", + ) + parser.add_argument( + "--pattern", type=str, default="*.pt", + help="Glob pattern for --compare mode (default: '*.pt').", + ) + parser.add_argument( + "--no-plots", action="store_true", + help="Skip generating per-scene plots.", + ) + + args = parser.parse_args() + + # Resolve device + device = None + if args.device is not None: + device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu") + + # Resolve output directory + output_dir = Path(args.output) if args.output else None + + if args.checkpoint: + ckpt_path = Path(args.checkpoint) + if not ckpt_path.exists(): + print(f"Error: checkpoint not found: {ckpt_path}") + sys.exit(1) + run_single_eval( + checkpoint_path=ckpt_path, + output_dir=output_dir, + device=device, + seq_len=args.seq_len, + batch_size=args.batch_size, + num_workers=args.num_workers, + save_plots=not args.no_plots, + ) + elif args.compare: + ckpt_dir = Path(args.compare) + if not ckpt_dir.is_dir(): + print(f"Error: checkpoint directory not found: {ckpt_dir}") + sys.exit(1) + run_compare( + checkpoint_dir=ckpt_dir, + output_dir=output_dir, + device=device, + seq_len=args.seq_len, + batch_size=args.batch_size, + num_workers=args.num_workers, + pattern=args.pattern, + ) + + +if __name__ == "__main__": + main() diff --git a/benchmark/config.py b/benchmark/config.py new file mode 100644 index 0000000..089b4c1 --- /dev/null +++ b/benchmark/config.py @@ -0,0 +1,98 @@ +""" +Benchmark configuration — evaluation-only scene splits and metric definitions. + +This config is independent from src.velocity_prediction.config so that +evaluation scenarios can be changed without touching training config. +""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Dict + + +# ──────────────────────────── Dataset root ──────────────────────────── + +DATASET_ROOT = Path(__file__).resolve().parents[1] / "dataset" + +# ──────────────────────────── Scene splits ──────────────────────────── + +# Each scene group has a name, a list of scene dirs, and a difficulty label. +# The test scenes are the primary evaluation set; val scenes are for +# checkpoint selection reference. + + +@dataclass +class SceneGroup: + name: str + scenes: List[str] + difficulty: str = "medium" # easy / medium / hard + + +# ── Validation scenes (for checkpoint selection reference) ── +VAL_SCENE_GROUPS: List[SceneGroup] = [ + SceneGroup("indoor_forward_7", ["indoor_forward_7"], "hard"), + SceneGroup("outdoor_forward_1", ["outdoor_forward_1"], "easy"), + # SceneGroup("indoor_forward_6", ["indoor_forward_6"], "medium"), + # SceneGroup("indoor_forward_9", ["indoor_forward_9"], "easy"), + # SceneGroup("indoor_forward_10", ["indoor_forward_10"], "easy"), + # SceneGroup("indoor_forward_5", ["indoor_forward_5"], "medium"), +] + +# ── Test scenes (primary evaluation) ── +TEST_SCENE_GROUPS: List[SceneGroup] = [ + SceneGroup("indoor_forward_7", ["indoor_forward_7"], "hard"), + SceneGroup("outdoor_forward_1", ["outdoor_forward_1"], "easy"), + SceneGroup("outdoor_forward_5", ["outdoor_forward_5"], "hard"), + SceneGroup("indoor_forward_6", ["indoor_forward_6"], "medium"), + SceneGroup("indoor_forward_9", ["indoor_forward_9"], "easy"), + SceneGroup("indoor_forward_10", ["indoor_forward_10"], "easy"), + SceneGroup("indoor_forward_5", ["indoor_forward_5"], "medium"), +] + +# Flat lists for convenience +VAL_SCENES: List[str] = [s for g in VAL_SCENE_GROUPS for s in g.scenes] +TEST_SCENES: List[str] = [s for g in TEST_SCENE_GROUPS for s in g.scenes] + +# Difficulty grouping +DIFFICULTY_GROUPS: Dict[str, List[str]] = {} +for g in TEST_SCENE_GROUPS: + DIFFICULTY_GROUPS.setdefault(g.difficulty, []).extend(g.scenes) + + +# ──────────────────────────── Evaluation parameters ──────────────────────────── + + +@dataclass +class EvalConfig: + """Parameters used when running evaluation.""" + + # Sequence length (must match what the model was trained with) + seq_len: int = 8 + + # Batch size for evaluation (can be larger than training) + batch_size: int = 64 + + # Data loading + num_workers: int = 2 + + # Event simulation (must match training config) + event_threshold: float = 0.1 + event_use_log: bool = True + + # Output directory (relative to benchmark/results/) + output_dir: str = "results" + + # Whether to generate per-scene plots + save_plots: bool = True + + # Device override (None = auto-detect) + device: str = "cuda" + + +# ──────────────────────────── Metrics definition ──────────────────────────── + +# Metrics computed per-axis and overall +METRICS = ["rmse", "mae", "r2"] + +# Singleton +eval_cfg = EvalConfig() diff --git a/benchmark/evaluate.py b/benchmark/evaluate.py new file mode 100644 index 0000000..e886771 --- /dev/null +++ b/benchmark/evaluate.py @@ -0,0 +1,458 @@ +""" +Core evaluation logic: run model on one or more scenes, compute metrics, +generate visualizations, and save structured results. + +This module is called by benchmark.py (the user-facing entry point). +""" + +import numpy as np +from pathlib import Path +from typing import List, Optional, Dict, Tuple + +import torch +import torch.nn as nn + +from src.velocity_prediction.model import VelocityPredictionModel +from src.velocity_prediction.dataset import create_val_loader +from src.velocity_prediction.config import VELOCITY_MEAN, VELOCITY_STD + +from benchmark.config import ( + eval_cfg, + TEST_SCENE_GROUPS, + VAL_SCENE_GROUPS, + DIFFICULTY_GROUPS, + DATASET_ROOT, +) + + +# ──────────────────────────── Metrics ──────────────────────────── + + +def compute_metrics( + pred: np.ndarray, + target: np.ndarray, +) -> Dict[str, float]: + """ + Compute RMSE, MAE, R² for each axis and overall. + + Args: + pred: (N, 2) denormalized predictions + target: (N, 2) denormalized ground truth + + Returns: + dict with keys like rmse_vx, rmse_vy, rmse_xy, mae_vx, ... + """ + # Per-axis + rmse_x = float(np.sqrt(np.mean((pred[:, 0] - target[:, 0]) ** 2))) + rmse_y = float(np.sqrt(np.mean((pred[:, 1] - target[:, 1]) ** 2))) + rmse_xy = float(np.sqrt(np.mean(np.sum((pred - target) ** 2, axis=1)))) + + mae_x = float(np.mean(np.abs(pred[:, 0] - target[:, 0]))) + mae_y = float(np.mean(np.abs(pred[:, 1] - target[:, 1]))) + mae_xy = float(np.mean(np.sqrt(np.sum((pred - target) ** 2, axis=1)))) + + # R² per axis + def r2(p, t): + ss_res = np.sum((t - p) ** 2) + ss_tot = np.sum((t - np.mean(t)) ** 2) + return float(1 - ss_res / ss_tot) if ss_tot > 1e-12 else 0.0 + + r2_x = r2(pred[:, 0], target[:, 0]) + r2_y = r2(pred[:, 1], target[:, 1]) + + return { + "rmse_vx": rmse_x, + "rmse_vy": rmse_y, + "rmse_xy": rmse_xy, + "mae_vx": mae_x, + "mae_vy": mae_y, + "mae_xy": mae_xy, + "r2_vx": r2_x, + "r2_vy": r2_y, + "count": len(pred), + } + + +# ──────────────────────────── Per-scene evaluation ──────────────────────────── + + +@torch.no_grad() +def evaluate_scene( + model: nn.Module, + scene_names: List[str], + device: torch.device, + seq_len: int = 8, + batch_size: int = 64, + num_workers: int = 2, + event_threshold: float = 0.1, + event_use_log: bool = True, +) -> Dict: + """ + Evaluate model on one or more scenes. + + Returns: + dict with keys: + preds: (N, 2) denormalized predictions + targets: (N, 2) denormalized ground truth + metrics: dict of scalar metrics + """ + loader = create_val_loader( + scene_names=scene_names, + seq_len=seq_len, + batch_size=batch_size, + num_workers=num_workers, + event_threshold=event_threshold, + event_use_log=event_use_log, + ) + + model.eval() + all_preds = [] + all_targets = [] + + for batch in loader: + events = batch["events"].to(device) + tilt = batch["tilt"].to(device) + target = batch["v_body_target"].to(device) # (B, S, 2) normalized + + pred = model(events, tilt) # (B, 2) normalized + target_last = target[:, -1, :] # (B, 2) normalized + + all_preds.append(pred.cpu().numpy()) + all_targets.append(target_last.cpu().numpy()) + + if not all_preds: + return {"preds": np.zeros((0, 2)), "targets": np.zeros((0, 2)), "metrics": {}} + + preds = np.concatenate(all_preds, axis=0) + targets = np.concatenate(all_targets, axis=0) + + # Denormalize + mean = np.array(VELOCITY_MEAN, dtype=np.float32) + std = np.array(VELOCITY_STD, dtype=np.float32) + preds_denorm = preds * std + mean + targets_denorm = targets * std + mean + + metrics = compute_metrics(preds_denorm, targets_denorm) + + return { + "preds": preds_denorm, + "targets": targets_denorm, + "metrics": metrics, + } + + +# ──────────────────────────── Full evaluation suite ──────────────────────────── + + +def run_full_evaluation( + model: nn.Module, + device: torch.device, + seq_len: int = 8, + batch_size: int = 64, + num_workers: int = 2, + event_threshold: float = 0.1, + event_use_log: bool = True, + scene_groups=None, +) -> Dict: + """ + Run evaluation on all scene groups. + + Returns nested dict: + { + "global": { metrics... }, + "per_scene": { + "indoor_forward_7": { metrics..., "preds": ..., "targets": ... }, + ... + }, + "by_difficulty": { + "easy": { metrics... }, + "hard": { metrics... }, + } + } + """ + if scene_groups is None: + from benchmark.config import TEST_SCENE_GROUPS + scene_groups = TEST_SCENE_GROUPS + + per_scene = {} + all_preds = [] + all_targets = [] + + for group in scene_groups: + for scene_name in group.scenes: + result = evaluate_scene( + model, [scene_name], device, + seq_len=seq_len, batch_size=batch_size, + num_workers=num_workers, + event_threshold=event_threshold, + event_use_log=event_use_log, + ) + per_scene[scene_name] = result + if result["preds"].shape[0] > 0: + all_preds.append(result["preds"]) + all_targets.append(result["targets"]) + + # Global metrics (all scenes combined) + if all_preds: + global_preds = np.concatenate(all_preds, axis=0) + global_targets = np.concatenate(all_targets, axis=0) + global_metrics = compute_metrics(global_preds, global_targets) + else: + global_preds = np.zeros((0, 2)) + global_targets = np.zeros((0, 2)) + global_metrics = {} + + # By difficulty + by_difficulty = {} + for diff, scenes in DIFFICULTY_GROUPS.items(): + diff_preds = [] + diff_targets = [] + for s in scenes: + if s in per_scene and per_scene[s]["preds"].shape[0] > 0: + diff_preds.append(per_scene[s]["preds"]) + diff_targets.append(per_scene[s]["targets"]) + if diff_preds: + by_difficulty[diff] = compute_metrics( + np.concatenate(diff_preds, axis=0), + np.concatenate(diff_targets, axis=0), + ) + else: + by_difficulty[diff] = {} + + return { + "global": global_metrics, + "per_scene": per_scene, + "by_difficulty": by_difficulty, + } + + +# ──────────────────────────── Visualization ──────────────────────────── + + +def plot_scene_comparison( + preds: np.ndarray, + targets: np.ndarray, + scene_name: str, + save_dir: Path, + metrics: Optional[Dict] = None, +): + """Generate time-series and scatter plots for a single scene.""" + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + time = np.arange(len(preds)) + + # ── Time-series plot ── + fig, axes = plt.subplots(2, 1, figsize=(14, 5), sharex=True) + + axes[0].plot(time, targets[:, 0], label="GT vx", color="C0", alpha=0.7, linewidth=0.8) + axes[0].plot(time, preds[:, 0], label="Pred vx", color="C1", alpha=0.7, linewidth=0.8) + axes[0].set_ylabel("vx (m/s)") + axes[0].legend(fontsize=9) + axes[0].grid(True, alpha=0.3) + + axes[1].plot(time, targets[:, 1], label="GT vy", color="C0", alpha=0.7, linewidth=0.8) + axes[1].plot(time, preds[:, 1], label="Pred vy", color="C1", alpha=0.7, linewidth=0.8) + axes[1].set_ylabel("vy (m/s)") + axes[1].set_xlabel("Frame index") + axes[1].legend(fontsize=9) + axes[1].grid(True, alpha=0.3) + + title = f"Body-frame Velocity — {scene_name}" + if metrics: + title += f" | RMSE vx={metrics['rmse_vx']:.3f} vy={metrics['rmse_vy']:.3f}" + fig.suptitle(title) + try: + plt.tight_layout() + plt.savefig(save_dir / f"{scene_name}_timeseries.png", dpi=150, bbox_inches="tight") + except Exception as e: + print(f" [WARN] Failed to save {scene_name}_timeseries.png: {e}") + plt.close() + + # ── Scatter plot ── + fig, axes = plt.subplots(1, 2, figsize=(10, 4.5)) + + for ax, pred, target, label in zip( + axes, [preds[:, 0], preds[:, 1]], [targets[:, 0], targets[:, 1]], ["vx", "vy"] + ): + ax.scatter(target, pred, s=3, alpha=0.4, c="C1", edgecolors="none") + lim_min = min(target.min(), pred.min()) + lim_max = max(target.max(), pred.max()) + margin = (lim_max - lim_min) * 0.05 + ax.plot([lim_min - margin, lim_max + margin], + [lim_min - margin, lim_max + margin], "r--", alpha=0.5, linewidth=1) + ax.set_xlabel(f"GT {label} (m/s)") + ax.set_ylabel(f"Pred {label} (m/s)") + ax.set_aspect("equal") + ax.grid(True, alpha=0.3) + if metrics: + ax.set_title(f"{label} — RMSE: {metrics[f'rmse_{label}']:.4f}") + + fig.suptitle(f"Scatter — {scene_name}") + try: + plt.tight_layout() + plt.savefig(save_dir / f"{scene_name}_scatter.png", dpi=150, bbox_inches="tight") + except Exception as e: + print(f" [WARN] Failed to save {scene_name}_scatter.png: {e}") + plt.close() + + +def plot_global_comparison( + results: Dict, + save_dir: Path, +): + """Generate a summary figure comparing all scenes.""" + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + scenes = list(results["per_scene"].keys()) + if not scenes: + return + + # Bar chart: RMSE vx and vy per scene + rmse_vx = [results["per_scene"][s]["metrics"].get("rmse_vx", 0) for s in scenes] + rmse_vy = [results["per_scene"][s]["metrics"].get("rmse_vy", 0) for s in scenes] + rmse_xy = [results["per_scene"][s]["metrics"].get("rmse_xy", 0) for s in scenes] + + x = np.arange(len(scenes)) + width = 0.25 + + fig, ax = plt.subplots(figsize=(10, 4.5)) + bars1 = ax.bar(x - width, rmse_vx, width, label="RMSE vx", alpha=0.8) + bars2 = ax.bar(x, rmse_vy, width, label="RMSE vy", alpha=0.8) + bars3 = ax.bar(x + width, rmse_xy, width, label="RMSE xy", alpha=0.8) + + ax.set_xticks(x) + ax.set_xticklabels(scenes, rotation=15, ha="right") + ax.set_ylabel("RMSE (m/s)") + ax.set_title("Per-Scene RMSE Comparison") + ax.legend(fontsize=9) + ax.grid(True, alpha=0.3, axis="y") + + # Annotate values + for bars in [bars1, bars2, bars3]: + for bar in bars: + height = bar.get_height() + ax.annotate(f"{height:.3f}", + xy=(bar.get_x() + bar.get_width() / 2, height), + xytext=(0, 2), textcoords="offset points", + ha="center", va="bottom", fontsize=7) + + try: + plt.tight_layout() + plt.savefig(save_dir / "per_scene_rmse.png", dpi=150, bbox_inches="tight") + except Exception as e: + print(f" [WARN] Failed to save per_scene_rmse.png: {e}") + plt.close() + + +# ──────────────────────────── Results serialization ──────────────────────────── + + +def save_results( + results: Dict, + save_dir: Path, + checkpoint_name: str = "model", +): + """Save all evaluation results to disk.""" + save_dir = Path(save_dir) + plots_dir = save_dir / "plots" + save_dir.mkdir(parents=True, exist_ok=True) + plots_dir.mkdir(parents=True, exist_ok=True) + + # ── 1. Global metrics ── + global_m = results["global"] + lines = [ + f"Benchmark Results: {checkpoint_name}", + f"{'=' * 50}", + f"Total samples: {global_m.get('count', '?')}", + "", + "── Global Metrics ──", + f" RMSE vx: {global_m.get('rmse_vx', 'N/A'):.4f} m/s", + f" RMSE vy: {global_m.get('rmse_vy', 'N/A'):.4f} m/s", + f" RMSE xy: {global_m.get('rmse_xy', 'N/A'):.4f} m/s", + f" MAE vx: {global_m.get('mae_vx', 'N/A'):.4f} m/s", + f" MAE vy: {global_m.get('mae_vy', 'N/A'):.4f} m/s", + f" MAE xy: {global_m.get('mae_xy', 'N/A'):.4f} m/s", + f" R² vx: {global_m.get('r2_vx', 'N/A'):.4f}", + f" R² vy: {global_m.get('r2_vy', 'N/A'):.4f}", + "", + ] + + # ── 2. Per-scene metrics ── + lines.append("── Per-Scene Metrics ──") + lines.append(f" {'Scene':<22} {'RMSE vx':>10} {'RMSE vy':>10} {'RMSE xy':>10} " + f"{'MAE vx':>10} {'MAE vy':>10} {'R² vx':>8} {'R² vy':>8} {'Samples':>8}") + lines.append(" " + "-" * 96) + + for scene_name, scene_result in results["per_scene"].items(): + m = scene_result["metrics"] + lines.append( + f" {scene_name:<22} {m.get('rmse_vx', 0):>10.4f} {m.get('rmse_vy', 0):>10.4f} " + f"{m.get('rmse_xy', 0):>10.4f} {m.get('mae_vx', 0):>10.4f} " + f"{m.get('mae_vy', 0):>10.4f} {m.get('r2_vx', 0):>8.4f} " + f"{m.get('r2_vy', 0):>8.4f} {m.get('count', 0):>8}" + ) + + lines.append("") + + # ── 3. By difficulty ── + lines.append("── By Difficulty ──") + for diff, metrics in results["by_difficulty"].items(): + lines.append(f" {diff:<10} RMSE vx={metrics.get('rmse_vx', 0):.4f} " + f"RMSE vy={metrics.get('rmse_vy', 0):.4f} " + f"RMSE xy={metrics.get('rmse_xy', 0):.4f} " + f"(samples={metrics.get('count', 0)})") + + summary_text = "\n".join(lines) + with open(save_dir / "summary.txt", "w") as f: + f.write(summary_text) + print(summary_text) + + # ── 4. CSV: per-scene metrics ── + import csv + csv_path = save_dir / "per_scene_metrics.csv" + fieldnames = ["scene", "rmse_vx", "rmse_vy", "rmse_xy", "mae_vx", "mae_vy", + "mae_xy", "r2_vx", "r2_vy", "count"] + with open(csv_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for scene_name, scene_result in results["per_scene"].items(): + row = {"scene": scene_name, **scene_result["metrics"]} + writer.writerow(row) + print(f"Per-scene CSV: {csv_path}") + + # ── 5. Global metrics CSV (single row) ── + csv_path_global = save_dir / "metrics.csv" + with open(csv_path_global, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=["checkpoint"] + list(global_m.keys())) + writer.writeheader() + row = {"checkpoint": checkpoint_name, **global_m} + writer.writerow(row) + print(f"Global metrics CSV: {csv_path_global}") + + # ── 6. Plots ── + for scene_name, scene_result in results["per_scene"].items(): + if scene_result["preds"].shape[0] > 0: + plot_scene_comparison( + scene_result["preds"], + scene_result["targets"], + scene_name, + plots_dir, + metrics=scene_result["metrics"], + ) + + plot_global_comparison(results, plots_dir) + + import matplotlib.pyplot as plt + plt.close("all") + + print(f"\nAll results saved to: {save_dir.resolve()}") diff --git a/download_davis_gt_rosbags.sh b/download_davis_gt_rosbags.sh new file mode 100755 index 0000000..fd0cd1e --- /dev/null +++ b/download_davis_gt_rosbags.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +# Download script for UZH FPV DAVIS rosbags with ground truth +# Skips files that already exist in the current directory. +# Uses the final download.ifi.uzh.ch URLs directly (avoids 301 redirect chain). +set -euo pipefail + +BASE="https://download.ifi.uzh.ch/rpg/web/datasets/uzh-fpv-newer-versions/v3" + +URLS=( + # Indoor forward facing + "${BASE}/indoor_forward_3_davis_with_gt.bag" + "${BASE}/indoor_forward_5_davis_with_gt.bag" + "${BASE}/indoor_forward_6_davis_with_gt.bag" + "${BASE}/indoor_forward_9_davis_with_gt.bag" + "${BASE}/indoor_forward_10_davis_with_gt.bag" + + # Indoor 45 degree downward + "${BASE}/indoor_45_2_davis_with_gt.bag" + "${BASE}/indoor_45_4_davis_with_gt.bag" + "${BASE}/indoor_45_9_davis_with_gt.bag" + "${BASE}/indoor_45_12_davis_with_gt.bag" + "${BASE}/indoor_45_13_davis_with_gt.bag" + "${BASE}/indoor_45_14_davis_with_gt.bag" + + # Outdoor forward facing + "${BASE}/outdoor_forward_1_davis_with_gt.bag" + "${BASE}/outdoor_forward_5_davis_with_gt.bag" + + # Outdoor 45 degree downward + "${BASE}/outdoor_45_1_davis_with_gt.bag" +) + +for url in "${URLS[@]}"; do + filename=$(basename "$url") + if [ -f "$filename" ]; then + echo "SKIP: $filename already exists" + else + echo "DOWNLOAD: $filename" + wget --continue --show-progress "$url" -O "$filename" + fi +done + +echo "" +echo "Done. All DAVIS with-GT rosbags downloaded." diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0871005 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +# ===== Core ===== +numpy +scipy + +# ===== PyTorch (CUDA 12.4) ===== +# Install via: pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 +torch>=2.4 +torchvision + +# ===== Data loading ===== +webdataset +opencv-python + +# ===== Training / logging ===== +tensorboard +tqdm +matplotlib diff --git a/rosbag2wds.py b/rosbag2wds.py new file mode 100644 index 0000000..7a012a1 --- /dev/null +++ b/rosbag2wds.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 +""" +ROS bag to WebDataset converter for DAVIS dataset +Extracts: grayscale images, IMU sequence, ground truth poses and velocities + +Usage: + python convert_bag_to_webdataset.py --bag --output --name +""" + +import argparse +import json +import os +import sys +from pathlib import Path +from typing import Dict, List, Tuple, Optional +import numpy as np +import rosbag +from cv_bridge import CvBridge +import cv2 +import webdataset as wds +from tqdm import tqdm +from scipy.spatial.transform import Rotation as R + + +class BagToWebDataset: + def __init__(self, bag_path: str, output_dir: str, dataset_name: str, + shard_size: int = 2000, image_width: int = 320, image_height: int = 240): + self.bag_path = Path(bag_path) + self.output_dir = Path(output_dir) / dataset_name + self.dataset_name = dataset_name + self.shard_size = shard_size + self.image_width = image_width + self.image_height = image_height + + self.bridge = CvBridge() + + # Data containers + self.images: List[Tuple[float, np.ndarray]] = [] # (timestamp, image) + self.imu_timestamps: List[float] = [] + self.imu_acc: List[np.ndarray] = [] # (ax, ay, az) + self.imu_gyro: List[np.ndarray] = [] # (gx, gy, gz) + self.gt_timestamps: List[float] = [] + self.gt_poses: List[np.ndarray] = [] # (x, y, z, qx, qy, qz, qw) + self.gt_velocities: List[np.ndarray] = [] # (vx, vy, vz, wx, wy, wz) + + def extract_all_data(self): + """Extract all data from ROS bag""" + print(f"Opening bag: {self.bag_path}") + bag = rosbag.Bag(str(self.bag_path), 'r') + + # Count messages for progress bar + topic_counts = {topic: bag.get_message_count(topic) for topic in + ['/dvs/image_raw', '/dvs/imu', '/groundtruth/odometry']} + total_msgs = sum(topic_counts.values()) + + print(f"Topics: {topic_counts}") + + with tqdm(total=total_msgs, desc="Extracting messages") as pbar: + for topic, msg, t in bag.read_messages(topics=['/dvs/image_raw', '/dvs/imu', '/groundtruth/odometry']): + + if topic == '/dvs/image_raw': + self._process_image(msg, t) # 传入 t + + elif topic == '/dvs/imu': + self._process_imu(msg, t) # 传入 t + + elif topic == '/groundtruth/odometry': + self._process_odometry(msg, t) # 传入 t + + pbar.update(1) + + bag.close() + + # Post-processing: compute velocities from poses if not directly available + self._ensure_velocities() + + # Print statistics + print(f"\nExtraction completed:") + print(f" Images: {len(self.images)}") + print(f" IMU messages: {len(self.imu_timestamps)}") + print(f" Ground truth poses: {len(self.gt_timestamps)}") + print(f" Ground truth velocities: {len(self.gt_velocities)}") + + def crop_to_gt_time_range(self): + """裁剪所有数据,只保留 GT 时间范围内的部分""" + + if len(self.gt_timestamps) == 0: + print("Warning: No GT data found, skipping crop") + return + + gt_start = min(self.gt_timestamps) + gt_end = max(self.gt_timestamps) + + print(f"\nCropping to GT time range: {gt_start:.3f} - {gt_end:.3f} ({gt_end - gt_start:.1f}s)") + + # 裁剪图像 + original_img_count = len(self.images) + self.images = [(ts, img) for ts, img in self.images if gt_start <= ts <= gt_end] + print(f" Images: {original_img_count} -> {len(self.images)}") + + # 裁剪 IMU + original_imu_count = len(self.imu_timestamps) + imu_filtered = [(ts, acc, gyro) for ts, acc, gyro + in zip(self.imu_timestamps, self.imu_acc, self.imu_gyro) + if gt_start <= ts <= gt_end] + + if imu_filtered: + self.imu_timestamps = [item[0] for item in imu_filtered] + self.imu_acc = [item[1] for item in imu_filtered] + self.imu_gyro = [item[2] for item in imu_filtered] + print(f" IMU: {original_imu_count} -> {len(self.imu_timestamps)}") + + # GT 数据本身已经在范围内,不需要裁剪 + print(f" GT: {len(self.gt_timestamps)} (unchanged)") + + def _process_image(self, msg, t): + """Process grayscale image message using system time""" + try: + # Convert ROS image to OpenCV format + cv_img = self.bridge.imgmsg_to_cv2(msg, desired_encoding='mono8') + + # Resize if needed + if self.image_width and self.image_height: + cv_img = cv2.resize(cv_img, (self.image_width, self.image_height), + interpolation=cv2.INTER_LINEAR) + + # 使用系统物理时间,而不是 msg.header.stamp + timestamp = t.to_sec() + self.images.append((timestamp, cv_img)) + + except Exception as e: + print(f"Error processing image: {e}") + + def _process_imu(self, msg, t): + """Process IMU message using system time""" + timestamp = t.to_sec() # 使用系统物理时间 + + # Linear acceleration (m/s^2) + acc = np.array([msg.linear_acceleration.x, + msg.linear_acceleration.y, + msg.linear_acceleration.z], dtype=np.float32) + + # Angular velocity (rad/s) + gyro = np.array([msg.angular_velocity.x, + msg.angular_velocity.y, + msg.angular_velocity.z], dtype=np.float32) + + self.imu_timestamps.append(timestamp) + self.imu_acc.append(acc) + self.imu_gyro.append(gyro) + + def _process_odometry(self, msg, t): + """Process ground truth odometry using system time""" + timestamp = t.to_sec() # 使用系统物理时间 + + # Position (x, y, z) + pos = np.array([msg.pose.pose.position.x, + msg.pose.pose.position.y, + msg.pose.pose.position.z], dtype=np.float32) + + # Orientation (qx, qy, qz, qw) - already normalized + quat = np.array([msg.pose.pose.orientation.x, + msg.pose.pose.orientation.y, + msg.pose.pose.orientation.z, + msg.pose.pose.orientation.w], dtype=np.float32) + + pose = np.concatenate([pos, quat]) + self.gt_timestamps.append(timestamp) + self.gt_poses.append(pose) + + # Velocity: always compute from pose differences in post-processing + vel = None + + self.gt_velocities.append(vel) + + def _ensure_velocities(self): + # 数据集中 twist 数据为 0 直接利用时间戳差值 + # """Compute velocities from pose differences if not directly available""" + # # Check if velocities are missing + # missing_velocities = any(v is None for v in self.gt_velocities) + + # if not missing_velocities: + # return + + print("Computing velocities from pose differences...") + + computed_velocities = [] + for i in range(len(self.gt_timestamps)): + if i == 0: + # Use forward difference for first frame + if len(self.gt_timestamps) > 1: + dt = self.gt_timestamps[1] - self.gt_timestamps[0] + if dt > 0: + # Linear velocity + v_lin = (self.gt_poses[1][:3] - self.gt_poses[0][:3]) / dt + # Angular velocity (from quaternion difference) + q0 = self.gt_poses[0][3:7] + q1 = self.gt_poses[1][3:7] + dq = R.from_quat(q1) * R.from_quat(q0).inv() + v_ang = dq.as_rotvec() / dt + computed_velocities.append(np.concatenate([v_lin, v_ang])) + else: + computed_velocities.append(np.zeros(6, dtype=np.float32)) + else: + computed_velocities.append(np.zeros(6, dtype=np.float32)) + else: + # Use backward difference + dt = self.gt_timestamps[i] - self.gt_timestamps[i-1] + if dt > 0: + v_lin = (self.gt_poses[i][:3] - self.gt_poses[i-1][:3]) / dt + q0 = self.gt_poses[i-1][3:7] + q1 = self.gt_poses[i][3:7] + dq = R.from_quat(q1) * R.from_quat(q0).inv() + v_ang = dq.as_rotvec() / dt + computed_velocities.append(np.concatenate([v_lin, v_ang])) + else: + computed_velocities.append(np.zeros(6, dtype=np.float32)) + + # Replace missing velocities + for i in range(len(self.gt_velocities)): + if self.gt_velocities[i] is None: + self.gt_velocities[i] = computed_velocities[i] + + def save_imu_sequence(self): + """Save IMU sequence as NPZ file""" + imu_data = { + 'timestamps': np.array(self.imu_timestamps, dtype=np.float64), + 'accelerations': np.array(self.imu_acc, dtype=np.float32), + 'angular_velocities': np.array(self.imu_gyro, dtype=np.float32) + } + + imu_path = self.output_dir / 'imu_sequence.npz' + imu_path.parent.mkdir(parents=True, exist_ok=True) + np.savez_compressed(imu_path, **imu_data) + + print(f"Saved IMU sequence: {imu_path}") + return imu_path + + def align_ground_truth_to_images(self) -> List[Tuple[float, np.ndarray, np.ndarray, np.ndarray]]: + """Align ground truth (pose + velocity) to each image using nearest timestamp""" + aligned_gt = [] + + gt_timestamps = np.array(self.gt_timestamps) + gt_poses = np.array(self.gt_poses) + gt_vels = np.array(self.gt_velocities) + + for img_ts, img in tqdm(self.images, desc="Aligning ground truth to images"): + idx = np.argmin(np.abs(gt_timestamps - img_ts)) + time_diff = abs(gt_timestamps[idx] - img_ts) + + if time_diff < 0.1: + aligned_gt.append((img_ts, img, gt_poses[idx], gt_vels[idx])) # 保存图像 + + return aligned_gt + + def save_as_webdataset(self, aligned_gt: List[Tuple[float, np.ndarray, np.ndarray, np.ndarray]]): + """Save images and aligned ground truth as WebDataset tar files""" + + num_shards = (len(aligned_gt) + self.shard_size - 1) // self.shard_size + + print(f"Saving {len(aligned_gt)} samples into {num_shards} shards...") + + for shard_idx in range(num_shards): + start_idx = shard_idx * self.shard_size + end_idx = min((shard_idx + 1) * self.shard_size, len(aligned_gt)) + + tar_path = self.output_dir / f'shard_{shard_idx:04d}.tar' + + with wds.TarWriter(str(tar_path)) as sink: + for local_idx, (img_ts, img, pose, vel) in enumerate(aligned_gt): + + # Encode image as JPEG + _, img_encoded = cv2.imencode('.jpg', img, + [cv2.IMWRITE_JPEG_QUALITY, 85]) + img_bytes = img_encoded.tobytes() + + # Prepare metadata + sample_key = f'frame_{local_idx:08d}' + + # Write to tar + sink.write({ + '__key__': sample_key, + 'jpg': img_bytes, + 'ts': np.array([img_ts], dtype=np.float64).tobytes(), + 'pose': pose.astype(np.float32).tobytes(), + 'vel': vel.astype(np.float32).tobytes() + }) + + print(f" Saved {tar_path} ({end_idx - start_idx} samples)") + + def save_metadata(self): + """Save dataset metadata""" + metadata = { + 'dataset_name': self.dataset_name, + 'source_bag': str(self.bag_path), + 'num_images': len(self.images), + 'num_imu_messages': len(self.imu_timestamps), + 'num_ground_truth': len(self.gt_timestamps), + 'image_size': [self.image_width, self.image_height], + 'imu_frequency_hz': len(self.imu_timestamps) / (self.imu_timestamps[-1] - self.imu_timestamps[0]) if len(self.imu_timestamps) > 1 else 0, + 'camera_frequency_hz': len(self.images) / (self.images[-1][0] - self.images[0][0]) if len(self.images) > 1 else 0, + 'gt_frequency_hz': len(self.gt_timestamps) / (self.gt_timestamps[-1] - self.gt_timestamps[0]) if len(self.gt_timestamps) > 1 else 0, + 'coordinate_system': 'horizontal (z aligned with gravity, assumed from GT)', + 'velocity_dimensions': 6, # (vx, vy, vz, wx, wy, wz) + } + + metadata_path = self.output_dir / 'metadata.json' + with open(metadata_path, 'w') as f: + json.dump(metadata, f, indent=2) + + print(f"Saved metadata: {metadata_path}") + + def convert(self): + """Main conversion pipeline""" + print(f"\n{'='*60}") + print(f"Converting: {self.bag_path.name}") + print(f"Output: {self.output_dir}") + print(f"{'='*60}\n") + + # Step 1: Extract all data from bag + self.extract_all_data() + + # 裁剪掉无 GT 的时间段 + self.crop_to_gt_time_range() + + # Step 2: Save IMU sequence + self.save_imu_sequence() + + # # Step 3: Align ground truth to images + aligned_gt = self.align_ground_truth_to_images() + + if len(aligned_gt) == 0: + print("Error: No aligned ground truth found!") + sys.exit(1) + + # # Step 4: Save as WebDataset + self.save_as_webdataset(aligned_gt) + + # # Step 5: Save metadata + self.save_metadata() + + self.diagnose_timestamps() + + print(f"\n✅ Conversion completed for {self.bag_path.name}") + + + + def diagnose_timestamps(self): + """Print timestamp ranges for debugging""" + img_timestamps = [t for t, _ in self.images] + gt_timestamps = self.gt_timestamps + + print(f"Image timestamps: {min(img_timestamps):.3f} - {max(img_timestamps):.3f}") + print(f"GT timestamps: {min(gt_timestamps):.3f} - {max(gt_timestamps):.3f}") + print(f"Image duration: {max(img_timestamps) - min(img_timestamps):.3f}s") + print(f"GT duration: {max(gt_timestamps) - min(gt_timestamps):.3f}s") + + # Check if there's a constant offset + if len(img_timestamps) > 0 and len(gt_timestamps) > 0: + offset = gt_timestamps[0] - img_timestamps[0] + print(f"Initial offset (first GT - first image): {offset:.3f}s") + +def main(): + parser = argparse.ArgumentParser(description='Convert ROS bag to WebDataset format') + parser.add_argument('--bag', type=str, required=True, help='Path to ROS bag file') + parser.add_argument('--output', type=str, default='./dataset', help='Output directory') + parser.add_argument('--name', type=str, default=None, help='Dataset name (default: bag filename without extension)') + parser.add_argument('--shard_size', type=int, default=2000, help='Number of samples per shard') + parser.add_argument('--width', type=int, default=320, help='Image width (resize)') + parser.add_argument('--height', type=int, default=240, help='Image height (resize)') + + args = parser.parse_args() + + # Validate inputs + if not os.path.exists(args.bag): + print(f"Error: Bag file not found: {args.bag}") + sys.exit(1) + + # Set dataset name + if args.name is None: + args.name = Path(args.bag).stem + + # Run conversion + converter = BagToWebDataset( + bag_path=args.bag, + output_dir=args.output, + dataset_name=args.name, + shard_size=args.shard_size, + image_width=args.width, + image_height=args.height + ) + + converter.convert() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/event_utils.py b/src/event_utils.py new file mode 100644 index 0000000..f37c69c --- /dev/null +++ b/src/event_utils.py @@ -0,0 +1,103 @@ +""" +Event camera simulation utilities for ML preprocessing. + +Core logic extracted from EventCameraSimulator (test.py). +Designed for frame-by-frame preprocessing in training pipelines. + +Output: + events_binary: (-1, 0, +1) hard threshold decision + events_strength: [-1, 1] continuous change intensity (clipped & normalized) +""" + +import numpy as np +from collections import deque + + +class EventProcessor: + """Lightweight event computation module. No visualization, no OpenCV dependency.""" + + def __init__(self, threshold=0.1, use_log=True, auto_threshold=False): + self.threshold = threshold + self.use_log = use_log + self.auto_threshold = auto_threshold + + self.prev_brightness = None + self.change_history = deque(maxlen=100) + self.threshold_scale = 1.5 + + def reset(self): + """Clear temporal state (call on video/reset).""" + self.prev_brightness = None + self.change_history.clear() + + def _to_grayscale(self, frame): + """Convert frame to grayscale float32.""" + if frame.ndim == 3: + # RGB/HWC -> gray via luminance weights + gray = 0.299 * frame[..., 0] + 0.587 * frame[..., 1] + 0.114 * frame[..., 2] + else: + gray = frame + return gray.astype(np.float32) + + def _compute_change(self, brightness): + """Compute log or linear brightness change.""" + if self.use_log: + eps = 1e-3 + return np.log(brightness + eps) - np.log(self.prev_brightness + eps) + else: + return brightness - self.prev_brightness + + def _update_auto_threshold(self, change): + """Adapt threshold based on global change statistics.""" + abs_change = np.abs(change) + mean_change = np.mean(abs_change) + self.change_history.append(mean_change) + + if len(self.change_history) > 10: + avg_change = np.mean(self.change_history) + new_threshold = max(avg_change * self.threshold_scale, 0.01) + self.threshold = self.threshold * 0.9 + new_threshold * 0.1 + + if self.use_log: + self.threshold = np.clip(self.threshold, 0.01, 0.5) + else: + self.threshold = np.clip(self.threshold, 1, 50) + + def __call__(self, frame): + """ + Process a single frame. + + Args: + frame: np.ndarray, shape (H, W) or (H, W, C), uint8 or float. + + Returns: + events_binary: np.ndarray (H, W), values in {-1, 0, +1} + events_strength: np.ndarray (H, W), values in [-1, 1] + event_count: int, number of non-zero events + """ + brightness = self._to_grayscale(frame) + + # First frame — initialise, no events + if self.prev_brightness is None: + self.prev_brightness = brightness + h, w = brightness.shape + return np.zeros((h, w), dtype=np.int8), np.zeros((h, w), dtype=np.float32), 0 + + change = self._compute_change(brightness) + + if self.auto_threshold: + self._update_auto_threshold(change) + + # Binary events + events_binary = np.zeros_like(brightness, dtype=np.int8) + events_binary[change > self.threshold] = 1 + events_binary[change < -self.threshold] = -1 + + # Continuous strength: clip to [-threshold, threshold] then normalise to [-1, 1] + events_strength = np.clip(change, -self.threshold, self.threshold) / self.threshold + + event_count = int(np.count_nonzero(events_binary)) + + self.prev_brightness = brightness + + return events_binary, events_strength, event_count diff --git a/src/velocity_prediction/README.md b/src/velocity_prediction/README.md new file mode 100644 index 0000000..6de1b55 --- /dev/null +++ b/src/velocity_prediction/README.md @@ -0,0 +1,198 @@ +# Velocity Prediction from Event Frames + Attitude + +基于 UZH-FPV 数据集,通过模拟事件帧 + 姿态输入预测机体速度(机体系 vx, vy)。 + +--- + +## 项目结构 + +``` +uzh_fpv/ +├── dataset/ # UZH-FPV 数据集(WebDataset shards) +│ ├── indoor_forward_3/ +│ ├── indoor_forward_5/ +│ ├── ... +│ └── outdoor_45_1/ +├── src/ +│ ├── event_utils.py # 模拟事件帧生成(已有模块) +│ └── velocity_prediction/ # 本工程 +│ ├── __init__.py +│ ├── config.py # 全局配置 +│ ├── utils.py # 四元数运算、坐标变换 +│ ├── transforms.py # 数据预处理管线 +│ ├── dataset.py # WebDataset 加载 + 序列采样 +│ ├── model.py # 网络模型定义 +│ ├── train.py # 训练入口 +│ └── evaluate.py # 评估与可视化 +├── DATASET_FORMAT.md # 数据集格式说明 +└── requirements.txt +``` + +--- + +## 网络架构 + +``` +事件帧 (1, 240, 320) ──► CNN (4层 Conv+Pool+GAP, 256-d) + │ +姿态 tilt_angles (3,) ──► PoseMLP (3→32→64, 64-d) ───┤ + │ + concat (320-d) ← 每帧融合 + │ + GRU (hidden=128) + │ + Head MLP (128→64→2) + │ + [vx_body, vy_body] +``` + +### 各模块参数 + +| 模块 | 参数量 | 说明 | +|------|--------|------| +| CNN Encoder | 387,840 | 4 层 Conv2D(3×3) + BN + ReLU + MaxPool(2×2) + GAP,通道 1→32→64→128→256 | +| PoseMLP | 2,240 | 3→32→64,两层全连接 | +| GRU | 172,800 | 单层,input=320, hidden=128 | +| Head MLP | 8,386 | 128→64→2 | +| **总计** | **~571K** | FP32 约 2.3 MB | + +--- + +## 数据预处理 + +### 输入变换管线(transforms.py) + +每个 WebDataset 样本依次经过: + +1. **DecodeSample** — JPEG 解码为灰度图 (H, W),pose/vel 字节转 numpy +2. **SimulateEvents** — `EventProcessor` 计算帧间亮度变化,输出二值事件帧 (1, H, W),值域 {-1, 0, 1} +3. **ComputeTilt** — 从四元数 `[qx, qy, qz, qw]` 中提取偏航角 yaw,移除后得到 tilt 旋转向量 (3,) +4. **ComputeBodyVelocity** — 世界系速度 `[vx, vy, vz]` → 补偿偏航 → 转到机体系,取 `[vx_body, vy_body]` (2,) + +### 坐标系变换逻辑(utils.py) + +``` +输入: q_world_to_body (四元数), v_world (3,) + +Step 1: 从四元数分解偏航角 yaw +Step 2: 构造纯偏航四元数 q_yaw +Step 3: q_tilt = q_yaw^{-1} * q_world_to_body → tilt_angles (旋转向量) +Step 4: v_yaw_comp = q_yaw^{-1} * v_world → 偏航补偿 +Step 5: v_body = q_tilt^{-1} * v_yaw_comp → 转到机体系 +Step 6: 取 v_body[:2] 作为回归目标 +``` + +### 序列采样(dataset.py) + +- 从 shard 中取连续 `seq_len` 帧(默认 8 帧)构成一个训练样本 +- 输出 batch 维度:`(B, S, 1, H, W)` 事件帧, `(B, S, 3)` tilt, `(B, S, 2)` 速度 GT +- 模型预测最后一帧的速度 + +### 数据集划分 + +| 集 | 场景 | +|----|------| +| **训练** | indoor_forward_3/5/6/7/9/10, indoor_45_2/4/9/12 | +| **验证** | indoor_45_13/14 | +| **测试** | outdoor_forward_1/3/5, outdoor_45_1 | + +--- + +## 运行方式 + +### 1. 安装依赖 + +```bash +# 使用 uv(推荐) +uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +uv pip install webdataset opencv-python matplotlib tensorboard numpy scipy + +# 或使用 pip +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install webdataset opencv-python matplotlib tensorboard numpy scipy +``` + +### 2. 训练 + +```bash +# 从项目根目录执行 +cd /home/hexone/Workplace/ws_asmo/uzh_fpv + +# 激活虚拟环境后 +python -m src.velocity_prediction.train +``` + +训练参数在 `config.py` 的 `TrainConfig` 中配置,关键参数: + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| seq_len | 8 | 每序列帧数 | +| batch_size | 32 | 批次大小 | +| epochs | 100 | 训练轮数 | +| lr | 1e-3 | 学习率 | +| event_threshold | 0.1 | 事件模拟阈值 | + +训练输出: +- `logs/` — TensorBoard 日志 +- `checkpoints/` — 模型 checkpoint(每 10 轮保存 + 最优模型 `best.pt`) + +### 3. 评估 + +```bash +python -m src.velocity_prediction.evaluate --checkpoint checkpoints/best.pt +``` + +输出: +- 控制台打印 RMSE(vx, vy, xy) +- `eval_velocity.png` — 预测 vs GT 时序对比图 +- `eval_scatter.png` — 散点图 + +### 4. 模型参数量检查 + +```bash +python -m src.velocity_prediction.model +``` + +输出类似: +``` +Total trainable parameters: 571,266 (0.571 M) +``` + +--- + +## 模型导出与部署 + +### 导出 TorchScript + +```python +import torch +from src.velocity_prediction.model import VelocityPredictionModel + +model = VelocityPredictionModel() +model.load_state_dict(torch.load("checkpoints/best.pt")["model_state_dict"]) +model.eval() + +# 导出 +traced = torch.jit.trace(model, (torch.randn(1, 8, 1, 240, 320), torch.randn(1, 8, 3))) +traced.save("velocity_model.pt") +``` + +### RV1106 部署注意事项 + +- 模型 ~0.57M 参数,FP32 ~2.3 MB,INT8 量化后 ~0.6 MB +- CNN 部分可跑 NPU(0.5 TOPS),GRU 需 ARM CPU 执行 +- 若需裁剪:`config.py` 中 `CNNConfig.channels` 减半(32→16, 64→32, 128→64, 256→128),或 `GRUConfig.hidden_size` 从 128 降至 64 + +--- + +## 文件职责速查 + +| 文件 | 职责 | 关键类/函数 | +|------|------|------------| +| `config.py` | 所有可配置参数 | `ModelConfig`, `TrainConfig`, `TRAIN_SCENES` | +| `utils.py` | 四元数运算、坐标变换 | `decompose_tilt`, `world_vel_to_body` | +| `transforms.py` | 数据预处理管线 | `DecodeSample`, `SimulateEvents`, `ComputeTilt`, `ComputeBodyVelocity` | +| `dataset.py` | WebDataset 加载 + 序列采样 | `create_train_loader`, `create_val_loader` | +| `model.py` | 网络模型 | `VelocityPredictionModel`, `CNNEncoder`, `PoseMLP` | +| `train.py` | 训练循环 | `train_one_epoch`, `validate`, `main` | +| `evaluate.py` | 评估与可视化 | `evaluate`, `plot_results`, `plot_scatter` | diff --git a/src/velocity_prediction/__init__.py b/src/velocity_prediction/__init__.py new file mode 100644 index 0000000..f9b7c7e --- /dev/null +++ b/src/velocity_prediction/__init__.py @@ -0,0 +1,7 @@ +""" +Velocity prediction from simulated event frames + attitude. + +Pipeline: + Event frame (1, H, W) ──► CNN ──┐ + Tilt angles (3,) ──► MLP ──┤──► concat ──► GRU ──► Head ──► [vx_body, vy_body] +""" diff --git a/src/velocity_prediction/config.py b/src/velocity_prediction/config.py new file mode 100644 index 0000000..f994ddf --- /dev/null +++ b/src/velocity_prediction/config.py @@ -0,0 +1,118 @@ +""" +Global configuration for velocity prediction. +""" + +from dataclasses import dataclass, field +from pathlib import Path + + +# ──────────────────────────── Dataset paths ──────────────────────────── + +DATASET_ROOT = Path(__file__).resolve().parents[2] / "dataset" + +# Velocity normalization stats (computed from training set) +VELOCITY_MEAN = [0.859184, -0.783945] # [vx, vy] +VELOCITY_STD = [2.244513, 1.088335] # [vx, vy] + +# TRAIN_SCENES = [ +# "indoor_forward_3", "indoor_forward_5", "indoor_forward_6", +# "indoor_forward_7", "indoor_forward_9", "indoor_forward_10", +# "indoor_45_2", "indoor_45_4", "indoor_45_9", "indoor_45_12", +# ] +# VAL_SCENES = [ +# "indoor_45_13", "indoor_45_14", +# ] +# TEST_SCENES = [ +# "outdoor_forward_1", "outdoor_forward_3", "outdoor_forward_5", +# "outdoor_45_1", +# ] + +TRAIN_SCENES = [ + "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy + "indoor_forward_5", "indoor_forward_6", # Medium + "outdoor_forward_3" # Medium 室外 +] +VAL_SCENES = [ + "indoor_forward_7", # Hard 室内 + "outdoor_forward_1" # Easy 室外 + # "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy +] +TEST_SCENES = [ + "indoor_forward_7", # Hard 室内 + "outdoor_forward_1", # Easy 室外 + "outdoor_forward_5" # Hard 室外 + # "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy +] + + +# ──────────────────────────── Model architecture ──────────────────────────── + +@dataclass +class CNNConfig: + in_channels: int = 1 + channels: tuple = (32, 64, 128, 256) # per-layer output channels + kernel_size: int = 3 + pool_size: int = 2 + use_bn: bool = True + + +@dataclass +class PoseMLPConfig: + input_dim: int = 3 + hidden_dim: int = 32 + output_dim: int = 64 + + +@dataclass +class GRUConfig: + input_size: int = 320 # CNN(256) + PoseMLP(64) + hidden_size: int = 128 + num_layers: int = 1 + dropout: float = 0.0 + + +@dataclass +class HeadConfig: + input_dim: int = 128 # GRU hidden_size + hidden_dim: int = 64 + output_dim: int = 2 # [vx_body, vy_body] + + +@dataclass +class ModelConfig: + cnn: CNNConfig = field(default_factory=CNNConfig) + pose_mlp: PoseMLPConfig = field(default_factory=PoseMLPConfig) + gru: GRUConfig = field(default_factory=GRUConfig) + head: HeadConfig = field(default_factory=HeadConfig) + + +# ──────────────────────────── Training ──────────────────────────── + +@dataclass +class TrainConfig: + seq_len: int = 8 # frames per training sequence + batch_size: int = 32 + epochs: int = 100 + lr: float = 1e-3 + weight_decay: float = 1e-5 + lr_scheduler_step: int = 30 + lr_scheduler_gamma: float = 0.5 + num_workers: int = 4 + seed: int = 42 + + # Event simulation + event_threshold: float = 0.1 + event_use_log: bool = True + event_auto_threshold: bool = False + + # Logging / checkpoint + log_dir: str = "logs" + checkpoint_dir: str = "checkpoints" + log_interval: int = 10 + save_interval: int = 10 + + +# ──────────────────────────── Singleton instances ──────────────────────────── + +model_cfg = ModelConfig() +train_cfg = TrainConfig() diff --git a/src/velocity_prediction/dataset.py b/src/velocity_prediction/dataset.py new file mode 100644 index 0000000..88325c8 --- /dev/null +++ b/src/velocity_prediction/dataset.py @@ -0,0 +1,120 @@ +""" +WebDataset-based dataset with sequence sampling. + +Each sample from the dataset is a dict: + { + "events": np.ndarray (1, H, W), # simulated event frame + "tilt": np.ndarray (3,), # tilt rotation vector + "v_body_target": np.ndarray (2,), # body-frame [vx, vy] + } + +The dataset groups consecutive frames into sequences of length seq_len. +""" + +import webdataset as wds +from pathlib import Path +from typing import List, Callable, Optional + +from src.velocity_prediction.config import DATASET_ROOT, train_cfg +from src.velocity_prediction.transforms import build_train_transform, build_val_transform + + +def _scene_urls(scene_names: List[str], root: Path = DATASET_ROOT) -> List[str]: + """Build list of actual shard file paths for the given scene names.""" + urls = [] + for name in scene_names: + scene_dir = root / name + if not scene_dir.exists(): + print(f"Warning: scene directory not found: {scene_dir}") + continue + # List all shard_*.tar files in the scene directory + shard_files = sorted(scene_dir.glob("shard_*.tar")) + if not shard_files: + print(f"Warning: no shard files found in {scene_dir}") + continue + urls.extend(str(f) for f in shard_files) + return urls + + +def _build_pipeline( + urls: List[str], + transform: Callable, + seq_len: int, + shuffle: int = 1000, + deterministic: bool = False, +): + """ + Build a WebDataset pipeline that: + 1. Reads tar shards + 2. Decodes and transforms individual samples + 3. Groups consecutive samples into sequences + """ + dataset = wds.WebDataset(urls, shardshuffle=shuffle if not deterministic else 0, empty_check=False) + + if not deterministic: + dataset = dataset.shuffle(shuffle) + + dataset = dataset.decode().map(transform) + + # Group into sequences of seq_len consecutive frames + dataset = dataset.batched(seq_len, partial=False) + + return dataset + + +def create_train_loader( + scene_names: Optional[List[str]] = None, + seq_len: int = 8, + batch_size: int = 32, + num_workers: int = 4, + event_threshold: float = 0.1, + event_use_log: bool = True, +): + """Create a DataLoader for training.""" + if scene_names is None: + from src.velocity_prediction.config import TRAIN_SCENES + scene_names = TRAIN_SCENES + + urls = _scene_urls(scene_names) + transform = build_train_transform( + event_threshold=event_threshold, + event_use_log=event_use_log, + ) + pipeline = _build_pipeline(urls, transform, seq_len=seq_len, shuffle=1000) + + loader = wds.WebLoader( + pipeline, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, # already shuffled in pipeline + ) + return loader + + +def create_val_loader( + scene_names: Optional[List[str]] = None, + seq_len: int = 8, + batch_size: int = 32, + num_workers: int = 4, + event_threshold: float = 0.1, + event_use_log: bool = True, +): + """Create a DataLoader for validation (deterministic order).""" + if scene_names is None: + from src.velocity_prediction.config import VAL_SCENES + scene_names = VAL_SCENES + + urls = _scene_urls(scene_names) + transform = build_val_transform( + event_threshold=event_threshold, + event_use_log=event_use_log, + ) + pipeline = _build_pipeline(urls, transform, seq_len=seq_len, shuffle=0, deterministic=True) + + loader = wds.WebLoader( + pipeline, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + ) + return loader diff --git a/src/velocity_prediction/evaluate.py b/src/velocity_prediction/evaluate.py new file mode 100644 index 0000000..d8b515d --- /dev/null +++ b/src/velocity_prediction/evaluate.py @@ -0,0 +1,212 @@ +""" +Evaluation and visualization for VelocityPredictionModel. + +Usage: + python -m src.velocity_prediction.evaluate --checkpoint checkpoints/best.pt +""" + +import argparse +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path + +import torch +import torch.nn as nn + +from src.velocity_prediction.model import VelocityPredictionModel +from src.velocity_prediction.dataset import create_val_loader +from src.velocity_prediction.config import train_cfg, VELOCITY_MEAN, VELOCITY_STD + + +@torch.no_grad() +def evaluate( + model: nn.Module, + loader, + device: torch.device, +) -> dict: + """ + Run evaluation on a dataloader. + + Returns: + dict with keys: + preds: np.ndarray (N, 2) predicted [vx, vy] + targets: np.ndarray (N, 2) ground truth [vx, vy] + """ + model.eval() + all_preds = [] + all_targets = [] + + for batch in loader: + events = batch["events"].to(device) + tilt = batch["tilt"].to(device) + target = batch["v_body_target"].to(device) # (B, S, 2) + + pred = model(events, tilt) # (B, 2) + target_last = target[:, -1, :] # (B, 2) + + all_preds.append(pred.cpu().numpy()) + all_targets.append(target_last.cpu().numpy()) + + preds = np.concatenate(all_preds, axis=0) + targets = np.concatenate(all_targets, axis=0) + + # Denormalize predictions back to original velocity space + mean = np.array(VELOCITY_MEAN, dtype=np.float32) + std = np.array(VELOCITY_STD, dtype=np.float32) + preds_denorm = preds * std + mean + targets_denorm = targets * std + mean + + # ── Diagnostics (in normalized space) ──────────────────────── + print("\n========== Evaluation Diagnostics (normalized space) ==========") + print(f"Total samples: {len(preds)}") + print(f"\n--- Targets (normalized) ---") + print(f" vx: mean={targets[:, 0].mean():.6f}, std={targets[:, 0].std():.6f}") + print(f" vy: mean={targets[:, 1].mean():.6f}, std={targets[:, 1].std():.6f}") + print(f"\n--- Predictions (normalized) ---") + print(f" vx: mean={preds[:, 0].mean():.6f}, std={preds[:, 0].std():.6f}, " + f"min={preds[:, 0].min():.6f}, max={preds[:, 0].max():.6f}") + print(f" vy: mean={preds[:, 1].mean():.6f}, std={preds[:, 1].std():.6f}, " + f"min={preds[:, 1].min():.6f}, max={preds[:, 1].max():.6f}") + print(f"\n--- Unique prediction values ---") + print(f" vx unique: {len(np.unique(preds[:, 0]))} / {len(preds)}") + print(f" vy unique: {len(np.unique(preds[:, 1]))} / {len(preds)}") + vx_range = preds[:, 0].max() - preds[:, 0].min() + vy_range = preds[:, 1].max() - preds[:, 1].min() + print(f"\n vx range: {vx_range:.8f} (constant if near 0)") + print(f" vy range: {vy_range:.8f} (constant if near 0)") + print(f"\n--- Constant prediction check ---") + print(f" pred vx mean ≈ 0? {abs(preds[:, 0].mean()):.6f} diff from zero") + print(f" pred vy mean ≈ 0? {abs(preds[:, 1].mean()):.6f} diff from zero") + print("=============================================\n") + + # Per-axis and overall RMSE (in original velocity space) + rmse_x = np.sqrt(np.mean((preds_denorm[:, 0] - targets_denorm[:, 0]) ** 2)) + rmse_y = np.sqrt(np.mean((preds_denorm[:, 1] - targets_denorm[:, 1]) ** 2)) + rmse_xy = np.sqrt(np.mean(np.sum((preds_denorm - targets_denorm) ** 2, axis=1))) + + return { + "preds": preds_denorm, # denormalized for plotting + "targets": targets_denorm, # denormalized for plotting + "rmse_x": rmse_x, + "rmse_y": rmse_y, + "rmse_xy": rmse_xy, + } + + +def plot_results( + preds: np.ndarray, + targets: np.ndarray, + save_path: str = "eval_plot.png", +): + """Plot predicted vs ground truth velocity traces.""" + fig, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=True) + + time = np.arange(len(preds)) + + axes[0].plot(time, targets[:, 0], label="GT vx", color="C0", alpha=0.8) + axes[0].plot(time, preds[:, 0], label="Pred vx", color="C1", alpha=0.8) + axes[0].set_ylabel("vx (m/s)") + axes[0].legend() + axes[0].grid(True, alpha=0.3) + + axes[1].plot(time, targets[:, 1], label="GT vy", color="C0", alpha=0.8) + axes[1].plot(time, preds[:, 1], label="Pred vy", color="C1", alpha=0.8) + axes[1].set_ylabel("vy (m/s)") + axes[1].set_xlabel("Frame index") + axes[1].legend() + axes[1].grid(True, alpha=0.3) + + fig.suptitle("Body-frame Velocity Prediction") + plt.tight_layout() + plt.savefig(save_path, dpi=150) + print(f"Plot saved: {save_path}") + plt.close() + + +def plot_scatter( + preds: np.ndarray, + targets: np.ndarray, + save_path: str = "eval_scatter.png", +): + """Scatter plot: predicted vs ground truth.""" + fig, axes = plt.subplots(1, 2, figsize=(10, 4)) + + for ax, pred, target, label in zip( + axes, [preds[:, 0], preds[:, 1]], [targets[:, 0], targets[:, 1]], ["vx", "vy"] + ): + ax.scatter(target, pred, s=2, alpha=0.5) + lim_min = min(target.min(), pred.min()) + lim_max = max(target.max(), pred.max()) + ax.plot([lim_min, lim_max], [lim_min, lim_max], "r--", alpha=0.5) + ax.set_xlabel(f"GT {label} (m/s)") + ax.set_ylabel(f"Pred {label} (m/s)") + ax.set_aspect("equal") + ax.grid(True, alpha=0.3) + ax.set_title(f"{label} — RMSE: {np.sqrt(np.mean((pred - target)**2)):.4f}") + + plt.tight_layout() + plt.savefig(save_path, dpi=150) + print(f"Scatter saved: {save_path}") + plt.close() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--checkpoint", type=str, default="checkpoints/best.pt", + help="Path to model checkpoint") + parser.add_argument("--device", type=str, default="cuda", + help="Device to use (e.g. 'cuda:2', 'cpu')") + parser.add_argument("--plot", action="store_true", default=True, + help="Generate evaluation plots") + args = parser.parse_args() + + device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu") + print(f"Device: {device}") + + # Load model + model = VelocityPredictionModel() + ckpt = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(ckpt["model_state_dict"]) + model.to(device) + print(f"Loaded checkpoint from {args.checkpoint} (epoch={ckpt.get('epoch', '?')})") + + # Validation loader (use test scenes for final eval) + from src.velocity_prediction.config import TEST_SCENES + loader = create_val_loader( + scene_names=TEST_SCENES, + seq_len=train_cfg.seq_len, + batch_size=train_cfg.batch_size, + num_workers=2, + event_threshold=train_cfg.event_threshold, + event_use_log=train_cfg.event_use_log, + ) + + # # ── Quick event diagnostics: inspect one batch ─────────────── + # print("\n========== Event Frame Diagnostics ==========") + # sample_batch = next(iter(loader)) + # ev = sample_batch["events"] # (B, S, 1, H, W) + # print(f"Events shape: {ev.shape}") + # print(f"Events dtype: {ev.dtype}") + # print(f"Events value counts: -1: {(ev == -1).sum().item()}, " + # f"0: {(ev == 0).sum().item()}, +1: {(ev == 1).sum().item()}") + # total_el = ev.numel() + # nonzero = (ev != 0).sum().item() + # print(f"Non-zero ratio: {nonzero / total_el:.6f} ({nonzero}/{total_el})") + # print(f"Per-sample non-zero: {[(ev[b] != 0).sum().item() for b in range(min(4, ev.shape[0]))]}") + # print("=============================================\n") + + # Evaluate + results = evaluate(model, loader, device) + print(f"\nEvaluation results on test scenes: {TEST_SCENES}") + print(f" RMSE vx: {results['rmse_x']:.4f} m/s") + print(f" RMSE vy: {results['rmse_y']:.4f} m/s") + print(f" RMSE xy: {results['rmse_xy']:.4f} m/s") + + # Plots + if args.plot: + plot_results(results["preds"], results["targets"], "eval_velocity.png") + plot_scatter(results["preds"], results["targets"], "eval_scatter.png") + + +if __name__ == "__main__": + main() diff --git a/src/velocity_prediction/model.py b/src/velocity_prediction/model.py new file mode 100644 index 0000000..720a5c3 --- /dev/null +++ b/src/velocity_prediction/model.py @@ -0,0 +1,170 @@ +""" +VelocityPredictionModel: CNN + PoseMLP → concat → GRU → Head → [vx_body, vy_body]. + +Architecture: + Event frame (1, H, W) ──► CNN ──┐ + Tilt angles (3,) ──► MLP ──┤──► concat ──► GRU ──► Head ──► [vx, vy] +""" + +import torch +import torch.nn as nn + +from src.velocity_prediction.config import model_cfg + + +class CNNEncoder(nn.Module): + """ + 4-layer ConvNet with BatchNorm, ReLU, MaxPool, ending with Global Avg Pool. + + Input: (B, S, 1, H, W) — processed per-frame (flattened to (B*S, 1, H, W)) + Output: (B, S, C_out) — per-frame feature vectors + """ + + def __init__(self, cfg=model_cfg.cnn): + super().__init__() + channels = cfg.channels + in_ch = cfg.in_channels + + layers = [] + for out_ch in channels: + layers.extend([ + nn.Conv2d(in_ch, out_ch, kernel_size=cfg.kernel_size, padding=cfg.kernel_size // 2), + # nn.BatchNorm2d(out_ch) if cfg.use_bn else nn.Identity(), + nn.Identity(), + nn.LeakyReLU(inplace=True), + nn.MaxPool2d(cfg.pool_size), + ]) + in_ch = out_ch + + self.conv = nn.Sequential(*layers) + self.gap = nn.AdaptiveAvgPool2d(1) + self.out_dim = channels[-1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: (B, S, 1, H, W) event frame sequence + Returns: + features: (B, S, C_out) + """ + B, S, C, H, W = x.shape + x = x.view(B * S, C, H, W) # (B*S, 1, H, W) + x = self.conv(x) # (B*S, C_out, H', W') + x = self.gap(x) # (B*S, C_out, 1, 1) + x = x.view(B, S, self.out_dim) # (B, S, C_out) + return x + + +class PoseMLP(nn.Module): + """ + Encode tilt rotation vector (3,) into a compact feature vector. + + Input: (B, S, 3) + Output: (B, S, output_dim) + """ + + def __init__(self, cfg=model_cfg.pose_mlp): + super().__init__() + self.net = nn.Sequential( + nn.Linear(cfg.input_dim, cfg.hidden_dim), + nn.LeakyReLU(inplace=True), + nn.Linear(cfg.hidden_dim, cfg.output_dim), + nn.LeakyReLU(inplace=True), + ) + self.out_dim = cfg.output_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: (B, S, 3) → (B, S, output_dim) + """ + B, S, D = x.shape + x = x.view(B * S, D) + x = self.net(x) + x = x.view(B, S, self.out_dim) + return x + + +class VelocityPredictionModel(nn.Module): + """ + Full model: CNN + PoseMLP → concat → GRU → Head → [vx, vy]. + + Input: + events: (B, S, 1, H, W) + tilt: (B, S, 3) + Output: + v_body: (B, 2) — body-frame [vx, vy] for the last frame in the sequence + """ + + def __init__(self, cnn_cfg=model_cfg.cnn, pose_cfg=model_cfg.pose_mlp, + gru_cfg=model_cfg.gru, head_cfg=model_cfg.head): + super().__init__() + + self.cnn = CNNEncoder(cnn_cfg) + self.pose_mlp = PoseMLP(pose_cfg) + + fused_dim = self.cnn.out_dim + self.pose_mlp.out_dim # 256 + 64 = 320 + + self.gru = nn.GRU( + input_size=fused_dim, + hidden_size=gru_cfg.hidden_size, + num_layers=gru_cfg.num_layers, + dropout=gru_cfg.dropout if gru_cfg.num_layers > 1 else 0.0, + batch_first=True, + ) + + self.head = nn.Sequential( + nn.Linear(gru_cfg.hidden_size, head_cfg.hidden_dim), + nn.LeakyReLU(inplace=True), + nn.Linear(head_cfg.hidden_dim, head_cfg.output_dim), + ) + + # # Small init for the final layer: start from near-zero output + # self.head[-1].weight.data.mul_(0.01) + # self.head[-1].bias.data.zero_() + + def forward(self, events: torch.Tensor, tilt: torch.Tensor) -> torch.Tensor: + """ + Args: + events: (B, S, 1, H, W) + tilt: (B, S, 3) + + Returns: + v_body: (B, 2) predicted body-frame [vx, vy] at the last timestep + """ + # Per-frame encoding + cnn_feat = self.cnn(events) # (B, S, 256) + pose_feat = self.pose_mlp(tilt) # (B, S, 64) + + # Fuse per frame + fused = torch.cat([cnn_feat, pose_feat], dim=-1) # (B, S, 320) + + # GRU temporal modelling + gru_out, h_n = self.gru(fused) # gru_out: (B, S, 128), h_n: (1, B, 128) + + # Use last hidden state + last_hidden = h_n[-1] # (B, 128) + + # Head regression + v_body = self.head(last_hidden) # (B, 2) + return v_body + + +def count_parameters(model: nn.Module) -> int: + """Count trainable parameters.""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +if __name__ == "__main__": + # Quick sanity check + model = VelocityPredictionModel() + total = count_parameters(model) + print(f"Total trainable parameters: {total:,} ({total/1e6:.3f} M)") + + # Forward pass test + B, S, H, W = 4, 8, 240, 320 + events = torch.randn(B, S, 1, H, W) + tilt = torch.randn(B, S, 3) + out = model(events, tilt) + print(f"Input events: {events.shape}") + print(f"Input tilt: {tilt.shape}") + print(f"Output: {out.shape} (should be [4, 2])") diff --git a/src/velocity_prediction/train.py b/src/velocity_prediction/train.py new file mode 100644 index 0000000..6c45aad --- /dev/null +++ b/src/velocity_prediction/train.py @@ -0,0 +1,213 @@ +""" +Training loop for VelocityPredictionModel. + +Usage: + python -m src.velocity_prediction.train [--device cuda:0] +""" + +import argparse +import os +import time +import numpy as np +from pathlib import Path + +import torch +import torch.nn as nn +from torch.utils.tensorboard import SummaryWriter + +from src.velocity_prediction.config import train_cfg, model_cfg +from src.velocity_prediction.model import VelocityPredictionModel, count_parameters +from src.velocity_prediction.dataset import create_train_loader, create_val_loader + + +def set_seed(seed: int): + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def train_one_epoch( + model: nn.Module, + loader, + optimizer: torch.optim.Optimizer, + criterion: nn.Module, + device: torch.device, + epoch: int, + writer: SummaryWriter, + log_interval: int = 50, +) -> float: + """Train for one epoch. Returns average loss.""" + model.train() + total_loss = 0.0 + num_batches = 0 + start_time = time.time() + + for batch_idx, batch in enumerate(loader): + events = batch["events"].to(device) # (B, S, 1, H, W) + tilt = batch["tilt"].to(device) # (B, S, 3) + target = batch["v_body_target"].to(device) # (B, S, 2) + + # Predict velocity for the last frame in the sequence + pred = model(events, tilt) # (B, 2) + target_last = target[:, -1, :] # (B, 2) + + loss = criterion(pred, target_last) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.item() + num_batches += 1 + + if batch_idx % log_interval == 0: + elapsed = time.time() - start_time + print(f" Epoch {epoch} | Batch {batch_idx} | Loss: {loss.item():.6f} | {elapsed:.1f}s") + writer.add_scalar("train/loss_batch", loss.item(), batch_idx) + + avg_loss = total_loss / max(num_batches, 1) + return avg_loss + + +@torch.no_grad() +def validate( + model: nn.Module, + loader, + criterion: nn.Module, + device: torch.device, +) -> float: + """Validate. Returns average loss.""" + model.eval() + total_loss = 0.0 + num_batches = 0 + + for batch in loader: + events = batch["events"].to(device) + tilt = batch["tilt"].to(device) + target = batch["v_body_target"].to(device) + + pred = model(events, tilt) + target_last = target[:, -1, :] + + loss = criterion(pred, target_last) + total_loss += loss.item() + num_batches += 1 + + avg_loss = total_loss / max(num_batches, 1) + return avg_loss + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default="cuda", + help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')") + args = parser.parse_args() + + set_seed(train_cfg.seed) + device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu") + print(f"Device: {device}") + + # Create model + model = VelocityPredictionModel() + model.to(device) + total_params = count_parameters(model) + print(f"Model parameters: {total_params:,} ({total_params/1e6:.3f} M)") + + # Data loaders + train_loader = create_train_loader( + seq_len=train_cfg.seq_len, + batch_size=train_cfg.batch_size, + num_workers=train_cfg.num_workers, + event_threshold=train_cfg.event_threshold, + event_use_log=train_cfg.event_use_log, + ) + val_loader = create_val_loader( + seq_len=train_cfg.seq_len, + batch_size=train_cfg.batch_size, + num_workers=train_cfg.num_workers, + event_threshold=train_cfg.event_threshold, + event_use_log=train_cfg.event_use_log, + ) + + # Optimizer & scheduler + optimizer = torch.optim.AdamW( + model.parameters(), + lr=train_cfg.lr, + weight_decay=train_cfg.weight_decay, + ) + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, + step_size=train_cfg.lr_scheduler_step, + gamma=train_cfg.lr_scheduler_gamma, + ) + # criterion = nn.SmoothL1Loss() + criterion = nn.MSELoss() + + # Logging + log_dir = Path(train_cfg.log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + writer = SummaryWriter(log_dir=str(log_dir)) + + ckpt_dir = Path(train_cfg.checkpoint_dir) + ckpt_dir.mkdir(parents=True, exist_ok=True) + + best_val_loss = float("inf") + + print(f"\nStarting training for {train_cfg.epochs} epochs...") + print(f" seq_len={train_cfg.seq_len}, batch_size={train_cfg.batch_size}") + print(f" lr={train_cfg.lr}, weight_decay={train_cfg.weight_decay}") + print(f" log_dir={log_dir}, checkpoint_dir={ckpt_dir}\n") + + for epoch in range(1, train_cfg.epochs + 1): + epoch_start = time.time() + + train_loss = train_one_epoch( + model, train_loader, optimizer, criterion, device, epoch, writer, + log_interval=train_cfg.log_interval, + ) + val_loss = validate(model, val_loader, criterion, device) + scheduler.step() + + epoch_time = time.time() - epoch_start + current_lr = scheduler.get_last_lr()[0] + + print(f"Epoch {epoch:3d}/{train_cfg.epochs} | " + f"Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | " + f"LR: {current_lr:.2e} | Time: {epoch_time:.1f}s") + + writer.add_scalar("train/loss_epoch", train_loss, epoch) + writer.add_scalar("val/loss", val_loss, epoch) + writer.add_scalar("lr", current_lr, epoch) + + # Save checkpoint + if epoch % train_cfg.save_interval == 0: + ckpt_path = ckpt_dir / f"epoch_{epoch:03d}_val_{val_loss:.6f}.pt" + torch.save({ + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict(), + "train_loss": train_loss, + "val_loss": val_loss, + }, ckpt_path) + print(f" Checkpoint saved: {ckpt_path}") + + # Save best model + if val_loss < best_val_loss: + best_val_loss = val_loss + best_path = ckpt_dir / "best.pt" + torch.save({ + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "val_loss": val_loss, + }, best_path) + print(f" Best model updated: {best_path} (val_loss={val_loss:.6f})") + + writer.close() + print(f"\nTraining complete. Best val loss: {best_val_loss:.6f}") + print(f"Best checkpoint: {ckpt_dir / 'best.pt'}") + + +if __name__ == "__main__": + main() diff --git a/src/velocity_prediction/transforms.py b/src/velocity_prediction/transforms.py new file mode 100644 index 0000000..de44145 --- /dev/null +++ b/src/velocity_prediction/transforms.py @@ -0,0 +1,142 @@ +""" +Transforms: event frame generation + coordinate transforms for pose/velocity. + +Each transform operates on a single decoded sample dict: + { + "jpg": bytes, # JPEG-encoded grayscale image + "ts": bytes, # float64 timestamp + "pose": bytes, # float32[7] [x, y, z, qx, qy, qz, qw] + "vel": bytes, # float32[6] [vx, vy, vz, wx, wy, wz] + } +""" + +import io +import numpy as np +import cv2 + +from src.event_utils import EventProcessor +from src.velocity_prediction.utils import decompose_tilt_np, world_vel_to_body_np +from src.velocity_prediction.config import VELOCITY_MEAN, VELOCITY_STD + + +class DecodeSample: + """Decode raw bytes from WebDataset tar entry into numpy arrays.""" + + def __call__(self, sample: dict) -> dict: + # Image: JPEG bytes → grayscale uint8 (H, W) + img = cv2.imdecode(np.frombuffer(sample["jpg"], np.uint8), cv2.IMREAD_GRAYSCALE) + + # Timestamp + ts = np.frombuffer(sample["ts"], dtype=np.float64).item() + + # Pose: [x, y, z, qx, qy, qz, qw] + pose = np.frombuffer(sample["pose"], dtype=np.float32).copy() + + # Velocity: [vx, vy, vz, wx, wy, wz] + vel = np.frombuffer(sample["vel"], dtype=np.float32).copy() + + return {"img": img, "ts": ts, "pose": pose, "vel": vel} + + +class SimulateEvents: + """Convert grayscale frame to binary event frame using EventProcessor.""" + + def __init__(self, threshold=0.1, use_log=True, auto_threshold=False, verbose=False): + self.processor = EventProcessor( + threshold=threshold, + use_log=use_log, + auto_threshold=auto_threshold, + ) + self.verbose = verbose + self._frame_count = 0 + + def __call__(self, sample: dict) -> dict: + img = sample["img"] + events_binary, events_strength, event_count = self.processor(img) + # Use binary events as network input: shape (1, H, W), values in {-1, 0, 1} + sample["events"] = events_binary.astype(np.float32)[np.newaxis, ...] + + if self.verbose: + self._frame_count += 1 + total_pixels = events_binary.shape[0] * events_binary.shape[1] + nonzero_ratio = event_count / total_pixels + pos_ratio = (events_binary > 0).sum() / total_pixels + neg_ratio = (events_binary < 0).sum() / total_pixels + if self._frame_count <= 5 or self._frame_count % 100 == 0: + print(f" [EventDiagnostics] frame={self._frame_count} | " + f"nonzero={nonzero_ratio:.4f} (+{pos_ratio:.4f}/-{neg_ratio:.4f}) | " + f"count={event_count}") + + return sample + + def reset(self): + self.processor.reset() + self._frame_count = 0 + + +class ComputeTilt: + """Extract tilt rotation vector from pose quaternion (discard position, discard yaw).""" + + def __call__(self, sample: dict) -> dict: + q = sample["pose"][3:7] # [qx, qy, qz, qw] + tilt = decompose_tilt_np(q) # (3,) rotation vector + sample["tilt"] = tilt.astype(np.float32) + return sample + + +class ComputeBodyVelocity: + """Transform world-frame velocity to body-frame (yaw-compensated).""" + + def __call__(self, sample: dict) -> dict: + v_world = sample["vel"][:3] # [vx, vy, vz] world frame + q = sample["pose"][3:7] # [qx, qy, qz, qw] + v_body = world_vel_to_body_np(v_world, q) # (3,) + # Only predict forward (x) and lateral (y) body velocity + sample["v_body_target"] = v_body[:2].astype(np.float32) # (2,) + return sample + + +class NormalizeVelocity: + """Normalize body-frame velocity to zero mean, unit variance.""" + + def __init__(self): + self.mean = np.array(VELOCITY_MEAN, dtype=np.float32) + self.std = np.array(VELOCITY_STD, dtype=np.float32) + + def __call__(self, sample: dict) -> dict: + sample["v_body_target"] = (sample["v_body_target"] - self.mean) / self.std + return sample + + +class Compose: + """Chain multiple transforms.""" + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, sample: dict) -> dict: + for t in self.transforms: + sample = t(sample) + return sample + + +def build_train_transform(event_threshold=0.1, event_use_log=True): + """Build the full transform pipeline for training samples.""" + return Compose([ + DecodeSample(), + SimulateEvents(threshold=event_threshold, use_log=event_use_log), + ComputeTilt(), + ComputeBodyVelocity(), + NormalizeVelocity(), + ]) + + +def build_val_transform(event_threshold=0.1, event_use_log=True): + """Same as train but with a fresh EventProcessor per sample (no cross-contamination).""" + return Compose([ + DecodeSample(), + SimulateEvents(threshold=event_threshold, use_log=event_use_log), + ComputeTilt(), + ComputeBodyVelocity(), + NormalizeVelocity(), + ]) diff --git a/src/velocity_prediction/utils.py b/src/velocity_prediction/utils.py new file mode 100644 index 0000000..a64ce34 --- /dev/null +++ b/src/velocity_prediction/utils.py @@ -0,0 +1,165 @@ +""" +Quaternion and coordinate-frame utility functions. + +All quaternions follow [x, y, z, w] convention (matching dataset pose field). +""" + +import torch +import numpy as np + + +# ──────────────────────────── Quaternion operations ──────────────────────────── + +def quat_normalize(q: torch.Tensor) -> torch.Tensor: + """Normalize quaternion. q: (..., 4)""" + return q / torch.norm(q, dim=-1, keepdim=True).clamp(min=1e-12) + + +def quat_conjugate(q: torch.Tensor) -> torch.Tensor: + """Conjugate (inverse for unit quaternion). q: (..., 4)""" + return q * torch.tensor([-1, -1, -1, 1], device=q.device) + + +def quat_mul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """Hamilton product. a, b: (..., 4)""" + x1, y1, z1, w1 = a.unbind(-1) + x2, y2, z2, w2 = b.unbind(-1) + return torch.stack([ + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + ], dim=-1) + + +def quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Rotate vector v by quaternion q. q: (..., 4), v: (..., 3)""" + q_conj = quat_conjugate(q) + v_pad = torch.zeros_like(v[..., :1]) # (..., 1) + v_q = torch.cat([v, v_pad], dim=-1) # (..., 4) pure quaternion + return quat_mul(quat_mul(q, v_q), q_conj)[..., :3] + + +# ──────────────────────────── Yaw decomposition ──────────────────────────── + +def quat_to_yaw(q: torch.Tensor) -> torch.Tensor: + """ + Extract yaw (heading) angle from a quaternion in gravity-aligned frame. + + Gravity axis is +z. Yaw is the rotation around z-axis. + Returns angle in radians, shape (...,). + """ + x, y, z, w = q.unbind(-1) + # From quaternion to Euler: yaw = atan2(2(w*z + x*y), 1 - 2(y² + z²)) + siny = 2.0 * (w * z + x * y) + cosy = 1.0 - 2.0 * (y * y + z * z) + return torch.atan2(siny, cosy) + + +def quat_from_yaw(yaw: torch.Tensor) -> torch.Tensor: + """ + Build a pure-yaw quaternion (rotation around +z). + yaw: (...,) → quat: (..., 4) + """ + half = yaw * 0.5 + cos = torch.cos(half) + sin = torch.sin(half) + z = torch.zeros_like(cos) + return torch.stack([z, z, sin, cos], dim=-1) # [0, 0, sin(yaw/2), cos(yaw/2)] + + +def decompose_tilt(q: torch.Tensor) -> torch.Tensor: + """ + Remove yaw from a quaternion, returning the residual tilt rotation vector. + + Given q = q_yaw * q_tilt (z-yaw first, then body tilt), + we compute q_tilt = q_yaw^{-1} * q, then convert to rotation vector. + + Args: + q: (..., 4) unit quaternion in world→body convention. + + Returns: + tilt_angles: (..., 3) rotation vector [rx, ry, rz] representing + the body's deviation from the heading direction. + """ + yaw = quat_to_yaw(q) + q_yaw = quat_from_yaw(yaw) + q_yaw_inv = quat_conjugate(q_yaw) + q_tilt = quat_mul(q_yaw_inv, q) # remove yaw + q_tilt = quat_normalize(q_tilt) + return quat_to_rotvec(q_tilt) + + +def quat_to_rotvec(q: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """ + Convert unit quaternion to rotation vector (axis * angle). + q: (..., 4) → rotvec: (..., 3) + """ + q = quat_normalize(q) + x, y, z, w = q.unbind(-1) + angle = 2.0 * torch.acos(w.clamp(-1.0, 1.0)) + sin_half = torch.sqrt((1.0 - w * w).clamp(min=eps)) + scale = angle / sin_half + # Avoid division by zero for small angles + mask = sin_half > eps + rx = torch.where(mask, x * scale, torch.zeros_like(x)) + ry = torch.where(mask, y * scale, torch.zeros_like(y)) + rz = torch.where(mask, z * scale, torch.zeros_like(z)) + return torch.stack([rx, ry, rz], dim=-1) + + +# ──────────────────────────── Velocity transformation ──────────────────────────── + +def world_vel_to_body( + v_world: torch.Tensor, + q_world_to_body: torch.Tensor, +) -> torch.Tensor: + """ + Transform world-frame velocity to body-frame velocity. + + Steps: + 1. Extract yaw from q_world_to_body. + 2. Build pure-yaw quaternion q_yaw. + 3. Remove yaw from velocity: v_yaw_compensated = q_yaw^{-1} * v_world + 4. Rotate to body frame: v_body = q_tilt^{-1} * v_yaw_compensated + where q_tilt = q_yaw^{-1} * q_world_to_body + + Args: + v_world: (..., 3) world-frame linear velocity [vx, vy, vz] + q_world_to_body: (..., 4) world→body unit quaternion + + Returns: + v_body: (..., 3) body-frame linear velocity + """ + yaw = quat_to_yaw(q_world_to_body) + q_yaw = quat_from_yaw(yaw) + q_yaw_inv = quat_conjugate(q_yaw) + + # Step 1: remove yaw from velocity (rotate to yaw-aligned intermediate frame) + v_yaw_comp = quat_rotate(q_yaw_inv, v_world) + + # Step 2: compute tilt quaternion + q_tilt = quat_mul(q_yaw_inv, q_world_to_body) + q_tilt = quat_normalize(q_tilt) + q_tilt_inv = quat_conjugate(q_tilt) + + # Step 3: rotate to body frame + v_body = quat_rotate(q_tilt_inv, v_yaw_comp) + return v_body + + +# ──────────────────────────── NumPy wrappers (for transforms.py) ──────────────────────────── + +def decompose_tilt_np(q: np.ndarray) -> np.ndarray: + """NumPy version of decompose_tilt.""" + q_t = torch.from_numpy(q) + tilt = decompose_tilt(q_t) + return tilt.numpy() + + +def world_vel_to_body_np(v_world: np.ndarray, q: np.ndarray) -> np.ndarray: + """NumPy version of world_vel_to_body.""" + v_t = torch.from_numpy(v_world) + q_t = torch.from_numpy(q) + vb = world_vel_to_body(v_t, q_t) + return vb.numpy() diff --git a/start_ros_container.sh b/start_ros_container.sh new file mode 100644 index 0000000..4c4a007 --- /dev/null +++ b/start_ros_container.sh @@ -0,0 +1,143 @@ +#!/bin/bash + +# ROS Podman容器启动脚本 +# 作者自动生成 +# 功能:启动ROS容器,挂载当前目录,设置环境变量,安装pip3,错误不退出 + +set +e # 遇到错误时不退出 + +# 颜色定义(用于美化输出) +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# 配置参数 +CONTAINER_IMAGE="69a38b2c0905" # ROS Noetic Desktop Full镜像ID +WORKSPACE_DIR="$(pwd)" # 当前工作目录 +MOUNT_POINT="/mnt" # 容器内挂载点 + +# ROS环境变量设置 +ROS_ENV_VARS=( + "ROS_DISTRO=noetic" + "ROS_PYTHON_VERSION=3" + "ROS_VERSION=1" + "ROS_ETC_DIR=/opt/ros/noetic/etc/ros" + "ROS_ROOT=/opt/ros/noetic/share/ros" + "ROS_PACKAGE_PATH=/opt/ros/noetic/share" + "PYTHONIOENCODING=utf-8" + "TZ=Asia/Shanghai" +) + +echo -e "${GREEN}========================================${NC}" +echo -e "${GREEN}ROS Podman容器启动脚本${NC}" +echo -e "${GREEN}========================================${NC}" + +# 检查Podman是否安装 +if ! command -v podman &> /dev/null; then + echo -e "${RED}错误: Podman未安装,请先安装Podman${NC}" + echo -e "${YELLOW}Ubuntu/Debian: sudo apt-get install podman${NC}" + echo -e "${YELLOW}CentOS/RHEL: sudo yum install podman${NC}" + exit 1 +fi + +# 检查镜像是否存在 +echo -e "${YELLOW}检查容器镜像...${NC}" +if ! podman image exists $CONTAINER_IMAGE; then + echo -e "${RED}警告: 镜像 $CONTAINER_IMAGE 不存在${NC}" + echo -e "${YELLOW}可用的ROS镜像:${NC}" + podman image list | grep ros + echo -e "${YELLOW}继续使用指定镜像,如果启动失败请更换镜像ID${NC}" +fi + +# 构建环境变量参数 +ENV_ARGS="" +for env_var in "${ROS_ENV_VARS[@]}"; do + ENV_ARGS="$ENV_ARGS -e $env_var" +done + +# 添加额外的环境变量(如果用户需要) +ENV_ARGS="$ENV_ARGS -e DISPLAY=$DISPLAY" # 支持GUI应用 +ENV_ARGS="$ENV_ARGS -e QT_X11_NO_MITSHM=1" + +# X11支持(用于GUI应用) +if [ -n "$DISPLAY" ]; then + echo -e "${YELLOW}启用X11支持,用于GUI应用...${NC}" + xhost +local:root 2>/dev/null + ENV_ARGS="$ENV_ARGS --volume /tmp/.X11-unix:/tmp/.X11-unix:rw" +fi + +echo -e "${GREEN}配置信息:${NC}" +echo -e " 工作目录: $WORKSPACE_DIR" +echo -e " 挂载点: $MOUNT_POINT" +echo -e " 镜像ID: $CONTAINER_IMAGE" +echo -e " ROS发行版: noetic" + +echo -e "${YELLOW}正在启动容器...${NC}" + +# 启动容器的命令 +# 使用bash -c来执行多个命令,确保pip3安装即使失败也不会退出容器 +PODMAN_CMD="podman run -it \ + --rm \ + --name ros_noetic_container_$(date +%s) \ + -v $WORKSPACE_DIR:$MOUNT_POINT:rw \ + $ENV_ARGS \ + $CONTAINER_IMAGE \ + /bin/bash -c \"\ + echo '========================================' && \ + echo 'ROS容器已启动' && \ + echo '========================================' && \ + echo '工作目录已挂载到: $MOUNT_POINT' && \ + echo '当前ROS版本: ' && \ + echo \\\$ROS_DISTRO && \ + echo '' && \ + echo '正在检查并安装pip3...' && \ + if command -v pip3 &> /dev/null; then \ + echo 'pip3已安装,版本: ' && \ + pip3 --version; \ + else \ + echo 'pip3未安装,正在安装...' && \ + apt-get update 2>/dev/null && \ + apt-get install -y python3-pip 2>/dev/null; \ + if [ \\\$? -eq 0 ]; then \ + echo 'pip3安装成功!'; \ + pip3 --version; \ + else \ + echo '警告: pip3安装失败,请手动安装'; \ + fi; \ + fi && \ + echo '' && \ + echo '========================================' && \ + echo '环境变量已设置:' && \ + env | grep ROS_ && \ + echo '========================================' && \ + echo '容器已准备就绪,进入交互式shell...' && \ + echo '提示: 输入exit退出容器' && \ + echo '========================================' && \ + pip config set global.index-url https://mirrors.ustc.edu.cn/pypi/simple && \ + source /ros_entrypoint.sh && \ + cd $MOUNT_POINT && \ + exec /bin/bash -l\"" + +# 执行命令 +echo -e "${GREEN}执行启动命令...${NC}" +# echo -e "${YELLOW}命令详情:${NC}" +# echo "$PODMAN_CMD" +echo -e "${YELLOW}========================================${NC}" + +# 运行容器 +eval $PODMAN_CMD + +# 检查退出状态 +if [ $? -ne 0 ]; then + echo -e "${RED}容器已退出(退出码非0)${NC}" +else + echo -e "${GREEN}容器正常退出${NC}" +fi + +# 清理X11权限 +if [ -n "$DISPLAY" ]; then + xhost -local:root 2>/dev/null +fi + +echo -e "${GREEN}脚本执行完成${NC}" \ No newline at end of file