big_transfer
Official repository for the "Big Transfer (BiT): General Visual Representation Learning" paper.
Top Related Projects
Official DeiT repository
VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
Quick Overview
Big Transfer (BiT) is a research project by Google that focuses on transfer learning for computer vision tasks. It introduces pre-trained models that can be fine-tuned on small datasets to achieve state-of-the-art performance across a wide range of visual tasks, including image classification, object detection, and semantic segmentation.
Pros
- Achieves excellent performance on various visual tasks with minimal fine-tuning
- Provides pre-trained models of different sizes to suit various computational requirements
- Demonstrates strong transfer learning capabilities, even on small datasets
- Includes detailed documentation and example code for easy implementation
Cons
- Requires significant computational resources for training and fine-tuning large models
- May not be suitable for highly specialized or domain-specific tasks without extensive modifications
- Limited to visual tasks and may not generalize well to other types of data
- Dependency on specific versions of TensorFlow and other libraries may cause compatibility issues
Code Examples
- Loading a pre-trained BiT model:
import bit_pytorch
model = bit_pytorch.models.KNOWN_MODELS['BiT-M-R50x1'](head_size=10, zero_head=True)
model.load_from(np.load('BiT-M-R50x1.npz'))
- Fine-tuning the model on a custom dataset:
import bit_hyperrule
model.train()
optim = torch.optim.SGD(model.parameters(), lr=bit_hyperrule.get_lr(10, 512), momentum=0.9)
for x, y in dataloader:
logits = model(x)
loss = cross_entropy_loss(logits, y)
loss.backward()
optim.step()
optim.zero_grad()
- Performing inference with a fine-tuned model:
model.eval()
with torch.no_grad():
logits = model(input_image)
predicted_class = logits.argmax(dim=1)
Getting Started
To get started with Big Transfer:
-
Clone the repository:
git clone https://github.com/google-research/big_transfer.git cd big_transfer
-
Install dependencies:
pip install -r requirements.txt
-
Download pre-trained models:
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz
-
Use the provided examples to fine-tune on your dataset or perform inference:
import bit_pytorch model = bit_pytorch.models.KNOWN_MODELS['BiT-M-R50x1'](head_size=num_classes, zero_head=True) model.load_from(np.load('BiT-M-R50x1.npz')) # Continue with fine-tuning or inference as shown in the code examples
Competitor Comparisons
Official DeiT repository
Pros of DeiT
- Focuses on Vision Transformers, offering a more specialized approach for image classification tasks
- Provides data-efficient training methods, reducing the need for large datasets
- Includes distillation techniques to improve model performance and efficiency
Cons of DeiT
- Limited to Vision Transformer architectures, whereas Big Transfer supports various CNN models
- May require more computational resources for training compared to traditional CNN approaches
- Less versatile for general transfer learning tasks across different domains
Code Comparison
DeiT:
model = deit_small_patch16_224(pretrained=True)
model.head = nn.Linear(model.head.in_features, num_classes)
Big Transfer:
model = bit_model.KNOWN_MODELS['BiT-M-R50x1'](head_size=num_classes, zero_head=True)
model.load_from(np.load('BiT-M-R50x1.npz'))
Both repositories provide pre-trained models and methods for fine-tuning on custom datasets. DeiT focuses on Vision Transformers and includes distillation techniques, while Big Transfer offers a broader range of CNN-based models for transfer learning across various tasks.
VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
Pros of VISSL
- More comprehensive and flexible framework for self-supervised learning
- Supports a wider range of SSL algorithms and techniques
- Offers extensive documentation and tutorials for easier adoption
Cons of VISSL
- Steeper learning curve due to its complexity and extensive features
- May be overkill for simpler self-supervised learning tasks
- Requires more computational resources for training and experimentation
Code Comparison
VISSL example:
from vissl.engines.train import train_main
from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict
cfg = compose_hydra_configuration(["config=test/integration_test/quick_simclr"])
cfg = convert_to_attrdict(cfg)
train_main(cfg)
Big Transfer example:
import bit_hyperrule
import bit_common
import bit_pytorch
model = bit_pytorch.KNOWN_MODELS['BiT-M-R50x1'](head_size=10, zero_head=True)
bit_hyperrule.train(model, train_set, test_set, optimizer, steps)
The VISSL code demonstrates its configuration-based approach, while Big Transfer shows a more straightforward model initialization and training process.
The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
Pros of pytorch-image-models
- Extensive collection of pre-trained models and architectures
- Active community and frequent updates
- Seamless integration with PyTorch ecosystem
Cons of pytorch-image-models
- Less focus on transfer learning compared to big_transfer
- May require more manual configuration for specific tasks
Code Comparison
big_transfer:
import bit_hyperrule
import bit_common
import bit_model
model = bit_model.KNOWN_MODELS['BiT-M-R50x1'](head_size=10)
bit_common.train(model, train_loader, test_loader, bit_hyperrule.get_resolution(10))
pytorch-image-models:
import timm
model = timm.create_model('resnet50', pretrained=True, num_classes=10)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()
Summary
Both repositories offer valuable tools for computer vision tasks. big_transfer focuses on transfer learning with a specific set of models, while pytorch-image-models provides a broader range of pre-trained models and architectures. The choice between them depends on the specific requirements of your project and your familiarity with the respective ecosystems.
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".
Pros of Swin-Transformer
- Hierarchical structure allows for better handling of varying scales in images
- Shifted window approach reduces computational complexity while maintaining performance
- More versatile for various vision tasks beyond image classification
Cons of Swin-Transformer
- More complex architecture, potentially harder to implement and fine-tune
- May require more computational resources for training and inference
- Less focus on transfer learning across diverse tasks compared to BiT
Code Comparison
Swin-Transformer:
class SwinTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, **kwargs):
super().__init__()
# ... (implementation details)
BiT:
class ResNetV2(tf.keras.Model):
def __init__(self, num_classes, width_factor, depth):
super().__init__(name='resnet')
self.width = int(64 * width_factor)
self.num_blocks = _NUM_BLOCKS[depth]
# ... (implementation details)
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Pros of transformers
- Extensive library with support for numerous pre-trained models and architectures
- Active community and frequent updates
- Comprehensive documentation and examples
Cons of transformers
- Larger codebase, potentially more complex for beginners
- May have higher computational requirements for some models
Code comparison
transformers:
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
big_transfer:
import bit_pytorch
model = bit_pytorch.models.KNOWN_MODELS['BiT-M-R50x1'](head_size=10)
model.load_from(np.load('BiT-M-R50x1.npz'))
Summary
transformers offers a wide range of pre-trained models and architectures with excellent documentation and community support. However, it may be more complex for beginners and require more computational resources. big_transfer focuses on transfer learning for computer vision tasks, providing a simpler interface but with a narrower scope. The code examples demonstrate the ease of use for both libraries, with transformers showing more flexibility in model selection.
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
Pros of CLIP
- Multimodal learning: CLIP can understand both images and text, enabling more versatile applications
- Zero-shot learning capabilities: Can perform tasks without specific training on those tasks
- Larger dataset: Trained on 400 million image-text pairs, potentially offering broader knowledge
Cons of CLIP
- More complex architecture: Requires both image and text encoders, increasing computational demands
- Less specialized for transfer learning: BiT is specifically designed for efficient transfer learning
Code Comparison
CLIP usage example:
import torch
from PIL import Image
import clip
model, preprocess = clip.load("ViT-B/32", device="cuda")
image = preprocess(Image.open("image.jpg")).unsqueeze(0).to("cuda")
text = clip.tokenize(["a dog", "a cat"]).to("cuda")
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
BiT usage example:
import bit_pytorch
import torch
from torchvision import transforms
model = bit_pytorch.models.KNOWN_MODELS['BiT-M-R50x1'](head_size=10)
model.load_from(torch.load('BiT-M-R50x1.npz'))
preprocess = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)])
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
Big Transfer (BiT): General Visual Representation Learning
by Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, Neil Houlsby
Update 18/06/2021: We release new high performing BiT-R50x1 models, which were distilled from BiT-M-R152x2, see this section. More details in our paper "Knowledge distillation: A good teacher is patient and consistent".
Update 08/02/2021: We also release ALL BiT-M models fine-tuned on ALL 19 VTAB-1k datasets, see below.
Introduction
In this repository we release multiple models from the Big Transfer (BiT): General Visual Representation Learning paper that were pre-trained on the ILSVRC-2012 and ImageNet-21k datasets. We provide the code to fine-tuning the released models in the major deep learning frameworks TensorFlow 2, PyTorch and Jax/Flax.
We hope that the computer vision community will benefit by employing more powerful ImageNet-21k pretrained models as opposed to conventional models pre-trained on the ILSVRC-2012 dataset.
We also provide colabs for a more exploratory interactive use: a TensorFlow 2 colab, a PyTorch colab, and a Jax colab.
Installation
Make sure you have Python>=3.6
installed on your machine.
To setup Tensorflow 2, PyTorch or Jax, follow the instructions provided in the corresponding repository linked here.
In addition, install python dependencies by running (please select tf2
, pytorch
or jax
in the command below):
pip install -r bit_{tf2|pytorch|jax}/requirements.txt
How to fine-tune BiT
First, download the BiT model. We provide models pre-trained on ILSVRC-2012 (BiT-S) or ImageNet-21k (BiT-M) for 5 different architectures: ResNet-50x1, ResNet-101x1, ResNet-50x3, ResNet-101x3, and ResNet-152x4.
For example, if you would like to download the ResNet-50x1 pre-trained on ImageNet-21k, run the following command:
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1.{npz|h5}
Other models can be downloaded accordingly by plugging the name of the model (BiT-S or BiT-M) and architecture in the above command.
Note that we provide models in two formats: npz
(for PyTorch and Jax) and h5
(for TF2). By default we expect that model weights are stored in the root folder of this repository.
Then, you can run fine-tuning of the downloaded model on your dataset of interest in any of the three frameworks. All frameworks share the command line interface
python3 -m bit_{pytorch|jax|tf2}.train --name cifar10_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset cifar10
Currently. all frameworks will automatically download CIFAR-10 and CIFAR-100 datasets. Other public or custom datasets can be easily integrated: in TF2 and JAX we rely on the extensible tensorflow datasets library. In PyTorch, we use torchvisionâs data input pipeline.
Note that our code uses all available GPUs for fine-tuning.
We also support training in the low-data regime: the --examples_per_class <K>
option will randomly draw K samples per class for training.
To see a detailed list of all available flags, run python3 -m bit_{pytorch|jax|tf2}.train --help
.
BiT-M models fine-tuned on ILSVRC-2012
For convenience, we provide BiT-M models that were already fine-tuned on the
ILSVRC-2012 dataset. The models can be downloaded by adding the -ILSVRC2012
postfix, e.g.
wget https://storage.googleapis.com/bit_models/BiT-M-R50x1-ILSVRC2012.npz
Available architectures
We release all architectures mentioned in the paper, such that you may choose between accuracy or speed: R50x1, R101x1, R50x3, R101x3, R152x4.
In the above path to the model file, simply replace R50x1
by your architecture of choice.
We further investigated more architectures after the paper's publication and found R152x2 to have a nice trade-off between speed and accuracy, hence we also include this in the release and provide a few numbers below.
BiT-M models fine-tuned on the 19 VTAB-1k tasks
We also release the fine-tuned models for each of the 19 tasks included in the VTAB-1k benchmark. We ran each model three times and release each of these runs. This means we release a total of 5x19x3=285 models, and hope these can be useful in further analysis of transfer learning.
The files can be downloaded via the following pattern:
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-{R50x1,R101x1,R50x3,R101x3,R152x4}-run{0,1,2}-{caltech101,diabetic_retinopathy,dtd,oxford_flowers102,oxford_iiit_pet,resisc45,sun397,cifar100,eurosat,patch_camelyon,smallnorb-elevation,svhn,dsprites-orientation,smallnorb-azimuth,clevr-distance,clevr-count,dmlab,kitti-distance,dsprites-xpos}.npz
We did not convert these models to TF2 (hence there is no corresponding .h5
file), however, we also uploaded TFHub models which can be used in TF1 and TF2. An example sequence of commands for downloading one such model is:
mkdir BiT-M-R50x1-run0-caltech101.tfhub && cd BiT-M-R50x1-run0-caltech101.tfhub
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/{saved_model.pb,tfhub_module.pb}
mkdir variables && cd variables
wget https://storage.googleapis.com/bit_models/vtab/BiT-M-R50x1-run0-caltech101.tfhub/variables/variables.{data@1,index}
Hyper-parameters
For reproducibility, our training script uses hyper-parameters (BiT-HyperRule) that were used in the original paper. Note, however, that BiT models were trained and finetuned using Cloud TPU hardware, so for a typical GPU setup our default hyper-parameters could require too much memory or result in a very slow progress. Moreover, BiT-HyperRule is designed to generalize across many datasets, so it is typically possible to devise more efficient application-specific hyper-parameters. Thus, we encourage the user to try more light-weight settings, as they require much less resources and often result in a similar accuracy.
For example, we tested our code using a 8xV100 GPU machine on the CIFAR-10 and CIFAR-100 datasets, while reducing batch size from 512 to 128 and learning rate from 0.003 to 0.001. This setup resulted in nearly identical performance (see Expected results below) in comparison to BiT-HyperRule, despite being less computationally demanding.
Below, we provide more suggestions on how to optimize our paper's setup.
Tips for optimizing memory or speed
The default BiT-HyperRule was developed on Cloud TPUs and is quite memory-hungry. This is mainly due to the large batch-size (512) and image resolution (up to 480x480). Here are some tips if you are running out of memory:
- In
bit_hyperrule.py
we specify the input resolution. By reducing it, one can save a lot of memory and compute, at the expense of accuracy. - The batch-size can be reduced in order to reduce memory consumption. However, one then also needs to play with learning-rate and schedule (steps) in order to maintain the desired accuracy.
- The PyTorch codebase supports a batch-splitting technique ("micro-batching") via
--batch_split
option. For example, running the fine-tuning with--batch_split 8
reduces memory requirement by a factor of 8.
Expected results
We verified that when using the BiT-HyperRule, the code in this repository reproduces the paper's results.
CIFAR results (few-shot and full)
For these common benchmarks, the aforementioned changes to the BiT-HyperRule (--batch 128 --base_lr 0.001
) lead to the following, very similar results.
The table shows the minâmedianâmax result of at least five runs.
NOTE: This is not a comparison of frameworks, just evidence that all code-bases can be trusted to reproduce results.
BiT-M-R101x3
Dataset | Ex/cls | TF2 | Jax | PyTorch |
---|---|---|---|---|
CIFAR10 | 1 | 52.5 â 55.8 â 60.2 | 48.7 â 53.9 â 65.0 | 56.4 â 56.7 â 73.1 |
CIFAR10 | 5 | 85.3 â 87.2 â 89.1 | 80.2 â 85.8 â 88.6 | 84.8 â 85.8 â 89.6 |
CIFAR10 | full | 98.5 | 98.4 | 98.5 â 98.6 â 98.6 |
CIFAR100 | 1 | 34.8 â 35.7 â 37.9 | 32.1 â 35.0 â 37.1 | 31.6 â 33.8 â 36.9 |
CIFAR100 | 5 | 68.8 â 70.4 â 71.4 | 68.6 â 70.8 â 71.6 | 70.6 â 71.6 â 71.7 |
CIFAR100 | full | 90.8 | 91.2 | 91.1 â 91.2 â 91.4 |
BiT-M-R152x2
Dataset | Ex/cls | Jax | PyTorch |
---|---|---|---|
CIFAR10 | 1 | 44.0 â 56.7 â 65.0 | 50.9 â 55.5 â 59.5 |
CIFAR10 | 5 | 85.3 â 87.0 â 88.2 | 85.3 â 85.8 â 88.6 |
CIFAR10 | full | 98.5 | 98.5 â 98.5 â 98.6 |
CIFAR100 | 1 | 36.4 â 37.2 â 38.9 | 34.3 â 36.8 â 39.0 |
CIFAR100 | 5 | 69.3 â 70.5 â 72.0 | 70.3 â 72.0 â 72.3 |
CIFAR100 | full | 91.2 | 91.2 â 91.3 â 91.4 |
(TF2 models not yet available.)
BiT-M-R50x1
Dataset | Ex/cls | TF2 | Jax | PyTorch |
---|---|---|---|---|
CIFAR10 | 1 | 49.9 â 54.4 â 60.2 | 48.4 â 54.1 â 66.1 | 45.8 â 57.9 â 65.7 |
CIFAR10 | 5 | 80.8 â 83.3 â 85.5 | 76.7 â 82.4 â 85.4 | 80.3 â 82.3 â 84.9 |
CIFAR10 | full | 97.2 | 97.3 | 97.4 |
CIFAR100 | 1 | 35.3 â 37.1 â 38.2 | 32.0 â 35.2 â 37.8 | 34.6 â 35.2 â 38.6 |
CIFAR100 | 5 | 63.8 â 65.0 â 66.5 | 63.4 â 64.8 â 66.5 | 64.7 â 65.5 â 66.0 |
CIFAR100 | full | 86.5 | 86.4 | 86.6 |
ImageNet results
These results were obtained using BiT-HyperRule.
However, because this results in large batch-size and large resolution, memory can be an issue.
The PyTorch code supports batch-splitting, and hence we can still run things there without resorting to Cloud TPUs by adding the --batch_split N
command where N
is a power of two.
For instance, the following command produces a validation accuracy of 80.68
on a machine with 8 V100 GPUs:
python3 -m bit_pytorch.train --name ilsvrc_`date +%F_%H%M%S` --model BiT-M-R50x1 --logdir /tmp/bit_logs --dataset imagenet2012 --batch_split 4
Further increase to --batch_split 8
when running with 4 V100 GPUs, etc.
Full results achieved that way in some test runs were:
Ex/cls | R50x1 | R152x2 | R101x3 |
---|---|---|---|
1 | 18.36 | 24.5 | 25.55 |
5 | 50.64 | 64.5 | 64.18 |
full | 80.68 | 85.15 | WIP |
VTAB-1k results
These are re-runs and not the exact paper models. The expected VTAB scores for two of the models are:
Model | Full | Natural | Structured | Specialized |
---|---|---|---|---|
BiT-M-R152x4 | 73.51 | 80.77 | 61.08 | 85.67 |
BiT-M-R101x3 | 72.65 | 80.29 | 59.40 | 85.75 |
Out of context dataset
In Appendix G of our paper, we investigate whether BiT improves out-of-context robustness. To do this, we created a dataset comprising foreground objects corresponding to 21 ILSVRC-2012 classes pasted onto 41 miscellaneous backgrounds.
To download the dataset, run
wget https://storage.googleapis.com/bit-out-of-context-dataset/bit_out_of_context_dataset.zip
Images from each of the 21 classes are kept in a directory with the name of the class.
Distilled models
We release top-performing compressed BiT models from our paper "Knowledge distillation: A good teacher is patient and consistent" on knoweldge distillation. In particular, we distill the BiT-M-R152x2 model (which was pre-trained on ImageNet-21k) to BiT-R50x1 models. As a result, we obtain compact models with very competitive performance.
Model | Download link | Resolution | ImageNet top-1 acc. (paper) |
---|---|---|---|
BiT-R50x1 | link | 224 | 82.8 |
BiT-R50x1 | link | 160 | 80.5 |
For reproducibility, we also release weights of two BiT-M-R152x2 teacher models: pretrained at resolution 224 and resolution 384. See the paper for details on how these teachers were used.
Distillation code
We have no concrete plans for publishing the distillation code, as the recipe is simple and we imagine most people would integrate it in their existing training code. However, Sayak Paul has independently re-implemented the distillation setup in TensorFlow and nearly reproduced our results in several settings.
Top Related Projects
Official DeiT repository
VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
The largest collection of PyTorch image encoders / backbones. Including train, eval, inference, export scripts, and pretrained weights -- ResNet, ResNeXT, EfficientNet, NFNet, Vision Transformer (ViT), MobileNetV4, MobileNet-V3 & V2, RegNet, DPN, CSPNet, Swin Transformer, MaxViT, CoAtNet, ConvNeXt, and more
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows".
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot