Convert Figma logo to code with AI

huggingface logotrl

Train transformer language models with reinforcement learning.

13,851
1,898
13,851
469

Top Related Projects

13,851

Train transformer language models with reinforcement learning.

4,647

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

7,085

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)

16,242

OpenAI Baselines: high-quality implementations of reinforcement learning algorithms

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.

Quick Overview

Huggingface/trl is a library for training language models using reinforcement learning techniques. It provides tools and utilities to fine-tune large language models (LLMs) using various RL algorithms, including PPO (Proximal Policy Optimization) and ILQL (Implicit Language Q-Learning).

Pros

  • Seamless integration with Hugging Face's Transformers library
  • Supports multiple RL algorithms for language model fine-tuning
  • Includes utilities for reward modeling and data collection
  • Actively maintained and regularly updated

Cons

  • Requires significant computational resources for training large models
  • Limited documentation and examples for advanced use cases
  • Steep learning curve for users unfamiliar with RL concepts
  • May require careful hyperparameter tuning for optimal results

Code Examples

  1. Training a model using PPO:
from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

ppo_config = PPOConfig(
    batch_size=1,
    learning_rate=1.41e-5,
    mini_batch_size=1,
)

ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=model,
    tokenizer=tokenizer,
)

# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        # Forward pass and optimization step
        stats = ppo_trainer.step(batch)
  1. Using SFTTrainer for supervised fine-tuning:
from trl import SFTTrainer
from datasets import load_dataset

dataset = load_dataset("imdb")
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset["train"],
    dataset_text_field="text",
    max_seq_length=512,
)

trainer.train()
  1. Implementing a custom reward function:
from trl import AutoModelForCausalLMWithValueHead

def custom_reward_fn(responses, prompts):
    rewards = []
    for response in responses:
        # Implement your custom reward logic here
        reward = calculate_reward(response)
        rewards.append(reward)
    return rewards

model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
# Use the custom reward function in your training loop

Getting Started

To get started with huggingface/trl, follow these steps:

  1. Install the library:
pip install trl
  1. Import the necessary modules:
from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
  1. Load a pre-trained model and tokenizer:
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
  1. Create a PPOTrainer instance and start training:
ppo_config = PPOConfig(batch_size=1, learning_rate=1e-5)
ppo_trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer)
# Implement your training loop here

Competitor Comparisons

13,851

Train transformer language models with reinforcement learning.

Pros of trl

  • More comprehensive and feature-rich library for reinforcement learning in transformers
  • Supports a wider range of RL algorithms and techniques
  • Better documentation and examples for various use cases

Cons of trl

  • Potentially more complex to set up and use for simpler tasks
  • May have a steeper learning curve for beginners
  • Requires more computational resources due to its extensive features

Code Comparison

trl:

from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
ppo_trainer = PPOTrainer(config=PPOConfig(), model=model, tokenizer=tokenizer)

trl>:

from trl import AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer

model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

The code comparison shows that trl offers more advanced configuration options and a dedicated PPOTrainer, while trl> provides a simpler setup with a specialized model class. This reflects the broader feature set of trl and the more streamlined approach of trl>.

4,647

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

Pros of trlx

  • More extensive documentation and examples
  • Supports a wider range of reinforcement learning algorithms (PPO, ILQL, etc.)
  • Active development with frequent updates and contributions

Cons of trlx

  • Steeper learning curve for beginners
  • Less integration with Hugging Face ecosystem
  • May require more computational resources for some tasks

Code Comparison

trlx:

from trlx.data.default_configs import default_ppo_config
from trlx.models.modeling_ppo import PPOConfig
from trlx import train

config = default_ppo_config()
trainer = train('gpt2', config)

trl:

from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer

config = PPOConfig(model_name="gpt2")
trainer = PPOTrainer(config=config, model="gpt2", tokenizer=AutoTokenizer.from_pretrained("gpt2"))

Both libraries offer similar functionality for training language models with reinforcement learning techniques. trlx provides more advanced features and algorithms, while trl integrates more seamlessly with the Hugging Face ecosystem. The choice between them depends on the specific requirements of your project and your familiarity with reinforcement learning concepts.

7,085

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)

Pros of CleanRL

  • Focuses on simplicity and readability, making it easier for beginners to understand RL algorithms
  • Implements a wide range of RL algorithms in a consistent style
  • Provides single-file implementations for each algorithm, enhancing portability

Cons of CleanRL

  • Less integration with pre-trained language models compared to TRL
  • May lack some advanced features and optimizations present in TRL
  • Smaller community and ecosystem compared to Hugging Face's offerings

Code Comparison

CleanRL (PPO implementation):

def compute_gae(next_value, rewards, masks, values, gamma=0.99, lam=0.95):
    values = values + [next_value]
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * lam * masks[step] * gae
        returns.insert(0, gae + values[step])
    return returns

TRL (PPO implementation):

