fix: ensure moffat derivs are not nan and increase performance for trunc=0#188
fix: ensure moffat derivs are not nan and increase performance for trunc=0#188
trunc=0#188Conversation
Merging this PR will improve performance by ×3.7
|
| Mode | Benchmark | BASE |
HEAD |
Efficiency | |
|---|---|---|---|---|---|
| ⚡ | Simulation | test_benchmark_moffat_conv[run] |
10.4 s | 2.8 s | ×3.7 |
| ⚡ | Simulation | test_benchmark_moffat_conv_grad[run] |
13.6 s | 6.2 s | ×2.2 |
| ⚡ | WallTime | test_benchmark_moffat_conv[run] |
681.8 ms | 301.2 ms | ×2.3 |
| ⚡ | WallTime | test_benchmark_moffat_conv_grad[run] |
1,279.6 ms | 899.7 ms | +42.22% |
Comparing debug-moffat-perf (ecba9da) with main (51df286)
|
pre-commit.ci autofix |
for more information, see https://pre-commit.ci
|
@EiffL This version of the moffat uses an interpolant in k space internally and is a lot faster. I need to work on the accuracy, but I think it may help. I am confused on one thing with the "workspace" trick which I pulled from def tree_flatten(self):
"""This function flattens the GSObject into a list of children
nodes that will be traced by JAX and auxiliary static data."""
# Define the children nodes of the PyTree that need tracing
children = (self.params,)
# Define auxiliary static data that doesn’t need to be traced
aux_data = {"gsparams": self.gsparams}
return (children, aux_data)The derivs appear correct in the tests above, but there is something basic I don't follow and I suspect things might break in other contexts. |
|
workspaces are pretty dangerous generally, and there is a known issue in jax-cosmo that shows up from time to time. I would avoid propagating the workspace in the tree flatten, and instead rebuild it everytime the object is inflated. But here, I think here the best thing to do is not to have a workspace, but just compute the interpolation array at the beginning of each function that needs to compute a bessel function, use it, and discard it. This way it's safe, and still way more efficient than computing the value for all elements of an image. |
Can you point me to this issue or an example of it? |
…ers/JAX-GalSim into debug-moffat-perf
trunc=0
This PR ensures that the Moffat derivs wrt scale_radius are not Nan. I added tests as well for the other profiles.