Files
uzh-fpv-sv-test/AGENTS.md
CaoWangrenbo cb9936542e feat: replace non-overlapping windows with sliding-window sequence sampling
- Remove sample-level shuffle before transforms (broke SimulateEvents)
- Add _sliding_window_fn: yields overlapping sequences with configurable stride
- Add sequence-level shuffle after grouping (preserves temporal coherence)
- Add sliding_window_stride to TrainConfig (stride=1 for full overlap)
- Update create_train/val_loader and train.py to pass stride
- AGENTS.md: document known issues (cross-shard boundary, SimulateEvents state)
- AGENTS.md: add cuda:7 device preference

Generated by Mistral Vibe (deepseek-v4-flash).
Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
2026-06-05 16:45:24 +08:00

8.2 KiB
Raw Permalink Blame History

UZH-FPV Velocity Prediction

从 DAVIS 事件相机灰度图像序列中预测机体速度body-frame forward/lateral velocity

项目结构

uzh_fpv/
├── AGENTS.md                          # ← 本文件
├── requirements.txt                   # Python 依赖
├── DATASET_FORMAT.md                  # 数据集格式详细说明
├── rosbag2wds.py                      # ROS bag → WebDataset shard 转换脚本
├── batch_convert.sh                   # 批量转换脚本
├── dataset/                           # 数据集(.gitignore 忽略)
│   └── <scene_name>/
│       ├── shard_0000.tar             # WebDataset shard图像+GT
│       ├── imu_sequence.npz           # 完整 IMU 序列
│       └── metadata.json              # 元信息
├── src/
│   ├── event_utils.py                 # EventProcessor: 帧间亮度变化 → 模拟事件帧
│   └── velocity_prediction/           # 主项目代码
│       ├── __init__.py                # 模块说明
│       ├── config.py                  # 路径、模型架构、训练超参数
│       ├── utils.py                   # 四元数运算torch + numpy 封装)
│       ├── transforms.py              # 数据预处理管线
│       ├── dataset.py                 # WebDataset 加载 + 序列采样
│       ├── model.py                   # CNN + PoseMLP + GRU + Head
│       ├── train.py                   # 训练循环
│       └── evaluate.py                # 评估 + 绘图
├── visualize/
│   ├── __init__.py
│   └── visualize_dataset.py           # 数据集可视化:叠加位姿信息并生成视频
├── benchmark/
│   ├── __init__.py
│   ├── config.py                      # 评估配置
│   ├── evaluate.py                    # 完整评估管线
│   └── benchmark.py                   # 统一评估入口
├── checkpoints/                       # 模型权重(.gitignore 忽略)
├── logs/                              # TensorBoard 日志(.gitignore 忽略)
└── videos/                            # 可视化输出视频

运行环境

uv run python -m <module>    # 使用 uv 虚拟环境运行

依赖见 requirements.txt核心依赖PyTorch、WebDataset、OpenCV、NumPy、Matplotlib。

数据集

UZH-FPV 数据集,由 DAVIS 事件相机采集。每个场景目录包含:

文件 格式 内容
shard_*.tar WebDataset 灰度图 (320×240) + 位姿 + 速度 + 时间戳
imu_sequence.npz NPZ 完整 IMU 序列(加速度+角速度)
metadata.json JSON 场景元信息

shard 中每个样本的字段:

Key 类型 说明
jpg JPEG bytes 灰度图 320×240
ts float64 时间戳
pose float32[7] [x, y, z, qx, qy, qz, qw] 世界→机体四元数
vel float32[6] [vx, vy, vz, wx, wy, wz] 世界线速度 + 角速度

坐标系z 轴与重力对齐(水平坐标系)。

场景列表

场景 帧数 类型
indoor_forward_3/5/6/7/9/10 627~4918 室内前飞
indoor_45_2/4/9/12/13/14 656~1472 室内 45° 飞行
outdoor_forward_1/3/5 907~13299 室外前飞
outdoor_45_1 799 室外 45° 飞行

模型

架构

Event frame (1, 240, 320) ──► CNN (4 Conv+Pool+GAP, 256-d)
                                                                  │
Body up (3,) ──► PoseMLP (3→32→64, 64-d) ────────────────────────┤
                                                                  │
                                                          concat (320-d)  ← per-frame
                                                                  │
                                                          GRU (hidden=128)
                                                                  │
                                                          Head MLP (128→64→2)
                                                                  │
                                                          [v_right, v_forward]

