diff --git a/evaluate_temporal.py b/evaluate_temporal.py index c1553d7..9e59292 100644 --- a/evaluate_temporal.py +++ b/evaluate_temporal.py @@ -147,15 +147,15 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path, print(target_frame) print(pred_frame) - # debug print - 改进为更有信息量的输出 - if isinstance(pred_frame, np.ndarray): - print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}") - # 检查是否有大量值在127.5附近 - mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0 - percent_near_127_5 = np.mean(mask_near_127_5) * 100 - print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%") - else: - print(pred_frame) + # # debug print - 改进为更有信息量的输出 + # if isinstance(pred_frame, np.ndarray): + # print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}") + # # 检查是否有大量值在127.5附近 + # mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0 + # percent_near_127_5 = np.mean(mask_near_127_5) * 100 + # print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%") + # else: + # print(pred_frame) plt.tight_layout() plt.savefig(save_path, dpi=150, bbox_inches='tight') @@ -347,10 +347,6 @@ def main(args): except (pickle.UnpicklingError, TypeError) as e: print(f"使用weights_only=False加载失败: {e}") print("尝试使用torch.serialization.add_safe_globals...") - # from argparse import Namespace - # # 添加安全全局变量 - # torch.serialization.add_safe_globals([Namespace]) - # checkpoint = torch.load(args.resume, map_location='cpu') # 处理状态字典(可能包含'module.'前缀) if 'model' in checkpoint: diff --git a/models/swiftformer_temporal.py b/models/swiftformer_temporal.py index 31e8406..eeaa29b 100644 --- a/models/swiftformer_temporal.py +++ b/models/swiftformer_temporal.py @@ -21,71 +21,79 @@ class DecoderBlock(nn.Module): stride=stride, padding=padding, output_padding=output_padding, - bias=True # 启用bias,因为移除了BN + bias=False # 禁用bias,因为使用BN ) + self.bn1 = nn.BatchNorm2d(out_channels) self.conv1 = nn.Conv2d(out_channels, out_channels, - kernel_size=3, padding=1, bias=True) + kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, - kernel_size=3, padding=1, bias=True) + kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(out_channels) # 残差路径:如果需要改变通道数或空间尺寸 self.shortcut = nn.Identity() if in_channels != out_channels or stride != 1: # 使用1x1卷积调整通道数,如果需要上采样则使用反卷积 if stride == 1: - self.shortcut = nn.Conv2d(in_channels, out_channels, - kernel_size=1, bias=True) + self.shortcut = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), + nn.BatchNorm2d(out_channels) + ) else: - self.shortcut = nn.ConvTranspose2d( - in_channels, out_channels, - kernel_size=1, - stride=stride, - padding=0, - output_padding=output_padding, - bias=True + self.shortcut = nn.Sequential( + nn.ConvTranspose2d( + in_channels, out_channels, + kernel_size=1, + stride=stride, + padding=0, + output_padding=output_padding, + bias=False + ), + nn.BatchNorm2d(out_channels) ) - # 使用LeakyReLU避免死亡神经元 - self.activation = nn.LeakyReLU(0.2, inplace=True) + # 使用ReLU激活函数 + self.activation = nn.ReLU(inplace=True) # 初始化权重 self._init_weights() def _init_weights(self): # 初始化反卷积层 - nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu') - if self.conv_transpose.bias is not None: - nn.init.constant_(self.conv_transpose.bias, 0) + nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='relu') # 初始化卷积层 - nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu') - if self.conv1.bias is not None: - nn.init.constant_(self.conv1.bias, 0) - - nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu') - if self.conv2.bias is not None: - nn.init.constant_(self.conv2.bias, 0) + nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu') # 初始化shortcut if not isinstance(self.shortcut, nn.Identity): - if isinstance(self.shortcut, nn.Conv2d): - nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu') - elif isinstance(self.shortcut, nn.ConvTranspose2d): - nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu') - if self.shortcut.bias is not None: - nn.init.constant_(self.shortcut.bias, 0) + # shortcut现在是Sequential,需要初始化其中的卷积层 + for module in self.shortcut: + if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + + # 初始化BN层(使用默认初始化) + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) def forward(self, x): identity = self.shortcut(x) # 主路径 x = self.conv_transpose(x) + x = self.bn1(x) x = self.activation(x) x = self.conv1(x) + x = self.bn2(x) x = self.activation(x) x = self.conv2(x) + x = self.bn3(x) # 残差连接 x = x + identity @@ -102,34 +110,35 @@ class FramePredictionDecoder(nn.Module): self.blocks = nn.ModuleList() - # 使用普通的DecoderBlock,第一个block使用大步长 + # 调整顺序:将stride=4放在倒数第二的位置 + # 第一个block:stride=2 (220 -> 112) self.blocks.append(DecoderBlock( decoder_dims[0], decoder_dims[1], - kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4 + kernel_size=3, stride=2, padding=1, output_padding=1 )) + # 第二个block:stride=2 (112 -> 56) self.blocks.append(DecoderBlock( decoder_dims[1], decoder_dims[2], kernel_size=3, stride=2, padding=1, output_padding=1 )) + # 第三个block:stride=2 (56 -> 48) self.blocks.append(DecoderBlock( decoder_dims[2], decoder_dims[3], kernel_size=3, stride=2, padding=1, output_padding=1 )) - # 第四个block:增加到64通道 + # 第四个block:stride=4 (48 -> 64),放在倒数第二的位置 self.blocks.append(DecoderBlock( decoder_dims[3], 64, - kernel_size=3, stride=2, padding=1, output_padding=1 + kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里 )) - # 改进的最终输出层:不使用反卷积,只进行特征精炼 - # 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合 self.final_block = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True), - nn.LeakyReLU(0.2, inplace=True), + nn.ReLU(inplace=True), nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True), - nn.LeakyReLU(0.2, inplace=True), - nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True) - # 移除Tanh,让输出在任意范围,由损失函数和归一化处理 + nn.ReLU(inplace=True), + nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True), + nn.Tanh() # 添加Tanh激活函数,约束输出在[-1, 1]范围内 ) def forward(self, x): @@ -156,7 +165,6 @@ class SwiftFormerTemporal(nn.Module): model_name='XS', num_frames=3, use_decoder=True, - return_features=False, **kwargs): super().__init__() @@ -167,7 +175,6 @@ class SwiftFormerTemporal(nn.Module): # Store configuration self.num_frames = num_frames self.use_decoder = use_decoder - self.return_features = return_features # Modify stem to accept multiple frames (only Y channel) in_channels = num_frames @@ -207,13 +214,13 @@ class SwiftFormerTemporal(nn.Module): def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): - # 使用Kaiming初始化,适合ReLU/LeakyReLU - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + # 使用Kaiming初始化,适合ReLU + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.ConvTranspose2d): # 反卷积层使用特定的初始化 - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): @@ -221,39 +228,20 @@ class SwiftFormerTemporal(nn.Module): nn.init.constant_(m.weight, 1.0) def forward_tokens(self, x): - """Forward through encoder network, return list of stage features if return_features else final output""" - if self.return_features: - features = [] - stage_idx = 0 - for idx, block in enumerate(self.network): - x = block(x) - # 收集每个stage的输出(stage0, stage1, stage2, stage3) - # 根据SwiftFormer结构,stage在索引0,2,4,6位置 - if idx in [0, 2, 4, 6]: - features.append(x) - stage_idx += 1 - return x, features - else: - for block in self.network: - x = block(x) - return x + for block in self.network: + x = block(x) + return x def forward(self, x): """ Args: x: input frames of shape [B, num_frames, H, W] Returns: - If return_features is False: - pred_frame: predicted frame [B, 1, H, W] (or None) - If return_features is True: - pred_frame, features (list of stage features) + pred_frame: predicted frame [B, 1, H, W] (or None) """ # Encode x = self.patch_embed(x) - if self.return_features: - x, features = self.forward_tokens(x) - else: - x = self.forward_tokens(x) + x = self.forward_tokens(x) x = self.norm(x) # Decode to frame @@ -261,10 +249,7 @@ class SwiftFormerTemporal(nn.Module): if self.use_decoder: pred_frame = self.decoder(x) - if self.return_features: - return pred_frame, features - else: - return pred_frame + return pred_frame # Factory functions for different model sizes