清理代码,删除跳连接部分

This commit is contained in:
2026-01-11 13:25:34 +08:00
parent c5502cc87c
commit df703638da
3 changed files with 68 additions and 268 deletions

View File

@@ -49,11 +49,6 @@ def get_args_parser():
# Model parameters
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--use-representation-head', action='store_true',
help='Use representation head for pose/velocity prediction')
parser.add_argument('--representation-dim', default=128, type=int,
help='Dimension of representation vector')
parser.add_argument('--use-skip', default=False, type=bool, help='using skip connections')
# Training parameters
parser.add_argument('--batch-size', default=32, type=int)
@@ -207,9 +202,6 @@ def main(args):
print(f"Creating model: {args.model}")
model_kwargs = {
'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
'use_skip': args.use_skip,
}
if args.model == 'SwiftFormerTemporal_XS':
@@ -258,7 +250,7 @@ def main(args):
super().__init__()
self.mse = nn.MSELoss()
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
def forward(self, pred_frame, target_frame, temporal_indices=None):
loss = self.mse(pred_frame, target_frame)
loss_dict = {'mse': loss}
return loss, loss_dict
@@ -386,10 +378,10 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
# Forward pass
with torch.amp.autocast(device_type='cuda'):
pred_frames, representations = model(input_frames)
pred_frames = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
representations, temporal_indices
temporal_indices
)
loss_value = loss.item()
@@ -452,7 +444,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
with torch.no_grad():
# Take first sample from batch for visualization
pred_vis, _ = model(input_frames[:1])
pred_vis = model(input_frames[:1])
# Convert to appropriate format for TensorBoard
# Assuming frames are in [B, C, H, W] format
writer.add_images('train/input', input_frames[:1], global_step)
@@ -497,10 +489,10 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
# Compute output
with torch.amp.autocast(device_type='cuda'):
pred_frames, representations = model(input_frames)
pred_frames = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
representations, temporal_indices
temporal_indices
)
# 计算诊断指标
@@ -555,4 +547,4 @@ if __name__ == '__main__':
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
main(args)