Top Related Projects
Quick Overview
SimCLR is a simple framework for contrastive learning of visual representations, developed by Google Research. It aims to improve self-supervised learning techniques for computer vision tasks, particularly in scenarios with limited labeled data. The project provides implementations and pre-trained models for the SimCLR approach.
Pros
- Achieves state-of-the-art performance in self-supervised and semi-supervised learning on ImageNet
- Simplifies the self-supervised learning pipeline by removing specialized architectures or memory banks
- Demonstrates strong transfer learning capabilities to downstream tasks
- Provides pre-trained models and implementations for easy adoption and experimentation
Cons
- Requires large batch sizes and long training times for optimal performance
- May not perform as well on smaller datasets or with limited computational resources
- The approach might not generalize equally well to all types of visual data or tasks
- Limited documentation and examples for customization and integration into other projects
Code Examples
- Loading a pre-trained SimCLR model:
import tensorflow as tf
import tensorflow_hub as hub
model = hub.load('https://tfhub.dev/google/simclr/2/r50_2x_ft_in1k/1')
- Extracting features from an image:
import numpy as np
from PIL import Image
image = Image.open('example.jpg').resize((224, 224))
image_array = np.array(image) / 255.0
features = model.signatures['default'](tf.constant(image_array[np.newaxis, ...]))['default']
- Fine-tuning SimCLR on a custom dataset:
base_model = tf.keras.Sequential([
hub.KerasLayer('https://tfhub.dev/google/simclr/2/r50_2x_ft_in1k/1', trainable=True),
tf.keras.layers.Dense(num_classes, activation='softmax')
])
base_model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=['accuracy']
)
base_model.fit(train_dataset, epochs=10, validation_data=val_dataset)
Getting Started
To get started with SimCLR:
-
Clone the repository:
git clone https://github.com/google-research/simclr.git cd simclr
-
Install dependencies:
pip install -r requirements.txt
-
Download pre-trained models or prepare your dataset for training.
-
Use the provided scripts to train or evaluate models:
python run.py --mode=train --train_mode=pretrain --train_batch_size=512 --train_epochs=1000 --learning_rate=0.3 --weight_decay=1e-6 --temperature=0.1 --dataset=imagenet2012 --image_size=224 --eval_split=validation --use_blur=True --color_jitter_strength=0.5 --model_dir=/tmp/simclr_model --use_tpu=False
Refer to the repository's README for more detailed instructions and configuration options.
Competitor Comparisons
PyTorch implementation of MoCo: https://arxiv.org/abs/1911.05722
Pros of MoCo
- Utilizes a momentum encoder, which provides more consistent representations
- Supports larger batch sizes, enabling better scalability
- Offers a simpler implementation with fewer hyperparameters to tune
Cons of MoCo
- May require more GPU memory due to the momentum encoder
- Potentially slower convergence compared to SimCLR in some scenarios
- Limited flexibility in terms of data augmentation strategies
Code Comparison
MoCo:
# Momentum update
self._momentum_update_key_encoder()
# Compute loss
loss = self.contrastive_loss(q, k)
SimCLR:
# Apply data augmentation
x_i, x_j = self.data_augmentation(x)
# Compute loss
loss = self.nt_xent_loss(h_i, h_j)
Both repositories implement self-supervised learning methods for visual representation learning. MoCo focuses on maintaining a dynamic dictionary with a queue and a moving-average encoder, while SimCLR emphasizes data augmentation and a larger batch size. MoCo's approach may be more memory-efficient for large-scale training, whereas SimCLR's method might be easier to implement and adapt to different datasets.
PyTorch implementation of SwAV https//arxiv.org/abs/2006.09882
Pros of SwAV
- Implements online clustering, allowing for more efficient training on large datasets
- Supports multi-crop augmentation, which can improve performance on downstream tasks
- Provides a more memory-efficient implementation, suitable for larger batch sizes
Cons of SwAV
- May require more careful hyperparameter tuning compared to SimCLR
- Can be more sensitive to the choice of clustering algorithm and parameters
Code Comparison
SimCLR:
# Define the contrastive loss function
def nt_xent_loss(features, temperature=0.5):
labels = tf.range(features.shape[0])
masks = tf.one_hot(labels, features.shape[0])
logits = tf.matmul(features, features, transpose_b=True) / temperature
return tf.nn.softmax_cross_entropy_with_logits(labels=masks, logits=logits)
SwAV:
# Define the SwAV loss function
def swav_loss(q, k, prototype_vectors):
q = nn.functional.normalize(q, dim=1, p=2)
k = nn.functional.normalize(k, dim=1, p=2)
scores = torch.mm(q, prototype_vectors.t())
targets = torch.mm(k, prototype_vectors.t()).argmax(dim=1)
return nn.CrossEntropyLoss()(scores, targets)
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations
Pros of SimCLR
- More beginner-friendly implementation with clearer code structure
- Includes detailed documentation and explanations within the code
- Easier to set up and run on local machines or smaller datasets
Cons of SimCLR
- Less optimized for large-scale training compared to the Google implementation
- May not include all the latest features and improvements from the original paper
- Limited support for distributed training and TPU acceleration
Code Comparison
SimCLR (sthalles):
def nt_xent_loss(out_1, out_2, temperature):
out = torch.cat([out_1, out_2], dim=0)
n_samples = len(out)
cov = torch.mm(out, out.t().contiguous())
sim = torch.exp(cov / temperature)
mask = ~torch.eye(n_samples, device=sim.device).bool()
neg = sim.masked_select(mask).view(n_samples, -1)
pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature)
loss = -torch.log(pos / neg.sum(dim=-1)).mean()
return loss
SimCLR (google-research):
def nt_xent_loss(hidden1, hidden2, temperature):
hidden1, hidden2 = tf.math.l2_normalize(hidden1, -1), tf.math.l2_normalize(hidden2, -1)
batch_size = tf.shape(hidden1)[0]
labels = tf.range(batch_size)
masks = tf.one_hot(tf.range(batch_size), batch_size)
logits_aa = tf.matmul(hidden1, hidden1, transpose_b=True) / temperature
logits_bb = tf.matmul(hidden2, hidden2, transpose_b=True) / temperature
logits_ab = tf.matmul(hidden1, hidden2, transpose_b=True) / temperature
loss_a = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, logits_ab)
loss_b = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, tf.transpose(logits_ab))
return tf.reduce_mean(loss_a + loss_b) / 2
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
SimCLR - A Simple Framework for Contrastive Learning of Visual Representations
News! We have released a TF2 implementation of SimCLR (along with converted checkpoints in TF2), they are in tf2/ folder.
News! Colabs for Intriguing Properties of Contrastive Losses are added, see here.
Pre-trained models for SimCLRv2
We opensourced total 65 pretrained models here, corresponding to those in Table 1 of the SimCLRv2 paper:
Depth | Width | SK | Param (M) | F-T (1%) | F-T(10%) | F-T(100%) | Linear eval | Supervised |
---|---|---|---|---|---|---|---|---|
50 | 1X | False | 24 | 57.9 | 68.4 | 76.3 | 71.7 | 76.6 |
50 | 1X | True | 35 | 64.5 | 72.1 | 78.7 | 74.6 | 78.5 |
50 | 2X | False | 94 | 66.3 | 73.9 | 79.1 | 75.6 | 77.8 |
50 | 2X | True | 140 | 70.6 | 77.0 | 81.3 | 77.7 | 79.3 |
101 | 1X | False | 43 | 62.1 | 71.4 | 78.2 | 73.6 | 78.0 |
101 | 1X | True | 65 | 68.3 | 75.1 | 80.6 | 76.3 | 79.6 |
101 | 2X | False | 170 | 69.1 | 75.8 | 80.7 | 77.0 | 78.9 |
101 | 2X | True | 257 | 73.2 | 78.8 | 82.4 | 79.0 | 80.1 |
152 | 1X | False | 58 | 64.0 | 73.0 | 79.3 | 74.5 | 78.3 |
152 | 1X | True | 89 | 70.0 | 76.5 | 81.3 | 77.2 | 79.9 |
152 | 2X | False | 233 | 70.2 | 76.6 | 81.1 | 77.4 | 79.1 |
152 | 2X | True | 354 | 74.2 | 79.4 | 82.9 | 79.4 | 80.4 |
152 | 3X | True | 795 | 74.9 | 80.1 | 83.1 | 79.8 | 80.5 |
These checkpoints are stored in Google Cloud Storage:
- Pretrained SimCLRv2 models (with linear eval head): gs://simclr-checkpoints/simclrv2/pretrained
- Fine-tuned SimCLRv2 models on 1% of labels: gs://simclr-checkpoints/simclrv2/finetuned_1pct
- Fine-tuned SimCLRv2 models on 10% of labels: gs://simclr-checkpoints/simclrv2/finetuned_10pct
- Fine-tuned SimCLRv2 models on 100% of labels: gs://simclr-checkpoints/simclrv2/finetuned_100pct
- Supervised models with the same architectures: gs://simclr-checkpoints/simclrv2/supervised
- The distilled / self-trained models (after fine-tuning) are also provided:
We also provide examples on how to use the checkpoints in colabs/
folder.
Pre-trained models for SimCLRv1
The pre-trained models (base network with linear classifier layer) can be found below. Note that for these SimCLRv1 checkpoints, the projection head is not available.
Model checkpoint and hub-module | ImageNet Top-1 |
---|---|
ResNet50 (1x) | 69.1 |
ResNet50 (2x) | 74.2 |
ResNet50 (4x) | 76.6 |
Additional SimCLRv1 checkpoints are available: gs://simclr-checkpoints/simclrv1.
A note on the signatures of the TensorFlow Hub module: default
is the representation output of the base network; logits_sup
is the supervised classification logits for ImageNet 1000 categories. Others (e.g. initial_max_pool
, block_group1
) are middle layers of ResNet; refer to resnet.py for the specifics. See this tutorial for additional information regarding use of TensorFlow Hub modules.
Enviroment setup
Our models are trained with TPUs. It is recommended to run distributed training with TPUs when using our code for pretraining.
Our code can also run on a single GPU. It does not support multi-GPUs, for reasons such as global BatchNorm and contrastive loss across cores.
The code is compatible with both TensorFlow v1 and v2. See requirements.txt for all prerequisites, and you can also install them using the following command.
pip install -r requirements.txt
Pretraining
To pretrain the model on CIFAR-10 with a single GPU, try the following command:
python run.py --train_mode=pretrain \
--train_batch_size=512 --train_epochs=1000 \
--learning_rate=1.0 --weight_decay=1e-4 --temperature=0.5 \
--dataset=cifar10 --image_size=32 --eval_split=test --resnet_depth=18 \
--use_blur=False --color_jitter_strength=0.5 \
--model_dir=/tmp/simclr_test --use_tpu=False
To pretrain the model on ImageNet with Cloud TPUs, first check out the Google Cloud TPU tutorial for basic information on how to use Google Cloud TPUs.
Once you have created virtual machine with Cloud TPUs, and pre-downloaded the ImageNet data for tensorflow_datasets, please set the following enviroment variables:
TPU_NAME=<tpu-name>
STORAGE_BUCKET=gs://<storage-bucket>
DATA_DIR=$STORAGE_BUCKET/<path-to-tensorflow-dataset>
MODEL_DIR=$STORAGE_BUCKET/<path-to-store-checkpoints>
The following command can be used to pretrain a ResNet-50 on ImageNet (which reflects the default hyperparameters in our paper):
python run.py --train_mode=pretrain \
--train_batch_size=4096 --train_epochs=100 --temperature=0.1 \
--learning_rate=0.075 --learning_rate_scaling=sqrt --weight_decay=1e-4 \
--dataset=imagenet2012 --image_size=224 --eval_split=validation \
--data_dir=$DATA_DIR --model_dir=$MODEL_DIR \
--use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0
A batch size of 4096 requires at least 32 TPUs. 100 epochs takes around 6 hours with 32 TPU v3s. Note that learning rate of 0.3 with learning_rate_scaling=linear
is equivalent to that of 0.075 with learning_rate_scaling=sqrt
when the batch size is 4096. However, using sqrt scaling allows it to train better when smaller batch size is used.
Finetuning the linear head (linear eval)
To fine-tune a linear head (with a single GPU), try the following command:
python run.py --mode=train_then_eval --train_mode=finetune \
--fine_tune_after_block=4 --zero_init_logits_layer=True \
--variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
--global_bn=False --optimizer=momentum --learning_rate=0.1 --weight_decay=0.0 \
--train_epochs=100 --train_batch_size=512 --warmup_epochs=0 \
--dataset=cifar10 --image_size=32 --eval_split=test --resnet_depth=18 \
--checkpoint=/tmp/simclr_test --model_dir=/tmp/simclr_test_ft --use_tpu=False
You can check the results using tensorboard, such as
python -m tensorboard.main --logdir=/tmp/simclr_test
As a reference, the above runs on CIFAR-10 should give you around 91% accuracy, though it can be further optimized.
For fine-tuning a linear head on ImageNet using Cloud TPUs, first set the CHKPT_DIR
to pretrained model dir and set a new MODEL_DIR
, then use the following command:
python run.py --mode=train_then_eval --train_mode=finetune \
--fine_tune_after_block=4 --zero_init_logits_layer=True \
--variable_schema='(?!global_step|(?:.*/|^)Momentum|head)' \
--global_bn=False --optimizer=momentum --learning_rate=0.1 --weight_decay=1e-6 \
--train_epochs=90 --train_batch_size=4096 --warmup_epochs=0 \
--dataset=imagenet2012 --image_size=224 --eval_split=validation \
--data_dir=$DATA_DIR --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR \
--use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0
As a reference, the above runs on ImageNet should give you around 64.5% accuracy.
Semi-supervised learning and fine-tuning the whole network
You can access 1% and 10% ImageNet subsets used for semi-supervised learning via tensorflow datasets: simply set dataset=imagenet2012_subset/1pct
and dataset=imagenet2012_subset/10pct
in the command line for fine-tuning on these subsets.
You can also find image IDs of these subsets in imagenet_subsets/
.
To fine-tune the whole network on ImageNet (1% of labels), refer to the following command:
python run.py --mode=train_then_eval --train_mode=finetune \
--fine_tune_after_block=-1 --zero_init_logits_layer=True \
--variable_schema='(?!global_step|(?:.*/|^)Momentum|head_supervised)' \
--global_bn=True --optimizer=lars --learning_rate=0.005 \
--learning_rate_scaling=sqrt --weight_decay=0 \
--train_epochs=60 --train_batch_size=1024 --warmup_epochs=0 \
--dataset=imagenet2012_subset/1pct --image_size=224 --eval_split=validation \
--data_dir=$DATA_DIR --model_dir=$MODEL_DIR --checkpoint=$CHKPT_DIR \
--use_tpu=True --tpu_name=$TPU_NAME --train_summary_steps=0 \
--num_proj_layers=3 --ft_proj_selector=1
Set the checkpoint
to those that are only pre-trained but not fine-tuned. Given that SimCLRv1 checkpoints do not contain projection head, it is recommended to run with SimCLRv2 checkpoints (you can still run with SimCLRv1 checkpoints, but variable_schema
needs to exclude head
). The num_proj_layers
and ft_proj_selector
need to be adjusted accordingly following SimCLRv2 paper to obtain best performances.
Other resources
Model conversion to Pytorch format
This repo provides a solution for converting the pretrained SimCLRv1 Tensorflow checkpoints into Pytorch ones.
This repo provides a solution for converting the pretrained SimCLRv2 Tensorflow checkpoints into Pytorch ones.
Other non-offical / unverified implementations
(Feel free to share your implementation by creating an issue)
Implementations in PyTorch:
Implementations in Tensorflow 2 / Keras (official TF2 implementation was added in tf2/ folder):
Known issues
-
Batch size: original results of SimCLR were tuned under a large batch size (i.e. 4096), which leads to suboptimal results when training using a smaller batch size. However, with a good set of hyper-parameters (mainly learning rate, temperature, projection head depth), small batch sizes can yield results that are on par with large batch sizes (e.g., see Table 2 in this paper).
-
Pretrained models / Checkpoints: SimCLRv1 and SimCLRv2 are pretrained with different weight decays, so the pretrained models from the two versions have very different weight norm scales (convolutional weights in SimCLRv1 ResNet-50 are on average 16.8X of that in SimCLRv2). For fine-tuning the pretrained models from both versions, it is fine if you use an LARS optimizer, but it requires very different hyperparameters (e.g. learning rate, weight decay) if you use the momentum optimizer. So for the latter case, you may want to either search for very different hparams according to which version used, or re-scale th weight (i.e. conv
kernel
parameters ofbase_model
in the checkpoints) to make sure they're roughly in the same scale.
Cite
@article{chen2020simple,
title={A Simple Framework for Contrastive Learning of Visual Representations},
author={Chen, Ting and Kornblith, Simon and Norouzi, Mohammad and Hinton, Geoffrey},
journal={arXiv preprint arXiv:2002.05709},
year={2020}
}
@article{chen2020big,
title={Big Self-Supervised Models are Strong Semi-Supervised Learners},
author={Chen, Ting and Kornblith, Simon and Swersky, Kevin and Norouzi, Mohammad and Hinton, Geoffrey},
journal={arXiv preprint arXiv:2006.10029},
year={2020}
}
Disclaimer
This is not an official Google product.
Top Related Projects
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot