511 lines
19 KiB
Python
511 lines
19 KiB
Python
"""
|
||
评估脚本 for SwiftFormerTemporal frame prediction
|
||
输出预测图(注意反归一化)以及对应指标(mse&ssim&psnr)
|
||
"""
|
||
import argparse
|
||
import os
|
||
import torch
|
||
import torch.nn as nn
|
||
import pickle
|
||
import numpy as np
|
||
import random
|
||
from pathlib import Path
|
||
import json
|
||
import matplotlib.pyplot as plt
|
||
from PIL import Image
|
||
import torchvision.transforms as transforms
|
||
import torch.backends.cudnn as cudnn
|
||
|
||
from util.video_dataset import VideoFrameDataset
|
||
from models.swiftformer_temporal import (
|
||
SwiftFormerTemporal_XS, SwiftFormerTemporal_S,
|
||
SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
||
)
|
||
|
||
# 导入SSIM和PSNR计算
|
||
from skimage.metrics import structural_similarity as ssim
|
||
from skimage.metrics import peak_signal_noise_ratio as psnr
|
||
import warnings
|
||
warnings.filterwarnings('ignore')
|
||
|
||
|
||
def denormalize(tensor):
|
||
"""
|
||
将[-1, 1]范围的张量反归一化到[0, 255]范围
|
||
Args:
|
||
tensor: 形状为[B, C, H, W]或[C, H, W],值在[-1, 1]
|
||
Returns:
|
||
反归一化后的张量,值在[0, 255]
|
||
"""
|
||
# clip 到 [-1, 1] 范围
|
||
tensor = tensor.clamp(-1, 1)
|
||
|
||
# [-1, 1] -> [0, 1]
|
||
tensor = (tensor + 1) / 2
|
||
# [0, 1] -> [0, 255]
|
||
tensor = tensor * 255
|
||
return tensor.clamp(0, 255)
|
||
# return tensor
|
||
|
||
def minmax_denormalize(tensor):
|
||
tensor_min = tensor.min()
|
||
tensor_max = tensor.max()
|
||
tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
|
||
# tensor = tensor*2-1
|
||
tensor = tensor*255
|
||
return tensor.clamp(0, 255)
|
||
|
||
|
||
def calculate_metrics(pred, target, debug=False):
|
||
"""
|
||
计算MSE, SSIM, PSNR指标
|
||
Args:
|
||
pred: 预测图像,形状[H, W],值在[0, 255]
|
||
target: 目标图像,形状[H, W],值在[0, 255]
|
||
debug: 是否输出调试信息
|
||
Returns:
|
||
mse, ssim_value, psnr_value
|
||
"""
|
||
# 转换为numpy数组
|
||
pred_np = pred.cpu().numpy() if torch.is_tensor(pred) else pred
|
||
target_np = target.cpu().numpy() if torch.is_tensor(target) else target
|
||
|
||
# 确保是2D数组
|
||
if pred_np.ndim == 3:
|
||
pred_np = pred_np.squeeze(0)
|
||
if target_np.ndim == 3:
|
||
target_np = target_np.squeeze(0)
|
||
|
||
if debug:
|
||
print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
|
||
print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
|
||
print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
|
||
|
||
# 计算MSE - 修复错误的tmp公式
|
||
# 原错误公式: tmp = 1 - (pred_np - target_np) / 255 * 2
|
||
# 正确公式: 直接计算像素差的平方
|
||
mse = np.mean((pred_np - target_np) ** 2)
|
||
|
||
# 同时计算错误公式的MSE用于对比
|
||
tmp = 1 - (pred_np - target_np) / 255 * 2
|
||
wrong_mse = np.mean(tmp**2)
|
||
|
||
if debug:
|
||
print(f"[DEBUG] Correct MSE: {mse:.6f}, Wrong MSE (tmp formula): {wrong_mse:.6f}")
|
||
|
||
# 计算SSIM (数据范围0-255)
|
||
data_range = 255.0
|
||
ssim_value = ssim(pred_np, target_np, data_range=data_range)
|
||
|
||
# 计算PSNR
|
||
psnr_value = psnr(target_np, pred_np, data_range=data_range)
|
||
|
||
return mse, ssim_value, psnr_value
|
||
|
||
|
||
def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
|
||
input_frame_indices=None, target_frame_index=None):
|
||
"""
|
||
保存对比图:输入帧、目标帧、预测帧
|
||
Args:
|
||
input_frames: 输入帧列表,每个形状为[H, W],值在[0, 255]
|
||
target_frame: 目标帧,形状[H, W],值在[0, 255]
|
||
pred_frame: 预测帧,形状[H, W],值在[0, 255]
|
||
save_path: 保存路径
|
||
input_frame_indices: 输入帧的索引列表(可选)
|
||
target_frame_index: 目标帧索引(可选)
|
||
"""
|
||
num_input = len(input_frames)
|
||
fig, axes = plt.subplots(1, num_input + 2, figsize=(4*(num_input+2), 4))
|
||
|
||
# 绘制输入帧
|
||
for i in range(num_input):
|
||
ax = axes[i]
|
||
ax.imshow(input_frames[i], cmap='gray')
|
||
if input_frame_indices is not None:
|
||
ax.set_title(f'Input Frame {input_frame_indices[i]}')
|
||
else:
|
||
ax.set_title(f'Input {i+1}')
|
||
ax.axis('off')
|
||
|
||
# 绘制目标帧
|
||
ax = axes[num_input]
|
||
ax.imshow(target_frame, cmap='gray')
|
||
if target_frame_index is not None:
|
||
ax.set_title(f'Target Frame {target_frame_index}')
|
||
else:
|
||
ax.set_title('Target')
|
||
ax.axis('off')
|
||
|
||
# 绘制预测帧
|
||
ax = axes[num_input + 1]
|
||
ax.imshow(pred_frame, cmap='gray')
|
||
ax.set_title('Predicted')
|
||
ax.axis('off')
|
||
|
||
#debug print
|
||
print(target_frame)
|
||
print(pred_frame)
|
||
|
||
# debug print - 改进为更有信息量的输出
|
||
if isinstance(pred_frame, np.ndarray):
|
||
print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
|
||
# 检查是否有大量值在127.5附近
|
||
mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
|
||
percent_near_127_5 = np.mean(mask_near_127_5) * 100
|
||
print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
|
||
else:
|
||
print(pred_frame)
|
||
|
||
plt.tight_layout()
|
||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
|
||
def evaluate_model(model, data_loader, device, args):
|
||
"""
|
||
评估模型并计算指标
|
||
Args:
|
||
model: 训练好的模型
|
||
data_loader: 数据加载器
|
||
device: 设备
|
||
args: 命令行参数
|
||
Returns:
|
||
metrics_dict: 包含所有指标的字典
|
||
sample_results: 示例结果用于可视化
|
||
"""
|
||
model.eval()
|
||
# model.train() # 临时使用训练模式
|
||
|
||
# 初始化指标累加器
|
||
total_mse = 0.0
|
||
total_ssim = 0.0
|
||
total_psnr = 0.0
|
||
total_samples = 0
|
||
|
||
# 存储示例结果用于可视化(使用蓄水池抽样随机选择)
|
||
sample_results = []
|
||
max_samples_to_save = args.num_samples_to_save
|
||
max_samples = args.max_samples
|
||
# 用于蓄水池抽样的计数器(已处理的样本数,不包括因max_samples限制而跳过的样本)
|
||
sample_count = 0
|
||
|
||
with torch.no_grad():
|
||
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(data_loader):
|
||
input_frames = input_frames.to(device, non_blocking=True)
|
||
target_frames = target_frames.to(device, non_blocking=True)
|
||
|
||
# 前向传播
|
||
pred_frames = model(input_frames)
|
||
|
||
# 反归一化用于指标计算
|
||
# pred_denorm = minmax_denormalize(pred_frames) # [B, 1, H, W]
|
||
pred_denorm = denormalize(pred_frames)
|
||
target_denorm = denormalize(target_frames) # [B, 1, H, W]
|
||
|
||
batch_size = input_frames.size(0)
|
||
|
||
# 计算每个样本的指标
|
||
for i in range(batch_size):
|
||
# 检查是否达到最大样本数限制
|
||
if max_samples is not None and total_samples >= max_samples:
|
||
break
|
||
|
||
pred_i = pred_denorm[i] # [1, H, W]
|
||
target_i = target_denorm[i] # [1, H, W]
|
||
|
||
# 对第一个样本启用调试
|
||
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
|
||
if debug_mode:
|
||
print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
|
||
print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
|
||
print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
|
||
print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
|
||
|
||
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=debug_mode)
|
||
|
||
total_mse += mse
|
||
total_ssim += ssim_value
|
||
total_psnr += psnr_value
|
||
total_samples += 1
|
||
sample_count += 1
|
||
|
||
# 构建样本数据字典
|
||
input_denorm = denormalize(input_frames[i]) # [num_frames, H, W]
|
||
# 分离输入帧
|
||
input_frames_list = []
|
||
for j in range(args.num_frames):
|
||
input_frame_j = input_denorm[j].squeeze(0) # [H, W]
|
||
input_frames_list.append(input_frame_j.cpu().numpy())
|
||
|
||
sample_data = {
|
||
'input_frames': input_frames_list,
|
||
'target_frame': target_i.squeeze(0).cpu().numpy(),
|
||
'pred_frame': pred_i.squeeze(0).cpu().numpy(),
|
||
'metrics': {
|
||
'mse': mse,
|
||
'ssim': ssim_value,
|
||
'psnr': psnr_value
|
||
},
|
||
'batch_idx': batch_idx,
|
||
'sample_idx': i
|
||
}
|
||
|
||
# 蓄水池抽样 (Reservoir Sampling)
|
||
if sample_count <= max_samples_to_save:
|
||
# 蓄水池未满,直接加入
|
||
sample_results.append(sample_data)
|
||
else:
|
||
# 以 max_samples_to_save / sample_count 的概率替换蓄水池中的一个随机位置
|
||
r = random.randint(0, sample_count - 1)
|
||
if r < max_samples_to_save:
|
||
sample_results[r] = sample_data
|
||
|
||
# 检查是否达到最大样本数限制
|
||
if max_samples is not None and total_samples >= max_samples:
|
||
print(f"达到最大样本数限制: {max_samples}")
|
||
break
|
||
|
||
# 进度打印
|
||
if (batch_idx + 1) % 10 == 0:
|
||
print(f'Processed {batch_idx + 1} batches, {total_samples} samples')
|
||
|
||
# 计算平均指标
|
||
if total_samples > 0:
|
||
avg_mse = float(total_mse / total_samples)
|
||
avg_ssim = float(total_ssim / total_samples)
|
||
avg_psnr = float(total_psnr / total_samples)
|
||
else:
|
||
avg_mse = avg_ssim = avg_psnr = 0.0
|
||
|
||
metrics_dict = {
|
||
'mse': avg_mse,
|
||
'ssim': avg_ssim,
|
||
'psnr': avg_psnr,
|
||
'num_samples': total_samples
|
||
}
|
||
|
||
return metrics_dict, sample_results
|
||
|
||
|
||
def main(args):
|
||
print("评估参数:", args)
|
||
|
||
device = torch.device(args.device)
|
||
|
||
# 设置随机种子
|
||
torch.manual_seed(args.seed)
|
||
np.random.seed(args.seed)
|
||
random.seed(args.seed)
|
||
|
||
cudnn.benchmark = True
|
||
|
||
# 构建数据集
|
||
print("构建数据集...")
|
||
dataset_val = VideoFrameDataset(
|
||
root_dir=args.data_path,
|
||
num_frames=args.num_frames,
|
||
frame_size=args.frame_size,
|
||
is_train=False,
|
||
max_interval=args.max_interval
|
||
)
|
||
|
||
data_loader_val = torch.utils.data.DataLoader(
|
||
dataset_val,
|
||
batch_size=args.batch_size,
|
||
num_workers=args.num_workers,
|
||
pin_memory=args.pin_mem,
|
||
shuffle=False,
|
||
drop_last=False
|
||
)
|
||
|
||
# 创建模型
|
||
print(f"创建模型: {args.model}")
|
||
model_kwargs = {
|
||
'num_frames': args.num_frames,
|
||
}
|
||
|
||
if args.model == 'SwiftFormerTemporal_XS':
|
||
model = SwiftFormerTemporal_XS(**model_kwargs)
|
||
elif args.model == 'SwiftFormerTemporal_S':
|
||
model = SwiftFormerTemporal_S(**model_kwargs)
|
||
elif args.model == 'SwiftFormerTemporal_L1':
|
||
model = SwiftFormerTemporal_L1(**model_kwargs)
|
||
elif args.model == 'SwiftFormerTemporal_L3':
|
||
model = SwiftFormerTemporal_L3(**model_kwargs)
|
||
else:
|
||
raise ValueError(f"未知模型: {args.model}")
|
||
|
||
model.to(device)
|
||
|
||
# 加载检查点
|
||
if args.resume:
|
||
print(f"加载检查点: {args.resume}")
|
||
try:
|
||
# 尝试使用weights_only=False加载(PyTorch 2.6+需要)
|
||
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
|
||
except (pickle.UnpicklingError, TypeError) as e:
|
||
print(f"使用weights_only=False加载失败: {e}")
|
||
print("尝试使用torch.serialization.add_safe_globals...")
|
||
# from argparse import Namespace
|
||
# # 添加安全全局变量
|
||
# torch.serialization.add_safe_globals([Namespace])
|
||
# checkpoint = torch.load(args.resume, map_location='cpu')
|
||
|
||
# 处理状态字典(可能包含'module.'前缀)
|
||
if 'model' in checkpoint:
|
||
state_dict = checkpoint['model']
|
||
elif 'state_dict' in checkpoint:
|
||
state_dict = checkpoint['state_dict']
|
||
else:
|
||
state_dict = checkpoint
|
||
|
||
# 移除'module.'前缀(如果存在)
|
||
if hasattr(model, 'module'):
|
||
model.module.load_state_dict(state_dict)
|
||
else:
|
||
# 如果状态字典有'module.'前缀但模型没有,需要移除前缀
|
||
if any(key.startswith('module.') for key in state_dict.keys()):
|
||
from collections import OrderedDict
|
||
new_state_dict = OrderedDict()
|
||
for k, v in state_dict.items():
|
||
if k.startswith('module.'):
|
||
new_state_dict[k[7:]] = v
|
||
else:
|
||
new_state_dict[k] = v
|
||
state_dict = new_state_dict
|
||
model.load_state_dict(state_dict)
|
||
|
||
print(f"检查点加载成功,epoch: {checkpoint.get('epoch', 'unknown')}")
|
||
else:
|
||
print("警告: 未提供检查点路径,使用随机初始化的模型")
|
||
|
||
# 创建输出目录
|
||
output_dir = Path(args.output_dir)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 评估模型
|
||
print("开始评估...")
|
||
metrics, sample_results = evaluate_model(model, data_loader_val, device, args)
|
||
|
||
# 打印指标
|
||
print("\n" + "="*50)
|
||
print("评估结果:")
|
||
print(f"MSE: {metrics['mse']:.6f}")
|
||
print(f"SSIM: {metrics['ssim']:.6f}")
|
||
print(f"PSNR: {metrics['psnr']:.6f} dB")
|
||
print(f"样本数量: {metrics['num_samples']}")
|
||
print("="*50)
|
||
|
||
# 保存指标到JSON文件
|
||
metrics_file = output_dir / 'evaluation_metrics.json'
|
||
with open(metrics_file, 'w') as f:
|
||
json.dump(metrics, f, indent=4)
|
||
print(f"指标已保存到: {metrics_file}")
|
||
|
||
# 保存示例可视化
|
||
if sample_results:
|
||
print(f"\n保存 {len(sample_results)} 个示例可视化...")
|
||
samples_dir = output_dir / 'sample_predictions'
|
||
samples_dir.mkdir(exist_ok=True)
|
||
|
||
for i, sample in enumerate(sample_results):
|
||
save_path = samples_dir / f'sample_{i:03d}.png'
|
||
|
||
# 生成输入帧索引(假设连续)
|
||
input_frame_indices = list(range(1, args.num_frames + 1))
|
||
target_frame_index = args.num_frames + 1
|
||
|
||
save_comparison_figure(
|
||
sample['input_frames'],
|
||
sample['target_frame'],
|
||
sample['pred_frame'],
|
||
save_path,
|
||
input_frame_indices=input_frame_indices,
|
||
target_frame_index=target_frame_index
|
||
)
|
||
|
||
# 保存该样本的指标
|
||
sample_metrics_file = samples_dir / f'sample_{i:03d}_metrics.txt'
|
||
with open(sample_metrics_file, 'w') as f:
|
||
f.write(f"Sample {i} (batch {sample['batch_idx']}, idx {sample['sample_idx']})\n")
|
||
f.write(f"MSE: {sample['metrics']['mse']:.6f}\n")
|
||
f.write(f"SSIM: {sample['metrics']['ssim']:.6f}\n")
|
||
f.write(f"PSNR: {sample['metrics']['psnr']:.6f} dB\n")
|
||
|
||
print(f"示例可视化已保存到: {samples_dir}")
|
||
|
||
# 生成汇总报告
|
||
report_file = output_dir / 'evaluation_report.txt'
|
||
with open(report_file, 'w') as f:
|
||
f.write("SwiftFormerTemporal 帧预测评估报告\n")
|
||
f.write("="*50 + "\n")
|
||
f.write(f"模型: {args.model}\n")
|
||
f.write(f"检查点: {args.resume}\n")
|
||
f.write(f"数据集: {args.data_path}\n")
|
||
f.write(f"输入帧数: {args.num_frames}\n")
|
||
f.write(f"帧大小: {args.frame_size}\n")
|
||
f.write(f"批次大小: {args.batch_size}\n")
|
||
f.write(f"样本总数: {metrics['num_samples']}\n\n")
|
||
|
||
f.write("评估指标:\n")
|
||
f.write(f" MSE: {metrics['mse']:.6f}\n")
|
||
f.write(f" SSIM: {metrics['ssim']:.6f}\n")
|
||
f.write(f" PSNR: {metrics['psnr']:.6f} dB\n")
|
||
|
||
print(f"评估报告已保存到: {report_file}")
|
||
print("\n评估完成!")
|
||
|
||
|
||
def get_args_parser():
|
||
parser = argparse.ArgumentParser(
|
||
'SwiftFormerTemporal 评估脚本', add_help=False)
|
||
|
||
# 数据集参数
|
||
parser.add_argument('--data-path', default='./videos', type=str,
|
||
help='视频数据集路径')
|
||
parser.add_argument('--num-frames', default=3, type=int,
|
||
help='输入帧数 (T)')
|
||
parser.add_argument('--frame-size', default=224, type=int,
|
||
help='输入帧大小')
|
||
parser.add_argument('--max-interval', default=4, type=int,
|
||
help='连续帧之间的最大间隔')
|
||
|
||
# 模型参数
|
||
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
|
||
help='要评估的模型名称')
|
||
|
||
# 评估参数
|
||
parser.add_argument('--batch-size', default=16, type=int,
|
||
help='评估批次大小')
|
||
parser.add_argument('--num-samples-to-save', default=10, type=int,
|
||
help='保存可视化的样本数量')
|
||
parser.add_argument('--max-samples', default=None, type=int,
|
||
help='最大评估样本数(None表示全部)')
|
||
|
||
# 系统参数
|
||
parser.add_argument('--output-dir', default='./evaluation_results',
|
||
help='保存结果的路径')
|
||
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu',
|
||
help='使用的设备')
|
||
parser.add_argument('--seed', default=0, type=int)
|
||
parser.add_argument('--resume', default='', help='检查点路径')
|
||
parser.add_argument('--num-workers', default=4, type=int)
|
||
parser.add_argument('--pin-mem', action='store_true',
|
||
help='在DataLoader中固定CPU内存')
|
||
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
|
||
parser.set_defaults(pin_mem=True)
|
||
|
||
return parser
|
||
|
||
|
||
if __name__ == '__main__':
|
||
parser = argparse.ArgumentParser(
|
||
'SwiftFormerTemporal 评估', parents=[get_args_parser()])
|
||
args = parser.parse_args()
|
||
|
||
# 确保输出目录存在
|
||
if args.output_dir:
|
||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||
|
||
main(args) |