Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
default_language_version:
python: python3

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.14.10
hooks:
# Run the linter
- id: ruff
args: [ --fix, --config, pyproject.toml ]
# Run the formatter
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
hooks:
- id: codespell
additional_dependencies:
- tomli
5 changes: 2 additions & 3 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ cuslines/
cu_propagate_seeds.py # SeedBatchPropagator: chunked seed processing
cu_direction_getters.py # Direction getter ABC + Boot/Prob/PTT implementations
cutils.py # REAL_DTYPE, REAL3_DTYPE, checkCudaErrors(), ModelType enum
_globals.py # AUTO-GENERATED from globals.h (never edit manually)
_globals.py # Global constants useful for all languages
cuda_c/ # CUDA kernel source
globals.h # Source-of-truth for constants (REAL_SIZE, thread config)
globals.h # CUDA specific global constants
generate_streamlines_cuda.cu, boot.cu, ptt.cu, tracking_helpers.cu, utils.cu
cudamacro.h, cuwsort.cuh, ptt.cuh, disc.h
metal/ # Metal backend (mirrors cuda_python/)
Expand Down Expand Up @@ -82,7 +82,6 @@ Each has `from_dipy_*()` class methods for initialization from DIPY models.

## Critical Conventions

- **`_globals.py` is auto-generated** from `cuslines/cuda_c/globals.h` during `setup.py` build via `defines_to_python()`. Never edit it manually; change `globals.h` and rebuild.
- **GPU arrays must be C-contiguous** — always use `np.ascontiguousarray()` and project scalar types (`REAL_DTYPE`, `REAL_SIZE` from `cutils.py` or `mutils.py`).
- **All CUDA API calls must be wrapped** with `checkCudaErrors()`.
- **Angle units**: CLI accepts degrees, internals convert to radians before the GPU layer.
Expand Down
20 changes: 16 additions & 4 deletions cuslines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,37 @@ def _detect_backend():
BACKEND = _detect_backend()

if BACKEND == "metal":
from cuslines.metal import (
MetalBootDirectionGetter as BootDirectionGetter,
)
from cuslines.metal import (
MetalGPUTracker as GPUTracker,
)
from cuslines.metal import (
MetalProbDirectionGetter as ProbDirectionGetter,
)
from cuslines.metal import (
MetalPttDirectionGetter as PttDirectionGetter,
MetalBootDirectionGetter as BootDirectionGetter,
)
elif BACKEND == "cuda":
from cuslines.cuda_python import (
BootDirectionGetter,
GPUTracker,
ProbDirectionGetter,
PttDirectionGetter,
BootDirectionGetter,
)
elif BACKEND == "webgpu":
from cuslines.webgpu import (
WebGPUTracker as GPUTracker,
WebGPUBootDirectionGetter as BootDirectionGetter,
)
from cuslines.webgpu import (
WebGPUProbDirectionGetter as ProbDirectionGetter,
)
from cuslines.webgpu import (
WebGPUPttDirectionGetter as PttDirectionGetter,
WebGPUBootDirectionGetter as BootDirectionGetter,
)
from cuslines.webgpu import (
WebGPUTracker as GPUTracker,
)
else:
raise ImportError(
Expand Down
54 changes: 38 additions & 16 deletions cuslines/boot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,62 @@
from dipy.reconst import shm


def prepare_opdt(gtab, sphere, sh_order_max=6, full_basis=False,
sh_lambda=0.006, min_signal=1):
def prepare_opdt(
gtab, sphere, sh_order_max=6, full_basis=False, sh_lambda=0.006, min_signal=1
):
"""Build bootstrap matrices for the OPDT model.

Returns dict with keys: model_type, min_signal, H, R, delta_b,
delta_q, sampling_matrix, b0s_mask.
"""
sampling_matrix, _, _ = shm.real_sh_descoteaux(
sh_order_max, sphere.theta, sphere.phi,
full_basis=full_basis, legacy=True,
sh_order_max,
sphere.theta,
sphere.phi,
full_basis=full_basis,
legacy=True,
)
model = shm.OpdtModel(
gtab, sh_order_max=sh_order_max, smooth=sh_lambda,
gtab,
sh_order_max=sh_order_max,
smooth=sh_lambda,
min_signal=min_signal,
)
delta_b, delta_q = model._fit_matrix

H, R = _hat_and_lcr(gtab, model, sh_order_max)

return dict(
model_type="OPDT", min_signal=min_signal,
H=H, R=R, delta_b=delta_b, delta_q=delta_q,
sampling_matrix=sampling_matrix, b0s_mask=gtab.b0s_mask,
model_type="OPDT",
min_signal=min_signal,
H=H,
R=R,
delta_b=delta_b,
delta_q=delta_q,
sampling_matrix=sampling_matrix,
b0s_mask=gtab.b0s_mask,
)


def prepare_csa(gtab, sphere, sh_order_max=6, full_basis=False,
sh_lambda=0.006, min_signal=1):
def prepare_csa(
gtab, sphere, sh_order_max=6, full_basis=False, sh_lambda=0.006, min_signal=1
):
"""Build bootstrap matrices for the CSA model.

Returns dict with keys: model_type, min_signal, H, R, delta_b,
delta_q, sampling_matrix, b0s_mask.
"""
sampling_matrix, _, _ = shm.real_sh_descoteaux(
sh_order_max, sphere.theta, sphere.phi,
full_basis=full_basis, legacy=True,
sh_order_max,
sphere.theta,
sphere.phi,
full_basis=full_basis,
legacy=True,
)
model = shm.CsaOdfModel(
gtab, sh_order_max=sh_order_max, smooth=sh_lambda,
gtab,
sh_order_max=sh_order_max,
smooth=sh_lambda,
min_signal=min_signal,
)
delta_b = model._fit_matrix
Expand All @@ -55,9 +72,14 @@ def prepare_csa(gtab, sphere, sh_order_max=6, full_basis=False,
H, R = _hat_and_lcr(gtab, model, sh_order_max)

return dict(
model_type="CSA", min_signal=min_signal,
H=H, R=R, delta_b=delta_b, delta_q=delta_q,
sampling_matrix=sampling_matrix, b0s_mask=gtab.b0s_mask,
model_type="CSA",
min_signal=min_signal,
H=H,
R=R,
delta_b=delta_b,
delta_q=delta_q,
sampling_matrix=sampling_matrix,
b0s_mask=gtab.b0s_mask,
)


Expand Down
Loading
Loading