Reproducibility in large language model (LLM) training is a growing concern, and researchers are actively seeking solutions that don’t come at the cost of performance. A team from Shanghai Jiao Tong University and ByteDance Seed has introduced DASH (Deterministic Attention Scheduling for High-Throughput), a new approach designed to address the performance penalty often associated with deterministic attention mechanisms. Their work, detailed in a paper published on , demonstrates that DASH can significantly improve throughput – by up to 1.28x – while maintaining the crucial benefit of reproducible results.
The challenge stems from the fact that deterministic attention, necessary for ensuring bitwise identical results across training runs, can substantially reduce training speed. Existing attention implementations, such as FlashAttention-3, can experience a throughput reduction of up to 37.9% when deterministic backward passes are enabled. This slowdown is caused by the need to serialize gradient accumulation operations to guarantee numerical consistency, leading to inefficient hardware utilization.
Formulating Attention as a Scheduling Problem
The researchers tackled this problem by reframing the deterministic attention backward pass as a scheduling problem on a Directed Acyclic Graph (DAG). This allowed them to derive schedules specifically designed to minimize the critical path length – the longest sequence of operations that determines the overall execution time. By optimizing the order in which computations are performed, DASH aims to reduce bottlenecks and maximize hardware efficiency.
DASH incorporates two complementary scheduling strategies. The first, Descending Q-Tile Iteration, focuses on causal attention mechanisms. It employs a reversed query-block traversal to reduce pipeline stalls, optimizing data flow and minimizing dependencies. The second, Shift Scheduling, is a theoretically optimal schedule within the DAG model, designed to reduce pipeline stalls for both full and causal attention masks. This approach addresses the core issue of misalignment between tile execution and accumulation ordering, a key contributor to performance degradation.
Performance Gains on NVIDIA H800 GPUs
Empirical evaluations conducted on NVIDIA H800 GPUs demonstrated the effectiveness of DASH. The team’s experiments, utilizing CUDA 12.6 and Triton 3.4, showed that DASH improved the throughput of the attention backward pass by up to 1.28x compared to the baseline deterministic FlashAttention-3 implementation. Benchmarks were conducted with a fixed total of 16,384 tokens, varying sequence lengths from 512 to 16,384, and testing hidden dimensions of 2,048 with head dimensions of 64 and 128, all using BF16 precision random inputs.
The research identified that inter-SM (Streaming Multiprocessor) communication latency, particularly accesses to remote L2 cache segments (ranging from 200 to over 500 cycles), can become a limiting factor at extreme parallelism. While Shift Scheduling offered computational benefits, it proved more susceptible to this communication overhead in some scenarios. However, for causal attention masks, both Descending Q-Tile Iteration and Symmetric Shift Scheduling consistently improved throughput, with Symmetric Shift Scheduling demonstrating superior workload balancing at a head dimension of 64.
Parallelizing Reduction and Optimizing Tile Execution
The core innovation of DASH lies in its ability to parallelize the reduction process. Unlike naive scheduling approaches that force sequential reductions, DASH allows CTAs (Compute Thread Arrays) to begin reduction on different tiles concurrently, thereby avoiding bottlenecks. This is achieved by carefully coordinating tile execution with accumulation ordering, a misalignment previously identified as a primary cause of performance loss.
The researchers emphasize that the performance gap isn’t inherent to the serialization process itself, but rather a consequence of suboptimal tile scheduling and a rigid accumulation order. By modelling the deterministic backward pass as a DAG, they were able to design strategies that optimize the critical path length, ensuring a more balanced workload and reducing contention during serial reduction operations.
Open-Source Availability and Future Directions
To facilitate further research and adoption within the LLM community, the team has open-sourced the DASH code at https://github.com/SJTU-Liquid/deterministic-FA3. The researchers acknowledge that theoretical optimality doesn’t always translate to practical superiority, and that hardware limitations like register pressure and inter-SM communication latency play a crucial role in performance. Future work could explore further optimization considering these hardware realities.
The study highlights the importance of co-optimizing execution and accumulation order, rather than solely focusing on bandwidth or memory. DASH provides a suite of solutions that enable practitioners to achieve high throughput attention while maintaining reproducibility in LLM training, suggesting that a nuanced approach, balancing theoretical optimality with practical hardware constraints, is essential for maximizing performance in this domain.
