Skip to content

feat: add janus support for quantization+torch.compile combo(s)#145

Merged
llcnt merged 14 commits intomainfrom
feat/llamagen_ar_janus_support
Jul 8, 2025
Merged

feat: add janus support for quantization+torch.compile combo(s)#145
llcnt merged 14 commits intomainfrom
feat/llamagen_ar_janus_support

Conversation

@llcnt
Copy link
Copy Markdown
Collaborator

@llcnt llcnt commented May 21, 2025

Description

This goal of this PR is to decrease the memory impact and the latency of the janus(pro-7b) model.
This model are based on Llamagen and compute tokens in a latent space with an autoregressive fashion, thanks to an attribute (here called language_model) defined as a llama model.
There is currently no standardization regarding llamagen AR models, so this PR is exclusively dedicated to janus models (that are compatible with the transformers package). But we expect in the near future (see this thread) that llmgenAR models will have similar .generate() functions.

The idea of the code change is:

  • create a JanusGenerator (similar to the TransformersGenerator we already had, and that is renamed CausalLMGenerator);
  • adapt the context manager to be able to deal with Janus models;
  • adapt HQQ to be able to work on Janus;
  • adapt hqq save function (distinct savings for the lm model and the rest: similar to what is done for diffusers pipeline, but more tricky as there is no (yet) pipeline for janus);
  • adapt hqq load function (also distinct loads for the LM model and the rest).
    The above points can be extended to other llm quantizer. However, testing the code and adapting all save/load functions is time consuming. I leave this work for a future PR.

Related Issue

Fixes #(issue number)

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

When quantized with hqq4bits and combined with torch.compile, we can obtain a ~x3 speedup.
I have added 2 unit tests, and provided (below) a notebook for reproducing the results.
Edit: The notebook works well for torch==2.5.1. For torch==2.7, major changes has been introduced in torch dynamo, we have to slightly adapt the smash_config: smash_config['torch_compile_fullgraph'] = False
smash_config['torch_compile_mode'] = 'default'
smash_config['torch_compile_backend'] = 'inductor'.
Otherwise the compilation step takes very long, and is reapplied at each step with torch==2.7

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

  • A notebook for reproducing results is provided here;
  • The current Janus models lags behind the current sota flow matching models. It is more comparable to the first versions of midjourney and dalee's, as you can see in this blog.

@llcnt llcnt requested review from johnrachwan123 and sharpenb May 23, 2025 15:13
@llcnt llcnt marked this pull request as ready for review May 23, 2025 15:17
@llcnt llcnt force-pushed the feat/llamagen_ar_janus_support branch from fdf809e to 1a6c0a5 Compare June 2, 2025 08:14
Copy link
Copy Markdown
Member

@sharpenb sharpenb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool to have it working for Janus! The PR description is clear. Could we add a small benchmark for this model? Main points are about how ot improve better the code factorization to avoid redundant work in the future.

@llcnt llcnt requested review from gsprochette and removed request for johnrachwan123 June 3, 2025 08:29
Copy link
Copy Markdown
Collaborator

@gsprochette gsprochette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work, can't wait to have this working for Janus in Pruna! I also hope that the HF PR for adding generate functions to LlamaGen models will be added soon so we can remove the Generator code. For the rest of the code I left a few comments, nothing breaking only trying to improve already good code and make some double checks here and there :)

@llcnt llcnt force-pushed the feat/llamagen_ar_janus_support branch from 1a6c0a5 to 92b69ab Compare June 17, 2025 16:34
@llcnt llcnt requested review from gsprochette and sharpenb June 17, 2025 16:34
cursor[bot]

This comment was marked as outdated.

@llcnt llcnt force-pushed the feat/llamagen_ar_janus_support branch from 2421573 to 4d8daad Compare June 18, 2025 09:29
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Model Loading Fails Due to Incorrect `lm_head` Handling

When loading Janus-like models, the load_hqq_model function attempts to add a dummy lm_head. This process introduces several issues:

  1. The lm_head is created with hardcoded (1024, 1024) dimensions, which may not match the model's actual hidden size, leading to shape mismatches.
  2. Its randomly initialized weights are incorrectly added to the model's state dictionary as a nested dictionary under the "lm_head" key, instead of flattening its parameters (e.g., "lm_head.weight", "lm_head.bias"). This will cause loading failures.
  3. The function loads qmodel.pt, modifies its contents with these problematic lm_head weights, and then overwrites the original file in-place. This is an unexpected and potentially harmful side effect for a load operation.

src/pruna/engine/load.py#L359-L365

https://github.com/PrunaAI/pruna/blob/4d8daadf1d2cdbcd6a9ae71f37e274e2e1dfb33f/src/pruna/engine/load.py#L359-L365

Fix in Cursor


Was this report helpful? Give feedback by reacting with 👍 or 👎

Copy link
Copy Markdown
Collaborator

@gsprochette gsprochette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, there's just this expected_quantized_model_path that could be re-used to make the code a bit cleaner, see comment. If Bertrand is satisfied with your answers this is ready to go for me, thanks for the updates :)

Copy link
Copy Markdown
Member

@sharpenb sharpenb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the details! I left some comments. Happy to discuss them if needed :)

@llcnt llcnt force-pushed the feat/llamagen_ar_janus_support branch from 4d8daad to a9d3d84 Compare June 30, 2025 13:30
@llcnt
Copy link
Copy Markdown
Collaborator Author

llcnt commented Jul 1, 2025

I have edited the main comment, and I am fixing the code (will appear in the next push), for enabling speedups with torch==2.7

@llcnt llcnt force-pushed the feat/llamagen_ar_janus_support branch from 35bb705 to 3341a90 Compare July 2, 2025 17:10
@llcnt llcnt requested review from gsprochette and sharpenb July 2, 2025 17:17
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Step 3. and Step 6. prepare inputs in some way. Could we factorize these in one step?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! Indeed step6 only depends on the previous step3, I merged it and put setp6 into the sub-function "self.prepare_inputs_tokens" :)

Copy link
Copy Markdown
Member

@sharpenb sharpenb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tackling the comments! I left a couple more but I think that it should be good to go then

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is unlcear why step 5. prepare logit processors while step 4. alreayd used the suer passed processor. Could we also merge all processors preparation step?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are definitely right ;) I merged all process related to logits_processor into a single sub-function

(ps: it is maybe unclear in the function name self.model._get_logits_processor, but this function will use the function from transformers that merge the user-defined logits_processor together with the logits_processor defined into the generation_config. I have added a small comment in the function to highlight this)

@llcnt llcnt force-pushed the feat/llamagen_ar_janus_support branch from ae2d1f8 to 7e6603f Compare July 8, 2025 13:19
@llcnt llcnt merged commit ec79dcd into main Jul 8, 2025
6 checks passed
@johannaSommer johannaSommer deleted the feat/llamagen_ar_janus_support branch July 9, 2025 07:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants