Skip to content

split instead or sequential getitem (slicing under the hood)#546

Merged
CompRhys merged 1 commit intoTorchSim:mainfrom
thomasloux:feat/improve_speed_autobatching
Apr 16, 2026
Merged

split instead or sequential getitem (slicing under the hood)#546
CompRhys merged 1 commit intoTorchSim:mainfrom
thomasloux:feat/improve_speed_autobatching

Conversation

@thomasloux
Copy link
Copy Markdown
Collaborator

@thomasloux thomasloux commented Apr 15, 2026

Summary

Small change to improve speed. __getitem__ makes a lot of slicing and reindexing.
In normal setting, there is a 3x factor difference for overall, low duration, so that's fine anyway.

Interestingly, it combines super poorly with FixSymmery, script at the end:
On cpu: for 100 systems
CustomAutoBatcher duration: 0.0091 seconds
Original BinningAutoBatcher duration: 6.1686 seconds
On gpu:
CustomAutoBatcher duration: 0.0409 seconds
Original BinningAutoBatcher duration: 0.9110 seconds

I'm currently investigating FixSymmetry as the slicing is super costly.

from torch_sim.autobatching import BinningAutoBatcher
from itertools import chain
from typing import Sequence, TypeVar
import torch_sim as ts
from torch_sim.optimizers import OPTIM_REGISTRY
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

T = TypeVar('T')

# inherits to modify restore_original_order

class CustomAutoBatcher(BinningAutoBatcher):
    def restore_original_order(self, batched_states: Sequence[T]) -> list[T]:
        # results is a list of lists of results for each batch
        # we need to flatten it and then reorder it according to original_indices
        all_states = [state.split() for state in batched_states]
        all_states = list(chain.from_iterable(all_states))
        original_indices = list(chain.from_iterable(self.index_bins))

        if len(all_states) != len(original_indices):
            raise ValueError(
                f"Number of states ({len(all_states)}) does not match "
                f"number of original indices ({len(original_indices)})"
            )

        # sort states by original indices
        indexed_states = list(zip(original_indices, all_states, strict=True))
        return [state for _, state in sorted(indexed_states, key=lambda x: x[0])]
    

from ase.build import bulk

def speed_auto_batching(autobatcher: CustomAutoBatcher):
  import torch
  import numpy as np

  torch_dtype = torch.float32
  device = "cuda" if torch.cuda.is_available() else "cpu"
  # device = "cpu"
  print(device)

  init_fn, step_fn = OPTIM_REGISTRY["bfgs"]


  structure = bulk("Al", "fcc", a=4.05, cubic=True)
  structures = [structure * np.random.randint(1, 3) for _ in range(100)]
  print(len(structures), "structures created")
  # structures[0] = structure * (3, 3, 3)
  initial_state = ts.initialize_state(structures, device=device, dtype=torch_dtype)

  initial_state.constraints = ts.constraints.FixSymmetry.from_state(
      initial_state, symprec=0.1
    )

  autobatcher.load_states(initial_state)
  initialized = [
    batch for batch, _indices in autobatcher
  ]
  print("Number of batches:", len(initialized))
  start = time.time()
  for _ in range(1):
    # step_fn should be called with the original order of states, not the batched order
    _ = autobatcher.restore_original_order(initialized)
  end = time.time()
  return end - start

original_autobatcher = BinningAutoBatcher(
  model=None,
  max_memory_scaler=10_000,
  memory_scales_with="n_atoms"
)

autobatcher = CustomAutoBatcher(
  model=None,
  max_memory_scaler=10_000,
  memory_scales_with="n_atoms"
)

import time

duration = speed_auto_batching(autobatcher)
original_duration = speed_auto_batching(original_autobatcher)
print(f"CustomAutoBatcher duration: {duration:.4f} seconds")
print(f"Original BinningAutoBatcher duration: {original_duration:.4f} seconds ")

Checklist

Before a pull request can be merged, the following items must be checked:

  • Doc strings have been added in the Google docstring format.
  • Run ruff on your code.
  • Tests have been added for any new functionality or bug fixes.

@thomasloux
Copy link
Copy Markdown
Collaborator Author

thomasloux commented Apr 15, 2026

Claude suggestion:

Performance Bottlenecks in FixSymmetry + State Operations

I identified several sources of slowness, with the GPU-vs-CPU difference explained by implicit synchronization points.

  1. copy.deepcopy in _filter_attrs_by_index (state.py:1079) — Biggest bottleneck

for con in copy.deepcopy(state.constraints)

copy.deepcopy on a FixSymmetry constraint deep-copies every tensor in rotations (list of (n_ops, 3, 3) tensors), symm_maps (list of (n_ops, n_atoms) tensors), and reference_cells (list of (3, 3) tensors). On GPU, each
tensor deep-copy triggers a CUDA synchronization — the Python-level deepcopy calls .clone() then waits for it. For N systems with many symmetry ops, this is N×3 GPU syncs per slice.

This is called by both getitem (→ _slice_state) and _pop_states, which are core to autobatching.

Fix: select_constraint already produces a new object with selected data — the deep copy is redundant. Replace copy.deepcopy(state.constraints) with direct iteration, letting select_constraint handle the copying
internally. Or make select_constraint explicitly copy only what it needs (it already does this for FixSymmetry).

  1. SystemConstraint.init duplicate check (constraints.py:336) — GPU sync per construction

if len(system_idx) != len(torch.unique(system_idx)):

torch.unique on a GPU tensor forces a sync. This runs every time a FixSymmetry is constructed — during reindex, merge, select_constraint, select_sub_constraint, and to. In autobatching workflows this adds up fast.

Fix: Skip the uniqueness check in internal construction paths (e.g., add a _skip_validation parameter or a classmethod factory), or move the check to CPU.

  1. select_sub_constraint — Multiple GPU syncs per system (constraints.py:1108-1126)

if sys_idx not in self.system_idx: # GPU sync: contains on CUDA tensor
local = (self.system_idx == sys_idx).nonzero(as_tuple=True)[0].item() # GPU sync: .item()

Called per system during _split_state (state.py:1192-1195). For N systems, that's 2N GPU syncs minimum — plus N more from the SystemConstraint.init uniqueness check in the constructor.

Fix: Batch the lookups. Convert self.system_idx to a CPU dict/set once, or rewrite _split_state to pass all system indices at once and let the constraint do a single batched selection.

  1. select_constraint uses .tolist() (constraints.py:1092)

local_idx = mask.nonzero(as_tuple=False).flatten().tolist()

.tolist() on a CUDA tensor forces a GPU→CPU transfer + sync. Then it uses the Python list to index into Python lists (self.rotations[idx]), which is fine, but the sync is the cost.

Fix: Move this to CPU before the .tolist(), or keep system_idx on CPU for index-tracking purposes.

For the deep.copy, it's indeed 10x faster. I'm not super confortable not copying the tensor, because of weird errors with border effects. But at the same time it's a really big difference and there's not reason to manually change a constraint like this.

They are almost no difference when considering the fix of this PR, so maybe constraints need a dedicated split method.

@CompRhys
Copy link
Copy Markdown
Member

Looks fine to me, for sanity can you time some of the methods here before and after the change to show that it doesn't have deleterious effects elsewhere: https://github.com/TorchSim/torch-sim/blob/main/examples/benchmarking/scaling.py

Copy link
Copy Markdown
Collaborator

@orionarcher orionarcher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

@CompRhys CompRhys merged commit ee60700 into TorchSim:main Apr 16, 2026
63 checks passed
@thomasloux
Copy link
Copy Markdown
Collaborator Author

thomasloux commented Apr 16, 2026

@CompRhys It's marginally better with the changes (hardly better in this case) for benchmarking/scaling.py, but it has a dramatic change because of the deep.copy of constraints, so the change is expected to be much better in this case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants