删除残差路径和shortcut,镜像问题仍存在
This commit is contained in:
@@ -45,7 +45,6 @@ def denormalize(tensor):
|
||||
# [0, 1] -> [0, 255]
|
||||
tensor = tensor * 255
|
||||
return tensor.clamp(0, 255)
|
||||
# return tensor
|
||||
|
||||
def minmax_denormalize(tensor):
|
||||
tensor_min = tensor.min()
|
||||
@@ -76,28 +75,16 @@ def calculate_metrics(pred, target, debug=False):
|
||||
if target_np.ndim == 3:
|
||||
target_np = target_np.squeeze(0)
|
||||
|
||||
if debug:
|
||||
print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
|
||||
print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
|
||||
print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
|
||||
# if debug:
|
||||
# print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
|
||||
# print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
|
||||
# print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
|
||||
|
||||
# 计算MSE - 修复错误的tmp公式
|
||||
# 原错误公式: tmp = 1 - (pred_np - target_np) / 255 * 2
|
||||
# 正确公式: 直接计算像素差的平方
|
||||
mse = np.mean((pred_np - target_np) ** 2)
|
||||
|
||||
# 同时计算错误公式的MSE用于对比
|
||||
tmp = 1 - (pred_np - target_np) / 255 * 2
|
||||
wrong_mse = np.mean(tmp**2)
|
||||
|
||||
if debug:
|
||||
print(f"[DEBUG] Correct MSE: {mse:.6f}, Wrong MSE (tmp formula): {wrong_mse:.6f}")
|
||||
|
||||
# 计算SSIM (数据范围0-255)
|
||||
data_range = 255.0
|
||||
ssim_value = ssim(pred_np, target_np, data_range=data_range)
|
||||
|
||||
# 计算PSNR
|
||||
psnr_value = psnr(target_np, pred_np, data_range=data_range)
|
||||
|
||||
return mse, ssim_value, psnr_value
|
||||
@@ -146,16 +133,6 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
|
||||
#debug print
|
||||
print(target_frame)
|
||||
print(pred_frame)
|
||||
|
||||
# # debug print - 改进为更有信息量的输出
|
||||
# if isinstance(pred_frame, np.ndarray):
|
||||
# print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
|
||||
# # 检查是否有大量值在127.5附近
|
||||
# mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
|
||||
# percent_near_127_5 = np.mean(mask_near_127_5) * 100
|
||||
# print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
|
||||
# else:
|
||||
# print(pred_frame)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
@@ -216,13 +193,13 @@ def evaluate_model(model, data_loader, device, args):
|
||||
|
||||
# 对第一个样本启用调试
|
||||
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
|
||||
if debug_mode:
|
||||
print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
|
||||
print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
|
||||
print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
|
||||
print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
|
||||
# if debug_mode:
|
||||
# print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
|
||||
# print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
|
||||
# print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
|
||||
# print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
|
||||
|
||||
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=debug_mode)
|
||||
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=False)
|
||||
|
||||
total_mse += mse
|
||||
total_ssim += ssim_value
|
||||
|
||||
Reference in New Issue
Block a user