If I attempt to perform evaluation using the Evaluation Agent using a HF dataset such as data-is-better-together/open-image-preferences-v1-binarized, after I rename the dataset columns in the PrunaDataModule to image and text the following error is seen.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[3], line 31
28 smashed_pipe.move_to_device(device)
30 # Evaluate the smashed model pipeline using the evaluation agent
---> 31 smashed_model_results = eval_agent.evaluate(smashed_pipe)
33 # Optionally, print results for verification
34 print(smashed_model_results)
File ~/pruna/src/pruna/evaluation/evaluation_agent.py:107, in EvaluationAgent.evaluate(self, model)
105 # Compute stateless metrics.
106 pruna_logger.info("Evaluating isolated inference metrics.")
--> 107 results.extend(self.compute_stateless_metrics(model, stateless_metrics))
109 model.move_to_device("cpu")
110 safe_memory_cleanup()
File ~/pruna/src/pruna/evaluation/evaluation_agent.py:260, in EvaluationAgent.compute_stateless_metrics(self, model, stateless_metrics)
257 parent_to_children, children_of_base = group_metrics_by_inheritance(stateless_metrics)
258 for (parent, _), children in parent_to_children.items():
259 # Get the metrics that share a common parent to share inference computation by calling the parent metric.
--> 260 raw_results = parent.compute(children[0], model, self.task.dataloader)
261 for child in children:
262 results.append(
263 MetricResult.from_results_dict(child.metric_name, dict(children[0].__dict__), raw_results)
264 )
File ~/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/pruna/src/pruna/evaluation/metrics/metric_elapsed_time.py:164, in InferenceTimeStats.compute(self, model, dataloader)
161 model.move_to_device(self.device)
163 # Warmup
--> 164 self._measure(
165 model,
166 dataloader,
167 self.n_warmup_iterations,
168 lambda m, x: (
169 m(**x, **m.inference_handler.model_args) # x is a dict
170 if isinstance(x, dict)
171 else m(x, **m.inference_handler.model_args) # x is tensor/list
172 ),
173 )
175 # Measurement
176 list_elapsed_times = []
File ~/pruna/src/pruna/evaluation/metrics/metric_elapsed_time.py:96, in InferenceTimeStats._measure(self, model, dataloader, iterations, measure_fn)
94 batch = model.inference_handler.move_inputs_to_device(batch, self.device)
95 x = model.inference_handler.prepare_inputs(batch)
---> 96 measure_fn(model, x)
97 c += 1
98 if c >= iterations:
File ~/pruna/src/pruna/evaluation/metrics/metric_elapsed_time.py:171, in InferenceTimeStats.compute.<locals>.<lambda>(m, x)
161 model.move_to_device(self.device)
163 # Warmup
164 self._measure(
165 model,
166 dataloader,
167 self.n_warmup_iterations,
168 lambda m, x: (
169 m(**x, **m.inference_handler.model_args) # x is a dict
170 if isinstance(x, dict)
--> 171 else m(x, **m.inference_handler.model_args) # x is tensor/list
172 ),
173 )
175 # Measurement
176 list_elapsed_times = []
File ~/pruna/src/pruna/telemetry/metrics.py:218, in track_usage.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
216 smash_config = repr(smash_config) if smash_config is not None else ""
217 try:
--> 218 result = func(*args, **kwargs)
219 increment_counter(function_name, success=True, smash_config=smash_config)
220 return result
File ~/pruna/src/pruna/engine/pruna_model.py:75, in PrunaModel.__call__(self, *args, **kwargs)
73 else:
74 with torch.no_grad():
---> 75 return self.model.__call__(*args, **kwargs)
File ~/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/.venv/lib/python3.11/site-packages/diffusers/pipelines/sana/pipeline_sana.py:942, in SanaPipeline.__call__(self, prompt, negative_prompt, num_inference_steps, timesteps, sigmas, guidance_scale, num_images_per_prompt, height, width, eta, generator, latents, prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask, output_type, return_dict, clean_caption, use_resolution_binning, attention_kwargs, callback_on_step_end, callback_on_step_end_tensor_inputs, max_sequence_length, complex_human_instruction)
939 timestep = timestep * self.transformer.config.timestep_scale
941 # predict noise model_output
--> 942 noise_pred = self.transformer(
943 latent_model_input.to(dtype=transformer_dtype),
944 encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
945 encoder_attention_mask=prompt_attention_mask,
946 timestep=timestep,
947 return_dict=False,
948 attention_kwargs=self.attention_kwargs,
949 )[0]
950 noise_pred = noise_pred.float()
952 # perform guidance
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File ~/.venv/lib/python3.11/site-packages/diffusers/models/transformers/sana_transformer.py:532, in SanaTransformer2DModel.forward(self, hidden_states, encoder_hidden_states, timestep, guidance, encoder_attention_mask, attention_mask, attention_kwargs, controlnet_block_samples, return_dict)
529 p = self.config.patch_size
530 post_patch_height, post_patch_width = height // p, width // p
--> 532 hidden_states = self.patch_embed(hidden_states)
534 if guidance is not None:
535 timestep, embedded_timestep = self.time_embed(
536 timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
537 )
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File ~/.venv/lib/python3.11/site-packages/diffusers/models/embeddings.py:547, in PatchEmbed.forward(self, latent)
545 else:
546 height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
--> 547 latent = self.proj(latent)
548 if self.flatten:
549 latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1751, in Module._wrapped_call_impl(self, *args, **kwargs)
1749 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1750 else:
-> 1751 return self._call_impl(*args, **kwargs)
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1762, in Module._call_impl(self, *args, **kwargs)
1757 # If we don't have any hooks, we want to skip the rest of the logic in
1758 # this function, and just call forward.
1759 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1760 or _global_backward_pre_hooks or _global_backward_hooks
1761 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1762 return forward_call(*args, **kwargs)
1764 result = None
1765 called_always_called_hooks = set()
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py:554, in Conv2d.forward(self, input)
553 def forward(self, input: Tensor) -> Tensor:
--> 554 return self._conv_forward(input, self.weight, self.bias)
File ~/.venv/lib/python3.11/site-packages/torch/nn/modules/conv.py:549, in Conv2d._conv_forward(self, input, weight, bias)
537 if self.padding_mode != "zeros":
538 return F.conv2d(
539 F.pad(
540 input, self._reversed_padding_repeated_twice, mode=self.padding_mode
(...) 547 self.groups,
548 )
--> 549 return F.conv2d(
550 input, weight, bias, self.stride, self.padding, self.dilation, self.groups
551 )
RuntimeError: Input type (float) and bias type (c10::Half) should be the same
The Evaluation Agent should evaluate the model.
Describe the bug
If I attempt to perform evaluation using the Evaluation Agent using a HF dataset such as data-is-better-together/open-image-preferences-v1-binarized, after I rename the dataset columns in the
PrunaDataModuletoimageandtextthe following error is seen.What I did
I also attempted to use the different data types in the collate function to
int,floatandnormalized, but the eval is still failing. WDYT?Expected behavior
The Evaluation Agent should evaluate the model.
Environment