fix: device_map specification for accelerate-compatible quantizers#226
Merged
johannaSommer merged 6 commits intomainfrom Jul 7, 2025
Merged
fix: device_map specification for accelerate-compatible quantizers#226johannaSommer merged 6 commits intomainfrom
device_map specification for accelerate-compatible quantizers#226johannaSommer merged 6 commits intomainfrom
Conversation
Co-authored-by: simlang <simon.langrieger@pruna.ai>
Co-authored-by: simglang <simon.langrieger@pruna.ai>
gsprochette
reviewed
Jul 2, 2025
Collaborator
There was a problem hiding this comment.
Looks great ! Thanks for the improvement and for future-proofing pruna :)
While we're on this subject I think the get_device function in pruna.engine.utils could be improved:
- most important: the
return_device_mapis only handled in the case of thehf_device_mapcheck at the end, but themodel_device.typedoes not seem to be a valid device_map as suggested by the argument name and docstring. This will probably cause issues in the future. - readability: the first if statement would be slightly cleaner with the else case first
- readability: the initial
model_deviceis overwritten at the end, there should maybe be aif..elseto make the different cases more explicit
We can make this into a separate PR but the first point seems related.
simlang
reviewed
Jul 2, 2025
Member
simlang
left a comment
There was a problem hiding this comment.
Changes so far look good to me, thanks for fixing it so fast! 👍
Will approve, depending on how we handle @gsprochette comment 🙂
gsprochette
reviewed
Jul 7, 2025
| else: | ||
| return model.hf_device_map[subset_key] | ||
| else: | ||
| device = "cuda:0" if model_device == "cuda" else "cpu" |
Collaborator
There was a problem hiding this comment.
There should probably be mps case here
simlang
reviewed
Jul 7, 2025
simlang
approved these changes
Jul 7, 2025
Member
simlang
left a comment
There was a problem hiding this comment.
Looks good! Very clean now 🫧🧼
gsprochette
approved these changes
Jul 7, 2025
Collaborator
gsprochette
left a comment
There was a problem hiding this comment.
Looks absolutely perfect, thanks for all the "just this last update" iterations it's so clean now :)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR fixes a small bug in the device map specification in
pruna, namely, the device map should be specified as{"":"cuda:0"}instead of{"":"cuda"}. Otherwise, the resulting model will work at inference time but will no longer be compatible with otheracceleratefunctionality.Related Issue
None.
Type of Change
How Has This Been Tested?
Reran all algorithms tests in pruna (due to change in the load-function for transformers models), including the accelerate tests for both quantizers.
Checklist
Additional Notes
At the moment we assume throughout the pruna repository that the device index is 0, hence the changes in this PR. We will update this soon and be compatible with specifying the device index when using only "cuda".