Convert Figma logo to code with AI

facebookresearch logodino

PyTorch code for Vision Transformers training with the Self-Supervised learning method DINO

6,211
901
6,211
109

Top Related Projects

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

24,594

CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more

Quick Overview

DINO (Self-Supervised Vision Transformers) is a self-supervised learning method for visual representations. It leverages Vision Transformers (ViT) and a novel self-distillation approach to learn high-quality image features without using labeled data. DINO demonstrates impressive performance on various downstream tasks and exhibits properties similar to supervised models.

Pros

  • Achieves state-of-the-art performance on self-supervised learning tasks
  • Learns features that are highly transferable to various downstream tasks
  • Requires no labeled data for training, reducing annotation costs
  • Demonstrates emergent properties like segmentation and object discovery

Cons

  • Computationally intensive, requiring significant resources for training
  • May not perform as well on small-scale datasets or specific domain tasks
  • Requires careful hyperparameter tuning for optimal performance
  • Limited interpretability of the learned features compared to some supervised approaches

Code Examples

  1. Loading a pre-trained DINO model:
import torch
import torch.nn as nn
from torchvision import models as torchvision_models

def get_dino_model(arch='vit_small', patch_size=16):
    model = torchvision_models.__dict__[arch](num_classes=0, patch_size=patch_size)
    url = f"https://dl.fbaipublicfiles.com/dino/{arch}{patch_size}_pretrain/dino_{arch}{patch_size}_pretrain.pth"
    state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
    model.load_state_dict(state_dict, strict=True)
    return model

dino_model = get_dino_model()
  1. Extracting features from an image:
from torchvision import transforms
from PIL import Image

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

image = Image.open('path/to/image.jpg')
input_tensor = transform(image).unsqueeze(0)

with torch.no_grad():
    features = dino_model(input_tensor)
  1. Fine-tuning DINO for a classification task:
num_classes = 10
dino_model.head = nn.Linear(dino_model.embed_dim, num_classes)

optimizer = torch.optim.Adam(dino_model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(num_epochs):
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = dino_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

Getting Started

To get started with DINO, follow these steps:

  1. Install the required dependencies:
pip install torch torchvision
  1. Clone the DINO repository:
git clone https://github.com/facebookresearch/dino.git
cd dino
  1. Load a pre-trained DINO model and use it for feature extraction or fine-tuning:
import torch
from torchvision import models as torchvision_models

model = torchvision_models.vit_small(patch_size=16, num_classes=0)
url = "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu')
model.load_state_dict(state_dict, strict=True)

# Use the model for feature extraction or fine-tuning

Competitor Comparisons

Pros of vision_transformer

  • Implements the original Vision Transformer (ViT) architecture
  • Provides pre-trained models for various ViT variants
  • Includes extensive documentation and usage examples

Cons of vision_transformer

  • Limited to ViT architecture, less flexible than DINO
  • Fewer self-supervised learning options compared to DINO
  • Less focus on multi-crop augmentation strategies

Code Comparison

vision_transformer:

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True,
                 representation_size=None, distilled=False, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None, weight_init=''):
        super().__init__()
        # ... (implementation details)

DINO:

class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
        super().__init__()
        nlayers = max(nlayers, 1)
        if nlayers == 1:
            self.mlp = nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim)]
            # ... (implementation details)

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

Pros of transformers

  • Broader scope: Covers a wide range of NLP tasks and models
  • Extensive documentation and community support
  • Regular updates and new model implementations

Cons of transformers

  • Larger codebase, potentially more complex to navigate
  • May have higher computational requirements for some models

Code comparison

DINO (self-supervised vision transformer):

import torch
from dino import utils, vision_transformer as vits

model = vits.__dict__["vit_small"](patch_size=16, num_classes=0)
utils.load_pretrained_weights(model, "path/to/checkpoint.pth", "teacher")

transformers (BERT for sequence classification):

from transformers import BertForSequenceClassification, BertTokenizer

model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
24,594

CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image

Pros of CLIP

  • Multimodal learning: CLIP can understand both images and text, enabling versatile applications
  • Zero-shot learning capabilities: Can classify images into arbitrary categories without fine-tuning
  • Robust performance across diverse datasets and tasks

