Skip to content
Merged
10 changes: 5 additions & 5 deletions src/pruna/engine/hf_hub_utils/model_card_template.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
library_name: {library_name}
tags:
- pruna-ai
- {pruna_library}-ai
---

# Model Card for {repo_id}
Expand All @@ -13,17 +13,17 @@ This model was created using the [pruna](https://github.com/PrunaAI/pruna) libra
First things first, you need to install the pruna library:

```bash
pip install pruna
pip install {pruna_library}
```

You can [use the {library_name} library to load the model](https://huggingface.co/{repo_id}?library={library_name}) but this might not include all optimizations by default.

To ensure that all optimizations are applied, use the pruna library to load the model using the following code:

```python
from pruna import PrunaModel
from {pruna_library} import {pruna_model_class}

loaded_model = PrunaModel.from_hub(
loaded_model = {pruna_model_class}.from_hub(
"{repo_id}"
)
```
Expand All @@ -44,4 +44,4 @@ The compression configuration of the model is stored in the `smash_config.json`
[![GitHub](https://img.shields.io/github/followers/PrunaAI?label=Follow%20%40PrunaAI&style=social)](https://github.com/PrunaAI)
[![LinkedIn](https://img.shields.io/badge/LinkedIn-Connect-blue)](https://www.linkedin.com/company/93832878/admin/feed/posts/?feedType=following)
[![Discord](https://img.shields.io/badge/Discord-Join%20Us-blue?style=social&logo=discord)](https://discord.com/invite/rskEr4BZJx)
[![Reddit](https://img.shields.io/reddit/subreddit-subscribers/PrunaAI?style=social)](https://www.reddit.com/r/PrunaAI/)
[![Reddit](https://img.shields.io/reddit/subreddit-subscribers/PrunaAI?style=social)](https://www.reddit.com/r/PrunaAI/)
1 change: 1 addition & 0 deletions src/pruna/engine/pruna_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def save_to_hub(
The number of steps to print the report of the saved model.
"""
save_pruna_model_to_hub(
instance=self,
model=self.model,
smash_config=self.smash_config,
repo_id=repo_id,
Expand Down
21 changes: 16 additions & 5 deletions src/pruna/engine/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Any, List
from typing import TYPE_CHECKING, Any, List

import torch
import transformers
from huggingface_hub import upload_large_folder

from pruna.config.smash_config import SMASH_CONFIG_FILE_NAME, SmashConfig
from pruna.config.smash_config import SMASH_CONFIG_FILE_NAME
from pruna.engine.load import (
LOAD_FUNCTIONS,
PICKLED_FILE_NAME,
Expand All @@ -37,6 +37,10 @@
from pruna.engine.utils import determine_dtype
from pruna.logging.logger import pruna_logger

if TYPE_CHECKING:
from pruna.config.smash_config import SmashConfig
from pruna.engine.pruna_model import PrunaModel


def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None:
"""
Expand Down Expand Up @@ -90,8 +94,9 @@ def save_pruna_model(model: Any, model_path: str | Path, smash_config: SmashConf


def save_pruna_model_to_hub(
instance: "PrunaModel" | Any,
model: Any,
smash_config: SmashConfig,
smash_config: "SmashConfig" | Any,
repo_id: str,
model_path: str | Path | None = None,
*,
Expand All @@ -104,13 +109,15 @@ def save_pruna_model_to_hub(
print_report_every: int = 60,
) -> None:
"""
Save the model to the specified directory.
Save the model to the Hugging Face Hub.

Parameters
----------
instance : PrunaModel | Any
The PrunaModel instance to save.
model : Any
Comment thread
davidberenstein1957 marked this conversation as resolved.
Outdated
The model to save.
smash_config : SmashConfig
smash_config : Union[SmashConfig, Any]
The SmashConfig object containing the save and load functions.
repo_id : str
The repository ID.
Expand Down Expand Up @@ -155,10 +162,13 @@ def save_pruna_model_to_hub(
# Format the content for the README using the template and the loaded configuration data
template_path = Path(__file__).parent / "hf_hub_utils" / "model_card_template.md"
template = template_path.read_text()
pruna_library = instance.__module__.split(".")[0] if "." in instance.__module__ else None
content = template.format(
repo_id=repo_id,
smash_config=json.dumps(smash_config_data, indent=4),
library_name=library_name,
pruna_model_class=instance.__class__.__name__,
pruna_library=pruna_library,
)

# Define the path for the README file and write the formatted content to it
Expand Down Expand Up @@ -305,6 +315,7 @@ def save_model_hqq(model: Any, model_path: str | Path, smash_config: SmashConfig
The SmashConfig object containing the save and load functions.
"""
from pruna.algorithms.quantization.hqq import HQQQuantizer

algorithm_packages = HQQQuantizer().import_algorithm_packages()

if isinstance(model, algorithm_packages["HQQModelForCausalLM"]):
Expand Down
5 changes: 4 additions & 1 deletion tests/engine/test_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pruna.engine.load import load_pruna_model
from pruna.config.smash_config import SmashConfig
from diffusers import DiffusionPipeline
from pruna.engine.pruna_model import PrunaModel



Expand Down Expand Up @@ -137,9 +138,11 @@ def test_save_to_hub_path_types(tmp_path) -> None:
config = SmashConfig()
string_path = str(tmp_path / "string_test")
pathlib_path = Path(tmp_path / "pathlib_test")
pruna_model = PrunaModel(dummy_model, config)

with patch('pruna.engine.save.upload_large_folder') as mock_upload:
save_pruna_model_to_hub(
instance=pruna_model,
model=dummy_model,
smash_config=config,
repo_id="test/repo",
Expand All @@ -151,11 +154,11 @@ def test_save_to_hub_path_types(tmp_path) -> None:
mock_upload.reset_mock()

save_pruna_model_to_hub(
instance=pruna_model,
model=dummy_model,
smash_config=config,
repo_id="test/repo2",
model_path=pathlib_path,
private=True
)
assert mock_upload.called

Loading