Top Related Projects
A library of reinforcement learning components and agents
A toolkit for developing and comparing reinforcement learning algorithms.
TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.
A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
An API standard for single-agent reinforcement learning environments, with popular reference environments and related utilities (formerly Gym)
High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
Quick Overview
RLax is a library of reinforcement learning building blocks developed by DeepMind. It provides a collection of common RL functions and algorithms implemented in JAX, designed to be efficient, flexible, and easy to use in research projects.
Pros
- High performance due to JAX backend, enabling GPU/TPU acceleration
- Modular design allows for easy customization and experimentation
- Well-documented with clear examples and API references
- Integrates seamlessly with other JAX-based libraries
Cons
- Requires familiarity with JAX, which may have a learning curve for some users
- Limited to reinforcement learning tasks, not a general-purpose machine learning library
- May have fewer pre-built, end-to-end algorithms compared to some other RL libraries
- Ongoing development may lead to occasional API changes
Code Examples
- Calculating the n-step return:
import jax
import rlax
rewards = jax.numpy.array([1.0, 0.5, 0.0, -1.0])
discounts = jax.numpy.array([0.9, 0.9, 0.9, 0.0])
values = jax.numpy.array([2.0, 1.5, 1.0, 0.5])
n_step_return = rlax.n_step_bootstrapped_returns(rewards, discounts, values, n=2)
print(n_step_return)
- Implementing epsilon-greedy exploration:
import jax
import rlax
q_values = jax.numpy.array([1.0, 2.0, 0.5, 1.5])
epsilon = 0.1
action = rlax.epsilon_greedy(epsilon).sample(jax.random.PRNGKey(0), q_values)
print(action)
- Computing the Huber loss:
import jax
import rlax
predictions = jax.numpy.array([1.0, 2.0, 3.0])
targets = jax.numpy.array([1.2, 1.8, 3.1])
loss = rlax.huber_loss(predictions, targets, delta=1.0)
print(loss)
Getting Started
To get started with RLax, first install the library:
pip install rlax
Then, import the library and start using its functions:
import jax
import rlax
# Example: Compute discounted returns
rewards = jax.numpy.array([1.0, 2.0, 3.0, 4.0])
discounts = jax.numpy.array([0.9, 0.9, 0.9, 0.0])
returns = rlax.discounted_returns(rewards, discounts)
print(returns)
For more detailed examples and API documentation, refer to the official RLax documentation.
Competitor Comparisons
A library of reinforcement learning components and agents
Pros of Acme
- Provides a full-featured RL framework with agents, environments, and training loops
- Offers distributed and single-process implementations for scalability
- Includes pre-built agents and baselines for quick experimentation
Cons of Acme
- More complex and heavyweight compared to RLax's focused approach
- Steeper learning curve due to its comprehensive nature
- May be overkill for simple RL tasks or research
Code Comparison
Acme (agent creation):
agent = dqn.DQN(
environment_spec=env_spec,
network=network,
batch_size=batch_size,
samples_per_insert=samples_per_insert,
min_replay_size=min_replay_size
)
RLax (Q-learning update):
new_q = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t)
loss = jax.numpy.square(q_tm1[a_tm1] - jax.lax.stop_gradient(new_q))
RLax focuses on providing building blocks for RL algorithms, while Acme offers a complete framework for agent development and training. RLax is more flexible and lightweight, suitable for researchers who want fine-grained control over their implementations. Acme is better suited for those who need a full-featured, production-ready RL system with pre-built components and distributed capabilities.
A toolkit for developing and comparing reinforcement learning algorithms.
Pros of Gym
- Widely adopted and supported by the RL community
- Provides a diverse set of pre-built environments for testing and benchmarking
- Offers a simple, consistent API for interacting with environments
Cons of Gym
- Limited to environment simulation, lacking built-in RL algorithms
- Less flexibility for custom environment creation compared to RLax
Code Comparison
Gym:
import gym
env = gym.make('CartPole-v1')
observation = env.reset()
for _ in range(1000):
action = env.action_space.sample()
observation, reward, done, info = env.step(action)
RLax:
import rlax
import jax.numpy as jnp
q_values = jnp.array([1.0, 2.0, 3.0])
action = rlax.softmax(q_values)
RLax focuses on providing building blocks for RL algorithms, while Gym emphasizes environment simulation. RLax offers more flexibility for implementing custom RL algorithms, but requires more setup. Gym provides a ready-to-use framework for testing RL agents in various environments but may be less suitable for advanced algorithm development. The choice between the two depends on the specific needs of the project and the user's familiarity with JAX (for RLax) or more traditional Python libraries (for Gym).
TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.
Pros of TF-Agents
- More comprehensive, offering a full suite of RL algorithms and tools
- Better integration with TensorFlow ecosystem and hardware acceleration
- Extensive documentation and examples for various use cases
Cons of TF-Agents
- Steeper learning curve due to its complexity and broader scope
- Potentially slower development and experimentation cycles
- Less flexibility for custom implementations compared to RLax
Code Comparison
RLax (simple Q-learning update):
def q_learning_update(q_tm1, a_tm1, r_t, q_t, discount):
return q_tm1 + learning_rate * (r_t + discount * jnp.max(q_t) - q_tm1[a_tm1])
TF-Agents (Q-learning agent setup):
agent = dqn_agent.DqnAgent(
time_step_spec,
action_spec,
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)
RLax focuses on providing low-level, flexible building blocks for RL algorithms, while TF-Agents offers a more complete, high-level framework for implementing and training RL agents. RLax is better suited for researchers and those who need fine-grained control, while TF-Agents is more appropriate for practitioners and those who want a ready-to-use solution.
A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
Pros of Stable-baselines
- More user-friendly and easier to get started with for beginners
- Provides pre-implemented and optimized algorithms ready for use
- Includes a wide range of popular RL algorithms out of the box
Cons of Stable-baselines
- Less flexible and customizable compared to RLax
- May have performance limitations for advanced or custom implementations
- Potentially slower development cycle for new algorithms and features
Code Comparison
Stable-baselines example:
from stable_baselines3 import PPO
model = PPO("MlpPolicy", "CartPole-v1", verbose=1)
model.learn(total_timesteps=10000)
RLax example:
import jax
import rlax
def loss_fn(params, target, prediction):
return rlax.l2_loss(prediction, target)
grad_fn = jax.grad(loss_fn)
The Stable-baselines code shows a higher-level API for quick implementation, while RLax provides lower-level building blocks for custom RL algorithms.
An API standard for single-agent reinforcement learning environments, with popular reference environments and related utilities (formerly Gym)
Pros of Gymnasium
- More comprehensive and diverse set of environments for reinforcement learning
- Larger community and ecosystem, with more third-party extensions and integrations
- Better documentation and tutorials for beginners
Cons of Gymnasium
- Potentially slower performance due to Python implementation
- Less focus on advanced RL algorithms and research-oriented features
Code Comparison
Gymnasium:
import gymnasium as gym
env = gym.make("CartPole-v1")
observation, info = env.reset(seed=42)
for _ in range(1000):
action = env.action_space.sample()
observation, reward, terminated, truncated, info = env.step(action)
RLax:
import jax
import rlax
epsilon = 0.1
state = jax.random.PRNGKey(42)
q_values = jax.random.uniform(state, (4,))
action = rlax.epsilon_greedy(epsilon).sample(state, q_values)
The code snippets highlight the different focus areas of the two libraries. Gymnasium provides a high-level interface for creating and interacting with environments, while RLax offers low-level building blocks for implementing RL algorithms using JAX.
High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
Pros of Ignite
- More general-purpose, supporting various deep learning tasks beyond reinforcement learning
- Extensive documentation and tutorials for easier onboarding
- Larger community and more frequent updates
Cons of Ignite
- Less specialized for reinforcement learning tasks
- May have more overhead for simple RL projects
- Steeper learning curve for beginners in deep learning
Code Comparison
Ignite example:
from ignite.engine import Engine, Events
def train_step(engine, batch):
# Training logic here
return loss
trainer = Engine(train_step)
trainer.run(data_loader, max_epochs=10)
RLax example:
import rlax
def loss_fn(params, target, state):
q_values = network(params, state)
return rlax.q_learning(q_values, target, actions)
loss = jax.vmap(loss_fn)(params, targets, states)
Both libraries offer concise ways to implement training loops and loss functions, but RLax is more focused on RL-specific operations, while Ignite provides a more general framework for various deep learning tasks.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
RLax
RLax (pronounced "relax") is a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning agents. Full documentation can be found at rlax.readthedocs.io.
Installation
You can install the latest released version of RLax from PyPI via:
pip install rlax
or you can install the latest development version from GitHub:
pip install git+https://github.com/deepmind/rlax.git
All RLax code may then be just in time compiled for different hardware
(e.g. CPU, GPU, TPU) using jax.jit
.
In order to run the examples/
you will also need to clone the repo and
install the additional requirements:
optax,
haiku, and
bsuite.
Content
The operations and functions provided are not complete algorithms, but implementations of reinforcement learning specific mathematical operations that are needed when building fully-functional agents capable of learning:
- Values, including both state and action-values;
- Values for Non-linear generalizations of the Bellman equations.
- Return Distributions, aka distributional value functions;
- General Value Functions, for cumulants other than the main reward;
- Policies, via policy-gradients in both continuous and discrete action spaces.
The library supports both on-policy and off-policy learning (i.e. learning from data sampled from a policy different from the agent's policy).
See file-level and function-level doc-strings for the documentation of these functions and for references to the papers that introduced and/or used them.
Usage
See examples/
for examples of using some of the functions in RLax to
implement a few simple reinforcement learning agents, and demonstrate learning
on BSuite's version of the Catch environment (a common unit-test for
agent development in the reinforcement learning literature):
Other examples of JAX reinforcement learning agents using rlax
can be found in
bsuite.
Background
Reinforcement learning studies the problem of a learning system (the agent), which must learn to interact with the universe it is embedded in (the environment).
Agent and environment interact on discrete steps. On each step the agent selects an action, and is provided in return a (partial) snapshot of the state of the environment (the observation), and a scalar feedback signal (the reward).
The behaviour of the agent is characterized by a probability distribution over actions, conditioned on past observations of the environment (the policy). The agents seeks a policy that, from any given step, maximises the discounted cumulative reward that will be collected from that point onwards (the return).
Often the agent policy or the environment dynamics itself are stochastic. In this case the return is a random variable, and the optimal agent's policy is typically more precisely specified as a policy that maximises the expectation of the return (the value), under the agent's and environment's stochasticity.
Reinforcement Learning Algorithms
There are three prototypical families of reinforcement learning algorithms:
- those that estimate the value of states and actions, and infer a policy by inspection (e.g. by selecting the action with highest estimated value)
- those that learn a model of the environment (capable of predicting the observations and rewards) and infer a policy via planning.
- those that parameterize a policy that can be directly executed,
In any case, policies, values or models are just functions. In deep reinforcement learning such functions are represented by a neural network. In this setting, it is common to formulate reinforcement learning updates as differentiable pseudo-loss functions (analogously to (un-)supervised learning). Under automatic differentiation, the original update rule is recovered.
Note however, that in particular, the updates are only valid if the input data is sampled in the correct manner. For example, a policy gradient loss is only valid if the input trajectory is an unbiased sample from the current policy; i.e. the data are on-policy. The library cannot check or enforce such constraints. Links to papers describing how each operation is used are however provided in the functions' doc-strings.
Naming Conventions and Developer Guidelines
We define functions and operations for agents interacting with a single stream
of experience. The JAX construct vmap
can be used to apply these same
functions to batches (e.g. to support replay and parallel data generation).
Many functions consider policies, actions, rewards, values, in consecutive
timesteps in order to compute their outputs. In this case the suffix _t
and
tm1
is often to clarify on which step each input was generated, e.g:
q_tm1
: the action value in thesource
state of a transition.a_tm1
: the action that was selected in thesource
state.r_t
: the resulting rewards collected in thedestination
state.discount_t
: thediscount
associated with a transition.q_t
: the action values in thedestination
state.
Extensive testing is provided for each function. All tests should also verify
the output of rlax
functions when compiled to XLA using jax.jit
and when
performing batch operations using jax.vmap
.
Citing RLax
This repository is part of the DeepMind JAX Ecosystem, to cite Rlax please use the citation:
@software{deepmind2020jax,
title = {The {D}eep{M}ind {JAX} {E}cosystem},
author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
url = {http://github.com/deepmind},
year = {2020},
}
Top Related Projects
A library of reinforcement learning components and agents
A toolkit for developing and comparing reinforcement learning algorithms.
TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.
A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
An API standard for single-agent reinforcement learning environments, with popular reference environments and related utilities (formerly Gym)
High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot