feat: add unconstrained hyperparameter#263
Conversation
sharpenb
left a comment
There was a problem hiding this comment.
Thanks for this PR! This is a really important one :)
49603c2 to
7f16056
Compare
29be5b1 to
4a0cff4
Compare
sharpenb
left a comment
There was a problem hiding this comment.
Thanks for addressing my comments! Left some last points but this is ready to be merged ;) A question is which PR will include the small tutorial? The users will like it!
There was a problem hiding this comment.
I was wondering if we could get rid of this 0 index by removing () in the return of get_default_hyperparameters :)
There was a problem hiding this comment.
Sure, I loosened the return type of the method in PrunaAlgorithmBase so that I don't have to return a tuple. In the future if we need to return multiple defaults we can still make it a tuple for other algorithms.
There was a problem hiding this comment.
Sensational work!
Outside of the comments above, I feel one thing that we are missing is informing the user about which Target Modules they are selecting (because who can really write correct regex/fnmatch patterns lol). I think what I would be expecting is that either in a seperate function before smash, or better, in the to_list_of_module_paths function, we add some significant logging.
Some ideas but you know best:
- raise an error if a pattern (exclude/include) doesnt find any hits? or at least log some error?
- raise a controlled/informative error if you are requesting some module that doesnt exist?
- at either "INFO" or possibly "DEBUG" level log the modules we are actually selecting. I want to avoid that the user misspecifies the pattern and e.g. the algo isnt applied at all.
src/pruna/algorithms/pruna_base.py
Outdated
There was a problem hiding this comment.
I think this naming can be a bit confusing because i would expect this to return e.g. also the default of the other hyperparameters like cache_interval=2 or so. Something like get_user_specified_hyperparameter_default (ok thats way too long but something along those lines?
There was a problem hiding this comment.
Makes sense. All the hyperparameters are user-specified so I changed it to get_unconstrained_hyperparameter_defaults to reference the fact that is only gets the defaults for hyperparameters which are instances of UnconstrainedHyperparameter. It's long but making it shorter would make it less explicit...
There was a problem hiding this comment.
We should use the functionality as defined above here. I know you can't directly use the target module result here but we should use it in some capacity so that the logic for quanto is centralized if we adjust this in the future (probs we wont adjust it but you get me, for duplication reasons)
There was a problem hiding this comment.
In fact working model has no place here, it can simply be the model itself when it's a LLM. For diffusers, the if statement should ignore that anyway (although you can technically add a tokenizer in the smash config after the dataloader and get an error here because it doesn't run).
In all cases, we can't use parts of the model here because it's called directly on the data, meaning if we get the diffuser's transformer the prompt won't be tokenized.
As you say, we should fix that but this is for an other PR. I'll leave it in the "calibrate LLM only" state as it was before, if you agree.
There was a problem hiding this comment.
I would say we should only have one loop through the modules_with_subpaths and then to the quanto logic. From what I understand, it makes only sense to freeze it if we actually quantized it
There was a problem hiding this comment.
hm nvm we raise an error... Still, maybe it would be clearer to loop once and then apply the algorithm logic inside?
There was a problem hiding this comment.
I agree it would be cleaner but we still need to calibrate inbetween. Calibration needs the whole quantization to already be done since we need to run the full model, meaning we can't really freeze inside the first loop.
Here we only freeze the attributes (e.g. unet or transformer) that contain targeted modules, and the freeze function checks each nn.Module to see if it's been quantized anyway. The only reason for iterating over modules_with_subpath is to guarentee that module is a nn.Module, because for diffusers model is not a nn.Module. I changed it to get_nn_modules which is a simpler function and achieves the same thing. Is it ok for you?
2ecc2b6 to
a27bdd1
Compare
49b427b to
fca277e
Compare
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
2430c10 to
b77d954
Compare
f495878 to
3fc590d
Compare
cdea28d to
79eff17
Compare
79eff17 to
731b716
Compare
src/pruna/algorithms/pruna_base.py
Outdated
There was a problem hiding this comment.
This function is both too broad and too narrow. At the moment we only have target_modules as the only unconstrained hyperparameter, so for now we could change this to get_target_module_defaults.
I am however for making it applicable to any future unconstrained hyperparameter. This is not how this function is implemented, or at least not clear by the signature (as the return type is Any, technically everything is supported)
In the case where there is a future second unconstrained hyperparam, we don't know which default belongs to what.
-> change the return type to Dict[str, Any] - so for any hyperparameter we can get a default, but indexed by the param it concerns
Is there an argument to extend it to something like get_model_dependent_hyperparameter_defaults - why keep it constrained to the unconstrained hyperparameters? This way any parameter which might need information from the model/SmashConfig can get a default
There was a problem hiding this comment.
I'm torn, because on one hand I agree with you and I do think a dictionary would be the better option, both general and very readable.
On the other hand, this will probably most ofter be used only for target modules, in which case it's much clearer to just return default_target_modules instead of {"target_modules": default_target_modules}.
Note that this is why the return type is Any currently, you can handle both: returning just the default hyperparam if there is a single one, and return a dict if there are multiple. The method in child classes can always make the return type more specific. I didn't want to force a heavy TypedDict class every time we redefine this method.
I agree with the get_model_dependent_hyperparameter_defaults, it's a better name.
There was a problem hiding this comment.
As a conclusion: I'm in favor of
- renaming the method
- keep the signature as is to leave more flexibility
Refactoring this return type wouldn't be too much of a problem anyway.
There was a problem hiding this comment.
This is unnecessary duplicate code in every _apply.
I would be for moving this into the apply wrapper. Connected to the proposition to change the return type of get_unconstrained_hyperparameter_defaults.
In pruna_base.apply before calling _apply add something like:
defaults = self.get_unconstrained_hyperparameter_defaults(model, smash_config)
for key in defaults.keys() if smash_config[key] is None:
smash_config[key] = defaults[key]
self._apply(model, smash_config)
There was a problem hiding this comment.
The reason I don't want to do that is that it modifies the smash_config in-place.
In the case where a user would use the same smash_config to smash two different models, the second would have the defaults of the first one, instead of looking for its own default.
We can discuss doing a copy of the smash config at the start of smash, but that would be a refactor PR and not this one.
For now, I personnally prefer the duplicated "check for default None and replace the value" which is pretty standard in python I would say. WDYT?
There was a problem hiding this comment.
the in-place problem is a very good point - have not thought of this!
For now it won't hurt anyone, i still think it's cleaner - maybe part of a future algorithm application refactoring
simlang
left a comment
There was a problem hiding this comment.
Very cool! Apart from some naming stuff, I added a proposition to handle the defaults, let me know what you think about that
d57923f to
63e4bb4
Compare
| value["include"] = ["*"] | ||
| elif "exclude" not in value: | ||
| value["exclude"] = [] # for consistency | ||
|
|
There was a problem hiding this comment.
Bug: Inconsistent Defaulting and Parameter Mutation
The legal_value method modifies the input value dictionary in-place by adding default "include" and "exclude" keys. This violates the principle of not mutating input parameters, which can lead to unexpected side effects if the dictionary is reused. Additionally, the defaulting logic is inconsistent, as the "exclude" key might not be added when "include" is added.
Description
Separate smash_space and custom hyperparameters, and add a new hyperparameter class:
Related Issue
Fixes #(issue number)
Type of Change
How Has This Been Tested?
Checklist
Additional Notes