Add Blelloch parallel prefix scan for LASP#2
Open
petrpan26 wants to merge 6 commits intoOpenNLPLab:mainfrom
Open
Add Blelloch parallel prefix scan for LASP#2petrpan26 wants to merge 6 commits intoOpenNLPLab:mainfrom
petrpan26 wants to merge 6 commits intoOpenNLPLab:mainfrom
Conversation
134e5a6 to
75aca60
Compare
This PR implements Blelloch parallel prefix scan to reduce inter-GPU communication from O(P) sequential steps (ring) to O(log P) parallel steps (tree-based). Key improvements: - O(log P) communication complexity (e.g., 128 GPUs: 128 steps → 14 steps) - Work-efficient tree-based algorithm - Supports non-power-of-2 GPU counts - Reuses KV/DKV buffers to avoid allocation overhead Implementation details: 1. **BlellochScanner** (lasp/utils/blelloch_ops.py): - Tree-based up-sweep and down-sweep communication - Correct sender/receiver logic using "right edge" of subtrees - Distance-based decay in down-sweep for proper accumulation - Support for reverse scan (suffix) for backward pass - Global rank conversion for multi-group data parallelism 2. **lasp_blelloch** (lasp/lasp_blelloch.py): - Combines Blelloch scan with fused Triton kernels - Correct inclusive-to-exclusive conversion: λ^(-C) * (inclusive - local) - Buffer reuse pattern matching lasp_fuse_parallel - Forward: prefix scan, Backward: suffix scan 3. **Tests and benchmarks**: - test_blelloch_correctness.py: Gradient correctness tests - test_non_power_of_two.py: Non-power-of-2 world sizes - benchmark_blelloch.py: Performance benchmarks - benchmark_all_methods.py: Comprehensive comparison Tested with: - Single GPU and multi-GPU (4-8 GPUs) - Data parallelism (dp_size > 1) with sequence parallelism - Power-of-2 and non-power-of-2 world sizes - Forward and backward pass correctness
75aca60 to
c5cd122
Compare
Collaborator
|
Hi petrpan26, Nice work. Thanks for contributing LASP! I will check the code change and do some tests in a few days. |
Changed Blelloch scan to compute exclusive prefix directly instead of converting from inclusive, avoiding division by lambda^n which causes overflow when lambda is small. Implementation: 1. Compute inclusive prefix using standard up-sweep + down-sweep 2. Convert to exclusive via simple rank shift: each rank i receives inclusive[i-1] from rank i-1, rank 0 gets zero This matches the pattern used in lasp_naive where the ring naturally produces exclusive prefix, avoiding the numerical issues of computing 1/lambda^n which overflows to infinity when s >= 1.0. Fixes NaN gradients in backward pass.
6842b21 to
ac2f03b
Compare
Author
|
@weigao266 Sounds good I'm debugging why for large steps i think errors are accumulating but I added most of the result in and it should looks correct and add a few benchmarks file as well. Feel free to comment and let me know if I can change anything |
Root cause: In suffix scan (backward pass), the rank shift was sending in the wrong direction. For suffix scan, rank i should receive from rank i+1 (not i-1) and send to rank i-1 (not i+1). The bug: Used scan_rank±1 for both prefix and suffix, which worked for prefix but was backwards for suffix due to the scan_rank reversal. The fix: - Separate logic for prefix vs suffix scan in rank shift - Prefix: rank i receives from i-1, sends to i+1 (left to right) - Suffix: rank i receives from i+1, sends to i-1 (right to left) - Use actual rank (not scan_rank) for the shift communication - Add actual_to_global_rank() helper to avoid scan_rank confusion This should fix the 10x larger backward gradient errors (dk: 0.209, dv: 0.297) by ensuring the suffix scan produces correct exclusive values for each rank.
1834356 to
9881835
Compare
added 2 commits
November 4, 2025 09:53
Root cause: With 32+ GPUs, the rank shift was hanging because blocking send/recv created a sequential dependency chain. Each rank had to wait for the previous rank to send before it could send to the next rank, creating O(P) latency and potential deadlock. The fix: Use dist.irecv() and dist.isend() (non-blocking) instead of blocking send/recv. This allows all ranks to initiate their send/recv operations simultaneously, then wait for completion. Benefits: - Prevents deadlock with large GPU counts (tested hang at 32 GPUs) - Allows parallel execution of send/recv operations - Maintains O(1) latency for the rank shift step This preserves the O(log P) overall complexity of Blelloch scan.
890279c to
f84f1d4
Compare
Author
|
I was try running on smaller GPU but i cant so i tune the params a bit as well |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Todo:
For sequence parallelism, There is one thing that I think we are doing inefficiently right now. We are doing accumulation of KV in a linear way and this in turn incur more latencies as more GPU are added (O(n) in this case). This have a lot of GPU idling between waiting KV accumulation. I'm suggesting adding a blelloch prefix scan algorithm to help reduce this linear steps to logarithmic instead.
When testing in 8xH100SXM, I saw a 2x speed up over other methods.
dp-size=2 (Data Parallel: 2, Sequence Parallel: 4)
Forward-only throughput:
fuse: 49.56M tokens/s (3.59x speedup) — fastest
blelloch: 39.76M tokens/s (2.88x speedup)
fuse_parallel: 35.14M tokens/s (2.55x speedup)
cache: 26.78M tokens/s (1.94x speedup)
naive: 13.80M tokens/s (baseline)
Forward+backward throughput:
blelloch: 16.81M tokens/s (3.81x speedup) — fastest
fuse: 16.69M tokens/s (3.78x speedup)
fuse_parallel: 11.87M tokens/s (2.69x speedup)
cache: 6.68M tokens/s (1.51x speedup)
naive: 4.41M tokens/s (baseline)
gpu = 8, dp = 2
gpu = 8, dp = 1
At gpu = 16 RTX 5090
Forward + Backward we see more profound effect at more gpu
Forward only
Key improvements:
Implementation details:
BlellochScanner (lasp/utils/blelloch_ops.py):
lasp_blelloch (lasp/lasp_blelloch.py):
Tests and benchmarks:
Tested with: