Top Related Projects
An Open Source Machine Learning Framework for Everyone
Tensors and Dynamic neural networks in Python with strong GPU acceleration
DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Distributed training framework for TensorFlow, Keras, PyTorch, and Apache MXNet.
Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
Quick Overview
Flax is a neural network library for JAX that is designed for flexibility and ease of use. It provides a set of building blocks for machine learning research and development, focusing on simplicity and composability while leveraging JAX's powerful features like automatic differentiation and just-in-time compilation.
Pros
- Highly flexible and customizable, allowing researchers to easily implement and experiment with new ideas
- Seamless integration with JAX, providing access to its powerful features like automatic differentiation and GPU/TPU acceleration
- Clean and intuitive API, making it easy for both beginners and experienced practitioners to use
- Excellent documentation and growing community support
Cons
- Smaller ecosystem compared to more established frameworks like PyTorch or TensorFlow
- Learning curve for those unfamiliar with JAX's functional programming paradigm
- Limited pre-trained models and high-level APIs compared to some other frameworks
- Potential performance overhead for small models due to JIT compilation
Code Examples
- Creating a simple neural network:
import flax.linen as nn
class SimpleNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=128)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
- Training a model using Flax and Optax:
import jax
import jax.numpy as jnp
import optax
def train_step(state, batch):
def loss_fn(params):
logits = state.apply_fn({'params': params}, batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['label'])
return loss.mean()
loss, grads = jax.value_and_grad(loss_fn)(state.params)
state = state.apply_gradients(grads=grads)
return state, loss
- Defining a custom layer:
class CustomLayer(nn.Module):
features: int
@nn.compact
def __call__(self, inputs):
weights = self.param('weights', nn.initializers.normal(), (inputs.shape[-1], self.features))
return jnp.dot(inputs, weights) + jnp.sin(inputs)
Getting Started
To get started with Flax, first install it using pip:
pip install flax
Then, you can create a simple model and train it:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
# Define your model
class MyModel(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=128)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x
# Initialize the model
model = MyModel()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))
# Create an optimizer
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)
# Training loop (simplified)
def train_step(params, opt_state, batch):
def loss_fn(params):
logits = model.apply({'params': params}, batch['image'])
loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['label']).mean()
return loss
loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss
# Use train_step in your training loop
Competitor Comparisons
An Open Source Machine Learning Framework for Everyone
Pros of TensorFlow
- Larger ecosystem with more tools, libraries, and community support
- Better production deployment options, including TensorFlow Serving
- More comprehensive documentation and learning resources
Cons of TensorFlow
- More complex API, steeper learning curve
- Slower development cycle and less flexibility for research
- Heavier and more resource-intensive
Code Comparison
Flax:
import jax.numpy as jnp
from flax import linen as nn
class MyModel(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=12)(x)
x = nn.relu(x)
return nn.Dense(features=1)(x)
TensorFlow:
import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super().__init__()
self.dense1 = tf.keras.layers.Dense(12, activation='relu')
self.dense2 = tf.keras.layers.Dense(1)
def call(self, x):
x = self.dense1(x)
return self.dense2(x)
Flax offers a more functional and concise approach, while TensorFlow uses an object-oriented style with explicit layer definitions. Flax leverages JAX's automatic differentiation and JIT compilation, potentially offering better performance for certain workloads.
Tensors and Dynamic neural networks in Python with strong GPU acceleration
Pros of PyTorch
- Larger community and ecosystem, with more resources and third-party libraries
- Dynamic computational graphs, allowing for more flexible and intuitive debugging
- Easier to use for researchers and those new to deep learning
Cons of PyTorch
- Generally slower than Flax for large-scale training on TPUs
- Less optimized for JAX-based acceleration and functional programming paradigms
- More memory-intensive due to its eager execution model
Code Comparison
PyTorch example:
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.add(x, y)
Flax example:
import jax.numpy as jnp
from flax import linen as nn
x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])
z = jnp.add(x, y)
Both examples perform simple array addition, but Flax uses JAX's NumPy-like interface, while PyTorch uses its own tensor operations. Flax's approach is more aligned with functional programming principles, which can lead to better performance optimizations in certain scenarios.
DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
Pros of DeepSpeed
- Offers more advanced distributed training features, including ZeRO optimizer stages and pipeline parallelism
- Provides better memory efficiency, allowing training of larger models on limited hardware
- Includes a more comprehensive suite of optimization techniques for large-scale model training
Cons of DeepSpeed
- Steeper learning curve and more complex setup compared to Flax's simplicity
- Less integrated with JAX ecosystem, which may be a drawback for some researchers and developers
- Primarily focused on PyTorch, limiting flexibility for users of other frameworks
Code Comparison
Flax example:
import jax.numpy as jnp
from flax import linen as nn
class MyModel(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=12)(x)
return nn.relu(x)
DeepSpeed example:
import torch
import deepspeed
model = MyModel()
engine = deepspeed.initialize(model=model, config_params=ds_config)
output = engine(input_data)
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Pros of Transformers
- Extensive pre-trained model library for various NLP tasks
- User-friendly API with high-level abstractions for easy implementation
- Strong community support and frequent updates
Cons of Transformers
- Less flexible for custom model architectures
- Potentially higher memory usage due to pre-built models
- Steeper learning curve for advanced customization
Code Comparison
Transformers example:
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)
Flax example:
from flax import linen as nn
import jax.numpy as jnp
class MyModel(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=12)(x)
return nn.relu(x)
Distributed training framework for TensorFlow, Keras, PyTorch, and Apache MXNet.
Pros of Horovod
- Supports multiple deep learning frameworks (TensorFlow, PyTorch, MXNet)
- Designed for distributed training across multiple GPUs and nodes
- Integrates well with existing codebases and requires minimal code changes
Cons of Horovod
- Steeper learning curve for beginners compared to Flax
- Less focus on functional programming paradigms
- May have more overhead for single-machine training scenarios
Code Comparison
Horovod (with TensorFlow):
import horovod.tensorflow as hvd
hvd.init()
optimizer = tf.optimizers.Adam(0.001 * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer)
Flax:
from flax import optim
optimizer = optim.Adam(learning_rate=0.001)
optimizer = jax.pmap(optimizer.create)
Horovod focuses on distributed training across multiple devices, while Flax emphasizes functional programming and JAX integration. Horovod requires explicit initialization and wrapping of optimizers, whereas Flax leverages JAX's parallelization capabilities more seamlessly. Flax's approach may be more intuitive for those familiar with JAX, while Horovod offers broader framework support and established distributed training patterns.
Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
Pros of fairseq
- More comprehensive toolkit for sequence modeling tasks
- Extensive support for machine translation and other NLP tasks
- Larger community and more extensive documentation
Cons of fairseq
- Steeper learning curve due to its complexity
- Less flexibility for general-purpose machine learning tasks
- Slower development cycle compared to Flax
Code Comparison
fairseq example:
from fairseq.models.transformer import TransformerModel
model = TransformerModel.from_pretrained('/path/to/model', checkpoint_file='model.pt')
tokens = model.encode('Hello world!')
translated = model.translate(tokens)
print(translated)
Flax example:
from flax import linen as nn
import jax.numpy as jnp
class MyModel(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=12)(x)
return nn.relu(x)
model = MyModel()
x = jnp.ones((1, 3))
variables = model.init(jax.random.PRNGKey(0), x)
y = model.apply(variables, x)
fairseq is more focused on NLP tasks and provides pre-built models, while Flax offers a more flexible, low-level approach for building custom neural networks using JAX. fairseq is better suited for researchers working on specific NLP tasks, while Flax is more versatile for general machine learning applications.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
Flax: A neural network library and ecosystem for JAX designed for flexibility
Overview | Quick install | What does Flax look like? | Documentation
ð£ NEW: Check out the NNX API!
This README is a very short intro. To learn everything you need to know about Flax, refer to our full documentation.
Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.
Flax is being used by a growing community of hundreds of folks in various Alphabet research departments for their daily work, as well as a growing community of open source projects.
The Flax team's mission is to serve the growing JAX neural network research ecosystem -- both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request threads. We are in the process of moving some remaining internal design docs and conversation threads to GitHub discussions, issues and pull requests. We hope to increasingly engage with the needs and clarifications of the broader ecosystem. Please let us know how we can help!
Please report any feature requests, issues, questions or concerns in our discussion forum, or just let us know what you're working on!
We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use Changelog entries and deprecation warnings when possible.
In case you want to reach us directly, we're at flax-dev@google.com.
Overview
Flax is a high-performance neural network library and ecosystem for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not by adding features to a framework.
Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:
-
Neural network API (
flax.linen
): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout -
Utilities and patterns: replicated training, serialization and checkpointing, metrics, prefetching on device
-
Educational examples that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging
-
Fast, tuned large-scale end-to-end examples: CIFAR10, ResNet on ImageNet, Transformer LM1b
Quick install
You will need Python 3.6 or later, and a working JAX installation (with or without GPU support - refer to the instructions). For a CPU-only version of JAX:
pip install --upgrade pip # To support manylinux2010 wheels.
pip install --upgrade jax jaxlib # CPU-only
Then, install Flax from PyPi:
pip install flax
To upgrade to the latest version of Flax, you can use:
pip install --upgrade git+https://github.com/google/flax.git
To install some additional dependencies (like matplotlib
) that are required but not included
by some dependencies, you can use:
pip install "flax[all]"
What does Flax look like?
We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.
To learn more about the Module
abstraction, check out our docs, our broad intro to the Module abstraction. For additional concrete demonstrations of best practices, refer to our
guides and
developer notes.
from typing import Sequence
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
class MLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, x):
for feat in self.features[:-1]:
x = nn.relu(nn.Dense(feat)(x))
x = nn.Dense(self.features[-1])(x)
return x
model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
x = nn.log_softmax(x)
return x
model = CNN()
batch = jnp.ones((32, 64, 64, 10)) # (N, H, W, C) format
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
class AutoEncoder(nn.Module):
encoder_widths: Sequence[int]
decoder_widths: Sequence[int]
input_shape: Sequence[int]
def setup(self):
input_dim = np.prod(self.input_shape)
self.encoder = MLP(self.encoder_widths)
self.decoder = MLP(self.decoder_widths + (input_dim,))
def __call__(self, x):
return self.decode(self.encode(x))
def encode(self, x):
assert x.shape[1:] == self.input_shape
return self.encoder(jnp.reshape(x, (x.shape[0], -1)))
def decode(self, z):
z = self.decoder(z)
x = nn.sigmoid(z)
x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
return x
model = AutoEncoder(encoder_widths=[20, 10, 5],
decoder_widths=[5, 10, 20],
input_shape=(12,))
batch = jnp.ones((16, 12))
variables = model.init(jax.random.key(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)
ð¤ Hugging Face
In-detail examples to train and evaluate a variety of Flax models for Natural Language Processing, Computer Vision, and Speech Recognition are actively maintained in the ð¤ Transformers repository.
As of October 2021, the 19 most-used Transformer architectures are supported in Flax and over 5000 pretrained checkpoints in Flax have been uploaded to the ð¤ Hub.
Citing Flax
To cite this repository:
@software{flax2020github,
author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
title = {{F}lax: A neural network library and ecosystem for {JAX}},
url = {http://github.com/google/flax},
version = {0.9.0},
year = {2024},
}
In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from flax/version.py, and the year corresponds to the project's open-source release.
Note
Flax is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.
Top Related Projects
An Open Source Machine Learning Framework for Everyone
Tensors and Dynamic neural networks in Python with strong GPU acceleration
DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Distributed training framework for TensorFlow, Keras, PyTorch, and Apache MXNet.
Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot