更新归一化方式,当前直接映射,不利用均值标准差进行标准化

This commit is contained in:
2026-01-08 16:10:24 +08:00
parent f7601e9170
commit 500c2eb18f
3 changed files with 89 additions and 74 deletions

View File

@@ -19,7 +19,7 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma
from util import *
from models import *
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
from util.video_dataset import VideoFrameDataset
from util.frame_losses import MultiTaskLoss
# Try to import TensorBoard
@@ -47,7 +47,7 @@ def get_args_parser():
help='Number of input frames (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='Input frame size')
parser.add_argument('--max-interval', default=1, type=int,
parser.add_argument('--max-interval', default=4, type=int,
help='Maximum interval between consecutive frames')
# Model parameters
@@ -109,10 +109,10 @@ def get_args_parser():
help='Weight for frame prediction loss')
parser.add_argument('--contrastive-weight', type=float, default=0.1,
help='Weight for contrastive loss')
parser.add_argument('--l1-weight', type=float, default=1.0,
help='Weight for L1 loss')
parser.add_argument('--ssim-weight', type=float, default=0.1,
help='Weight for SSIM loss')
# parser.add_argument('--l1-weight', type=float, default=1.0,
# help='Weight for L1 loss')
# parser.add_argument('--ssim-weight', type=float, default=0.1,
# help='Weight for SSIM loss')
parser.add_argument('--no-contrastive', action='store_true',
help='Disable contrastive loss')
parser.add_argument('--no-ssim', action='store_true',
@@ -326,7 +326,7 @@ def main(args):
lr_scheduler.step(epoch)
# Save checkpoint
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1):
if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1):
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
utils.save_on_master({
'model': model_without_ddp.state_dict(),