diff --git a/brainpy/_src/_delay.py b/brainpy/_src/_delay.py index a646fd159..bac73e012 100644 --- a/brainpy/_src/_delay.py +++ b/brainpy/_src/_delay.py @@ -144,7 +144,7 @@ def register_entry( delay_type = 'homo' else: delay_type = 'heter' - delay_step = bm.Array(delay_step) + delay_step = delay_step elif callable(delay_step): delay_step = delay_step(self.delay_target_shape) delay_type = 'heter' diff --git a/brainpy/_src/connect/tests/test_random_conn_visualize.py b/brainpy/_src/connect/tests/test_random_conn_visualize.py index 9cd64821c..ba0d95f13 100644 --- a/brainpy/_src/connect/tests/test_random_conn_visualize.py +++ b/brainpy/_src/connect/tests/test_random_conn_visualize.py @@ -2,176 +2,178 @@ import pytest +pytest.skip('skip', allow_module_level=True) + import brainpy as bp def test_random_fix_pre1(): - for num in [0.4, 20]: - conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) - mat1 = conn1.require(bp.connect.CONN_MAT) + for num in [0.4, 20]: + conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) + mat1 = conn1.require(bp.connect.CONN_MAT) - conn2 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) - mat2 = conn2.require(bp.connect.CONN_MAT) + conn2 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) + mat2 = conn2.require(bp.connect.CONN_MAT) - print() - print(f'num = {num}') - print('conn_mat 1\n', mat1) - print(mat1.sum()) - print('conn_mat 2\n', mat2) - print(mat2.sum()) + print() + print(f'num = {num}') + print('conn_mat 1\n', mat1) + print(mat1.sum()) + print('conn_mat 2\n', mat2) + print(mat2.sum()) - assert bp.math.array_equal(mat1, mat2) - bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num) + assert bp.math.array_equal(mat1, mat2) + bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num) def test_random_fix_pre2(): - for num in [0.5, 3]: - conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=5, post_size=4) - mat1 = conn1.require(bp.connect.CONN_MAT) - print() - print(mat1) + for num in [0.5, 3]: + conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=5, post_size=4) + mat1 = conn1.require(bp.connect.CONN_MAT) + print() + print(mat1) - bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=5, post_size=4' % num) + bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=5, post_size=4' % num) def test_random_fix_pre3(): - with pytest.raises(bp.errors.ConnectorError): - conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4) - conn1.require(bp.connect.CONN_MAT) + with pytest.raises(bp.errors.ConnectorError): + conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4) + conn1.require(bp.connect.CONN_MAT) - bp.connect.visualizeMat(conn1, 'FixedPreNum: num=6, pre_size=3, post_size=4') + bp.connect.visualizeMat(conn1, 'FixedPreNum: num=6, pre_size=3, post_size=4') def test_random_fix_post1(): - for num in [0.4, 20]: - conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) - mat1 = conn1.require(bp.connect.CONN_MAT) + for num in [0.4, 20]: + conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) + mat1 = conn1.require(bp.connect.CONN_MAT) - conn2 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) - mat2 = conn2.require(bp.connect.CONN_MAT) + conn2 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) + mat2 = conn2.require(bp.connect.CONN_MAT) - print() - print('conn_mat 1\n', mat1) - print('conn_mat 2\n', mat2) + print() + print('conn_mat 1\n', mat1) + print('conn_mat 2\n', mat2) - assert bp.math.array_equal(mat1, mat2) - bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num) + assert bp.math.array_equal(mat1, mat2) + bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num) def test_random_fix_post2(): - for num in [0.5, 3]: - conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=5, post_size=4) - mat1 = conn1.require(bp.connect.CONN_MAT) - print(mat1) - bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=5, post_size=4' % num) + for num in [0.5, 3]: + conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=5, post_size=4) + mat1 = conn1.require(bp.connect.CONN_MAT) + print(mat1) + bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=5, post_size=4' % num) def test_random_fix_post3(): - with pytest.raises(bp.errors.ConnectorError): - conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4) - conn1.require(bp.connect.CONN_MAT) - bp.connect.visualizeMat(conn1, 'FixedPostNum: num=6, pre_size=3, post_size=4') + with pytest.raises(bp.errors.ConnectorError): + conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4) + conn1.require(bp.connect.CONN_MAT) + bp.connect.visualizeMat(conn1, 'FixedPostNum: num=6, pre_size=3, post_size=4') def test_gaussian_prob1(): - conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100) - mat = conn.require(bp.connect.CONN_MAT) + conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100) + mat = conn.require(bp.connect.CONN_MAT) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'GaussianProb: sigma=1., include_self=False, pre_size=100') + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'GaussianProb: sigma=1., include_self=False, pre_size=100') def test_gaussian_prob2(): - conn = bp.connect.GaussianProb(sigma=4)(pre_size=(50, 50)) - mat = conn.require(bp.connect.CONN_MAT) + conn = bp.connect.GaussianProb(sigma=4)(pre_size=(50, 50)) + mat = conn.require(bp.connect.CONN_MAT) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, pre_size=(50, 50)') + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, pre_size=(50, 50)') def test_gaussian_prob3(): - conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(50, 50)) - mat = conn.require(bp.connect.CONN_MAT) + conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(50, 50)) + mat = conn.require(bp.connect.CONN_MAT) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(50, 50)') + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(50, 50)') def test_gaussian_prob4(): - conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(10, 10, 10)) - conn.require(bp.connect.CONN_MAT, - bp.connect.PRE_IDS, bp.connect.POST_IDS, - bp.connect.PRE2POST, bp.connect.POST_IDS) - mat = conn.require(bp.connect.CONN_MAT) - bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(10, 10, 10)') + conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(10, 10, 10)) + conn.require(bp.connect.CONN_MAT, + bp.connect.PRE_IDS, bp.connect.POST_IDS, + bp.connect.PRE2POST, bp.connect.POST_IDS) + mat = conn.require(bp.connect.CONN_MAT) + bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(10, 10, 10)') def test_SmallWorld1(): - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) - conn(pre_size=10, post_size=10) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) + conn(pre_size=10, post_size=10) - mat = conn.require(bp.connect.CONN_MAT) + mat = conn.require(bp.connect.CONN_MAT) - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=False, pre_size=10, post_size=10') + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=False, pre_size=10, post_size=10') def test_SmallWorld3(): - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=True) - conn(pre_size=20, post_size=20) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=True) + conn(pre_size=20, post_size=20) - mat = conn.require(bp.connect.CONN_MAT) + mat = conn.require(bp.connect.CONN_MAT) - print('conn_mat', mat) + print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=True, pre_size=20, post_size=20') + bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=True, pre_size=20, post_size=20') def test_SmallWorld2(): - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5) - conn(pre_size=(100,), post_size=(100,)) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5) + conn(pre_size=(100,), post_size=(100,)) + mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, + bp.connect.PRE_IDS, bp.connect.POST_IDS, + bp.connect.PRE2POST, bp.connect.POST_IDS) + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, pre_size=(100,), post_size=(100,)') + + +def test_ScaleFreeBA(): + conn = bp.connect.ScaleFreeBA(m=2) + for size in [100, (10, 20), (2, 10, 20)]: + conn(pre_size=size, post_size=size) mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, bp.connect.PRE_IDS, bp.connect.POST_IDS, bp.connect.PRE2POST, bp.connect.POST_IDS) print() print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, pre_size=(100,), post_size=(100,)') - - -def test_ScaleFreeBA(): - conn = bp.connect.ScaleFreeBA(m=2) - for size in [100, (10, 20), (2, 10, 20)]: - conn(pre_size=size, post_size=size) - mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, - bp.connect.PRE_IDS, bp.connect.POST_IDS, - bp.connect.PRE2POST, bp.connect.POST_IDS) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'ScaleFreeBA: m=2, pre_size=%s, post_size=%s' % (size, size)) + bp.connect.visualizeMat(mat, 'ScaleFreeBA: m=2, pre_size=%s, post_size=%s' % (size, size)) def test_ScaleFreeBADual(): - conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) - for size in [100, (10, 20), (2, 10, 20)]: - conn(pre_size=size, post_size=size) - mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, - bp.connect.PRE_IDS, bp.connect.POST_IDS, - bp.connect.PRE2POST, bp.connect.POST_IDS) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'ScaleFreeBADual: m1=2, m2=3, p=0.4, pre_size=%s, post_size=%s' % (size, size)) + conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) + for size in [100, (10, 20), (2, 10, 20)]: + conn(pre_size=size, post_size=size) + mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, + bp.connect.PRE_IDS, bp.connect.POST_IDS, + bp.connect.PRE2POST, bp.connect.POST_IDS) + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'ScaleFreeBADual: m1=2, m2=3, p=0.4, pre_size=%s, post_size=%s' % (size, size)) def test_PowerLaw(): - conn = bp.connect.PowerLaw(m=3, p=0.4) - for size in [100, (10, 20), (2, 10, 20)]: - conn(pre_size=size, post_size=size) - mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, - bp.connect.PRE_IDS, bp.connect.POST_IDS, - bp.connect.PRE2POST, bp.connect.POST_IDS) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'PowerLaw: m=3, p=0.4, pre_size=%s, post_size=%s' % (size, size)) + conn = bp.connect.PowerLaw(m=3, p=0.4) + for size in [100, (10, 20), (2, 10, 20)]: + conn(pre_size=size, post_size=size) + mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, + bp.connect.PRE_IDS, bp.connect.POST_IDS, + bp.connect.PRE2POST, bp.connect.POST_IDS) + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'PowerLaw: m=3, p=0.4, pre_size=%s, post_size=%s' % (size, size)) diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index 3102bc1d0..de559de56 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -49,7 +49,6 @@ # operators from .op_register import * from .pre_syn_post import * -from .surrogate._compt import * from . import surrogate, event, sparse, jitconn # Variable and Objects for object-oriented JAX transformations diff --git a/brainpy/_src/math/compat_numpy.py b/brainpy/_src/math/compat_numpy.py index 213185df1..0eb391458 100644 --- a/brainpy/_src/math/compat_numpy.py +++ b/brainpy/_src/math/compat_numpy.py @@ -103,6 +103,10 @@ _max = max +def _return(a): + return Array(a) + + def fill_diagonal(a, val, inplace=True): if a.ndim < 2: raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}') @@ -120,30 +124,30 @@ def fill_diagonal(a, val, inplace=True): def zeros(shape, dtype=None): - return Array(jnp.zeros(shape, dtype=dtype)) + return _return(jnp.zeros(shape, dtype=dtype)) def ones(shape, dtype=None): - return Array(jnp.ones(shape, dtype=dtype)) + return _return(jnp.ones(shape, dtype=dtype)) def empty(shape, dtype=None): - return Array(jnp.zeros(shape, dtype=dtype)) + return _return(jnp.zeros(shape, dtype=dtype)) def zeros_like(a, dtype=None, shape=None): a = _as_jax_array_(a) - return Array(jnp.zeros_like(a, dtype=dtype, shape=shape)) + return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) def ones_like(a, dtype=None, shape=None): a = _as_jax_array_(a) - return Array(jnp.ones_like(a, dtype=dtype, shape=shape)) + return _return(jnp.ones_like(a, dtype=dtype, shape=shape)) def empty_like(a, dtype=None, shape=None): a = _as_jax_array_(a) - return Array(jnp.zeros_like(a, dtype=dtype, shape=shape)) + return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: @@ -155,7 +159,7 @@ def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: leaves = [_as_jax_array_(l) for l in leaves] a = tree_unflatten(tree, leaves) res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) - return Array(res) + return _return(res) def asarray(a, dtype=None, order=None): @@ -167,13 +171,13 @@ def asarray(a, dtype=None, order=None): leaves = [_as_jax_array_(l) for l in leaves] arrays = tree_unflatten(tree, leaves) res = jnp.asarray(a=arrays, dtype=dtype, order=order) - return Array(res) + return _return(res) def arange(*args, **kwargs): args = [_as_jax_array_(a) for a in args] kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return Array(jnp.arange(*args, **kwargs)) + return _return(jnp.arange(*args, **kwargs)) def linspace(*args, **kwargs): @@ -181,15 +185,15 @@ def linspace(*args, **kwargs): kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} res = jnp.linspace(*args, **kwargs) if isinstance(res, tuple): - return Array(res[0]), res[1] + return _return(res[0]), res[1] else: - return Array(res) + return _return(res) def logspace(*args, **kwargs): args = [_as_jax_array_(a) for a in args] kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return Array(jnp.logspace(*args, **kwargs)) + return _return(jnp.logspace(*args, **kwargs)) def asanyarray(a, dtype=None, order=None): diff --git a/brainpy/_src/math/compat_tensorflow.py b/brainpy/_src/math/compat_tensorflow.py index 7e9168cfa..e9e87e24c 100644 --- a/brainpy/_src/math/compat_tensorflow.py +++ b/brainpy/_src/math/compat_tensorflow.py @@ -259,13 +259,13 @@ def segment_sum(data: Union[Array, jnp.ndarray], An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ - return Array(jax.ops.segment_sum(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + return _return(jax.ops.segment_sum(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def segment_prod(data: Union[Array, jnp.ndarray], @@ -311,13 +311,13 @@ def segment_prod(data: Union[Array, jnp.ndarray], An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ - return Array(jax.ops.segment_prod(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + return _return(jax.ops.segment_prod(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def segment_max(data: Union[Array, jnp.ndarray], @@ -363,13 +363,13 @@ def segment_max(data: Union[Array, jnp.ndarray], An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ - return Array(jax.ops.segment_max(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + return _return(jax.ops.segment_max(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def segment_min(data: Union[Array, jnp.ndarray], @@ -415,13 +415,13 @@ def segment_min(data: Union[Array, jnp.ndarray], An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ - return Array(jax.ops.segment_min(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + return _return(jax.ops.segment_min(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def cast(x, dtype): diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index 9f3c50454..eab8b9b66 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -12,32 +12,37 @@ # Default computation mode. mode = NonBatchingMode() -# '''Default computation mode.''' +# Default computation mode. membrane_scaling = IdScaling() -# '''Default time step.''' +# Default time step. dt = 0.1 -# '''Default bool data type.''' +# Default bool data type. bool_ = jnp.bool_ -# '''Default integer data type.''' +# Default integer data type. int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32 -# '''Default float data type.''' +# Default float data type. float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32 -# '''Default complex data type.''' +# Default complex data type. complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64 # register brainpy object as pytree bp_object_as_pytree = False + +# default return array type +numpy_func_return = 'bp_array' # 'bp_array','jax_array' + + if ti is not None: - # '''Default integer data type in Taichi.''' + # Default integer data type in Taichi. ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 - # '''Default float data type in Taichi.''' + # Default float data type in Taichi. ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 else: diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index d49e70f51..ebbb8b6a3 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -169,6 +169,7 @@ def __init__( int_: type = None, bool_: type = None, bp_object_as_pytree: bool = None, + numpy_func_return: str = None, ) -> None: super().__init__() @@ -208,6 +209,12 @@ def __init__( assert isinstance(bp_object_as_pytree, bool), '"bp_object_as_pytree" must be a bool.' self.old_bp_object_as_pytree = defaults.bp_object_as_pytree + if numpy_func_return is not None: + assert isinstance(numpy_func_return, str), '"numpy_func_return" must be a string.' + assert numpy_func_return in ['bp_array', 'jax_array'], \ + f'"numpy_func_return" must be "bp_array" or "jax_array". Got {numpy_func_return}.' + self.old_numpy_func_return = defaults.numpy_func_return + self.dt = dt self.mode = mode self.membrane_scaling = membrane_scaling @@ -217,6 +224,7 @@ def __init__( self.int_ = int_ self.bool_ = bool_ self.bp_object_as_pytree = bp_object_as_pytree + self.numpy_func_return = numpy_func_return def __enter__(self) -> 'environment': if self.dt is not None: set_dt(self.dt) @@ -228,6 +236,7 @@ def __enter__(self) -> 'environment': if self.complex_ is not None: set_complex(self.complex_) if self.bool_ is not None: set_bool(self.bool_) if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.bp_object_as_pytree + if self.numpy_func_return is not None: defaults.__dict__['numpy_func_return'] = self.numpy_func_return return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: @@ -240,6 +249,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self.complex_ is not None: set_complex(self.old_complex) if self.bool_ is not None: set_bool(self.old_bool) if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.old_bp_object_as_pytree + if self.numpy_func_return is not None: defaults.__dict__['numpy_func_return'] = self.old_numpy_func_return def clone(self): return self.__class__(dt=self.dt, @@ -250,7 +260,8 @@ def clone(self): complex_=self.complex_, float_=self.float_, int_=self.int_, - bp_object_as_pytree=self.bp_object_as_pytree) + bp_object_as_pytree=self.bp_object_as_pytree, + numpy_func_return=self.numpy_func_return) def __eq__(self, other): return id(self) == id(other) @@ -279,6 +290,7 @@ def __init__( batch_size: int = 1, membrane_scaling: scales.Scaling = None, bp_object_as_pytree: bool = None, + numpy_func_return: str = None, ): super().__init__(dt=dt, x64=x64, @@ -288,7 +300,8 @@ def __init__( bool_=bool_, membrane_scaling=membrane_scaling, mode=modes.TrainingMode(batch_size), - bp_object_as_pytree=bp_object_as_pytree) + bp_object_as_pytree=bp_object_as_pytree, + numpy_func_return=numpy_func_return) class batching_environment(environment): @@ -315,6 +328,7 @@ def __init__( batch_size: int = 1, membrane_scaling: scales.Scaling = None, bp_object_as_pytree: bool = None, + numpy_func_return: str = None, ): super().__init__(dt=dt, x64=x64, @@ -324,7 +338,8 @@ def __init__( bool_=bool_, mode=modes.BatchingMode(batch_size), membrane_scaling=membrane_scaling, - bp_object_as_pytree=bp_object_as_pytree) + bp_object_as_pytree=bp_object_as_pytree, + numpy_func_return=numpy_func_return) def set( @@ -337,6 +352,7 @@ def set( int_: type = None, bool_: type = None, bp_object_as_pytree: bool = None, + numpy_func_return: str = None, ): """Set the default computation environment. @@ -360,6 +376,8 @@ def set( The bool data type. bp_object_as_pytree: bool Whether to register brainpy object as pytree. + numpy_func_return: str + The array to return in all numpy functions. Support 'bp_array' and 'jax_array'. """ if dt is not None: assert isinstance(dt, float), '"dt" must a float.' @@ -396,6 +414,10 @@ def set( if bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = bp_object_as_pytree + if numpy_func_return is not None: + assert numpy_func_return in ['bp_array', 'jax_array'], f'"numpy_func_return" must be "bp_array" or "jax_array".' + defaults.__dict__['numpy_func_return'] = numpy_func_return + set_environment = set diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 6c0a2ed47..181ee5520 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -14,6 +14,12 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) +import platform +force_test = False # turn on to force test on windows locally +if platform.system() == 'Windows' and not force_test: + pytest.skip('skip windows', allow_module_level=True) + + seed = 1234 diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index 6fb8d02ec..dd1bafded 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -11,7 +11,12 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) -shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] +import platform +force_test = False # turn on to force test on windows locally +if platform.system() == 'Windows' and not force_test: + pytest.skip('skip windows', allow_module_level=True) + + shapes = [(100, 200), (1000, 10)] diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 67c18124f..e42bd3695 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -11,6 +11,12 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) +import platform +force_test = False # turn on to force test on windows locally +if platform.system() == 'Windows' and not force_test: + pytest.skip('skip windows', allow_module_level=True) + + shapes = [(100, 200), (1000, 10)] diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index cf2b2343d..791c8d9fe 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -10,6 +10,7 @@ from jax.tree_util import register_pytree_node_class from brainpy.errors import MathError +from . import defaults bm = None @@ -41,8 +42,8 @@ def _check_input_array(array): def _return(a): - if isinstance(a, jax.Array) and a.ndim > 0: - return Array(a) + if defaults.numpy_func_return == 'bp_array' and isinstance(a, jax.Array) and a.ndim > 0: + return Array(a) return a @@ -1087,7 +1088,7 @@ def unsqueeze(self, dim: int) -> 'Array': See :func:`brainpy.math.unsqueeze` """ - return Array(jnp.expand_dims(self.value, dim)) + return _return(jnp.expand_dims(self.value, dim)) def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Array': """ @@ -1119,7 +1120,7 @@ def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Array': self.expand_dims(axis)==self.expand_dims(axis[0]).expand_dims(axis[1])... expand_dims(axis[len(axis)-1]) """ - return Array(jnp.expand_dims(self.value, axis)) + return _return(jnp.expand_dims(self.value, axis)) def expand_as(self, array: Union['Array', jax.Array, np.ndarray]) -> 'Array': """ @@ -1136,9 +1137,7 @@ def expand_as(self, array: Union['Array', jax.Array, np.ndarray]) -> 'Array': typically not contiguous. Furthermore, more than one element of a expanded array may refer to a single memory location. """ - if not isinstance(array, Array): - array = Array(array) - return Array(jnp.broadcast_to(self.value, array.value.shape)) + return _return(jnp.broadcast_to(self.value, array)) def pow(self, index: int): return _return(self.value ** index) @@ -1228,7 +1227,7 @@ def absolute_(self): return self.abs_() def mul(self, value): - return Array(self.value * value) + return _return(self.value * value) def mul_(self, value): """ @@ -1404,7 +1403,7 @@ def clip_(self, return self def clone(self) -> 'Array': - return Array(self.value.copy()) + return _return(self.value.copy()) def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> 'Array': self.value = jnp.copy(_as_jax_array_(src)) @@ -1423,7 +1422,7 @@ def cov_with( fweights = _as_jax_array_(fweights) aweights = _as_jax_array_(aweights) r = jnp.cov(self.value, y, rowvar, bias, fweights, aweights) - return Array(r) + return _return(r) def expand(self, *sizes) -> 'Array': """ @@ -1459,7 +1458,7 @@ def expand(self, *sizes) -> 'Array': raise ValueError( f'The expanded size of the tensor ({sizes_list[base + i]}) must match the existing size ({v}) at non-singleton ' f'dimension {i}. Target sizes: {sizes}. Tensor sizes: {self.shape}') - return Array(jnp.broadcast_to(self.value, sizes_list)) + return _return(jnp.broadcast_to(self.value, sizes_list)) def tree_flatten(self): return (self.value,), None diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 53346a7d1..b21ed2af3 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -557,7 +557,7 @@ def load_state_dict( missing_keys = [] unexpected_keys = [] for name, node in nodes.items(): - r = node.load_state(state_dict[name], **kwargs) + r = node.load_state(state_dict[name] if name in state_dict else {}, **kwargs) if r is not None: missing, unexpected = r missing_keys.extend([f'{name}.{key}' for key in missing]) diff --git a/brainpy/_src/math/object_transform/tests/test_base.py b/brainpy/_src/math/object_transform/tests/test_base.py index c6f8f90d4..ebad7eb06 100644 --- a/brainpy/_src/math/object_transform/tests/test_base.py +++ b/brainpy/_src/math/object_transform/tests/test_base.py @@ -106,8 +106,7 @@ def update(self, x): self.assertTrue(len(obj.nodes()) == 7) print(obj.nodes().keys()) - print("obj.nodes(method='relative'): ", - obj.nodes(method='relative').keys()) + print("obj.nodes(method='relative'): ", obj.nodes(method='relative').keys()) # print(jax.tree_util.tree_structure(obj)) with bm.environment(mode=bm.TrainingMode()): @@ -116,8 +115,7 @@ def update(self, x): self.assertTrue(len(obj.nodes()) == 7) print(obj.nodes().keys()) - print("obj.nodes(method='relative'): ", - obj.nodes(method='relative').keys()) + print("obj.nodes(method='relative'): ", obj.nodes(method='relative').keys()) # print(jax.tree_util.tree_structure(obj)) @@ -248,5 +246,49 @@ def test1(self): print() +class TestStateSavingAndLoading(unittest.TestCase): + def test_load_states(self): + class Object(bp.DynamicalSystem): + def __init__(self): + super().__init__() + + self.l1 = bp.layers.Dense(5, 10) + self.ls = bm.NodeList([bp.layers.Dense(10, 4), + bp.layers.Activation(bm.tanh), + bp.layers.Dropout(0.1), + bp.layers.Dense(4, 5), + bp.layers.Activation(bm.relu)]) + self.lif = bp.dyn.LifRef(5) + + def update(self, x): + x = self.l1(x) + for l in self.ls: + x = l(x) + return x + + with bm.training_environment(): + obj = Object() + variables = {k: dict(n.vars()) for k, n in obj.nodes(include_self=False).items()} + variables = {k: v for k, v in variables.items() if len(v) > 0} + + all_states = obj.state_dict() + all_states = {k: v for k, v in all_states.items() if len(v) > 0} + print(set(all_states.keys())) + print(set(variables.keys())) + + def not_close(x, y): + assert not bm.allclose(x, y) + def all_close(x, y): + assert bm.allclose(x, y) + + jax.tree_map(all_close, all_states, variables, is_leaf=bm.is_bp_array) + + random_state = jax.tree_map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) + jax.tree_map(not_close, random_state, variables, is_leaf=bm.is_bp_array) + + obj.load_state_dict(random_state) + jax.tree_map(all_close, random_state, variables, is_leaf=bm.is_bp_array) + + diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py index ed687eea5..21c222c00 100644 --- a/brainpy/_src/math/op_register/__init__.py +++ b/brainpy/_src/math/op_register/__init__.py @@ -3,6 +3,6 @@ compile_cpu_signature_with_numba) from .base import XLACustomOp from .utils import register_general_batching -from .taichi_aot_based import clean_caches, check_kernels_count +from .taichi_aot_based import clear_taichi_aot_caches, count_taichi_aot_kernels from .base import XLACustomOp from .utils import register_general_batching diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index f9328906e..595460ea0 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -8,7 +8,7 @@ import re import shutil from functools import partial, reduce -from typing import Any, Sequence +from typing import Any, Sequence, Union import jax.core import numpy as np @@ -16,14 +16,17 @@ from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call -from brainpy.errors import PackageMissingError from brainpy._src.dependency_check import (import_taichi, import_brainpylib_cpu_ops, import_brainpylib_gpu_ops) +from brainpy.errors import PackageMissingError from .utils import _shape_to_layout -### UTILS ### +taichi_cache_path = None + + +# --- UTILS ### # get the path of home directory on Linux, Windows, Mac def get_home_dir(): @@ -43,8 +46,18 @@ def encode_md5(source: str) -> str: return md5.hexdigest() + # check kernels count -def check_kernels_count() -> int: +def count_taichi_aot_kernels() -> int: + """ + Count the number of AOT compiled kernels. + + Returns + ------- + kernels_count: int + The number of AOT compiled kernels. + + """ if not os.path.exists(kernels_aot_path): return 0 kernels_count = 0 @@ -54,23 +67,37 @@ def check_kernels_count() -> int: kernels_count += len(dir2) return kernels_count -# clean caches -def clean_caches(kernels_name: list[str]=None): - if kernels_name is None: - if not os.path.exists(kernels_aot_path): - raise FileNotFoundError("The kernels cache folder does not exist. \ - Please define a kernel using `taichi.kernel` \ - and customize the operator using `bm.XLACustomOp` \ - before calling the operator.") - shutil.rmtree(kernels_aot_path) - print('Clean all kernel\'s cache successfully') + +def clear_taichi_aot_caches(kernels: Union[str, Sequence[str]] = None): + """ + Clean the cache of the AOT compiled kernels. + + Parameters + ---------- + kernels: str or list of str + The name of the kernel to be cleaned. If None, all the kernels will be cleaned. + """ + if kernels is None: + global taichi_cache_path + if taichi_cache_path is None: + from taichi._lib.utils import import_ti_python_core + taichi_cache_path = import_ti_python_core().get_repo_dir() + # clean taichi cache + if os.path.exists(taichi_cache_path): + shutil.rmtree(taichi_cache_path) + # clean brainpy-taichi AOT cache + if os.path.exists(kernels_aot_path): + shutil.rmtree(kernels_aot_path) return - for kernel_name in kernels_name: - try: + if isinstance(kernels, str): + kernels = [kernels] + if not isinstance(kernels, list): + raise TypeError(f'kernels_name must be a list of str, but got {type(kernels)}') + # clear brainpy kernel cache + for kernel_name in kernels: + if os.path.exists(os.path.join(kernels_aot_path, kernel_name)): shutil.rmtree(os.path.join(kernels_aot_path, kernel_name)) - except FileNotFoundError: - raise FileNotFoundError(f'Kernel {kernel_name} does not exist.') - print('Clean kernel\'s cache successfully') + # TODO # not a very good way @@ -104,7 +131,7 @@ def is_metal_supported(): return True -### VARIABLES ### +# --- VARIABLES ### home_path = get_home_dir() kernels_aot_path = os.path.join(home_path, '.brainpy', 'kernels') is_metal_device = is_metal_supported() @@ -122,7 +149,7 @@ def _check_kernel_exist(source_md5_encode: str) -> bool: return False -### KERNEL AOT BUILD ### +# --- KERNEL AOT BUILD ### def _array_to_field(dtype, shape) -> Any: @@ -212,7 +239,7 @@ def _build_kernel( kernel.__name__ = kernel_name -### KERNEL CALL PREPROCESS ### +# --- KERNEL CALL PREPROCESS ### # convert type to number type_number_map = { @@ -334,9 +361,6 @@ def _preprocess_kernel_call_gpu( return opaque - - - def _XlaOp_to_ShapedArray(c, xla_op): xla_op = c.get_shape(xla_op) return jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type()) @@ -376,7 +400,7 @@ def _compile_kernel(abs_ins, kernel, platform: str, **kwargs): try: os.removedirs(os.path.join(kernels_aot_path, source_md5_encode)) except Exception: - raise RuntimeError(f'Failed to preprocess info to build kernel:\n\n {codes}') from e + raise RuntimeError(f'Failed to preprocess info to build kernel:\n\n {codes}') from e raise RuntimeError(f'Failed to build kernel:\n\n {codes}') from e # returns diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py index 5b27b2fd5..b534435dc 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py @@ -51,10 +51,10 @@ def test_taichi_clean_cache(): print(out) bm.clear_buffer_memory() - print('kernels: ', bm.check_kernels_count()) + print('kernels: ', bm.count_taichi_aot_kernels()) - bm.clean_caches() + bm.clear_taichi_aot_caches() - print('kernels: ', bm.check_kernels_count()) + print('kernels: ', bm.count_taichi_aot_kernels()) # test_taichi_clean_cache() diff --git a/brainpy/_src/math/others.py b/brainpy/_src/math/others.py index 94aeebb16..59588d3b9 100644 --- a/brainpy/_src/math/others.py +++ b/brainpy/_src/math/others.py @@ -11,7 +11,7 @@ from .compat_numpy import fill_diagonal from .environment import get_dt, get_int from .interoperability import as_jax -from .ndarray import Array +from .ndarray import Array, _return __all__ = [ 'shared_args_over_time', @@ -79,7 +79,7 @@ def remove_diag(arr): """ if arr.ndim != 2: raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.') - eyes = Array(jnp.ones(arr.shape, dtype=bool)) + eyes = _return(jnp.ones(arr.shape, dtype=bool)) fill_diagonal(eyes, False) return jnp.reshape(arr[eyes.value], (arr.shape[0], arr.shape[1] - 1)) diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index d0f74bf23..9ae012bc4 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -1232,9 +1232,10 @@ def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optiona a = _check_py_seq(_as_jax_array(a)) if size is None: size = jnp.shape(a) - r = call(lambda x: np.random.zipf(x, size), + dtype = jax.dtypes.canonicalize_dtype(jnp.int_) + r = call(lambda x: np.random.zipf(x, size).astype(dtype), a, - result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): @@ -1242,8 +1243,10 @@ def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Option if size is None: size = jnp.shape(a) size = _size2shape(size) - r = call(lambda a: np.random.power(a=a, size=size), - a, result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + dtype = jax.dtypes.canonicalize_dtype(jnp.float_) + r = call(lambda a: np.random.power(a=a, size=size).astype(dtype), + a, + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, @@ -1256,11 +1259,12 @@ def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden)) size = _size2shape(size) d = {'dfnum': dfnum, 'dfden': dfden} + dtype = jax.dtypes.canonicalize_dtype(jnp.float_) r = call(lambda x: np.random.f(dfnum=x['dfnum'], dfden=x['dfden'], - size=size), + size=size).astype(dtype), d, - result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None, @@ -1274,12 +1278,14 @@ def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequenc jnp.shape(nbad), jnp.shape(nsample)) size = _size2shape(size) + dtype = jax.dtypes.canonicalize_dtype(jnp.int_) d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample} - r = call(lambda x: np.random.hypergeometric(ngood=x['ngood'], - nbad=x['nbad'], - nsample=x['nsample'], - size=size), - d, result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + r = call(lambda d: np.random.hypergeometric(ngood=d['ngood'], + nbad=d['nbad'], + nsample=d['nsample'], + size=size).astype(dtype), + d, + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None, @@ -1288,8 +1294,10 @@ def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None, if size is None: size = jnp.shape(p) size = _size2shape(size) - r = call(lambda p: np.random.logseries(p=p, size=size), - p, result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + dtype = jax.dtypes.canonicalize_dtype(jnp.int_) + r = call(lambda p: np.random.logseries(p=p, size=size).astype(dtype), + p, + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None, @@ -1303,11 +1311,12 @@ def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[in jnp.shape(nonc)) size = _size2shape(size) d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc} + dtype = jax.dtypes.canonicalize_dtype(jnp.float_) r = call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], dfden=x['dfden'], nonc=x['nonc'], - size=size), - d, result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + size=size).astype(dtype), + d, result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) # PyTorch compatibility # diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 40bcbb706..acedcff12 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -9,10 +9,15 @@ import brainpy as bp import brainpy.math as bm from brainpy._src.dependency_check import import_taichi - if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) +import platform +force_test = False # turn on to force test on windows locally +if platform.system() == 'Windows' and not force_test: + pytest.skip('skip windows', allow_module_level=True) + + seed = 1234 diff --git a/brainpy/_src/math/surrogate/_compt.py b/brainpy/_src/math/surrogate/_compt.py deleted file mode 100644 index 67b7d5158..000000000 --- a/brainpy/_src/math/surrogate/_compt.py +++ /dev/null @@ -1,247 +0,0 @@ -# -*- coding: utf-8 -*- - -import warnings - -from jax import custom_gradient, numpy as jnp - -from brainpy._src.math.compat_numpy import asarray -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.environment import get_float -from brainpy._src.math.ndarray import Array - -__all__ = [ - 'spike_with_sigmoid_grad', - 'spike_with_linear_grad', - 'spike_with_gaussian_grad', - 'spike_with_mg_grad', - - 'spike2_with_sigmoid_grad', - 'spike2_with_linear_grad', -] - - -def _consistent_type(target, compare): - return as_jax(target) if not isinstance(compare, Array) else asarray(target) - - -@custom_gradient -def spike_with_sigmoid_grad(x: Array, scale: float = 100.): - """Spike function with the sigmoid surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.sigmoid_grad()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x: Array - The input data. - scale: float - The scaling factor. - """ - warnings.warn('Use `brainpy.math.surrogate.inv_square_grad()` instead.', UserWarning) - - x = as_jax(x) - z = jnp.asarray(x >= 0, dtype=get_float()) - - def grad(dE_dz): - dE_dz = as_jax(dE_dz) - dE_dx = dE_dz / (scale * jnp.abs(x) + 1.0) ** 2 - if scale is None: - return (_consistent_type(dE_dx, x),) - else: - dscale = jnp.zeros_like(scale) - return (dE_dx, dscale) - - return z, grad - - -@custom_gradient -def spike2_with_sigmoid_grad(x_new: Array, x_old: Array, scale: float = None): - """Spike function with the sigmoid surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.inv_square_grad2()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x_new: Array - The input data. - x_old: Array - The input data. - scale: optional, float - The scaling factor. - """ - warnings.warn('Use `brainpy.math.surrogate.inv_square_grad2()` instead.', UserWarning) - - x_new_comp = x_new >= 0 - x_old_comp = x_old < 0 - z = jnp.asarray(jnp.logical_and(x_new_comp, x_old_comp), dtype=get_float()) - - def grad(dE_dz): - _scale = 100. if scale is None else scale - dx_new = (dE_dz / (_scale * jnp.abs(x_new) + 1.0) ** 2) * jnp.asarray(x_old_comp, dtype=get_float()) - dx_old = -(dE_dz / (_scale * jnp.abs(x_old) + 1.0) ** 2) * jnp.asarray(x_new_comp, dtype=get_float()) - if scale is None: - return (_consistent_type(dx_new, x_new), - _consistent_type(dx_old, x_old)) - else: - dscale = jnp.zeros_like(_scale) - return (_consistent_type(dx_new, x_new), - _consistent_type(dx_old, x_old), - _consistent_type(dscale, scale)) - - return z, grad - - -@custom_gradient -def spike_with_linear_grad(x: Array, scale: float = None): - """Spike function with the relu surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.relu_grad()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x: Array - The input data. - scale: float - The scaling factor. - """ - - warnings.warn('Use `brainpy.math.surrogate.relu_grad()` instead.', UserWarning) - - z = jnp.asarray(x >= 0., dtype=get_float()) - - def grad(dE_dz): - _scale = 0.3 if scale is None else scale - dE_dx = dE_dz * jnp.maximum(1 - jnp.abs(x), 0) * _scale - if scale is None: - return (_consistent_type(dE_dx, x),) - else: - dscale = jnp.zeros_like(_scale) - return (_consistent_type(dE_dx, x), _consistent_type(dscale, _scale)) - - return z, grad - - -@custom_gradient -def spike2_with_linear_grad(x_new: Array, x_old: Array, scale: float = 10.): - """Spike function with the linear surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.relu_grad2()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x_new: Array - The input data. - x_old: Array - The input data. - scale: float - The scaling factor. - """ - warnings.warn('Use `brainpy.math.surrogate.relu_grad2()` instead.', UserWarning) - - x_new_comp = x_new >= 0 - x_old_comp = x_old < 0 - z = jnp.asarray(jnp.logical_and(x_new_comp, x_old_comp), dtype=get_float()) - - def grad(dE_dz): - _scale = 0.3 if scale is None else scale - dx_new = (dE_dz * jnp.maximum(1 - jnp.abs(x_new), 0) * _scale) * jnp.asarray(x_old_comp, dtype=get_float()) - dx_old = -(dE_dz * jnp.maximum(1 - jnp.abs(x_old), 0) * _scale) * jnp.asarray(x_new_comp, dtype=get_float()) - if scale is None: - return (_consistent_type(dx_new, x_new), - _consistent_type(dx_old, x_old)) - else: - dscale = jnp.zeros_like(_scale) - return (_consistent_type(dx_new, x_new), - _consistent_type(dx_old, x_old), - _consistent_type(dscale, scale)) - - return z, grad - - -def _gaussian(x, mu, sigma): - return jnp.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / jnp.sqrt(2 * jnp.pi) / sigma - - -@custom_gradient -def spike_with_gaussian_grad(x, sigma=None, scale=None): - """Spike function with the Gaussian surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.gaussian_grad()`` instead. - Will be removed after version 2.4.0. - - """ - - warnings.warn('Use `brainpy.math.surrogate.gaussian_grad()` instead.', UserWarning) - - z = jnp.asarray(x >= 0., dtype=get_float()) - - def grad(dE_dz): - _scale = 0.5 if scale is None else scale - _sigma = 0.5 if sigma is None else sigma - dE_dx = dE_dz * _gaussian(x, 0., _sigma) * _scale - returns = (_consistent_type(dE_dx, x),) - if sigma is not None: - returns += (_consistent_type(jnp.zeros_like(_sigma), sigma),) - if scale is not None: - returns += (_consistent_type(jnp.zeros_like(_scale), scale),) - return returns - - return z, grad - - -@custom_gradient -def spike_with_mg_grad(x, h=None, s=None, sigma=None, scale=None): - """Spike function with the multi-Gaussian surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.multi_sigmoid_grad()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x: ndarray - The variable to judge spike. - h: float - The hyper-parameters of approximate function - s: float - The hyper-parameters of approximate function - sigma: float - The gaussian sigma. - scale: float - The gradient scale. - """ - - warnings.warn('Use `brainpy.math.surrogate.multi_sigmoid_grad()` instead.', UserWarning) - - z = jnp.asarray(x >= 0., dtype=get_float()) - - def grad(dE_dz): - _sigma = 0.5 if sigma is None else sigma - _scale = 0.5 if scale is None else scale - _s = 6.0 if s is None else s - _h = 0.15 if h is None else h - dE_dx = dE_dz * (_gaussian(x, mu=0., sigma=_sigma) * (1. + _h) - - _gaussian(x, mu=_sigma, sigma=_s * _sigma) * _h - - _gaussian(x, mu=-_sigma, sigma=_s * _sigma) * _h) * _scale - returns = (_consistent_type(dE_dx, x),) - if h is not None: - returns += (_consistent_type(jnp.zeros_like(_h), h),) - if s is not None: - returns += (_consistent_type(jnp.zeros_like(_s), s),) - if sigma is not None: - returns += (_consistent_type(jnp.zeros_like(_sigma), sigma),) - if scale is not None: - returns += (_consistent_type(jnp.zeros_like(_scale), scale),) - return returns - - return z, grad - diff --git a/brainpy/_src/math/tests/test_environment.py b/brainpy/_src/math/tests/test_environment.py new file mode 100644 index 000000000..83315899f --- /dev/null +++ b/brainpy/_src/math/tests/test_environment.py @@ -0,0 +1,15 @@ +import unittest + +import jax + +import brainpy.math as bm + + +class TestEnvironment(unittest.TestCase): + def test_numpy_func_return(self): + with bm.environment(numpy_func_return='jax_array'): + a = bm.random.randn(3, 3) + self.assertTrue(isinstance(a, jax.Array)) + with bm.environment(numpy_func_return='bp_array'): + a = bm.random.randn(3, 3) + self.assertTrue(isinstance(a, bm.Array)) diff --git a/brainpy/_src/math/tests/test_random.py b/brainpy/_src/math/tests/test_random.py index 63b770646..1621f43df 100644 --- a/brainpy/_src/math/tests/test_random.py +++ b/brainpy/_src/math/tests/test_random.py @@ -1,8 +1,10 @@ +import platform import unittest import jax.numpy as jnp import jax.random as jr import numpy as np +import pytest import brainpy.math as bm import brainpy.math.random as br @@ -354,11 +356,13 @@ def test_hypergeometric1(self): a = bm.random.hypergeometric(10, 10, 10, 20) self.assertTupleEqual(a.shape, (20,)) + @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error') def test_hypergeometric2(self): br.seed() a = bm.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]]) self.assertTupleEqual(a.shape, (2, 2)) + @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error') def test_hypergeometric3(self): br.seed() a = bm.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2)) diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 9a64f9f25..08a070f02 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -44,13 +44,6 @@ del jnp, config -from brainpy._src.math.surrogate._compt import ( - spike_with_sigmoid_grad as spike_with_sigmoid_grad, - spike_with_linear_grad as spike_with_linear_grad, - spike_with_gaussian_grad as spike_with_gaussian_grad, - spike_with_mg_grad as spike_with_mg_grad, -) - from brainpy._src.math import defaults from brainpy._src.deprecations import deprecation_getattr from brainpy._src.dependency_check import import_taichi, import_numba diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py index c0fcb67ae..f383c1a20 100644 --- a/brainpy/math/op_register.py +++ b/brainpy/math/op_register.py @@ -2,8 +2,8 @@ from brainpy._src.math.op_register import ( CustomOpByNumba, compile_cpu_signature_with_numba, - clean_caches, - check_kernels_count, + clear_taichi_aot_caches, + count_taichi_aot_kernels, ) from brainpy._src.math.op_register.base import XLACustomOp diff --git a/docs/apis/brainpy.math.op_register.rst b/docs/apis/brainpy.math.op_register.rst index a50b4d300..13ce518cb 100644 --- a/docs/apis/brainpy.math.op_register.rst +++ b/docs/apis/brainpy.math.op_register.rst @@ -22,6 +22,23 @@ General Operator Customization Interface +CPU Operator Customization with Taichi +------------------------------------- + +.. currentmodule:: brainpy.math +.. automodule:: brainpy.math + +.. autosummary:: + :toctree: generated/ + + clear_taichi_aot_caches + count_taichi_aot_kernels + + + + + + CPU Operator Customization with Numba ------------------------------------- @@ -34,7 +51,6 @@ CPU Operator Customization with Numba :template: classtemplate.rst CustomOpByNumba - XLACustomOp .. autosummary:: @@ -43,3 +59,17 @@ CPU Operator Customization with Numba register_op_with_numba compile_cpu_signature_with_numba + + +Operator Autograd Customization +------------------------------- + +.. currentmodule:: brainpy.math +.. automodule:: brainpy.math + +.. autosummary:: + :toctree: generated/ + + defjvp + + diff --git a/examples/dynamics_simulation/ei_nets.py b/examples/dynamics_simulation/ei_nets.py index f98527458..7923c93d0 100644 --- a/examples/dynamics_simulation/ei_nets.py +++ b/examples/dynamics_simulation/ei_nets.py @@ -213,7 +213,7 @@ def __init__(self): super().__init__() self.N = bp.dyn.LifRefLTC(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 1.)) - self.delay = bp.VarDelay(self.N.spike, entries={'delay': 2}) + self.delay = bp.VarDelay(self.N.spike, entries={'delay': 0.}) self.syn1 = bp.dyn.Expon(size=3200, tau=5.) self.syn2 = bp.dyn.Expon(size=800, tau=10.) self.E = bp.dyn.VanillaProj( @@ -228,7 +228,7 @@ def __init__(self): ) def update(self, input): - spk = self.delay.at('I') + spk = self.delay.at('delay') self.E(self.syn1(spk[:3200])) self.I(self.syn2(spk[3200:])) self.delay(self.N(input))