Skip to content

fix: device_map specification for accelerate-compatible quantizers#226

Merged
johannaSommer merged 6 commits intomainfrom
fix/quantizer-device-casting
Jul 7, 2025
Merged

fix: device_map specification for accelerate-compatible quantizers#226
johannaSommer merged 6 commits intomainfrom
fix/quantizer-device-casting

Conversation

@johannaSommer
Copy link
Copy Markdown
Member

@johannaSommer johannaSommer commented Jul 1, 2025

Description

This PR fixes a small bug in the device map specification in pruna, namely, the device map should be specified as {"":"cuda:0"} instead of {"":"cuda"}. Otherwise, the resulting model will work at inference time but will no longer be compatible with other accelerate functionality.

Related Issue

None.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

How Has This Been Tested?

Reran all algorithms tests in pruna (due to change in the load-function for transformers models), including the accelerate tests for both quantizers.

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Additional Notes

At the moment we assume throughout the pruna repository that the device index is 0, hence the changes in this PR. We will update this soon and be compatible with specifying the device index when using only "cuda".

johannaSommer and others added 2 commits July 1, 2025 15:55
Co-authored-by: simlang <simon.langrieger@pruna.ai>
Co-authored-by: simglang <simon.langrieger@pruna.ai>
Copy link
Copy Markdown
Collaborator

@gsprochette gsprochette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great ! Thanks for the improvement and for future-proofing pruna :)
While we're on this subject I think the get_device function in pruna.engine.utils could be improved:

  • most important: the return_device_map is only handled in the case of the hf_device_map check at the end, but the model_device.type does not seem to be a valid device_map as suggested by the argument name and docstring. This will probably cause issues in the future.
  • readability: the first if statement would be slightly cleaner with the else case first
  • readability: the initial model_device is overwritten at the end, there should maybe be a if..else to make the different cases more explicit

We can make this into a separate PR but the first point seems related.

Copy link
Copy Markdown
Member

@simlang simlang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes so far look good to me, thanks for fixing it so fast! 👍
Will approve, depending on how we handle @gsprochette comment 🙂

Comment thread src/pruna/engine/utils.py Outdated
else:
return model.hf_device_map[subset_key]
else:
device = "cuda:0" if model_device == "cuda" else "cpu"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should probably be mps case here

Comment thread src/pruna/engine/load.py
Copy link
Copy Markdown
Member

@simlang simlang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Very clean now 🫧🧼

@gsprochette gsprochette self-requested a review July 7, 2025 14:49
Copy link
Copy Markdown
Collaborator

@gsprochette gsprochette left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks absolutely perfect, thanks for all the "just this last update" iterations it's so clean now :)

@johannaSommer johannaSommer merged commit 41a1ad1 into main Jul 7, 2025
6 checks passed
@johannaSommer johannaSommer deleted the fix/quantizer-device-casting branch July 7, 2025 15:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants