Initial release of SwiftFormer
This commit is contained in:
142
README.md
Normal file
142
README.md
Normal file
@@ -0,0 +1,142 @@
|
||||
# SwiftFormer
|
||||
### **SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications**
|
||||
|
||||
[Abdelrahman Shaker](https://scholar.google.com/citations?hl=en&user=eEz4Wu4AAAAJ),
|
||||
[Muhammad Maaz](https://scholar.google.com/citations?user=vTy9Te8AAAAJ&hl=en&authuser=1&oi=sra),
|
||||
[Hanoona Rasheed](https://scholar.google.com/citations?user=yhDdEuEAAAAJ&hl=en&authuser=1&oi=sra),
|
||||
[Salman Khan](https://salman-h-khan.github.io),
|
||||
[Ming-Hsuan Yang](https://scholar.google.com/citations?user=p9-ohHsAAAAJ&hl=en),
|
||||
and [Fahad Shahbaz Khan](https://scholar.google.es/citations?user=zvaeYnUAAAAJ&hl=en)
|
||||
|
||||
<!-- [](site_url) -->
|
||||
[](arxiv_link)
|
||||
<!-- [](youtube_link) -->
|
||||
<!-- [](presentation) -->
|
||||
|
||||
## :rocket: News
|
||||
* **(Mar 27, 2023):** Classification training and evaluation codes along with pre-trained models are released.
|
||||
|
||||
<hr />
|
||||
|
||||
<p align="center">
|
||||
<img src="images/Swiftformer_performance.png" width=60%> <br>
|
||||
Comparison of our SwiftFormer Models with state-of-the-art on ImgeNet-1K. The latency is measured on iPhone 14 Neural Engine (iOS 16).
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="images/attentions_comparison.png" width=99%> <br>
|
||||
</p>
|
||||
<p align="left">
|
||||
Comparison with different self-attention modules. (a) is a typical self-attention. (b) is the transpose self-attention, where the self-attention operation is applied across channel feature dimensions (d×d) instead of the spatial dimension (n×n). (c) is the separable self-attention of MobileViT-v2, it uses element-wise operations to compute the context vector from the interactions of Q and K matrices. Then, the context vector is multiplied by V matrix to produce the final output. (d) Our proposed efficient additive self-attention. Here, the query matrix is multiplied by learnable weights and pooled to produce global queries. Then, the matrix K is element-wise multiplied by the broadcasted global queries, resulting the global context representation.
|
||||
</p>
|
||||
|
||||
<details>
|
||||
<summary>
|
||||
<font size="+1">Abstract</font>
|
||||
</summary>
|
||||
Self-attention has become a defacto choice for capturing global context in various vision applications. However, its quadratic computational complexity with respect to image resolution limits its use in real-time applications, especially for deployment on resource-constrained mobile devices. Although hybrid approaches have been proposed to combine the advantages of convolutions and self-attention for a better speed-accuracy trade-off, the expensive matrix multiplication operations in self-attention remain a bottleneck. In this work, we introduce a novel efficient additive attention mechanism that effectively replaces the quadratic matrix multiplication operations with linear element-wise multiplications. Our design shows that the key-value interaction can be replaced with a linear layer without sacrificing any accuracy. Unlike previous state-of-the-art methods, our efficient formulation of self-attention enables its usage at all stages of the network. Using our proposed efficient additive attention, we build a series of models called "SwiftFormer" which achieves state-of-the-art performance in terms of both accuracy and mobile inference speed. Our small variant achieves 78.5% top-1 ImageNet-1K accuracy with only 0.8~ms latency on iPhone 14, which is more accurate and 2x faster compared to MobileViT-v2.
|
||||
</details>
|
||||
|
||||
<br>
|
||||
|
||||
|
||||
|
||||
## Classification on ImageNet-1K
|
||||
|
||||
### Models
|
||||
|
||||
| Model | Top-1 accuracy | #params | GMACs | Latency | Ckpt | CoreML|
|
||||
|:---------------|:----:|:---:|:--:|:--:|:--:|:--:|
|
||||
| SwiftFormer-XS | 75.7% | 3.5M | 0.4G | 0.7ms | [XS](https://drive.google.com/file/d/15Ils-U96pQePXQXx2MpmaI-yAceFAr2x/view?usp=sharing) | [XS](https://drive.google.com/file/d/1tZVxtbtAZoLLoDc5qqoUGulilksomLeK/view?usp=sharing) |
|
||||
| SwiftFormer-S | 78.5% | 6.1M | 1.0G | 0.8ms | [S](https://drive.google.com/file/d/1_0eWwgsejtS0bWGBQS3gwAtYjXdPRGlu/view?usp=sharing) | [S](https://drive.google.com/file/d/13EOCZmtvbMR2V6UjezSZnbBz2_-59Fva/view?usp=sharing) |
|
||||
| SwiftFormer-L1 | 80.9% | 12.1M | 1.6G | 1.1ms | [L1](https://drive.google.com/file/d/1jlwrwWQ0SQzDRc5adtWIwIut5d1g9EsM/view?usp=sharing) | [L1](https://drive.google.com/file/d/1c3VUsi4q7QQ2ykXVS2d4iCRL478fWF3e/view?usp=sharing) |
|
||||
| SwiftFormer-L3 | 83.0% | 26.5M | 4.0G | 1.9ms | [L3](https://drive.google.com/file/d/1ypBcjx04ShmPYRhhjBRubiVjbExUgSa7/view?usp=sharing) | [L3](https://drive.google.com/file/d/1svahgIjh7da781jHOHjX58mtzCzYXSsJ/view?usp=sharing) |
|
||||
|
||||
|
||||
## Detection and Segmentation Qualitative Results
|
||||
|
||||
<p align="center">
|
||||
<img src="images/detection_seg.png" width=100%> <br>
|
||||
</p>
|
||||
<p align="center">
|
||||
<img src="images/semantic_seg.png" width=100%> <br>
|
||||
</p>
|
||||
|
||||
## Latency Measurement
|
||||
|
||||
The latency reported in SwiftFormer for iPhone 14 (iOS 16) uses the benchmark tool from [XCode 14](https://developer.apple.com/videos/play/wwdc2022/10027/).
|
||||
|
||||
## ImageNet
|
||||
|
||||
### Prerequisites
|
||||
`conda` virtual environment is recommended.
|
||||
|
||||
```shell
|
||||
conda create --name=swiftformer python=3.9
|
||||
conda activate swiftformer
|
||||
|
||||
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
pip install timm
|
||||
```
|
||||
|
||||
### Data preparation
|
||||
|
||||
Download and extract ImageNet train and val images from http://image-net.org. The training and validation data are expected to be in the `train` folder and `val` folder respectively:
|
||||
```
|
||||
|-- /path/to/imagenet/
|
||||
|-- train
|
||||
|-- val
|
||||
```
|
||||
|
||||
### Single machine multi-GPU training
|
||||
|
||||
We provide training script for all models in `dist_train.sh` using PyTorch distributed data parallel (DDP).
|
||||
|
||||
To train SwiftFormer models on an 8-GPU machine:
|
||||
|
||||
```
|
||||
sh dist_train.sh /path/to/imagenet 8
|
||||
```
|
||||
|
||||
Note: specify which model command you want to run in the script. To reproduce the results of the paper, use 16-GPU machine with batch-size of 128 or 8-GPU machine with batch size of 256. Auto Augmentation, CutMix, MixUp are disabled for SwiftFormer-XS only.
|
||||
|
||||
### Multi-node training
|
||||
|
||||
On a Slurm-managed cluster, multi-node training can be launched as
|
||||
|
||||
```
|
||||
sbatch slurm_train.sh /path/to/imagenet SwiftFormer_XS
|
||||
```
|
||||
|
||||
Note: specify slurm specific paramters in `slurm_train.sh` script.
|
||||
|
||||
### Testing
|
||||
|
||||
We provide an example test script `dist_test.sh` using PyTorch distributed data parallel (DDP).
|
||||
For example, to test SwiftFormer-XS on an 8-GPU machine:
|
||||
|
||||
```
|
||||
sh dist_test.sh SwiftFormer_XS 8 weights/SwiftFormer_XS_ckpt.pth
|
||||
```
|
||||
|
||||
## Citation
|
||||
if you use our work, please consider citing us:
|
||||
```BibTeX
|
||||
@article{Shaker2023SwiftFormer,
|
||||
title={SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications},
|
||||
author={Shaker, Abdelrahman and Maaz, Muhammad and Rasheed, Hanoona and Khan, Salman and Yang, Ming-Hsuan and Khan, Fahad Shahbaz},
|
||||
journal={arXiv preprint arXiv:X.X},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
## Contact:
|
||||
If you have any question, please create an issue on this repository or contact at abdelrahman.youssief@mbzuai.ac.ae.
|
||||
|
||||
|
||||
## Acknowledgement
|
||||
Our code base is based on [LeViT](https://github.com/facebookresearch/LeViT) and [EfficientFormer](https://github.com/snap-research/EfficientFormer) repositories. We thank authors for their open-source implementation.
|
||||
|
||||
## Our Related Works
|
||||
|
||||
- EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications, CADL'22, ECCV. [Paper](https://arxiv.org/abs/2206.10589) | [Code](https://github.com/mmaaz60/EdgeNeXt).
|
||||
11
dist_test.sh
Normal file
11
dist_test.sh
Normal file
@@ -0,0 +1,11 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
IMAGENET_PATH=$1
|
||||
MODEL=$2
|
||||
CHECKPOINT=$3
|
||||
nGPUs=$4
|
||||
|
||||
python -m torch.distributed.launch --master_addr="127.0.0.1" --master_port=1234 --nproc_per_node=$nGPUs --use_env main.py --model "$MODEL" \
|
||||
--resume $CHECKPOINT --eval \
|
||||
--data-path "$IMAGENET_PATH" \
|
||||
--output_dir SwiftFormer_test
|
||||
21
dist_train.sh
Normal file
21
dist_train.sh
Normal file
@@ -0,0 +1,21 @@
|
||||
|
||||
#!/usr/bin/env bash
|
||||
|
||||
IMAGENET_PATH=$1
|
||||
nGPUs=$2
|
||||
|
||||
## SwiftFormer-XS
|
||||
python -m torch.distributed.launch --nproc_per_node=$nGPUs --use_env main.py --model SwiftFormer_XS --aa="" --mixup 0 --cutmix 0 --data-path "$IMAGENET_PATH" \
|
||||
--output_dir SwiftFormer_XS_results
|
||||
|
||||
## SwiftFormer-S
|
||||
python -m torch.distributed.launch --nproc_per_node=$nGPUs --use_env main.py --model SwiftFormer_S --data-path "$IMAGENET_PATH" \
|
||||
--output_dir SwiftFormer_S_results
|
||||
|
||||
## SwiftFormer-L1
|
||||
python -m torch.distributed.launch --nproc_per_node=$nGPUs --use_env main.py --model SwiftFormer_L1 --data-path "$IMAGENET_PATH" \
|
||||
--output_dir SwiftFormer_L1_results
|
||||
|
||||
## SwiftFormer-L3
|
||||
python -m torch.distributed.launch --nproc_per_node=$nGPUs --use_env main.py --model SwiftFormer_L3 --data-path "$IMAGENET_PATH" \
|
||||
--output_dir SwiftFormer_L3_results
|
||||
BIN
images/Swiftformer_performance.png
Normal file
BIN
images/Swiftformer_performance.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 477 KiB |
BIN
images/attention_comparison.pdf
Normal file
BIN
images/attention_comparison.pdf
Normal file
Binary file not shown.
BIN
images/attentions_comparison.png
Normal file
BIN
images/attentions_comparison.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 669 KiB |
BIN
images/detection_seg.png
Normal file
BIN
images/detection_seg.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.7 MiB |
BIN
images/semantic_seg.png
Normal file
BIN
images/semantic_seg.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.9 MiB |
412
main.py
Normal file
412
main.py
Normal file
@@ -0,0 +1,412 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from timm.data import Mixup
|
||||
from timm.models import create_model
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
||||
from timm.scheduler import create_scheduler
|
||||
from timm.optim import create_optimizer
|
||||
from timm.utils import NativeScaler, get_state_dict, ModelEma
|
||||
|
||||
from util import *
|
||||
from models import *
|
||||
|
||||
|
||||
def get_args_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
'SwiftFormer training and evaluation script', add_help=False)
|
||||
parser.add_argument('--batch-size', default=128, type=int)
|
||||
parser.add_argument('--epochs', default=300, type=int)
|
||||
|
||||
# Model parameters
|
||||
parser.add_argument('--model', default='SwiftFormer_XS', type=str, metavar='MODEL',
|
||||
help='Name of model to train')
|
||||
parser.add_argument('--input-size', default=224,
|
||||
type=int, help='images input size')
|
||||
|
||||
parser.add_argument('--model-ema', action='store_true')
|
||||
parser.add_argument(
|
||||
'--no-model-ema', action='store_false', dest='model_ema')
|
||||
parser.set_defaults(model_ema=True)
|
||||
parser.add_argument('--model-ema-decay', type=float,
|
||||
default=0.99996, help='')
|
||||
parser.add_argument('--model-ema-force-cpu',
|
||||
action='store_true', default=False, help='')
|
||||
|
||||
# Optimizer parameters
|
||||
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "adamw"')
|
||||
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
||||
help='Optimizer Epsilon (default: 1e-8)')
|
||||
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
||||
help='Optimizer Betas (default: None, use opt default)')
|
||||
parser.add_argument('--clip-grad', type=float, default=0.01, metavar='NORM',
|
||||
help='Clip gradient norm (default: None, no clipping)')
|
||||
parser.add_argument('--clip-mode', type=str, default='agc',
|
||||
help='Gradient clipping mode. One of ("norm", "value", "agc")')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
||||
help='SGD momentum (default: 0.9)')
|
||||
parser.add_argument('--weight-decay', type=float, default=0.025,
|
||||
help='weight decay (default: 0.025)')
|
||||
# Learning rate schedule parameters
|
||||
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
||||
help='LR scheduler (default: "cosine"')
|
||||
parser.add_argument('--lr', type=float, default=2e-3, metavar='LR',
|
||||
help='learning rate (default: 2e-3)')
|
||||
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
||||
help='learning rate noise on/off epoch percentages')
|
||||
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
||||
help='learning rate noise limit percent (default: 0.67)')
|
||||
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
||||
help='learning rate noise std-dev (default: 1.0)')
|
||||
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
|
||||
help='warmup learning rate (default: 1e-6)')
|
||||
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
||||
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
||||
|
||||
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
|
||||
help='epoch interval to decay LR')
|
||||
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
|
||||
help='epochs to warmup LR, if scheduler supports')
|
||||
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
||||
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
||||
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
||||
help='patience epochs for Plateau LR scheduler (default: 10')
|
||||
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
||||
help='LR decay rate (default: 0.1)')
|
||||
|
||||
# Augmentation parameters
|
||||
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
||||
help='Color jitter factor (default: 0.4)')
|
||||
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
||||
help='Use AutoAugment policy. "v0" or "original". " + \
|
||||
"(default: rand-m9-mstd0.5-inc1)'),
|
||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||
help='Label smoothing (default: 0.1)')
|
||||
parser.add_argument('--train-interpolation', type=str, default='bicubic',
|
||||
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
|
||||
|
||||
parser.add_argument('--repeated-aug', action='store_true')
|
||||
parser.add_argument('--no-repeated-aug',
|
||||
action='store_false', dest='repeated_aug')
|
||||
parser.set_defaults(repeated_aug=True)
|
||||
|
||||
# * Random Erase params
|
||||
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
|
||||
help='Random erase prob (default: 0.25)')
|
||||
parser.add_argument('--remode', type=str, default='pixel',
|
||||
help='Random erase mode (default: "pixel")')
|
||||
parser.add_argument('--recount', type=int, default=1,
|
||||
help='Random erase count (default: 1)')
|
||||
parser.add_argument('--resplit', action='store_true', default=False,
|
||||
help='Do not random erase first (clean) augmentation split')
|
||||
|
||||
# * Mixup params
|
||||
parser.add_argument('--mixup', type=float, default=0.8,
|
||||
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
|
||||
parser.add_argument('--cutmix', type=float, default=1.0,
|
||||
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
|
||||
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
|
||||
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
||||
parser.add_argument('--mixup-prob', type=float, default=1.0,
|
||||
help='Probability of performing mixup or cutmix when either/both is enabled')
|
||||
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
|
||||
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
||||
parser.add_argument('--mixup-mode', type=str, default='batch',
|
||||
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
||||
|
||||
# Distillation parameters
|
||||
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
|
||||
help='Name of teacher model to train (default: "regnety_160"')
|
||||
parser.add_argument('--teacher-path', type=str,
|
||||
default='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth')
|
||||
parser.add_argument('--distillation-type', default='hard',
|
||||
choices=['none', 'soft', 'hard'], type=str, help="")
|
||||
parser.add_argument('--distillation-alpha',
|
||||
default=0.5, type=float, help="")
|
||||
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
|
||||
|
||||
# * Finetuning params
|
||||
parser.add_argument('--finetune', default='',
|
||||
help='finetune from checkpoint')
|
||||
|
||||
# Dataset parameters
|
||||
parser.add_argument('--data-path', default='./imagenet', type=str,
|
||||
help='dataset path')
|
||||
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
|
||||
type=str, help='Image Net dataset path')
|
||||
parser.add_argument('--inat-category', default='name',
|
||||
choices=['kingdom', 'phylum', 'class', 'order',
|
||||
'supercategory', 'family', 'genus', 'name'],
|
||||
type=str, help='semantic granularity')
|
||||
|
||||
parser.add_argument('--output_dir', default='',
|
||||
help='path where to save, empty for no saving')
|
||||
parser.add_argument('--device', default='cuda',
|
||||
help='device to use for training / testing')
|
||||
parser.add_argument('--seed', default=0, type=int)
|
||||
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
||||
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
||||
help='start epoch')
|
||||
parser.add_argument('--eval', action='store_true',
|
||||
help='Perform evaluation only')
|
||||
parser.add_argument('--dist-eval', action='store_true',
|
||||
default=False, help='Enabling distributed evaluation')
|
||||
parser.add_argument('--num_workers', default=10, type=int)
|
||||
parser.add_argument('--pin-mem', action='store_true',
|
||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
|
||||
help='')
|
||||
parser.set_defaults(pin_mem=True)
|
||||
|
||||
# distributed training parameters
|
||||
parser.add_argument('--world_size', default=1, type=int,
|
||||
help='number of distributed processes')
|
||||
parser.add_argument('--dist_url', default='env://',
|
||||
help='url used to set up distributed training')
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
utils.init_distributed_mode(args)
|
||||
|
||||
print(args)
|
||||
|
||||
if args.distillation_type != 'none' and args.finetune and not args.eval:
|
||||
raise NotImplementedError(
|
||||
"Finetuning with distillation not yet supported")
|
||||
|
||||
device = torch.device(args.device)
|
||||
|
||||
# Fix the seed for reproducibility
|
||||
seed = args.seed + utils.get_rank()
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
|
||||
dataset_val, _ = build_dataset(is_train=False, args=args)
|
||||
|
||||
num_tasks = utils.get_world_size()
|
||||
global_rank = utils.get_rank()
|
||||
if args.repeated_aug:
|
||||
sampler_train = RASampler(
|
||||
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||
)
|
||||
else:
|
||||
sampler_train = torch.utils.data.DistributedSampler(
|
||||
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
|
||||
)
|
||||
if args.dist_eval:
|
||||
if len(dataset_val) % num_tasks != 0:
|
||||
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
|
||||
'This will slightly alter validation results as extra duplicate entries are added to achieve '
|
||||
'equal num of samples per-process.')
|
||||
sampler_val = torch.utils.data.DistributedSampler(
|
||||
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
|
||||
else:
|
||||
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
|
||||
|
||||
data_loader_train = torch.utils.data.DataLoader(
|
||||
dataset_train, sampler=sampler_train,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
data_loader_val = torch.utils.data.DataLoader(
|
||||
dataset_val, sampler=sampler_val,
|
||||
batch_size=int(1.5 * args.batch_size),
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=args.pin_mem,
|
||||
drop_last=False
|
||||
)
|
||||
|
||||
mixup_fn = None
|
||||
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
||||
if mixup_active:
|
||||
mixup_fn = Mixup(
|
||||
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
||||
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
||||
label_smoothing=args.smoothing, num_classes=args.nb_classes)
|
||||
|
||||
print(f"Creating model: {args.model}")
|
||||
model = create_model(
|
||||
args.model,
|
||||
num_classes=args.nb_classes,
|
||||
distillation=(args.distillation_type != 'none'),
|
||||
pretrained=args.eval,
|
||||
fuse=args.eval,
|
||||
)
|
||||
|
||||
if args.finetune:
|
||||
if args.finetune.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
args.finetune, map_location='cpu', check_hash=True)
|
||||
else:
|
||||
checkpoint = torch.load(args.finetune, map_location='cpu')
|
||||
|
||||
checkpoint_model = checkpoint['model']
|
||||
state_dict = model.state_dict()
|
||||
for k in ['head.weight', 'head.bias',
|
||||
'head_dist.weight', 'head_dist.bias']:
|
||||
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
||||
print(f"Removing key {k} from pretrained checkpoint")
|
||||
del checkpoint_model[k]
|
||||
|
||||
model.load_state_dict(checkpoint_model, strict=False)
|
||||
|
||||
model.to(device)
|
||||
|
||||
model_ema = None
|
||||
if args.model_ema:
|
||||
# Important to create EMA model after cuda(), DP wrapper, and AMP but
|
||||
# before SyncBN and DDP wrapper
|
||||
model_ema = ModelEma(
|
||||
model,
|
||||
decay=args.model_ema_decay,
|
||||
device='cpu' if args.model_ema_force_cpu else '',
|
||||
resume='')
|
||||
|
||||
model_without_ddp = model
|
||||
if args.distributed:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.gpu])
|
||||
model_without_ddp = model.module
|
||||
n_parameters = sum(p.numel()
|
||||
for p in model.parameters() if p.requires_grad)
|
||||
print('number of params:', n_parameters)
|
||||
|
||||
# better not to scale up lr for AdamW optimizer
|
||||
# linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
|
||||
# args.lr = linear_scaled_lr
|
||||
|
||||
optimizer = create_optimizer(args, model_without_ddp)
|
||||
loss_scaler = NativeScaler()
|
||||
|
||||
lr_scheduler, _ = create_scheduler(args, optimizer)
|
||||
|
||||
if args.mixup > 0.:
|
||||
# smoothing is handled with mixup label transform
|
||||
criterion = SoftTargetCrossEntropy()
|
||||
elif args.smoothing:
|
||||
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
||||
else:
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
teacher_model = None
|
||||
if args.distillation_type != 'none':
|
||||
assert args.teacher_path, 'need to specify teacher-path when using distillation'
|
||||
print(f"Creating teacher model: {args.teacher_model}")
|
||||
teacher_model = create_model(
|
||||
args.teacher_model,
|
||||
pretrained=False,
|
||||
num_classes=args.nb_classes,
|
||||
global_pool='avg',
|
||||
)
|
||||
if args.teacher_path.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
args.teacher_path, map_location='cpu', check_hash=True)
|
||||
else:
|
||||
checkpoint = torch.load(args.teacher_path, map_location='cpu')
|
||||
teacher_model.load_state_dict(checkpoint['model'])
|
||||
teacher_model.to(device)
|
||||
teacher_model.eval()
|
||||
|
||||
# Wrap the criterion in our custom DistillationLoss, which
|
||||
# just dispatches to the original criterion if args.distillation_type is
|
||||
# 'none'
|
||||
criterion = DistillationLoss(
|
||||
criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
|
||||
)
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
if args.resume:
|
||||
if args.resume.startswith('https'):
|
||||
checkpoint = torch.hub.load_state_dict_from_url(
|
||||
args.resume, map_location='cpu', check_hash=True)
|
||||
else:
|
||||
checkpoint = torch.load(args.resume, map_location='cpu')
|
||||
model_without_ddp.load_state_dict(checkpoint['model'])
|
||||
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
||||
args.start_epoch = checkpoint['epoch'] + 1
|
||||
if args.model_ema:
|
||||
utils._load_checkpoint_for_ema(
|
||||
model_ema, checkpoint['model_ema'])
|
||||
if 'scaler' in checkpoint:
|
||||
loss_scaler.load_state_dict(checkpoint['scaler'])
|
||||
if args.eval:
|
||||
test_stats = evaluate(data_loader_val, model, device)
|
||||
print(
|
||||
f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
return
|
||||
|
||||
print(f"Start training for {args.epochs} epochs")
|
||||
start_time = time.time()
|
||||
max_accuracy = 0.0
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
if args.distributed:
|
||||
data_loader_train.sampler.set_epoch(epoch)
|
||||
|
||||
train_stats = train_one_epoch(
|
||||
model, criterion, data_loader_train,
|
||||
optimizer, device, epoch, loss_scaler,
|
||||
args.clip_grad, args.clip_mode, model_ema, mixup_fn,
|
||||
set_training_mode=args.finetune == '' # keep in eval mode during finetuning
|
||||
)
|
||||
|
||||
lr_scheduler.step(epoch)
|
||||
if args.output_dir:
|
||||
checkpoint_paths = [output_dir / 'checkpoint.pth']
|
||||
for checkpoint_path in checkpoint_paths:
|
||||
utils.save_on_master({
|
||||
'model': model_without_ddp.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'lr_scheduler': lr_scheduler.state_dict(),
|
||||
'epoch': epoch,
|
||||
'model_ema': get_state_dict(model_ema),
|
||||
'scaler': loss_scaler.state_dict(),
|
||||
'args': args,
|
||||
}, checkpoint_path)
|
||||
|
||||
if epoch % 20 == 19:
|
||||
test_stats = evaluate(data_loader_val, model, device)
|
||||
print(
|
||||
f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
||||
max_accuracy = max(max_accuracy, test_stats["acc1"])
|
||||
print(f'Max accuracy: {max_accuracy:.2f}%')
|
||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
||||
**{f'test_{k}': v for k, v in test_stats.items()},
|
||||
'epoch': epoch,
|
||||
'n_parameters': n_parameters}
|
||||
else:
|
||||
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
||||
'epoch': epoch,
|
||||
'n_parameters': n_parameters}
|
||||
|
||||
if args.output_dir and utils.is_main_process():
|
||||
with (output_dir / "log.txt").open("a") as f:
|
||||
f.write(json.dumps(log_stats) + "\n")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('Training time {}'.format(total_time_str))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
'SwiftFormer training and evaluation script', parents=[get_args_parser()])
|
||||
args = parser.parse_args()
|
||||
if args.output_dir:
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
main(args)
|
||||
1
models/__init__.py
Normal file
1
models/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .swiftformer import SwiftFormer_XS, SwiftFormer_S, SwiftFormer_L1, SwiftFormer_L3
|
||||
507
models/swiftformer.py
Normal file
507
models/swiftformer.py
Normal file
@@ -0,0 +1,507 @@
|
||||
"""
|
||||
SwiftFormer
|
||||
"""
|
||||
import os
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.models.layers import DropPath, trunc_normal_
|
||||
from timm.models.registry import register_model
|
||||
from timm.models.layers.helpers import to_2tuple
|
||||
import einops
|
||||
|
||||
SwiftFormer_width = {
|
||||
'XS': [48, 56, 112, 220],
|
||||
'S': [48, 64, 168, 224],
|
||||
'l1': [48, 96, 192, 384],
|
||||
'l3': [64, 128, 320, 512],
|
||||
}
|
||||
|
||||
SwiftFormer_depth = {
|
||||
'XS': [3, 3, 6, 4],
|
||||
'S': [3, 3, 9, 6],
|
||||
'l1': [4, 3, 10, 5],
|
||||
'l3': [4, 4, 12, 6],
|
||||
}
|
||||
|
||||
CoreMLConversion = False
|
||||
|
||||
|
||||
def stem(in_chs, out_chs):
|
||||
"""
|
||||
Stem Layer that is implemented by two layers of conv.
|
||||
Output: sequence of layers with final shape of [B, C, H/4, W/4]
|
||||
"""
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(out_chs // 2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1),
|
||||
nn.BatchNorm2d(out_chs),
|
||||
nn.ReLU(), )
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
"""
|
||||
Patch Embedding that is implemented by a layer of conv.
|
||||
Input: tensor in shape [B, C, H, W]
|
||||
Output: tensor in shape [B, C, H/stride, W/stride]
|
||||
"""
|
||||
|
||||
def __init__(self, patch_size=16, stride=16, padding=0,
|
||||
in_chans=3, embed_dim=768, norm_layer=nn.BatchNorm2d):
|
||||
super().__init__()
|
||||
patch_size = to_2tuple(patch_size)
|
||||
stride = to_2tuple(stride)
|
||||
padding = to_2tuple(padding)
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size,
|
||||
stride=stride, padding=padding)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvEncoder(nn.Module):
|
||||
"""
|
||||
Implementation of ConvEncoder with 3*3 and 1*1 convolutions.
|
||||
Input: tensor with shape [B, C, H, W]
|
||||
Output: tensor with shape [B, C, H, W]
|
||||
"""
|
||||
|
||||
def __init__(self, dim, hidden_dim=64, kernel_size=3, drop_path=0., use_layer_scale=True):
|
||||
super().__init__()
|
||||
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
|
||||
self.norm = nn.BatchNorm2d(dim)
|
||||
self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1)
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||
else nn.Identity()
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
input = x
|
||||
x = self.dwconv(x)
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
if self.use_layer_scale:
|
||||
x = input + self.drop_path(self.layer_scale * x)
|
||||
else:
|
||||
x = input + self.drop_path(x)
|
||||
return x
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""
|
||||
Implementation of MLP layer with 1*1 convolutions.
|
||||
Input: tensor with shape [B, C, H, W]
|
||||
Output: tensor with shape [B, C, H, W]
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None,
|
||||
out_features=None, act_layer=nn.GELU, drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.norm1 = nn.BatchNorm2d(in_features)
|
||||
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm1(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class EfficientAdditiveAttnetion(nn.Module):
|
||||
"""
|
||||
Efficient Additive Attention module for SwiftFormer.
|
||||
Input: tensor in shape [B, C, H, W]
|
||||
Output: tensor in shape [B, C, H, W]
|
||||
"""
|
||||
|
||||
def __init__(self, in_dims=512, token_dim=256, num_heads=2):
|
||||
super().__init__()
|
||||
|
||||
self.to_query = nn.Linear(in_dims, token_dim * num_heads)
|
||||
self.to_key = nn.Linear(in_dims, token_dim * num_heads)
|
||||
|
||||
self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1))
|
||||
self.scale_factor = token_dim ** -0.5
|
||||
self.Proj = nn.Linear(token_dim * num_heads, token_dim * num_heads)
|
||||
self.final = nn.Linear(token_dim * num_heads, token_dim)
|
||||
|
||||
def forward(self, x):
|
||||
query = self.to_query(x)
|
||||
key = self.to_key(x)
|
||||
|
||||
if not CoreMLConversion:
|
||||
# torch.nn.functional.normalize is not supported by the ANE of iPhone devices.
|
||||
# Using this layer improves the accuracy by ~0.1-0.2%
|
||||
query = torch.nn.functional.normalize(query, dim=-1)
|
||||
key = torch.nn.functional.normalize(key, dim=-1)
|
||||
|
||||
query_weight = query @ self.w_g
|
||||
A = query_weight * self.scale_factor
|
||||
|
||||
A = A.softmax(dim=-1)
|
||||
|
||||
G = torch.sum(A * query, dim=1)
|
||||
|
||||
G = einops.repeat(
|
||||
G, "b d -> b repeat d", repeat=key.shape[1]
|
||||
)
|
||||
|
||||
out = self.Proj(G * key) + query
|
||||
|
||||
out = self.final(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SwiftFormerLocalRepresentation(nn.Module):
|
||||
"""
|
||||
Local Representation module for SwiftFormer that is implemented by 3*3 depth-wise and point-wise convolutions.
|
||||
Input: tensor in shape [B, C, H, W]
|
||||
Output: tensor in shape [B, C, H, W]
|
||||
"""
|
||||
|
||||
def __init__(self, dim, kernel_size=3, drop_path=0., use_layer_scale=True):
|
||||
super().__init__()
|
||||
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
|
||||
self.norm = nn.BatchNorm2d(dim)
|
||||
self.pwconv1 = nn.Conv2d(dim, dim, kernel_size=1)
|
||||
self.act = nn.GELU()
|
||||
self.pwconv2 = nn.Conv2d(dim, dim, kernel_size=1)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||
else nn.Identity()
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale = nn.Parameter(torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Conv2d):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
input = x
|
||||
x = self.dwconv(x)
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.pwconv2(x)
|
||||
if self.use_layer_scale:
|
||||
x = input + self.drop_path(self.layer_scale * x)
|
||||
else:
|
||||
x = input + self.drop_path(x)
|
||||
return x
|
||||
|
||||
|
||||
class SwiftFormerEncoder(nn.Module):
|
||||
"""
|
||||
SwiftFormer Encoder Block for SwiftFormer. It consists of (1) Local representation module, (2) EfficientAdditiveAttention, and (3) MLP block.
|
||||
Input: tensor in shape [B, C, H, W]
|
||||
Output: tensor in shape [B, C, H, W]
|
||||
"""
|
||||
|
||||
def __init__(self, dim, mlp_ratio=4.,
|
||||
act_layer=nn.GELU,
|
||||
drop=0., drop_path=0.,
|
||||
use_layer_scale=True, layer_scale_init_value=1e-5):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.local_representation = SwiftFormerLocalRepresentation(dim=dim, kernel_size=3, drop_path=0.,
|
||||
use_layer_scale=True)
|
||||
self.attn = EfficientAdditiveAttnetion(in_dims=dim, token_dim=dim, num_heads=1)
|
||||
self.linear = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. \
|
||||
else nn.Identity()
|
||||
self.use_layer_scale = use_layer_scale
|
||||
if use_layer_scale:
|
||||
self.layer_scale_1 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
|
||||
self.layer_scale_2 = nn.Parameter(
|
||||
layer_scale_init_value * torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.local_representation(x)
|
||||
B, C, H, W = x.shape
|
||||
if self.use_layer_scale:
|
||||
x = x + self.drop_path(
|
||||
self.layer_scale_1 * self.attn(x.permute(0, 2, 3, 1).reshape(B, H * W, C)).reshape(B, H, W, C).permute(
|
||||
0, 3, 1, 2))
|
||||
x = x + self.drop_path(self.layer_scale_2 * self.linear(x))
|
||||
|
||||
else:
|
||||
x = x + self.drop_path(
|
||||
self.attn(x.permute(0, 2, 3, 1).reshape(B, H * W, C)).reshape(B, H, W, C).permute(0, 3, 1, 2))
|
||||
x = x + self.drop_path(self.linear(x))
|
||||
return x
|
||||
|
||||
|
||||
def Stage(dim, index, layers, mlp_ratio=4.,
|
||||
act_layer=nn.GELU,
|
||||
drop_rate=.0, drop_path_rate=0.,
|
||||
use_layer_scale=True, layer_scale_init_value=1e-5, vit_num=1):
|
||||
"""
|
||||
Implementation of each SwiftFormer stages. Here, SwiftFormerEncoder used as the last block in all stages, while ConvEncoder used in the rest of the blocks.
|
||||
Input: tensor in shape [B, C, H, W]
|
||||
Output: tensor in shape [B, C, H, W]
|
||||
"""
|
||||
blocks = []
|
||||
|
||||
for block_idx in range(layers[index]):
|
||||
block_dpr = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
|
||||
|
||||
if layers[index] - block_idx <= vit_num:
|
||||
blocks.append(SwiftFormerEncoder(
|
||||
dim, mlp_ratio=mlp_ratio,
|
||||
act_layer=act_layer, drop_path=block_dpr,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value))
|
||||
|
||||
else:
|
||||
blocks.append(ConvEncoder(dim=dim, hidden_dim=int(mlp_ratio * dim), kernel_size=3))
|
||||
|
||||
blocks = nn.Sequential(*blocks)
|
||||
return blocks
|
||||
|
||||
|
||||
class SwiftFormer(nn.Module):
|
||||
|
||||
def __init__(self, layers, embed_dims=None,
|
||||
mlp_ratios=4, downsamples=None,
|
||||
act_layer=nn.GELU,
|
||||
num_classes=1000,
|
||||
down_patch_size=3, down_stride=2, down_pad=1,
|
||||
drop_rate=0., drop_path_rate=0.,
|
||||
use_layer_scale=True, layer_scale_init_value=1e-5,
|
||||
fork_feat=False,
|
||||
init_cfg=None,
|
||||
pretrained=None,
|
||||
vit_num=1,
|
||||
distillation=True,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
if not fork_feat:
|
||||
self.num_classes = num_classes
|
||||
self.fork_feat = fork_feat
|
||||
|
||||
self.patch_embed = stem(3, embed_dims[0])
|
||||
|
||||
network = []
|
||||
for i in range(len(layers)):
|
||||
stage = Stage(embed_dims[i], i, layers, mlp_ratio=mlp_ratios,
|
||||
act_layer=act_layer,
|
||||
drop_rate=drop_rate,
|
||||
drop_path_rate=drop_path_rate,
|
||||
use_layer_scale=use_layer_scale,
|
||||
layer_scale_init_value=layer_scale_init_value,
|
||||
vit_num=vit_num)
|
||||
network.append(stage)
|
||||
if i >= len(layers) - 1:
|
||||
break
|
||||
if downsamples[i] or embed_dims[i] != embed_dims[i + 1]:
|
||||
# downsampling between two stages
|
||||
network.append(
|
||||
Embedding(
|
||||
patch_size=down_patch_size, stride=down_stride,
|
||||
padding=down_pad,
|
||||
in_chans=embed_dims[i], embed_dim=embed_dims[i + 1]
|
||||
)
|
||||
)
|
||||
|
||||
self.network = nn.ModuleList(network)
|
||||
|
||||
if self.fork_feat:
|
||||
# add a norm layer for each output
|
||||
self.out_indices = [0, 2, 4, 6]
|
||||
for i_emb, i_layer in enumerate(self.out_indices):
|
||||
if i_emb == 0 and os.environ.get('FORK_LAST3', None):
|
||||
layer = nn.Identity()
|
||||
else:
|
||||
layer = nn.BatchNorm2d(embed_dims[i_emb])
|
||||
layer_name = f'norm{i_layer}'
|
||||
self.add_module(layer_name, layer)
|
||||
else:
|
||||
# Classifier head
|
||||
self.norm = nn.BatchNorm2d(embed_dims[-1])
|
||||
self.head = nn.Linear(
|
||||
embed_dims[-1], num_classes) if num_classes > 0 \
|
||||
else nn.Identity()
|
||||
self.dist = distillation
|
||||
if self.dist:
|
||||
self.dist_head = nn.Linear(
|
||||
embed_dims[-1], num_classes) if num_classes > 0 \
|
||||
else nn.Identity()
|
||||
|
||||
# self.apply(self.cls_init_weights)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
self.init_cfg = copy.deepcopy(init_cfg)
|
||||
# load pre-trained model
|
||||
if self.fork_feat and (
|
||||
self.init_cfg is not None or pretrained is not None):
|
||||
self.init_weights()
|
||||
|
||||
# init for mmdetection or mmsegmentation by loading
|
||||
# imagenet pre-trained weights
|
||||
def init_weights(self, pretrained=None):
|
||||
logger = get_root_logger()
|
||||
if self.init_cfg is None and pretrained is None:
|
||||
logger.warn(f'No pre-trained weights for '
|
||||
f'{self.__class__.__name__}, '
|
||||
f'training start from scratch')
|
||||
pass
|
||||
else:
|
||||
assert 'checkpoint' in self.init_cfg, f'Only support ' \
|
||||
f'specify `Pretrained` in ' \
|
||||
f'`init_cfg` in ' \
|
||||
f'{self.__class__.__name__} '
|
||||
if self.init_cfg is not None:
|
||||
ckpt_path = self.init_cfg['checkpoint']
|
||||
elif pretrained is not None:
|
||||
ckpt_path = pretrained
|
||||
|
||||
ckpt = _load_checkpoint(
|
||||
ckpt_path, logger=logger, map_location='cpu')
|
||||
if 'state_dict' in ckpt:
|
||||
_state_dict = ckpt['state_dict']
|
||||
elif 'model' in ckpt:
|
||||
_state_dict = ckpt['model']
|
||||
else:
|
||||
_state_dict = ckpt
|
||||
|
||||
state_dict = _state_dict
|
||||
missing_keys, unexpected_keys = \
|
||||
self.load_state_dict(state_dict, False)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, (nn.LayerNorm)):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def forward_tokens(self, x):
|
||||
outs = []
|
||||
for idx, block in enumerate(self.network):
|
||||
x = block(x)
|
||||
if self.fork_feat and idx in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{idx}')
|
||||
x_out = norm_layer(x)
|
||||
outs.append(x_out)
|
||||
if self.fork_feat:
|
||||
return outs
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.forward_tokens(x)
|
||||
if self.fork_feat:
|
||||
# Output features of four stages for dense prediction
|
||||
return x
|
||||
|
||||
x = self.norm(x)
|
||||
if self.dist:
|
||||
cls_out = self.head(x.flatten(2).mean(-1)), self.dist_head(x.flatten(2).mean(-1))
|
||||
if not self.training:
|
||||
cls_out = (cls_out[0] + cls_out[1]) / 2
|
||||
else:
|
||||
cls_out = self.head(x.mean(-2))
|
||||
# For image classification
|
||||
return cls_out
|
||||
|
||||
|
||||
def _cfg(url='', **kwargs):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
||||
'crop_pct': .95, 'interpolation': 'bicubic',
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'classifier': 'head',
|
||||
**kwargs
|
||||
}
|
||||
|
||||
|
||||
@register_model
|
||||
def SwiftFormer_XS(pretrained=False, **kwargs):
|
||||
model = SwiftFormer(
|
||||
layers=SwiftFormer_depth['XS'],
|
||||
embed_dims=SwiftFormer_width['XS'],
|
||||
downsamples=[True, True, True, True],
|
||||
vit_num=1,
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg(crop_pct=0.9)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def SwiftFormer_S(pretrained=False, **kwargs):
|
||||
model = SwiftFormer(
|
||||
layers=SwiftFormer_depth['S'],
|
||||
embed_dims=SwiftFormer_width['S'],
|
||||
downsamples=[True, True, True, True],
|
||||
vit_num=1,
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg(crop_pct=0.9)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def SwiftFormer_L1(pretrained=False, **kwargs):
|
||||
model = SwiftFormer(
|
||||
layers=SwiftFormer_depth['l1'],
|
||||
embed_dims=SwiftFormer_width['l1'],
|
||||
downsamples=[True, True, True, True],
|
||||
vit_num=1,
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg(crop_pct=0.9)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def SwiftFormer_L3(pretrained=False, **kwargs):
|
||||
model = SwiftFormer(
|
||||
layers=SwiftFormer_depth['l3'],
|
||||
embed_dims=SwiftFormer_width['l3'],
|
||||
downsamples=[True, True, True, True],
|
||||
vit_num=1,
|
||||
**kwargs)
|
||||
model.default_cfg = _cfg(crop_pct=0.9)
|
||||
return model
|
||||
3
requirements.txt
Normal file
3
requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
torch==1.11.0+cu113
|
||||
torchvision==0.12.0+cu113
|
||||
timm==0.5.4
|
||||
23
slurm_train.sh
Normal file
23
slurm_train.sh
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/bin/sh
|
||||
#SBATCH --job-name=swiftformer
|
||||
#SBATCH --partition=your_partition
|
||||
#SBATCH --time=48:00:00
|
||||
#SBATCH --nodes=4
|
||||
#SBATCH --ntasks=16
|
||||
#SBATCH --cpus-per-task=16
|
||||
#SBATCH --gres=gpu:4
|
||||
#SBATCH --mem-per-cpu=8000
|
||||
|
||||
IMAGENET_PATH=$1
|
||||
MODEL=$2
|
||||
|
||||
srun python main.py --model "$MODEL" \
|
||||
--data-path "$IMAGENET_PATH" \
|
||||
--batch-size 128 \
|
||||
--epochs 300 \
|
||||
--aa="" --mixup 0 --cutmix 0
|
||||
|
||||
|
||||
## Note: Disable aa, mixup, and cutmix for SwiftFormer-XS only
|
||||
## By default, this script requests total 16 GPUs on 4 nodes. The batch size per gpu is set to 128,
|
||||
## tha sums to 128*16=2048 in total.
|
||||
6
util/__init__.py
Normal file
6
util/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import util.utils as utils
|
||||
from .datasets import build_dataset
|
||||
from .engine import train_one_epoch, evaluate
|
||||
from .losses import DistillationLoss
|
||||
from .samplers import RASampler
|
||||
|
||||
120
util/datasets.py
Normal file
120
util/datasets.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
from torchvision import datasets, transforms
|
||||
from torchvision.datasets.folder import ImageFolder, default_loader
|
||||
import torch
|
||||
|
||||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from timm.data import create_transform
|
||||
|
||||
|
||||
class INatDataset(ImageFolder):
|
||||
def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, category='name',
|
||||
loader=default_loader):
|
||||
super().__init__(root, transform, target_transform, loader)
|
||||
self.transform = transform
|
||||
self.loader = loader
|
||||
self.target_transform = target_transform
|
||||
self.year = year
|
||||
# assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
|
||||
path_json = os.path.join(
|
||||
root, f'{"train" if train else "val"}{year}.json')
|
||||
with open(path_json) as json_file:
|
||||
data = json.load(json_file)
|
||||
|
||||
with open(os.path.join(root, 'categories.json')) as json_file:
|
||||
data_catg = json.load(json_file)
|
||||
|
||||
path_json_for_targeter = os.path.join(root, f"train{year}.json")
|
||||
|
||||
with open(path_json_for_targeter) as json_file:
|
||||
data_for_targeter = json.load(json_file)
|
||||
|
||||
targeter = {}
|
||||
indexer = 0
|
||||
for elem in data_for_targeter['annotations']:
|
||||
king = []
|
||||
king.append(data_catg[int(elem['category_id'])][category])
|
||||
if king[0] not in targeter.keys():
|
||||
targeter[king[0]] = indexer
|
||||
indexer += 1
|
||||
self.nb_classes = len(targeter)
|
||||
|
||||
self.samples = []
|
||||
for elem in data['images']:
|
||||
cut = elem['file_name'].split('/')
|
||||
target_current = int(cut[2])
|
||||
path_current = os.path.join(root, cut[0], cut[2], cut[3])
|
||||
|
||||
categors = data_catg[target_current]
|
||||
target_current_true = targeter[categors[category]]
|
||||
self.samples.append((path_current, target_current_true))
|
||||
|
||||
# __getitem__ and __len__ inherited from ImageFolder
|
||||
|
||||
|
||||
def build_dataset(is_train, args):
|
||||
transform = build_transform(is_train, args)
|
||||
|
||||
if args.data_set == 'CIFAR':
|
||||
dataset = datasets.CIFAR100(
|
||||
args.data_path, train=is_train, transform=transform)
|
||||
nb_classes = 100
|
||||
elif args.data_set == 'IMNET':
|
||||
root = os.path.join(args.data_path, 'train' if is_train else 'val')
|
||||
dataset = datasets.ImageFolder(root, transform=transform)
|
||||
nb_classes = 1000
|
||||
elif args.data_set == 'FLOWERS':
|
||||
root = os.path.join(args.data_path, 'train' if is_train else 'test')
|
||||
dataset = datasets.ImageFolder(root, transform=transform)
|
||||
if is_train:
|
||||
dataset = torch.utils.data.ConcatDataset(
|
||||
[dataset for _ in range(100)])
|
||||
nb_classes = 102
|
||||
elif args.data_set == 'INAT':
|
||||
dataset = INatDataset(args.data_path, train=is_train, year=2018,
|
||||
category=args.inat_category, transform=transform)
|
||||
nb_classes = dataset.nb_classes
|
||||
elif args.data_set == 'INAT19':
|
||||
dataset = INatDataset(args.data_path, train=is_train, year=2019,
|
||||
category=args.inat_category, transform=transform)
|
||||
nb_classes = dataset.nb_classes
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return dataset, nb_classes
|
||||
|
||||
|
||||
def build_transform(is_train, args):
|
||||
resize_im = args.input_size > 32
|
||||
if is_train:
|
||||
# This should always dispatch to transforms_imagenet_train
|
||||
transform = create_transform(
|
||||
input_size=args.input_size,
|
||||
is_training=True,
|
||||
color_jitter=args.color_jitter,
|
||||
auto_augment=args.aa,
|
||||
interpolation=args.train_interpolation,
|
||||
re_prob=args.reprob,
|
||||
re_mode=args.remode,
|
||||
re_count=args.recount,
|
||||
)
|
||||
if not resize_im:
|
||||
# Replace RandomResizedCropAndInterpolation with RandomCrop
|
||||
transform.transforms[0] = transforms.RandomCrop(
|
||||
args.input_size, padding=4)
|
||||
return transform
|
||||
|
||||
t = []
|
||||
if resize_im:
|
||||
size = int((256 / 224) * args.input_size)
|
||||
t.append(
|
||||
# to maintain same ratio w.r.t. 224 images
|
||||
transforms.Resize(size, interpolation=3),
|
||||
)
|
||||
t.append(transforms.CenterCrop(args.input_size))
|
||||
|
||||
t.append(transforms.ToTensor())
|
||||
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
|
||||
return transforms.Compose(t)
|
||||
101
util/engine.py
Normal file
101
util/engine.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Train and eval functions used in main.py
|
||||
"""
|
||||
import math
|
||||
import sys
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from timm.data import Mixup
|
||||
from timm.utils import accuracy, ModelEma
|
||||
|
||||
from .losses import DistillationLoss
|
||||
import util.utils as utils
|
||||
|
||||
|
||||
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
|
||||
data_loader: Iterable, optimizer: torch.optim.Optimizer,
|
||||
device: torch.device, epoch: int, loss_scaler,
|
||||
clip_grad: float = 0,
|
||||
clip_mode: str = 'norm',
|
||||
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
|
||||
set_training_mode=True):
|
||||
model.train(set_training_mode)
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
metric_logger.add_meter('lr', utils.SmoothedValue(
|
||||
window_size=1, fmt='{value:.6f}'))
|
||||
header = 'Epoch: [{}]'.format(epoch)
|
||||
print_freq = 100
|
||||
|
||||
for samples, targets in metric_logger.log_every(
|
||||
data_loader, print_freq, header):
|
||||
samples = samples.to(device, non_blocking=True)
|
||||
targets = targets.to(device, non_blocking=True)
|
||||
|
||||
if mixup_fn is not None:
|
||||
samples, targets = mixup_fn(samples, targets)
|
||||
|
||||
if True: # with torch.cuda.amp.autocast():
|
||||
outputs = model(samples)
|
||||
loss = criterion(samples, outputs, targets)
|
||||
|
||||
loss_value = loss.item()
|
||||
|
||||
if not math.isfinite(loss_value):
|
||||
print("Loss is {}, stopping training".format(loss_value))
|
||||
sys.exit(1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# This attribute is added by timm on one optimizer (adahessian)
|
||||
is_second_order = hasattr(
|
||||
optimizer, 'is_second_order') and optimizer.is_second_order
|
||||
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
|
||||
parameters=model.parameters(), create_graph=is_second_order)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
if model_ema is not None:
|
||||
model_ema.update(model)
|
||||
|
||||
metric_logger.update(loss=loss_value)
|
||||
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||
# gather the stats from all processes
|
||||
metric_logger.synchronize_between_processes()
|
||||
print("Averaged stats:", metric_logger)
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(data_loader, model, device):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
metric_logger = utils.MetricLogger(delimiter=" ")
|
||||
header = 'Test:'
|
||||
|
||||
# Switch to evaluation mode
|
||||
model.eval()
|
||||
|
||||
for images, target in metric_logger.log_every(data_loader, 10, header):
|
||||
images = images.to(device, non_blocking=True)
|
||||
target = target.to(device, non_blocking=True)
|
||||
|
||||
# Compute output
|
||||
with torch.cuda.amp.autocast():
|
||||
output = model(images)
|
||||
loss = criterion(output, target)
|
||||
|
||||
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
||||
|
||||
batch_size = images.shape[0]
|
||||
metric_logger.update(loss=loss.item())
|
||||
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
|
||||
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
|
||||
|
||||
# Gather the stats from all processes
|
||||
metric_logger.synchronize_between_processes()
|
||||
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
|
||||
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
|
||||
print(output.mean().item(), output.std().item())
|
||||
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
64
util/losses.py
Normal file
64
util/losses.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Implements the knowledge distillation loss
|
||||
"""
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class DistillationLoss(torch.nn.Module):
|
||||
"""
|
||||
This module wraps a standard criterion and adds an extra knowledge distillation loss by
|
||||
taking a teacher model prediction and using it as additional supervision.
|
||||
"""
|
||||
|
||||
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
|
||||
distillation_type: str, alpha: float, tau: float):
|
||||
super().__init__()
|
||||
self.base_criterion = base_criterion
|
||||
self.teacher_model = teacher_model
|
||||
assert distillation_type in ['none', 'soft', 'hard']
|
||||
self.distillation_type = distillation_type
|
||||
self.alpha = alpha
|
||||
self.tau = tau
|
||||
|
||||
def forward(self, inputs, outputs, labels):
|
||||
"""
|
||||
Args:
|
||||
inputs: The original inputs that are feed to the teacher model
|
||||
outputs: the outputs of the model to be trained. It is expected to be
|
||||
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
|
||||
in the first position and the distillation predictions as the second output
|
||||
labels: the labels for the base criterion
|
||||
"""
|
||||
outputs_kd = None
|
||||
if not isinstance(outputs, torch.Tensor):
|
||||
# assume that the model outputs a tuple of [outputs, outputs_kd]
|
||||
outputs, outputs_kd = outputs
|
||||
base_loss = self.base_criterion(outputs, labels)
|
||||
if self.distillation_type == 'none':
|
||||
return base_loss
|
||||
|
||||
if outputs_kd is None:
|
||||
raise ValueError("When knowledge distillation is enabled, the model is "
|
||||
"expected to return a Tuple[Tensor, Tensor] with the output of the "
|
||||
"class_token and the dist_token")
|
||||
# Don't backprop throught the teacher
|
||||
with torch.no_grad():
|
||||
teacher_outputs = self.teacher_model(inputs)
|
||||
|
||||
if self.distillation_type == 'soft':
|
||||
T = self.tau
|
||||
# taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
|
||||
# with slight modifications
|
||||
distillation_loss = F.kl_div(
|
||||
F.log_softmax(outputs_kd / T, dim=1),
|
||||
F.log_softmax(teacher_outputs / T, dim=1),
|
||||
reduction='sum',
|
||||
log_target=True
|
||||
) * (T * T) / outputs_kd.numel()
|
||||
elif self.distillation_type == 'hard':
|
||||
distillation_loss = F.cross_entropy(
|
||||
outputs_kd, teacher_outputs.argmax(dim=1))
|
||||
|
||||
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
|
||||
return loss
|
||||
60
util/samplers.py
Normal file
60
util/samplers.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import math
|
||||
|
||||
|
||||
class RASampler(torch.utils.data.Sampler):
|
||||
"""Sampler that restricts data loading to a subset of the dataset for distributed,
|
||||
with repeated augmentation.
|
||||
It ensures that different each augmented version of a sample will be visible to a
|
||||
different process (GPU)
|
||||
Heavily based on torch.utils.data.DistributedSampler
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
"Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
"Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(
|
||||
math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.num_selected_samples = int(math.floor(
|
||||
len(self.dataset) // 256 * 256 / self.num_replicas))
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
# Deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
if self.shuffle:
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
# Add extra samples to make it evenly divisible
|
||||
indices = [ele for ele in indices for i in range(3)]
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# Subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices[:self.num_selected_samples])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_selected_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
280
util/utils.py
Normal file
280
util/utils.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
Misc functions, including distributed helpers.
|
||||
|
||||
Mostly copy-paste from torchvision references.
|
||||
"""
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import subprocess
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not is_dist_avail_and_initialized():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total],
|
||||
dtype=torch.float64, device='cuda')
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value)
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||
type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(
|
||||
"{}: {}".format(name, str(meter))
|
||||
)
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ''
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||
log_msg = [
|
||||
header,
|
||||
'[{0' + space_fmt + '}/{1}]',
|
||||
'eta: {eta}',
|
||||
'{meters}',
|
||||
'time: {time}',
|
||||
'data: {data}'
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
log_msg.append('max mem: {memory:.0f}')
|
||||
log_msg = self.delimiter.join(log_msg)
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB))
|
||||
else:
|
||||
print(log_msg.format(
|
||||
i, len(iterable), eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time), data=str(data_time)))
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print('{} Total time: {} ({:.4f} s / it)'.format(
|
||||
header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
def _load_checkpoint_for_ema(model_ema, checkpoint):
|
||||
"""
|
||||
Workaround for ModelEma._load_checkpoint to accept an already-loaded object
|
||||
"""
|
||||
mem_file = io.BytesIO()
|
||||
torch.save(checkpoint, mem_file)
|
||||
mem_file.seek(0)
|
||||
model_ema._load_checkpoint(mem_file)
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop('force', False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||
args.dist_url = 'env://'
|
||||
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
|
||||
print('Using distributed mode: 1')
|
||||
elif 'SLURM_PROCID' in os.environ:
|
||||
proc_id = int(os.environ['SLURM_PROCID'])
|
||||
ntasks = int(os.environ['SLURM_NTASKS'])
|
||||
node_list = os.environ['SLURM_NODELIST']
|
||||
num_gpus = torch.cuda.device_count()
|
||||
addr = subprocess.getoutput(
|
||||
'scontrol show hostname {} | head -n1'.format(node_list))
|
||||
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
|
||||
os.environ['MASTER_ADDR'] = addr
|
||||
os.environ['WORLD_SIZE'] = str(ntasks)
|
||||
os.environ['RANK'] = str(proc_id)
|
||||
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
||||
os.environ['LOCAL_SIZE'] = str(num_gpus)
|
||||
args.dist_url = 'env://'
|
||||
args.world_size = ntasks
|
||||
args.rank = proc_id
|
||||
args.gpu = proc_id % num_gpus
|
||||
print('Using distributed mode: slurm')
|
||||
print(f"world: {os.environ['WORLD_SIZE']}, rank:{os.environ['RANK']},"
|
||||
f" local_rank{os.environ['LOCAL_RANK']}, local_size{os.environ['LOCAL_SIZE']}")
|
||||
else:
|
||||
print('Not using distributed mode')
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = 'nccl'
|
||||
print('| distributed init (rank {}): {}'.format(
|
||||
args.rank, args.dist_url), flush=True)
|
||||
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
def replace_batchnorm(net):
|
||||
for child_name, child in net.named_children():
|
||||
if hasattr(child, 'fuse'):
|
||||
setattr(net, child_name, child.fuse())
|
||||
elif isinstance(child, torch.nn.Conv2d):
|
||||
child.bias = torch.nn.Parameter(torch.zeros(child.weight.size(0)))
|
||||
elif isinstance(child, torch.nn.BatchNorm2d):
|
||||
setattr(net, child_name, torch.nn.Identity())
|
||||
else:
|
||||
replace_batchnorm(child)
|
||||
|
||||
|
||||
def replace_layernorm(net):
|
||||
import apex
|
||||
for child_name, child in net.named_children():
|
||||
if isinstance(child, torch.nn.LayerNorm):
|
||||
setattr(net, child_name, apex.normalization.FusedLayerNorm(
|
||||
child.weight.size(0)))
|
||||
else:
|
||||
replace_layernorm(child)
|
||||
Reference in New Issue
Block a user