From cb9936542e1aa46c98cceb04c0af45c02356c723 Mon Sep 17 00:00:00 2001 From: CaoWangrenbo Date: Fri, 5 Jun 2026 16:45:24 +0800 Subject: [PATCH] feat: replace non-overlapping windows with sliding-window sequence sampling - Remove sample-level shuffle before transforms (broke SimulateEvents) - Add _sliding_window_fn: yields overlapping sequences with configurable stride - Add sequence-level shuffle after grouping (preserves temporal coherence) - Add sliding_window_stride to TrainConfig (stride=1 for full overlap) - Update create_train/val_loader and train.py to pass stride - AGENTS.md: document known issues (cross-shard boundary, SimulateEvents state) - AGENTS.md: add cuda:7 device preference Generated by Mistral Vibe (deepseek-v4-flash). Co-Authored-By: Mistral Vibe --- AGENTS.md | 20 +++++++++++ src/velocity_prediction/config.py | 8 +++-- src/velocity_prediction/dataset.py | 57 ++++++++++++++++++++++-------- src/velocity_prediction/train.py | 2 ++ 4 files changed, 70 insertions(+), 17 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 2c864ff..48d2e07 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -169,3 +169,23 @@ uv run python -m benchmark.benchmark --checkpoint checkpoints/best.pt - 速度归一化统计量:待重新计算 - 模型预测 `[v_right, v_forward]`(右向和前向速度) - 所有代码在项目根目录下以 `uv run python -m ` 运行 +- GPU 优先使用 `cuda:7`,训练时添加 `--device cuda:7` + +## 已知问题 + +### 1. 滑窗跨 shard 边界 + +`dataset.py` 中滑窗实现基于 WebDataset 串联后的连续流,不感知 shard 边界。当样本恰好处于 shard 末尾时,序列会跨越到下一个 shard 的起始帧。 + +- 影响:每个 shard 边界处约有 `seq_len` 个序列包含跨 shard 样本(占总数 <1%) +- 修复思路:在 `_sliding_window_fn` 中注入 shard 边界标记,遇到边界时清空缓冲区 +- 严重程度:低。若 shard 内帧数远大于 seq_len,可忽略 + +### 2. `SimulateEvents` 跨 shard 状态残留 + +`EventProcessor` 内部维护 `_prev_frame` 用于帧差计算。跨 shard 时,新 shard 的第一帧会与上一个 shard 最后一帧计算差,产生错误的事件帧。 + +- 影响:每个 shard 的第 1 帧事件帧错误,涉及该帧的所有滑窗序列均受影响 +- 每 shard 错误帧数:1 帧(加上滑窗放大,约 `seq_len` 个序列各包含此帧) +- 修复:在 shard 边界处调用 `EventProcessor.reset()`。需在 `_build_pipeline` 中插入边界信号或改用按 shard 独立处理的方案 +- 严重程度:低。每 shard 仅 1 帧,训练数据量大时可忽略 diff --git a/src/velocity_prediction/config.py b/src/velocity_prediction/config.py index 8e1c128..2ba2000 100644 --- a/src/velocity_prediction/config.py +++ b/src/velocity_prediction/config.py @@ -39,10 +39,11 @@ VAL_SCENES = [ # "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy ] TEST_SCENES = [ + "indoor_forward_9","indoor_forward_3", # "indoor_forward_7", # Hard 室内 # "outdoor_forward_1", # Easy 室外 # "outdoor_forward_5" # Hard 室外 - "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy + # "indoor_forward_3", "indoor_forward_9", "indoor_forward_10", # Easy ] @@ -93,7 +94,7 @@ class ModelConfig: class TrainConfig: seq_len: int = 8 # frames per training sequence batch_size: int = 32 - epochs: int = 100 + epochs: int = 300 lr: float = 1e-3 weight_decay: float = 1e-5 lr_scheduler_step: int = 30 @@ -101,6 +102,9 @@ class TrainConfig: num_workers: int = 4 seed: int = 42 + # Sliding window: stride=1 → full overlap, stride=seq_len → non-overlapping + sliding_window_stride: int = 1 + # Event simulation event_threshold: float = 0.1 event_use_log: bool = True diff --git a/src/velocity_prediction/dataset.py b/src/velocity_prediction/dataset.py index 88325c8..364a189 100644 --- a/src/velocity_prediction/dataset.py +++ b/src/velocity_prediction/dataset.py @@ -1,5 +1,5 @@ """ -WebDataset-based dataset with sequence sampling. +WebDataset-based dataset with sliding-window sequence sampling. Each sample from the dataset is a dict: { @@ -8,7 +8,8 @@ Each sample from the dataset is a dict: "v_body_target": np.ndarray (2,), # body-frame [vx, vy] } -The dataset groups consecutive frames into sequences of length seq_len. +The dataset groups consecutive frames into overlapping sequences via a +sliding window (stride=1 by default for full overlap). """ import webdataset as wds @@ -36,28 +37,52 @@ def _scene_urls(scene_names: List[str], root: Path = DATASET_ROOT) -> List[str]: return urls +def _sliding_window_fn(seq_len: int, stride: int = 1): + """Convert a stream of individual sample dicts into overlapping sequences. + + Filters out non-numeric keys (e.g. ``__key__`` from tar metadata). + Yields dicts where each value is a stacked array of shape ``(seq_len, ...)``. + """ + def apply(iterator): + import numpy as np + _NUMERIC_TYPES = (np.ndarray, np.floating, np.integer, float, int) + buffer = [] + for item in iterator: + buffer.append(item) + while len(buffer) >= seq_len: + seq = buffer[:seq_len] + keys = [k for k in seq[0] if isinstance(seq[0][k], _NUMERIC_TYPES)] + yield {k: np.stack([s[k] for s in seq]) for k in keys} + buffer = buffer[stride:] + return apply + + def _build_pipeline( urls: List[str], transform: Callable, seq_len: int, + stride: int = 1, shuffle: int = 1000, deterministic: bool = False, ): """ Build a WebDataset pipeline that: - 1. Reads tar shards - 2. Decodes and transforms individual samples - 3. Groups consecutive samples into sequences + 1. Reads tar shards (shard-level shuffle only) + 2. Decodes and transforms individual samples ***in temporal order*** + 3. Groups consecutive samples into overlapping sequences via sliding window + 4. Shuffles at the ***sequence level*** """ dataset = wds.WebDataset(urls, shardshuffle=shuffle if not deterministic else 0, empty_check=False) - if not deterministic: - dataset = dataset.shuffle(shuffle) - + # NO sample-level shuffle → preserves temporal order for SimulateEvents dataset = dataset.decode().map(transform) - # Group into sequences of seq_len consecutive frames - dataset = dataset.batched(seq_len, partial=False) + # Sliding window: overlapping sequences with configurable stride + dataset = dataset.compose(_sliding_window_fn(seq_len, stride)) + + # Sequence-level shuffle (after grouping, so sequences are coherent) + if not deterministic: + dataset = dataset.shuffle(shuffle) return dataset @@ -65,12 +90,13 @@ def _build_pipeline( def create_train_loader( scene_names: Optional[List[str]] = None, seq_len: int = 8, + stride: int = 1, batch_size: int = 32, num_workers: int = 4, event_threshold: float = 0.1, event_use_log: bool = True, ): - """Create a DataLoader for training.""" + """Create a DataLoader for training with sliding-window sampling.""" if scene_names is None: from src.velocity_prediction.config import TRAIN_SCENES scene_names = TRAIN_SCENES @@ -80,13 +106,13 @@ def create_train_loader( event_threshold=event_threshold, event_use_log=event_use_log, ) - pipeline = _build_pipeline(urls, transform, seq_len=seq_len, shuffle=1000) + pipeline = _build_pipeline(urls, transform, seq_len=seq_len, stride=stride, shuffle=1000) loader = wds.WebLoader( pipeline, batch_size=batch_size, num_workers=num_workers, - shuffle=False, # already shuffled in pipeline + shuffle=False, ) return loader @@ -94,12 +120,13 @@ def create_train_loader( def create_val_loader( scene_names: Optional[List[str]] = None, seq_len: int = 8, + stride: int = 1, batch_size: int = 32, num_workers: int = 4, event_threshold: float = 0.1, event_use_log: bool = True, ): - """Create a DataLoader for validation (deterministic order).""" + """Create a DataLoader for validation (deterministic order, sliding window).""" if scene_names is None: from src.velocity_prediction.config import VAL_SCENES scene_names = VAL_SCENES @@ -109,7 +136,7 @@ def create_val_loader( event_threshold=event_threshold, event_use_log=event_use_log, ) - pipeline = _build_pipeline(urls, transform, seq_len=seq_len, shuffle=0, deterministic=True) + pipeline = _build_pipeline(urls, transform, seq_len=seq_len, stride=stride, shuffle=0, deterministic=True) loader = wds.WebLoader( pipeline, diff --git a/src/velocity_prediction/train.py b/src/velocity_prediction/train.py index 34b88cb..a0f4ba2 100644 --- a/src/velocity_prediction/train.py +++ b/src/velocity_prediction/train.py @@ -122,6 +122,7 @@ def main(): # Data loaders train_loader = create_train_loader( seq_len=train_cfg.seq_len, + stride=train_cfg.sliding_window_stride, batch_size=train_cfg.batch_size, num_workers=train_cfg.num_workers, event_threshold=train_cfg.event_threshold, @@ -129,6 +130,7 @@ def main(): ) val_loader = create_val_loader( seq_len=train_cfg.seq_len, + stride=train_cfg.sliding_window_stride, batch_size=train_cfg.batch_size, num_workers=train_cfg.num_workers, event_threshold=train_cfg.event_threshold,