Files
asmo_vhead/models/swiftformer_temporal.py

481 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction
"""
import torch
import torch.nn as nn
from .swiftformer import (
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
stem, Embedding, Stage
)
from timm.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder with residual connections"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
# 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=True # 启用bias因为移除了BN
)
self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
# 残差路径:如果需要改变通道数或空间尺寸
self.shortcut = nn.Identity()
if in_channels != out_channels or stride != 1:
# 使用1x1卷积调整通道数如果需要上采样则使用反卷积
if stride == 1:
self.shortcut = nn.Conv2d(in_channels, out_channels,
kernel_size=1, bias=True)
else:
self.shortcut = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=1,
stride=stride,
padding=0,
output_padding=output_padding,
bias=True
)
# 使用LeakyReLU避免死亡神经元
self.activation = nn.LeakyReLU(0.2, inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv_transpose.bias is not None:
nn.init.constant_(self.conv_transpose.bias, 0)
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv1.bias is not None:
nn.init.constant_(self.conv1.bias, 0)
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv2.bias is not None:
nn.init.constant_(self.conv2.bias, 0)
# 初始化shortcut
if not isinstance(self.shortcut, nn.Identity):
if isinstance(self.shortcut, nn.Conv2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(self.shortcut, nn.ConvTranspose2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.shortcut.bias is not None:
nn.init.constant_(self.shortcut.bias, 0)
def forward(self, x):
identity = self.shortcut(x)
# 主路径
x = self.conv_transpose(x)
x = self.activation(x)
x = self.conv1(x)
x = self.activation(x)
x = self.conv2(x)
# 残差连接
x = x + identity
x = self.activation(x)
return x
class DecoderBlockWithSkip(nn.Module):
"""Decoder block with skip connection support"""
def __init__(self, in_channels, out_channels, skip_channels=0, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
# 总输入通道 = 输入通道 + skip通道
total_in_channels = in_channels + skip_channels
# 主路径:反卷积 + 两个卷积层
self.conv_transpose = nn.ConvTranspose2d(
total_in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=True
)
self.conv1 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
self.conv2 = nn.Conv2d(out_channels, out_channels,
kernel_size=3, padding=1, bias=True)
# 残差路径:如果需要改变通道数或空间尺寸
self.shortcut = nn.Identity()
if total_in_channels != out_channels or stride != 1:
if stride == 1:
self.shortcut = nn.Conv2d(total_in_channels, out_channels,
kernel_size=1, bias=True)
else:
self.shortcut = nn.ConvTranspose2d(
total_in_channels, out_channels,
kernel_size=1,
stride=stride,
padding=0,
output_padding=output_padding,
bias=True
)
# 使用LeakyReLU避免死亡神经元
self.activation = nn.LeakyReLU(0.2, inplace=True)
# 初始化权重
self._init_weights()
def _init_weights(self):
# 初始化反卷积层
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv_transpose.bias is not None:
nn.init.constant_(self.conv_transpose.bias, 0)
# 初始化卷积层
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv1.bias is not None:
nn.init.constant_(self.conv1.bias, 0)
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.conv2.bias is not None:
nn.init.constant_(self.conv2.bias, 0)
# 初始化shortcut
if not isinstance(self.shortcut, nn.Identity):
if isinstance(self.shortcut, nn.Conv2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(self.shortcut, nn.ConvTranspose2d):
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
if self.shortcut.bias is not None:
nn.init.constant_(self.shortcut.bias, 0)
def forward(self, x, skip_feature=None):
# 如果有skip feature将其与输入拼接
if skip_feature is not None:
# 确保skip特征的空间尺寸与x匹配
if skip_feature.shape[2:] != x.shape[2:]:
# 使用双线性插值进行上采样或下采样
skip_feature = torch.nn.functional.interpolate(
skip_feature,
size=x.shape[2:],
mode='bilinear',
align_corners=False
)
x = torch.cat([x, skip_feature], dim=1)
identity = self.shortcut(x)
# 主路径
x = self.conv_transpose(x)
x = self.activation(x)
x = self.conv1(x)
x = self.activation(x)
x = self.conv2(x)
# 残差连接
x = x + identity
x = self.activation(x)
return x
class FramePredictionDecoder(nn.Module):
"""Improved decoder for frame prediction with better upsampling strategy"""
def __init__(self, embed_dims, output_channels=1, use_skip=False):
super().__init__()
self.use_skip = use_skip
# Reverse the embed_dims for decoder
decoder_dims = embed_dims[::-1]
self.blocks = nn.ModuleList()
if use_skip:
# 使用支持skip connections的block
# 第一个block从bottleneck到stage4使用大步长stride=4skip来自stage3
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[0], decoder_dims[1],
skip_channels=embed_dims[3], # stage3的通道数
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
))
# 第二个blockstage4到stage3stride=2skip来自stage2
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[1], decoder_dims[2],
skip_channels=embed_dims[2], # stage2的通道数
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第三个blockstage3到stage2stride=2skip来自stage1
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[2], decoder_dims[3],
skip_channels=embed_dims[1], # stage1的通道数
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第四个blockstage2到stage1stride=2skip来自stage0
self.blocks.append(DecoderBlockWithSkip(
decoder_dims[3], 64, # 输出到64通道
skip_channels=embed_dims[0], # stage0的通道数
kernel_size=3, stride=2, padding=1, output_padding=1
))
else:
# 使用普通的DecoderBlock第一个block使用大步长
self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
))
self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1
))
self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 第四个block增加到64通道
self.blocks.append(DecoderBlock(
decoder_dims[3], 64,
kernel_size=3, stride=2, padding=1, output_padding=1
))
# 改进的最终输出层:不使用反卷积,只进行特征精炼
# 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合
self.final_block = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True)
# 移除Tanh让输出在任意范围由损失函数和归一化处理
)
def forward(self, x, skip_features=None):
"""
Args:
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
skip_features: list of encoder features from stages [stage3, stage2, stage1, stage0]
each of shape [B, C, H', W'] where C matches encoder dims
"""
if self.use_skip:
if skip_features is None:
raise ValueError("skip_features must be provided when use_skip=True")
# 确保有4个skip features
assert len(skip_features) == 4, f"Need 4 skip features, got {len(skip_features)}"
# 反转顺序以匹配解码器stage3, stage2, stage1, stage0
skip_features = skip_features[::-1]
# 调整skip特征的尺寸以匹配新的上采样策略
adjusted_skip_features = []
for i, skip in enumerate(skip_features):
if skip is not None:
# 计算目标尺寸4, 2, 2, 2倍上采样
upsample_factors = [4, 2, 2, 2]
target_height = x.shape[2] * upsample_factors[i]
target_width = x.shape[3] * upsample_factors[i]
if skip.shape[2:] != (target_height, target_width):
skip = torch.nn.functional.interpolate(
skip,
size=(target_height, target_width),
mode='bilinear',
align_corners=False
)
adjusted_skip_features.append(skip)
# 四个block使用skip connections
for i in range(4):
x = self.blocks[i](x, adjusted_skip_features[i])
else:
# 不使用skip connections
for i in range(4):
x = self.blocks[i](x)
# 最终输出层:只进行特征精炼,不上采样
x = self.final_block(x)
return x
class SwiftFormerTemporal(nn.Module):
"""
SwiftFormer with temporal input for frame prediction.
Input: [B, num_frames, H, W] (Y channel only)
Output: predicted frame [B, 1, H, W] and optional representation
"""
def __init__(self,
model_name='XS',
num_frames=3,
use_decoder=True,
use_skip=True, # 新增是否使用skip connections
use_representation_head=False,
representation_dim=128,
return_features=False,
**kwargs):
super().__init__()
# Get model configuration
layers = SwiftFormer_depth[model_name]
embed_dims = SwiftFormer_width[model_name]
# Store configuration
self.num_frames = num_frames
self.use_decoder = use_decoder
self.use_skip = use_skip # 保存skip connections设置
self.use_representation_head = use_representation_head
self.return_features = return_features
# Modify stem to accept multiple frames (only Y channel)
in_channels = num_frames
self.patch_embed = stem(in_channels, embed_dims[0])
# Build encoder network (same as SwiftFormer)
network = []
for i in range(len(layers)):
stage = Stage(embed_dims[i], i, layers, mlp_ratio=4,
act_layer=nn.GELU,
drop_rate=0., drop_path_rate=0.,
use_layer_scale=True,
layer_scale_init_value=1e-5,
vit_num=1)
network.append(stage)
if i >= len(layers) - 1:
break
if embed_dims[i] != embed_dims[i + 1]:
network.append(
Embedding(
patch_size=3, stride=2, padding=1,
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
)
)
self.network = nn.ModuleList(network)
self.norm = nn.BatchNorm2d(embed_dims[-1])
# Frame prediction decoder
if use_decoder:
self.decoder = FramePredictionDecoder(
embed_dims,
output_channels=1,
use_skip=use_skip # 传递skip connections设置
)
# Representation head for pose/velocity prediction
if use_representation_head:
self.representation_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(embed_dims[-1], representation_dim),
nn.ReLU(),
nn.Linear(representation_dim, representation_dim)
)
else:
self.representation_head = None
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
# 使用Kaiming初始化适合ReLU/LeakyReLU
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.ConvTranspose2d):
# 反卷积层使用特定的初始化
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x):
"""Forward through encoder network, return list of stage features if return_features else final output"""
if self.return_features or self.use_skip:
features = []
stage_idx = 0
for idx, block in enumerate(self.network):
x = block(x)
# 收集每个stage的输出stage0, stage1, stage2, stage3
# 根据SwiftFormer结构stage在索引0,2,4,6位置
if idx in [0, 2, 4, 6]:
features.append(x)
stage_idx += 1
return x, features
else:
for block in self.network:
x = block(x)
return x
def forward(self, x):
"""
Args:
x: input frames of shape [B, num_frames, H, W]
Returns:
If return_features is False:
pred_frame: predicted frame [B, 1, H, W] (or None)
representation: optional representation vector [B, representation_dim] (or None)
If return_features is True:
pred_frame, representation, features (list of stage features)
"""
# Encode
x = self.patch_embed(x)
if self.return_features or self.use_skip:
x, features = self.forward_tokens(x)
else:
x = self.forward_tokens(x)
x = self.norm(x)
# Get representation if needed
representation = None
if self.representation_head is not None:
representation = self.representation_head(x)
# Decode to frame
pred_frame = None
if self.use_decoder:
if self.use_skip:
# 提取用于skip connections的特征
# features包含所有stage的输出我们需要stage0, stage1, stage2, stage3
# 根据SwiftFormer结构应该有4个stage特征
if len(features) >= 4:
# 取四个stage的特征stage0, stage1, stage2, stage3
skip_features = [features[0], features[1], features[2], features[3]]
else:
# 如果特征不够,使用可用的特征
skip_features = features[:4]
# 如果特征仍然不够使用None填充
while len(skip_features) < 4:
skip_features.append(None)
pred_frame = self.decoder(x, skip_features)
else:
pred_frame = self.decoder(x)
if self.return_features:
return pred_frame, representation, features
else:
return pred_frame, representation
# Factory functions for different model sizes
def SwiftFormerTemporal_XS(num_frames=3, use_skip=True, **kwargs):
return SwiftFormerTemporal('XS', num_frames=num_frames, use_skip=use_skip, **kwargs)
def SwiftFormerTemporal_S(num_frames=3, use_skip=True, **kwargs):
return SwiftFormerTemporal('S', num_frames=num_frames, use_skip=use_skip, **kwargs)
def SwiftFormerTemporal_L1(num_frames=3, use_skip=True, **kwargs):
return SwiftFormerTemporal('l1', num_frames=num_frames, use_skip=use_skip, **kwargs)
def SwiftFormerTemporal_L3(num_frames=3, use_skip=True, **kwargs):
return SwiftFormerTemporal('l3', num_frames=num_frames, use_skip=use_skip, **kwargs)