diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 26a6775a..9671ed97 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -794,9 +794,8 @@ def restore_original_order(self, batched_states: Sequence[T]) -> list[T]: ordered_results = batcher.restore_original_order(results) """ - all_states = [ - state[i] for state in batched_states for i in range(state.n_systems) - ] + all_states = [state.split() for state in batched_states] + all_states = list(chain.from_iterable(all_states)) original_indices = list(chain.from_iterable(self.index_bins)) if len(all_states) != len(original_indices):