diff --git a/docs/user_manual/configure.rst b/docs/user_manual/configure.rst index 092fa5f5..2609a7ab 100644 --- a/docs/user_manual/configure.rst +++ b/docs/user_manual/configure.rst @@ -269,9 +269,9 @@ Underneath you can find the list of all the available datasets. - ``text_generation_collate`` - ``text: str`` * - Image Generation - - `LAION256 `_, `OpenImage `_, `COCO `_ - - ``image_generation_collate`` - - ``image: PIL.Image.Image``, ``text: str`` + - `LAION256 `_, `OpenImage `_, `COCO `_, `DrawBench `_, `PartiPrompts `_, `GenAIBench `_ + - ``image_generation_collate``, ``prompt_collate`` + - ``text: str``, ``image: Optional[PIL.Image.Image]`` * - Image Classification - `ImageNet `_, `MNIST `_, `CIFAR10 `_ - ``image_classification_collate`` diff --git a/src/pruna/data/__init__.py b/src/pruna/data/__init__.py index d2e369ef..e2c6266b 100644 --- a/src/pruna/data/__init__.py +++ b/src/pruna/data/__init__.py @@ -24,6 +24,11 @@ setup_imagenet_dataset, setup_mnist_dataset, ) +from pruna.data.datasets.prompt import ( + setup_drawbench_dataset, + setup_genai_bench_dataset, + setup_parti_prompts_dataset, +) from pruna.data.datasets.question_answering import setup_polyglot_dataset from pruna.data.datasets.text_generation import ( setup_c4_dataset, @@ -56,4 +61,7 @@ "Polyglot": (setup_polyglot_dataset, "question_answering_collate", {}), "OpenImage": (setup_open_image_dataset, "image_generation_collate", {"img_size": 1024}), "CIFAR10": (setup_cifar10_dataset, "image_classification_collate", {"img_size": 32}), + "DrawBench": (setup_drawbench_dataset, "prompt_collate", {}), + "PartiPrompts": (setup_parti_prompts_dataset, "prompt_collate", {}), + "GenAIBench": (setup_genai_bench_dataset, "prompt_collate", {}), } diff --git a/src/pruna/data/collate.py b/src/pruna/data/collate.py index 939ca119..8cb91b69 100644 --- a/src/pruna/data/collate.py +++ b/src/pruna/data/collate.py @@ -92,6 +92,25 @@ def image_generation_collate(data: Any, img_size: int, output_format: str = "int return texts, images_tensor +def prompt_collate(data: Any) -> Tuple[List[str], None]: + """ + Custom collation function for prompt datasets. + + Expects a ``text`` column containing the clear-text prompt in the dataset. + + Parameters + ---------- + data : Any + The data to collate. + + Returns + ------- + Tuple[List[str], None] + The collated data. + """ + return [item["text"] for item in data], None + + def audio_collate(data: Any) -> Tuple[List[str], List[str]]: """ Custom collation function for audio datasets. @@ -226,4 +245,5 @@ def question_answering_collate( "image_classification_collate": image_classification_collate, "text_generation_collate": text_generation_collate, "question_answering_collate": question_answering_collate, + "prompt_collate": prompt_collate, } diff --git a/src/pruna/data/datasets/prompt.py b/src/pruna/data/datasets/prompt.py new file mode 100644 index 00000000..c97d8d62 --- /dev/null +++ b/src/pruna/data/datasets/prompt.py @@ -0,0 +1,85 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +from datasets import Dataset, load_dataset + +from pruna.logging.logger import pruna_logger + + +def setup_drawbench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the DrawBench dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The DrawBench dataset. + """ + ds = load_dataset("sayakpaul/drawbench", trust_remote_code=True)["train"] + ds = ds.rename_column("Prompts", "text") + pruna_logger.info("DrawBench is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + + +def setup_parti_prompts_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the Parti Prompts dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The Parti Prompts dataset. + """ + ds = load_dataset("nateraw/parti-prompts")["train"] + ds = ds.rename_column("Prompt", "text") + pruna_logger.info("PartiPrompts is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds + + +def setup_genai_bench_dataset(seed: int) -> Tuple[Dataset, Dataset, Dataset]: + """ + Setup the GenAI Bench dataset. + + License: Apache 2.0 + + Parameters + ---------- + seed : int + The seed to use. + + Returns + ------- + Tuple[Dataset, Dataset, Dataset] + The GenAI Bench dataset. + """ + ds = load_dataset("BaiqiL/GenAI-Bench")["train"] + ds = ds.rename_column("Prompt", "text") + pruna_logger.info("GenAI-Bench is a test-only dataset. Do not use it for training or validation.") + return ds.select([0]), ds.select([0]), ds diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 1473f0ac..40a609ab 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -36,6 +36,9 @@ def iterate_dataloaders(datamodule: PrunaDataModule) -> None: pytest.param("OpenAssistant", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow), pytest.param("C4", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow), pytest.param("Polyglot", dict(tokenizer=bert_tokenizer), marks=pytest.mark.slow), + pytest.param("DrawBench", dict(), marks=pytest.mark.slow), + pytest.param("PartiPrompts", dict(), marks=pytest.mark.slow), + pytest.param("GenAIBench", dict(), marks=pytest.mark.slow), ], ) def test_dm_from_string(dataset_name: str, collate_fn_args: dict[str, Any]) -> None: