Compare commits

...

3 Commits

Author SHA1 Message Date
1369edaad7 feat: enable BatchNorm2d in CNNEncoder and add AMP support
- Uncomment BatchNorm2d in CNNEncoder (activated when cfg.use_bn=True)
- Add torch.amp.GradScaler + autocast for mixed precision training
- Add --amp/--no-amp CLI argument (default: enabled)

Generated by Mistral Vibe. deepseek-v4-flash
Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
2026-06-08 00:41:34 +08:00
b5abbc239d feat: activate CNN encoder, enable head near-zero init, disable NormalizeVelocity
- Activate CNNEncoder forward (replace zero placeholder with actual inference)
- Enable near-zero weight init for head final layer (weight*=0.01, bias=0)
- Disable NormalizeVelocity transform to train on raw velocity scale
- (BatchNorm remains commented out)

Generated by deepseek-v4-flash.
Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
2026-06-06 14:04:40 +08:00
e7e773a48f fix: evaluate each scene independently to avoid plot mixing
Multi-scene evaluation previously concatenated all scenes into one
continuous trace, causing scene boundary jumps to appear as glitches
in plots. Now evaluates each scene separately and inserts NaN
separators between scenes when concatenating for plotting.

Generated by Mistral Vibe (deepseek-v4-flash).
Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
2026-06-05 16:47:42 +08:00
4 changed files with 64 additions and 50 deletions

View File

@@ -170,42 +170,43 @@ def main():
model.to(device)
print(f"Loaded checkpoint from {args.checkpoint} (epoch={ckpt.get('epoch', '?')})")
# Validation loader (use test scenes for final eval)
# Evaluate each scene independently → NaN gaps prevent plot mixing
from src.velocity_prediction.config import TEST_SCENES
loader = create_val_loader(
scene_names=TEST_SCENES,
seq_len=train_cfg.seq_len,
batch_size=train_cfg.batch_size,
num_workers=2,
event_threshold=train_cfg.event_threshold,
event_use_log=train_cfg.event_use_log,
)
all_preds, all_targets = [], []
scene_rmses = []
# # ── Quick event diagnostics: inspect one batch ───────────────
# print("\n========== Event Frame Diagnostics ==========")
# sample_batch = next(iter(loader))
# ev = sample_batch["events"] # (B, S, 1, H, W)
# print(f"Events shape: {ev.shape}")
# print(f"Events dtype: {ev.dtype}")
# print(f"Events value counts: -1: {(ev == -1).sum().item()}, "
# f"0: {(ev == 0).sum().item()}, +1: {(ev == 1).sum().item()}")
# total_el = ev.numel()
# nonzero = (ev != 0).sum().item()
# print(f"Non-zero ratio: {nonzero / total_el:.6f} ({nonzero}/{total_el})")
# print(f"Per-sample non-zero: {[(ev[b] != 0).sum().item() for b in range(min(4, ev.shape[0]))]}")
# print("=============================================\n")
for scene in TEST_SCENES:
loader = create_val_loader(
scene_names=[scene],
seq_len=train_cfg.seq_len,
batch_size=train_cfg.batch_size,
num_workers=2,
event_threshold=train_cfg.event_threshold,
event_use_log=train_cfg.event_use_log,
)
results = evaluate(model, loader, device)
n = len(results["preds"])
print(f" [{scene}] RMSE vx={results['rmse_x']:.4f} vy={results['rmse_y']:.4f} "
f"xy={results['rmse_xy']:.4f} samples={n}")
scene_rmses.append(results["rmse_xy"])
# Evaluate
results = evaluate(model, loader, device)
print(f"\nEvaluation results on test scenes: {TEST_SCENES}")
print(f" RMSE vx: {results['rmse_x']:.4f} m/s")
print(f" RMSE vy: {results['rmse_y']:.4f} m/s")
print(f" RMSE xy: {results['rmse_xy']:.4f} m/s")
all_preds.append(results["preds"])
all_targets.append(results["targets"])
# NaN separator → plot won't connect discontinuous scenes
sep = np.full((1, 2), np.nan, dtype=np.float32)
all_preds.append(sep)
all_targets.append(sep)
# Plots
# Overall RMSE = mean across scenes (unweighted, avoids scene size bias)
rmse_xy = np.mean(scene_rmses)
print(f"\nOverall ({len(TEST_SCENES)} scenes, mean across scenes): RMSE xy={rmse_xy:.4f} m/s")
# Plots (with NaN gaps between scenes)
if args.plot:
plot_results(results["preds"], results["targets"], "eval_velocity.png")
plot_scatter(results["preds"], results["targets"], "eval_scatter.png")
preds_cat = np.concatenate(all_preds, axis=0)
targets_cat = np.concatenate(all_targets, axis=0)
plot_results(preds_cat, targets_cat, "eval_velocity.png")
plot_scatter(preds_cat, targets_cat, "eval_scatter.png")
if __name__ == "__main__":

View File

