feat: enable BatchNorm2d in CNNEncoder and add AMP support

- Uncomment BatchNorm2d in CNNEncoder (activated when cfg.use_bn=True)
- Add torch.amp.GradScaler + autocast for mixed precision training
- Add --amp/--no-amp CLI argument (default: enabled)

Generated by Mistral Vibe. deepseek-v4-flash
Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
This commit is contained in:
2026-06-08 00:41:34 +08:00
parent b5abbc239d
commit 1369edaad7
2 changed files with 23 additions and 12 deletions

View File

@@ -29,7 +29,7 @@ class CNNEncoder(nn.Module):
for out_ch in channels: for out_ch in channels:
layers.extend([ layers.extend([
nn.Conv2d(in_ch, out_ch, kernel_size=cfg.kernel_size, padding=cfg.kernel_size // 2), nn.Conv2d(in_ch, out_ch, kernel_size=cfg.kernel_size, padding=cfg.kernel_size // 2),
# nn.BatchNorm2d(out_ch) if cfg.use_bn else nn.Identity(), nn.BatchNorm2d(out_ch) if cfg.use_bn else nn.Identity(),
nn.Identity(), nn.Identity(),
nn.LeakyReLU(inplace=True), nn.LeakyReLU(inplace=True),
nn.MaxPool2d(cfg.pool_size), nn.MaxPool2d(cfg.pool_size),

View File

@@ -32,11 +32,13 @@ def train_one_epoch(
loader, loader,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
criterion: nn.Module, criterion: nn.Module,
scaler: torch.cuda.amp.GradScaler,
device: torch.device, device: torch.device,
epoch: int, epoch: int,
writer: SummaryWriter, writer: SummaryWriter,
log_interval: int = 50, log_interval: int = 50,
global_step: int = 0, global_step: int = 0,
use_amp: bool = True,
) -> tuple[float, int]: ) -> tuple[float, int]:
"""Train for one epoch. Returns (avg_loss, updated_global_step).""" """Train for one epoch. Returns (avg_loss, updated_global_step)."""
model.train() model.train()
@@ -50,14 +52,15 @@ def train_one_epoch(
target = batch["v_body_target"].to(device) # (B, S, 2) target = batch["v_body_target"].to(device) # (B, S, 2)
# Predict velocity for the last frame in the sequence # Predict velocity for the last frame in the sequence
with torch.amp.autocast(device.type, enabled=use_amp):
pred = model(events, tilt) # (B, 2) pred = model(events, tilt) # (B, 2)
target_last = target[:, -1, :] # (B, 2) target_last = target[:, -1, :] # (B, 2)
loss = criterion(pred, target_last) loss = criterion(pred, target_last)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() scaler.scale(loss).backward()
optimizer.step() scaler.step(optimizer)
scaler.update()
total_loss += loss.item() total_loss += loss.item()
num_batches += 1 num_batches += 1
@@ -79,6 +82,7 @@ def validate(
loader, loader,
criterion: nn.Module, criterion: nn.Module,
device: torch.device, device: torch.device,
use_amp: bool = True,
) -> float: ) -> float:
"""Validate. Returns average loss.""" """Validate. Returns average loss."""
model.eval() model.eval()
@@ -90,10 +94,11 @@ def validate(
tilt = batch["tilt"].to(device) tilt = batch["tilt"].to(device)
target = batch["v_body_target"].to(device) target = batch["v_body_target"].to(device)
with torch.amp.autocast(device.type, enabled=use_amp):
pred = model(events, tilt) pred = model(events, tilt)
target_last = target[:, -1, :] target_last = target[:, -1, :]
loss = criterion(pred, target_last) loss = criterion(pred, target_last)
total_loss += loss.item() total_loss += loss.item()
num_batches += 1 num_batches += 1
@@ -107,7 +112,10 @@ def main():
help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')") help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')")
parser.add_argument("--resume", type=str, default=None, parser.add_argument("--resume", type=str, default=None,
help="Path to checkpoint .pt file to resume training from") help="Path to checkpoint .pt file to resume training from")
parser.add_argument("--amp", action=argparse.BooleanOptionalAction, default=True,
help="Enable Automatic Mixed Precision (default: True)")
args = parser.parse_args() args = parser.parse_args()
use_amp = args.amp
set_seed(train_cfg.seed) set_seed(train_cfg.seed)
device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu") device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu")
@@ -116,8 +124,10 @@ def main():
# Create model # Create model
model = VelocityPredictionModel() model = VelocityPredictionModel()
model.to(device) model.to(device)
scaler = torch.amp.GradScaler(device.type, enabled=use_amp)
total_params = count_parameters(model) total_params = count_parameters(model)
print(f"Model parameters: {total_params:,} ({total_params/1e6:.3f} M)") print(f"Model parameters: {total_params:,} ({total_params/1e6:.3f} M)")
print(f"AMP: {'enabled' if use_amp else 'disabled'}")
# Data loaders # Data loaders
train_loader = create_train_loader( train_loader = create_train_loader(
@@ -196,11 +206,12 @@ def main():
epoch_start = time.time() epoch_start = time.time()
train_loss, global_step = train_one_epoch( train_loss, global_step = train_one_epoch(
model, train_loader, optimizer, criterion, device, epoch, writer, model, train_loader, optimizer, criterion, scaler, device, epoch, writer,
log_interval=train_cfg.log_interval, log_interval=train_cfg.log_interval,
global_step=global_step, global_step=global_step,
use_amp=use_amp,
) )
val_loss = validate(model, val_loader, criterion, device) val_loss = validate(model, val_loader, criterion, device, use_amp=use_amp)
scheduler.step() scheduler.step()
epoch_time = time.time() - epoch_start epoch_time = time.time() - epoch_start