Skip to content

Rework Model Context#323

Merged
simlang merged 9 commits intomainfrom
refactor/rework-model-context
Aug 28, 2025
Merged

Rework Model Context#323
simlang merged 9 commits intomainfrom
refactor/rework-model-context

Conversation

@simlang
Copy link
Copy Markdown
Member

@simlang simlang commented Aug 28, 2025

Description

This PR refactors the ModelContext abstraction
Before the ModelContext used the incoming pipeline as a storage, for the smashed model, now the context itself is returned to get resources from.
Using pipeline as storage device, lead to problems when e.g. using the combination hqq+torch.compile

The changes include:

  1. instead of providing: pipeline, working model and denoiser_type, only provide model_context and working_model (denoiser_type was only used by one algorithm and can now be accessed via model_context.denoiser_type if needed)
  2. at end of smashing, set_smashed_working_model(smashed_model) has to be called to tell the context, that the working model has changed
  3. on context exit, the internal state of the pipeline/model given to the context is updated, but only if set_smashed_working_model has been called before - otherwise nothing happens. This allows us to use the context also, if the working model is not adapted
  4. since the model given to the context might be immutable (if it's the working model and not a pipeline) we can't directly change it - this means to get the updated pipeline/model we have to call model_context.get_smashed() to return after smashing

Related Issue

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Locally ran all tests for algorithms which use ModelContext

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

There might be a point of moving the ModelContext to the apply wrapper, to avoid duplicate code

@simlang simlang added the bug Something isn't working label Aug 28, 2025
@simlang simlang requested review from gsprochette and llcnt August 28, 2025 12:24
Comment thread src/pruna/engine/utils.py Outdated
Comment thread src/pruna/engine/utils.py Outdated
Comment thread src/pruna/engine/utils.py Outdated
Comment on lines +583 to +603
def set_smashed_working_model(self, working_model: Any) -> None:
"""
Set the smashed working model.

Parameters
----------
working_model : Any
The smashed working model.
"""
self.smashed_working_model = working_model

def get_smashed(self) -> "ModelMixin":
"""
Get the smashed model.

Returns
-------
ModelMixin
The smashed model.
"""
return self.smashed_pipeline
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason for using this instead of mc.smashed_working_model = working_model and mc.smashed_pipeline from outside?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This self.smashed_working_model = working_model was super cryptic to me, when i first read it before. So i add functions to name what is happening - so just readability

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in all cases the user needs to know what they are doing with the ModelContext, and these setter and getter are adding complexity... Using these variables could be explained within an error raised in the __exit__ if smashed_working_model wasn't set.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would rename the functions, as with the current names there is not really a difference to just assigning and reading.
i like having this abstraction however, without ever seeing the inside of this context it is hard to understand from the outside what is happening

Comment thread src/pruna/engine/utils.py Outdated
Comment thread src/pruna/engine/utils.py Outdated
Comment thread src/pruna/engine/utils.py Outdated
Comment thread src/pruna/engine/save.py Outdated
Comment thread src/pruna/engine/model_checks.py Outdated
Copy link
Copy Markdown
Collaborator

@gsprochette gsprochette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job fixing this ! We're using a context in a weird way so we should spend a minute making it extra clear and extra clean, all my comments go in that direction :)

Copy link
Copy Markdown
Collaborator

@llcnt llcnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for the fix, it is much more clean :)

Comment thread src/pruna/engine/utils.py


class ModelContext:
class ModelContext(AbstractContextManager):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this AbstractContextManager provide us ? Any reason why we want our mc to depend on it? :)

Copy link
Copy Markdown
Member Author

@simlang simlang Aug 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh nothing - just readability that this is a ContextManager - should i remove it?

Comment thread src/pruna/engine/utils.py Outdated
@simlang simlang requested a review from gsprochette August 28, 2025 15:00
cursor[bot]

This comment was marked as outdated.

Copy link
Copy Markdown
Collaborator

@gsprochette gsprochette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's almost ready to go, I left a couple of comments in the read_only check because it can still be improved. Once this is done, you can merge :)

Comment thread src/pruna/engine/utils.py
"""
if self.smashed_working_model is None:
return
if self.read_only:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also check if self.read_only and self.smashed_working_model is not None because this is bound to produce a cryptic bug

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also good point!

Comment thread src/pruna/engine/utils.py Outdated
Comment thread src/pruna/engine/utils.py
@simlang simlang merged commit 179c3b2 into main Aug 28, 2025
7 checks passed
@simlang simlang deleted the refactor/rework-model-context branch August 28, 2025 15:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants