Top Related Projects
An Open Source Machine Learning Framework for Everyone
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Tensors and Dynamic neural networks in Python with strong GPU acceleration
Deep Learning for humans
DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
Quick Overview
Trax is a deep learning library developed by Google Research. It focuses on clear code and speed, particularly for sequence models and reinforcement learning. Trax is designed to be easy to use, especially for researchers and students working on neural network projects.
Pros
- Clear and readable code structure, making it easier for researchers to understand and modify
- Optimized for speed, particularly in sequence models and reinforcement learning tasks
- Built on top of JAX, providing efficient automatic differentiation and GPU/TPU support
- Includes a variety of pre-built models and layers for quick experimentation
Cons
- Smaller community compared to more established frameworks like TensorFlow or PyTorch
- Less extensive documentation and fewer tutorials available
- May have a steeper learning curve for those unfamiliar with JAX
- Limited ecosystem of third-party extensions and tools
Code Examples
- Creating a simple feed-forward neural network:
import trax
from trax import layers as tl
model = tl.Serial(
tl.Dense(64),
tl.Relu(),
tl.Dense(10),
tl.LogSoftmax()
)
- Training a model on MNIST dataset:
from trax import fastmath
from trax.supervised import training
# Load MNIST data
mnist_data = trax.data.MNIST()
train_stream = mnist_data.train_stream(batch_size=128)
# Create the training loop
train_task = training.TrainTask(
labeled_data=train_stream,
loss_layer=tl.CrossEntropyLoss(),
optimizer=trax.optimizers.Adam(0.001)
)
eval_task = training.EvalTask(
labeled_data=mnist_data.eval_stream(batch_size=128),
metrics=[tl.CrossEntropyLoss(), tl.Accuracy()]
)
loop = training.Loop(model, train_task, eval_tasks=[eval_task])
loop.run(n_steps=1000)
- Implementing a transformer model:
transformer = tl.Serial(
tl.ShiftRight(mode='train'),
tl.Embedding(vocab_size=32000, d_feature=512),
tl.Transformer(
n_layers=6,
d_model=512,
d_ff=2048,
n_heads=8,
dropout=0.1,
mode='train'
),
tl.Dense(32000),
tl.LogSoftmax()
)
Getting Started
To get started with Trax, first install it using pip:
pip install trax
Then, you can import and use Trax in your Python code:
import trax
from trax import layers as tl
# Create a simple model
model = tl.Serial(
tl.Dense(64),
tl.Relu(),
tl.Dense(10)
)
# Print the model layout
print(model)
This will create a simple feed-forward neural network and print its structure. From here, you can explore more complex models, datasets, and training procedures using the Trax documentation and examples.
Competitor Comparisons
An Open Source Machine Learning Framework for Everyone
Pros of TensorFlow
- Larger ecosystem with more tools, libraries, and community support
- Extensive documentation and learning resources
- Widely adopted in industry and research
Cons of TensorFlow
- More complex API, steeper learning curve
- Slower development cycle for new features
- Heavier and more resource-intensive
Code Comparison
TensorFlow:
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
Trax:
import trax
model = trax.layers.Serial(
trax.layers.Dense(64),
trax.layers.Relu(),
trax.layers.Dense(10),
trax.layers.LogSoftmax()
)
Key Differences
- Trax focuses on simplicity and ease of use, while TensorFlow offers more flexibility and features
- Trax is designed for fast prototyping and research, whereas TensorFlow is more suited for production-ready applications
- Trax has a more functional programming style, while TensorFlow follows an object-oriented approach
Use Cases
- Choose TensorFlow for large-scale, production-ready projects with complex requirements
- Opt for Trax for quick experimentation, research, and projects that prioritize simplicity and readability
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Pros of Transformers
- Larger community and more extensive documentation
- Wider range of pre-trained models and tasks supported
- More frequent updates and active development
Cons of Transformers
- Can be more complex for beginners due to its extensive features
- Potentially higher memory usage for some models
Code Comparison
Transformers:
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs)
Trax:
from trax import layers as tl
from trax.fastmath import numpy as jnp
model = tl.Serial(
tl.Embedding(vocab_size=32000, d_feature=512),
tl.LSTM(n_units=512),
tl.Dense(10)
)
The code snippets demonstrate the different approaches to model initialization and usage. Transformers offers a more high-level API with pre-trained models, while Trax provides a lower-level approach for building custom architectures.
Tensors and Dynamic neural networks in Python with strong GPU acceleration
Pros of PyTorch
- Larger community and ecosystem, with more resources and third-party libraries
- More flexible and Pythonic API, allowing for easier debugging and customization
- Better support for dynamic computational graphs, beneficial for natural language processing tasks
Cons of PyTorch
- Slightly steeper learning curve for beginners compared to Trax's simplicity
- Less focus on TPU support, which may be important for some large-scale projects
- Potentially slower execution speed in some scenarios due to its dynamic nature
Code Comparison
PyTorch example:
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.add(x, y)
Trax example:
import trax
import numpy as np
x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
z = trax.math.add(x, y)
Both frameworks offer similar functionality, but PyTorch's syntax is often more intuitive for Python developers. Trax, however, integrates more seamlessly with NumPy arrays and provides a simpler API for some tasks. The choice between the two depends on specific project requirements, team expertise, and the need for advanced features or community support.
Deep Learning for humans
Pros of Keras
- Wider adoption and larger community support
- More extensive documentation and tutorials
- Easier to use for beginners and rapid prototyping
Cons of Keras
- Less flexibility for advanced users and custom implementations
- Slower execution compared to lower-level frameworks
- Limited support for distributed training
Code Comparison
Keras:
from keras.models import Sequential
from keras.layers import Dense
model = Sequential([
Dense(64, activation='relu', input_shape=(10,)),
Dense(1, activation='sigmoid')
])
Trax:
import trax.layers as tl
model = tl.Serial(
tl.Dense(64),
tl.Relu(),
tl.Dense(1),
tl.Sigmoid()
)
Both Keras and Trax offer high-level APIs for building neural networks, but Trax focuses more on advanced research and scalability. Keras provides a more user-friendly interface and extensive ecosystem, while Trax offers better performance for large-scale models and distributed training. The code comparison shows that both frameworks allow for concise model definitions, with Trax using a more functional approach compared to Keras' object-oriented style.
DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
Pros of DeepSpeed
- Offers more advanced distributed training features, including ZeRO optimizer stages and pipeline parallelism
- Provides better integration with popular frameworks like PyTorch and Hugging Face Transformers
- Includes a more comprehensive set of optimization techniques for large-scale model training
Cons of DeepSpeed
- Steeper learning curve due to more complex configuration options
- May require more setup and fine-tuning for optimal performance compared to Trax's simpler approach
Code Comparison
Trax example:
from trax import layers as tl
from trax.fastmath import numpy as jnp
def TransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8):
return tl.Serial(
tl.ShiftRight(mode='train'),
tl.Embedding(vocab_size, d_model),
[tl.TransformerLMBlock(d_model, d_ff, n_heads) for _ in range(n_layers)],
tl.Dense(vocab_size),
)
DeepSpeed example:
import torch
import deepspeed
model = MyModel()
model_engine, optimizer, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
Summary
While Trax focuses on simplicity and ease of use, DeepSpeed offers more advanced features for large-scale model training and optimization. DeepSpeed provides better integration with popular frameworks but may require more setup and configuration. Trax's code tends to be more concise, while DeepSpeed offers more flexibility and control over the training process.
Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
Pros of fairseq
- More extensive documentation and examples
- Wider range of pre-trained models available
- Stronger focus on machine translation tasks
Cons of fairseq
- Steeper learning curve for beginners
- Less emphasis on reinforcement learning
Code Comparison
fairseq:
from fairseq.models.transformer import TransformerModel
en2de = TransformerModel.from_pretrained(
'/path/to/checkpoints',
checkpoint_file='checkpoint_best.pt',
data_name_or_path='data-bin/wmt16_en_de_bpe32k'
)
en2de.translate('Hello world!')
trax:
from trax import layers as tl
from trax.models import transformer
model = transformer.TransformerLM(
vocab_size=33000,
d_model=512,
d_ff=2048,
n_layers=6,
n_heads=8,
max_len=2048,
mode='train',
)
Both repositories offer powerful tools for natural language processing tasks, with fairseq providing a more comprehensive set of pre-trained models and examples, particularly for machine translation. trax, on the other hand, offers a more streamlined approach with a focus on ease of use and integration with other Google AI tools. The code examples demonstrate the different approaches to model initialization and usage in each library.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
Trax — Deep Learning with Clear Code and Speed
Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team. This notebook (run it in colab) shows how to use Trax and where you can find more information.
- Run a pre-trained Transformer: create a translator in a few lines of code
- Features and resources: API docs, where to talk to us, how to open an issue and more
- Walkthrough: how Trax works, how to make new models and train on your own data
We welcome contributions to Trax! We welcome PRs with code for new models and layers as well as improvements to our code and documentation. We especially love notebooks that explain how models work and show how to use them to solve problems!
Here are a few example notebooks:-
- trax.data API explained : Explains some of the major functions in the
trax.data
API - Named Entity Recognition using Reformer : Uses a Kaggle dataset for implementing Named Entity Recognition using the Reformer architecture.
- Deep N-Gram models : Implementation of deep n-gram models trained on Shakespeares works
General Setup
Execute the following cell (once) before running any of the code samples.
import os
import numpy as np
!pip install -q -U trax
import trax
1. Run a pre-trained Transformer
Here is how you create an English-German translator in a few lines of code:
- create a Transformer model in Trax with trax.models.Transformer
- initialize it from a file with pre-trained weights with model.init_from_file
- tokenize your input sentence to input into the model with trax.data.tokenize
- decode from the Transformer with trax.supervised.decoding.autoregressive_sample
- de-tokenize the decoded result to get the translation with trax.data.detokenize
# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
input_vocab_size=33300,
d_model=512, d_ff=2048,
n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
max_len=2048, mode='predict')
# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
weights_only=True)
# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]), # Operates on streams.
vocab_dir='gs://trax-ml/vocabs/',
vocab_file='ende_32k.subword'))[0]
# Decode from the Transformer.
tokenized = tokenized[None, :] # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
model, tokenized, temperature=0.0) # Higher temperature: more diverse results.
# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1] # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
vocab_dir='gs://trax-ml/vocabs/',
vocab_file='ende_32k.subword')
print(translation)
Es ist schön, heute neue Dinge zu lernen!
2. Features and resources
Trax includes basic models (like ResNet, LSTM, Transformer) and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and new RL algorithms like AWR. Trax has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets.
You can use Trax either as a library from your own python scripts and notebooks or as a binary from the shell, which can be more convenient for training large models. It runs without any changes on CPUs, GPUs and TPUs.
- API docs
- chat with us
- open an issue
- subscribe to trax-discuss for news
3. Walkthrough
You can learn here how Trax works, how to create new models and how to train them on your own data.
Tensors and Fast Math
The basic units flowing through Trax models are tensors - multi-dimensional arrays, sometimes also known as numpy arrays, due to the most widely used package for tensor operations -- numpy
. You should take a look at the numpy guide if you don't know how to operate on tensors: Trax also uses the numpy API for that.
In Trax we want numpy operations to run very fast, making use of GPUs and TPUs to accelerate them. We also want to automatically compute gradients of functions on tensors. This is done in the trax.fastmath
package thanks to its backends -- JAX and TensorFlow numpy.
from trax.fastmath import numpy as fastnp
trax.fastmath.use_backend('jax') # Can be 'jax' or 'tensorflow-numpy'.
matrix = fastnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
print(f'matrix = \n{matrix}')
vector = fastnp.ones(3)
print(f'vector = {vector}')
product = fastnp.dot(vector, matrix)
print(f'product = {product}')
tanh = fastnp.tanh(product)
print(f'tanh(product) = {tanh}')
matrix =
[[1 2 3]
[4 5 6]
[7 8 9]]
vector = [1. 1. 1.]
product = [12. 15. 18.]
tanh(product) = [0.99999994 0.99999994 0.99999994]
Gradients can be calculated using trax.fastmath.grad
.
def f(x):
return 2.0 * x * x
grad_f = trax.fastmath.grad(f)
print(f'grad(2x^2) at 1 = {grad_f(1.0)}')
grad(2x^2) at 1 = 4.0
Layers
Layers are basic building blocks of Trax models. You will learn all about them in the layers intro but for now, just take a look at the implementation of one core Trax layer, Embedding
:
class Embedding(base.Layer):
"""Trainable layer that maps discrete tokens/IDs to vectors."""
def __init__(self,
vocab_size,
d_feature,
kernel_initializer=init.RandomNormalInitializer(1.0)):
"""Returns an embedding layer with given vocabulary size and vector size.
Args:
vocab_size: Size of the input vocabulary. The layer will assign a unique
vector to each ID in `range(vocab_size)`.
d_feature: Dimensionality/depth of the output vectors.
kernel_initializer: Function that creates (random) initial vectors for
the embedding.
"""
super().__init__(name=f'Embedding_{vocab_size}_{d_feature}')
self._d_feature = d_feature # feature dimensionality
self._vocab_size = vocab_size
self._kernel_initializer = kernel_initializer
def forward(self, x):
"""Returns embedding vectors corresponding to input token IDs.
Args:
x: Tensor of token IDs.
Returns:
Tensor of embedding vectors.
"""
return jnp.take(self.weights, x, axis=0, mode='clip')
def init_weights_and_state(self, input_signature):
"""Returns tensor of newly initialized embedding vectors."""
del input_signature
shape_w = (self._vocab_size, self._d_feature)
w = self._kernel_initializer(shape_w, self.rng)
self.weights = w
Layers with trainable weights like Embedding
need to be initialized with the signature (shape and dtype) of the input, and then can be run by calling them.
from trax import layers as tl
# Create an input tensor x.
x = np.arange(15)
print(f'x = {x}')
# Create the embedding layer.
embedding = tl.Embedding(vocab_size=20, d_feature=32)
embedding.init(trax.shapes.signature(x))
# Run the layer -- y = embedding(x).
y = embedding(x)
print(f'shape of y = {y.shape}')
x = [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14]
shape of y = (15, 32)
Models
Models in Trax are built from layers most often using the Serial
and Branch
combinators. You can read more about those combinators in the layers intro and
see the code for many models in trax/models/
, e.g., this is how the Transformer Language Model is implemented. Below is an example of how to build a sentiment classification model.
model = tl.Serial(
tl.Embedding(vocab_size=8192, d_feature=256),
tl.Mean(axis=1), # Average on axis 1 (length of sentence).
tl.Dense(2), # Classify 2 classes.
tl.LogSoftmax() # Produce log-probabilities.
)
# You can print model structure.
print(model)
Serial[
Embedding_8192_256
Mean
Dense_2
LogSoftmax
]
Data
To train your model, you need data. In Trax, data streams are represented as python iterators, so you can call next(data_stream)
and get a tuple, e.g., (inputs, targets)
. Trax allows you to use TensorFlow Datasets easily and you can also get an iterator from your own text file using the standard open('my_file.txt')
.
train_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=True)()
eval_stream = trax.data.TFDS('imdb_reviews', keys=('text', 'label'), train=False)()
print(next(train_stream)) # See one example.
(b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.", 0)
Using the trax.data
module you can create input processing pipelines, e.g., to tokenize and shuffle your data. You create data pipelines using trax.data.Serial
and they are functions that you apply to streams to create processed streams.
data_pipeline = trax.data.Serial(
trax.data.Tokenize(vocab_file='en_8k.subword', keys=[0]),
trax.data.Shuffle(),
trax.data.FilterByLength(max_length=2048, length_keys=[0]),
trax.data.BucketByLength(boundaries=[ 32, 128, 512, 2048],
batch_sizes=[256, 64, 16, 4, 1],
length_keys=[0]),
trax.data.AddLossWeights()
)
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}') # Check the shapes.
shapes = [(4, 1024), (4,), (4,)]
Supervised training
When you have the model and the data, use trax.supervised.training
to define training and eval tasks and create a training loop. The Trax training loop optimizes training and will create TensorBoard logs and model checkpoints for you.
from trax.supervised import training
# Training task.
train_task = training.TrainTask(
labeled_data=train_batches_stream,
loss_layer=tl.WeightedCategoryCrossEntropy(),
optimizer=trax.optimizers.Adam(0.01),
n_steps_per_checkpoint=500,
)
# Evaluaton task.
eval_task = training.EvalTask(
labeled_data=eval_batches_stream,
metrics=[tl.WeightedCategoryCrossEntropy(), tl.WeightedCategoryAccuracy()],
n_eval_batches=20 # For less variance in eval numbers.
)
# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
!rm -rf {output_dir}
training_loop = training.Loop(model,
train_task,
eval_tasks=[eval_task],
output_dir=output_dir)
# Run 2000 steps (batches).
training_loop.run(2000)
Step 1: Ran 1 train steps in 0.78 secs
Step 1: train WeightedCategoryCrossEntropy | 1.33800304
Step 1: eval WeightedCategoryCrossEntropy | 0.71843582
Step 1: eval WeightedCategoryAccuracy | 0.56562500
Step 500: Ran 499 train steps in 5.77 secs
Step 500: train WeightedCategoryCrossEntropy | 0.62914723
Step 500: eval WeightedCategoryCrossEntropy | 0.49253047
Step 500: eval WeightedCategoryAccuracy | 0.74062500
Step 1000: Ran 500 train steps in 5.03 secs
Step 1000: train WeightedCategoryCrossEntropy | 0.42949259
Step 1000: eval WeightedCategoryCrossEntropy | 0.35451687
Step 1000: eval WeightedCategoryAccuracy | 0.83750000
Step 1500: Ran 500 train steps in 4.80 secs
Step 1500: train WeightedCategoryCrossEntropy | 0.41843575
Step 1500: eval WeightedCategoryCrossEntropy | 0.35207348
Step 1500: eval WeightedCategoryAccuracy | 0.82109375
Step 2000: Ran 500 train steps in 5.35 secs
Step 2000: train WeightedCategoryCrossEntropy | 0.38129005
Step 2000: eval WeightedCategoryCrossEntropy | 0.33760912
Step 2000: eval WeightedCategoryAccuracy | 0.85312500
After training the model, run it like any layer to get results.
example_input = next(eval_batches_stream)[0][0]
example_input_str = trax.data.detokenize(example_input, vocab_file='en_8k.subword')
print(f'example input_str: {example_input_str}')
sentiment_log_probs = model(example_input[None, :]) # Add batch dimension.
print(f'Model returned sentiment probabilities: {np.exp(sentiment_log_probs)}')
example input_str: I first saw this when I was a teen in my last year of Junior High. I was riveted to it! I loved the special effects, the fantastic places and the trial-aspect and flashback method of telling the story.<br /><br />Several years later I read the book and while it was interesting and I could definitely see what Swift was trying to say, I think that while it's not as perfect as the book for social commentary, as a story the movie is better. It makes more sense to have it be one long adventure than having Gulliver return after each voyage and making a profit by selling the tiny Lilliput sheep or whatever.<br /><br />It's much more arresting when everyone thinks he's crazy and the sheep DO make a cameo anyway. As a side note, when I saw Laputa I was stunned. It looks very much like the Kingdom of Zeal from the Chrono Trigger video game (1995) that also made me like this mini-series even more.<br /><br />I saw it again about 4 years ago, and realized that I still enjoyed it just as much. Really high quality stuff and began an excellent run of Sweeps mini-series for NBC who followed it up with the solid Merlin and interesting Alice in Wonderland.<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
Model returned sentiment probabilities: [[3.984500e-04 9.996014e-01]]
Top Related Projects
An Open Source Machine Learning Framework for Everyone
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Tensors and Dynamic neural networks in Python with strong GPU acceleration
Deep Learning for humans
DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot