Fix torch.compile recompilation issue with HF modeling + TP
#7
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixing the bug #6
TODO: need to apply change in
transformers V5. That requires to wait for V5 to be a bit stable before switch torchtitan transformers modeling backend to v5 (as for now, it relies on 4.57.1)Issue
Fix
transformersatmodeling_llama.py, change./tooling_dev/debug_local.sh debugperf_large --compileExplanation
torch.compiletraces your model, it creates a compiled graph along with guards. Guards are conditions that must be true for that graph to be reused. If guard fails,torch.compilewill recompiles.modeling_llama.py, theself.attn(hidden_states=hidden_states)is called withkwargsregister_forward_pre_hook. However, depending on if you usekwargsor not, it will call different function (cf https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L576).module.register_forward_pre_hook(lambda _, inputs, kwargs: some_fn(inputs, kwargs), with_kwargs=Trueif hook_id in self._forward_pre_hooks_with_kwargs:(cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1808)kwargswill results in differenthook_id, hence the error___dict_contains(148, self._modules['_checkpoint_wrapped_module']._modules['self_attn']._forward_pre_hooks_with_kwargs)kwargs,self._forward_pre_hooks_with_kwargswill always be empty (cf https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py#L1679C13-L1679C48) so the if check is not triggered, so each attention layer has samehook_id, thus no recompile