initial commit
This commit is contained in:
48
.gitignore
vendored
Normal file
48
.gitignore
vendored
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.pyo
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
|
||||||
|
# Virtual environment
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
|
||||||
|
# IDE / Editor
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Data (large binary files)
|
||||||
|
bags/
|
||||||
|
dataset/
|
||||||
|
|
||||||
|
# Model checkpoints / weights
|
||||||
|
checkpoints/
|
||||||
|
*.pt
|
||||||
|
|
||||||
|
# Logs (TensorBoard, etc.)
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# Benchmark evaluation results
|
||||||
|
benchmark/results/
|
||||||
|
|
||||||
|
# Evaluation figures
|
||||||
|
*.png
|
||||||
|
*.jpg
|
||||||
|
*.jpeg
|
||||||
|
*.pdf
|
||||||
|
*.svg
|
||||||
|
|
||||||
|
# Shell scripts (optional — uncomment if you want to ignore)
|
||||||
|
# *.sh
|
||||||
|
|
||||||
|
# ROS bag files
|
||||||
|
*.bag
|
||||||
|
*.bag.active
|
||||||
95
DATASET_FORMAT.md
Normal file
95
DATASET_FORMAT.md
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
# UZH FPV Dataset Format
|
||||||
|
|
||||||
|
> 由 `rosbag2wds.py` 从 DAVIS 事件相机 ROS bag 转换生成
|
||||||
|
|
||||||
|
## 目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
dataset/
|
||||||
|
├── <dataset_name>/
|
||||||
|
│ ├── shard_0000.tar # WebDataset shard (图像 + 对齐的 GT)
|
||||||
|
│ ├── shard_0001.tar
|
||||||
|
│ ├── ...
|
||||||
|
│ ├── imu_sequence.npz # 完整 IMU 序列 (独立存储)
|
||||||
|
│ └── metadata.json # 数据集元信息
|
||||||
|
```
|
||||||
|
|
||||||
|
## 文件说明
|
||||||
|
|
||||||
|
### 1. WebDataset Shard (`shard_*.tar`)
|
||||||
|
|
||||||
|
每个 tar 文件包含 `shard_size` 个样本(默认 2000),每个样本的 key 为 `frame_<index>`,包含以下字段:
|
||||||
|
|
||||||
|
| Key | 类型 | 内容 |
|
||||||
|
|------|-------------|-------------------------------------------------------------------|
|
||||||
|
| `jpg` | JPEG bytes | 灰度图,尺寸 `320×240`,JPEG quality=85 |
|
||||||
|
| `ts` | float64 | 图像时间戳(ROS bag 系统时间 `t.to_sec()`) |
|
||||||
|
| `pose`| float32[7] | 位姿:`[x, y, z, qx, qy, qz, qw]`(位置 + 单位四元数) |
|
||||||
|
| `vel` | float32[6] | 速度:`[vx, vy, vz, wx, wy, wz]`(线速度 + 角速度) |
|
||||||
|
|
||||||
|
**读取示例 (Python):**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import webdataset as wds
|
||||||
|
|
||||||
|
dataset = wds.WebDataset("dataset/<name>/shard_0000.tar")
|
||||||
|
for sample in dataset:
|
||||||
|
img = sample["jpg"] # JPEG bytes
|
||||||
|
ts = sample["ts"] # bytes -> np.frombuffer(..., dtype=np.float64)
|
||||||
|
pose = sample["pose"] # bytes -> np.frombuffer(..., dtype=np.float32).reshape(7)
|
||||||
|
vel = sample["vel"] # bytes -> np.frombuffer(..., dtype=np.float32).reshape(6)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. IMU 序列 (`imu_sequence.npz`)
|
||||||
|
|
||||||
|
独立存储的完整 IMU 数据(NPZ 压缩格式),包含三个数组:
|
||||||
|
|
||||||
|
| Key | 类型 | 形状 | 内容 |
|
||||||
|
|--------------------|-------------|-------------|----------------------------|
|
||||||
|
| `timestamps` | float64 | (N,) | IMU 时间戳 |
|
||||||
|
| `accelerations` | float32 | (N, 3) | 线性加速度 `(ax, ay, az)` m/s² |
|
||||||
|
| `angular_velocities`| float32 | (N, 3) | 角速度 `(gx, gy, gz)` rad/s |
|
||||||
|
|
||||||
|
**读取示例:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
import numpy as np
|
||||||
|
data = np.load("dataset/<name>/imu_sequence.npz")
|
||||||
|
timestamps = data["timestamps"]
|
||||||
|
acc = data["accelerations"]
|
||||||
|
gyro = data["angular_velocities"]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 元信息 (`metadata.json`)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"dataset_name": "indoor_forward_7",
|
||||||
|
"source_bag": "/mnt/indoor_forward_7_davis_with_gt.bag",
|
||||||
|
"num_images": 2459,
|
||||||
|
"num_imu_messages": 66632,
|
||||||
|
"num_ground_truth": 33350,
|
||||||
|
"image_size": [320, 240],
|
||||||
|
"imu_frequency_hz": 999.02,
|
||||||
|
"camera_frequency_hz": 36.89,
|
||||||
|
"gt_frequency_hz": 500.01,
|
||||||
|
"coordinate_system": "horizontal (z aligned with gravity, assumed from GT)",
|
||||||
|
"velocity_dimensions": 6
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 数据来源
|
||||||
|
|
||||||
|
| Topic | 内容 | 频率 (典型值) |
|
||||||
|
|------------------------|------------------|-------------|
|
||||||
|
| `/dvs/image_raw` | 灰度图像 (mono8) | ~30–50 Hz |
|
||||||
|
| `/dvs/imu` | IMU (加速度+角速度)| ~1000 Hz |
|
||||||
|
| `/groundtruth/odometry`| 位姿真值 | ~500 Hz |
|
||||||
|
|
||||||
|
## 预处理说明
|
||||||
|
|
||||||
|
- **时间戳**: 统一使用 ROS bag 系统时间 `t.to_sec()`,而非 `msg.header.stamp`
|
||||||
|
- **时间对齐**: 图像与 GT 通过最近邻时间戳匹配,最大允许偏差 0.1s
|
||||||
|
- **速度计算**: GT 速度由位姿差分计算(前向/后向有限差分 + 四元数旋转向量),忽略 bag 中原始 twist 数据
|
||||||
|
- **时间裁剪**: 所有数据裁剪至 GT 时间范围内,去除首尾无 GT 的片段
|
||||||
|
- **图像缩放**: 原始 DAVIS 分辨率 `240×180` → 缩放至 `320×240` (INTER_LINEAR)
|
||||||
87
batch_convert.sh
Executable file
87
batch_convert.sh
Executable file
@@ -0,0 +1,87 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# batch_convert.sh
|
||||||
|
# 在已运行的 ROS 容器内执行,批量转换尚未转换的数据集
|
||||||
|
#
|
||||||
|
# 用法(容器内):
|
||||||
|
# cd /mnt && bash batch_convert.sh
|
||||||
|
#
|
||||||
|
# 它会自动检测:
|
||||||
|
# - 哪些 .bag 文件尚未转换(通过检查 dataset/<name>/metadata.json)
|
||||||
|
# - 跳过已转换的数据集
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
||||||
|
OUTPUT_DIR="${SCRIPT_DIR}/dataset"
|
||||||
|
BAG_DIR="${SCRIPT_DIR}/bags"
|
||||||
|
|
||||||
|
# 已转换的数据集列表(通过检查 metadata.json)
|
||||||
|
echo "=========================================="
|
||||||
|
echo "批量转换 UZH FPV 数据集"
|
||||||
|
echo "=========================================="
|
||||||
|
|
||||||
|
# 收集所有 bag 文件
|
||||||
|
BAG_FILES=()
|
||||||
|
while IFS= read -r -d '' f; do
|
||||||
|
BAG_FILES+=("$f")
|
||||||
|
done < <(find "$BAG_DIR" -maxdepth 2 -name "*_davis_with_gt.bag" -print0)
|
||||||
|
|
||||||
|
if [ ${#BAG_FILES[@]} -eq 0 ]; then
|
||||||
|
echo "❌ 未找到任何 .bag 文件"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "找到 ${#BAG_FILES[@]} 个 bag 文件"
|
||||||
|
|
||||||
|
# 逐个检查并转换
|
||||||
|
CONVERTED=0
|
||||||
|
SKIPPED=0
|
||||||
|
FAILED=0
|
||||||
|
|
||||||
|
for bag_path in "${BAG_FILES[@]}"; do
|
||||||
|
bag_name="$(basename "$bag_path")"
|
||||||
|
dataset_name="${bag_name%_davis_with_gt.bag}"
|
||||||
|
|
||||||
|
# 检查是否已转换
|
||||||
|
metadata_file="${OUTPUT_DIR}/${dataset_name}/metadata.json"
|
||||||
|
if [ -f "$metadata_file" ]; then
|
||||||
|
echo " ⏭️ 跳过 ${dataset_name} (已转换)"
|
||||||
|
SKIPPED=$((SKIPPED + 1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo " 🔄 转换: ${dataset_name}"
|
||||||
|
|
||||||
|
# 检查依赖
|
||||||
|
if ! python3 -c "import webdataset" 2>/dev/null; then
|
||||||
|
echo " ⚠️ 正在安装 webdataset..."
|
||||||
|
pip3 install webdataset tqdm scipy 2>/dev/null || {
|
||||||
|
echo " ❌ pip install 失败,跳过 ${dataset_name}"
|
||||||
|
FAILED=$((FAILED + 1))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 执行转换
|
||||||
|
if python3 "${SCRIPT_DIR}/rosbag2wds.py" \
|
||||||
|
--bag "$bag_path" \
|
||||||
|
--output "$OUTPUT_DIR" \
|
||||||
|
--name "$dataset_name" \
|
||||||
|
--shard_size 2000 \
|
||||||
|
--width 320 --height 240; then
|
||||||
|
echo " ✅ ${dataset_name} 转换完成"
|
||||||
|
CONVERTED=$((CONVERTED + 1))
|
||||||
|
else
|
||||||
|
echo " ❌ ${dataset_name} 转换失败"
|
||||||
|
FAILED=$((FAILED + 1))
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo "批量转换结束"
|
||||||
|
echo " 已转换: ${CONVERTED}"
|
||||||
|
echo " 已跳过: ${SKIPPED}"
|
||||||
|
echo " 失败: ${FAILED}"
|
||||||
|
echo "=========================================="
|
||||||
0
benchmark/__init__.py
Normal file
0
benchmark/__init__.py
Normal file
309
benchmark/benchmark.py
Normal file
309
benchmark/benchmark.py
Normal file
@@ -0,0 +1,309 @@
|
|||||||
|
"""
|
||||||
|
benchmark.py — Unified evaluation entry point.
|
||||||
|
|
||||||
|
Two modes:
|
||||||
|
1. Single-model eval: python -m benchmark.benchmark --checkpoint <path>
|
||||||
|
2. Compare mode: python -m benchmark.benchmark --compare <checkpoint_dir>
|
||||||
|
|
||||||
|
Results are saved to benchmark/results/<exp_name>/.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from benchmark.config import eval_cfg, TEST_SCENE_GROUPS
|
||||||
|
from benchmark.evaluate import run_full_evaluation, save_results
|
||||||
|
|
||||||
|
# Project root (two levels up from benchmark/benchmark.py)
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
RESULTS_DIR = PROJECT_ROOT / "benchmark" / "results"
|
||||||
|
|
||||||
|
|
||||||
|
def load_checkpoint(
|
||||||
|
checkpoint_path: Path,
|
||||||
|
device: torch.device,
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
"""Load a VelocityPredictionModel from a checkpoint file."""
|
||||||
|
from src.velocity_prediction.model import VelocityPredictionModel
|
||||||
|
|
||||||
|
model = VelocityPredictionModel()
|
||||||
|
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
||||||
|
state_dict = ckpt.get("model_state_dict", ckpt)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def run_single_eval(
|
||||||
|
checkpoint_path: Path,
|
||||||
|
output_dir: Optional[Path] = None,
|
||||||
|
device: torch.device = None,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
save_plots: bool = True,
|
||||||
|
) -> Path:
|
||||||
|
"""Evaluate a single checkpoint and save results.
|
||||||
|
|
||||||
|
Returns the output directory path.
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
seq_len = seq_len or eval_cfg.seq_len
|
||||||
|
batch_size = batch_size or eval_cfg.batch_size
|
||||||
|
num_workers = num_workers or eval_cfg.num_workers
|
||||||
|
|
||||||
|
# Derive experiment name from checkpoint filename (strip extension)
|
||||||
|
exp_name = checkpoint_path.stem # e.g. "best" or "epoch_050_val_1.827390"
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = RESULTS_DIR / exp_name
|
||||||
|
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"Benchmark — Single Model Evaluation")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f" Checkpoint: {checkpoint_path}")
|
||||||
|
print(f" Device: {device}")
|
||||||
|
print(f" Seq len: {seq_len}")
|
||||||
|
print(f" Batch size: {batch_size}")
|
||||||
|
print(f" Output: {output_dir}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
model = load_checkpoint(checkpoint_path, device)
|
||||||
|
|
||||||
|
results = run_full_evaluation(
|
||||||
|
model=model,
|
||||||
|
device=device,
|
||||||
|
seq_len=seq_len,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
event_threshold=eval_cfg.event_threshold,
|
||||||
|
event_use_log=eval_cfg.event_use_log,
|
||||||
|
scene_groups=TEST_SCENE_GROUPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(results, save_dir=output_dir, checkpoint_name=exp_name)
|
||||||
|
|
||||||
|
return output_dir
|
||||||
|
|
||||||
|
|
||||||
|
def run_compare(
|
||||||
|
checkpoint_dir: Path,
|
||||||
|
output_dir: Optional[Path] = None,
|
||||||
|
device: torch.device = None,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
pattern: str = "*.pt",
|
||||||
|
) -> Path:
|
||||||
|
"""Evaluate all checkpoints in a directory and produce a comparison table.
|
||||||
|
|
||||||
|
Returns the output directory path.
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
seq_len = seq_len or eval_cfg.seq_len
|
||||||
|
batch_size = batch_size or eval_cfg.batch_size
|
||||||
|
num_workers = num_workers or eval_cfg.num_workers
|
||||||
|
|
||||||
|
checkpoint_paths = sorted(Path(checkpoint_dir).glob(pattern))
|
||||||
|
if not checkpoint_paths:
|
||||||
|
print(f"No checkpoints found matching '{pattern}' in {checkpoint_dir}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = RESULTS_DIR / "compare"
|
||||||
|
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f"Benchmark — Compare Mode ({len(checkpoint_paths)} checkpoints)")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
print(f" Checkpoint dir: {checkpoint_dir}")
|
||||||
|
print(f" Device: {device}")
|
||||||
|
print(f" Seq len: {seq_len}")
|
||||||
|
print(f" Batch size: {batch_size}")
|
||||||
|
print(f" Output: {output_dir}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
all_global_metrics = []
|
||||||
|
|
||||||
|
for ckpt_path in checkpoint_paths:
|
||||||
|
exp_name = ckpt_path.stem
|
||||||
|
print(f"\n── Evaluating {exp_name} ──")
|
||||||
|
|
||||||
|
model = load_checkpoint(ckpt_path, device)
|
||||||
|
|
||||||
|
results = run_full_evaluation(
|
||||||
|
model=model,
|
||||||
|
device=device,
|
||||||
|
seq_len=seq_len,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
event_threshold=eval_cfg.event_threshold,
|
||||||
|
event_use_log=eval_cfg.event_use_log,
|
||||||
|
scene_groups=TEST_SCENE_GROUPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save individual results
|
||||||
|
ckpt_output_dir = output_dir / exp_name
|
||||||
|
save_results(results, save_dir=ckpt_output_dir, checkpoint_name=exp_name)
|
||||||
|
|
||||||
|
all_global_metrics.append((exp_name, results["global"]))
|
||||||
|
|
||||||
|
# ── Comparison table ──
|
||||||
|
print(f"\n\n{'=' * 60}")
|
||||||
|
print("Comparison Summary")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
|
header = f"{'Checkpoint':<30} {'RMSE vx':>10} {'RMSE vy':>10} {'RMSE xy':>10} {'MAE vx':>10} {'MAE vy':>10} {'R² vx':>8} {'R² vy':>8}"
|
||||||
|
sep = "-" * len(header)
|
||||||
|
print(header)
|
||||||
|
print(sep)
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for name, metrics in all_global_metrics:
|
||||||
|
row = (
|
||||||
|
f"{name:<30} "
|
||||||
|
f"{metrics.get('rmse_vx', 0):>10.4f} "
|
||||||
|
f"{metrics.get('rmse_vy', 0):>10.4f} "
|
||||||
|
f"{metrics.get('rmse_xy', 0):>10.4f} "
|
||||||
|
f"{metrics.get('mae_vx', 0):>10.4f} "
|
||||||
|
f"{metrics.get('mae_vy', 0):>10.4f} "
|
||||||
|
f"{metrics.get('r2_vx', 0):>8.4f} "
|
||||||
|
f"{metrics.get('r2_vy', 0):>8.4f}"
|
||||||
|
)
|
||||||
|
print(row)
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
# Save comparison CSV
|
||||||
|
import csv
|
||||||
|
csv_path = output_dir / "comparison.csv"
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
fieldnames = ["checkpoint", "rmse_vx", "rmse_vy", "rmse_xy", "mae_vx", "mae_vy",
|
||||||
|
"mae_xy", "r2_vx", "r2_vy", "count"]
|
||||||
|
with open(csv_path, "w", newline="") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for name, metrics in all_global_metrics:
|
||||||
|
row = {"checkpoint": name, **metrics}
|
||||||
|
writer.writerow(row)
|
||||||
|
print(f"\nComparison CSV: {csv_path}")
|
||||||
|
|
||||||
|
# Save comparison text
|
||||||
|
txt_path = output_dir / "comparison.txt"
|
||||||
|
with open(txt_path, "w") as f:
|
||||||
|
f.write("Benchmark Comparison\n")
|
||||||
|
f.write(f"{'=' * 60}\n\n")
|
||||||
|
f.write(header + "\n")
|
||||||
|
f.write(sep + "\n")
|
||||||
|
for row in rows:
|
||||||
|
f.write(row + "\n")
|
||||||
|
print(f"Comparison TXT: {txt_path}")
|
||||||
|
|
||||||
|
return output_dir
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Unified benchmark for velocity prediction models.",
|
||||||
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
|
epilog=(
|
||||||
|
"Examples:\n"
|
||||||
|
" # Single model evaluation\n"
|
||||||
|
" python -m benchmark.benchmark --checkpoint checkpoints/best.pt\n\n"
|
||||||
|
" # Compare all checkpoints in a directory\n"
|
||||||
|
" python -m benchmark.benchmark --compare checkpoints/\n\n"
|
||||||
|
" # Custom output directory\n"
|
||||||
|
" python -m benchmark.benchmark --checkpoint checkpoints/best.pt --output my_results/\n"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mutually exclusive mode selection
|
||||||
|
mode = parser.add_mutually_exclusive_group(required=True)
|
||||||
|
mode.add_argument(
|
||||||
|
"--checkpoint", type=str, default=None,
|
||||||
|
help="Path to a single checkpoint .pt file for single-model evaluation.",
|
||||||
|
)
|
||||||
|
mode.add_argument(
|
||||||
|
"--compare", type=str, default=None,
|
||||||
|
help="Directory containing multiple .pt checkpoints for comparison.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional overrides
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", type=str, default=None,
|
||||||
|
help="Output directory for results (default: benchmark/results/<exp_name>/).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, default=None,
|
||||||
|
help="Device override, e.g. 'cuda:0' or 'cpu' (default: auto-detect).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seq-len", type=int, default=None,
|
||||||
|
help=f"Sequence length override (default: {eval_cfg.seq_len}).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size", type=int, default=None,
|
||||||
|
help=f"Batch size override (default: {eval_cfg.batch_size}).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-workers", type=int, default=None,
|
||||||
|
help=f"DataLoader workers override (default: {eval_cfg.num_workers}).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pattern", type=str, default="*.pt",
|
||||||
|
help="Glob pattern for --compare mode (default: '*.pt').",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-plots", action="store_true",
|
||||||
|
help="Skip generating per-scene plots.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Resolve device
|
||||||
|
device = None
|
||||||
|
if args.device is not None:
|
||||||
|
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
|
||||||
|
|
||||||
|
# Resolve output directory
|
||||||
|
output_dir = Path(args.output) if args.output else None
|
||||||
|
|
||||||
|
if args.checkpoint:
|
||||||
|
ckpt_path = Path(args.checkpoint)
|
||||||
|
if not ckpt_path.exists():
|
||||||
|
print(f"Error: checkpoint not found: {ckpt_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
run_single_eval(
|
||||||
|
checkpoint_path=ckpt_path,
|
||||||
|
output_dir=output_dir,
|
||||||
|
device=device,
|
||||||
|
seq_len=args.seq_len,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
save_plots=not args.no_plots,
|
||||||
|
)
|
||||||
|
elif args.compare:
|
||||||
|
ckpt_dir = Path(args.compare)
|
||||||
|
if not ckpt_dir.is_dir():
|
||||||
|
print(f"Error: checkpoint directory not found: {ckpt_dir}")
|
||||||
|
sys.exit(1)
|
||||||
|
run_compare(
|
||||||
|
checkpoint_dir=ckpt_dir,
|
||||||
|
output_dir=output_dir,
|
||||||
|
device=device,
|
||||||
|
seq_len=args.seq_len,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
pattern=args.pattern,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
98
benchmark/config.py
Normal file
98
benchmark/config.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
Benchmark configuration — evaluation-only scene splits and metric definitions.
|
||||||
|
|
||||||
|
This config is independent from src.velocity_prediction.config so that
|
||||||
|
evaluation scenarios can be changed without touching training config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Dataset root ────────────────────────────
|
||||||
|
|
||||||
|
DATASET_ROOT = Path(__file__).resolve().parents[1] / "dataset"
|
||||||
|
|
||||||
|
# ──────────────────────────── Scene splits ────────────────────────────
|
||||||
|
|
||||||
|
# Each scene group has a name, a list of scene dirs, and a difficulty label.
|
||||||
|
# The test scenes are the primary evaluation set; val scenes are for
|
||||||
|
# checkpoint selection reference.
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SceneGroup:
|
||||||
|
name: str
|
||||||
|
scenes: List[str]
|
||||||
|
difficulty: str = "medium" # easy / medium / hard
|
||||||
|
|
||||||
|
|
||||||
|
# ── Validation scenes (for checkpoint selection reference) ──
|
||||||
|
VAL_SCENE_GROUPS: List[SceneGroup] = [
|
||||||
|
SceneGroup("indoor_forward_7", ["indoor_forward_7"], "hard"),
|
||||||
|
SceneGroup("outdoor_forward_1", ["outdoor_forward_1"], "easy"),
|
||||||
|
# SceneGroup("indoor_forward_6", ["indoor_forward_6"], "medium"),
|
||||||
|
# SceneGroup("indoor_forward_9", ["indoor_forward_9"], "easy"),
|
||||||
|
# SceneGroup("indoor_forward_10", ["indoor_forward_10"], "easy"),
|
||||||
|
# SceneGroup("indoor_forward_5", ["indoor_forward_5"], "medium"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# ── Test scenes (primary evaluation) ──
|
||||||
|
TEST_SCENE_GROUPS: List[SceneGroup] = [
|
||||||
|
SceneGroup("indoor_forward_7", ["indoor_forward_7"], "hard"),
|
||||||
|
SceneGroup("outdoor_forward_1", ["outdoor_forward_1"], "easy"),
|
||||||
|
SceneGroup("outdoor_forward_5", ["outdoor_forward_5"], "hard"),
|
||||||
|
SceneGroup("indoor_forward_6", ["indoor_forward_6"], "medium"),
|
||||||
|
SceneGroup("indoor_forward_9", ["indoor_forward_9"], "easy"),
|
||||||
|
SceneGroup("indoor_forward_10", ["indoor_forward_10"], "easy"),
|
||||||
|
SceneGroup("indoor_forward_5", ["indoor_forward_5"], "medium"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Flat lists for convenience
|
||||||
|
VAL_SCENES: List[str] = [s for g in VAL_SCENE_GROUPS for s in g.scenes]
|
||||||
|
TEST_SCENES: List[str] = [s for g in TEST_SCENE_GROUPS for s in g.scenes]
|
||||||
|
|
||||||
|
# Difficulty grouping
|
||||||
|
DIFFICULTY_GROUPS: Dict[str, List[str]] = {}
|
||||||
|
for g in TEST_SCENE_GROUPS:
|
||||||
|
DIFFICULTY_GROUPS.setdefault(g.difficulty, []).extend(g.scenes)
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Evaluation parameters ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalConfig:
|
||||||
|
"""Parameters used when running evaluation."""
|
||||||
|
|
||||||
|
# Sequence length (must match what the model was trained with)
|
||||||
|
seq_len: int = 8
|
||||||
|
|
||||||
|
# Batch size for evaluation (can be larger than training)
|
||||||
|
batch_size: int = 64
|
||||||
|
|
||||||
|
# Data loading
|
||||||
|
num_workers: int = 2
|
||||||
|
|
||||||
|
# Event simulation (must match training config)
|
||||||
|
event_threshold: float = 0.1
|
||||||
|
event_use_log: bool = True
|
||||||
|
|
||||||
|
# Output directory (relative to benchmark/results/)
|
||||||
|
output_dir: str = "results"
|
||||||
|
|
||||||
|
# Whether to generate per-scene plots
|
||||||
|
save_plots: bool = True
|
||||||
|
|
||||||
|
# Device override (None = auto-detect)
|
||||||
|
device: str = "cuda"
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Metrics definition ────────────────────────────
|
||||||
|
|
||||||
|
# Metrics computed per-axis and overall
|
||||||
|
METRICS = ["rmse", "mae", "r2"]
|
||||||
|
|
||||||
|
# Singleton
|
||||||
|
eval_cfg = EvalConfig()
|
||||||
458
benchmark/evaluate.py
Normal file
458
benchmark/evaluate.py
Normal file
@@ -0,0 +1,458 @@
|
|||||||
|
"""
|
||||||
|
Core evaluation logic: run model on one or more scenes, compute metrics,
|
||||||
|
generate visualizations, and save structured results.
|
||||||
|
|
||||||
|
This module is called by benchmark.py (the user-facing entry point).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from src.velocity_prediction.model import VelocityPredictionModel
|
||||||
|
from src.velocity_prediction.dataset import create_val_loader
|
||||||
|
from src.velocity_prediction.config import VELOCITY_MEAN, VELOCITY_STD
|
||||||
|
|
||||||
|
from benchmark.config import (
|
||||||
|
eval_cfg,
|
||||||
|
TEST_SCENE_GROUPS,
|
||||||
|
VAL_SCENE_GROUPS,
|
||||||
|
DIFFICULTY_GROUPS,
|
||||||
|
DATASET_ROOT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Metrics ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def compute_metrics(
|
||||||
|
pred: np.ndarray,
|
||||||
|
target: np.ndarray,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute RMSE, MAE, R² for each axis and overall.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pred: (N, 2) denormalized predictions
|
||||||
|
target: (N, 2) denormalized ground truth
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys like rmse_vx, rmse_vy, rmse_xy, mae_vx, ...
|
||||||
|
"""
|
||||||
|
# Per-axis
|
||||||
|
rmse_x = float(np.sqrt(np.mean((pred[:, 0] - target[:, 0]) ** 2)))
|
||||||
|
rmse_y = float(np.sqrt(np.mean((pred[:, 1] - target[:, 1]) ** 2)))
|
||||||
|
rmse_xy = float(np.sqrt(np.mean(np.sum((pred - target) ** 2, axis=1))))
|
||||||
|
|
||||||
|
mae_x = float(np.mean(np.abs(pred[:, 0] - target[:, 0])))
|
||||||
|
mae_y = float(np.mean(np.abs(pred[:, 1] - target[:, 1])))
|
||||||
|
mae_xy = float(np.mean(np.sqrt(np.sum((pred - target) ** 2, axis=1))))
|
||||||
|
|
||||||
|
# R² per axis
|
||||||
|
def r2(p, t):
|
||||||
|
ss_res = np.sum((t - p) ** 2)
|
||||||
|
ss_tot = np.sum((t - np.mean(t)) ** 2)
|
||||||
|
return float(1 - ss_res / ss_tot) if ss_tot > 1e-12 else 0.0
|
||||||
|
|
||||||
|
r2_x = r2(pred[:, 0], target[:, 0])
|
||||||
|
r2_y = r2(pred[:, 1], target[:, 1])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"rmse_vx": rmse_x,
|
||||||
|
"rmse_vy": rmse_y,
|
||||||
|
"rmse_xy": rmse_xy,
|
||||||
|
"mae_vx": mae_x,
|
||||||
|
"mae_vy": mae_y,
|
||||||
|
"mae_xy": mae_xy,
|
||||||
|
"r2_vx": r2_x,
|
||||||
|
"r2_vy": r2_y,
|
||||||
|
"count": len(pred),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Per-scene evaluation ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate_scene(
|
||||||
|
model: nn.Module,
|
||||||
|
scene_names: List[str],
|
||||||
|
device: torch.device,
|
||||||
|
seq_len: int = 8,
|
||||||
|
batch_size: int = 64,
|
||||||
|
num_workers: int = 2,
|
||||||
|
event_threshold: float = 0.1,
|
||||||
|
event_use_log: bool = True,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Evaluate model on one or more scenes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys:
|
||||||
|
preds: (N, 2) denormalized predictions
|
||||||
|
targets: (N, 2) denormalized ground truth
|
||||||
|
metrics: dict of scalar metrics
|
||||||
|
"""
|
||||||
|
loader = create_val_loader(
|
||||||
|
scene_names=scene_names,
|
||||||
|
seq_len=seq_len,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
event_threshold=event_threshold,
|
||||||
|
event_use_log=event_use_log,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
all_preds = []
|
||||||
|
all_targets = []
|
||||||
|
|
||||||
|
for batch in loader:
|
||||||
|
events = batch["events"].to(device)
|
||||||
|
tilt = batch["tilt"].to(device)
|
||||||
|
target = batch["v_body_target"].to(device) # (B, S, 2) normalized
|
||||||
|
|
||||||
|
pred = model(events, tilt) # (B, 2) normalized
|
||||||
|
target_last = target[:, -1, :] # (B, 2) normalized
|
||||||
|
|
||||||
|
all_preds.append(pred.cpu().numpy())
|
||||||
|
all_targets.append(target_last.cpu().numpy())
|
||||||
|
|
||||||
|
if not all_preds:
|
||||||
|
return {"preds": np.zeros((0, 2)), "targets": np.zeros((0, 2)), "metrics": {}}
|
||||||
|
|
||||||
|
preds = np.concatenate(all_preds, axis=0)
|
||||||
|
targets = np.concatenate(all_targets, axis=0)
|
||||||
|
|
||||||
|
# Denormalize
|
||||||
|
mean = np.array(VELOCITY_MEAN, dtype=np.float32)
|
||||||
|
std = np.array(VELOCITY_STD, dtype=np.float32)
|
||||||
|
preds_denorm = preds * std + mean
|
||||||
|
targets_denorm = targets * std + mean
|
||||||
|
|
||||||
|
metrics = compute_metrics(preds_denorm, targets_denorm)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"preds": preds_denorm,
|
||||||
|
"targets": targets_denorm,
|
||||||
|
"metrics": metrics,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Full evaluation suite ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def run_full_evaluation(
|
||||||
|
model: nn.Module,
|
||||||
|
device: torch.device,
|
||||||
|
seq_len: int = 8,
|
||||||
|
batch_size: int = 64,
|
||||||
|
num_workers: int = 2,
|
||||||
|
event_threshold: float = 0.1,
|
||||||
|
event_use_log: bool = True,
|
||||||
|
scene_groups=None,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Run evaluation on all scene groups.
|
||||||
|
|
||||||
|
Returns nested dict:
|
||||||
|
{
|
||||||
|
"global": { metrics... },
|
||||||
|
"per_scene": {
|
||||||
|
"indoor_forward_7": { metrics..., "preds": ..., "targets": ... },
|
||||||
|
...
|
||||||
|
},
|
||||||
|
"by_difficulty": {
|
||||||
|
"easy": { metrics... },
|
||||||
|
"hard": { metrics... },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if scene_groups is None:
|
||||||
|
from benchmark.config import TEST_SCENE_GROUPS
|
||||||
|
scene_groups = TEST_SCENE_GROUPS
|
||||||
|
|
||||||
|
per_scene = {}
|
||||||
|
all_preds = []
|
||||||
|
all_targets = []
|
||||||
|
|
||||||
|
for group in scene_groups:
|
||||||
|
for scene_name in group.scenes:
|
||||||
|
result = evaluate_scene(
|
||||||
|
model, [scene_name], device,
|
||||||
|
seq_len=seq_len, batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
event_threshold=event_threshold,
|
||||||
|
event_use_log=event_use_log,
|
||||||
|
)
|
||||||
|
per_scene[scene_name] = result
|
||||||
|
if result["preds"].shape[0] > 0:
|
||||||
|
all_preds.append(result["preds"])
|
||||||
|
all_targets.append(result["targets"])
|
||||||
|
|
||||||
|
# Global metrics (all scenes combined)
|
||||||
|
if all_preds:
|
||||||
|
global_preds = np.concatenate(all_preds, axis=0)
|
||||||
|
global_targets = np.concatenate(all_targets, axis=0)
|
||||||
|
global_metrics = compute_metrics(global_preds, global_targets)
|
||||||
|
else:
|
||||||
|
global_preds = np.zeros((0, 2))
|
||||||
|
global_targets = np.zeros((0, 2))
|
||||||
|
global_metrics = {}
|
||||||
|
|
||||||
|
# By difficulty
|
||||||
|
by_difficulty = {}
|
||||||
|
for diff, scenes in DIFFICULTY_GROUPS.items():
|
||||||
|
diff_preds = []
|
||||||
|
diff_targets = []
|
||||||
|
for s in scenes:
|
||||||
|
if s in per_scene and per_scene[s]["preds"].shape[0] > 0:
|
||||||
|
diff_preds.append(per_scene[s]["preds"])
|
||||||
|
diff_targets.append(per_scene[s]["targets"])
|
||||||
|
if diff_preds:
|
||||||
|
by_difficulty[diff] = compute_metrics(
|
||||||
|
np.concatenate(diff_preds, axis=0),
|
||||||
|
np.concatenate(diff_targets, axis=0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
by_difficulty[diff] = {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"global": global_metrics,
|
||||||
|
"per_scene": per_scene,
|
||||||
|
"by_difficulty": by_difficulty,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Visualization ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def plot_scene_comparison(
|
||||||
|
preds: np.ndarray,
|
||||||
|
targets: np.ndarray,
|
||||||
|
scene_name: str,
|
||||||
|
save_dir: Path,
|
||||||
|
metrics: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
"""Generate time-series and scatter plots for a single scene."""
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
save_dir = Path(save_dir)
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
time = np.arange(len(preds))
|
||||||
|
|
||||||
|
# ── Time-series plot ──
|
||||||
|
fig, axes = plt.subplots(2, 1, figsize=(14, 5), sharex=True)
|
||||||
|
|
||||||
|
axes[0].plot(time, targets[:, 0], label="GT vx", color="C0", alpha=0.7, linewidth=0.8)
|
||||||
|
axes[0].plot(time, preds[:, 0], label="Pred vx", color="C1", alpha=0.7, linewidth=0.8)
|
||||||
|
axes[0].set_ylabel("vx (m/s)")
|
||||||
|
axes[0].legend(fontsize=9)
|
||||||
|
axes[0].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
axes[1].plot(time, targets[:, 1], label="GT vy", color="C0", alpha=0.7, linewidth=0.8)
|
||||||
|
axes[1].plot(time, preds[:, 1], label="Pred vy", color="C1", alpha=0.7, linewidth=0.8)
|
||||||
|
axes[1].set_ylabel("vy (m/s)")
|
||||||
|
axes[1].set_xlabel("Frame index")
|
||||||
|
axes[1].legend(fontsize=9)
|
||||||
|
axes[1].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
title = f"Body-frame Velocity — {scene_name}"
|
||||||
|
if metrics:
|
||||||
|
title += f" | RMSE vx={metrics['rmse_vx']:.3f} vy={metrics['rmse_vy']:.3f}"
|
||||||
|
fig.suptitle(title)
|
||||||
|
try:
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(save_dir / f"{scene_name}_timeseries.png", dpi=150, bbox_inches="tight")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" [WARN] Failed to save {scene_name}_timeseries.png: {e}")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# ── Scatter plot ──
|
||||||
|
fig, axes = plt.subplots(1, 2, figsize=(10, 4.5))
|
||||||
|
|
||||||
|
for ax, pred, target, label in zip(
|
||||||
|
axes, [preds[:, 0], preds[:, 1]], [targets[:, 0], targets[:, 1]], ["vx", "vy"]
|
||||||
|
):
|
||||||
|
ax.scatter(target, pred, s=3, alpha=0.4, c="C1", edgecolors="none")
|
||||||
|
lim_min = min(target.min(), pred.min())
|
||||||
|
lim_max = max(target.max(), pred.max())
|
||||||
|
margin = (lim_max - lim_min) * 0.05
|
||||||
|
ax.plot([lim_min - margin, lim_max + margin],
|
||||||
|
[lim_min - margin, lim_max + margin], "r--", alpha=0.5, linewidth=1)
|
||||||
|
ax.set_xlabel(f"GT {label} (m/s)")
|
||||||
|
ax.set_ylabel(f"Pred {label} (m/s)")
|
||||||
|
ax.set_aspect("equal")
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
if metrics:
|
||||||
|
ax.set_title(f"{label} — RMSE: {metrics[f'rmse_{label}']:.4f}")
|
||||||
|
|
||||||
|
fig.suptitle(f"Scatter — {scene_name}")
|
||||||
|
try:
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(save_dir / f"{scene_name}_scatter.png", dpi=150, bbox_inches="tight")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" [WARN] Failed to save {scene_name}_scatter.png: {e}")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_global_comparison(
|
||||||
|
results: Dict,
|
||||||
|
save_dir: Path,
|
||||||
|
):
|
||||||
|
"""Generate a summary figure comparing all scenes."""
|
||||||
|
import matplotlib
|
||||||
|
matplotlib.use("Agg")
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
save_dir = Path(save_dir)
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
scenes = list(results["per_scene"].keys())
|
||||||
|
if not scenes:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Bar chart: RMSE vx and vy per scene
|
||||||
|
rmse_vx = [results["per_scene"][s]["metrics"].get("rmse_vx", 0) for s in scenes]
|
||||||
|
rmse_vy = [results["per_scene"][s]["metrics"].get("rmse_vy", 0) for s in scenes]
|
||||||
|
rmse_xy = [results["per_scene"][s]["metrics"].get("rmse_xy", 0) for s in scenes]
|
||||||
|
|
||||||
|
x = np.arange(len(scenes))
|
||||||
|
width = 0.25
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 4.5))
|
||||||
|
bars1 = ax.bar(x - width, rmse_vx, width, label="RMSE vx", alpha=0.8)
|
||||||
|
bars2 = ax.bar(x, rmse_vy, width, label="RMSE vy", alpha=0.8)
|
||||||
|
bars3 = ax.bar(x + width, rmse_xy, width, label="RMSE xy", alpha=0.8)
|
||||||
|
|
||||||
|
ax.set_xticks(x)
|
||||||
|
ax.set_xticklabels(scenes, rotation=15, ha="right")
|
||||||
|
ax.set_ylabel("RMSE (m/s)")
|
||||||
|
ax.set_title("Per-Scene RMSE Comparison")
|
||||||
|
ax.legend(fontsize=9)
|
||||||
|
ax.grid(True, alpha=0.3, axis="y")
|
||||||
|
|
||||||
|
# Annotate values
|
||||||
|
for bars in [bars1, bars2, bars3]:
|
||||||
|
for bar in bars:
|
||||||
|
height = bar.get_height()
|
||||||
|
ax.annotate(f"{height:.3f}",
|
||||||
|
xy=(bar.get_x() + bar.get_width() / 2, height),
|
||||||
|
xytext=(0, 2), textcoords="offset points",
|
||||||
|
ha="center", va="bottom", fontsize=7)
|
||||||
|
|
||||||
|
try:
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(save_dir / "per_scene_rmse.png", dpi=150, bbox_inches="tight")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" [WARN] Failed to save per_scene_rmse.png: {e}")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Results serialization ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
results: Dict,
|
||||||
|
save_dir: Path,
|
||||||
|
checkpoint_name: str = "model",
|
||||||
|
):
|
||||||
|
"""Save all evaluation results to disk."""
|
||||||
|
save_dir = Path(save_dir)
|
||||||
|
plots_dir = save_dir / "plots"
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
plots_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# ── 1. Global metrics ──
|
||||||
|
global_m = results["global"]
|
||||||
|
lines = [
|
||||||
|
f"Benchmark Results: {checkpoint_name}",
|
||||||
|
f"{'=' * 50}",
|
||||||
|
f"Total samples: {global_m.get('count', '?')}",
|
||||||
|
"",
|
||||||
|
"── Global Metrics ──",
|
||||||
|
f" RMSE vx: {global_m.get('rmse_vx', 'N/A'):.4f} m/s",
|
||||||
|
f" RMSE vy: {global_m.get('rmse_vy', 'N/A'):.4f} m/s",
|
||||||
|
f" RMSE xy: {global_m.get('rmse_xy', 'N/A'):.4f} m/s",
|
||||||
|
f" MAE vx: {global_m.get('mae_vx', 'N/A'):.4f} m/s",
|
||||||
|
f" MAE vy: {global_m.get('mae_vy', 'N/A'):.4f} m/s",
|
||||||
|
f" MAE xy: {global_m.get('mae_xy', 'N/A'):.4f} m/s",
|
||||||
|
f" R² vx: {global_m.get('r2_vx', 'N/A'):.4f}",
|
||||||
|
f" R² vy: {global_m.get('r2_vy', 'N/A'):.4f}",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
# ── 2. Per-scene metrics ──
|
||||||
|
lines.append("── Per-Scene Metrics ──")
|
||||||
|
lines.append(f" {'Scene':<22} {'RMSE vx':>10} {'RMSE vy':>10} {'RMSE xy':>10} "
|
||||||
|
f"{'MAE vx':>10} {'MAE vy':>10} {'R² vx':>8} {'R² vy':>8} {'Samples':>8}")
|
||||||
|
lines.append(" " + "-" * 96)
|
||||||
|
|
||||||
|
for scene_name, scene_result in results["per_scene"].items():
|
||||||
|
m = scene_result["metrics"]
|
||||||
|
lines.append(
|
||||||
|
f" {scene_name:<22} {m.get('rmse_vx', 0):>10.4f} {m.get('rmse_vy', 0):>10.4f} "
|
||||||
|
f"{m.get('rmse_xy', 0):>10.4f} {m.get('mae_vx', 0):>10.4f} "
|
||||||
|
f"{m.get('mae_vy', 0):>10.4f} {m.get('r2_vx', 0):>8.4f} "
|
||||||
|
f"{m.get('r2_vy', 0):>8.4f} {m.get('count', 0):>8}"
|
||||||
|
)
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
# ── 3. By difficulty ──
|
||||||
|
lines.append("── By Difficulty ──")
|
||||||
|
for diff, metrics in results["by_difficulty"].items():
|
||||||
|
lines.append(f" {diff:<10} RMSE vx={metrics.get('rmse_vx', 0):.4f} "
|
||||||
|
f"RMSE vy={metrics.get('rmse_vy', 0):.4f} "
|
||||||
|
f"RMSE xy={metrics.get('rmse_xy', 0):.4f} "
|
||||||
|
f"(samples={metrics.get('count', 0)})")
|
||||||
|
|
||||||
|
summary_text = "\n".join(lines)
|
||||||
|
with open(save_dir / "summary.txt", "w") as f:
|
||||||
|
f.write(summary_text)
|
||||||
|
print(summary_text)
|
||||||
|
|
||||||
|
# ── 4. CSV: per-scene metrics ──
|
||||||
|
import csv
|
||||||
|
csv_path = save_dir / "per_scene_metrics.csv"
|
||||||
|
fieldnames = ["scene", "rmse_vx", "rmse_vy", "rmse_xy", "mae_vx", "mae_vy",
|
||||||
|
"mae_xy", "r2_vx", "r2_vy", "count"]
|
||||||
|
with open(csv_path, "w", newline="") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
writer.writeheader()
|
||||||
|
for scene_name, scene_result in results["per_scene"].items():
|
||||||
|
row = {"scene": scene_name, **scene_result["metrics"]}
|
||||||
|
writer.writerow(row)
|
||||||
|
print(f"Per-scene CSV: {csv_path}")
|
||||||
|
|
||||||
|
# ── 5. Global metrics CSV (single row) ──
|
||||||
|
csv_path_global = save_dir / "metrics.csv"
|
||||||
|
with open(csv_path_global, "w", newline="") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=["checkpoint"] + list(global_m.keys()))
|
||||||
|
writer.writeheader()
|
||||||
|
row = {"checkpoint": checkpoint_name, **global_m}
|
||||||
|
writer.writerow(row)
|
||||||
|
print(f"Global metrics CSV: {csv_path_global}")
|
||||||
|
|
||||||
|
# ── 6. Plots ──
|
||||||
|
for scene_name, scene_result in results["per_scene"].items():
|
||||||
|
if scene_result["preds"].shape[0] > 0:
|
||||||
|
plot_scene_comparison(
|
||||||
|
scene_result["preds"],
|
||||||
|
scene_result["targets"],
|
||||||
|
scene_name,
|
||||||
|
plots_dir,
|
||||||
|
metrics=scene_result["metrics"],
|
||||||
|
)
|
||||||
|
|
||||||
|
plot_global_comparison(results, plots_dir)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
plt.close("all")
|
||||||
|
|
||||||
|
print(f"\nAll results saved to: {save_dir.resolve()}")
|
||||||
44
download_davis_gt_rosbags.sh
Executable file
44
download_davis_gt_rosbags.sh
Executable file
@@ -0,0 +1,44 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Download script for UZH FPV DAVIS rosbags with ground truth
|
||||||
|
# Skips files that already exist in the current directory.
|
||||||
|
# Uses the final download.ifi.uzh.ch URLs directly (avoids 301 redirect chain).
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
BASE="https://download.ifi.uzh.ch/rpg/web/datasets/uzh-fpv-newer-versions/v3"
|
||||||
|
|
||||||
|
URLS=(
|
||||||
|
# Indoor forward facing
|
||||||
|
"${BASE}/indoor_forward_3_davis_with_gt.bag"
|
||||||
|
"${BASE}/indoor_forward_5_davis_with_gt.bag"
|
||||||
|
"${BASE}/indoor_forward_6_davis_with_gt.bag"
|
||||||
|
"${BASE}/indoor_forward_9_davis_with_gt.bag"
|
||||||
|
"${BASE}/indoor_forward_10_davis_with_gt.bag"
|
||||||
|
|
||||||
|
# Indoor 45 degree downward
|
||||||
|
"${BASE}/indoor_45_2_davis_with_gt.bag"
|
||||||
|
"${BASE}/indoor_45_4_davis_with_gt.bag"
|
||||||
|
"${BASE}/indoor_45_9_davis_with_gt.bag"
|
||||||
|
"${BASE}/indoor_45_12_davis_with_gt.bag"
|
||||||
|
"${BASE}/indoor_45_13_davis_with_gt.bag"
|
||||||
|
"${BASE}/indoor_45_14_davis_with_gt.bag"
|
||||||
|
|
||||||
|
# Outdoor forward facing
|
||||||
|
"${BASE}/outdoor_forward_1_davis_with_gt.bag"
|
||||||
|
"${BASE}/outdoor_forward_5_davis_with_gt.bag"
|
||||||
|
|
||||||
|
# Outdoor 45 degree downward
|
||||||
|
"${BASE}/outdoor_45_1_davis_with_gt.bag"
|
||||||
|
)
|
||||||
|
|
||||||
|
for url in "${URLS[@]}"; do
|
||||||
|
filename=$(basename "$url")
|
||||||
|
if [ -f "$filename" ]; then
|
||||||
|
echo "SKIP: $filename already exists"
|
||||||
|
else
|
||||||
|
echo "DOWNLOAD: $filename"
|
||||||
|
wget --continue --show-progress "$url" -O "$filename"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Done. All DAVIS with-GT rosbags downloaded."
|
||||||
17
requirements.txt
Normal file
17
requirements.txt
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# ===== Core =====
|
||||||
|
numpy
|
||||||
|
scipy
|
||||||
|
|
||||||
|
# ===== PyTorch (CUDA 12.4) =====
|
||||||
|
# Install via: pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
|
||||||
|
torch>=2.4
|
||||||
|
torchvision
|
||||||
|
|
||||||
|
# ===== Data loading =====
|
||||||
|
webdataset
|
||||||
|
opencv-python
|
||||||
|
|
||||||
|
# ===== Training / logging =====
|
||||||
|
tensorboard
|
||||||
|
tqdm
|
||||||
|
matplotlib
|
||||||
396
rosbag2wds.py
Normal file
396
rosbag2wds.py
Normal file
@@ -0,0 +1,396 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
ROS bag to WebDataset converter for DAVIS dataset
|
||||||
|
Extracts: grayscale images, IMU sequence, ground truth poses and velocities
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python convert_bag_to_webdataset.py --bag <path_to.bag> --output <output_dir> --name <dataset_name>
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
import numpy as np
|
||||||
|
import rosbag
|
||||||
|
from cv_bridge import CvBridge
|
||||||
|
import cv2
|
||||||
|
import webdataset as wds
|
||||||
|
from tqdm import tqdm
|
||||||
|
from scipy.spatial.transform import Rotation as R
|
||||||
|
|
||||||
|
|
||||||
|
class BagToWebDataset:
|
||||||
|
def __init__(self, bag_path: str, output_dir: str, dataset_name: str,
|
||||||
|
shard_size: int = 2000, image_width: int = 320, image_height: int = 240):
|
||||||
|
self.bag_path = Path(bag_path)
|
||||||
|
self.output_dir = Path(output_dir) / dataset_name
|
||||||
|
self.dataset_name = dataset_name
|
||||||
|
self.shard_size = shard_size
|
||||||
|
self.image_width = image_width
|
||||||
|
self.image_height = image_height
|
||||||
|
|
||||||
|
self.bridge = CvBridge()
|
||||||
|
|
||||||
|
# Data containers
|
||||||
|
self.images: List[Tuple[float, np.ndarray]] = [] # (timestamp, image)
|
||||||
|
self.imu_timestamps: List[float] = []
|
||||||
|
self.imu_acc: List[np.ndarray] = [] # (ax, ay, az)
|
||||||
|
self.imu_gyro: List[np.ndarray] = [] # (gx, gy, gz)
|
||||||
|
self.gt_timestamps: List[float] = []
|
||||||
|
self.gt_poses: List[np.ndarray] = [] # (x, y, z, qx, qy, qz, qw)
|
||||||
|
self.gt_velocities: List[np.ndarray] = [] # (vx, vy, vz, wx, wy, wz)
|
||||||
|
|
||||||
|
def extract_all_data(self):
|
||||||
|
"""Extract all data from ROS bag"""
|
||||||
|
print(f"Opening bag: {self.bag_path}")
|
||||||
|
bag = rosbag.Bag(str(self.bag_path), 'r')
|
||||||
|
|
||||||
|
# Count messages for progress bar
|
||||||
|
topic_counts = {topic: bag.get_message_count(topic) for topic in
|
||||||
|
['/dvs/image_raw', '/dvs/imu', '/groundtruth/odometry']}
|
||||||
|
total_msgs = sum(topic_counts.values())
|
||||||
|
|
||||||
|
print(f"Topics: {topic_counts}")
|
||||||
|
|
||||||
|
with tqdm(total=total_msgs, desc="Extracting messages") as pbar:
|
||||||
|
for topic, msg, t in bag.read_messages(topics=['/dvs/image_raw', '/dvs/imu', '/groundtruth/odometry']):
|
||||||
|
|
||||||
|
if topic == '/dvs/image_raw':
|
||||||
|
self._process_image(msg, t) # 传入 t
|
||||||
|
|
||||||
|
elif topic == '/dvs/imu':
|
||||||
|
self._process_imu(msg, t) # 传入 t
|
||||||
|
|
||||||
|
elif topic == '/groundtruth/odometry':
|
||||||
|
self._process_odometry(msg, t) # 传入 t
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
bag.close()
|
||||||
|
|
||||||
|
# Post-processing: compute velocities from poses if not directly available
|
||||||
|
self._ensure_velocities()
|
||||||
|
|
||||||
|
# Print statistics
|
||||||
|
print(f"\nExtraction completed:")
|
||||||
|
print(f" Images: {len(self.images)}")
|
||||||
|
print(f" IMU messages: {len(self.imu_timestamps)}")
|
||||||
|
print(f" Ground truth poses: {len(self.gt_timestamps)}")
|
||||||
|
print(f" Ground truth velocities: {len(self.gt_velocities)}")
|
||||||
|
|
||||||
|
def crop_to_gt_time_range(self):
|
||||||
|
"""裁剪所有数据,只保留 GT 时间范围内的部分"""
|
||||||
|
|
||||||
|
if len(self.gt_timestamps) == 0:
|
||||||
|
print("Warning: No GT data found, skipping crop")
|
||||||
|
return
|
||||||
|
|
||||||
|
gt_start = min(self.gt_timestamps)
|
||||||
|
gt_end = max(self.gt_timestamps)
|
||||||
|
|
||||||
|
print(f"\nCropping to GT time range: {gt_start:.3f} - {gt_end:.3f} ({gt_end - gt_start:.1f}s)")
|
||||||
|
|
||||||
|
# 裁剪图像
|
||||||
|
original_img_count = len(self.images)
|
||||||
|
self.images = [(ts, img) for ts, img in self.images if gt_start <= ts <= gt_end]
|
||||||
|
print(f" Images: {original_img_count} -> {len(self.images)}")
|
||||||
|
|
||||||
|
# 裁剪 IMU
|
||||||
|
original_imu_count = len(self.imu_timestamps)
|
||||||
|
imu_filtered = [(ts, acc, gyro) for ts, acc, gyro
|
||||||
|
in zip(self.imu_timestamps, self.imu_acc, self.imu_gyro)
|
||||||
|
if gt_start <= ts <= gt_end]
|
||||||
|
|
||||||
|
if imu_filtered:
|
||||||
|
self.imu_timestamps = [item[0] for item in imu_filtered]
|
||||||
|
self.imu_acc = [item[1] for item in imu_filtered]
|
||||||
|
self.imu_gyro = [item[2] for item in imu_filtered]
|
||||||
|
print(f" IMU: {original_imu_count} -> {len(self.imu_timestamps)}")
|
||||||
|
|
||||||
|
# GT 数据本身已经在范围内,不需要裁剪
|
||||||
|
print(f" GT: {len(self.gt_timestamps)} (unchanged)")
|
||||||
|
|
||||||
|
def _process_image(self, msg, t):
|
||||||
|
"""Process grayscale image message using system time"""
|
||||||
|
try:
|
||||||
|
# Convert ROS image to OpenCV format
|
||||||
|
cv_img = self.bridge.imgmsg_to_cv2(msg, desired_encoding='mono8')
|
||||||
|
|
||||||
|
# Resize if needed
|
||||||
|
if self.image_width and self.image_height:
|
||||||
|
cv_img = cv2.resize(cv_img, (self.image_width, self.image_height),
|
||||||
|
interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
|
# 使用系统物理时间,而不是 msg.header.stamp
|
||||||
|
timestamp = t.to_sec()
|
||||||
|
self.images.append((timestamp, cv_img))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing image: {e}")
|
||||||
|
|
||||||
|
def _process_imu(self, msg, t):
|
||||||
|
"""Process IMU message using system time"""
|
||||||
|
timestamp = t.to_sec() # 使用系统物理时间
|
||||||
|
|
||||||
|
# Linear acceleration (m/s^2)
|
||||||
|
acc = np.array([msg.linear_acceleration.x,
|
||||||
|
msg.linear_acceleration.y,
|
||||||
|
msg.linear_acceleration.z], dtype=np.float32)
|
||||||
|
|
||||||
|
# Angular velocity (rad/s)
|
||||||
|
gyro = np.array([msg.angular_velocity.x,
|
||||||
|
msg.angular_velocity.y,
|
||||||
|
msg.angular_velocity.z], dtype=np.float32)
|
||||||
|
|
||||||
|
self.imu_timestamps.append(timestamp)
|
||||||
|
self.imu_acc.append(acc)
|
||||||
|
self.imu_gyro.append(gyro)
|
||||||
|
|
||||||
|
def _process_odometry(self, msg, t):
|
||||||
|
"""Process ground truth odometry using system time"""
|
||||||
|
timestamp = t.to_sec() # 使用系统物理时间
|
||||||
|
|
||||||
|
# Position (x, y, z)
|
||||||
|
pos = np.array([msg.pose.pose.position.x,
|
||||||
|
msg.pose.pose.position.y,
|
||||||
|
msg.pose.pose.position.z], dtype=np.float32)
|
||||||
|
|
||||||
|
# Orientation (qx, qy, qz, qw) - already normalized
|
||||||
|
quat = np.array([msg.pose.pose.orientation.x,
|
||||||
|
msg.pose.pose.orientation.y,
|
||||||
|
msg.pose.pose.orientation.z,
|
||||||
|
msg.pose.pose.orientation.w], dtype=np.float32)
|
||||||
|
|
||||||
|
pose = np.concatenate([pos, quat])
|
||||||
|
self.gt_timestamps.append(timestamp)
|
||||||
|
self.gt_poses.append(pose)
|
||||||
|
|
||||||
|
# Velocity: always compute from pose differences in post-processing
|
||||||
|
vel = None
|
||||||
|
|
||||||
|
self.gt_velocities.append(vel)
|
||||||
|
|
||||||
|
def _ensure_velocities(self):
|
||||||
|
# 数据集中 twist 数据为 0 直接利用时间戳差值
|
||||||
|
# """Compute velocities from pose differences if not directly available"""
|
||||||
|
# # Check if velocities are missing
|
||||||
|
# missing_velocities = any(v is None for v in self.gt_velocities)
|
||||||
|
|
||||||
|
# if not missing_velocities:
|
||||||
|
# return
|
||||||
|
|
||||||
|
print("Computing velocities from pose differences...")
|
||||||
|
|
||||||
|
computed_velocities = []
|
||||||
|
for i in range(len(self.gt_timestamps)):
|
||||||
|
if i == 0:
|
||||||
|
# Use forward difference for first frame
|
||||||
|
if len(self.gt_timestamps) > 1:
|
||||||
|
dt = self.gt_timestamps[1] - self.gt_timestamps[0]
|
||||||
|
if dt > 0:
|
||||||
|
# Linear velocity
|
||||||
|
v_lin = (self.gt_poses[1][:3] - self.gt_poses[0][:3]) / dt
|
||||||
|
# Angular velocity (from quaternion difference)
|
||||||
|
q0 = self.gt_poses[0][3:7]
|
||||||
|
q1 = self.gt_poses[1][3:7]
|
||||||
|
dq = R.from_quat(q1) * R.from_quat(q0).inv()
|
||||||
|
v_ang = dq.as_rotvec() / dt
|
||||||
|
computed_velocities.append(np.concatenate([v_lin, v_ang]))
|
||||||
|
else:
|
||||||
|
computed_velocities.append(np.zeros(6, dtype=np.float32))
|
||||||
|
else:
|
||||||
|
computed_velocities.append(np.zeros(6, dtype=np.float32))
|
||||||
|
else:
|
||||||
|
# Use backward difference
|
||||||
|
dt = self.gt_timestamps[i] - self.gt_timestamps[i-1]
|
||||||
|
if dt > 0:
|
||||||
|
v_lin = (self.gt_poses[i][:3] - self.gt_poses[i-1][:3]) / dt
|
||||||
|
q0 = self.gt_poses[i-1][3:7]
|
||||||
|
q1 = self.gt_poses[i][3:7]
|
||||||
|
dq = R.from_quat(q1) * R.from_quat(q0).inv()
|
||||||
|
v_ang = dq.as_rotvec() / dt
|
||||||
|
computed_velocities.append(np.concatenate([v_lin, v_ang]))
|
||||||
|
else:
|
||||||
|
computed_velocities.append(np.zeros(6, dtype=np.float32))
|
||||||
|
|
||||||
|
# Replace missing velocities
|
||||||
|
for i in range(len(self.gt_velocities)):
|
||||||
|
if self.gt_velocities[i] is None:
|
||||||
|
self.gt_velocities[i] = computed_velocities[i]
|
||||||
|
|
||||||
|
def save_imu_sequence(self):
|
||||||
|
"""Save IMU sequence as NPZ file"""
|
||||||
|
imu_data = {
|
||||||
|
'timestamps': np.array(self.imu_timestamps, dtype=np.float64),
|
||||||
|
'accelerations': np.array(self.imu_acc, dtype=np.float32),
|
||||||
|
'angular_velocities': np.array(self.imu_gyro, dtype=np.float32)
|
||||||
|
}
|
||||||
|
|
||||||
|
imu_path = self.output_dir / 'imu_sequence.npz'
|
||||||
|
imu_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
np.savez_compressed(imu_path, **imu_data)
|
||||||
|
|
||||||
|
print(f"Saved IMU sequence: {imu_path}")
|
||||||
|
return imu_path
|
||||||
|
|
||||||
|
def align_ground_truth_to_images(self) -> List[Tuple[float, np.ndarray, np.ndarray, np.ndarray]]:
|
||||||
|
"""Align ground truth (pose + velocity) to each image using nearest timestamp"""
|
||||||
|
aligned_gt = []
|
||||||
|
|
||||||
|
gt_timestamps = np.array(self.gt_timestamps)
|
||||||
|
gt_poses = np.array(self.gt_poses)
|
||||||
|
gt_vels = np.array(self.gt_velocities)
|
||||||
|
|
||||||
|
for img_ts, img in tqdm(self.images, desc="Aligning ground truth to images"):
|
||||||
|
idx = np.argmin(np.abs(gt_timestamps - img_ts))
|
||||||
|
time_diff = abs(gt_timestamps[idx] - img_ts)
|
||||||
|
|
||||||
|
if time_diff < 0.1:
|
||||||
|
aligned_gt.append((img_ts, img, gt_poses[idx], gt_vels[idx])) # 保存图像
|
||||||
|
|
||||||
|
return aligned_gt
|
||||||
|
|
||||||
|
def save_as_webdataset(self, aligned_gt: List[Tuple[float, np.ndarray, np.ndarray, np.ndarray]]):
|
||||||
|
"""Save images and aligned ground truth as WebDataset tar files"""
|
||||||
|
|
||||||
|
num_shards = (len(aligned_gt) + self.shard_size - 1) // self.shard_size
|
||||||
|
|
||||||
|
print(f"Saving {len(aligned_gt)} samples into {num_shards} shards...")
|
||||||
|
|
||||||
|
for shard_idx in range(num_shards):
|
||||||
|
start_idx = shard_idx * self.shard_size
|
||||||
|
end_idx = min((shard_idx + 1) * self.shard_size, len(aligned_gt))
|
||||||
|
|
||||||
|
tar_path = self.output_dir / f'shard_{shard_idx:04d}.tar'
|
||||||
|
|
||||||
|
with wds.TarWriter(str(tar_path)) as sink:
|
||||||
|
for local_idx, (img_ts, img, pose, vel) in enumerate(aligned_gt):
|
||||||
|
|
||||||
|
# Encode image as JPEG
|
||||||
|
_, img_encoded = cv2.imencode('.jpg', img,
|
||||||
|
[cv2.IMWRITE_JPEG_QUALITY, 85])
|
||||||
|
img_bytes = img_encoded.tobytes()
|
||||||
|
|
||||||
|
# Prepare metadata
|
||||||
|
sample_key = f'frame_{local_idx:08d}'
|
||||||
|
|
||||||
|
# Write to tar
|
||||||
|
sink.write({
|
||||||
|
'__key__': sample_key,
|
||||||
|
'jpg': img_bytes,
|
||||||
|
'ts': np.array([img_ts], dtype=np.float64).tobytes(),
|
||||||
|
'pose': pose.astype(np.float32).tobytes(),
|
||||||
|
'vel': vel.astype(np.float32).tobytes()
|
||||||
|
})
|
||||||
|
|
||||||
|
print(f" Saved {tar_path} ({end_idx - start_idx} samples)")
|
||||||
|
|
||||||
|
def save_metadata(self):
|
||||||
|
"""Save dataset metadata"""
|
||||||
|
metadata = {
|
||||||
|
'dataset_name': self.dataset_name,
|
||||||
|
'source_bag': str(self.bag_path),
|
||||||
|
'num_images': len(self.images),
|
||||||
|
'num_imu_messages': len(self.imu_timestamps),
|
||||||
|
'num_ground_truth': len(self.gt_timestamps),
|
||||||
|
'image_size': [self.image_width, self.image_height],
|
||||||
|
'imu_frequency_hz': len(self.imu_timestamps) / (self.imu_timestamps[-1] - self.imu_timestamps[0]) if len(self.imu_timestamps) > 1 else 0,
|
||||||
|
'camera_frequency_hz': len(self.images) / (self.images[-1][0] - self.images[0][0]) if len(self.images) > 1 else 0,
|
||||||
|
'gt_frequency_hz': len(self.gt_timestamps) / (self.gt_timestamps[-1] - self.gt_timestamps[0]) if len(self.gt_timestamps) > 1 else 0,
|
||||||
|
'coordinate_system': 'horizontal (z aligned with gravity, assumed from GT)',
|
||||||
|
'velocity_dimensions': 6, # (vx, vy, vz, wx, wy, wz)
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata_path = self.output_dir / 'metadata.json'
|
||||||
|
with open(metadata_path, 'w') as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
|
||||||
|
print(f"Saved metadata: {metadata_path}")
|
||||||
|
|
||||||
|
def convert(self):
|
||||||
|
"""Main conversion pipeline"""
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Converting: {self.bag_path.name}")
|
||||||
|
print(f"Output: {self.output_dir}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
# Step 1: Extract all data from bag
|
||||||
|
self.extract_all_data()
|
||||||
|
|
||||||
|
# 裁剪掉无 GT 的时间段
|
||||||
|
self.crop_to_gt_time_range()
|
||||||
|
|
||||||
|
# Step 2: Save IMU sequence
|
||||||
|
self.save_imu_sequence()
|
||||||
|
|
||||||
|
# # Step 3: Align ground truth to images
|
||||||
|
aligned_gt = self.align_ground_truth_to_images()
|
||||||
|
|
||||||
|
if len(aligned_gt) == 0:
|
||||||
|
print("Error: No aligned ground truth found!")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# # Step 4: Save as WebDataset
|
||||||
|
self.save_as_webdataset(aligned_gt)
|
||||||
|
|
||||||
|
# # Step 5: Save metadata
|
||||||
|
self.save_metadata()
|
||||||
|
|
||||||
|
self.diagnose_timestamps()
|
||||||
|
|
||||||
|
print(f"\n✅ Conversion completed for {self.bag_path.name}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def diagnose_timestamps(self):
|
||||||
|
"""Print timestamp ranges for debugging"""
|
||||||
|
img_timestamps = [t for t, _ in self.images]
|
||||||
|
gt_timestamps = self.gt_timestamps
|
||||||
|
|
||||||
|
print(f"Image timestamps: {min(img_timestamps):.3f} - {max(img_timestamps):.3f}")
|
||||||
|
print(f"GT timestamps: {min(gt_timestamps):.3f} - {max(gt_timestamps):.3f}")
|
||||||
|
print(f"Image duration: {max(img_timestamps) - min(img_timestamps):.3f}s")
|
||||||
|
print(f"GT duration: {max(gt_timestamps) - min(gt_timestamps):.3f}s")
|
||||||
|
|
||||||
|
# Check if there's a constant offset
|
||||||
|
if len(img_timestamps) > 0 and len(gt_timestamps) > 0:
|
||||||
|
offset = gt_timestamps[0] - img_timestamps[0]
|
||||||
|
print(f"Initial offset (first GT - first image): {offset:.3f}s")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Convert ROS bag to WebDataset format')
|
||||||
|
parser.add_argument('--bag', type=str, required=True, help='Path to ROS bag file')
|
||||||
|
parser.add_argument('--output', type=str, default='./dataset', help='Output directory')
|
||||||
|
parser.add_argument('--name', type=str, default=None, help='Dataset name (default: bag filename without extension)')
|
||||||
|
parser.add_argument('--shard_size', type=int, default=2000, help='Number of samples per shard')
|
||||||
|
parser.add_argument('--width', type=int, default=320, help='Image width (resize)')
|
||||||
|
parser.add_argument('--height', type=int, default=240, help='Image height (resize)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Validate inputs
|
||||||
|
if not os.path.exists(args.bag):
|
||||||
|
print(f"Error: Bag file not found: {args.bag}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Set dataset name
|
||||||
|
if args.name is None:
|
||||||
|
args.name = Path(args.bag).stem
|
||||||
|
|
||||||
|
# Run conversion
|
||||||
|
converter = BagToWebDataset(
|
||||||
|
bag_path=args.bag,
|
||||||
|
output_dir=args.output,
|
||||||
|
dataset_name=args.name,
|
||||||
|
shard_size=args.shard_size,
|
||||||
|
image_width=args.width,
|
||||||
|
image_height=args.height
|
||||||
|
)
|
||||||
|
|
||||||
|
converter.convert()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
103
src/event_utils.py
Normal file
103
src/event_utils.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""
|
||||||
|
Event camera simulation utilities for ML preprocessing.
|
||||||
|
|
||||||
|
Core logic extracted from EventCameraSimulator (test.py).
|
||||||
|
Designed for frame-by-frame preprocessing in training pipelines.
|
||||||
|
|
||||||
|
Output:
|
||||||
|
events_binary: (-1, 0, +1) hard threshold decision
|
||||||
|
events_strength: [-1, 1] continuous change intensity (clipped & normalized)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class EventProcessor:
|
||||||
|
"""Lightweight event computation module. No visualization, no OpenCV dependency."""
|
||||||
|
|
||||||
|
def __init__(self, threshold=0.1, use_log=True, auto_threshold=False):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.use_log = use_log
|
||||||
|
self.auto_threshold = auto_threshold
|
||||||
|
|
||||||
|
self.prev_brightness = None
|
||||||
|
self.change_history = deque(maxlen=100)
|
||||||
|
self.threshold_scale = 1.5
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Clear temporal state (call on video/reset)."""
|
||||||
|
self.prev_brightness = None
|
||||||
|
self.change_history.clear()
|
||||||
|
|
||||||
|
def _to_grayscale(self, frame):
|
||||||
|
"""Convert frame to grayscale float32."""
|
||||||
|
if frame.ndim == 3:
|
||||||
|
# RGB/HWC -> gray via luminance weights
|
||||||
|
gray = 0.299 * frame[..., 0] + 0.587 * frame[..., 1] + 0.114 * frame[..., 2]
|
||||||
|
else:
|
||||||
|
gray = frame
|
||||||
|
return gray.astype(np.float32)
|
||||||
|
|
||||||
|
def _compute_change(self, brightness):
|
||||||
|
"""Compute log or linear brightness change."""
|
||||||
|
if self.use_log:
|
||||||
|
eps = 1e-3
|
||||||
|
return np.log(brightness + eps) - np.log(self.prev_brightness + eps)
|
||||||
|
else:
|
||||||
|
return brightness - self.prev_brightness
|
||||||
|
|
||||||
|
def _update_auto_threshold(self, change):
|
||||||
|
"""Adapt threshold based on global change statistics."""
|
||||||
|
abs_change = np.abs(change)
|
||||||
|
mean_change = np.mean(abs_change)
|
||||||
|
self.change_history.append(mean_change)
|
||||||
|
|
||||||
|
if len(self.change_history) > 10:
|
||||||
|
avg_change = np.mean(self.change_history)
|
||||||
|
new_threshold = max(avg_change * self.threshold_scale, 0.01)
|
||||||
|
self.threshold = self.threshold * 0.9 + new_threshold * 0.1
|
||||||
|
|
||||||
|
if self.use_log:
|
||||||
|
self.threshold = np.clip(self.threshold, 0.01, 0.5)
|
||||||
|
else:
|
||||||
|
self.threshold = np.clip(self.threshold, 1, 50)
|
||||||
|
|
||||||
|
def __call__(self, frame):
|
||||||
|
"""
|
||||||
|
Process a single frame.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame: np.ndarray, shape (H, W) or (H, W, C), uint8 or float.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
events_binary: np.ndarray (H, W), values in {-1, 0, +1}
|
||||||
|
events_strength: np.ndarray (H, W), values in [-1, 1]
|
||||||
|
event_count: int, number of non-zero events
|
||||||
|
"""
|
||||||
|
brightness = self._to_grayscale(frame)
|
||||||
|
|
||||||
|
# First frame — initialise, no events
|
||||||
|
if self.prev_brightness is None:
|
||||||
|
self.prev_brightness = brightness
|
||||||
|
h, w = brightness.shape
|
||||||
|
return np.zeros((h, w), dtype=np.int8), np.zeros((h, w), dtype=np.float32), 0
|
||||||
|
|
||||||
|
change = self._compute_change(brightness)
|
||||||
|
|
||||||
|
if self.auto_threshold:
|
||||||
|
self._update_auto_threshold(change)
|
||||||
|
|
||||||
|
# Binary events
|
||||||
|
events_binary = np.zeros_like(brightness, dtype=np.int8)
|
||||||
|
events_binary[change > self.threshold] = 1
|
||||||
|
events_binary[change < -self.threshold] = -1
|
||||||
|
|
||||||
|
# Continuous strength: clip to [-threshold, threshold] then normalise to [-1, 1]
|
||||||
|
events_strength = np.clip(change, -self.threshold, self.threshold) / self.threshold
|
||||||
|
|
||||||
|
event_count = int(np.count_nonzero(events_binary))
|
||||||
|
|
||||||
|
self.prev_brightness = brightness
|
||||||
|
|
||||||
|
return events_binary, events_strength, event_count
|
||||||
198
src/velocity_prediction/README.md
Normal file
198
src/velocity_prediction/README.md
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# Velocity Prediction from Event Frames + Attitude
|
||||||
|
|
||||||
|
基于 UZH-FPV 数据集,通过模拟事件帧 + 姿态输入预测机体速度(机体系 vx, vy)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
uzh_fpv/
|
||||||
|
├── dataset/ # UZH-FPV 数据集(WebDataset shards)
|
||||||
|
│ ├── indoor_forward_3/
|
||||||
|
│ ├── indoor_forward_5/
|
||||||
|
│ ├── ...
|
||||||
|
│ └── outdoor_45_1/
|
||||||
|
├── src/
|
||||||
|
│ ├── event_utils.py # 模拟事件帧生成(已有模块)
|
||||||
|
│ └── velocity_prediction/ # 本工程
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── config.py # 全局配置
|
||||||
|
│ ├── utils.py # 四元数运算、坐标变换
|
||||||
|
│ ├── transforms.py # 数据预处理管线
|
||||||
|
│ ├── dataset.py # WebDataset 加载 + 序列采样
|
||||||
|
│ ├── model.py # 网络模型定义
|
||||||
|
│ ├── train.py # 训练入口
|
||||||
|
│ └── evaluate.py # 评估与可视化
|
||||||
|
├── DATASET_FORMAT.md # 数据集格式说明
|
||||||
|
└── requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 网络架构
|
||||||
|
|
||||||
|
```
|
||||||
|
事件帧 (1, 240, 320) ──► CNN (4层 Conv+Pool+GAP, 256-d)
|
||||||
|
│
|
||||||
|
姿态 tilt_angles (3,) ──► PoseMLP (3→32→64, 64-d) ───┤
|
||||||
|
│
|
||||||
|
concat (320-d) ← 每帧融合
|
||||||
|
│
|
||||||
|
GRU (hidden=128)
|
||||||
|
│
|
||||||
|
Head MLP (128→64→2)
|
||||||
|
│
|
||||||
|
[vx_body, vy_body]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 各模块参数
|
||||||
|
|
||||||
|
| 模块 | 参数量 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| CNN Encoder | 387,840 | 4 层 Conv2D(3×3) + BN + ReLU + MaxPool(2×2) + GAP,通道 1→32→64→128→256 |
|
||||||
|
| PoseMLP | 2,240 | 3→32→64,两层全连接 |
|
||||||
|
| GRU | 172,800 | 单层,input=320, hidden=128 |
|
||||||
|
| Head MLP | 8,386 | 128→64→2 |
|
||||||
|
| **总计** | **~571K** | FP32 约 2.3 MB |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 数据预处理
|
||||||
|
|
||||||
|
### 输入变换管线(transforms.py)
|
||||||
|
|
||||||
|
每个 WebDataset 样本依次经过:
|
||||||
|
|
||||||
|
1. **DecodeSample** — JPEG 解码为灰度图 (H, W),pose/vel 字节转 numpy
|
||||||
|
2. **SimulateEvents** — `EventProcessor` 计算帧间亮度变化,输出二值事件帧 (1, H, W),值域 {-1, 0, 1}
|
||||||
|
3. **ComputeTilt** — 从四元数 `[qx, qy, qz, qw]` 中提取偏航角 yaw,移除后得到 tilt 旋转向量 (3,)
|
||||||
|
4. **ComputeBodyVelocity** — 世界系速度 `[vx, vy, vz]` → 补偿偏航 → 转到机体系,取 `[vx_body, vy_body]` (2,)
|
||||||
|
|
||||||
|
### 坐标系变换逻辑(utils.py)
|
||||||
|
|
||||||
|
```
|
||||||
|
输入: q_world_to_body (四元数), v_world (3,)
|
||||||
|
|
||||||
|
Step 1: 从四元数分解偏航角 yaw
|
||||||
|
Step 2: 构造纯偏航四元数 q_yaw
|
||||||
|
Step 3: q_tilt = q_yaw^{-1} * q_world_to_body → tilt_angles (旋转向量)
|
||||||
|
Step 4: v_yaw_comp = q_yaw^{-1} * v_world → 偏航补偿
|
||||||
|
Step 5: v_body = q_tilt^{-1} * v_yaw_comp → 转到机体系
|
||||||
|
Step 6: 取 v_body[:2] 作为回归目标
|
||||||
|
```
|
||||||
|
|
||||||
|
### 序列采样(dataset.py)
|
||||||
|
|
||||||
|
- 从 shard 中取连续 `seq_len` 帧(默认 8 帧)构成一个训练样本
|
||||||
|
- 输出 batch 维度:`(B, S, 1, H, W)` 事件帧, `(B, S, 3)` tilt, `(B, S, 2)` 速度 GT
|
||||||
|
- 模型预测最后一帧的速度
|
||||||
|
|
||||||
|
### 数据集划分
|
||||||
|
|
||||||
|
| 集 | 场景 |
|
||||||
|
|----|------|
|
||||||
|
| **训练** | indoor_forward_3/5/6/7/9/10, indoor_45_2/4/9/12 |
|
||||||
|
| **验证** | indoor_45_13/14 |
|
||||||
|
| **测试** | outdoor_forward_1/3/5, outdoor_45_1 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 运行方式
|
||||||
|
|
||||||
|
### 1. 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 使用 uv(推荐)
|
||||||
|
uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
uv pip install webdataset opencv-python matplotlib tensorboard numpy scipy
|
||||||
|
|
||||||
|
# 或使用 pip
|
||||||
|
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
pip install webdataset opencv-python matplotlib tensorboard numpy scipy
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 训练
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 从项目根目录执行
|
||||||
|
cd /home/hexone/Workplace/ws_asmo/uzh_fpv
|
||||||
|
|
||||||
|
# 激活虚拟环境后
|
||||||
|
python -m src.velocity_prediction.train
|
||||||
|
```
|
||||||
|
|
||||||
|
训练参数在 `config.py` 的 `TrainConfig` 中配置,关键参数:
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|------|--------|------|
|
||||||
|
| seq_len | 8 | 每序列帧数 |
|
||||||
|
| batch_size | 32 | 批次大小 |
|
||||||
|
| epochs | 100 | 训练轮数 |
|
||||||
|
| lr | 1e-3 | 学习率 |
|
||||||
|
| event_threshold | 0.1 | 事件模拟阈值 |
|
||||||
|
|
||||||
|
训练输出:
|
||||||
|
- `logs/` — TensorBoard 日志
|
||||||
|
- `checkpoints/` — 模型 checkpoint(每 10 轮保存 + 最优模型 `best.pt`)
|
||||||
|
|
||||||
|
### 3. 评估
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.velocity_prediction.evaluate --checkpoint checkpoints/best.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
输出:
|
||||||
|
- 控制台打印 RMSE(vx, vy, xy)
|
||||||
|
- `eval_velocity.png` — 预测 vs GT 时序对比图
|
||||||
|
- `eval_scatter.png` — 散点图
|
||||||
|
|
||||||
|
### 4. 模型参数量检查
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.velocity_prediction.model
|
||||||
|
```
|
||||||
|
|
||||||
|
输出类似:
|
||||||
|
```
|
||||||
|
Total trainable parameters: 571,266 (0.571 M)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 模型导出与部署
|
||||||
|
|
||||||
|
### 导出 TorchScript
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from src.velocity_prediction.model import VelocityPredictionModel
|
||||||
|
|
||||||
|
model = VelocityPredictionModel()
|
||||||
|
model.load_state_dict(torch.load("checkpoints/best.pt")["model_state_dict"])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# 导出
|
||||||
|
traced = torch.jit.trace(model, (torch.randn(1, 8, 1, 240, 320), torch.randn(1, 8, 3)))
|
||||||
|
traced.save("velocity_model.pt")
|
||||||
|
```
|
||||||
|
|
||||||
|
### RV1106 部署注意事项
|
||||||
|
|
||||||
|
- 模型 ~0.57M 参数,FP32 ~2.3 MB,INT8 量化后 ~0.6 MB
|
||||||
|
- CNN 部分可跑 NPU(0.5 TOPS),GRU 需 ARM CPU 执行
|
||||||
|
- 若需裁剪:`config.py` 中 `CNNConfig.channels` 减半(32→16, 64→32, 128→64, 256→128),或 `GRUConfig.hidden_size` 从 128 降至 64
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 文件职责速查
|
||||||
|
|
||||||
|
| 文件 | 职责 | 关键类/函数 |
|
||||||
|
|------|------|------------|
|
||||||
|
| `config.py` | 所有可配置参数 | `ModelConfig`, `TrainConfig`, `TRAIN_SCENES` |
|
||||||
|
| `utils.py` | 四元数运算、坐标变换 | `decompose_tilt`, `world_vel_to_body` |
|
||||||
|
| `transforms.py` | 数据预处理管线 | `DecodeSample`, `SimulateEvents`, `ComputeTilt`, `ComputeBodyVelocity` |
|
||||||
|
| `dataset.py` | WebDataset 加载 + 序列采样 | `create_train_loader`, `create_val_loader` |
|
||||||
|
| `model.py` | 网络模型 | `VelocityPredictionModel`, `CNNEncoder`, `PoseMLP` |
|
||||||
|
| `train.py` | 训练循环 | `train_one_epoch`, `validate`, `main` |
|
||||||
|
| `evaluate.py` | 评估与可视化 | `evaluate`, `plot_results`, `plot_scatter` |
|
||||||
7
src/velocity_prediction/__init__.py
Normal file
7
src/velocity_prediction/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
Velocity prediction from simulated event frames + attitude.
|
||||||
|
|
||||||
|
Pipeline:
|
||||||
|
Event frame (1, H, W) ──► CNN ──┐
|
||||||
|
Tilt angles (3,) ──► MLP ──┤──► concat ──► GRU ──► Head ──► [vx_body, vy_body]
|
||||||
|
"""
|
||||||
118
src/velocity_prediction/config.py
Normal file
118
src/velocity_prediction/config.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""
|
||||||
|
Global configuration for velocity prediction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Dataset paths ────────────────────────────
|
||||||
|
|
||||||
|
DATASET_ROOT = Path(__file__).resolve().parents[2] / "dataset"
|
||||||
|
|
||||||
|
# Velocity normalization stats (computed from training set)
|
||||||
|
VELOCITY_MEAN = [0.859184, -0.783945] # [vx, vy]
|
||||||
|
VELOCITY_STD = [2.244513, 1.088335] # [vx, vy]
|
||||||
|
|
||||||
|
# TRAIN_SCENES = [
|
||||||
|
# "indoor_forward_3", "indoor_forward_5", "indoor_forward_6",
|
||||||
|
# "indoor_forward_7", "indoor_forward_9", "indoor_forward_10",
|
||||||
|
# "indoor_45_2", "indoor_45_4", "indoor_45_9", "indoor_45_12",
|
||||||
|
# ]
|
||||||
|
# VAL_SCENES = [
|
||||||
|
# "indoor_45_13", "indoor_45_14",
|
||||||
|
# ]
|
||||||
|
# TEST_SCENES = [
|
||||||
|
# "outdoor_forward_1", "outdoor_forward_3", "outdoor_forward_5",
|
||||||
|
# "outdoor_45_1",
|
||||||
|
# ]
|
||||||
|
|
||||||
|
TRAIN_SCENES = [
|
||||||
|
"indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
||||||
|
"indoor_forward_5", "indoor_forward_6", # Medium
|
||||||
|
"outdoor_forward_3" # Medium 室外
|
||||||
|
]
|
||||||
|
VAL_SCENES = [
|
||||||
|
"indoor_forward_7", # Hard 室内
|
||||||
|
"outdoor_forward_1" # Easy 室外
|
||||||
|
# "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
||||||
|
]
|
||||||
|
TEST_SCENES = [
|
||||||
|
"indoor_forward_7", # Hard 室内
|
||||||
|
"outdoor_forward_1", # Easy 室外
|
||||||
|
"outdoor_forward_5" # Hard 室外
|
||||||
|
# "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Model architecture ────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CNNConfig:
|
||||||
|
in_channels: int = 1
|
||||||
|
channels: tuple = (32, 64, 128, 256) # per-layer output channels
|
||||||
|
kernel_size: int = 3
|
||||||
|
pool_size: int = 2
|
||||||
|
use_bn: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PoseMLPConfig:
|
||||||
|
input_dim: int = 3
|
||||||
|
hidden_dim: int = 32
|
||||||
|
output_dim: int = 64
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GRUConfig:
|
||||||
|
input_size: int = 320 # CNN(256) + PoseMLP(64)
|
||||||
|
hidden_size: int = 128
|
||||||
|
num_layers: int = 1
|
||||||
|
dropout: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HeadConfig:
|
||||||
|
input_dim: int = 128 # GRU hidden_size
|
||||||
|
hidden_dim: int = 64
|
||||||
|
output_dim: int = 2 # [vx_body, vy_body]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
cnn: CNNConfig = field(default_factory=CNNConfig)
|
||||||
|
pose_mlp: PoseMLPConfig = field(default_factory=PoseMLPConfig)
|
||||||
|
gru: GRUConfig = field(default_factory=GRUConfig)
|
||||||
|
head: HeadConfig = field(default_factory=HeadConfig)
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Training ────────────────────────────
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainConfig:
|
||||||
|
seq_len: int = 8 # frames per training sequence
|
||||||
|
batch_size: int = 32
|
||||||
|
epochs: int = 100
|
||||||
|
lr: float = 1e-3
|
||||||
|
weight_decay: float = 1e-5
|
||||||
|
lr_scheduler_step: int = 30
|
||||||
|
lr_scheduler_gamma: float = 0.5
|
||||||
|
num_workers: int = 4
|
||||||
|
seed: int = 42
|
||||||
|
|
||||||
|
# Event simulation
|
||||||
|
event_threshold: float = 0.1
|
||||||
|
event_use_log: bool = True
|
||||||
|
event_auto_threshold: bool = False
|
||||||
|
|
||||||
|
# Logging / checkpoint
|
||||||
|
log_dir: str = "logs"
|
||||||
|
checkpoint_dir: str = "checkpoints"
|
||||||
|
log_interval: int = 10
|
||||||
|
save_interval: int = 10
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Singleton instances ────────────────────────────
|
||||||
|
|
||||||
|
model_cfg = ModelConfig()
|
||||||
|
train_cfg = TrainConfig()
|
||||||
120
src/velocity_prediction/dataset.py
Normal file
120
src/velocity_prediction/dataset.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""
|
||||||
|
WebDataset-based dataset with sequence sampling.
|
||||||
|
|
||||||
|
Each sample from the dataset is a dict:
|
||||||
|
{
|
||||||
|
"events": np.ndarray (1, H, W), # simulated event frame
|
||||||
|
"tilt": np.ndarray (3,), # tilt rotation vector
|
||||||
|
"v_body_target": np.ndarray (2,), # body-frame [vx, vy]
|
||||||
|
}
|
||||||
|
|
||||||
|
The dataset groups consecutive frames into sequences of length seq_len.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import webdataset as wds
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Callable, Optional
|
||||||
|
|
||||||
|
from src.velocity_prediction.config import DATASET_ROOT, train_cfg
|
||||||
|
from src.velocity_prediction.transforms import build_train_transform, build_val_transform
|
||||||
|
|
||||||
|
|
||||||
|
def _scene_urls(scene_names: List[str], root: Path = DATASET_ROOT) -> List[str]:
|
||||||
|
"""Build list of actual shard file paths for the given scene names."""
|
||||||
|
urls = []
|
||||||
|
for name in scene_names:
|
||||||
|
scene_dir = root / name
|
||||||
|
if not scene_dir.exists():
|
||||||
|
print(f"Warning: scene directory not found: {scene_dir}")
|
||||||
|
continue
|
||||||
|
# List all shard_*.tar files in the scene directory
|
||||||
|
shard_files = sorted(scene_dir.glob("shard_*.tar"))
|
||||||
|
if not shard_files:
|
||||||
|
print(f"Warning: no shard files found in {scene_dir}")
|
||||||
|
continue
|
||||||
|
urls.extend(str(f) for f in shard_files)
|
||||||
|
return urls
|
||||||
|
|
||||||
|
|
||||||
|
def _build_pipeline(
|
||||||
|
urls: List[str],
|
||||||
|
transform: Callable,
|
||||||
|
seq_len: int,
|
||||||
|
shuffle: int = 1000,
|
||||||
|
deterministic: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Build a WebDataset pipeline that:
|
||||||
|
1. Reads tar shards
|
||||||
|
2. Decodes and transforms individual samples
|
||||||
|
3. Groups consecutive samples into sequences
|
||||||
|
"""
|
||||||
|
dataset = wds.WebDataset(urls, shardshuffle=shuffle if not deterministic else 0, empty_check=False)
|
||||||
|
|
||||||
|
if not deterministic:
|
||||||
|
dataset = dataset.shuffle(shuffle)
|
||||||
|
|
||||||
|
dataset = dataset.decode().map(transform)
|
||||||
|
|
||||||
|
# Group into sequences of seq_len consecutive frames
|
||||||
|
dataset = dataset.batched(seq_len, partial=False)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def create_train_loader(
|
||||||
|
scene_names: Optional[List[str]] = None,
|
||||||
|
seq_len: int = 8,
|
||||||
|
batch_size: int = 32,
|
||||||
|
num_workers: int = 4,
|
||||||
|
event_threshold: float = 0.1,
|
||||||
|
event_use_log: bool = True,
|
||||||
|
):
|
||||||
|
"""Create a DataLoader for training."""
|
||||||
|
if scene_names is None:
|
||||||
|
from src.velocity_prediction.config import TRAIN_SCENES
|
||||||
|
scene_names = TRAIN_SCENES
|
||||||
|
|
||||||
|
urls = _scene_urls(scene_names)
|
||||||
|
transform = build_train_transform(
|
||||||
|
event_threshold=event_threshold,
|
||||||
|
event_use_log=event_use_log,
|
||||||
|
)
|
||||||
|
pipeline = _build_pipeline(urls, transform, seq_len=seq_len, shuffle=1000)
|
||||||
|
|
||||||
|
loader = wds.WebLoader(
|
||||||
|
pipeline,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
shuffle=False, # already shuffled in pipeline
|
||||||
|
)
|
||||||
|
return loader
|
||||||
|
|
||||||
|
|
||||||
|
def create_val_loader(
|
||||||
|
scene_names: Optional[List[str]] = None,
|
||||||
|
seq_len: int = 8,
|
||||||
|
batch_size: int = 32,
|
||||||
|
num_workers: int = 4,
|
||||||
|
event_threshold: float = 0.1,
|
||||||
|
event_use_log: bool = True,
|
||||||
|
):
|
||||||
|
"""Create a DataLoader for validation (deterministic order)."""
|
||||||
|
if scene_names is None:
|
||||||
|
from src.velocity_prediction.config import VAL_SCENES
|
||||||
|
scene_names = VAL_SCENES
|
||||||
|
|
||||||
|
urls = _scene_urls(scene_names)
|
||||||
|
transform = build_val_transform(
|
||||||
|
event_threshold=event_threshold,
|
||||||
|
event_use_log=event_use_log,
|
||||||
|
)
|
||||||
|
pipeline = _build_pipeline(urls, transform, seq_len=seq_len, shuffle=0, deterministic=True)
|
||||||
|
|
||||||
|
loader = wds.WebLoader(
|
||||||
|
pipeline,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
return loader
|
||||||
212
src/velocity_prediction/evaluate.py
Normal file
212
src/velocity_prediction/evaluate.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""
|
||||||
|
Evaluation and visualization for VelocityPredictionModel.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m src.velocity_prediction.evaluate --checkpoint checkpoints/best.pt
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from src.velocity_prediction.model import VelocityPredictionModel
|
||||||
|
from src.velocity_prediction.dataset import create_val_loader
|
||||||
|
from src.velocity_prediction.config import train_cfg, VELOCITY_MEAN, VELOCITY_STD
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def evaluate(
|
||||||
|
model: nn.Module,
|
||||||
|
loader,
|
||||||
|
device: torch.device,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Run evaluation on a dataloader.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys:
|
||||||
|
preds: np.ndarray (N, 2) predicted [vx, vy]
|
||||||
|
targets: np.ndarray (N, 2) ground truth [vx, vy]
|
||||||
|
"""
|
||||||
|
model.eval()
|
||||||
|
all_preds = []
|
||||||
|
all_targets = []
|
||||||
|
|
||||||
|
for batch in loader:
|
||||||
|
events = batch["events"].to(device)
|
||||||
|
tilt = batch["tilt"].to(device)
|
||||||
|
target = batch["v_body_target"].to(device) # (B, S, 2)
|
||||||
|
|
||||||
|
pred = model(events, tilt) # (B, 2)
|
||||||
|
target_last = target[:, -1, :] # (B, 2)
|
||||||
|
|
||||||
|
all_preds.append(pred.cpu().numpy())
|
||||||
|
all_targets.append(target_last.cpu().numpy())
|
||||||
|
|
||||||
|
preds = np.concatenate(all_preds, axis=0)
|
||||||
|
targets = np.concatenate(all_targets, axis=0)
|
||||||
|
|
||||||
|
# Denormalize predictions back to original velocity space
|
||||||
|
mean = np.array(VELOCITY_MEAN, dtype=np.float32)
|
||||||
|
std = np.array(VELOCITY_STD, dtype=np.float32)
|
||||||
|
preds_denorm = preds * std + mean
|
||||||
|
targets_denorm = targets * std + mean
|
||||||
|
|
||||||
|
# ── Diagnostics (in normalized space) ────────────────────────
|
||||||
|
print("\n========== Evaluation Diagnostics (normalized space) ==========")
|
||||||
|
print(f"Total samples: {len(preds)}")
|
||||||
|
print(f"\n--- Targets (normalized) ---")
|
||||||
|
print(f" vx: mean={targets[:, 0].mean():.6f}, std={targets[:, 0].std():.6f}")
|
||||||
|
print(f" vy: mean={targets[:, 1].mean():.6f}, std={targets[:, 1].std():.6f}")
|
||||||
|
print(f"\n--- Predictions (normalized) ---")
|
||||||
|
print(f" vx: mean={preds[:, 0].mean():.6f}, std={preds[:, 0].std():.6f}, "
|
||||||
|
f"min={preds[:, 0].min():.6f}, max={preds[:, 0].max():.6f}")
|
||||||
|
print(f" vy: mean={preds[:, 1].mean():.6f}, std={preds[:, 1].std():.6f}, "
|
||||||
|
f"min={preds[:, 1].min():.6f}, max={preds[:, 1].max():.6f}")
|
||||||
|
print(f"\n--- Unique prediction values ---")
|
||||||
|
print(f" vx unique: {len(np.unique(preds[:, 0]))} / {len(preds)}")
|
||||||
|
print(f" vy unique: {len(np.unique(preds[:, 1]))} / {len(preds)}")
|
||||||
|
vx_range = preds[:, 0].max() - preds[:, 0].min()
|
||||||
|
vy_range = preds[:, 1].max() - preds[:, 1].min()
|
||||||
|
print(f"\n vx range: {vx_range:.8f} (constant if near 0)")
|
||||||
|
print(f" vy range: {vy_range:.8f} (constant if near 0)")
|
||||||
|
print(f"\n--- Constant prediction check ---")
|
||||||
|
print(f" pred vx mean ≈ 0? {abs(preds[:, 0].mean()):.6f} diff from zero")
|
||||||
|
print(f" pred vy mean ≈ 0? {abs(preds[:, 1].mean()):.6f} diff from zero")
|
||||||
|
print("=============================================\n")
|
||||||
|
|
||||||
|
# Per-axis and overall RMSE (in original velocity space)
|
||||||
|
rmse_x = np.sqrt(np.mean((preds_denorm[:, 0] - targets_denorm[:, 0]) ** 2))
|
||||||
|
rmse_y = np.sqrt(np.mean((preds_denorm[:, 1] - targets_denorm[:, 1]) ** 2))
|
||||||
|
rmse_xy = np.sqrt(np.mean(np.sum((preds_denorm - targets_denorm) ** 2, axis=1)))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"preds": preds_denorm, # denormalized for plotting
|
||||||
|
"targets": targets_denorm, # denormalized for plotting
|
||||||
|
"rmse_x": rmse_x,
|
||||||
|
"rmse_y": rmse_y,
|
||||||
|
"rmse_xy": rmse_xy,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def plot_results(
|
||||||
|
preds: np.ndarray,
|
||||||
|
targets: np.ndarray,
|
||||||
|
save_path: str = "eval_plot.png",
|
||||||
|
):
|
||||||
|
"""Plot predicted vs ground truth velocity traces."""
|
||||||
|
fig, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=True)
|
||||||
|
|
||||||
|
time = np.arange(len(preds))
|
||||||
|
|
||||||
|
axes[0].plot(time, targets[:, 0], label="GT vx", color="C0", alpha=0.8)
|
||||||
|
axes[0].plot(time, preds[:, 0], label="Pred vx", color="C1", alpha=0.8)
|
||||||
|
axes[0].set_ylabel("vx (m/s)")
|
||||||
|
axes[0].legend()
|
||||||
|
axes[0].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
axes[1].plot(time, targets[:, 1], label="GT vy", color="C0", alpha=0.8)
|
||||||
|
axes[1].plot(time, preds[:, 1], label="Pred vy", color="C1", alpha=0.8)
|
||||||
|
axes[1].set_ylabel("vy (m/s)")
|
||||||
|
axes[1].set_xlabel("Frame index")
|
||||||
|
axes[1].legend()
|
||||||
|
axes[1].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
fig.suptitle("Body-frame Velocity Prediction")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(save_path, dpi=150)
|
||||||
|
print(f"Plot saved: {save_path}")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_scatter(
|
||||||
|
preds: np.ndarray,
|
||||||
|
targets: np.ndarray,
|
||||||
|
save_path: str = "eval_scatter.png",
|
||||||
|
):
|
||||||
|
"""Scatter plot: predicted vs ground truth."""
|
||||||
|
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
|
||||||
|
|
||||||
|
for ax, pred, target, label in zip(
|
||||||
|
axes, [preds[:, 0], preds[:, 1]], [targets[:, 0], targets[:, 1]], ["vx", "vy"]
|
||||||
|
):
|
||||||
|
ax.scatter(target, pred, s=2, alpha=0.5)
|
||||||
|
lim_min = min(target.min(), pred.min())
|
||||||
|
lim_max = max(target.max(), pred.max())
|
||||||
|
ax.plot([lim_min, lim_max], [lim_min, lim_max], "r--", alpha=0.5)
|
||||||
|
ax.set_xlabel(f"GT {label} (m/s)")
|
||||||
|
ax.set_ylabel(f"Pred {label} (m/s)")
|
||||||
|
ax.set_aspect("equal")
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.set_title(f"{label} — RMSE: {np.sqrt(np.mean((pred - target)**2)):.4f}")
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(save_path, dpi=150)
|
||||||
|
print(f"Scatter saved: {save_path}")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--checkpoint", type=str, default="checkpoints/best.pt",
|
||||||
|
help="Path to model checkpoint")
|
||||||
|
parser.add_argument("--device", type=str, default="cuda",
|
||||||
|
help="Device to use (e.g. 'cuda:2', 'cpu')")
|
||||||
|
parser.add_argument("--plot", action="store_true", default=True,
|
||||||
|
help="Generate evaluation plots")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
|
||||||
|
print(f"Device: {device}")
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model = VelocityPredictionModel()
|
||||||
|
ckpt = torch.load(args.checkpoint, map_location="cpu")
|
||||||
|
model.load_state_dict(ckpt["model_state_dict"])
|
||||||
|
model.to(device)
|
||||||
|
print(f"Loaded checkpoint from {args.checkpoint} (epoch={ckpt.get('epoch', '?')})")
|
||||||
|
|
||||||
|
# Validation loader (use test scenes for final eval)
|
||||||
|
from src.velocity_prediction.config import TEST_SCENES
|
||||||
|
loader = create_val_loader(
|
||||||
|
scene_names=TEST_SCENES,
|
||||||
|
seq_len=train_cfg.seq_len,
|
||||||
|
batch_size=train_cfg.batch_size,
|
||||||
|
num_workers=2,
|
||||||
|
event_threshold=train_cfg.event_threshold,
|
||||||
|
event_use_log=train_cfg.event_use_log,
|
||||||
|
)
|
||||||
|
|
||||||
|
# # ── Quick event diagnostics: inspect one batch ───────────────
|
||||||
|
# print("\n========== Event Frame Diagnostics ==========")
|
||||||
|
# sample_batch = next(iter(loader))
|
||||||
|
# ev = sample_batch["events"] # (B, S, 1, H, W)
|
||||||
|
# print(f"Events shape: {ev.shape}")
|
||||||
|
# print(f"Events dtype: {ev.dtype}")
|
||||||
|
# print(f"Events value counts: -1: {(ev == -1).sum().item()}, "
|
||||||
|
# f"0: {(ev == 0).sum().item()}, +1: {(ev == 1).sum().item()}")
|
||||||
|
# total_el = ev.numel()
|
||||||
|
# nonzero = (ev != 0).sum().item()
|
||||||
|
# print(f"Non-zero ratio: {nonzero / total_el:.6f} ({nonzero}/{total_el})")
|
||||||
|
# print(f"Per-sample non-zero: {[(ev[b] != 0).sum().item() for b in range(min(4, ev.shape[0]))]}")
|
||||||
|
# print("=============================================\n")
|
||||||
|
|
||||||
|
# Evaluate
|
||||||
|
results = evaluate(model, loader, device)
|
||||||
|
print(f"\nEvaluation results on test scenes: {TEST_SCENES}")
|
||||||
|
print(f" RMSE vx: {results['rmse_x']:.4f} m/s")
|
||||||
|
print(f" RMSE vy: {results['rmse_y']:.4f} m/s")
|
||||||
|
print(f" RMSE xy: {results['rmse_xy']:.4f} m/s")
|
||||||
|
|
||||||
|
# Plots
|
||||||
|
if args.plot:
|
||||||
|
plot_results(results["preds"], results["targets"], "eval_velocity.png")
|
||||||
|
plot_scatter(results["preds"], results["targets"], "eval_scatter.png")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
170
src/velocity_prediction/model.py
Normal file
170
src/velocity_prediction/model.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
"""
|
||||||
|
VelocityPredictionModel: CNN + PoseMLP → concat → GRU → Head → [vx_body, vy_body].
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
Event frame (1, H, W) ──► CNN ──┐
|
||||||
|
Tilt angles (3,) ──► MLP ──┤──► concat ──► GRU ──► Head ──► [vx, vy]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from src.velocity_prediction.config import model_cfg
|
||||||
|
|
||||||
|
|
||||||
|
class CNNEncoder(nn.Module):
|
||||||
|
"""
|
||||||
|
4-layer ConvNet with BatchNorm, ReLU, MaxPool, ending with Global Avg Pool.
|
||||||
|
|
||||||
|
Input: (B, S, 1, H, W) — processed per-frame (flattened to (B*S, 1, H, W))
|
||||||
|
Output: (B, S, C_out) — per-frame feature vectors
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg=model_cfg.cnn):
|
||||||
|
super().__init__()
|
||||||
|
channels = cfg.channels
|
||||||
|
in_ch = cfg.in_channels
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
for out_ch in channels:
|
||||||
|
layers.extend([
|
||||||
|
nn.Conv2d(in_ch, out_ch, kernel_size=cfg.kernel_size, padding=cfg.kernel_size // 2),
|
||||||
|
# nn.BatchNorm2d(out_ch) if cfg.use_bn else nn.Identity(),
|
||||||
|
nn.Identity(),
|
||||||
|
nn.LeakyReLU(inplace=True),
|
||||||
|
nn.MaxPool2d(cfg.pool_size),
|
||||||
|
])
|
||||||
|
in_ch = out_ch
|
||||||
|
|
||||||
|
self.conv = nn.Sequential(*layers)
|
||||||
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.out_dim = channels[-1]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (B, S, 1, H, W) event frame sequence
|
||||||
|
Returns:
|
||||||
|
features: (B, S, C_out)
|
||||||
|
"""
|
||||||
|
B, S, C, H, W = x.shape
|
||||||
|
x = x.view(B * S, C, H, W) # (B*S, 1, H, W)
|
||||||
|
x = self.conv(x) # (B*S, C_out, H', W')
|
||||||
|
x = self.gap(x) # (B*S, C_out, 1, 1)
|
||||||
|
x = x.view(B, S, self.out_dim) # (B, S, C_out)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PoseMLP(nn.Module):
|
||||||
|
"""
|
||||||
|
Encode tilt rotation vector (3,) into a compact feature vector.
|
||||||
|
|
||||||
|
Input: (B, S, 3)
|
||||||
|
Output: (B, S, output_dim)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg=model_cfg.pose_mlp):
|
||||||
|
super().__init__()
|
||||||
|
self.net = nn.Sequential(
|
||||||
|
nn.Linear(cfg.input_dim, cfg.hidden_dim),
|
||||||
|
nn.LeakyReLU(inplace=True),
|
||||||
|
nn.Linear(cfg.hidden_dim, cfg.output_dim),
|
||||||
|
nn.LeakyReLU(inplace=True),
|
||||||
|
)
|
||||||
|
self.out_dim = cfg.output_dim
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: (B, S, 3) → (B, S, output_dim)
|
||||||
|
"""
|
||||||
|
B, S, D = x.shape
|
||||||
|
x = x.view(B * S, D)
|
||||||
|
x = self.net(x)
|
||||||
|
x = x.view(B, S, self.out_dim)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VelocityPredictionModel(nn.Module):
|
||||||
|
"""
|
||||||
|
Full model: CNN + PoseMLP → concat → GRU → Head → [vx, vy].
|
||||||
|
|
||||||
|
Input:
|
||||||
|
events: (B, S, 1, H, W)
|
||||||
|
tilt: (B, S, 3)
|
||||||
|
Output:
|
||||||
|
v_body: (B, 2) — body-frame [vx, vy] for the last frame in the sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cnn_cfg=model_cfg.cnn, pose_cfg=model_cfg.pose_mlp,
|
||||||
|
gru_cfg=model_cfg.gru, head_cfg=model_cfg.head):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cnn = CNNEncoder(cnn_cfg)
|
||||||
|
self.pose_mlp = PoseMLP(pose_cfg)
|
||||||
|
|
||||||
|
fused_dim = self.cnn.out_dim + self.pose_mlp.out_dim # 256 + 64 = 320
|
||||||
|
|
||||||
|
self.gru = nn.GRU(
|
||||||
|
input_size=fused_dim,
|
||||||
|
hidden_size=gru_cfg.hidden_size,
|
||||||
|
num_layers=gru_cfg.num_layers,
|
||||||
|
dropout=gru_cfg.dropout if gru_cfg.num_layers > 1 else 0.0,
|
||||||
|
batch_first=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.head = nn.Sequential(
|
||||||
|
nn.Linear(gru_cfg.hidden_size, head_cfg.hidden_dim),
|
||||||
|
nn.LeakyReLU(inplace=True),
|
||||||
|
nn.Linear(head_cfg.hidden_dim, head_cfg.output_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
# # Small init for the final layer: start from near-zero output
|
||||||
|
# self.head[-1].weight.data.mul_(0.01)
|
||||||
|
# self.head[-1].bias.data.zero_()
|
||||||
|
|
||||||
|
def forward(self, events: torch.Tensor, tilt: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
events: (B, S, 1, H, W)
|
||||||
|
tilt: (B, S, 3)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
v_body: (B, 2) predicted body-frame [vx, vy] at the last timestep
|
||||||
|
"""
|
||||||
|
# Per-frame encoding
|
||||||
|
cnn_feat = self.cnn(events) # (B, S, 256)
|
||||||
|
pose_feat = self.pose_mlp(tilt) # (B, S, 64)
|
||||||
|
|
||||||
|
# Fuse per frame
|
||||||
|
fused = torch.cat([cnn_feat, pose_feat], dim=-1) # (B, S, 320)
|
||||||
|
|
||||||
|
# GRU temporal modelling
|
||||||
|
gru_out, h_n = self.gru(fused) # gru_out: (B, S, 128), h_n: (1, B, 128)
|
||||||
|
|
||||||
|
# Use last hidden state
|
||||||
|
last_hidden = h_n[-1] # (B, 128)
|
||||||
|
|
||||||
|
# Head regression
|
||||||
|
v_body = self.head(last_hidden) # (B, 2)
|
||||||
|
return v_body
|
||||||
|
|
||||||
|
|
||||||
|
def count_parameters(model: nn.Module) -> int:
|
||||||
|
"""Count trainable parameters."""
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Quick sanity check
|
||||||
|
model = VelocityPredictionModel()
|
||||||
|
total = count_parameters(model)
|
||||||
|
print(f"Total trainable parameters: {total:,} ({total/1e6:.3f} M)")
|
||||||
|
|
||||||
|
# Forward pass test
|
||||||
|
B, S, H, W = 4, 8, 240, 320
|
||||||
|
events = torch.randn(B, S, 1, H, W)
|
||||||
|
tilt = torch.randn(B, S, 3)
|
||||||
|
out = model(events, tilt)
|
||||||
|
print(f"Input events: {events.shape}")
|
||||||
|
print(f"Input tilt: {tilt.shape}")
|
||||||
|
print(f"Output: {out.shape} (should be [4, 2])")
|
||||||
213
src/velocity_prediction/train.py
Normal file
213
src/velocity_prediction/train.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""
|
||||||
|
Training loop for VelocityPredictionModel.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m src.velocity_prediction.train [--device cuda:0]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from src.velocity_prediction.config import train_cfg, model_cfg
|
||||||
|
from src.velocity_prediction.model import VelocityPredictionModel, count_parameters
|
||||||
|
from src.velocity_prediction.dataset import create_train_loader, create_val_loader
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed: int):
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def train_one_epoch(
|
||||||
|
model: nn.Module,
|
||||||
|
loader,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
criterion: nn.Module,
|
||||||
|
device: torch.device,
|
||||||
|
epoch: int,
|
||||||
|
writer: SummaryWriter,
|
||||||
|
log_interval: int = 50,
|
||||||
|
) -> float:
|
||||||
|
"""Train for one epoch. Returns average loss."""
|
||||||
|
model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
num_batches = 0
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(loader):
|
||||||
|
events = batch["events"].to(device) # (B, S, 1, H, W)
|
||||||
|
tilt = batch["tilt"].to(device) # (B, S, 3)
|
||||||
|
target = batch["v_body_target"].to(device) # (B, S, 2)
|
||||||
|
|
||||||
|
# Predict velocity for the last frame in the sequence
|
||||||
|
pred = model(events, tilt) # (B, 2)
|
||||||
|
target_last = target[:, -1, :] # (B, 2)
|
||||||
|
|
||||||
|
loss = criterion(pred, target_last)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(f" Epoch {epoch} | Batch {batch_idx} | Loss: {loss.item():.6f} | {elapsed:.1f}s")
|
||||||
|
writer.add_scalar("train/loss_batch", loss.item(), batch_idx)
|
||||||
|
|
||||||
|
avg_loss = total_loss / max(num_batches, 1)
|
||||||
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def validate(
|
||||||
|
model: nn.Module,
|
||||||
|
loader,
|
||||||
|
criterion: nn.Module,
|
||||||
|
device: torch.device,
|
||||||
|
) -> float:
|
||||||
|
"""Validate. Returns average loss."""
|
||||||
|
model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
num_batches = 0
|
||||||
|
|
||||||
|
for batch in loader:
|
||||||
|
events = batch["events"].to(device)
|
||||||
|
tilt = batch["tilt"].to(device)
|
||||||
|
target = batch["v_body_target"].to(device)
|
||||||
|
|
||||||
|
pred = model(events, tilt)
|
||||||
|
target_last = target[:, -1, :]
|
||||||
|
|
||||||
|
loss = criterion(pred, target_last)
|
||||||
|
total_loss += loss.item()
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
avg_loss = total_loss / max(num_batches, 1)
|
||||||
|
return avg_loss
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--device", type=str, default="cuda",
|
||||||
|
help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
set_seed(train_cfg.seed)
|
||||||
|
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
|
||||||
|
print(f"Device: {device}")
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
model = VelocityPredictionModel()
|
||||||
|
model.to(device)
|
||||||
|
total_params = count_parameters(model)
|
||||||
|
print(f"Model parameters: {total_params:,} ({total_params/1e6:.3f} M)")
|
||||||
|
|
||||||
|
# Data loaders
|
||||||
|
train_loader = create_train_loader(
|
||||||
|
seq_len=train_cfg.seq_len,
|
||||||
|
batch_size=train_cfg.batch_size,
|
||||||
|
num_workers=train_cfg.num_workers,
|
||||||
|
event_threshold=train_cfg.event_threshold,
|
||||||
|
event_use_log=train_cfg.event_use_log,
|
||||||
|
)
|
||||||
|
val_loader = create_val_loader(
|
||||||
|
seq_len=train_cfg.seq_len,
|
||||||
|
batch_size=train_cfg.batch_size,
|
||||||
|
num_workers=train_cfg.num_workers,
|
||||||
|
event_threshold=train_cfg.event_threshold,
|
||||||
|
event_use_log=train_cfg.event_use_log,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer & scheduler
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
model.parameters(),
|
||||||
|
lr=train_cfg.lr,
|
||||||
|
weight_decay=train_cfg.weight_decay,
|
||||||
|
)
|
||||||
|
scheduler = torch.optim.lr_scheduler.StepLR(
|
||||||
|
optimizer,
|
||||||
|
step_size=train_cfg.lr_scheduler_step,
|
||||||
|
gamma=train_cfg.lr_scheduler_gamma,
|
||||||
|
)
|
||||||
|
# criterion = nn.SmoothL1Loss()
|
||||||
|
criterion = nn.MSELoss()
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_dir = Path(train_cfg.log_dir)
|
||||||
|
log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
writer = SummaryWriter(log_dir=str(log_dir))
|
||||||
|
|
||||||
|
ckpt_dir = Path(train_cfg.checkpoint_dir)
|
||||||
|
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
best_val_loss = float("inf")
|
||||||
|
|
||||||
|
print(f"\nStarting training for {train_cfg.epochs} epochs...")
|
||||||
|
print(f" seq_len={train_cfg.seq_len}, batch_size={train_cfg.batch_size}")
|
||||||
|
print(f" lr={train_cfg.lr}, weight_decay={train_cfg.weight_decay}")
|
||||||
|
print(f" log_dir={log_dir}, checkpoint_dir={ckpt_dir}\n")
|
||||||
|
|
||||||
|
for epoch in range(1, train_cfg.epochs + 1):
|
||||||
|
epoch_start = time.time()
|
||||||
|
|
||||||
|
train_loss = train_one_epoch(
|
||||||
|
model, train_loader, optimizer, criterion, device, epoch, writer,
|
||||||
|
log_interval=train_cfg.log_interval,
|
||||||
|
)
|
||||||
|
val_loss = validate(model, val_loader, criterion, device)
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
epoch_time = time.time() - epoch_start
|
||||||
|
current_lr = scheduler.get_last_lr()[0]
|
||||||
|
|
||||||
|
print(f"Epoch {epoch:3d}/{train_cfg.epochs} | "
|
||||||
|
f"Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | "
|
||||||
|
f"LR: {current_lr:.2e} | Time: {epoch_time:.1f}s")
|
||||||
|
|
||||||
|
writer.add_scalar("train/loss_epoch", train_loss, epoch)
|
||||||
|
writer.add_scalar("val/loss", val_loss, epoch)
|
||||||
|
writer.add_scalar("lr", current_lr, epoch)
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
if epoch % train_cfg.save_interval == 0:
|
||||||
|
ckpt_path = ckpt_dir / f"epoch_{epoch:03d}_val_{val_loss:.6f}.pt"
|
||||||
|
torch.save({
|
||||||
|
"epoch": epoch,
|
||||||
|
"model_state_dict": model.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"scheduler_state_dict": scheduler.state_dict(),
|
||||||
|
"train_loss": train_loss,
|
||||||
|
"val_loss": val_loss,
|
||||||
|
}, ckpt_path)
|
||||||
|
print(f" Checkpoint saved: {ckpt_path}")
|
||||||
|
|
||||||
|
# Save best model
|
||||||
|
if val_loss < best_val_loss:
|
||||||
|
best_val_loss = val_loss
|
||||||
|
best_path = ckpt_dir / "best.pt"
|
||||||
|
torch.save({
|
||||||
|
"epoch": epoch,
|
||||||
|
"model_state_dict": model.state_dict(),
|
||||||
|
"optimizer_state_dict": optimizer.state_dict(),
|
||||||
|
"val_loss": val_loss,
|
||||||
|
}, best_path)
|
||||||
|
print(f" Best model updated: {best_path} (val_loss={val_loss:.6f})")
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
print(f"\nTraining complete. Best val loss: {best_val_loss:.6f}")
|
||||||
|
print(f"Best checkpoint: {ckpt_dir / 'best.pt'}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
142
src/velocity_prediction/transforms.py
Normal file
142
src/velocity_prediction/transforms.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
"""
|
||||||
|
Transforms: event frame generation + coordinate transforms for pose/velocity.
|
||||||
|
|
||||||
|
Each transform operates on a single decoded sample dict:
|
||||||
|
{
|
||||||
|
"jpg": bytes, # JPEG-encoded grayscale image
|
||||||
|
"ts": bytes, # float64 timestamp
|
||||||
|
"pose": bytes, # float32[7] [x, y, z, qx, qy, qz, qw]
|
||||||
|
"vel": bytes, # float32[6] [vx, vy, vz, wx, wy, wz]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
from src.event_utils import EventProcessor
|
||||||
|
from src.velocity_prediction.utils import decompose_tilt_np, world_vel_to_body_np
|
||||||
|
from src.velocity_prediction.config import VELOCITY_MEAN, VELOCITY_STD
|
||||||
|
|
||||||
|
|
||||||
|
class DecodeSample:
|
||||||
|
"""Decode raw bytes from WebDataset tar entry into numpy arrays."""
|
||||||
|
|
||||||
|
def __call__(self, sample: dict) -> dict:
|
||||||
|
# Image: JPEG bytes → grayscale uint8 (H, W)
|
||||||
|
img = cv2.imdecode(np.frombuffer(sample["jpg"], np.uint8), cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
|
# Timestamp
|
||||||
|
ts = np.frombuffer(sample["ts"], dtype=np.float64).item()
|
||||||
|
|
||||||
|
# Pose: [x, y, z, qx, qy, qz, qw]
|
||||||
|
pose = np.frombuffer(sample["pose"], dtype=np.float32).copy()
|
||||||
|
|
||||||
|
# Velocity: [vx, vy, vz, wx, wy, wz]
|
||||||
|
vel = np.frombuffer(sample["vel"], dtype=np.float32).copy()
|
||||||
|
|
||||||
|
return {"img": img, "ts": ts, "pose": pose, "vel": vel}
|
||||||
|
|
||||||
|
|
||||||
|
class SimulateEvents:
|
||||||
|
"""Convert grayscale frame to binary event frame using EventProcessor."""
|
||||||
|
|
||||||
|
def __init__(self, threshold=0.1, use_log=True, auto_threshold=False, verbose=False):
|
||||||
|
self.processor = EventProcessor(
|
||||||
|
threshold=threshold,
|
||||||
|
use_log=use_log,
|
||||||
|
auto_threshold=auto_threshold,
|
||||||
|
)
|
||||||
|
self.verbose = verbose
|
||||||
|
self._frame_count = 0
|
||||||
|
|
||||||
|
def __call__(self, sample: dict) -> dict:
|
||||||
|
img = sample["img"]
|
||||||
|
events_binary, events_strength, event_count = self.processor(img)
|
||||||
|
# Use binary events as network input: shape (1, H, W), values in {-1, 0, 1}
|
||||||
|
sample["events"] = events_binary.astype(np.float32)[np.newaxis, ...]
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
self._frame_count += 1
|
||||||
|
total_pixels = events_binary.shape[0] * events_binary.shape[1]
|
||||||
|
nonzero_ratio = event_count / total_pixels
|
||||||
|
pos_ratio = (events_binary > 0).sum() / total_pixels
|
||||||
|
neg_ratio = (events_binary < 0).sum() / total_pixels
|
||||||
|
if self._frame_count <= 5 or self._frame_count % 100 == 0:
|
||||||
|
print(f" [EventDiagnostics] frame={self._frame_count} | "
|
||||||
|
f"nonzero={nonzero_ratio:.4f} (+{pos_ratio:.4f}/-{neg_ratio:.4f}) | "
|
||||||
|
f"count={event_count}")
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.processor.reset()
|
||||||
|
self._frame_count = 0
|
||||||
|
|
||||||
|
|
||||||
|
class ComputeTilt:
|
||||||
|
"""Extract tilt rotation vector from pose quaternion (discard position, discard yaw)."""
|
||||||
|
|
||||||
|
def __call__(self, sample: dict) -> dict:
|
||||||
|
q = sample["pose"][3:7] # [qx, qy, qz, qw]
|
||||||
|
tilt = decompose_tilt_np(q) # (3,) rotation vector
|
||||||
|
sample["tilt"] = tilt.astype(np.float32)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class ComputeBodyVelocity:
|
||||||
|
"""Transform world-frame velocity to body-frame (yaw-compensated)."""
|
||||||
|
|
||||||
|
def __call__(self, sample: dict) -> dict:
|
||||||
|
v_world = sample["vel"][:3] # [vx, vy, vz] world frame
|
||||||
|
q = sample["pose"][3:7] # [qx, qy, qz, qw]
|
||||||
|
v_body = world_vel_to_body_np(v_world, q) # (3,)
|
||||||
|
# Only predict forward (x) and lateral (y) body velocity
|
||||||
|
sample["v_body_target"] = v_body[:2].astype(np.float32) # (2,)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeVelocity:
|
||||||
|
"""Normalize body-frame velocity to zero mean, unit variance."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.mean = np.array(VELOCITY_MEAN, dtype=np.float32)
|
||||||
|
self.std = np.array(VELOCITY_STD, dtype=np.float32)
|
||||||
|
|
||||||
|
def __call__(self, sample: dict) -> dict:
|
||||||
|
sample["v_body_target"] = (sample["v_body_target"] - self.mean) / self.std
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class Compose:
|
||||||
|
"""Chain multiple transforms."""
|
||||||
|
|
||||||
|
def __init__(self, transforms):
|
||||||
|
self.transforms = transforms
|
||||||
|
|
||||||
|
def __call__(self, sample: dict) -> dict:
|
||||||
|
for t in self.transforms:
|
||||||
|
sample = t(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
def build_train_transform(event_threshold=0.1, event_use_log=True):
|
||||||
|
"""Build the full transform pipeline for training samples."""
|
||||||
|
return Compose([
|
||||||
|
DecodeSample(),
|
||||||
|
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
||||||
|
ComputeTilt(),
|
||||||
|
ComputeBodyVelocity(),
|
||||||
|
NormalizeVelocity(),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def build_val_transform(event_threshold=0.1, event_use_log=True):
|
||||||
|
"""Same as train but with a fresh EventProcessor per sample (no cross-contamination)."""
|
||||||
|
return Compose([
|
||||||
|
DecodeSample(),
|
||||||
|
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
||||||
|
ComputeTilt(),
|
||||||
|
ComputeBodyVelocity(),
|
||||||
|
NormalizeVelocity(),
|
||||||
|
])
|
||||||
165
src/velocity_prediction/utils.py
Normal file
165
src/velocity_prediction/utils.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""
|
||||||
|
Quaternion and coordinate-frame utility functions.
|
||||||
|
|
||||||
|
All quaternions follow [x, y, z, w] convention (matching dataset pose field).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Quaternion operations ────────────────────────────
|
||||||
|
|
||||||
|
def quat_normalize(q: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Normalize quaternion. q: (..., 4)"""
|
||||||
|
return q / torch.norm(q, dim=-1, keepdim=True).clamp(min=1e-12)
|
||||||
|
|
||||||
|
|
||||||
|
def quat_conjugate(q: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Conjugate (inverse for unit quaternion). q: (..., 4)"""
|
||||||
|
return q * torch.tensor([-1, -1, -1, 1], device=q.device)
|
||||||
|
|
||||||
|
|
||||||
|
def quat_mul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Hamilton product. a, b: (..., 4)"""
|
||||||
|
x1, y1, z1, w1 = a.unbind(-1)
|
||||||
|
x2, y2, z2, w2 = b.unbind(-1)
|
||||||
|
return torch.stack([
|
||||||
|
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
||||||
|
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
|
||||||
|
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
|
||||||
|
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
|
||||||
|
], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Rotate vector v by quaternion q. q: (..., 4), v: (..., 3)"""
|
||||||
|
q_conj = quat_conjugate(q)
|
||||||
|
v_pad = torch.zeros_like(v[..., :1]) # (..., 1)
|
||||||
|
v_q = torch.cat([v, v_pad], dim=-1) # (..., 4) pure quaternion
|
||||||
|
return quat_mul(quat_mul(q, v_q), q_conj)[..., :3]
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Yaw decomposition ────────────────────────────
|
||||||
|
|
||||||
|
def quat_to_yaw(q: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Extract yaw (heading) angle from a quaternion in gravity-aligned frame.
|
||||||
|
|
||||||
|
Gravity axis is +z. Yaw is the rotation around z-axis.
|
||||||
|
Returns angle in radians, shape (...,).
|
||||||
|
"""
|
||||||
|
x, y, z, w = q.unbind(-1)
|
||||||
|
# From quaternion to Euler: yaw = atan2(2(w*z + x*y), 1 - 2(y² + z²))
|
||||||
|
siny = 2.0 * (w * z + x * y)
|
||||||
|
cosy = 1.0 - 2.0 * (y * y + z * z)
|
||||||
|
return torch.atan2(siny, cosy)
|
||||||
|
|
||||||
|
|
||||||
|
def quat_from_yaw(yaw: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Build a pure-yaw quaternion (rotation around +z).
|
||||||
|
yaw: (...,) → quat: (..., 4)
|
||||||
|
"""
|
||||||
|
half = yaw * 0.5
|
||||||
|
cos = torch.cos(half)
|
||||||
|
sin = torch.sin(half)
|
||||||
|
z = torch.zeros_like(cos)
|
||||||
|
return torch.stack([z, z, sin, cos], dim=-1) # [0, 0, sin(yaw/2), cos(yaw/2)]
|
||||||
|
|
||||||
|
|
||||||
|
def decompose_tilt(q: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Remove yaw from a quaternion, returning the residual tilt rotation vector.
|
||||||
|
|
||||||
|
Given q = q_yaw * q_tilt (z-yaw first, then body tilt),
|
||||||
|
we compute q_tilt = q_yaw^{-1} * q, then convert to rotation vector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: (..., 4) unit quaternion in world→body convention.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tilt_angles: (..., 3) rotation vector [rx, ry, rz] representing
|
||||||
|
the body's deviation from the heading direction.
|
||||||
|
"""
|
||||||
|
yaw = quat_to_yaw(q)
|
||||||
|
q_yaw = quat_from_yaw(yaw)
|
||||||
|
q_yaw_inv = quat_conjugate(q_yaw)
|
||||||
|
q_tilt = quat_mul(q_yaw_inv, q) # remove yaw
|
||||||
|
q_tilt = quat_normalize(q_tilt)
|
||||||
|
return quat_to_rotvec(q_tilt)
|
||||||
|
|
||||||
|
|
||||||
|
def quat_to_rotvec(q: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert unit quaternion to rotation vector (axis * angle).
|
||||||
|
q: (..., 4) → rotvec: (..., 3)
|
||||||
|
"""
|
||||||
|
q = quat_normalize(q)
|
||||||
|
x, y, z, w = q.unbind(-1)
|
||||||
|
angle = 2.0 * torch.acos(w.clamp(-1.0, 1.0))
|
||||||
|
sin_half = torch.sqrt((1.0 - w * w).clamp(min=eps))
|
||||||
|
scale = angle / sin_half
|
||||||
|
# Avoid division by zero for small angles
|
||||||
|
mask = sin_half > eps
|
||||||
|
rx = torch.where(mask, x * scale, torch.zeros_like(x))
|
||||||
|
ry = torch.where(mask, y * scale, torch.zeros_like(y))
|
||||||
|
rz = torch.where(mask, z * scale, torch.zeros_like(z))
|
||||||
|
return torch.stack([rx, ry, rz], dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── Velocity transformation ────────────────────────────
|
||||||
|
|
||||||
|
def world_vel_to_body(
|
||||||
|
v_world: torch.Tensor,
|
||||||
|
q_world_to_body: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Transform world-frame velocity to body-frame velocity.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
1. Extract yaw from q_world_to_body.
|
||||||
|
2. Build pure-yaw quaternion q_yaw.
|
||||||
|
3. Remove yaw from velocity: v_yaw_compensated = q_yaw^{-1} * v_world
|
||||||
|
4. Rotate to body frame: v_body = q_tilt^{-1} * v_yaw_compensated
|
||||||
|
where q_tilt = q_yaw^{-1} * q_world_to_body
|
||||||
|
|
||||||
|
Args:
|
||||||
|
v_world: (..., 3) world-frame linear velocity [vx, vy, vz]
|
||||||
|
q_world_to_body: (..., 4) world→body unit quaternion
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
v_body: (..., 3) body-frame linear velocity
|
||||||
|
"""
|
||||||
|
yaw = quat_to_yaw(q_world_to_body)
|
||||||
|
q_yaw = quat_from_yaw(yaw)
|
||||||
|
q_yaw_inv = quat_conjugate(q_yaw)
|
||||||
|
|
||||||
|
# Step 1: remove yaw from velocity (rotate to yaw-aligned intermediate frame)
|
||||||
|
v_yaw_comp = quat_rotate(q_yaw_inv, v_world)
|
||||||
|
|
||||||
|
# Step 2: compute tilt quaternion
|
||||||
|
q_tilt = quat_mul(q_yaw_inv, q_world_to_body)
|
||||||
|
q_tilt = quat_normalize(q_tilt)
|
||||||
|
q_tilt_inv = quat_conjugate(q_tilt)
|
||||||
|
|
||||||
|
# Step 3: rotate to body frame
|
||||||
|
v_body = quat_rotate(q_tilt_inv, v_yaw_comp)
|
||||||
|
return v_body
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────── NumPy wrappers (for transforms.py) ────────────────────────────
|
||||||
|
|
||||||
|
def decompose_tilt_np(q: np.ndarray) -> np.ndarray:
|
||||||
|
"""NumPy version of decompose_tilt."""
|
||||||
|
q_t = torch.from_numpy(q)
|
||||||
|
tilt = decompose_tilt(q_t)
|
||||||
|
return tilt.numpy()
|
||||||
|
|
||||||
|
|
||||||
|
def world_vel_to_body_np(v_world: np.ndarray, q: np.ndarray) -> np.ndarray:
|
||||||
|
"""NumPy version of world_vel_to_body."""
|
||||||
|
v_t = torch.from_numpy(v_world)
|
||||||
|
q_t = torch.from_numpy(q)
|
||||||
|
vb = world_vel_to_body(v_t, q_t)
|
||||||
|
return vb.numpy()
|
||||||
143
start_ros_container.sh
Normal file
143
start_ros_container.sh
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# ROS Podman容器启动脚本
|
||||||
|
# 作者自动生成
|
||||||
|
# 功能:启动ROS容器,挂载当前目录,设置环境变量,安装pip3,错误不退出
|
||||||
|
|
||||||
|
set +e # 遇到错误时不退出
|
||||||
|
|
||||||
|
# 颜色定义(用于美化输出)
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# 配置参数
|
||||||
|
CONTAINER_IMAGE="69a38b2c0905" # ROS Noetic Desktop Full镜像ID
|
||||||
|
WORKSPACE_DIR="$(pwd)" # 当前工作目录
|
||||||
|
MOUNT_POINT="/mnt" # 容器内挂载点
|
||||||
|
|
||||||
|
# ROS环境变量设置
|
||||||
|
ROS_ENV_VARS=(
|
||||||
|
"ROS_DISTRO=noetic"
|
||||||
|
"ROS_PYTHON_VERSION=3"
|
||||||
|
"ROS_VERSION=1"
|
||||||
|
"ROS_ETC_DIR=/opt/ros/noetic/etc/ros"
|
||||||
|
"ROS_ROOT=/opt/ros/noetic/share/ros"
|
||||||
|
"ROS_PACKAGE_PATH=/opt/ros/noetic/share"
|
||||||
|
"PYTHONIOENCODING=utf-8"
|
||||||
|
"TZ=Asia/Shanghai"
|
||||||
|
)
|
||||||
|
|
||||||
|
echo -e "${GREEN}========================================${NC}"
|
||||||
|
echo -e "${GREEN}ROS Podman容器启动脚本${NC}"
|
||||||
|
echo -e "${GREEN}========================================${NC}"
|
||||||
|
|
||||||
|
# 检查Podman是否安装
|
||||||
|
if ! command -v podman &> /dev/null; then
|
||||||
|
echo -e "${RED}错误: Podman未安装,请先安装Podman${NC}"
|
||||||
|
echo -e "${YELLOW}Ubuntu/Debian: sudo apt-get install podman${NC}"
|
||||||
|
echo -e "${YELLOW}CentOS/RHEL: sudo yum install podman${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 检查镜像是否存在
|
||||||
|
echo -e "${YELLOW}检查容器镜像...${NC}"
|
||||||
|
if ! podman image exists $CONTAINER_IMAGE; then
|
||||||
|
echo -e "${RED}警告: 镜像 $CONTAINER_IMAGE 不存在${NC}"
|
||||||
|
echo -e "${YELLOW}可用的ROS镜像:${NC}"
|
||||||
|
podman image list | grep ros
|
||||||
|
echo -e "${YELLOW}继续使用指定镜像,如果启动失败请更换镜像ID${NC}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 构建环境变量参数
|
||||||
|
ENV_ARGS=""
|
||||||
|
for env_var in "${ROS_ENV_VARS[@]}"; do
|
||||||
|
ENV_ARGS="$ENV_ARGS -e $env_var"
|
||||||
|
done
|
||||||
|
|
||||||
|
# 添加额外的环境变量(如果用户需要)
|
||||||
|
ENV_ARGS="$ENV_ARGS -e DISPLAY=$DISPLAY" # 支持GUI应用
|
||||||
|
ENV_ARGS="$ENV_ARGS -e QT_X11_NO_MITSHM=1"
|
||||||
|
|
||||||
|
# X11支持(用于GUI应用)
|
||||||
|
if [ -n "$DISPLAY" ]; then
|
||||||
|
echo -e "${YELLOW}启用X11支持,用于GUI应用...${NC}"
|
||||||
|
xhost +local:root 2>/dev/null
|
||||||
|
ENV_ARGS="$ENV_ARGS --volume /tmp/.X11-unix:/tmp/.X11-unix:rw"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "${GREEN}配置信息:${NC}"
|
||||||
|
echo -e " 工作目录: $WORKSPACE_DIR"
|
||||||
|
echo -e " 挂载点: $MOUNT_POINT"
|
||||||
|
echo -e " 镜像ID: $CONTAINER_IMAGE"
|
||||||
|
echo -e " ROS发行版: noetic"
|
||||||
|
|
||||||
|
echo -e "${YELLOW}正在启动容器...${NC}"
|
||||||
|
|
||||||
|
# 启动容器的命令
|
||||||
|
# 使用bash -c来执行多个命令,确保pip3安装即使失败也不会退出容器
|
||||||
|
PODMAN_CMD="podman run -it \
|
||||||
|
--rm \
|
||||||
|
--name ros_noetic_container_$(date +%s) \
|
||||||
|
-v $WORKSPACE_DIR:$MOUNT_POINT:rw \
|
||||||
|
$ENV_ARGS \
|
||||||
|
$CONTAINER_IMAGE \
|
||||||
|
/bin/bash -c \"\
|
||||||
|
echo '========================================' && \
|
||||||
|
echo 'ROS容器已启动' && \
|
||||||
|
echo '========================================' && \
|
||||||
|
echo '工作目录已挂载到: $MOUNT_POINT' && \
|
||||||
|
echo '当前ROS版本: ' && \
|
||||||
|
echo \\\$ROS_DISTRO && \
|
||||||
|
echo '' && \
|
||||||
|
echo '正在检查并安装pip3...' && \
|
||||||
|
if command -v pip3 &> /dev/null; then \
|
||||||
|
echo 'pip3已安装,版本: ' && \
|
||||||
|
pip3 --version; \
|
||||||
|
else \
|
||||||
|
echo 'pip3未安装,正在安装...' && \
|
||||||
|
apt-get update 2>/dev/null && \
|
||||||
|
apt-get install -y python3-pip 2>/dev/null; \
|
||||||
|
if [ \\\$? -eq 0 ]; then \
|
||||||
|
echo 'pip3安装成功!'; \
|
||||||
|
pip3 --version; \
|
||||||
|
else \
|
||||||
|
echo '警告: pip3安装失败,请手动安装'; \
|
||||||
|
fi; \
|
||||||
|
fi && \
|
||||||
|
echo '' && \
|
||||||
|
echo '========================================' && \
|
||||||
|
echo '环境变量已设置:' && \
|
||||||
|
env | grep ROS_ && \
|
||||||
|
echo '========================================' && \
|
||||||
|
echo '容器已准备就绪,进入交互式shell...' && \
|
||||||
|
echo '提示: 输入exit退出容器' && \
|
||||||
|
echo '========================================' && \
|
||||||
|
pip config set global.index-url https://mirrors.ustc.edu.cn/pypi/simple && \
|
||||||
|
source /ros_entrypoint.sh && \
|
||||||
|
cd $MOUNT_POINT && \
|
||||||
|
exec /bin/bash -l\""
|
||||||
|
|
||||||
|
# 执行命令
|
||||||
|
echo -e "${GREEN}执行启动命令...${NC}"
|
||||||
|
# echo -e "${YELLOW}命令详情:${NC}"
|
||||||
|
# echo "$PODMAN_CMD"
|
||||||
|
echo -e "${YELLOW}========================================${NC}"
|
||||||
|
|
||||||
|
# 运行容器
|
||||||
|
eval $PODMAN_CMD
|
||||||
|
|
||||||
|
# 检查退出状态
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo -e "${RED}容器已退出(退出码非0)${NC}"
|
||||||
|
else
|
||||||
|
echo -e "${GREEN}容器正常退出${NC}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 清理X11权限
|
||||||
|
if [ -n "$DISPLAY" ]; then
|
||||||
|
xhost -local:root 2>/dev/null
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "${GREEN}脚本执行完成${NC}"
|
||||||
Reference in New Issue
Block a user