- Remove sample-level shuffle before transforms (broke SimulateEvents) - Add _sliding_window_fn: yields overlapping sequences with configurable stride - Add sequence-level shuffle after grouping (preserves temporal coherence) - Add sliding_window_stride to TrainConfig (stride=1 for full overlap) - Update create_train/val_loader and train.py to pass stride - AGENTS.md: document known issues (cross-shard boundary, SimulateEvents state) - AGENTS.md: add cuda:7 device preference Generated by Mistral Vibe (deepseek-v4-flash). Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
8.2 KiB
8.2 KiB
UZH-FPV Velocity Prediction
从 DAVIS 事件相机灰度图像序列中预测机体速度(body-frame forward/lateral velocity)。
项目结构
uzh_fpv/
├── AGENTS.md # ← 本文件
├── requirements.txt # Python 依赖
├── DATASET_FORMAT.md # 数据集格式详细说明
├── rosbag2wds.py # ROS bag → WebDataset shard 转换脚本
├── batch_convert.sh # 批量转换脚本
├── dataset/ # 数据集(.gitignore 忽略)
│ └── <scene_name>/
│ ├── shard_0000.tar # WebDataset shard(图像+GT)
│ ├── imu_sequence.npz # 完整 IMU 序列
│ └── metadata.json # 元信息
├── src/
│ ├── event_utils.py # EventProcessor: 帧间亮度变化 → 模拟事件帧
│ └── velocity_prediction/ # 主项目代码
│ ├── __init__.py # 模块说明
│ ├── config.py # 路径、模型架构、训练超参数
│ ├── utils.py # 四元数运算(torch + numpy 封装)
│ ├── transforms.py # 数据预处理管线
│ ├── dataset.py # WebDataset 加载 + 序列采样
│ ├── model.py # CNN + PoseMLP + GRU + Head
│ ├── train.py # 训练循环
│ └── evaluate.py # 评估 + 绘图
├── visualize/
│ ├── __init__.py
│ └── visualize_dataset.py # 数据集可视化:叠加位姿信息并生成视频
├── benchmark/
│ ├── __init__.py
│ ├── config.py # 评估配置
│ ├── evaluate.py # 完整评估管线
│ └── benchmark.py # 统一评估入口
├── checkpoints/ # 模型权重(.gitignore 忽略)
├── logs/ # TensorBoard 日志(.gitignore 忽略)
└── videos/ # 可视化输出视频
运行环境
uv run python -m <module> # 使用 uv 虚拟环境运行
依赖见 requirements.txt,核心依赖:PyTorch、WebDataset、OpenCV、NumPy、Matplotlib。
数据集
UZH-FPV 数据集,由 DAVIS 事件相机采集。每个场景目录包含:
| 文件 | 格式 | 内容 |
|---|---|---|
shard_*.tar |
WebDataset | 灰度图 (320×240) + 位姿 + 速度 + 时间戳 |
imu_sequence.npz |
NPZ | 完整 IMU 序列(加速度+角速度) |
metadata.json |
JSON | 场景元信息 |
shard 中每个样本的字段:
| Key | 类型 | 说明 |
|---|---|---|
jpg |
JPEG bytes | 灰度图 320×240 |
ts |
float64 | 时间戳 |
pose |
float32[7] | [x, y, z, qx, qy, qz, qw] 世界→机体四元数 |
vel |
float32[6] | [vx, vy, vz, wx, wy, wz] 世界线速度 + 角速度 |
坐标系:z 轴与重力对齐(水平坐标系)。
场景列表
| 场景 | 帧数 | 类型 |
|---|---|---|
| indoor_forward_3/5/6/7/9/10 | 627~4918 | 室内前飞 |
| indoor_45_2/4/9/12/13/14 | 656~1472 | 室内 45° 飞行 |
| outdoor_forward_1/3/5 | 907~13299 | 室外前飞 |
| outdoor_45_1 | 799 | 室外 45° 飞行 |
模型
架构
Event frame (1, 240, 320) ──► CNN (4 Conv+Pool+GAP, 256-d)
│
Body up (3,) ──► PoseMLP (3→32→64, 64-d) ────────────────────────┤
│
concat (320-d) ← per-frame
│
GRU (hidden=128)
│
Head MLP (128→64→2)
│
[v_right, v_forward]
注意:当前 CNN 编码器被禁用(输出全零),模型仅依赖 PoseMLP + GRU + Head。
输入
events:(B, S, 1, H, W)— 模拟事件帧,值域{-1, 0, +1}tilt:(B, S, 3)— body up 向量(世界 up 旋转到机体坐标系),仅含 pitch/roll,不含 yaw,单位向量
输出
v_body:(B, 2)— 机体坐标系[v_right, v_forward]速度 (m/s)
数据预处理管线
shard_*.tar → DecodeSample → SimulateEvents → ComputeTilt → ComputeBodyVelocity → NormalizeVelocity
- DecodeSample: JPEG → 灰度图 uint8 (H,W);bytes → float32 数组
- SimulateEvents: 帧间亮度变化 → 二值事件帧
{-1, 0, +1} - ComputeTilt: 四元数 (world→odom) → 应用 R_odom_to_body → 旋转 world-up [0,0,1] → body up 向量 (3,)
- ComputeBodyVelocity: 世界速度 → 应用 R_odom_to_body → yaw 补偿(仅去除偏航,保留 tilt)→ 水平面
[v_right, v_forward] - NormalizeVelocity: 归一化
训练配置
- seq_len=8, batch_size=32, epochs=100
- lr=1e-3, AdamW, StepLR (step=30, gamma=0.5)
- Loss: MSELoss
- 训练/验证/测试场景见
config.py
关键命令
# 训练
uv run python -m src.velocity_prediction.train --device cuda:0
# 评估
uv run python -m src.velocity_prediction.evaluate --checkpoint checkpoints/best.pt
# 数据集可视化(单场景)
uv run python -m visualize.visualize_dataset --scene indoor_forward_3 --output videos/scene.mp4
# 数据集可视化(全部场景)
uv run python -m visualize.visualize_dataset --all --output videos/
# 数据集可视化(实时显示)
uv run python -m visualize.visualize_dataset --scene indoor_forward_3 --show
# Benchmark 评估
uv run python -m benchmark.benchmark --checkpoint checkpoints/best.pt
可视化说明
visualize/visualize_dataset.py 在每帧图像上叠加:
- 帧号、时间戳、世界坐标位置
- 欧拉角
[roll, pitch, yaw](从 body 四元数计算) - Body up 向量
[x, y, z] - 机体速度
v_body [forward, lateral] - 世界速度
v_world [vx, vy, vz] - 机体坐标系三轴箭头(左下角)
- 机体速度方向箭头(图像中心)
关键约定
- 四元数格式:
[x, y, z, w](不是[w, x, y, z]) - GT 四元数表示 world→odom(不是 world→body),通过静态 R_odom_to_body 校正
- Body 坐标系(ROS 右手系):
body_x=右, body_y=前, body_z=上 - R_odom_to_body = R_y(45°) @ R_x(90°):先绕 odom_x 转 +90°,再绕 odom_y 转 +45°
- 速度归一化统计量:待重新计算
- 模型预测
[v_right, v_forward](右向和前向速度) - 所有代码在项目根目录下以
uv run python -m <module>运行 - GPU 优先使用
cuda:7,训练时添加--device cuda:7
已知问题
1. 滑窗跨 shard 边界
dataset.py 中滑窗实现基于 WebDataset 串联后的连续流,不感知 shard 边界。当样本恰好处于 shard 末尾时,序列会跨越到下一个 shard 的起始帧。
- 影响:每个 shard 边界处约有
seq_len个序列包含跨 shard 样本(占总数 <1%) - 修复思路:在
_sliding_window_fn中注入 shard 边界标记,遇到边界时清空缓冲区 - 严重程度:低。若 shard 内帧数远大于 seq_len,可忽略
2. SimulateEvents 跨 shard 状态残留
EventProcessor 内部维护 _prev_frame 用于帧差计算。跨 shard 时,新 shard 的第一帧会与上一个 shard 最后一帧计算差,产生错误的事件帧。
- 影响:每个 shard 的第 1 帧事件帧错误,涉及该帧的所有滑窗序列均受影响
- 每 shard 错误帧数:1 帧(加上滑窗放大,约
seq_len个序列各包含此帧) - 修复:在 shard 边界处调用
EventProcessor.reset()。需在_build_pipeline中插入边界信号或改用按 shard 独立处理的方案 - 严重程度:低。每 shard 仅 1 帧,训练数据量大时可忽略