Fix #8459: NPU环境8卡环境下使用megatron进行dpo训练Qwen3-8B,出现RuntimeError: Expect...#9014
Conversation
Replace CUDA-specific device placement with get_current_device() so ref models and batch data are placed on NPU in NPU environments. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request replaces hardcoded CUDA device references with a more generic get_current_device() utility to improve hardware abstraction. The review feedback suggests using get_device() instead of get_current_device() because the former returns a full device string (e.g., 'npu:0') rather than just an integer index, which provides more robust and explicit device placement across different accelerator backends.
| for ref_model in self.ref_models: | ||
| if not args.use_cpu_initialization: | ||
| ref_model.cuda(torch.cuda.current_device()) | ||
| ref_model.to(get_current_device()) |
There was a problem hiding this comment.
Using get_current_device() returns an integer index (e.g., 0) for NPU and CUDA backends. While this works for tensors when the default device type is set, calling model.to(index) can be ambiguous in environments with multiple accelerator types or if the backend is not explicitly specified, potentially defaulting to CPU or CUDA even on NPU systems. Using get_device() is safer as it returns a full device string (e.g., 'npu:0'), ensuring correct placement.
| ref_model.to(get_current_device()) | |
| ref_model.to(get_device()) |
| get_router_replay_data, load_mcore_checkpoint, set_router_replay_data) | ||
| from swift.rlhf_trainers.utils import identity_data_collator | ||
| from swift.utils import get_logger, safe_snapshot_download | ||
| from swift.utils import get_current_device, get_logger, safe_snapshot_download |
There was a problem hiding this comment.
| if 'loss_scale' in data: | ||
| data['loss_scale'] = torch.roll(data['loss_scale'], -1, dims=-1) | ||
| batch = to_device(data, 'cuda', non_blocking=True) | ||
| batch = to_device(data, get_current_device(), non_blocking=True) |
There was a problem hiding this comment.
Similar to the change in rlhf_mixin.py, using get_device() instead of get_current_device() provides a more explicit device string (e.g., 'npu:0'), which is generally more robust than passing an integer index to to_device.
| batch = to_device(data, get_current_device(), non_blocking=True) | |
| batch = to_device(data, get_device(), non_blocking=True) |
Closes #8459
PR type
PR information
Fixes a device placement error when running Megatron DPO training on NPU (8-card) environments. The root cause was that two code paths unconditionally used CUDA-specific APIs to move tensors/models to device, which fails on NPU backends:
swift/megatron/trainers/rlhf_mixin.py, line 37:ref_model.cuda(torch.cuda.current_device())was replaced withref_model.to(get_current_device()). This caused the reference model to be placed oncpuinstead of the active NPU device, triggeringRuntimeError: Expected all tensors to be on the same device, but found at least two devices, npu:0 and cpu!during the forward pass indpo_trainer.py.swift/megatron/trainers/utils.py,get_batch_on_this_pp_rank():to_device(data, 'cuda', non_blocking=True)was replaced withto_device(data, get_current_device(), non_blocking=True), ensuring batch tensors are moved to the correct accelerator device regardless of backend.Both fixes use the existing
get_current_device()utility (imported fromswift.utils) which correctly resolves the active device for CUDA, NPU, and other backends.Experiment results
Verified by the issue reporter that DPO training with Megatron on an 8-card NPU environment no longer raises the cross-device tensor error after this fix. CUDA environments are unaffected since
get_current_device()returns the correct CUDA device there as well.This PR was created with AI assistance (Claude). The changes were reviewed by quality gates and a critic model before submission.