pytorch-metric-learning
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.
Top Related Projects
A library for efficient similarity search and clustering of dense vectors.
TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
Qdrant - High-performance, massive-scale Vector Database for the next generation of AI. Also available in the cloud https://cloud.qdrant.io/
Approximate Nearest Neighbors in C++/Python optimized for memory usage and loading/saving to disk
Quick Overview
PyTorch Metric Learning is a library that provides a comprehensive collection of metric learning losses, miners, samplers, and trainers for PyTorch. It aims to make metric learning more accessible and easier to implement in various deep learning projects, offering a wide range of algorithms and techniques for learning similarity metrics between data points.
Pros
- Extensive collection of metric learning algorithms and techniques
- Easy integration with PyTorch projects
- Well-documented with clear examples and tutorials
- Actively maintained and regularly updated
Cons
- Steep learning curve for beginners in metric learning
- Some advanced features may require in-depth understanding of metric learning concepts
- Performance can vary depending on the specific use case and dataset
Code Examples
- Basic usage of Contrastive Loss:
from pytorch_metric_learning import losses
loss_func = losses.ContrastiveLoss()
loss = loss_func(embeddings, labels)
- Using a mining function with Triplet Margin Loss:
from pytorch_metric_learning import losses, miners
miner = miners.MultiSimilarityMiner()
loss_func = losses.TripletMarginLoss()
hard_pairs = miner(embeddings, labels)
loss = loss_func(embeddings, labels, hard_pairs)
- Implementing a basic metric learning trainer:
from pytorch_metric_learning import trainers, testers
trainer = trainers.MetricLossOnly(
models={"trunk": trunk, "embedder": embedder},
optimizers={"trunk_optimizer": trunk_optimizer, "embedder_optimizer": embedder_optimizer},
loss_funcs={"metric_loss": loss_func},
mining_funcs={"tuple_miner": miner},
)
trainer.train(num_epochs=10, train_loader=train_loader)
Getting Started
To get started with PyTorch Metric Learning, follow these steps:
- Install the library:
pip install pytorch-metric-learning
- Import the necessary modules:
import torch
from pytorch_metric_learning import losses, miners, distances, reducers, testers
- Define your model, loss function, and miner:
model = YourModel()
loss_func = losses.ContrastiveLoss()
miner = miners.MultiSimilarityMiner()
- Use the loss function and miner in your training loop:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(num_epochs):
for batch in dataloader:
embeddings = model(batch)
hard_pairs = miner(embeddings, labels)
loss = loss_func(embeddings, labels, hard_pairs)
loss.backward()
optimizer.step()
optimizer.zero_grad()
This basic setup will get you started with metric learning using PyTorch Metric Learning. For more advanced usage and detailed examples, refer to the library's documentation.
Competitor Comparisons
A library for efficient similarity search and clustering of dense vectors.
Pros of faiss
- Highly optimized for large-scale similarity search and clustering
- Supports GPU acceleration for faster processing
- Provides efficient indexing structures for billion-scale datasets
Cons of faiss
- Steeper learning curve due to its C++ core and Python bindings
- Less integrated with PyTorch ecosystem
- Focused primarily on similarity search, not as flexible for general metric learning tasks
Code Comparison
faiss:
import faiss
index = faiss.IndexFlatL2(d)
index.add(xb)
D, I = index.search(xq, k)
pytorch-metric-learning:
from pytorch_metric_learning import losses, miners, distances
loss_func = losses.TripletMarginLoss()
mining_func = miners.TripletMarginMiner()
loss = loss_func(embeddings, labels, mining_func(embeddings, labels))
faiss excels in efficient similarity search and indexing for large-scale datasets, while pytorch-metric-learning offers a more flexible and PyTorch-integrated approach to metric learning tasks. faiss is better suited for production-scale similarity search, while pytorch-metric-learning is more appropriate for research and experimentation in metric learning techniques.
TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
Pros of tensorflow/similarity
- Seamless integration with TensorFlow ecosystem
- Optimized for distributed training and large-scale deployments
- Comprehensive documentation and examples for various use cases
Cons of tensorflow/similarity
- Less flexible compared to pytorch-metric-learning
- Smaller community and fewer third-party contributions
- Limited support for custom loss functions and mining strategies
Code Comparison
pytorch-metric-learning:
from pytorch_metric_learning import losses, miners, distances
loss_func = losses.TripletMarginLoss()
mining_func = miners.TripletMarginMiner()
distance = distances.CosineSimilarity()
loss = loss_func(embeddings, labels, mining_func(embeddings, labels))
tensorflow/similarity:
import tensorflow_similarity as tfsim
loss = tfsim.losses.MultiSimilarityLoss()
model = tfsim.models.SimilarityModel(
backbone=backbone,
loss=loss,
metrics=[tfsim.metrics.DistanceMetric(distance='cosine')]
)
Both libraries offer concise ways to implement metric learning, but pytorch-metric-learning provides more granular control over individual components, while tensorflow/similarity offers a more integrated approach within the TensorFlow ecosystem.
PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
Pros of pytorch3d
- Comprehensive 3D computer vision library with a wide range of functionalities
- Optimized for GPU acceleration, providing faster computations for 3D tasks
- Seamless integration with PyTorch ecosystem and neural networks
Cons of pytorch3d
- Steeper learning curve due to its focus on 3D vision tasks
- Less specialized in metric learning compared to pytorch-metric-learning
- Larger codebase and potentially higher overhead for simpler tasks
Code Comparison
pytorch3d:
from pytorch3d.structures import Meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import chamfer_distance
verts, faces = load_mesh()
mesh = Meshes(verts=[verts], faces=[faces])
sample_points = sample_points_from_meshes(mesh, num_samples=5000)
loss = chamfer_distance(sample_points, target_points)
pytorch-metric-learning:
from pytorch_metric_learning import losses, miners, distances
loss_func = losses.TripletMarginLoss()
mining_func = miners.TripletMarginMiner()
distance = distances.CosineSimilarity()
hard_pairs = mining_func(embeddings, labels)
loss = loss_func(embeddings, labels, hard_pairs)
Qdrant - High-performance, massive-scale Vector Database for the next generation of AI. Also available in the cloud https://cloud.qdrant.io/
Pros of Qdrant
- Designed specifically for vector similarity search and storage
- Supports real-time updates and efficient querying
- Offers a RESTful API and various client libraries
Cons of Qdrant
- Limited to vector similarity search use cases
- Requires separate infrastructure setup and maintenance
- Less flexible for custom metric learning tasks
Code Comparison
Qdrant (Python client):
from qdrant_client import QdrantClient
client = QdrantClient("localhost", port=6333)
client.create_collection("my_collection", vector_size=768)
client.upsert("my_collection", [(1, [0.1, 0.2, 0.3], {"metadata": "value"})])
pytorch-metric-learning:
from pytorch_metric_learning import losses, miners, distances
loss_func = losses.TripletMarginLoss()
mining_func = miners.TripletMarginMiner()
distance = distances.CosineSimilarity()
Summary
Qdrant is a specialized vector database for similarity search, offering efficient querying and storage. pytorch-metric-learning is a library for developing and training custom metric learning models. While Qdrant excels in production deployments for vector search, pytorch-metric-learning provides more flexibility for research and experimentation in metric learning tasks.
Approximate Nearest Neighbors in C++/Python optimized for memory usage and loading/saving to disk
Pros of annoy
- Lightweight and efficient for approximate nearest neighbor search
- Language-agnostic with bindings for multiple programming languages
- Optimized for memory usage and fast index building
Cons of annoy
- Limited to specific use cases (nearest neighbor search)
- Less flexibility for custom metric learning tasks
- Fewer built-in distance metrics compared to pytorch-metric-learning
Code Comparison
annoy:
from annoy import AnnoyIndex
f = 40
t = AnnoyIndex(f, 'angular')
for i in range(1000):
v = [random.gauss(0, 1) for z in range(f)]
t.add_item(i, v)
t.build(10)
pytorch-metric-learning:
from pytorch_metric_learning import losses, miners, distances
loss_func = losses.TripletMarginLoss()
mining_func = miners.TripletMarginMiner()
distance = distances.CosineSimilarity()
loss = loss_func(embeddings, labels, mining_func(embeddings, labels))
Summary
annoy is a specialized library for efficient approximate nearest neighbor search, while pytorch-metric-learning is a more comprehensive framework for metric learning tasks in PyTorch. annoy excels in performance and language support, but pytorch-metric-learning offers greater flexibility and a wider range of metric learning techniques.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
News
July 24: v2.6.0
- Changed the
emb
argument ofDistributedLossWrapper.forward
toembeddings
to be consistent with the rest of the library. - Added a warning and early-return when
DistributedLossWrapper
is being used in a non-distributed setting. - Thank you elisim.
April 1: v2.5.0
- Improved
get_all_triplets_indices
so that large batch sizes don't trigger theINT_MAX
error. - See the release notes.
- Thank you mkmenta.
Documentation
- View the documentation here
- View the installation instructions here
- View the available losses, miners etc. here
Google Colab Examples
See the examples folder for notebooks you can download or run on Google Colab.
PyTorch Metric Learning Overview
This library contains 9 modules, each of which can be used independently within your existing codebase, or combined together for a complete train/test workflow.
How loss functions work
Using losses and miners in your training loop
Letâs initialize a plain TripletMarginLoss:
from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss()
To compute the loss in your training loop, pass in the embeddings computed by your model, and the corresponding labels. The embeddings should have size (N, embedding_size), and the labels should have size (N), where N is the batch size.
# your training loop
for i, (data, labels) in enumerate(dataloader):
optimizer.zero_grad()
embeddings = model(data)
loss = loss_func(embeddings, labels)
loss.backward()
optimizer.step()
The TripletMarginLoss computes all possible triplets within the batch, based on the labels you pass into it. Anchor-positive pairs are formed by embeddings that share the same label, and anchor-negative pairs are formed by embeddings that have different labels.
Sometimes it can help to add a mining function:
from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner()
loss_func = losses.TripletMarginLoss()
# your training loop
for i, (data, labels) in enumerate(dataloader):
optimizer.zero_grad()
embeddings = model(data)
hard_pairs = miner(embeddings, labels)
loss = loss_func(embeddings, labels, hard_pairs)
loss.backward()
optimizer.step()
In the above code, the miner finds positive and negative pairs that it thinks are particularly difficult. Note that even though the TripletMarginLoss operates on triplets, itâs still possible to pass in pairs. This is because the library automatically converts pairs to triplets and triplets to pairs, when necessary.
Customizing loss functions
Loss functions can be customized using distances, reducers, and regularizers. In the diagram below, a miner finds the indices of hard pairs within a batch. These are used to index into the distance matrix, computed by the distance object. For this diagram, the loss function is pair-based, so it computes a loss per pair. In addition, a regularizer has been supplied, so a regularization loss is computed for each embedding in the batch. The per-pair and per-element losses are passed to the reducer, which (in this diagram) only keeps losses with a high value. The averages are computed for the high-valued pair and element losses, and are then added together to obtain the final loss.
Now here's an example of a customized TripletMarginLoss:
from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.reducers import ThresholdReducer
from pytorch_metric_learning.regularizers import LpRegularizer
from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss(distance = CosineSimilarity(),
reducer = ThresholdReducer(high=0.3),
embedding_regularizer = LpRegularizer())
This customized triplet loss has the following properties:
- The loss will be computed using cosine similarity instead of Euclidean distance.
- All triplet losses that are higher than 0.3 will be discarded.
- The embeddings will be L2 regularized.
Using loss functions for unsupervised / self-supervised learning
A SelfSupervisedLoss
wrapper is provided for self-supervised learning:
from pytorch_metric_learning.losses import SelfSupervisedLoss
loss_func = SelfSupervisedLoss(TripletMarginLoss())
# your training for-loop
for i, data in enumerate(dataloader):
optimizer.zero_grad()
embeddings = your_model(data)
augmented = your_model(your_augmentation(data))
loss = loss_func(embeddings, augmented)
loss.backward()
optimizer.step()
If you're interested in MoCo-style self-supervision, take a look at the MoCo on CIFAR10 notebook. It uses CrossBatchMemory to implement the momentum encoder queue, which means you can use any tuple loss, and any tuple miner to extract hard samples from the queue.
Highlights of the rest of the library
- For a convenient way to train your model, take a look at the trainers.
- Want to test your model's accuracy on a dataset? Try the testers.
- To compute the accuracy of an embedding space directly, use AccuracyCalculator.
If you're short of time and want a complete train/test workflow, check out the example Google Colab notebooks.
To learn more about all of the above, see the documentation.
Installation
Required PyTorch version
pytorch-metric-learning >= v0.9.90
requirestorch >= 1.6
pytorch-metric-learning < v0.9.90
doesn't have a version requirement, but was tested withtorch >= 1.2
Other dependencies: numpy, scikit-learn, tqdm, torchvision
Pip
pip install pytorch-metric-learning
To get the latest dev version:
pip install pytorch-metric-learning --pre
To install on Windows:
pip install torch===1.6.0 torchvision===0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install pytorch-metric-learning
To install with evaluation and logging capabilities
(This will install the unofficial pypi version of faiss-gpu, plus record-keeper and tensorboard):
pip install pytorch-metric-learning[with-hooks]
To install with evaluation and logging capabilities (CPU)
(This will install the unofficial pypi version of faiss-cpu, plus record-keeper and tensorboard):
pip install pytorch-metric-learning[with-hooks-cpu]
Conda
conda install -c conda-forge pytorch-metric-learning
To use the testing module, you'll need faiss, which can be installed via conda as well. See the installation instructions for faiss.
Benchmark results
See powerful-benchmarker to view benchmark results and to use the benchmarking tool.
Development
Development is done on the dev
branch:
git checkout dev
Unit tests can be run with the default unittest
library:
python -m unittest discover
You can specify the test datatypes and test device as environment variables. For example, to test using float32 and float64 on the CPU:
TEST_DTYPES=float32,float64 TEST_DEVICE=cpu python -m unittest discover
To run a single test file instead of the entire test suite, specify the file name:
python -m unittest tests/losses/test_angular_loss.py
Code is formatted using black
and isort
:
pip install black isort
./format_code.sh
Acknowledgements
Contributors
Thanks to the contributors who made pull requests!
Facebook AI
Thank you to Ser-Nam Lim at Facebook AI, and my research advisor, Professor Serge Belongie. This project began during my internship at Facebook AI where I received valuable feedback from Ser-Nam, and his team of computer vision and machine learning engineers and research scientists. In particular, thanks to Ashish Shah and Austin Reiter for reviewing my code during its early stages of development.
Open-source repos
This library contains code that has been adapted and modified from the following great open-source repos:
- https://github.com/bnu-wangxun/Deep_Metric
- https://github.com/chaoyuaw/incubator-mxnet/blob/master/example/gluon/embedding_learning
- https://github.com/facebookresearch/deepcluster
- https://github.com/geonm/proxy-anchor-loss
- https://github.com/idstcv/SoftTriple
- https://github.com/kunhe/FastAP-metric-learning
- https://github.com/ronekko/deep_metric_learning
- https://github.com/tjddus9597/Proxy-Anchor-CVPR2020
- http://kaizhao.net/regularface
- https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts
Logo
Thanks to Jeff Musgrave for designing the logo.
Citing this library
If you'd like to cite pytorch-metric-learning in your paper, you can use this bibtex:
@article{Musgrave2020PyTorchML,
title={PyTorch Metric Learning},
author={Kevin Musgrave and Serge J. Belongie and Ser-Nam Lim},
journal={ArXiv},
year={2020},
volume={abs/2008.09164}
}
Top Related Projects
A library for efficient similarity search and clustering of dense vectors.
TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
PyTorch3D is FAIR's library of reusable components for deep learning with 3D data
Qdrant - High-performance, massive-scale Vector Database for the next generation of AI. Also available in the cloud https://cloud.qdrant.io/
Approximate Nearest Neighbors in C++/Python optimized for memory usage and loading/saving to disk
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot