Convert Figma logo to code with AI

google logoflax

Flax is a neural network library for JAX that is designed for flexibility.

5,936
628
5,936
272

Top Related Projects

185,446

An Open Source Machine Learning Framework for Everyone

82,049

Tensors and Dynamic neural networks in Python with strong GPU acceleration

34,658

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

14,221

Distributed training framework for TensorFlow, Keras, PyTorch, and Apache MXNet.

30,331

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.

Quick Overview

Flax is a neural network library for JAX that is designed for flexibility and ease of use. It provides a set of building blocks for machine learning research and development, focusing on simplicity and composability while leveraging JAX's powerful features like automatic differentiation and just-in-time compilation.

Pros

  • Highly flexible and customizable, allowing researchers to easily implement and experiment with new ideas
  • Seamless integration with JAX, providing access to its powerful features like automatic differentiation and GPU/TPU acceleration
  • Clean and intuitive API, making it easy for both beginners and experienced practitioners to use
  • Excellent documentation and growing community support

Cons

  • Smaller ecosystem compared to more established frameworks like PyTorch or TensorFlow
  • Learning curve for those unfamiliar with JAX's functional programming paradigm
  • Limited pre-trained models and high-level APIs compared to some other frameworks
  • Potential performance overhead for small models due to JIT compilation

Code Examples

  1. Creating a simple neural network:
import flax.linen as nn

class SimpleNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x
  1. Training a model using Flax and Optax:
import jax
import jax.numpy as jnp
import optax

def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['image'])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['label'])
        return loss.mean()
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss
  1. Defining a custom layer:
class CustomLayer(nn.Module):
    features: int

    @nn.compact
    def __call__(self, inputs):
        weights = self.param('weights', nn.initializers.normal(), (inputs.shape[-1], self.features))
        return jnp.dot(inputs, weights) + jnp.sin(inputs)

Getting Started

To get started with Flax, first install it using pip:

pip install flax

Then, you can create a simple model and train it:

import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

# Define your model
class MyModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

# Initialize the model
model = MyModel()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))

# Create an optimizer
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)

# Training loop (simplified)
def train_step(params, opt_state, batch):
    def loss_fn(params):
        logits = model.apply({'params': params}, batch['image'])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['label']).mean()
        return loss
    
    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Use train_step in your training loop

Competitor Comparisons

185,446

An Open Source Machine Learning Framework for Everyone

Pros of TensorFlow

  • Larger ecosystem with more tools, libraries, and community support
  • Better production deployment options, including TensorFlow Serving
  • More comprehensive documentation and learning resources

Cons of TensorFlow

  • More complex API, steeper learning curve
  • Slower development cycle and less flexibility for research
  • Heavier and more resource-intensive

Code Comparison

Flax:

import jax.numpy as jnp
from flax import linen as nn

class MyModel(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=12)(x)
    x = nn.relu(x)
    return nn.Dense(features=1)(x)

TensorFlow:

import tensorflow as tf

class MyModel(tf.keras.Model):
  def __init__(self):
    super().__init__()
    self.dense1 = tf.keras.layers.Dense(12, activation='relu')
    self.dense2 = tf.keras.layers.Dense(1)

  def call(self, x):
    x = self.dense1(x)
    return self.dense2(x)

Flax offers a more functional and concise approach, while TensorFlow uses an object-oriented style with explicit layer definitions. Flax leverages JAX's automatic differentiation and JIT compilation, potentially offering better performance for certain workloads.

82,049

Tensors and Dynamic neural networks in Python with strong GPU acceleration

Pros of PyTorch

  • Larger community and ecosystem, with more resources and third-party libraries
  • Dynamic computational graphs, allowing for more flexible and intuitive debugging
  • Easier to use for researchers and those new to deep learning

Cons of PyTorch

  • Generally slower than Flax for large-scale training on TPUs
  • Less optimized for JAX-based acceleration and functional programming paradigms
  • More memory-intensive due to its eager execution model

Code Comparison

PyTorch example:

import torch

x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.add(x, y)

Flax example:

import jax.numpy as jnp
from flax import linen as nn

x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])
z = jnp.add(x, y)

Both examples perform simple array addition, but Flax uses JAX's NumPy-like interface, while PyTorch uses its own tensor operations. Flax's approach is more aligned with functional programming principles, which can lead to better performance optimizations in certain scenarios.

34,658

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.

Pros of DeepSpeed

  • Offers more advanced distributed training features, including ZeRO optimizer stages and pipeline parallelism
  • Provides better memory efficiency, allowing training of larger models on limited hardware
  • Includes a more comprehensive suite of optimization techniques for large-scale model training

Cons of DeepSpeed

  • Steeper learning curve and more complex setup compared to Flax's simplicity
  • Less integrated with JAX ecosystem, which may be a drawback for some researchers and developers
  • Primarily focused on PyTorch, limiting flexibility for users of other frameworks

Code Comparison

Flax example:

import jax.numpy as jnp
from flax import linen as nn

class MyModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=12)(x)
        return nn.relu(x)

DeepSpeed example:

import torch
import deepspeed

model = MyModel()
engine = deepspeed.initialize(model=model, config_params=ds_config)
output = engine(input_data)

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

Pros of Transformers

  • Extensive pre-trained model library for various NLP tasks
  • User-friendly API with high-level abstractions for easy implementation
  • Strong community support and frequent updates

Cons of Transformers

  • Less flexible for custom model architectures
  • Potentially higher memory usage due to pre-built models
  • Steeper learning curve for advanced customization

Code Comparison

Transformers example:

from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)

Flax example:

from flax import linen as nn
import jax.numpy as jnp

class MyModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=12)(x)
        return nn.relu(x)
14,221

Distributed training framework for TensorFlow, Keras, PyTorch, and Apache MXNet.

Pros of Horovod

  • Supports multiple deep learning frameworks (TensorFlow, PyTorch, MXNet)
  • Designed for distributed training across multiple GPUs and nodes
  • Integrates well with existing codebases and requires minimal code changes

Cons of Horovod

  • Steeper learning curve for beginners compared to Flax
  • Less focus on functional programming paradigms
  • May have more overhead for single-machine training scenarios

Code Comparison

Horovod (with TensorFlow):

import horovod.tensorflow as hvd
hvd.init()
optimizer = tf.optimizers.Adam(0.001 * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer)

Flax:

from flax import optim
optimizer = optim.Adam(learning_rate=0.001)
optimizer = jax.pmap(optimizer.create)

Horovod focuses on distributed training across multiple devices, while Flax emphasizes functional programming and JAX integration. Horovod requires explicit initialization and wrapping of optimizers, whereas Flax leverages JAX's parallelization capabilities more seamlessly. Flax's approach may be more intuitive for those familiar with JAX, while Horovod offers broader framework support and established distributed training patterns.

30,331

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.

Pros of fairseq

  • More comprehensive toolkit for sequence modeling tasks
  • Extensive support for machine translation and other NLP tasks
  • Larger community and more extensive documentation

Cons of fairseq

  • Steeper learning curve due to its complexity
  • Less flexibility for general-purpose machine learning tasks
  • Slower development cycle compared to Flax

Code Comparison

fairseq example:

from fairseq.models.transformer import TransformerModel

model = TransformerModel.from_pretrained('/path/to/model', checkpoint_file='model.pt')
tokens = model.encode('Hello world!')
translated = model.translate(tokens)
print(translated)

Flax example:

from flax import linen as nn
import jax.numpy as jnp

class MyModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=12)(x)
        return nn.relu(x)

model = MyModel()
x = jnp.ones((1, 3))
variables = model.init(jax.random.PRNGKey(0), x)
y = model.apply(variables, x)

fairseq is more focused on NLP tasks and provides pre-built models, while Flax offers a more flexible, low-level approach for building custom neural networks using JAX. fairseq is better suited for researchers working on specific NLP tasks, while Flax is more versatile for general machine learning applications.

Convert Figma logo designs to code with AI

Visual Copilot

Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.

Try Visual Copilot

