Convert Figma logo to code with AI

Lightning-AI logopytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.

27,932
3,341
27,932
820

Top Related Projects

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.

26,197

The fastai deep learning library

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

Accelerated deep learning R&D

33,421

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.

10,708

A hyperparameter optimization framework

Quick Overview

PyTorch Lightning is a high-level deep learning framework that simplifies the process of building, training, and deploying PyTorch models. It provides a modular and scalable architecture, allowing developers to focus on the core aspects of their models rather than boilerplate code.

Pros

  • Simplifies PyTorch: PyTorch Lightning abstracts away many of the low-level details of PyTorch, making it easier for developers to focus on the core aspects of their models.
  • Modular and Scalable: The framework's modular design allows for easy customization and scalability, making it suitable for a wide range of deep learning projects.
  • Efficient Training: PyTorch Lightning includes built-in support for efficient training, including automatic mixed precision, gradient accumulation, and gradient clipping.
  • Extensive Ecosystem: The project has a large and active community, with a wide range of pre-built components and integrations with other popular deep learning libraries and tools.

Cons

  • Steeper Learning Curve: While PyTorch Lightning simplifies certain aspects of PyTorch, it also introduces its own set of concepts and conventions that developers need to learn.
  • Limited Flexibility: In some cases, the abstraction provided by PyTorch Lightning may limit the flexibility and control that developers have over their models, compared to working directly with PyTorch.
  • Potential Performance Overhead: The additional layer of abstraction provided by PyTorch Lightning may introduce some performance overhead, especially for simple or highly customized models.
  • Dependency on PyTorch: PyTorch Lightning is tightly coupled with the PyTorch library, so developers who are not familiar with PyTorch may face a steeper learning curve.

Code Examples

Here are a few code examples demonstrating the usage of PyTorch Lightning:

Basic Model Definition

import pytorch_lightning as pl
import torch.nn as nn

class LightningMNISTClassifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.layer_1(x)
        x = torch.relu(x)
        x = self.layer_2(x)
        return x

This code defines a simple PyTorch Lightning module for classifying MNIST digits.

Training a Model

from pytorch_lightning.trainer import Trainer

model = LightningMNISTClassifier()
trainer = Trainer(max_epochs=10)
trainer.fit(model, train_dataloader, val_dataloader)

This code creates an instance of the LightningMNISTClassifier model and trains it using the Trainer class, which handles the training loop and other boilerplate tasks.

Logging and Callbacks

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

logger = TensorBoardLogger("tb_logs", name="my_model")
early_stop_callback = EarlyStopping(monitor='val_loss', patience=3)
checkpoint_callback = ModelCheckpoint(monitor='val_acc')

trainer = Trainer(
    max_epochs=10,
    logger=logger,
    callbacks=[early_stop_callback, checkpoint_callback]
)

This code demonstrates how to use PyTorch Lightning's logging and callback features, including TensorBoard logging, early stopping, and model checkpointing.

Getting Started

To get started with PyTorch Lightning, follow these steps:

  1. Install the library using pip:
pip install pytorch-lightning
  1. Create a PyTorch Lightning module by subclassing pl.LightningModule and implementing the necessary methods, such as forward, training_step, validation_step, and configure_optimizers.

  2. Prepare your data loaders for training and validation.

  3. Create a Trainer instance and call the fit() method to train your model:

from pytorch_lightning.trainer import Trainer

model

Competitor Comparisons

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.

Pros of pytorch-lightning

  • Simplified and organized PyTorch code structure
  • Automatic GPU/TPU support and distributed training
  • Extensive ecosystem of plugins and integrations

Cons of pytorch-lightning

  • Steeper learning curve for beginners
  • Less flexibility in certain custom training scenarios
  • Potential overhead for simple projects

Code Comparison

pytorch-lightning:

