Skip to content

Conversation

@tyler-romero
Copy link
Contributor

@tyler-romero tyler-romero commented Feb 7, 2025

Existing logprobs_from_logits_v2 doesnt achieve the memory savings it claims. This is because logsumexp still allocates a bs*seqlen*vocab tensor internally to hold the element-wise application of exp. However, by applying a loop over logsumexp, we can iteratively compute logsumexp outputs.

Benchmarks show this uses significantly less memory to compute logprobs.

Fix provided, as well as a separate memory-efficient approach for bfloat16 case.

@tyler-romero
Copy link
Contributor Author

Benchmarks:

import time
import torch

def naive_method(logits, input_ids):
    log_probs = logits.log_softmax(dim=-1)
    return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def method_1(logits, input_ids):  # old logprobs_from_logits_v2 implementation
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.logsumexp(logits, dim=-1)
    token_log_probs = token_logits - logsumexp_values  # log_softmax(logits) = logits - log(sum(exp(logits)))
    return token_log_probs

def method_2(logits, input_ids):  # compute log_softmax in a loop to reduce peak memory
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def method_3(logits, input_ids):  # combine methods 1 and 2
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def efficient_method(logits, input_ids):  # pull everything out of the loop except logsumexp
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs

def measure_memory_and_time(func, logits, input_ids):
    torch.cuda.reset_peak_memory_stats()
    start_time = time.perf_counter()
    result = func(logits, input_ids)
    end_time = time.perf_counter()
    mem_peak = torch.cuda.max_memory_allocated()
    return result, end_time - start_time, mem_peak

# Simulated data
torch.manual_seed(42)
vocab_size = 32768
seq_len = 1024
batch_size = 16

device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
logit_mem = torch.cuda.max_memory_allocated()

# Run all methods
naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids)
method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids)
method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids)
method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids)
efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids)

# Check equivalence
print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item())
print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item())
print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item())
print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item())
print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6))
print("Naive method time:      {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6))
print("Method 1 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6))
print("Method 2 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6))
print("Method 3 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6))
print("Efficient method time:  {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6))

# Results:
# > Max absolute difference (naive and 1): 1.9073486328125e-06
# > Max absolute difference (naive and 2): 0.0
# > Max absolute difference (naive and 3): 1.9073486328125e-06
# > Max absolute difference (naive and efficient): 1.9073486328125e-06
# > Memory consumed by logits: 2147.61 MB
# > Naive method time:      0.036307 sec, Memory peak: 4295.16 MB
# > Method 1 time:          0.134651 sec, Memory peak: 4295.43 MB
# > Method 2 time:          0.012156 sec, Memory peak: 2416.18 MB
# > Method 3 time:          0.001496 sec, Memory peak: 2282.10 MB
# > Efficient method time:  0.000918 sec, Memory peak: 2282.23 MB

@tyler-romero tyler-romero marked this pull request as ready for review February 7, 2025 07:57
@vermouth1992
Copy link
Collaborator

Hi @tyler-romero,

Great catch! Could you please put your test cases into our CI systems so that future PRs won't break it. Thanks. You can move your tests to here https://github.com/volcengine/verl/blob/main/tests/gpu_utility/test_torch_functional.py following pytest style.

@tyler-romero
Copy link
Contributor Author

Added test, and they're passing locally for me

@vermouth1992 vermouth1992 merged commit 4b51624 into volcengine:main Feb 8, 2025
11 checks passed
sunyi0505 pushed a commit to sunyi0505/verl that referenced this pull request Feb 20, 2025
)

Existing `logprobs_from_logits_v2` doesnt achieve the memory savings it
claims. This is because `logsumexp` still allocates a `bs*seqlen*vocab`
tensor internally to hold the element-wise application of `exp`.
However, by applying a loop over `logsumexp`, we can iteratively compute
logsumexp outputs.

Benchmarks show this uses significantly less memory to compute logprobs.

Fix provided, as well as a separate memory-efficient approach for
bfloat16 case.
yuchenwang3 pushed a commit to yuchenwang3/verl that referenced this pull request Apr 25, 2025
)

Existing `logprobs_from_logits_v2` doesnt achieve the memory savings it
claims. This is because `logsumexp` still allocates a `bs*seqlen*vocab`
tensor internally to hold the element-wise application of `exp`.
However, by applying a loop over `logsumexp`, we can iteratively compute
logsumexp outputs.

Benchmarks show this uses significantly less memory to compute logprobs.

Fix provided, as well as a separate memory-efficient approach for
bfloat16 case.
histmeisah pushed a commit to SJTU-IAAR/verl that referenced this pull request Apr 27, 2025
)

Existing `logprobs_from_logits_v2` doesnt achieve the memory savings it
claims. This is because `logsumexp` still allocates a `bs*seqlen*vocab`
tensor internally to hold the element-wise application of `exp`.
However, by applying a loop over `logsumexp`, we can iteratively compute
logsumexp outputs.

Benchmarks show this uses significantly less memory to compute logprobs.

Fix provided, as well as a separate memory-efficient approach for
bfloat16 case.
StrongerXi added a commit to StrongerXi/verl that referenced this pull request Oct 22, 2025
This effectively replaces the efficient impl from volcengine#220 with a more
efficient and simpler compiled impl (see `logprobs_from_logits_naive`).

Results from `run_qwen3-8b.sh` with tp=1 on 8xH100 (tp=2 won't run for
some reason):
```               |  max-reserved-memory  |  max-allocated-memory  |
old chunked impl  |       139.55gb        |       115.91gb         |
new compiled impl |       129.19gb        |       115.91gb         |
```

Also, a slightly modified test script from volcengine#220 to show that compiled impl is
superior to all the other tested ones (note the added
`torch.cuda.synchronize()` which makes the benchmark more accurate).
```python
import time
import torch

@torch.compile
def compile_method(logits, input_ids):
    return -torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)).float(),
        input_ids.view(-1),
        reduction='none'
    ).view_as(input_ids)

def naive_method(logits, input_ids):
    log_probs = logits.log_softmax(dim=-1)
    return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def method_1(logits, input_ids):  # old logprobs_from_logits_v2 implementation
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.logsumexp(logits, dim=-1)
    token_log_probs = token_logits - logsumexp_values  # log_softmax(logits) = logits - log(sum(exp(logits)))
    return token_log_probs

def method_2(logits, input_ids):  # compute log_softmax in a loop to reduce peak memory
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def method_3(logits, input_ids):  # combine methods 1 and 2
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def efficient_method(logits, input_ids):  # pull everything out of the loop except logsumexp
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs

def measure_memory_and_time(func, logits, input_ids):
    torch.cuda.reset_peak_memory_stats()
    # warm up, especially for compile
    func(logits, input_ids)
    func(logits, input_ids)
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    result = func(logits, input_ids)
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    mem_peak = torch.cuda.max_memory_allocated()
    return result, end_time - start_time, mem_peak

torch.manual_seed(42)
vocab_size = 32768
seq_len = 1024
batch_size = 16

device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
logit_mem = torch.cuda.max_memory_allocated()

naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids)
method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids)
method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids)
method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids)
efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids)
compile_result, compile_time, compile_mem = measure_memory_and_time(compile_method, logits, input_ids)

print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item())
print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item())
print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item())
print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item())
print("Max absolute difference (naive and compile):", (naive_result - compile_result).abs().max().item())
print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6))
print("Naive method time:      {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6))
print("Method 1 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6))
print("Method 2 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6))
print("Method 3 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6))
print("Efficient method time:  {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6))
print("Compile method time:    {:.6f} sec, Memory peak: {:.2f} MB".format(compile_time, compile_mem / 1e6))
```

Results:
```
Max absolute difference (naive and 1): 1.9073486328125e-06
Max absolute difference (naive and 2): 0.0
Max absolute difference (naive and 3): 1.9073486328125e-06
Max absolute difference (naive and efficient): 1.9073486328125e-06
Max absolute difference (naive and compile): 9.5367431640625e-07
Memory consumed by logits: 2147.61 MB
Naive method time:      0.005133 sec, Memory peak: 4295.16 MB
Method 1 time:          0.008121 sec, Memory peak: 4295.36 MB
Method 2 time:          0.004993 sec, Memory peak: 2416.24 MB
Method 3 time:          0.009120 sec, Memory peak: 2282.10 MB
Efficient method time:  0.008911 sec, Memory peak: 2282.23 MB
Compile method time:    0.001566 sec, Memory peak: 2190.02 MB
```
StrongerXi added a commit to StrongerXi/verl that referenced this pull request Oct 22, 2025
This effectively replaces the efficient impl from volcengine#220 with a more
efficient and simpler compiled impl (see `logprobs_from_logits_naive`).

Results from `run_qwen3-8b.sh` with tp=1 on 8xH100 (tp=2 won't run for
some reason):
```
                  |  max-reserved-memory  |  max-allocated-memory  |
old chunked impl  |       139.55gb        |       115.91gb         |
new compiled impl |       129.19gb        |       115.91gb         |
```

Also, a slightly modified test script from volcengine#220 to show that compiled impl is
superior to all the other tested ones (note the added
`torch.cuda.synchronize()` which makes the benchmark more accurate).
```python
import time
import torch

@torch.compile
def compile_method(logits, input_ids):
    return -torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)).float(),
        input_ids.view(-1),
        reduction='none'
    ).view_as(input_ids)

def naive_method(logits, input_ids):
    log_probs = logits.log_softmax(dim=-1)
    return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def method_1(logits, input_ids):  # old logprobs_from_logits_v2 implementation
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.logsumexp(logits, dim=-1)
    token_log_probs = token_logits - logsumexp_values  # log_softmax(logits) = logits - log(sum(exp(logits)))
    return token_log_probs

def method_2(logits, input_ids):  # compute log_softmax in a loop to reduce peak memory
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def method_3(logits, input_ids):  # combine methods 1 and 2
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def efficient_method(logits, input_ids):  # pull everything out of the loop except logsumexp
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs

def measure_memory_and_time(func, logits, input_ids):
    torch.cuda.reset_peak_memory_stats()
    # warm up, especially for compile
    func(logits, input_ids)
    func(logits, input_ids)
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    result = func(logits, input_ids)
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    mem_peak = torch.cuda.max_memory_allocated()
    return result, end_time - start_time, mem_peak

torch.manual_seed(42)
vocab_size = 32768
seq_len = 1024
batch_size = 16

device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
logit_mem = torch.cuda.max_memory_allocated()

naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids)
method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids)
method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids)
method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids)
efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids)
compile_result, compile_time, compile_mem = measure_memory_and_time(compile_method, logits, input_ids)

print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item())
print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item())
print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item())
print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item())
print("Max absolute difference (naive and compile):", (naive_result - compile_result).abs().max().item())
print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6))
print("Naive method time:      {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6))
print("Method 1 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6))
print("Method 2 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6))
print("Method 3 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6))
print("Efficient method time:  {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6))
print("Compile method time:    {:.6f} sec, Memory peak: {:.2f} MB".format(compile_time, compile_mem / 1e6))
```

Results:
```
Max absolute difference (naive and 1): 1.9073486328125e-06
Max absolute difference (naive and 2): 0.0
Max absolute difference (naive and 3): 1.9073486328125e-06
Max absolute difference (naive and efficient): 1.9073486328125e-06
Max absolute difference (naive and compile): 9.5367431640625e-07
Memory consumed by logits: 2147.61 MB
Naive method time:      0.005133 sec, Memory peak: 4295.16 MB
Method 1 time:          0.008121 sec, Memory peak: 4295.36 MB
Method 2 time:          0.004993 sec, Memory peak: 2416.24 MB
Method 3 time:          0.009120 sec, Memory peak: 2282.10 MB
Efficient method time:  0.008911 sec, Memory peak: 2282.23 MB
Compile method time:    0.001566 sec, Memory peak: 2190.02 MB
```
StrongerXi added a commit to StrongerXi/verl that referenced this pull request Oct 22, 2025
This effectively replaces the efficient impl from volcengine#220 with a more
efficient and simpler compiled impl (see `logprobs_from_logits_naive`).

Results from `run_qwen3-8b.sh` with tp=1 on 8xH100 (tp=2 won't run for
some reason):
```
                  |  max-reserved-memory  |  max-allocated-memory  |
old chunked impl  |       139.55gb        |       115.91gb         |
new compiled impl |       129.19gb        |       115.91gb         |
```

Also, a slightly modified test script from volcengine#220 to show that compiled impl is
superior to all the other tested ones (note the added
`torch.cuda.synchronize()` which makes the benchmark more accurate).
```python
import time
import torch

@torch.compile
def compile_method(logits, input_ids):
    return -torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)).float(),
        input_ids.view(-1),
        reduction='none'
    ).view_as(input_ids)

def naive_method(logits, input_ids):
    log_probs = logits.log_softmax(dim=-1)
    return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def method_1(logits, input_ids):  # old logprobs_from_logits_v2 implementation
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.logsumexp(logits, dim=-1)
    token_log_probs = token_logits - logsumexp_values  # log_softmax(logits) = logits - log(sum(exp(logits)))
    return token_log_probs

def method_2(logits, input_ids):  # compute log_softmax in a loop to reduce peak memory
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def method_3(logits, input_ids):  # combine methods 1 and 2
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def efficient_method(logits, input_ids):  # pull everything out of the loop except logsumexp
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs

def measure_memory_and_time(func, logits, input_ids):
    torch.cuda.reset_peak_memory_stats()
    # warm up, especially for compile
    func(logits, input_ids)
    func(logits, input_ids)
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    result = func(logits, input_ids)
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    mem_peak = torch.cuda.max_memory_allocated()
    return result, end_time - start_time, mem_peak

torch.manual_seed(42)
vocab_size = 32768
seq_len = 1024
batch_size = 16

device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
logit_mem = torch.cuda.max_memory_allocated()

naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids)
method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids)
method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids)
method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids)
efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids)
compile_result, compile_time, compile_mem = measure_memory_and_time(compile_method, logits, input_ids)

print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item())
print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item())
print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item())
print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item())
print("Max absolute difference (naive and compile):", (naive_result - compile_result).abs().max().item())
print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6))
print("Naive method time:      {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6))
print("Method 1 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6))
print("Method 2 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6))
print("Method 3 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6))
print("Efficient method time:  {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6))
print("Compile method time:    {:.6f} sec, Memory peak: {:.2f} MB".format(compile_time, compile_mem / 1e6))
```

Results:
```
Max absolute difference (naive and 1): 1.9073486328125e-06
Max absolute difference (naive and 2): 0.0
Max absolute difference (naive and 3): 1.9073486328125e-06
Max absolute difference (naive and efficient): 1.9073486328125e-06
Max absolute difference (naive and compile): 9.5367431640625e-07
Memory consumed by logits: 2147.61 MB
Naive method time:      0.005133 sec, Memory peak: 4295.16 MB
Method 1 time:          0.008121 sec, Memory peak: 4295.36 MB
Method 2 time:          0.004993 sec, Memory peak: 2416.24 MB
Method 3 time:          0.009120 sec, Memory peak: 2282.10 MB
Efficient method time:  0.008911 sec, Memory peak: 2282.23 MB
Compile method time:    0.001566 sec, Memory peak: 2190.02 MB
```
StrongerXi added a commit to StrongerXi/verl that referenced this pull request Oct 22, 2025
This effectively replaces the efficient impl from volcengine#220 with a more
efficient and simpler compiled impl (see `logprobs_from_logits_naive`).

Results from `run_qwen3-8b.sh` with tp=1 on 8xH100 (tp=2 won't run for
some reason):
```
                  |  max-reserved-memory  |  max-allocated-memory  |
old chunked impl  |       139.55gb        |       115.91gb         |
new compiled impl |       129.19gb        |       115.91gb         |
```

Also, a slightly modified test script from volcengine#220 to show that compiled impl is
superior to all the other tested ones (note the added
`torch.cuda.synchronize()` which makes the benchmark more accurate).
```python
import time
import torch

@torch.compile
def compile_method(logits, input_ids):
    return -torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)).float(),
        input_ids.view(-1),
        reduction='none'
    ).view_as(input_ids)

def naive_method(logits, input_ids):
    log_probs = logits.log_softmax(dim=-1)
    return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def method_1(logits, input_ids):  # old logprobs_from_logits_v2 implementation
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.logsumexp(logits, dim=-1)
    token_log_probs = token_logits - logsumexp_values  # log_softmax(logits) = logits - log(sum(exp(logits)))
    return token_log_probs

def method_2(logits, input_ids):  # compute log_softmax in a loop to reduce peak memory
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def method_3(logits, input_ids):  # combine methods 1 and 2
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def efficient_method(logits, input_ids):  # pull everything out of the loop except logsumexp
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs

def measure_memory_and_time(func, logits, input_ids):
    torch.cuda.reset_peak_memory_stats()
    # warm up, especially for compile
    func(logits, input_ids)
    func(logits, input_ids)
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    result = func(logits, input_ids)
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    mem_peak = torch.cuda.max_memory_allocated()
    return result, end_time - start_time, mem_peak

torch.manual_seed(42)
vocab_size = 32768
seq_len = 1024
batch_size = 16

device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
logit_mem = torch.cuda.max_memory_allocated()

naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids)
method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids)
method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids)
method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids)
efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids)
compile_result, compile_time, compile_mem = measure_memory_and_time(compile_method, logits, input_ids)

print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item())
print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item())
print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item())
print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item())
print("Max absolute difference (naive and compile):", (naive_result - compile_result).abs().max().item())
print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6))
print("Naive method time:      {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6))
print("Method 1 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6))
print("Method 2 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6))
print("Method 3 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6))
print("Efficient method time:  {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6))
print("Compile method time:    {:.6f} sec, Memory peak: {:.2f} MB".format(compile_time, compile_mem / 1e6))
```

Results:
```
Max absolute difference (naive and 1): 1.9073486328125e-06
Max absolute difference (naive and 2): 0.0
Max absolute difference (naive and 3): 1.9073486328125e-06
Max absolute difference (naive and efficient): 1.9073486328125e-06
Max absolute difference (naive and compile): 9.5367431640625e-07
Memory consumed by logits: 2147.61 MB
Naive method time:      0.005133 sec, Memory peak: 4295.16 MB
Method 1 time:          0.008121 sec, Memory peak: 4295.36 MB
Method 2 time:          0.004993 sec, Memory peak: 2416.24 MB
Method 3 time:          0.009120 sec, Memory peak: 2282.10 MB
Efficient method time:  0.008911 sec, Memory peak: 2282.23 MB
Compile method time:    0.001566 sec, Memory peak: 2190.02 MB
```
StrongerXi added a commit to StrongerXi/verl that referenced this pull request Oct 24, 2025
This effectively replaces the efficient impl from volcengine#220 with a more
efficient and simpler compiled impl (see `logprobs_from_logits_naive`).

Results from `run_qwen3-8b.sh` with tp=1 on 8xH100 (tp=2 won't run for
some reason):
```
                  |  max-reserved-memory  |  max-allocated-memory  |
old chunked impl  |       139.55gb        |       115.91gb         |
new compiled impl |       129.19gb        |       115.91gb         |
```

Also, a slightly modified test script from volcengine#220 to show that compiled impl is
superior to all the other tested ones (note the added
`torch.cuda.synchronize()` which makes the benchmark more accurate).
```python
import time
import torch

@torch.compile
def compile_method(logits, input_ids):
    return -torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)).float(),
        input_ids.view(-1),
        reduction='none'
    ).view_as(input_ids)

def naive_method(logits, input_ids):
    log_probs = logits.log_softmax(dim=-1)
    return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def method_1(logits, input_ids):  # old logprobs_from_logits_v2 implementation
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.logsumexp(logits, dim=-1)
    token_log_probs = token_logits - logsumexp_values  # log_softmax(logits) = logits - log(sum(exp(logits)))
    return token_log_probs

def method_2(logits, input_ids):  # compute log_softmax in a loop to reduce peak memory
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def method_3(logits, input_ids):  # combine methods 1 and 2
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def efficient_method(logits, input_ids):  # pull everything out of the loop except logsumexp
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs

def measure_memory_and_time(func, logits, input_ids):
    torch.cuda.reset_peak_memory_stats()
    # warm up, especially for compile
    func(logits, input_ids)
    func(logits, input_ids)
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    result = func(logits, input_ids)
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    mem_peak = torch.cuda.max_memory_allocated()
    return result, end_time - start_time, mem_peak

torch.manual_seed(42)
vocab_size = 32768
seq_len = 1024
batch_size = 16

device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
logit_mem = torch.cuda.max_memory_allocated()

naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids)
method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids)
method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids)
method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids)
efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids)
compile_result, compile_time, compile_mem = measure_memory_and_time(compile_method, logits, input_ids)

print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item())
print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item())
print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item())
print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item())
print("Max absolute difference (naive and compile):", (naive_result - compile_result).abs().max().item())
print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6))
print("Naive method time:      {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6))
print("Method 1 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6))
print("Method 2 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6))
print("Method 3 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6))
print("Efficient method time:  {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6))
print("Compile method time:    {:.6f} sec, Memory peak: {:.2f} MB".format(compile_time, compile_mem / 1e6))
```

Results:
```
Max absolute difference (naive and 1): 1.9073486328125e-06
Max absolute difference (naive and 2): 0.0
Max absolute difference (naive and 3): 1.9073486328125e-06
Max absolute difference (naive and efficient): 1.9073486328125e-06
Max absolute difference (naive and compile): 9.5367431640625e-07
Memory consumed by logits: 2147.61 MB
Naive method time:      0.005133 sec, Memory peak: 4295.16 MB
Method 1 time:          0.008121 sec, Memory peak: 4295.36 MB
Method 2 time:          0.004993 sec, Memory peak: 2416.24 MB
Method 3 time:          0.009120 sec, Memory peak: 2282.10 MB
Efficient method time:  0.008911 sec, Memory peak: 2282.23 MB
Compile method time:    0.001566 sec, Memory peak: 2190.02 MB
```
StrongerXi added a commit to StrongerXi/verl that referenced this pull request Nov 13, 2025
This effectively replaces the efficient impl from volcengine#220 with a more
efficient and simpler compiled impl (see `logprobs_from_logits_naive`).

Results from `run_qwen3-8b.sh` with tp=1 on 8xH100 (tp=2 won't run for
some reason):
```
                  |  max-reserved-memory  |  max-allocated-memory  |
old chunked impl  |       139.55gb        |       115.91gb         |
new compiled impl |       129.19gb        |       115.91gb         |
```

Also, a slightly modified test script from volcengine#220 to show that compiled impl is
superior to all the other tested ones (note the added
`torch.cuda.synchronize()` which makes the benchmark more accurate).
```python
import time
import torch

@torch.compile
def compile_method(logits, input_ids):
    return -torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)).float(),
        input_ids.view(-1),
        reduction='none'
    ).view_as(input_ids)

def naive_method(logits, input_ids):
    log_probs = logits.log_softmax(dim=-1)
    return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def method_1(logits, input_ids):  # old logprobs_from_logits_v2 implementation
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.logsumexp(logits, dim=-1)
    token_log_probs = token_logits - logsumexp_values  # log_softmax(logits) = logits - log(sum(exp(logits)))
    return token_log_probs

def method_2(logits, input_ids):  # compute log_softmax in a loop to reduce peak memory
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def method_3(logits, input_ids):  # combine methods 1 and 2
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def efficient_method(logits, input_ids):  # pull everything out of the loop except logsumexp
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs

def measure_memory_and_time(func, logits, input_ids):
    torch.cuda.reset_peak_memory_stats()
    # warm up, especially for compile
    func(logits, input_ids)
    func(logits, input_ids)
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    result = func(logits, input_ids)
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    mem_peak = torch.cuda.max_memory_allocated()
    return result, end_time - start_time, mem_peak

torch.manual_seed(42)
vocab_size = 32768
seq_len = 1024
batch_size = 16

device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
logit_mem = torch.cuda.max_memory_allocated()

naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids)
method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids)
method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids)
method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids)
efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids)
compile_result, compile_time, compile_mem = measure_memory_and_time(compile_method, logits, input_ids)

print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item())
print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item())
print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item())
print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item())
print("Max absolute difference (naive and compile):", (naive_result - compile_result).abs().max().item())
print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6))
print("Naive method time:      {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6))
print("Method 1 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6))
print("Method 2 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6))
print("Method 3 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6))
print("Efficient method time:  {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6))
print("Compile method time:    {:.6f} sec, Memory peak: {:.2f} MB".format(compile_time, compile_mem / 1e6))
```

Results:
```
Max absolute difference (naive and 1): 1.9073486328125e-06
Max absolute difference (naive and 2): 0.0
Max absolute difference (naive and 3): 1.9073486328125e-06
Max absolute difference (naive and efficient): 1.9073486328125e-06
Max absolute difference (naive and compile): 9.5367431640625e-07
Memory consumed by logits: 2147.61 MB
Naive method time:      0.005133 sec, Memory peak: 4295.16 MB
Method 1 time:          0.008121 sec, Memory peak: 4295.36 MB
Method 2 time:          0.004993 sec, Memory peak: 2416.24 MB
Method 3 time:          0.009120 sec, Memory peak: 2282.10 MB
Efficient method time:  0.008911 sec, Memory peak: 2282.23 MB
Compile method time:    0.001566 sec, Memory peak: 2190.02 MB
```
chenjiaoAngel added a commit to chenjiaoAngel/verl that referenced this pull request Nov 14, 2025
)

Existing `logprobs_from_logits_v2` doesnt achieve the memory savings it
claims. This is because `logsumexp` still allocates a `bs*seqlen*vocab`
tensor internally to hold the element-wise application of `exp`.
However, by applying a loop over `logsumexp`, we can iteratively compute
logsumexp outputs.

Benchmarks show this uses significantly less memory to compute logprobs.

Fix provided, as well as a separate memory-efficient approach for
bfloat16 case.
StrongerXi added a commit to StrongerXi/verl that referenced this pull request Nov 20, 2025
This effectively replaces the efficient impl from volcengine#220 with a more
efficient and simpler compiled impl (see `logprobs_from_logits_naive`).

Results from `run_qwen3-8b.sh` with tp=1 on 8xH100 (tp=2 won't run for
some reason):
```
                  |  max-reserved-memory  |  max-allocated-memory  |
old chunked impl  |       139.55gb        |       115.91gb         |
new compiled impl |       129.19gb        |       115.91gb         |
```

Also, a slightly modified test script from volcengine#220 to show that compiled impl is
superior to all the other tested ones (note the added
`torch.cuda.synchronize()` which makes the benchmark more accurate).
```python
import time
import torch

@torch.compile
def compile_method(logits, input_ids):
    return -torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)).float(),
        input_ids.view(-1),
        reduction='none'
    ).view_as(input_ids)

def naive_method(logits, input_ids):
    log_probs = logits.log_softmax(dim=-1)
    return torch.gather(log_probs, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

def method_1(logits, input_ids):  # old logprobs_from_logits_v2 implementation
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.logsumexp(logits, dim=-1)
    token_log_probs = token_logits - logsumexp_values  # log_softmax(logits) = logits - log(sum(exp(logits)))
    return token_log_probs

def method_2(logits, input_ids):  # compute log_softmax in a loop to reduce peak memory
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def method_3(logits, input_ids):  # combine methods 1 and 2
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids):
        token_logits = torch.gather(logits_row, dim=-1, index=input_ids_row.unsqueeze(-1)).squeeze(-1)
        token_log_prob = token_logits - torch.logsumexp(logits_row, dim=-1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)

def efficient_method(logits, input_ids):  # pull everything out of the loop except logsumexp
    token_logits = torch.gather(logits, dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    logsumexp_values = torch.stack([torch.logsumexp(l, dim=-1) for l in logits])
    token_log_probs = token_logits - logsumexp_values
    return token_log_probs

def measure_memory_and_time(func, logits, input_ids):
    torch.cuda.reset_peak_memory_stats()
    # warm up, especially for compile
    func(logits, input_ids)
    func(logits, input_ids)
    torch.cuda.synchronize()
    start_time = time.perf_counter()
    result = func(logits, input_ids)
    torch.cuda.synchronize()
    end_time = time.perf_counter()
    mem_peak = torch.cuda.max_memory_allocated()
    return result, end_time - start_time, mem_peak

torch.manual_seed(42)
vocab_size = 32768
seq_len = 1024
batch_size = 16

device = "cuda" if torch.cuda.is_available() else "cpu"
logits = torch.randn(batch_size, seq_len, vocab_size, device=device, dtype=torch.float32)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
logit_mem = torch.cuda.max_memory_allocated()

naive_result, naive_time, naive_mem = measure_memory_and_time(naive_method, logits, input_ids)
method1_result, method1_time, method1_mem = measure_memory_and_time(method_1, logits, input_ids)
method2_result, method2_time, method2_mem = measure_memory_and_time(method_2, logits, input_ids)
method3_result, method3_time, method3_mem = measure_memory_and_time(method_3, logits, input_ids)
efficient_result, efficient_time, efficient_mem = measure_memory_and_time(efficient_method, logits, input_ids)
compile_result, compile_time, compile_mem = measure_memory_and_time(compile_method, logits, input_ids)

print("Max absolute difference (naive and 1):", (naive_result - method1_result).abs().max().item())
print("Max absolute difference (naive and 2):", (naive_result - method2_result).abs().max().item())
print("Max absolute difference (naive and 3):", (naive_result - method3_result).abs().max().item())
print("Max absolute difference (naive and efficient):", (naive_result - efficient_result).abs().max().item())
print("Max absolute difference (naive and compile):", (naive_result - compile_result).abs().max().item())
print("Memory consumed by logits: {:.2f} MB".format(logit_mem / 1e6))
print("Naive method time:      {:.6f} sec, Memory peak: {:.2f} MB".format(naive_time, naive_mem / 1e6))
print("Method 1 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method1_time, method1_mem / 1e6))
print("Method 2 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method2_time, method2_mem / 1e6))
print("Method 3 time:          {:.6f} sec, Memory peak: {:.2f} MB".format(method3_time, method3_mem / 1e6))
print("Efficient method time:  {:.6f} sec, Memory peak: {:.2f} MB".format(efficient_time, efficient_mem / 1e6))
print("Compile method time:    {:.6f} sec, Memory peak: {:.2f} MB".format(compile_time, compile_mem / 1e6))
```

Results:
```
Max absolute difference (naive and 1): 1.9073486328125e-06
Max absolute difference (naive and 2): 0.0
Max absolute difference (naive and 3): 1.9073486328125e-06
Max absolute difference (naive and efficient): 1.9073486328125e-06
Max absolute difference (naive and compile): 9.5367431640625e-07
Memory consumed by logits: 2147.61 MB
Naive method time:      0.005133 sec, Memory peak: 4295.16 MB
Method 1 time:          0.008121 sec, Memory peak: 4295.36 MB
Method 2 time:          0.004993 sec, Memory peak: 2416.24 MB
Method 3 time:          0.009120 sec, Memory peak: 2282.10 MB
Efficient method time:  0.008911 sec, Memory peak: 2282.23 MB
Compile method time:    0.001566 sec, Memory peak: 2190.02 MB
```
TimurTaepov pushed a commit to giorgossideris/verl that referenced this pull request Dec 20, 2025
)

Existing `logprobs_from_logits_v2` doesnt achieve the memory savings it
claims. This is because `logsumexp` still allocates a `bs*seqlen*vocab`
tensor internally to hold the element-wise application of `exp`.
However, by applying a loop over `logsumexp`, we can iteratively compute
logsumexp outputs.

Benchmarks show this uses significantly less memory to compute logprobs.

Fix provided, as well as a separate memory-efficient approach for
bfloat16 case.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants