Convert Figma logo to code with AI

google-research logovision_transformer

No description available

10,021
1,258
10,021
126

Top Related Projects

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".

3,999

Official DeiT repository

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

7,138

PyTorch implementation of MAE https//arxiv.org/abs/2111.06377

Quick Overview

The google-research/vision_transformer repository contains the official implementation of the Vision Transformer (ViT) model, as described in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale." This project demonstrates how transformer architectures, traditionally used in natural language processing, can be effectively applied to computer vision tasks.

Pros

  • Achieves state-of-the-art performance on image classification tasks when pre-trained on large datasets
  • Scales well to large datasets and model sizes
  • Provides a novel approach to image processing, potentially opening new avenues for computer vision research
  • Includes pre-trained models and code for easy experimentation and fine-tuning

Cons

  • Requires large amounts of data and computational resources for optimal performance
  • May underperform compared to traditional convolutional neural networks when trained on smaller datasets
  • Limited to image classification tasks in the current implementation
  • Complexity of the transformer architecture may make it challenging for beginners to understand and modify

Code Examples

  1. Loading a pre-trained ViT model:
import tensorflow as tf
from vit_jax import models

model = models.vit_b16(
    num_classes=1000,
    representation_size=None,
    classifier='token'
)
  1. Preprocessing an image for ViT:
from vit_jax import preprocess

image = tf.io.read_file('image.jpg')
image = tf.image.decode_jpeg(image, channels=3)
image = preprocess.preprocess_image(image, 224, 224)
  1. Making predictions with ViT:
logits = model(image[None, ...])
predicted_class = tf.argmax(logits, axis=-1)

Getting Started

To get started with the Vision Transformer:

  1. Clone the repository:

    git clone https://github.com/google-research/vision_transformer.git
    cd vision_transformer
    
  2. Install dependencies:

    pip install -r requirements.txt
    
  3. Download pre-trained weights and run inference:

    from vit_jax import models, preprocess
    import jax.numpy as jnp
    
    model = models.vit_b16(pretrained=True)
    image = preprocess.preprocess_image(your_image, 224, 224)
    logits = model(image[None, ...])
    

For more detailed instructions and examples, refer to the repository's README and documentation.

Competitor Comparisons

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Pros of vit-pytorch

  • More lightweight and focused implementation
  • Easier to integrate into PyTorch projects
  • Includes additional features like distillation and adaptive token sampling

Cons of vit-pytorch

  • Less comprehensive documentation
  • Fewer pre-trained models available
  • May lack some advanced features present in the Google implementation

Code Comparison

vision_transformer:

class VisionTransformer(nn.Module):
  def __init__(self, config: ml_collections.ConfigDict, num_classes: int):
    super().__init__()
    self.num_classes = num_classes
    self.transformer = Transformer(config)
    self.head = nn.Linear(config.hidden_size, num_classes)

vit-pytorch:

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

The vision_transformer implementation is more closely aligned with the original paper, while vit-pytorch offers a more flexible and customizable approach. The vit-pytorch code is generally more concise and easier to understand for those familiar with PyTorch.

The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more

Pros of pytorch-image-models

  • Broader collection of image models and architectures
  • More active development and frequent updates
  • Extensive documentation and examples

Cons of pytorch-image-models

  • Larger codebase, potentially more complex to navigate
  • May include experimental or less stable models

Code Comparison

vision_transformer:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

pytorch-image-models:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

Both repositories implement similar attention mechanisms, with pytorch-image-models including an additional assertion for dimension compatibility. The core functionality remains largely the same, showcasing the common approach to implementing attention in vision transformers across different libraries.

This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".

Pros of Swin-Transformer

  • Hierarchical structure allows for better handling of varying scales in images
  • More efficient computation and memory usage for high-resolution images
  • Achieves state-of-the-art performance on various vision tasks

Cons of Swin-Transformer

  • More complex architecture, potentially harder to implement and fine-tune
  • May require more training data to fully leverage its capabilities
  • Slightly higher computational cost for smaller image sizes

Code Comparison

Vision Transformer:

class VisionTransformer(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim):
        super().__init__()
        self.patch_embed = PatchEmbedding(image_size, patch_size, dim)
        self.transformer = Transformer(dim, depth, heads, mlp_dim)
        self.to_latent = nn.Identity()
        self.mlp_head = nn.Linear(dim, num_classes)

Swin Transformer:

class SwinTransformer(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, num_classes, embed_dim, depths, num_heads):
        super().__init__()
        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
                               input_resolution=(patches_resolution[0] // (2 ** i_layer),
                                                 patches_resolution[1] // (2 ** i_layer)),
                               depth=depths[i_layer],
                               num_heads=num_heads[i_layer])
            self.layers.append(layer)
3,999

Official DeiT repository

Pros of DeiT

  • Offers distillation techniques for more efficient training
  • Includes a teacher-student architecture for knowledge transfer
  • Provides pre-trained models with various configurations

Cons of DeiT

  • More complex implementation due to distillation components
  • May require additional computational resources for distillation process

Code Comparison

Vision Transformer:

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True,
                 representation_size=None, distilled=False, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None, weight_init=''):
        super().__init__()
        # ... (implementation details)

DeiT:

class DistilledVisionTransformer(VisionTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
        # ... (additional implementation details)

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

Pros of transformers

  • Broader scope: Covers various NLP tasks and modalities, not just vision
  • Extensive documentation and community support
  • Regular updates and maintenance

Cons of transformers

  • Larger codebase, potentially more complex to navigate
  • May have higher computational requirements due to its comprehensive nature

Code Comparison

vision_transformer:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

transformers:

class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

Both repositories implement attention mechanisms, but transformers offers a more generalized approach suitable for various transformer architectures, while vision_transformer focuses specifically on vision tasks. The transformers library provides a wider range of pre-trained models and utilities, making it more versatile for different NLP and vision tasks. However, this breadth can make it more complex to use for specific vision-only applications compared to the more focused vision_transformer repository.

7,138

PyTorch implementation of MAE https//arxiv.org/abs/2111.06377

Pros of MAE

  • Implements self-supervised learning, potentially reducing the need for large labeled datasets
  • Focuses on masked autoencoders, which can be more efficient for pretraining
  • Includes PyTorch implementation, offering flexibility for researchers using this framework

Cons of MAE

  • More complex architecture, potentially requiring more computational resources
  • Narrower focus on masked autoencoders, while Vision Transformer covers a broader range of vision transformer applications
  • May require more fine-tuning for specific downstream tasks

Code Comparison

Vision Transformer:

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

MAE:

class MaskedAutoencoderViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3,
                 embed_dim=1024, depth=24, num_heads=16,
                 decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
                 mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
        super().__init__()

Convert Figma logo designs to code with AI

Visual Copilot

Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.

Try Visual Copilot

README

Vision Transformer and MLP-Mixer Architectures

In this repository we release models from the papers

The models were pre-trained on the ImageNet and ImageNet-21k datasets. We provide the code for fine-tuning the released models in JAX/Flax.

The models from this codebase were originally trained in https://github.com/google-research/big_vision/ where you can find more advanced code (e.g. multi-host training), as well as some of the original training scripts (e.g. configs/vit_i21k.py for pre-training a ViT, or configs/transfer.py for transfering a model).

Table of contents:

Colab

Below Colabs run both with GPUs, and TPUs (8 cores, data parallelism).

The first Colab demonstrates the JAX code of Vision Transformers and MLP Mixers. This Colab allows you to edit the files from the repository directly in the Colab UI and has annotated Colab cells that walk you through the code step by step, and lets you interact with the data.

https://colab.research.google.com/github/google-research/vision_transformer/blob/main/vit_jax.ipynb

The second Colab allows you to explore the >50k Vision Transformer and hybrid checkpoints that were used to generate the data of the third paper "How to train your ViT? ...". The Colab includes code to explore and select checkpoints, and to do inference both using the JAX code from this repo, and also using the popular timm PyTorch library that can directly load these checkpoints as well. Note that a handful of models are also available directly from TF-Hub: sayakpaul/collections/vision_transformer (external contribution by Sayak Paul).

The second Colab also lets you fine-tune the checkpoints on any tfds dataset and your own dataset with examples in individual JPEG files (optionally directly reading from Google Drive).

https://colab.research.google.com/github/google-research/vision_transformer/blob/main/vit_jax_augreg.ipynb

Note: As for now (6/20/21) Google Colab only supports a single GPU (Nvidia Tesla T4), and TPUs (currently TPUv2-8) are attached indirectly to the Colab VM and communicate over slow network, which leads to pretty bad training speed. You would usually want to set up a dedicated machine if you have a non-trivial amount of data to fine-tune on. For details see the Running on cloud section.

Installation

Make sure you have Python>=3.10 installed on your machine.

Install JAX and python dependencies by running:

# If using GPU:
pip install -r vit_jax/requirements.txt

# If using TPU:
pip install -r vit_jax/requirements-tpu.txt

For newer versions of JAX, follow the instructions provided in the corresponding repository linked here. Note that installation instructions for CPU, GPU and TPU differs slightly.

Install Flaxformer, follow the instructions provided in the corresponding repository linked here.

For more details refer to the section Running on cloud below.

Fine-tuning a model

You can run fine-tuning of the downloaded model on your dataset of interest. All models share the same command line interface.

For example for fine-tuning a ViT-B/16 (pre-trained on imagenet21k) on CIFAR10 (note how we specify b16,cifar10 as arguments to the config, and how we instruct the code to access the models directly from a GCS bucket instead of first downloading them into the local directory):

python -m vit_jax.main --workdir=/tmp/vit-$(date +%s) \
    --config=$(pwd)/vit_jax/configs/vit.py:b16,cifar10 \
    --config.pretrained_dir='gs://vit_models/imagenet21k'

In order to fine-tune a Mixer-B/16 (pre-trained on imagenet21k) on CIFAR10:

python -m vit_jax.main --workdir=/tmp/vit-$(date +%s) \
    --config=$(pwd)/vit_jax/configs/mixer_base16_cifar10.py \
    --config.pretrained_dir='gs://mixer_models/imagenet21k'

The "How to train your ViT? ..." paper added >50k checkpoints that you can fine-tune with the configs/augreg.py config. When you only specify the model name (the config.name value from configs/model.py), then the best i21k checkpoint by upstream validation accuracy ("recommended" checkpoint, see section 4.5 of the paper) is chosen. To make up your mind which model you want to use, have a look at Figure 3 in the paper. It's also possible to choose a different checkpoint (see Colab vit_jax_augreg.ipynb) and then specify the value from the filename or adapt_filename column, which correspond to the filenames without .npz from the gs://vit_models/augreg directory.

python -m vit_jax.main --workdir=/tmp/vit-$(date +%s) \
    --config=$(pwd)/vit_jax/configs/augreg.py:R_Ti_16 \
    --config.dataset=oxford_iiit_pet \
    --config.base_lr=0.01

Currently, the code will automatically download CIFAR-10 and CIFAR-100 datasets. Other public or custom datasets can be easily integrated, using tensorflow datasets library. Note that you will also need to update vit_jax/input_pipeline.py to specify some parameters about any added dataset.

Note that our code uses all available GPUs/TPUs for fine-tuning.

To see a detailed list of all available flags, run python3 -m vit_jax.train --help.

Notes on memory:

  • Different models require different amount of memory. Available memory also depends on the accelerator configuration (both type and count). If you encounter an out-of-memory error you can increase the value of --config.accum_steps=8 -- alternatively, you could also decrease the --config.batch=512 (and decrease --config.base_lr accordingly).
  • The host keeps a shuffle buffer in memory. If you encounter a host OOM (as opposed to an accelerator OOM), you can decrease the default --config.shuffle_buffer=50000.

Vision Transformer

by Alexey Dosovitskiy*†, Lucas Beyer*, Alexander Kolesnikov*, Dirk Weissenborn*, Xiaohua Zhai*, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit and Neil Houlsby*†.

(*) equal technical contribution, (†) equal advising.

Figure 1 from paper

Overview of the model: we split an image into fixed-size patches, linearly embed each of them, add position embeddings, and feed the resulting sequence of vectors to a standard Transformer encoder. In order to perform classification, we use the standard approach of adding an extra learnable "classification token" to the sequence.

Available ViT models

We provide a variety of ViT models in different GCS buckets. The models can be downloaded with e.g.:

wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz

The model filenames (without the .npz extension) correspond to the config.model_name in vit_jax/configs/models.py

We recommend using the following checkpoints, trained with AugReg that have the best pre-training metrics:

ModelPre-trained checkpointSizeFine-tuned checkpointResolutionImg/secImagenet accuracy
L/16gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0.npz1243 MiBgs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_strong1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz3845085.59%
B/16gs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz391 MiBgs://vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz38413885.49%
S/16gs://vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz115 MiBgs://vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz38430083.73%
R50+L/32gs://vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz1337 MiBgs://vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz38432785.99%
R26+S/32gs://vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz170 MiBgs://vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz38456083.85%
Ti/16gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz37 MiBgs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz38461078.22%
B/32gs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0.npz398 MiBgs://vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz38495583.59%
S/32gs://vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0.npz118 MiBgs://vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_none-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz384215479.58%
R+Ti/16gs://vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz40 MiBgs://vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz384242675.40%

The results from the original ViT paper (https://arxiv.org/abs/2010.11929) have been replicated using the models from gs://vit_models/imagenet21k:

modeldatasetdropout=0.0dropout=0.1
R50+ViT-B_16cifar1098.72%, 3.9h (A100), tb.dev98.94%, 10.1h (V100), tb.dev
R50+ViT-B_16cifar10090.88%, 4.1h (A100), tb.dev92.30%, 10.1h (V100), tb.dev
R50+ViT-B_16imagenet201283.72%, 9.9h (A100), tb.dev85.08%, 24.2h (V100), tb.dev
ViT-B_16cifar1099.02%, 2.2h (A100), tb.dev98.76%, 7.8h (V100), tb.dev
ViT-B_16cifar10092.06%, 2.2h (A100), tb.dev91.92%, 7.8h (V100), tb.dev
ViT-B_16imagenet201284.53%, 6.5h (A100), tb.dev84.12%, 19.3h (V100), tb.dev
ViT-B_32cifar1098.88%, 0.8h (A100), tb.dev98.75%, 1.8h (V100), tb.dev
ViT-B_32cifar10092.31%, 0.8h (A100), tb.dev92.05%, 1.8h (V100), tb.dev
ViT-B_32imagenet201281.66%, 3.3h (A100), tb.dev81.31%, 4.9h (V100), tb.dev
ViT-L_16cifar1099.13%, 6.9h (A100), tb.dev99.14%, 24.7h (V100), tb.dev
ViT-L_16cifar10092.91%, 7.1h (A100), tb.dev93.22%, 24.4h (V100), tb.dev
ViT-L_16imagenet201284.47%, 16.8h (A100), tb.dev85.05%, 59.7h (V100), tb.dev
ViT-L_32cifar1099.06%, 1.9h (A100), tb.dev99.09%, 6.1h (V100), tb.dev
ViT-L_32cifar10093.29%, 1.9h (A100), tb.dev93.34%, 6.2h (V100), tb.dev
ViT-L_32imagenet201281.89%, 7.5h (A100), tb.dev81.13%, 15.0h (V100), tb.dev

We also would like to emphasize that high-quality results can be achieved with shorter training schedules and encourage users of our code to play with hyper-parameters to trade-off accuracy and computational budget. Some examples for CIFAR-10/100 datasets are presented in the table below.

upstreammodeldatasettotal_steps / warmup_stepsaccuracywall-clock timelink
imagenet21kViT-B_16cifar10500 / 5098.59%17mtensorboard.dev
imagenet21kViT-B_16cifar101000 / 10098.86%39mtensorboard.dev
imagenet21kViT-B_16cifar100500 / 5089.17%17mtensorboard.dev
imagenet21kViT-B_16cifar1001000 / 10091.15%39mtensorboard.dev

MLP-Mixer

by Ilya Tolstikhin*, Neil Houlsby*, Alexander Kolesnikov*, Lucas Beyer*, Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Andreas Steiner, Daniel Keysers, Jakob Uszkoreit, Mario Lucic, Alexey Dosovitskiy.

(*) equal contribution.

Figure 1 from paper

MLP-Mixer (Mixer for short) consists of per-patch linear embeddings, Mixer layers, and a classifier head. Mixer layers contain one token-mixing MLP and one channel-mixing MLP, each consisting of two fully-connected layers and a GELU nonlinearity. Other components include: skip-connections, dropout, and linear classifier head.

For installation follow the same steps as above.

Available Mixer models

We provide the Mixer-B/16 and Mixer-L/16 models pre-trained on the ImageNet and ImageNet-21k datasets. Details can be found in Table 3 of the Mixer paper. All the models can be found at:

https://console.cloud.google.com/storage/mixer_models/

Note that these models are also available directly from TF-Hub: sayakpaul/collections/mlp-mixer (external contribution by Sayak Paul).

Expected Mixer results

We ran the fine-tuning code on Google Cloud machine with four V100 GPUs with the default adaption parameters from this repository. Here are the results:

upstreammodeldatasetaccuracywall_clock_timelink
ImageNetMixer-B/16cifar1096.72%3.0htensorboard.dev
ImageNetMixer-L/16cifar1096.59%3.0htensorboard.dev
ImageNet-21kMixer-B/16cifar1096.82%9.6htensorboard.dev
ImageNet-21kMixer-L/16cifar1098.34%10.0htensorboard.dev

LiT models

For details, refer to the Google AI blog post LiT: adding language understanding to image models, or read the CVPR paper "LiT: Zero-Shot Transfer with Locked-image text Tuning" (https://arxiv.org/abs/2111.07991).

We published a Transformer B/16-base model with an ImageNet zeroshot accuracy of 72.1%, and a L/16-large model with an ImageNet zeroshot accuracy of 75.7%. For more details about these models, please refer to the LiT model card.

We provide a in-browser demo with small text encoders for interactive use (the smallest models should even run on a modern cell phone):

https://google-research.github.io/vision_transformer/lit/

And finally a Colab to use the JAX models with both image and text encoders:

https://colab.research.google.com/github/google-research/vision_transformer/blob/main/lit.ipynb

Note that none of above models support multi-lingual inputs yet, but we're working on publishing such models and will update this repository once they become available.

This repository only contains evaluation code for LiT models. You can find the training code in the big_vision repository:

https://github.com/google-research/big_vision/tree/main/big_vision/configs/proj/image_text

Expected zeroshot results from model_cards/lit.md (note that the zeroshot evaluation is slightly different from the simplified evaluation in the Colab):

ModelB16B_2L16L
ImageNet zero-shot73.9%75.7%
ImageNet v2 zero-shot65.1%66.6%
CIFAR100 zero-shot79.0%80.5%
Pets37 zero-shot83.3%83.3%
Resisc45 zero-shot25.3%25.6%
MS-COCO Captions image-to-text retrieval51.6%48.5%
MS-COCO Captions text-to-image retrieval31.8%31.1%

Running on cloud

While above colabs are pretty useful to get started, you would usually want to train on a larger machine with more powerful accelerators.

Create a VM

You can use the following commands to setup a VM with GPUs on Google Cloud:

# Set variables used by all commands below.
# Note that project must have accounting set up.
# For a list of zones with GPUs refer to
# https://cloud.google.com/compute/docs/gpus/gpu-regions-zones
PROJECT=my-awesome-gcp-project  # Project must have billing enabled.
VM_NAME=vit-jax-vm-gpu
ZONE=europe-west4-b

# Below settings have been tested with this repository. You can choose other
# combinations of images & machines (e.g.), refer to the corresponding gcloud commands:
# gcloud compute images list --project ml-images
# gcloud compute machine-types list
# etc.
gcloud compute instances create $VM_NAME \
    --project=$PROJECT --zone=$ZONE \
    --image=c1-deeplearning-tf-2-5-cu110-v20210527-debian-10 \
    --image-project=ml-images --machine-type=n1-standard-96 \
    --scopes=cloud-platform,storage-full --boot-disk-size=256GB \
    --boot-disk-type=pd-ssd --metadata=install-nvidia-driver=True \
    --maintenance-policy=TERMINATE \
    --accelerator=type=nvidia-tesla-v100,count=8

# Connect to VM (after some minutes needed to setup & start the machine).
gcloud compute ssh --project $PROJECT --zone $ZONE $VM_NAME

# Stop the VM after use (only storage is billed for a stopped VM).
gcloud compute instances stop --project $PROJECT --zone $ZONE $VM_NAME

# Delete VM after use (this will also remove all data stored on VM).
gcloud compute instances delete --project $PROJECT --zone $ZONE $VM_NAME

Alternatively, you can use the following similar commands to set up a Cloud VM with TPUs attached to them (below commands copied from the TPU tutorial):

PROJECT=my-awesome-gcp-project  # Project must have billing enabled.
VM_NAME=vit-jax-vm-tpu
ZONE=europe-west4-a

# Required to set up service identity initially.
gcloud beta services identity create --service tpu.googleapis.com

# Create a VM with TPUs directly attached to it.
gcloud alpha compute tpus tpu-vm create $VM_NAME \
    --project=$PROJECT --zone=$ZONE \
    --accelerator-type v3-8 \
    --version tpu-vm-base

# Connect to VM (after some minutes needed to setup & start the machine).
gcloud alpha compute tpus tpu-vm ssh --project $PROJECT --zone $ZONE $VM_NAME

# Stop the VM after use (only storage is billed for a stopped VM).
gcloud alpha compute tpus tpu-vm stop --project $PROJECT --zone $ZONE $VM_NAME

# Delete VM after use (this will also remove all data stored on VM).
gcloud alpha compute tpus tpu-vm delete --project $PROJECT --zone $ZONE $VM_NAME

Setup VM

And then fetch the repository and the install dependencies (including jaxlib with TPU support) as usual:

git clone --depth=1 --branch=master https://github.com/google-research/vision_transformer
cd vision_transformer

# optional: install virtualenv
pip3 install virtualenv
python3 -m virtualenv env
. env/bin/activate

If you're connected to a VM with GPUs attached, install JAX and other dependencies with the following command:

pip install -r vit_jax/requirements.txt

If you're connected to a VM with TPUs attached, install JAX and other dependencies with the following command:

pip install -r vit_jax/requirements-tpu.txt

Install Flaxformer, follow the instructions provided in the corresponding repository linked here.

For both GPUs and TPUs, Check that JAX can connect to attached accelerators with the command:

python -c 'import jax; print(jax.devices())'

And finally execute one of the commands mentioned in the section fine-tuning a model.

Bibtex

@article{dosovitskiy2020vit,
  title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
  author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and  Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
  journal={ICLR},
  year={2021}
}

@article{tolstikhin2021mixer,
  title={MLP-Mixer: An all-MLP Architecture for Vision},
  author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner, Thomas and Yung, Jessica and Steiner, Andreas and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
  journal={arXiv preprint arXiv:2105.01601},
  year={2021}
}

@article{steiner2021augreg,
  title={How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers},
  author={Steiner, Andreas and Kolesnikov, Alexander and and Zhai, Xiaohua and Wightman, Ross and Uszkoreit, Jakob and Beyer, Lucas},
  journal={arXiv preprint arXiv:2106.10270},
  year={2021}
}

@article{chen2021outperform,
  title={When Vision Transformers Outperform ResNets without Pretraining or Strong Data Augmentations},
  author={Chen, Xiangning and Hsieh, Cho-Jui and Gong, Boqing},
  journal={arXiv preprint arXiv:2106.01548},
  year={2021},
}

@article{zhuang2022gsam,
  title={Surrogate Gap Minimization Improves Sharpness-Aware Training},
  author={Zhuang, Juntang and Gong, Boqing and Yuan, Liangzhe and Cui, Yin and Adam, Hartwig and Dvornek, Nicha and Tatikonda, Sekhar and Duncan, James and Liu, Ting},
  journal={ICLR},
  year={2022},
}

@article{zhai2022lit,
  title={LiT: Zero-Shot Transfer with Locked-image Text Tuning},
  author={Zhai, Xiaohua and Wang, Xiao and Mustafa, Basil and Steiner, Andreas and Keysers, Daniel and Kolesnikov, Alexander and Beyer, Lucas},
  journal={CVPR},
  year={2022}
}

Changelog

In reverse chronological order:

  • 2022-08-18: Added LiT-B16B_2 model that was trained for 60k steps (LiT_B16B: 30k) without linear head on the image side (LiT_B16B: 768) and has better performance.

  • 2022-06-09: Added the ViT and Mixer models trained from scratch using GSAM on ImageNet without strong data augmentations. The resultant ViTs outperform those of similar sizes trained using AdamW optimizer or the original SAM algorithm, or with strong data augmentations.

  • 2022-04-14: Added models and Colab for LiT models.

  • 2021-07-29: Added ViT-B/8 AugReg models (3 upstream checkpoints and adaptations with resolution=224).

  • 2021-07-02: Added the "When Vision Transformers Outperform ResNets..." paper

  • 2021-07-02: Added SAM (Sharpness-Aware Minimization) optimized ViT and MLP-Mixer checkpoints.

  • 2021-06-20: Added the "How to train your ViT? ..." paper, and a new Colab to explore the >50k pre-trained and fine-tuned checkpoints mentioned in the paper.

  • 2021-06-18: This repository was rewritten to use Flax Linen API and ml_collections.ConfigDict for configuration.

  • 2021-05-19: With publication of the "How to train your ViT? ..." paper, we added more than 50k ViT and hybrid models pre-trained on ImageNet and ImageNet-21k with various degrees of data augmentation and model regularization, and fine-tuned on ImageNet, Pets37, Kitti-distance, CIFAR-100, and Resisc45. Check out vit_jax_augreg.ipynb to navigate this treasure trove of models! For example, you can use that Colab to fetch the filenames of recommended pre-trained and fine-tuned checkpoints from the i21k_300 column of Table 3 in the paper.

  • 2020-12-01: Added the R50+ViT-B/16 hybrid model (ViT-B/16 on top of a Resnet-50 backbone). When pretrained on imagenet21k, this model achieves almost the performance of the L/16 model with less than half the computational finetuning cost. Note that "R50" is somewhat modified for the B/16 variant: The original ResNet-50 has [3,4,6,3] blocks, each reducing the resolution of the image by a factor of two. In combination with the ResNet stem this would result in a reduction of 32x so even with a patch size of (1,1) the ViT-B/16 variant cannot be realized anymore. For this reason we instead use [3,4,9] blocks for the R50+B/16 variant.

  • 2020-11-09: Added the ViT-L/16 model.

  • 2020-10-29: Added ViT-B/16 and ViT-L/16 models pretrained on ImageNet-21k and then fine-tuned on ImageNet at 224x224 resolution (instead of default 384x384). These models have the suffix "-224" in their name. They are expected to achieve 81.2% and 82.7% top-1 accuracies respectively.

Disclaimers

Open source release prepared by Andreas Steiner.

Note: This repository was forked and modified from google-research/big_transfer.

This is not an official Google product.