Skip to content

feat: enhance model checks for transformers pipelines#281

Merged
davidberenstein1957 merged 5 commits intomainfrom
feat/276-feature-add-transformerspipeline-support-to-smash
Aug 8, 2025
Merged

feat: enhance model checks for transformers pipelines#281
davidberenstein1957 merged 5 commits intomainfrom
feat/276-feature-add-transformerspipeline-support-to-smash

Conversation

@davidberenstein1957
Copy link
Copy Markdown
Member

@davidberenstein1957 davidberenstein1957 commented Jul 26, 2025

Description

  • Added functions to identify transformers pipelines for causal language models, sequence-to-sequence models, and speech recognition.
  • Updated various quantization and compilation algorithms to utilize the new model checks, allowing for better integration with transformers pipelines.
  • Introduced new test fixtures for transformers pipelines in the test suite to ensure comprehensive coverage.
from transformers import pipeline
from pruna import smash, PrunaModel, SmashConfig

for algo in [("compiler", "torch_compile"), ("quantizer", "hqq")]:
    model_name = "HuggingFaceTB/SmolLM2-360M-Instruct"
    pipe = pipeline(
        task="text-generation",
        model=model_name,
    )
    print(pipe("Hello, how are you?"))
    smash_config = SmashConfig(device="cuda")
    smash_config[algo[0]] = algo[1]
    smashed_pipe = smash(pipe, smash_config=smash_config)
    smashed_pipe.save_pretrained("test_model")
    smashed_pipe = PrunaModel.from_pretrained("test_model")
    print(smashed_pipe("Hello, how are you?", max_new_tokens=100))

Related Issue

Fixes #276

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

from transformers import pipeline

from pruna import SmashConfig, smash

pipe = pipeline("text-generation", model="HuggingFaceTB/SmolLM2-360M-Instruct", device="cpu")
messages = [
    {"role": "user", "content": "Who are you?"},
]
pipe(messages)
config = SmashConfig(device="cpu")
config["quantizer"] = "torchao"
smashed_pipe = smash(pipe, config)
print(smashed_pipe(messages, max_new_tokens=100))

@davidberenstein1957 davidberenstein1957 linked an issue Jul 26, 2025 that may be closed by this pull request
@davidberenstein1957 davidberenstein1957 requested review from begumcig and johannaSommer and removed request for johannaSommer July 28, 2025 07:31
@davidberenstein1957
Copy link
Copy Markdown
Member Author

@begumcig do we need to do anything special with the inference handler here? Not as far as I could see but I would love to hear your thoughts :)

Copy link
Copy Markdown
Member

@johannaSommer johannaSommer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love it!!! Just a few questions 🙌

@davidberenstein1957 davidberenstein1957 force-pushed the feat/276-feature-add-transformerspipeline-support-to-smash branch from 9602323 to 389c900 Compare August 1, 2025 16:34
Copy link
Copy Markdown
Collaborator

@llcnt llcnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So nice to support LM pipelines :)
I only have one question regarding save/load functionality: did you re-ran the gpu tests? Or try to save and re-load a pipeline smashed with hqq ? (I wonder if we need to add the save_pipeline_info function inside hqq custom saved function)

Copy link
Copy Markdown
Member

@johannaSommer johannaSommer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonderful, looks good from my side! 🙌

cursor[bot]

This comment was marked as outdated.

- Added new functions to check for transformers pipelines with causal and seq2seq language models.
- Updated existing quantization and compilation algorithms to utilize the new pipeline checks.
- Enhanced the `PrunaAlgorithmBase` class to apply algorithms directly to transformers pipelines.
- Updated tests and fixtures to include transformers pipeline scenarios for better coverage.
…ers pipelines

- Renamed `_apply_to_model_within_pipeline` to `_apply_to_model_within_transformers_pipeline` for consistency across various algorithms.
- Updated references in quantization and compilation classes to use the new method name.
- Enhanced test fixtures to support a more generalized method for obtaining transformers pipelines for specific tasks.
…ts and signature

- Updated the `track_usage` decorator to accept a callable directly, improving type safety with TypeVar.
- Enhanced the wrapper function to extract `smash_config` from the function's signature, ensuring better handling of arguments.
- Preserved the original function's type hints and signature for improved compatibility and clarity.
- Updated the `load_hqq` function to ensure proper handling of model paths and added support for loading tokenizer alongside the model.
- Enhanced the `save_model_hqq` function to save the tokenizer when saving a transformers pipeline model.
- Improved code clarity and maintainability by ensuring consistent handling of model and tokenizer paths.
@davidberenstein1957 davidberenstein1957 force-pushed the feat/276-feature-add-transformerspipeline-support-to-smash branch from 8eaf9ad to 308adf2 Compare August 7, 2025 16:48
@davidberenstein1957
Copy link
Copy Markdown
Member Author

@llcnt you were right. I fixed it now.

@davidberenstein1957 davidberenstein1957 requested review from sharpenb and removed request for begumcig August 7, 2025 16:51
cursor[bot]

This comment was marked as outdated.

…luation

- Updated `save_model_hqq` to check for the presence of a tokenizer before saving, preventing potential errors.
- Modified model loading in `EnvironmentalImpactStats` to ensure the save path is converted to a string, enhancing compatibility.
- Refactored the `track_usage` decorator to streamline argument handling and improve clarity in function usage tracking.
Copy link
Copy Markdown
Collaborator

@llcnt llcnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the hqq fix :) LGTM!

@davidberenstein1957 davidberenstein1957 merged commit e5ef1b7 into main Aug 8, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] add transformers.pipeline support to smash

3 participants