feat: activate CNN encoder, enable head near-zero init, disable NormalizeVelocity
- Activate CNNEncoder forward (replace zero placeholder with actual inference) - Enable near-zero weight init for head final layer (weight*=0.01, bias=0) - Disable NormalizeVelocity transform to train on raw velocity scale - (BatchNorm remains commented out) Generated by deepseek-v4-flash. Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit is contained in:
@@ -119,8 +119,10 @@ class VelocityPredictionModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# # Small init for the final layer: start from near-zero output
|
# # Small init for the final layer: start from near-zero output
|
||||||
# self.head[-1].weight.data.mul_(0.01)
|
self.head[-1].weight.data.mul_(0.01)
|
||||||
# self.head[-1].bias.data.zero_()
|
self.head[-1].bias.data.zero_()
|
||||||
|
# nn.init.uniform_(self.head[-1].weight, -0.001, 0.001)
|
||||||
|
# nn.init.zeros_(self.head[-1].bias)
|
||||||
|
|
||||||
def forward(self, events: torch.Tensor, tilt: torch.Tensor) -> torch.Tensor:
|
def forward(self, events: torch.Tensor, tilt: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -132,9 +134,9 @@ class VelocityPredictionModel(nn.Module):
|
|||||||
v_body: (B, 2) predicted body-frame [v_forward, v_lateral] at the last timestep
|
v_body: (B, 2) predicted body-frame [v_forward, v_lateral] at the last timestep
|
||||||
"""
|
"""
|
||||||
# Per-frame encoding
|
# Per-frame encoding
|
||||||
# cnn_feat = self.cnn(events) # (B, S, 256)
|
cnn_feat = self.cnn(events) # (B, S, 256)
|
||||||
B, S = events.shape[:2]
|
# B, S = events.shape[:2]
|
||||||
cnn_feat = events.new_zeros(B, S, self.cnn.out_dim) # 全零替代
|
# cnn_feat = events.new_zeros(B, S, self.cnn.out_dim) # 全零替代
|
||||||
|
|
||||||
pose_feat = self.pose_mlp(tilt) # (B, S, 64)
|
pose_feat = self.pose_mlp(tilt) # (B, S, 64)
|
||||||
|
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ def build_train_transform(event_threshold=0.1, event_use_log=True):
|
|||||||
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
||||||
ComputeTilt(),
|
ComputeTilt(),
|
||||||
ComputeBodyVelocity(),
|
ComputeBodyVelocity(),
|
||||||
NormalizeVelocity(),
|
# NormalizeVelocity(),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
@@ -150,5 +150,5 @@ def build_val_transform(event_threshold=0.1, event_use_log=True):
|
|||||||
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
|
||||||
ComputeTilt(),
|
ComputeTilt(),
|
||||||
ComputeBodyVelocity(),
|
ComputeBodyVelocity(),
|
||||||
NormalizeVelocity(),
|
# NormalizeVelocity(),
|
||||||
])
|
])
|
||||||
|
|||||||
Reference in New Issue
Block a user