diff --git a/src/velocity_prediction/model.py b/src/velocity_prediction/model.py index dc294a3..42d784c 100644 --- a/src/velocity_prediction/model.py +++ b/src/velocity_prediction/model.py @@ -119,8 +119,10 @@ class VelocityPredictionModel(nn.Module): ) # # Small init for the final layer: start from near-zero output - # self.head[-1].weight.data.mul_(0.01) - # self.head[-1].bias.data.zero_() + self.head[-1].weight.data.mul_(0.01) + self.head[-1].bias.data.zero_() + # nn.init.uniform_(self.head[-1].weight, -0.001, 0.001) + # nn.init.zeros_(self.head[-1].bias) def forward(self, events: torch.Tensor, tilt: torch.Tensor) -> torch.Tensor: """ @@ -132,9 +134,9 @@ class VelocityPredictionModel(nn.Module): v_body: (B, 2) predicted body-frame [v_forward, v_lateral] at the last timestep """ # Per-frame encoding - # cnn_feat = self.cnn(events) # (B, S, 256) - B, S = events.shape[:2] - cnn_feat = events.new_zeros(B, S, self.cnn.out_dim) # 全零替代 + cnn_feat = self.cnn(events) # (B, S, 256) + # B, S = events.shape[:2] + # cnn_feat = events.new_zeros(B, S, self.cnn.out_dim) # 全零替代 pose_feat = self.pose_mlp(tilt) # (B, S, 64) diff --git a/src/velocity_prediction/transforms.py b/src/velocity_prediction/transforms.py index cdcc118..5f3aa39 100644 --- a/src/velocity_prediction/transforms.py +++ b/src/velocity_prediction/transforms.py @@ -139,7 +139,7 @@ def build_train_transform(event_threshold=0.1, event_use_log=True): SimulateEvents(threshold=event_threshold, use_log=event_use_log), ComputeTilt(), ComputeBodyVelocity(), - NormalizeVelocity(), + # NormalizeVelocity(), ]) @@ -150,5 +150,5 @@ def build_val_transform(event_threshold=0.1, event_use_log=True): SimulateEvents(threshold=event_threshold, use_log=event_use_log), ComputeTilt(), ComputeBodyVelocity(), - NormalizeVelocity(), + # NormalizeVelocity(), ])