Compare commits

..

10 Commits

Author SHA1 Message Date
7e9564ef20 test modify swiftformer to temporal input 2026-01-07 11:03:33 +08:00
Abdelrahman Shaker
4aa6cd6752 Create LICENSE 2025-07-18 16:04:30 +04:00
Abdelrahman Shaker
898d23ca89 Update README.md 2024-01-12 17:00:03 +04:00
Abdelrahman Shaker
3daedbd499 Merge pull request #15 from escorciav/main
Update README.md
2024-01-12 16:41:43 +04:00
Victor Escorcia
28ce806f55 Update README.md
Community drive contributions: SwiftFormer meets Android. Qualcomm S8G2
DSP/HTP hardware, via Qualcomm tooling (QNN). Details in #14. Work done
by @3scorciav . Refer to his fork for details.
2024-01-12 10:27:15 +00:00
Abdelrahman Shaker
9b7df0d145 Merge pull request #12 from ThomasCai/main
Fix the issue when the distillation type is set to none.
2023-11-30 15:41:26 +04:00
caitianren
0ddadad723 Fix this bug when setting distillation-type to none 2023-11-29 20:15:00 +08:00
Abdelrahman Shaker
cd1f854e59 Update README.md 2023-10-02 21:54:23 +02:00
Abdelrahman Shaker
5c9b4ceece Update README.md 2023-08-17 21:23:06 +04:00
Abdelrahman Shaker
7d5ca0c25b Update README.md 2023-08-10 18:54:53 +04:00
9 changed files with 1319 additions and 20 deletions

