---
license: apache-2.0
---
DashAttention
Differentiable and Adaptive Sparse Hierarchical Attention
## Installation
For the usage of DashAttention kernels and running the example, please run the following script:
```
pip install -e .
```
For benchmark environment setup, please refer to each corresponding folder.
## Usage
The dash attention's interface can be used as follows:
```python
queries = torch.randn(batch, query_heads, seq_len, head_dim, device=device, dtype=dtype).contiguous()
keys = torch.randn(batch, kv_heads, seq_len, head_dim, device=device, dtype=dtype).contiguous()
values = torch.randn(batch, kv_heads, seq_len, head_dim, device=device, dtype=dtype).contiguous()
head_cls = torch.randn(kv_heads, head_dim, device=device, dtype=dtype).contiguous()
model = dash_attn(
chunk_size=chunk_size,
enable_gqa=True,
estimate_diagonal=True,
return_active_blocks=True,
)
out, active_blocks = model(queries, keys, values, head_cls)
```
We also provide an example on how to use DashAttention in Llama-architecture models in [here](./example/run_niah.py).
```
python ./example/run_niah.py
```
## Documentation
DashAttention implements the attention mechanism introduced in [DashAttention: Differentiable and Adaptive Sparse Hierarchical Attention](https://arxiv.org/abs/2605.18753). The method replaces fixed-budget top-k block routing with an adaptive, differentiable sparse router, then refines the selected regions with token-level softmax attention.
### How it works
The implementation follows the three-stage hierarchy described in the paper:
1. **Local chunk summarization**: `dash_attn.prefill.summarize_chunk` and `dash_attn.decoding.summarize_chunk` build one learned key summary per KV chunk.
2. **Entmax block routing**: `score_blocks` computes sparse chunk supports and routing priors from query-to-summary scores.
3. **Prior-induced sparse softmax**: `full_attn` applies token-level attention only over routed chunks, using the Stage 1 prior to preserve differentiability through the hierarchy.
The public kernel wrapper is [`dash_attn.dash_attn_interface.dash_attn`](./dash_attn/dash_attn_interface.py). It supports both prefill and decoding: prefill summarizes the current sequence and stores complete chunk summaries, while decoding reuses the chunk-summary cache and appends newly completed chunks.
### Core API
```python
from dash_attn import dash_attn
attn = dash_attn(
chunk_size=64,
enable_gqa=True,
estimate_diagonal=True,
scaling_factor=1.0,
return_active_blocks=False,
)
```
Important arguments:
| Argument | Description |
|:-|:-|
| `chunk_size` | Number of tokens per routed KV chunk. |
| `enable_gqa` | Enables grouped-query attention support when query heads outnumber KV heads. |
| `estimate_diagonal` | Includes special handling for the current or near-diagonal chunk. |
| `scaling_factor` | Scales routing logits before sparse block selection; this is the main knob for sparsity. |
| `return_active_blocks` | Returns the number of active routed blocks per token for sparsity analysis. |
| `max_chunks` | Preallocated chunk-summary cache capacity used during decoding. |
| `sigma` | Controls the strength of the Stage 1 routing prior used by Stage 2. |
Inputs are expected in `[batch, heads, seq_len, head_dim]` layout for `queries`, `keys`, and `values`; `head_cls` has shape `[kv_heads, head_dim]`.
### Llama integration
DashAttention includes a Llama-compatible modeling implementation in [`dash_attn.models.llama`](./dash_attn/models/llama). `LlamaConfig` defaults to `attn_implementation="dash_attn"` and adds DashAttention-specific fields such as `chunk_size`, `estimate_diagonal`, `sigma`, and `scaling_factor`.
```python
from dash_attn.models.llama import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(
"fasa-org/MiniCPM-4-8B-DashAttention",
attn_implementation="dash_attn",
torch_dtype="auto",
)
```
To inspect routing behavior, call the model with `return_active_blocks=True`, then read `model.get_active_blocks()`.
### Examples and tests
- [`example/run_niah.py`](./example/run_niah.py) runs a needle-in-a-haystack style generation example and reports measured sparsity.
- [`test/test_smoke.py`](./test/test_smoke.py) checks the standalone DashAttention kernel wrapper.
- [`test/test_llama_dash_attn.py`](./test/test_llama_dash_attn.py) checks the Llama integration and active-block reporting.
Run the test suite with:
```bash
pytest
```
The current kernels require CUDA-capable hardware.
## Models
We release our 8B models for reproducibility.
| Model | Link |
|:-:|:-:|
| 8B-FullAttn | [Hugging Face](https://huggingface.co/fasa-org/MiniCPM-4-8B-FullAttn) |
| 8B-InfLLMv2 | [Hugging Face](https://huggingface.co/fasa-org/MiniCPM-4-8B-InfLLMv2) |
| 8B-NSA | [Hugging Face](https://huggingface.co/fasa-org/MiniCPM-4-8B-NSA) |
| 8B-DashAttention | [Hugging Face](https://huggingface.co/fasa-org/MiniCPM-4-8B-DashAttention) |
The base models we use are [MiniCPM4-1B-Base](https://modelscope.cn/models/OpenBMB/MiniCPM4-1B-Base), [MiniCPM4-3B-Base](https://modelscope.cn/models/OpenBMB/MiniCPM4-3B-Base), and [MiniCPM4-8B-Base](https://modelscope.cn/models/OpenBMB/MiniCPM4-8B-Base).
## Benchmarks
- Performance: Please refer to [README](./benchmarks/performance/README.md).
## License
This project is released under the [BSD-3-Clause License](./LICENSE).
## Acknowledgement
This repository is developed with the aid of [RULER](https://github.com/NVIDIA/RULER), [OLMES](https://github.com/allenai/olmes), [InfLLMv2](https://github.com/OpenBMB/infllmv2_cuda_impl), and [NSA-triton](https://github.com/XunhaoLai/native-sparse-attention-triton).
## Citation
```latex
@article{dash-attention,
title={DashAttention: Differentiable and Adaptive Sparse Hierarchical Attention},
author={Huang, Yuxiang and Gon{\c{c}}alves, Nuno M. T. and Alvetreti, Federico and Li, Lei and Han, Xu and Ponti, Edoardo M. and Martins, Andr{\'e} F. T. and Treviso, Marcos V.},
journal={arXiv preprint arXiv:2605.18753},
year={2026}
}
```