Convert Figma logo to code with AI

state-spaces logomamba

Mamba SSM architecture

12,564
1,057
12,564
346

Top Related Projects

Fast and memory-efficient exact attention

34,658

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.

Hackable and optimized Transformers building blocks, supporting a composable construction.

An implementation of model parallel autoregressive transformers on GPUs, based on the Megatron and DeepSpeed libraries

2,656

Quick Overview

Mamba is a state-space model architecture that serves as an efficient alternative to Transformers. It's designed for sequence modeling tasks and offers linear time complexity and constant memory usage with respect to sequence length, making it particularly suitable for processing long sequences.

Pros

  • Linear time and memory complexity, enabling efficient processing of long sequences
  • Competitive performance with Transformers on various language tasks
  • Faster training and inference times compared to traditional Transformer models
  • Flexible architecture that can be adapted to different sequence modeling tasks

Cons

  • Relatively new technology, which may lack extensive community support and resources
  • May require additional optimization for specific use cases
  • Limited pre-trained models available compared to established Transformer architectures
  • Potential learning curve for developers familiar with traditional Transformer models

Code Examples

  1. Installing Mamba:
pip install mamba-ssm
  1. Importing and initializing a Mamba model:
import torch
from mamba_ssm import Mamba

model = Mamba(
    d_model=768,
    d_state=16,
    d_conv=4,
    expand=2,
)
  1. Processing a sequence with Mamba:
# Create a random input sequence
x = torch.randn(1, 1024, 768)  # (batch_size, sequence_length, d_model)

# Process the sequence
output = model(x)
print(output.shape)  # Should be (1, 1024, 768)
  1. Using Mamba for classification:
from mamba_ssm import MambaLMHeadModel

model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b")
inputs = tokenizer("Hello, how are you?", return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits

Getting Started

To get started with Mamba:

  1. Install the library:

    pip install mamba-ssm
    
  2. Import and initialize a Mamba model:

    from mamba_ssm import Mamba
    
    model = Mamba(d_model=768, d_state=16, d_conv=4, expand=2)
    
  3. Prepare your input data and process it:

    import torch
    
    x = torch.randn(1, 1024, 768)  # Example input
    output = model(x)
    
  4. For more advanced usage, refer to the documentation and examples in the GitHub repository.

Competitor Comparisons

Fast and memory-efficient exact attention

Pros of Flash-Attention

  • Optimized for efficient attention computation in transformers
  • Supports various attention patterns (e.g., causal, bidirectional)
  • Widely adopted in the ML community for performance improvements

Cons of Flash-Attention

  • Primarily focused on attention mechanisms, not a complete architecture
  • May require more integration effort for full model implementation
  • Limited to GPU acceleration, potentially less flexible for other hardware

Code Comparison

Flash-Attention:

from flash_attn import flash_attn_func

attn_output = flash_attn_func(q, k, v, causal=True)

Mamba:

from mamba_ssm import Mamba

model = Mamba(d_model, d_state, d_conv, expand)
output = model(input_ids)

Key Differences

  • Flash-Attention focuses on optimizing attention computation, while Mamba introduces a new architecture based on state spaces
  • Mamba aims to replace attention mechanisms entirely, potentially offering better scalability
  • Flash-Attention is more of a drop-in replacement for existing transformer models, whereas Mamba requires a different approach to model design

Use Cases

  • Flash-Attention: Improving performance of existing transformer-based models
  • Mamba: Exploring alternative architectures for sequence modeling, potentially surpassing transformers in efficiency and scalability
34,658

DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.

Pros of DeepSpeed

  • Extensive optimization techniques for large-scale model training
  • Broader compatibility with various deep learning frameworks
  • More comprehensive documentation and community support

Cons of DeepSpeed

  • Higher complexity and steeper learning curve
  • May require more setup and configuration for optimal performance

Code Comparison

Mamba:

from mamba import Mamba

model = Mamba(
    d_model=768,
    d_state=16,
    d_conv=4,
    expand=2
)

DeepSpeed:

import deepspeed

model_engine, optimizer, _, _ = deepspeed.initialize(
    args=args,
    model=model,
    model_parameters=params
)

Key Differences

  • Mamba focuses on state space models and selective state updates
  • DeepSpeed offers a broader range of optimization techniques
  • Mamba's implementation is more straightforward for specific use cases
  • DeepSpeed provides more flexibility for various model architectures

Use Cases

  • Mamba: Ideal for sequence modeling tasks with long-range dependencies
  • DeepSpeed: Suitable for large-scale model training across different domains

Community and Support

  • Mamba: Growing community, focused on state space models
  • DeepSpeed: Large, established community with extensive resources

Hackable and optimized Transformers building blocks, supporting a composable construction.

Pros of xformers

  • Broader scope: Offers a wide range of transformer-related components and optimizations
  • Established ecosystem: Part of the Facebook Research suite, with extensive documentation and community support
  • Flexibility: Provides modular building blocks for custom transformer architectures

Cons of xformers

  • Complexity: May have a steeper learning curve due to its extensive feature set
  • Performance: Potentially slower for certain tasks compared to Mamba's state space models
  • Resource usage: Can be more memory-intensive for large-scale transformer models

Code Comparison

xformers:

from xformers.components import MultiHeadDispatch

attention = MultiHeadDispatch(
    dim_model=512,
    num_heads=8,
    attention_dropout=0.1,
    residual_dropout=0.1
)

Mamba:

from mamba import Mamba

model = Mamba(
    d_model=512,
    d_state=16,
    d_conv=4,
    expand=2
)

Both repositories offer powerful tools for building efficient deep learning models, but they focus on different approaches. xformers provides a comprehensive toolkit for transformer-based architectures, while Mamba specializes in state space models, potentially offering better performance for certain tasks with lower computational requirements.

An implementation of model parallel autoregressive transformers on GPUs, based on the Megatron and DeepSpeed libraries

Pros of gpt-neox

  • Designed for large-scale language model training
  • Extensive documentation and community support
  • Optimized for distributed training across multiple GPUs/nodes

Cons of gpt-neox

  • Higher computational requirements
  • More complex setup and configuration
  • Potentially slower inference for smaller models

Code Comparison

gpt-neox:

model = GPTNeoX(num_layers=12, hidden_size=768, num_attention_heads=12)
output = model(input_ids, attention_mask=attention_mask)

mamba:

model = Mamba(d_model=768, d_state=16, d_conv=4, expand=2)
output = model(x)

Key Differences

  • gpt-neox focuses on traditional transformer architecture, while mamba introduces state-space models
  • mamba aims for faster inference and lower memory usage
  • gpt-neox provides more flexibility for large-scale training scenarios

Use Cases

  • gpt-neox: Large language model training, research on transformer architectures
  • mamba: Efficient sequence modeling, real-time applications, resource-constrained environments

Community and Support

  • gpt-neox: Larger community, more third-party resources
  • mamba: Growing community, active development, newer technology
2,656

Pros of t5x

  • Extensive documentation and examples for training and fine-tuning T5 models
  • Built on JAX and Flax, offering efficient and scalable training on TPUs
  • Supports a wide range of NLP tasks and model architectures

Cons of t5x

  • Primarily focused on T5 models, limiting flexibility for other architectures
  • Steeper learning curve due to JAX/Flax ecosystem complexity
  • Less suitable for deployment in production environments

Code Comparison

t5x example:

import jax
from t5x import models
from t5x import utils

model = models.EncoderDecoderModel(...)
trainer = utils.Trainer(model=model, ...)
trainer.train(...)

Mamba example:

from mamba import Mamba

model = Mamba(d_model=128, d_state=16, d_conv=4, expand=2)
x = torch.randn(1, 1024, 128)
y = model(x)

Key Differences

  • t5x is tailored for T5 models and NLP tasks, while Mamba focuses on state space models
  • t5x uses JAX/Flax, whereas Mamba is built on PyTorch
  • Mamba offers a simpler API for working with state space models
  • t5x provides more comprehensive tools for training and fine-tuning

Convert Figma logo designs to code with AI

Visual Copilot

Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.

Try Visual Copilot

README

Mamba

Mamba

Mamba: Linear-Time Sequence Modeling with Selective State Spaces
Albert Gu*, Tri Dao*
Paper: https://arxiv.org/abs/2312.00752

Mamba-2

Transformers are SSMs: Generalized Models and Efficient Algorithms
Through Structured State Space Duality
Tri Dao*, Albert Gu*
Paper: https://arxiv.org/abs/2405.21060

About

Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. It is based on the line of progress on structured state space models, with an efficient hardware-aware design and implementation in the spirit of FlashAttention.

Installation

  • [Option] pip install causal-conv1d>=1.4.0: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
  • pip install mamba-ssm: the core Mamba package.
  • pip install mamba-ssm[causal-conv1d]: To install core Mamba package and causal-conv1d.
  • pip install mamba-ssm[dev]: To install core Mamba package and dev depdencies.

It can also be built from source with pip install . from this repository.

If pip complains about PyTorch versions, try passing --no-build-isolation to pip.

Other requirements:

  • Linux
  • NVIDIA GPU
  • PyTorch 1.12+
  • CUDA 11.6+

For AMD cards, see additional prerequisites below.

Usage

We expose several levels of interface with the Mamba model.

Selective SSM

Mamba is based on a selective SSM layer, which is the focus of the paper (Section 3; Algorithm 2).

Source: ops/selective_scan_interface.py.

Mamba Block

The main module of this repository is the Mamba architecture block wrapping the selective SSM.

Source: modules/mamba_simple.py.

Usage:

import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

Mamba-2

The Mamba-2 block is implemented at modules/mamba2.py.

A simpler version is at modules/mamba2_simple.py

The usage is similar to Mamba(-1):

from mamba_ssm import Mamba2
model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

SSD

A minimal version of the inner SSD module (Listing 1 from the Mamba-2 paper) with conversion between "discrete" and "continuous" SSM versions is at modules/ssd_minimal.py.

Mamba Language Model

Finally, we provide an example of a complete language model: a deep sequence model backbone (with repeating Mamba blocks) + language model head.

Source: models/mixer_seq_simple.py.

This is an example of how to integrate Mamba into an end-to-end neural network. This example is used in the generation scripts below.

Pretrained Models

Pretrained models are uploaded to Hugging Face: mamba-130m, mamba-370m, mamba-790m, mamba-1.4b, mamba-2.8b, mamba2-130m, mamba2-370m, mamba2-780m, mamba2-1.3b, mamba2-2.7b, transformerpp-2.7b, mamba2attn-2.7b, trained on 300B tokens on the Pile, as well as mamba-2.8b-slimpj (trained on 600B tokens on the SlimPajama dataset).

The models will be autodownloaded by the generation script below.

These models were trained on the Pile, and follow the standard model dimensions described by GPT-3 and followed by many open source models:

ParametersLayersModel dim.
130M24768
370M481024
790M481536
1.4B482048
2.8B642560

(The layer count of Mamba doubles that of a Transformer with similar size, as two Mamba blocks are needed for each "layer" (MHA block + MLP block) of a Transformer.)

Note: these are base models trained only for 300B tokens, without any form of downstream modification (instruction tuning, etc.). Performance is expected to be comparable or better than other architectures trained on similar data, but not to match larger or fine-tuned models.

Evaluations

To run zero-shot evaluations of models (corresponding to Table 3 of the paper), we use the lm-evaluation-harness library.

  1. Install lm-evaluation-harness by pip install lm-eval==0.4.2.
  2. Run evaluation with (more documentation at the lm-evaluation-harness repo):
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-130m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
python evals/lm_harness_eval.py --model hf --model_args pretrained=EleutherAI/pythia-160m --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande --device cuda --batch_size 64

To reproduce the results on the mamba-2.8b-slimpj model reported in the blogposts:

lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks boolq,piqa,hellaswag,winogrande,arc_easy,arc_challenge,openbookqa,race,truthfulqa_mc2 --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba-2.8b-slimpj --tasks mmlu --num_fewshot 5 --device cuda --batch_size 256

To run evaluations on Mamba-2 models, simply replace the model names:

lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/transformerpp-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256
lm_eval --model mamba_ssm --model_args pretrained=state-spaces/mamba2attn-2.7b --tasks lambada_openai,hellaswag,piqa,arc_easy,arc_challenge,winogrande,openbookqa --device cuda --batch_size 256

Note that the result of each task might differ from reported values by 0.1-0.3 due to noise in the evaluation process.

Inference

The script benchmarks/benchmark_generation_mamba_simple.py

  1. autoloads a model from the Hugging Face Hub,
  2. generates completions of a user-specified prompt,
  3. benchmarks the inference speed of this generation.

Other configurable options include the top-p (nucleus sampling) probability, and the softmax temperature.

Examples

To test generation latency (e.g. batch size = 1) with different sampling strategies:

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --minp 0.05 --topk 0 --temperature 0.7 --repetition-penalty 1.2

To test generation throughput with random prompts (e.g. large batch size):

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --batch 64
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --batch 64

With Mamba-2, you just need to change the model name:

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2

Troubleshooting

Precision

Our models were trained using PyTorch AMP for mixed precision. AMP keeps model parameters in float32 and casts to half precision when necessary. On the other hand, other frameworks like DeepSpeed store parameters in float16 and upcasts when necessary (e.g. for optimizer accumulation).

We've observed that higher precision for the main model parameters may be necessary, because SSMs are sensitive to their recurrent dynamics. If you are experiencing instabilities, as a first step please try a framework storing parameters in fp32 (such as AMP).

Initialization

Some parts of the model have initializations inherited from prior work on S4 models. For example, the $\Delta$ parameter has a targeted range by initializing the bias of its linear projection. However, some frameworks may have post-initialization hooks (e.g. setting all bias terms in nn.Linear modules to zero). If this is the case, you may have to add custom logic (e.g. this line turns off re-initializing in our trainer, but would be a no-op in any other framework) that is specific to the training framework.

Additional Prerequisites for AMD cards

Patching ROCm

If you are on ROCm 6.0, run the following steps to avoid errors during compilation. This is not required for ROCm 6.1 onwards.

  1. Locate your ROCm installation directory. This is typically found at /opt/rocm/, but may vary depending on your installation.

  2. Apply the Patch. Run with sudo in case you encounter permission issues.

     patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h < rocm_patch/rocm6_0.patch 
    

Citation

If you use this codebase, or otherwise find our work valuable, please cite Mamba:

@article{mamba,
  title={Mamba: Linear-Time Sequence Modeling with Selective State Spaces},
  author={Gu, Albert and Dao, Tri},
  journal={arXiv preprint arXiv:2312.00752},
  year={2023}
}

@inproceedings{mamba2,
  title={Transformers are {SSM}s: Generalized Models and Efficient Algorithms Through Structured State Space Duality},
  author={Dao, Tri and Gu, Albert},
  booktitle={International Conference on Machine Learning (ICML)},
  year={2024}
}