注意:当前 CNN 编码器被禁用(输出全零),模型仅依赖 PoseMLP + GRU + Head

输入

  • events: (B, S, 1, H, W) — 模拟事件帧,值域 {-1, 0, +1}
  • tilt: (B, S, 3) — body up 向量(世界 up 旋转到机体坐标系),仅含 pitch/roll不含 yaw单位向量

输出

  • v_body: (B, 2) — 机体坐标系 [v_right, v_forward] 速度 (m/s)

数据预处理管线

shard_*.tar → DecodeSample → SimulateEvents → ComputeTilt → ComputeBodyVelocity → NormalizeVelocity
  1. DecodeSample: JPEG → 灰度图 uint8 (H,W)bytes → float32 数组
  2. SimulateEvents: 帧间亮度变化 → 二值事件帧 {-1, 0, +1}
  3. ComputeTilt: 四元数 (world→odom) → 应用 R_odom_to_body → 旋转 world-up [0,0,1] → body up 向量 (3,)
  4. ComputeBodyVelocity: 世界速度 → 应用 R_odom_to_body → yaw 补偿(仅去除偏航,保留 tilt→ 水平面 [v_right, v_forward]
  5. NormalizeVelocity: 归一化

训练配置

  • seq_len=8, batch_size=32, epochs=100
  • lr=1e-3, AdamW, StepLR (step=30, gamma=0.5)
  • Loss: MSELoss
  • 训练/验证/测试场景见 config.py

关键命令

# 训练
uv run python -m src.velocity_prediction.train --device cuda:0

# 评估
uv run python -m src.velocity_prediction.evaluate --checkpoint checkpoints/best.pt

# 数据集可视化(单场景)
uv run python -m visualize.visualize_dataset --scene indoor_forward_3 --output videos/scene.mp4

# 数据集可视化(全部场景)
uv run python -m visualize.visualize_dataset --all --output videos/

# 数据集可视化(实时显示)
uv run python -m visualize.visualize_dataset --scene indoor_forward_3 --show

# Benchmark 评估
uv run python -m benchmark.benchmark --checkpoint checkpoints/best.pt

可视化说明

visualize/visualize_dataset.py 在每帧图像上叠加:

  • 帧号、时间戳、世界坐标位置
  • 欧拉角 [roll, pitch, yaw](从 body 四元数计算)
  • Body up 向量 [x, y, z]
  • 机体速度 v_body [forward, lateral]
  • 世界速度 v_world [vx, vy, vz]
  • 机体坐标系三轴箭头(左下角)
  • 机体速度方向箭头(图像中心)

关键约定

  • 四元数格式:[x, y, z, w](不是 [w, x, y, z]
  • GT 四元数表示 world→odom(不是 world→body通过静态 R_odom_to_body 校正
  • Body 坐标系ROS 右手系):body_x=右, body_y=前, body_z=上
  • R_odom_to_body = R_y(45°) @ R_x(90°):先绕 odom_x 转 +90°再绕 odom_y 转 +45°
  • 速度归一化统计量:待重新计算
  • 模型预测 [v_right, v_forward](右向和前向速度)
  • 所有代码在项目根目录下以 uv run python -m <module> 运行
  • GPU 优先使用 cuda:7,训练时添加 --device cuda:7

已知问题

1. 滑窗跨 shard 边界

dataset.py 中滑窗实现基于 WebDataset 串联后的连续流,不感知 shard 边界。当样本恰好处于 shard 末尾时,序列会跨越到下一个 shard 的起始帧。

  • 影响:每个 shard 边界处约有 seq_len 个序列包含跨 shard 样本(占总数 <1%
  • 修复思路:在 _sliding_window_fn 中注入 shard 边界标记,遇到边界时清空缓冲区
  • 严重程度:低。若 shard 内帧数远大于 seq_len可忽略

2. SimulateEvents 跨 shard 状态残留

EventProcessor 内部维护 _prev_frame 用于帧差计算。跨 shard 时,新 shard 的第一帧会与上一个 shard 最后一帧计算差,产生错误的事件帧。

  • 影响:每个 shard 的第 1 帧事件帧错误,涉及该帧的所有滑窗序列均受影响
  • 每 shard 错误帧数1 帧(加上滑窗放大,约 seq_len 个序列各包含此帧)
  • 修复:在 shard 边界处调用 EventProcessor.reset()。需在 _build_pipeline 中插入边界信号或改用按 shard 独立处理的方案
  • 严重程度:低。每 shard 仅 1 帧,训练数据量大时可忽略