def compute_advantages(
    rewards: torch.Tensor,
    values: torch.Tensor,
    mask: torch.Tensor,
    gamma: float,
    lam: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
    last_gae_lam = 0
    advantages_reversed = []
    length = rewards.shape[0]
    for t in reversed(range(length)):
        next_values = values[t + 1] if t < length - 1 else 0.0
        delta = rewards[t] + gamma * next_values * mask[t] - values[t]
        last_gae_lam = delta + gamma * lam * mask[t] * last_gae_lam
        advantages_reversed.append(last_gae_lam)
    advantages = torch.stack(advantages_reversed[::-1])
    returns = advantages + values
    return advantages, returns
16,242

OpenAI Baselines: high-quality implementations of reinforcement learning algorithms

Pros of baselines

  • Extensive collection of reinforcement learning algorithms
  • Well-established and widely used in the research community
  • Includes implementations for various environments (Atari, MuJoCo, etc.)

Cons of baselines

  • Less focus on transformer-based models and language tasks
  • Not actively maintained (last commit over 2 years ago)
  • Steeper learning curve for beginners

Code comparison

baselines (DQN implementation):

def learn(env, network, seed=None, lr=5e-4, total_timesteps=100000, buffer_size=50000,
          exploration_fraction=0.1, exploration_final_eps=0.02, train_freq=1,
          batch_size=32, print_freq=100, checkpoint_freq=10000, checkpoint_path=None,
          learning_starts=1000, gamma=1.0, target_network_update_freq=500,
          prioritized_replay=False, prioritized_replay_alpha=0.6,
          prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None,
          prioritized_replay_eps=1e-6, param_noise=False, callback=None,
          load_path=None, **network_kwargs):

trl (PPO implementation):

def train(
    self,
    max_steps: int,
    learning_rate: float = 1e-5,
    batch_size: int = 32,
    gradient_accumulation_steps: int = 1,
):

The trl library focuses on transformer-based models and provides a more streamlined API for fine-tuning language models with reinforcement learning techniques. It's actively maintained and integrates well with the Hugging Face ecosystem, making it easier to use for NLP tasks.

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms

Pros of stable-baselines

  • More comprehensive set of RL algorithms, including DQN, PPO, and SAC
  • Better documentation and tutorials for beginners
  • Longer history and more established community support

Cons of stable-baselines

  • Less focus on language models and NLP tasks
  • Not as tightly integrated with the Hugging Face ecosystem
  • Slower development pace and fewer recent updates

Code Comparison

stable-baselines:

from stable_baselines3 import PPO

model = PPO("MlpPolicy", "CartPole-v1", verbose=1)
model.learn(total_timesteps=10000)

trl:

from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
ppo_trainer = PPOTrainer(config=PPOConfig(), model=model, tokenizer=tokenizer)

The code snippets highlight the different focus areas of the two libraries. stable-baselines is geared towards traditional RL tasks, while trl is designed for fine-tuning language models using RL techniques.

PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.

Pros of stable-baselines3

  • More comprehensive set of RL algorithms implemented
  • Better documentation and examples for various environments
  • More mature and stable codebase with longer development history

Cons of stable-baselines3

  • Less focus on integration with large language models
  • Not specifically designed for fine-tuning transformer models
  • May require more setup and configuration for NLP tasks

Code Comparison

stable-baselines3:

from stable_baselines3 import PPO

model = PPO("MlpPolicy", "CartPole-v1", verbose=1)
model.learn(total_timesteps=10_000)

trl:

from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
ppo_trainer = PPOTrainer(config=PPOConfig(), model=model, tokenizer=tokenizer)

The code snippets highlight the difference in focus between the two libraries. stable-baselines3 is geared towards traditional RL tasks, while trl is designed for fine-tuning language models using RL techniques.

Convert Figma logo designs to code with AI

Visual Copilot

Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.

Try Visual Copilot

README

TRL - Transformer Reinforcement Learning

TRL Banner


A comprehensive library to post-train foundation models

License Documentation GitHub release Hugging Face Hub

Overview

TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the 🤗 Transformers ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.

Highlights

  • Trainers: Various fine-tuning methods are easily accessible via trainers like SFTTrainer, GRPOTrainer, DPOTrainer, RewardTrainer and more.

  • Efficient and scalable:

    • Leverages 🤗 Accelerate to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
    • Full integration with 🤗 PEFT enables training on large models with modest hardware via quantization and LoRA/QLoRA.
    • Integrates 🦥 Unsloth for accelerating training using optimized kernels.
  • Command Line Interface (CLI): A simple interface lets you fine-tune with models without needing to write code.

Installation

Python Package

Install the library using pip:

pip install trl

From source

If you want to use the latest features before an official release, you can install TRL from source:

pip install git+https://github.com/huggingface/trl.git

Repository

If you want to use the examples you can clone the repository with the following command:

git clone https://github.com/huggingface/trl.git

Quick Start

For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.

SFTTrainer

Here is a basic example of how to use the SFTTrainer:

from trl import SFTTrainer
from datasets import load_dataset

dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=dataset,
)
trainer.train()

GRPOTrainer

GRPOTrainer implements the Group Relative Policy Optimization (GRPO) algorithm that is more memory-efficient than PPO and was used to train Deepseek AI's R1.

from datasets import load_dataset
from trl import GRPOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
    return [len(set(c)) for c in completions]

trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_num_unique_chars,
    train_dataset=dataset,
)
trainer.train()

DPOTrainer

DPOTrainer implements the popular Direct Preference Optimization (DPO) algorithm that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the DPOTrainer:

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    processing_class=tokenizer
)
trainer.train()

RewardTrainer

Here is a basic example of how to use the RewardTrainer:

from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
model.config.pad_token_id = tokenizer.pad_token_id

dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")

training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(
    args=training_args,
    model=model,
    processing_class=tokenizer,
    train_dataset=dataset,
)
trainer.train()

Command Line Interface (CLI)

You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):

SFT:

trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
    --dataset_name trl-lib/Capybara \
    --output_dir Qwen2.5-0.5B-SFT

DPO:

trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
    --dataset_name argilla/Capybara-Preferences \
    --output_dir Qwen2.5-0.5B-DPO 

Read more about CLI in the relevant documentation section or use --help for more details.

Development

If you want to contribute to trl or customize it to your needs make sure to read the contribution guide and make sure you make a dev install:

git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]

Citation

@misc{vonwerra2022trl,
  author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
  title = {TRL: Transformer Reinforcement Learning},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/huggingface/trl}}
}

License

This repository's source code is available under the Apache-2.0 License.