Convert Figma logo to code with AI

google-deepmind logorlax

No description available

1,273
87
1,273
15

Top Related Projects

3,485

A library of reinforcement learning components and agents

34,643

A toolkit for developing and comparing reinforcement learning algorithms.

2,788

TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms

An API standard for single-agent reinforcement learning environments, with popular reference environments and related utilities (formerly Gym)

4,516

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.

Quick Overview

RLax is a library of reinforcement learning building blocks developed by DeepMind. It provides a collection of common RL functions and algorithms implemented in JAX, designed to be efficient, flexible, and easy to use in research projects.

Pros

  • High performance due to JAX backend, enabling GPU/TPU acceleration
  • Modular design allows for easy customization and experimentation
  • Well-documented with clear examples and API references
  • Integrates seamlessly with other JAX-based libraries

Cons

  • Requires familiarity with JAX, which may have a learning curve for some users
  • Limited to reinforcement learning tasks, not a general-purpose machine learning library
  • May have fewer pre-built, end-to-end algorithms compared to some other RL libraries
  • Ongoing development may lead to occasional API changes

Code Examples

  1. Calculating the n-step return:
import jax
import rlax

rewards = jax.numpy.array([1.0, 0.5, 0.0, -1.0])
discounts = jax.numpy.array([0.9, 0.9, 0.9, 0.0])
values = jax.numpy.array([2.0, 1.5, 1.0, 0.5])

n_step_return = rlax.n_step_bootstrapped_returns(rewards, discounts, values, n=2)
print(n_step_return)
  1. Implementing epsilon-greedy exploration:
import jax
import rlax

q_values = jax.numpy.array([1.0, 2.0, 0.5, 1.5])
epsilon = 0.1

action = rlax.epsilon_greedy(epsilon).sample(jax.random.PRNGKey(0), q_values)
print(action)
  1. Computing the Huber loss:
import jax
import rlax

predictions = jax.numpy.array([1.0, 2.0, 3.0])
targets = jax.numpy.array([1.2, 1.8, 3.1])

loss = rlax.huber_loss(predictions, targets, delta=1.0)
print(loss)

Getting Started

To get started with RLax, first install the library:

pip install rlax

Then, import the library and start using its functions:

import jax
import rlax

# Example: Compute discounted returns
rewards = jax.numpy.array([1.0, 2.0, 3.0, 4.0])
discounts = jax.numpy.array([0.9, 0.9, 0.9, 0.0])
returns = rlax.discounted_returns(rewards, discounts)
print(returns)

For more detailed examples and API documentation, refer to the official RLax documentation.

Competitor Comparisons

3,485

A library of reinforcement learning components and agents

Pros of Acme

  • Provides a full-featured RL framework with agents, environments, and training loops
  • Offers distributed and single-process implementations for scalability
  • Includes pre-built agents and baselines for quick experimentation

Cons of Acme

  • More complex and heavyweight compared to RLax's focused approach
  • Steeper learning curve due to its comprehensive nature
  • May be overkill for simple RL tasks or research

Code Comparison

Acme (agent creation):

agent = dqn.DQN(
    environment_spec=env_spec,
    network=network,
    batch_size=batch_size,
    samples_per_insert=samples_per_insert,
    min_replay_size=min_replay_size
)

RLax (Q-learning update):

new_q = rlax.q_learning(q_tm1, a_tm1, r_t, discount_t, q_t)
loss = jax.numpy.square(q_tm1[a_tm1] - jax.lax.stop_gradient(new_q))

RLax focuses on providing building blocks for RL algorithms, while Acme offers a complete framework for agent development and training. RLax is more flexible and lightweight, suitable for researchers who want fine-grained control over their implementations. Acme is better suited for those who need a full-featured, production-ready RL system with pre-built components and distributed capabilities.

34,643

A toolkit for developing and comparing reinforcement learning algorithms.

Pros of Gym

  • Widely adopted and supported by the RL community
  • Provides a diverse set of pre-built environments for testing and benchmarking
  • Offers a simple, consistent API for interacting with environments

Cons of Gym

  • Limited to environment simulation, lacking built-in RL algorithms
  • Less flexibility for custom environment creation compared to RLax

Code Comparison

Gym:

import gym
env = gym.make('CartPole-v1')
observation = env.reset()
for _ in range(1000):
    action = env.action_space.sample()
    observation, reward, done, info = env.step(action)

RLax:

import rlax
import jax.numpy as jnp

q_values = jnp.array([1.0, 2.0, 3.0])
action = rlax.softmax(q_values)

RLax focuses on providing building blocks for RL algorithms, while Gym emphasizes environment simulation. RLax offers more flexibility for implementing custom RL algorithms, but requires more setup. Gym provides a ready-to-use framework for testing RL agents in various environments but may be less suitable for advanced algorithm development. The choice between the two depends on the specific needs of the project and the user's familiarity with JAX (for RLax) or more traditional Python libraries (for Gym).

2,788

TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.

Pros of TF-Agents

  • More comprehensive, offering a full suite of RL algorithms and tools
  • Better integration with TensorFlow ecosystem and hardware acceleration
  • Extensive documentation and examples for various use cases

Cons of TF-Agents

  • Steeper learning curve due to its complexity and broader scope
  • Potentially slower development and experimentation cycles
  • Less flexibility for custom implementations compared to RLax

Code Comparison

RLax (simple Q-learning update):

def q_learning_update(q_tm1, a_tm1, r_t, q_t, discount):
    return q_tm1 + learning_rate * (r_t + discount * jnp.max(q_t) - q_tm1[a_tm1])

TF-Agents (Q-learning agent setup):

agent = dqn_agent.DqnAgent(
    time_step_spec,
    action_spec,
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

RLax focuses on providing low-level, flexible building blocks for RL algorithms, while TF-Agents offers a more complete, high-level framework for implementing and training RL agents. RLax is better suited for researchers and those who need fine-grained control, while TF-Agents is more appropriate for practitioners and those who want a ready-to-use solution.

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms

Pros of Stable-baselines

  • More user-friendly and easier to get started with for beginners
  • Provides pre-implemented and optimized algorithms ready for use
  • Includes a wide range of popular RL algorithms out of the box

Cons of Stable-baselines

  • Less flexible and customizable compared to RLax
  • May have performance limitations for advanced or custom implementations
  • Potentially slower development cycle for new algorithms and features

Code Comparison

Stable-baselines example:

from stable_baselines3 import PPO

model = PPO("MlpPolicy", "CartPole-v1", verbose=1)
model.learn(total_timesteps=10000)

RLax example:

import jax
import rlax

def loss_fn(params, target, prediction):
    return rlax.l2_loss(prediction, target)

grad_fn = jax.grad(loss_fn)

The Stable-baselines code shows a higher-level API for quick implementation, while RLax provides lower-level building blocks for custom RL algorithms.

An API standard for single-agent reinforcement learning environments, with popular reference environments and related utilities (formerly Gym)

Pros of Gymnasium

  • More comprehensive and diverse set of environments for reinforcement learning
  • Larger community and ecosystem, with more third-party extensions and integrations
  • Better documentation and tutorials for beginners

Cons of Gymnasium

  • Potentially slower performance due to Python implementation
  • Less focus on advanced RL algorithms and research-oriented features

Code Comparison

Gymnasium:

import gymnasium as gym
env = gym.make("CartPole-v1")
observation, info = env.reset(seed=42)
for _ in range(1000):
    action = env.action_space.sample()
    observation, reward, terminated, truncated, info = env.step(action)

RLax:

import jax
import rlax
epsilon = 0.1
state = jax.random.PRNGKey(42)
q_values = jax.random.uniform(state, (4,))
action = rlax.epsilon_greedy(epsilon).sample(state, q_values)

The code snippets highlight the different focus areas of the two libraries. Gymnasium provides a high-level interface for creating and interacting with environments, while RLax offers low-level building blocks for implementing RL algorithms using JAX.

4,516

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.

Pros of Ignite

  • More general-purpose, supporting various deep learning tasks beyond reinforcement learning
  • Extensive documentation and tutorials for easier onboarding
  • Larger community and more frequent updates

Cons of Ignite

  • Less specialized for reinforcement learning tasks
  • May have more overhead for simple RL projects
  • Steeper learning curve for beginners in deep learning

Code Comparison

Ignite example:

from ignite.engine import Engine, Events

def train_step(engine, batch):
    # Training logic here
    return loss

trainer = Engine(train_step)
trainer.run(data_loader, max_epochs=10)

RLax example:

import rlax

def loss_fn(params, target, state):
    q_values = network(params, state)
    return rlax.q_learning(q_values, target, actions)

loss = jax.vmap(loss_fn)(params, targets, states)

Both libraries offer concise ways to implement training loops and loss functions, but RLax is more focused on RL-specific operations, while Ignite provides a more general framework for various deep learning tasks.

Convert Figma logo designs to code with AI

Visual Copilot

Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.

Try Visual Copilot

README

RLax

CI status docs pypi

RLax (pronounced "relax") is a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning agents. Full documentation can be found at rlax.readthedocs.io.

Installation

You can install the latest released version of RLax from PyPI via:

pip install rlax

or you can install the latest development version from GitHub:

pip install git+https://github.com/deepmind/rlax.git

All RLax code may then be just in time compiled for different hardware (e.g. CPU, GPU, TPU) using jax.jit.

In order to run the examples/ you will also need to clone the repo and install the additional requirements: optax, haiku, and bsuite.

Content

The operations and functions provided are not complete algorithms, but implementations of reinforcement learning specific mathematical operations that are needed when building fully-functional agents capable of learning:

  • Values, including both state and action-values;
  • Values for Non-linear generalizations of the Bellman equations.
  • Return Distributions, aka distributional value functions;
  • General Value Functions, for cumulants other than the main reward;
  • Policies, via policy-gradients in both continuous and discrete action spaces.

The library supports both on-policy and off-policy learning (i.e. learning from data sampled from a policy different from the agent's policy).

See file-level and function-level doc-strings for the documentation of these functions and for references to the papers that introduced and/or used them.

Usage

See examples/ for examples of using some of the functions in RLax to implement a few simple reinforcement learning agents, and demonstrate learning on BSuite's version of the Catch environment (a common unit-test for agent development in the reinforcement learning literature):

Other examples of JAX reinforcement learning agents using rlax can be found in bsuite.

Background

Reinforcement learning studies the problem of a learning system (the agent), which must learn to interact with the universe it is embedded in (the environment).

Agent and environment interact on discrete steps. On each step the agent selects an action, and is provided in return a (partial) snapshot of the state of the environment (the observation), and a scalar feedback signal (the reward).

The behaviour of the agent is characterized by a probability distribution over actions, conditioned on past observations of the environment (the policy). The agents seeks a policy that, from any given step, maximises the discounted cumulative reward that will be collected from that point onwards (the return).

Often the agent policy or the environment dynamics itself are stochastic. In this case the return is a random variable, and the optimal agent's policy is typically more precisely specified as a policy that maximises the expectation of the return (the value), under the agent's and environment's stochasticity.

Reinforcement Learning Algorithms

There are three prototypical families of reinforcement learning algorithms:

  1. those that estimate the value of states and actions, and infer a policy by inspection (e.g. by selecting the action with highest estimated value)
  2. those that learn a model of the environment (capable of predicting the observations and rewards) and infer a policy via planning.
  3. those that parameterize a policy that can be directly executed,

In any case, policies, values or models are just functions. In deep reinforcement learning such functions are represented by a neural network. In this setting, it is common to formulate reinforcement learning updates as differentiable pseudo-loss functions (analogously to (un-)supervised learning). Under automatic differentiation, the original update rule is recovered.

Note however, that in particular, the updates are only valid if the input data is sampled in the correct manner. For example, a policy gradient loss is only valid if the input trajectory is an unbiased sample from the current policy; i.e. the data are on-policy. The library cannot check or enforce such constraints. Links to papers describing how each operation is used are however provided in the functions' doc-strings.

Naming Conventions and Developer Guidelines

We define functions and operations for agents interacting with a single stream of experience. The JAX construct vmap can be used to apply these same functions to batches (e.g. to support replay and parallel data generation).

Many functions consider policies, actions, rewards, values, in consecutive timesteps in order to compute their outputs. In this case the suffix _t and tm1 is often to clarify on which step each input was generated, e.g:

  • q_tm1: the action value in the source state of a transition.
  • a_tm1: the action that was selected in the source state.
  • r_t: the resulting rewards collected in the destination state.
  • discount_t: the discount associated with a transition.
  • q_t: the action values in the destination state.

Extensive testing is provided for each function. All tests should also verify the output of rlax functions when compiled to XLA using jax.jit and when performing batch operations using jax.vmap.

Citing RLax

This repository is part of the DeepMind JAX Ecosystem, to cite Rlax please use the citation:

@software{deepmind2020jax,
  title = {The {D}eep{M}ind {JAX} {E}cosystem},
  author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
  url = {http://github.com/deepmind},
  year = {2020},
}