Skip to content

[BUG] Error in performing eval with Evaluation Agent with HF Dataset #273

@ParagEkbote

Description

@ParagEkbote

Describe the bug

If I attempt to perform evaluation using the Evaluation Agent using a HF dataset such as data-is-better-together/open-image-preferences-v1-binarized, after I rename the dataset columns in the PrunaDataModule to image and text the following error is seen.

What I did

from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.metrics import (
    LatencyMetric,
    TotalTimeMetric,
)
from pruna.evaluation.task import Task
from pruna import PrunaModel

# Load the smashed (optimized) model pipeline from Hugging Face Hub
smashed_pipe = PrunaModel.from_hub("AINovice2005/Sana_600M_ControlNet_HED-smashed")

# Define evaluation metrics (example: total time and latency)
metrics = [
    TotalTimeMetric(n_iterations=1, n_warmup_iterations=1),
    LatencyMetric(n_iterations=1, n_warmup_iterations=1),
]


task = Task(metrics, datamodule=datamodule, device=device)

# Initialize the evaluation agent
eval_agent = EvaluationAgent(task)

# Move smashed model to evaluation device (GPU or CPU)
smashed_pipe.move_to_device(device)

# Evaluate the smashed model pipeline using the evaluation agent
smashed_model_results = eval_agent.evaluate(smashed_pipe)

# Optionally, print results for verification
print(smashed_model_results)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[3], line 31
     28 smashed_pipe.move_to_device(device)
     30 # Evaluate the smashed model pipeline using the evaluation agent
---> 31 smashed_model_results = eval_agent.evaluate(smashed_pipe)
     33 # Optionally, print results for verification
     34 print(smashed_model_results)

File ~/pruna/src/pruna/evaluation/evaluation_agent.py:107, in EvaluationAgent.evaluate(self, model)
    105 # Compute stateless metrics.
    106 pruna_logger.info("Evaluating isolated inference metrics.")
--> 107 results.extend(self.compute_stateless_metrics(model, stateless_metrics))
    109 model.move_to_device("cpu")
    110 safe_memory_cleanup()

File ~/pruna/src/pruna/evaluation/evaluation_agent.py:260, in EvaluationAgent.compute_stateless_metrics(self, model, stateless_metrics)
    257 parent_to_children, children_of_base = group_metrics_by_inheritance(stateless_metrics)
    258 for (parent, _), children in parent_to_children.items():
    259     # Get the metrics that share a common parent to share inference computation by calling the parent metric.
--> 260     raw_results = parent.compute(children[0], model, self.task.dataloader)
    261     for child in children:
    262         results.append(
    263             MetricResult.from_results_dict(child.metric_name, dict(children[0].__dict__), raw_results)
    264         )

File ~/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/pruna/src/pruna/evaluation/metrics/metric_elapsed_time.py:164, in InferenceTimeStats.compute(self, model, dataloader)
    161 model.move_to_device(self.device)
    163 # Warmup
--> 164 self._measure(
    165     model,
    166     dataloader,
    167     self.n_warmup_iterations,
    168     lambda m, x: (
    169         m(**x, **m.inference_handler.model_args)  # x is a dict
    170         if isinstance(x, dict)
    171         else m(x, **m.inference_handler.model_args)  # x is tensor/list
    172     ),
    173 )
    175 # Measurement
    176 list_elapsed_times = []

File ~/pruna/src/pruna/evaluation/metrics/metric_elapsed_time.py:96, in InferenceTimeStats._measure(self, model, dataloader, iterations, measure_fn)
     94 batch = model.inference_handler.move_inputs_to_device(batch, self.device)
     95 x = model.inference_handler.prepare_inputs(batch)
---> 96 measure_fn(model, x)
     97 c += 1
     98 if c >= iterations:

File ~/pruna/src/pruna/evaluation/metrics/metric_elapsed_time.py:171, in InferenceTimeStats.compute.<locals>.<lambda>(m, x)
    161 model.move_to_device(self.device)
    163 # Warmup
    164 self._measure(
    165     model,
    166     dataloader,
    167     self.n_warmup_iterations,
    168     lambda m, x: (
    169         m(**x, **m.inference_handler.model_args)  # x is a dict
    170         if isinstance(x, dict)
--> 171         else m(x, **m.inference_handler.model_args)  # x is tensor/list
    172     ),
    173 )
    175 # Measurement
    176 list_elapsed_times = []

File ~/pruna/src/pruna/telemetry/metrics.py:218, in track_usage.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    216 smash_config = repr(smash_config) if smash_config is not None else ""
    217 try:
--> 218     result = func(*args, **kwargs)
    219     increment_counter(function_name, success=True, smash_config=smash_config)
    220     return result

File ~/pruna/src/pruna/engine/pruna_model.py:75, in PrunaModel.__call__(self, *args, **kwargs)
     73 else:
     74     with torch.no_grad():
---> 75         return self.model.__call__(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/diffusers/pipelines/sana/pipeline_sana.py:942, in SanaPipeline.__call__(self, prompt, negative_prompt, num_inference_steps, timesteps, sigmas, guidance_scale, num_images_per_prompt, height, width, eta, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, output_type, return_dict, clean_caption, use_resolution_binning, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length, complex_human_instruction)
    939 timestep = timestep * self.transformer.config.timestep_scale
    941 # predict noise model_output
--> 942 noise_pred = self.transformer(
    943     latent_model_input.to(dtype=transformer_dtype),
    944     encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
    945     encoder_attention_mask=prompt_attention_mask,
    946     timestep=timestep,
    947     return_dict=False,
    948     attention_kwargs=self.attention_kwargs,
    949 )[0]
    950 noise_pred = noise_pred.float()
    952 # perform guidance

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File ~/.venv/lib/python3.11/site-packages/diffusers/models/transformers/sana_transformer.py:532, in SanaTransformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, guidance, encoder_attention_mask, attention_mask, attention_kwargs, controlnet_block_samples, return_dict)
    529 p = self.config.patch_size
    530 post_patch_height, post_patch_width = height // p, width // p
--> 532 hidden_states = self.patch_embed(hidden_states)
    534 if guidance is not None:
    535     timestep, embedded_timestep = self.time_embed(
    536         timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
    537     )

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File ~/.venv/lib/python3.11/site-packages/diffusers/models/embeddings.py:547, in PatchEmbed.forward(self, latent)
    545 else:
    546     height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
--> 547 latent = self.proj(latent)
    548 if self.flatten:
    549     latent = latent.flatten(2).transpose(1, 2)  # BCHW -> BNC

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
   1749     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1750 else:
-> 1751     return self._call_impl(*args, **kwargs)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
   1757 # If we don't have any hooks, we want to skip the rest of the logic in
   1758 # this function, and just call forward.
   1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1760         or _global_backward_pre_hooks or _global_backward_hooks
   1761         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762     return forward_call(*args, **kwargs)
   1764 result = None
   1765 called_always_called_hooks = set()

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py:554, in Conv2d.forward(self, input)
    553 def forward(self, input: Tensor) -> Tensor:
--> 554     return self._conv_forward(input, self.weight, self.bias)

File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py:549, in Conv2d._conv_forward(self, input, weight, bias)
    537 if self.padding_mode != "zeros":
    538     return F.conv2d(
    539         F.pad(
    540             input, self._reversed_padding_repeated_twice, mode=self.padding_mode
   (...)    547         self.groups,
    548     )
--> 549 return F.conv2d(
    550     input, weight, bias, self.stride, self.padding, self.dilation, self.groups
    551 )

RuntimeError: Input type (float) and bias type (c10::Half) should be the same

I also attempted to use the different data types in the collate function to int, float and normalized, but the eval is still failing. WDYT?

Expected behavior

The Evaluation Agent should evaluate the model.

Environment

  • pruna version: 0.2.6
  • python version: 3.11.2
  • Operating System: Linux-5.15.0-1084-aws-x86_64-with-glibc2.31

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions