Skip to content

Conversation

@StrongerXi
Copy link

This effectively replaces the efficient impl from #220 with a more efficient and simpler compiled impl (see logprobs_from_logits_naive).

Results from run_qwen3-8b.sh with tp=1 on 8xH100 (tp=2 won't run for some reason):

                  |  max-reserved-memory  |  max-allocated-memory  |
old chunked impl  |       139.55gb        |       115.91gb         |
new compiled impl |       129.19gb        |       115.91gb         |

Also, a slightly modified test script from #220 to show that compiled impl is superior to all the other tested ones (note the added torch.cuda.synchronize() which makes the benchmark more accurate).

import time
import torch

@torch.compile
def compile_method(logits, input_ids):
    return -torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)).float(),
        input_ids.view(-1),
        reduction='none'
    ).view_as(input_ids)

def naive_method(logits, input_ids):
    log_probs = logits.log_softmax(dim=-1)
    return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def method_1(logits, input_ids):  # old logprobs_from_logits_v2 implementation
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.logsumexp(logits, dim=-1)
    token_log_probs = token_logits - logsumexp_values  # log_softmax(logits) = logits - log(sum(exp(logits)))
    return token_log_probs

def method_2(logits, input_ids):  # compute log_softmax in a loop to reduce peak memory
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def method_3(logits, input_ids):  # combine methods 1 and 2
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def efficient_method(logits, input_ids):  # pull everything out of the loop except logsumexp
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs

def measure_memory_and_time(func, logits, input_ids):
    torch.cuda.reset_peak_memory_stats()
    # warm up, especially for compile
    func(logits, input_ids)
    func(logits, input_ids)
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    result = func(logits, input_ids)
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    mem_peak = torch.cuda.max_memory_allocated()
    return result, end_time - start_time, mem_peak

torch.manual_seed(42)
vocab_size = 32768
seq_len = 1024
batch_size = 16

device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
logit_mem = torch.cuda.max_memory_allocated()

naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids)
method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids)
method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids)
method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids)
efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids)
compile_result, compile_time, compile_mem = measure_memory_and_time(compile_method, logits, input_ids)

print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item())
print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item())
print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item())
print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item())
print("Max absolute difference (naive and compile):", (naive_result - compile_result).abs().max().item())
print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6))
print("Naive method time:      {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6))
print("Method 1 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6))
print("Method 2 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6))
print("Method 3 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6))
print("Efficient method time:  {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6))
print("Compile method time:    {:.6f} sec, Memory peak: {:.2f} MB".format(compile_time, compile_mem / 1e6))

Results:

Max absolute difference (naive and 1): 1.9073486328125e-06
Max absolute difference (naive and 2): 0.0
Max absolute difference (naive and 3): 1.9073486328125e-06
Max absolute difference (naive and efficient): 1.9073486328125e-06
Max absolute difference (naive and compile): 9.5367431640625e-07
Memory consumed by logits: 2147.61 MB
Naive method time:      0.005133 sec, Memory peak: 4295.16 MB
Method 1 time:          0.008121 sec, Memory peak: 4295.36 MB
Method 2 time:          0.004993 sec, Memory peak: 2416.24 MB
Method 3 time:          0.009120 sec, Memory peak: 2282.10 MB
Efficient method time:  0.008911 sec, Memory peak: 2282.23 MB
Compile method time:    0.001566 sec, Memory peak: 2190.02 MB

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the logprobs_from_logits implementation to use a compiled version of a naive implementation, which, according to the provided benchmarks, is more efficient. The changes simplify the code by removing a more complex, memory-efficient version and centralizing tensor reshaping logic. The addition of a new test for the naive implementation is also a welcome improvement. I have found one critical issue in the NPU-specific implementation that needs to be addressed.

@StrongerXi
Copy link
Author

@vermouth1992 wdyt?

@StrongerXi StrongerXi changed the title Restructure logprobs_from_logits impl and always compile the naive impl [fsdp] feat: Restructure logprobs_from_logits impl and always compile the naive impl Oct 23, 2025
@wuxibin89
Copy link
Collaborator

wuxibin89 commented Oct 23, 2025

@StrongerXi There's a ci fail in model/model_engine, please take a look.

@wuxibin89
Copy link
Collaborator

@StrongerXi The ci is still failed, you can fix and run pytest -s -x tests/models/test_engine.py locally.

@CLAassistant
Copy link

CLAassistant commented Nov 10, 2025

CLA assistant check
All committers have signed the CLA.

@StrongerXi
Copy link
Author

@wuxibin89 just updated, could you trigger CI?

@StrongerXi StrongerXi requested a review from wuxibin89 November 13, 2025 23:58
This effectively replaces the efficient impl from volcengine#220 with a more
efficient and simpler compiled impl (see `logprobs_from_logits_naive`).

Results from `run_qwen3-8b.sh` with tp=1 on 8xH100 (tp=2 won't run for
some reason):
```
                  |  max-reserved-memory  |  max-allocated-memory  |
old chunked impl  |       139.55gb        |       115.91gb         |
new compiled impl |       129.19gb        |       115.91gb         |
```

Also, a slightly modified test script from volcengine#220 to show that compiled impl is
superior to all the other tested ones (note the added
`torch.cuda.synchronize()` which makes the benchmark more accurate).
```python
import time
import torch

@torch.compile
def compile_method(logits, input_ids):
    return -torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)).float(),
        input_ids.view(-1),
        reduction='none'
    ).view_as(input_ids)

def naive_method(logits, input_ids):
    log_probs = logits.log_softmax(dim=-1)
    return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def method_1(logits, input_ids):  # old logprobs_from_logits_v2 implementation
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.logsumexp(logits, dim=-1)
    token_log_probs = token_logits - logsumexp_values  # log_softmax(logits) = logits - log(sum(exp(logits)))
    return token_log_probs

def method_2(logits, input_ids):  # compute log_softmax in a loop to reduce peak memory
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def method_3(logits, input_ids):  # combine methods 1 and 2
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def efficient_method(logits, input_ids):  # pull everything out of the loop except logsumexp
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs

def measure_memory_and_time(func, logits, input_ids):
    torch.cuda.reset_peak_memory_stats()
    # warm up, especially for compile
    func(logits, input_ids)
    func(logits, input_ids)
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    result = func(logits, input_ids)
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    mem_peak = torch.cuda.max_memory_allocated()
    return result, end_time - start_time, mem_peak

torch.manual_seed(42)
vocab_size = 32768
seq_len = 1024
batch_size = 16

device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
logit_mem = torch.cuda.max_memory_allocated()

naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids)
method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids)
method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids)
method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids)
efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids)
compile_result, compile_time, compile_mem = measure_memory_and_time(compile_method, logits, input_ids)

print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item())
print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item())
print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item())
print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item())
print("Max absolute difference (naive and compile):", (naive_result - compile_result).abs().max().item())
print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6))
print("Naive method time:      {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6))
print("Method 1 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6))
print("Method 2 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6))
print("Method 3 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6))
print("Efficient method time:  {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6))
print("Compile method time:    {:.6f} sec, Memory peak: {:.2f} MB".format(compile_time, compile_mem / 1e6))
```

Results:
```
Max absolute difference (naive and 1): 1.9073486328125e-06
Max absolute difference (naive and 2): 0.0
Max absolute difference (naive and 3): 1.9073486328125e-06
Max absolute difference (naive and efficient): 1.9073486328125e-06
Max absolute difference (naive and compile): 9.5367431640625e-07
Memory consumed by logits: 2147.61 MB
Naive method time:      0.005133 sec, Memory peak: 4295.16 MB
Method 1 time:          0.008121 sec, Memory peak: 4295.36 MB
Method 2 time:          0.004993 sec, Memory peak: 2416.24 MB
Method 3 time:          0.009120 sec, Memory peak: 2282.10 MB
Efficient method time:  0.008911 sec, Memory peak: 2282.23 MB
Compile method time:    0.001566 sec, Memory peak: 2190.02 MB
```
@StrongerXi
Copy link
Author

Sorry, keep having trouble running the test locally, fixed another shape error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants