Top Related Projects
Fast and memory-efficient exact attention
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
Hackable and optimized Transformers building blocks, supporting a composable construction.
Development repository for the Triton language and compiler
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Quick Overview
Flash Attention is an optimized attention algorithm that significantly speeds up transformer models while using less memory. It's designed to overcome the quadratic compute and memory complexity of standard attention mechanisms, making it particularly useful for training large language models and other transformer-based architectures.
Pros
- Dramatically reduces memory usage and increases speed for attention computations
- Enables training of larger models and longer sequences than previously possible
- Compatible with various transformer architectures and can be easily integrated into existing codebases
- Provides both CUDA and Triton implementations for flexibility
Cons
- Primarily focused on NVIDIA GPUs, limiting its use on other hardware
- Requires some familiarity with CUDA or Triton for optimal usage and customization
- May require modifications to existing model architectures to fully leverage its benefits
- Still an evolving project, which may lead to frequent updates and potential breaking changes
Code Examples
- Basic usage of FlashAttention:
import torch
from flash_attn.flash_attention import FlashAttention
batch_size, seq_len, hidden_dim = 32, 1024, 512
q = torch.randn(batch_size, seq_len, hidden_dim).cuda()
k = torch.randn(batch_size, seq_len, hidden_dim).cuda()
v = torch.randn(batch_size, seq_len, hidden_dim).cuda()
flash_attn = FlashAttention()
output, _ = flash_attn(q, k, v)
- Using FlashAttention with a causal mask:
from flash_attn.flash_attention import FlashAttention
flash_attn = FlashAttention(causal=True)
output, _ = flash_attn(q, k, v)
- Integrating FlashAttention into a custom attention module:
import torch.nn as nn
from flash_attn.flash_attention import FlashAttention
class CustomAttention(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.flash_attn = FlashAttention()
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
def forward(self, x):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
return self.flash_attn(q, k, v)[0]
Getting Started
To use Flash Attention in your project:
- Install the package:
pip install flash-attn
- Import and use in your code:
import torch
from flash_attn.flash_attention import FlashAttention
# Initialize FlashAttention
flash_attn = FlashAttention()
# Use in your model
output, _ = flash_attn(q, k, v)
Note: Ensure you have the necessary CUDA toolkit and compatible GPU for optimal performance.
Competitor Comparisons
Fast and memory-efficient exact attention
Pros of flash-attention
- Faster and more memory-efficient attention mechanism
- Supports various attention patterns and sparsity structures
- Integrates well with popular deep learning frameworks
Cons of flash-attention
- May require additional setup and configuration
- Limited compatibility with older hardware
- Potential learning curve for developers new to the concept
Code Comparison
flash-attention:
from flash_attention import FlashAttention
attention = FlashAttention(dim=512, num_heads=8)
output = attention(query, key, value)
flash-attention:
from flash_attention import FlashAttention
attention = FlashAttention(dim=512, num_heads=8)
output = attention(query, key, value)
Note: The code comparison shows identical snippets because both repositories refer to the same project. The flash-attention repository is the main implementation, and there is no separate flash-attention repository for comparison.
In this case, the comparison is between the same repository, so the pros and cons listed are general characteristics of the flash-attention project rather than a direct comparison between two different repositories.
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
Pros of Apex
- Broader scope: Offers a suite of optimization tools beyond just attention mechanisms
- NVIDIA-backed: Potentially better support for NVIDIA GPUs and integration with other NVIDIA tools
- Mature project: Longer development history and wider adoption in the community
Cons of Apex
- Less specialized: May not offer the same level of performance optimization for attention mechanisms
- Heavier: Larger codebase and potentially more overhead due to its broader feature set
- NVIDIA-focused: May not be as optimized for non-NVIDIA hardware
Code Comparison
Flash-Attention:
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v, mask=mask, causal=True)
Apex:
from apex import amp
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
Summary
Flash-Attention focuses specifically on optimizing attention mechanisms, while Apex provides a broader set of optimization tools. Flash-Attention may offer better performance for attention-specific tasks, while Apex provides a more comprehensive suite of optimization techniques backed by NVIDIA. The choice between the two depends on the specific requirements of the project and the hardware being used.
DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
Pros of DeepSpeed
- Broader scope: Offers a comprehensive suite of optimization techniques beyond just attention mechanisms
- Flexibility: Supports various model architectures and training scenarios
- Integration: Seamlessly integrates with popular deep learning frameworks like PyTorch
Cons of DeepSpeed
- Complexity: May have a steeper learning curve due to its extensive feature set
- Overhead: Potential performance overhead for simpler models or smaller datasets
- Specificity: Less specialized in attention optimization compared to Flash-Attention
Code Comparison
Flash-Attention:
from flash_attn import flash_attn_func
attn_output = flash_attn_func(q, k, v, dropout_p=0.1, causal=True)
DeepSpeed:
import deepspeed
model_engine, optimizer, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=params)
loss = model_engine(inputs)
model_engine.backward(loss)
model_engine.step()
While Flash-Attention focuses specifically on optimizing attention mechanisms, DeepSpeed provides a more comprehensive toolkit for training large-scale models efficiently. Flash-Attention may be more suitable for projects primarily concerned with attention optimization, while DeepSpeed offers a broader range of optimizations and scaling techniques for various deep learning tasks.
Hackable and optimized Transformers building blocks, supporting a composable construction.
Pros of xformers
- Broader scope: Covers a wide range of transformer optimizations beyond attention mechanisms
- More extensive documentation and examples
- Integrated with PyTorch ecosystem and other Facebook AI tools
Cons of xformers
- Less specialized: May not achieve the same level of performance optimization for attention as Flash Attention
- Potentially more complex to use due to its broader feature set
- Slower development cycle compared to Flash Attention
Code Comparison
Flash Attention:
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v, causal=True)
xformers:
from xformers.ops import memory_efficient_attention
output = memory_efficient_attention(q, k, v, attn_bias=None)
Both libraries aim to optimize attention mechanisms, but Flash Attention focuses specifically on efficient attention computation, while xformers provides a broader set of transformer-related optimizations. Flash Attention may offer better performance for attention operations, while xformers provides more flexibility and integration with other PyTorch tools.
Development repository for the Triton language and compiler
Pros of Triton
- More general-purpose: Triton is a language and compiler for GPU programming, offering broader applicability beyond attention mechanisms
- Flexibility: Allows for custom kernel development and optimization across various GPU operations
- Active community: Larger user base and more frequent updates
Cons of Triton
- Steeper learning curve: Requires understanding of GPU programming concepts and Triton-specific syntax
- Less specialized: May not provide the same level of optimization for attention mechanisms as Flash-Attention
Code Comparison
Flash-Attention:
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v, causal=True)
Triton:
import triton
import triton.language as tl
@triton.jit
def attention_kernel(q_ptr, k_ptr, v_ptr, output_ptr, ...):
# Custom attention implementation
The Flash-Attention example shows a simple function call, while Triton requires writing custom CUDA-like kernels for GPU operations.
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Pros of transformers
- Comprehensive library with a wide range of pre-trained models
- Easy-to-use API for fine-tuning and inference
- Extensive documentation and community support
Cons of transformers
- Larger memory footprint and slower inference for some models
- Less optimized for specific hardware acceleration techniques
Code comparison
transformers:
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
inputs = tokenizer("Hello world!", return_tensors="pt")
outputs = model(**inputs)
flash-attention:
import torch
from flash_attn.flash_attention import FlashAttention
flash_attn = FlashAttention(causal=True)
q, k, v = torch.randn(4, 512, 8, 64).cuda()
output = flash_attn(q, k, v)
flash-attention focuses on optimizing attention computations, while transformers provides a higher-level interface for working with various transformer models. flash-attention offers faster and more memory-efficient attention operations, particularly beneficial for large-scale models and specific hardware setups. However, transformers provides a more comprehensive ecosystem for working with a wide range of pre-trained models and NLP tasks.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
FlashAttention
This repository provides the official implementation of FlashAttention and FlashAttention-2 from the following papers.
FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
Paper: https://arxiv.org/abs/2205.14135
IEEE Spectrum article about our submission to the MLPerf 2.0 benchmark using FlashAttention.
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
Tri Dao
Paper: https://tridao.me/publications/flash2/flash2.pdf
Usage
We've been very happy to see FlashAttention being widely adopted in such a short time after its release. This page contains a partial list of places where FlashAttention is being used.
FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). Please cite and credit FlashAttention if you use it.
FlashAttention-3 beta release
FlashAttention-3 is optimized for Hopper GPUs (e.g. H100).
Blogpost: https://tridao.me/blog/2024/flash3/
Paper: https://tridao.me/publications/flash3/flash3.pdf
This is a beta release for testing / benchmarking before we integrate that with the rest of the repo.
Currently released:
- FP16 forward and backward
Coming soon in the next couple of days / next week:
- BF16
- Variable length (FP16, BF16)
- FP8 forward.
Requirements: H100 / H800 GPU, CUDA >= 12.3.
To install:
cd hopper
python setup.py install
To run the test:
export PYTHONPATH=$PWD
pytest -q -s test_flash_attn.py
Installation and features
Requirements:
- CUDA toolkit or ROCm toolkit
- PyTorch 1.12 and above.
packaging
Python package (pip install packaging
)ninja
Python package (pip install ninja
) *- Linux. Might work for Windows starting v2.3.2 (we've seen a few positive reports) but Windows compilation still requires more testing. If you have ideas on how to set up prebuilt CUDA wheels for Windows, please reach out via Github issue.
* Make sure that ninja
is installed and that it works correctly (e.g. ninja --version
then echo $?
should return exit code 0). If not (sometimes ninja --version
then echo $?
returns a nonzero exit code), uninstall then reinstall
ninja
(pip uninstall -y ninja && pip install ninja
). Without ninja
,
compiling can take a very long time (2h) since it does not use multiple CPU
cores. With ninja
compiling takes 3-5 minutes on a 64-core machine using CUDA toolkit.
To install:
pip install flash-attn --no-build-isolation
Alternatively you can compile from source:
python setup.py install
If your machine has less than 96GB of RAM and lots of CPU cores, ninja
might
run too many parallel compilation jobs that could exhaust the amount of RAM. To
limit the number of parallel compilation jobs, you can set the environment
variable MAX_JOBS
:
MAX_JOBS=4 pip install flash-attn --no-build-isolation
Interface: src/flash_attention_interface.py
NVIDIA CUDA Support
Requirements:
- CUDA 11.7 and above.
We recommend the Pytorch container from Nvidia, which has all the required tools to install FlashAttention.
FlashAttention-2 with CUDA currently supports:
- Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing GPUs for now.
- Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
- All head dimensions up to 256.
Head dim > 192 backward requires A100/A800 or H100/H800. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.
AMD ROCm Support
ROCm version uses composable_kernel as the backend. It provides the implementation of FlashAttention-2.
Requirements:
- ROCm 6.0 and above.
We recommend the Pytorch container from ROCm, which has all the required tools to install FlashAttention.
FlashAttention-2 with ROCm currently supports:
- MI200 or MI300 GPUs.
- Datatype fp16 and bf16
- Forward's head dimensions up to 256. Backward head dimensions up to 128.
How to use FlashAttention
The main functions implement scaled dot product attention (softmax(Q @ K^T * softmax_scale) @ V):
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None, deterministic=False):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k: (batch_size, seqlen, nheads_k, headdim)
v: (batch_size, seqlen, nheads_k, headdim)
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
deterministic: bool. Whether to use the deterministic implementation of the backward pass,
which is slightly slower and uses more memory. The forward pass is always deterministic.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
rotary_interleaved=True,
alibi_slopes=None,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.
If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.
Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).
See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.
Note: Does not support backward pass.
Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
page_block_size must be a multiple of 256.
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k.
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
Return:
out: (batch_size, seqlen, nheads, headdim).
"""
To see how these functions are used in a multi-head attention layer (which includes QKV projection, output projection), see the MHA implementation.
Changelog
2.0: Complete rewrite, 2x faster
Upgrading from FlashAttention (1.x) to FlashAttention-2
These functions have been renamed:
flash_attn_unpadded_func
->flash_attn_varlen_func
flash_attn_unpadded_qkvpacked_func
->flash_attn_varlen_qkvpacked_func
flash_attn_unpadded_kvpacked_func
->flash_attn_varlen_kvpacked_func
If the inputs have the same sequence lengths in the same batch, it is simpler and faster to use these functions:
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False)
2.1: Change behavior of causal flag
If seqlen_q != seqlen_k and causal=True, the causal mask is aligned to the bottom right corner of the attention matrix, instead of the top-left corner.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 =
masked out) is:
v2.0:
1 0 0 0 0
1 1 0 0 0
v2.1:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
v2.0:
1 0
1 1
1 1
1 1
1 1
v2.1:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
2.2: Optimize for inference
Optimize for inference (iterative decoding) when query has very small sequence length (e.g., query sequence length = 1). The bottleneck here is to load KV cache as fast as possible, and we split the loading across different thread blocks, with a separate kernel to combine results.
See the function flash_attn_with_kvcache
with more features for inference
(perform rotary embedding, updating KV cache inplace).
Thanks to the xformers team, and in particular Daniel Haziza, for this collaboration.
2.3: Local (i.e., sliding window) attention
Implement sliding window attention (i.e., local attention). Thanks to Mistral AI and in particular Timothée Lacroix for this contribution. Sliding window was used in the Mistral 7B model.
2.4: ALiBi (attention with linear bias), deterministic backward pass.
Implement ALiBi (Press et al., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
Implement deterministic backward pass. Thanks to engineers from Meituan for this contribution.
2.5: Paged KV cache.
Support paged KV cache (i.e., PagedAttention). Thanks to @beginlner for this contribution.
2.6: Softcapping.
Support attention with softcapping, as used in Gemma-2 and Grok models. Thanks to @Narsil and @lucidrains for this contribution.
Performance
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
We currently have benchmarks for these GPUs:
A100
We display FlashAttention speedup using these parameters:
- Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads).
- Sequence length 512, 1k, 2k, 4k, 8k, 16k.
- Batch size set to 16k / seqlen.
Speedup
Memory
We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking). Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. We see 10X memory savings at sequence length 2K, and 20X at 4K. As a result, FlashAttention can scale to much longer sequence lengths.
H100
Full model code and training script
We have released the full GPT model implementation. We also provide optimized implementations of other layers (e.g., MLP, LayerNorm, cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x compared to the baseline implementation from Huggingface, reaching up to 225 TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need any activation checkpointing).
We also include a training script to train GPT2 on Openwebtext and GPT3 on The Pile.
Triton implementation of FlashAttention
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton: https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
As Triton is a higher-level language than CUDA, it might be easier to understand and experiment with. The notations in the Triton implementation are also closer to what's used in our paper.
We also have an experimental implementation in Triton that support attention bias (e.g. ALiBi): https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
Tests
We test that FlashAttention produces the same output and gradient as a reference implementation, up to some numerical tolerance. In particular, we check that the maximum numerical error of FlashAttention is at most twice the numerical error of a baseline implementation in Pytorch (for different head dimensions, input dtype, sequence length, causal / non-causal).
To run the tests:
pytest -q -s tests/test_flash_attn.py
When you encounter issues
This new release of FlashAttention-2 has been tested on several GPT-style models, mostly on A100 GPUs.
If you encounter bugs, please open a GitHub Issue!
Tests
To run the tests:
pytest tests/test_flash_attn_ck.py
Citation
If you use this codebase, or otherwise found our work valuable, please cite:
@inproceedings{dao2022flashattention,
title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2022}
}
@inproceedings{dao2023flashattention2,
title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning},
author={Dao, Tri},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}
Top Related Projects
Fast and memory-efficient exact attention
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch
DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.
Hackable and optimized Transformers building blocks, supporting a composable construction.
Development repository for the Triton language and compiler
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot