删除残差路径和shortcut,镜像问题仍存在

This commit is contained in:
2026-01-16 15:21:47 +08:00
parent a92a0b29e9
commit 543beefa2a
3 changed files with 24 additions and 81 deletions

View File

@@ -11,7 +11,7 @@ from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder with residual connections"""
"""Upsampling block for frame prediction decoder without residual connections"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
# 主路径:反卷积 + 两个卷积层
@@ -31,28 +31,6 @@ class DecoderBlock(nn.Module):
kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
# 残差路径:如果需要改变通道数或空间尺寸
self.shortcut = nn.Identity()
if in_channels != out_channels or stride != 1:
# 使用1x1卷积调整通道数如果需要上采样则使用反卷积
if stride == 1:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels)
)
else:
self.shortcut = nn.Sequential(
nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=1,
stride=stride,
padding=0,
output_padding=output_padding,
bias=False
),
nn.BatchNorm2d(out_channels)
)
# 使用ReLU激活函数
self.activation = nn.ReLU(inplace=True)
@@ -67,13 +45,6 @@ class DecoderBlock(nn.Module):
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
# 初始化shortcut
if not isinstance(self.shortcut, nn.Identity):
# shortcut现在是Sequential需要初始化其中的卷积层
for module in self.shortcut:
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
# 初始化BN层使用默认初始化
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
@@ -81,8 +52,6 @@ class DecoderBlock(nn.Module):
nn.init.constant_(m.bias, 0)
def forward(self, x):
identity = self.shortcut(x)
# 主路径
x = self.conv_transpose(x)
x = self.bn1(x)
@@ -94,9 +63,6 @@ class DecoderBlock(nn.Module):
x = self.conv2(x)
x = self.bn3(x)
# 残差连接
x = x + identity
x = self.activation(x)
return x
@@ -105,28 +71,28 @@ class FramePredictionDecoder(nn.Module):
"""Improved decoder for frame prediction"""
def __init__(self, embed_dims, output_channels=1):
super().__init__()
# Reverse the embed_dims for decoder
decoder_dims = embed_dims[::-1]
# Define decoder dimensions independently (no skip connections)
start_dim = embed_dims[-1]
decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
self.blocks = nn.ModuleList()
# 调整顺序将stride=4放在倒数第二的位置
# 第一个blockstride=2 (220 -> 112)
# 第一个blockstride=2 (decoder_dims[0] -> decoder_dims[1])
self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第二个blockstride=2 (112 -> 56)
# 第二个blockstride=2 (decoder_dims[1] -> decoder_dims[2])
self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第三个blockstride=2 (56 -> 48)
# 第三个blockstride=2 (decoder_dims[2] -> decoder_dims[3])
self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第四个blockstride=4 (48 -> 64),放在倒数第二的位置
# 第四个blockstride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
self.blocks.append(DecoderBlock(
decoder_dims[3], 64,
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
@@ -138,7 +104,7 @@ class FramePredictionDecoder(nn.Module):
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
nn.Tanh() # 添加Tanh激活函数约束输出在[-1, 1]范围内
nn.Tanh()
)
def forward(self, x):