diff --git a/src/pruna/algorithms/compilation/torch_compile.py b/src/pruna/algorithms/compilation/torch_compile.py index 83d192ab..03ce7c4b 100644 --- a/src/pruna/algorithms/compilation/torch_compile.py +++ b/src/pruna/algorithms/compilation/torch_compile.py @@ -20,13 +20,14 @@ from ConfigSpace import CategoricalHyperparameter, OrdinalHyperparameter from pruna.algorithms.compilation import PrunaCompiler -from pruna.algorithms.compilation.utils import TransformersGenerator +from pruna.algorithms.compilation.utils import CausalLMGenerator, JanusGenerator from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper from pruna.config.smash_space import Boolean from pruna.engine.model_checks import ( get_diffusers_transformer_models, get_diffusers_unet_models, is_causal_lm, + is_janus_llamagen_ar, is_opt_model, ) from pruna.engine.save import SAVE_FUNCTIONS @@ -205,8 +206,9 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: ): return unet_transformer_pipeline_logic(model, smash_config) - if is_causal_lm(model): - return causal_lm_logic(model, smash_config) + if is_causal_lm(model) or is_janus_llamagen_ar(model): + return causal_lm_or_janus_logic(model, smash_config) + return compile_callable(model, smash_config) def import_algorithm_packages(self) -> Dict[str, Any]: @@ -381,9 +383,9 @@ def unet_transformer_pipeline_logic(model: Any, smash_config: SmashConfigPrefixW return model -def causal_lm_logic(model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: +def causal_lm_or_janus_logic(model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ - Apply compilation logic for causal language models. + Apply compilation logic for causal language models or Janus LlamaGen AR models. Parameters ---------- @@ -407,17 +409,29 @@ def causal_lm_logic(model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: # https://huggingface.co/docs/transformers/en/main_classes/text_generation#transformers.GenerationConfig.temperature temperature = 1.0 - # We use a generator as in https://github.com/mobiusml/hqq/blob/1f052eb5a0aab0572d380d48b708ae1c74936d23/hqq/utils/generation_hf.py - gen = TransformersGenerator( - model, - max_kv_cache_size=smash_config["max_kv_cache_size"], - temperature=temperature, - top_k=top_k, - compile_mode=smash_config["mode"], - compile_fullgraph=smash_config["fullgraph"], - batch_size=smash_config.batch_size, - device=smash_config.device, - ) + if is_causal_lm(model): + # We use a generator as in https://github.com/mobiusml/hqq/blob/1f052eb5a0aab0572d380d48b708ae1c74936d23/hqq/utils/generation_hf.py + gen = CausalLMGenerator( + model, + max_kv_cache_size=smash_config["max_kv_cache_size"], + temperature=temperature, + top_k=top_k, + compile_mode=smash_config["mode"], + compile_fullgraph=smash_config["fullgraph"], + batch_size=smash_config.batch_size, + device=smash_config.device, + ) + elif is_janus_llamagen_ar(model): + gen = JanusGenerator( # type: ignore + model, + temperature=temperature, + top_k=top_k, + compile_mode=smash_config["mode"], + compile_fullgraph=smash_config["fullgraph"], + ) + else: + raise ValueError(f"Model {model} is not a causal language model or Janus LlamaGen AR model.") + # If we are using max-autotune-no-cudagraphs, we need to handle the cudagraphs manually. if smash_config["mode"] == "max-autotune-no-cudagraphs": pruna_logger.error("max-autotune-no-cudagraphs is not supported for causal language models.") diff --git a/src/pruna/algorithms/compilation/utils.py b/src/pruna/algorithms/compilation/utils.py index e8e5090a..24464b50 100644 --- a/src/pruna/algorithms/compilation/utils.py +++ b/src/pruna/algorithms/compilation/utils.py @@ -14,15 +14,19 @@ from __future__ import annotations import contextlib +import copy +from typing import Optional import torch from torch.nn.attention import SDPBackend, sdpa_kernel -from transformers.cache_utils import StaticCache +from transformers.cache_utils import Cache, StaticCache +from transformers.generation import ClassifierFreeGuidanceLogitsProcessor, GenerationMode, LogitsProcessorList +from transformers.generation.utils import GenerateDecoderOnlyOutput from pruna.logging.logger import pruna_logger -class TransformersGenerator: +class CausalLMGenerator: """ A class for generating text using a Hugging Face model, and using torch.compile. @@ -61,7 +65,7 @@ def __init__( device: str = "cuda", ): """ - Initialize the TransformersGenerator. + Initialize the CausalLMGenerator. Parameters ---------- @@ -509,7 +513,7 @@ def next_token_iterator( return output_tokens - @torch.inference_mode() + @torch.no_grad() def generate(self, *args, **kwargs) -> torch.Tensor: """ Generate tokens using the model. @@ -550,3 +554,522 @@ def generate(self, *args, **kwargs) -> torch.Tensor: max_new_tokens=kwargs["max_new_tokens"] if "max_new_tokens" in kwargs else args[1], ) return self.next_token_iterator(self.prefill(), kwargs["max_new_tokens"]) + + +class JanusGenerator: + """ + A class for generating images using a Janus model, and using torch.compile. + + The code is adapted from # https://github.com/huggingface/transformers/blob/4542086db764080c4333beef7b9f4327b4f8ff64/src/transformers/models/janus/modular_janus.py#L1147. + + Parameters + ---------- + model : transformers.PreTrainedModel + The Hugging Face model to use for text generation. + temperature : float, default=0.6 + The sampling temperature to use for text generation. Higher values increase randomness. + top_k : int, default=5 + The number of highest probability vocabulary tokens to keep for top-k filtering. + compile_mode : str, default='reduce-overhead' + The compilation mode to use with torch.compile(). Options include 'reduce-overhead', 'max-autotune', etc. + compile_fullgraph : bool, default=True + Whether to compile the full computation graph or use partial graph compilation. + compile_backend : str, default='inductor' + The backend to use for compilation. Options include 'inductor', 'cudagraphs', etc. + """ + + def __init__( + self, + model, + temperature: float = 0.6, + top_k: int = 5, + compile_mode: str = "reduce-overhead", + compile_fullgraph: bool = True, + compile_backend: str = "inductor", + ): + """ + Initialize the JanusGenerator. + + Parameters + ---------- + model : transformers.PreTrainedModel + The Hugging Face model to use for image generation. + temperature : float + The sampling temperature to use for image generation. Higher values increase randomness. + top_k : int + The number of highest probability vocabulary tokens to keep for top-k filtering. + compile_mode : str + The compilation mode to use with torch.compile(). Options include 'reduce-overhead', 'max-autotune', etc. + compile_fullgraph : bool + Whether to compile the full computation graph or use partial graph compilation. + compile_backend : str + The backend to use for compilation. Options include 'inductor', 'cudagraphs', etc. + + Returns + ------- + None + """ + super().__init__() + + self.model = model + self.temperature = temperature + self.top_k = top_k + self.compile_mode = compile_mode + self.compile_fullgraph = compile_fullgraph + self.compile_backend = compile_backend + + self.compiled_language_model = torch.compile( + self.model.model.language_model, + mode=self.compile_mode, + fullgraph=self.compile_fullgraph, + backend=self.compile_backend, + ) + + self.model.eval() + + def validate_config_and_model_kwargs(self, generation_config, model_kwargs): + """ + Validate the generation config and model kwargs. + + This function is adapted from the `_validate_model_kwargs` function in the `transformers` library. + + Parameters + ---------- + generation_config : GenerationConfig + The generation config. + model_kwargs : dict + The model kwargs. + """ + generation_config.validate() + self.model._validate_model_kwargs(model_kwargs.copy()) + + def prepare_logits_processor(self, generation_config, input_ids, device, logits_processor): + """ + Prepare (and merge) the logits processor. + + Parameters + ---------- + generation_config : GenerationConfig + The generation config. + input_ids : torch.Tensor + The input token ids that serve as the prompt. + device : torch.device + The device to use for the input tokens. + logits_processor : LogitsProcessorList | None + The logits processor for the input tokens. + + Returns + ------- + LogitsProcessorList + The logits processor. + """ + # Initialize logit processors + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + + # Add CFG processor along with user passed logit processor. + if generation_config.guidance_scale and generation_config.guidance_scale > 1: + logits_processor.append(ClassifierFreeGuidanceLogitsProcessor(generation_config.guidance_scale)) + generation_config.guidance_scale = None # Reset to prevent processor duplication. + + # Prepare and merge logits processor + logits_processor = self.model._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids.shape[1], + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=None, + logits_processor=logits_processor, + device=device, + ) + return logits_processor + + def prepare_inputs_tokens(self, inputs, generation_config, model_kwargs, attention_mask): + """ + Check inputs shapes, and setup special tokens and model kwargs. + + Parameters + ---------- + inputs : torch.Tensor + The input tokens. + generation_config : GenerationConfig + The generation config. + model_kwargs : dict + The model kwargs. + attention_mask : torch.Tensor | None + The attention mask. + + Returns + ------- + tuple[torch.Tensor, dict, torch.dtype, torch.device] + The input ids, model kwargs, dtype, and device. + """ + input_ids, _, model_kwargs = self.model._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs + ) + dtype, device = input_ids.dtype, input_ids.device + + if len(input_ids.shape) != 2: + raise ValueError( + f"Expected input ids of shape (batch_size, seq_len), but got {input_ids.shape}" + "Passing `inputs embeds` is not supported currently." + ) + + # Prepare special tokens which will be used generate internally. + kwargs_has_attention_mask = attention_mask is not None + self.model._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=input_ids.device) + + # Expand inputs for multiple image generations per prompt. + input_ids, model_kwargs = self.model._expand_inputs_for_generation( + input_ids=input_ids, + attention_mask=attention_mask, + expand_size=generation_config.num_return_sequences, + **model_kwargs, + ) + + return input_ids, model_kwargs, dtype, device + + def get_initial_cache_position(self, input_ids, model_kwargs): + """ + Get the initial cache position for the model. + + This function is adapted from the `get_initial_cache_position` function in the `transformers` library. + + Parameters + ---------- + input_ids : torch.Tensor + The input token ids that serve as the prompt. + model_kwargs : dict + The model kwargs. + + Returns + ------- + dict + The model kwargs with the initial cache position. + """ + # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` + if "inputs_embeds" in model_kwargs: + cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + elif "decoder_inputs_embeds" in model_kwargs: + cache_position = ( + torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + ) + else: + cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 + + past_length = 0 + if model_kwargs.get("past_key_values") is not None: + cache = model_kwargs["past_key_values"] + past_length = 0 + if not isinstance(cache, Cache): + past_length = cache[0][0].shape[2] + elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: + past_length = cache.get_seq_length() + + cache_position = cache_position[past_length:] + + model_kwargs["cache_position"] = cache_position + return model_kwargs + + def prepare_input_and_cache(self, input_ids, model_kwargs, attention_mask, generation_config, device): + """ + Setup input tokens, mask and cache. + + Prepare the input tokens, inputs embeddings, model kwargs, batch size, the number of image tokens, + and setup the KV cache. + + Parameters + ---------- + input_ids : torch.Tensor + The input token ids that serve as the prompt. + model_kwargs : dict + The model kwargs. + attention_mask : torch.Tensor | None + The attention mask. + generation_config : GenerationConfig + The generation config. + device : torch.device + The device to use for the input tokens. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, dict, int, int] + The input tokens, inputs embeddings, model kwargs, batch size, and number of image tokens. + """ + num_image_tokens = self.model.model.vision_model.config.num_image_tokens + batch_size, seq_len = input_ids.shape + + input_tokens = input_ids.repeat(2, 1) # Double batch size for conditional/unconditional logits + attention_mask = model_kwargs.pop("attention_mask", None) + attention_mask = attention_mask.repeat(2, 1) # type: ignore + model_kwargs["attention_mask"] = attention_mask + + # Mask all the tokens that are neither BOS nor BOI with pad token in the unconditional logits. + mask = (input_tokens[batch_size:, :] != generation_config.bos_token_id) & ( + input_tokens[batch_size:, :] != generation_config.generation_kwargs["boi_token_id"] + ) + input_tokens[batch_size:, :].masked_fill_(mask, generation_config.pad_token_id) + + inputs_embeds = self.model.get_input_embeddings()(input_tokens) + + model_kwargs = self.get_initial_cache_position(input_ids, model_kwargs) + + if model_kwargs.get("past_key_values", None) is None: + # Prepare cache if not provided. + model_kwargs["past_key_values"] = self.model._get_cache( + cache_implementation=generation_config.cache_implementation or "static", + # batch_size should account for both conditional/unconditional input; hence multiplied by 2. + batch_size=batch_size * 2, + # we should have at least a cache len of seq_len + num_image_tokens. + max_cache_len=max(generation_config.max_length, num_image_tokens + seq_len), + device=device, + model_kwargs=model_kwargs, + ) + + return input_tokens, inputs_embeds, model_kwargs, batch_size, num_image_tokens + + def loop_over_latent_tokens( + self, + input_tokens, + input_ids, + model_kwargs, + num_image_tokens, + output_attentions, + output_hidden_states, + inputs_embeds, + generated_tokens, + logits_processor, + generation_config, + ): + """ + Loop over the latent tokens. + + Parameters + ---------- + input_tokens : torch.Tensor + The input token ids that serve as the prompt. + input_ids : torch.Tensor + The input token ids that serve as the prompt. + model_kwargs : dict + The model kwargs. + num_image_tokens : int + The number of image tokens. + output_attentions : bool + Whether to output attentions. + output_hidden_states : bool + Whether to output hidden states. + inputs_embeds : torch.Tensor + The input embeddings. + generated_tokens : torch.Tensor + The generated tokens. + logits_processor : LogitsProcessorList + The logits processor. + generation_config : GenerationConfig + The generation config. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + The scores, hidden state, and outputs. + """ + for i in range(num_image_tokens): + model_inputs = self.model.prepare_inputs_for_generation( + inputs_embeds=inputs_embeds, input_ids=input_tokens, **model_kwargs + ) + + model_inputs["attention_mask"] = model_inputs["attention_mask"].to(inputs_embeds.device) + model_inputs["cache_position"] = model_inputs["cache_position"].to(inputs_embeds.device) + + # Pad attention mask to max length to avoid dynamic shapes error during compilation. + max_length = model_inputs["past_key_values"].get_max_cache_shape() + current_length = model_inputs["attention_mask"].shape[1] + if current_length < max_length: + padding = torch.zeros( + (model_inputs["attention_mask"].shape[0], max_length - current_length), + dtype=model_inputs["attention_mask"].dtype, + device=model_inputs["attention_mask"].device, + ) + model_inputs["attention_mask"] = torch.cat([model_inputs["attention_mask"], padding], dim=1) + + # no compilation for the prefill. + if i == 0: + outputs = self.model.model.language_model( + **model_inputs, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + # compilation for the decoding phase (one token at a time). + else: + outputs = self.compiled_language_model( + **model_inputs, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + # Update model_kwargs like cache_position for next generation. + model_kwargs = self.model._update_model_kwargs_for_generation(outputs, model_kwargs) + hidden_state = outputs.last_hidden_state[:, -1, :].clone() + + # Generate scores using the generation head (Not using above defined LM Head) + scores = self.model.model.generation_head(hidden_state) + next_token_scores = logits_processor(input_ids, scores) if logits_processor is not None else scores + + # Sample next token. + if generation_config.do_sample: + probs = torch.softmax(next_token_scores, dim=-1) + next_token = torch.multinomial(probs, num_samples=1).squeeze(-1) + else: + next_token = torch.argmax(next_token_scores, dim=-1) + + generated_tokens[:, i] = next_token + + # Prepare embeddings for the next step. + next_token = torch.cat([next_token, next_token]) + next_token = next_token.unsqueeze(-1) + + inputs_embeds = self.model.prepare_embeddings_for_image_generation(next_token) + + return scores, hidden_state, outputs + + @torch.no_grad() + def generate( + self, + inputs: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + logits_processor: Optional[LogitsProcessorList] = None, + **kwargs, + ) -> torch.Tensor | GenerateDecoderOnlyOutput: + """ + Generate latent tokens using the model. + + Parameters + ---------- + inputs : torch.Tensor | None + The input token ids that serve as the prompt. + attention_mask : torch.LongTensor | None + The attention mask for the input tokens. + logits_processor : LogitsProcessorList | None + The logits processor for the input tokens. + **kwargs : dict + Keyword arguments dictionary. + + Returns + ------- + torch.Tensor | GenerateDecoderOnlyOutput + The generated latent tokens. + """ + # Extract parameters from kwargs with defaults from instance variables + self.temperature = kwargs.pop("temperature", self.temperature) + self.top_k = kwargs.pop("top_k", self.top_k) + + # 1. Handle generation config and model kwargs + generation_config = kwargs.pop("generation_config", self.model.generation_config) + generation_config = copy.deepcopy(generation_config) + + # Default to "text" generation if mode isn't provided + generation_mode = kwargs.pop("generation_mode", "text") + if generation_mode == "text": + # Set guidance_scale=None to prevent running UnbatchedCFG processor. + return self.model.generate( + inputs=inputs, + attention_mask=attention_mask, + generation_config=generation_config, + guidance_scale=None, + **kwargs, + ) + + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + + # Validate generation mode + if generation_config.get_generation_mode() not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): + raise ValueError( + "Got incompatible mode for Image Generation, should be one of greedy or sampling. " + "Ensure that beam search is de-activated by setting `num_beams=1` and `num_beam_groups=1`." + ) + + # Validate the configuration and model kwargs + self.validate_config_and_model_kwargs(generation_config, model_kwargs) + + # Set `use_cache=True` as we will be using input embeds for generation. + model_kwargs["use_cache"] = True + + # Check if guidance_scale is provided. + if generation_config.guidance_scale is None: + pruna_logger.warning("`guidance_scale` is required for CFG but not provided. Setting to default value of 5.") + generation_config.guidance_scale = 5 + model_kwargs["guidance_scale"] = generation_config.guidance_scale + + # 2. Prepare model inputs shapes, and check special tokens. + input_ids, model_kwargs, dtype, device = self.prepare_inputs_tokens( + inputs, generation_config, model_kwargs, attention_mask + ) + + # 3. Prepare logits processor + logits_processor = self.prepare_logits_processor(generation_config, input_ids, device, logits_processor) + + # 4. Prepare input and model caches + input_tokens, inputs_embeds, model_kwargs, batch_size, num_image_tokens = self.prepare_input_and_cache( + input_ids, + model_kwargs, + attention_mask, + generation_config, + device, + ) + + # Placeholder for generated tokens. + generated_tokens = torch.zeros((batch_size, num_image_tokens), dtype=dtype, device=device) + + # 5. init attention / hidden states / scores tuples + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + + raw_scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + + # 6. Loop over the latent tokens. + scores, hidden_state, outputs = self.loop_over_latent_tokens( + input_tokens, + input_ids, + model_kwargs, + num_image_tokens, + output_attentions, + output_hidden_states, + inputs_embeds, + generated_tokens, + logits_processor, + generation_config, + ) + + # 7. Return the results. + if return_dict_in_generate: + if output_scores: + raw_scores = tuple(raw_scores) + (scores,) if raw_scores is not None else (scores,) + if output_logits: + raw_logits = ( + tuple(raw_logits) + (hidden_state.float(),) if raw_logits is not None else (hidden_state.float(),) + ) + if output_attentions: + decoder_attentions = ( + tuple(decoder_attentions) + (outputs.attentions,) + if decoder_attentions is not None + else (outputs.attentions,) + ) + if output_hidden_states: + decoder_hidden_states = ( + tuple(decoder_hidden_states) + (outputs.hidden_states,) + if decoder_hidden_states is not None + else (outputs.hidden_states,) + ) + return GenerateDecoderOnlyOutput( + sequences=generated_tokens, # type: ignore + scores=scores, # type: ignore + logits=raw_logits, # type: ignore + attentions=decoder_attentions, # type: ignore + hidden_states=decoder_hidden_states, # type: ignore + past_key_values=outputs.past_key_values, + ) + else: + return generated_tokens diff --git a/src/pruna/algorithms/quantization/hqq.py b/src/pruna/algorithms/quantization/hqq.py index 24913e2c..2dd9834a 100644 --- a/src/pruna/algorithms/quantization/hqq.py +++ b/src/pruna/algorithms/quantization/hqq.py @@ -22,9 +22,9 @@ from pruna.algorithms.quantization import PrunaQuantizer from pruna.config.smash_config import SmashConfigPrefixWrapper -from pruna.engine.model_checks import is_causal_lm +from pruna.engine.model_checks import is_causal_lm, is_janus_llamagen_ar from pruna.engine.save import SAVE_FUNCTIONS -from pruna.engine.utils import move_to_device, safe_memory_cleanup +from pruna.engine.utils import ModelContext, move_to_device, safe_memory_cleanup from pruna.logging.filter import SuppressOutput from pruna.logging.logger import pruna_logger @@ -92,9 +92,9 @@ def model_check_fn(self, model: Any) -> bool: Returns ------- bool - True if the model is a causal language model, False otherwise. + True if the model is a causal language model or a Janus LlamaGen AR model, False otherwise. """ - return is_causal_lm(model) + return is_causal_lm(model) or is_janus_llamagen_ar(model) def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ @@ -121,39 +121,45 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: quant_config_hf = imported_modules["HqqConfig"](nbits=weight_quantization_bits, group_size=group_size) move_to_device(model, "cpu") safe_memory_cleanup() - try: # Try to quantize the model using HQQ - model = imported_modules["AutoHQQHFModel"].quantize_model( - model, - quant_config=quant_config_hqq, - device=smash_config["device"], - compute_dtype=torch.float16 if smash_config["compute_dtype"] == "torch.float16" else torch.bfloat16, - ) - except Exception: # Default to generic HF quantization if it fails - pruna_logger.info("Could not quantize model using specialized HQQ pipeline, trying generic interface...") - # Create a temporary directory in a specific location - base_temp_dir = smash_config["cache_dir"] - temp_dir = tempfile.mkdtemp(dir=base_temp_dir) - model.save_pretrained(temp_dir) - - model = AutoModelForCausalLM.from_pretrained( - temp_dir, - quantization_config=quant_config_hf, - trust_remote_code=True, - device_map="auto", - torch_dtype=torch.float16 if smash_config["compute_dtype"] == "torch.float16" else torch.bfloat16, - ) - - # Delete the temporary directory and its contents - shutil.rmtree(temp_dir) - - # Prepare the model for fast inference - try: - if weight_quantization_bits == 4: - imported_modules["prepare_for_inference"](model, backend=smash_config["backend"]) - except Exception as e: - pruna_logger.error(f"Error: {e}") - pass - return model + with ModelContext(model) as (pipeline, working_model, denoiser_type): + try: # Try to quantize the model using HQQ + working_model = imported_modules["AutoHQQHFModel"].quantize_model( + working_model, + quant_config=quant_config_hqq, + device=smash_config["device"], + compute_dtype=torch.float16 if smash_config["compute_dtype"] == "torch.float16" else torch.bfloat16, + ) + except Exception: # Default to generic HF quantization if it fails + pruna_logger.info("Could not quantize model using specialized HQQ pipeline, trying generic interface...") + # Create a temporary directory in a specific location + base_temp_dir = smash_config["cache_dir"] + temp_dir = tempfile.mkdtemp(dir=base_temp_dir) + working_model.save_pretrained(temp_dir) + + working_model = AutoModelForCausalLM.from_pretrained( + temp_dir, + quantization_config=quant_config_hf, + trust_remote_code=True, + device_map="auto", + torch_dtype=torch.float16 if smash_config["compute_dtype"] == "torch.float16" else torch.bfloat16, + ) + + # Delete the temporary directory and its contents + shutil.rmtree(temp_dir) + + # Prepare the model for fast inference + try: + if weight_quantization_bits == 4: + imported_modules["prepare_for_inference"](working_model, backend=smash_config["backend"]) + except Exception as e: + pruna_logger.error(f"Error: {e}") + pass + # redefining the working_model breaks links with context manager + # so we need to re-define the working_model as an attribute of the model. + pipeline.working_model = working_model + # as we have moved the model to cpu for cleaning, but only one of its attribute was put back on cuda. + move_to_device(model, smash_config["device"]) + return model def import_algorithm_packages(self) -> Dict[str, Any]: """ diff --git a/src/pruna/engine/load.py b/src/pruna/engine/load.py index 9760a7e8..bc633a95 100644 --- a/src/pruna/engine/load.py +++ b/src/pruna/engine/load.py @@ -353,21 +353,53 @@ def load_hqq(model_path: str | Path, smash_config: SmashConfig, **kwargs) -> Any algorithm_packages = HQQQuantizer().import_algorithm_packages() + # if the model is a janus like model, we need to load the quantized model from the hqq_language_model directory + if os.path.exists(os.path.join(model_path, "hqq_language_model")): + quantized_path = str(os.path.join(model_path, "hqq_language_model")) + quantized_model_path = os.path.join(quantized_path, "qmodel.pt") + # load the weight on cpu to rename attr -> model.attr, + # and also artifically add a random lm_head to the weights. + weights = torch.load(quantized_model_path, map_location="cpu", weights_only=True) + weights = {f"model.{k}" if not k.startswith("model.") else k: v for k, v in weights.items()} + weights["lm_head"] = torch.nn.Linear(1024, 1024).state_dict() + # hqq expects the qmodel.pt file to be in the quantized_path directory. + torch.save(weights, quantized_model_path) + else: + quantized_path = str(model_path) + try: # Try to use pipeline for HF specific HQQ quantization - model = algorithm_packages["HQQModelForCausalLM"].from_quantized( - model_path, + quantized_model = algorithm_packages["HQQModelForCausalLM"].from_quantized( + quantized_path, device=smash_config.device, **filter_load_kwargs(algorithm_packages["HQQModelForCausalLM"].from_quantized, kwargs), ) except Exception: # Default to generic HQQ pipeline if it fails pruna_logger.info("Could not load HQQ model using pipeline, trying generic HQQ pipeline...") - model = algorithm_packages["AutoHQQHFModel"].from_quantized( - model_path, + if "compute_dtype" in kwargs: + compute_dtype = kwargs.pop("compute_dtype") + else: + saved_smash_config = SmashConfig() + saved_smash_config.load_from_json(model_path) + compute_dtype = ( + torch.float16 if saved_smash_config["hqq_compute_dtype"] == "torch.float16" else torch.bfloat16 + ) + quantized_model = algorithm_packages["AutoHQQHFModel"].from_quantized( + quantized_path, device=smash_config.device, + compute_dtype=compute_dtype, **filter_load_kwargs(algorithm_packages["AutoHQQHFModel"].from_quantized, kwargs), ) - return model + original_config = load_json_config(model_path, "config.json") + if original_config["architectures"][0] == "JanusForConditionalGeneration": + cls = getattr(transformers, "JanusForConditionalGeneration") + model = cls.from_pretrained(model_path, **kwargs) + model.model.language_model = quantized_model.model + # some weights of the language_model are not on the correct device, so we move it afterwards. + move_to_device(model, smash_config.device) + return model + else: + return quantized_model def load_torch_artifacts(model_path: str | Path, **kwargs) -> None: diff --git a/src/pruna/engine/model_checks.py b/src/pruna/engine/model_checks.py index a758bf62..240de4d1 100644 --- a/src/pruna/engine/model_checks.py +++ b/src/pruna/engine/model_checks.py @@ -570,3 +570,20 @@ def is_opt_model(model: Any) -> bool: """ opt_mapping = {k: v for k, v in MODEL_FOR_CAUSAL_LM_MAPPING.items() if "opt" in str(k).lower()} return isinstance(model, tuple(opt_mapping.values())) + + +def is_janus_llamagen_ar(model: Any) -> bool: + """ + Check if the model is a Janus LlamaGen AR model. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is a Janus LlamaGen AR model, False otherwise. + """ + return model.__class__.__name__ == "JanusForConditionalGeneration" diff --git a/src/pruna/engine/save.py b/src/pruna/engine/save.py index ec40d0b3..41bc80c1 100644 --- a/src/pruna/engine/save.py +++ b/src/pruna/engine/save.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import annotations +import copy import json import os import shutil @@ -34,7 +35,7 @@ SAVE_BEFORE_SMASH_CACHE_DIR, ) from pruna.engine.model_checks import get_helpers -from pruna.engine.utils import determine_dtype +from pruna.engine.utils import ModelContext, determine_dtype from pruna.logging.logger import pruna_logger if TYPE_CHECKING: @@ -318,10 +319,36 @@ def save_model_hqq(model: Any, model_path: str | Path, smash_config: SmashConfig algorithm_packages = HQQQuantizer().import_algorithm_packages() - if isinstance(model, algorithm_packages["HQQModelForCausalLM"]): - model.save_quantized(model_path) + # we need to create a separate path for the quantized model + if hasattr(model, "model") and hasattr(model.model, "language_model"): + quantized_path = os.path.join(str(model_path), "hqq_language_model") else: - algorithm_packages["AutoHQQHFModel"].save_quantized(model, str(model_path)) + quantized_path = str(model_path) + + # save the quantized model only. + with ModelContext(model) as (pipeline, working_model, denoiser_type): + if isinstance(working_model, algorithm_packages["HQQModelForCausalLM"]): + working_model.save_quantized(quantized_path) + else: + algorithm_packages["AutoHQQHFModel"].save_quantized(working_model, str(quantized_path)) + # redefining the working_model breaks links with context manager + # so we need to re-define the working_model as an attribute of the model. + pipeline.working_model = working_model + + # save the rest of the model, if it is a janus like model, + # and add a config file to the quantized model path. + if hasattr(model, "model") and hasattr(model.model, "language_model"): + transformer_backup = model.model.language_model + model.model.language_model = None + model.save_pretrained(model_path) + # Create a copy to avoid modifying the original config + hqq_config = copy.deepcopy(model.config.text_config) + # for re-loading the model, hqq expects the architecture to be LlamaForCausalLM + hqq_config.architectures = ["LlamaForCausalLM"] + os.makedirs(quantized_path, exist_ok=True) + with open(os.path.join(quantized_path, "config.json"), "w") as f: + json.dump(hqq_config.to_dict(), f, indent=2) + model.model.language_model = transformer_backup smash_config.load_fns.append(LOAD_FUNCTIONS.hqq.name) diff --git a/src/pruna/engine/utils.py b/src/pruna/engine/utils.py index 74b02598..548c3638 100644 --- a/src/pruna/engine/utils.py +++ b/src/pruna/engine/utils.py @@ -489,6 +489,9 @@ def __enter__(self) -> tuple[ModelMixin, Any, str | None]: elif hasattr(self.pipeline, "unet"): self.working_model = self.pipeline.unet self.denoiser_type = "unet" + elif hasattr(self.pipeline, "model") and hasattr(self.pipeline.model, "language_model"): + self.working_model = self.pipeline.model.language_model + self.denoiser_type = "language_model" else: self.working_model = self.pipeline self.denoiser_type = None # type: ignore [assignment] @@ -511,6 +514,8 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.pipeline.transformer = self.pipeline.working_model elif hasattr(self.pipeline, "unet"): self.pipeline.unet = self.pipeline.working_model + elif hasattr(self.pipeline, "model") and hasattr(self.pipeline.model, "language_model"): + self.pipeline.model.language_model = self.pipeline.working_model else: self.pipeline = self.pipeline.working_model del self.pipeline.working_model diff --git a/tests/algorithms/test_combinations.py b/tests/algorithms/test_combinations.py index 23b7c885..18952e2d 100644 --- a/tests/algorithms/test_combinations.py +++ b/tests/algorithms/test_combinations.py @@ -51,6 +51,7 @@ def prepare_smash_config(self, smash_config: SmashConfig, device: str) -> None: ("flux_tiny_random", dict(cacher="fora", quantizer="diffusers_int8"), False), ("flux_tiny_random", dict(cacher="fora", compiler="torch_compile"), False), ("flux_tiny_random", dict(cacher="fora", compiler="stable_fast"), False), + ("tiny_janus_pro", dict(quantizer="hqq", compiler="torch_compile"), False), ], indirect=["model_fixture"], ) diff --git a/tests/algorithms/testers/quantization.py b/tests/algorithms/testers/quantization.py index d9489fa4..3506e7dd 100644 --- a/tests/algorithms/testers/quantization.py +++ b/tests/algorithms/testers/quantization.py @@ -56,7 +56,7 @@ class TestDiffusersInt8(AlgorithmTesterBase): class TestHQQ(AlgorithmTesterBase): """Test the HQQ quantizer.""" - models = ["llama_3_tiny_random"] + models = ["llama_3_tiny_random", "tiny_janus_pro"] reject_models = ["sd_tiny_random"] allow_pickle_files = False algorithm_class = HQQQuantizer diff --git a/tests/fixtures.py b/tests/fixtures.py index 53010cd6..5a198880 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -6,7 +6,7 @@ import torch from huggingface_hub import snapshot_download from torchvision.models import get_model as torchvision_get_model -from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline +from transformers import AutoModelForCausalLM, AutoTokenizer, JanusForConditionalGeneration, pipeline from pruna import SmashConfig from pruna.data.pruna_datamodule import PrunaDataModule @@ -117,6 +117,13 @@ def get_torchvision_model(name: str) -> tuple[Any, SmashConfig]: return model, smash_config +def get_janus_model(model_id: str) -> tuple[Any, SmashConfig]: + """Get a Janus model for image generation.""" + model = JanusForConditionalGeneration.from_pretrained(model_id) + smash_config = SmashConfig() + return model, smash_config + + MODEL_FACTORY: dict[str, Callable] = { # whisper models "whisper_tiny_random": whisper_tiny_random_model, @@ -148,4 +155,7 @@ def get_torchvision_model(name: str) -> tuple[Any, SmashConfig]: "llama_3_1_8b": partial(get_automodel_transformers, "NousResearch/Hermes-3-Llama-3.1-8B"), "llama_3_tiny_random": partial(get_automodel_transformers, "llamafactory/tiny-random-Llama-3"), "dummy_lambda": dummy_model, + + # image generation AR models + "tiny_janus_pro": partial(get_janus_model, "loulou2/tiny_janus"), }