diff --git a/src/main/java/com/gpuopenanalytics/jenkins/remotedocker/config/NvidiaGpuDevicesConfigItem.java b/src/main/java/com/gpuopenanalytics/jenkins/remotedocker/config/NvidiaGpuDevicesConfigItem.java index 987e5d4..cce58b7 100644 --- a/src/main/java/com/gpuopenanalytics/jenkins/remotedocker/config/NvidiaGpuDevicesConfigItem.java +++ b/src/main/java/com/gpuopenanalytics/jenkins/remotedocker/config/NvidiaGpuDevicesConfigItem.java @@ -26,12 +26,22 @@ import com.gpuopenanalytics.jenkins.remotedocker.AbstractDockerLauncher; import hudson.Extension; +import hudson.Launcher; import hudson.model.Descriptor; import hudson.util.ArgumentListBuilder; import org.apache.commons.lang.StringUtils; import org.jenkinsci.Symbol; import org.kohsuke.stapler.DataBoundConstructor; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + /** * Defines which GPU devices are visible in the container. Passes * -e NVIDIA_VISIBLE_DEVICES=value @@ -69,7 +79,13 @@ public void addCreateArgs(AbstractDockerLauncher launcher, ArgumentListBuilder args) { String value; if ("executor".equals(getValue())) { - value = launcher.getEnvironment().get("EXECUTOR_NUMBER"); + String executorNum = launcher.getEnvironment().get("EXECUTOR_NUMBER"); + String nvidiasmiOutput = executeWithOutput(launcher.getInner(), "nvidia-smi", "-L"); + if (isMIG(nvidiasmiOutput)) { + value = getMIG(nvidiasmiOutput, executorNum); + } else { + value = executorNum; + } } else { value = getResolvedValue(launcher); } @@ -99,4 +115,45 @@ public String getDisplayName() { return "NVIDIA Device Visibility"; } } + + private String executeWithOutput(Launcher launcher, String... args) { + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + int status = launcher.launch() + .cmds(args) + .stdout(baos) + .stderr(launcher.getListener().getLogger()) + .join(); + if (status != 0) { + throw new RuntimeException( + "Non-zero status " + status + ": " + Arrays + .toString(args)); + } + return baos.toString(StandardCharsets.UTF_8.name()).trim(); + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + } + + private boolean isMIG(String output) { + Pattern pattern = Pattern.compile("(MIG-GPU-[a-f0-9\\-\\/]+)"); + Matcher m = pattern.matcher(output); + + if (m.find()) { + return true; + } + return false; + } + + private String getMIG(String output, String executor) { + int executorNum = Integer.parseInt(executor); + List uuids = new ArrayList(); + Pattern pattern = Pattern.compile("(MIG-GPU-[a-f0-9\\-\\/]+)"); + Matcher m = pattern.matcher(output); + + while (m.find()) { + uuids.add(m.group()); + } + return uuids.get(executorNum); + } }