Top Related Projects
PyTorch implementation of MoCo v3 https//arxiv.org/abs/2104.02057
SimCLRv2 - Big Self-Supervised Models are Strong Semi-Supervised Learners
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
Official repository for the "Big Transfer (BiT): General Visual Representation Learning" paper.
Quick Overview
SwAV (Swapping Assignments between Views) is a self-supervised learning method for visual representation learning. It leverages multi-crop augmentation and a clustering-based approach to learn meaningful features from unlabeled images. SwAV achieves state-of-the-art performance on various downstream tasks without using labels during pre-training.
Pros
- Efficient and scalable, suitable for large-scale datasets
- Achieves competitive results on various downstream tasks
- Does not require large batches or memory banks
- Can be trained on a single 8-GPU machine
Cons
- Requires careful hyperparameter tuning for optimal performance
- May be sensitive to the choice of data augmentation techniques
- Performance can vary depending on the specific downstream task
- Requires significant computational resources for training on large datasets
Code Examples
- Loading a pre-trained SwAV model:
import torch
from torchvision import models
# Load pre-trained SwAV model
model = models.resnet50(pretrained=False)
checkpoint = torch.hub.load_state_dict_from_url('https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar')
model.load_state_dict(checkpoint, strict=False)
- Extracting features using SwAV:
import torch
from torchvision import transforms
from PIL import Image
# Prepare image transformation
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load and transform an image
image = Image.open('path/to/image.jpg')
input_tensor = transform(image).unsqueeze(0)
# Extract features
with torch.no_grad():
features = model(input_tensor)
- Fine-tuning SwAV for a downstream task:
import torch.nn as nn
# Replace the last layer for fine-tuning
num_classes = 10 # Number of classes in your downstream task
model.fc = nn.Linear(model.fc.in_features, num_classes)
# Freeze all layers except the last one
for param in model.parameters():
param.requires_grad = False
for param in model.fc.parameters():
param.requires_grad = True
# Train the model on your downstream task
# ... (Add your training loop here)
Getting Started
To get started with SwAV, follow these steps:
-
Clone the repository:
git clone https://github.com/facebookresearch/swav.git cd swav
-
Install the required dependencies:
pip install -r requirements.txt
-
Download a pre-trained model or train your own:
import torch # Download pre-trained model checkpoint = torch.hub.load_state_dict_from_url('https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar') # Or train your own model (requires significant computational resources) # python main_swav.py --data_path /path/to/imagenet/train --epochs 800 --base_lr 0.6 --final_lr 0.0006 --warmup_epochs 10 --batch_size 4096 --size_crops 224 96 --nmb_crops 2 6 --min_scale_crops 0.14 0.05 --max_scale_crops 1. 0.14 --use_fp16 true --freeze_prototypes_niters 5005 --queue_length 3840 --epoch_queue_starts 15
-
Use the model for feature extraction or fine-tuning as shown in the code examples above.
Competitor Comparisons
PyTorch implementation of MoCo v3 https//arxiv.org/abs/2104.02057
Pros of MoCo-v3
- Improved performance on downstream tasks compared to SwAV
- More flexible architecture allowing for various backbone networks
- Better scalability to larger datasets and model sizes
Cons of MoCo-v3
- Potentially higher computational requirements than SwAV
- May require more careful hyperparameter tuning
- Slightly more complex implementation due to additional components
Code Comparison
SwAV:
class SwAV(nn.Module):
def __init__(self, base_encoder, dim=128, K=65536, T=0.1, m=0.999, mlp=False):
super(SwAV, self).__init__()
self.K = K
self.T = T
self.m = m
MoCo-v3:
class MoCo(nn.Module):
def __init__(self, base_encoder, dim=256, K=65536, m=0.999, T=0.07, mlp=False):
super(MoCo, self).__init__()
self.K = K
self.m = m
self.T = T
Both SwAV and MoCo-v3 are self-supervised learning methods for visual representation learning. While they share some similarities in their approach, MoCo-v3 generally offers improved performance and flexibility. However, this comes at the cost of potentially higher computational requirements and implementation complexity. The code snippets show that both methods use similar initialization parameters, but MoCo-v3 may have a slightly different structure to accommodate its enhanced features.
SimCLRv2 - Big Self-Supervised Models are Strong Semi-Supervised Learners
Pros of SimCLR
- Simpler architecture and training process
- Better performance on smaller datasets
- More extensive documentation and community support
Cons of SimCLR
- Higher computational requirements for large batch sizes
- Less effective at capturing global context in images
- May struggle with fine-grained distinctions between similar classes
Code Comparison
SimCLR:
def contrastive_loss(hidden1, hidden2, temperature=0.5):
hidden1_large = hidden1
hidden2_large = hidden2
labels = tf.range(batch_size)
masks = tf.one_hot(labels, batch_size)
logits_aa = tf.matmul(hidden1, hidden1_large, transpose_b=True) / temperature
logits_aa = logits_aa - masks * LARGE_NUM
logits_bb = tf.matmul(hidden2, hidden2_large, transpose_b=True) / temperature
logits_bb = logits_bb - masks * LARGE_NUM
logits_ab = tf.matmul(hidden1, hidden2_large, transpose_b=True) / temperature
logits_ba = tf.matmul(hidden2, hidden1_large, transpose_b=True) / temperature
SwAV:
def swav_loss(q, k, prototype_vectors):
q = nn.functional.normalize(q, dim=1, p=2)
k = nn.functional.normalize(k, dim=1, p=2)
scores = torch.mm(q, prototype_vectors.t())
scores = scores.softmax(dim=1)
target = torch.mm(k, prototype_vectors.t()).softmax(dim=1)
loss = torch.mean(torch.sum(-target * torch.log(scores + 1e-6), dim=1))
return loss
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
Pros of CLIP
- Multimodal learning: CLIP can understand both images and text, enabling versatile applications
- Zero-shot learning capabilities: Can perform tasks without specific training
- Larger dataset: Trained on 400 million image-text pairs, potentially offering broader knowledge
Cons of CLIP
- Higher computational requirements: More complex architecture may need more resources
- Less focused on self-supervised learning: SwAV's approach may be more efficient for some tasks
- Potential biases: Large-scale web-crawled data may introduce unwanted biases
Code Comparison
CLIP (Python):
import torch
from PIL import Image
import clip
model, preprocess = clip.load("ViT-B/32", device="cuda")
image = preprocess(Image.open("image.jpg")).unsqueeze(0).to("cuda")
text = clip.tokenize(["a dog", "a cat"]).to("cuda")
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
SwAV (Python):
import torch
import torchvision.transforms as transforms
from PIL import Image
transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
img = Image.open("image.jpg")
input_tensor = transform(img).unsqueeze(0)
features = model(input_tensor)
VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
Pros of VISSL
- More comprehensive library with support for multiple self-supervised learning methods
- Highly modular and customizable architecture
- Extensive documentation and tutorials for easier adoption
Cons of VISSL
- Steeper learning curve due to its broader scope
- Potentially higher computational requirements for some tasks
Code Comparison
VISSL:
from vissl.engines.train import train_main
from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict
config = compose_hydra_configuration(["config=test/integration_test/quick_simclr"])
cfg = convert_to_attrdict(config)
train_main(cfg)
SwAV:
import swav
model = swav.SwAV(
backbone=dict(type='ResNet', depth=50),
neck=dict(type='SwAVNeck', in_channels=2048, hid_channels=2048, out_channels=128),
head=dict(type='SwAVHead', feat_dim=128, num_clusters=[3000, 3000, 3000]),
)
VISSL offers a more flexible and extensive framework for various self-supervised learning tasks, while SwAV focuses specifically on the SwAV method. VISSL's code structure emphasizes configuration and modularity, whereas SwAV's implementation is more straightforward for its specific use case.
Official repository for the "Big Transfer (BiT): General Visual Representation Learning" paper.
Pros of Big Transfer
- Focuses on transfer learning for image classification tasks
- Provides pre-trained models for various architectures (ResNet, EfficientNet)
- Offers extensive documentation and usage examples
Cons of Big Transfer
- Limited to image classification tasks
- Requires more computational resources for fine-tuning large models
- Less emphasis on self-supervised learning techniques
Code Comparison
SwAV:
model = resnet50(normalize=True, hidden_mlp=2048, output_dim=128)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
criterion = SwAV(temperature=0.1, epsilon=0.05)
Big Transfer:
model = bit_models.KNOWN_MODELS['BiT-M-R50x1'](head_size=10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.003)
criterion = torch.nn.CrossEntropyLoss()
Key Differences
- SwAV focuses on self-supervised learning, while Big Transfer emphasizes transfer learning
- SwAV uses contrastive learning techniques, whereas Big Transfer relies on pre-trained models
- Big Transfer provides more out-of-the-box models for various architectures
- SwAV offers more flexibility in terms of training objectives and data augmentation
Both repositories provide valuable tools for computer vision tasks, with SwAV being more suitable for self-supervised learning and Big Transfer excelling in transfer learning for image classification.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
Unsupervised Learning of Visual Features by Contrasting Cluster Assignments
This code provides a PyTorch implementation and pretrained models for SwAV (Swapping Assignments between Views), as described in the paper Unsupervised Learning of Visual Features by Contrasting Cluster Assignments.
SwAV is an efficient and simple method for pre-training convnets without using annotations. Similarly to contrastive approaches, SwAV learns representations by comparing transformations of an image, but unlike contrastive methods, it does not require to compute feature pairwise comparisons. It makes our framework more efficient since it does not require a large memory bank or an auxiliary momentum network. Specifically, our method simultaneously clusters the data while enforcing consistency between cluster assignments produced for different augmentations (or âviewsâ) of the same image, instead of comparing features directly. Simply put, we use a âswappedâ prediction mechanism where we predict the cluster assignment of a view from the representation of another view. Our method can be trained with large and small batches and can scale to unlimited amounts of data.
Model Zoo
We release several models pre-trained with SwAV with the hope that other researchers might also benefit by replacing the ImageNet supervised network with SwAV backbone. To load our best SwAV pre-trained ResNet-50 model, simply do:
import torch
model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
We provide several baseline SwAV pre-trained models with ResNet-50 architecture in torchvision format. We also provide models pre-trained with DeepCluster-v2 and SeLa-v2 obtained by applying improvements from the self-supervised community to DeepCluster and SeLa (see details in the appendix of our paper).
method | epochs | batch-size | multi-crop | ImageNet top-1 acc. | url | args |
---|---|---|---|---|---|---|
SwAV | 800 | 4096 | 2x224 + 6x96 | 75.3 | model | script |
SwAV | 400 | 4096 | 2x224 + 6x96 | 74.6 | model | script |
SwAV | 200 | 4096 | 2x224 + 6x96 | 73.9 | model | script |
SwAV | 100 | 4096 | 2x224 + 6x96 | 72.1 | model | script |
SwAV | 200 | 256 | 2x224 + 6x96 | 72.7 | model | script |
SwAV | 400 | 256 | 2x224 + 6x96 | 74.3 | model | script |
SwAV | 400 | 4096 | 2x224 | 70.1 | model | script |
DeepCluster-v2 | 800 | 4096 | 2x224 + 6x96 | 75.2 | model | script |
DeepCluster-v2 | 400 | 4096 | 2x160 + 4x96 | 74.3 | model | script |
DeepCluster-v2 | 400 | 4096 | 2x224 | 70.2 | model | script |
SeLa-v2 | 400 | 4096 | 2x160 + 4x96 | 71.8 | model | - |
SeLa-v2 | 400 | 4096 | 2x224 | 67.2 | model | - |
Larger architectures
We provide SwAV models with ResNet-50 networks where we multiply the width by a factor Ã2, Ã4, and Ã5. To load the corresponding backbone you can use:
import torch
rn50w2 = torch.hub.load('facebookresearch/swav:main', 'resnet50w2')
rn50w4 = torch.hub.load('facebookresearch/swav:main', 'resnet50w4')
rn50w5 = torch.hub.load('facebookresearch/swav:main', 'resnet50w5')
network | parameters | epochs | ImageNet top-1 acc. | url | args |
---|---|---|---|---|---|
RN50-w2 | 94M | 400 | 77.3 | model | script |
RN50-w4 | 375M | 400 | 77.9 | model | script |
RN50-w5 | 586M | 400 | 78.5 | model | - |
Running times
We provide the running times for some of our runs:
method | batch-size | multi-crop | scripts | time per epoch |
---|---|---|---|---|
SwAV | 4096 | 2x224 + 6x96 | * * * * | 3min40s |
SwAV | 256 | 2x224 + 6x96 | * * | 52min10s |
DeepCluster-v2 | 4096 | 2x160 + 4x96 | * | 3min13s |
Running SwAV unsupervised training
Requirements
- Python 3.6
- PyTorch install = 1.4.0
- torchvision
- CUDA 10.1
- Apex with CUDA extension (see how I installed apex)
- Other dependencies: scipy, pandas, numpy
Singlenode training
SwAV is very simple to implement and experiment with. Our implementation consists in a main_swav.py file from which are imported the dataset definition src/multicropdataset.py, the model architecture src/resnet50.py and some miscellaneous training utilities src/utils.py.
For example, to train SwAV baseline on a single node with 8 gpus for 400 epochs, run:
python -m torch.distributed.launch --nproc_per_node=8 main_swav.py \
--data_path /path/to/imagenet/train \
--epochs 400 \
--base_lr 0.6 \
--final_lr 0.0006 \
--warmup_epochs 0 \
--batch_size 32 \
--size_crops 224 96 \
--nmb_crops 2 6 \
--min_scale_crops 0.14 0.05 \
--max_scale_crops 1. 0.14 \
--use_fp16 true \
--freeze_prototypes_niters 5005 \
--queue_length 3840 \
--epoch_queue_starts 15
Multinode training
Distributed training is available via Slurm. We provide several SBATCH scripts to reproduce our SwAV models. For example, to train SwAV on 8 nodes and 64 GPUs with a batch size of 4096 for 800 epochs run:
sbatch ./scripts/swav_800ep_pretrain.sh
Note that you might need to remove the copyright header from the sbatch file to launch it.
Set up dist_url
parameter: We refer the user to pytorch distributed documentation (env or file or tcp) for setting the distributed initialization method (parameter dist_url
) correctly. In the provided sbatch files, we use the tcp init method (see * for example).
Evaluating models
Evaluate models: Linear classification on ImageNet
To train a supervised linear classifier on frozen features/weights on a single node with 8 gpus, run:
python -m torch.distributed.launch --nproc_per_node=8 eval_linear.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar
The resulting linear classifier can be downloaded here.
Evaluate models: Semi-supervised learning on ImageNet
To reproduce our results and fine-tune a network with 1% or 10% of ImageNet labels on a single node with 8 gpus, run:
- 10% labels
python -m torch.distributed.launch --nproc_per_node=8 eval_semisup.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar \
--labels_perc "10" \
--lr 0.01 \
--lr_last_layer 0.2
- 1% labels
python -m torch.distributed.launch --nproc_per_node=8 eval_semisup.py \
--data_path /path/to/imagenet \
--pretrained /path/to/checkpoints/swav_800ep_pretrain.pth.tar \
--labels_perc "1" \
--lr 0.02 \
--lr_last_layer 5
Evaluate models: Transferring to Detection with DETR
DETR is a recent object detection framework that reaches competitive performance with Faster R-CNN while being conceptually simpler and trainable end-to-end. We evaluate our SwAV ResNet-50 backbone on object detection on COCO dataset using DETR framework with full fine-tuning. Here are the instructions for reproducing our experiments:
-
Install detr and prepare COCO dataset following these instructions.
-
Apply the changes highlighted in this gist to detr backbone file in order to load SwAV backbone instead of ImageNet supervised weights.
-
Launch training from
detr
repository with run_with_submitit.py.
python run_with_submitit.py --batch_size 4 --nodes 2 --lr_backbone 5e-5
Common Issues
For help or issues using SwAV, please submit a GitHub issue.
The loss does not decrease and is stuck at ln(nmb_prototypes) (8.006 for 3000 prototypes).
It sometimes happens that the system collapses at the beginning and does not manage to converge. We have found the following empirical workarounds to improve convergence and avoid collapsing at the beginning:
- use a lower epsilon value (
--epsilon 0.03
instead of the default 0.05) - carefully tune the hyper-parameters
- freeze the prototypes during first iterations (
freeze_prototypes_niters
argument) - switch to hard assignment
- remove batch-normalization layer from the projection head
- reduce the difficulty of the problem (less crops or softer data augmentation)
We now analyze the collapsing problem: it happens when all examples are mapped to the same unique representation.
In other words, the convnet always has the same output regardless of its input, it is a constant function.
All examples gets the same cluster assignment because they are identical, and the only valid assignment that satisfy the equipartition constraint in this case is the uniform assignment (1/K where K is the number of prototypes).
In turn, this uniform assignment is trivial to predict since it is the same for all examples.
Reducing epsilon parameter (see Eq(3) of our paper) encourages the assignments Q
to be sharper (i.e. less uniform), which strongly helps avoiding collapse.
However, using a too low value for epsilon may lead to numerical instability.
Training gets unstable when using the queue.
The queue is composed of feature representations from the previous batches.
These lines discard the oldest feature representations from the queue and save the newest one (i.e. from the current batch) through a round-robin mechanism.
This way, the assignment problem is performed on more samples: without the queue we assign B
examples to num_prototypes
clusters where B
is the total batch size while with the queue we assign (B + queue_length)
examples to num_prototypes
clusters.
This is especially useful when working with small batches because it improves the precision of the assignment.
If you start using the queue too early or if you use a too large queue, this can considerably disturb training: this is because the queue members are too inconsistent. After introducing the queue the loss should be lower than what it was without the queue. On the following loss curve (30 first epochs of this script) we introduced the queue at epoch 15. We observe that it made the loss go more down.
If when introducing the queue, the loss goes up and does not decrease afterwards you should stop your training and change the queue parameters. We recommend (i) using a smaller queue, (ii) starting the queue later in training.
License
See the LICENSE file for more details.
See also
PyTorch Lightning Bolts: Implementation by the Lightning team.
SwAV-TF: A TensorFlow re-implementation.
Citation
If you find this repository useful in your research, please cite:
@article{caron2020unsupervised,
title={Unsupervised Learning of Visual Features by Contrasting Cluster Assignments},
author={Caron, Mathilde and Misra, Ishan and Mairal, Julien and Goyal, Priya and Bojanowski, Piotr and Joulin, Armand},
booktitle={Proceedings of Advances in Neural Information Processing Systems (NeurIPS)},
year={2020}
}
Top Related Projects
PyTorch implementation of MoCo v3 https//arxiv.org/abs/2104.02057
SimCLRv2 - Big Self-Supervised Models are Strong Semi-Supervised Learners
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
Official repository for the "Big Transfer (BiT): General Visual Representation Learning" paper.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot