initial commit
This commit is contained in:
103
src/event_utils.py
Normal file
103
src/event_utils.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Event camera simulation utilities for ML preprocessing.
|
||||
|
||||
Core logic extracted from EventCameraSimulator (test.py).
|
||||
Designed for frame-by-frame preprocessing in training pipelines.
|
||||
|
||||
Output:
|
||||
events_binary: (-1, 0, +1) hard threshold decision
|
||||
events_strength: [-1, 1] continuous change intensity (clipped & normalized)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
|
||||
class EventProcessor:
|
||||
"""Lightweight event computation module. No visualization, no OpenCV dependency."""
|
||||
|
||||
def __init__(self, threshold=0.1, use_log=True, auto_threshold=False):
|
||||
self.threshold = threshold
|
||||
self.use_log = use_log
|
||||
self.auto_threshold = auto_threshold
|
||||
|
||||
self.prev_brightness = None
|
||||
self.change_history = deque(maxlen=100)
|
||||
self.threshold_scale = 1.5
|
||||
|
||||
def reset(self):
|
||||
"""Clear temporal state (call on video/reset)."""
|
||||
self.prev_brightness = None
|
||||
self.change_history.clear()
|
||||
|
||||
def _to_grayscale(self, frame):
|
||||
"""Convert frame to grayscale float32."""
|
||||
if frame.ndim == 3:
|
||||
# RGB/HWC -> gray via luminance weights
|
||||
gray = 0.299 * frame[..., 0] + 0.587 * frame[..., 1] + 0.114 * frame[..., 2]
|
||||
else:
|
||||
gray = frame
|
||||
return gray.astype(np.float32)
|
||||
|
||||
def _compute_change(self, brightness):
|
||||
"""Compute log or linear brightness change."""
|
||||
if self.use_log:
|
||||
eps = 1e-3
|
||||
return np.log(brightness + eps) - np.log(self.prev_brightness + eps)
|
||||
else:
|
||||
return brightness - self.prev_brightness
|
||||
|
||||
def _update_auto_threshold(self, change):
|
||||
"""Adapt threshold based on global change statistics."""
|
||||
abs_change = np.abs(change)
|
||||
mean_change = np.mean(abs_change)
|
||||
self.change_history.append(mean_change)
|
||||
|
||||
if len(self.change_history) > 10:
|
||||
avg_change = np.mean(self.change_history)
|
||||
new_threshold = max(avg_change * self.threshold_scale, 0.01)
|
||||
self.threshold = self.threshold * 0.9 + new_threshold * 0.1
|
||||
|
||||
if self.use_log:
|
||||
self.threshold = np.clip(self.threshold, 0.01, 0.5)
|
||||
else:
|
||||
self.threshold = np.clip(self.threshold, 1, 50)
|
||||
|
||||
def __call__(self, frame):
|
||||
"""
|
||||
Process a single frame.
|
||||
|
||||
Args:
|
||||
frame: np.ndarray, shape (H, W) or (H, W, C), uint8 or float.
|
||||
|
||||
Returns:
|
||||
events_binary: np.ndarray (H, W), values in {-1, 0, +1}
|
||||
events_strength: np.ndarray (H, W), values in [-1, 1]
|
||||
event_count: int, number of non-zero events
|
||||
"""
|
||||
brightness = self._to_grayscale(frame)
|
||||
|
||||
# First frame — initialise, no events
|
||||
if self.prev_brightness is None:
|
||||
self.prev_brightness = brightness
|
||||
h, w = brightness.shape
|
||||
return np.zeros((h, w), dtype=np.int8), np.zeros((h, w), dtype=np.float32), 0
|
||||
|
||||
change = self._compute_change(brightness)
|
||||
|
||||
if self.auto_threshold:
|
||||
self._update_auto_threshold(change)
|
||||
|
||||
# Binary events
|
||||
events_binary = np.zeros_like(brightness, dtype=np.int8)
|
||||
events_binary[change > self.threshold] = 1
|
||||
events_binary[change < -self.threshold] = -1
|
||||
|
||||
# Continuous strength: clip to [-threshold, threshold] then normalise to [-1, 1]
|
||||
events_strength = np.clip(change, -self.threshold, self.threshold) / self.threshold
|
||||
|
||||
event_count = int(np.count_nonzero(events_binary))
|
||||
|
||||
self.prev_brightness = brightness
|
||||
|
||||
return events_binary, events_strength, event_count
|
||||
198
src/velocity_prediction/README.md
Normal file
198
src/velocity_prediction/README.md
Normal file
@@ -0,0 +1,198 @@
|
||||
# Velocity Prediction from Event Frames + Attitude
|
||||
|
||||
基于 UZH-FPV 数据集,通过模拟事件帧 + 姿态输入预测机体速度(机体系 vx, vy)。
|
||||
|
||||
---
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
uzh_fpv/
|
||||
├── dataset/ # UZH-FPV 数据集(WebDataset shards)
|
||||
│ ├── indoor_forward_3/
|
||||
│ ├── indoor_forward_5/
|
||||
│ ├── ...
|
||||
│ └── outdoor_45_1/
|
||||
├── src/
|
||||
│ ├── event_utils.py # 模拟事件帧生成(已有模块)
|
||||
│ └── velocity_prediction/ # 本工程
|
||||
│ ├── __init__.py
|
||||
│ ├── config.py # 全局配置
|
||||
│ ├── utils.py # 四元数运算、坐标变换
|
||||
│ ├── transforms.py # 数据预处理管线
|
||||
│ ├── dataset.py # WebDataset 加载 + 序列采样
|
||||
│ ├── model.py # 网络模型定义
|
||||
│ ├── train.py # 训练入口
|
||||
│ └── evaluate.py # 评估与可视化
|
||||
├── DATASET_FORMAT.md # 数据集格式说明
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 网络架构
|
||||
|
||||
```
|
||||
事件帧 (1, 240, 320) ──► CNN (4层 Conv+Pool+GAP, 256-d)
|
||||
│
|
||||
姿态 tilt_angles (3,) ──► PoseMLP (3→32→64, 64-d) ───┤
|
||||
│
|
||||
concat (320-d) ← 每帧融合
|
||||
│
|
||||
GRU (hidden=128)
|
||||
│
|
||||
Head MLP (128→64→2)
|
||||
│
|
||||
[vx_body, vy_body]
|
||||
```
|
||||
|
||||
### 各模块参数
|
||||
|
||||
| 模块 | 参数量 | 说明 |
|
||||
|------|--------|------|
|
||||
| CNN Encoder | 387,840 | 4 层 Conv2D(3×3) + BN + ReLU + MaxPool(2×2) + GAP,通道 1→32→64→128→256 |
|
||||
| PoseMLP | 2,240 | 3→32→64,两层全连接 |
|
||||
| GRU | 172,800 | 单层,input=320, hidden=128 |
|
||||
| Head MLP | 8,386 | 128→64→2 |
|
||||
| **总计** | **~571K** | FP32 约 2.3 MB |
|
||||
|
||||
---
|
||||
|
||||
## 数据预处理
|
||||
|
||||
### 输入变换管线(transforms.py)
|
||||
|
||||
每个 WebDataset 样本依次经过:
|
||||
|
||||
1. **DecodeSample** — JPEG 解码为灰度图 (H, W),pose/vel 字节转 numpy
|
||||
2. **SimulateEvents** — `EventProcessor` 计算帧间亮度变化,输出二值事件帧 (1, H, W),值域 {-1, 0, 1}
|
||||
3. **ComputeTilt** — 从四元数 `[qx, qy, qz, qw]` 中提取偏航角 yaw,移除后得到 tilt 旋转向量 (3,)
|
||||
4. **ComputeBodyVelocity** — 世界系速度 `[vx, vy, vz]` → 补偿偏航 → 转到机体系,取 `[vx_body, vy_body]` (2,)
|
||||
|
||||
### 坐标系变换逻辑(utils.py)
|
||||
|
||||
```
|
||||
输入: q_world_to_body (四元数), v_world (3,)
|
||||
|
||||
Step 1: 从四元数分解偏航角 yaw
|
||||
Step 2: 构造纯偏航四元数 q_yaw
|
||||
Step 3: q_tilt = q_yaw^{-1} * q_world_to_body → tilt_angles (旋转向量)
|
||||
Step 4: v_yaw_comp = q_yaw^{-1} * v_world → 偏航补偿
|
||||
Step 5: v_body = q_tilt^{-1} * v_yaw_comp → 转到机体系
|
||||
Step 6: 取 v_body[:2] 作为回归目标
|
||||
```
|
||||
|
||||
### 序列采样(dataset.py)
|
||||
|
||||
- 从 shard 中取连续 `seq_len` 帧(默认 8 帧)构成一个训练样本
|
||||
- 输出 batch 维度:`(B, S, 1, H, W)` 事件帧, `(B, S, 3)` tilt, `(B, S, 2)` 速度 GT
|
||||
- 模型预测最后一帧的速度
|
||||
|
||||
### 数据集划分
|
||||
|
||||
| 集 | 场景 |
|
||||
|----|------|
|
||||
| **训练** | indoor_forward_3/5/6/7/9/10, indoor_45_2/4/9/12 |
|
||||
| **验证** | indoor_45_13/14 |
|
||||
| **测试** | outdoor_forward_1/3/5, outdoor_45_1 |
|
||||
|
||||
---
|
||||
|
||||
## 运行方式
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
# 使用 uv(推荐)
|
||||
uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
uv pip install webdataset opencv-python matplotlib tensorboard numpy scipy
|
||||
|
||||
# 或使用 pip
|
||||
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install webdataset opencv-python matplotlib tensorboard numpy scipy
|
||||
```
|
||||
|
||||
### 2. 训练
|
||||
|
||||
```bash
|
||||
# 从项目根目录执行
|
||||
cd /home/hexone/Workplace/ws_asmo/uzh_fpv
|
||||
|
||||
# 激活虚拟环境后
|
||||
python -m src.velocity_prediction.train
|
||||
```
|
||||
|
||||
训练参数在 `config.py` 的 `TrainConfig` 中配置,关键参数:
|
||||
|
||||
| 参数 | 默认值 | 说明 |
|
||||
|------|--------|------|
|
||||
| seq_len | 8 | 每序列帧数 |
|
||||
| batch_size | 32 | 批次大小 |
|
||||
| epochs | 100 | 训练轮数 |
|
||||
| lr | 1e-3 | 学习率 |
|
||||
| event_threshold | 0.1 | 事件模拟阈值 |
|
||||
|
||||
训练输出:
|
||||
- `logs/` — TensorBoard 日志
|
||||
- `checkpoints/` — 模型 checkpoint(每 10 轮保存 + 最优模型 `best.pt`)
|
||||
|
||||
### 3. 评估
|
||||
|
||||
```bash
|
||||
python -m src.velocity_prediction.evaluate --checkpoint checkpoints/best.pt
|
||||
```
|
||||
|
||||
输出:
|
||||
- 控制台打印 RMSE(vx, vy, xy)
|
||||
- `eval_velocity.png` — 预测 vs GT 时序对比图
|
||||
- `eval_scatter.png` — 散点图
|
||||
|
||||
### 4. 模型参数量检查
|
||||
|
||||
```bash
|
||||
python -m src.velocity_prediction.model
|
||||
```
|
||||
|
||||
输出类似:
|
||||
```
|
||||
Total trainable parameters: 571,266 (0.571 M)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 模型导出与部署
|
||||
|
||||
### 导出 TorchScript
|
||||
|
||||
```python
|
||||
import torch
|
||||
from src.velocity_prediction.model import VelocityPredictionModel
|
||||
|
||||
model = VelocityPredictionModel()
|
||||
model.load_state_dict(torch.load("checkpoints/best.pt")["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
# 导出
|
||||
traced = torch.jit.trace(model, (torch.randn(1, 8, 1, 240, 320), torch.randn(1, 8, 3)))
|
||||
traced.save("velocity_model.pt")
|
||||
```
|
||||
|
||||
### RV1106 部署注意事项
|
||||
|
||||
- 模型 ~0.57M 参数,FP32 ~2.3 MB,INT8 量化后 ~0.6 MB
|
||||
- CNN 部分可跑 NPU(0.5 TOPS),GRU 需 ARM CPU 执行
|
||||
- 若需裁剪:`config.py` 中 `CNNConfig.channels` 减半(32→16, 64→32, 128→64, 256→128),或 `GRUConfig.hidden_size` 从 128 降至 64
|
||||
|
||||
---
|
||||
|
||||
## 文件职责速查
|
||||
|
||||
| 文件 | 职责 | 关键类/函数 |
|
||||
|------|------|------------|
|
||||
| `config.py` | 所有可配置参数 | `ModelConfig`, `TrainConfig`, `TRAIN_SCENES` |
|
||||
| `utils.py` | 四元数运算、坐标变换 | `decompose_tilt`, `world_vel_to_body` |
|
||||
| `transforms.py` | 数据预处理管线 | `DecodeSample`, `SimulateEvents`, `ComputeTilt`, `ComputeBodyVelocity` |
|
||||
| `dataset.py` | WebDataset 加载 + 序列采样 | `create_train_loader`, `create_val_loader` |
|
||||
| `model.py` | 网络模型 | `VelocityPredictionModel`, `CNNEncoder`, `PoseMLP` |
|
||||
| `train.py` | 训练循环 | `train_one_epoch`, `validate`, `main` |
|
||||
| `evaluate.py` | 评估与可视化 | `evaluate`, `plot_results`, `plot_scatter` |
|
||||
7
src/velocity_prediction/__init__.py
Normal file
7
src/velocity_prediction/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Velocity prediction from simulated event frames + attitude.
|
||||
|
||||
Pipeline:
|
||||
Event frame (1, H, W) ──► CNN ──┐
|
||||
Tilt angles (3,) ──► MLP ──┤──► concat ──► GRU ──► Head ──► [vx_body, vy_body]
|
||||
"""
|
||||
118
src/velocity_prediction/config.py
Normal file
118
src/velocity_prediction/config.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Global configuration for velocity prediction.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# ──────────────────────────── Dataset paths ────────────────────────────
|
||||
|
||||
DATASET_ROOT = Path(__file__).resolve().parents[2] / "dataset"
|
||||
|
||||
# Velocity normalization stats (computed from training set)
|
||||
VELOCITY_MEAN = [0.859184, -0.783945] # [vx, vy]
|
||||
VELOCITY_STD = [2.244513, 1.088335] # [vx, vy]
|
||||
|
||||
# TRAIN_SCENES = [
|
||||
# "indoor_forward_3", "indoor_forward_5", "indoor_forward_6",
|
||||
# "indoor_forward_7", "indoor_forward_9", "indoor_forward_10",
|
||||
# "indoor_45_2", "indoor_45_4", "indoor_45_9", "indoor_45_12",
|
||||
# ]
|
||||
# VAL_SCENES = [
|
||||
# "indoor_45_13", "indoor_45_14",
|
||||
# ]
|
||||
# TEST_SCENES = [
|
||||
# "outdoor_forward_1", "outdoor_forward_3", "outdoor_forward_5",
|
||||
# "outdoor_45_1",
|
||||
# ]
|
||||
|
||||
TRAIN_SCENES = [
|
||||
"indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
||||
"indoor_forward_5", "indoor_forward_6", # Medium
|
||||
"outdoor_forward_3" # Medium 室外
|
||||
]
|
||||
VAL_SCENES = [
|
||||
"indoor_forward_7", # Hard 室内
|
||||
"outdoor_forward_1" # Easy 室外
|
||||
# "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
||||
]
|
||||
TEST_SCENES = [
|
||||
"indoor_forward_7", # Hard 室内
|
||||
"outdoor_forward_1", # Easy 室外
|
||||
"outdoor_forward_5" # Hard 室外
|
||||
# "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
||||
]
|
||||
|
||||
|
||||
# ──────────────────────────── Model architecture ────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class CNNConfig:
|
||||
in_channels: int = 1
|
||||
channels: tuple = (32, 64, 128, 256) # per-layer output channels
|
||||
kernel_size: int = 3
|
||||
pool_size: int = 2
|
||||
use_bn: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoseMLPConfig:
|
||||
input_dim: int = 3
|
||||
hidden_dim: int = 32
|
||||
output_dim: int = 64
|
||||
|
||||
|
||||
@dataclass
|
||||
class GRUConfig:
|
||||
input_size: int = 320 # CNN(256) + PoseMLP(64)
|
||||
hidden_size: int = 128
|
||||
num_layers: int = 1
|
||||
dropout: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class HeadConfig:
|
||||
input_dim: int = 128 # GRU hidden_size
|
||||
hidden_dim: int = 64
|
||||
output_dim: int = 2 # [vx_body, vy_body]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
cnn: CNNConfig = field(default_factory=CNNConfig)
|
||||
pose_mlp: PoseMLPConfig = field(default_factory=PoseMLPConfig)
|
||||
gru: GRUConfig = field(default_factory=GRUConfig)
|
||||
head: HeadConfig = field(default_factory=HeadConfig)
|
||||
|
||||
|
||||
# ──────────────────────────── Training ────────────────────────────
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
seq_len: int = 8 # frames per training sequence
|
||||
batch_size: int = 32
|
||||
epochs: int = 100
|
||||
lr: float = 1e-3
|
||||
weight_decay: float = 1e-5
|
||||
lr_scheduler_step: int = 30
|
||||
lr_scheduler_gamma: float = 0.5
|
||||
num_workers: int = 4
|
||||
seed: int = 42
|
||||
|
||||
# Event simulation
|
||||
event_threshold: float = 0.1
|
||||
event_use_log: bool = True
|
||||
event_auto_threshold: bool = False
|
||||
|
||||
# Logging / checkpoint
|
||||
log_dir: str = "logs"
|
||||
checkpoint_dir: str = "checkpoints"
|
||||
log_interval: int = 10
|
||||
save_interval: int = 10
|
||||
|
||||
|
||||
# ──────────────────────────── Singleton instances ────────────────────────────
|
||||
|
||||
model_cfg = ModelConfig()
|
||||
train_cfg = TrainConfig()
|
||||
120
src/velocity_prediction/dataset.py
Normal file
120
src/velocity_prediction/dataset.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
WebDataset-based dataset with sequence sampling.
|
||||
|
||||
Each sample from the dataset is a dict:
|
||||
{
|
||||
"events": np.ndarray (1, H, W), # simulated event frame
|
||||
"tilt": np.ndarray (3,), # tilt rotation vector
|
||||
"v_body_target": np.ndarray (2,), # body-frame [vx, vy]
|
||||
}
|
||||
|
||||
The dataset groups consecutive frames into sequences of length seq_len.
|
||||
"""
|
||||
|
||||
import webdataset as wds
|
||||
from pathlib import Path
|
||||
from typing import List, Callable, Optional
|
||||
|
||||
from src.velocity_prediction.config import DATASET_ROOT, train_cfg
|
||||
from src.velocity_prediction.transforms import build_train_transform, build_val_transform
|
||||
|
||||
|
||||
def _scene_urls(scene_names: List[str], root: Path = DATASET_ROOT) -> List[str]:
|
||||
"""Build list of actual shard file paths for the given scene names."""
|
||||
urls = []
|
||||
for name in scene_names:
|
||||
scene_dir = root / name
|
||||
if not scene_dir.exists():
|
||||
print(f"Warning: scene directory not found: {scene_dir}")
|
||||
continue
|
||||
# List all shard_*.tar files in the scene directory
|
||||
shard_files = sorted(scene_dir.glob("shard_*.tar"))
|
||||
if not shard_files:
|
||||
print(f"Warning: no shard files found in {scene_dir}")
|
||||
continue
|
||||
urls.extend(str(f) for f in shard_files)
|
||||
return urls
|
||||
|
||||
|
||||
def _build_pipeline(
|
||||
urls: List[str],
|
||||
transform: Callable,
|
||||
seq_len: int,
|
||||
shuffle: int = 1000,
|
||||
deterministic: bool = False,
|
||||
):
|
||||
"""
|
||||
Build a WebDataset pipeline that:
|
||||
1. Reads tar shards
|
||||
2. Decodes and transforms individual samples
|
||||
3. Groups consecutive samples into sequences
|
||||
"""
|
||||
dataset = wds.WebDataset(urls, shardshuffle=shuffle if not deterministic else 0, empty_check=False)
|
||||
|
||||
if not deterministic:
|
||||
dataset = dataset.shuffle(shuffle)
|
||||
|
||||
dataset = dataset.decode().map(transform)
|
||||
|
||||
# Group into sequences of seq_len consecutive frames
|
||||
dataset = dataset.batched(seq_len, partial=False)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def create_train_loader(
|
||||
scene_names: Optional[List[str]] = None,
|
||||
seq_len: int = 8,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 4,
|
||||
event_threshold: float = 0.1,
|
||||
event_use_log: bool = True,
|
||||
):
|
||||
"""Create a DataLoader for training."""
|
||||
if scene_names is None:
|
||||
from src.velocity_prediction.config import TRAIN_SCENES
|
||||
scene_names = TRAIN_SCENES
|
||||
|
||||
urls = _scene_urls(scene_names)
|
||||
transform = build_train_transform(
|
||||
event_threshold=event_threshold,
|
||||
event_use_log=event_use_log,
|
||||
)
|
||||
pipeline = _build_pipeline(urls, transform, seq_len=seq_len, shuffle=1000)
|
||||
|
||||
loader = wds.WebLoader(
|
||||
pipeline,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=False, # already shuffled in pipeline
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def create_val_loader(
|
||||
scene_names: Optional[List[str]] = None,
|
||||
seq_len: int = 8,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 4,
|
||||
event_threshold: float = 0.1,
|
||||
event_use_log: bool = True,
|
||||
):
|
||||
"""Create a DataLoader for validation (deterministic order)."""
|
||||
if scene_names is None:
|
||||
from src.velocity_prediction.config import VAL_SCENES
|
||||
scene_names = VAL_SCENES
|
||||
|
||||
urls = _scene_urls(scene_names)
|
||||
transform = build_val_transform(
|
||||
event_threshold=event_threshold,
|
||||
event_use_log=event_use_log,
|
||||
)
|
||||
pipeline = _build_pipeline(urls, transform, seq_len=seq_len, shuffle=0, deterministic=True)
|
||||
|
||||
loader = wds.WebLoader(
|
||||
pipeline,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=False,
|
||||
)
|
||||
return loader
|
||||
212
src/velocity_prediction/evaluate.py
Normal file
212
src/velocity_prediction/evaluate.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Evaluation and visualization for VelocityPredictionModel.
|
||||
|
||||
Usage:
|
||||
python -m src.velocity_prediction.evaluate --checkpoint checkpoints/best.pt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from src.velocity_prediction.model import VelocityPredictionModel
|
||||
from src.velocity_prediction.dataset import create_val_loader
|
||||
from src.velocity_prediction.config import train_cfg, VELOCITY_MEAN, VELOCITY_STD
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(
|
||||
model: nn.Module,
|
||||
loader,
|
||||
device: torch.device,
|
||||
) -> dict:
|
||||
"""
|
||||
Run evaluation on a dataloader.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
preds: np.ndarray (N, 2) predicted [vx, vy]
|
||||
targets: np.ndarray (N, 2) ground truth [vx, vy]
|
||||
"""
|
||||
model.eval()
|
||||
all_preds = []
|
||||
all_targets = []
|
||||
|
||||
for batch in loader:
|
||||
events = batch["events"].to(device)
|
||||
tilt = batch["tilt"].to(device)
|
||||
target = batch["v_body_target"].to(device) # (B, S, 2)
|
||||
|
||||
pred = model(events, tilt) # (B, 2)
|
||||
target_last = target[:, -1, :] # (B, 2)
|
||||
|
||||
all_preds.append(pred.cpu().numpy())
|
||||
all_targets.append(target_last.cpu().numpy())
|
||||
|
||||
preds = np.concatenate(all_preds, axis=0)
|
||||
targets = np.concatenate(all_targets, axis=0)
|
||||
|
||||
# Denormalize predictions back to original velocity space
|
||||
mean = np.array(VELOCITY_MEAN, dtype=np.float32)
|
||||
std = np.array(VELOCITY_STD, dtype=np.float32)
|
||||
preds_denorm = preds * std + mean
|
||||
targets_denorm = targets * std + mean
|
||||
|
||||
# ── Diagnostics (in normalized space) ────────────────────────
|
||||
print("\n========== Evaluation Diagnostics (normalized space) ==========")
|
||||
print(f"Total samples: {len(preds)}")
|
||||
print(f"\n--- Targets (normalized) ---")
|
||||
print(f" vx: mean={targets[:, 0].mean():.6f}, std={targets[:, 0].std():.6f}")
|
||||
print(f" vy: mean={targets[:, 1].mean():.6f}, std={targets[:, 1].std():.6f}")
|
||||
print(f"\n--- Predictions (normalized) ---")
|
||||
print(f" vx: mean={preds[:, 0].mean():.6f}, std={preds[:, 0].std():.6f}, "
|
||||
f"min={preds[:, 0].min():.6f}, max={preds[:, 0].max():.6f}")
|
||||
print(f" vy: mean={preds[:, 1].mean():.6f}, std={preds[:, 1].std():.6f}, "
|
||||
f"min={preds[:, 1].min():.6f}, max={preds[:, 1].max():.6f}")
|
||||
print(f"\n--- Unique prediction values ---")
|
||||
print(f" vx unique: {len(np.unique(preds[:, 0]))} / {len(preds)}")
|
||||
print(f" vy unique: {len(np.unique(preds[:, 1]))} / {len(preds)}")
|
||||
vx_range = preds[:, 0].max() - preds[:, 0].min()
|
||||
vy_range = preds[:, 1].max() - preds[:, 1].min()
|
||||
print(f"\n vx range: {vx_range:.8f} (constant if near 0)")
|
||||
print(f" vy range: {vy_range:.8f} (constant if near 0)")
|
||||
print(f"\n--- Constant prediction check ---")
|
||||
print(f" pred vx mean ≈ 0? {abs(preds[:, 0].mean()):.6f} diff from zero")
|
||||
print(f" pred vy mean ≈ 0? {abs(preds[:, 1].mean()):.6f} diff from zero")
|
||||
print("=============================================\n")
|
||||
|
||||
# Per-axis and overall RMSE (in original velocity space)
|
||||
rmse_x = np.sqrt(np.mean((preds_denorm[:, 0] - targets_denorm[:, 0]) ** 2))
|
||||
rmse_y = np.sqrt(np.mean((preds_denorm[:, 1] - targets_denorm[:, 1]) ** 2))
|
||||
rmse_xy = np.sqrt(np.mean(np.sum((preds_denorm - targets_denorm) ** 2, axis=1)))
|
||||
|
||||
return {
|
||||
"preds": preds_denorm, # denormalized for plotting
|
||||
"targets": targets_denorm, # denormalized for plotting
|
||||
"rmse_x": rmse_x,
|
||||
"rmse_y": rmse_y,
|
||||
"rmse_xy": rmse_xy,
|
||||
}
|
||||
|
||||
|
||||
def plot_results(
|
||||
preds: np.ndarray,
|
||||
targets: np.ndarray,
|
||||
save_path: str = "eval_plot.png",
|
||||
):
|
||||
"""Plot predicted vs ground truth velocity traces."""
|
||||
fig, axes = plt.subplots(2, 1, figsize=(12, 6), sharex=True)
|
||||
|
||||
time = np.arange(len(preds))
|
||||
|
||||
axes[0].plot(time, targets[:, 0], label="GT vx", color="C0", alpha=0.8)
|
||||
axes[0].plot(time, preds[:, 0], label="Pred vx", color="C1", alpha=0.8)
|
||||
axes[0].set_ylabel("vx (m/s)")
|
||||
axes[0].legend()
|
||||
axes[0].grid(True, alpha=0.3)
|
||||
|
||||
axes[1].plot(time, targets[:, 1], label="GT vy", color="C0", alpha=0.8)
|
||||
axes[1].plot(time, preds[:, 1], label="Pred vy", color="C1", alpha=0.8)
|
||||
axes[1].set_ylabel("vy (m/s)")
|
||||
axes[1].set_xlabel("Frame index")
|
||||
axes[1].legend()
|
||||
axes[1].grid(True, alpha=0.3)
|
||||
|
||||
fig.suptitle("Body-frame Velocity Prediction")
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_path, dpi=150)
|
||||
print(f"Plot saved: {save_path}")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_scatter(
|
||||
preds: np.ndarray,
|
||||
targets: np.ndarray,
|
||||
save_path: str = "eval_scatter.png",
|
||||
):
|
||||
"""Scatter plot: predicted vs ground truth."""
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
|
||||
|
||||
for ax, pred, target, label in zip(
|
||||
axes, [preds[:, 0], preds[:, 1]], [targets[:, 0], targets[:, 1]], ["vx", "vy"]
|
||||
):
|
||||
ax.scatter(target, pred, s=2, alpha=0.5)
|
||||
lim_min = min(target.min(), pred.min())
|
||||
lim_max = max(target.max(), pred.max())
|
||||
ax.plot([lim_min, lim_max], [lim_min, lim_max], "r--", alpha=0.5)
|
||||
ax.set_xlabel(f"GT {label} (m/s)")
|
||||
ax.set_ylabel(f"Pred {label} (m/s)")
|
||||
ax.set_aspect("equal")
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_title(f"{label} — RMSE: {np.sqrt(np.mean((pred - target)**2)):.4f}")
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_path, dpi=150)
|
||||
print(f"Scatter saved: {save_path}")
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint", type=str, default="checkpoints/best.pt",
|
||||
help="Path to model checkpoint")
|
||||
parser.add_argument("--device", type=str, default="cuda",
|
||||
help="Device to use (e.g. 'cuda:2', 'cpu')")
|
||||
parser.add_argument("--plot", action="store_true", default=True,
|
||||
help="Generate evaluation plots")
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
|
||||
print(f"Device: {device}")
|
||||
|
||||
# Load model
|
||||
model = VelocityPredictionModel()
|
||||
ckpt = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(ckpt["model_state_dict"])
|
||||
model.to(device)
|
||||
print(f"Loaded checkpoint from {args.checkpoint} (epoch={ckpt.get('epoch', '?')})")
|
||||
|
||||
# Validation loader (use test scenes for final eval)
|
||||
from src.velocity_prediction.config import TEST_SCENES
|
||||
loader = create_val_loader(
|
||||
scene_names=TEST_SCENES,
|
||||
seq_len=train_cfg.seq_len,
|
||||
batch_size=train_cfg.batch_size,
|
||||
num_workers=2,
|
||||
event_threshold=train_cfg.event_threshold,
|
||||
event_use_log=train_cfg.event_use_log,
|
||||
)
|
||||
|
||||
# # ── Quick event diagnostics: inspect one batch ───────────────
|
||||
# print("\n========== Event Frame Diagnostics ==========")
|
||||
# sample_batch = next(iter(loader))
|
||||
# ev = sample_batch["events"] # (B, S, 1, H, W)
|
||||
# print(f"Events shape: {ev.shape}")
|
||||
# print(f"Events dtype: {ev.dtype}")
|
||||
# print(f"Events value counts: -1: {(ev == -1).sum().item()}, "
|
||||
# f"0: {(ev == 0).sum().item()}, +1: {(ev == 1).sum().item()}")
|
||||
# total_el = ev.numel()
|
||||
# nonzero = (ev != 0).sum().item()
|
||||
# print(f"Non-zero ratio: {nonzero / total_el:.6f} ({nonzero}/{total_el})")
|
||||
# print(f"Per-sample non-zero: {[(ev[b] != 0).sum().item() for b in range(min(4, ev.shape[0]))]}")
|
||||
# print("=============================================\n")
|
||||
|
||||
# Evaluate
|
||||
results = evaluate(model, loader, device)
|
||||
print(f"\nEvaluation results on test scenes: {TEST_SCENES}")
|
||||
print(f" RMSE vx: {results['rmse_x']:.4f} m/s")
|
||||
print(f" RMSE vy: {results['rmse_y']:.4f} m/s")
|
||||
print(f" RMSE xy: {results['rmse_xy']:.4f} m/s")
|
||||
|
||||
# Plots
|
||||
if args.plot:
|
||||
plot_results(results["preds"], results["targets"], "eval_velocity.png")
|
||||
plot_scatter(results["preds"], results["targets"], "eval_scatter.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
170
src/velocity_prediction/model.py
Normal file
170
src/velocity_prediction/model.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
VelocityPredictionModel: CNN + PoseMLP → concat → GRU → Head → [vx_body, vy_body].
|
||||
|
||||
Architecture:
|
||||
Event frame (1, H, W) ──► CNN ──┐
|
||||
Tilt angles (3,) ──► MLP ──┤──► concat ──► GRU ──► Head ──► [vx, vy]
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from src.velocity_prediction.config import model_cfg
|
||||
|
||||
|
||||
class CNNEncoder(nn.Module):
|
||||
"""
|
||||
4-layer ConvNet with BatchNorm, ReLU, MaxPool, ending with Global Avg Pool.
|
||||
|
||||
Input: (B, S, 1, H, W) — processed per-frame (flattened to (B*S, 1, H, W))
|
||||
Output: (B, S, C_out) — per-frame feature vectors
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=model_cfg.cnn):
|
||||
super().__init__()
|
||||
channels = cfg.channels
|
||||
in_ch = cfg.in_channels
|
||||
|
||||
layers = []
|
||||
for out_ch in channels:
|
||||
layers.extend([
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=cfg.kernel_size, padding=cfg.kernel_size // 2),
|
||||
# nn.BatchNorm2d(out_ch) if cfg.use_bn else nn.Identity(),
|
||||
nn.Identity(),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.MaxPool2d(cfg.pool_size),
|
||||
])
|
||||
in_ch = out_ch
|
||||
|
||||
self.conv = nn.Sequential(*layers)
|
||||
self.gap = nn.AdaptiveAvgPool2d(1)
|
||||
self.out_dim = channels[-1]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: (B, S, 1, H, W) event frame sequence
|
||||
Returns:
|
||||
features: (B, S, C_out)
|
||||
"""
|
||||
B, S, C, H, W = x.shape
|
||||
x = x.view(B * S, C, H, W) # (B*S, 1, H, W)
|
||||
x = self.conv(x) # (B*S, C_out, H', W')
|
||||
x = self.gap(x) # (B*S, C_out, 1, 1)
|
||||
x = x.view(B, S, self.out_dim) # (B, S, C_out)
|
||||
return x
|
||||
|
||||
|
||||
class PoseMLP(nn.Module):
|
||||
"""
|
||||
Encode tilt rotation vector (3,) into a compact feature vector.
|
||||
|
||||
Input: (B, S, 3)
|
||||
Output: (B, S, output_dim)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=model_cfg.pose_mlp):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(cfg.input_dim, cfg.hidden_dim),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.Linear(cfg.hidden_dim, cfg.output_dim),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
)
|
||||
self.out_dim = cfg.output_dim
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: (B, S, 3) → (B, S, output_dim)
|
||||
"""
|
||||
B, S, D = x.shape
|
||||
x = x.view(B * S, D)
|
||||
x = self.net(x)
|
||||
x = x.view(B, S, self.out_dim)
|
||||
return x
|
||||
|
||||
|
||||
class VelocityPredictionModel(nn.Module):
|
||||
"""
|
||||
Full model: CNN + PoseMLP → concat → GRU → Head → [vx, vy].
|
||||
|
||||
Input:
|
||||
events: (B, S, 1, H, W)
|
||||
tilt: (B, S, 3)
|
||||
Output:
|
||||
v_body: (B, 2) — body-frame [vx, vy] for the last frame in the sequence
|
||||
"""
|
||||
|
||||
def __init__(self, cnn_cfg=model_cfg.cnn, pose_cfg=model_cfg.pose_mlp,
|
||||
gru_cfg=model_cfg.gru, head_cfg=model_cfg.head):
|
||||
super().__init__()
|
||||
|
||||
self.cnn = CNNEncoder(cnn_cfg)
|
||||
self.pose_mlp = PoseMLP(pose_cfg)
|
||||
|
||||
fused_dim = self.cnn.out_dim + self.pose_mlp.out_dim # 256 + 64 = 320
|
||||
|
||||
self.gru = nn.GRU(
|
||||
input_size=fused_dim,
|
||||
hidden_size=gru_cfg.hidden_size,
|
||||
num_layers=gru_cfg.num_layers,
|
||||
dropout=gru_cfg.dropout if gru_cfg.num_layers > 1 else 0.0,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
self.head = nn.Sequential(
|
||||
nn.Linear(gru_cfg.hidden_size, head_cfg.hidden_dim),
|
||||
nn.LeakyReLU(inplace=True),
|
||||
nn.Linear(head_cfg.hidden_dim, head_cfg.output_dim),
|
||||
)
|
||||
|
||||
# # Small init for the final layer: start from near-zero output
|
||||
# self.head[-1].weight.data.mul_(0.01)
|
||||
# self.head[-1].bias.data.zero_()
|
||||
|
||||
def forward(self, events: torch.Tensor, tilt: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
events: (B, S, 1, H, W)
|
||||
tilt: (B, S, 3)
|
||||
|
||||
Returns:
|
||||
v_body: (B, 2) predicted body-frame [vx, vy] at the last timestep
|
||||
"""
|
||||
# Per-frame encoding
|
||||
cnn_feat = self.cnn(events) # (B, S, 256)
|
||||
pose_feat = self.pose_mlp(tilt) # (B, S, 64)
|
||||
|
||||
# Fuse per frame
|
||||
fused = torch.cat([cnn_feat, pose_feat], dim=-1) # (B, S, 320)
|
||||
|
||||
# GRU temporal modelling
|
||||
gru_out, h_n = self.gru(fused) # gru_out: (B, S, 128), h_n: (1, B, 128)
|
||||
|
||||
# Use last hidden state
|
||||
last_hidden = h_n[-1] # (B, 128)
|
||||
|
||||
# Head regression
|
||||
v_body = self.head(last_hidden) # (B, 2)
|
||||
return v_body
|
||||
|
||||
|
||||
def count_parameters(model: nn.Module) -> int:
|
||||
"""Count trainable parameters."""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Quick sanity check
|
||||
model = VelocityPredictionModel()
|
||||
total = count_parameters(model)
|
||||
print(f"Total trainable parameters: {total:,} ({total/1e6:.3f} M)")
|
||||
|
||||
# Forward pass test
|
||||
B, S, H, W = 4, 8, 240, 320
|
||||
events = torch.randn(B, S, 1, H, W)
|
||||
tilt = torch.randn(B, S, 3)
|
||||
out = model(events, tilt)
|
||||
print(f"Input events: {events.shape}")
|
||||
print(f"Input tilt: {tilt.shape}")
|
||||
print(f"Output: {out.shape} (should be [4, 2])")
|
||||
213
src/velocity_prediction/train.py
Normal file
213
src/velocity_prediction/train.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
Training loop for VelocityPredictionModel.
|
||||
|
||||
Usage:
|
||||
python -m src.velocity_prediction.train [--device cuda:0]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from src.velocity_prediction.config import train_cfg, model_cfg
|
||||
from src.velocity_prediction.model import VelocityPredictionModel, count_parameters
|
||||
from src.velocity_prediction.dataset import create_train_loader, create_val_loader
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
model: nn.Module,
|
||||
loader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
criterion: nn.Module,
|
||||
device: torch.device,
|
||||
epoch: int,
|
||||
writer: SummaryWriter,
|
||||
log_interval: int = 50,
|
||||
) -> float:
|
||||
"""Train for one epoch. Returns average loss."""
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
start_time = time.time()
|
||||
|
||||
for batch_idx, batch in enumerate(loader):
|
||||
events = batch["events"].to(device) # (B, S, 1, H, W)
|
||||
tilt = batch["tilt"].to(device) # (B, S, 3)
|
||||
target = batch["v_body_target"].to(device) # (B, S, 2)
|
||||
|
||||
# Predict velocity for the last frame in the sequence
|
||||
pred = model(events, tilt) # (B, 2)
|
||||
target_last = target[:, -1, :] # (B, 2)
|
||||
|
||||
loss = criterion(pred, target_last)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
elapsed = time.time() - start_time
|
||||
print(f" Epoch {epoch} | Batch {batch_idx} | Loss: {loss.item():.6f} | {elapsed:.1f}s")
|
||||
writer.add_scalar("train/loss_batch", loss.item(), batch_idx)
|
||||
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
return avg_loss
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def validate(
|
||||
model: nn.Module,
|
||||
loader,
|
||||
criterion: nn.Module,
|
||||
device: torch.device,
|
||||
) -> float:
|
||||
"""Validate. Returns average loss."""
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch in loader:
|
||||
events = batch["events"].to(device)
|
||||
tilt = batch["tilt"].to(device)
|
||||
target = batch["v_body_target"].to(device)
|
||||
|
||||
pred = model(events, tilt)
|
||||
target_last = target[:, -1, :]
|
||||
|
||||
loss = criterion(pred, target_last)
|
||||
total_loss += loss.item()
|
||||
num_batches += 1
|
||||
|
||||
avg_loss = total_loss / max(num_batches, 1)
|
||||
return avg_loss
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--device", type=str, default="cuda",
|
||||
help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')")
|
||||
args = parser.parse_args()
|
||||
|
||||
set_seed(train_cfg.seed)
|
||||
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
|
||||
print(f"Device: {device}")
|
||||
|
||||
# Create model
|
||||
model = VelocityPredictionModel()
|
||||
model.to(device)
|
||||
total_params = count_parameters(model)
|
||||
print(f"Model parameters: {total_params:,} ({total_params/1e6:.3f} M)")
|
||||
|
||||
# Data loaders
|
||||
train_loader = create_train_loader(
|
||||
seq_len=train_cfg.seq_len,
|
||||
batch_size=train_cfg.batch_size,
|
||||
num_workers=train_cfg.num_workers,
|
||||
event_threshold=train_cfg.event_threshold,
|
||||
event_use_log=train_cfg.event_use_log,
|
||||
)
|
||||
val_loader = create_val_loader(
|
||||
seq_len=train_cfg.seq_len,
|
||||
batch_size=train_cfg.batch_size,
|
||||
num_workers=train_cfg.num_workers,
|
||||
event_threshold=train_cfg.event_threshold,
|
||||
event_use_log=train_cfg.event_use_log,
|
||||
)
|
||||
|
||||
# Optimizer & scheduler
|
||||
optimizer = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=train_cfg.lr,
|
||||
weight_decay=train_cfg.weight_decay,
|
||||
)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(
|
||||
optimizer,
|
||||
step_size=train_cfg.lr_scheduler_step,
|
||||
gamma=train_cfg.lr_scheduler_gamma,
|
||||
)
|
||||
# criterion = nn.SmoothL1Loss()
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
# Logging
|
||||
log_dir = Path(train_cfg.log_dir)
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
writer = SummaryWriter(log_dir=str(log_dir))
|
||||
|
||||
ckpt_dir = Path(train_cfg.checkpoint_dir)
|
||||
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
best_val_loss = float("inf")
|
||||
|
||||
print(f"\nStarting training for {train_cfg.epochs} epochs...")
|
||||
print(f" seq_len={train_cfg.seq_len}, batch_size={train_cfg.batch_size}")
|
||||
print(f" lr={train_cfg.lr}, weight_decay={train_cfg.weight_decay}")
|
||||
print(f" log_dir={log_dir}, checkpoint_dir={ckpt_dir}\n")
|
||||
|
||||
for epoch in range(1, train_cfg.epochs + 1):
|
||||
epoch_start = time.time()
|
||||
|
||||
train_loss = train_one_epoch(
|
||||
model, train_loader, optimizer, criterion, device, epoch, writer,
|
||||
log_interval=train_cfg.log_interval,
|
||||
)
|
||||
val_loss = validate(model, val_loader, criterion, device)
|
||||
scheduler.step()
|
||||
|
||||
epoch_time = time.time() - epoch_start
|
||||
current_lr = scheduler.get_last_lr()[0]
|
||||
|
||||
print(f"Epoch {epoch:3d}/{train_cfg.epochs} | "
|
||||
f"Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f} | "
|
||||
f"LR: {current_lr:.2e} | Time: {epoch_time:.1f}s")
|
||||
|
||||
writer.add_scalar("train/loss_epoch", train_loss, epoch)
|
||||
writer.add_scalar("val/loss", val_loss, epoch)
|
||||
writer.add_scalar("lr", current_lr, epoch)
|
||||
|
||||
# Save checkpoint
|
||||
if epoch % train_cfg.save_interval == 0:
|
||||
ckpt_path = ckpt_dir / f"epoch_{epoch:03d}_val_{val_loss:.6f}.pt"
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"scheduler_state_dict": scheduler.state_dict(),
|
||||
"train_loss": train_loss,
|
||||
"val_loss": val_loss,
|
||||
}, ckpt_path)
|
||||
print(f" Checkpoint saved: {ckpt_path}")
|
||||
|
||||
# Save best model
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_path = ckpt_dir / "best.pt"
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict(),
|
||||
"val_loss": val_loss,
|
||||
}, best_path)
|
||||
print(f" Best model updated: {best_path} (val_loss={val_loss:.6f})")
|
||||
|
||||
writer.close()
|
||||
print(f"\nTraining complete. Best val loss: {best_val_loss:.6f}")
|
||||
print(f"Best checkpoint: {ckpt_dir / 'best.pt'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
142
src/velocity_prediction/transforms.py
Normal file
142
src/velocity_prediction/transforms.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Transforms: event frame generation + coordinate transforms for pose/velocity.
|
||||
|
||||
Each transform operates on a single decoded sample dict:
|
||||
{
|
||||
"jpg": bytes, # JPEG-encoded grayscale image
|
||||
"ts": bytes, # float64 timestamp
|
||||
"pose": bytes, # float32[7] [x, y, z, qx, qy, qz, qw]
|
||||
"vel": bytes, # float32[6] [vx, vy, vz, wx, wy, wz]
|
||||
}
|
||||
"""
|
||||
|
||||
import io
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from src.event_utils import EventProcessor
|
||||
from src.velocity_prediction.utils import decompose_tilt_np, world_vel_to_body_np
|
||||
from src.velocity_prediction.config import VELOCITY_MEAN, VELOCITY_STD
|
||||
|
||||
|
||||
class DecodeSample:
|
||||
"""Decode raw bytes from WebDataset tar entry into numpy arrays."""
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
# Image: JPEG bytes → grayscale uint8 (H, W)
|
||||
img = cv2.imdecode(np.frombuffer(sample["jpg"], np.uint8), cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
# Timestamp
|
||||
ts = np.frombuffer(sample["ts"], dtype=np.float64).item()
|
||||
|
||||
# Pose: [x, y, z, qx, qy, qz, qw]
|
||||
pose = np.frombuffer(sample["pose"], dtype=np.float32).copy()
|
||||
|
||||
# Velocity: [vx, vy, vz, wx, wy, wz]
|
||||
vel = np.frombuffer(sample["vel"], dtype=np.float32).copy()
|
||||
|
||||
return {"img": img, "ts": ts, "pose": pose, "vel": vel}
|
||||
|
||||
|
||||
class SimulateEvents:
|
||||
"""Convert grayscale frame to binary event frame using EventProcessor."""
|
||||
|
||||
def __init__(self, threshold=0.1, use_log=True, auto_threshold=False, verbose=False):
|
||||
self.processor = EventProcessor(
|
||||
threshold=threshold,
|
||||
use_log=use_log,
|
||||
auto_threshold=auto_threshold,
|
||||
)
|
||||
self.verbose = verbose
|
||||
self._frame_count = 0
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
img = sample["img"]
|
||||
events_binary, events_strength, event_count = self.processor(img)
|
||||
# Use binary events as network input: shape (1, H, W), values in {-1, 0, 1}
|
||||
sample["events"] = events_binary.astype(np.float32)[np.newaxis, ...]
|
||||
|
||||
if self.verbose:
|
||||
self._frame_count += 1
|
||||
total_pixels = events_binary.shape[0] * events_binary.shape[1]
|
||||
nonzero_ratio = event_count / total_pixels
|
||||
pos_ratio = (events_binary > 0).sum() / total_pixels
|
||||
neg_ratio = (events_binary < 0).sum() / total_pixels
|
||||
if self._frame_count <= 5 or self._frame_count % 100 == 0:
|
||||
print(f" [EventDiagnostics] frame={self._frame_count} | "
|
||||
f"nonzero={nonzero_ratio:.4f} (+{pos_ratio:.4f}/-{neg_ratio:.4f}) | "
|
||||
f"count={event_count}")
|
||||
|
||||
return sample
|
||||
|
||||
def reset(self):
|
||||
self.processor.reset()
|
||||
self._frame_count = 0
|
||||
|
||||
|
||||
class ComputeTilt:
|
||||
"""Extract tilt rotation vector from pose quaternion (discard position, discard yaw)."""
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
q = sample["pose"][3:7] # [qx, qy, qz, qw]
|
||||
tilt = decompose_tilt_np(q) # (3,) rotation vector
|
||||
sample["tilt"] = tilt.astype(np.float32)
|
||||
return sample
|
||||
|
||||
|
||||
class ComputeBodyVelocity:
|
||||
"""Transform world-frame velocity to body-frame (yaw-compensated)."""
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
v_world = sample["vel"][:3] # [vx, vy, vz] world frame
|
||||
q = sample["pose"][3:7] # [qx, qy, qz, qw]
|
||||
v_body = world_vel_to_body_np(v_world, q) # (3,)
|
||||
# Only predict forward (x) and lateral (y) body velocity
|
||||
sample["v_body_target"] = v_body[:2].astype(np.float32) # (2,)
|
||||
return sample
|
||||
|
||||
|
||||
class NormalizeVelocity:
|
||||
"""Normalize body-frame velocity to zero mean, unit variance."""
|
||||
|
||||
def __init__(self):
|
||||
self.mean = np.array(VELOCITY_MEAN, dtype=np.float32)
|
||||
self.std = np.array(VELOCITY_STD, dtype=np.float32)
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
sample["v_body_target"] = (sample["v_body_target"] - self.mean) / self.std
|
||||
return sample
|
||||
|
||||
|
||||
class Compose:
|
||||
"""Chain multiple transforms."""
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = transforms
|
||||
|
||||
def __call__(self, sample: dict) -> dict:
|
||||
for t in self.transforms:
|
||||
sample = t(sample)
|
||||
return sample
|
||||
|
||||
|
||||
def build_train_transform(event_threshold=0.1, event_use_log=True):
|
||||
"""Build the full transform pipeline for training samples."""
|
||||
return Compose([
|
||||
DecodeSample(),
|
||||
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
||||
ComputeTilt(),
|
||||
ComputeBodyVelocity(),
|
||||
NormalizeVelocity(),
|
||||
])
|
||||
|
||||
|
||||
def build_val_transform(event_threshold=0.1, event_use_log=True):
|
||||
"""Same as train but with a fresh EventProcessor per sample (no cross-contamination)."""
|
||||
return Compose([
|
||||
DecodeSample(),
|
||||
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
||||
ComputeTilt(),
|
||||
ComputeBodyVelocity(),
|
||||
NormalizeVelocity(),
|
||||
])
|
||||
165
src/velocity_prediction/utils.py
Normal file
165
src/velocity_prediction/utils.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
Quaternion and coordinate-frame utility functions.
|
||||
|
||||
All quaternions follow [x, y, z, w] convention (matching dataset pose field).
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ──────────────────────────── Quaternion operations ────────────────────────────
|
||||
|
||||
def quat_normalize(q: torch.Tensor) -> torch.Tensor:
|
||||
"""Normalize quaternion. q: (..., 4)"""
|
||||
return q / torch.norm(q, dim=-1, keepdim=True).clamp(min=1e-12)
|
||||
|
||||
|
||||
def quat_conjugate(q: torch.Tensor) -> torch.Tensor:
|
||||
"""Conjugate (inverse for unit quaternion). q: (..., 4)"""
|
||||
return q * torch.tensor([-1, -1, -1, 1], device=q.device)
|
||||
|
||||
|
||||
def quat_mul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
||||
"""Hamilton product. a, b: (..., 4)"""
|
||||
x1, y1, z1, w1 = a.unbind(-1)
|
||||
x2, y2, z2, w2 = b.unbind(-1)
|
||||
return torch.stack([
|
||||
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
|
||||
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
|
||||
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
|
||||
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
|
||||
], dim=-1)
|
||||
|
||||
|
||||
def quat_rotate(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
"""Rotate vector v by quaternion q. q: (..., 4), v: (..., 3)"""
|
||||
q_conj = quat_conjugate(q)
|
||||
v_pad = torch.zeros_like(v[..., :1]) # (..., 1)
|
||||
v_q = torch.cat([v, v_pad], dim=-1) # (..., 4) pure quaternion
|
||||
return quat_mul(quat_mul(q, v_q), q_conj)[..., :3]
|
||||
|
||||
|
||||
# ──────────────────────────── Yaw decomposition ────────────────────────────
|
||||
|
||||
def quat_to_yaw(q: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extract yaw (heading) angle from a quaternion in gravity-aligned frame.
|
||||
|
||||
Gravity axis is +z. Yaw is the rotation around z-axis.
|
||||
Returns angle in radians, shape (...,).
|
||||
"""
|
||||
x, y, z, w = q.unbind(-1)
|
||||
# From quaternion to Euler: yaw = atan2(2(w*z + x*y), 1 - 2(y² + z²))
|
||||
siny = 2.0 * (w * z + x * y)
|
||||
cosy = 1.0 - 2.0 * (y * y + z * z)
|
||||
return torch.atan2(siny, cosy)
|
||||
|
||||
|
||||
def quat_from_yaw(yaw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Build a pure-yaw quaternion (rotation around +z).
|
||||
yaw: (...,) → quat: (..., 4)
|
||||
"""
|
||||
half = yaw * 0.5
|
||||
cos = torch.cos(half)
|
||||
sin = torch.sin(half)
|
||||
z = torch.zeros_like(cos)
|
||||
return torch.stack([z, z, sin, cos], dim=-1) # [0, 0, sin(yaw/2), cos(yaw/2)]
|
||||
|
||||
|
||||
def decompose_tilt(q: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Remove yaw from a quaternion, returning the residual tilt rotation vector.
|
||||
|
||||
Given q = q_yaw * q_tilt (z-yaw first, then body tilt),
|
||||
we compute q_tilt = q_yaw^{-1} * q, then convert to rotation vector.
|
||||
|
||||
Args:
|
||||
q: (..., 4) unit quaternion in world→body convention.
|
||||
|
||||
Returns:
|
||||
tilt_angles: (..., 3) rotation vector [rx, ry, rz] representing
|
||||
the body's deviation from the heading direction.
|
||||
"""
|
||||
yaw = quat_to_yaw(q)
|
||||
q_yaw = quat_from_yaw(yaw)
|
||||
q_yaw_inv = quat_conjugate(q_yaw)
|
||||
q_tilt = quat_mul(q_yaw_inv, q) # remove yaw
|
||||
q_tilt = quat_normalize(q_tilt)
|
||||
return quat_to_rotvec(q_tilt)
|
||||
|
||||
|
||||
def quat_to_rotvec(q: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
|
||||
"""
|
||||
Convert unit quaternion to rotation vector (axis * angle).
|
||||
q: (..., 4) → rotvec: (..., 3)
|
||||
"""
|
||||
q = quat_normalize(q)
|
||||
x, y, z, w = q.unbind(-1)
|
||||
angle = 2.0 * torch.acos(w.clamp(-1.0, 1.0))
|
||||
sin_half = torch.sqrt((1.0 - w * w).clamp(min=eps))
|
||||
scale = angle / sin_half
|
||||
# Avoid division by zero for small angles
|
||||
mask = sin_half > eps
|
||||
rx = torch.where(mask, x * scale, torch.zeros_like(x))
|
||||
ry = torch.where(mask, y * scale, torch.zeros_like(y))
|
||||
rz = torch.where(mask, z * scale, torch.zeros_like(z))
|
||||
return torch.stack([rx, ry, rz], dim=-1)
|
||||
|
||||
|
||||
# ──────────────────────────── Velocity transformation ────────────────────────────
|
||||
|
||||
def world_vel_to_body(
|
||||
v_world: torch.Tensor,
|
||||
q_world_to_body: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Transform world-frame velocity to body-frame velocity.
|
||||
|
||||
Steps:
|
||||
1. Extract yaw from q_world_to_body.
|
||||
2. Build pure-yaw quaternion q_yaw.
|
||||
3. Remove yaw from velocity: v_yaw_compensated = q_yaw^{-1} * v_world
|
||||
4. Rotate to body frame: v_body = q_tilt^{-1} * v_yaw_compensated
|
||||
where q_tilt = q_yaw^{-1} * q_world_to_body
|
||||
|
||||
Args:
|
||||
v_world: (..., 3) world-frame linear velocity [vx, vy, vz]
|
||||
q_world_to_body: (..., 4) world→body unit quaternion
|
||||
|
||||
Returns:
|
||||
v_body: (..., 3) body-frame linear velocity
|
||||
"""
|
||||
yaw = quat_to_yaw(q_world_to_body)
|
||||
q_yaw = quat_from_yaw(yaw)
|
||||
q_yaw_inv = quat_conjugate(q_yaw)
|
||||
|
||||
# Step 1: remove yaw from velocity (rotate to yaw-aligned intermediate frame)
|
||||
v_yaw_comp = quat_rotate(q_yaw_inv, v_world)
|
||||
|
||||
# Step 2: compute tilt quaternion
|
||||
q_tilt = quat_mul(q_yaw_inv, q_world_to_body)
|
||||
q_tilt = quat_normalize(q_tilt)
|
||||
q_tilt_inv = quat_conjugate(q_tilt)
|
||||
|
||||
# Step 3: rotate to body frame
|
||||
v_body = quat_rotate(q_tilt_inv, v_yaw_comp)
|
||||
return v_body
|
||||
|
||||
|
||||
# ──────────────────────────── NumPy wrappers (for transforms.py) ────────────────────────────
|
||||
|
||||
def decompose_tilt_np(q: np.ndarray) -> np.ndarray:
|
||||
"""NumPy version of decompose_tilt."""
|
||||
q_t = torch.from_numpy(q)
|
||||
tilt = decompose_tilt(q_t)
|
||||
return tilt.numpy()
|
||||
|
||||
|
||||
def world_vel_to_body_np(v_world: np.ndarray, q: np.ndarray) -> np.ndarray:
|
||||
"""NumPy version of world_vel_to_body."""
|
||||
v_t = torch.from_numpy(v_world)
|
||||
q_t = torch.from_numpy(q)
|
||||
vb = world_vel_to_body(v_t, q_t)
|
||||
return vb.numpy()
|
||||
Reference in New Issue
Block a user