Initial release of SwiftFormer

This commit is contained in:
amshaker
2023-03-26 23:31:59 +04:00
commit 574907f49b
19 changed files with 1751 additions and 0 deletions

6
util/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
import util.utils as utils
from .datasets import build_dataset
from .engine import train_one_epoch, evaluate
from .losses import DistillationLoss
from .samplers import RASampler

120
util/datasets.py Normal file
View File

@@ -0,0 +1,120 @@
import os
import json
from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader
import torch
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform
class INatDataset(ImageFolder):
def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, category='name',
loader=default_loader):
super().__init__(root, transform, target_transform, loader)
self.transform = transform
self.loader = loader
self.target_transform = target_transform
self.year = year
# assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
path_json = os.path.join(
root, f'{"train" if train else "val"}{year}.json')
with open(path_json) as json_file:
data = json.load(json_file)
with open(os.path.join(root, 'categories.json')) as json_file:
data_catg = json.load(json_file)
path_json_for_targeter = os.path.join(root, f"train{year}.json")
with open(path_json_for_targeter) as json_file:
data_for_targeter = json.load(json_file)
targeter = {}
indexer = 0
for elem in data_for_targeter['annotations']:
king = []
king.append(data_catg[int(elem['category_id'])][category])
if king[0] not in targeter.keys():
targeter[king[0]] = indexer
indexer += 1
self.nb_classes = len(targeter)
self.samples = []
for elem in data['images']:
cut = elem['file_name'].split('/')
target_current = int(cut[2])
path_current = os.path.join(root, cut[0], cut[2], cut[3])
categors = data_catg[target_current]
target_current_true = targeter[categors[category]]
self.samples.append((path_current, target_current_true))
# __getitem__ and __len__ inherited from ImageFolder
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
if args.data_set == 'CIFAR':
dataset = datasets.CIFAR100(
args.data_path, train=is_train, transform=transform)
nb_classes = 100
elif args.data_set == 'IMNET':
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif args.data_set == 'FLOWERS':
root = os.path.join(args.data_path, 'train' if is_train else 'test')
dataset = datasets.ImageFolder(root, transform=transform)
if is_train:
dataset = torch.utils.data.ConcatDataset(
[dataset for _ in range(100)])
nb_classes = 102
elif args.data_set == 'INAT':
dataset = INatDataset(args.data_path, train=is_train, year=2018,
category=args.inat_category, transform=transform)
nb_classes = dataset.nb_classes
elif args.data_set == 'INAT19':
dataset = INatDataset(args.data_path, train=is_train, year=2019,
category=args.inat_category, transform=transform)
nb_classes = dataset.nb_classes
else:
raise NotImplementedError
return dataset, nb_classes
def build_transform(is_train, args):
resize_im = args.input_size > 32
if is_train:
# This should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.train_interpolation,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
)
if not resize_im:
# Replace RandomResizedCropAndInterpolation with RandomCrop
transform.transforms[0] = transforms.RandomCrop(
args.input_size, padding=4)
return transform
t = []
if resize_im:
size = int((256 / 224) * args.input_size)
t.append(
# to maintain same ratio w.r.t. 224 images
transforms.Resize(size, interpolation=3),
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
return transforms.Compose(t)

101
util/engine.py Normal file
View File

@@ -0,0 +1,101 @@
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional
import torch
from timm.data import Mixup
from timm.utils import accuracy, ModelEma
from .losses import DistillationLoss
import util.utils as utils
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
clip_grad: float = 0,
clip_mode: str = 'norm',
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
set_training_mode=True):
model.train(set_training_mode)
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(
window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 100
for samples, targets in metric_logger.log_every(
data_loader, print_freq, header):
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
if True: # with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(samples, outputs, targets)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
optimizer.zero_grad()
# This attribute is added by timm on one optimizer (adahessian)
is_second_order = hasattr(
optimizer, 'is_second_order') and optimizer.is_second_order
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters(), create_graph=is_second_order)
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
# Switch to evaluation mode
model.eval()
for images, target in metric_logger.log_every(data_loader, 10, header):
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# Compute output
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# Gather the stats from all processes
metric_logger.synchronize_between_processes()
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
print(output.mean().item(), output.std().item())
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

64
util/losses.py Normal file
View File

