Convert Figma logo to code with AI

google logoflax

Flax is a neural network library for JAX that is designed for flexibility.

6,199
652
6,199
340

Top Related Projects

186,879

An Open Source Machine Learning Framework for Everyone

85,015

Tensors and Dynamic neural networks in Python with strong GPU acceleration

35,868

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

14,221

Distributed training framework for TensorFlow, Keras, PyTorch, and Apache MXNet.

30,331

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.

Quick Overview

Flax is a neural network library for JAX that is designed for flexibility and ease of use. It provides a set of building blocks for machine learning research and development, focusing on simplicity and composability while leveraging JAX's powerful features like automatic differentiation and just-in-time compilation.

Pros

  • Highly flexible and customizable, allowing researchers to easily implement and experiment with new ideas
  • Seamless integration with JAX, providing access to its powerful features like automatic differentiation and GPU/TPU acceleration
  • Clean and intuitive API, making it easy for both beginners and experienced practitioners to use
  • Excellent documentation and growing community support

Cons

  • Smaller ecosystem compared to more established frameworks like PyTorch or TensorFlow
  • Learning curve for those unfamiliar with JAX's functional programming paradigm
  • Limited pre-trained models and high-level APIs compared to some other frameworks
  • Potential performance overhead for small models due to JIT compilation

Code Examples

  1. Creating a simple neural network:
import flax.linen as nn

class SimpleNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x
  1. Training a model using Flax and Optax:
import jax
import jax.numpy as jnp
import optax

def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, batch['image'])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['label'])
        return loss.mean()
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss
  1. Defining a custom layer:
class CustomLayer(nn.Module):
    features: int

    @nn.compact
    def __call__(self, inputs):
        weights = self.param('weights', nn.initializers.normal(), (inputs.shape[-1], self.features))
        return jnp.dot(inputs, weights) + jnp.sin(inputs)

Getting Started

To get started with Flax, first install it using pip:

pip install flax

Then, you can create a simple model and train it:

import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

# Define your model
class MyModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

# Initialize the model
model = MyModel()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))

# Create an optimizer
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(params)

# Training loop (simplified)
def train_step(params, opt_state, batch):
    def loss_fn(params):
        logits = model.apply({'params': params}, batch['image'])
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch['label']).mean()
        return loss
    
    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

# Use train_step in your training loop

Competitor Comparisons

186,879

An Open Source Machine Learning Framework for Everyone

Pros of TensorFlow

  • Larger ecosystem with more tools, libraries, and community support
  • Better production deployment options, including TensorFlow Serving
  • More comprehensive documentation and learning resources

Cons of TensorFlow

  • More complex API, steeper learning curve
  • Slower development cycle and less flexibility for research
  • Heavier and more resource-intensive

Code Comparison

Flax:

import jax.numpy as jnp
from flax import linen as nn

class MyModel(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=12)(x)
    x = nn.relu(x)
    return nn.Dense(features=1)(x)

TensorFlow:

import tensorflow as tf

class MyModel(tf.keras.Model):
  def __init__(self):
    super().__init__()
    self.dense1 = tf.keras.layers.Dense(12, activation='relu')
    self.dense2 = tf.keras.layers.Dense(1)

  def call(self, x):
    x = self.dense1(x)
    return self.dense2(x)

Flax offers a more functional and concise approach, while TensorFlow uses an object-oriented style with explicit layer definitions. Flax leverages JAX's automatic differentiation and JIT compilation, potentially offering better performance for certain workloads.

85,015

Tensors and Dynamic neural networks in Python with strong GPU acceleration

Pros of PyTorch

  • Larger community and ecosystem, with more resources and third-party libraries
  • Dynamic computational graphs, allowing for more flexible and intuitive debugging
  • Easier to use for researchers and those new to deep learning

Cons of PyTorch

  • Generally slower than Flax for large-scale training on TPUs
  • Less optimized for JAX-based acceleration and functional programming paradigms
  • More memory-intensive due to its eager execution model

Code Comparison

PyTorch example:

import torch

x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.add(x, y)

Flax example:

import jax.numpy as jnp
from flax import linen as nn

x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])
z = jnp.add(x, y)

Both examples perform simple array addition, but Flax uses JAX's NumPy-like interface, while PyTorch uses its own tensor operations. Flax's approach is more aligned with functional programming principles, which can lead to better performance optimizations in certain scenarios.

35,868

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.

Pros of DeepSpeed

  • Offers more advanced distributed training features, including ZeRO optimizer stages and pipeline parallelism
  • Provides better memory efficiency, allowing training of larger models on limited hardware
  • Includes a more comprehensive suite of optimization techniques for large-scale model training

Cons of DeepSpeed

  • Steeper learning curve and more complex setup compared to Flax's simplicity
  • Less integrated with JAX ecosystem, which may be a drawback for some researchers and developers
  • Primarily focused on PyTorch, limiting flexibility for users of other frameworks

Code Comparison

Flax example:

import jax.numpy as jnp
from flax import linen as nn

class MyModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=12)(x)
        return nn.relu(x)

DeepSpeed example:

import torch
import deepspeed

model = MyModel()
engine = deepspeed.initialize(model=model, config_params=ds_config)
output = engine(input_data)

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.

Pros of Transformers

  • Extensive pre-trained model library for various NLP tasks
  • User-friendly API with high-level abstractions for easy implementation
  • Strong community support and frequent updates

Cons of Transformers

  • Less flexible for custom model architectures
  • Potentially higher memory usage due to pre-built models
  • Steeper learning curve for advanced customization

Code Comparison

Transformers example:

from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)

Flax example:

from flax import linen as nn
import jax.numpy as jnp

class MyModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=12)(x)
        return nn.relu(x)
14,221

Distributed training framework for TensorFlow, Keras, PyTorch, and Apache MXNet.

Pros of Horovod

  • Supports multiple deep learning frameworks (TensorFlow, PyTorch, MXNet)
  • Designed for distributed training across multiple GPUs and nodes
  • Integrates well with existing codebases and requires minimal code changes

Cons of Horovod

  • Steeper learning curve for beginners compared to Flax
  • Less focus on functional programming paradigms
  • May have more overhead for single-machine training scenarios

Code Comparison

Horovod (with TensorFlow):

import horovod.tensorflow as hvd
hvd.init()
optimizer = tf.optimizers.Adam(0.001 * hvd.size())
optimizer = hvd.DistributedOptimizer(optimizer)

Flax:

from flax import optim
optimizer = optim.Adam(learning_rate=0.001)
optimizer = jax.pmap(optimizer.create)

Horovod focuses on distributed training across multiple devices, while Flax emphasizes functional programming and JAX integration. Horovod requires explicit initialization and wrapping of optimizers, whereas Flax leverages JAX's parallelization capabilities more seamlessly. Flax's approach may be more intuitive for those familiar with JAX, while Horovod offers broader framework support and established distributed training patterns.

30,331

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.

Pros of fairseq

  • More comprehensive toolkit for sequence modeling tasks
  • Extensive support for machine translation and other NLP tasks
  • Larger community and more extensive documentation

Cons of fairseq

  • Steeper learning curve due to its complexity
  • Less flexibility for general-purpose machine learning tasks
  • Slower development cycle compared to Flax

Code Comparison

fairseq example:

from fairseq.models.transformer import TransformerModel

model = TransformerModel.from_pretrained('/path/to/model', checkpoint_file='model.pt')
tokens = model.encode('Hello world!')
translated = model.translate(tokens)
print(translated)

Flax example:

from flax import linen as nn
import jax.numpy as jnp

class MyModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=12)(x)
        return nn.relu(x)

model = MyModel()
x = jnp.ones((1, 3))
variables = model.init(jax.random.PRNGKey(0), x)
y = model.apply(variables, x)

fairseq is more focused on NLP tasks and provides pre-built models, while Flax offers a more flexible, low-level approach for building custom neural networks using JAX. fairseq is better suited for researchers working on specific NLP tasks, while Flax is more versatile for general machine learning applications.

Convert Figma logo designs to code with AI

Visual Copilot

Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.

Try Visual Copilot

README

logo

Flax: A neural network library and ecosystem for JAX designed for flexibility

Build coverage

Overview | Quick install | What does Flax look like? | Documentation

Released in 2024, Flax NNX is a new simplified Flax API that is designed to make it easier to create, inspect, debug, and analyze neural networks in JAX. It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, enabling reference sharing and mutability.

Flax NNX evolved from the Flax Linen API, which was released in 2020 by engineers and researchers at Google Brain in close collaboration with the JAX team.

You can learn more about Flax NNX on the dedicated Flax documentation site. Make sure you check out:

Note: Flax Linen's documentation has its own site.

The Flax team's mission is to serve the growing JAX neural network research ecosystem - both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request threads.

You can make feature requests, let us know what you are working on, report issues, ask questions in our Flax GitHub discussion forum.

We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use Changelog entries and deprecation warnings when possible.

In case you want to reach us directly, we're at flax-dev@google.com.

Overview

Flax is a high-performance neural network library and ecosystem for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not adding features to a framework.

Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:

Quick install

Flax uses JAX, so do check out JAX installation instructions on CPUs, GPUs and TPUs.

You will need Python 3.8 or later. Install Flax from PyPi:

pip install flax

To upgrade to the latest version of Flax, you can use:

pip install --upgrade git+https://github.com/google/flax.git

To install some additional dependencies (like matplotlib) that are required but not included by some dependencies, you can use:

pip install "flax[all]"

What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about the Module abstraction, check out our docs, our broad intro to the Module abstraction. For additional concrete demonstrations of best practices, refer to our guides and developer notes.

Example of an MLP:

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

Example of a CNN:

class CNN(nnx.Module):
  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
    self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
    self.linear2 = nnx.Linear(256, 10, rngs=rngs)

  def __call__(self, x):
    x = self.avg_pool(nnx.relu(self.conv1(x)))
    x = self.avg_pool(nnx.relu(self.conv2(x)))
    x = x.reshape(x.shape[0], -1)  # flatten
    x = nnx.relu(self.linear1(x))
    x = self.linear2(x)
    return x

Example of an autoencoder:

Encoder = lambda rngs: nnx.Linear(2, 10, rngs=rngs)
Decoder = lambda rngs: nnx.Linear(10, 2, rngs=rngs)

class AutoEncoder(nnx.Module):
  def __init__(self, rngs):
    self.encoder = Encoder(rngs)
    self.decoder = Decoder(rngs)

  def __call__(self, x) -> jax.Array:
    return self.decoder(self.encoder(x))

  def encode(self, x) -> jax.Array:
    return self.encoder(x)

Citing Flax

To cite this repository:

@software{flax2020github,
  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
  title = {{F}lax: A neural network library and ecosystem for {JAX}},
  url = {http://github.com/google/flax},
  version = {0.10.2},
  year = {2024},
}

In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from flax/version.py, and the year corresponds to the project's open-source release.

Note

Flax is an open source project maintained by a dedicated team at Google DeepMind, but is not an official Google product.