Skip to content

feat: improve overall device placement handling#148

Merged
davidberenstein1957 merged 11 commits intomainfrom
feat/108-feature-improve-overall-device-placement-handling
Jun 3, 2025
Merged

feat: improve overall device placement handling#148
davidberenstein1957 merged 11 commits intomainfrom
feat/108-feature-improve-overall-device-placement-handling

Conversation

@davidberenstein1957
Copy link
Copy Markdown
Member

@davidberenstein1957 davidberenstein1957 commented May 22, 2025

Description

Before, running the following code on CPU would fail because it was being forced on CUDA.

from pruna.evaluation.metrics import CMMD

CMMD()

Related Issue

Fixes #108

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

Copy link
Copy Markdown

@guennemann guennemann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice. See some minor request inline.

Comment thread src/pruna/evaluation/metrics/metric_cmmd.py Outdated
Comment thread src/pruna/evaluation/metrics/metric_elapsed_time.py Outdated
Copy link
Copy Markdown

@guennemann guennemann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks.

Copy link
Copy Markdown
Member

@johannaSommer johannaSommer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM from my side w.r.t. device casting, etc, but I really feel @begumcig should review this.

@johannaSommer johannaSommer requested review from begumcig and removed request for johnrachwan123 May 28, 2025 11:11
Copy link
Copy Markdown
Member

@begumcig begumcig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good to me! Thank you especially for solving the device inconsistencies between the Task and metrics! Left some suggestions! Also we need a change for the memory metric as it is measuring the GPU usage

Comment thread src/pruna/evaluation/metrics/metric_memory.py Outdated
Comment thread src/pruna/evaluation/metrics/registry.py Outdated
Comment thread src/pruna/engine/utils.py Outdated
Comment thread src/pruna/evaluation/evaluation_agent.py Outdated
Comment thread src/pruna/evaluation/task.py Outdated
- Added `check_device_compatibility` utility to ensure proper device assignment in various metrics.
- Updated device handling in `Task`, `CMMD`, `InferenceTimeStats`, and other metrics to utilize the new compatibility check.
- Improved docstrings to clarify device parameter usage and fallback behavior.
- Streamlined device assignment in model evaluation and metric calculations for better robustness.
- Updated all instances of `check_device_compatibility` to `set_to_best_available_device` across various modules, including `SmashConfig`, evaluation metrics, and task handling.
- Modified the `run_inference` method in `PrunaModel` to accept a device parameter that defaults to None, utilizing the new utility function `set_to_best_available_device` for improved device management.
- Cleaned up the import statements in `pruna_model.py` to include the new device utility.
- Removed unnecessary blank line in the test file `test_cmmd.py` for better code cleanliness.
- Updated the `set_to_best_available_device` function to raise a ValueError for unsupported devices, improving error handling.
- Modified the `Task` class to directly use the provided device parameter instead of relying on the utility function for device assignment.
- Enhanced the `get_metrics` and `_process_metric_names` functions to accept and utilize the device parameter, ensuring consistent device management across metric processing.
- Improved docstrings to clarify the usage of the device parameter in various functions.
- Updated the event handling in the `InferenceTimeStats` class to improve device attribute access, enhancing code clarity and maintainability.
- Replaced direct calls to `getattr(torch, self.device)` with a local variable for better performance and readability.
- Revised docstrings in multiple classes and functions to clarify the usage of the device parameter, removing redundant phrasing related to "smashing."
- Ensured consistency in the description of device handling across various evaluation metrics and configurations.
…tats

- Added a try-except block around the device attribute access to raise a ValueError when an unsupported device is specified for sync timing, ensuring clearer error reporting and fallback behavior.
- Improved error handling for the 'accelerate' device to raise a ValueError when neither CUDA nor MPS is available.
- Streamlined checks for 'cuda' and 'mps' devices, ensuring warnings are logged when the requested device is unavailable, and fallback behavior is maintained.
- Updated the GPUMemoryStats class to include support for MPS devices, using -1 as a placeholder.
- Improved device index retrieval logic for both CUDA and MPS, ensuring consistent handling of device types.
- Streamlined inference method calls to utilize the best available device, enhancing overall device management.
- Revised the documentation for the `mode` parameter in the `GPUMemoryStats` class to specify the correct options as 'disk_memory', 'inference_memory', and 'training_memory', enhancing clarity for users.
@davidberenstein1957 davidberenstein1957 force-pushed the feat/108-feature-improve-overall-device-placement-handling branch from dfe6534 to 59ec8f6 Compare June 3, 2025 12:42
- Modified the `prepare_inputs` method signatures across multiple handler classes to accept a more flexible input type, allowing for `List[str]`, `torch.Tensor`, or a tuple of these types.
- Removed redundant type annotations to enhance clarity and maintainability in the codebase.
- Adjusted the `CTranslateCompiler` class to remove an unnecessary line related to model configuration.
@davidberenstein1957 davidberenstein1957 merged commit da0ccdd into main Jun 3, 2025
6 checks passed
@johannaSommer johannaSommer mentioned this pull request Jun 30, 2025
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[FEATURE] Improve overall device placement handling

4 participants