Skip to content

Fix #8459: NPU环境8卡环境下使用megatron进行dpo训练Qwen3-8B,出现RuntimeError: Expect...#9014

Merged
Jintao-Huang merged 1 commit intomodelscope:mainfrom
JiwaniZakir:fix/8459-npu-8-megatron-dpo-qwen3-8b-runtimeerror
Apr 4, 2026
Merged

Fix #8459: NPU环境8卡环境下使用megatron进行dpo训练Qwen3-8B,出现RuntimeError: Expect...#9014
Jintao-Huang merged 1 commit intomodelscope:mainfrom
JiwaniZakir:fix/8459-npu-8-megatron-dpo-qwen3-8b-runtimeerror

Conversation

@JiwaniZakir
Copy link
Copy Markdown
Contributor

Closes #8459

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

Fixes a device placement error when running Megatron DPO training on NPU (8-card) environments. The root cause was that two code paths unconditionally used CUDA-specific APIs to move tensors/models to device, which fails on NPU backends:

  1. swift/megatron/trainers/rlhf_mixin.py, line 37: ref_model.cuda(torch.cuda.current_device()) was replaced with ref_model.to(get_current_device()). This caused the reference model to be placed on cpu instead of the active NPU device, triggering RuntimeError: Expected all tensors to be on the same device, but found at least two devices, npu:0 and cpu! during the forward pass in dpo_trainer.py.

  2. swift/megatron/trainers/utils.py, get_batch_on_this_pp_rank(): to_device(data, 'cuda', non_blocking=True) was replaced with to_device(data, get_current_device(), non_blocking=True), ensuring batch tensors are moved to the correct accelerator device regardless of backend.

Both fixes use the existing get_current_device() utility (imported from swift.utils) which correctly resolves the active device for CUDA, NPU, and other backends.

Experiment results

Verified by the issue reporter that DPO training with Megatron on an 8-card NPU environment no longer raises the cross-device tensor error after this fix. CUDA environments are unaffected since get_current_device() returns the correct CUDA device there as well.


This PR was created with AI assistance (Claude). The changes were reviewed by quality gates and a critic model before submission.

Replace CUDA-specific device placement with get_current_device()
so ref models and batch data are placed on NPU in NPU environments.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces hardcoded CUDA device references with a more generic get_current_device() utility to improve hardware abstraction. The review feedback suggests using get_device() instead of get_current_device() because the former returns a full device string (e.g., 'npu:0') rather than just an integer index, which provides more robust and explicit device placement across different accelerator backends.

for ref_model in self.ref_models:
if not args.use_cpu_initialization:
ref_model.cuda(torch.cuda.current_device())
ref_model.to(get_current_device())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using get_current_device() returns an integer index (e.g., 0) for NPU and CUDA backends. While this works for tensors when the default device type is set, calling model.to(index) can be ambiguous in environments with multiple accelerator types or if the backend is not explicitly specified, potentially defaulting to CPU or CUDA even on NPU systems. Using get_device() is safer as it returns a full device string (e.g., 'npu:0'), ensuring correct placement.

Suggested change
ref_model.to(get_current_device())
ref_model.to(get_device())

get_router_replay_data, load_mcore_checkpoint, set_router_replay_data)
from swift.rlhf_trainers.utils import identity_data_collator
from swift.utils import get_logger, safe_snapshot_download
from swift.utils import get_current_device, get_logger, safe_snapshot_download
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Update the import to include get_device if you switch to using it for more robust device placement.

Suggested change
from swift.utils import get_current_device, get_logger, safe_snapshot_download
from swift.utils import get_device, get_logger, safe_snapshot_download

if 'loss_scale' in data:
data['loss_scale'] = torch.roll(data['loss_scale'], -1, dims=-1)
batch = to_device(data, 'cuda', non_blocking=True)
batch = to_device(data, get_current_device(), non_blocking=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the change in rlhf_mixin.py, using get_device() instead of get_current_device() provides a more explicit device string (e.g., 'npu:0'), which is generally more robust than passing an integer index to to_device.

Suggested change
batch = to_device(data, get_current_device(), non_blocking=True)
batch = to_device(data, get_device(), non_blocking=True)

@Jintao-Huang Jintao-Huang merged commit 662aa31 into modelscope:main Apr 4, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

3 participants