清理代码,删除跳连接部分
This commit is contained in:
@@ -49,11 +49,6 @@ def get_args_parser():
|
||||
# Model parameters
|
||||
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
|
||||
help='Name of model to train')
|
||||
parser.add_argument('--use-representation-head', action='store_true',
|
||||
help='Use representation head for pose/velocity prediction')
|
||||
parser.add_argument('--representation-dim', default=128, type=int,
|
||||
help='Dimension of representation vector')
|
||||
parser.add_argument('--use-skip', default=False, type=bool, help='using skip connections')
|
||||
|
||||
# Training parameters
|
||||
parser.add_argument('--batch-size', default=32, type=int)
|
||||
@@ -207,9 +202,6 @@ def main(args):
|
||||
print(f"Creating model: {args.model}")
|
||||
model_kwargs = {
|
||||
'num_frames': args.num_frames,
|
||||
'use_representation_head': args.use_representation_head,
|
||||
'representation_dim': args.representation_dim,
|
||||
'use_skip': args.use_skip,
|
||||
}
|
||||
|
||||
if args.model == 'SwiftFormerTemporal_XS':
|
||||
@@ -258,7 +250,7 @@ def main(args):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
|
||||
def forward(self, pred_frame, target_frame, temporal_indices=None):
|
||||
loss = self.mse(pred_frame, target_frame)
|
||||
loss_dict = {'mse': loss}
|
||||
return loss, loss_dict
|
||||
@@ -386,10 +378,10 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
|
||||
# Forward pass
|
||||
with torch.amp.autocast(device_type='cuda'):
|
||||
pred_frames, representations = model(input_frames)
|
||||
pred_frames = model(input_frames)
|
||||
loss, loss_dict = criterion(
|
||||
pred_frames, target_frames,
|
||||
representations, temporal_indices
|
||||
temporal_indices
|
||||
)
|
||||
|
||||
loss_value = loss.item()
|
||||
@@ -452,7 +444,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
|
||||
with torch.no_grad():
|
||||
# Take first sample from batch for visualization
|
||||
pred_vis, _ = model(input_frames[:1])
|
||||
pred_vis = model(input_frames[:1])
|
||||
# Convert to appropriate format for TensorBoard
|
||||
# Assuming frames are in [B, C, H, W] format
|
||||
writer.add_images('train/input', input_frames[:1], global_step)
|
||||
@@ -497,10 +489,10 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
|
||||
|
||||
# Compute output
|
||||
with torch.amp.autocast(device_type='cuda'):
|
||||
pred_frames, representations = model(input_frames)
|
||||
pred_frames = model(input_frames)
|
||||
loss, loss_dict = criterion(
|
||||
pred_frames, target_frames,
|
||||
representations, temporal_indices
|
||||
temporal_indices
|
||||
)
|
||||
|
||||
# 计算诊断指标
|
||||
@@ -555,4 +547,4 @@ if __name__ == '__main__':
|
||||
if args.output_dir:
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
main(args)
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user