FlashAttention-2 is a Python package that can significantly improve the performance of attention mechanisms in transformer-based models when running on nvidia hardware1. A 7-10x speedup. It does it by some clever memory optimizations, leading to more of the attention operations running within SRAM instead of constantly having to communicate back to GPU RAM (e.g. HBM).
Installing it is often a beast. While there are prebuilt wheels (those are from the official source, but there are community wheel builds as well, offering a more comprehensive collection), builds need to be targeted to a specific combination of flash-attn version, GPU hardware generation (e.g. Blackwell), CUDA version, Torch version, and platform (e.g. x86-64, ARM64, etc.). The community builds has some 384 different variations, and even then it often isn’t sufficient.
So most of the time you end up building the wheel on your hardware. This is the point where a lot of people have issues and abandon the effort and just go flash-attn-less.
It’s a resource intensive build. The ninja build system will use all available cores, each build thread consuming enormous resources, which you might restrain by setting MAX_JOBS to some low number, maybe even 1. Still you’re going to find the build consuming enormous amounts of memory, grinding into paging hell, because while ninja only uses one thread, it calls out to NVCC which then has its own horizontal scaling tendencies, spawning out many copies of CICC (CUDA Internal C/C++ Compiler). CICC makes heavy use of nvidia’s CUTLASS templating engine, and every instance is going to consume 5 to 8GB, even when doing marginal processing.
You reign CICC’s insanity in with NVCC_THREADS, which will restrain NVCC from going wild with its own process spawning tendencies.
# 1. Restrict the build system to 1 file at a time
export MAX_JOBS=1
# 2. Restrict the CUDA compiler to 1 frontend process at a time
export NVCC_THREADS=1
# 3. Target Blackwell (change for your target arch)
export TORCH_CUDA_ARCH_LIST="12.0"
# 4. Build without isolation
pip install flash-attn --no-build-isolation --force-reinstall
With both thread-scale optionals set to 1 a build will take hours and hours, but it isn’t going to churn through your flash storage endurance endlessly paging and will actually succeed. Higher numbers are viable on decently equipped systems.
Just an FYI. I’ve encountered this on just about every new install and finally thought I’d comment on it. It’s the NVCC_THREADS that seldom gets mentioned, when it is by far the most effective way of reigning in outrageous memory needs to install a package.
Footnotes
-
there are ports for AMD, Intel, and Apple Silicon variants, but the main use is against nvidia ↩