feat: add pre-smash-hook for model preparation#309
Conversation
gsprochette
left a comment
There was a problem hiding this comment.
Looks almost good to me, I left a few comments:
- some docstring and name related stuff
- a discussion about undoing the setup after the smash.
It would be nice to have a unit test to make sure that the pre_smash_setup is executed when the aglo is activated and not executed otherwise, this can be done with a monkey patch of an existing method.
src/pruna/engine/pre_smash_setup.py
Outdated
| for current_group in ALGORITHM_GROUPS: | ||
| algorithm = smash_config[current_group] | ||
| if algorithm is not None: | ||
| check_algorithm_availability(algorithm, current_group, algorithm_dict) |
There was a problem hiding this comment.
This call is repeated many times: 1. in this function, 2. in the smash loop and 3. in check_model_compatibility. Should we take this opportunity to define a check_active_algorithm_availabilities function and running it at the beginning of smash?
There was a problem hiding this comment.
i'm not sure if this PR is the correct place for it? but generally I agree - @johannaSommer thoughts on this?
There was a problem hiding this comment.
I agree it's not really the PR for it could we clean this up while we're working on this part of the code? Meaning in a follow up PR (my favorite option) or directly here. Would you be ok with that? Johanna do you have a strong opinion about this?
src/pruna/algorithms/pruna_base.py
Outdated
| """ | ||
| return [] | ||
|
|
||
| def pre_smash_setup(self, model: Any, smash_config: SmashConfig) -> None: |
There was a problem hiding this comment.
Is there an argument for a post_smash_? as well? For example if the pre_smash_setup computes something based on the pre-smashed model and stores it in smash_config, a post_smash could be the opportunity to destroy it and restore the smash_config to its original state before the smash function was applied. What do you think?
There was a problem hiding this comment.
I think there could be an argument for it.
For e.g. recovery it could make sense to create the dataset in the pre-smash, replace the current data and then in a potential post_smash insert the original one again.
What do you think @johannaSommer?
src/pruna/smash.py
Outdated
| check_model_compatibility(model, smash_config) | ||
|
|
||
| # perform any necessary setup steps before the smashing process begins | ||
| pre_smash_setup(model, smash_config) |
There was a problem hiding this comment.
right now this only allows inplace operation. After discussion with Simon we can leave it like that for now but should add documentation/justification as to why
src/pruna/engine/pre_smash_setup.py
Outdated
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations |
There was a problem hiding this comment.
i would slightly prefer not putting this into a separate file and putting it into the "compatibility checks" file as there we have all the pre-smash checks and setup (e.g. device casting). If you feel the naming of the file is a problem feel free to change it to pre_smash_setup or so
There was a problem hiding this comment.
I'm fine with either options. The point is to keep both the engine dir and the pre_smash_setup file manageable. Each direction is a problem only if we have too many of those setup functions, in which case we can split them into multiple files in a pres_smash_setup directory instead. As long as we don't have that sort of problem I don't have a strong opinion :)
src/pruna/engine/pre_smash_setup.py
Outdated
| algorithm = smash_config[current_group] | ||
| if algorithm is not None: | ||
| check_algorithm_availability(algorithm, current_group, algorithm_dict) | ||
| algorithm_dict[current_group][algorithm].pre_smash_setup(model, smash_config) |
There was a problem hiding this comment.
merge with existing function in compatibility check and possible adjust naming? -> pre_smash_hook?
src/pruna/algorithms/pruna_base.py
Outdated
|
|
||
| def _pre_smash_hook(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> None: | ||
| """ | ||
| Function to be overridden by an algorithm to perform a pre-smash setup. |
There was a problem hiding this comment.
pre-smash-hook instead of setup in the doc?
| model, smash_config = model_fixture | ||
|
|
||
| pre_smash_hook_called = False | ||
| def mock_pre_smash_hook(self: LLMInt8Quantizer, model: Any, smash_config: SmashConfigPrefixWrapper) -> None: |
gsprochette
left a comment
There was a problem hiding this comment.
Thanks for adressing every comment, this looks super good :) I may have found a typo in the _pre_smash_hook docstring but other than that it's ready to merge 🤩
The post_smash_hook thing, we can probably wait until we have a case where we need it, or simply add it in a follow-up PR.
johannaSommer
left a comment
There was a problem hiding this comment.
Thanks Simon! Agree with @gsprochette that we should keep the post smash hook in mind but since no algorithm needs it at the moment let's keep it for later :)
Description
This PR introduces new functionality, such that any algorithm can implement a setup function which is called before any smashing algorithm is applied.
To do pre-smash-setup an algorithm has to override
_pre_smash_setupto apply in-place operations on the model based on information provided in the SmashConfig for that specific algorithm.Related Issue
Type of Change
How Has This Been Tested?
If no algorithm overrides
_pre_smash_setupthere should be no change of functionality compared to the current version.Therefore to test, I successfully ran the existing tests
Checklist
Additional Notes