class LitModel(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

PyTorch (without pytorch-lightning):

for epoch in range(num_epochs):
    for batch in dataloader:
        x, y = batch
        optimizer.zero_grad()
        y_hat = model(x)
        loss = F.cross_entropy(y_hat, y)
        loss.backward()
        optimizer.step()

pytorch-lightning abstracts away much of the boilerplate code, allowing developers to focus on the core logic of their models. It provides a more structured approach to PyTorch development, which can lead to cleaner and more maintainable code. However, this abstraction may come at the cost of some flexibility in certain scenarios.

26,197

The fastai deep learning library

Pros of fastai

  • Simpler API with high-level abstractions for common tasks
  • Integrated data preprocessing and augmentation pipelines
  • Extensive documentation and tutorials for beginners

Cons of fastai

  • Less flexibility for custom architectures and workflows
  • Smaller community and ecosystem compared to PyTorch Lightning
  • Steeper learning curve for those familiar with raw PyTorch

Code Comparison

fastai:

from fastai.vision.all import *
dls = ImageDataLoaders.from_folder(path)
learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn.fit_one_cycle(4)

PyTorch Lightning:

import pytorch_lightning as pl
class MyModel(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss
trainer = pl.Trainer(max_epochs=4)
trainer.fit(model, train_loader)

The fastai code is more concise and abstracts away many details, while PyTorch Lightning provides a more structured approach with explicit definition of training steps and model components.

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

Pros of transformers

  • Extensive pre-trained model library for NLP tasks
  • Easy-to-use API for fine-tuning and inference
  • Strong community support and frequent updates

Cons of transformers

  • Focused primarily on NLP, less versatile for other domains
  • Can be resource-intensive for large models
  • Steeper learning curve for customizing model architectures

Code Comparison

transformers:

from transformers import BertTokenizer, BertForSequenceClassification
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
inputs = tokenizer("Hello, world!", return_tensors="pt")
outputs = model(**inputs)

pytorch-lightning:

import pytorch_lightning as pl
class MyModel(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        loss = self.model(batch)
        return loss
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloader)

Both repositories serve different purposes: transformers focuses on pre-trained models for NLP tasks, while pytorch-lightning provides a high-level interface for organizing PyTorch code and simplifying the training process across various domains.

Accelerated deep learning R&D

Pros of Catalyst

  • More flexible experiment configuration with YAML files
  • Built-in support for various ML tasks (classification, segmentation, etc.)
  • Extensive set of callbacks and metrics out-of-the-box

Cons of Catalyst

  • Steeper learning curve due to more complex API
  • Smaller community and fewer resources compared to PyTorch Lightning
  • Less frequent updates and maintenance

Code Comparison

Catalyst:

from catalyst import dl

class CustomRunner(dl.Runner):
    def predict_batch(self, batch):
        return self.model(batch[0])

runner = CustomRunner()
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    num_epochs=10,
    callbacks=[dl.AccuracyCallback()]
)

PyTorch Lightning:

import pytorch_lightning as pl

class LitModel(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_loader)
33,421

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.

Pros of Ray

  • More general-purpose distributed computing framework, supporting various ML tasks beyond just PyTorch
  • Offers advanced features like distributed hyperparameter tuning and reinforcement learning
  • Provides a flexible and scalable architecture for distributed applications

Cons of Ray

  • Steeper learning curve due to its broader scope and more complex API
  • Less specialized for PyTorch-specific workflows compared to PyTorch Lightning
  • May require more setup and configuration for simple PyTorch projects

Code Comparison

Ray example:

import ray

@ray.remote
def train_model(hyperparameters):
    # Model training code here
    return results

results = ray.get([train_model.remote(hp) for hp in hyperparameters])

PyTorch Lightning example:

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping

trainer = Trainer(gpus=1, callbacks=[EarlyStopping()])
trainer.fit(model)

Ray is more flexible but requires explicit distributed setup, while PyTorch Lightning provides a higher-level abstraction for PyTorch training workflows.

10,708

A hyperparameter optimization framework

Pros of Optuna

  • More flexible and framework-agnostic, supporting various ML libraries beyond PyTorch
  • Provides advanced hyperparameter optimization techniques like pruning and parallel optimization
  • Offers a wider range of built-in optimization algorithms (e.g., TPE, CMA-ES, NSGA-II)

Cons of Optuna

  • Lacks built-in training loop abstractions and progress tracking features
  • Requires more manual setup for integrating with deep learning workflows
  • Less focus on distributed training and multi-GPU support out of the box

Code Comparison

Optuna:

import optuna

def objective(trial):
    x = trial.suggest_float('x', -10, 10)
    return (x - 2) ** 2

study = optuna.create_study()
study.optimize(objective, n_trials=100)

PyTorch Lightning:

import pytorch_lightning as pl

class MyModel(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = self.loss_fn(self(x), y)
        return loss

trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloader)

Convert Figma logo designs to code with AI

Visual Copilot

Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.

Try Visual Copilot

README

Lightning

The deep learning framework to pretrain, finetune and deploy AI models.

NEW- Deploying models? Check out LitServe, the PyTorch Lightning for model serving


Quick start • Examples • PyTorch Lightning • Fabric • Lightning AI • Community • Docs

PyPI - Python Version PyPI Status PyPI - Downloads Conda codecov

Discord GitHub commit activity license

 

Get started

 

Lightning has 2 core packages

PyTorch Lightning: Train and deploy PyTorch at scale.
Lightning Fabric: Expert control.

Lightning gives you granular control over how much abstraction you want to add over PyTorch.

 

Quick start

Install Lightning:

pip install lightning
Advanced install options

Install with optional dependencies

pip install lightning['extra']

Conda

conda install lightning -c conda-forge

Install stable version

Install future release from the source

pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/release/stable.zip -U

Install bleeding-edge

Install nightly from the source (no guarantees)

pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U

or from testing PyPI

pip install -iU https://test.pypi.org/simple/ pytorch-lightning

PyTorch Lightning example

Define the training workflow. Here's a toy example (explore real examples):

# main.py
# ! pip install torchvision
import torch, torch.nn as nn, torch.utils.data as data, torchvision as tv, torch.nn.functional as F
import lightning as L

# --------------------------------
# Step 1: Define a LightningModule
# --------------------------------
# A LightningModule (nn.Module subclass) defines a full *system*
# (ie: an LLM, diffusion model, autoencoder, or simple image classifier).


class LitAutoEncoder(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# -------------------
# Step 2: Define data
# -------------------
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
train, val = data.random_split(dataset, [55000, 5000])

# -------------------
# Step 3: Train
# -------------------
autoencoder = LitAutoEncoder()
trainer = L.Trainer()
trainer.fit(autoencoder, data.DataLoader(train), data.DataLoader(val))

Run the model on your terminal

pip install torchvision
python main.py

 

Why PyTorch Lightning?

PyTorch Lightning is just organized PyTorch - Lightning disentangles PyTorch code to decouple the science from the engineering.

PT to PL

 


Examples

Explore various types of training possible with PyTorch Lightning. Pretrain and finetune ANY kind of model to perform ANY task like classification, segmentation, summarization and more:

TaskDescriptionRun
Hello worldPretrain - Hello world exampleOpen In Studio
Image classificationFinetune - ResNet-34 model to classify images of carsOpen In Studio
Image segmentationFinetune - ResNet-50 model to segment imagesOpen In Studio
Object detectionFinetune - Faster R-CNN model to detect objectsOpen In Studio
Text classificationFinetune - text classifier (BERT model)Open In Studio
Text summarizationFinetune - text summarization (Hugging Face transformer model)Open In Studio
Audio generationFinetune - audio generator (transformer model)Open In Studio
LLM finetuningFinetune - LLM (Meta Llama 3.1 8B)Open In Studio
Image generationPretrain - Image generator (diffusion model)Open In Studio
Recommendation systemTrain - recommendation system (factorization and embedding)Open In Studio
Time-series forecastingTrain - Time-series forecasting with LSTMOpen In Studio

Advanced features

Lightning has over 40+ advanced features designed for professional AI research at scale.

Here are some examples:

Train on 1000s of GPUs without code changes
# 8 GPUs
# no code changes needed
trainer = Trainer(accelerator="gpu", devices=8)

# 256 GPUs
trainer = Trainer(accelerator="gpu", devices=8, num_nodes=32)
Train on other accelerators like TPUs without code changes
# no code changes needed
trainer = Trainer(accelerator="tpu", devices=8)
16-bit precision
# no code changes needed
trainer = Trainer(precision=16)
Experiment managers
from lightning import loggers

# tensorboard
trainer = Trainer(logger=TensorBoardLogger("logs/"))

# weights and biases
trainer = Trainer(logger=loggers.WandbLogger())

# comet
trainer = Trainer(logger=loggers.CometLogger())

# mlflow
trainer = Trainer(logger=loggers.MLFlowLogger())

# neptune
trainer = Trainer(logger=loggers.NeptuneLogger())

# ... and dozens more
Early Stopping
es = EarlyStopping(monitor="val_loss")
trainer = Trainer(callbacks=[es])
Checkpointing
checkpointing = ModelCheckpoint(monitor="val_loss")
trainer = Trainer(callbacks=[checkpointing])
Export to torchscript (JIT) (production use)
# torchscript
autoencoder = LitAutoEncoder()
torch.jit.save(autoencoder.to_torchscript(), "model.pt")
Export to ONNX (production use)
# onnx
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile:
    autoencoder = LitAutoEncoder()
    input_sample = torch.randn((1, 64))
    autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)
    os.path.isfile(tmpfile.name)

Advantages over unstructured PyTorch

  • Models become hardware agnostic
  • Code is clear to read because engineering code is abstracted away
  • Easier to reproduce
  • Make fewer mistakes because lightning handles the tricky engineering
  • Keeps all the flexibility (LightningModules are still PyTorch modules), but removes a ton of boilerplate
  • Lightning has dozens of integrations with popular machine learning tools.
  • Tested rigorously with every new PR. We test every combination of PyTorch and Python supported versions, every OS, multi GPUs and even TPUs.
  • Minimal running speed overhead (about 300 ms per epoch compared with pure PyTorch).


   

Lightning Fabric: Expert control

Run on any device at any scale with expert-level control over PyTorch training loop and scaling strategy. You can even write your own Trainer.

Fabric is designed for the most complex models like foundation model scaling, LLMs, diffusion, transformers, reinforcement learning, active learning. Of any size.

What to change Resulting Fabric Code (copy me!)
+ import lightning as L
  import torch; import torchvision as tv

 dataset = tv.datasets.CIFAR10("data", download=True,
                               train=True,
                               transform=tv.transforms.ToTensor())

+ fabric = L.Fabric()
+ fabric.launch()

  model = tv.models.resnet18()
  optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
- device = "cuda" if torch.cuda.is_available() else "cpu"
- model.to(device)
+ model, optimizer = fabric.setup(model, optimizer)

  dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
+ dataloader = fabric.setup_dataloaders(dataloader)

  model.train()
  num_epochs = 10
  for epoch in range(num_epochs):
      for batch in dataloader:
          inputs, labels = batch
-         inputs, labels = inputs.to(device), labels.to(device)
          optimizer.zero_grad()
          outputs = model(inputs)
          loss = torch.nn.functional.cross_entropy(outputs, labels)
-         loss.backward()
+         fabric.backward(loss)
          optimizer.step()
          print(loss.data)
import lightning as L
import torch; import torchvision as tv

dataset = tv.datasets.CIFAR10("data", download=True,
                              train=True,
                              transform=tv.transforms.ToTensor())

fabric = L.Fabric()
fabric.launch()

model = tv.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
model, optimizer = fabric.setup(model, optimizer)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
dataloader = fabric.setup_dataloaders(dataloader)

model.train()
num_epochs = 10
for epoch in range(num_epochs):
    for batch in dataloader:
        inputs, labels = batch
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        fabric.backward(loss)
        optimizer.step()
        print(loss.data)

Key features

Easily switch from running on CPU to GPU (Apple Silicon, CUDA, …), TPU, multi-GPU or even multi-node training
# Use your available hardware
# no code changes needed
fabric = Fabric()

# Run on GPUs (CUDA or MPS)
fabric = Fabric(accelerator="gpu")

# 8 GPUs
fabric = Fabric(accelerator="gpu", devices=8)

# 256 GPUs, multi-node
fabric = Fabric(accelerator="gpu", devices=8, num_nodes=32)

# Run on TPUs
fabric = Fabric(accelerator="tpu")
Use state-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed) and mixed precision out of the box
# Use state-of-the-art distributed training techniques
fabric = Fabric(strategy="ddp")
fabric = Fabric(strategy="deepspeed")
fabric = Fabric(strategy="fsdp")

