feat: add janus support for quantization+torch.compile combo(s)#145
feat: add janus support for quantization+torch.compile combo(s)#145
Conversation
fdf809e to
1a6c0a5
Compare
sharpenb
left a comment
There was a problem hiding this comment.
Very cool to have it working for Janus! The PR description is clear. Could we add a small benchmark for this model? Main points are about how ot improve better the code factorization to avoid redundant work in the future.
gsprochette
left a comment
There was a problem hiding this comment.
Awesome work, can't wait to have this working for Janus in Pruna! I also hope that the HF PR for adding generate functions to LlamaGen models will be added soon so we can remove the Generator code. For the rest of the code I left a few comments, nothing breaking only trying to improve already good code and make some double checks here and there :)
1a6c0a5 to
92b69ab
Compare
2421573 to
4d8daad
Compare
There was a problem hiding this comment.
Bug: Model Loading Fails Due to Incorrect `lm_head` Handling
When loading Janus-like models, the load_hqq_model function attempts to add a dummy lm_head. This process introduces several issues:
- The
lm_headis created with hardcoded (1024, 1024) dimensions, which may not match the model's actual hidden size, leading to shape mismatches. - Its randomly initialized weights are incorrectly added to the model's state dictionary as a nested dictionary under the "lm_head" key, instead of flattening its parameters (e.g., "lm_head.weight", "lm_head.bias"). This will cause loading failures.
- The function loads
qmodel.pt, modifies its contents with these problematiclm_headweights, and then overwrites the original file in-place. This is an unexpected and potentially harmful side effect for a load operation.
src/pruna/engine/load.py#L359-L365
Was this report helpful? Give feedback by reacting with 👍 or 👎
gsprochette
left a comment
There was a problem hiding this comment.
Looks good to me, there's just this expected_quantized_model_path that could be re-used to make the code a bit cleaner, see comment. If Bertrand is satisfied with your answers this is ready to go for me, thanks for the updates :)
sharpenb
left a comment
There was a problem hiding this comment.
Thanks for the details! I left some comments. Happy to discuss them if needed :)
4d8daad to
a9d3d84
Compare
|
I have edited the main comment, and I am fixing the code (will appear in the next push), for enabling speedups with torch==2.7 |
35bb705 to
3341a90
Compare
There was a problem hiding this comment.
Step 3. and Step 6. prepare inputs in some way. Could we factorize these in one step?
There was a problem hiding this comment.
Nice catch! Indeed step6 only depends on the previous step3, I merged it and put setp6 into the sub-function "self.prepare_inputs_tokens" :)
sharpenb
left a comment
There was a problem hiding this comment.
Thanks for tackling the comments! I left a couple more but I think that it should be good to go then
There was a problem hiding this comment.
It is unlcear why step 5. prepare logit processors while step 4. alreayd used the suer passed processor. Could we also merge all processors preparation step?
There was a problem hiding this comment.
You are definitely right ;) I merged all process related to logits_processor into a single sub-function
(ps: it is maybe unclear in the function name self.model._get_logits_processor, but this function will use the function from transformers that merge the user-defined logits_processor together with the logits_processor defined into the generation_config. I have added a small comment in the function to highlight this)
ae2d1f8 to
7e6603f
Compare
Description
This goal of this PR is to decrease the memory impact and the latency of the janus(pro-7b) model.
This model are based on Llamagen and compute tokens in a latent space with an autoregressive fashion, thanks to an attribute (here called
language_model) defined as a llama model.There is currently no standardization regarding llamagen AR models, so this PR is exclusively dedicated to janus models (that are compatible with the
transformerspackage). But we expect in the near future (see this thread) that llmgenAR models will have similar.generate()functions.The idea of the code change is:
JanusGenerator(similar to the TransformersGenerator we already had, and that is renamed CausalLMGenerator);diffuserspipeline, but more tricky as there is no (yet) pipeline for janus);The above points can be extended to other llm quantizer. However, testing the code and adapting all save/load functions is time consuming. I leave this work for a future PR.
Related Issue
Fixes #(issue number)
Type of Change
How Has This Been Tested?
When quantized with hqq4bits and combined with torch.compile, we can obtain a ~x3 speedup.
I have added 2 unit tests, and provided (below) a notebook for reproducing the results.
Edit: The notebook works well for torch==2.5.1. For torch==2.7, major changes has been introduced in torch dynamo, we have to slightly adapt the smash_config: smash_config['torch_compile_fullgraph'] = False
smash_config['torch_compile_mode'] = 'default'
smash_config['torch_compile_backend'] = 'inductor'.
Otherwise the compilation step takes very long, and is reapplied at each step with torch==2.7
Checklist
Additional Notes