Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 198 additions & 4 deletions cuda_core/cuda/core/experimental/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,118 @@ def _process_define_macro(formatted_options, macro):
raise RuntimeError(f"Expected define_macro {union_type}, list[{union_type}], got {macro}")


def _format_options_for_backend(options_dict: dict, backend: str) -> list[str]:
"""Format compilation options for a specific backend.

This helper function converts a dictionary of option names and values into
properly formatted string options for the specified backend. Different backends
(NVRTC, NVVM, nvJitLink) use slightly different option naming conventions and
value formats.

Parameters
----------
options_dict : dict
Dictionary mapping option names to their values. The keys should be
generic option names (e.g., "arch", "debug", "ftz").
backend : str
The backend to format options for. Must be one of "NVRTC", "NVVM", or "nvJitLink".

Returns
-------
list[str]
List of formatted option strings suitable for the specified backend.

Raises
------
ValueError
If an unsupported backend is specified.

Notes
-----
- NVRTC uses `--` prefix and "true"/"false" for booleans
- NVVM uses `-` prefix and "1"/"0" for booleans
- nvJitLink uses `-` prefix and "true"/"false" for booleans
"""
if backend not in ("NVRTC", "NVVM", "nvJitLink"):
raise ValueError(f"Unsupported backend '{backend}'. Must be one of: NVRTC, NVVM, nvJitLink")

formatted = []

for key, value in options_dict.items():
if value is None:
continue

if backend == "NVRTC":
# NVRTC uses -- prefix
if key == "arch":
formatted.append(f"-arch={value}")
elif key == "debug" and value:
formatted.append("--device-debug")
elif key == "lineinfo" and value:
formatted.append("--generate-line-info")
elif key == "max_register_count":
formatted.append(f"--maxrregcount={value}")
elif key in ("ftz", "prec_sqrt", "prec_div"):
bool_val = "true" if value else "false"
# NVRTC uses hyphens in option names
option_name = key.replace("_", "-")
formatted.append(f"--{option_name}={bool_val}")
elif key == "fma":
bool_val = "true" if value else "false"
formatted.append(f"--fmad={bool_val}")
elif key == "device_code_optimize" and value:
formatted.append("--dopt=on")
elif key == "use_fast_math" and value:
formatted.append("--use_fast_math")
elif key == "link_time_optimization" and value:
formatted.append("--dlink-time-opt")
# Add more NVRTC-specific options as needed

elif backend == "NVVM":
# NVVM uses - prefix and 1/0 for booleans
if key == "arch":
# NVVM uses compute_ instead of sm_
arch_val = value
if arch_val.startswith("sm_"):
arch_val = f"compute_{arch_val[3:]}"
formatted.append(f"-arch={arch_val}")
elif key == "debug" and value:
formatted.append("-g")
elif key == "device_code_optimize":
# NVVM explicitly handles both True and False
if value is False:
formatted.append("-opt=0")
elif value is True:
formatted.append("-opt=3")
elif key in ("ftz", "prec_sqrt", "prec_div", "fma"):
bool_val = "1" if value else "0"
# NVVM uses hyphens in option names
option_name = key.replace("_", "-")
formatted.append(f"-{option_name}={bool_val}")
# lineinfo and link_time_optimization are not supported by NVVM, skip them

elif backend == "nvJitLink":
# nvJitLink uses - prefix and true/false for booleans
if key == "arch":
formatted.append(f"-arch={value}")
elif key == "debug" and value:
formatted.append("-g")
elif key == "lineinfo" and value:
formatted.append("-lineinfo")
elif key == "max_register_count":
formatted.append(f"-maxrregcount={value}")
elif key in ("ftz", "prec_sqrt", "prec_div", "fma"):
bool_val = "true" if value else "false"
# nvJitLink uses hyphens in option names
option_name = key.replace("_", "-")
formatted.append(f"-{option_name}={bool_val}")
elif key == "link_time_optimization" and value:
formatted.append("-lto")
# device_code_optimize is not supported by nvJitLink, skip it

return formatted


@dataclass
class ProgramOptions:
"""Customizable options for configuring `Program`.
Expand Down Expand Up @@ -422,9 +534,91 @@ def __post_init__(self):
if self.numba_debug:
self._formatted_options.append("--numba-debug")

def _as_bytes(self):
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
return list(o.encode() for o in self._formatted_options)
def as_bytes(self, backend: str = "NVRTC") -> list[bytes]:
"""Convert the formatted program options to a list of byte strings.

This method encodes the options stored in this `ProgramOptions` instance
into byte strings formatted for the specified backend, suitable for passing
to C libraries that calls the underlying compiler library.

Parameters
----------
backend : str, optional
The compiler backend to format options for. Must be one of:

- "NVRTC" (default): NVIDIA NVRTC compiler, supports all ProgramOptions
- "NVVM": NVIDIA NVVM compiler, supports a subset of options
- "nvJitLink": NVIDIA nvJitLink linker, supports a subset of options

Different backends use different option naming conventions and support
different option subsets. This method will format and filter options
appropriately for the chosen backend.

Returns
-------
list[bytes]
A list of byte-encoded option strings. Each element represents
a single compilation option in the format expected by the underlying compiler library.

Raises
------
ValueError
If an unsupported backend is specified.

Examples
--------
>>> options = ProgramOptions(arch="sm_80", debug=True)
>>> # Get options for NVRTC (default)
>>> nvrtc_options = options.as_bytes()
>>> print(nvrtc_options)
[b'-arch=sm_80', b'--device-debug']
>>>
>>> # Get options for NVVM
>>> nvvm_options = options.as_bytes("NVVM")
>>> print(nvvm_options)
[b'-arch=compute_80', b'-g']
>>>
>>> # Get options for nvJitLink
>>> nvjitlink_options = options.as_bytes("nvJitLink")
>>> print(nvjitlink_options)
[b'-arch=sm_80', b'-g']
"""
if backend == "NVRTC":
# For NVRTC, use the pre-formatted options (backward compatible)
# TODO: allow tuples once NVIDIA/cuda-python#72 is resolved
return list(o.encode() for o in self._formatted_options)

elif backend in ("NVVM", "nvJitLink"):
# For NVVM and nvJitLink, extract common options and format appropriately
options_dict = {}

# Common options supported by multiple backends
if self.arch is not None:
options_dict["arch"] = self.arch
if self.debug is not None:
options_dict["debug"] = self.debug
if self.lineinfo is not None:
options_dict["lineinfo"] = self.lineinfo
if self.max_register_count is not None:
options_dict["max_register_count"] = self.max_register_count
if self.ftz is not None:
options_dict["ftz"] = self.ftz
if self.prec_sqrt is not None:
options_dict["prec_sqrt"] = self.prec_sqrt
if self.prec_div is not None:
options_dict["prec_div"] = self.prec_div
if self.fma is not None:
options_dict["fma"] = self.fma
if self.device_code_optimize is not None:
options_dict["device_code_optimize"] = self.device_code_optimize
if self.link_time_optimization is not None:
options_dict["link_time_optimization"] = self.link_time_optimization

formatted_options = _format_options_for_backend(options_dict, backend)
return list(o.encode() for o in formatted_options)

else:
raise ValueError(f"Unsupported backend '{backend}'. Must be one of: NVRTC, NVVM, nvJitLink")

def __repr__(self):
# __TODO__ improve this
Expand Down Expand Up @@ -609,7 +803,7 @@ def compile(self, target_type, name_expressions=(), logs=None):
nvrtc.nvrtcAddNameExpression(self._mnff.handle, n.encode()),
handle=self._mnff.handle,
)
options = self._options._as_bytes()
options = self._options.as_bytes()
handle_return(
nvrtc.nvrtcCompileProgram(self._mnff.handle, len(options), options),
handle=self._mnff.handle,
Expand Down
106 changes: 106 additions & 0 deletions cuda_core/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,109 @@ def test_nvvm_program_options(init_cuda, nvvm_ir, options):
assert ".visible .entry simple(" in ptx_text

program.close()


def test_program_options_as_bytes():
"""Test that ProgramOptions.as_bytes() returns correct byte strings"""
# Test with various options
options = ProgramOptions(
arch="sm_80",
debug=True,
lineinfo=True,
max_register_count=32,
ftz=True,
use_fast_math=True,
)

byte_options = options.as_bytes()

# Verify the return type
assert isinstance(byte_options, list)
assert all(isinstance(opt, bytes) for opt in byte_options)

# Verify specific options are present in byte format
assert b"-arch=sm_80" in byte_options
assert b"--device-debug" in byte_options
assert b"--generate-line-info" in byte_options
assert b"--maxrregcount=32" in byte_options
assert b"--ftz=true" in byte_options
assert b"--use_fast_math" in byte_options


def test_program_options_as_bytes_empty():
"""Test that ProgramOptions.as_bytes() works with minimal options"""
# Test with minimal options (only defaults)
options = ProgramOptions()

byte_options = options.as_bytes()

# Should at least have arch option (automatically set based on Device if not provided)
assert isinstance(byte_options, list)
assert len(byte_options) > 0
assert all(isinstance(opt, bytes) for opt in byte_options)
# The arch option should be present (automatically determined from current device)
assert any(b"-arch=" in opt for opt in byte_options)


def test_program_options_as_bytes_nvvm_backend():
"""Test that ProgramOptions.as_bytes() formats options correctly for NVVM backend"""
options = ProgramOptions(
arch="sm_80",
debug=True,
ftz=True,
prec_sqrt=False,
prec_div=True,
fma=False,
device_code_optimize=True,
)

byte_options = options.as_bytes("NVVM")

# Verify the return type
assert isinstance(byte_options, list)
assert all(isinstance(opt, bytes) for opt in byte_options)

# NVVM uses compute_ instead of sm_ and 1/0 for booleans, with hyphens in option names
assert b"-arch=compute_80" in byte_options
assert b"-g" in byte_options
assert b"-ftz=1" in byte_options
assert b"-prec-sqrt=0" in byte_options
assert b"-prec-div=1" in byte_options
assert b"-fma=0" in byte_options
assert b"-opt=3" in byte_options


def test_program_options_as_bytes_nvjitlink_backend():
"""Test that ProgramOptions.as_bytes() formats options correctly for nvJitLink backend"""
options = ProgramOptions(
arch="sm_80",
debug=True,
lineinfo=True,
max_register_count=32,
ftz=False,
prec_sqrt=True,
link_time_optimization=True,
)

byte_options = options.as_bytes("nvJitLink")

# Verify the return type
assert isinstance(byte_options, list)
assert all(isinstance(opt, bytes) for opt in byte_options)

# nvJitLink uses - prefix and true/false for booleans, with hyphens in option names
assert b"-arch=sm_80" in byte_options
assert b"-g" in byte_options
assert b"-lineinfo" in byte_options
assert b"-maxrregcount=32" in byte_options
assert b"-ftz=false" in byte_options
assert b"-prec-sqrt=true" in byte_options
assert b"-lto" in byte_options


def test_program_options_as_bytes_invalid_backend():
"""Test that ProgramOptions.as_bytes() raises error for invalid backend"""
options = ProgramOptions()

with pytest.raises(ValueError, match="Unsupported backend 'invalid'"):
options.as_bytes("invalid")