Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions src/pruna/algorithms/compilation/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
from ConfigSpace import CategoricalHyperparameter, OrdinalHyperparameter

from pruna.algorithms.compilation import PrunaCompiler
from pruna.algorithms.compilation.utils import TransformersGenerator
from pruna.algorithms.compilation.utils import CausalLMGenerator, JanusGenerator
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
from pruna.config.smash_space import Boolean
from pruna.engine.model_checks import (
get_diffusers_transformer_models,
get_diffusers_unet_models,
is_causal_lm,
is_janus_llamagen_ar,
is_opt_model,
)
from pruna.engine.save import SAVE_FUNCTIONS
Expand Down Expand Up @@ -205,8 +206,9 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
):
return unet_transformer_pipeline_logic(model, smash_config)

if is_causal_lm(model):
return causal_lm_logic(model, smash_config)
if is_causal_lm(model) or is_janus_llamagen_ar(model):
return causal_lm_or_janus_logic(model, smash_config)

return compile_callable(model, smash_config)

def import_algorithm_packages(self) -> Dict[str, Any]:
Expand Down Expand Up @@ -381,9 +383,9 @@ def unet_transformer_pipeline_logic(model: Any, smash_config: SmashConfigPrefixW
return model


def causal_lm_logic(model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
def causal_lm_or_janus_logic(model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
"""
Apply compilation logic for causal language models.
Apply compilation logic for causal language models or Janus LlamaGen AR models.

Parameters
----------
Expand All @@ -407,17 +409,29 @@ def causal_lm_logic(model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
# https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig.temperature
temperature = 1.0

# We use a generator as in https://github.com/mobiusml/hqq/blob/1f052eb5a0aab0572d380d48b708ae1c74936d23/hqq/utils/generation_hf.py
gen = TransformersGenerator(
model,
max_kv_cache_size=smash_config["max_kv_cache_size"],
temperature=temperature,
top_k=top_k,
compile_mode=smash_config["mode"],
compile_fullgraph=smash_config["fullgraph"],
batch_size=smash_config.batch_size,
device=smash_config.device,
)
if is_causal_lm(model):
# We use a generator as in https://github.com/mobiusml/hqq/blob/1f052eb5a0aab0572d380d48b708ae1c74936d23/hqq/utils/generation_hf.py
gen = CausalLMGenerator(
model,
max_kv_cache_size=smash_config["max_kv_cache_size"],
temperature=temperature,
top_k=top_k,
compile_mode=smash_config["mode"],
compile_fullgraph=smash_config["fullgraph"],
batch_size=smash_config.batch_size,
device=smash_config.device,
)
elif is_janus_llamagen_ar(model):
gen = JanusGenerator( # type: ignore
model,
temperature=temperature,
top_k=top_k,
compile_mode=smash_config["mode"],
compile_fullgraph=smash_config["fullgraph"],
)
else:
raise ValueError(f"Model {model} is not a causal language model or Janus LlamaGen AR model.")

# If we are using max-autotune-no-cudagraphs, we need to handle the cudagraphs manually.
if smash_config["mode"] == "max-autotune-no-cudagraphs":
pruna_logger.error("max-autotune-no-cudagraphs is not supported for causal language models.")
Expand Down
Loading