Cons of CLIP

  • Computationally intensive: Requires significant resources for training and inference
  • Limited interpretability: Complex model architecture makes it challenging to understand decision-making process
  • Potential biases: May inherit biases present in the large-scale training data

Code Comparison

CLIP example:

import torch
from PIL import Image
from clip import clip

model, preprocess = clip.load("ViT-B/32", device="cuda")
image = preprocess(Image.open("image.jpg")).unsqueeze(0).to("cuda")
text = clip.tokenize(["a dog", "a cat"]).to("cuda")

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

DINO example:

import torch
import torchvision.transforms as transforms
from dino import utils, vision_transformer as vits

model = vits.__dict__["vit_small"](patch_size=16, num_classes=0)
utils.load_pretrained_weights(model, "dino_deitsmall16_pretrain.pth", "teacher")
model.eval()

transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])
image = transform(Image.open("image.jpg")).unsqueeze(0)

with torch.no_grad():
    features = model(image)

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".

Pros of Swin-Transformer

  • Hierarchical structure allows for better handling of varying scales in images
  • More efficient for high-resolution images due to local attention mechanism
  • Demonstrates superior performance on various vision tasks, including object detection and semantic segmentation

Cons of Swin-Transformer

  • More complex architecture, potentially harder to implement and fine-tune
  • May require more computational resources for training and inference
  • Less effective for self-supervised learning tasks compared to DINO

Code Comparison

DINO (PyTorch):

class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
        super().__init__()
        nlayers = max(nlayers, 1)
        if nlayers == 1:
            self.mlp = nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim)]
            if use_bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.GELU())
            for _ in range(nlayers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                if use_bn:
                    layers.append(nn.BatchNorm1d(hidden_dim))
                layers.append(nn.GELU())
            layers.append(nn.Linear(hidden_dim, bottleneck_dim))
            self.mlp = nn.Sequential(*layers)
        self.apply(self._init_weights)
        self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        x = nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        return x

Swin-Transformer (PyTorch):

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Pros of vit-pytorch

  • Lightweight and focused implementation of Vision Transformer (ViT)
  • Easy to understand and modify for experimentation
  • Includes additional variants and improvements on the original ViT

Cons of vit-pytorch

  • Less comprehensive feature set compared to DINO
  • May require more manual setup for training and evaluation
  • Limited to ViT architecture, while DINO offers more flexibility

Code Comparison

vit-pytorch:

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

DINO:

model = vits.__dict__['vit_small'](patch_size=16, num_classes=0)
model = utils.MultiCropWrapper(
    model,
    DINOHead(embed_dim, args.out_dim, use_bn=args.use_bn_in_head),
)

The vit-pytorch example shows a more straightforward instantiation of a ViT model, while the DINO example demonstrates its integration with additional components for self-supervised learning.

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more

Pros of pytorch-image-models

  • Extensive collection of pre-trained models and architectures
  • Regular updates and active community support
  • Comprehensive documentation and examples

Cons of pytorch-image-models

  • Less focused on self-supervised learning techniques
  • May require more setup and configuration for specific tasks

Code Comparison

DINO:

model = vit_small(patch_size=16, num_classes=0)
model = utils.MultiCropWrapper(
    model,
    DINOHead(embed_dim, args.out_dim, use_bn=args.use_bn_in_head),
)

pytorch-image-models:

model = timm.create_model('vit_small_patch16_224', pretrained=True)
model.head = nn.Linear(model.head.in_features, num_classes)

Summary

DINO focuses on self-supervised learning for vision transformers, while pytorch-image-models offers a broader range of pre-trained models and architectures. DINO provides a more specialized approach to self-supervised learning, whereas pytorch-image-models offers greater flexibility and a wider variety of models for different computer vision tasks. The code comparison shows that DINO requires more setup for its specific self-supervised learning approach, while pytorch-image-models allows for easier model creation and fine-tuning.

Convert Figma logo designs to code with AI

Visual Copilot

Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.

Try Visual Copilot

README

:new: Please check out our more recent DINOv2 effort in the same line of work.

Self-Supervised Vision Transformers with DINO

PyTorch implementation and pretrained models for DINO. For details, see Emerging Properties in Self-Supervised Vision Transformers.
[blogpost] [arXiv] [Yannic Kilcher's video]

DINO illustration

Pretrained models

You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the backbone in onnx format, as well as detailed arguments and training/evaluation logs. Note that DeiT-S and ViT-S names refer exactly to the same architecture.

arch params k-nn linear download
ViT-S/16 21M 74.5% 77.0% backbone only full ckpt onnx args logs eval logs
ViT-S/8 21M 78.3% 79.7% backbone only full ckpt onnx args logs eval logs
ViT-B/16 85M 76.1% 78.2% backbone only full ckpt onnx args logs eval logs
ViT-B/8 85M 77.4% 80.1% backbone only full ckpt onnx args logs eval logs
ResNet-50 23M 67.5% 75.3% backbone only full ckpt onnx args logs eval logs

We also release XCiT models ([arXiv] [code]) trained with DINO:

arch params k-nn linear download
xcit_small_12_p16 26M 76.0% 77.8% backbone only full ckpt args logs eval
xcit_small_12_p8 26M 77.1% 79.2% backbone only full ckpt args logs eval
xcit_medium_24_p16 84M 76.4% 78.8% backbone only full ckpt args logs eval
xcit_medium_24_p8 84M 77.9% 80.3% backbone only full ckpt args logs eval

Pretrained models on PyTorch Hub

import torch
vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
xcit_small_12_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p16')
xcit_small_12_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p8')
xcit_medium_24_p16 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p16')
xcit_medium_24_p8 = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')
resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')

Training

Documentation

Please install PyTorch and download the ImageNet dataset. This codebase has been developed with python version 3.6, PyTorch version 1.7.1, CUDA 11.0 and torchvision 0.8.2. The exact arguments to reproduce the models presented in our paper can be found in the args column of the pretrained models section. For a glimpse at the full documentation of DINO training please run:

python main_dino.py --help

Vanilla DINO training :sauropod:

Run DINO with ViT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach 69.3% on k-NN eval and 74.0% on linear eval. We provide training and linear evaluation logs (with batch size 256 at evaluation time) for this run to help reproducibility.

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Multi-node training

We use Slurm and submitit (pip install submitit). To train on 2 nodes with 8 GPUs each (total 16 GPUs):

python run_with_submitit.py --nodes 2 --ngpus 8 --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
DINO with ViT-base network.
python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base  --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Boosting DINO performance :t-rex:

You can improve the performance of the vanilla run by:

  • training for more epochs: --epochs 300,
  • increasing the teacher temperature: --teacher_temp 0.07 --warmup_teacher_temp_epochs 30.
  • removing last layer normalization (only safe with --arch vit_small): --norm_last_layer false,
Full command.
python run_with_submitit.py --arch vit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

The resulting pretrained model should reach 73.3% on k-NN eval and 76.0% on linear eval. Training time is 2.6 days with 16 GPUs. We provide training and linear evaluation logs (with batch size 256 at evaluation time) for this run to help reproducibility.

ResNet-50 and other convnets trainings

This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example following is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs. We provide training logs and final checkpoint for this run.

python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch resnet50 --optimizer sgd --lr 0.03 --weight_decay 1e-4 --weight_decay_end 1e-4 --global_crops_scale 0.14 1 --local_crops_scale 0.05 0.14 --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir

Self-attention visualization

You can look at the self-attention of the [CLS] token on the different heads of the last layer by running:

python visualize_attention.py
Self-attention from a Vision Transformer with 8x8 patches trained with DINO

Self-attention video generation

You can generate videos like the one on the blog post with video_generation.py.

https://user-images.githubusercontent.com/46140458/116817761-47885e80-ab68-11eb-9975-d61d5a919e13.mp4

Extract frames from input video and generate attention video:

python video_generation.py  --pretrained_weights dino_deitsmall8_pretrain.pth \
    --input_path input/video.mp4 \
    --output_path output/ \
    --fps 25

Use folder of frames already extracted and generate attention video:

python video_generation.py  --pretrained_weights dino_deitsmall8_pretrain.pth \
    --input_path output/frames/ \
    --output_path output/ \
    --resize 256 \

Only generate video from folder of attention maps images:

python video_generation.py --input_path output/attention \
    --output_path output/ \
    --video_only \
    --video_format avi

Evaluation: k-NN classification on ImageNet

To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:

python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --data_path /path/to/imagenet

If you choose not to specify --pretrained_weights, then DINO reference weights are used by default. If you want instead to evaluate checkpoints from a run of your own, you can run for example:

python -m torch.distributed.launch --nproc_per_node=1 eval_knn.py --pretrained_weights /path/to/checkpoint.pth --checkpoint_key teacher --data_path /path/to/imagenet 

Evaluation: Linear classification on ImageNet

To train a supervised linear classifier on frozen weights on a single node with 8 gpus, run:

python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py --data_path /path/to/imagenet

We release the logs and weights from evaluating the different models:

arch top-1 ImageNet linear evaluation
ViT-S/16 77.0% linear weights logs
ViT-S/8 79.7% linear weights logs
ViT-B/16 78.2% linear weights logs
ViT-B/8 80.1% linear weights logs
xcit_small_12_p16 77.8% linear weights logs
xcit_small_12_p8 79.2% linear weights logs
xcit_medium_24_p16 78.8% linear weights logs
xcit_medium_24_p8 80.3% linear weights logs
ResNet-50 75.3% linear weights logs

You can check the performance of the pretrained weights on ImageNet validation set by running the following command lines:

python eval_linear.py --evaluate --arch vit_small --patch_size 16 --data_path /path/to/imagenet/train
python eval_linear.py --evaluate --arch vit_small --patch_size 8 --data_path /path/to/imagenet/train
python eval_linear.py --evaluate --arch vit_base --patch_size 16 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
python eval_linear.py --evaluate --arch vit_base --patch_size 8 --n_last_blocks 1 --avgpool_patchtokens true --data_path /path/to/imagenet/train
python eval_linear.py --evaluate --arch resnet50 --data_path /path/to/imagenet/train

Evaluation: DAVIS 2017 Video object segmentation

Please verify that you're using pytorch version 1.7.1 since we are not able to reproduce the results with most recent pytorch 1.8.1 at the moment.

Step 1: Prepare DAVIS 2017 data

cd $HOME
git clone https://github.com/davisvideochallenge/davis-2017 && cd davis-2017
./data/get_davis.sh

Step 2: Video object segmentation

python eval_video_segmentation.py --data_path $HOME/davis-2017/DAVIS/ --output_dir /path/to/saving_dir

Step 3: Evaluate the obtained segmentation

git clone https://github.com/davisvideochallenge/davis2017-evaluation $HOME/davis2017-evaluation
python $HOME/davis2017-evaluation/evaluation_method.py --task semi-supervised --results_path /path/to/saving_dir --davis_path $HOME/davis-2017/DAVIS/

Evaluation: Image Retrieval on revisited Oxford and Paris

Step 1: Prepare revisited Oxford and Paris by following this repo.

Step 2: Image retrieval (if you do not specify weights with --pretrained_weights then by default DINO weights pretrained on Google Landmark v2 dataset will be used).

Paris:

python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 512 --multiscale 1 --data_path /path/to/revisited_paris_oxford/ --dataset rparis6k

Oxford:

python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 224 --multiscale 0 --data_path /path/to/revisited_paris_oxford/ --dataset roxford5k

Evaluation: Copy detection on Copydays

Step 1: Prepare Copydays dataset.

Step 2 (opt): Prepare a set of image distractors and a set of images on which to learn the whitening operator. In our paper, we use 10k random images from YFCC100M as distractors and 20k random images from YFCC100M (different from the distractors) for computing the whitening operation.

Step 3: Run copy detection:

python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_copy_detection.py --data_path /path/to/copydays/ --whitening_path /path/to/whitening_data/ --distractors_path /path/to/distractors/

We report result on the strong subset. For example in the stdout from the command above we get: eval on strong mAP=0.858.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Citation

If you find this repository useful, please consider giving a star :star: and citation :t-rex::

@inproceedings{caron2021emerging,
  title={Emerging Properties in Self-Supervised Vision Transformers},
  author={Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e  and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
  booktitle={Proceedings of the International Conference on Computer Vision (ICCV)},
  year={2021}
}