feat: improve overall device placement handling#148
Merged
davidberenstein1957 merged 11 commits intomainfrom Jun 3, 2025
Merged
feat: improve overall device placement handling#148davidberenstein1957 merged 11 commits intomainfrom
davidberenstein1957 merged 11 commits intomainfrom
Conversation
10 tasks
guennemann
requested changes
May 23, 2025
guennemann
left a comment
There was a problem hiding this comment.
Very nice. See some minor request inline.
johannaSommer
approved these changes
May 28, 2025
Member
johannaSommer
left a comment
There was a problem hiding this comment.
LGTM from my side w.r.t. device casting, etc, but I really feel @begumcig should review this.
begumcig
previously requested changes
May 28, 2025
Member
begumcig
left a comment
There was a problem hiding this comment.
Looks pretty good to me! Thank you especially for solving the device inconsistencies between the Task and metrics! Left some suggestions! Also we need a change for the memory metric as it is measuring the GPU usage
- Added `check_device_compatibility` utility to ensure proper device assignment in various metrics. - Updated device handling in `Task`, `CMMD`, `InferenceTimeStats`, and other metrics to utilize the new compatibility check. - Improved docstrings to clarify device parameter usage and fallback behavior. - Streamlined device assignment in model evaluation and metric calculations for better robustness.
- Updated all instances of `check_device_compatibility` to `set_to_best_available_device` across various modules, including `SmashConfig`, evaluation metrics, and task handling.
- Modified the `run_inference` method in `PrunaModel` to accept a device parameter that defaults to None, utilizing the new utility function `set_to_best_available_device` for improved device management. - Cleaned up the import statements in `pruna_model.py` to include the new device utility. - Removed unnecessary blank line in the test file `test_cmmd.py` for better code cleanliness.
- Updated the `set_to_best_available_device` function to raise a ValueError for unsupported devices, improving error handling. - Modified the `Task` class to directly use the provided device parameter instead of relying on the utility function for device assignment. - Enhanced the `get_metrics` and `_process_metric_names` functions to accept and utilize the device parameter, ensuring consistent device management across metric processing. - Improved docstrings to clarify the usage of the device parameter in various functions.
- Updated the event handling in the `InferenceTimeStats` class to improve device attribute access, enhancing code clarity and maintainability. - Replaced direct calls to `getattr(torch, self.device)` with a local variable for better performance and readability.
- Revised docstrings in multiple classes and functions to clarify the usage of the device parameter, removing redundant phrasing related to "smashing." - Ensured consistency in the description of device handling across various evaluation metrics and configurations.
…tats - Added a try-except block around the device attribute access to raise a ValueError when an unsupported device is specified for sync timing, ensuring clearer error reporting and fallback behavior.
- Improved error handling for the 'accelerate' device to raise a ValueError when neither CUDA nor MPS is available. - Streamlined checks for 'cuda' and 'mps' devices, ensuring warnings are logged when the requested device is unavailable, and fallback behavior is maintained.
- Updated the GPUMemoryStats class to include support for MPS devices, using -1 as a placeholder. - Improved device index retrieval logic for both CUDA and MPS, ensuring consistent handling of device types. - Streamlined inference method calls to utilize the best available device, enhancing overall device management.
- Revised the documentation for the `mode` parameter in the `GPUMemoryStats` class to specify the correct options as 'disk_memory', 'inference_memory', and 'training_memory', enhancing clarity for users.
dfe6534 to
59ec8f6
Compare
- Modified the `prepare_inputs` method signatures across multiple handler classes to accept a more flexible input type, allowing for `List[str]`, `torch.Tensor`, or a tuple of these types. - Removed redundant type annotations to enhance clarity and maintainability in the codebase. - Adjusted the `CTranslateCompiler` class to remove an unnecessary line related to model configuration.
10 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Before, running the following code on CPU would fail because it was being forced on CUDA.
Related Issue
Fixes #108
Type of Change
How Has This Been Tested?
Checklist
Additional Notes