清理代码,删除跳连接部分

This commit is contained in:
2026-01-11 13:25:34 +08:00
parent c5502cc87c
commit df703638da
3 changed files with 68 additions and 268 deletions

View File

@@ -45,6 +45,15 @@ def denormalize(tensor):
# [0, 1] -> [0, 255] # [0, 1] -> [0, 255]
tensor = tensor * 255 tensor = tensor * 255
return tensor.clamp(0, 255) return tensor.clamp(0, 255)
# return tensor
def minmax_denormalize(tensor):
tensor_min = tensor.min()
tensor_max = tensor.max()
tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
# tensor = tensor*2-1
tensor = tensor*255
return tensor.clamp(0, 255)
def calculate_metrics(pred, target, debug=False): def calculate_metrics(pred, target, debug=False):
@@ -134,6 +143,10 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
ax.set_title('Predicted') ax.set_title('Predicted')
ax.axis('off') ax.axis('off')
#debug print
print(target_frame)
print(pred_frame)
# debug print - 改进为更有信息量的输出 # debug print - 改进为更有信息量的输出
if isinstance(pred_frame, np.ndarray): if isinstance(pred_frame, np.ndarray):
print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}") print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
@@ -161,8 +174,8 @@ def evaluate_model(model, data_loader, device, args):
metrics_dict: 包含所有指标的字典 metrics_dict: 包含所有指标的字典
sample_results: 示例结果用于可视化 sample_results: 示例结果用于可视化
""" """
# model.eval() model.eval()
model.train() # 临时使用训练模式 # model.train() # 临时使用训练模式
# 初始化指标累加器 # 初始化指标累加器
total_mse = 0.0 total_mse = 0.0
@@ -183,10 +196,11 @@ def evaluate_model(model, data_loader, device, args):
target_frames = target_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True)
# 前向传播 # 前向传播
pred_frames, _ = model(input_frames) pred_frames = model(input_frames)
# 反归一化用于指标计算 # 反归一化用于指标计算
pred_denorm = denormalize(pred_frames) # [B, 1, H, W] # pred_denorm = minmax_denormalize(pred_frames) # [B, 1, H, W]
pred_denorm = denormalize(pred_frames)
target_denorm = denormalize(target_frames) # [B, 1, H, W] target_denorm = denormalize(target_frames) # [B, 1, H, W]
batch_size = input_frames.size(0) batch_size = input_frames.size(0)
@@ -309,8 +323,6 @@ def main(args):
print(f"创建模型: {args.model}") print(f"创建模型: {args.model}")
model_kwargs = { model_kwargs = {
'num_frames': args.num_frames, 'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
} }
if args.model == 'SwiftFormerTemporal_XS': if args.model == 'SwiftFormerTemporal_XS':
@@ -335,10 +347,10 @@ def main(args):
except (pickle.UnpicklingError, TypeError) as e: except (pickle.UnpicklingError, TypeError) as e:
print(f"使用weights_only=False加载失败: {e}") print(f"使用weights_only=False加载失败: {e}")
print("尝试使用torch.serialization.add_safe_globals...") print("尝试使用torch.serialization.add_safe_globals...")
from argparse import Namespace # from argparse import Namespace
# 添加安全全局变量 # # 添加安全全局变量
torch.serialization.add_safe_globals([Namespace]) # torch.serialization.add_safe_globals([Namespace])
checkpoint = torch.load(args.resume, map_location='cpu') # checkpoint = torch.load(args.resume, map_location='cpu')
# 处理状态字典(可能包含'module.'前缀) # 处理状态字典(可能包含'module.'前缀)
if 'model' in checkpoint: if 'model' in checkpoint:
@@ -462,10 +474,6 @@ def get_args_parser():
# 模型参数 # 模型参数
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL', parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='要评估的模型名称') help='要评估的模型名称')
parser.add_argument('--use-representation-head', action='store_true',
help='使用表示头进行姿态/速度预测')
parser.add_argument('--representation-dim', default=128, type=int,
help='表示向量的维度')
# 评估参数 # 评估参数
parser.add_argument('--batch-size', default=16, type=int, parser.add_argument('--batch-size', default=16, type=int,

View File

@@ -49,11 +49,6 @@ def get_args_parser():
# Model parameters # Model parameters
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL', parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='Name of model to train') help='Name of model to train')
parser.add_argument('--use-representation-head', action='store_true',
help='Use representation head for pose/velocity prediction')
parser.add_argument('--representation-dim', default=128, type=int,
help='Dimension of representation vector')
parser.add_argument('--use-skip', default=False, type=bool, help='using skip connections')
# Training parameters # Training parameters
parser.add_argument('--batch-size', default=32, type=int) parser.add_argument('--batch-size', default=32, type=int)
@@ -207,9 +202,6 @@ def main(args):
print(f"Creating model: {args.model}") print(f"Creating model: {args.model}")
model_kwargs = { model_kwargs = {
'num_frames': args.num_frames, 'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
'use_skip': args.use_skip,
} }
if args.model == 'SwiftFormerTemporal_XS': if args.model == 'SwiftFormerTemporal_XS':
@@ -258,7 +250,7 @@ def main(args):
super().__init__() super().__init__()
self.mse = nn.MSELoss() self.mse = nn.MSELoss()
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None): def forward(self, pred_frame, target_frame, temporal_indices=None):
loss = self.mse(pred_frame, target_frame) loss = self.mse(pred_frame, target_frame)
loss_dict = {'mse': loss} loss_dict = {'mse': loss}
return loss, loss_dict return loss, loss_dict
@@ -386,10 +378,10 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
# Forward pass # Forward pass
with torch.amp.autocast(device_type='cuda'): with torch.amp.autocast(device_type='cuda'):
pred_frames, representations = model(input_frames) pred_frames = model(input_frames)
loss, loss_dict = criterion( loss, loss_dict = criterion(
pred_frames, target_frames, pred_frames, target_frames,
representations, temporal_indices temporal_indices
) )
loss_value = loss.item() loss_value = loss.item()
@@ -452,7 +444,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0: if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
with torch.no_grad(): with torch.no_grad():
# Take first sample from batch for visualization # Take first sample from batch for visualization
pred_vis, _ = model(input_frames[:1]) pred_vis = model(input_frames[:1])
# Convert to appropriate format for TensorBoard # Convert to appropriate format for TensorBoard
# Assuming frames are in [B, C, H, W] format # Assuming frames are in [B, C, H, W] format
writer.add_images('train/input', input_frames[:1], global_step) writer.add_images('train/input', input_frames[:1], global_step)
@@ -497,10 +489,10 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
# Compute output # Compute output
with torch.amp.autocast(device_type='cuda'): with torch.amp.autocast(device_type='cuda'):
pred_frames, representations = model(input_frames) pred_frames = model(input_frames)
loss, loss_dict = criterion( loss, loss_dict = criterion(
pred_frames, target_frames, pred_frames, target_frames,
representations, temporal_indices temporal_indices
) )
# 计算诊断指标 # 计算诊断指标

View File

@@ -93,159 +93,33 @@ class DecoderBlock(nn.Module):
return x return x
class DecoderBlockWithSkip(nn.Module):
"""Decoder block with skip connection support"""
def __init__(self, in_channels, out_channels, skip_channels=0, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
# 总输入通道 = 输入通道 + skip通道
total_in_channels = in_channels + skip_channels
# 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
total_in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=True
)
self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
# 残差路径:如果需要改变通道数或空间尺寸
self.shortcut = nn.Identity()
if total_in_channels != out_channels or stride != 1:
if stride == 1:
self.shortcut = nn.Conv2d(total_in_channels, out_channels,
kernel_size=1, bias=True)
else:
self.shortcut = nn.ConvTranspose2d(
total_in_channels, out_channels,
kernel_size=1,
stride=stride,
padding=0,
output_padding=output_padding,
bias=True
)
# 使用LeakyReLU避免死亡神经元
self.activation = nn.LeakyReLU(0.2, inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv_transpose.bias is not None:
nn.init.constant_(self.conv_transpose.bias, 0)
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv1.bias is not None:
nn.init.constant_(self.conv1.bias, 0)
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv2.bias is not None:
nn.init.constant_(self.conv2.bias, 0)
# 初始化shortcut
if not isinstance(self.shortcut, nn.Identity):
if isinstance(self.shortcut, nn.Conv2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(self.shortcut, nn.ConvTranspose2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.shortcut.bias is not None:
nn.init.constant_(self.shortcut.bias, 0)
def forward(self, x, skip_feature=None):
# 如果有skip feature将其与输入拼接
if skip_feature is not None:
# 确保skip特征的空间尺寸与x匹配
if skip_feature.shape[2:] != x.shape[2:]:
# 使用双线性插值进行上采样或下采样
skip_feature = torch.nn.functional.interpolate(
skip_feature,
size=x.shape[2:],
mode='bilinear',
align_corners=False
)
x = torch.cat([x, skip_feature], dim=1)
identity = self.shortcut(x)
# 主路径
x = self.conv_transpose(x)
x = self.activation(x)
x = self.conv1(x)
x = self.activation(x)
x = self.conv2(x)
# 残差连接
x = x + identity
x = self.activation(x)
return x
class FramePredictionDecoder(nn.Module): class FramePredictionDecoder(nn.Module):
"""Improved decoder for frame prediction with better upsampling strategy""" """Improved decoder for frame prediction"""
def __init__(self, embed_dims, output_channels=1, use_skip=False): def __init__(self, embed_dims, output_channels=1):
super().__init__() super().__init__()
self.use_skip = use_skip
# Reverse the embed_dims for decoder # Reverse the embed_dims for decoder
decoder_dims = embed_dims[::-1] decoder_dims = embed_dims[::-1]
self.blocks = nn.ModuleList() self.blocks = nn.ModuleList()
if use_skip: # 使用普通的DecoderBlock第一个block使用大步长
# 使用支持skip connections的block self.blocks.append(DecoderBlock(
# 第一个block从bottleneck到stage4使用大步长stride=4skip来自stage3 decoder_dims[0], decoder_dims[1],
self.blocks.append(DecoderBlockWithSkip( kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
decoder_dims[0], decoder_dims[1], ))
skip_channels=embed_dims[3], # stage3的通道数 self.blocks.append(DecoderBlock(
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4 decoder_dims[1], decoder_dims[2],
)) kernel_size=3, stride=2, padding=1, output_padding=1
# 第二个blockstage4到stage3stride=2skip来自stage2 ))
self.blocks.append(DecoderBlockWithSkip( self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2], decoder_dims[2], decoder_dims[3],
skip_channels=embed_dims[2], # stage2的通道数 kernel_size=3, stride=2, padding=1, output_padding=1
kernel_size=3, stride=2, padding=1, output_padding=1 ))
)) # 第四个block增加到64通道
# 第三个blockstage3到stage2stride=2skip来自stage1 self.blocks.append(DecoderBlock(
self.blocks.append(DecoderBlockWithSkip( decoder_dims[3], 64,
decoder_dims[2], decoder_dims[3], kernel_size=3, stride=2, padding=1, output_padding=1
skip_channels=embed_dims[1], # stage1的通道数 ))
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第四个blockstage2到stage1stride=2skip来自stage0
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[3], 64, # 输出到64通道
skip_channels=embed_dims[0], # stage0的通道数
kernel_size=3, stride=2, padding=1, output_padding=1
))
else:
# 使用普通的DecoderBlock第一个block使用大步长
self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
))
self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1
))
self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第四个block增加到64通道
self.blocks.append(DecoderBlock(
decoder_dims[3], 64,
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 改进的最终输出层:不使用反卷积,只进行特征精炼 # 改进的最终输出层:不使用反卷积,只进行特征精炼
# 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合 # 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合
@@ -258,48 +132,14 @@ class FramePredictionDecoder(nn.Module):
# 移除Tanh让输出在任意范围由损失函数和归一化处理 # 移除Tanh让输出在任意范围由损失函数和归一化处理
) )
def forward(self, x, skip_features=None): def forward(self, x):
""" """
Args: Args:
x: input tensor of shape [B, embed_dims[-1], H/32, W/32] x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
skip_features: list of encoder features from stages [stage3, stage2, stage1, stage0]
each of shape [B, C, H', W'] where C matches encoder dims
""" """
if self.use_skip: # 不使用skip connections
if skip_features is None: for i in range(4):
raise ValueError("skip_features must be provided when use_skip=True") x = self.blocks[i](x)
# 确保有4个skip features
assert len(skip_features) == 4, f"Need 4 skip features, got {len(skip_features)}"
# 反转顺序以匹配解码器stage3, stage2, stage1, stage0
skip_features = skip_features[::-1]
# 调整skip特征的尺寸以匹配新的上采样策略
adjusted_skip_features = []
for i, skip in enumerate(skip_features):
if skip is not None:
# 计算目标尺寸4, 2, 2, 2倍上采样
upsample_factors = [4, 2, 2, 2]
target_height = x.shape[2] * upsample_factors[i]
target_width = x.shape[3] * upsample_factors[i]
if skip.shape[2:] != (target_height, target_width):
skip = torch.nn.functional.interpolate(
skip,
size=(target_height, target_width),
mode='bilinear',
align_corners=False
)
adjusted_skip_features.append(skip)
# 四个block使用skip connections
for i in range(4):
x = self.blocks[i](x, adjusted_skip_features[i])
else:
# 不使用skip connections
for i in range(4):
x = self.blocks[i](x)
# 最终输出层:只进行特征精炼,不上采样 # 最终输出层:只进行特征精炼,不上采样
x = self.final_block(x) x = self.final_block(x)
@@ -316,9 +156,6 @@ class SwiftFormerTemporal(nn.Module):
model_name='XS', model_name='XS',
num_frames=3, num_frames=3,
use_decoder=True, use_decoder=True,
use_skip=True, # 新增是否使用skip connections
use_representation_head=False,
representation_dim=128,
return_features=False, return_features=False,
**kwargs): **kwargs):
super().__init__() super().__init__()
@@ -330,8 +167,6 @@ class SwiftFormerTemporal(nn.Module):
# Store configuration # Store configuration
self.num_frames = num_frames self.num_frames = num_frames
self.use_decoder = use_decoder self.use_decoder = use_decoder
self.use_skip = use_skip # 保存skip connections设置
self.use_representation_head = use_representation_head
self.return_features = return_features self.return_features = return_features
# Modify stem to accept multiple frames (only Y channel) # Modify stem to accept multiple frames (only Y channel)
@@ -365,22 +200,9 @@ class SwiftFormerTemporal(nn.Module):
if use_decoder: if use_decoder:
self.decoder = FramePredictionDecoder( self.decoder = FramePredictionDecoder(
embed_dims, embed_dims,
output_channels=1, output_channels=1
use_skip=use_skip # 传递skip connections设置
) )
# Representation head for pose/velocity prediction
if use_representation_head:
self.representation_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(embed_dims[-1], representation_dim),
nn.ReLU(),
nn.Linear(representation_dim, representation_dim)
)
else:
self.representation_head = None
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
@@ -400,7 +222,7 @@ class SwiftFormerTemporal(nn.Module):
def forward_tokens(self, x): def forward_tokens(self, x):
"""Forward through encoder network, return list of stage features if return_features else final output""" """Forward through encoder network, return list of stage features if return_features else final output"""
if self.return_features or self.use_skip: if self.return_features:
features = [] features = []
stage_idx = 0 stage_idx = 0
for idx, block in enumerate(self.network): for idx, block in enumerate(self.network):
@@ -423,59 +245,37 @@ class SwiftFormerTemporal(nn.Module):
Returns: Returns:
If return_features is False: If return_features is False:
pred_frame: predicted frame [B, 1, H, W] (or None) pred_frame: predicted frame [B, 1, H, W] (or None)
representation: optional representation vector [B, representation_dim] (or None)
If return_features is True: If return_features is True:
pred_frame, representation, features (list of stage features) pred_frame, features (list of stage features)
""" """
# Encode # Encode
x = self.patch_embed(x) x = self.patch_embed(x)
if self.return_features or self.use_skip: if self.return_features:
x, features = self.forward_tokens(x) x, features = self.forward_tokens(x)
else: else:
x = self.forward_tokens(x) x = self.forward_tokens(x)
x = self.norm(x) x = self.norm(x)
# Get representation if needed
representation = None
if self.representation_head is not None:
representation = self.representation_head(x)
# Decode to frame # Decode to frame
pred_frame = None pred_frame = None
if self.use_decoder: if self.use_decoder:
if self.use_skip: pred_frame = self.decoder(x)
# 提取用于skip connections的特征
# features包含所有stage的输出我们需要stage0, stage1, stage2, stage3
# 根据SwiftFormer结构应该有4个stage特征
if len(features) >= 4:
# 取四个stage的特征stage0, stage1, stage2, stage3
skip_features = [features[0], features[1], features[2], features[3]]
else:
# 如果特征不够,使用可用的特征
skip_features = features[:4]
# 如果特征仍然不够使用None填充
while len(skip_features) < 4:
skip_features.append(None)
pred_frame = self.decoder(x, skip_features)
else:
pred_frame = self.decoder(x)
if self.return_features: if self.return_features:
return pred_frame, representation, features return pred_frame, features
else: else:
return pred_frame, representation return pred_frame
# Factory functions for different model sizes # Factory functions for different model sizes
def SwiftFormerTemporal_XS(num_frames=3, use_skip=True, **kwargs): def SwiftFormerTemporal_XS(num_frames=3, **kwargs):
return SwiftFormerTemporal('XS', num_frames=num_frames, use_skip=use_skip, **kwargs) return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_S(num_frames=3, use_skip=True, **kwargs): def SwiftFormerTemporal_S(num_frames=3, **kwargs):
return SwiftFormerTemporal('S', num_frames=num_frames, use_skip=use_skip, **kwargs) return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_L1(num_frames=3, use_skip=True, **kwargs): def SwiftFormerTemporal_L1(num_frames=3, **kwargs):
return SwiftFormerTemporal('l1', num_frames=num_frames, use_skip=use_skip, **kwargs) return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_L3(num_frames=3, use_skip=True, **kwargs): def SwiftFormerTemporal_L3(num_frames=3, **kwargs):
return SwiftFormerTemporal('l3', num_frames=num_frames, use_skip=use_skip, **kwargs) return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)