diff --git a/src/pruna/engine/hf_hub_utils/model_card_template.md b/src/pruna/engine/hf_hub_utils/model_card_template.md index 6ed7a818..c6f89771 100644 --- a/src/pruna/engine/hf_hub_utils/model_card_template.md +++ b/src/pruna/engine/hf_hub_utils/model_card_template.md @@ -1,7 +1,7 @@ --- library_name: {library_name} tags: -- pruna-ai +- {pruna_library}-ai --- # Model Card for {repo_id} @@ -13,7 +13,7 @@ This model was created using the [pruna](https://github.com/PrunaAI/pruna) libra First things first, you need to install the pruna library: ```bash -pip install pruna +pip install {pruna_library} ``` You can [use the {library_name} library to load the model](https://huggingface.co/{repo_id}?library={library_name}) but this might not include all optimizations by default. @@ -21,9 +21,9 @@ You can [use the {library_name} library to load the model](https://huggingface.c To ensure that all optimizations are applied, use the pruna library to load the model using the following code: ```python -from pruna import PrunaModel +from {pruna_library} import {pruna_model_class} -loaded_model = PrunaModel.from_hub( +loaded_model = {pruna_model_class}.from_hub( "{repo_id}" ) ``` @@ -44,4 +44,4 @@ The compression configuration of the model is stored in the `smash_config.json` [![GitHub](https://img.shields.io/github/followers/PrunaAI?label=Follow%20%40PrunaAI&style=social)](https://github.com/PrunaAI) [![LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue)](https://www.linkedin.com/company/93832878/admin/feed/posts/?feedType=following) [![Discord](https://img.shields.io/badge/Discord-Join%20Us-blue?style=social&logo=discord)](https://discord.com/invite/rskEr4BZJx) -[![Reddit](https://img.shields.io/reddit/subreddit-subscribers/PrunaAI?style=social)](https://www.reddit.com/r/PrunaAI/) \ No newline at end of file +[![Reddit](https://img.shields.io/reddit/subreddit-subscribers/PrunaAI?style=social)](https://www.reddit.com/r/PrunaAI/) diff --git a/src/pruna/engine/pruna_model.py b/src/pruna/engine/pruna_model.py index 022c5594..07d445ab 100644 --- a/src/pruna/engine/pruna_model.py +++ b/src/pruna/engine/pruna_model.py @@ -225,6 +225,7 @@ def save_to_hub( The number of steps to print the report of the saved model. """ save_pruna_model_to_hub( + instance=self, model=self.model, smash_config=self.smash_config, repo_id=repo_id, diff --git a/src/pruna/engine/save.py b/src/pruna/engine/save.py index 49be74f1..ec40d0b3 100644 --- a/src/pruna/engine/save.py +++ b/src/pruna/engine/save.py @@ -20,13 +20,13 @@ from enum import Enum from functools import partial from pathlib import Path -from typing import Any, List +from typing import TYPE_CHECKING, Any, List import torch import transformers from huggingface_hub import upload_large_folder -from pruna.config.smash_config import SMASH_CONFIG_FILE_NAME, SmashConfig +from pruna.config.smash_config import SMASH_CONFIG_FILE_NAME from pruna.engine.load import ( LOAD_FUNCTIONS, PICKLED_FILE_NAME, @@ -37,6 +37,10 @@ from pruna.engine.utils import determine_dtype from pruna.logging.logger import pruna_logger +if TYPE_CHECKING: + from pruna.config.smash_config import SmashConfig + from pruna.engine.pruna_model import PrunaModel + def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: """ @@ -90,8 +94,9 @@ def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConf def save_pruna_model_to_hub( + instance: "PrunaModel" | Any, model: Any, - smash_config: SmashConfig, + smash_config: "SmashConfig" | Any, repo_id: str, model_path: str | Path | None = None, *, @@ -104,13 +109,15 @@ def save_pruna_model_to_hub( print_report_every: int = 60, ) -> None: """ - Save the model to the specified directory. + Save the model to the Hugging Face Hub. Parameters ---------- + instance : PrunaModel | Any + The PrunaModel instance to save. model : Any The model to save. - smash_config : SmashConfig + smash_config : Union[SmashConfig, Any] The SmashConfig object containing the save and load functions. repo_id : str The repository ID. @@ -155,10 +162,13 @@ def save_pruna_model_to_hub( # Format the content for the README using the template and the loaded configuration data template_path = Path(__file__).parent / "hf_hub_utils" / "model_card_template.md" template = template_path.read_text() + pruna_library = instance.__module__.split(".")[0] if "." in instance.__module__ else None content = template.format( repo_id=repo_id, smash_config=json.dumps(smash_config_data, indent=4), library_name=library_name, + pruna_model_class=instance.__class__.__name__, + pruna_library=pruna_library, ) # Define the path for the README file and write the formatted content to it @@ -305,6 +315,7 @@ def save_model_hqq(model: Any, model_path: str | Path, smash_config: SmashConfig The SmashConfig object containing the save and load functions. """ from pruna.algorithms.quantization.hqq import HQQQuantizer + algorithm_packages = HQQQuantizer().import_algorithm_packages() if isinstance(model, algorithm_packages["HQQModelForCausalLM"]): diff --git a/tests/engine/test_save.py b/tests/engine/test_save.py index 831e5c58..60c00318 100644 --- a/tests/engine/test_save.py +++ b/tests/engine/test_save.py @@ -12,6 +12,7 @@ from pruna.engine.load import load_pruna_model from pruna.config.smash_config import SmashConfig from diffusers import DiffusionPipeline +from pruna.engine.pruna_model import PrunaModel @@ -137,9 +138,11 @@ def test_save_to_hub_path_types(tmp_path) -> None: config = SmashConfig() string_path = str(tmp_path / "string_test") pathlib_path = Path(tmp_path / "pathlib_test") + pruna_model = PrunaModel(dummy_model, config) with patch('pruna.engine.save.upload_large_folder') as mock_upload: save_pruna_model_to_hub( + instance=pruna_model, model=dummy_model, smash_config=config, repo_id="test/repo", @@ -151,6 +154,7 @@ def test_save_to_hub_path_types(tmp_path) -> None: mock_upload.reset_mock() save_pruna_model_to_hub( + instance=pruna_model, model=dummy_model, smash_config=config, repo_id="test/repo2", @@ -158,4 +162,3 @@ def test_save_to_hub_path_types(tmp_path) -> None: private=True ) assert mock_upload.called -