# Switch the precision
fabric = Fabric(precision="16-mixed")
fabric = Fabric(precision="64")
All the device logic boilerplate is handled for you
  # no more of this!
- model.to(device)
- batch.to(device)
Build your own custom Trainer using Fabric primitives for training checkpointing, logging, and more
import lightning as L


class MyCustomTrainer:
    def __init__(self, accelerator="auto", strategy="auto", devices="auto", precision="32-true"):
        self.fabric = L.Fabric(accelerator=accelerator, strategy=strategy, devices=devices, precision=precision)

    def fit(self, model, optimizer, dataloader, max_epochs):
        self.fabric.launch()

        model, optimizer = self.fabric.setup(model, optimizer)
        dataloader = self.fabric.setup_dataloaders(dataloader)
        model.train()

        for epoch in range(max_epochs):
            for batch in dataloader:
                input, target = batch
                optimizer.zero_grad()
                output = model(input)
                loss = loss_fn(output, target)
                self.fabric.backward(loss)
                optimizer.step()

You can find a more extensive example in our examples



   

Examples

Self-supervised Learning
Convolutional Architectures
Reinforcement Learning
GANs
Classic ML

   

Continuous Integration

Lightning is rigorously tested across multiple CPUs, GPUs and TPUs and against major Python and PyTorch versions.

*Codecov is > 90%+ but build delays may show less
Current build statuses
System / PyTorch ver.1.132.02.1
Linux py3.9 [GPUs]Build Status
Linux py3.9 [TPUs]Test PyTorch - TPU
Linux (multiple Python versions)Test PyTorchTest PyTorchTest PyTorch
OSX (multiple Python versions)Test PyTorchTest PyTorchTest PyTorch
Windows (multiple Python versions)Test PyTorchTest PyTorchTest PyTorch

   

Community

The lightning community is maintained by

  • 10+ core contributors who are all a mix of professional engineers, Research Scientists, and Ph.D. students from top AI labs.
  • 800+ community contributors.

Want to help us build Lightning and reduce boilerplate for thousands of researchers? Learn how to make your first contribution here

Lightning is also part of the PyTorch ecosystem which requires projects to have solid testing, documentation and support.

Asking for help

If you have any questions please:

  1. Read the docs.
  2. Search through existing Discussions, or add a new question
  3. Join our discord.