refactor: replace rotation vector with body up vector for tilt input
- Replace body_attitude() with body_up_vector(): rotate world-up [0,0,1] by corrected world→body quaternion to get body up vector (pitch/roll only, no yaw). Matches DiffPhysDrone's env.R[:, 2] approach. - Update ComputeTilt transform to use body_up_vector_np - Update visualize_dataset.py to display Euler angles and body up vector - Update model.py comments and disable CNN (zero output) - Sync AGENTS.md with new architecture description Generated by Mistral Vibe (ds-v4-flash). Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit is contained in:
171
AGENTS.md
Normal file
171
AGENTS.md
Normal file
@@ -0,0 +1,171 @@
|
||||
# UZH-FPV Velocity Prediction
|
||||
|
||||
从 DAVIS 事件相机灰度图像序列中预测**机体速度**(body-frame forward/lateral velocity)。
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
uzh_fpv/
|
||||
├── AGENTS.md # ← 本文件
|
||||
├── requirements.txt # Python 依赖
|
||||
├── DATASET_FORMAT.md # 数据集格式详细说明
|
||||
├── rosbag2wds.py # ROS bag → WebDataset shard 转换脚本
|
||||
├── batch_convert.sh # 批量转换脚本
|
||||
├── dataset/ # 数据集(.gitignore 忽略)
|
||||
│ └── <scene_name>/
|
||||
│ ├── shard_0000.tar # WebDataset shard(图像+GT)
|
||||
│ ├── imu_sequence.npz # 完整 IMU 序列
|
||||
│ └── metadata.json # 元信息
|
||||
├── src/
|
||||
│ ├── event_utils.py # EventProcessor: 帧间亮度变化 → 模拟事件帧
|
||||
│ └── velocity_prediction/ # 主项目代码
|
||||
│ ├── __init__.py # 模块说明
|
||||
│ ├── config.py # 路径、模型架构、训练超参数
|
||||
│ ├── utils.py # 四元数运算(torch + numpy 封装)
|
||||
│ ├── transforms.py # 数据预处理管线
|
||||
│ ├── dataset.py # WebDataset 加载 + 序列采样
|
||||
│ ├── model.py # CNN + PoseMLP + GRU + Head
|
||||
│ ├── train.py # 训练循环
|
||||
│ └── evaluate.py # 评估 + 绘图
|
||||
├── visualize/
|
||||
│ ├── __init__.py
|
||||
│ └── visualize_dataset.py # 数据集可视化:叠加位姿信息并生成视频
|
||||
├── benchmark/
|
||||
│ ├── __init__.py
|
||||
│ ├── config.py # 评估配置
|
||||
│ ├── evaluate.py # 完整评估管线
|
||||
│ └── benchmark.py # 统一评估入口
|
||||
├── checkpoints/ # 模型权重(.gitignore 忽略)
|
||||
├── logs/ # TensorBoard 日志(.gitignore 忽略)
|
||||
└── videos/ # 可视化输出视频
|
||||
```
|
||||
|
||||
## 运行环境
|
||||
|
||||
```bash
|
||||
uv run python -m <module> # 使用 uv 虚拟环境运行
|
||||
```
|
||||
|
||||
依赖见 `requirements.txt`,核心依赖:PyTorch、WebDataset、OpenCV、NumPy、Matplotlib。
|
||||
|
||||
## 数据集
|
||||
|
||||
UZH-FPV 数据集,由 DAVIS 事件相机采集。每个场景目录包含:
|
||||
|
||||
| 文件 | 格式 | 内容 |
|
||||
|------|------|------|
|
||||
| `shard_*.tar` | WebDataset | 灰度图 (320×240) + 位姿 + 速度 + 时间戳 |
|
||||
| `imu_sequence.npz` | NPZ | 完整 IMU 序列(加速度+角速度) |
|
||||
| `metadata.json` | JSON | 场景元信息 |
|
||||
|
||||
shard 中每个样本的字段:
|
||||
|
||||
| Key | 类型 | 说明 |
|
||||
|-----|------|------|
|
||||
| `jpg` | JPEG bytes | 灰度图 320×240 |
|
||||
| `ts` | float64 | 时间戳 |
|
||||
| `pose` | float32[7] | `[x, y, z, qx, qy, qz, qw]` 世界→机体四元数 |
|
||||
| `vel` | float32[6] | `[vx, vy, vz, wx, wy, wz]` 世界线速度 + 角速度 |
|
||||
|
||||
坐标系:z 轴与重力对齐(水平坐标系)。
|
||||
|
||||
### 场景列表
|
||||
|
||||
| 场景 | 帧数 | 类型 |
|
||||
|------|------|------|
|
||||
| indoor_forward_3/5/6/7/9/10 | 627~4918 | 室内前飞 |
|
||||
| indoor_45_2/4/9/12/13/14 | 656~1472 | 室内 45° 飞行 |
|
||||
| outdoor_forward_1/3/5 | 907~13299 | 室外前飞 |
|
||||
| outdoor_45_1 | 799 | 室外 45° 飞行 |
|
||||
|
||||
## 模型
|
||||
|
||||
### 架构
|
||||
|
||||
```
|
||||
Event frame (1, 240, 320) ──► CNN (4 Conv+Pool+GAP, 256-d)
|
||||
│
|
||||
Body up (3,) ──► PoseMLP (3→32→64, 64-d) ────────────────────────┤
|
||||
│
|
||||
concat (320-d) ← per-frame
|
||||
│
|
||||
GRU (hidden=128)
|
||||
│
|
||||
Head MLP (128→64→2)
|
||||
│
|
||||
[v_right, v_forward]
|
||||
```
|
||||
|
||||
**注意**:当前 CNN 编码器被禁用(输出全零),模型仅依赖 `PoseMLP + GRU + Head`。
|
||||
|
||||
### 输入
|
||||
|
||||
- `events`: `(B, S, 1, H, W)` — 模拟事件帧,值域 `{-1, 0, +1}`
|
||||
- `tilt`: `(B, S, 3)` — body up 向量(世界 up 旋转到机体坐标系),仅含 pitch/roll,不含 yaw,单位向量
|
||||
|
||||
### 输出
|
||||
|
||||
- `v_body`: `(B, 2)` — 机体坐标系 `[v_right, v_forward]` 速度 (m/s)
|
||||
|
||||
### 数据预处理管线
|
||||
|
||||
```
|
||||
shard_*.tar → DecodeSample → SimulateEvents → ComputeTilt → ComputeBodyVelocity → NormalizeVelocity
|
||||
```
|
||||
|
||||
1. **DecodeSample**: JPEG → 灰度图 uint8 (H,W);bytes → float32 数组
|
||||
2. **SimulateEvents**: 帧间亮度变化 → 二值事件帧 `{-1, 0, +1}`
|
||||
3. **ComputeTilt**: 四元数 (world→odom) → 应用 R_odom_to_body → 旋转 world-up [0,0,1] → body up 向量 (3,)
|
||||
4. **ComputeBodyVelocity**: 世界速度 → 应用 R_odom_to_body → yaw 补偿(仅去除偏航,保留 tilt)→ 水平面 `[v_right, v_forward]`
|
||||
5. **NormalizeVelocity**: 归一化
|
||||
|
||||
### 训练配置
|
||||
|
||||
- seq_len=8, batch_size=32, epochs=100
|
||||
- lr=1e-3, AdamW, StepLR (step=30, gamma=0.5)
|
||||
- Loss: MSELoss
|
||||
- 训练/验证/测试场景见 `config.py`
|
||||
|
||||
## 关键命令
|
||||
|
||||
```bash
|
||||
# 训练
|
||||
uv run python -m src.velocity_prediction.train --device cuda:0
|
||||
|
||||
# 评估
|
||||
uv run python -m src.velocity_prediction.evaluate --checkpoint checkpoints/best.pt
|
||||
|
||||
# 数据集可视化(单场景)
|
||||
uv run python -m visualize.visualize_dataset --scene indoor_forward_3 --output videos/scene.mp4
|
||||
|
||||
# 数据集可视化(全部场景)
|
||||
uv run python -m visualize.visualize_dataset --all --output videos/
|
||||
|
||||
# 数据集可视化(实时显示)
|
||||
uv run python -m visualize.visualize_dataset --scene indoor_forward_3 --show
|
||||
|
||||
# Benchmark 评估
|
||||
uv run python -m benchmark.benchmark --checkpoint checkpoints/best.pt
|
||||
```
|
||||
|
||||
## 可视化说明
|
||||
|
||||
`visualize/visualize_dataset.py` 在每帧图像上叠加:
|
||||
|
||||
- 帧号、时间戳、世界坐标位置
|
||||
- 欧拉角 `[roll, pitch, yaw]`(从 body 四元数计算)
|
||||
- Body up 向量 `[x, y, z]`
|
||||
- 机体速度 `v_body [forward, lateral]`
|
||||
- 世界速度 `v_world [vx, vy, vz]`
|
||||
- 机体坐标系三轴箭头(左下角)
|
||||
- 机体速度方向箭头(图像中心)
|
||||
|
||||
## 关键约定
|
||||
|
||||
- 四元数格式:`[x, y, z, w]`(不是 `[w, x, y, z]`)
|
||||
- GT 四元数表示 **world→odom**(不是 world→body),通过静态 R_odom_to_body 校正
|
||||
- Body 坐标系(ROS 右手系):`body_x=右, body_y=前, body_z=上`
|
||||
- R_odom_to_body = R_y(45°) @ R_x(90°):先绕 odom_x 转 +90°,再绕 odom_y 转 +45°
|
||||
- 速度归一化统计量:待重新计算
|
||||
- 模型预测 `[v_right, v_forward]`(右向和前向速度)
|
||||
- 所有代码在项目根目录下以 `uv run python -m <module>` 运行
|
||||
Reference in New Issue
Block a user