Compare commits
4 Commits
12de74f130
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 543beefa2a | |||
| a92a0b29e9 | |||
| df703638da | |||
| c5502cc87c |
@@ -46,6 +46,14 @@ def denormalize(tensor):
|
|||||||
tensor = tensor * 255
|
tensor = tensor * 255
|
||||||
return tensor.clamp(0, 255)
|
return tensor.clamp(0, 255)
|
||||||
|
|
||||||
|
def minmax_denormalize(tensor):
|
||||||
|
tensor_min = tensor.min()
|
||||||
|
tensor_max = tensor.max()
|
||||||
|
tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
|
||||||
|
# tensor = tensor*2-1
|
||||||
|
tensor = tensor*255
|
||||||
|
return tensor.clamp(0, 255)
|
||||||
|
|
||||||
|
|
||||||
def calculate_metrics(pred, target, debug=False):
|
def calculate_metrics(pred, target, debug=False):
|
||||||
"""
|
"""
|
||||||
@@ -67,28 +75,16 @@ def calculate_metrics(pred, target, debug=False):
|
|||||||
if target_np.ndim == 3:
|
if target_np.ndim == 3:
|
||||||
target_np = target_np.squeeze(0)
|
target_np = target_np.squeeze(0)
|
||||||
|
|
||||||
if debug:
|
# if debug:
|
||||||
print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
|
# print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
|
||||||
print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
|
# print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
|
||||||
print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
|
# print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
|
||||||
|
|
||||||
# 计算MSE - 修复错误的tmp公式
|
|
||||||
# 原错误公式: tmp = 1 - (pred_np - target_np) / 255 * 2
|
|
||||||
# 正确公式: 直接计算像素差的平方
|
|
||||||
mse = np.mean((pred_np - target_np) ** 2)
|
mse = np.mean((pred_np - target_np) ** 2)
|
||||||
|
|
||||||
# 同时计算错误公式的MSE用于对比
|
|
||||||
tmp = 1 - (pred_np - target_np) / 255 * 2
|
|
||||||
wrong_mse = np.mean(tmp**2)
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f"[DEBUG] Correct MSE: {mse:.6f}, Wrong MSE (tmp formula): {wrong_mse:.6f}")
|
|
||||||
|
|
||||||
# 计算SSIM (数据范围0-255)
|
|
||||||
data_range = 255.0
|
data_range = 255.0
|
||||||
ssim_value = ssim(pred_np, target_np, data_range=data_range)
|
ssim_value = ssim(pred_np, target_np, data_range=data_range)
|
||||||
|
|
||||||
# 计算PSNR
|
|
||||||
psnr_value = psnr(target_np, pred_np, data_range=data_range)
|
psnr_value = psnr(target_np, pred_np, data_range=data_range)
|
||||||
|
|
||||||
return mse, ssim_value, psnr_value
|
return mse, ssim_value, psnr_value
|
||||||
@@ -134,14 +130,8 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
|
|||||||
ax.set_title('Predicted')
|
ax.set_title('Predicted')
|
||||||
ax.axis('off')
|
ax.axis('off')
|
||||||
|
|
||||||
# debug print - 改进为更有信息量的输出
|
#debug print
|
||||||
if isinstance(pred_frame, np.ndarray):
|
print(target_frame)
|
||||||
print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
|
|
||||||
# 检查是否有大量值在127.5附近
|
|
||||||
mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
|
|
||||||
percent_near_127_5 = np.mean(mask_near_127_5) * 100
|
|
||||||
print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
|
|
||||||
else:
|
|
||||||
print(pred_frame)
|
print(pred_frame)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
@@ -161,8 +151,8 @@ def evaluate_model(model, data_loader, device, args):
|
|||||||
metrics_dict: 包含所有指标的字典
|
metrics_dict: 包含所有指标的字典
|
||||||
sample_results: 示例结果用于可视化
|
sample_results: 示例结果用于可视化
|
||||||
"""
|
"""
|
||||||
# model.eval()
|
model.eval()
|
||||||
model.train() # 临时使用训练模式
|
# model.train() # 临时使用训练模式
|
||||||
|
|
||||||
# 初始化指标累加器
|
# 初始化指标累加器
|
||||||
total_mse = 0.0
|
total_mse = 0.0
|
||||||
@@ -183,10 +173,11 @@ def evaluate_model(model, data_loader, device, args):
|
|||||||
target_frames = target_frames.to(device, non_blocking=True)
|
target_frames = target_frames.to(device, non_blocking=True)
|
||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
pred_frames, _ = model(input_frames)
|
pred_frames = model(input_frames)
|
||||||
|
|
||||||
# 反归一化用于指标计算
|
# 反归一化用于指标计算
|
||||||
pred_denorm = denormalize(pred_frames) # [B, 1, H, W]
|
# pred_denorm = minmax_denormalize(pred_frames) # [B, 1, H, W]
|
||||||
|
pred_denorm = denormalize(pred_frames)
|
||||||
target_denorm = denormalize(target_frames) # [B, 1, H, W]
|
target_denorm = denormalize(target_frames) # [B, 1, H, W]
|
||||||
|
|
||||||
batch_size = input_frames.size(0)
|
batch_size = input_frames.size(0)
|
||||||
@@ -202,13 +193,13 @@ def evaluate_model(model, data_loader, device, args):
|
|||||||
|
|
||||||
# 对第一个样本启用调试
|
# 对第一个样本启用调试
|
||||||
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
|
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
|
||||||
if debug_mode:
|
# if debug_mode:
|
||||||
print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
|
# print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
|
||||||
print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
|
# print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
|
||||||
print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
|
# print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
|
||||||
print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
|
# print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
|
||||||
|
|
||||||
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=debug_mode)
|
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=False)
|
||||||
|
|
||||||
total_mse += mse
|
total_mse += mse
|
||||||
total_ssim += ssim_value
|
total_ssim += ssim_value
|
||||||
@@ -309,8 +300,6 @@ def main(args):
|
|||||||
print(f"创建模型: {args.model}")
|
print(f"创建模型: {args.model}")
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
'num_frames': args.num_frames,
|
'num_frames': args.num_frames,
|
||||||
'use_representation_head': args.use_representation_head,
|
|
||||||
'representation_dim': args.representation_dim,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.model == 'SwiftFormerTemporal_XS':
|
if args.model == 'SwiftFormerTemporal_XS':
|
||||||
@@ -335,10 +324,6 @@ def main(args):
|
|||||||
except (pickle.UnpicklingError, TypeError) as e:
|
except (pickle.UnpicklingError, TypeError) as e:
|
||||||
print(f"使用weights_only=False加载失败: {e}")
|
print(f"使用weights_only=False加载失败: {e}")
|
||||||
print("尝试使用torch.serialization.add_safe_globals...")
|
print("尝试使用torch.serialization.add_safe_globals...")
|
||||||
from argparse import Namespace
|
|
||||||
# 添加安全全局变量
|
|
||||||
torch.serialization.add_safe_globals([Namespace])
|
|
||||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
|
||||||
|
|
||||||
# 处理状态字典(可能包含'module.'前缀)
|
# 处理状态字典(可能包含'module.'前缀)
|
||||||
if 'model' in checkpoint:
|
if 'model' in checkpoint:
|
||||||
@@ -462,10 +447,6 @@ def get_args_parser():
|
|||||||
# 模型参数
|
# 模型参数
|
||||||
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
|
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
|
||||||
help='要评估的模型名称')
|
help='要评估的模型名称')
|
||||||
parser.add_argument('--use-representation-head', action='store_true',
|
|
||||||
help='使用表示头进行姿态/速度预测')
|
|
||||||
parser.add_argument('--representation-dim', default=128, type=int,
|
|
||||||
help='表示向量的维度')
|
|
||||||
|
|
||||||
# 评估参数
|
# 评估参数
|
||||||
parser.add_argument('--batch-size', default=16, type=int,
|
parser.add_argument('--batch-size', default=16, type=int,
|
||||||
|
|||||||
@@ -20,17 +20,13 @@ from util import *
|
|||||||
from models import *
|
from models import *
|
||||||
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
||||||
from util.video_dataset import VideoFrameDataset
|
from util.video_dataset import VideoFrameDataset
|
||||||
from util.frame_losses import MultiTaskLoss
|
# from util.frame_losses import MultiTaskLoss
|
||||||
|
|
||||||
# Try to import TensorBoard
|
# Try to import TensorBoard
|
||||||
try:
|
try:
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
TENSORBOARD_AVAILABLE = True
|
TENSORBOARD_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
|
||||||
from tensorboardX import SummaryWriter
|
|
||||||
TENSORBOARD_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
TENSORBOARD_AVAILABLE = False
|
TENSORBOARD_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
@@ -47,17 +43,12 @@ def get_args_parser():
|
|||||||
help='Number of input frames (T)')
|
help='Number of input frames (T)')
|
||||||
parser.add_argument('--frame-size', default=224, type=int,
|
parser.add_argument('--frame-size', default=224, type=int,
|
||||||
help='Input frame size')
|
help='Input frame size')
|
||||||
parser.add_argument('--max-interval', default=4, type=int,
|
parser.add_argument('--max-interval', default=10, type=int,
|
||||||
help='Maximum interval between consecutive frames')
|
help='Maximum interval between consecutive frames')
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
|
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
|
||||||
help='Name of model to train')
|
help='Name of model to train')
|
||||||
parser.add_argument('--use-representation-head', action='store_true',
|
|
||||||
help='Use representation head for pose/velocity prediction')
|
|
||||||
parser.add_argument('--representation-dim', default=128, type=int,
|
|
||||||
help='Dimension of representation vector')
|
|
||||||
parser.add_argument('--use-skip', default=True, type=bool, help='using skip connections')
|
|
||||||
|
|
||||||
# Training parameters
|
# Training parameters
|
||||||
parser.add_argument('--batch-size', default=32, type=int)
|
parser.add_argument('--batch-size', default=32, type=int)
|
||||||
@@ -130,7 +121,7 @@ def get_args_parser():
|
|||||||
help='start epoch')
|
help='start epoch')
|
||||||
parser.add_argument('--eval', action='store_true',
|
parser.add_argument('--eval', action='store_true',
|
||||||
help='Perform evaluation only')
|
help='Perform evaluation only')
|
||||||
parser.add_argument('--num-workers', default=4, type=int)
|
parser.add_argument('--num-workers', default=16, type=int)
|
||||||
parser.add_argument('--pin-mem', action='store_true',
|
parser.add_argument('--pin-mem', action='store_true',
|
||||||
help='Pin CPU memory in DataLoader')
|
help='Pin CPU memory in DataLoader')
|
||||||
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
|
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
|
||||||
@@ -211,9 +202,6 @@ def main(args):
|
|||||||
print(f"Creating model: {args.model}")
|
print(f"Creating model: {args.model}")
|
||||||
model_kwargs = {
|
model_kwargs = {
|
||||||
'num_frames': args.num_frames,
|
'num_frames': args.num_frames,
|
||||||
'use_representation_head': args.use_representation_head,
|
|
||||||
'representation_dim': args.representation_dim,
|
|
||||||
'use_skip': args.use_skip,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if args.model == 'SwiftFormerTemporal_XS':
|
if args.model == 'SwiftFormerTemporal_XS':
|
||||||
@@ -262,7 +250,7 @@ def main(args):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.mse = nn.MSELoss()
|
self.mse = nn.MSELoss()
|
||||||
|
|
||||||
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
|
def forward(self, pred_frame, target_frame, temporal_indices=None):
|
||||||
loss = self.mse(pred_frame, target_frame)
|
loss = self.mse(pred_frame, target_frame)
|
||||||
loss_dict = {'mse': loss}
|
loss_dict = {'mse': loss}
|
||||||
return loss, loss_dict
|
return loss, loss_dict
|
||||||
@@ -276,7 +264,7 @@ def main(args):
|
|||||||
checkpoint = torch.hub.load_state_dict_from_url(
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
args.resume, map_location='cpu', check_hash=True)
|
args.resume, map_location='cpu', check_hash=True)
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
|
||||||
|
|
||||||
model_without_ddp.load_state_dict(checkpoint['model'])
|
model_without_ddp.load_state_dict(checkpoint['model'])
|
||||||
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
||||||
@@ -320,7 +308,7 @@ def main(args):
|
|||||||
|
|
||||||
train_stats, global_step = train_one_epoch(
|
train_stats, global_step = train_one_epoch(
|
||||||
model, criterion, data_loader_train,
|
model, criterion, data_loader_train,
|
||||||
optimizer, device, epoch, loss_scaler,
|
optimizer, device, epoch, loss_scaler, args.clip_grad, args.clip_mode,
|
||||||
model_ema=model_ema, writer=writer,
|
model_ema=model_ema, writer=writer,
|
||||||
global_step=global_step, args=args
|
global_step=global_step, args=args
|
||||||
)
|
)
|
||||||
@@ -328,7 +316,7 @@ def main(args):
|
|||||||
lr_scheduler.step(epoch)
|
lr_scheduler.step(epoch)
|
||||||
|
|
||||||
# Save checkpoint
|
# Save checkpoint
|
||||||
if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1):
|
if args.output_dir and (epoch % 1 == 0 or epoch == args.epochs - 1):
|
||||||
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
|
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
|
||||||
utils.save_on_master({
|
utils.save_on_master({
|
||||||
'model': model_without_ddp.state_dict(),
|
'model': model_without_ddp.state_dict(),
|
||||||
@@ -368,7 +356,7 @@ def main(args):
|
|||||||
|
|
||||||
|
|
||||||
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
||||||
clip_grad=0, clip_mode='norm', model_ema=None, writer=None,
|
clip_grad=0.01, clip_mode='norm', model_ema=None, writer=None,
|
||||||
global_step=0, args=None, **kwargs):
|
global_step=0, args=None, **kwargs):
|
||||||
model.train()
|
model.train()
|
||||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||||
@@ -390,10 +378,10 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
|||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
with torch.amp.autocast(device_type='cuda'):
|
with torch.amp.autocast(device_type='cuda'):
|
||||||
pred_frames, representations = model(input_frames)
|
pred_frames = model(input_frames)
|
||||||
loss, loss_dict = criterion(
|
loss, loss_dict = criterion(
|
||||||
pred_frames, target_frames,
|
pred_frames, target_frames,
|
||||||
representations, temporal_indices
|
temporal_indices
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
@@ -403,7 +391,6 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
|||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# 在反向传播前保存梯度用于诊断
|
|
||||||
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
|
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
|
||||||
parameters=model.parameters())
|
parameters=model.parameters())
|
||||||
|
|
||||||
@@ -426,14 +413,14 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
|||||||
metric_logger.update(pred_std=pred_std)
|
metric_logger.update(pred_std=pred_std)
|
||||||
metric_logger.update(grad_norm=total_grad_norm)
|
metric_logger.update(grad_norm=total_grad_norm)
|
||||||
|
|
||||||
# 每50个批次打印一次BatchNorm统计
|
# # 每50个批次打印一次BatchNorm统计
|
||||||
if batch_idx % 50 == 0:
|
if batch_idx % 50 == 0:
|
||||||
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
|
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
|
||||||
# 检查一个BatchNorm层的运行统计
|
# # 检查一个BatchNorm层的运行统计
|
||||||
for name, module in model.named_modules():
|
# for name, module in model.named_modules():
|
||||||
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||||
print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
# print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||||
break
|
# break
|
||||||
|
|
||||||
# Log to TensorBoard
|
# Log to TensorBoard
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
@@ -457,7 +444,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
|||||||
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
|
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Take first sample from batch for visualization
|
# Take first sample from batch for visualization
|
||||||
pred_vis, _ = model(input_frames[:1])
|
pred_vis = model(input_frames[:1])
|
||||||
# Convert to appropriate format for TensorBoard
|
# Convert to appropriate format for TensorBoard
|
||||||
# Assuming frames are in [B, C, H, W] format
|
# Assuming frames are in [B, C, H, W] format
|
||||||
writer.add_images('train/input', input_frames[:1], global_step)
|
writer.add_images('train/input', input_frames[:1], global_step)
|
||||||
@@ -502,10 +489,10 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
|
|||||||
|
|
||||||
# Compute output
|
# Compute output
|
||||||
with torch.amp.autocast(device_type='cuda'):
|
with torch.amp.autocast(device_type='cuda'):
|
||||||
pred_frames, representations = model(input_frames)
|
pred_frames = model(input_frames)
|
||||||
loss, loss_dict = criterion(
|
loss, loss_dict = criterion(
|
||||||
pred_frames, target_frames,
|
pred_frames, target_frames,
|
||||||
representations, temporal_indices
|
temporal_indices
|
||||||
)
|
)
|
||||||
|
|
||||||
# 计算诊断指标
|
# 计算诊断指标
|
||||||
@@ -520,21 +507,21 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
|
|||||||
metric_logger.update(target_mean=target_mean)
|
metric_logger.update(target_mean=target_mean)
|
||||||
metric_logger.update(target_std=target_std)
|
metric_logger.update(target_std=target_std)
|
||||||
|
|
||||||
# 第一个批次打印详细诊断信息
|
# # 第一个批次打印详细诊断信息
|
||||||
if batch_idx == 0:
|
# if batch_idx == 0:
|
||||||
print(f"[评估诊断] 批次 0:")
|
# print(f"[评估诊断] 批次 0:")
|
||||||
print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
|
# print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
|
||||||
print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
|
# print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
|
||||||
print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
|
# print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
|
||||||
print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
|
# print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
|
||||||
|
|
||||||
# 检查BatchNorm运行统计
|
# # 检查BatchNorm运行统计
|
||||||
for name, module in model.named_modules():
|
# for name, module in model.named_modules():
|
||||||
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
# if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||||
print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
# print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||||
if module.running_var[0].item() < 1e-6:
|
# if module.running_var[0].item() < 1e-6:
|
||||||
print(f" 警告: BatchNorm运行方差接近零!")
|
# print(f" 警告: BatchNorm运行方差接近零!")
|
||||||
break
|
# break
|
||||||
|
|
||||||
# Update metrics
|
# Update metrics
|
||||||
metric_logger.update(loss=loss.item())
|
metric_logger.update(loss=loss.item())
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from timm.layers import DropPath, trunc_normal_
|
|||||||
|
|
||||||
|
|
||||||
class DecoderBlock(nn.Module):
|
class DecoderBlock(nn.Module):
|
||||||
"""Upsampling block for frame prediction decoder with residual connections"""
|
"""Upsampling block for frame prediction decoder without residual connections"""
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 主路径:反卷积 + 两个卷积层
|
# 主路径:反卷积 + 两个卷积层
|
||||||
@@ -21,282 +21,97 @@ class DecoderBlock(nn.Module):
|
|||||||
stride=stride,
|
stride=stride,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
output_padding=output_padding,
|
output_padding=output_padding,
|
||||||
bias=True # 启用bias,因为移除了BN
|
bias=False # 禁用bias,因为使用BN
|
||||||
)
|
)
|
||||||
|
self.bn1 = nn.BatchNorm2d(out_channels)
|
||||||
self.conv1 = nn.Conv2d(out_channels, out_channels,
|
self.conv1 = nn.Conv2d(out_channels, out_channels,
|
||||||
kernel_size=3, padding=1, bias=True)
|
kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||||
self.conv2 = nn.Conv2d(out_channels, out_channels,
|
self.conv2 = nn.Conv2d(out_channels, out_channels,
|
||||||
kernel_size=3, padding=1, bias=True)
|
kernel_size=3, padding=1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
# 残差路径:如果需要改变通道数或空间尺寸
|
# 使用ReLU激活函数
|
||||||
self.shortcut = nn.Identity()
|
self.activation = nn.ReLU(inplace=True)
|
||||||
if in_channels != out_channels or stride != 1:
|
|
||||||
# 使用1x1卷积调整通道数,如果需要上采样则使用反卷积
|
|
||||||
if stride == 1:
|
|
||||||
self.shortcut = nn.Conv2d(in_channels, out_channels,
|
|
||||||
kernel_size=1, bias=True)
|
|
||||||
else:
|
|
||||||
self.shortcut = nn.ConvTranspose2d(
|
|
||||||
in_channels, out_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=stride,
|
|
||||||
padding=0,
|
|
||||||
output_padding=output_padding,
|
|
||||||
bias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用LeakyReLU避免死亡神经元
|
|
||||||
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
|
||||||
|
|
||||||
# 初始化权重
|
# 初始化权重
|
||||||
self._init_weights()
|
self._init_weights()
|
||||||
|
|
||||||
def _init_weights(self):
|
def _init_weights(self):
|
||||||
# 初始化反卷积层
|
# 初始化反卷积层
|
||||||
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu')
|
||||||
if self.conv_transpose.bias is not None:
|
|
||||||
nn.init.constant_(self.conv_transpose.bias, 0)
|
|
||||||
|
|
||||||
# 初始化卷积层
|
# 初始化卷积层
|
||||||
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
|
||||||
if self.conv1.bias is not None:
|
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
|
||||||
nn.init.constant_(self.conv1.bias, 0)
|
|
||||||
|
|
||||||
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
|
# 初始化BN层(使用默认初始化)
|
||||||
if self.conv2.bias is not None:
|
for m in self.modules():
|
||||||
nn.init.constant_(self.conv2.bias, 0)
|
if isinstance(m, nn.BatchNorm2d):
|
||||||
|
nn.init.constant_(m.weight, 1)
|
||||||
# 初始化shortcut
|
nn.init.constant_(m.bias, 0)
|
||||||
if not isinstance(self.shortcut, nn.Identity):
|
|
||||||
if isinstance(self.shortcut, nn.Conv2d):
|
|
||||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
|
||||||
elif isinstance(self.shortcut, nn.ConvTranspose2d):
|
|
||||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
|
||||||
if self.shortcut.bias is not None:
|
|
||||||
nn.init.constant_(self.shortcut.bias, 0)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
identity = self.shortcut(x)
|
|
||||||
|
|
||||||
# 主路径
|
# 主路径
|
||||||
x = self.conv_transpose(x)
|
x = self.conv_transpose(x)
|
||||||
|
x = self.bn1(x)
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
|
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
|
x = self.bn2(x)
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
|
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
|
x = self.bn3(x)
|
||||||
# 残差连接
|
|
||||||
x = x + identity
|
|
||||||
x = self.activation(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class DecoderBlockWithSkip(nn.Module):
|
|
||||||
"""Decoder block with skip connection support"""
|
|
||||||
def __init__(self, in_channels, out_channels, skip_channels=0, kernel_size=3, stride=2, padding=1, output_padding=1):
|
|
||||||
super().__init__()
|
|
||||||
# 总输入通道 = 输入通道 + skip通道
|
|
||||||
total_in_channels = in_channels + skip_channels
|
|
||||||
|
|
||||||
# 主路径:反卷积 + 两个卷积层
|
|
||||||
self.conv_transpose = nn.ConvTranspose2d(
|
|
||||||
total_in_channels, out_channels,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
stride=stride,
|
|
||||||
padding=padding,
|
|
||||||
output_padding=output_padding,
|
|
||||||
bias=True
|
|
||||||
)
|
|
||||||
self.conv1 = nn.Conv2d(out_channels, out_channels,
|
|
||||||
kernel_size=3, padding=1, bias=True)
|
|
||||||
self.conv2 = nn.Conv2d(out_channels, out_channels,
|
|
||||||
kernel_size=3, padding=1, bias=True)
|
|
||||||
|
|
||||||
# 残差路径:如果需要改变通道数或空间尺寸
|
|
||||||
self.shortcut = nn.Identity()
|
|
||||||
if total_in_channels != out_channels or stride != 1:
|
|
||||||
if stride == 1:
|
|
||||||
self.shortcut = nn.Conv2d(total_in_channels, out_channels,
|
|
||||||
kernel_size=1, bias=True)
|
|
||||||
else:
|
|
||||||
self.shortcut = nn.ConvTranspose2d(
|
|
||||||
total_in_channels, out_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=stride,
|
|
||||||
padding=0,
|
|
||||||
output_padding=output_padding,
|
|
||||||
bias=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用LeakyReLU避免死亡神经元
|
|
||||||
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
|
||||||
|
|
||||||
# 初始化权重
|
|
||||||
self._init_weights()
|
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
# 初始化反卷积层
|
|
||||||
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
|
|
||||||
if self.conv_transpose.bias is not None:
|
|
||||||
nn.init.constant_(self.conv_transpose.bias, 0)
|
|
||||||
|
|
||||||
# 初始化卷积层
|
|
||||||
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
|
|
||||||
if self.conv1.bias is not None:
|
|
||||||
nn.init.constant_(self.conv1.bias, 0)
|
|
||||||
|
|
||||||
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
|
|
||||||
if self.conv2.bias is not None:
|
|
||||||
nn.init.constant_(self.conv2.bias, 0)
|
|
||||||
|
|
||||||
# 初始化shortcut
|
|
||||||
if not isinstance(self.shortcut, nn.Identity):
|
|
||||||
if isinstance(self.shortcut, nn.Conv2d):
|
|
||||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
|
||||||
elif isinstance(self.shortcut, nn.ConvTranspose2d):
|
|
||||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
|
||||||
if self.shortcut.bias is not None:
|
|
||||||
nn.init.constant_(self.shortcut.bias, 0)
|
|
||||||
|
|
||||||
def forward(self, x, skip_feature=None):
|
|
||||||
# 如果有skip feature,将其与输入拼接
|
|
||||||
if skip_feature is not None:
|
|
||||||
# 确保skip特征的空间尺寸与x匹配
|
|
||||||
if skip_feature.shape[2:] != x.shape[2:]:
|
|
||||||
# 使用双线性插值进行上采样或下采样
|
|
||||||
skip_feature = torch.nn.functional.interpolate(
|
|
||||||
skip_feature,
|
|
||||||
size=x.shape[2:],
|
|
||||||
mode='bilinear',
|
|
||||||
align_corners=False
|
|
||||||
)
|
|
||||||
x = torch.cat([x, skip_feature], dim=1)
|
|
||||||
|
|
||||||
identity = self.shortcut(x)
|
|
||||||
|
|
||||||
# 主路径
|
|
||||||
x = self.conv_transpose(x)
|
|
||||||
x = self.activation(x)
|
|
||||||
|
|
||||||
x = self.conv1(x)
|
|
||||||
x = self.activation(x)
|
|
||||||
|
|
||||||
x = self.conv2(x)
|
|
||||||
|
|
||||||
# 残差连接
|
|
||||||
x = x + identity
|
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class FramePredictionDecoder(nn.Module):
|
class FramePredictionDecoder(nn.Module):
|
||||||
"""Improved decoder for frame prediction with better upsampling strategy"""
|
"""Improved decoder for frame prediction"""
|
||||||
def __init__(self, embed_dims, output_channels=1, use_skip=False):
|
def __init__(self, embed_dims, output_channels=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_skip = use_skip
|
# Define decoder dimensions independently (no skip connections)
|
||||||
# Reverse the embed_dims for decoder
|
start_dim = embed_dims[-1]
|
||||||
decoder_dims = embed_dims[::-1]
|
decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
|
||||||
|
|
||||||
self.blocks = nn.ModuleList()
|
self.blocks = nn.ModuleList()
|
||||||
|
|
||||||
if use_skip:
|
# 第一个block:stride=2 (decoder_dims[0] -> decoder_dims[1])
|
||||||
# 使用支持skip connections的block
|
|
||||||
# 第一个block:从bottleneck到stage4,使用大步长stride=4,skip来自stage3
|
|
||||||
self.blocks.append(DecoderBlockWithSkip(
|
|
||||||
decoder_dims[0], decoder_dims[1],
|
|
||||||
skip_channels=embed_dims[3], # stage3的通道数
|
|
||||||
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
|
|
||||||
))
|
|
||||||
# 第二个block:stage4到stage3,stride=2,skip来自stage2
|
|
||||||
self.blocks.append(DecoderBlockWithSkip(
|
|
||||||
decoder_dims[1], decoder_dims[2],
|
|
||||||
skip_channels=embed_dims[2], # stage2的通道数
|
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
|
||||||
))
|
|
||||||
# 第三个block:stage3到stage2,stride=2,skip来自stage1
|
|
||||||
self.blocks.append(DecoderBlockWithSkip(
|
|
||||||
decoder_dims[2], decoder_dims[3],
|
|
||||||
skip_channels=embed_dims[1], # stage1的通道数
|
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
|
||||||
))
|
|
||||||
# 第四个block:stage2到stage1,stride=2,skip来自stage0
|
|
||||||
self.blocks.append(DecoderBlockWithSkip(
|
|
||||||
decoder_dims[3], 64, # 输出到64通道
|
|
||||||
skip_channels=embed_dims[0], # stage0的通道数
|
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
# 使用普通的DecoderBlock,第一个block使用大步长
|
|
||||||
self.blocks.append(DecoderBlock(
|
self.blocks.append(DecoderBlock(
|
||||||
decoder_dims[0], decoder_dims[1],
|
decoder_dims[0], decoder_dims[1],
|
||||||
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
))
|
))
|
||||||
|
# 第二个block:stride=2 (decoder_dims[1] -> decoder_dims[2])
|
||||||
self.blocks.append(DecoderBlock(
|
self.blocks.append(DecoderBlock(
|
||||||
decoder_dims[1], decoder_dims[2],
|
decoder_dims[1], decoder_dims[2],
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
))
|
))
|
||||||
|
# 第三个block:stride=2 (decoder_dims[2] -> decoder_dims[3])
|
||||||
self.blocks.append(DecoderBlock(
|
self.blocks.append(DecoderBlock(
|
||||||
decoder_dims[2], decoder_dims[3],
|
decoder_dims[2], decoder_dims[3],
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
))
|
))
|
||||||
# 第四个block:增加到64通道
|
# 第四个block:stride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
|
||||||
self.blocks.append(DecoderBlock(
|
self.blocks.append(DecoderBlock(
|
||||||
decoder_dims[3], 64,
|
decoder_dims[3], 64,
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
|
||||||
))
|
))
|
||||||
|
|
||||||
# 改进的最终输出层:不使用反卷积,只进行特征精炼
|
|
||||||
# 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合
|
|
||||||
self.final_block = nn.Sequential(
|
self.final_block = nn.Sequential(
|
||||||
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
|
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
|
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
|
||||||
nn.LeakyReLU(0.2, inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True)
|
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
|
||||||
# 移除Tanh,让输出在任意范围,由损失函数和归一化处理
|
nn.Tanh()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, skip_features=None):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
|
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
|
||||||
skip_features: list of encoder features from stages [stage3, stage2, stage1, stage0]
|
|
||||||
each of shape [B, C, H', W'] where C matches encoder dims
|
|
||||||
"""
|
"""
|
||||||
if self.use_skip:
|
|
||||||
if skip_features is None:
|
|
||||||
raise ValueError("skip_features must be provided when use_skip=True")
|
|
||||||
|
|
||||||
# 确保有4个skip features
|
|
||||||
assert len(skip_features) == 4, f"Need 4 skip features, got {len(skip_features)}"
|
|
||||||
|
|
||||||
# 反转顺序以匹配解码器:stage3, stage2, stage1, stage0
|
|
||||||
skip_features = skip_features[::-1]
|
|
||||||
|
|
||||||
# 调整skip特征的尺寸以匹配新的上采样策略
|
|
||||||
adjusted_skip_features = []
|
|
||||||
for i, skip in enumerate(skip_features):
|
|
||||||
if skip is not None:
|
|
||||||
# 计算目标尺寸:4, 2, 2, 2倍上采样
|
|
||||||
upsample_factors = [4, 2, 2, 2]
|
|
||||||
target_height = x.shape[2] * upsample_factors[i]
|
|
||||||
target_width = x.shape[3] * upsample_factors[i]
|
|
||||||
|
|
||||||
if skip.shape[2:] != (target_height, target_width):
|
|
||||||
skip = torch.nn.functional.interpolate(
|
|
||||||
skip,
|
|
||||||
size=(target_height, target_width),
|
|
||||||
mode='bilinear',
|
|
||||||
align_corners=False
|
|
||||||
)
|
|
||||||
adjusted_skip_features.append(skip)
|
|
||||||
|
|
||||||
# 四个block使用skip connections
|
|
||||||
for i in range(4):
|
|
||||||
x = self.blocks[i](x, adjusted_skip_features[i])
|
|
||||||
else:
|
|
||||||
# 不使用skip connections
|
# 不使用skip connections
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
x = self.blocks[i](x)
|
x = self.blocks[i](x)
|
||||||
@@ -316,10 +131,6 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
model_name='XS',
|
model_name='XS',
|
||||||
num_frames=3,
|
num_frames=3,
|
||||||
use_decoder=True,
|
use_decoder=True,
|
||||||
use_skip=True, # 新增:是否使用skip connections
|
|
||||||
use_representation_head=False,
|
|
||||||
representation_dim=128,
|
|
||||||
return_features=False,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -330,9 +141,6 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
# Store configuration
|
# Store configuration
|
||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
self.use_decoder = use_decoder
|
self.use_decoder = use_decoder
|
||||||
self.use_skip = use_skip # 保存skip connections设置
|
|
||||||
self.use_representation_head = use_representation_head
|
|
||||||
self.return_features = return_features
|
|
||||||
|
|
||||||
# Modify stem to accept multiple frames (only Y channel)
|
# Modify stem to accept multiple frames (only Y channel)
|
||||||
in_channels = num_frames
|
in_channels = num_frames
|
||||||
@@ -365,33 +173,20 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
if use_decoder:
|
if use_decoder:
|
||||||
self.decoder = FramePredictionDecoder(
|
self.decoder = FramePredictionDecoder(
|
||||||
embed_dims,
|
embed_dims,
|
||||||
output_channels=1,
|
output_channels=1
|
||||||
use_skip=use_skip # 传递skip connections设置
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Representation head for pose/velocity prediction
|
|
||||||
if use_representation_head:
|
|
||||||
self.representation_head = nn.Sequential(
|
|
||||||
nn.AdaptiveAvgPool2d(1),
|
|
||||||
nn.Flatten(),
|
|
||||||
nn.Linear(embed_dims[-1], representation_dim),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(representation_dim, representation_dim)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.representation_head = None
|
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
def _init_weights(self, m):
|
def _init_weights(self, m):
|
||||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||||
# 使用Kaiming初始化,适合ReLU/LeakyReLU
|
# 使用Kaiming初始化,适合ReLU
|
||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, nn.ConvTranspose2d):
|
elif isinstance(m, nn.ConvTranspose2d):
|
||||||
# 反卷积层使用特定的初始化
|
# 反卷积层使用特定的初始化
|
||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||||
@@ -399,19 +194,6 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
nn.init.constant_(m.weight, 1.0)
|
nn.init.constant_(m.weight, 1.0)
|
||||||
|
|
||||||
def forward_tokens(self, x):
|
def forward_tokens(self, x):
|
||||||
"""Forward through encoder network, return list of stage features if return_features else final output"""
|
|
||||||
if self.return_features or self.use_skip:
|
|
||||||
features = []
|
|
||||||
stage_idx = 0
|
|
||||||
for idx, block in enumerate(self.network):
|
|
||||||
x = block(x)
|
|
||||||
# 收集每个stage的输出(stage0, stage1, stage2, stage3)
|
|
||||||
# 根据SwiftFormer结构,stage在索引0,2,4,6位置
|
|
||||||
if idx in [0, 2, 4, 6]:
|
|
||||||
features.append(x)
|
|
||||||
stage_idx += 1
|
|
||||||
return x, features
|
|
||||||
else:
|
|
||||||
for block in self.network:
|
for block in self.network:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
return x
|
return x
|
||||||
@@ -421,61 +203,30 @@ class SwiftFormerTemporal(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
x: input frames of shape [B, num_frames, H, W]
|
x: input frames of shape [B, num_frames, H, W]
|
||||||
Returns:
|
Returns:
|
||||||
If return_features is False:
|
|
||||||
pred_frame: predicted frame [B, 1, H, W] (or None)
|
pred_frame: predicted frame [B, 1, H, W] (or None)
|
||||||
representation: optional representation vector [B, representation_dim] (or None)
|
|
||||||
If return_features is True:
|
|
||||||
pred_frame, representation, features (list of stage features)
|
|
||||||
"""
|
"""
|
||||||
# Encode
|
# Encode
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
if self.return_features or self.use_skip:
|
|
||||||
x, features = self.forward_tokens(x)
|
|
||||||
else:
|
|
||||||
x = self.forward_tokens(x)
|
x = self.forward_tokens(x)
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
||||||
# Get representation if needed
|
|
||||||
representation = None
|
|
||||||
if self.representation_head is not None:
|
|
||||||
representation = self.representation_head(x)
|
|
||||||
|
|
||||||
# Decode to frame
|
# Decode to frame
|
||||||
pred_frame = None
|
pred_frame = None
|
||||||
if self.use_decoder:
|
if self.use_decoder:
|
||||||
if self.use_skip:
|
|
||||||
# 提取用于skip connections的特征
|
|
||||||
# features包含所有stage的输出,我们需要stage0, stage1, stage2, stage3
|
|
||||||
# 根据SwiftFormer结构,应该有4个stage特征
|
|
||||||
if len(features) >= 4:
|
|
||||||
# 取四个stage的特征:stage0, stage1, stage2, stage3
|
|
||||||
skip_features = [features[0], features[1], features[2], features[3]]
|
|
||||||
else:
|
|
||||||
# 如果特征不够,使用可用的特征
|
|
||||||
skip_features = features[:4]
|
|
||||||
# 如果特征仍然不够,使用None填充
|
|
||||||
while len(skip_features) < 4:
|
|
||||||
skip_features.append(None)
|
|
||||||
|
|
||||||
pred_frame = self.decoder(x, skip_features)
|
|
||||||
else:
|
|
||||||
pred_frame = self.decoder(x)
|
pred_frame = self.decoder(x)
|
||||||
|
|
||||||
if self.return_features:
|
return pred_frame
|
||||||
return pred_frame, representation, features
|
|
||||||
else:
|
|
||||||
return pred_frame, representation
|
|
||||||
|
|
||||||
|
|
||||||
# Factory functions for different model sizes
|
# Factory functions for different model sizes
|
||||||
def SwiftFormerTemporal_XS(num_frames=3, use_skip=True, **kwargs):
|
def SwiftFormerTemporal_XS(num_frames=3, **kwargs):
|
||||||
return SwiftFormerTemporal('XS', num_frames=num_frames, use_skip=use_skip, **kwargs)
|
return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs)
|
||||||
|
|
||||||
def SwiftFormerTemporal_S(num_frames=3, use_skip=True, **kwargs):
|
def SwiftFormerTemporal_S(num_frames=3, **kwargs):
|
||||||
return SwiftFormerTemporal('S', num_frames=num_frames, use_skip=use_skip, **kwargs)
|
return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs)
|
||||||
|
|
||||||
def SwiftFormerTemporal_L1(num_frames=3, use_skip=True, **kwargs):
|
def SwiftFormerTemporal_L1(num_frames=3, **kwargs):
|
||||||
return SwiftFormerTemporal('l1', num_frames=num_frames, use_skip=use_skip, **kwargs)
|
return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs)
|
||||||
|
|
||||||
def SwiftFormerTemporal_L3(num_frames=3, use_skip=True, **kwargs):
|
def SwiftFormerTemporal_L3(num_frames=3, **kwargs):
|
||||||
return SwiftFormerTemporal('l3', num_frames=num_frames, use_skip=use_skip, **kwargs)
|
return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)
|
||||||
@@ -1,182 +0,0 @@
|
|||||||
"""
|
|
||||||
Loss functions for frame prediction and representation learning
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import math
|
|
||||||
|
|
||||||
|
|
||||||
class SSIMLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
Structural Similarity Index Measure Loss
|
|
||||||
Based on: https://github.com/Po-Hsun-Su/pytorch-ssim
|
|
||||||
"""
|
|
||||||
def __init__(self, window_size=11, size_average=True):
|
|
||||||
super().__init__()
|
|
||||||
self.window_size = window_size
|
|
||||||
self.size_average = size_average
|
|
||||||
self.channel = 3
|
|
||||||
self.window = self.create_window(window_size, self.channel)
|
|
||||||
|
|
||||||
def create_window(self, window_size, channel):
|
|
||||||
def gaussian(window_size, sigma):
|
|
||||||
gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
|
||||||
return gauss/gauss.sum()
|
|
||||||
|
|
||||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
|
||||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
|
||||||
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
|
|
||||||
return window
|
|
||||||
|
|
||||||
def forward(self, img1, img2):
|
|
||||||
# Ensure window is on correct device
|
|
||||||
if self.window.device != img1.device:
|
|
||||||
self.window = self.window.to(img1.device)
|
|
||||||
|
|
||||||
mu1 = F.conv2d(img1, self.window, padding=self.window_size//2, groups=self.channel)
|
|
||||||
mu2 = F.conv2d(img2, self.window, padding=self.window_size//2, groups=self.channel)
|
|
||||||
|
|
||||||
mu1_sq = mu1.pow(2)
|
|
||||||
mu2_sq = mu2.pow(2)
|
|
||||||
mu1_mu2 = mu1 * mu2
|
|
||||||
|
|
||||||
sigma1_sq = F.conv2d(img1*img1, self.window, padding=self.window_size//2, groups=self.channel) - mu1_sq
|
|
||||||
sigma2_sq = F.conv2d(img2*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu2_sq
|
|
||||||
sigma12 = F.conv2d(img1*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu1_mu2
|
|
||||||
|
|
||||||
C1 = 0.01**2
|
|
||||||
C2 = 0.03**2
|
|
||||||
|
|
||||||
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
|
||||||
|
|
||||||
if self.size_average:
|
|
||||||
return 1 - ssim_map.mean()
|
|
||||||
else:
|
|
||||||
return 1 - ssim_map.mean(1).mean(1).mean(1)
|
|
||||||
|
|
||||||
|
|
||||||
class FramePredictionLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
Combined loss for frame prediction
|
|
||||||
"""
|
|
||||||
def __init__(self, l1_weight=1.0, ssim_weight=0.1, use_ssim=True):
|
|
||||||
super().__init__()
|
|
||||||
self.l1_weight = l1_weight
|
|
||||||
self.ssim_weight = ssim_weight
|
|
||||||
self.use_ssim = use_ssim
|
|
||||||
|
|
||||||
self.l1_loss = nn.L1Loss()
|
|
||||||
if use_ssim:
|
|
||||||
self.ssim_loss = SSIMLoss()
|
|
||||||
|
|
||||||
def forward(self, pred, target):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
pred: predicted frame [B, 3, H, W] in range [-1, 1]
|
|
||||||
target: target frame [B, 3, H, W] in range [-1, 1]
|
|
||||||
Returns:
|
|
||||||
total_loss, loss_dict
|
|
||||||
"""
|
|
||||||
loss_dict = {}
|
|
||||||
|
|
||||||
# L1 loss
|
|
||||||
l1_loss = self.l1_loss(pred, target)
|
|
||||||
loss_dict['l1'] = l1_loss
|
|
||||||
total_loss = self.l1_weight * l1_loss
|
|
||||||
|
|
||||||
# SSIM loss
|
|
||||||
if self.use_ssim:
|
|
||||||
ssim_loss = self.ssim_loss(pred, target)
|
|
||||||
loss_dict['ssim'] = ssim_loss
|
|
||||||
total_loss += self.ssim_weight * ssim_loss
|
|
||||||
|
|
||||||
loss_dict['total'] = total_loss
|
|
||||||
return total_loss, loss_dict
|
|
||||||
|
|
||||||
|
|
||||||
class ContrastiveLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
Contrastive loss for representation learning
|
|
||||||
Positive pairs: representations from adjacent frames
|
|
||||||
Negative pairs: representations from distant frames
|
|
||||||
"""
|
|
||||||
def __init__(self, temperature=0.1, margin=1.0):
|
|
||||||
super().__init__()
|
|
||||||
self.temperature = temperature
|
|
||||||
self.margin = margin
|
|
||||||
self.cosine_similarity = nn.CosineSimilarity(dim=-1)
|
|
||||||
|
|
||||||
def forward(self, representations, temporal_indices):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
representations: [B, D] representation vectors
|
|
||||||
temporal_indices: [B] temporal indices of each sample
|
|
||||||
Returns:
|
|
||||||
contrastive_loss
|
|
||||||
"""
|
|
||||||
batch_size = representations.size(0)
|
|
||||||
|
|
||||||
# Compute similarity matrix
|
|
||||||
sim_matrix = torch.matmul(representations, representations.T) / self.temperature
|
|
||||||
|
|
||||||
# Create positive mask (adjacent frames)
|
|
||||||
indices_expanded = temporal_indices.unsqueeze(0)
|
|
||||||
diff = torch.abs(indices_expanded - indices_expanded.T)
|
|
||||||
positive_mask = (diff == 1).float()
|
|
||||||
|
|
||||||
# Create negative mask (distant frames)
|
|
||||||
negative_mask = (diff > 2).float()
|
|
||||||
|
|
||||||
# Positive loss
|
|
||||||
pos_sim = sim_matrix * positive_mask
|
|
||||||
pos_loss = -torch.log(torch.exp(pos_sim) / torch.exp(sim_matrix).sum(dim=-1, keepdim=True) + 1e-8)
|
|
||||||
pos_loss = (pos_loss * positive_mask).sum() / (positive_mask.sum() + 1e-8)
|
|
||||||
|
|
||||||
# Negative loss (push apart)
|
|
||||||
neg_sim = sim_matrix * negative_mask
|
|
||||||
neg_loss = torch.relu(neg_sim - self.margin).mean()
|
|
||||||
|
|
||||||
return pos_loss + 0.1 * neg_loss
|
|
||||||
|
|
||||||
|
|
||||||
class MultiTaskLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
Multi-task loss combining frame prediction and representation learning
|
|
||||||
"""
|
|
||||||
def __init__(self, frame_weight=1.0, contrastive_weight=0.1,
|
|
||||||
l1_weight=1.0, ssim_weight=0.1, use_contrastive=True):
|
|
||||||
super().__init__()
|
|
||||||
self.frame_weight = frame_weight
|
|
||||||
self.contrastive_weight = contrastive_weight
|
|
||||||
self.use_contrastive = use_contrastive
|
|
||||||
|
|
||||||
self.frame_loss = FramePredictionLoss(l1_weight=l1_weight, ssim_weight=ssim_weight)
|
|
||||||
if use_contrastive:
|
|
||||||
self.contrastive_loss = ContrastiveLoss()
|
|
||||||
|
|
||||||
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
pred_frame: predicted frame [B, 3, H, W]
|
|
||||||
target_frame: target frame [B, 3, H, W]
|
|
||||||
representations: [B, D] representation vectors (optional)
|
|
||||||
temporal_indices: [B] temporal indices (optional)
|
|
||||||
Returns:
|
|
||||||
total_loss, loss_dict
|
|
||||||
"""
|
|
||||||
loss_dict = {}
|
|
||||||
|
|
||||||
# Frame prediction loss
|
|
||||||
frame_loss, frame_loss_dict = self.frame_loss(pred_frame, target_frame)
|
|
||||||
loss_dict.update({f'frame_{k}': v for k, v in frame_loss_dict.items()})
|
|
||||||
total_loss = self.frame_weight * frame_loss
|
|
||||||
|
|
||||||
# Contrastive loss (if representations provided)
|
|
||||||
if self.use_contrastive and representations is not None and temporal_indices is not None:
|
|
||||||
contrastive_loss = self.contrastive_loss(representations, temporal_indices)
|
|
||||||
loss_dict['contrastive'] = contrastive_loss
|
|
||||||
total_loss += self.contrastive_weight * contrastive_loss
|
|
||||||
|
|
||||||
loss_dict['total'] = total_loss
|
|
||||||
return total_loss, loss_dict
|
|
||||||
Reference in New Issue
Block a user