201
LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -6,7 +6,7 @@
Mohamed Bin Zayed University of Artificial Intelligence<sup>1</sup>, University of California Merced<sup>2</sup>, Google Research<sup>3</sup>, Linkoping University<sup>4</sup>
<!-- [![Website](https://img.shields.io/badge/Project-Website-87CEEB)](site_url) -->
[![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2303.15446)
[![paper](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://openaccess.thecvf.com/content/ICCV2023/papers/Shaker_SwiftFormer_Efficient_Additive_Attention_for_Transformer-based_Real-time_Mobile_Vision_Applications_ICCV_2023_paper.pdf)
<!-- [![video](https://img.shields.io/badge/Video-Presentation-F9D371)](youtube_link) -->
<!-- [![slides](https://img.shields.io/badge/Presentation-Slides-B762C1)](presentation) -->
@@ -64,6 +64,28 @@ Self-attention has become a defacto choice for capturing global context in vario
The latency reported in SwiftFormer for iPhone 14 (iOS 16) uses the benchmark tool from [XCode 14](https://developer.apple.com/videos/play/wwdc2022/10027/).
### SwiftFormer meets Android
Community-driven results with [Samsung Galaxy S23 Ultra, with Qualcomm Snapdragon 8 Gen 2](https://www.qualcomm.com/snapdragon/device-finder/samsung-galaxy-s23-ultra):
1. [Export](https://github.com/escorciav/SwiftFormer/blob/main-v/export.py) & profiler results of [`SwiftFormer_L1`](./models/swiftformer.py):
| QNN | 2.16 | 2.17 | 2.18 |
| -------------- | -----| ----- | ------ |
| Latency (msec) | 2.63 | 2.26 | 2.43 |
2. [Export](https://github.com/escorciav/SwiftFormer/blob/main-v/export_block.py) & profiler results of SwiftFormerEncoder block:
| QNN | 2.16 | 2.17 | 2.18 |
| -------------- | -----| ----- | ------ |
| Latency (msec) | 2.17 | 1.69 | 1.7 |
Refer to the script above for details of the input & block parameters.
_Interested in reproducing the results above?_
Refer to [Issue #14](https://github.com/Amshaker/SwiftFormer/issues/14) for details about [exporting & profiling.](https://github.com/Amshaker/SwiftFormer/issues/14#issuecomment-1883351728)
## ImageNet
### Prerequisites
@@ -78,7 +100,7 @@ pip install timm
pip install coremltools==5.2.0
```
### Data preparation
### Data Preparation
Download and extract ImageNet train and val images from http://image-net.org. The training and validation data are expected to be in the `train` folder and `val` folder respectively:
```
@@ -87,7 +109,7 @@ Download and extract ImageNet train and val images from http://image-net.org. Th
|-- val
```
### Single machine multi-GPU training
### Single-machine multi-GPU training
We provide training script for all models in `dist_train.sh` using PyTorch distributed data parallel (DDP).
@@ -107,7 +129,7 @@ On a Slurm-managed cluster, multi-node training can be launched as
sbatch slurm_train.sh /path/to/imagenet SwiftFormer_XS
```
Note: specify slurm specific paramters in `slurm_train.sh` script.
Note: specify slurm specific parameters in `slurm_train.sh` script.
### Testing
@@ -121,20 +143,22 @@ sh dist_test.sh SwiftFormer_XS 8 weights/SwiftFormer_XS_ckpt.pth
## Citation
if you use our work, please consider citing us:
```BibTeX
@article{Shaker2023SwiftFormer,
title={SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications},
@InProceedings{Shaker_2023_ICCV,
author = {Shaker, Abdelrahman and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz},
journal={arXiv:2303.15446},
year={2023}
title = {SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
year = {2023},
}
```
## Contact:
If you have any question, please create an issue on this repository or contact at abdelrahman.youssief@mbzuai.ac.ae.
If you have any questions, please create an issue on this repository or contact at abdelrahman.youssief@mbzuai.ac.ae.
## Acknowledgement
Our code base is based on [LeViT](https://github.com/facebookresearch/LeViT) and [EfficientFormer](https://github.com/snap-research/EfficientFormer) repositories. We thank authors for their open-source implementation.
Our code base is based on [LeViT](https://github.com/facebookresearch/LeViT) and [EfficientFormer](https://github.com/snap-research/EfficientFormer) repositories. We thank the authors for their open-source implementation.
I'd like to express my sincere appreciation to [Victor Escorcia](https://github.com/escorciav) for measuring & reporting the latency of SwiftFormer on Android (Samsung Galaxy S23 Ultra, with Qualcomm Snapdragon 8 Gen 2). Check [SwiftFormer Meets Android](https://github.com/escorciav/SwiftFormer) for more details!
## Our Related Works

373
main_temporal.py Normal file
View File

@@ -0,0 +1,373 @@
"""
Main training script for SwiftFormerTemporal frame prediction
"""
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
from pathlib import Path
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler, get_state_dict, ModelEma
from util import *
from models import *
from models.swiftformer_temporal import SwiftFormerTemporal_XS, SwiftFormerTemporal_S, SwiftFormerTemporal_L1, SwiftFormerTemporal_L3
from util.video_dataset import VideoFrameDataset, SyntheticVideoDataset
from util.frame_losses import MultiTaskLoss
def get_args_parser():
parser = argparse.ArgumentParser(
'SwiftFormerTemporal training script', add_help=False)
# Dataset parameters
parser.add_argument('--data-path', default='./videos', type=str,
help='Path to video dataset')
parser.add_argument('--dataset-type', default='video', choices=['video', 'synthetic'],
type=str, help='Dataset type')
parser.add_argument('--num-frames', default=3, type=int,
help='Number of input frames (T)')
parser.add_argument('--frame-size', default=224, type=int,
help='Input frame size')
parser.add_argument('--max-interval', default=1, type=int,
help='Maximum interval between consecutive frames')
# Model parameters
parser.add_argument('--model', default='SwiftFormerTemporal_XS', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--use-representation-head', action='store_true',
help='Use representation head for pose/velocity prediction')
parser.add_argument('--representation-dim', default=128, type=int,
help='Dimension of representation vector')
# Training parameters
parser.add_argument('--batch-size', default=32, type=int)
parser.add_argument('--epochs', default=100, type=int)
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
parser.add_argument('--weight-decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
# Loss parameters
parser.add_argument('--frame-weight', type=float, default=1.0,
help='Weight for frame prediction loss')
parser.add_argument('--contrastive-weight', type=float, default=0.1,
help='Weight for contrastive loss')
parser.add_argument('--l1-weight', type=float, default=1.0,
help='Weight for L1 loss')
parser.add_argument('--ssim-weight', type=float, default=0.1,
help='Weight for SSIM loss')
parser.add_argument('--no-contrastive', action='store_true',
help='Disable contrastive loss')
parser.add_argument('--no-ssim', action='store_true',
help='Disable SSIM loss')
# System parameters
parser.add_argument('--output-dir', default='./output',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--num-workers', default=4, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# Distributed training
parser.add_argument('--world-size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist-url', default='env://',
help='url used to set up distributed training')
return parser
def build_dataset(is_train, args):
"""Build video frame dataset"""
if args.dataset_type == 'synthetic':
dataset = SyntheticVideoDataset(
num_samples=1000 if is_train else 200,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train
)
else:
dataset = VideoFrameDataset(
root_dir=args.data_path,
num_frames=args.num_frames,
frame_size=args.frame_size,
is_train=is_train,
max_interval=args.max_interval
)
return dataset
def main(args):
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
# Fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
# Build datasets
dataset_train = build_dataset(is_train=True, args=args)
dataset_val = build_dataset(is_train=False, args=args)
# Create samplers
if args.distributed:
sampler_train = torch.utils.data.DistributedSampler(dataset_train)
sampler_val = torch.utils.data.DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
# Create model
print(f"Creating model: {args.model}")
model_kwargs = {
'num_frames': args.num_frames,
'use_representation_head': args.use_representation_head,
'representation_dim': args.representation_dim,
}
if args.model == 'SwiftFormerTemporal_XS':
model = SwiftFormerTemporal_XS(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_S':
model = SwiftFormerTemporal_S(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L1':
model = SwiftFormerTemporal_L1(**model_kwargs)
elif args.model == 'SwiftFormerTemporal_L3':
model = SwiftFormerTemporal_L3(**model_kwargs)
else:
raise ValueError(f"Unknown model: {args.model}")
model.to(device)
# Model EMA
model_ema = None
if hasattr(args, 'model_ema') and args.model_ema:
model_ema = ModelEma(
model,
decay=args.model_ema_decay if hasattr(args, 'model_ema_decay') else 0.9999,
device='cpu' if hasattr(args, 'model_ema_force_cpu') and args.model_ema_force_cpu else '',
resume='')
# Distributed training
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {n_parameters}')
# Create optimizer
optimizer = create_optimizer(args, model_without_ddp)
# Create loss scaler
loss_scaler = NativeScaler()
# Create scheduler
lr_scheduler, _ = create_scheduler(args, optimizer)
# Create loss function
criterion = MultiTaskLoss(
frame_weight=args.frame_weight,
contrastive_weight=args.contrastive_weight,
l1_weight=args.l1_weight,
ssim_weight=args.ssim_weight,
use_contrastive=not args.no_contrastive
)
# Resume from checkpoint
output_dir = Path(args.output_dir)
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if model_ema is not None:
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
if args.eval:
test_stats = evaluate(data_loader_val, model, criterion, device)
print(f"Test stats: {test_stats}")
return
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
model_ema=model_ema
)
lr_scheduler.step(epoch)
# Save checkpoint
if args.output_dir and (epoch % 10 == 0 or epoch == args.epochs - 1):
checkpoint_path = output_dir / f'checkpoint_epoch{epoch}.pth'
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'model_ema': get_state_dict(model_ema) if model_ema else None,
'scaler': loss_scaler.state_dict(),
'args': args,
}, checkpoint_path)
# Evaluate
if epoch % 5 == 0 or epoch == args.epochs - 1:
test_stats = evaluate(data_loader_val, model, criterion, device)
print(f"Epoch {epoch}: Test stats: {test_stats}")
# Log stats
log_stats = {
**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters
}
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f'Training time {total_time_str}')
def train_one_epoch(model, criterion, data_loader, optimizer, device, epoch, loss_scaler,
clip_grad=0, clip_mode='norm', model_ema=None, **kwargs):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = f'Epoch: [{epoch}]'
print_freq = 10
for input_frames, target_frames, temporal_indices in metric_logger.log_every(
data_loader, print_freq, header):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Forward pass
with torch.cuda.amp.autocast():
pred_frames, representations = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
representations, temporal_indices
)
loss_value = loss.item()
if not torch.isfinite(torch.tensor(loss_value)):
print(f"Loss is {loss_value}, stopping training")
raise ValueError(f"Loss is {loss_value}")
optimizer.zero_grad()
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters())
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
# Update metrics
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, criterion, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
for input_frames, target_frames, temporal_indices in metric_logger.log_every(data_loader, 10, header):
input_frames = input_frames.to(device, non_blocking=True)
target_frames = target_frames.to(device, non_blocking=True)
temporal_indices = temporal_indices.to(device, non_blocking=True)
# Compute output
with torch.cuda.amp.autocast():
pred_frames, representations = model(input_frames)
loss, loss_dict = criterion(
pred_frames, target_frames,
representations, temporal_indices
)
# Update metrics
metric_logger.update(loss=loss.item())
for k, v in loss_dict.items():
metric_logger.update(**{k: v.item() if torch.is_tensor(v) else v})
metric_logger.synchronize_between_processes()
print('* Test stats:', metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'SwiftFormerTemporal training script', parents=[get_args_parser()])
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)

View File

@@ -1 +1,7 @@
from .swiftformer import SwiftFormer_XS, SwiftFormer_S, SwiftFormer_L1, SwiftFormer_L3
from .swiftformer_temporal import (
SwiftFormerTemporal_XS,
SwiftFormerTemporal_S,
SwiftFormerTemporal_L1,
SwiftFormerTemporal_L3
)

View File

@@ -437,7 +437,7 @@ class SwiftFormer(nn.Module):
if not self.training:
cls_out = (cls_out[0] + cls_out[1]) / 2
else:
cls_out = self.head(x.mean(-2))
cls_out = self.head(x.flatten(2).mean(-1))
# For image classification
return cls_out

View File

@@ -0,0 +1,244 @@
"""
SwiftFormerTemporal: Temporal extension of SwiftFormer for frame prediction
"""
import torch
import torch.nn as nn
from .swiftformer import (
SwiftFormer, SwiftFormer_depth, SwiftFormer_width,
stem, Embedding, Stage
)
from timm.models.layers import DropPath, trunc_normal_
class DecoderBlock(nn.Module):
"""Upsampling block for frame prediction decoder"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1):
super().__init__()
self.conv = nn.ConvTranspose2d(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class FramePredictionDecoder(nn.Module):
"""Lightweight decoder for frame prediction with optional skip connections"""
def __init__(self, embed_dims, output_channels=3, use_skip=False):
super().__init__()
self.use_skip = use_skip
# Reverse the embed_dims for decoder
decoder_dims = embed_dims[::-1]
self.blocks = nn.ModuleList()
# First upsampling from bottleneck to stage4 resolution
self.blocks.append(DecoderBlock(
decoder_dims[0], decoder_dims[1],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage4 to stage3
self.blocks.append(DecoderBlock(
decoder_dims[1], decoder_dims[2],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage3 to stage2
self.blocks.append(DecoderBlock(
decoder_dims[2], decoder_dims[3],
kernel_size=3, stride=2, padding=1, output_padding=1
))
# stage2 to original resolution (4x upsampling total)
self.blocks.append(nn.Sequential(
nn.ConvTranspose2d(
decoder_dims[3], 32,
kernel_size=3, stride=2, padding=1, output_padding=1
),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, output_channels, kernel_size=3, padding=1),
nn.Tanh() # Output in [-1, 1] range
))
# If using skip connections, we need to adjust input channels for each block
if use_skip:
# We'll modify the first three blocks to accept concatenated features
# Instead of modifying existing blocks, we'll replace them with custom blocks
# For simplicity, we'll keep the same architecture but forward will handle concatenation
pass
def forward(self, x, skip_features=None):
"""
Args:
x: input tensor of shape [B, embed_dims[-1], H/32, W/32]
skip_features: list of encoder features from stages [stage2, stage1, stage0]
each of shape [B, C, H', W'] where C matches decoder dims?
"""
if self.use_skip and skip_features is not None:
# Ensure we have exactly 3 skip features (for the first three blocks)
assert len(skip_features) == 3, "Need 3 skip features for skip connections"
# Reverse skip_features to match decoder order: stage2, stage1, stage0
# skip_features[0] should be stage2 (H/16), [1] stage1 (H/8), [2] stage0 (H/4)
skip_features = skip_features[::-1] # Now index 0: stage2, 1: stage1, 2: stage0
for i, block in enumerate(self.blocks):
if self.use_skip and skip_features is not None and i < 3:
# Concatenate skip feature along channel dimension
# Ensure spatial dimensions match (they should because of upsampling)
x = torch.cat([x, skip_features[i]], dim=1)
# Need to adjust block to accept extra channels? We'll create a separate block.
# For now, we'll just pass through, but this will cause channel mismatch.
# Instead, we should have created custom blocks with appropriate in_channels.
# This is a placeholder; we need to implement properly.
pass
x = block(x)
return x
class SwiftFormerTemporal(nn.Module):
"""
SwiftFormer with temporal input for frame prediction.
Input: [B, num_frames, H, W] (Y channel only)
Output: predicted frame [B, 3, H, W] and optional representation
"""
def __init__(self,
model_name='XS',
num_frames=3,
use_decoder=True,
use_representation_head=False,
representation_dim=128,
return_features=False,
**kwargs):
super().__init__()
# Get model configuration
layers = SwiftFormer_depth[model_name]
embed_dims = SwiftFormer_width[model_name]
# Store configuration
self.num_frames = num_frames
self.use_decoder = use_decoder
self.use_representation_head = use_representation_head
self.return_features = return_features
# Modify stem to accept multiple frames (only Y channel)
in_channels = num_frames
self.patch_embed = stem(in_channels, embed_dims[0])
# Build encoder network (same as SwiftFormer)
network = []
for i in range(len(layers)):
stage = Stage(embed_dims[i], i, layers, mlp_ratio=4,
act_layer=nn.GELU,
drop_rate=0., drop_path_rate=0.,
use_layer_scale=True,
layer_scale_init_value=1e-5,
vit_num=1)
network.append(stage)
if i >= len(layers) - 1:
break
if embed_dims[i] != embed_dims[i + 1]:
network.append(
Embedding(
patch_size=3, stride=2, padding=1,
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
)
)
self.network = nn.ModuleList(network)
self.norm = nn.BatchNorm2d(embed_dims[-1])
# Frame prediction decoder
if use_decoder:
self.decoder = FramePredictionDecoder(embed_dims, output_channels=3)
# Representation head for pose/velocity prediction
if use_representation_head:
self.representation_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(embed_dims[-1], representation_dim),
nn.ReLU(),
nn.Linear(representation_dim, representation_dim)
)
else:
self.representation_head = None
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward_tokens(self, x):
"""Forward through encoder network, return list of stage features if return_features else final output"""
if self.return_features:
features = []
for idx, block in enumerate(self.network):
x = block(x)
# Collect output after each stage (indices 0,2,4,6 correspond to stages)
if idx in [0, 2, 4, 6]:
features.append(x)
return x, features
else:
for block in self.network:
x = block(x)
return x
def forward(self, x):
"""
Args:
x: input frames of shape [B, num_frames, H, W]
Returns:
If return_features is False:
pred_frame: predicted frame [B, 3, H, W] (or None)
representation: optional representation vector [B, representation_dim] (or None)
If return_features is True:
pred_frame, representation, features (list of stage features)
"""
# Encode
x = self.patch_embed(x)
if self.return_features:
x, features = self.forward_tokens(x)
else:
x = self.forward_tokens(x)
x = self.norm(x)
# Get representation if needed
representation = None
if self.representation_head is not None:
representation = self.representation_head(x)
# Decode to frame
pred_frame = None
if self.use_decoder:
pred_frame = self.decoder(x)
if self.return_features:
return pred_frame, representation, features
else:
return pred_frame, representation
# Factory functions for different model sizes
def SwiftFormerTemporal_XS(num_frames=3, **kwargs):
return SwiftFormerTemporal('XS', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_S(num_frames=3, **kwargs):
return SwiftFormerTemporal('S', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_L1(num_frames=3, **kwargs):
return SwiftFormerTemporal('l1', num_frames=num_frames, **kwargs)
def SwiftFormerTemporal_L3(num_frames=3, **kwargs):
return SwiftFormerTemporal('l3', num_frames=num_frames, **kwargs)

60
test_model.py Normal file
View File

@@ -0,0 +1,60 @@
#!/usr/bin/env python3
"""
Test script for SwiftFormerTemporal model
"""
import torch
import sys
import os
# Add current directory to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from models.swiftformer_temporal import SwiftFormerTemporal_XS
def test_model():
print("Testing SwiftFormerTemporal model...")
# Create model
model = SwiftFormerTemporal_XS(num_frames=3, use_representation_head=True)
print(f'Model created: {model.__class__.__name__}')
print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):,}')
# Test forward pass
batch_size = 2
num_frames = 3
height = width = 224
x = torch.randn(batch_size, 3 * num_frames, height, width)
print(f'\nInput shape: {x.shape}')
with torch.no_grad():
pred_frame, representation = model(x)
print(f'Predicted frame shape: {pred_frame.shape}')
print(f'Representation shape: {representation.shape if representation is not None else "None"}')
# Check output ranges
print(f'\nPredicted frame range: [{pred_frame.min():.3f}, {pred_frame.max():.3f}]')
# Test loss function
from util.frame_losses import MultiTaskLoss
criterion = MultiTaskLoss()
target = torch.randn_like(pred_frame)
temporal_indices = torch.tensor([3, 3], dtype=torch.long)
loss, loss_dict = criterion(pred_frame, target, representation, temporal_indices)
print(f'\nLoss test:')
for k, v in loss_dict.items():
print(f' {k}: {v:.4f}')
print('\nAll tests passed!')
return True
if __name__ == '__main__':
try:
test_model()
except Exception as e:
print(f'Test failed with error: {e}')
import traceback
traceback.print_exc()
sys.exit(1)

182
util/frame_losses.py Normal file
View File

@@ -0,0 +1,182 @@
"""
Loss functions for frame prediction and representation learning
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SSIMLoss(nn.Module):
"""
Structural Similarity Index Measure Loss
Based on: https://github.com/Po-Hsun-Su/pytorch-ssim
"""
def __init__(self, window_size=11, size_average=True):
super().__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 3
self.window = self.create_window(window_size, self.channel)
def create_window(self, window_size, channel):
def gaussian(window_size, sigma):
gauss = torch.Tensor([math.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
return window
def forward(self, img1, img2):
# Ensure window is on correct device
if self.window.device != img1.device:
self.window = self.window.to(img1.device)
mu1 = F.conv2d(img1, self.window, padding=self.window_size//2, groups=self.channel)
mu2 = F.conv2d(img2, self.window, padding=self.window_size//2, groups=self.channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1*img1, self.window, padding=self.window_size//2, groups=self.channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, self.window, padding=self.window_size//2, groups=self.channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2)) / ((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if self.size_average:
return 1 - ssim_map.mean()
else:
return 1 - ssim_map.mean(1).mean(1).mean(1)
class FramePredictionLoss(nn.Module):
"""
Combined loss for frame prediction
"""
def __init__(self, l1_weight=1.0, ssim_weight=0.1, use_ssim=True):
super().__init__()
self.l1_weight = l1_weight
self.ssim_weight = ssim_weight
self.use_ssim = use_ssim
self.l1_loss = nn.L1Loss()
if use_ssim:
self.ssim_loss = SSIMLoss()
def forward(self, pred, target):
"""
Args:
pred: predicted frame [B, 3, H, W] in range [-1, 1]
target: target frame [B, 3, H, W] in range [-1, 1]
Returns:
total_loss, loss_dict
"""
loss_dict = {}
# L1 loss
l1_loss = self.l1_loss(pred, target)
loss_dict['l1'] = l1_loss
total_loss = self.l1_weight * l1_loss
# SSIM loss
if self.use_ssim:
ssim_loss = self.ssim_loss(pred, target)
loss_dict['ssim'] = ssim_loss
total_loss += self.ssim_weight * ssim_loss
loss_dict['total'] = total_loss
return total_loss, loss_dict
class ContrastiveLoss(nn.Module):
"""
Contrastive loss for representation learning
Positive pairs: representations from adjacent frames
Negative pairs: representations from distant frames
"""
def __init__(self, temperature=0.1, margin=1.0):
super().__init__()
self.temperature = temperature
self.margin = margin
self.cosine_similarity = nn.CosineSimilarity(dim=-1)
def forward(self, representations, temporal_indices):
"""
Args:
representations: [B, D] representation vectors
temporal_indices: [B] temporal indices of each sample
Returns:
contrastive_loss
"""
batch_size = representations.size(0)
# Compute similarity matrix
sim_matrix = torch.matmul(representations, representations.T) / self.temperature
# Create positive mask (adjacent frames)
indices_expanded = temporal_indices.unsqueeze(0)
diff = torch.abs(indices_expanded - indices_expanded.T)
positive_mask = (diff == 1).float()
# Create negative mask (distant frames)
negative_mask = (diff > 2).float()
# Positive loss
pos_sim = sim_matrix * positive_mask
pos_loss = -torch.log(torch.exp(pos_sim) / torch.exp(sim_matrix).sum(dim=-1, keepdim=True) + 1e-8)
pos_loss = (pos_loss * positive_mask).sum() / (positive_mask.sum() + 1e-8)
# Negative loss (push apart)
neg_sim = sim_matrix * negative_mask
neg_loss = torch.relu(neg_sim - self.margin).mean()
return pos_loss + 0.1 * neg_loss
class MultiTaskLoss(nn.Module):
"""
Multi-task loss combining frame prediction and representation learning
"""
def __init__(self, frame_weight=1.0, contrastive_weight=0.1,
l1_weight=1.0, ssim_weight=0.1, use_contrastive=True):
super().__init__()
self.frame_weight = frame_weight
self.contrastive_weight = contrastive_weight
self.use_contrastive = use_contrastive
self.frame_loss = FramePredictionLoss(l1_weight=l1_weight, ssim_weight=ssim_weight)
if use_contrastive:
self.contrastive_loss = ContrastiveLoss()
def forward(self, pred_frame, target_frame, representations=None, temporal_indices=None):
"""
Args:
pred_frame: predicted frame [B, 3, H, W]
target_frame: target frame [B, 3, H, W]
representations: [B, D] representation vectors (optional)
temporal_indices: [B] temporal indices (optional)
Returns:
total_loss, loss_dict
"""
loss_dict = {}
# Frame prediction loss
frame_loss, frame_loss_dict = self.frame_loss(pred_frame, target_frame)
loss_dict.update({f'frame_{k}': v for k, v in frame_loss_dict.items()})
total_loss = self.frame_weight * frame_loss
# Contrastive loss (if representations provided)
if self.use_contrastive and representations is not None and temporal_indices is not None:
contrastive_loss = self.contrastive_loss(representations, temporal_indices)
loss_dict['contrastive'] = contrastive_loss
total_loss += self.contrastive_weight * contrastive_loss
loss_dict['total'] = total_loss
return total_loss, loss_dict

209
util/video_dataset.py Normal file
View File

@@ -0,0 +1,209 @@
"""
Video frame dataset for temporal self-supervised learning
"""
import os
import random
from pathlib import Path
from typing import Optional, Tuple, List
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
class VideoFrameDataset(Dataset):
"""
Dataset for loading consecutive frames from videos for frame prediction.
Assumes directory structure:
dataset_root/
video1/
frame_0001.jpg
frame_0002.jpg
...
video2/
...
"""
def __init__(self,
root_dir: str,
num_frames: int = 3,
frame_size: int = 224,
is_train: bool = True,
max_interval: int = 1,
transform=None):
"""
Args:
root_dir: Root directory containing video folders
num_frames: Number of input frames (T)
frame_size: Size to resize frames to
is_train: Whether this is training set (affects augmentation)
max_interval: Maximum interval between consecutive frames
transform: Optional custom transform
"""
self.root_dir = Path(root_dir)
self.num_frames = num_frames
self.frame_size = frame_size
self.is_train = is_train
self.max_interval = max_interval
# Collect all video folders
self.video_folders = []
for item in self.root_dir.iterdir():
if item.is_dir():
self.video_folders.append(item)
if len(self.video_folders) == 0:
raise ValueError(f"No video folders found in {root_dir}")
# Build frame index: list of (video_idx, start_frame_idx)
self.frame_indices = []
for video_idx, video_folder in enumerate(self.video_folders):
# Get all frame files
frame_files = sorted([f for f in video_folder.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
if len(frame_files) < num_frames + 1:
continue # Skip videos with insufficient frames
# Add all possible starting positions
for start_idx in range(len(frame_files) - num_frames):
self.frame_indices.append((video_idx, start_idx))
if len(self.frame_indices) == 0:
raise ValueError("No valid frame sequences found in dataset")
# Default transforms
if transform is None:
self.transform = self._default_transform()
else:
self.transform = transform
# Normalization (ImageNet stats)
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
def _default_transform(self):
"""Default transform with augmentation for training"""
if self.is_train:
return transforms.Compose([
transforms.RandomResizedCrop(self.frame_size, scale=(0.8, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
])
else:
return transforms.Compose([
transforms.Resize(int(self.frame_size * 1.14)),
transforms.CenterCrop(self.frame_size),
])
def _load_frame(self, video_idx: int, frame_idx: int) -> Image.Image:
"""Load a single frame as PIL Image"""
video_folder = self.video_folders[video_idx]
frame_files = sorted([f for f in video_folder.iterdir()
if f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']])
frame_path = frame_files[frame_idx]
return Image.open(frame_path).convert('RGB')
def __len__(self) -> int:
return len(self.frame_indices)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Returns:
input_frames: [3 * num_frames, H, W] concatenated input frames
target_frame: [3, H, W] target frame to predict
temporal_idx: temporal index of target frame (for contrastive loss)
"""
video_idx, start_idx = self.frame_indices[idx]
# Determine frame interval (for temporal augmentation)
interval = random.randint(1, self.max_interval) if self.is_train else 1
# Load input frames
input_frames = []
for i in range(self.num_frames):
frame_idx = start_idx + i * interval
frame = self._load_frame(video_idx, frame_idx)
# Apply transform (same for all frames in sequence)
if self.transform:
frame = self.transform(frame)
input_frames.append(frame)
# Load target frame (next frame after input sequence)
target_idx = start_idx + self.num_frames * interval
target_frame = self._load_frame(video_idx, target_idx)
if self.transform:
target_frame = self.transform(target_frame)
# Convert to tensors and normalize
input_tensors = []
for frame in input_frames:
tensor = transforms.ToTensor()(frame)
tensor = self.normalize(tensor)
input_tensors.append(tensor)
target_tensor = transforms.ToTensor()(target_frame)
target_tensor = self.normalize(target_tensor)
# Concatenate input frames along channel dimension
input_concatenated = torch.cat(input_tensors, dim=0)
# Temporal index (for contrastive loss)
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx
class SyntheticVideoDataset(Dataset):
"""
Synthetic dataset for testing - generates random frames
"""
def __init__(self,
num_samples: int = 1000,
num_frames: int = 3,
frame_size: int = 224,
is_train: bool = True):
self.num_samples = num_samples
self.num_frames = num_frames
self.frame_size = frame_size
self.is_train = is_train
# Normalization
self.normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# Generate random "frames" (noise with temporal correlation)
input_frames = []
prev_frame = torch.randn(3, self.frame_size, self.frame_size) * 0.1
for i in range(self.num_frames):
# Add some temporal correlation
frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
frame = torch.clamp(frame, -1, 1)
input_frames.append(self.normalize(frame))
prev_frame = frame
# Target frame (next in sequence)
target_frame = prev_frame + torch.randn(3, self.frame_size, self.frame_size) * 0.05
target_frame = torch.clamp(target_frame, -1, 1)
target_tensor = self.normalize(target_frame)
# Concatenate inputs
input_concatenated = torch.cat(input_frames, dim=0)
# Temporal index
temporal_idx = torch.tensor(self.num_frames, dtype=torch.long)
return input_concatenated, target_tensor, temporal_idx