@@ -0,0 +1,64 @@
"""
Implements the knowledge distillation loss
"""
import torch
from torch.nn import functional as F
class DistillationLoss(torch.nn.Module):
"""
This module wraps a standard criterion and adds an extra knowledge distillation loss by
taking a teacher model prediction and using it as additional supervision.
"""
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
distillation_type: str, alpha: float, tau: float):
super().__init__()
self.base_criterion = base_criterion
self.teacher_model = teacher_model
assert distillation_type in ['none', 'soft', 'hard']
self.distillation_type = distillation_type
self.alpha = alpha
self.tau = tau
def forward(self, inputs, outputs, labels):
"""
Args:
inputs: The original inputs that are feed to the teacher model
outputs: the outputs of the model to be trained. It is expected to be
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
in the first position and the distillation predictions as the second output
labels: the labels for the base criterion
"""
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
# assume that the model outputs a tuple of [outputs, outputs_kd]
outputs, outputs_kd = outputs
base_loss = self.base_criterion(outputs, labels)
if self.distillation_type == 'none':
return base_loss
if outputs_kd is None:
raise ValueError("When knowledge distillation is enabled, the model is "
"expected to return a Tuple[Tensor, Tensor] with the output of the "
"class_token and the dist_token")
# Don't backprop throught the teacher
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
if self.distillation_type == 'soft':
T = self.tau
# taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
# with slight modifications
distillation_loss = F.kl_div(
F.log_softmax(outputs_kd / T, dim=1),
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum',
log_target=True
) * (T * T) / outputs_kd.numel()
elif self.distillation_type == 'hard':
distillation_loss = F.cross_entropy(
outputs_kd, teacher_outputs.argmax(dim=1))
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
return loss

60
util/samplers.py Normal file
View File

@@ -0,0 +1,60 @@
import torch
import torch.distributed as dist
import math
class RASampler(torch.utils.data.Sampler):
"""Sampler that restricts data loading to a subset of the dataset for distributed,
with repeated augmentation.
It ensures that different each augmented version of a sample will be visible to a
different process (GPU)
Heavily based on torch.utils.data.DistributedSampler
"""
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
if num_replicas is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError(
"Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(
math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.num_selected_samples = int(math.floor(
len(self.dataset) // 256 * 256 / self.num_replicas))
self.shuffle = shuffle
def __iter__(self):
# Deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# Add extra samples to make it evenly divisible
indices = [ele for ele in indices for i in range(3)]
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# Subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices[:self.num_selected_samples])
def __len__(self):
return self.num_selected_samples
def set_epoch(self, epoch):
self.epoch = epoch

280
util/utils.py Normal file
View File

@@ -0,0 +1,280 @@
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import io
import os
import time
from collections import defaultdict, deque
import datetime
import torch
import torch.distributed as dist
import subprocess
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total],
dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def _load_checkpoint_for_ema(model_ema, checkpoint):
"""
Workaround for ModelEma._load_checkpoint to accept an already-loaded object
"""
mem_file = io.BytesIO()
torch.save(checkpoint, mem_file)
mem_file.seek(0)
model_ema._load_checkpoint(mem_file)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
args.dist_url = 'env://'
os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
print('Using distributed mode: 1')
elif 'SLURM_PROCID' in os.environ:
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
addr = subprocess.getoutput(
'scontrol show hostname {} | head -n1'.format(node_list))
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['LOCAL_SIZE'] = str(num_gpus)
args.dist_url = 'env://'
args.world_size = ntasks
args.rank = proc_id
args.gpu = proc_id % num_gpus
print('Using distributed mode: slurm')
print(f"world: {os.environ['WORLD_SIZE']}, rank:{os.environ['RANK']},"
f" local_rank{os.environ['LOCAL_RANK']}, local_size{os.environ['LOCAL_SIZE']}")
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def replace_batchnorm(net):
for child_name, child in net.named_children():
if hasattr(child, 'fuse'):
setattr(net, child_name, child.fuse())
elif isinstance(child, torch.nn.Conv2d):
child.bias = torch.nn.Parameter(torch.zeros(child.weight.size(0)))
elif isinstance(child, torch.nn.BatchNorm2d):
setattr(net, child_name, torch.nn.Identity())
else:
replace_batchnorm(child)
def replace_layernorm(net):
import apex
for child_name, child in net.named_children():
if isinstance(child, torch.nn.LayerNorm):
setattr(net, child_name, apex.normalization.FusedLayerNorm(
child.weight.size(0)))
else:
replace_layernorm(child)