清理代码,删除跳连接部分
This commit is contained in:
@@ -45,6 +45,15 @@ def denormalize(tensor):
|
||||
# [0, 1] -> [0, 255]
|
||||
tensor = tensor * 255
|
||||
return tensor.clamp(0, 255)
|
||||
# return tensor
|
||||
|
||||
def minmax_denormalize(tensor):
|
||||
tensor_min = tensor.min()
|
||||
tensor_max = tensor.max()
|
||||
tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
|
||||
# tensor = tensor*2-1
|
||||
tensor = tensor*255
|
||||
return tensor.clamp(0, 255)
|
||||
|
||||
|
||||
def calculate_metrics(pred, target, debug=False):
|
||||
@@ -134,6 +143,10 @@ def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
|
||||
ax.set_title('Predicted')
|
||||
ax.axis('off')
|
||||
|
||||
#debug print
|
||||
print(target_frame)
|
||||
print(pred_frame)
|
||||
|
||||
# debug print - 改进为更有信息量的输出
|
||||
if isinstance(pred_frame, np.ndarray):
|
||||
print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
|
||||
@@ -161,8 +174,8 @@ def evaluate_model(model, data_loader, device, args):
|
||||
metrics_dict: 包含所有指标的字典
|
||||
sample_results: 示例结果用于可视化
|
||||
"""
|
||||
# model.eval()
|
||||
model.train() # 临时使用训练模式
|
||||
model.eval()
|
||||
# model.train() # 临时使用训练模式
|
||||
|
||||
# 初始化指标累加器
|
||||
total_mse = 0.0
|
||||
@@ -183,10 +196,11 @@ def evaluate_model(model, data_loader, device, args):
|
||||
target_frames = target_frames.to(device, non_blocking=True)
|
||||
|
||||
# 前向传播
|
||||
pred_frames, _ = model(input_frames)
|
||||
pred_frames = model(input_frames)
|
||||
|
||||
# 反归一化用于指标计算
|
||||
pred_denorm = denormalize(pred_frames) # [B, 1, H, W]
|
||||
# pred_denorm = minmax_denormalize(pred_frames) # [B, 1, H, W]
|
||||
pred_denorm = denormalize(pred_frames)
|
||||
target_denorm = denormalize(target_frames) # [B, 1, H, W]
|
||||
|
||||
batch_size = input_frames.size(0)
|
||||
@@ -309,8 +323,6 @@ def main(args):
|
||||
print(f"创建模型: {args.model}")
|
||||
model_kwargs = {
|
||||
'num_frames': args.num_frames,
|
||||
'use_representation_head': args.use_representation_head,
|
||||
'representation_dim': args.representation_dim,
|
||||
}
|
||||
|
||||
if args.model == 'SwiftFormerTemporal_XS':
|
||||
@@ -335,10 +347,10 @@ def main(args):
|
||||
except (pickle.UnpicklingError, TypeError) as e:
|
||||
print(f"使用weights_only=False加载失败: {e}")
|
||||
print("尝试使用torch.serialization.add_safe_globals...")
|
||||
from argparse import Namespace
|
||||
# 添加安全全局变量
|
||||
torch.serialization.add_safe_globals([Namespace])
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
# from argparse import Namespace
|
||||
# # 添加安全全局变量
|
||||
# torch.serialization.add_safe_globals([Namespace])
|
||||
# checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
|
||||
# 处理状态字典(可能包含'module.'前缀)
|
||||
if 'model' in checkpoint:
|
||||
@@ -462,10 +474,6 @@ def get_args_parser():
|
||||
# 模型参数
|
||||
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
|
||||
help='要评估的模型名称')
|
||||
parser.add_argument('--use-representation-head', action='store_true',
|
||||
help='使用表示头进行姿态/速度预测')
|
||||
parser.add_argument('--representation-dim', default=128, type=int,
|
||||
help='表示向量的维度')
|
||||
|
||||
# 评估参数
|
||||
parser.add_argument('--batch-size', default=16, type=int,
|
||||
|
||||
Reference in New Issue
Block a user