删除残差路径和shortcut,镜像问题仍存在
This commit is contained in:
@@ -45,7 +45,6 @@ def denormalize(tensor):
|
|||||||
# [0, 1] -> [0, 255]
|
# [0, 1] -> [0, 255]
|
||||||
tensor = tensor * 255
|
tensor = tensor * 255
|
||||||
return tensor.clamp(0, 255)
|
return tensor.clamp(0, 255)
|
||||||
# return tensor
|
|
||||||
|
|
||||||
def minmax_denormalize(tensor):
|
def minmax_denormalize(tensor):
|
||||||
tensor_min = tensor.min()
|
tensor_min = tensor.min()
|
||||||
@@ -76,28 +75,16 @@ def calculate_metrics(pred, target, debug=False):
|
|||||||
if target_np.ndim == 3:
|
if target_np.ndim == 3:
|
||||||
target_np = target_np.squeeze(0)
|
target_np = target_np.squeeze(0)
|
||||||
|
|
||||||
if debug:
|
# if debug:
|
||||||
print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
|
# print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
|
||||||
print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
|
# print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
|
||||||
print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
|
# print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
|
||||||
|
|
||||||
# 计算MSE - 修复错误的tmp公式
|
|
||||||
# 原错误公式: tmp = 1 - (pred_np - target_np) / 255 * 2
|
|
||||||
# 正确公式: 直接计算像素差的平方
|
|
||||||
mse = np.mean((pred_np - target_np) ** 2)
|
mse = np.mean((pred_np - target_np) ** 2)
|
||||||
|
|
||||||
# 同时计算错误公式的MSE用于对比
|
|
||||||
tmp = 1 - (pred_np - target_np) / 255 * 2
|
|
||||||
wrong_mse = np.mean(tmp**2)
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f"[DEBUG] Correct MSE: {mse:.6f}, Wrong MSE (tmp formula): {wrong_mse:.6f}")
|
|
||||||
|
|
||||||
# 计算SSIM (数据范围0-255)
|
|
||||||
data_range = 255.0
|
data_range = 255.0
|
||||||
ssim_value = ssim(pred_np, target_np, data_range=data_range)
|
ssim_value = ssim(pred_np, target_np, data_range=data_range)
|
||||||
|
|
||||||
# 计算PSNR
|
|
||||||
psnr_value = psnr(target_np, pred_np, data_range=data_range)
|
psnr_value = psnr(target_np, pred_np, data_range=data_range)
|
||||||
|
|
||||||
return mse, ssim_value, psnr_value
|
return mse, ssim_value, psnr_value
|
||||||
@@ -146,16 +133,6 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
|
|||||||
#debug print
|
#debug print
|
||||||
print(target_frame)
|
print(target_frame)
|
||||||
print(pred_frame)
|
print(pred_frame)
|
||||||
|
|
||||||
# # debug print - 改进为更有信息量的输出
|
|
||||||
# if isinstance(pred_frame, np.ndarray):
|
|
||||||
# print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
|
|
||||||
# # 检查是否有大量值在127.5附近
|
|
||||||
# mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
|
|
||||||
# percent_near_127_5 = np.mean(mask_near_127_5) * 100
|
|
||||||
# print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
|
|
||||||
# else:
|
|
||||||
# print(pred_frame)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||||
@@ -216,13 +193,13 @@ def evaluate_model(model, data_loader, device, args):
|
|||||||
|
|
||||||
# 对第一个样本启用调试
|
# 对第一个样本启用调试
|
||||||
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
|
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
|
||||||
if debug_mode:
|
# if debug_mode:
|
||||||
print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
|
# print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
|
||||||
print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
|
# print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
|
||||||
print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
|
# print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
|
||||||
print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
|
# print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
|
||||||
|
|
||||||
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=debug_mode)
|
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=False)
|
||||||
|
|
||||||
total_mse += mse
|
total_mse += mse
|
||||||
total_ssim += ssim_value
|
total_ssim += ssim_value
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ def get_args_parser():
|
|||||||
help='Number of input frames (T)')
|
help='Number of input frames (T)')
|
||||||
parser.add_argument('--frame-size', default=224, type=int,
|
parser.add_argument('--frame-size', default=224, type=int,
|
||||||
help='Input frame size')
|
help='Input frame size')
|
||||||
parser.add_argument('--max-interval', default=4, type=int,
|
parser.add_argument('--max-interval', default=10, type=int,
|
||||||
help='Maximum interval between consecutive frames')
|
help='Maximum interval between consecutive frames')
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
@@ -121,7 +121,7 @@ def get_args_parser():
|
|||||||
help='start epoch')
|
help='start epoch')
|
||||||
parser.add_argument('--eval', action='store_true',
|
parser.add_argument('--eval', action='store_true',
|
||||||
help='Perform evaluation only')
|
help='Perform evaluation only')
|
||||||
parser.add_argument('--num-workers', default=4, type=int)
|
parser.add_argument('--num-workers', default=16, type=int)
|
||||||
parser.add_argument('--pin-mem', action='store_true',
|
parser.add_argument('--pin-mem', action='store_true',
|
||||||
help='Pin CPU memory in DataLoader')
|
help='Pin CPU memory in DataLoader')
|
||||||
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
|
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
|
||||||
@@ -264,7 +264,7 @@ def main(args):
|
|||||||
checkpoint = torch.hub.load_state_dict_from_url(
|
checkpoint = torch.hub.load_state_dict_from_url(
|
||||||
args.resume, map_location='cpu', check_hash=True)
|
args.resume, map_location='cpu', check_hash=True)
|
||||||
else:
|
else:
|
||||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
|
||||||
|
|
||||||
model_without_ddp.load_state_dict(checkpoint['model'])
|
model_without_ddp.load_state_dict(checkpoint['model'])
|
||||||
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
||||||
@@ -308,7 +308,7 @@ def main(args):
|
|||||||
|
|
||||||
train_stats, global_step = train_one_epoch(
|
train_stats, global_step = train_one_epoch(
|
||||||
model, criterion, data_loader_train,
|
model, criterion, data_loader_train,
|
||||||
optimizer, device, epoch, loss_scaler,
|
optimizer, device, epoch, loss_scaler, args.clip_grad, args.clip_mode,
|
||||||
model_ema=model_ema, writer=writer,
|
model_ema=model_ema, writer=writer,
|
||||||
global_step=global_step, args=args
|
global_step=global_step, args=args
|
||||||
)
|
)
|
||||||
@@ -356,7 +356,7 @@ def main(args):
|
|||||||
|
|
||||||
|
|
||||||
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
|
||||||
clip_grad=None, clip_mode='norm', model_ema=None, writer=None,
|
clip_grad=0.01, clip_mode='norm', model_ema=None, writer=None,
|
||||||
global_step=0, args=None, **kwargs):
|
global_step=0, args=None, **kwargs):
|
||||||
model.train()
|
model.train()
|
||||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from timm.layers import DropPath, trunc_normal_
|
|||||||
|
|
||||||
|
|
||||||
class DecoderBlock(nn.Module):
|
class DecoderBlock(nn.Module):
|
||||||
"""Upsampling block for frame prediction decoder with residual connections"""
|
"""Upsampling block for frame prediction decoder without residual connections"""
|
||||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 主路径:反卷积 + 两个卷积层
|
# 主路径:反卷积 + 两个卷积层
|
||||||
@@ -31,28 +31,6 @@ class DecoderBlock(nn.Module):
|
|||||||
kernel_size=3, padding=1, bias=False)
|
kernel_size=3, padding=1, bias=False)
|
||||||
self.bn3 = nn.BatchNorm2d(out_channels)
|
self.bn3 = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
# 残差路径:如果需要改变通道数或空间尺寸
|
|
||||||
self.shortcut = nn.Identity()
|
|
||||||
if in_channels != out_channels or stride != 1:
|
|
||||||
# 使用1x1卷积调整通道数,如果需要上采样则使用反卷积
|
|
||||||
if stride == 1:
|
|
||||||
self.shortcut = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
|
|
||||||
nn.BatchNorm2d(out_channels)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.shortcut = nn.Sequential(
|
|
||||||
nn.ConvTranspose2d(
|
|
||||||
in_channels, out_channels,
|
|
||||||
kernel_size=1,
|
|
||||||
stride=stride,
|
|
||||||
padding=0,
|
|
||||||
output_padding=output_padding,
|
|
||||||
bias=False
|
|
||||||
),
|
|
||||||
nn.BatchNorm2d(out_channels)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 使用ReLU激活函数
|
# 使用ReLU激活函数
|
||||||
self.activation = nn.ReLU(inplace=True)
|
self.activation = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
@@ -67,13 +45,6 @@ class DecoderBlock(nn.Module):
|
|||||||
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
|
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
|
||||||
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
|
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
|
||||||
|
|
||||||
# 初始化shortcut
|
|
||||||
if not isinstance(self.shortcut, nn.Identity):
|
|
||||||
# shortcut现在是Sequential,需要初始化其中的卷积层
|
|
||||||
for module in self.shortcut:
|
|
||||||
if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
|
|
||||||
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
|
||||||
|
|
||||||
# 初始化BN层(使用默认初始化)
|
# 初始化BN层(使用默认初始化)
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.BatchNorm2d):
|
if isinstance(m, nn.BatchNorm2d):
|
||||||
@@ -81,8 +52,6 @@ class DecoderBlock(nn.Module):
|
|||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
identity = self.shortcut(x)
|
|
||||||
|
|
||||||
# 主路径
|
# 主路径
|
||||||
x = self.conv_transpose(x)
|
x = self.conv_transpose(x)
|
||||||
x = self.bn1(x)
|
x = self.bn1(x)
|
||||||
@@ -94,9 +63,6 @@ class DecoderBlock(nn.Module):
|
|||||||
|
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
x = self.bn3(x)
|
x = self.bn3(x)
|
||||||
|
|
||||||
# 残差连接
|
|
||||||
x = x + identity
|
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -105,28 +71,28 @@ class FramePredictionDecoder(nn.Module):
|
|||||||
"""Improved decoder for frame prediction"""
|
"""Improved decoder for frame prediction"""
|
||||||
def __init__(self, embed_dims, output_channels=1):
|
def __init__(self, embed_dims, output_channels=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# Reverse the embed_dims for decoder
|
# Define decoder dimensions independently (no skip connections)
|
||||||
decoder_dims = embed_dims[::-1]
|
start_dim = embed_dims[-1]
|
||||||
|
decoder_dims = [start_dim // (2 ** i) for i in range(4)] # e.g., [220, 110, 55, 27] for XS
|
||||||
|
|
||||||
self.blocks = nn.ModuleList()
|
self.blocks = nn.ModuleList()
|
||||||
|
|
||||||
# 调整顺序:将stride=4放在倒数第二的位置
|
# 第一个block:stride=2 (decoder_dims[0] -> decoder_dims[1])
|
||||||
# 第一个block:stride=2 (220 -> 112)
|
|
||||||
self.blocks.append(DecoderBlock(
|
self.blocks.append(DecoderBlock(
|
||||||
decoder_dims[0], decoder_dims[1],
|
decoder_dims[0], decoder_dims[1],
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
))
|
))
|
||||||
# 第二个block:stride=2 (112 -> 56)
|
# 第二个block:stride=2 (decoder_dims[1] -> decoder_dims[2])
|
||||||
self.blocks.append(DecoderBlock(
|
self.blocks.append(DecoderBlock(
|
||||||
decoder_dims[1], decoder_dims[2],
|
decoder_dims[1], decoder_dims[2],
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
))
|
))
|
||||||
# 第三个block:stride=2 (56 -> 48)
|
# 第三个block:stride=2 (decoder_dims[2] -> decoder_dims[3])
|
||||||
self.blocks.append(DecoderBlock(
|
self.blocks.append(DecoderBlock(
|
||||||
decoder_dims[2], decoder_dims[3],
|
decoder_dims[2], decoder_dims[3],
|
||||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||||
))
|
))
|
||||||
# 第四个block:stride=4 (48 -> 64),放在倒数第二的位置
|
# 第四个block:stride=4 (decoder_dims[3] -> 64),放在倒数第二的位置
|
||||||
self.blocks.append(DecoderBlock(
|
self.blocks.append(DecoderBlock(
|
||||||
decoder_dims[3], 64,
|
decoder_dims[3], 64,
|
||||||
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
|
kernel_size=3, stride=4, padding=1, output_padding=3 # stride=4放在这里
|
||||||
@@ -138,7 +104,7 @@ class FramePredictionDecoder(nn.Module):
|
|||||||
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
|
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
|
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True),
|
||||||
nn.Tanh() # 添加Tanh激活函数,约束输出在[-1, 1]范围内
|
nn.Tanh()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
Reference in New Issue
Block a user