@@ -29,7 +29,7 @@ class CNNEncoder(nn.Module):
for out_ch in channels:
layers.extend([
nn.Conv2d(in_ch, out_ch, kernel_size=cfg.kernel_size, padding=cfg.kernel_size // 2),
# nn.BatchNorm2d(out_ch) if cfg.use_bn else nn.Identity(),
nn.BatchNorm2d(out_ch) if cfg.use_bn else nn.Identity(),
nn.Identity(),
nn.LeakyReLU(inplace=True),
nn.MaxPool2d(cfg.pool_size),
@@ -119,8 +119,10 @@ class VelocityPredictionModel(nn.Module):
)
# # Small init for the final layer: start from near-zero output
# self.head[-1].weight.data.mul_(0.01)
# self.head[-1].bias.data.zero_()
self.head[-1].weight.data.mul_(0.01)
self.head[-1].bias.data.zero_()
# nn.init.uniform_(self.head[-1].weight, -0.001, 0.001)
# nn.init.zeros_(self.head[-1].bias)
def forward(self, events: torch.Tensor, tilt: torch.Tensor) -> torch.Tensor:
"""
@@ -132,9 +134,9 @@ class VelocityPredictionModel(nn.Module):
v_body: (B, 2) predicted body-frame [v_forward, v_lateral] at the last timestep
"""
# Per-frame encoding
# cnn_feat = self.cnn(events) # (B, S, 256)
B, S = events.shape[:2]
cnn_feat = events.new_zeros(B, S, self.cnn.out_dim) # 全零替代
cnn_feat = self.cnn(events) # (B, S, 256)
# B, S = events.shape[:2]
# cnn_feat = events.new_zeros(B, S, self.cnn.out_dim) # 全零替代
pose_feat = self.pose_mlp(tilt) # (B, S, 64)

View File

@@ -32,11 +32,13 @@ def train_one_epoch(
loader,
optimizer: torch.optim.Optimizer,
criterion: nn.Module,
scaler: torch.cuda.amp.GradScaler,
device: torch.device,
epoch: int,
writer: SummaryWriter,
log_interval: int = 50,
global_step: int = 0,
use_amp: bool = True,
) -> tuple[float, int]:
"""Train for one epoch. Returns (avg_loss, updated_global_step)."""
model.train()
@@ -50,14 +52,15 @@ def train_one_epoch(
target = batch["v_body_target"].to(device) # (B, S, 2)
# Predict velocity for the last frame in the sequence
pred = model(events, tilt) # (B, 2)
target_last = target[:, -1, :] # (B, 2)
loss = criterion(pred, target_last)
with torch.amp.autocast(device.type, enabled=use_amp):
pred = model(events, tilt) # (B, 2)
target_last = target[:, -1, :] # (B, 2)
loss = criterion(pred, target_last)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
total_loss += loss.item()
num_batches += 1
@@ -79,6 +82,7 @@ def validate(
loader,
criterion: nn.Module,
device: torch.device,
use_amp: bool = True,
) -> float:
"""Validate. Returns average loss."""
model.eval()
@@ -90,10 +94,11 @@ def validate(
tilt = batch["tilt"].to(device)
target = batch["v_body_target"].to(device)
pred = model(events, tilt)
target_last = target[:, -1, :]
with torch.amp.autocast(device.type, enabled=use_amp):
pred = model(events, tilt)
target_last = target[:, -1, :]
loss = criterion(pred, target_last)
loss = criterion(pred, target_last)
total_loss += loss.item()
num_batches += 1
@@ -107,7 +112,10 @@ def main():
help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')")
parser.add_argument("--resume", type=str, default=None,
help="Path to checkpoint .pt file to resume training from")
parser.add_argument("--amp", action=argparse.BooleanOptionalAction, default=True,
help="Enable Automatic Mixed Precision (default: True)")
args = parser.parse_args()
use_amp = args.amp
set_seed(train_cfg.seed)
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
@@ -116,8 +124,10 @@ def main():
# Create model
model = VelocityPredictionModel()
model.to(device)
scaler = torch.amp.GradScaler(device.type, enabled=use_amp)
total_params = count_parameters(model)
print(f"Model parameters: {total_params:,} ({total_params/1e6:.3f} M)")
print(f"AMP: {'enabled' if use_amp else 'disabled'}")
# Data loaders
train_loader = create_train_loader(
@@ -196,11 +206,12 @@ def main():
epoch_start = time.time()
train_loss, global_step = train_one_epoch(
model, train_loader, optimizer, criterion, device, epoch, writer,
model, train_loader, optimizer, criterion, scaler, device, epoch, writer,
log_interval=train_cfg.log_interval,
global_step=global_step,
use_amp=use_amp,
)
val_loss = validate(model, val_loader, criterion, device)
val_loss = validate(model, val_loader, criterion, device, use_amp=use_amp)
scheduler.step()
epoch_time = time.time() - epoch_start

View File

@@ -139,7 +139,7 @@ def build_train_transform(event_threshold=0.1, event_use_log=True):
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
ComputeTilt(),
ComputeBodyVelocity(),
NormalizeVelocity(),
# NormalizeVelocity(),
])
@@ -150,5 +150,5 @@ def build_val_transform(event_threshold=0.1, event_use_log=True):
SimulateEvents(threshold=event_threshold, use_log=event_use_log),
ComputeTilt(),
ComputeBodyVelocity(),
NormalizeVelocity(),
# NormalizeVelocity(),
])