Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ First, load any pre-trained model. Here's an example using Stable Diffusion:

```python
from diffusers import StableDiffusionPipeline
base_model = StableDiffusionPipeline.from_pretrained("segmind/Segmind-Vega")
base_model = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
Comment thread
johannaSommer marked this conversation as resolved.
```

Then, use Pruna's `smash` function to optimize your model. Pruna provides a variety of different optimization algorithms, allowing you to combine different algorithms to get the best possible results. You can customize the optimization process using `SmashConfig`:
Expand Down
876 changes: 427 additions & 449 deletions docs/tutorials/image_generation.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/user_manual/configure.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ Underneath you can find the list of all the available datasets.
- ``image_classification_collate``
- ``image: PIL.Image.Image``, ``label: int``
* - Audio Processing
- `CommonVoice <https://huggingface.co/datasets/mozilla-foundation/common_voice_1_0>`_, `AIPodcast <https://huggingface.co/datasets/reach-vb/random-audios>`_
- `CommonVoice <https://huggingface.co/datasets/mozilla-foundation/common_voice_1_0>`_, `AIPodcast <https://huggingface.co/datasets/reach-vb/random-audios/blob/main/sam_altman_lex_podcast_367.flac>`_, `MiniPresentation <https://huggingface.co/datasets/reach-vb/random-audios/blob/main/4469669-10.mp3>`_
- ``audio_processing_collate``
- ``audio: Optional[torch.Tensor]``, ``path: Optional[str]``, ``sentence: str``
* - Question Answering
Expand Down
4 changes: 2 additions & 2 deletions docs/user_manual/smash.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ Example 3: Speech Recognition Optimization
optimized_model = smash(model=model, smash_config=smash_config)

# Download and transcribe audio sample
audio_url = "https://huggingface.co/datasets/reach-vb/random-audios/resolve/main/sam_altman_lex_podcast_367.flac"
audio_file = "sam_altman_lex_podcast_367.flac"
audio_url = "https://huggingface.co/datasets/reach-vb/random-audios/resolve/main/4469669-10.mp3"
audio_file = "4469669-10.mp3"

# Download audio file
response = requests.get(audio_url)
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/batching/ws2t.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
processor = processor.backend_tokenizer
else:
processor = Tokenizer.from_pretrained(processor.tokenizer.name_or_path)
processor.save(Path(model.output_dir) / "tokenizer.json")
processor.save(str(Path(model.output_dir) / "tokenizer.json"))
else:
pruna_logger.error("Please pass a Huggingface Whisper Processor.")

Expand Down
7 changes: 6 additions & 1 deletion src/pruna/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

from typing import Any, Callable, Tuple

from pruna.data.datasets.audio import setup_commonvoice_dataset, setup_podcast_dataset
from pruna.data.datasets.audio import (
setup_commonvoice_dataset,
setup_mini_presentation_audio_dataset,
setup_podcast_dataset,
)
from pruna.data.datasets.image import (
setup_cifar10_dataset,
setup_imagenet_dataset,
Expand All @@ -40,6 +44,7 @@
"LAION256": (setup_laion256_dataset, "image_generation_collate", {"img_size": 512}),
"CommonVoice": (setup_commonvoice_dataset, "audio_collate", {}),
"AIPodcast": (setup_podcast_dataset, "audio_collate", {}),
"MiniPresentation": (setup_mini_presentation_audio_dataset, "audio_collate", {}),
"ImageNet": (setup_imagenet_dataset, "image_classification_collate", {"img_size": 224}),
"MNIST": (setup_mnist_dataset, "image_classification_collate", {"img_size": 28}),
"WikiText": (setup_wikitext_dataset, "text_generation_collate", {}),
Expand Down
20 changes: 19 additions & 1 deletion src/pruna/data/datasets/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,31 @@ def setup_podcast_dataset() -> Tuple[Dataset, Dataset, Dataset]:
Tuple[Dataset, Dataset, Dataset]
The AI Podcast dataset.
"""
return _download_audio_and_select_sample("sam_altman_lex_podcast_367.flac")


def setup_mini_presentation_audio_dataset() -> Tuple[Dataset, Dataset, Dataset]:
"""
Setup the Mini Audio dataset.

License: unspecified

Returns
-------
Tuple[Dataset, Dataset, Dataset]
The AI Podcast dataset.
"""
return _download_audio_and_select_sample("4469669-10.mp3")


def _download_audio_and_select_sample(file_id: str) -> Tuple[Dataset, Dataset, Dataset]:
load_dataset("reach-vb/random-audios", trust_remote_code=True)
cache_path = os.environ.get("HUGGINGFACE_HUB_CACHE")
if cache_path is None:
cache_path = str(Path.home() / ".cache" / "huggingface" / "hub")

dataset_path = Path(cache_path) / "datasets--reach-vb--random-audios"
path_to_podcast_file = str(list(dataset_path.rglob("sam_altman_lex_podcast_367.flac"))[0])
path_to_podcast_file = str(list(dataset_path.rglob(file_id))[0])

ds = Dataset.from_dict({"audio": [{"path": path_to_podcast_file}], "sentence": [""]})
pruna_logger.info(
Expand Down
11 changes: 8 additions & 3 deletions src/pruna/engine/pruna_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pruna.engine.handler.handler_utils import register_inference_handler
from pruna.engine.load import load_pruna_model, load_pruna_model_from_pretrained
from pruna.engine.save import save_pruna_model, save_pruna_model_to_hub
from pruna.engine.utils import get_nn_modules, move_to_device, set_to_best_available_device, set_to_eval
from pruna.engine.utils import get_device, get_nn_modules, move_to_device, set_to_eval
from pruna.logging.filter import apply_warning_filter
from pruna.telemetry import increment_counter, track_usage

Expand Down Expand Up @@ -91,8 +91,13 @@ def run_inference(self, batch: Any, device: torch.device | str | None = None) ->
Any
The processed output.
"""
device = set_to_best_available_device(device)
batch = self.inference_handler.move_inputs_to_device(batch, device) # type: ignore
if self.model is None:
raise ValueError("No more model available, this model is likely destroyed.")

# Rather than giving a device to the inference call,
# we should run the inference on the device of the model.
model_device = get_device(self.model)
batch = self.inference_handler.move_inputs_to_device(batch, model_device)
Comment thread
johannaSommer marked this conversation as resolved.
Outdated

if not isinstance(batch, tuple):
batch = (batch, {})
Expand Down
1 change: 0 additions & 1 deletion tests/algorithms/test_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None:
[
("sd_tiny_random", dict(cacher="deepcache", compiler="stable_fast"), False, 'cmmd'),
("mobilenet_v2", dict(pruner="torch_unstructured", quantizer="half"), True, 'latency'),
("whisper_tiny_random", dict(batcher="whisper_s2t", compiler="c_whisper"), False, 'latency'),
("sd_tiny_random", dict(quantizer="hqq_diffusers", compiler="torch_compile"), False, 'cmmd'),
("flux_tiny_random", dict(quantizer="hqq_diffusers", compiler="torch_compile"), False, 'cmmd'),
("sd_tiny_random", dict(quantizer="diffusers_int8", compiler="torch_compile"), False, 'cmmd'),
Expand Down
4 changes: 2 additions & 2 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ def dataloader_fixture(request: pytest.FixtureRequest) -> Any:

def whisper_tiny_random_model() -> tuple[Any, SmashConfig]:
"""Whisper tiny random model for speech recognition."""
model_id = "yujiepan/whisper-v3-tiny-random"
model_id = "PrunaAI/whisper-v3-tiny-random"
model = pipeline(
"automatic-speech-recognition",
model=model_id,
torch_dtype=torch.float16,
device_map="cpu",
)
smash_config = SmashConfig()
smash_config.add_data("AIPodcast")
smash_config.add_data("MiniPresentation")
smash_config.add_tokenizer(model_id)
smash_config.add_processor(model_id)
return model, smash_config
Expand Down