README

logo

Flax: A neural network library and ecosystem for JAX designed for flexibility

Build coverage

Overview | Quick install | What does Flax look like? | Documentation

📣 NEW: Check out the NNX API!

This README is a very short intro. To learn everything you need to know about Flax, refer to our full documentation.

Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.

Flax is being used by a growing community of hundreds of folks in various Alphabet research departments for their daily work, as well as a growing community of open source projects.

The Flax team's mission is to serve the growing JAX neural network research ecosystem -- both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request threads. We are in the process of moving some remaining internal design docs and conversation threads to GitHub discussions, issues and pull requests. We hope to increasingly engage with the needs and clarifications of the broader ecosystem. Please let us know how we can help!

Please report any feature requests, issues, questions or concerns in our discussion forum, or just let us know what you're working on!

We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use Changelog entries and deprecation warnings when possible.

In case you want to reach us directly, we're at flax-dev@google.com.

Overview

Flax is a high-performance neural network library and ecosystem for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not by adding features to a framework.

Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:

  • Neural network API (flax.linen): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout

  • Utilities and patterns: replicated training, serialization and checkpointing, metrics, prefetching on device

  • Educational examples that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging

  • Fast, tuned large-scale end-to-end examples: CIFAR10, ResNet on ImageNet, Transformer LM1b

Quick install

You will need Python 3.6 or later, and a working JAX installation (with or without GPU support - refer to the instructions). For a CPU-only version of JAX:

pip install --upgrade pip # To support manylinux2010 wheels.
pip install --upgrade jax jaxlib # CPU-only

Then, install Flax from PyPi:

pip install flax

To upgrade to the latest version of Flax, you can use:

pip install --upgrade git+https://github.com/google/flax.git

To install some additional dependencies (like matplotlib) that are required but not included by some dependencies, you can use:

pip install "flax[all]"

What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about the Module abstraction, check out our docs, our broad intro to the Module abstraction. For additional concrete demonstrations of best practices, refer to our guides and developer notes.

from typing import Sequence

import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(nn.Dense(feat)(x))
    x = nn.Dense(self.features[-1])(x)
    return x

model = MLP([12, 8, 4])
batch = jnp.ones((32, 10))
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x

model = CNN()
batch = jnp.ones((32, 64, 64, 10))  # (N, H, W, C) format
variables = model.init(jax.random.key(0), batch)
output = model.apply(variables, batch)
class AutoEncoder(nn.Module):
  encoder_widths: Sequence[int]
  decoder_widths: Sequence[int]
  input_shape: Sequence[int]

  def setup(self):
    input_dim = np.prod(self.input_shape)
    self.encoder = MLP(self.encoder_widths)
    self.decoder = MLP(self.decoder_widths + (input_dim,))

  def __call__(self, x):
    return self.decode(self.encode(x))

  def encode(self, x):
    assert x.shape[1:] == self.input_shape
    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

  def decode(self, z):
    z = self.decoder(z)
    x = nn.sigmoid(z)
    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
    return x

model = AutoEncoder(encoder_widths=[20, 10, 5],
                    decoder_widths=[5, 10, 20],
                    input_shape=(12,))
batch = jnp.ones((16, 12))
variables = model.init(jax.random.key(0), batch)
encoded = model.apply(variables, batch, method=model.encode)
decoded = model.apply(variables, encoded, method=model.decode)

🤗 Hugging Face

In-detail examples to train and evaluate a variety of Flax models for Natural Language Processing, Computer Vision, and Speech Recognition are actively maintained in the 🤗 Transformers repository.

As of October 2021, the 19 most-used Transformer architectures are supported in Flax and over 5000 pretrained checkpoints in Flax have been uploaded to the 🤗 Hub.

Citing Flax

To cite this repository:

@software{flax2020github,
  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
  title = {{F}lax: A neural network library and ecosystem for {JAX}},
  url = {http://github.com/google/flax},
  version = {0.9.0},
  year = {2024},
}

In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from flax/version.py, and the year corresponds to the project's open-source release.

Note

Flax is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.