更新模型结构,大步长反卷积后移,启用BN和tanh

This commit is contained in:
2026-01-15 21:12:27 +08:00
parent df703638da
commit a92a0b29e9
2 changed files with 67 additions and 86 deletions

View File

@@ -147,15 +147,15 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
print(target_frame)
print(pred_frame)
# debug print - 改进为更有信息量的输出
if isinstance(pred_frame, np.ndarray):
print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
# 检查是否有大量值在127.5附近
mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
percent_near_127_5 = np.mean(mask_near_127_5) * 100
print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
else:
print(pred_frame)
# # debug print - 改进为更有信息量的输出
# if isinstance(pred_frame, np.ndarray):
# print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
# # 检查是否有大量值在127.5附近
# mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
# percent_near_127_5 = np.mean(mask_near_127_5) * 100
# print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
# else:
# print(pred_frame)
plt.tight_layout()
plt.savefig(save_path, dpi=150, bbox_inches='tight')
@@ -347,10 +347,6 @@ def main(args):
except (pickle.UnpicklingError, TypeError) as e:
print(f"使用weights_only=False加载失败: {e}")
print("尝试使用torch.serialization.add_safe_globals...")
# from argparse import Namespace
# # 添加安全全局变量
# torch.serialization.add_safe_globals([Namespace])
# checkpoint = torch.load(args.resume, map_location='cpu')
# 处理状态字典(可能包含'module.'前缀)
if 'model' in checkpoint: