60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for SwiftFormerTemporal model
|
|
"""
|
|
import torch
|
|
import sys
|
|
import os
|
|
|
|
# Add current directory to path
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from models.swiftformer_temporal import SwiftFormerTemporal_XS
|
|
|
|
def test_model():
|
|
print("Testing SwiftFormerTemporal model...")
|
|
|
|
# Create model
|
|
model = SwiftFormerTemporal_XS(num_frames=3, use_representation_head=True)
|
|
print(f'Model created: {model.__class__.__name__}')
|
|
print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}')
|
|
|
|
# Test forward pass
|
|
batch_size = 2
|
|
num_frames = 3
|
|
height = width = 224
|
|
x = torch.randn(batch_size, 3 * num_frames, height, width)
|
|
|
|
print(f'\nInput shape: {x.shape}')
|
|
|
|
with torch.no_grad():
|
|
pred_frame, representation = model(x)
|
|
|
|
print(f'Predicted frame shape: {pred_frame.shape}')
|
|
print(f'Representation shape: {representation.shape if representation is not None else "None"}')
|
|
|
|
# Check output ranges
|
|
print(f'\nPredicted frame range: [{pred_frame.min():.3f}, {pred_frame.max():.3f}]')
|
|
|
|
# Test loss function
|
|
from util.frame_losses import MultiTaskLoss
|
|
criterion = MultiTaskLoss()
|
|
target = torch.randn_like(pred_frame)
|
|
temporal_indices = torch.tensor([3, 3], dtype=torch.long)
|
|
|
|
loss, loss_dict = criterion(pred_frame, target, representation, temporal_indices)
|
|
print(f'\nLoss test:')
|
|
for k, v in loss_dict.items():
|
|
print(f' {k}: {v:.4f}')
|
|
|
|
print('\nAll tests passed!')
|
|
return True
|
|
|
|
if __name__ == '__main__':
|
|
try:
|
|
test_model()
|
|
except Exception as e:
|
|
print(f'Test failed with error: {e}')
|
|
import traceback
|
|
traceback.print_exc()
|
|
sys.exit(1) |