#!/usr/bin/env python3 """ Test script for SwiftFormerTemporal model """ import torch import sys import os # Add current directory to path sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from models.swiftformer_temporal import SwiftFormerTemporal_XS def test_model(): print("Testing SwiftFormerTemporal model...") # Create model model = SwiftFormerTemporal_XS(num_frames=3, use_representation_head=True) print(f'Model created: {model.__class__.__name__}') print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}') # Test forward pass batch_size = 2 num_frames = 3 height = width = 224 x = torch.randn(batch_size, 3 * num_frames, height, width) print(f'\nInput shape: {x.shape}') with torch.no_grad(): pred_frame, representation = model(x) print(f'Predicted frame shape: {pred_frame.shape}') print(f'Representation shape: {representation.shape if representation is not None else "None"}') # Check output ranges print(f'\nPredicted frame range: [{pred_frame.min():.3f}, {pred_frame.max():.3f}]') # Test loss function from util.frame_losses import MultiTaskLoss criterion = MultiTaskLoss() target = torch.randn_like(pred_frame) temporal_indices = torch.tensor([3, 3], dtype=torch.long) loss, loss_dict = criterion(pred_frame, target, representation, temporal_indices) print(f'\nLoss test:') for k, v in loss_dict.items(): print(f' {k}: {v:.4f}') print('\nAll tests passed!') return True if __name__ == '__main__': try: test_model() except Exception as e: print(f'Test failed with error: {e}') import traceback traceback.print_exc() sys.exit(1)