feat: replace non-overlapping windows with sliding-window sequence sampling
- Remove sample-level shuffle before transforms (broke SimulateEvents) - Add _sliding_window_fn: yields overlapping sequences with configurable stride - Add sequence-level shuffle after grouping (preserves temporal coherence) - Add sliding_window_stride to TrainConfig (stride=1 for full overlap) - Update create_train/val_loader and train.py to pass stride - AGENTS.md: document known issues (cross-shard boundary, SimulateEvents state) - AGENTS.md: add cuda:7 device preference Generated by Mistral Vibe (deepseek-v4-flash). Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit is contained in:
20
AGENTS.md
20
AGENTS.md
@@ -169,3 +169,23 @@ uv run python -m benchmark.benchmark --checkpoint checkpoints/best.pt
|
||||
- 速度归一化统计量:待重新计算
|
||||
- 模型预测 `[v_right, v_forward]`(右向和前向速度)
|
||||
- 所有代码在项目根目录下以 `uv run python -m <module>` 运行
|
||||
- GPU 优先使用 `cuda:7`,训练时添加 `--device cuda:7`
|
||||
|
||||
## 已知问题
|
||||
|
||||
### 1. 滑窗跨 shard 边界
|
||||
|
||||
`dataset.py` 中滑窗实现基于 WebDataset 串联后的连续流,不感知 shard 边界。当样本恰好处于 shard 末尾时,序列会跨越到下一个 shard 的起始帧。
|
||||
|
||||
- 影响:每个 shard 边界处约有 `seq_len` 个序列包含跨 shard 样本(占总数 <1%)
|
||||
- 修复思路:在 `_sliding_window_fn` 中注入 shard 边界标记,遇到边界时清空缓冲区
|
||||
- 严重程度:低。若 shard 内帧数远大于 seq_len,可忽略
|
||||
|
||||
### 2. `SimulateEvents` 跨 shard 状态残留
|
||||
|
||||
`EventProcessor` 内部维护 `_prev_frame` 用于帧差计算。跨 shard 时,新 shard 的第一帧会与上一个 shard 最后一帧计算差,产生错误的事件帧。
|
||||
|
||||
- 影响:每个 shard 的第 1 帧事件帧错误,涉及该帧的所有滑窗序列均受影响
|
||||
- 每 shard 错误帧数:1 帧(加上滑窗放大,约 `seq_len` 个序列各包含此帧)
|
||||
- 修复:在 shard 边界处调用 `EventProcessor.reset()`。需在 `_build_pipeline` 中插入边界信号或改用按 shard 独立处理的方案
|
||||
- 严重程度:低。每 shard 仅 1 帧,训练数据量大时可忽略
|
||||
|
||||
@@ -39,10 +39,11 @@ VAL_SCENES = [
|
||||
# "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
||||
]
|
||||
TEST_SCENES = [
|
||||
"indoor_forward_9","indoor_forward_3",
|
||||
# "indoor_forward_7", # Hard 室内
|
||||
# "outdoor_forward_1", # Easy 室外
|
||||
# "outdoor_forward_5" # Hard 室外
|
||||
"indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
||||
# "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy
|
||||
]
|
||||
|
||||
|
||||
@@ -93,7 +94,7 @@ class ModelConfig:
|
||||
class TrainConfig:
|
||||
seq_len: int = 8 # frames per training sequence
|
||||
batch_size: int = 32
|
||||
epochs: int = 100
|
||||
epochs: int = 300
|
||||
lr: float = 1e-3
|
||||
weight_decay: float = 1e-5
|
||||
lr_scheduler_step: int = 30
|
||||
@@ -101,6 +102,9 @@ class TrainConfig:
|
||||
num_workers: int = 4
|
||||
seed: int = 42
|
||||
|
||||
# Sliding window: stride=1 → full overlap, stride=seq_len → non-overlapping
|
||||
sliding_window_stride: int = 1
|
||||
|
||||
# Event simulation
|
||||
event_threshold: float = 0.1
|
||||
event_use_log: bool = True
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
WebDataset-based dataset with sequence sampling.
|
||||
WebDataset-based dataset with sliding-window sequence sampling.
|
||||
|
||||
Each sample from the dataset is a dict:
|
||||
{
|
||||
@@ -8,7 +8,8 @@ Each sample from the dataset is a dict:
|
||||
"v_body_target": np.ndarray (2,), # body-frame [vx, vy]
|
||||
}
|
||||
|
||||
The dataset groups consecutive frames into sequences of length seq_len.
|
||||
The dataset groups consecutive frames into overlapping sequences via a
|
||||
sliding window (stride=1 by default for full overlap).
|
||||
"""
|
||||
|
||||
import webdataset as wds
|
||||
@@ -36,28 +37,52 @@ def _scene_urls(scene_names: List[str], root: Path = DATASET_ROOT) -> List[str]:
|
||||
return urls
|
||||
|
||||
|
||||
def _sliding_window_fn(seq_len: int, stride: int = 1):
|
||||
"""Convert a stream of individual sample dicts into overlapping sequences.
|
||||
|
||||
Filters out non-numeric keys (e.g. ``__key__`` from tar metadata).
|
||||
Yields dicts where each value is a stacked array of shape ``(seq_len, ...)``.
|
||||
"""
|
||||
def apply(iterator):
|
||||
import numpy as np
|
||||
_NUMERIC_TYPES = (np.ndarray, np.floating, np.integer, float, int)
|
||||
buffer = []
|
||||
for item in iterator:
|
||||
buffer.append(item)
|
||||
while len(buffer) >= seq_len:
|
||||
seq = buffer[:seq_len]
|
||||
keys = [k for k in seq[0] if isinstance(seq[0][k], _NUMERIC_TYPES)]
|
||||
yield {k: np.stack([s[k] for s in seq]) for k in keys}
|
||||
buffer = buffer[stride:]
|
||||
return apply
|
||||
|
||||
|
||||
def _build_pipeline(
|
||||
urls: List[str],
|
||||
transform: Callable,
|
||||
seq_len: int,
|
||||
stride: int = 1,
|
||||
shuffle: int = 1000,
|
||||
deterministic: bool = False,
|
||||
):
|
||||
"""
|
||||
Build a WebDataset pipeline that:
|
||||
1. Reads tar shards
|
||||
2. Decodes and transforms individual samples
|
||||
3. Groups consecutive samples into sequences
|
||||
1. Reads tar shards (shard-level shuffle only)
|
||||
2. Decodes and transforms individual samples ***in temporal order***
|
||||
3. Groups consecutive samples into overlapping sequences via sliding window
|
||||
4. Shuffles at the ***sequence level***
|
||||
"""
|
||||
dataset = wds.WebDataset(urls, shardshuffle=shuffle if not deterministic else 0, empty_check=False)
|
||||
|
||||
if not deterministic:
|
||||
dataset = dataset.shuffle(shuffle)
|
||||
|
||||
# NO sample-level shuffle → preserves temporal order for SimulateEvents
|
||||
dataset = dataset.decode().map(transform)
|
||||
|
||||
# Group into sequences of seq_len consecutive frames
|
||||
dataset = dataset.batched(seq_len, partial=False)
|
||||
# Sliding window: overlapping sequences with configurable stride
|
||||
dataset = dataset.compose(_sliding_window_fn(seq_len, stride))
|
||||
|
||||
# Sequence-level shuffle (after grouping, so sequences are coherent)
|
||||
if not deterministic:
|
||||
dataset = dataset.shuffle(shuffle)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -65,12 +90,13 @@ def _build_pipeline(
|
||||
def create_train_loader(
|
||||
scene_names: Optional[List[str]] = None,
|
||||
seq_len: int = 8,
|
||||
stride: int = 1,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 4,
|
||||
event_threshold: float = 0.1,
|
||||
event_use_log: bool = True,
|
||||
):
|
||||
"""Create a DataLoader for training."""
|
||||
"""Create a DataLoader for training with sliding-window sampling."""
|
||||
if scene_names is None:
|
||||
from src.velocity_prediction.config import TRAIN_SCENES
|
||||
scene_names = TRAIN_SCENES
|
||||
@@ -80,13 +106,13 @@ def create_train_loader(
|
||||
event_threshold=event_threshold,
|
||||
event_use_log=event_use_log,
|
||||
)
|
||||
pipeline = _build_pipeline(urls, transform, seq_len=seq_len, shuffle=1000)
|
||||
pipeline = _build_pipeline(urls, transform, seq_len=seq_len, stride=stride, shuffle=1000)
|
||||
|
||||
loader = wds.WebLoader(
|
||||
pipeline,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=False, # already shuffled in pipeline
|
||||
shuffle=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
@@ -94,12 +120,13 @@ def create_train_loader(
|
||||
def create_val_loader(
|
||||
scene_names: Optional[List[str]] = None,
|
||||
seq_len: int = 8,
|
||||
stride: int = 1,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 4,
|
||||
event_threshold: float = 0.1,
|
||||
event_use_log: bool = True,
|
||||
):
|
||||
"""Create a DataLoader for validation (deterministic order)."""
|
||||
"""Create a DataLoader for validation (deterministic order, sliding window)."""
|
||||
if scene_names is None:
|
||||
from src.velocity_prediction.config import VAL_SCENES
|
||||
scene_names = VAL_SCENES
|
||||
@@ -109,7 +136,7 @@ def create_val_loader(
|
||||
event_threshold=event_threshold,
|
||||
event_use_log=event_use_log,
|
||||
)
|
||||
pipeline = _build_pipeline(urls, transform, seq_len=seq_len, shuffle=0, deterministic=True)
|
||||
pipeline = _build_pipeline(urls, transform, seq_len=seq_len, stride=stride, shuffle=0, deterministic=True)
|
||||
|
||||
loader = wds.WebLoader(
|
||||
pipeline,
|
||||
|
||||
@@ -122,6 +122,7 @@ def main():
|
||||
# Data loaders
|
||||
train_loader = create_train_loader(
|
||||
seq_len=train_cfg.seq_len,
|
||||
stride=train_cfg.sliding_window_stride,
|
||||
batch_size=train_cfg.batch_size,
|
||||
num_workers=train_cfg.num_workers,
|
||||
event_threshold=train_cfg.event_threshold,
|
||||
@@ -129,6 +130,7 @@ def main():
|
||||
)
|
||||
val_loader = create_val_loader(
|
||||
seq_len=train_cfg.seq_len,
|
||||
stride=train_cfg.sliding_window_stride,
|
||||
batch_size=train_cfg.batch_size,
|
||||
num_workers=train_cfg.num_workers,
|
||||
event_threshold=train_cfg.event_threshold,
|
||||
|
||||
Reference in New Issue
Block a user