diff --git a/brainpy/__init__.py b/brainpy/__init__.py index a3a1de694..79aa216ba 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -61,6 +61,10 @@ Sequential as Sequential, Dynamic as Dynamic, # category Projection as Projection, + receive_update_input, # decorators + receive_update_output, + not_receive_update_input, + not_receive_update_output, ) DynamicalSystemNS = DynamicalSystem Network = DynSysGroup @@ -84,7 +88,6 @@ load_state as load_state, clear_input as clear_input) - # Part: Running # # --------------- # from brainpy._src.runners import (DSRunner as DSRunner) diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py index ee0be5763..66530a5b1 100644 --- a/brainpy/_src/delay.py +++ b/brainpy/_src/delay.py @@ -28,7 +28,21 @@ ] -delay_identifier = '_*_delay_*_' +delay_identifier = '_*_delay_of_' + + +def _get_delay(delay_time, delay_step): + if delay_time is None: + if delay_step is None: + return None, None + else: + assert isinstance(delay_step, int), '"delay_step" should be an integer.' + delay_time = delay_step * bm.get_dt() + else: + assert delay_step is None, '"delay_step" should be None if "delay_time" is given.' + assert isinstance(delay_time, (int, float)) + delay_step = math.ceil(delay_time / bm.get_dt()) + return delay_time, delay_step class Delay(DynamicalSystem, ParamDesc): @@ -97,13 +111,15 @@ def __init__( def register_entry( self, entry: str, - delay_time: Optional[Union[float, bm.Array, Callable]], + delay_time: Optional[Union[float, bm.Array, Callable]] = None, + delay_step: Optional[int] = None ) -> 'Delay': """Register an entry to access the data. Args: entry: str. The entry to access the delay data. delay_time: The delay time of the entry (can be a float). + delay_step: The delay step of the entry (must be an int). ``delay_step = delay_time / dt``. Returns: Return the self. @@ -237,13 +253,15 @@ def __init__( def register_entry( self, entry: str, - delay_time: Optional[Union[int, float]], + delay_time: Optional[Union[int, float]] = None, + delay_step: Optional[int] = None, ) -> 'Delay': """Register an entry to access the data. Args: entry: str. The entry to access the delay data. delay_time: The delay time of the entry (can be a float). + delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``. Returns: Return the self. @@ -258,12 +276,7 @@ def register_entry( assert delay_time.size == 1 and delay_time.ndim == 0 delay_time = delay_time.item() - if delay_time is None: - delay_step = None - delay_time = 0. - else: - assert isinstance(delay_time, (int, float)) - delay_step = math.ceil(delay_time / bm.get_dt()) + _, delay_step = _get_delay(delay_time, delay_step) # delay variable if delay_step is not None: @@ -354,6 +367,8 @@ def update( """Update delay variable with the new data. """ if self.data is not None: + # jax.debug.print('last value == target value {} ', jnp.allclose(latest_value, self.target.value)) + # get the latest target value if latest_value is None: latest_value = self.target.value @@ -361,17 +376,20 @@ def update( # update the delay data at the rotation index if self.method == ROTATE_UPDATE: i = share.load('i') - idx = bm.as_jax((-i - 1) % self.max_length, dtype=jnp.int32) - self.data[idx] = latest_value + idx = bm.as_jax(-i % self.max_length, dtype=jnp.int32) + self.data[jax.lax.stop_gradient(idx)] = latest_value # update the delay data at the first position elif self.method == CONCAT_UPDATE: if self.max_length > 1: latest_value = bm.expand_dims(latest_value, 0) - self.data.value = bm.concat([latest_value, self.data[1:]], axis=0) + self.data.value = bm.concat([latest_value, self.data[:-1]], axis=0) else: self.data[0] = latest_value + else: + raise ValueError(f'Unknown updating method "{self.method}"') + def reset_state(self, batch_size: int = None, **kwargs): """Reset the delay data. """ diff --git a/brainpy/_src/dynold/synapses/abstract_models.py b/brainpy/_src/dynold/synapses/abstract_models.py index c7a902f01..2e214ed29 100644 --- a/brainpy/_src/dynold/synapses/abstract_models.py +++ b/brainpy/_src/dynold/synapses/abstract_models.py @@ -115,12 +115,7 @@ def __init__( self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr') # register delay - self.pre.register_local_delay("spike", self.name, delay_step) - - def reset_state(self, batch_size=None): - self.output.reset_state(batch_size) - if self.stp is not None: - self.stp.reset_state(batch_size) + self.pre.register_local_delay("spike", self.name, delay_step=delay_step) def update(self, pre_spike=None): # pre-synaptic spikes @@ -232,7 +227,6 @@ class Exponential(TwoEndConn): method: str The numerical integration methods. - """ def __init__( @@ -283,17 +277,16 @@ def __init__( else: raise ValueError(f'Does not support {comp_method}, only "sparse" or "dense".') - # variables - self.g = self.syn.g - # delay - self.pre.register_local_delay("spike", self.name, delay_step) + self.pre.register_local_delay("spike", self.name, delay_step=delay_step) - def reset_state(self, batch_size=None): - self.syn.reset_state(batch_size) - self.output.reset_state(batch_size) - if self.stp is not None: - self.stp.reset_state(batch_size) + @property + def g(self): + return self.syn.g + + @g.setter + def g(self, value): + self.syn.g = value def update(self, pre_spike=None): # delays diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py index 55bac7111..5ceeb4e23 100644 --- a/brainpy/_src/dynold/synapses/base.py +++ b/brainpy/_src/dynold/synapses/base.py @@ -10,8 +10,7 @@ from brainpy._src.dyn.base import NeuDyn from brainpy._src.dynsys import DynamicalSystem from brainpy._src.initialize import parameter -from brainpy._src.mixin import (ParamDesc, JointType, - SupportAutoDelay, BindCondData, ReturnInfo) +from brainpy._src.mixin import (ParamDesc, JointType, SupportAutoDelay, BindCondData, ReturnInfo) from brainpy.errors import UnsupportedError from brainpy.types import ArrayType @@ -47,9 +46,6 @@ def isregistered(self, val: bool): raise ValueError('Must be an instance of bool.') self._registered = val - def reset_state(self, batch_size=None): - pass - def register_master(self, master: SynConn): if not isinstance(master, SynConn): raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}') @@ -296,7 +292,7 @@ def __init__( mode=mode) # delay - self.pre.register_local_delay("spike", self.name, delay_step) + self.pre.register_local_delay("spike", self.name, delay_step=delay_step) # synaptic dynamics self.syn = syn @@ -340,11 +336,5 @@ def g_max(self, v): UserWarning) self.comm.weight = v - def reset_state(self, *args, **kwargs): - self.syn.reset(*args, **kwargs) - self.comm.reset(*args, **kwargs) - self.output.reset(*args, **kwargs) - if self.stp is not None: - self.stp.reset(*args, **kwargs) diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index cb086b10d..a6fcc16a7 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -93,17 +93,41 @@ def __init__( # Attribute for "SupportInputProj" # each instance of "SupportInputProj" should have a "cur_inputs" attribute - self.current_inputs = bm.node_dict() - self.delta_inputs = bm.node_dict() + self._current_inputs: Optional[Dict[str, Callable]] = None + self._delta_inputs: Optional[Dict[str, Callable]] = None # the before- / after-updates used for computing # added after the version of 2.4.3 - self.before_updates: Dict[str, Callable] = bm.node_dict() - self.after_updates: Dict[str, Callable] = bm.node_dict() + self._before_updates: Optional[Dict[str, Callable]] = None + self._after_updates: Optional[Dict[str, Callable]] = None # super initialization super().__init__(name=name) + @property + def current_inputs(self): + if self._current_inputs is None: + self._current_inputs = bm.node_dict() + return self._current_inputs + + @property + def delta_inputs(self): + if self._delta_inputs is None: + self._delta_inputs = bm.node_dict() + return self._delta_inputs + + @property + def before_updates(self): + if self._before_updates is None: + self._before_updates = bm.node_dict() + return self._before_updates + + @property + def after_updates(self): + if self._after_updates is None: + self._after_updates = bm.node_dict() + return self._after_updates + def add_bef_update(self, key: Any, fun: Callable): """Add the before update into this node""" if key in self.before_updates: @@ -220,25 +244,32 @@ def register_local_delay( self, var_name: str, delay_name: str, - delay: Union[numbers.Number, ArrayType] = None, + delay_time: Union[numbers.Number, ArrayType] = None, + delay_step: Union[numbers.Number, ArrayType] = None, ): """Register local relay at the given delay time. Args: var_name: str. The name of the delay target variable. delay_name: str. The name of the current delay data. - delay: The delay time. + delay_time: The delay time. Float. + delay_step: The delay step. Int. ``delay_step`` and ``delay_time`` are exclusive. ``delay_step = delay_time / dt``. """ delay_identifier, init_delay_by_return = _get_delay_tool() delay_identifier = delay_identifier + var_name + # check whether the "var_name" has been registered try: target = getattr(self, var_name) except AttributeError: raise AttributeError(f'This node {self} does not has attribute of "{var_name}".') if not self.has_aft_update(delay_identifier): - self.add_aft_update(delay_identifier, init_delay_by_return(target)) + # add a model to receive the return of the target model + # moreover, the model should not receive the return of the update function + model = not_receive_update_output(init_delay_by_return(target)) + # register the model + self.add_aft_update(delay_identifier, model) delay_cls = self.get_aft_update(delay_identifier) - delay_cls.register_entry(delay_name, delay) + delay_cls.register_entry(delay_name, delay_time=delay_time, delay_step=delay_step) def get_local_delay(self, var_name, delay_name): """Get the delay at the given identifier (`name`). @@ -381,14 +412,20 @@ def __call__(self, *args, **kwargs): # ``before_updates`` for model in self.before_updates.values(): - model() + if hasattr(model, '_receive_update_input'): + model(*args, **kwargs) + else: + model() # update the model self ret = self.update(*args, **kwargs) # ``after_updates`` for model in self.after_updates.values(): - model(ret) + if hasattr(model, '_not_receive_update_output'): + model() + else: + model(ret) return ret def __rrshift__(self, other): @@ -832,3 +869,75 @@ def _slice_to_num(slice_: slice, length: int): start += step num += 1 return num + + +def receive_update_output(cls: object): + """ + The decorator to mark the object (as the after updates) to receive the output of the update function. + + That is, the `aft_update` will receive the return of the update function:: + + ret = model.update(*args, **kwargs) + for fun in model.aft_updates: + fun(ret) + + """ + # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' + if hasattr(cls, '_not_receive_update_output'): + delattr(cls, '_not_receive_update_output') + return cls + + +def not_receive_update_output(cls: object): + """ + The decorator to mark the object (as the after updates) to not receive the output of the update function. + + That is, the `aft_update` will not receive the return of the update function:: + + ret = model.update(*args, **kwargs) + for fun in model.aft_updates: + fun() + + """ + # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' + cls._not_receive_update_output = True + return cls + + +def receive_update_input(cls: object): + """ + The decorator to mark the object (as the before updates) to receive the input of the update function. + + That is, the `bef_update` will receive the input of the update function:: + + + for fun in model.bef_updates: + fun(*args, **kwargs) + model.update(*args, **kwargs) + + """ + # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' + cls._receive_update_input = True + return cls + + +def not_receive_update_input(cls: object): + """ + The decorator to mark the object (as the before updates) to not receive the input of the update function. + + That is, the `bef_update` will not receive the input of the update function:: + + for fun in model.bef_updates: + fun() + model.update() + + """ + # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' + if hasattr(cls, '_receive_update_input'): + delattr(cls, '_receive_update_input') + return cls + + + + + diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 936f62386..de64f94e7 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -14,6 +14,7 @@ import numpy as np from jax.tree_util import register_pytree_node_class +from brainpy._src.math import defaults from brainpy._src.math.modes import Mode from brainpy._src.math.ndarray import (Array, ) from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector) @@ -22,13 +23,11 @@ from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, VarList, VarDict) from brainpy._src.math.sharding import BATCH_AXIS -from brainpy._src.math import defaults variable_ = None StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys']) registered = set() - __all__ = [ 'BrainPyObject', 'Base', 'FunAsObject', 'ObjectTransform', @@ -103,11 +102,23 @@ def __init__(self, name=None): # Used to wrap the implicit variables # which cannot be accessed by self.xxx - self.implicit_vars: ArrayCollector = ArrayCollector() + self._implicit_vars: Optional[ArrayCollector] = None # Used to wrap the implicit children nodes # which cannot be accessed by self.xxx - self.implicit_nodes: Collector = Collector() + self._implicit_nodes: Optional[Collector] = None + + @property + def implicit_vars(self): + if self._implicit_vars is None: + self._implicit_vars = ArrayCollector() + return self._implicit_vars + + @property + def implicit_nodes(self): + if self._implicit_nodes is None: + self._implicit_nodes = Collector() + return self._implicit_nodes def setattr(self, key: str, value: Any) -> None: super().__setattr__(key, value) @@ -225,7 +236,7 @@ def tree_flatten(self): static_values = [] for k, v in self.__dict__.items(): if isinstance(v, (BrainPyObject, Variable, NodeList, NodeDict, VarList, VarDict)): - # if isinstance(v, (BrainPyObject, Variable)): + # if isinstance(v, (BrainPyObject, Variable)): dynamic_names.append(k) dynamic_values.append(v) else: diff --git a/brainpy/_src/tests/test_base_classes.py b/brainpy/_src/tests/test_base_classes.py index 9c095a30e..3534f0a48 100644 --- a/brainpy/_src/tests/test_base_classes.py +++ b/brainpy/_src/tests/test_base_classes.py @@ -3,6 +3,7 @@ import unittest import brainpy as bp +import brainpy.math as bm class TestDynamicalSystem(unittest.TestCase): @@ -17,4 +18,53 @@ def test_delay(self): runner = bp.DSRunner(net,) runner.run(10.) + bm.clear_buffer_memory() + + def test_receive_update_output(self): + def aft_update(inp): + assert inp is not None + + hh = bp.dyn.HH(1) + hh.add_aft_update('aft_update', aft_update) + bp.share.save(i=0, t=0.) + hh(1.) + + bm.clear_buffer_memory() + + def test_do_not_receive_update_output(self): + def aft_update(): + pass + + hh = bp.dyn.HH(1) + hh.add_aft_update('aft_update', bp.not_receive_update_output(aft_update)) + bp.share.save(i=0, t=0.) + hh(1.) + + bm.clear_buffer_memory() + + def test_not_receive_update_input(self): + def bef_update(): + pass + + hh = bp.dyn.HH(1) + hh.add_bef_update('bef_update', bef_update) + bp.share.save(i=0, t=0.) + hh(1.) + + bm.clear_buffer_memory() + + def test_receive_update_input(self): + def bef_update(inp): + assert inp is not None + + hh = bp.dyn.HH(1) + hh.add_bef_update('bef_update', bp.receive_update_input(bef_update)) + bp.share.save(i=0, t=0.) + hh(1.) + + bm.clear_buffer_memory() + + + + diff --git a/brainpy/_src/tests/test_delay.py b/brainpy/_src/tests/test_delay.py index 20d49937c..b7bd44ead 100644 --- a/brainpy/_src/tests/test_delay.py +++ b/brainpy/_src/tests/test_delay.py @@ -1,13 +1,15 @@ +import unittest + +import jax.numpy as jnp import brainpy as bp -import unittest class TestVarDelay(unittest.TestCase): def test_delay1(self): bp.math.random.seed() a = bp.math.Variable((10, 20)) - delay = bp.VarDelay(a,) + delay = bp.VarDelay(a, ) delay.register_entry('a', 1.) delay.register_entry('b', 2.) delay.register_entry('c', None) @@ -15,8 +17,44 @@ def test_delay1(self): delay.register_entry('c', 10.) bp.math.clear_buffer_memory() + def test_rotation_delay(self): + a = bp.math.Variable((1,)) + rotation_delay = bp.VarDelay(a) + t0 = 0. + t1, n1 = 1., 10 + t2, n2 = 2., 20 + + rotation_delay.register_entry('a', t0) + rotation_delay.register_entry('b', t1) + rotation_delay.register_entry('c', t2) + + print() + for i in range(100): + bp.share.save(i=i) + a.value = jnp.ones((1,)) * i + rotation_delay() + self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i)) + self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1 + 1, 0.))) + self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2 + 1, 0.))) + bp.math.clear_buffer_memory() - - - - + def test_concat_delay(self): + a = bp.math.Variable((1,)) + rotation_delay = bp.VarDelay(a, method='concat') + t0 = 0. + t1, n1 = 1., 10 + t2, n2 = 2., 20 + + rotation_delay.register_entry('a', t0) + rotation_delay.register_entry('b', t1) + rotation_delay.register_entry('c', t2) + + print() + for i in range(100): + bp.share.save(i=i) + a.value = jnp.ones((1,)) * i + rotation_delay() + self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i)) + self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1 + 1, 0.))) + self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2 + 1, 0.))) + bp.math.clear_buffer_memory() diff --git a/brainpy/_src/tests/test_mixin.py b/brainpy/_src/tests/test_mixin.py index 962b76cb9..e864fd647 100644 --- a/brainpy/_src/tests/test_mixin.py +++ b/brainpy/_src/tests/test_mixin.py @@ -42,7 +42,7 @@ class TestDelayRegister(unittest.TestCase): def test2(self): bp.share.save(i=0) lif = bp.dyn.Lif(10) - lif.register_local_delay('spike', 'a', 10.) + lif.register_local_delay('spike', 'a', delay_time=10.) data = lif.get_local_delay('spike', 'a') self.assertTrue(bm.allclose(data, bm.zeros(10)))