diff --git a/src/velocity_prediction/model.py b/src/velocity_prediction/model.py index 42d784c..196cc00 100644 --- a/src/velocity_prediction/model.py +++ b/src/velocity_prediction/model.py @@ -29,7 +29,7 @@ class CNNEncoder(nn.Module): for out_ch in channels: layers.extend([ nn.Conv2d(in_ch, out_ch, kernel_size=cfg.kernel_size, padding=cfg.kernel_size // 2), - # nn.BatchNorm2d(out_ch) if cfg.use_bn else nn.Identity(), + nn.BatchNorm2d(out_ch) if cfg.use_bn else nn.Identity(), nn.Identity(), nn.LeakyReLU(inplace=True), nn.MaxPool2d(cfg.pool_size), diff --git a/src/velocity_prediction/train.py b/src/velocity_prediction/train.py index a0f4ba2..62b6b93 100644 --- a/src/velocity_prediction/train.py +++ b/src/velocity_prediction/train.py @@ -32,11 +32,13 @@ def train_one_epoch( loader, optimizer: torch.optim.Optimizer, criterion: nn.Module, + scaler: torch.cuda.amp.GradScaler, device: torch.device, epoch: int, writer: SummaryWriter, log_interval: int = 50, global_step: int = 0, + use_amp: bool = True, ) -> tuple[float, int]: """Train for one epoch. Returns (avg_loss, updated_global_step).""" model.train() @@ -50,14 +52,15 @@ def train_one_epoch( target = batch["v_body_target"].to(device) # (B, S, 2) # Predict velocity for the last frame in the sequence - pred = model(events, tilt) # (B, 2) - target_last = target[:, -1, :] # (B, 2) - - loss = criterion(pred, target_last) + with torch.amp.autocast(device.type, enabled=use_amp): + pred = model(events, tilt) # (B, 2) + target_last = target[:, -1, :] # (B, 2) + loss = criterion(pred, target_last) optimizer.zero_grad() - loss.backward() - optimizer.step() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() total_loss += loss.item() num_batches += 1 @@ -79,6 +82,7 @@ def validate( loader, criterion: nn.Module, device: torch.device, + use_amp: bool = True, ) -> float: """Validate. Returns average loss.""" model.eval() @@ -90,10 +94,11 @@ def validate( tilt = batch["tilt"].to(device) target = batch["v_body_target"].to(device) - pred = model(events, tilt) - target_last = target[:, -1, :] + with torch.amp.autocast(device.type, enabled=use_amp): + pred = model(events, tilt) + target_last = target[:, -1, :] + loss = criterion(pred, target_last) - loss = criterion(pred, target_last) total_loss += loss.item() num_batches += 1 @@ -107,7 +112,10 @@ def main(): help="CUDA device, e.g. 'cuda:0', 'cuda:1' (default: 'cuda')") parser.add_argument("--resume", type=str, default=None, help="Path to checkpoint .pt file to resume training from") + parser.add_argument("--amp", action=argparse.BooleanOptionalAction, default=True, + help="Enable Automatic Mixed Precision (default: True)") args = parser.parse_args() + use_amp = args.amp set_seed(train_cfg.seed) device = torch.device(args.device if torch.cuda.is_available() and "cuda" in args.device else "cpu") @@ -116,8 +124,10 @@ def main(): # Create model model = VelocityPredictionModel() model.to(device) + scaler = torch.amp.GradScaler(device.type, enabled=use_amp) total_params = count_parameters(model) print(f"Model parameters: {total_params:,} ({total_params/1e6:.3f} M)") + print(f"AMP: {'enabled' if use_amp else 'disabled'}") # Data loaders train_loader = create_train_loader( @@ -196,11 +206,12 @@ def main(): epoch_start = time.time() train_loss, global_step = train_one_epoch( - model, train_loader, optimizer, criterion, device, epoch, writer, + model, train_loader, optimizer, criterion, scaler, device, epoch, writer, log_interval=train_cfg.log_interval, global_step=global_step, + use_amp=use_amp, ) - val_loss = validate(model, val_loader, criterion, device) + val_loss = validate(model, val_loader, criterion, device, use_amp=use_amp) scheduler.step() epoch_time = time.time() - epoch_start