Skip to content

fix: ensure moffat derivs are not nan and increase performance for trunc=0#188

Draft
beckermr wants to merge 19 commits intomainfrom
debug-moffat-perf
Draft

fix: ensure moffat derivs are not nan and increase performance for trunc=0#188
beckermr wants to merge 19 commits intomainfrom
debug-moffat-perf

Conversation

@beckermr
Copy link
Collaborator

This PR ensures that the Moffat derivs wrt scale_radius are not Nan. I added tests as well for the other profiles.

@beckermr beckermr marked this pull request as draft February 10, 2026 13:12
@codspeed-hq
Copy link

codspeed-hq bot commented Feb 10, 2026

Merging this PR will improve performance by ×3.7

⚠️ Unknown Walltime execution environment detected

Using the Walltime instrument on standard Hosted Runners will lead to inconsistent data.

For the most accurate results, we recommend using CodSpeed Macro Runners: bare-metal machines fine-tuned for performance measurement consistency.

⚡ 4 improved benchmarks
✅ 32 untouched benchmarks

Performance Changes

Mode Benchmark BASE HEAD Efficiency
Simulation test_benchmark_moffat_conv[run] 10.4 s 2.8 s ×3.7
Simulation test_benchmark_moffat_conv_grad[run] 13.6 s 6.2 s ×2.2
WallTime test_benchmark_moffat_conv[run] 681.8 ms 301.2 ms ×2.3
WallTime test_benchmark_moffat_conv_grad[run] 1,279.6 ms 899.7 ms +42.22%

Comparing debug-moffat-perf (ecba9da) with main (51df286)

Open in CodSpeed

@beckermr
Copy link
Collaborator Author

pre-commit.ci autofix

@beckermr
Copy link
Collaborator Author

@EiffL This version of the moffat uses an interpolant in k space internally and is a lot faster. I need to work on the accuracy, but I think it may help.

I am confused on one thing with the "workspace" trick which I pulled from jax-cosmo. Do we need to pass the workspace into to the tree_flatten functions?

    def tree_flatten(self):
        """This function flattens the GSObject into a list of children
        nodes that will be traced by JAX and auxiliary static data."""
        # Define the children nodes of the PyTree that need tracing
        children = (self.params,)
        # Define auxiliary static data that doesn’t need to be traced
        aux_data = {"gsparams": self.gsparams}
        return (children, aux_data)

The derivs appear correct in the tests above, but there is something basic I don't follow and I suspect things might break in other contexts.

@EiffL
Copy link
Member

EiffL commented Feb 10, 2026

workspaces are pretty dangerous generally, and there is a known issue in jax-cosmo that shows up from time to time. I would avoid propagating the workspace in the tree flatten, and instead rebuild it everytime the object is inflated.

But here, I think here the best thing to do is not to have a workspace, but just compute the interpolation array at the beginning of each function that needs to compute a bessel function, use it, and discard it.

This way it's safe, and still way more efficient than computing the value for all elements of an image.

@beckermr
Copy link
Collaborator Author

there is a known issue in jax-cosmo that shows up from time to time

Can you point me to this issue or an example of it?

@EiffL
Copy link
Member

EiffL commented Feb 10, 2026

DifferentiableUniverseInitiative/jax_cosmo#140

@beckermr beckermr changed the title fix: ensure moffat derivs are not nan fix: ensure moffat derivs are not nan and increase performance for trunc=0 Feb 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants