ViT-pytorch
Pytorch reimplementation of the Vision Transformer (An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale)
Top Related Projects
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".
Official DeiT repository
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Quick Overview
ViT-pytorch is a PyTorch implementation of the Vision Transformer (ViT) model introduced in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" by Google Research. This repository provides a clean and efficient implementation of ViT, allowing researchers and practitioners to easily experiment with and deploy vision transformers for various computer vision tasks.
Pros
- Clean and well-organized implementation of Vision Transformer in PyTorch
- Includes pre-trained models and easy-to-use inference code
- Supports various ViT configurations (ViT-Base, ViT-Large, ViT-Huge)
- Provides example scripts for fine-tuning on custom datasets
Cons
- Limited documentation and explanations of the implementation details
- Lacks extensive benchmarking results or comparisons with other models
- May require significant computational resources for training large ViT models
- Does not include implementations of more recent ViT variants or improvements
Code Examples
- Loading a pre-trained ViT model:
import torch
from models.modeling import VisionTransformer, CONFIGS
config = CONFIGS['ViT-B_16']
model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=True)
model.load_from(np.load("ViT-B_16.npz"))
- Performing inference on an image:
from utils.data_utils import get_loader
from PIL import Image
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
image = Image.open("example.jpg")
input_tensor = transform(image).unsqueeze(0)
output = model(input_tensor)
predicted_class = torch.argmax(output, dim=1).item()
- Fine-tuning ViT on a custom dataset:
from train import train
# Assuming you have prepared your custom dataset
train_loader, val_loader = get_custom_data_loaders()
# Fine-tune the model
args = {
'model_type': 'ViT-B_16',
'num_classes': 10, # Number of classes in your dataset
'lr': 3e-2,
'weight_decay': 0,
'num_steps': 10000,
}
train(args, train_loader, val_loader)
Getting Started
-
Clone the repository:
git clone https://github.com/jeonsworld/ViT-pytorch.git cd ViT-pytorch
-
Install dependencies:
pip install -r requirements.txt
-
Download pre-trained weights:
wget https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz
-
Run inference or fine-tuning:
python train.py --name cifar10-100_500 --dataset cifar10 --model_type ViT-B_16 --pretrained_dir ViT-B_16.npz
Competitor Comparisons
Pros of vision_transformer
- Official implementation from Google Research, likely more authoritative
- Supports multiple model variants (ViT-B, ViT-L, ViT-H)
- Includes pre-trained weights and evaluation scripts
Cons of vision_transformer
- Written in TensorFlow, which may be less familiar to PyTorch users
- Less focused on ease of use and integration with other projects
- Fewer examples and documentation for quick start
Code Comparison
vision_transformer:
class Encoder1D(tf.keras.layers.Layer):
"""Transformer Encoder."""
def __init__(self,
num_layers,
mlp_dim,
num_heads,
dropout_rate=0.1,
attention_dropout_rate=0.1,
add_position_embedding=True,
name='encoder'):
super().__init__(name=name)
self.num_layers = num_layers
self.mlp_dim = mlp_dim
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.attention_dropout_rate = attention_dropout_rate
self.add_position_embedding = add_position_embedding
ViT-pytorch:
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
Pros of vit-pytorch
- More modular and flexible implementation
- Includes additional features like distillation and various attention mechanisms
- Better documentation and examples
Cons of vit-pytorch
- May be more complex for beginners
- Less focus on reproducing original ViT paper results
Code Comparison
ViT-pytorch:
x = self.patch_embedding(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
x = self.dropout(x)
vit-pytorch:
x = self.to_patch_embedding(x)
b, n, _ = x.shape
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)
x += self.pos_embedding[:, :(n + 1)]
Both repositories implement Vision Transformers (ViT) in PyTorch, but they have different approaches and features. ViT-pytorch focuses on reproducing the original paper's results, while vit-pytorch offers a more flexible and feature-rich implementation. The code comparison shows similarities in the basic structure, but vit-pytorch uses more concise and modular code. Overall, the choice between the two depends on the user's specific needs and level of expertise.
The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
Pros of pytorch-image-models
- Extensive collection of pre-trained models and architectures
- Regular updates and active community support
- Comprehensive documentation and examples
Cons of pytorch-image-models
- Larger codebase, potentially more complex for beginners
- May include unnecessary features for those focused solely on ViT
Code Comparison
ViT-pytorch:
model = VisionTransformer(
img_size=224,
patch_size=16,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
)
pytorch-image-models:
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.head = nn.Linear(model.head.in_features, num_classes)
The pytorch-image-models example demonstrates easier model creation and fine-tuning, while ViT-pytorch offers more granular control over model parameters.
pytorch-image-models provides a more comprehensive suite of tools and models, making it suitable for a wider range of computer vision tasks. However, ViT-pytorch's focused approach on Vision Transformers may be preferable for those specifically working with this architecture.
Both repositories offer valuable resources for implementing Vision Transformers in PyTorch, with the choice depending on the user's specific needs and preferences.
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".
Pros of Swin-Transformer
- Hierarchical structure allows for better performance on various vision tasks
- More efficient computation and memory usage for high-resolution images
- Supports a wider range of vision applications beyond image classification
Cons of Swin-Transformer
- More complex architecture, potentially harder to understand and implement
- May require more computational resources for training and inference
- Less straightforward to adapt for non-vision tasks compared to ViT-pytorch
Code Comparison
ViT-pytorch:
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
Swin-Transformer:
class BasicLayer(nn.Module):
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
Official DeiT repository
Pros of DeiT
- More comprehensive and feature-rich implementation
- Includes distillation techniques for improved performance
- Better documentation and examples for usage
Cons of DeiT
- More complex codebase, potentially harder to understand for beginners
- Requires more computational resources due to additional features
Code Comparison
ViT-pytorch:
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
DeiT:
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
The main difference in the code is that DeiT allows for a custom qk_scale
parameter, providing more flexibility in attention scaling.
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Pros of transformers
- Comprehensive library with support for multiple architectures and tasks
- Extensive documentation and community support
- Regular updates and maintenance
Cons of transformers
- Larger codebase, potentially more complex for beginners
- May include unnecessary components for specific use cases
Code Comparison
ViT-pytorch:
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
transformers:
from transformers import ViTFeatureExtractor, ViTForImageClassification
import torch
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
ViT-pytorch offers a more straightforward implementation for Vision Transformer models, while transformers provides a more versatile and feature-rich environment for working with various transformer architectures, including ViT.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
Vision Transformer
Pytorch reimplementation of Google's repository for the ViT model that was released with the paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
This paper show that Transformers applied directly to image patches and pre-trained on large datasets work really well on image recognition task.
Vision Transformer achieve State-of-the-Art in image recognition task with standard Transformer encoder and fixed-size patches. In order to perform classification, author use the standard approach of adding an extra learnable "classification token" to the sequence.
Usage
1. Download Pre-trained model (Google's Official Checkpoint)
- Available models: ViT-B_16(85.8M), R50+ViT-B_16(97.96M), ViT-B_32(87.5M), ViT-L_16(303.4M), ViT-L_32(305.5M), ViT-H_14(630.8M)
- imagenet21k pre-train models
- ViT-B_16, ViT-B_32, ViT-L_16, ViT-L_32, ViT-H_14
- imagenet21k pre-train + imagenet2012 fine-tuned models
- ViT-B_16-224, ViT-B_16, ViT-B_32, ViT-L_16-224, ViT-L_16, ViT-L_32
- Hybrid Model(Resnet50 + Transformer)
- R50-ViT-B_16
- imagenet21k pre-train models
# imagenet21k pre-train
wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz
# imagenet21k pre-train + imagenet2012 fine-tuning
wget https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/{MODEL_NAME}.npz
2. Train Model
python3 train.py --name cifar10-100_500 --dataset cifar10 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz
CIFAR-10 and CIFAR-100 are automatically download and train. In order to use a different dataset you need to customize data_utils.py.
The default batch size is 512. When GPU memory is insufficient, you can proceed with training by adjusting the value of --gradient_accumulation_steps
.
Also can use Automatic Mixed Precision(Amp) to reduce memory usage and train faster
python3 train.py --name cifar10-100_500 --dataset cifar10 --model_type ViT-B_16 --pretrained_dir checkpoint/ViT-B_16.npz --fp16 --fp16_opt_level O2
Results
To verify that the converted model weight is correct, we simply compare it with the author's experimental results. We trained using mixed precision, and --fp16_opt_level
was set to O2.
imagenet-21k
model | dataset | resolution | acc(official) | acc(this repo) | time |
---|---|---|---|---|---|
ViT-B_16 | CIFAR-10 | 224x224 | - | 0.9908 | 3h 13m |
ViT-B_16 | CIFAR-10 | 384x384 | 0.9903 | 0.9906 | 12h 25m |
ViT_B_16 | CIFAR-100 | 224x224 | - | 0.923 | 3h 9m |
ViT_B_16 | CIFAR-100 | 384x384 | 0.9264 | 0.9228 | 12h 31m |
R50-ViT-B_16 | CIFAR-10 | 224x224 | - | 0.9892 | 4h 23m |
R50-ViT-B_16 | CIFAR-10 | 384x384 | 0.99 | 0.9904 | 15h 40m |
R50-ViT-B_16 | CIFAR-100 | 224x224 | - | 0.9231 | 4h 18m |
R50-ViT-B_16 | CIFAR-100 | 384x384 | 0.9231 | 0.9197 | 15h 53m |
ViT_L_32 | CIFAR-10 | 224x224 | - | 0.9903 | 2h 11m |
ViT_L_32 | CIFAR-100 | 224x224 | - | 0.9276 | 2h 9m |
ViT_H_14 | CIFAR-100 | 224x224 | - | WIP |
imagenet-21k + imagenet2012
model | dataset | resolution | acc |
---|---|---|---|
ViT-B_16-224 | CIFAR-10 | 224x224 | 0.99 |
ViT_B_16-224 | CIFAR-100 | 224x224 | 0.9245 |
ViT-L_32 | CIFAR-10 | 224x224 | 0.9903 |
ViT-L_32 | CIFAR-100 | 224x224 | 0.9285 |
shorter train
- In the experiment below, we used a resolution size (224x224).
- tensorboard
upstream | model | dataset | total_steps /warmup_steps | acc(official) | acc(this repo) |
---|---|---|---|---|---|
imagenet21k | ViT-B_16 | CIFAR-10 | 500/100 | 0.9859 | 0.9859 |
imagenet21k | ViT-B_16 | CIFAR-10 | 1000/100 | 0.9886 | 0.9878 |
imagenet21k | ViT-B_16 | CIFAR-100 | 500/100 | 0.8917 | 0.9072 |
imagenet21k | ViT-B_16 | CIFAR-100 | 1000/100 | 0.9115 | 0.9216 |
Visualization
The ViT consists of a Standard Transformer Encoder, and the encoder consists of Self-Attention and MLP module. The attention map for the input image can be visualized through the attention score of self-attention.
Visualization code can be found at visualize_attention_map.
Reference
Citations
@article{dosovitskiy2020,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
journal={arXiv preprint arXiv:2010.11929},
year={2020}
}
Top Related Projects
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch
The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".
Official DeiT repository
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot