From b5abbc239d1c383be7bebc99dfbe2e3cdfab2cc2 Mon Sep 17 00:00:00 2001 From: CaoWangrenbo Date: Sat, 6 Jun 2026 14:04:40 +0800 Subject: [PATCH] feat: activate CNN encoder, enable head near-zero init, disable NormalizeVelocity - Activate CNNEncoder forward (replace zero placeholder with actual inference) - Enable near-zero weight init for head final layer (weight*=0.01, bias=0) - Disable NormalizeVelocity transform to train on raw velocity scale - (BatchNorm remains commented out) Generated by deepseek-v4-flash. Co-Authored-By: Mistral Vibe --- src/velocity_prediction/model.py | 12 +++++++----- src/velocity_prediction/transforms.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/velocity_prediction/model.py b/src/velocity_prediction/model.py index dc294a3..42d784c 100644 --- a/src/velocity_prediction/model.py +++ b/src/velocity_prediction/model.py @@ -119,8 +119,10 @@ class VelocityPredictionModel(nn.Module): ) # # Small init for the final layer: start from near-zero output - # self.head[-1].weight.data.mul_(0.01) - # self.head[-1].bias.data.zero_() + self.head[-1].weight.data.mul_(0.01) + self.head[-1].bias.data.zero_() + # nn.init.uniform_(self.head[-1].weight, -0.001, 0.001) + # nn.init.zeros_(self.head[-1].bias) def forward(self, events: torch.Tensor, tilt: torch.Tensor) -> torch.Tensor: """ @@ -132,9 +134,9 @@ class VelocityPredictionModel(nn.Module): v_body: (B, 2) predicted body-frame [v_forward, v_lateral] at the last timestep """ # Per-frame encoding - # cnn_feat = self.cnn(events) # (B, S, 256) - B, S = events.shape[:2] - cnn_feat = events.new_zeros(B, S, self.cnn.out_dim) # 全零替代 + cnn_feat = self.cnn(events) # (B, S, 256) + # B, S = events.shape[:2] + # cnn_feat = events.new_zeros(B, S, self.cnn.out_dim) # 全零替代 pose_feat = self.pose_mlp(tilt) # (B, S, 64) diff --git a/src/velocity_prediction/transforms.py b/src/velocity_prediction/transforms.py index cdcc118..5f3aa39 100644 --- a/src/velocity_prediction/transforms.py +++ b/src/velocity_prediction/transforms.py @@ -139,7 +139,7 @@ def build_train_transform(event_threshold=0.1, event_use_log=True): SimulateEvents(threshold=event_threshold, use_log=event_use_log), ComputeTilt(), ComputeBodyVelocity(), - NormalizeVelocity(), + # NormalizeVelocity(), ]) @@ -150,5 +150,5 @@ def build_val_transform(event_threshold=0.1, event_use_log=True): SimulateEvents(threshold=event_threshold, use_log=event_use_log), ComputeTilt(), ComputeBodyVelocity(), - NormalizeVelocity(), + # NormalizeVelocity(), ])