Compare commits
2 Commits
f7601e9170
...
12de74f130
| Author | SHA1 | Date | |
|---|---|---|---|
| 12de74f130 | |||
| 500c2eb18f |
@@ -11,9 +11,10 @@ shift 2
|
||||
|
||||
# Default parameters
|
||||
MODEL=${MODEL:-"SwiftFormerTemporal_XS"}
|
||||
BATCH_SIZE=${BATCH_SIZE:-32}
|
||||
BATCH_SIZE=${BATCH_SIZE:-128}
|
||||
EPOCHS=${EPOCHS:-100}
|
||||
LR=${LR:-1e-3}
|
||||
# LR=${LR:-1e-3}
|
||||
LR=${LR:-0.01}
|
||||
OUTPUT_DIR=${OUTPUT_DIR:-"./temporal_output"}
|
||||
|
||||
echo "Starting distributed training with $NUM_GPUS GPUs"
|
||||
|
||||
503
evaluate_temporal.py
Normal file
503
evaluate_temporal.py
Normal file
@@ -0,0 +1,503 @@
|
||||
"""
|
||||
评估脚本 for SwiftFormerTemporal frame prediction
|
||||
输出预测图(注意反归一化)以及对应指标(mse&ssim&psnr)
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pickle
|
||||
import numpy as np
|
||||
import random
|
||||
from pathlib import Path
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from util.video_dataset import VideoFrameDataset
|
||||
from models.swiftformer_temporal import (
|
||||
SwiftFormerTemporal_XS, SwiftFormerTemporal_S,
|
||||
SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
||||
)
|
||||
|
||||
# 导入SSIM和PSNR计算
|
||||
from skimage.metrics import structural_similarity as ssim
|
||||
from skimage.metrics import peak_signal_noise_ratio as psnr
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def denormalize(tensor):
|
||||
"""
|
||||
将[-1, 1]范围的张量反归一化到[0, 255]范围
|
||||
Args:
|
||||
tensor: 形状为[B, C, H, W]或[C, H, W],值在[-1, 1]
|
||||
Returns:
|
||||
反归一化后的张量,值在[0, 255]
|
||||
"""
|
||||
# clip 到 [-1, 1] 范围
|
||||
tensor = tensor.clamp(-1, 1)
|
||||
|
||||
# [-1, 1] -> [0, 1]
|
||||
tensor = (tensor + 1) / 2
|
||||
# [0, 1] -> [0, 255]
|
||||
tensor = tensor * 255
|
||||
return tensor.clamp(0, 255)
|
||||
|
||||
|
||||
def calculate_metrics(pred, target, debug=False):
|
||||
"""
|
||||
计算MSE, SSIM, PSNR指标
|
||||
Args:
|
||||
pred: 预测图像,形状[H, W],值在[0, 255]
|
||||
target: 目标图像,形状[H, W],值在[0, 255]
|
||||
debug: 是否输出调试信息
|
||||
Returns:
|
||||
mse, ssim_value, psnr_value
|
||||
"""
|
||||
# 转换为numpy数组
|
||||
pred_np = pred.cpu().numpy() if torch.is_tensor(pred) else pred
|
||||
target_np = target.cpu().numpy() if torch.is_tensor(target) else target
|
||||
|
||||
# 确保是2D数组
|
||||
if pred_np.ndim == 3:
|
||||
pred_np = pred_np.squeeze(0)
|
||||
if target_np.ndim == 3:
|
||||
target_np = target_np.squeeze(0)
|
||||
|
||||
if debug:
|
||||
print(f"[DEBUG] pred_np range: [{pred_np.min():.2f}, {pred_np.max():.2f}], mean: {pred_np.mean():.2f}")
|
||||
print(f"[DEBUG] target_np range: [{target_np.min():.2f}, {target_np.max():.2f}], mean: {target_np.mean():.2f}")
|
||||
print(f"[DEBUG] pred_np sample values (first 5): {pred_np.ravel()[:5]}")
|
||||
|
||||
# 计算MSE - 修复错误的tmp公式
|
||||
# 原错误公式: tmp = 1 - (pred_np - target_np) / 255 * 2
|
||||
# 正确公式: 直接计算像素差的平方
|
||||
mse = np.mean((pred_np - target_np) ** 2)
|
||||
|
||||
# 同时计算错误公式的MSE用于对比
|
||||
tmp = 1 - (pred_np - target_np) / 255 * 2
|
||||
wrong_mse = np.mean(tmp**2)
|
||||
|
||||
if debug:
|
||||
print(f"[DEBUG] Correct MSE: {mse:.6f}, Wrong MSE (tmp formula): {wrong_mse:.6f}")
|
||||
|
||||
# 计算SSIM (数据范围0-255)
|
||||
data_range = 255.0
|
||||
ssim_value = ssim(pred_np, target_np, data_range=data_range)
|
||||
|
||||
# 计算PSNR
|
||||
psnr_value = psnr(target_np, pred_np, data_range=data_range)
|
||||
|
||||
return mse, ssim_value, psnr_value
|
||||
|
||||
|
||||
def save_comparison_figure(input_frames, target_frame, pred_frame, save_path,
|
||||
input_frame_indices=None, target_frame_index=None):
|
||||
"""
|
||||
保存对比图:输入帧、目标帧、预测帧
|
||||
Args:
|
||||
input_frames: 输入帧列表,每个形状为[H, W],值在[0, 255]
|
||||
target_frame: 目标帧,形状[H, W],值在[0, 255]
|
||||
pred_frame: 预测帧,形状[H, W],值在[0, 255]
|
||||
save_path: 保存路径
|
||||
input_frame_indices: 输入帧的索引列表(可选)
|
||||
target_frame_index: 目标帧索引(可选)
|
||||
"""
|
||||
num_input = len(input_frames)
|
||||
fig, axes = plt.subplots(1, num_input + 2, figsize=(4*(num_input+2), 4))
|
||||
|
||||
# 绘制输入帧
|
||||
for i in range(num_input):
|
||||
ax = axes[i]
|
||||
ax.imshow(input_frames[i], cmap='gray')
|
||||
if input_frame_indices is not None:
|
||||
ax.set_title(f'Input Frame {input_frame_indices[i]}')
|
||||
else:
|
||||
ax.set_title(f'Input {i+1}')
|
||||
ax.axis('off')
|
||||
|
||||
# 绘制目标帧
|
||||
ax = axes[num_input]
|
||||
ax.imshow(target_frame, cmap='gray')
|
||||
if target_frame_index is not None:
|
||||
ax.set_title(f'Target Frame {target_frame_index}')
|
||||
else:
|
||||
ax.set_title('Target')
|
||||
ax.axis('off')
|
||||
|
||||
# 绘制预测帧
|
||||
ax = axes[num_input + 1]
|
||||
ax.imshow(pred_frame, cmap='gray')
|
||||
ax.set_title('Predicted')
|
||||
ax.axis('off')
|
||||
|
||||
# debug print - 改进为更有信息量的输出
|
||||
if isinstance(pred_frame, np.ndarray):
|
||||
print(f"[DEBUG IMAGE] Pred frame shape: {pred_frame.shape}, range: [{pred_frame.min():.2f}, {pred_frame.max():.2f}], mean: {pred_frame.mean():.2f}")
|
||||
# 检查是否有大量值在127.5附近
|
||||
mask_near_127_5 = np.abs(pred_frame - 127.5) < 1.0
|
||||
percent_near_127_5 = np.mean(mask_near_127_5) * 100
|
||||
print(f"[DEBUG IMAGE] Percentage of values near 127.5 (±1.0): {percent_near_127_5:.2f}%")
|
||||
else:
|
||||
print(pred_frame)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
|
||||
def evaluate_model(model, data_loader, device, args):
|
||||
"""
|
||||
评估模型并计算指标
|
||||
Args:
|
||||
model: 训练好的模型
|
||||
data_loader: 数据加载器
|
||||
device: 设备
|
||||
args: 命令行参数
|
||||
Returns:
|
||||
metrics_dict: 包含所有指标的字典
|
||||
sample_results: 示例结果用于可视化
|
||||
"""
|
||||
# model.eval()
|
||||
model.train() # 临时使用训练模式
|
||||
|
||||
# 初始化指标累加器
|
||||
total_mse = 0.0
|
||||
total_ssim = 0.0
|
||||
total_psnr = 0.0
|
||||
total_samples = 0
|
||||
|
||||
# 存储示例结果用于可视化(使用蓄水池抽样随机选择)
|
||||
sample_results = []
|
||||
max_samples_to_save = args.num_samples_to_save
|
||||
max_samples = args.max_samples
|
||||
# 用于蓄水池抽样的计数器(已处理的样本数,不包括因max_samples限制而跳过的样本)
|
||||
sample_count = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(data_loader):
|
||||
input_frames = input_frames.to(device, non_blocking=True)
|
||||
target_frames = target_frames.to(device, non_blocking=True)
|
||||
|
||||
# 前向传播
|
||||
pred_frames, _ = model(input_frames)
|
||||
|
||||
# 反归一化用于指标计算
|
||||
pred_denorm = denormalize(pred_frames) # [B, 1, H, W]
|
||||
target_denorm = denormalize(target_frames) # [B, 1, H, W]
|
||||
|
||||
batch_size = input_frames.size(0)
|
||||
|
||||
# 计算每个样本的指标
|
||||
for i in range(batch_size):
|
||||
# 检查是否达到最大样本数限制
|
||||
if max_samples is not None and total_samples >= max_samples:
|
||||
break
|
||||
|
||||
pred_i = pred_denorm[i] # [1, H, W]
|
||||
target_i = target_denorm[i] # [1, H, W]
|
||||
|
||||
# 对第一个样本启用调试
|
||||
debug_mode = (batch_idx == 0 and i == 0 and total_samples == 0)
|
||||
if debug_mode:
|
||||
print(f"[DEBUG] Raw pred_frames range: [{pred_frames.min():.4f}, {pred_frames.max():.4f}], mean: {pred_frames.mean():.4f}")
|
||||
print(f"[DEBUG] Raw target_frames range: [{target_frames.min():.4f}, {target_frames.max():.4f}], mean: {target_frames.mean():.4f}")
|
||||
print(f"[DEBUG] Pred_denorm range: [{pred_denorm.min():.2f}, {pred_denorm.max():.2f}], mean: {pred_denorm.mean():.2f}")
|
||||
print(f"[DEBUG] Target_denorm range: [{target_denorm.min():.2f}, {target_denorm.max():.2f}], mean: {target_denorm.mean():.2f}")
|
||||
|
||||
mse, ssim_value, psnr_value = calculate_metrics(pred_i, target_i, debug=debug_mode)
|
||||
|
||||
total_mse += mse
|
||||
total_ssim += ssim_value
|
||||
total_psnr += psnr_value
|
||||
total_samples += 1
|
||||
sample_count += 1
|
||||
|
||||
# 构建样本数据字典
|
||||
input_denorm = denormalize(input_frames[i]) # [num_frames, H, W]
|
||||
# 分离输入帧
|
||||
input_frames_list = []
|
||||
for j in range(args.num_frames):
|
||||
input_frame_j = input_denorm[j].squeeze(0) # [H, W]
|
||||
input_frames_list.append(input_frame_j.cpu().numpy())
|
||||
|
||||
sample_data = {
|
||||
'input_frames': input_frames_list,
|
||||
'target_frame': target_i.squeeze(0).cpu().numpy(),
|
||||
'pred_frame': pred_i.squeeze(0).cpu().numpy(),
|
||||
'metrics': {
|
||||
'mse': mse,
|
||||
'ssim': ssim_value,
|
||||
'psnr': psnr_value
|
||||
},
|
||||
'batch_idx': batch_idx,
|
||||
'sample_idx': i
|
||||
}
|
||||
|
||||
# 蓄水池抽样 (Reservoir Sampling)
|
||||
if sample_count <= max_samples_to_save:
|
||||
# 蓄水池未满,直接加入
|
||||
sample_results.append(sample_data)
|
||||
else:
|
||||
# 以 max_samples_to_save / sample_count 的概率替换蓄水池中的一个随机位置
|
||||
r = random.randint(0, sample_count - 1)
|
||||
if r < max_samples_to_save:
|
||||
sample_results[r] = sample_data
|
||||
|
||||
# 检查是否达到最大样本数限制
|
||||
if max_samples is not None and total_samples >= max_samples:
|
||||
print(f"达到最大样本数限制: {max_samples}")
|
||||
break
|
||||
|
||||
# 进度打印
|
||||
if (batch_idx + 1) % 10 == 0:
|
||||
print(f'Processed {batch_idx + 1} batches, {total_samples} samples')
|
||||
|
||||
# 计算平均指标
|
||||
if total_samples > 0:
|
||||
avg_mse = float(total_mse / total_samples)
|
||||
avg_ssim = float(total_ssim / total_samples)
|
||||
avg_psnr = float(total_psnr / total_samples)
|
||||
else:
|
||||
avg_mse = avg_ssim = avg_psnr = 0.0
|
||||
|
||||
metrics_dict = {
|
||||
'mse': avg_mse,
|
||||
'ssim': avg_ssim,
|
||||
'psnr': avg_psnr,
|
||||
'num_samples': total_samples
|
||||
}
|
||||
|
||||
return metrics_dict, sample_results
|
||||
|
||||
|
||||
def main(args):
|
||||
print("评估参数:", args)
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
# 设置随机种子
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
# 构建数据集
|
||||
print("构建数据集...")
|
||||
dataset_val = VideoFrameDataset(
|
||||
root_dir=args.data_path,
|
||||
num_frames=args.num_frames,
|
||||
frame_size=args.frame_size,
|
||||
is_train=False,
|
||||
max_interval=args.max_interval
|
||||
)
|
||||
|
||||
data_loader_val = torch.utils.data.DataLoader(
|
||||
dataset_val,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
shuffle=False,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
# 创建模型
|
||||
print(f"创建模型: {args.model}")
|
||||
model_kwargs = {
|
||||
'num_frames': args.num_frames,
|
||||
'use_representation_head': args.use_representation_head,
|
||||
'representation_dim': args.representation_dim,
|
||||
}
|
||||
|
||||
if args.model == 'SwiftFormerTemporal_XS':
|
||||
model = SwiftFormerTemporal_XS(**model_kwargs)
|
||||
elif args.model == 'SwiftFormerTemporal_S':
|
||||
model = SwiftFormerTemporal_S(**model_kwargs)
|
||||
elif args.model == 'SwiftFormerTemporal_L1':
|
||||
model = SwiftFormerTemporal_L1(**model_kwargs)
|
||||
elif args.model == 'SwiftFormerTemporal_L3':
|
||||
model = SwiftFormerTemporal_L3(**model_kwargs)
|
||||
else:
|
||||
raise ValueError(f"未知模型: {args.model}")
|
||||
|
||||
model.to(device)
|
||||
|
||||
# 加载检查点
|
||||
if args.resume:
|
||||
print(f"加载检查点: {args.resume}")
|
||||
try:
|
||||
# 尝试使用weights_only=False加载(PyTorch 2.6+需要)
|
||||
checkpoint = torch.load(args.resume, map_location='cpu', weights_only=False)
|
||||
except (pickle.UnpicklingError, TypeError) as e:
|
||||
print(f"使用weights_only=False加载失败: {e}")
|
||||
print("尝试使用torch.serialization.add_safe_globals...")
|
||||
from argparse import Namespace
|
||||
# 添加安全全局变量
|
||||
torch.serialization.add_safe_globals([Namespace])
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
|
||||
# 处理状态字典(可能包含'module.'前缀)
|
||||
if 'model' in checkpoint:
|
||||
state_dict = checkpoint['model']
|
||||
elif 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
|
||||
# 移除'module.'前缀(如果存在)
|
||||
if hasattr(model, 'module'):
|
||||
model.module.load_state_dict(state_dict)
|
||||
else:
|
||||
# 如果状态字典有'module.'前缀但模型没有,需要移除前缀
|
||||
if any(key.startswith('module.') for key in state_dict.keys()):
|
||||
from collections import OrderedDict
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith('module.'):
|
||||
new_state_dict[k[7:]] = v
|
||||
else:
|
||||
new_state_dict[k] = v
|
||||
state_dict = new_state_dict
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
print(f"检查点加载成功,epoch: {checkpoint.get('epoch', 'unknown')}")
|
||||
else:
|
||||
print("警告: 未提供检查点路径,使用随机初始化的模型")
|
||||
|
||||
# 创建输出目录
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 评估模型
|
||||
print("开始评估...")
|
||||
metrics, sample_results = evaluate_model(model, data_loader_val, device, args)
|
||||
|
||||
# 打印指标
|
||||
print("\n" + "="*50)
|
||||
print("评估结果:")
|
||||
print(f"MSE: {metrics['mse']:.6f}")
|
||||
print(f"SSIM: {metrics['ssim']:.6f}")
|
||||
print(f"PSNR: {metrics['psnr']:.6f} dB")
|
||||
print(f"样本数量: {metrics['num_samples']}")
|
||||
print("="*50)
|
||||
|
||||
# 保存指标到JSON文件
|
||||
metrics_file = output_dir / 'evaluation_metrics.json'
|
||||
with open(metrics_file, 'w') as f:
|
||||
json.dump(metrics, f, indent=4)
|
||||
print(f"指标已保存到: {metrics_file}")
|
||||
|
||||
# 保存示例可视化
|
||||
if sample_results:
|
||||
print(f"\n保存 {len(sample_results)} 个示例可视化...")
|
||||
samples_dir = output_dir / 'sample_predictions'
|
||||
samples_dir.mkdir(exist_ok=True)
|
||||
|
||||
for i, sample in enumerate(sample_results):
|
||||
save_path = samples_dir / f'sample_{i:03d}.png'
|
||||
|
||||
# 生成输入帧索引(假设连续)
|
||||
input_frame_indices = list(range(1, args.num_frames + 1))
|
||||
target_frame_index = args.num_frames + 1
|
||||
|
||||
save_comparison_figure(
|
||||
sample['input_frames'],
|
||||
sample['target_frame'],
|
||||
sample['pred_frame'],
|
||||
save_path,
|
||||
input_frame_indices=input_frame_indices,
|
||||
target_frame_index=target_frame_index
|
||||
)
|
||||
|
||||
# 保存该样本的指标
|
||||
sample_metrics_file = samples_dir / f'sample_{i:03d}_metrics.txt'
|
||||
with open(sample_metrics_file, 'w') as f:
|
||||
f.write(f"Sample {i} (batch {sample['batch_idx']}, idx {sample['sample_idx']})\n")
|
||||
f.write(f"MSE: {sample['metrics']['mse']:.6f}\n")
|
||||
f.write(f"SSIM: {sample['metrics']['ssim']:.6f}\n")
|
||||
f.write(f"PSNR: {sample['metrics']['psnr']:.6f} dB\n")
|
||||
|
||||
print(f"示例可视化已保存到: {samples_dir}")
|
||||
|
||||
# 生成汇总报告
|
||||
report_file = output_dir / 'evaluation_report.txt'
|
||||
with open(report_file, 'w') as f:
|
||||
f.write("SwiftFormerTemporal 帧预测评估报告\n")
|
||||
f.write("="*50 + "\n")
|
||||
f.write(f"模型: {args.model}\n")
|
||||
f.write(f"检查点: {args.resume}\n")
|
||||
f.write(f"数据集: {args.data_path}\n")
|
||||
f.write(f"输入帧数: {args.num_frames}\n")
|
||||
f.write(f"帧大小: {args.frame_size}\n")
|
||||
f.write(f"批次大小: {args.batch_size}\n")
|
||||
f.write(f"样本总数: {metrics['num_samples']}\n\n")
|
||||
|
||||
f.write("评估指标:\n")
|
||||
f.write(f" MSE: {metrics['mse']:.6f}\n")
|
||||
f.write(f" SSIM: {metrics['ssim']:.6f}\n")
|
||||
f.write(f" PSNR: {metrics['psnr']:.6f} dB\n")
|
||||
|
||||
print(f"评估报告已保存到: {report_file}")
|
||||
print("\n评估完成!")
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
'SwiftFormerTemporal 评估脚本', add_help=False)
|
||||
|
||||
# 数据集参数
|
||||
parser.add_argument('--data-path', default='./videos', type=str,
|
||||
help='视频数据集路径')
|
||||
parser.add_argument('--num-frames', default=3, type=int,
|
||||
help='输入帧数 (T)')
|
||||
parser.add_argument('--frame-size', default=224, type=int,
|
||||
help='输入帧大小')
|
||||
parser.add_argument('--max-interval', default=4, type=int,
|
||||
help='连续帧之间的最大间隔')
|
||||
|
||||
# 模型参数
|
||||
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
|
||||
help='要评估的模型名称')
|
||||
parser.add_argument('--use-representation-head', action='store_true',
|
||||
help='使用表示头进行姿态/速度预测')
|
||||
parser.add_argument('--representation-dim', default=128, type=int,
|
||||
help='表示向量的维度')
|
||||
|
||||
# 评估参数
|
||||
parser.add_argument('--batch-size', default=16, type=int,
|
||||
help='评估批次大小')
|
||||
parser.add_argument('--num-samples-to-save', default=10, type=int,
|
||||
help='保存可视化的样本数量')
|
||||
parser.add_argument('--max-samples', default=None, type=int,
|
||||
help='最大评估样本数(None表示全部)')
|
||||
|
||||
# 系统参数
|
||||
parser.add_argument('--output-dir', default='./evaluation_results',
|
||||
help='保存结果的路径')
|
||||
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
help='使用的设备')
|
||||
parser.add_argument('--seed', default=0, type=int)
|
||||
parser.add_argument('--resume', default='', help='检查点路径')
|
||||
parser.add_argument('--num-workers', default=4, type=int)
|
||||
parser.add_argument('--pin-mem', action='store_true',
|
||||
help='在DataLoader中固定CPU内存')
|
||||
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
|
||||
parser.set_defaults(pin_mem=True)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
'SwiftFormerTemporal 评估', parents=[get_args_parser()])
|
||||
args = parser.parse_args()
|
||||
|
||||
# 确保输出目录存在
|
||||
if args.output_dir:
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
main(args)
|
||||
@@ -19,7 +19,7 @@ from timm.utils import NativeScaler, get_state_dict, ModelEma
|
||||
from util import *
|
||||
from models import *
|
||||
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
|
||||
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
|
||||
from util.video_dataset import VideoFrameDataset
|
||||
from util.frame_losses import MultiTaskLoss
|
||||
|
||||
# Try to import TensorBoard
|
||||
@@ -47,7 +47,7 @@ def get_args_parser():
|
||||
help='Number of input frames (T)')
|
||||
parser.add_argument('--frame-size', default=224, type=int,
|
||||
help='Input frame size')
|
||||
parser.add_argument('--max-interval', default=1, type=int,
|
||||
parser.add_argument('--max-interval', default=4, type=int,
|
||||
help='Maximum interval between consecutive frames')
|
||||
|
||||
# Model parameters
|
||||
@@ -57,6 +57,7 @@ def get_args_parser():
|
||||
help='Use representation head for pose/velocity prediction')
|
||||
parser.add_argument('--representation-dim', default=128, type=int,
|
||||
help='Dimension of representation vector')
|
||||
parser.add_argument('--use-skip', default=True, type=bool, help='using skip connections')
|
||||
|
||||
# Training parameters
|
||||
parser.add_argument('--batch-size', default=32, type=int)
|
||||
@@ -77,7 +78,7 @@ def get_args_parser():
|
||||
help='SGD momentum (default: 0.9)')
|
||||
parser.add_argument('--weight-decay', type=float, default=0.05,
|
||||
help='weight decay (default: 0.05)')
|
||||
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
|
||||
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
|
||||
help='learning rate (default: 1e-3)')
|
||||
|
||||
# Learning rate schedule parameters (required by timm's create_scheduler)
|
||||
@@ -89,7 +90,7 @@ def get_args_parser():
|
||||
help='learning rate noise limit percent (default: 0.67)')
|
||||
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
||||
help='learning rate noise std-dev (default: 1.0)')
|
||||
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
|
||||
parser.add_argument('--warmup-lr', type=float, default=1e-3, metavar='LR',
|
||||
help='warmup learning rate (default: 1e-6)')
|
||||
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
||||
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
||||
@@ -109,10 +110,10 @@ def get_args_parser():
|
||||
help='Weight for frame prediction loss')
|
||||
parser.add_argument('--contrastive-weight', type=float, default=0.1,
|
||||
help='Weight for contrastive loss')
|
||||
parser.add_argument('--l1-weight', type=float, default=1.0,
|
||||
help='Weight for L1 loss')
|
||||
parser.add_argument('--ssim-weight', type=float, default=0.1,
|
||||
help='Weight for SSIM loss')
|
||||
# parser.add_argument('--l1-weight', type=float, default=1.0,
|
||||
# help='Weight for L1 loss')
|
||||
# parser.add_argument('--ssim-weight', type=float, default=0.1,
|
||||
# help='Weight for SSIM loss')
|
||||
parser.add_argument('--no-contrastive', action='store_true',
|
||||
help='Disable contrastive loss')
|
||||
parser.add_argument('--no-ssim', action='store_true',
|
||||
@@ -212,6 +213,7 @@ def main(args):
|
||||
'num_frames': args.num_frames,
|
||||
'use_representation_head': args.use_representation_head,
|
||||
'representation_dim': args.representation_dim,
|
||||
'use_skip': args.use_skip,
|
||||
}
|
||||
|
||||
if args.model == 'SwiftFormerTemporal_XS':
|
||||
@@ -326,7 +328,7 @@ def main(args):
|
||||
lr_scheduler.step(epoch)
|
||||
|
||||
# Save checkpoint
|
||||
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1):
|
||||
if args.output_dir and (epoch % 2 == 0 or epoch == args.epochs - 1):
|
||||
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
|
||||
utils.save_on_master({
|
||||
'model': model_without_ddp.state_dict(),
|
||||
@@ -374,6 +376,11 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
header = f'Epoch: [{epoch}]'
|
||||
print_freq = 10
|
||||
|
||||
# 添加诊断指标
|
||||
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
|
||||
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(
|
||||
metric_logger.log_every(data_loader, print_freq, header)):
|
||||
|
||||
@@ -382,7 +389,7 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
temporal_indices = temporal_indices.to(device, non_blocking=True)
|
||||
|
||||
# Forward pass
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.amp.autocast(device_type='cuda'):
|
||||
pred_frames, representations = model(input_frames)
|
||||
loss, loss_dict = criterion(
|
||||
pred_frames, target_frames,
|
||||
@@ -395,6 +402,8 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
raise ValueError(f"Loss is {loss_value}")
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# 在反向传播前保存梯度用于诊断
|
||||
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
|
||||
parameters=model.parameters())
|
||||
|
||||
@@ -402,6 +411,30 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
|
||||
# 计算诊断指标
|
||||
pred_mean = pred_frames.mean().item()
|
||||
pred_std = pred_frames.std().item()
|
||||
|
||||
# 计算梯度范数
|
||||
total_grad_norm = 0.0
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
total_grad_norm += param.grad.norm().item()
|
||||
|
||||
# 记录诊断指标
|
||||
metric_logger.update(pred_mean=pred_mean)
|
||||
metric_logger.update(pred_std=pred_std)
|
||||
metric_logger.update(grad_norm=total_grad_norm)
|
||||
|
||||
# 每50个批次打印一次BatchNorm统计
|
||||
if batch_idx % 50 == 0:
|
||||
print(f"[诊断] 批次 {batch_idx}: 预测均值={pred_mean:.4f}, 预测标准差={pred_std:.4f}, 梯度范数={total_grad_norm:.4f}")
|
||||
# 检查一个BatchNorm层的运行统计
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||
print(f"[诊断] {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||
break
|
||||
|
||||
# Log to TensorBoard
|
||||
if writer is not None:
|
||||
# Log scalar metrics every iteration
|
||||
@@ -415,6 +448,11 @@ def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, los
|
||||
else:
|
||||
writer.add_scalar(f'train/{k}', v, global_step)
|
||||
|
||||
# Log diagnostic metrics
|
||||
writer.add_scalar('train/pred_mean', pred_mean, global_step)
|
||||
writer.add_scalar('train/pred_std', pred_std, global_step)
|
||||
writer.add_scalar('train/grad_norm', total_grad_norm, global_step)
|
||||
|
||||
# Log images periodically
|
||||
if args is not None and getattr(args, 'log_images', False) and global_step % getattr(args, 'image_log_freq', 100) == 0:
|
||||
with torch.no_grad():
|
||||
@@ -451,19 +489,53 @@ def evaluate(data_loader, model, criterion, device, writer=None, epoch=0):
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
header = 'Test:'
|
||||
|
||||
for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header):
|
||||
# 添加诊断指标
|
||||
metric_logger.add_meter('pred_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
metric_logger.add_meter('pred_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
metric_logger.add_meter('target_mean', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
metric_logger.add_meter('target_std', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
||||
|
||||
for batch_idx, (input_frames, target_frames, temporal_indices) in enumerate(metric_logger.log_every(data_loader, 10, header)):
|
||||
input_frames = input_frames.to(device, non_blocking=True)
|
||||
target_frames = target_frames.to(device, non_blocking=True)
|
||||
temporal_indices = temporal_indices.to(device, non_blocking=True)
|
||||
|
||||
# Compute output
|
||||
with torch.cuda.amp.autocast():
|
||||
with torch.amp.autocast(device_type='cuda'):
|
||||
pred_frames, representations = model(input_frames)
|
||||
loss, loss_dict = criterion(
|
||||
pred_frames, target_frames,
|
||||
representations, temporal_indices
|
||||
)
|
||||
|
||||
# 计算诊断指标
|
||||
pred_mean = pred_frames.mean().item()
|
||||
pred_std = pred_frames.std().item()
|
||||
target_mean = target_frames.mean().item()
|
||||
target_std = target_frames.std().item()
|
||||
|
||||
# 更新诊断指标
|
||||
metric_logger.update(pred_mean=pred_mean)
|
||||
metric_logger.update(pred_std=pred_std)
|
||||
metric_logger.update(target_mean=target_mean)
|
||||
metric_logger.update(target_std=target_std)
|
||||
|
||||
# 第一个批次打印详细诊断信息
|
||||
if batch_idx == 0:
|
||||
print(f"[评估诊断] 批次 0:")
|
||||
print(f" 预测范围: [{pred_frames.min().item():.4f}, {pred_frames.max().item():.4f}]")
|
||||
print(f" 预测均值: {pred_mean:.4f}, 预测标准差: {pred_std:.4f}")
|
||||
print(f" 目标范围: [{target_frames.min().item():.4f}, {target_frames.max().item():.4f}]")
|
||||
print(f" 目标均值: {target_mean:.4f}, 目标标准差: {target_std:.4f}")
|
||||
|
||||
# 检查BatchNorm运行统计
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.BatchNorm2d) and 'decoder.blocks.0.bn' in name:
|
||||
print(f" {name}: 运行均值={module.running_mean[0].item():.6f}, 运行方差={module.running_var[0].item():.6f}")
|
||||
if module.running_var[0].item() < 1e-6:
|
||||
print(f" 警告: BatchNorm运行方差接近零!")
|
||||
break
|
||||
|
||||
# Update metrics
|
||||
metric_logger.update(loss=loss.item())
|
||||
for k, v in loss_dict.items():
|
||||
|
||||
@@ -11,26 +11,188 @@ from timm.layers import DropPath, trunc_normal_
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
"""Upsampling block for frame prediction decoder"""
|
||||
"""Upsampling block for frame prediction decoder with residual connections"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose2d(
|
||||
# 主路径:反卷积 + 两个卷积层
|
||||
self.conv_transpose = nn.ConvTranspose2d(
|
||||
in_channels, out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
bias=False
|
||||
bias=True # 启用bias,因为移除了BN
|
||||
)
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv1 = nn.Conv2d(out_channels, out_channels,
|
||||
kernel_size=3, padding=1, bias=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels,
|
||||
kernel_size=3, padding=1, bias=True)
|
||||
|
||||
# 残差路径:如果需要改变通道数或空间尺寸
|
||||
self.shortcut = nn.Identity()
|
||||
if in_channels != out_channels or stride != 1:
|
||||
# 使用1x1卷积调整通道数,如果需要上采样则使用反卷积
|
||||
if stride == 1:
|
||||
self.shortcut = nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=1, bias=True)
|
||||
else:
|
||||
self.shortcut = nn.ConvTranspose2d(
|
||||
in_channels, out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
output_padding=output_padding,
|
||||
bias=True
|
||||
)
|
||||
|
||||
# 使用LeakyReLU避免死亡神经元
|
||||
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
||||
|
||||
# 初始化权重
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
# 初始化反卷积层
|
||||
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.conv_transpose.bias is not None:
|
||||
nn.init.constant_(self.conv_transpose.bias, 0)
|
||||
|
||||
# 初始化卷积层
|
||||
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.conv1.bias is not None:
|
||||
nn.init.constant_(self.conv1.bias, 0)
|
||||
|
||||
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.conv2.bias is not None:
|
||||
nn.init.constant_(self.conv2.bias, 0)
|
||||
|
||||
# 初始化shortcut
|
||||
if not isinstance(self.shortcut, nn.Identity):
|
||||
if isinstance(self.shortcut, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
elif isinstance(self.shortcut, nn.ConvTranspose2d):
|
||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.shortcut.bias is not None:
|
||||
nn.init.constant_(self.shortcut.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.bn(self.conv(x)))
|
||||
identity = self.shortcut(x)
|
||||
|
||||
# 主路径
|
||||
x = self.conv_transpose(x)
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
# 残差连接
|
||||
x = x + identity
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
|
||||
class DecoderBlockWithSkip(nn.Module):
|
||||
"""Decoder block with skip connection support"""
|
||||
def __init__(self, in_channels, out_channels, skip_channels=0, kernel_size=3, stride=2, padding=1, output_padding=1):
|
||||
super().__init__()
|
||||
# 总输入通道 = 输入通道 + skip通道
|
||||
total_in_channels = in_channels + skip_channels
|
||||
|
||||
# 主路径:反卷积 + 两个卷积层
|
||||
self.conv_transpose = nn.ConvTranspose2d(
|
||||
total_in_channels, out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
bias=True
|
||||
)
|
||||
self.conv1 = nn.Conv2d(out_channels, out_channels,
|
||||
kernel_size=3, padding=1, bias=True)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels,
|
||||
kernel_size=3, padding=1, bias=True)
|
||||
|
||||
# 残差路径:如果需要改变通道数或空间尺寸
|
||||
self.shortcut = nn.Identity()
|
||||
if total_in_channels != out_channels or stride != 1:
|
||||
if stride == 1:
|
||||
self.shortcut = nn.Conv2d(total_in_channels, out_channels,
|
||||
kernel_size=1, bias=True)
|
||||
else:
|
||||
self.shortcut = nn.ConvTranspose2d(
|
||||
total_in_channels, out_channels,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
output_padding=output_padding,
|
||||
bias=True
|
||||
)
|
||||
|
||||
# 使用LeakyReLU避免死亡神经元
|
||||
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
||||
|
||||
# 初始化权重
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
# 初始化反卷积层
|
||||
nn.init.kaiming_normal_(self.conv_transpose.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.conv_transpose.bias is not None:
|
||||
nn.init.constant_(self.conv_transpose.bias, 0)
|
||||
|
||||
# 初始化卷积层
|
||||
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.conv1.bias is not None:
|
||||
nn.init.constant_(self.conv1.bias, 0)
|
||||
|
||||
nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.conv2.bias is not None:
|
||||
nn.init.constant_(self.conv2.bias, 0)
|
||||
|
||||
# 初始化shortcut
|
||||
if not isinstance(self.shortcut, nn.Identity):
|
||||
if isinstance(self.shortcut, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
elif isinstance(self.shortcut, nn.ConvTranspose2d):
|
||||
nn.init.kaiming_normal_(self.shortcut.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if self.shortcut.bias is not None:
|
||||
nn.init.constant_(self.shortcut.bias, 0)
|
||||
|
||||
def forward(self, x, skip_feature=None):
|
||||
# 如果有skip feature,将其与输入拼接
|
||||
if skip_feature is not None:
|
||||
# 确保skip特征的空间尺寸与x匹配
|
||||
if skip_feature.shape[2:] != x.shape[2:]:
|
||||
# 使用双线性插值进行上采样或下采样
|
||||
skip_feature = torch.nn.functional.interpolate(
|
||||
skip_feature,
|
||||
size=x.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
x = torch.cat([x, skip_feature], dim=1)
|
||||
|
||||
identity = self.shortcut(x)
|
||||
|
||||
# 主路径
|
||||
x = self.conv_transpose(x)
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.activation(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
# 残差连接
|
||||
x = x + identity
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
|
||||
class FramePredictionDecoder(nn.Module):
|
||||
"""Lightweight decoder for frame prediction with optional skip connections"""
|
||||
"""Improved decoder for frame prediction with better upsampling strategy"""
|
||||
def __init__(self, embed_dims, output_channels=1, use_skip=False):
|
||||
super().__init__()
|
||||
self.use_skip = use_skip
|
||||
@@ -38,65 +200,109 @@ class FramePredictionDecoder(nn.Module):
|
||||
decoder_dims = embed_dims[::-1]
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
# First upsampling from bottleneck to stage4 resolution
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[0], decoder_dims[1],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# stage4 to stage3
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[1], decoder_dims[2],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# stage3 to stage2
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[2], decoder_dims[3],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# stage2 to original resolution (now 8x upsampling total with stride 4)
|
||||
self.blocks.append(nn.Sequential(
|
||||
nn.ConvTranspose2d(
|
||||
decoder_dims[3], 32,
|
||||
kernel_size=3, stride=4, padding=1, output_padding=3
|
||||
),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(32, output_channels, kernel_size=3, padding=1),
|
||||
nn.Tanh() # Output in [-1, 1] range
|
||||
))
|
||||
|
||||
# If using skip connections, we need to adjust input channels for each block
|
||||
if use_skip:
|
||||
# We'll modify the first three blocks to accept concatenated features
|
||||
# Instead of modifying existing blocks, we'll replace them with custom blocks
|
||||
# For simplicity, we'll keep the same architecture but forward will handle concatenation
|
||||
pass
|
||||
# 使用支持skip connections的block
|
||||
# 第一个block:从bottleneck到stage4,使用大步长stride=4,skip来自stage3
|
||||
self.blocks.append(DecoderBlockWithSkip(
|
||||
decoder_dims[0], decoder_dims[1],
|
||||
skip_channels=embed_dims[3], # stage3的通道数
|
||||
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
|
||||
))
|
||||
# 第二个block:stage4到stage3,stride=2,skip来自stage2
|
||||
self.blocks.append(DecoderBlockWithSkip(
|
||||
decoder_dims[1], decoder_dims[2],
|
||||
skip_channels=embed_dims[2], # stage2的通道数
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# 第三个block:stage3到stage2,stride=2,skip来自stage1
|
||||
self.blocks.append(DecoderBlockWithSkip(
|
||||
decoder_dims[2], decoder_dims[3],
|
||||
skip_channels=embed_dims[1], # stage1的通道数
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# 第四个block:stage2到stage1,stride=2,skip来自stage0
|
||||
self.blocks.append(DecoderBlockWithSkip(
|
||||
decoder_dims[3], 64, # 输出到64通道
|
||||
skip_channels=embed_dims[0], # stage0的通道数
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
else:
|
||||
# 使用普通的DecoderBlock,第一个block使用大步长
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[0], decoder_dims[1],
|
||||
kernel_size=3, stride=4, padding=1, output_padding=3 # 改为stride=4
|
||||
))
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[1], decoder_dims[2],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[2], decoder_dims[3],
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
# 第四个block:增加到64通道
|
||||
self.blocks.append(DecoderBlock(
|
||||
decoder_dims[3], 64,
|
||||
kernel_size=3, stride=2, padding=1, output_padding=1
|
||||
))
|
||||
|
||||
# 改进的最终输出层:不使用反卷积,只进行特征精炼
|
||||
# 输入尺寸已经是目标尺寸,只需要调整通道数和进行特征融合
|
||||
self.final_block = nn.Sequential(
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=True),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(32, output_channels, kernel_size=3, padding=1, bias=True)
|
||||
# 移除Tanh,让输出在任意范围,由损失函数和归一化处理
|
||||
)
|
||||
|
||||
def forward(self, x, skip_features=None):
|
||||
"""
|
||||
Args:
|
||||
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
|
||||
skip_features: list of encoder features from stages [stage2, stage1, stage0]
|
||||
each of shape [B, C, H', W'] where C matches decoder dims?
|
||||
skip_features: list of encoder features from stages [stage3, stage2, stage1, stage0]
|
||||
each of shape [B, C, H', W'] where C matches encoder dims
|
||||
"""
|
||||
if self.use_skip and skip_features is not None:
|
||||
# Ensure we have exactly 3 skip features (for the first three blocks)
|
||||
assert len(skip_features) == 3, "Need 3 skip features for skip connections"
|
||||
# Reverse skip_features to match decoder order: stage2, stage1, stage0
|
||||
# skip_features[0] should be stage2 (H/16), [1] stage1 (H/8), [2] stage0 (H/4)
|
||||
skip_features = skip_features[::-1] # Now index 0: stage2, 1: stage1, 2: stage0
|
||||
if self.use_skip:
|
||||
if skip_features is None:
|
||||
raise ValueError("skip_features must be provided when use_skip=True")
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
if self.use_skip and skip_features is not None and i < 3:
|
||||
# Concatenate skip feature along channel dimension
|
||||
# Ensure spatial dimensions match (they should because of upsampling)
|
||||
x = torch.cat([x, skip_features[i]], dim=1)
|
||||
# Need to adjust block to accept extra channels? We'll create a separate block.
|
||||
# For now, we'll just pass through, but this will cause channel mismatch.
|
||||
# Instead, we should have created custom blocks with appropriate in_channels.
|
||||
# This is a placeholder; we need to implement properly.
|
||||
pass
|
||||
x = block(x)
|
||||
# 确保有4个skip features
|
||||
assert len(skip_features) == 4, f"Need 4 skip features, got {len(skip_features)}"
|
||||
|
||||
# 反转顺序以匹配解码器:stage3, stage2, stage1, stage0
|
||||
skip_features = skip_features[::-1]
|
||||
|
||||
# 调整skip特征的尺寸以匹配新的上采样策略
|
||||
adjusted_skip_features = []
|
||||
for i, skip in enumerate(skip_features):
|
||||
if skip is not None:
|
||||
# 计算目标尺寸:4, 2, 2, 2倍上采样
|
||||
upsample_factors = [4, 2, 2, 2]
|
||||
target_height = x.shape[2] * upsample_factors[i]
|
||||
target_width = x.shape[3] * upsample_factors[i]
|
||||
|
||||
if skip.shape[2:] != (target_height, target_width):
|
||||
skip = torch.nn.functional.interpolate(
|
||||
skip,
|
||||
size=(target_height, target_width),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
adjusted_skip_features.append(skip)
|
||||
|
||||
# 四个block使用skip connections
|
||||
for i in range(4):
|
||||
x = self.blocks[i](x, adjusted_skip_features[i])
|
||||
else:
|
||||
# 不使用skip connections
|
||||
for i in range(4):
|
||||
x = self.blocks[i](x)
|
||||
|
||||
# 最终输出层:只进行特征精炼,不上采样
|
||||
x = self.final_block(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -110,6 +316,7 @@ class SwiftFormerTemporal(nn.Module):
|
||||
model_name='XS',
|
||||
num_frames=3,
|
||||
use_decoder=True,
|
||||
use_skip=True, # 新增:是否使用skip connections
|
||||
use_representation_head=False,
|
||||
representation_dim=128,
|
||||
return_features=False,
|
||||
@@ -123,6 +330,7 @@ class SwiftFormerTemporal(nn.Module):
|
||||
# Store configuration
|
||||
self.num_frames = num_frames
|
||||
self.use_decoder = use_decoder
|
||||
self.use_skip = use_skip # 保存skip connections设置
|
||||
self.use_representation_head = use_representation_head
|
||||
self.return_features = return_features
|
||||
|
||||
@@ -155,7 +363,11 @@ class SwiftFormerTemporal(nn.Module):
|
||||
|
||||
# Frame prediction decoder
|
||||
if use_decoder:
|
||||
self.decoder = FramePredictionDecoder(embed_dims, output_channels=1)
|
||||
self.decoder = FramePredictionDecoder(
|
||||
embed_dims,
|
||||
output_channels=1,
|
||||
use_skip=use_skip # 传递skip connections设置
|
||||
)
|
||||
|
||||
# Representation head for pose/velocity prediction
|
||||
if use_representation_head:
|
||||
@@ -173,22 +385,31 @@ class SwiftFormerTemporal(nn.Module):
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
# 使用Kaiming初始化,适合ReLU/LeakyReLU
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.LayerNorm)):
|
||||
elif isinstance(m, nn.ConvTranspose2d):
|
||||
# 反卷积层使用特定的初始化
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward_tokens(self, x):
|
||||
"""Forward through encoder network, return list of stage features if return_features else final output"""
|
||||
if self.return_features:
|
||||
if self.return_features or self.use_skip:
|
||||
features = []
|
||||
stage_idx = 0
|
||||
for idx, block in enumerate(self.network):
|
||||
x = block(x)
|
||||
# Collect output after each stage (indices 0,2,4,6 correspond to stages)
|
||||
# 收集每个stage的输出(stage0, stage1, stage2, stage3)
|
||||
# 根据SwiftFormer结构,stage在索引0,2,4,6位置
|
||||
if idx in [0, 2, 4, 6]:
|
||||
features.append(x)
|
||||
stage_idx += 1
|
||||
return x, features
|
||||
else:
|
||||
for block in self.network:
|
||||
@@ -208,7 +429,7 @@ class SwiftFormerTemporal(nn.Module):
|
||||
"""
|
||||
# Encode
|
||||
x = self.patch_embed(x)
|
||||
if self.return_features:
|
||||
if self.return_features or self.use_skip:
|
||||
x, features = self.forward_tokens(x)
|
||||
else:
|
||||
x = self.forward_tokens(x)
|
||||
@@ -222,7 +443,23 @@ class SwiftFormerTemporal(nn.Module):
|
||||
# Decode to frame
|
||||
pred_frame = None
|
||||
if self.use_decoder:
|
||||
pred_frame = self.decoder(x)
|
||||
if self.use_skip:
|
||||
# 提取用于skip connections的特征
|
||||
# features包含所有stage的输出,我们需要stage0, stage1, stage2, stage3
|
||||
# 根据SwiftFormer结构,应该有4个stage特征
|
||||
if len(features) >= 4:
|
||||
# 取四个stage的特征:stage0, stage1, stage2, stage3
|
||||
skip_features = [features[0], features[1], features[2], features[3]]
|
||||
else:
|
||||
# 如果特征不够,使用可用的特征
|
||||
skip_features = features[:4]
|
||||
# 如果特征仍然不够,使用None填充
|
||||
while len(skip_features) < 4:
|
||||
skip_features.append(None)
|
||||
|
||||
pred_frame = self.decoder(x, skip_features)
|
||||
else:
|
||||
pred_frame = self.decoder(x)
|
||||
|
||||
if self.return_features:
|
||||
return pred_frame, representation, features
|
||||
@@ -231,14 +468,14 @@ class SwiftFormerTemporal(nn.Module):
|
||||
|
||||
|
||||
# Factory functions for different model sizes
|
||||
def SwiftFormerTemporal_XS(num_frames=3, **kwargs):
|
||||
return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs)
|
||||
def SwiftFormerTemporal_XS(num_frames=3, use_skip=True, **kwargs):
|
||||
return SwiftFormerTemporal('XS', num_frames=num_frames, use_skip=use_skip, **kwargs)
|
||||
|
||||
def SwiftFormerTemporal_S(num_frames=3, **kwargs):
|
||||
return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs)
|
||||
def SwiftFormerTemporal_S(num_frames=3, use_skip=True, **kwargs):
|
||||
return SwiftFormerTemporal('S', num_frames=num_frames, use_skip=use_skip, **kwargs)
|
||||
|
||||
def SwiftFormerTemporal_L1(num_frames=3, **kwargs):
|
||||
return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs)
|
||||
def SwiftFormerTemporal_L1(num_frames=3, use_skip=True, **kwargs):
|
||||
return SwiftFormerTemporal('l1', num_frames=num_frames, use_skip=use_skip, **kwargs)
|
||||
|
||||
def SwiftFormerTemporal_L3(num_frames=3, **kwargs):
|
||||
return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)
|
||||
def SwiftFormerTemporal_L3(num_frames=3, use_skip=True, **kwargs):
|
||||
return SwiftFormerTemporal('l3', num_frames=num_frames, use_skip=use_skip, **kwargs)
|
||||
@@ -1,26 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Simple multi-GPU training script for SwiftFormerTemporal
|
||||
# Usage: ./multi_gpu_temporal_train.sh <NUM_GPUS> [OPTIONS]
|
||||
|
||||
NUM_GPUS=${1:-2}
|
||||
shift
|
||||
|
||||
echo "Starting multi-GPU training with $NUM_GPUS GPUs"
|
||||
|
||||
# Set environment variables for distributed training
|
||||
export MASTER_PORT=12345
|
||||
export MASTER_ADDR=localhost
|
||||
export WORLD_SIZE=$NUM_GPUS
|
||||
|
||||
# Launch training
|
||||
torchrun --nproc_per_node=$NUM_GPUS --master_port=$MASTER_PORT main_temporal.py \
|
||||
--data-path "./videos" \
|
||||
--model SwiftFormerTemporal_XS \
|
||||
--batch-size 32 \
|
||||
--epochs 100 \
|
||||
--lr 1e-3 \
|
||||
--output-dir "./temporal_output_multi" \
|
||||
--num-workers 8 \
|
||||
--pin-mem \
|
||||
"$@"
|
||||
45
test_cuda.py
45
test_cuda.py
@@ -1,45 +0,0 @@
|
||||
import torch
|
||||
|
||||
def test_cuda_availability():
|
||||
"""全面测试CUDA可用性"""
|
||||
|
||||
print("="*50)
|
||||
print("PyTorch CUDA 测试")
|
||||
print("="*50)
|
||||
|
||||
# 基本信息
|
||||
print(f"PyTorch版本: {torch.__version__}")
|
||||
print(f"CUDA可用: {torch.cuda.is_available()}")
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA不可用,可能原因:")
|
||||
print("1. 未安装CUDA驱动")
|
||||
print("2. 安装的是CPU版本的PyTorch")
|
||||
print("3. CUDA版本与PyTorch不匹配")
|
||||
return False
|
||||
|
||||
# 设备信息
|
||||
device_count = torch.cuda.device_count()
|
||||
print(f"发现 {device_count} 个CUDA设备")
|
||||
|
||||
for i in range(device_count):
|
||||
print(f"\n设备 {i}:")
|
||||
print(f" 名称: {torch.cuda.get_device_name(i)}")
|
||||
print(f" 内存总量: {torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB")
|
||||
print(f" 计算能力: {torch.cuda.get_device_properties(i).major}.{torch.cuda.get_device_properties(i).minor}")
|
||||
|
||||
# 简单张量测试
|
||||
print("\n运行CUDA测试...")
|
||||
try:
|
||||
x = torch.randn(3, 3).cuda()
|
||||
y = torch.randn(3, 3).cuda()
|
||||
z = x + y
|
||||
print("CUDA计算测试: 成功!")
|
||||
print(f"设备上的张量形状: {z.shape}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"CUDA计算测试失败: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cuda_availability()
|
||||
@@ -1,33 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试 timm 导入是否正常工作
|
||||
"""
|
||||
import sys
|
||||
print("Python version:", sys.version)
|
||||
|
||||
try:
|
||||
from timm.layers import to_2tuple, DropPath, trunc_normal_
|
||||
from timm.models import register_model
|
||||
print("✓ 成功导入 timm.layers.to_2tuple")
|
||||
print("✓ 成功导入 timm.layers.DropPath")
|
||||
print("✓ 成功导入 timm.layers.trunc_normal_")
|
||||
print("✓ 成功导入 timm.models.register_model")
|
||||
except ImportError as e:
|
||||
print(f"✗ 导入失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from models.swiftformer import SwiftFormer_XS
|
||||
print("✓ 成功导入 SwiftFormer_XS")
|
||||
except ImportError as e:
|
||||
print(f"✗ 导入 SwiftFormer_XS 失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
from models.swiftformer_temporal import SwiftFormerTemporal_XS
|
||||
print("✓ 成功导入 SwiftFormerTemporal_XS")
|
||||
except ImportError as e:
|
||||
print(f"✗ 导入 SwiftFormerTemporal_XS 失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
print("\n✅ 所有导入测试通过!")
|
||||
@@ -1,60 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for SwiftFormerTemporal model
|
||||
"""
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add current directory to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from models.swiftformer_temporal import SwiftFormerTemporal_XS
|
||||
|
||||
def test_model():
|
||||
print("Testing SwiftFormerTemporal model...")
|
||||
|
||||
# Create model
|
||||
model = SwiftFormerTemporal_XS(num_frames=3, use_representation_head=True)
|
||||
print(f'Model created: {model.__class__.__name__}')
|
||||
print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}')
|
||||
|
||||
# Test forward pass
|
||||
batch_size = 2
|
||||
num_frames = 3
|
||||
height = width = 224
|
||||
x = torch.randn(batch_size, 3 * num_frames, height, width)
|
||||
|
||||
print(f'\nInput shape: {x.shape}')
|
||||
|
||||
with torch.no_grad():
|
||||
pred_frame, representation = model(x)
|
||||
|
||||
print(f'Predicted frame shape: {pred_frame.shape}')
|
||||
print(f'Representation shape: {representation.shape if representation is not None else "None"}')
|
||||
|
||||
# Check output ranges
|
||||
print(f'\nPredicted frame range: [{pred_frame.min():.3f}, {pred_frame.max():.3f}]')
|
||||
|
||||
# Test loss function
|
||||
from util.frame_losses import MultiTaskLoss
|
||||
criterion = MultiTaskLoss()
|
||||
target = torch.randn_like(pred_frame)
|
||||
temporal_indices = torch.tensor([3, 3], dtype=torch.long)
|
||||
|
||||
loss, loss_dict = criterion(pred_frame, target, representation, temporal_indices)
|
||||
print(f'\nLoss test:')
|
||||
for k, v in loss_dict.items():
|
||||
print(f' {k}: {v:.4f}')
|
||||
|
||||
print('\nAll tests passed!')
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
test_model()
|
||||
except Exception as e:
|
||||
print(f'Test failed with error: {e}')
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -48,27 +48,39 @@ class VideoFrameDataset(Dataset):
|
||||
self.is_train = is_train
|
||||
self.max_interval = max_interval
|
||||
|
||||
# Collect all video folders
|
||||
# if num_frames < 1:
|
||||
# raise ValueError("num_frames must be >= 1")
|
||||
# if frame_size < 1:
|
||||
# raise ValueError("frame_size must be >= 1")
|
||||
# if max_interval < 1:
|
||||
# raise ValueError("max_interval must be >= 1")
|
||||
|
||||
# Collect all video folders and their frame files
|
||||
self.video_folders = []
|
||||
self.video_frame_files = [] # list of list of Path objects
|
||||
for item in self.root_dir.iterdir():
|
||||
if item.is_dir():
|
||||
self.video_folders.append(item)
|
||||
# Get all frame files
|
||||
frame_files = sorted([f for f in item.iterdir()
|
||||
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
|
||||
self.video_frame_files.append(frame_files)
|
||||
|
||||
if len(self.video_folders) == 0:
|
||||
raise ValueError(f"No video folders found in {root_dir}")
|
||||
|
||||
# Build frame index: list of (video_idx, start_frame_idx)
|
||||
self.frame_indices = []
|
||||
for video_idx, video_folder in enumerate(self.video_folders):
|
||||
# Get all frame files
|
||||
frame_files = sorted([f for f in video_folder.iterdir()
|
||||
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
|
||||
|
||||
if len(frame_files) < num_frames + 1:
|
||||
for video_idx, frame_files in enumerate(self.video_frame_files):
|
||||
# Minimum frames needed considering max interval
|
||||
min_frames_needed = num_frames * max_interval + 1
|
||||
if len(frame_files) < min_frames_needed:
|
||||
continue # Skip videos with insufficient frames
|
||||
|
||||
# Add all possible starting positions
|
||||
for start_idx in range(len(frame_files) - num_frames):
|
||||
# Ensure that for any interval up to max_interval, all frames are within bounds
|
||||
max_start = len(frame_files) - num_frames * max_interval
|
||||
for start_idx in range(max_start):
|
||||
self.frame_indices.append((video_idx, start_idx))
|
||||
|
||||
if len(self.frame_indices) == 0:
|
||||
@@ -80,14 +92,12 @@ class VideoFrameDataset(Dataset):
|
||||
else:
|
||||
self.transform = transform
|
||||
|
||||
# Normalization for Y channel (single channel)
|
||||
# Compute average of ImageNet RGB means and stds
|
||||
y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
||||
y_std = (0.229 + 0.224 + 0.225) / 3.0
|
||||
self.normalize = transforms.Normalize(
|
||||
mean=[y_mean],
|
||||
std=[y_std]
|
||||
)
|
||||
# Simple normalization to [-1, 1] range (不使用ImageNet标准化)
|
||||
# Convert pixel values [0, 255] to [-1, 1]
|
||||
# This matches the model's tanh output range
|
||||
self.normalize = None # We'll handle normalization manually
|
||||
|
||||
# print(f"[数据集初始化] 使用简单归一化: 像素值[0,255] -> [-1,1]")
|
||||
|
||||
def _default_transform(self):
|
||||
"""Default transform with augmentation for training"""
|
||||
@@ -105,9 +115,12 @@ class VideoFrameDataset(Dataset):
|
||||
|
||||
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
|
||||
"""Load a single frame as PIL Image"""
|
||||
video_folder = self.video_folders[video_idx]
|
||||
frame_files = sorted([f for f in video_folder.iterdir()
|
||||
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
|
||||
frame_files = self.video_frame_files[video_idx]
|
||||
if frame_idx < 0 or frame_idx >= len(frame_files):
|
||||
raise IndexError(
|
||||
f"Frame index {frame_idx} out of range for video {video_idx} "
|
||||
f"(0-{len(frame_files)-1})"
|
||||
)
|
||||
frame_path = frame_files[frame_idx]
|
||||
return Image.open(frame_path).convert('RGB')
|
||||
|
||||
@@ -144,19 +157,21 @@ class VideoFrameDataset(Dataset):
|
||||
if self.transform:
|
||||
target_frame = self.transform(target_frame)
|
||||
|
||||
# Convert to tensors, normalize, and convert to grayscale (Y channel)
|
||||
# Convert to tensors and convert to grayscale (Y channel)
|
||||
input_tensors = []
|
||||
for frame in input_frames:
|
||||
tensor = transforms.ToTensor()(frame) # [3, H, W]
|
||||
tensor = transforms.ToTensor()(frame) # [3, H, W], range [0, 1]
|
||||
# Convert RGB to grayscale using weighted sum
|
||||
# Y = 0.2989 * R + 0.5870 * G + 0.1140 * B (same as PIL)
|
||||
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W]
|
||||
gray = self.normalize(gray) # normalize with single-channel stats (mean/std broadcast)
|
||||
gray = (0.2989 * tensor[0] + 0.5870 * tensor[1] + 0.1140 * tensor[2]).unsqueeze(0) # [1, H, W], range [0, 1]
|
||||
# Normalize from [0, 1] to [-1, 1]
|
||||
gray = gray * 2 - 1 # [0,1] -> [-1,1]
|
||||
input_tensors.append(gray)
|
||||
|
||||
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W]
|
||||
target_tensor = transforms.ToTensor()(target_frame) # [3, H, W], range [0, 1]
|
||||
target_gray = (0.2989 * target_tensor[0] + 0.5870 * target_tensor[1] + 0.1140 * target_tensor[2]).unsqueeze(0)
|
||||
target_gray = self.normalize(target_gray)
|
||||
# Normalize from [0, 1] to [-1, 1]
|
||||
target_gray = target_gray * 2 - 1 # [0,1] -> [-1,1]
|
||||
|
||||
# Concatenate input frames along channel dimension
|
||||
input_concatenated = torch.cat(input_tensors, dim=0) # [num_frames, H, W]
|
||||
@@ -167,52 +182,52 @@ class VideoFrameDataset(Dataset):
|
||||
return input_concatenated, target_gray, temporal_idx
|
||||
|
||||
|
||||
class SyntheticVideoDataset(Dataset):
|
||||
"""
|
||||
Synthetic dataset for testing - generates random frames
|
||||
"""
|
||||
def __init__(self,
|
||||
num_samples: int = 1000,
|
||||
num_frames: int = 3,
|
||||
frame_size: int = 224,
|
||||
is_train: bool = True):
|
||||
self.num_samples = num_samples
|
||||
self.num_frames = num_frames
|
||||
self.frame_size = frame_size
|
||||
self.is_train = is_train
|
||||
# class SyntheticVideoDataset(Dataset):
|
||||
# """
|
||||
# Synthetic dataset for testing - generates random frames
|
||||
# """
|
||||
# def __init__(self,
|
||||
# num_samples: int = 1000,
|
||||
# num_frames: int = 3,
|
||||
# frame_size: int = 224,
|
||||
# is_train: bool = True):
|
||||
# self.num_samples = num_samples
|
||||
# self.num_frames = num_frames
|
||||
# self.frame_size = frame_size
|
||||
# self.is_train = is_train
|
||||
|
||||
# Normalization for Y channel (single channel)
|
||||
y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
||||
y_std = (0.229 + 0.224 + 0.225) / 3.0
|
||||
self.normalize = transforms.Normalize(
|
||||
mean=[y_mean],
|
||||
std=[y_std]
|
||||
)
|
||||
# # Normalization for Y channel (single channel)
|
||||
# y_mean = (0.485 + 0.456 + 0.406) / 3.0
|
||||
# y_std = (0.229 + 0.224 + 0.225) / 3.0
|
||||
# self.normalize = transforms.Normalize(
|
||||
# mean=[y_mean],
|
||||
# std=[y_std]
|
||||
# )
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
# def __len__(self):
|
||||
# return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Generate random "frames" (noise with temporal correlation)
|
||||
input_frames = []
|
||||
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
|
||||
# def __getitem__(self, idx):
|
||||
# # Generate random "frames" (noise with temporal correlation)
|
||||
# input_frames = []
|
||||
# prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
|
||||
|
||||
for i in range(self.num_frames):
|
||||
# Add some temporal correlation
|
||||
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||
frame = torch.clamp(frame, -1, 1)
|
||||
input_frames.append(self.normalize(frame))
|
||||
prev_frame = frame
|
||||
# for i in range(self.num_frames):
|
||||
# # Add some temporal correlation
|
||||
# frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||
# frame = torch.clamp(frame, -1, 1)
|
||||
# input_frames.append(self.normalize(frame))
|
||||
# prev_frame = frame
|
||||
|
||||
# Target frame (next in sequence)
|
||||
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||
target_frame = torch.clamp(target_frame, -1, 1)
|
||||
target_tensor = self.normalize(target_frame)
|
||||
# # Target frame (next in sequence)
|
||||
# target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
|
||||
# target_frame = torch.clamp(target_frame, -1, 1)
|
||||
# target_tensor = self.normalize(target_frame)
|
||||
|
||||
# Concatenate inputs
|
||||
input_concatenated = torch.cat(input_frames, dim=0)
|
||||
# # Concatenate inputs
|
||||
# input_concatenated = torch.cat(input_frames, dim=0)
|
||||
|
||||
# Temporal index
|
||||
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
||||
# # Temporal index
|
||||
# temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
|
||||
|
||||
return input_concatenated, target_tensor, temporal_idx
|
||||
# return input_concatenated, target_tensor, temporal_idx
|
||||
Reference in New Issue
Block a user