tensor2tensor
Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Top Related Projects
Trax — Deep Learning with Clear Code and Speed
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
Open Source Neural Machine Translation and (Large) Language Models in PyTorch
Conditional Transformer Language Model for Controllable Generation
An open-source NLP research library, built on PyTorch.
Quick Overview
Tensor2Tensor (T2T) is a library of deep learning models and datasets designed to make deep learning more accessible and easier to experiment with. It's built on top of TensorFlow and provides a wide range of pre-implemented models, particularly for natural language processing and computer vision tasks.
Pros
- Extensive collection of state-of-the-art models and datasets
- Easy-to-use API for training, evaluating, and deploying models
- Highly configurable and customizable
- Well-integrated with TensorFlow ecosystem
Cons
- Learning curve can be steep for beginners
- Documentation can be sparse or outdated in some areas
- Some models may require significant computational resources
- Less active development in recent years compared to its peak
Code Examples
- Defining a basic Transformer model:
import tensorflow as tf
import tensor2tensor as t2t
hparams = t2t.models.transformer.transformer_base()
model = t2t.models.Transformer(hparams)
- Preparing data for training:
problem = t2t.problems.translate_ende.TranslateEndeWmt32k()
train_dataset = problem.dataset(tf.estimator.ModeKeys.TRAIN, data_dir)
eval_dataset = problem.dataset(tf.estimator.ModeKeys.EVAL, data_dir)
- Training a model:
t2t_trainer = t2t.utils.trainer_lib.T2TTrainer(
problem=problem,
model=model,
data_dir=data_dir,
train_steps=100000,
eval_steps=100
)
t2t_trainer.train_and_evaluate()
Getting Started
To get started with Tensor2Tensor:
- Install the library:
pip install tensor2tensor
- Import the necessary modules:
import tensorflow as tf
import tensor2tensor as t2t
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils import registry
from tensor2tensor import problems
- Choose a problem and model:
problem = problems.problem("translate_ende_wmt32k")
model = t2t.models.Transformer(t2t.utils.hparams.HParams())
- Set up training:
hparams = trainer_lib.create_hparams("transformer_base")
trainer_lib.add_problem_hparams(hparams, problem)
trainer = t2t.utils.trainer_lib.T2TTrainer(problem, hparams, data_dir="/tmp/t2t_data", train_dir="/tmp/t2t_train")
- Start training:
trainer.train()
Competitor Comparisons
Trax — Deep Learning with Clear Code and Speed
Pros of Trax
- Simpler and more intuitive API, making it easier for beginners to get started
- Better support for JAX and TPU acceleration, potentially offering improved performance
- More focused on sequence models and transformers, which can be beneficial for specific use cases
Cons of Trax
- Smaller community and ecosystem compared to Tensor2Tensor
- Less comprehensive documentation and fewer examples available
- More limited in scope, primarily focusing on sequence models and transformers
Code Comparison
Trax example:
import trax
from trax import layers as tl
model = tl.Serial(
tl.Embedding(vocab_size=1000, d_feature=32),
tl.LSTM(n_units=64),
tl.Dense(10)
)
Tensor2Tensor example:
import tensorflow as tf
from tensor2tensor import models
from tensor2tensor.utils import t2t_model
model = t2t_model.T2TModel(
hparams,
mode=tf.estimator.ModeKeys.TRAIN,
problem_hparams=p_hparams,
data_parallelism=None,
decode_hparams=None
)
Both repositories offer powerful tools for working with sequence models and transformers, but Trax provides a more streamlined experience with a focus on simplicity and performance. Tensor2Tensor, on the other hand, offers a broader range of features and a larger ecosystem, making it more suitable for complex projects and research applications.
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Pros of Transformers
- Broader support for various deep learning frameworks (PyTorch, TensorFlow, JAX)
- More extensive model zoo with pre-trained models and easy fine-tuning
- Active development and frequent updates
Cons of Transformers
- Steeper learning curve for beginners
- Less focus on end-to-end training pipelines
Code Comparison
Tensor2Tensor:
import tensorflow as tf
import tensor2tensor as t2t
problem = t2t.problems.problem("translate_ende_wmt32k")
model = t2t.models.transformer.Transformer(problem.hparams)
Transformers:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
Both repositories focus on implementing transformer-based models, but Transformers offers a more flexible and extensive ecosystem. Tensor2Tensor is more tightly integrated with TensorFlow and provides end-to-end training pipelines, while Transformers emphasizes ease of use and model accessibility across multiple frameworks. The code comparison illustrates the difference in approach, with Transformers offering a more straightforward API for loading pre-trained models and tokenizers.
Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
Pros of fairseq
- Built on PyTorch, offering dynamic computational graphs and easier debugging
- More extensive support for various NLP tasks and architectures
- Active development and frequent updates from Facebook AI Research
Cons of fairseq
- Steeper learning curve for beginners compared to tensor2tensor
- Less integration with TensorFlow ecosystem and tools
- May require more manual configuration for some tasks
Code Comparison
tensor2tensor:
import tensorflow as tf
import tensor2tensor as t2t
problem = t2t.problems.problem("translate_ende_wmt32k")
model = t2t.models.transformer.Transformer(hparams)
fairseq:
from fairseq.models.transformer import TransformerModel
from fairseq.tasks.translation import TranslationTask
task = TranslationTask.setup_task(args)
model = TransformerModel.build_model(args, task)
Both repositories provide powerful tools for sequence-to-sequence tasks, particularly in machine translation. tensor2tensor offers a more streamlined approach with pre-configured problems and models, while fairseq provides greater flexibility and customization options. The choice between them often depends on the user's familiarity with TensorFlow or PyTorch, as well as the specific requirements of the project at hand.
Open Source Neural Machine Translation and (Large) Language Models in PyTorch
Pros of OpenNMT-py
- More focused on neural machine translation tasks
- Easier to use and customize for specific NMT applications
- Better documentation and examples for NMT-specific use cases
Cons of OpenNMT-py
- Less versatile for general sequence-to-sequence tasks
- Smaller community and fewer contributions compared to Tensor2Tensor
- Limited support for non-NMT tasks and architectures
Code Comparison
OpenNMT-py:
import onmt
model = onmt.models.build_model(opt, fields, checkpoint)
translator = onmt.translate.Translator(model, fields, opt, cuda=opt.cuda)
translated = translator.translate(src_data, src_lengths, tgt=tgt_data)
Tensor2Tensor:
import tensorflow as tf
from tensor2tensor import models
from tensor2tensor import problems
problem = problems.problem("translate_ende_wmt32k")
model = models.Transformer(hparams)
inputs = tf.placeholder(tf.int32, shape=[None, None])
outputs = model(inputs)
Both libraries offer powerful tools for sequence-to-sequence tasks, but OpenNMT-py is more specialized for neural machine translation, while Tensor2Tensor provides a broader range of models and applications. The code examples demonstrate the different approaches: OpenNMT-py has a more streamlined API for translation tasks, while Tensor2Tensor offers more flexibility in model selection and problem definition.
Conditional Transformer Language Model for Controllable Generation
Pros of CTRL
- Focuses on controllable text generation with a simpler, more targeted approach
- Provides pre-trained models for immediate use in various applications
- Offers better control over generated text through the use of control codes
Cons of CTRL
- Less flexible for general-purpose machine learning tasks
- Smaller community and fewer resources compared to Tensor2Tensor
- Limited to text generation tasks, while Tensor2Tensor covers a broader range of applications
Code Comparison
CTRL example:
from transformers import CTRLTokenizer, CTRLModel
tokenizer = CTRLTokenizer.from_pretrained("ctrl")
model = CTRLModel.from_pretrained("ctrl")
Tensor2Tensor example:
from tensor2tensor.models import transformer
from tensor2tensor.data_generators import problem
model = transformer.Transformer(...)
Key Differences
- CTRL is specifically designed for controllable text generation, while Tensor2Tensor is a more general-purpose library for various machine learning tasks.
- Tensor2Tensor offers a wider range of models and problems, making it more versatile for different applications.
- CTRL provides pre-trained models that can be easily fine-tuned, whereas Tensor2Tensor often requires more setup and training from scratch.
- Tensor2Tensor has a larger community and more extensive documentation due to its association with TensorFlow.
An open-source NLP research library, built on PyTorch.
Pros of AllenNLP
- More focused on NLP tasks, providing specialized tools and models
- Cleaner, more modular codebase with better documentation
- Built on PyTorch, offering dynamic computation graphs and easier debugging
Cons of AllenNLP
- Smaller community and ecosystem compared to TensorFlow-based projects
- Less versatile for non-NLP machine learning tasks
- Potentially slower training speed on some hardware configurations
Code Comparison
AllenNLP:
from allennlp.data import DatasetReader, Instance
from allennlp.data.fields import TextField
from allennlp.data.token_indexers import SingleIdTokenIndexer
class MyDatasetReader(DatasetReader):
def _read(self, file_path: str) -> Iterable[Instance]:
with open(file_path, "r") as f:
for line in f:
yield self.text_to_instance(line.strip())
Tensor2Tensor:
import tensorflow as tf
from tensor2tensor.data_generators import problem
from tensor2tensor.utils import registry
@registry.register_problem
class MyProblem(problem.Problem):
def generate_samples(self, data_dir, tmp_dir, dataset_split):
with tf.gfile.GFile(data_path, "r") as f:
for line in f:
yield {"inputs": line.strip()}
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
Tensor2Tensor
Tensor2Tensor, or T2T for short, is a library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
T2T was developed by researchers and engineers in the Google Brain team and a community of users. It is now deprecated — we keep it running and welcome bug-fixes, but encourage users to use the successor library Trax.
Quick Start
This iPython notebook explains T2T and runs in your browser using a free VM from Google, no installation needed. Alternatively, here is a one-command version that installs T2T, downloads MNIST, trains a model and evaluates it:
pip install tensor2tensor && t2t-trainer \
--generate_data \
--data_dir=~/t2t_data \
--output_dir=~/t2t_train/mnist \
--problem=image_mnist \
--model=shake_shake \
--hparams_set=shake_shake_quick \
--train_steps=1000 \
--eval_steps=100
Contents
- Suggested Datasets and Models
- Basics
- T2T Overview
- Adding your own components
- Adding a dataset
- Papers
- Run on FloydHub
Suggested Datasets and Models
Below we list a number of tasks that can be solved with T2T when you train the appropriate model on the appropriate problem. We give the problem and model below and we suggest a setting of hyperparameters that we know works well in our setup. We usually run either on Cloud TPUs or on 8-GPU machines; you might need to modify the hyperparameters if you run on a different setup.
Mathematical Language Understanding
For evaluating mathematical expressions at the character level involving addition, subtraction and multiplication of both positive and negative decimal numbers with variable digits assigned to symbolic variables, use
- the MLU data-set:
--problem=algorithmic_math_two_variables
You can try solving the problem with different transformer models and hyperparameters as described in the paper:
- Standard transformer:
--model=transformer
--hparams_set=transformer_tiny
- Universal transformer:
--model=universal_transformer
--hparams_set=universal_transformer_tiny
- Adaptive universal transformer:
--model=universal_transformer
--hparams_set=adaptive_universal_transformer_tiny
Story, Question and Answer
For answering questions based on a story, use
- the bAbi data-set:
--problem=babi_qa_concat_task1_1k
You can choose the bAbi task from the range [1,20] and the subset from 1k or
10k. To combine test data from all tasks into a single test set, use
--problem=babi_qa_concat_all_tasks_10k
Image Classification
For image classification, we have a number of standard data-sets:
- ImageNet (a large data-set):
--problem=image_imagenet
, or one of the re-scaled versions (image_imagenet224
,image_imagenet64
,image_imagenet32
) - CIFAR-10:
--problem=image_cifar10
(or--problem=image_cifar10_plain
to turn off data augmentation) - CIFAR-100:
--problem=image_cifar100
- MNIST:
--problem=image_mnist
For ImageNet, we suggest to use the ResNet or Xception, i.e.,
use --model=resnet --hparams_set=resnet_50
or
--model=xception --hparams_set=xception_base
.
Resnet should get to above 76% top-1 accuracy on ImageNet.
For CIFAR and MNIST, we suggest to try the shake-shake model:
--model=shake_shake --hparams_set=shakeshake_big
.
This setting trained for --train_steps=700000
should yield
close to 97% accuracy on CIFAR-10.
Image Generation
For (un)conditional image generation, we have a number of standard data-sets:
- CelebA:
--problem=img2img_celeba
for image-to-image translation, namely, superresolution from 8x8 to 32x32. - CelebA-HQ:
--problem=image_celeba256_rev
for a downsampled 256x256. - CIFAR-10:
--problem=image_cifar10_plain_gen_rev
for class-conditional 32x32 generation. - LSUN Bedrooms:
--problem=image_lsun_bedrooms_rev
- MS-COCO:
--problem=image_text_ms_coco_rev
for text-to-image generation. - Small ImageNet (a large data-set):
--problem=image_imagenet32_gen_rev
for 32x32 or--problem=image_imagenet64_gen_rev
for 64x64.
We suggest to use the Image Transformer, i.e., --model=imagetransformer
, or
the Image Transformer Plus, i.e., --model=imagetransformerpp
that uses
discretized mixture of logistics, or variational auto-encoder, i.e.,
--model=transformer_ae
.
For CIFAR-10, using --hparams_set=imagetransformer_cifar10_base
or
--hparams_set=imagetransformer_cifar10_base_dmol
yields 2.90 bits per
dimension. For Imagenet-32, using
--hparams_set=imagetransformer_imagenet32_base
yields 3.77 bits per dimension.
Language Modeling
For language modeling, we have these data-sets in T2T:
- PTB (a small data-set):
--problem=languagemodel_ptb10k
for word-level modeling and--problem=languagemodel_ptb_characters
for character-level modeling. - LM1B (a billion-word corpus):
--problem=languagemodel_lm1b32k
for subword-level modeling and--problem=languagemodel_lm1b_characters
for character-level modeling.
We suggest to start with --model=transformer
on this task and use
--hparams_set=transformer_small
for PTB and
--hparams_set=transformer_base
for LM1B.
Sentiment Analysis
For the task of recognizing the sentiment of a sentence, use
- the IMDB data-set:
--problem=sentiment_imdb
We suggest to use --model=transformer_encoder
here and since it is
a small data-set, try --hparams_set=transformer_tiny
and train for
few steps (e.g., --train_steps=2000
).
Speech Recognition
For speech-to-text, we have these data-sets in T2T:
-
Librispeech (US English):
--problem=librispeech
for the whole set and--problem=librispeech_clean
for a smaller but nicely filtered part. -
Mozilla Common Voice (US English):
--problem=common_voice
for the whole set--problem=common_voice_clean
for a quality-checked subset.
Summarization
For summarizing longer text into shorter one we have these data-sets:
- CNN/DailyMail articles summarized into a few sentences:
--problem=summarize_cnn_dailymail32k
We suggest to use --model=transformer
and
--hparams_set=transformer_prepend
for this task.
This yields good ROUGE scores.
Translation
There are a number of translation data-sets in T2T:
- English-German:
--problem=translate_ende_wmt32k
- English-French:
--problem=translate_enfr_wmt32k
- English-Czech:
--problem=translate_encs_wmt32k
- English-Chinese:
--problem=translate_enzh_wmt32k
- English-Vietnamese:
--problem=translate_envi_iwslt32k
- English-Spanish:
--problem=translate_enes_wmt32k
You can get translations in the other direction by appending _rev
to
the problem name, e.g., for German-English use
--problem=translate_ende_wmt32k_rev
(note that you still need to download the original data with t2t-datagen
--problem=translate_ende_wmt32k
).
For all translation problems, we suggest to try the Transformer model:
--model=transformer
. At first it is best to try the base setting,
--hparams_set=transformer_base
. When trained on 8 GPUs for 300K steps
this should reach a BLEU score of about 28 on the English-German data-set,
which is close to state-of-the art. If training on a single GPU, try the
--hparams_set=transformer_base_single_gpu
setting. For very good results
or larger data-sets (e.g., for English-French), try the big model
with --hparams_set=transformer_big
.
See this example to know how the translation works.
Basics
Walkthrough
Here's a walkthrough training a good English-to-German translation model using the Transformer model from Attention Is All You Need on WMT data.
pip install tensor2tensor
# See what problems, models, and hyperparameter sets are available.
# You can easily swap between them (and add new ones).
t2t-trainer --registry_help
PROBLEM=translate_ende_wmt32k
MODEL=transformer
HPARAMS=transformer_base_single_gpu
DATA_DIR=$HOME/t2t_data
TMP_DIR=/tmp/t2t_datagen
TRAIN_DIR=$HOME/t2t_train/$PROBLEM/$MODEL-$HPARAMS
mkdir -p $DATA_DIR $TMP_DIR $TRAIN_DIR
# Generate data
t2t-datagen \
--data_dir=$DATA_DIR \
--tmp_dir=$TMP_DIR \
--problem=$PROBLEM
# Train
# * If you run out of memory, add --hparams='batch_size=1024'.
t2t-trainer \
--data_dir=$DATA_DIR \
--problem=$PROBLEM \
--model=$MODEL \
--hparams_set=$HPARAMS \
--output_dir=$TRAIN_DIR
# Decode
DECODE_FILE=$DATA_DIR/decode_this.txt
echo "Hello world" >> $DECODE_FILE
echo "Goodbye world" >> $DECODE_FILE
echo -e 'Hallo Welt\nAuf Wiedersehen Welt' > ref-translation.de
BEAM_SIZE=4
ALPHA=0.6
t2t-decoder \
--data_dir=$DATA_DIR \
--problem=$PROBLEM \
--model=$MODEL \
--hparams_set=$HPARAMS \
--output_dir=$TRAIN_DIR \
--decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
--decode_from_file=$DECODE_FILE \
--decode_to_file=translation.en
# See the translations
cat translation.en
# Evaluate the BLEU score
# Note: Report this BLEU score in papers, not the internal approx_bleu metric.
t2t-bleu --translation=translation.en --reference=ref-translation.de
Installation
# Assumes tensorflow or tensorflow-gpu installed
pip install tensor2tensor
# Installs with tensorflow-gpu requirement
pip install tensor2tensor[tensorflow_gpu]
# Installs with tensorflow (cpu) requirement
pip install tensor2tensor[tensorflow]
Binaries:
# Data generator
t2t-datagen
# Trainer
t2t-trainer --registry_help
Library usage:
python -c "from tensor2tensor.models.transformer import Transformer"
Features
- Many state of the art and baseline models are built-in and new models can be added easily (open an issue or pull request!).
- Many datasets across modalities - text, audio, image - available for generation and use, and new ones can be added easily (open an issue or pull request for public datasets!).
- Models can be used with any dataset and input mode (or even multiple); all
modality-specific processing (e.g. embedding lookups for text tokens) is done
with
bottom
andtop
transformations, which are specified per-feature in the model. - Support for multi-GPU machines and synchronous (1 master, many workers) and asynchronous (independent workers synchronizing through a parameter server) distributed training.
- Easily swap amongst datasets and models by command-line flag with the data
generation script
t2t-datagen
and the training scriptt2t-trainer
. - Train on Google Cloud ML and Cloud TPUs.
T2T overview
Problems
Problems consist of features such as inputs and targets, and metadata such
as each feature's modality (e.g. symbol, image, audio) and vocabularies. Problem
features are given by a dataset, which is stored as a TFRecord
file with
tensorflow.Example
protocol buffers. All
problems are imported in
all_problems.py
or are registered with @registry.register_problem
. Run
t2t-datagen
to see the list of available problems and download them.
Models
T2TModel
s define the core tensor-to-tensor computation. They apply a
default transformation to each input and output so that models may deal with
modality-independent tensors (e.g. embeddings at the input; and a linear
transform at the output to produce logits for a softmax over classes). All
models are imported in the
models
subpackage,
inherit from T2TModel
,
and are registered with
@registry.register_model
.
Hyperparameter Sets
Hyperparameter sets are encoded in
HParams
objects, and are registered with
@registry.register_hparams
.
Every model and problem has a HParams
. A basic set of hyperparameters are
defined in
common_hparams.py
and hyperparameter set functions can compose other hyperparameter set functions.
Trainer
The trainer binary is the entrypoint for training, evaluation, and
inference. Users can easily switch between problems, models, and hyperparameter
sets by using the --model
, --problem
, and --hparams_set
flags. Specific
hyperparameters can be overridden with the --hparams
flag. --schedule
and
related flags control local and distributed training/evaluation
(distributed training documentation).
Adding your own components
T2T's components are registered using a central registration mechanism that
enables easily adding new ones and easily swapping amongst them by command-line
flag. You can add your own components without editing the T2T codebase by
specifying the --t2t_usr_dir
flag in t2t-trainer
.
You can do so for models, hyperparameter sets, modalities, and problems. Please do submit a pull request if your component might be useful to others.
See the example_usr_dir
for an example user directory.
Adding a dataset
To add a new dataset, subclass
Problem
and register it with @registry.register_problem
. See
TranslateEndeWmt8k
for an example. Also see the data generators
README.
Run on FloydHub
Click this button to open a Workspace on FloydHub. You can use the workspace to develop and test your code on a fully configured cloud GPU machine.
Tensor2Tensor comes preinstalled in the environment, you can simply open a Terminal and run your code.
# Test the quick-start on a Workspace's Terminal with this command
t2t-trainer \
--generate_data \
--data_dir=./t2t_data \
--output_dir=./t2t_train/mnist \
--problem=image_mnist \
--model=shake_shake \
--hparams_set=shake_shake_quick \
--train_steps=1000 \
--eval_steps=100
Note: Ensure compliance with the FloydHub Terms of Service.
Papers
When referencing Tensor2Tensor, please cite this paper.
@article{tensor2tensor,
author = {Ashish Vaswani and Samy Bengio and Eugene Brevdo and
Francois Chollet and Aidan N. Gomez and Stephan Gouws and Llion Jones and
\L{}ukasz Kaiser and Nal Kalchbrenner and Niki Parmar and Ryan Sepassi and
Noam Shazeer and Jakob Uszkoreit},
title = {Tensor2Tensor for Neural Machine Translation},
journal = {CoRR},
volume = {abs/1803.07416},
year = {2018},
url = {http://arxiv.org/abs/1803.07416},
}
Tensor2Tensor was used to develop a number of state-of-the-art models and deep learning methods. Here we list some papers that were based on T2T from the start and benefited from its features and architecture in ways described in the Google Research Blog post introducing T2T.
- Attention Is All You Need
- Depthwise Separable Convolutions for Neural Machine Translation
- One Model To Learn Them All
- Discrete Autoencoders for Sequence Models
- Generating Wikipedia by Summarizing Long Sequences
- Image Transformer
- Training Tips for the Transformer Model
- Self-Attention with Relative Position Representations
- Fast Decoding in Sequence Models using Discrete Latent Variables
- Adafactor: Adaptive Learning Rates with Sublinear Memory Cost
- Universal Transformers
- Attending to Mathematical Language with Transformers
- The Evolved Transformer
- Model-Based Reinforcement Learning for Atari
- VideoFlow: A Flow-Based Generative Model for Video
NOTE: This is not an official Google product.
Top Related Projects
Trax — Deep Learning with Clear Code and Speed
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
Open Source Neural Machine Translation and (Large) Language Models in PyTorch
Conditional Transformer Language Model for Controllable Generation
An open-source NLP research library, built on PyTorch.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot