diff --git a/src/main/java/com/gpuopenanalytics/jenkins/remotedocker/config/NvidiaGpuDevicesConfigItem.java b/src/main/java/com/gpuopenanalytics/jenkins/remotedocker/config/NvidiaGpuDevicesConfigItem.java
index 987e5d4..cce58b7 100644
--- a/src/main/java/com/gpuopenanalytics/jenkins/remotedocker/config/NvidiaGpuDevicesConfigItem.java
+++ b/src/main/java/com/gpuopenanalytics/jenkins/remotedocker/config/NvidiaGpuDevicesConfigItem.java
@@ -26,12 +26,22 @@
import com.gpuopenanalytics.jenkins.remotedocker.AbstractDockerLauncher;
import hudson.Extension;
+import hudson.Launcher;
import hudson.model.Descriptor;
import hudson.util.ArgumentListBuilder;
import org.apache.commons.lang.StringUtils;
import org.jenkinsci.Symbol;
import org.kohsuke.stapler.DataBoundConstructor;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
/**
* Defines which GPU devices are visible in the container. Passes
* -e NVIDIA_VISIBLE_DEVICES=value
@@ -69,7 +79,13 @@ public void addCreateArgs(AbstractDockerLauncher launcher,
ArgumentListBuilder args) {
String value;
if ("executor".equals(getValue())) {
- value = launcher.getEnvironment().get("EXECUTOR_NUMBER");
+ String executorNum = launcher.getEnvironment().get("EXECUTOR_NUMBER");
+ String nvidiasmiOutput = executeWithOutput(launcher.getInner(), "nvidia-smi", "-L");
+ if (isMIG(nvidiasmiOutput)) {
+ value = getMIG(nvidiasmiOutput, executorNum);
+ } else {
+ value = executorNum;
+ }
} else {
value = getResolvedValue(launcher);
}
@@ -99,4 +115,45 @@ public String getDisplayName() {
return "NVIDIA Device Visibility";
}
}
+
+ private String executeWithOutput(Launcher launcher, String... args) {
+ try {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ int status = launcher.launch()
+ .cmds(args)
+ .stdout(baos)
+ .stderr(launcher.getListener().getLogger())
+ .join();
+ if (status != 0) {
+ throw new RuntimeException(
+ "Non-zero status " + status + ": " + Arrays
+ .toString(args));
+ }
+ return baos.toString(StandardCharsets.UTF_8.name()).trim();
+ } catch (InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private boolean isMIG(String output) {
+ Pattern pattern = Pattern.compile("(MIG-GPU-[a-f0-9\\-\\/]+)");
+ Matcher m = pattern.matcher(output);
+
+ if (m.find()) {
+ return true;
+ }
+ return false;
+ }
+
+ private String getMIG(String output, String executor) {
+ int executorNum = Integer.parseInt(executor);
+ List uuids = new ArrayList();
+ Pattern pattern = Pattern.compile("(MIG-GPU-[a-f0-9\\-\\/]+)");
+ Matcher m = pattern.matcher(output);
+
+ while (m.find()) {
+ uuids.add(m.group());
+ }
+ return uuids.get(executorNum);
+ }
}