Top Related Projects
Train transformer language models with reinforcement learning.
A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
OpenAI Baselines: high-quality implementations of reinforcement learning algorithms
A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
Quick Overview
Huggingface/trl is a library for training language models using reinforcement learning techniques. It provides tools and utilities to fine-tune large language models (LLMs) using various RL algorithms, including PPO (Proximal Policy Optimization) and ILQL (Implicit Language Q-Learning).
Pros
- Seamless integration with Hugging Face's Transformers library
- Supports multiple RL algorithms for language model fine-tuning
- Includes utilities for reward modeling and data collection
- Actively maintained and regularly updated
Cons
- Requires significant computational resources for training large models
- Limited documentation and examples for advanced use cases
- Steep learning curve for users unfamiliar with RL concepts
- May require careful hyperparameter tuning for optimal results
Code Examples
- Training a model using PPO:
from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
ppo_config = PPOConfig(
batch_size=1,
learning_rate=1.41e-5,
mini_batch_size=1,
)
ppo_trainer = PPOTrainer(
config=ppo_config,
model=model,
tokenizer=tokenizer,
)
# Training loop
for epoch in range(num_epochs):
for batch in dataloader:
# Forward pass and optimization step
stats = ppo_trainer.step(batch)
- Using SFTTrainer for supervised fine-tuning:
from trl import SFTTrainer
from datasets import load_dataset
dataset = load_dataset("imdb")
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset["train"],
dataset_text_field="text",
max_seq_length=512,
)
trainer.train()
- Implementing a custom reward function:
from trl import AutoModelForCausalLMWithValueHead
def custom_reward_fn(responses, prompts):
rewards = []
for response in responses:
# Implement your custom reward logic here
reward = calculate_reward(response)
rewards.append(reward)
return rewards
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
# Use the custom reward function in your training loop
Getting Started
To get started with huggingface/trl, follow these steps:
- Install the library:
pip install trl
- Import the necessary modules:
from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
- Load a pre-trained model and tokenizer:
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
- Create a PPOTrainer instance and start training:
ppo_config = PPOConfig(batch_size=1, learning_rate=1e-5)
ppo_trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer)
# Implement your training loop here
Competitor Comparisons
Train transformer language models with reinforcement learning.
Pros of trl
- More comprehensive and feature-rich library for reinforcement learning in transformers
- Supports a wider range of RL algorithms and techniques
- Better documentation and examples for various use cases
Cons of trl
- Potentially more complex to set up and use for simpler tasks
- May have a steeper learning curve for beginners
- Requires more computational resources due to its extensive features
Code Comparison
trl:
from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
ppo_trainer = PPOTrainer(config=PPOConfig(), model=model, tokenizer=tokenizer)
trl>:
from trl import AutoModelForCausalLMWithValueHead
from transformers import AutoTokenizer
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
The code comparison shows that trl offers more advanced configuration options and a dedicated PPOTrainer, while trl> provides a simpler setup with a specialized model class. This reflects the broader feature set of trl and the more streamlined approach of trl>.
A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
Pros of trlx
- More extensive documentation and examples
- Supports a wider range of reinforcement learning algorithms (PPO, ILQL, etc.)
- Active development with frequent updates and contributions
Cons of trlx
- Steeper learning curve for beginners
- Less integration with Hugging Face ecosystem
- May require more computational resources for some tasks
Code Comparison
trlx:
from trlx.data.default_configs import default_ppo_config
from trlx.models.modeling_ppo import PPOConfig
from trlx import train
config = default_ppo_config()
trainer = train('gpt2', config)
trl:
from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer
config = PPOConfig(model_name="gpt2")
trainer = PPOTrainer(config=config, model="gpt2", tokenizer=AutoTokenizer.from_pretrained("gpt2"))
Both libraries offer similar functionality for training language models with reinforcement learning techniques. trlx provides more advanced features and algorithms, while trl integrates more seamlessly with the Hugging Face ecosystem. The choice between them depends on the specific requirements of your project and your familiarity with reinforcement learning concepts.
High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
Pros of CleanRL
- Focuses on simplicity and readability, making it easier for beginners to understand RL algorithms
- Implements a wide range of RL algorithms in a consistent style
- Provides single-file implementations for each algorithm, enhancing portability
Cons of CleanRL
- Less integration with pre-trained language models compared to TRL
- May lack some advanced features and optimizations present in TRL
- Smaller community and ecosystem compared to Hugging Face's offerings
Code Comparison
CleanRL (PPO implementation):
def compute_gae(next_value, rewards, masks, values, gamma=0.99, lam=0.95):
values = values + [next_value]
gae = 0
returns = []
for step in reversed(range(len(rewards))):
delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
gae = delta + gamma * lam * masks[step] * gae
returns.insert(0, gae + values[step])
return returns
TRL (PPO implementation):
def compute_advantages(
rewards: torch.Tensor,
values: torch.Tensor,
mask: torch.Tensor,
gamma: float,
lam: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
last_gae_lam = 0
advantages_reversed = []
length = rewards.shape[0]
for t in reversed(range(length)):
next_values = values[t + 1] if t < length - 1 else 0.0
delta = rewards[t] + gamma * next_values * mask[t] - values[t]
last_gae_lam = delta + gamma * lam * mask[t] * last_gae_lam
advantages_reversed.append(last_gae_lam)
advantages = torch.stack(advantages_reversed[::-1])
returns = advantages + values
return advantages, returns
OpenAI Baselines: high-quality implementations of reinforcement learning algorithms
Pros of baselines
- Extensive collection of reinforcement learning algorithms
- Well-established and widely used in the research community
- Includes implementations for various environments (Atari, MuJoCo, etc.)
Cons of baselines
- Less focus on transformer-based models and language tasks
- Not actively maintained (last commit over 2 years ago)
- Steeper learning curve for beginners
Code comparison
baselines (DQN implementation):
def learn(env, network, seed=None, lr=5e-4, total_timesteps=100000, buffer_size=50000,
exploration_fraction=0.1, exploration_final_eps=0.02, train_freq=1,
batch_size=32, print_freq=100, checkpoint_freq=10000, checkpoint_path=None,
learning_starts=1000, gamma=1.0, target_network_update_freq=500,
prioritized_replay=False, prioritized_replay_alpha=0.6,
prioritized_replay_beta0=0.4, prioritized_replay_beta_iters=None,
prioritized_replay_eps=1e-6, param_noise=False, callback=None,
load_path=None, **network_kwargs):
trl (PPO implementation):
def train(
self,
max_steps: int,
learning_rate: float = 1e-5,
batch_size: int = 32,
gradient_accumulation_steps: int = 1,
):
The trl library focuses on transformer-based models and provides a more streamlined API for fine-tuning language models with reinforcement learning techniques. It's actively maintained and integrates well with the Hugging Face ecosystem, making it easier to use for NLP tasks.
A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
Pros of stable-baselines
- More comprehensive set of RL algorithms, including DQN, PPO, and SAC
- Better documentation and tutorials for beginners
- Longer history and more established community support
Cons of stable-baselines
- Less focus on language models and NLP tasks
- Not as tightly integrated with the Hugging Face ecosystem
- Slower development pace and fewer recent updates
Code Comparison
stable-baselines:
from stable_baselines3 import PPO
model = PPO("MlpPolicy", "CartPole-v1", verbose=1)
model.learn(total_timesteps=10000)
trl:
from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
ppo_trainer = PPOTrainer(config=PPOConfig(), model=model, tokenizer=tokenizer)
The code snippets highlight the different focus areas of the two libraries. stable-baselines is geared towards traditional RL tasks, while trl is designed for fine-tuning language models using RL techniques.
PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
Pros of stable-baselines3
- More comprehensive set of RL algorithms implemented
- Better documentation and examples for various environments
- More mature and stable codebase with longer development history
Cons of stable-baselines3
- Less focus on integration with large language models
- Not specifically designed for fine-tuning transformer models
- May require more setup and configuration for NLP tasks
Code Comparison
stable-baselines3:
from stable_baselines3 import PPO
model = PPO("MlpPolicy", "CartPole-v1", verbose=1)
model.learn(total_timesteps=10_000)
trl:
from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
ppo_trainer = PPOTrainer(config=PPOConfig(), model=model, tokenizer=tokenizer)
The code snippets highlight the difference in focus between the two libraries. stable-baselines3 is geared towards traditional RL tasks, while trl is designed for fine-tuning language models using RL techniques.
Convert
designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
TRL - Transformer Reinforcement Learning

A comprehensive library to post-train foundation models
Overview
TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the ð¤ Transformers ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
Highlights
-
Trainers: Various fine-tuning methods are easily accessible via trainers like
SFTTrainer
,GRPOTrainer
,DPOTrainer
,RewardTrainer
and more. -
Efficient and scalable:
- Leverages ð¤ Accelerate to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
- Full integration with ð¤ PEFT enables training on large models with modest hardware via quantization and LoRA/QLoRA.
- Integrates 𦥠Unsloth for accelerating training using optimized kernels.
-
Command Line Interface (CLI): A simple interface lets you fine-tune with models without needing to write code.
Installation
Python Package
Install the library using pip
:
pip install trl
From source
If you want to use the latest features before an official release, you can install TRL from source:
pip install git+https://github.com/huggingface/trl.git
Repository
If you want to use the examples you can clone the repository with the following command:
git clone https://github.com/huggingface/trl.git
Quick Start
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the ð¤ Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
SFTTrainer
Here is a basic example of how to use the SFTTrainer
:
from trl import SFTTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
GRPOTrainer
GRPOTrainer
implements the Group Relative Policy Optimization (GRPO) algorithm that is more memory-efficient than PPO and was used to train Deepseek AI's R1.
from datasets import load_dataset
from trl import GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
DPOTrainer
DPOTrainer
implements the popular Direct Preference Optimization (DPO) algorithm that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the DPOTrainer
:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer
)
trainer.train()
RewardTrainer
Here is a basic example of how to use the RewardTrainer
:
from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
model.config.pad_token_id = tokenizer.pad_token_id
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(
args=training_args,
model=model,
processing_class=tokenizer,
train_dataset=dataset,
)
trainer.train()
Command Line Interface (CLI)
You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):
SFT:
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/Capybara \
--output_dir Qwen2.5-0.5B-SFT
DPO:
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name argilla/Capybara-Preferences \
--output_dir Qwen2.5-0.5B-DPO
Read more about CLI in the relevant documentation section or use --help
for more details.
Development
If you want to contribute to trl
or customize it to your needs make sure to read the contribution guide and make sure you make a dev install:
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .[dev]
Citation
@misc{vonwerra2022trl,
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
title = {TRL: Transformer Reinforcement Learning},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/huggingface/trl}}
}
License
This repository's source code is available under the Apache-2.0 License.
Top Related Projects
Train transformer language models with reinforcement learning.
A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)
High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
OpenAI Baselines: high-quality implementations of reinforcement learning algorithms
A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
PyTorch version of Stable Baselines, reliable implementations of reinforcement learning algorithms.
Convert
designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot