From 12de74f130f998cdd3f21f59d8a314cb10fe3a2b Mon Sep 17 00:00:00 2001 From: CaoWangrenbo Date: Fri, 9 Jan 2026 18:23:45 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E4=BA=86=E8=B7=B3=E8=BF=9E?= =?UTF-8?q?=E6=8E=A5=EF=BC=8C=E5=9C=A8=E4=B8=8Adecode=E5=9D=97=E5=90=8E?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E7=89=B9=E5=BE=81=E7=B2=BE=E7=82=BC=E5=B1=82?= =?UTF-8?q?=EF=BC=8C=E6=9C=AA=E6=B5=8B=E6=95=88=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dist_temporal_train.sh | 5 +- evaluate_temporal.py | 503 +++++++++++++++++++++++++++++++++ main_temporal.py | 82 +++++- models/swiftformer_temporal.py | 383 ++++++++++++++++++++----- multi_gpu_temporal_train.sh | 26 -- test_cuda.py | 45 --- test_import.py | 33 --- test_model.py | 60 ---- 8 files changed, 893 insertions(+), 244 deletions(-) create mode 100644 evaluate_temporal.py delete mode 100755 multi_gpu_temporal_train.sh delete mode 100644 test_cuda.py delete mode 100644 test_import.py delete mode 100644 test_model.py diff --git a/dist_temporal_train.sh b/dist_temporal_train.sh index ce10ceb..ac87d03 100755 --- a/dist_temporal_train.sh +++ b/dist_temporal_train.sh @@ -11,9 +11,10 @@ shift 2 # Default parameters MODEL=${MODEL:-"SwiftFormerTemporal_XS"} -BATCH_SIZE=${BATCH_SIZE:-256} +BATCH_SIZE=${BATCH_SIZE:-128} EPOCHS=${EPOCHS:-100} -LR=${LR:-1e-3} +# LR=${LR:-1e-3} +LR=${LR:-0.01} OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"} echo "Starting distributed training with $NUM_GPUS GPUs" diff --git a/evaluate_temporal.py b/evaluate_temporal.py new file mode 100644 index 0000000..3c7b2ae --- /dev/null +++ b/evaluate_temporal.py @@ -0,0 +1,503 @@ +""" +评估脚本 for SwiftFormerTemporal frame prediction +输出预测图(注意反归一化)以及对应指标(mse&ssim&psnr) +""" +import argparse +import os +import torch +import torch.nn as nn +import pickle +import numpy as np +import random +from pathlib import Path +import json +import matplotlib.pyplot as plt +from PIL import Image +import torchvision.transforms as transforms +import torch.backends.cudnn as cudnn + +from util.video_dataset import VideoFrameDataset +from models.swiftformer_temporal import ( + SwiftFormerTemporal_XS, SwiftFormerTemporal_S, + SwiftFormerTemporal_L1, SwiftFormerTemporal_L3 +) + +# 导入SSIM和PSNR计算 +from skimage.metrics import structural_similarity as ssim +from skimage.metrics import peak_signal_noise_ratio as psnr +import warnings +warnings.filterwarnings('ignore') + + +def denormalize(tensor): + """ + 将[-1, 1]范围的张量反归一化到[0, 255]范围 + Args: + tensor: 形状为[B, C, H, W]或[C, H, W],值在[-1, 1] + Returns: + 反归一化后的张量,值在[0, 255] + """ + # clip 到 [-1, 1] 范围 + tensor = tensor.clamp(-1, 1) + + # [-1, 1] -> [0, 1] + tensor = (tensor + 1) / 2 + # [0, 1] -> [0, 255] + tensor = tensor * 255 + return tensor.clamp(0, 255) + + +def calculate_metrics(pred, target, debug=False): + """ + 计算MSE, SSIM, PSNR指标 + Args: + pred: 预测图像,形状[H, W],值在[0, 255] + target: 目标图像,形状[H, W],值在[0, 255] + debug: 是否输出调试信息 + Returns: + mse, ssim_value, psnr_value + """ + # 转换为numpy数组 + pred_np = pred.cpu().numpy() if torch.is_tensor(pred) else pred + target_np = target.cpu().numpy() if torch.is_tensor(target) else target + + # 确保是2D数组 + if pred_np.ndim == 3: + pred_np = pred_np.squeeze(0) + if target_np.ndim == 3: + target_np = target_np.squeeze(0) + + if debug: + print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}") + print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}") + print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}") + + # 计算MSE - 修复错误的tmp公式 + # 原错误公式: tmp = 1 - (pred_np - target_np) / 255 * 2 + # 正确公式: 直接计算像素差的平方 + mse = np.mean((pred_np - target_np) ** 2) + + # 同时计算错误公式的MSE用于对比 + tmp = 1 - (pred_np - target_np) / 255 * 2 + wrong_mse = np.mean(tmp**2) + + if debug: + print(f"[DEBUG] Correct MSE: {mse:.6f}, Wrong MSE (tmp formula): {wrong_mse:.6f}") + + # 计算SSIM (数据范围0-255) + data_range = 255.0 + ssim_value = ssim(pred_np, target_np, data_range=data_range) + + # 计算PSNR + psnr_value = psnr(target_np, pred_np, data_range=data_range) + + return mse, ssim_value, psnr_value + + +def save_comparison_figure(input_frames, target_frame, pred_frame, save_path, + input_frame_indices=None, target_frame_index=None): + """ + 保存对比图:输入帧、目标帧、预测帧 + Args: + input_frames: 输入帧列表,每个形状为[H, W],值在[0, 255] + target_frame: 目标帧,形状[H, W],值在[0, 255] + pred_frame: 预测帧,形状[H, W],值在[0, 255] + save_path: 保存路径 + input_frame_indices: 输入帧的索引列表(可选) + target_frame_index: 目标帧索引(可选) + """ + num_input = len(input_frames) + fig, axes = plt.subplots(1, num_input + 2, figsize=(4*(num_input+2), 4)) + + # 绘制输入帧 + for i in range(num_input): + ax = axes[i] + ax.imshow(input_frames[i], cmap='gray') + if input_frame_indices is not None: + ax.set_title(f'Input Frame {input_frame_indices[i]}') + else: + ax.set_title(f'Input {i+1}') + ax.axis('off') + + # 绘制目标帧 + ax = axes[num_input] + ax.imshow(target_frame, cmap='gray') + if target_frame_index is not None: + ax.set_title(f'Target Frame {target_frame_index}') + else: + ax.set_title('Target') + ax.axis('off') + + # 绘制预测帧 + ax = axes[num_input + 1] + ax.imshow(pred_frame, cmap='gray') + ax.set_title('Predicted') + ax.axis('off') + + # debug print - 改进为更有信息量的输出 + if isinstance(pred_frame, np.ndarray): + print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}") + # 检查是否有大量值在127.5附近 + mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0 + percent_near_127_5 = np.mean(mask_near_127_5) * 100 + print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%") + else: + print(pred_frame) + + plt.tight_layout() + plt.savefig(save_path, dpi=150, bbox_inches='tight') + plt.close() + + +def evaluate_model(model, data_loader, device, args): + """ + 评估模型并计算指标 + Args: + model: 训练好的模型 + data_loader: 数据加载器 + device: 设备 + args: 命令行参数 + Returns: + metrics_dict: 包含所有指标的字典 + sample_results: 示例结果用于可视化 + """ + # model.eval() + model.train() # 临时使用训练模式 + + # 初始化指标累加器 + total_mse = 0.0 + total_ssim = 0.0 + total_psnr = 0.0 + total_samples = 0 + + # 存储示例结果用于可视化(使用蓄水池抽样随机选择) + sample_results = [] + max_samples_to_save = args.num_samples_to_save + max_samples = args.max_samples + # 用于蓄水池抽样的计数器(已处理的样本数,不包括因max_samples限制而跳过的样本) + sample_count = 0 + + with torch.no_grad(): + for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(data_loader): + input_frames = input_frames.to(device, non_blocking=True) + target_frames = target_frames.to(device, non_blocking=True) + + # 前向传播 + pred_frames, _ = model(input_frames) + + # 反归一化用于指标计算 + pred_denorm = denormalize(pred_frames) # [B, 1, H, W] + target_denorm = denormalize(target_frames) # [B, 1, H, W] + + batch_size = input_frames.size(0) + + # 计算每个样本的指标 + for i in range(batch_size): + # 检查是否达到最大样本数限制 + if max_samples is not None and total_samples >= max_samples: + break + + pred_i = pred_denorm[i] # [1, H, W] + target_i = target_denorm[i] # [1, H, W] + + # 对第一个样本启用调试 + debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0) + if debug_mode: + print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}") + print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}") + print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}") + print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}") + + mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=debug_mode) + + total_mse += mse + total_ssim += ssim_value + total_psnr += psnr_value + total_samples += 1 + sample_count += 1 + + # 构建样本数据字典 + input_denorm = denormalize(input_frames[i]) # [num_frames, H, W] + # 分离输入帧 + input_frames_list = [] + for j in range(args.num_frames): + input_frame_j = input_denorm[j].squeeze(0) # [H, W] + input_frames_list.append(input_frame_j.cpu().numpy()) + + sample_data = { + 'input_frames': input_frames_list, + 'target_frame': target_i.squeeze(0).cpu().numpy(), + 'pred_frame': pred_i.squeeze(0).cpu().numpy(), + 'metrics': { + 'mse': mse, + 'ssim': ssim_value, + 'psnr': psnr_value + }, + 'batch_idx': batch_idx, + 'sample_idx': i + } + + # 蓄水池抽样 (Reservoir Sampling) + if sample_count <= max_samples_to_save: + # 蓄水池未满,直接加入 + sample_results.append(sample_data) + else: + # 以 max_samples_to_save / sample_count 的概率替换蓄水池中的一个随机位置 + r = random.randint(0, sample_count - 1) + if r < max_samples_to_save: + sample_results[r] = sample_data + + # 检查是否达到最大样本数限制 + if max_samples is not None and total_samples >= max_samples: + print(f"达到最大样本数限制: {max_samples}") + break + + # 进度打印 + if (batch_idx + 1) % 10 == 0: + print(f'Processed {batch_idx + 1} batches, {total_samples} samples') + + # 计算平均指标 + if total_samples > 0: + avg_mse = float(total_mse / total_samples) + avg_ssim = float(total_ssim / total_samples) + avg_psnr = float(total_psnr / total_samples) + else: + avg_mse = avg_ssim = avg_psnr = 0.0 + + metrics_dict = { + 'mse': avg_mse, + 'ssim': avg_ssim, + 'psnr': avg_psnr, + 'num_samples': total_samples + } + + return metrics_dict, sample_results + + +def main(args): + print("评估参数:", args) + + device = torch.device(args.device) + + # 设置随机种子 + torch.manual_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + cudnn.benchmark = True + + # 构建数据集 + print("构建数据集...") + dataset_val = VideoFrameDataset( + root_dir=args.data_path, + num_frames=args.num_frames, + frame_size=args.frame_size, + is_train=False, + max_interval=args.max_interval + ) + + data_loader_val = torch.utils.data.DataLoader( + dataset_val, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=args.pin_mem, + shuffle=False, + drop_last=False + ) + + # 创建模型 + print(f"创建模型: {args.model}") + model_kwargs = { + 'num_frames': args.num_frames, + 'use_representation_head': args.use_representation_head, + 'representation_dim': args.representation_dim, + } + + if args.model == 'SwiftFormerTemporal_XS': + model = SwiftFormerTemporal_XS(**model_kwargs) + elif args.model == 'SwiftFormerTemporal_S': + model = SwiftFormerTemporal_S(**model_kwargs) + elif args.model == 'SwiftFormerTemporal_L1': + model = SwiftFormerTemporal_L1(**model_kwargs) + elif args.model == 'SwiftFormerTemporal_L3': + model = SwiftFormerTemporal_L3(**model_kwargs) + else: + raise ValueError(f"未知模型: {args.model}") + + model.to(device) + + # 加载检查点 + if args.resume: + print(f"加载检查点: {args.resume}") + try: + # 尝试使用weights_only=False加载(PyTorch 2.6+需要) + checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False) + except (pickle.UnpicklingError, TypeError) as e: + print(f"使用weights_only=False加载失败: {e}") + print("尝试使用torch.serialization.add_safe_globals...") + from argparse import Namespace + # 添加安全全局变量 + torch.serialization.add_safe_globals([Namespace]) + checkpoint = torch.load(args.resume, map_location='cpu') + + # 处理状态字典(可能包含'module.'前缀) + if 'model' in checkpoint: + state_dict = checkpoint['model'] + elif 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + # 移除'module.'前缀(如果存在) + if hasattr(model, 'module'): + model.module.load_state_dict(state_dict) + else: + # 如果状态字典有'module.'前缀但模型没有,需要移除前缀 + if any(key.startswith('module.') for key in state_dict.keys()): + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('module.'): + new_state_dict[k[7:]] = v + else: + new_state_dict[k] = v + state_dict = new_state_dict + model.load_state_dict(state_dict) + + print(f"检查点加载成功,epoch: {checkpoint.get('epoch', 'unknown')}") + else: + print("警告: 未提供检查点路径,使用随机初始化的模型") + + # 创建输出目录 + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # 评估模型 + print("开始评估...") + metrics, sample_results = evaluate_model(model, data_loader_val, device, args) + + # 打印指标 + print("\n" + "="*50) + print("评估结果:") + print(f"MSE: {metrics['mse']:.6f}") + print(f"SSIM: {metrics['ssim']:.6f}") + print(f"PSNR: {metrics['psnr']:.6f} dB") + print(f"样本数量: {metrics['num_samples']}") + print("="*50) + + # 保存指标到JSON文件 + metrics_file = output_dir / 'evaluation_metrics.json' + with open(metrics_file, 'w') as f: + json.dump(metrics, f, indent=4) + print(f"指标已保存到: {metrics_file}") + + # 保存示例可视化 + if sample_results: + print(f"\n保存 {len(sample_results)} 个示例可视化...") + samples_dir = output_dir / 'sample_predictions' + samples_dir.mkdir(exist_ok=True) + + for i, sample in enumerate(sample_results): + save_path = samples_dir / f'sample_{i:03d}.png' + + # 生成输入帧索引(假设连续) + input_frame_indices = list(range(1, args.num_frames + 1)) + target_frame_index = args.num_frames + 1 + + save_comparison_figure( + sample['input_frames'], + sample['target_frame'], + sample['pred_frame'], + save_path, + input_frame_indices=input_frame_indices, + target_frame_index=target_frame_index + ) + + # 保存该样本的指标 + sample_metrics_file = samples_dir / f'sample_{i:03d}_metrics.txt' + with open(sample_metrics_file, 'w') as f: + f.write(f"Sample {i} (batch {sample['batch_idx']}, idx {sample['sample_idx']})\n") + f.write(f"MSE: {sample['metrics']['mse']:.6f}\n") + f.write(f"SSIM: {sample['metrics']['ssim']:.6f}\n") + f.write(f"PSNR: {sample['metrics']['psnr']:.6f} dB\n") + + print(f"示例可视化已保存到: {samples_dir}") + + # 生成汇总报告 + report_file = output_dir / 'evaluation_report.txt' + with open(report_file, 'w') as f: + f.write("SwiftFormerTemporal 帧预测评估报告\n") + f.write("="*50 + "\n") + f.write(f"模型: {args.model}\n") + f.write(f"检查点: {args.resume}\n") + f.write(f"数据集: {args.data_path}\n") + f.write(f"输入帧数: {args.num_frames}\n") + f.write(f"帧大小: {args.frame_size}\n") + f.write(f"批次大小: {args.batch_size}\n") + f.write(f"样本总数: {metrics['num_samples']}\n\n") + + f.write("评估指标:\n") + f.write(f" MSE: {metrics['mse']:.6f}\n") + f.write(f" SSIM: {metrics['ssim']:.6f}\n") + f.write(f" PSNR: {metrics['psnr']:.6f} dB\n") + + print(f"评估报告已保存到: {report_file}") + print("\n评估完成!") + + +def get_args_parser(): + parser = argparse.ArgumentParser( + 'SwiftFormerTemporal 评估脚本', add_help=False) + + # 数据集参数 + parser.add_argument('--data-path', default='./videos', type=str, + help='视频数据集路径') + parser.add_argument('--num-frames', default=3, type=int, + help='输入帧数 (T)') + parser.add_argument('--frame-size', default=224, type=int, + help='输入帧大小') + parser.add_argument('--max-interval', default=4, type=int, + help='连续帧之间的最大间隔') + + # 模型参数 + parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL', + help='要评估的模型名称') + parser.add_argument('--use-representation-head', action='store_true', + help='使用表示头进行姿态/速度预测') + parser.add_argument('--representation-dim', default=128, type=int, + help='表示向量的维度') + + # 评估参数 + parser.add_argument('--batch-size', default=16, type=int, + help='评估批次大小') + parser.add_argument('--num-samples-to-save', default=10, type=int, + help='保存可视化的样本数量') + parser.add_argument('--max-samples', default=None, type=int, + help='最大评估样本数(None表示全部)') + + # 系统参数 + parser.add_argument('--output-dir', default='./evaluation_results', + help='保存结果的路径') + parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', + help='使用的设备') + parser.add_argument('--seed', default=0, type=int) + parser.add_argument('--resume', default='', help='检查点路径') + parser.add_argument('--num-workers', default=4, type=int) + parser.add_argument('--pin-mem', action='store_true', + help='在DataLoader中固定CPU内存') + parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem') + parser.set_defaults(pin_mem=True) + + return parser + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + 'SwiftFormerTemporal 评估', parents=[get_args_parser()]) + args = parser.parse_args() + + # 确保输出目录存在 + if args.output_dir: + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + main(args) \ No newline at end of file diff --git a/main_temporal.py b/main_temporal.py index d1fab34..57cfb3c 100644 --- a/main_temporal.py +++ b/main_temporal.py @@ -57,6 +57,7 @@ def get_args_parser(): help='Use representation head for pose/velocity prediction') parser.add_argument('--representation-dim', default=128, type=int, help='Dimension of representation vector') + parser.add_argument('--use-skip', default=True, type=bool, help='using skip connections') # Training parameters parser.add_argument('--batch-size', default=32, type=int) @@ -77,7 +78,7 @@ def get_args_parser(): help='SGD momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=0.05, help='weight decay (default: 0.05)') - parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', + parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 1e-3)') # Learning rate schedule parameters (required by timm's create_scheduler) @@ -89,7 +90,7 @@ def get_args_parser(): help='learning rate noise limit percent (default: 0.67)') parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') - parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', + parser.add_argument('--warmup-lr', type=float, default=1e-3, metavar='LR', help='warmup learning rate (default: 1e-6)') parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') @@ -212,6 +213,7 @@ def main(args): 'num_frames': args.num_frames, 'use_representation_head': args.use_representation_head, 'representation_dim': args.representation_dim, + 'use_skip': args.use_skip, } if args.model == 'SwiftFormerTemporal_XS': @@ -373,6 +375,11 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = f'Epoch: [{epoch}]' print_freq = 10 + + # 添加诊断指标 + metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate( metric_logger.log_every(data_loader, print_freq, header)): @@ -382,7 +389,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los temporal_indices = temporal_indices.to(device, non_blocking=True) # Forward pass - with torch.cuda.amp.autocast(): + with torch.amp.autocast(device_type='cuda'): pred_frames, representations = model(input_frames) loss, loss_dict = criterion( pred_frames, target_frames, @@ -395,6 +402,8 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los raise ValueError(f"Loss is {loss_value}") optimizer.zero_grad() + + # 在反向传播前保存梯度用于诊断 loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode, parameters=model.parameters()) @@ -402,6 +411,30 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los if model_ema is not None: model_ema.update(model) + # 计算诊断指标 + pred_mean = pred_frames.mean().item() + pred_std = pred_frames.std().item() + + # 计算梯度范数 + total_grad_norm = 0.0 + for param in model.parameters(): + if param.grad is not None: + total_grad_norm += param.grad.norm().item() + + # 记录诊断指标 + metric_logger.update(pred_mean=pred_mean) + metric_logger.update(pred_std=pred_std) + metric_logger.update(grad_norm=total_grad_norm) + + # 每50个批次打印一次BatchNorm统计 + if batch_idx % 50 == 0: + print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}") + # 检查一个BatchNorm层的运行统计 + for name, module in model.named_modules(): + if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name: + print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}") + break + # Log to TensorBoard if writer is not None: # Log scalar metrics every iteration @@ -415,6 +448,11 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los else: writer.add_scalar(f'train/{k}', v, global_step) + # Log diagnostic metrics + writer.add_scalar('train/pred_mean', pred_mean, global_step) + writer.add_scalar('train/pred_std', pred_std, global_step) + writer.add_scalar('train/grad_norm', total_grad_norm, global_step) + # Log images periodically if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0: with torch.no_grad(): @@ -450,20 +488,54 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0): model.eval() metric_logger = utils.MetricLogger(delimiter=" ") header = 'Test:' + + # 添加诊断指标 + metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('target_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + metric_logger.add_meter('target_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) - for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header): + for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(metric_logger.log_every(data_loader, 10, header)): input_frames = input_frames.to(device, non_blocking=True) target_frames = target_frames.to(device, non_blocking=True) temporal_indices = temporal_indices.to(device, non_blocking=True) # Compute output - with torch.cuda.amp.autocast(): + with torch.amp.autocast(device_type='cuda'): pred_frames, representations = model(input_frames) loss, loss_dict = criterion( pred_frames, target_frames, representations, temporal_indices ) + # 计算诊断指标 + pred_mean = pred_frames.mean().item() + pred_std = pred_frames.std().item() + target_mean = target_frames.mean().item() + target_std = target_frames.std().item() + + # 更新诊断指标 + metric_logger.update(pred_mean=pred_mean) + metric_logger.update(pred_std=pred_std) + metric_logger.update(target_mean=target_mean) + metric_logger.update(target_std=target_std) + + # 第一个批次打印详细诊断信息 + if batch_idx == 0: + print(f"[评估诊断] 批次 0:") + print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]") + print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}") + print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]") + print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}") + + # 检查BatchNorm运行统计 + for name, module in model.named_modules(): + if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name: + print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}") + if module.running_var[0].item() < 1e-6: + print(f" 警告: BatchNorm运行方差接近零!") + break + # Update metrics metric_logger.update(loss=loss.item()) for k, v in loss_dict.items(): diff --git a/models/swiftformer_temporal.py b/models/swiftformer_temporal.py index 3cac757..e91b569 100644 --- a/models/swiftformer_temporal.py +++ b/models/swiftformer_temporal.py @@ -11,26 +11,188 @@ from timm.layers import DropPath, trunc_normal_ class DecoderBlock(nn.Module): - """Upsampling block for frame prediction decoder""" + """Upsampling block for frame prediction decoder with residual connections""" def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1): super().__init__() - self.conv = nn.ConvTranspose2d( + # 主路径:反卷积 + 两个卷积层 + self.conv_transpose = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, - bias=False + bias=True # 启用bias,因为移除了BN ) - self.bn = nn.BatchNorm2d(out_channels) - self.relu = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(out_channels, out_channels, + kernel_size=3, padding=1, bias=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, + kernel_size=3, padding=1, bias=True) + + # 残差路径:如果需要改变通道数或空间尺寸 + self.shortcut = nn.Identity() + if in_channels != out_channels or stride != 1: + # 使用1x1卷积调整通道数,如果需要上采样则使用反卷积 + if stride == 1: + self.shortcut = nn.Conv2d(in_channels, out_channels, + kernel_size=1, bias=True) + else: + self.shortcut = nn.ConvTranspose2d( + in_channels, out_channels, + kernel_size=1, + stride=stride, + padding=0, + output_padding=output_padding, + bias=True + ) + + # 使用LeakyReLU避免死亡神经元 + self.activation = nn.LeakyReLU(0.2, inplace=True) + + # 初始化权重 + self._init_weights() + + def _init_weights(self): + # 初始化反卷积层 + nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu') + if self.conv_transpose.bias is not None: + nn.init.constant_(self.conv_transpose.bias, 0) + + # 初始化卷积层 + nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu') + if self.conv1.bias is not None: + nn.init.constant_(self.conv1.bias, 0) + + nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu') + if self.conv2.bias is not None: + nn.init.constant_(self.conv2.bias, 0) + + # 初始化shortcut + if not isinstance(self.shortcut, nn.Identity): + if isinstance(self.shortcut, nn.Conv2d): + nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(self.shortcut, nn.ConvTranspose2d): + nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu') + if self.shortcut.bias is not None: + nn.init.constant_(self.shortcut.bias, 0) def forward(self, x): - return self.relu(self.bn(self.conv(x))) + identity = self.shortcut(x) + + # 主路径 + x = self.conv_transpose(x) + x = self.activation(x) + + x = self.conv1(x) + x = self.activation(x) + + x = self.conv2(x) + + # 残差连接 + x = x + identity + x = self.activation(x) + return x + + +class DecoderBlockWithSkip(nn.Module): + """Decoder block with skip connection support""" + def __init__(self, in_channels, out_channels, skip_channels=0, kernel_size=3, stride=2, padding=1, output_padding=1): + super().__init__() + # 总输入通道 = 输入通道 + skip通道 + total_in_channels = in_channels + skip_channels + + # 主路径:反卷积 + 两个卷积层 + self.conv_transpose = nn.ConvTranspose2d( + total_in_channels, out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + bias=True + ) + self.conv1 = nn.Conv2d(out_channels, out_channels, + kernel_size=3, padding=1, bias=True) + self.conv2 = nn.Conv2d(out_channels, out_channels, + kernel_size=3, padding=1, bias=True) + + # 残差路径:如果需要改变通道数或空间尺寸 + self.shortcut = nn.Identity() + if total_in_channels != out_channels or stride != 1: + if stride == 1: + self.shortcut = nn.Conv2d(total_in_channels, out_channels, + kernel_size=1, bias=True) + else: + self.shortcut = nn.ConvTranspose2d( + total_in_channels, out_channels, + kernel_size=1, + stride=stride, + padding=0, + output_padding=output_padding, + bias=True + ) + + # 使用LeakyReLU避免死亡神经元 + self.activation = nn.LeakyReLU(0.2, inplace=True) + + # 初始化权重 + self._init_weights() + + def _init_weights(self): + # 初始化反卷积层 + nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu') + if self.conv_transpose.bias is not None: + nn.init.constant_(self.conv_transpose.bias, 0) + + # 初始化卷积层 + nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu') + if self.conv1.bias is not None: + nn.init.constant_(self.conv1.bias, 0) + + nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu') + if self.conv2.bias is not None: + nn.init.constant_(self.conv2.bias, 0) + + # 初始化shortcut + if not isinstance(self.shortcut, nn.Identity): + if isinstance(self.shortcut, nn.Conv2d): + nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(self.shortcut, nn.ConvTranspose2d): + nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu') + if self.shortcut.bias is not None: + nn.init.constant_(self.shortcut.bias, 0) + + def forward(self, x, skip_feature=None): + # 如果有skip feature,将其与输入拼接 + if skip_feature is not None: + # 确保skip特征的空间尺寸与x匹配 + if skip_feature.shape[2:] != x.shape[2:]: + # 使用双线性插值进行上采样或下采样 + skip_feature = torch.nn.functional.interpolate( + skip_feature, + size=x.shape[2:], + mode='bilinear', + align_corners=False + ) + x = torch.cat([x, skip_feature], dim=1) + + identity = self.shortcut(x) + + # 主路径 + x = self.conv_transpose(x) + x = self.activation(x) + + x = self.conv1(x) + x = self.activation(x) + + x = self.conv2(x) + + # 残差连接 + x = x + identity + x = self.activation(x) + return x class FramePredictionDecoder(nn.Module): - """Lightweight decoder for frame prediction with optional skip connections""" + """Improved decoder for frame prediction with better upsampling strategy""" def __init__(self, embed_dims, output_channels=1, use_skip=False): super().__init__() self.use_skip = use_skip @@ -38,65 +200,109 @@ class FramePredictionDecoder(nn.Module): decoder_dims = embed_dims[::-1] self.blocks = nn.ModuleList() - # First upsampling from bottleneck to stage4 resolution - self.blocks.append(DecoderBlock( - decoder_dims[0], decoder_dims[1], - kernel_size=3, stride=2, padding=1, output_padding=1 - )) - # stage4 to stage3 - self.blocks.append(DecoderBlock( - decoder_dims[1], decoder_dims[2], - kernel_size=3, stride=2, padding=1, output_padding=1 - )) - # stage3 to stage2 - self.blocks.append(DecoderBlock( - decoder_dims[2], decoder_dims[3], - kernel_size=3, stride=2, padding=1, output_padding=1 - )) - # stage2 to original resolution (now 8x upsampling total with stride 4) - self.blocks.append(nn.Sequential( - nn.ConvTranspose2d( - decoder_dims[3], 32, - kernel_size=3, stride=4, padding=1, output_padding=3 - ), - nn.BatchNorm2d(32), - nn.ReLU(inplace=True), - nn.Conv2d(32, output_channels, kernel_size=3, padding=1), - nn.Tanh() # Output in [-1, 1] range - )) - # If using skip connections, we need to adjust input channels for each block if use_skip: - # We'll modify the first three blocks to accept concatenated features - # Instead of modifying existing blocks, we'll replace them with custom blocks - # For simplicity, we'll keep the same architecture but forward will handle concatenation - pass + # 使用支持skip connections的block + # 第一个block:从bottleneck到stage4,使用大步长stride=4,skip来自stage3 + self.blocks.append(DecoderBlockWithSkip( + decoder_dims[0], decoder_dims[1], + skip_channels=embed_dims[3], # stage3的通道数 + kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4 + )) + # 第二个block:stage4到stage3,stride=2,skip来自stage2 + self.blocks.append(DecoderBlockWithSkip( + decoder_dims[1], decoder_dims[2], + skip_channels=embed_dims[2], # stage2的通道数 + kernel_size=3, stride=2, padding=1, output_padding=1 + )) + # 第三个block:stage3到stage2,stride=2,skip来自stage1 + self.blocks.append(DecoderBlockWithSkip( + decoder_dims[2], decoder_dims[3], + skip_channels=embed_dims[1], # stage1的通道数 + kernel_size=3, stride=2, padding=1, output_padding=1 + )) + # 第四个block:stage2到stage1,stride=2,skip来自stage0 + self.blocks.append(DecoderBlockWithSkip( + decoder_dims[3], 64, # 输出到64通道 + skip_channels=embed_dims[0], # stage0的通道数 + kernel_size=3, stride=2, padding=1, output_padding=1 + )) + else: + # 使用普通的DecoderBlock,第一个block使用大步长 + self.blocks.append(DecoderBlock( + decoder_dims[0], decoder_dims[1], + kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4 + )) + self.blocks.append(DecoderBlock( + decoder_dims[1], decoder_dims[2], + kernel_size=3, stride=2, padding=1, output_padding=1 + )) + self.blocks.append(DecoderBlock( + decoder_dims[2], decoder_dims[3], + kernel_size=3, stride=2, padding=1, output_padding=1 + )) + # 第四个block:增加到64通道 + self.blocks.append(DecoderBlock( + decoder_dims[3], 64, + kernel_size=3, stride=2, padding=1, output_padding=1 + )) + + # 改进的最终输出层:不使用反卷积,只进行特征精炼 + # 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合 + self.final_block = nn.Sequential( + nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True) + # 移除Tanh,让输出在任意范围,由损失函数和归一化处理 + ) def forward(self, x, skip_features=None): """ Args: x: input tensor of shape [B, embed_dims[-1], H/32, W/32] - skip_features: list of encoder features from stages [stage2, stage1, stage0] - each of shape [B, C, H', W'] where C matches decoder dims? + skip_features: list of encoder features from stages [stage3, stage2, stage1, stage0] + each of shape [B, C, H', W'] where C matches encoder dims """ - if self.use_skip and skip_features is not None: - # Ensure we have exactly 3 skip features (for the first three blocks) - assert len(skip_features) == 3, "Need 3 skip features for skip connections" - # Reverse skip_features to match decoder order: stage2, stage1, stage0 - # skip_features[0] should be stage2 (H/16), [1] stage1 (H/8), [2] stage0 (H/4) - skip_features = skip_features[::-1] # Now index 0: stage2, 1: stage1, 2: stage0 + if self.use_skip: + if skip_features is None: + raise ValueError("skip_features must be provided when use_skip=True") + + # 确保有4个skip features + assert len(skip_features) == 4, f"Need 4 skip features, got {len(skip_features)}" + + # 反转顺序以匹配解码器:stage3, stage2, stage1, stage0 + skip_features = skip_features[::-1] + + # 调整skip特征的尺寸以匹配新的上采样策略 + adjusted_skip_features = [] + for i, skip in enumerate(skip_features): + if skip is not None: + # 计算目标尺寸:4, 2, 2, 2倍上采样 + upsample_factors = [4, 2, 2, 2] + target_height = x.shape[2] * upsample_factors[i] + target_width = x.shape[3] * upsample_factors[i] + + if skip.shape[2:] != (target_height, target_width): + skip = torch.nn.functional.interpolate( + skip, + size=(target_height, target_width), + mode='bilinear', + align_corners=False + ) + adjusted_skip_features.append(skip) + + # 四个block使用skip connections + for i in range(4): + x = self.blocks[i](x, adjusted_skip_features[i]) + else: + # 不使用skip connections + for i in range(4): + x = self.blocks[i](x) - for i, block in enumerate(self.blocks): - if self.use_skip and skip_features is not None and i < 3: - # Concatenate skip feature along channel dimension - # Ensure spatial dimensions match (they should because of upsampling) - x = torch.cat([x, skip_features[i]], dim=1) - # Need to adjust block to accept extra channels? We'll create a separate block. - # For now, we'll just pass through, but this will cause channel mismatch. - # Instead, we should have created custom blocks with appropriate in_channels. - # This is a placeholder; we need to implement properly. - pass - x = block(x) + # 最终输出层:只进行特征精炼,不上采样 + x = self.final_block(x) return x @@ -106,10 +312,11 @@ class SwiftFormerTemporal(nn.Module): Input: [B, num_frames, H, W] (Y channel only) Output: predicted frame [B, 1, H, W] and optional representation """ - def __init__(self, + def __init__(self, model_name='XS', num_frames=3, use_decoder=True, + use_skip=True, # 新增:是否使用skip connections use_representation_head=False, representation_dim=128, return_features=False, @@ -123,6 +330,7 @@ class SwiftFormerTemporal(nn.Module): # Store configuration self.num_frames = num_frames self.use_decoder = use_decoder + self.use_skip = use_skip # 保存skip connections设置 self.use_representation_head = use_representation_head self.return_features = return_features @@ -155,7 +363,11 @@ class SwiftFormerTemporal(nn.Module): # Frame prediction decoder if use_decoder: - self.decoder = FramePredictionDecoder(embed_dims, output_channels=1) + self.decoder = FramePredictionDecoder( + embed_dims, + output_channels=1, + use_skip=use_skip # 传递skip connections设置 + ) # Representation head for pose/velocity prediction if use_representation_head: @@ -173,22 +385,31 @@ class SwiftFormerTemporal(nn.Module): def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): - trunc_normal_(m.weight, std=.02) + # 使用Kaiming初始化,适合ReLU/LeakyReLU + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') if m.bias is not None: nn.init.constant_(m.bias, 0) - elif isinstance(m, (nn.LayerNorm)): + elif isinstance(m, nn.ConvTranspose2d): + # 反卷积层使用特定的初始化 + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward_tokens(self, x): """Forward through encoder network, return list of stage features if return_features else final output""" - if self.return_features: + if self.return_features or self.use_skip: features = [] + stage_idx = 0 for idx, block in enumerate(self.network): x = block(x) - # Collect output after each stage (indices 0,2,4,6 correspond to stages) + # 收集每个stage的输出(stage0, stage1, stage2, stage3) + # 根据SwiftFormer结构,stage在索引0,2,4,6位置 if idx in [0, 2, 4, 6]: features.append(x) + stage_idx += 1 return x, features else: for block in self.network: @@ -208,7 +429,7 @@ class SwiftFormerTemporal(nn.Module): """ # Encode x = self.patch_embed(x) - if self.return_features: + if self.return_features or self.use_skip: x, features = self.forward_tokens(x) else: x = self.forward_tokens(x) @@ -222,7 +443,23 @@ class SwiftFormerTemporal(nn.Module): # Decode to frame pred_frame = None if self.use_decoder: - pred_frame = self.decoder(x) + if self.use_skip: + # 提取用于skip connections的特征 + # features包含所有stage的输出,我们需要stage0, stage1, stage2, stage3 + # 根据SwiftFormer结构,应该有4个stage特征 + if len(features) >= 4: + # 取四个stage的特征:stage0, stage1, stage2, stage3 + skip_features = [features[0], features[1], features[2], features[3]] + else: + # 如果特征不够,使用可用的特征 + skip_features = features[:4] + # 如果特征仍然不够,使用None填充 + while len(skip_features) < 4: + skip_features.append(None) + + pred_frame = self.decoder(x, skip_features) + else: + pred_frame = self.decoder(x) if self.return_features: return pred_frame, representation, features @@ -231,14 +468,14 @@ class SwiftFormerTemporal(nn.Module): # Factory functions for different model sizes -def SwiftFormerTemporal_XS(num_frames=3, **kwargs): - return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs) +def SwiftFormerTemporal_XS(num_frames=3, use_skip=True, **kwargs): + return SwiftFormerTemporal('XS', num_frames=num_frames, use_skip=use_skip, **kwargs) -def SwiftFormerTemporal_S(num_frames=3, **kwargs): - return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs) +def SwiftFormerTemporal_S(num_frames=3, use_skip=True, **kwargs): + return SwiftFormerTemporal('S', num_frames=num_frames, use_skip=use_skip, **kwargs) -def SwiftFormerTemporal_L1(num_frames=3, **kwargs): - return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs) +def SwiftFormerTemporal_L1(num_frames=3, use_skip=True, **kwargs): + return SwiftFormerTemporal('l1', num_frames=num_frames, use_skip=use_skip, **kwargs) -def SwiftFormerTemporal_L3(num_frames=3, **kwargs): - return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs) \ No newline at end of file +def SwiftFormerTemporal_L3(num_frames=3, use_skip=True, **kwargs): + return SwiftFormerTemporal('l3', num_frames=num_frames, use_skip=use_skip, **kwargs) \ No newline at end of file diff --git a/multi_gpu_temporal_train.sh b/multi_gpu_temporal_train.sh deleted file mode 100755 index 2ee1403..0000000 --- a/multi_gpu_temporal_train.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env bash - -# Simple multi-GPU training script for SwiftFormerTemporal -# Usage: ./multi_gpu_temporal_train.sh [OPTIONS] - -NUM_GPUS=${1:-2} -shift - -echo "Starting multi-GPU training with $NUM_GPUS GPUs" - -# Set environment variables for distributed training -export MASTER_PORT=12345 -export MASTER_ADDR=localhost -export WORLD_SIZE=$NUM_GPUS - -# Launch training -torchrun --nproc_per_node=$NUM_GPUS --master_port=$MASTER_PORT main_temporal.py \ - --data-path "./videos" \ - --model SwiftFormerTemporal_XS \ - --batch-size 32 \ - --epochs 100 \ - --lr 1e-3 \ - --output-dir "./temporal_output_multi" \ - --num-workers 8 \ - --pin-mem \ - "$@" \ No newline at end of file diff --git a/test_cuda.py b/test_cuda.py deleted file mode 100644 index 7dfa3aa..0000000 --- a/test_cuda.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch - -def test_cuda_availability(): - """全面测试CUDA可用性""" - - print("="*50) - print("PyTorch CUDA 测试") - print("="*50) - - # 基本信息 - print(f"PyTorch版本: {torch.__version__}") - print(f"CUDA可用: {torch.cuda.is_available()}") - - if not torch.cuda.is_available(): - print("CUDA不可用,可能原因:") - print("1. 未安装CUDA驱动") - print("2. 安装的是CPU版本的PyTorch") - print("3. CUDA版本与PyTorch不匹配") - return False - - # 设备信息 - device_count = torch.cuda.device_count() - print(f"发现 {device_count} 个CUDA设备") - - for i in range(device_count): - print(f"\n设备 {i}:") - print(f" 名称: {torch.cuda.get_device_name(i)}") - print(f" 内存总量: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB") - print(f" 计算能力: {torch.cuda.get_device_properties(i).major}.{torch.cuda.get_device_properties(i).minor}") - - # 简单张量测试 - print("\n运行CUDA测试...") - try: - x = torch.randn(3, 3).cuda() - y = torch.randn(3, 3).cuda() - z = x + y - print("CUDA计算测试: 成功!") - print(f"设备上的张量形状: {z.shape}") - return True - except Exception as e: - print(f"CUDA计算测试失败: {e}") - return False - -if __name__ == "__main__": - test_cuda_availability() \ No newline at end of file diff --git a/test_import.py b/test_import.py deleted file mode 100644 index 3d764b6..0000000 --- a/test_import.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python3 -""" -测试 timm 导入是否正常工作 -""" -import sys -print("Python version:", sys.version) - -try: - from timm.layers import to_2tuple, DropPath, trunc_normal_ - from timm.models import register_model - print("✓ 成功导入 timm.layers.to_2tuple") - print("✓ 成功导入 timm.layers.DropPath") - print("✓ 成功导入 timm.layers.trunc_normal_") - print("✓ 成功导入 timm.models.register_model") -except ImportError as e: - print(f"✗ 导入失败: {e}") - sys.exit(1) - -try: - from models.swiftformer import SwiftFormer_XS - print("✓ 成功导入 SwiftFormer_XS") -except ImportError as e: - print(f"✗ 导入 SwiftFormer_XS 失败: {e}") - sys.exit(1) - -try: - from models.swiftformer_temporal import SwiftFormerTemporal_XS - print("✓ 成功导入 SwiftFormerTemporal_XS") -except ImportError as e: - print(f"✗ 导入 SwiftFormerTemporal_XS 失败: {e}") - sys.exit(1) - -print("\n✅ 所有导入测试通过!") \ No newline at end of file diff --git a/test_model.py b/test_model.py deleted file mode 100644 index 0caab88..0000000 --- a/test_model.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for SwiftFormerTemporal model -""" -import torch -import sys -import os - -# Add current directory to path -sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) - -from models.swiftformer_temporal import SwiftFormerTemporal_XS - -def test_model(): - print("Testing SwiftFormerTemporal model...") - - # Create model - model = SwiftFormerTemporal_XS(num_frames=3, use_representation_head=True) - print(f'Model created: {model.__class__.__name__}') - print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}') - - # Test forward pass - batch_size = 2 - num_frames = 3 - height = width = 224 - x = torch.randn(batch_size, 3 * num_frames, height, width) - - print(f'\nInput shape: {x.shape}') - - with torch.no_grad(): - pred_frame, representation = model(x) - - print(f'Predicted frame shape: {pred_frame.shape}') - print(f'Representation shape: {representation.shape if representation is not None else "None"}') - - # Check output ranges - print(f'\nPredicted frame range: [{pred_frame.min():.3f}, {pred_frame.max():.3f}]') - - # Test loss function - from util.frame_losses import MultiTaskLoss - criterion = MultiTaskLoss() - target = torch.randn_like(pred_frame) - temporal_indices = torch.tensor([3, 3], dtype=torch.long) - - loss, loss_dict = criterion(pred_frame, target, representation, temporal_indices) - print(f'\nLoss test:') - for k, v in loss_dict.items(): - print(f' {k}: {v:.4f}') - - print('\nAll tests passed!') - return True - -if __name__ == '__main__': - try: - test_model() - except Exception as e: - print(f'Test failed with error: {e}') - import traceback - traceback.print_exc() - sys.exit(1) \ No newline at end of file