-
Notifications
You must be signed in to change notification settings - Fork 3k
[fsdp] feat: Restructure logprobs_from_logits impl and always compile the naive impl #3852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the logprobs_from_logits implementation to use a compiled version of a naive implementation, which, according to the provided benchmarks, is more efficient. The changes simplify the code by removing a more complex, memory-efficient version and centralizing tensor reshaping logic. The addition of a new test for the naive implementation is also a welcome improvement. I have found one critical issue in the NPU-specific implementation that needs to be addressed.
|
@vermouth1992 wdyt? |
|
@StrongerXi There's a ci fail in |
|
@StrongerXi The ci is still failed, you can fix and run |
|
@wuxibin89 just updated, could you trigger CI? |
This effectively replaces the efficient impl from volcengine#220 with a more efficient and simpler compiled impl (see `logprobs_from_logits_naive`). Results from `run_qwen3-8b.sh` with tp=1 on 8xH100 (tp=2 won't run for some reason): ``` | max-reserved-memory | max-allocated-memory | old chunked impl | 139.55gb | 115.91gb | new compiled impl | 129.19gb | 115.91gb | ``` Also, a slightly modified test script from volcengine#220 to show that compiled impl is superior to all the other tested ones (note the added `torch.cuda.synchronize()` which makes the benchmark more accurate). ```python import time import torch @torch.compile def compile_method(logits, input_ids): return -torch.nn.functional.cross_entropy( logits.view(-1, logits.size(-1)).float(), input_ids.view(-1), reduction='none' ).view_as(input_ids) def naive_method(logits, input_ids): log_probs = logits.log_softmax(dim=-1) return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) def method_1(logits, input_ids): # old logprobs_from_logits_v2 implementation token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) logsumexp_values = torch.logsumexp(logits, dim=-1) token_log_probs = token_logits - logsumexp_values # log_softmax(logits) = logits - log(sum(exp(logits))) return token_log_probs def method_2(logits, input_ids): # compute log_softmax in a loop to reduce peak memory per_token_logps = [] for logits_row, input_ids_row in zip(logits, input_ids): log_probs = logits_row.log_softmax(dim=-1) token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1) per_token_logps.append(token_log_prob) return torch.stack(per_token_logps) def method_3(logits, input_ids): # combine methods 1 and 2 per_token_logps = [] for logits_row, input_ids_row in zip(logits, input_ids): token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1) token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1) per_token_logps.append(token_log_prob) return torch.stack(per_token_logps) def efficient_method(logits, input_ids): # pull everything out of the loop except logsumexp token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1) logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits]) token_log_probs = token_logits - logsumexp_values return token_log_probs def measure_memory_and_time(func, logits, input_ids): torch.cuda.reset_peak_memory_stats() # warm up, especially for compile func(logits, input_ids) func(logits, input_ids) torch.cuda.synchronize() start_time = time.perf_counter() result = func(logits, input_ids) torch.cuda.synchronize() end_time = time.perf_counter() mem_peak = torch.cuda.max_memory_allocated() return result, end_time - start_time, mem_peak torch.manual_seed(42) vocab_size = 32768 seq_len = 1024 batch_size = 16 device = "cuda" if torch.cuda.is_available() else "cpu" logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32) input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) logit_mem = torch.cuda.max_memory_allocated() naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids) method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids) method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids) method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids) efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids) compile_result, compile_time, compile_mem = measure_memory_and_time(compile_method, logits, input_ids) print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item()) print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item()) print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item()) print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item()) print("Max absolute difference (naive and compile):", (naive_result - compile_result).abs().max().item()) print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6)) print("Naive method time: {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6)) print("Method 1 time: {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6)) print("Method 2 time: {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6)) print("Method 3 time: {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6)) print("Efficient method time: {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6)) print("Compile method time: {:.6f} sec, Memory peak: {:.2f} MB".format(compile_time, compile_mem / 1e6)) ``` Results: ``` Max absolute difference (naive and 1): 1.9073486328125e-06 Max absolute difference (naive and 2): 0.0 Max absolute difference (naive and 3): 1.9073486328125e-06 Max absolute difference (naive and efficient): 1.9073486328125e-06 Max absolute difference (naive and compile): 9.5367431640625e-07 Memory consumed by logits: 2147.61 MB Naive method time: 0.005133 sec, Memory peak: 4295.16 MB Method 1 time: 0.008121 sec, Memory peak: 4295.36 MB Method 2 time: 0.004993 sec, Memory peak: 2416.24 MB Method 3 time: 0.009120 sec, Memory peak: 2282.10 MB Efficient method time: 0.008911 sec, Memory peak: 2282.23 MB Compile method time: 0.001566 sec, Memory peak: 2190.02 MB ```
|
Sorry, keep having trouble running the test locally, fixed another shape error. |
This effectively replaces the efficient impl from #220 with a more efficient and simpler compiled impl (see
logprobs_from_logits_naive).Results from
run_qwen3-8b.shwith tp=1 on 8xH100 (tp=2 won't run for some reason):Also, a slightly modified test script from #220 to show that compiled impl is superior to all the other tested ones (note the added
torch.cuda.synchronize()which makes the benchmark more accurate).Results: