Top Related Projects
An Open Source Machine Learning Framework for Everyone
Tensors and Dynamic neural networks in Python with strong GPU acceleration
ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
scikit-learn: machine learning in Python
Quick Overview
Keras is a high-level neural networks API, written in Python and capable of running on top of TensorFlow, CNTK, or Theano. It was developed with a focus on enabling fast experimentation and ease of use. Keras is designed to be user-friendly, modular, and extensible, making it suitable for both research and production environments.
Pros
- Easy to use and intuitive API, allowing for quick prototyping of deep learning models
- Supports both convolutional networks and recurrent networks, as well as combinations of the two
- Runs seamlessly on CPU and GPU
- Has a large and active community, providing extensive documentation and support
Cons
- Can be slower compared to lower-level libraries for certain operations
- Less flexibility compared to using TensorFlow or PyTorch directly
- May have a steeper learning curve for users who want to implement custom layers or loss functions
- Occasional compatibility issues when transitioning between different backend engines
Code Examples
- Creating a simple sequential model:
from tensorflow import keras
model = keras.Sequential([
keras.layers.Dense(64, activation='relu', input_shape=(784,)),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
- Compiling and training a model:
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5, batch_size=32)
- Making predictions with a trained model:
predictions = model.predict(x_test)
Getting Started
To get started with Keras, follow these steps:
- Install Keras and TensorFlow:
pip install tensorflow
- Import Keras and create a simple model:
from tensorflow import keras
model = keras.Sequential([
keras.layers.Dense(32, activation='relu', input_shape=(100,)),
keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
- Train the model with your data:
model.fit(x_train, y_train, epochs=10, batch_size=32, validation_split=0.2)
This quick start guide assumes you have your data prepared as NumPy arrays (x_train
and y_train
). For more detailed information and advanced usage, refer to the official Keras documentation.
Competitor Comparisons
An Open Source Machine Learning Framework for Everyone
Pros of TensorFlow
- More comprehensive and flexible, offering lower-level control
- Supports distributed computing and deployment on various platforms
- Extensive ecosystem with tools like TensorBoard for visualization
Cons of TensorFlow
- Steeper learning curve, especially for beginners
- More verbose code, requiring more lines to accomplish tasks
- Can be slower to prototype compared to Keras
Code Comparison
Keras:
from keras.models import Sequential
from keras.layers import Dense
model = Sequential([
Dense(64, activation='relu', input_shape=(10,)),
Dense(1, activation='sigmoid')
])
TensorFlow:
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(10,)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
The code comparison shows that TensorFlow now includes Keras as its high-level API, making the syntax very similar. However, TensorFlow offers more flexibility and lower-level control when needed, while Keras focuses on simplicity and ease of use.
Tensors and Dynamic neural networks in Python with strong GPU acceleration
Pros of PyTorch
- More flexible and dynamic computational graph
- Better support for debugging and research-oriented tasks
- Closer to Python programming style, making it more intuitive for many developers
Cons of PyTorch
- Steeper learning curve for beginners
- Smaller ecosystem of pre-built models and tools compared to Keras
- Less integrated with production deployment tools
Code Comparison
Keras:
from keras.models import Sequential
from keras.layers import Dense
model = Sequential([
Dense(64, activation='relu', input_shape=(10,)),
Dense(1, activation='sigmoid')
])
PyTorch:
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(10, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.layers(x)
model = Model()
The code comparison shows that PyTorch requires more explicit definition of the model structure, while Keras offers a more concise and high-level API. PyTorch's approach provides greater flexibility but may require more code for simple models.
ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
Pros of ONNX Runtime
- Cross-platform inference optimization for various hardware
- Supports a wide range of ML frameworks and models
- Provides better performance and reduced inference times
Cons of ONNX Runtime
- Steeper learning curve for beginners
- Less focus on model training compared to Keras
- May require additional steps to convert models from other frameworks
Code Comparison
ONNX Runtime example:
import onnxruntime as ort
session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
output = session.run(None, {input_name: input_data})
Keras example:
from keras.models import load_model
model = load_model('model.h5')
predictions = model.predict(input_data)
ONNX Runtime focuses on optimized inference across platforms, while Keras provides a more intuitive API for both model building and inference. ONNX Runtime excels in deployment scenarios, especially when performance is crucial. Keras, on the other hand, offers a more straightforward approach for beginners and is deeply integrated with TensorFlow, making it easier to build and train models from scratch.
Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
Pros of MXNet
- Better performance and scalability for distributed training
- More flexible and customizable for low-level operations
- Supports multiple programming languages (Python, C++, R, Julia, etc.)
Cons of MXNet
- Steeper learning curve and more complex API
- Smaller community and ecosystem compared to Keras
- Less frequent updates and maintenance
Code Comparison
MXNet:
import mxnet as mx
from mxnet import nd, autograd, gluon
data = nd.random.normal(shape=(100, 1))
label = 2 * data + 1 + 0.1 * nd.random.normal(shape=(100, 1))
Keras:
import tensorflow as tf
from tensorflow import keras
data = tf.random.normal((100, 1))
label = 2 * data + 1 + 0.1 * tf.random.normal((100, 1))
Both frameworks provide similar functionality for creating and manipulating tensors, but MXNet uses the nd
module while Keras uses TensorFlow's built-in functions. MXNet's syntax is slightly more verbose, reflecting its lower-level nature compared to Keras' high-level abstractions.
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
Pros of JAX
- Offers automatic differentiation and GPU/TPU acceleration
- Provides a more flexible, lower-level API for numerical computing
- Supports just-in-time (JIT) compilation for improved performance
Cons of JAX
- Steeper learning curve compared to Keras' high-level API
- Smaller ecosystem and fewer pre-built models/layers
- Less focus on production deployment tools
Code Comparison
Keras example:
from keras.models import Sequential
from keras.layers import Dense
model = Sequential([
Dense(64, activation='relu', input_shape=(10,)),
Dense(1, activation='sigmoid')
])
JAX example:
import jax.numpy as jnp
from jax import random, grad
def model(params, x):
w1, b1, w2, b2 = params
h = jnp.tanh(jnp.dot(x, w1) + b1)
return jnp.dot(h, w2) + b2
JAX offers more flexibility but requires more low-level implementation, while Keras provides a higher-level abstraction for quick model building. JAX is better suited for research and custom algorithm development, whereas Keras excels in rapid prototyping and production deployment of standard deep learning models.
scikit-learn: machine learning in Python
Pros of scikit-learn
- Broader range of machine learning algorithms and tools
- Simpler API for traditional ML tasks
- Better integration with scientific Python ecosystem (NumPy, SciPy, Pandas)
Cons of scikit-learn
- Limited support for deep learning and neural networks
- Less optimized for GPU acceleration
- Slower for large-scale data processing compared to Keras
Code Comparison
scikit-learn:
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier()
model.fit(X_train, y_train)
predictions = model.predict(X_test)
Keras:
from keras.models import Sequential
from keras.layers import Dense
model = Sequential([Dense(64, activation='relu', input_shape=(10,)),
Dense(1, activation='sigmoid')])
model.compile(optimizer='adam', loss='binary_crossentropy')
model.fit(X_train, y_train, epochs=10, batch_size=32)
scikit-learn excels in traditional machine learning tasks with a simple API, while Keras specializes in deep learning with a more flexible and powerful neural network architecture. scikit-learn integrates better with the scientific Python ecosystem, but Keras offers superior performance for large-scale data and deep learning tasks.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
Keras 3: Deep Learning for Humans
Keras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, and PyTorch. Effortlessly build and train models for computer vision, natural language processing, audio processing, timeseries forecasting, recommender systems, etc.
- Accelerated model development: Ship deep learning solutions faster thanks to the high-level UX of Keras and the availability of easy-to-debug runtimes like PyTorch or JAX eager execution.
- State-of-the-art performance: By picking the backend that is the fastest for your model architecture (often JAX!), leverage speedups ranging from 20% to 350% compared to other frameworks. Benchmark here.
- Datacenter-scale training: Scale confidently from your laptop to large clusters of GPUs or TPUs.
Join nearly three million developers, from burgeoning startups to global enterprises, in harnessing the power of Keras 3.
Installation
Install with pip
Keras 3 is available on PyPI as keras
. Note that Keras 2 remains available as the tf-keras
package.
- Install
keras
:
pip install keras --upgrade
- Install backend package(s).
To use keras
, you should also install the backend of choice: tensorflow
, jax
, or torch
.
Note that tensorflow
is required for using certain Keras 3 features: certain preprocessing layers
as well as tf.data
pipelines.
Local installation
Minimal installation
Keras 3 is compatible with Linux and MacOS systems. For Windows users, we recommend using WSL2 to run Keras. To install a local development version:
- Install dependencies:
pip install -r requirements.txt
- Run installation command from the root directory.
python pip_build.py --install
- Run API generation script when creating PRs that update
keras_export
public APIs:
./shell/api_gen.sh
Adding GPU support
The requirements.txt
file will install a CPU-only version of TensorFlow, JAX, and PyTorch. For GPU support, we also
provide a separate requirements-{backend}-cuda.txt
for TensorFlow, JAX, and PyTorch. These install all CUDA
dependencies via pip
and expect a NVIDIA driver to be pre-installed. We recommend a clean python environment for each
backend to avoid CUDA version mismatches. As an example, here is how to create a Jax GPU environment with conda
:
conda create -y -n keras-jax python=3.10
conda activate keras-jax
pip install -r requirements-jax-cuda.txt
python pip_build.py --install
Configuring your backend
You can export the environment variable KERAS_BACKEND
or you can edit your local config file at ~/.keras/keras.json
to configure your backend. Available backend options are: "tensorflow"
, "jax"
, "torch"
. Example:
export KERAS_BACKEND="jax"
In Colab, you can do:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
Note: The backend must be configured before importing keras
, and the backend cannot be changed after
the package has been imported.
Backwards compatibility
Keras 3 is intended to work as a drop-in replacement for tf.keras
(when using the TensorFlow backend). Just take your
existing tf.keras
code, make sure that your calls to model.save()
are using the up-to-date .keras
format, and you're
done.
If your tf.keras
model does not include custom components, you can start running it on top of JAX or PyTorch immediately.
If it does include custom components (e.g. custom layers or a custom train_step()
), it is usually possible to convert it
to a backend-agnostic implementation in just a few minutes.
In addition, Keras models can consume datasets in any format, regardless of the backend you're using:
you can train your models with your existing tf.data.Dataset
pipelines or PyTorch DataLoaders
.
Why use Keras 3?
- Run your high-level Keras workflows on top of any framework -- benefiting at will from the advantages of each framework, e.g. the scalability and performance of JAX or the production ecosystem options of TensorFlow.
- Write custom components (e.g. layers, models, metrics) that you can use in low-level workflows in any framework.
- You can take a Keras model and train it in a training loop written from scratch in native TF, JAX, or PyTorch.
- You can take a Keras model and use it as part of a PyTorch-native
Module
or as part of a JAX-native model function.
- Make your ML code future-proof by avoiding framework lock-in.
- As a PyTorch user: get access to power and usability of Keras, at last!
- As a JAX user: get access to a fully-featured, battle-tested, well-documented modeling and training library.
Read more in the Keras 3 release announcement.
Top Related Projects
An Open Source Machine Learning Framework for Everyone
Tensors and Dynamic neural networks in Python with strong GPU acceleration
ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
scikit-learn: machine learning in Python
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot