feat: activate CNN encoder, enable head near-zero init, disable NormalizeVelocity
- Activate CNNEncoder forward (replace zero placeholder with actual inference) - Enable near-zero weight init for head final layer (weight*=0.01, bias=0) - Disable NormalizeVelocity transform to train on raw velocity scale - (BatchNorm remains commented out) Generated by deepseek-v4-flash. Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit is contained in:
@@ -119,8 +119,10 @@ class VelocityPredictionModel(nn.Module):
|
||||
)
|
||||
|
||||
# # Small init for the final layer: start from near-zero output
|
||||
# self.head[-1].weight.data.mul_(0.01)
|
||||
# self.head[-1].bias.data.zero_()
|
||||
self.head[-1].weight.data.mul_(0.01)
|
||||
self.head[-1].bias.data.zero_()
|
||||
# nn.init.uniform_(self.head[-1].weight, -0.001, 0.001)
|
||||
# nn.init.zeros_(self.head[-1].bias)
|
||||
|
||||
def forward(self, events: torch.Tensor, tilt: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
@@ -132,9 +134,9 @@ class VelocityPredictionModel(nn.Module):
|
||||
v_body: (B, 2) predicted body-frame [v_forward, v_lateral] at the last timestep
|
||||
"""
|
||||
# Per-frame encoding
|
||||
# cnn_feat = self.cnn(events) # (B, S, 256)
|
||||
B, S = events.shape[:2]
|
||||
cnn_feat = events.new_zeros(B, S, self.cnn.out_dim) # 全零替代
|
||||
cnn_feat = self.cnn(events) # (B, S, 256)
|
||||
# B, S = events.shape[:2]
|
||||
# cnn_feat = events.new_zeros(B, S, self.cnn.out_dim) # 全零替代
|
||||
|
||||
pose_feat = self.pose_mlp(tilt) # (B, S, 64)
|
||||
|
||||
|
||||
@@ -139,7 +139,7 @@ def build_train_transform(event_threshold=0.1, event_use_log=True):
|
||||
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
||||
ComputeTilt(),
|
||||
ComputeBodyVelocity(),
|
||||
NormalizeVelocity(),
|
||||
# NormalizeVelocity(),
|
||||
])
|
||||
|
||||
|
||||
@@ -150,5 +150,5 @@ def build_val_transform(event_threshold=0.1, event_use_log=True):
|
||||
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
||||
ComputeTilt(),
|
||||
ComputeBodyVelocity(),
|
||||
NormalizeVelocity(),
|
||||
# NormalizeVelocity(),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user