pytorch-lightning
Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
Top Related Projects
Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
The fastai deep learning library
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Accelerated deep learning R&D
Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
A hyperparameter optimization framework
Quick Overview
PyTorch Lightning is a high-level deep learning framework that simplifies the process of building, training, and deploying PyTorch models. It provides a modular and scalable architecture, allowing developers to focus on the core aspects of their models rather than boilerplate code.
Pros
- Simplifies PyTorch: PyTorch Lightning abstracts away many of the low-level details of PyTorch, making it easier for developers to focus on the core aspects of their models.
- Modular and Scalable: The framework's modular design allows for easy customization and scalability, making it suitable for a wide range of deep learning projects.
- Efficient Training: PyTorch Lightning includes built-in support for efficient training, including automatic mixed precision, gradient accumulation, and gradient clipping.
- Extensive Ecosystem: The project has a large and active community, with a wide range of pre-built components and integrations with other popular deep learning libraries and tools.
Cons
- Steeper Learning Curve: While PyTorch Lightning simplifies certain aspects of PyTorch, it also introduces its own set of concepts and conventions that developers need to learn.
- Limited Flexibility: In some cases, the abstraction provided by PyTorch Lightning may limit the flexibility and control that developers have over their models, compared to working directly with PyTorch.
- Potential Performance Overhead: The additional layer of abstraction provided by PyTorch Lightning may introduce some performance overhead, especially for simple or highly customized models.
- Dependency on PyTorch: PyTorch Lightning is tightly coupled with the PyTorch library, so developers who are not familiar with PyTorch may face a steeper learning curve.
Code Examples
Here are a few code examples demonstrating the usage of PyTorch Lightning:
Basic Model Definition
import pytorch_lightning as pl
import torch.nn as nn
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer_1(x)
x = torch.relu(x)
x = self.layer_2(x)
return x
This code defines a simple PyTorch Lightning module for classifying MNIST digits.
Training a Model
from pytorch_lightning.trainer import Trainer
model = LightningMNISTClassifier()
trainer = Trainer(max_epochs=10)
trainer.fit(model, train_dataloader, val_dataloader)
This code creates an instance of the LightningMNISTClassifier
model and trains it using the Trainer
class, which handles the training loop and other boilerplate tasks.
Logging and Callbacks
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
logger = TensorBoardLogger("tb_logs", name="my_model")
early_stop_callback = EarlyStopping(monitor='val_loss', patience=3)
checkpoint_callback = ModelCheckpoint(monitor='val_acc')
trainer = Trainer(
max_epochs=10,
logger=logger,
callbacks=[early_stop_callback, checkpoint_callback]
)
This code demonstrates how to use PyTorch Lightning's logging and callback features, including TensorBoard logging, early stopping, and model checkpointing.
Getting Started
To get started with PyTorch Lightning, follow these steps:
- Install the library using pip:
pip install pytorch-lightning
-
Create a PyTorch Lightning module by subclassing
pl.LightningModule
and implementing the necessary methods, such asforward
,training_step
,validation_step
, andconfigure_optimizers
. -
Prepare your data loaders for training and validation.
-
Create a
Trainer
instance and call thefit()
method to train your model:
from pytorch_lightning.trainer import Trainer
model
Competitor Comparisons
Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
Pros of pytorch-lightning
- Simplified and organized PyTorch code structure
- Automatic GPU/TPU support and distributed training
- Extensive ecosystem of plugins and integrations
Cons of pytorch-lightning
- Steeper learning curve for beginners
- Less flexibility in certain custom training scenarios
- Potential overhead for simple projects
Code Comparison
pytorch-lightning:
class LitModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
PyTorch (without pytorch-lightning):
for epoch in range(num_epochs):
for batch in dataloader:
x, y = batch
optimizer.zero_grad()
y_hat = model(x)
loss = F.cross_entropy(y_hat, y)
loss.backward()
optimizer.step()
pytorch-lightning abstracts away much of the boilerplate code, allowing developers to focus on the core logic of their models. It provides a more structured approach to PyTorch development, which can lead to cleaner and more maintainable code. However, this abstraction may come at the cost of some flexibility in certain scenarios.
The fastai deep learning library
Pros of fastai
- Simpler API with high-level abstractions for common tasks
- Integrated data preprocessing and augmentation pipelines
- Extensive documentation and tutorials for beginners
Cons of fastai
- Less flexibility for custom architectures and workflows
- Smaller community and ecosystem compared to PyTorch Lightning
- Steeper learning curve for those familiar with raw PyTorch
Code Comparison
fastai:
from fastai.vision.all import *
dls = ImageDataLoaders.from_folder(path)
learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn.fit_one_cycle(4)
PyTorch Lightning:
import pytorch_lightning as pl
class MyModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
trainer = pl.Trainer(max_epochs=4)
trainer.fit(model, train_loader)
The fastai code is more concise and abstracts away many details, while PyTorch Lightning provides a more structured approach with explicit definition of training steps and model components.
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Pros of transformers
- Extensive pre-trained model library for NLP tasks
- Easy-to-use API for fine-tuning and inference
- Strong community support and frequent updates
Cons of transformers
- Focused primarily on NLP, less versatile for other domains
- Can be resource-intensive for large models
- Steeper learning curve for customizing model architectures
Code Comparison
transformers:
from transformers import BertTokenizer, BertForSequenceClassification
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
inputs = tokenizer("Hello, world!", return_tensors="pt")
outputs = model(**inputs)
pytorch-lightning:
import pytorch_lightning as pl
class MyModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
loss = self.model(batch)
return loss
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloader)
Both repositories serve different purposes: transformers focuses on pre-trained models for NLP tasks, while pytorch-lightning provides a high-level interface for organizing PyTorch code and simplifying the training process across various domains.
Accelerated deep learning R&D
Pros of Catalyst
- More flexible experiment configuration with YAML files
- Built-in support for various ML tasks (classification, segmentation, etc.)
- Extensive set of callbacks and metrics out-of-the-box
Cons of Catalyst
- Steeper learning curve due to more complex API
- Smaller community and fewer resources compared to PyTorch Lightning
- Less frequent updates and maintenance
Code Comparison
Catalyst:
from catalyst import dl
class CustomRunner(dl.Runner):
def predict_batch(self, batch):
return self.model(batch[0])
runner = CustomRunner()
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=10,
callbacks=[dl.AccuracyCallback()]
)
PyTorch Lightning:
import pytorch_lightning as pl
class LitModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_loader)
Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
Pros of Ray
- More general-purpose distributed computing framework, supporting various ML tasks beyond just PyTorch
- Offers advanced features like distributed hyperparameter tuning and reinforcement learning
- Provides a flexible and scalable architecture for distributed applications
Cons of Ray
- Steeper learning curve due to its broader scope and more complex API
- Less specialized for PyTorch-specific workflows compared to PyTorch Lightning
- May require more setup and configuration for simple PyTorch projects
Code Comparison
Ray example:
import ray
@ray.remote
def train_model(hyperparameters):
# Model training code here
return results
results = ray.get([train_model.remote(hp) for hp in hyperparameters])
PyTorch Lightning example:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
trainer = Trainer(gpus=1, callbacks=[EarlyStopping()])
trainer.fit(model)
Ray is more flexible but requires explicit distributed setup, while PyTorch Lightning provides a higher-level abstraction for PyTorch training workflows.
A hyperparameter optimization framework
Pros of Optuna
- More flexible and framework-agnostic, supporting various ML libraries beyond PyTorch
- Provides advanced hyperparameter optimization techniques like pruning and parallel optimization
- Offers a wider range of built-in optimization algorithms (e.g., TPE, CMA-ES, NSGA-II)
Cons of Optuna
- Lacks built-in training loop abstractions and progress tracking features
- Requires more manual setup for integrating with deep learning workflows
- Less focus on distributed training and multi-GPU support out of the box
Code Comparison
Optuna:
import optuna
def objective(trial):
x = trial.suggest_float('x', -10, 10)
return (x - 2) ** 2
study = optuna.create_study()
study.optimize(objective, n_trials=100)
PyTorch Lightning:
import pytorch_lightning as pl
class MyModel(pl.LightningModule):
def training_step(self, batch, batch_idx):
x, y = batch
loss = self.loss_fn(self(x), y)
return loss
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloader)
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
The deep learning framework to pretrain, finetune and deploy AI models.
NEW- Deploying models? Check out LitServe, the PyTorch Lightning for model serving
Quick start ⢠Examples ⢠PyTorch Lightning ⢠Fabric ⢠Lightning AI ⢠Community ⢠Docs
Lightning has 2 core packages
PyTorch Lightning: Train and deploy PyTorch at scale.
Lightning Fabric: Expert control.
Lightning gives you granular control over how much abstraction you want to add over PyTorch.
Quick start
Install Lightning:
pip install lightning
Advanced install options
Install with optional dependencies
pip install lightning['extra']
Conda
conda install lightning -c conda-forge
Install stable version
Install future release from the source
pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/release/stable.zip -U
Install bleeding-edge
Install nightly from the source (no guarantees)
pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U
or from testing PyPI
pip install -iU https://test.pypi.org/simple/ pytorch-lightning
PyTorch Lightning example
Define the training workflow. Here's a toy example (explore real examples):
# main.py
# ! pip install torchvision
import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F
import lightning as L
# --------------------------------
# Step 1: Define a LightningModule
# --------------------------------
# A LightningModule (nn.Module subclass) defines a full *system*
# (ie: an LLM, diffusion model, autoencoder, or simple image classifier).
class LitAutoEncoder(L.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
# training_step defines the train loop. It is independent of forward
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
# -------------------
# Step 2: Define data
# -------------------
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])
# -------------------
# Step 3: Train
# -------------------
autoencoder = LitAutoEncoder()
trainer = L.Trainer()
trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val))
Run the model on your terminal
pip install torchvision
python main.py
Why PyTorch Lightning?
PyTorch Lightning is just organized PyTorch - Lightning disentangles PyTorch code to decouple the science from the engineering.
Examples
Explore various types of training possible with PyTorch Lightning. Pretrain and finetune ANY kind of model to perform ANY task like classification, segmentation, summarization and more:
Task | Description | Run |
---|---|---|
Hello world | Pretrain - Hello world example | |
Image classification | Finetune - ResNet-34 model to classify images of cars | |
Image segmentation | Finetune - ResNet-50 model to segment images | |
Object detection | Finetune - Faster R-CNN model to detect objects | |
Text classification | Finetune - text classifier (BERT model) | |
Text summarization | Finetune - text summarization (Hugging Face transformer model) | |
Audio generation | Finetune - audio generator (transformer model) | |
LLM finetuning | Finetune - LLM (Meta Llama 3.1 8B) | |
Image generation | Pretrain - Image generator (diffusion model) | |
Recommendation system | Train - recommendation system (factorization and embedding) | |
Time-series forecasting | Train - Time-series forecasting with LSTM |
Advanced features
Lightning has over 40+ advanced features designed for professional AI research at scale.
Here are some examples:
Train on 1000s of GPUs without code changes
# 8 GPUs
# no code changes needed
trainer = Trainer(accelerator="gpu", devices=8)
# 256 GPUs
trainer = Trainer(accelerator="gpu", devices=8, num_nodes=32)
Train on other accelerators like TPUs without code changes
# no code changes needed
trainer = Trainer(accelerator="tpu", devices=8)
16-bit precision
# no code changes needed
trainer = Trainer(precision=16)
Experiment managers
from lightning import loggers
# tensorboard
trainer = Trainer(logger=TensorBoardLogger("logs/"))
# weights and biases
trainer = Trainer(logger=loggers.WandbLogger())
# comet
trainer = Trainer(logger=loggers.CometLogger())
# mlflow
trainer = Trainer(logger=loggers.MLFlowLogger())
# neptune
trainer = Trainer(logger=loggers.NeptuneLogger())
# ... and dozens more
Early Stopping
es = EarlyStopping(monitor="val_loss")
trainer = Trainer(callbacks=[es])
Checkpointing
checkpointing = ModelCheckpoint(monitor="val_loss")
trainer = Trainer(callbacks=[checkpointing])
Export to torchscript (JIT) (production use)
# torchscript
autoencoder = LitAutoEncoder()
torch.jit.save(autoencoder.to_torchscript(), "model.pt")
Export to ONNX (production use)
# onnx
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile:
autoencoder = LitAutoEncoder()
input_sample = torch.randn((1, 64))
autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)
os.path.isfile(tmpfile.name)
Advantages over unstructured PyTorch
- Models become hardware agnostic
- Code is clear to read because engineering code is abstracted away
- Easier to reproduce
- Make fewer mistakes because lightning handles the tricky engineering
- Keeps all the flexibility (LightningModules are still PyTorch modules), but removes a ton of boilerplate
- Lightning has dozens of integrations with popular machine learning tools.
- Tested rigorously with every new PR. We test every combination of PyTorch and Python supported versions, every OS, multi GPUs and even TPUs.
- Minimal running speed overhead (about 300 ms per epoch compared with pure PyTorch).
Lightning Fabric: Expert control
Run on any device at any scale with expert-level control over PyTorch training loop and scaling strategy. You can even write your own Trainer.
Fabric is designed for the most complex models like foundation model scaling, LLMs, diffusion, transformers, reinforcement learning, active learning. Of any size.
What to change | Resulting Fabric Code (copy me!) |
---|---|
|
|
Key features
Easily switch from running on CPU to GPU (Apple Silicon, CUDA, â¦), TPU, multi-GPU or even multi-node training
# Use your available hardware
# no code changes needed
fabric = Fabric()
# Run on GPUs (CUDA or MPS)
fabric = Fabric(accelerator="gpu")
# 8 GPUs
fabric = Fabric(accelerator="gpu", devices=8)
# 256 GPUs, multi-node
fabric = Fabric(accelerator="gpu", devices=8, num_nodes=32)
# Run on TPUs
fabric = Fabric(accelerator="tpu")
Use state-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed) and mixed precision out of the box
# Use state-of-the-art distributed training techniques
fabric = Fabric(strategy="ddp")
fabric = Fabric(strategy="deepspeed")
fabric = Fabric(strategy="fsdp")
# Switch the precision
fabric = Fabric(precision="16-mixed")
fabric = Fabric(precision="64")
All the device logic boilerplate is handled for you
# no more of this!
- model.to(device)
- batch.to(device)
Build your own custom Trainer using Fabric primitives for training checkpointing, logging, and more
import lightning as L
class MyCustomTrainer:
def __init__(self, accelerator="auto", strategy="auto", devices="auto", precision="32-true"):
self.fabric = L.Fabric(accelerator=accelerator, strategy=strategy, devices=devices, precision=precision)
def fit(self, model, optimizer, dataloader, max_epochs):
self.fabric.launch()
model, optimizer = self.fabric.setup(model, optimizer)
dataloader = self.fabric.setup_dataloaders(dataloader)
model.train()
for epoch in range(max_epochs):
for batch in dataloader:
input, target = batch
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
self.fabric.backward(loss)
optimizer.step()
You can find a more extensive example in our examples
Examples
Self-supervised Learning
Convolutional Architectures
Reinforcement Learning
GANs
Classic ML
Continuous Integration
Lightning is rigorously tested across multiple CPUs, GPUs and TPUs and against major Python and PyTorch versions.
*Codecov is > 90%+ but build delays may show less
Current build statuses
System / PyTorch ver. | 1.13 | 2.0 | 2.1 |
---|---|---|---|
Linux py3.9 [GPUs] | |||
Linux py3.9 [TPUs] | |||
Linux (multiple Python versions) | |||
OSX (multiple Python versions) | |||
Windows (multiple Python versions) |
Community
The lightning community is maintained by
- 10+ core contributors who are all a mix of professional engineers, Research Scientists, and Ph.D. students from top AI labs.
- 800+ community contributors.
Want to help us build Lightning and reduce boilerplate for thousands of researchers? Learn how to make your first contribution here
Lightning is also part of the PyTorch ecosystem which requires projects to have solid testing, documentation and support.
Asking for help
If you have any questions please:
Top Related Projects
Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
The fastai deep learning library
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Accelerated deep learning R&D
Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
A hyperparameter optimization framework
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot