From 5557f627b47f0b4e696e086aa59beffb6e8bb268 Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 29 Dec 2023 17:45:32 +0800 Subject: [PATCH 1/2] [math] fix the default setting in `brainpy.math` --- brainpy/_src/deprecations.py | 13 ++++----- brainpy/_src/math/defaults.py | 10 ------- brainpy/_src/math/tests/test_defaults.py | 36 ++++++++++++++++++++++++ brainpy/math/__init__.py | 3 +- 4 files changed, 44 insertions(+), 18 deletions(-) create mode 100644 brainpy/_src/math/tests/test_defaults.py diff --git a/brainpy/_src/deprecations.py b/brainpy/_src/deprecations.py index 4719d982e..74a0103da 100644 --- a/brainpy/_src/deprecations.py +++ b/brainpy/_src/deprecations.py @@ -41,7 +41,6 @@ def f_input_or_monitor(): ''' - def _deprecate(msg): warnings.simplefilter('always', DeprecationWarning) # turn off filter warnings.warn(msg, category=DeprecationWarning, stacklevel=2) @@ -61,10 +60,10 @@ def new_func(*args, **kwargs): return new_func -def deprecation_getattr(module, deprecations, redirects=None): +def deprecation_getattr(module, deprecations, redirects=None, redirect_module=None): redirects = redirects or {} - def getattr(name): + def get_attr(name): if name in deprecations: message, fn = deprecations[name] if fn is None: @@ -72,14 +71,14 @@ def getattr(name): _deprecate(message) return fn if name in redirects: - return redirects[name] + return getattr(redirect_module, name) raise AttributeError(f"module {module!r} has no attribute {name!r}") - return getattr + return get_attr def deprecation_getattr2(module, deprecations): - def getattr(name): + def get_attr(name): if name in deprecations: old_name, new_name, fn = deprecations[name] message = f"{old_name} is deprecated. " @@ -91,4 +90,4 @@ def getattr(name): return fn raise AttributeError(f"module {module!r} has no attribute {name!r}") - return getattr + return get_attr diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index ad91fa6ab..19aca92cf 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -36,13 +36,3 @@ # '''Default complex data type.''' complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64 -# redirects -redirects = {'mode': mode, - 'membrane_scaling': membrane_scaling, - 'dt': dt, - 'bool_': bool_, - 'int_': int_, - 'ti_int': ti_int, - 'float_': float_, - 'ti_float': ti_float, - 'complex_': complex_} diff --git a/brainpy/_src/math/tests/test_defaults.py b/brainpy/_src/math/tests/test_defaults.py new file mode 100644 index 000000000..9076829b7 --- /dev/null +++ b/brainpy/_src/math/tests/test_defaults.py @@ -0,0 +1,36 @@ +import unittest + +import brainpy.math as bm + + +class TestDefaults(unittest.TestCase): + def test_dt(self): + with bm.environment(dt=1.0): + self.assertEqual(bm.dt, 1.0) + self.assertEqual(bm.get_dt(), 1.0) + + def test_bool(self): + with bm.environment(bool_=bm.int32): + self.assertTrue(bm.bool_ == bm.int32) + self.assertTrue(bm.get_bool() == bm.int32) + + def test_int(self): + with bm.environment(int_=bm.int32): + self.assertTrue(bm.int == bm.int32) + self.assertTrue(bm.get_int() == bm.int32) + + def test_float(self): + with bm.environment(float_=bm.float32): + self.assertTrue(bm.float_ == bm.float32) + self.assertTrue(bm.get_float() == bm.float32) + + def test_complex(self): + with bm.environment(complex_=bm.complex64): + self.assertTrue(bm.complex_ == bm.complex64) + self.assertTrue(bm.get_complex() == bm.complex64) + + def test_mode(self): + mode = bm.TrainingMode() + with bm.environment(mode=mode): + self.assertTrue(bm.mode == mode) + self.assertTrue(bm.get_mode() == mode) diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index d45df89d5..cf7a766b4 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -97,5 +97,6 @@ "Use brainpy.math.event.info instead.", event.info), } -__getattr__ = deprecation_getattr(__name__, __deprecations, defaults.redirects) + +__getattr__ = deprecation_getattr(__name__, __deprecations, redirects=defaults.__all__, redirect_module=defaults) del deprecation_getattr, defaults From 260bf135902dd3033328d6924b362a9526d0e82b Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 29 Dec 2023 17:50:58 +0800 Subject: [PATCH 2/2] [doc] upgrade math doc --- docs/apis/brainpy.math.defaults.rst | 22 ++++++++++++++++++++++ docs/apis/brainpy.math.op_register.rst | 16 ++++++++++++++++ docs/apis/math.rst | 1 + 3 files changed, 39 insertions(+) create mode 100644 docs/apis/brainpy.math.defaults.rst diff --git a/docs/apis/brainpy.math.defaults.rst b/docs/apis/brainpy.math.defaults.rst new file mode 100644 index 000000000..515391dcf --- /dev/null +++ b/docs/apis/brainpy.math.defaults.rst @@ -0,0 +1,22 @@ + +Default Math Parameters +======================= + +.. currentmodule:: brainpy.math +.. automodule:: brainpy.math + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + mode + membrane_scaling + dt + bool_ + int_ + ti_int + float_ + ti_float + complex_ + + diff --git a/docs/apis/brainpy.math.op_register.rst b/docs/apis/brainpy.math.op_register.rst index 7010b64eb..a50b4d300 100644 --- a/docs/apis/brainpy.math.op_register.rst +++ b/docs/apis/brainpy.math.op_register.rst @@ -6,6 +6,22 @@ Operator Registration :depth: 1 + +General Operator Customization Interface +---------------------------------------- + +.. currentmodule:: brainpy.math +.. automodule:: brainpy.math + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + XLACustomOp + + + CPU Operator Customization with Numba ------------------------------------- diff --git a/docs/apis/math.rst b/docs/apis/math.rst index e3f0b765a..f4b778aba 100644 --- a/docs/apis/math.rst +++ b/docs/apis/math.rst @@ -24,6 +24,7 @@ dynamics programming. For more information and usage examples, please refer to t :maxdepth: 1 brainpy.math.rst + brainpy.math.defaults.rst brainpy.math.delayvars.rst brainpy.math.oo_transform.rst brainpy.math.pre_syn_post.rst