FlashMLA,deepseek发布的为Hopper GPU优化的MLA解码内核,专为可变长度序列进行了优化
支持BF16格式
带有64块大小的分页KV缓存
在 H800 上实现:
内存受限情况下:3000 GB/s
计算受限情况下:580 TFLOPS
在AI服务部署时,适合需要快速响应用户请求的场景,对于要处理大量文本的应用特别有用
现已在生产环境中使用
当前发布:
python setup.py install
python tests/test_flash_mla.py
使用 CUDA 12.6,在 H800 SXM5 上,在内存绑定配置下实现高达 3000 GB/s,在计算绑定配置下实现 580 TFLOPS。
from flash_mla import get_mla_metadata, flash_mla_with_kvcache
tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)
for i in range(num_layers):
...
o_i, lse_i = flash_mla_with_kvcache(
q_i, kvcache_i, block_table, cache_seqlens, dv,
tile_scheduler_metadata, num_splits, causal=True,
)
...