From b28cb6a7e2b942235b7d1e4ec22681feb15cfb1d Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 28 Dec 2023 18:39:49 +0800 Subject: [PATCH 1/4] [dyn] synaptic projection updates 1. reorganize the projection structures; 2. rename previous reduced projections with intuitive names 3. add `brainpy.dyn.HalfProjDelta` and `brainpy.dyn.FullProjDelta` --- brainpy/_add_deprecations.py | 10 + brainpy/_src/dyn/neurons/hh.py | 21 +- brainpy/_src/dyn/neurons/lif.py | 70 +- brainpy/_src/dyn/others/common.py | 2 +- brainpy/_src/dyn/outs/outputs.py | 6 +- brainpy/_src/dyn/projections/__init__.py | 5 - brainpy/_src/dyn/projections/align_post.py | 442 +++++++ brainpy/_src/dyn/projections/align_pre.py | 524 ++++++++ brainpy/_src/dyn/projections/aligns.py | 1053 ----------------- brainpy/_src/dyn/projections/base.py | 12 + brainpy/_src/dyn/projections/delta.py | 203 ++++ brainpy/_src/dyn/projections/inputs.py | 237 ++-- brainpy/_src/dyn/projections/others.py | 81 -- brainpy/_src/dyn/projections/plasticity.py | 7 +- .../_src/dyn/projections/tests/test_STDP.py | 2 +- .../_src/dyn/projections/tests/test_aligns.py | 176 +-- .../_src/dyn/projections/tests/test_delta.py | 51 + brainpy/_src/dyn/projections/vanilla.py | 83 ++ brainpy/_src/dyn/synapses/abstract_models.py | 66 +- brainpy/_src/dynold/synapses/base.py | 14 +- brainpy/_src/dynsys.py | 3 +- brainpy/_src/mixin.py | 98 +- brainpy/dyn/projections.py | 34 +- brainpy/dyn/synapses.py | 1 - docs/apis/brainpy.dyn.projections.rst | 18 +- docs/apis/brainpy.dyn.synapses.rst | 1 - docs/apis/losses.rst | 8 + examples/dynamics_simulation/COBA.py | 16 +- examples/dynamics_simulation/COBA_parallel.py | 6 +- .../decision_making_network.py | 4 +- examples/dynamics_simulation/ei_nets.py | 160 +-- 31 files changed, 1844 insertions(+), 1570 deletions(-) create mode 100644 brainpy/_src/dyn/projections/align_post.py create mode 100644 brainpy/_src/dyn/projections/align_pre.py delete mode 100644 brainpy/_src/dyn/projections/aligns.py create mode 100644 brainpy/_src/dyn/projections/base.py create mode 100644 brainpy/_src/dyn/projections/delta.py delete mode 100644 brainpy/_src/dyn/projections/others.py create mode 100644 brainpy/_src/dyn/projections/tests/test_delta.py create mode 100644 brainpy/_src/dyn/projections/vanilla.py diff --git a/brainpy/_add_deprecations.py b/brainpy/_add_deprecations.py index 17edcff31..d04c3aa2e 100644 --- a/brainpy/_add_deprecations.py +++ b/brainpy/_add_deprecations.py @@ -88,6 +88,16 @@ # neurons 'NeuGroup': ('brainpy.dyn.NeuGroup', 'brainpy.dyn.NeuDyn', NeuDyn), + # projections + 'ProjAlignPostMg1': ('brainpy.dyn.ProjAlignPostMg1', 'brainpy.dyn.HalfProjAlignPostMg', dyn.HalfProjAlignPostMg), + 'ProjAlignPostMg2': ('brainpy.dyn.ProjAlignPostMg2', 'brainpy.dyn.FullProjAlignPostMg', dyn.FullProjAlignPostMg), + 'ProjAlignPost1': ('brainpy.dyn.ProjAlignPost1', 'brainpy.dyn.HalfProjAlignPost', dyn.HalfProjAlignPost), + 'ProjAlignPost2': ('brainpy.dyn.ProjAlignPost2', 'brainpy.dyn.FullProjAlignPost', dyn.FullProjAlignPost), + 'ProjAlignPreMg1': ('brainpy.dyn.ProjAlignPreMg1', 'brainpy.dyn.FullProjAlignPreSDMg', dyn.FullProjAlignPreSDMg), + 'ProjAlignPreMg2': ('brainpy.dyn.ProjAlignPreMg2', 'brainpy.dyn.FullProjAlignPreDSMg', dyn.FullProjAlignPreDSMg), + 'ProjAlignPre1': ('brainpy.dyn.ProjAlignPre1', 'brainpy.dyn.FullProjAlignPreSD', dyn.FullProjAlignPreSD), + 'ProjAlignPre2': ('brainpy.dyn.ProjAlignPre2', 'brainpy.dyn.FullProjAlignPreDS', dyn.FullProjAlignPreDS), + # synapses 'TwoEndConn': ('brainpy.dyn.TwoEndConn', 'brainpy.synapses.TwoEndConn', synapses.TwoEndConn), 'SynSTP': ('brainpy.dyn.SynSTP', 'brainpy.synapses.SynSTP', synapses.SynSTP), diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py index 97e612097..fca13e8e1 100644 --- a/brainpy/_src/dyn/neurons/hh.py +++ b/brainpy/_src/dyn/neurons/hh.py @@ -117,7 +117,7 @@ def __init__( def derivative(self, V, t, I): # synapses - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) # channels for ch in self.nodes(level=1, include_self=False).subset(IonChaDyn).unique().values(): I = I + ch.current(V) @@ -140,7 +140,7 @@ def update(self, x=None): x = x * (1e-3 / self.A) # integral - V = self.integral(self.V.value, share['t'], x, share['dt']) + V = self.integral(self.V.value, share['t'], x, share['dt']) + self.sum_delta_inputs() # check whether the children channels have the correct parents. channels = self.nodes(level=1, include_self=False).subset(IonChaDyn).unique() @@ -176,7 +176,7 @@ def derivative(self, V, t, I): def update(self, x=None): # inputs x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -384,7 +384,7 @@ def reset_state(self, batch_size=None, **kwargs): self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_size) def dV(self, V, t, m, h, n, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) I_Na = (self.gNa * m * m * m * h) * (V - self.ENa) n2 = n * n I_K = (self.gK * n2 * n2) * (V - self.EK) @@ -402,6 +402,7 @@ def update(self, x=None): x = 0. if x is None else x V, m, h, n = self.integral(self.V.value, self.m.value, self.h.value, self.n.value, t, x, dt) + V += self.sum_delta_inputs() self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.m.value = m @@ -532,7 +533,7 @@ def derivative(self): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -662,7 +663,7 @@ def reset_state(self, batch_or_mode=None, **kwargs): self.spike = self.init_variable(partial(bm.zeros, dtype=bool), batch_or_mode) def dV(self, V, t, W, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) M_inf = (1 / 2) * (1 + bm.tanh((V - self.V1) / self.V2)) I_Ca = self.g_Ca * M_inf * (V - self.V_Ca) I_K = self.g_K * W * (V - self.V_K) @@ -685,6 +686,7 @@ def update(self, x=None): dt = share.load('dt') x = 0. if x is None else x V, W = self.integral(self.V, self.W, t, x, dt) + V += self.sum_delta_inputs() spike = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.W.value = W @@ -761,7 +763,7 @@ def dV(self, V, t, W, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -951,7 +953,7 @@ def dn(self, n, t, V): return self.phi * dndt def dV(self, V, t, h, n, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) INa = self.gNa * self.m_inf(V) ** 3 * h * (V - self.ENa) IK = self.gK * n ** 4 * (V - self.EK) IL = self.gL * (V - self.EL) @@ -968,6 +970,7 @@ def update(self, x=None): x = 0. if x is None else x V, h, n = self.integral(self.V, self.h, self.n, t, x, dt) + V += self.sum_delta_inputs() self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th) self.V.value = V self.h.value = h @@ -1091,5 +1094,5 @@ def dV(self, V, t, h, n, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py index 988c915ac..d4599ebca 100644 --- a/brainpy/_src/dyn/neurons/lif.py +++ b/brainpy/_src/dyn/neurons/lif.py @@ -119,7 +119,7 @@ def __init__( self.reset_state(self.mode) def derivative(self, V, t, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) return (-V + self.V_rest + self.R * I) / self.tau def reset_state(self, batch_size=None, **kwargs): @@ -132,7 +132,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - self.V.value = self.integral(self.V.value, t, x, dt) + self.V.value = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() return self.V.value @@ -146,7 +146,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -252,7 +252,7 @@ def __init__( self.reset_state(self.mode) def derivative(self, V, t, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) return (-V + self.V_rest + self.R * I) / self.tau def reset_state(self, batch_size=None, **kwargs): @@ -265,7 +265,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -337,7 +337,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -464,7 +464,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -552,7 +552,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -723,7 +723,7 @@ def __init__( self.reset_state(self.mode) def derivative(self, V, t, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau return dvdt @@ -738,7 +738,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -880,7 +880,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -1076,7 +1076,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -1228,7 +1228,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -1400,7 +1400,7 @@ def __init__( self.reset_state(self.mode) def dV(self, V, t, w, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) exp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) dVdt = (- V + self.V_rest + exp - self.R * w + self.R * I) / self.tau return dVdt @@ -1424,7 +1424,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V, w = self.integral(self.V.value, self.w.value, t, x, dt) + self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -1559,7 +1559,7 @@ def dV(self, V, t, w, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -1756,7 +1756,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V, w = self.integral(self.V.value, self.w.value, t, x, dt) + self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -1901,7 +1901,7 @@ def dV(self, V, t, w, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -2040,7 +2040,7 @@ def __init__( self.reset_state(self.mode) def derivative(self, V, t, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau return dVdt @@ -2054,7 +2054,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -2166,7 +2166,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -2330,7 +2330,7 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V = self.integral(self.V.value, t, x, dt) + V = self.integral(self.V.value, t, x, dt) + self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -2451,7 +2451,7 @@ def derivative(self, V, t, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -2609,7 +2609,7 @@ def __init__( self.reset_state(self.mode) def dV(self, V, t, w, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) - w + I) / self.tau return dVdt @@ -2633,6 +2633,7 @@ def update(self, x=None): # integrate membrane potential V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V = V + self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -2756,7 +2757,7 @@ def dV(self, V, t, w, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -2939,6 +2940,7 @@ def update(self, x=None): # integrate membrane potential V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V += self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -3072,7 +3074,7 @@ def dV(self, V, t, w, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -3279,7 +3281,7 @@ def dVth(self, V_th, t, V): return self.a * (V - self.V_rest) - self.b * (V_th - self.V_th_inf) def dV(self, V, t, I1, I2, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau @property @@ -3300,6 +3302,7 @@ def update(self, x=None): # integrate membrane potential I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt) + V += self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -3452,7 +3455,7 @@ def dV(self, V, t, I1, I2, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -3680,6 +3683,7 @@ def update(self, x=None): # integrate membrane potential I1, I2, V_th, V = self.integral(self.I1.value, self.I2.value, self.V_th.value, self.V.value, t, x, dt) + V += self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -3846,7 +3850,7 @@ def dV(self, V, t, I1, I2, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -4012,7 +4016,7 @@ def __init__( self.reset_state(self.mode) def dV(self, V, t, u, I): - I = self.sum_inputs(V, init=I) + I = self.sum_current_inputs(V, init=I) dVdt = self.p1 * V * V + self.p2 * V + self.p3 - u + I return dVdt @@ -4040,6 +4044,7 @@ def update(self, x=None): # integrate membrane potential V, u = self.integral(self.V.value, self.u.value, t, x, dt) + V += self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -4161,7 +4166,7 @@ def dV(self, V, t, u, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) @@ -4351,6 +4356,7 @@ def update(self, x=None): # integrate membrane potential V, u = self.integral(self.V.value, self.u.value, t, x, dt) + V += self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -4485,7 +4491,7 @@ def dV(self, V, t, u, I): def update(self, x=None): x = 0. if x is None else x - x = self.sum_inputs(self.V.value, init=x) + x = self.sum_current_inputs(self.V.value, init=x) return super().update(x) diff --git a/brainpy/_src/dyn/others/common.py b/brainpy/_src/dyn/others/common.py index 7cf4f98b8..812375787 100644 --- a/brainpy/_src/dyn/others/common.py +++ b/brainpy/_src/dyn/others/common.py @@ -77,7 +77,7 @@ def update(self, inp=None): dt = share.load('dt') self.x.value = self.integral(self.x.value, t, dt) if inp is None: inp = 0. - inp = self.sum_inputs(self.x.value, init=inp) + inp = self.sum_current_inputs(self.x.value, init=inp) self.x += inp return self.x.value diff --git a/brainpy/_src/dyn/outs/outputs.py b/brainpy/_src/dyn/outs/outputs.py index 5dc54a232..8171367d7 100644 --- a/brainpy/_src/dyn/outs/outputs.py +++ b/brainpy/_src/dyn/outs/outputs.py @@ -82,7 +82,7 @@ def __init__( super().__init__(name=name, scaling=scaling) def update(self, conductance, potential=None): - return self.std_scaling(conductance) + return conductance class MgBlock(SynOut): @@ -138,5 +138,5 @@ def __init__( self.beta = init.parameter(beta, np.shape(beta), sharding=sharding) def update(self, conductance, potential): - return conductance *\ - (self.E - potential) / (1 + self.cc_Mg / self.beta * bm.exp(self.alpha * (self.V_offset - potential))) + norm = (1 + self.cc_Mg / self.beta * bm.exp(self.alpha * (self.V_offset - potential))) + return conductance * (self.E - potential) / norm diff --git a/brainpy/_src/dyn/projections/__init__.py b/brainpy/_src/dyn/projections/__init__.py index 8a7040824..e69de29bb 100644 --- a/brainpy/_src/dyn/projections/__init__.py +++ b/brainpy/_src/dyn/projections/__init__.py @@ -1,5 +0,0 @@ - -from .aligns import * -from .conn import * -from .others import * -from .inputs import * diff --git a/brainpy/_src/dyn/projections/align_post.py b/brainpy/_src/dyn/projections/align_post.py new file mode 100644 index 000000000..217045032 --- /dev/null +++ b/brainpy/_src/dyn/projections/align_post.py @@ -0,0 +1,442 @@ +from typing import Optional, Callable, Union + +from brainpy import math as bm, check +from brainpy._src.delay import (delay_identifier, + register_delay_by_return) +from brainpy._src.dynsys import DynamicalSystem, Projection +from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData, AlignPost) + +__all__ = [ + 'HalfProjAlignPostMg', 'FullProjAlignPostMg', + 'HalfProjAlignPost', 'FullProjAlignPost', + +] + + +def get_post_repr(out_label, syn, out): + return f'{out_label} // {syn.identifier} // {out.identifier}' + + +def align_post_add_bef_update(out_label, syn_desc, out_desc, post, proj_name): + # synapse and output initialization + _post_repr = get_post_repr(out_label, syn_desc, out_desc) + if not post.has_bef_update(_post_repr): + syn_cls = syn_desc() + out_cls = out_desc() + + # synapse and output initialization + post.add_inp_fun(proj_name, out_cls, label=out_label) + post.add_bef_update(_post_repr, _AlignPost(syn_cls, out_cls)) + syn = post.get_bef_update(_post_repr).syn + out = post.get_bef_update(_post_repr).out + return syn, out + + +class _AlignPost(DynamicalSystem): + def __init__(self, + syn: Callable, + out: JointType[DynamicalSystem, BindCondData]): + super().__init__() + self.syn = syn + self.out = out + + def update(self, *args, **kwargs): + self.out.bind_cond(self.syn(*args, **kwargs)) + + def reset_state(self, *args, **kwargs): + pass + + +class HalfProjAlignPostMg(Projection): + r"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. + + **Code Examples** + + To define an E/I balanced network model. + + .. code-block:: python + + import brainpy as bp + import brainpy.math as bm + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + syn=bp.dyn.Expon.desc(size=4000, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPostMg(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + syn=bp.dyn.Expon.desc(size=4000, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:3200]) + self.I(spk[3200:]) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + Args: + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + out_label: str. The prefix of the output function. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + comm: DynamicalSystem, + syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], + out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(comm, DynamicalSystem) + check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) + check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # synapse and output initialization + syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) + + # references + self.refs = dict(post=post) # invisible to ``self.nodes()`` + self.refs['syn'] = syn + self.refs['out'] = out + self.refs['comm'] = comm # unify the access + + def update(self, x): + current = self.comm(x) + self.refs['syn'].add_current(current) # synapse post current + return current + + +class FullProjAlignPostMg(Projection): + """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. + + **Code Examples** + + To define an E/I balanced network model. + + .. code-block:: python + + import brainpy as bp + import brainpy.math as bm + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPostMg(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPostMg(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + syn=bp.dyn.Expon.desc(size=ni, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPostMg(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + syn=bp.dyn.Expon.desc(size=ne, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPostMg(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + comm: DynamicalSystem, + syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], + out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) + check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) + + # synapse and output initialization + syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) + + # references + self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` + self.refs['syn'] = syn # invisible to ``self.node()`` + self.refs['out'] = out # invisible to ``self.node()`` + # unify the access + self.refs['comm'] = comm + self.refs['delay'] = pre.get_aft_update(delay_identifier) + + def update(self): + x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) + current = self.comm(x) + self.refs['syn'].add_current(current) # synapse post current + return current + + +class HalfProjAlignPost(Projection): + """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. + + To simulate an E/I balanced network: + + .. code-block:: + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=4000, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=4000, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:3200]) + self.I(spk[3200:]) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + comm: DynamicalSystem, + syn: JointType[DynamicalSystem, AlignPost], + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(comm, DynamicalSystem) + check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + self.syn = syn + self.out = out + + # synapse and output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # reference + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['post'] = post + self.refs['syn'] = syn + self.refs['out'] = out + # unify the access + self.refs['comm'] = comm + + def update(self, x): + current = self.comm(x) + g = self.syn(self.comm(x)) + self.refs['out'].bind_cond(g) # synapse post current + return current + + +class FullProjAlignPost(Projection): + """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. + + To simulate and define an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=ne, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=ni, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=ne, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=ni, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + comm: DynamicalSystem, + syn: JointType[DynamicalSystem, AlignPost], + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + self.syn = syn + + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) + + # synapse and output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + # unify the access + self.refs['delay'] = delay_cls + self.refs['comm'] = comm + self.refs['syn'] = syn + + def update(self): + x = self.refs['delay'].at(self.name) + g = self.syn(self.comm(x)) + self.refs['out'].bind_cond(g) # synapse post current + return g diff --git a/brainpy/_src/dyn/projections/align_pre.py b/brainpy/_src/dyn/projections/align_pre.py new file mode 100644 index 000000000..2b609322c --- /dev/null +++ b/brainpy/_src/dyn/projections/align_pre.py @@ -0,0 +1,524 @@ +from typing import Optional, Union + +from brainpy import math as bm, check +from brainpy._src.delay import (Delay, DelayAccess, init_delay_by_return, register_delay_by_return) +from brainpy._src.dynsys import DynamicalSystem, Projection +from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData) +from .base import _get_return + +__all__ = [ + 'FullProjAlignPreSDMg', 'FullProjAlignPreDSMg', + 'FullProjAlignPreSD', 'FullProjAlignPreDS', +] + + +def align_pre2_add_bef_update(syn_desc, delay, delay_cls, proj_name=None): + _syn_id = f'Delay({str(delay)}) // {syn_desc.identifier}' + if not delay_cls.has_bef_update(_syn_id): + # delay + delay_access = DelayAccess(delay_cls, delay, delay_entry=proj_name) + # synapse + syn_cls = syn_desc() + # add to "after_updates" + delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls)) + syn = delay_cls.get_bef_update(_syn_id).syn + return syn + + +class _AlignPreMg(DynamicalSystem): + def __init__(self, access, syn): + super().__init__() + self.access = access + self.syn = syn + + def update(self, *args, **kwargs): + return self.syn(self.access()) + + def reset_state(self, *args, **kwargs): + pass + + +def align_pre1_add_bef_update(syn_desc, pre): + _syn_id = f'{syn_desc.identifier} // Delay' + if not pre.has_aft_update(_syn_id): + # "syn_cls" needs an instance of "ProjAutoDelay" + syn_cls: SupportAutoDelay = syn_desc() + delay_cls = init_delay_by_return(syn_cls.return_info()) + # add to "after_updates" + pre.add_aft_update(_syn_id, _AlignPre(syn_cls, delay_cls)) + delay_cls: Delay = pre.get_aft_update(_syn_id).delay + syn = pre.get_aft_update(_syn_id).syn + return delay_cls, syn + + +class _AlignPre(DynamicalSystem): + def __init__(self, syn, delay=None): + super().__init__() + self.syn = syn + self.delay = delay + + def update(self, x): + if self.delay is None: + return x >> self.syn + else: + return x >> self.syn >> self.delay + + def reset_state(self, *args, **kwargs): + pass + + +class FullProjAlignPreSDMg(Projection): + """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. + + To simulate an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + syn: The synaptic dynamics. + delay: The synaptic delay. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: DynamicalSystem, + syn: ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]], + delay: Union[None, int, float], + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, DynamicalSystem) + check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # synapse and delay initialization + delay_cls, syn_cls = align_pre1_add_bef_update(syn, pre) + delay_cls.register_entry(self.name, delay) + + # output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + self.refs['delay'] = delay_cls + self.refs['syn'] = syn_cls + # unify the access + self.refs['comm'] = comm + + def update(self, x=None): + if x is None: + x = self.refs['delay'].at(self.name) + current = self.comm(x) + self.refs['out'].bind_cond(current) + return current + + +class FullProjAlignPreDSMg(Projection): + """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. + + To simulate an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + syn: The synaptic dynamics. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + syn: ParamDescriber[DynamicalSystem], + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(syn, ParamDescriber[DynamicalSystem]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # delay initialization + delay_cls = register_delay_by_return(pre) + + # synapse initialization + syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name) + + # output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to `self.nodes()` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['syn'] = syn_cls + self.refs['out'] = out + # unify the access + self.refs['comm'] = comm + + def update(self): + x = _get_return(self.refs['syn'].return_info()) + current = self.comm(x) + self.refs['out'].bind_cond(current) + return current + + +class FullProjAlignPreSD(Projection): + """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. + + To simulate an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreSD(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSD(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSD(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSD(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + syn: The synaptic dynamics. + delay: The synaptic delay. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: DynamicalSystem, + syn: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, DynamicalSystem) + check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # synapse and delay initialization + delay_cls = init_delay_by_return(syn.return_info()) + delay_cls.register_entry(self.name, delay) + pre.add_aft_update(self.name, _AlignPre(syn, delay_cls)) + + # output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + self.refs['delay'] = delay_cls + self.refs['syn'] = syn + # unify the access + self.refs['comm'] = comm + + def update(self, x=None): + if x is None: + x = self.refs['delay'].at(self.name) + current = self.comm(x) + self.refs['out'].bind_cond(current) + return current + + +class FullProjAlignPreDS(Projection): + """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. + + To simulate an E/I balanced network model: + + .. code-block:: python + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + ne, ni = 3200, 800 + self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreDS(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDS(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDS(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDS(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + syn: The synaptic dynamics. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + syn: DynamicalSystem, + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + out_label: Optional[str] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(syn, DynamicalSystem) + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + self.syn = syn + + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) + + # output initialization + post.add_inp_fun(self.name, out, label=out_label) + + # references + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + self.refs['delay'] = delay_cls + # unify the access + self.refs['syn'] = syn + self.refs['comm'] = comm + + def update(self): + spk = self.refs['delay'].at(self.name) + g = self.comm(self.syn(spk)) + self.refs['out'].bind_cond(g) + return g diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py deleted file mode 100644 index 2616e928b..000000000 --- a/brainpy/_src/dyn/projections/aligns.py +++ /dev/null @@ -1,1053 +0,0 @@ -from typing import Optional, Callable, Union - -from brainpy import math as bm, check -from brainpy._src.delay import (Delay, DelayAccess, delay_identifier, - init_delay_by_return, register_delay_by_return) -from brainpy._src.dynsys import DynamicalSystem, Projection -from brainpy._src.mixin import (JointType, ParamDescriber, ReturnInfo, - SupportAutoDelay, BindCondData, AlignPost) - -__all__ = [ - 'VanillaProj', - 'ProjAlignPostMg1', 'ProjAlignPostMg2', - 'ProjAlignPost1', 'ProjAlignPost2', - 'ProjAlignPreMg1', 'ProjAlignPreMg2', - 'ProjAlignPre1', 'ProjAlignPre2', -] - - -def get_post_repr(out_label, syn, out): - return f'{out_label} // {syn.identifier} // {out.identifier}' - - -def add_inp_fun(out_label, proj_name, out, post): - # synapse and output initialization - if out_label is None: - out_name = proj_name - else: - out_name = f'{out_label} // {proj_name}' - post.add_inp_fun(out_name, out) - - -def align_post_add_bef_update(out_label, syn_desc, out_desc, post, proj_name): - # synapse and output initialization - _post_repr = get_post_repr(out_label, syn_desc, out_desc) - if not post.has_bef_update(_post_repr): - syn_cls = syn_desc() - out_cls = out_desc() - - # synapse and output initialization - if out_label is None: - out_name = proj_name - else: - out_name = f'{out_label} // {proj_name}' - post.add_inp_fun(out_name, out_cls) - post.add_bef_update(_post_repr, _AlignPost(syn_cls, out_cls)) - syn = post.get_bef_update(_post_repr).syn - out = post.get_bef_update(_post_repr).out - return syn, out - - -def align_pre2_add_bef_update(syn_desc, delay, delay_cls, proj_name=None): - _syn_id = f'Delay({str(delay)}) // {syn_desc.identifier}' - if not delay_cls.has_bef_update(_syn_id): - # delay - delay_access = DelayAccess(delay_cls, delay, delay_entry=proj_name) - # synapse - syn_cls = syn_desc() - # add to "after_updates" - delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls)) - syn = delay_cls.get_bef_update(_syn_id).syn - return syn - - -def align_pre1_add_bef_update(syn_desc, pre): - _syn_id = f'{syn_desc.identifier} // Delay' - if not pre.has_aft_update(_syn_id): - # "syn_cls" needs an instance of "ProjAutoDelay" - syn_cls: SupportAutoDelay = syn_desc() - delay_cls = init_delay_by_return(syn_cls.return_info()) - # add to "after_updates" - pre.add_aft_update(_syn_id, _AlignPre(syn_cls, delay_cls)) - delay_cls: Delay = pre.get_aft_update(_syn_id).delay - syn = pre.get_aft_update(_syn_id).syn - return delay_cls, syn - - -class _AlignPre(DynamicalSystem): - def __init__(self, syn, delay=None): - super().__init__() - self.syn = syn - self.delay = delay - - def update(self, x): - if self.delay is None: - return x >> self.syn - else: - return x >> self.syn >> self.delay - - def reset_state(self, *args, **kwargs): - pass - - -class _AlignPost(DynamicalSystem): - def __init__(self, - syn: Callable, - out: JointType[DynamicalSystem, BindCondData]): - super().__init__() - self.syn = syn - self.out = out - - def update(self, *args, **kwargs): - self.out.bind_cond(self.syn(*args, **kwargs)) - - def reset_state(self, *args, **kwargs): - pass - - -class _AlignPreMg(DynamicalSystem): - def __init__(self, access, syn): - super().__init__() - self.access = access - self.syn = syn - - def update(self, *args, **kwargs): - return self.syn(self.access()) - - def reset_state(self, *args, **kwargs): - pass - - -def _get_return(return_info): - if isinstance(return_info, bm.Variable): - return return_info.value - elif isinstance(return_info, ReturnInfo): - return return_info.get_data() - else: - raise NotImplementedError - - -class VanillaProj(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of pre-synaptic neuron group. - - **Code Examples** - - To simulate an E/I balanced network model: - - .. code-block:: - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.syn1 = bp.dyn.Expon(size=3200, tau=5.) - self.syn2 = bp.dyn.Expon(size=800, tau=10.) - self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(self.syn1(spk[:3200])) - self.I(self.syn2(spk[3200:])) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # output initialization - post.add_inp_fun(self.name, out) - - # references - self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` - self.refs['comm'] = comm # unify the access - - def update(self, x): - current = self.comm(x) - self.refs['out'].bind_cond(current) - return current - - -class ProjAlignPostMg1(Projection): - r"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. - - **Code Examples** - - To define an E/I balanced network model. - - .. code-block:: python - - import brainpy as bp - import brainpy.math as bm - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.ProjAlignPostMg1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=4000, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.N) - self.I = bp.dyn.ProjAlignPostMg1(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=4000, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(spk[:3200]) - self.I(spk[3200:]) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - Args: - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - out_label: str. The prefix of the output function. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - comm: DynamicalSystem, - syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], - out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) - check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # synapse and output initialization - syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) - - # references - self.refs = dict(post=post) # invisible to ``self.nodes()`` - self.refs['syn'] = syn - self.refs['out'] = out - self.refs['comm'] = comm # unify the access - - def update(self, x): - current = self.comm(x) - self.refs['syn'].add_current(current) # synapse post current - return current - - -class ProjAlignPostMg2(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. - - **Code Examples** - - To define an E/I balanced network model. - - .. code-block:: python - - import brainpy as bp - import brainpy.math as bm - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPostMg2(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPostMg2(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=ni, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPostMg2(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=ne, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPostMg2(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - comm: DynamicalSystem, - syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], - out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) - check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # delay initialization - delay_cls = register_delay_by_return(pre) - delay_cls.register_entry(self.name, delay) - - # synapse and output initialization - syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) - - # references - self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` - self.refs['syn'] = syn # invisible to ``self.node()`` - self.refs['out'] = out # invisible to ``self.node()`` - # unify the access - self.refs['comm'] = comm - self.refs['delay'] = pre.get_aft_update(delay_identifier) - - def update(self): - x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) - current = self.comm(x) - self.refs['syn'].add_current(current) # synapse post current - return current - - -class ProjAlignPost1(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. - - To simulate an E/I balanced network: - - .. code-block:: - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=4000, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=4000, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(spk[:3200]) - self.I(spk[3200:]) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - comm: DynamicalSystem, - syn: JointType[DynamicalSystem, AlignPost], - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - self.syn = syn - self.out = out - - # synapse and output initialization - add_inp_fun(out_label, self.name, out, post) - - # reference - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['post'] = post - self.refs['syn'] = syn - self.refs['out'] = out - # unify the access - self.refs['comm'] = comm - - def update(self, x): - current = self.comm(x) - g = self.syn(self.comm(x)) - self.refs['out'].bind_cond(g) # synapse post current - return current - - -class ProjAlignPost2(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. - - To simulate and define an E/I balanced network model: - - .. code-block:: python - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPost2(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=ne, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPost2(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=ni, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPost2(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=ne, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPost2(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=ni, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - comm: The synaptic communication. - syn: The synaptic dynamics. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - comm: DynamicalSystem, - syn: JointType[DynamicalSystem, AlignPost], - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - self.syn = syn - - # delay initialization - delay_cls = register_delay_by_return(pre) - delay_cls.register_entry(self.name, delay) - - # synapse and output initialization - add_inp_fun(out_label, self.name, out, post) - - # references - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['out'] = out - # unify the access - self.refs['delay'] = delay_cls - self.refs['comm'] = comm - self.refs['syn'] = syn - - def update(self): - x = self.refs['delay'].at(self.name) - g = self.syn(self.comm(x)) - self.refs['out'].bind_cond(g) # synapse post current - return g - - -class ProjAlignPreMg1(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. - - To simulate an E/I balanced network model: - - .. code-block:: python - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - syn: The synaptic dynamics. - delay: The synaptic delay. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: DynamicalSystem, - syn: ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]], - delay: Union[None, int, float], - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, DynamicalSystem) - check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # synapse and delay initialization - delay_cls, syn_cls = align_pre1_add_bef_update(syn, pre) - delay_cls.register_entry(self.name, delay) - - # output initialization - add_inp_fun(out_label, self.name, out, post) - - # references - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['out'] = out - self.refs['delay'] = delay_cls - self.refs['syn'] = syn_cls - # unify the access - self.refs['comm'] = comm - - def update(self, x=None): - if x is None: - x = self.refs['delay'].at(self.name) - current = self.comm(x) - self.refs['out'].bind_cond(current) - return current - - -class ProjAlignPreMg2(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. - - To simulate an E/I balanced network model: - - .. code-block:: python - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - syn: The synaptic dynamics. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - syn: ParamDescriber[DynamicalSystem], - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(syn, ParamDescriber[DynamicalSystem]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # delay initialization - delay_cls = register_delay_by_return(pre) - - # synapse initialization - syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name) - - # output initialization - add_inp_fun(out_label, self.name, out, post) - - # references - self.refs = dict() - # invisible to `self.nodes()` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['syn'] = syn_cls - self.refs['out'] = out - # unify the access - self.refs['comm'] = comm - - def update(self): - x = _get_return(self.refs['syn'].return_info()) - current = self.comm(x) - self.refs['out'].bind_cond(current) - return current - - -class ProjAlignPre1(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. - - To simulate an E/I balanced network model: - - .. code-block:: python - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - syn: The synaptic dynamics. - delay: The synaptic delay. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: DynamicalSystem, - syn: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, DynamicalSystem) - check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - - # synapse and delay initialization - delay_cls = init_delay_by_return(syn.return_info()) - delay_cls.register_entry(self.name, delay) - pre.add_aft_update(self.name, _AlignPre(syn, delay_cls)) - - # output initialization - add_inp_fun(out_label, self.name, out, post) - - # references - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['out'] = out - self.refs['delay'] = delay_cls - self.refs['syn'] = syn - # unify the access - self.refs['comm'] = comm - - def update(self, x=None): - if x is None: - x = self.refs['delay'].at(self.name) - current = self.comm(x) - self.refs['out'].bind_cond(current) - return current - - -class ProjAlignPre2(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. - - To simulate an E/I balanced network model: - - .. code-block:: python - - class EINet(bp.DynSysGroup): - def __init__(self): - super().__init__() - ne, ni = 3200, 800 - self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - - Args: - pre: The pre-synaptic neuron group. - delay: The synaptic delay. - syn: The synaptic dynamics. - comm: The synaptic communication. - out: The synaptic output. - post: The post-synaptic neuron group. - name: str. The projection name. - mode: Mode. The computing mode. - """ - - def __init__( - self, - pre: JointType[DynamicalSystem, SupportAutoDelay], - delay: Union[None, int, float], - syn: DynamicalSystem, - comm: DynamicalSystem, - out: JointType[DynamicalSystem, BindCondData], - post: DynamicalSystem, - out_label: Optional[str] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, mode=mode) - - # synaptic models - check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(syn, DynamicalSystem) - check.is_instance(comm, DynamicalSystem) - check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, DynamicalSystem) - self.comm = comm - self.syn = syn - - # delay initialization - delay_cls = register_delay_by_return(pre) - delay_cls.register_entry(self.name, delay) - - # output initialization - add_inp_fun(out_label, self.name, out, post) - - # references - self.refs = dict() - # invisible to ``self.nodes()`` - self.refs['pre'] = pre - self.refs['post'] = post - self.refs['out'] = out - self.refs['delay'] = delay_cls - # unify the access - self.refs['syn'] = syn - self.refs['comm'] = comm - - def update(self): - spk = self.refs['delay'].at(self.name) - g = self.comm(self.syn(spk)) - self.refs['out'].bind_cond(g) - return g diff --git a/brainpy/_src/dyn/projections/base.py b/brainpy/_src/dyn/projections/base.py new file mode 100644 index 000000000..44a2273a4 --- /dev/null +++ b/brainpy/_src/dyn/projections/base.py @@ -0,0 +1,12 @@ +from brainpy import math as bm +from brainpy._src.mixin import ReturnInfo + + +def _get_return(return_info): + if isinstance(return_info, bm.Variable): + return return_info.value + elif isinstance(return_info, ReturnInfo): + return return_info.get_data() + else: + raise NotImplementedError + diff --git a/brainpy/_src/dyn/projections/delta.py b/brainpy/_src/dyn/projections/delta.py new file mode 100644 index 000000000..616f83df6 --- /dev/null +++ b/brainpy/_src/dyn/projections/delta.py @@ -0,0 +1,203 @@ +from typing import Optional, Union + +from brainpy import math as bm, check +from brainpy._src.delay import (delay_identifier, register_delay_by_return) +from brainpy._src.dynsys import DynamicalSystem, Projection +from brainpy._src.mixin import (JointType, SupportAutoDelay) + +__all__ = [ + 'HalfProjDelta', 'FullProjDelta', +] + + +class _Delta: + def __init__(self): + self._cond = None + + def bind_cond(self, cond): + self._cond = cond + + def __call__(self, *args, **kwargs): + r = self._cond + return r + + +class HalfProjDelta(Projection): + """Delta synaptic projection. + + **Model Descriptions** + + .. math:: + + I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D) + + where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, + :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, + :math:`C` the set of neurons connected to the post-synaptic neuron, + and :math:`D` the transmission delay of chemical synapses. + For simplicity, the rise and decay phases of post-synaptic currents are + omitted in this model. + + + **Code Examples** + + .. code-block:: + + import brainpy as bp + import brainpy.math as bm + + class Net(bp.DynamicalSystem): + def __init__(self): + super().__init__() + + self.pre = bp.dyn.PoissonGroup(10, 100.) + self.post = bp.dyn.LifRef(1) + self.syn = bp.dyn.HalfProjDelta(bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) + + def update(self): + self.syn(self.pre()) + self.post() + return self.post.V.value + + net = Net() + indices = bm.arange(1000).to_numpy() + vs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices, vs, show=True) + + Args: + comm: DynamicalSystem. The synaptic communication. + post: DynamicalSystem. The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + comm: DynamicalSystem, + post: DynamicalSystem, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(comm, DynamicalSystem) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # output initialization + out = _Delta() + post.add_inp_fun(self.name, out, category='delta') + + # references + self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` + self.refs['comm'] = comm # unify the access + + def update(self, x): + # call the communication + current = self.comm(x) + # bind the output + self.refs['out'].bind_cond(current) + # return the current, if needed + return current + + +class FullProjDelta(Projection): + """Delta synaptic projection. + + **Model Descriptions** + + .. math:: + + I_{syn} (t) = \sum_{j\in C} g_{\mathrm{max}} * \delta(t-t_j-D) + + where :math:`g_{\mathrm{max}}` denotes the chemical synaptic strength, + :math:`t_j` the spiking moment of the presynaptic neuron :math:`j`, + :math:`C` the set of neurons connected to the post-synaptic neuron, + and :math:`D` the transmission delay of chemical synapses. + For simplicity, the rise and decay phases of post-synaptic currents are + omitted in this model. + + + **Code Examples** + + To simulate an E/I balanced network model: + + .. code-block:: + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.syn1 = bp.dyn.Expon(size=3200, tau=5.) + self.syn2 = bp.dyn.Expon(size=800, tau=10.) + self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(self.syn1(spk[:3200])) + self.I(self.syn2(spk[3200:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + comm: DynamicalSystem. The synaptic communication. + post: DynamicalSystem. The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + pre: JointType[DynamicalSystem, SupportAutoDelay], + delay: Union[None, int, float], + comm: DynamicalSystem, + post: DynamicalSystem, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) + check.is_instance(comm, DynamicalSystem) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # delay initialization + delay_cls = register_delay_by_return(pre) + delay_cls.register_entry(self.name, delay) + + # output initialization + out = _Delta() + post.add_inp_fun(self.name, out, category='delta') + + # references + self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()`` + self.refs['comm'] = comm # unify the access + self.refs['delay'] = pre.get_aft_update(delay_identifier) + + def update(self): + # get delay + x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) + # call the communication + current = self.comm(x) + # bind the output + self.refs['out'].bind_cond(current) + # return the current, if needed + return current diff --git a/brainpy/_src/dyn/projections/inputs.py b/brainpy/_src/dyn/projections/inputs.py index f0001988b..dd1e1e3df 100644 --- a/brainpy/_src/dyn/projections/inputs.py +++ b/brainpy/_src/dyn/projections/inputs.py @@ -1,96 +1,167 @@ -from typing import Optional, Any +import numbers +from typing import Any +from typing import Union, Optional -from brainpy import math as bm +from brainpy import check, math as bm +from brainpy._src.context import share from brainpy._src.dynsys import Dynamic +from brainpy._src.dynsys import Projection from brainpy._src.mixin import SupportAutoDelay from brainpy.types import Shape __all__ = [ - 'InputVar', + 'InputVar', + 'PoissonInput', ] class InputVar(Dynamic, SupportAutoDelay): - """Define an input variable. + """Define an input variable. - Example:: + Example:: + + import brainpy as bp - import brainpy as bp - - class Exponential(bp.Projection): - def __init__(self, pre, post, prob, g_max, tau, E=0.): - super().__init__() - self.proj = bp.dyn.ProjAlignPostMg2( - pre=pre, - delay=None, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), - syn=bp.dyn.Expon.desc(post.num, tau=tau), - out=bp.dyn.COBA.desc(E=E), - post=post, - ) - - - class EINet(bp.DynSysGroup): - def __init__(self, num_exc, num_inh, method='exp_auto'): - super(EINet, self).__init__() - - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.), method=method) - self.E = bp.dyn.LifRef(num_exc, **pars) - self.I = bp.dyn.LifRef(num_inh, **pars) - - # synapses - w_e = 0.6 # excitatory synaptic weight - w_i = 6.7 # inhibitory synaptic weight - - # Neurons connect to each other randomly with a connection probability of 2% - self.E2E = Exponential(self.E, self.E, 0.02, g_max=w_e, tau=5., E=0.) - self.E2I = Exponential(self.E, self.I, 0.02, g_max=w_e, tau=5., E=0.) - self.I2E = Exponential(self.I, self.E, 0.02, g_max=w_i, tau=10., E=-80.) - self.I2I = Exponential(self.I, self.I, 0.02, g_max=w_i, tau=10., E=-80.) - - # define input variables given to E/I populations - self.Ein = bp.dyn.InputVar(self.E.varshape) - self.Iin = bp.dyn.InputVar(self.I.varshape) - self.E.add_inp_fun('', self.Ein) - self.I.add_inp_fun('', self.Iin) - - - net = EINet(3200, 800, method='exp_auto') # "method": the numerical integrator method - runner = bp.DSRunner(net, monitors=['E.spike', 'I.spike'], inputs=[('Ein.input', 20.), ('Iin.input', 20.)]) - runner.run(100.) - - # visualization - bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], - title='Spikes of Excitatory Neurons', show=True) - bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], - title='Spikes of Inhibitory Neurons', show=True) - - - """ - def __init__( - self, - size: Shape, - keep_size: bool = False, - sharding: Optional[Any] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - method: str = 'exp_auto' - ): - super().__init__(size=size, keep_size=keep_size, sharding=sharding, name=name, mode=mode, method=method) - - self.reset_state(self.mode) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.input = self.init_variable(bm.zeros, batch_or_mode) - - def update(self, *args, **kwargs): - return self.input.value - - def return_info(self): - return self.input - - def clear_input(self, *args, **kwargs): - self.reset_state(self.mode) + class Exponential(bp.Projection): + def __init__(self, pre, post, prob, g_max, tau, E=0.): + super().__init__() + self.proj = bp.dyn.ProjAlignPostMg2( + pre=pre, + delay=None, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), + syn=bp.dyn.Expon.desc(post.num, tau=tau), + out=bp.dyn.COBA.desc(E=E), + post=post, + ) + + + class EINet(bp.DynSysGroup): + def __init__(self, num_exc, num_inh, method='exp_auto'): + super(EINet, self).__init__() + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.), method=method) + self.E = bp.dyn.LifRef(num_exc, **pars) + self.I = bp.dyn.LifRef(num_inh, **pars) + + # synapses + w_e = 0.6 # excitatory synaptic weight + w_i = 6.7 # inhibitory synaptic weight + + # Neurons connect to each other randomly with a connection probability of 2% + self.E2E = Exponential(self.E, self.E, 0.02, g_max=w_e, tau=5., E=0.) + self.E2I = Exponential(self.E, self.I, 0.02, g_max=w_e, tau=5., E=0.) + self.I2E = Exponential(self.I, self.E, 0.02, g_max=w_i, tau=10., E=-80.) + self.I2I = Exponential(self.I, self.I, 0.02, g_max=w_i, tau=10., E=-80.) + + # define input variables given to E/I populations + self.Ein = bp.dyn.InputVar(self.E.varshape) + self.Iin = bp.dyn.InputVar(self.I.varshape) + self.E.add_inp_fun('', self.Ein) + self.I.add_inp_fun('', self.Iin) + + + net = EINet(3200, 800, method='exp_auto') # "method": the numerical integrator method + runner = bp.DSRunner(net, monitors=['E.spike', 'I.spike'], inputs=[('Ein.input', 20.), ('Iin.input', 20.)]) + runner.run(100.) + + # visualization + bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], + title='Spikes of Excitatory Neurons', show=True) + bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'], + title='Spikes of Inhibitory Neurons', show=True) + + + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + sharding: Optional[Any] = None, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + method: str = 'exp_auto' + ): + super().__init__(size=size, keep_size=keep_size, sharding=sharding, name=name, mode=mode, method=method) + + self.reset_state(self.mode) + + def reset_state(self, batch_or_mode=None, **kwargs): + self.input = self.init_variable(bm.zeros, batch_or_mode) + + def update(self, *args, **kwargs): + return self.input.value + + def return_info(self): + return self.input + + def clear_input(self, *args, **kwargs): + self.reset_state(self.mode) + + +class PoissonInput(Projection): + """Poisson Input to the given :py:class:`~.Variable`. + + Adds independent Poisson input to a target variable. For large + numbers of inputs, this is much more efficient than creating a + `PoissonGroup`. The synaptic events are generated randomly during the + simulation and are not preloaded and stored in memory. All the inputs must + target the same variable, have the same frequency and same synaptic weight. + All neurons in the target variable receive independent realizations of + Poisson spike trains. + + Args: + target_var: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`. + num_input: The number of inputs. + freq: The frequency of each of the inputs. Must be a scalar. + weight: The synaptic weight. Must be a scalar. + name: The target name. + mode: The computing mode. + """ + + def __init__( + self, + target_var: bm.Variable, + num_input: int, + freq: Union[int, float], + weight: Union[int, float], + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + if not isinstance(target_var, bm.Variable): + raise TypeError(f'"target_var" must be an instance of Variable. ' + f'But we got {type(target_var)}: {target_var}') + self.target_var = target_var + self.num_input = check.is_integer(num_input, min_bound=1) + self.freq = check.is_float(freq, min_bound=0., allow_int=True) + self.weight = check.is_float(weight, allow_int=True) + + def reset_state(self, *args, **kwargs): + pass + + def update(self): + p = self.freq * share['dt'] / 1e3 + a = self.num_input * p + b = self.num_input * (1 - p) + + if isinstance(share['dt'], numbers.Number): # dt is not traced + if (a > 5) and (b > 5): + inp = bm.random.normal(a, b * p, self.target_var.shape) + else: + inp = bm.random.binomial(self.num_input, p, self.target_var.shape) + + else: # dt is traced + inp = bm.cond((a > 5) * (b > 5), + lambda: bm.random.normal(a, b * p, self.target_var.shape), + lambda: bm.random.binomial(self.num_input, p, self.target_var.shape)) + + # inp = bm.sharding.partition(inp, self.target_var.sharding) + self.target_var += inp * self.weight + + def __repr__(self): + return f'{self.name}(num_input={self.num_input}, freq={self.freq}, weight={self.weight})' diff --git a/brainpy/_src/dyn/projections/others.py b/brainpy/_src/dyn/projections/others.py deleted file mode 100644 index 72a77298f..000000000 --- a/brainpy/_src/dyn/projections/others.py +++ /dev/null @@ -1,81 +0,0 @@ -import numbers -import warnings -from typing import Union, Optional - -from brainpy import check, math as bm -from brainpy._src.context import share -from brainpy._src.dynsys import Projection - -__all__ = [ - 'PoissonInput', -] - - -class PoissonInput(Projection): - """Poisson Input to the given :py:class:`~.Variable`. - - Adds independent Poisson input to a target variable. For large - numbers of inputs, this is much more efficient than creating a - `PoissonGroup`. The synaptic events are generated randomly during the - simulation and are not preloaded and stored in memory. All the inputs must - target the same variable, have the same frequency and same synaptic weight. - All neurons in the target variable receive independent realizations of - Poisson spike trains. - - Args: - target_var: The variable that is targeted by this input. Should be an instance of :py:class:`~.Variable`. - num_input: The number of inputs. - freq: The frequency of each of the inputs. Must be a scalar. - weight: The synaptic weight. Must be a scalar. - name: The target name. - mode: The computing mode. - """ - - def __init__( - self, - target_var: bm.Variable, - num_input: int, - freq: Union[int, float], - weight: Union[int, float], - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - seed=None - ): - super().__init__(name=name, mode=mode) - - if seed is not None: - warnings.warn('') - - if not isinstance(target_var, bm.Variable): - raise TypeError(f'"target_var" must be an instance of Variable. ' - f'But we got {type(target_var)}: {target_var}') - self.target_var = target_var - self.num_input = check.is_integer(num_input, min_bound=1) - self.freq = check.is_float(freq, min_bound=0., allow_int=True) - self.weight = check.is_float(weight, allow_int=True) - - def reset_state(self, *args, **kwargs): - pass - - def update(self): - p = self.freq * share['dt'] / 1e3 - a = self.num_input * p - b = self.num_input * (1 - p) - - if isinstance(share['dt'], numbers.Number): # dt is not traced - if (a > 5) and (b > 5): - inp = bm.random.normal(a, b * p, self.target_var.shape) - else: - inp = bm.random.binomial(self.num_input, p, self.target_var.shape) - - else: # dt is traced - inp = bm.cond((a > 5) * (b > 5), - lambda: bm.random.normal(a, b * p, self.target_var.shape), - lambda: bm.random.binomial(self.num_input, p, self.target_var.shape), - ()) - - # inp = bm.sharding.partition(inp, self.target_var.sharding) - self.target_var += inp * self.weight - - def __repr__(self): - return f'{self.name}(num_input={self.num_input}, freq={self.freq}, weight={self.weight})' diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index 3fb3c1232..598a7496f 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -7,8 +7,9 @@ from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData, AlignPost, SupportSTDP) from brainpy.types import ArrayType -from .aligns import (_get_return, align_post_add_bef_update, - align_pre2_add_bef_update, add_inp_fun) +from .align_post import (align_post_add_bef_update, ) +from .align_pre import (align_pre2_add_bef_update, ) +from .base import (_get_return, ) __all__ = [ 'STDP_Song2000', @@ -165,7 +166,7 @@ def __init__( else: syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre') out_cls = out() - add_inp_fun(out_label, self.name, out_cls, post) + post.add_inp_fun(self.name, out_cls, label=out_label) # references self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index a4173c7ba..b8884f327 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -86,7 +86,7 @@ def update(self, I_pre, I_post): conductance = self.syn.refs['syn'].g Apre = self.syn.refs['pre_trace'].g Apost = self.syn.refs['post_trace'].g - current = self.post.sum_inputs(self.post.V) + current = self.post.sum_current_inputs(self.post.V) if comm_method == 'dense': w = self.syn.comm.W.flatten() else: diff --git a/brainpy/_src/dyn/projections/tests/test_aligns.py b/brainpy/_src/dyn/projections/tests/test_aligns.py index 32b072e5a..90500a26f 100644 --- a/brainpy/_src/dyn/projections/tests/test_aligns.py +++ b/brainpy/_src/dyn/projections/tests/test_aligns.py @@ -19,7 +19,7 @@ def __init__(self, scale=1., inp=20., delay=None): prob = 80 / (4000 * scale) - self.E2I = bp.dyn.ProjAlignPreMg1( + self.E2I = bp.dyn.FullProjAlignPreSDMg( pre=self.E, syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), delay=delay, @@ -27,7 +27,7 @@ def __init__(self, scale=1., inp=20., delay=None): out=bp.dyn.COBA(E=0.), post=self.I, ) - self.E2E = bp.dyn.ProjAlignPreMg1( + self.E2E = bp.dyn.FullProjAlignPreSDMg( pre=self.E, syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), delay=delay, @@ -35,7 +35,7 @@ def __init__(self, scale=1., inp=20., delay=None): out=bp.dyn.COBA(E=0.), post=self.E, ) - self.I2E = bp.dyn.ProjAlignPreMg1( + self.I2E = bp.dyn.FullProjAlignPreSDMg( pre=self.I, syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), delay=delay, @@ -43,7 +43,7 @@ def __init__(self, scale=1., inp=20., delay=None): out=bp.dyn.COBA(E=-80.), post=self.E, ) - self.I2I = bp.dyn.ProjAlignPreMg1( + self.I2I = bp.dyn.FullProjAlignPreSDMg( pre=self.I, syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), delay=delay, @@ -90,7 +90,7 @@ def __init__(self, scale, inp=20., ltc=True, delay=None): prob = 80 / (4000 * scale) - self.E2E = bp.dyn.ProjAlignPostMg2( + self.E2E = bp.dyn.FullProjAlignPostMg( pre=self.E, delay=delay, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), @@ -98,7 +98,7 @@ def __init__(self, scale, inp=20., ltc=True, delay=None): out=bp.dyn.COBA.desc(E=0.), post=self.E, ) - self.E2I = bp.dyn.ProjAlignPostMg2( + self.E2I = bp.dyn.FullProjAlignPostMg( pre=self.E, delay=delay, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), @@ -106,7 +106,7 @@ def __init__(self, scale, inp=20., ltc=True, delay=None): out=bp.dyn.COBA.desc(E=0.), post=self.I, ) - self.I2E = bp.dyn.ProjAlignPostMg2( + self.I2E = bp.dyn.FullProjAlignPostMg( pre=self.I, delay=delay, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), @@ -114,7 +114,7 @@ def __init__(self, scale, inp=20., ltc=True, delay=None): out=bp.dyn.COBA.desc(E=-80.), post=self.E, ) - self.I2I = bp.dyn.ProjAlignPostMg2( + self.I2I = bp.dyn.FullProjAlignPostMg( pre=self.I, delay=delay, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), @@ -163,14 +163,14 @@ def __init__(self, scale=1.): self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), - syn=bp.dyn.Expon(size=num, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), - syn=bp.dyn.Expon(size=num, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.N) + self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), + syn=bp.dyn.Expon(size=num, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), + syn=bp.dyn.Expon(size=num, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N) def update(self, input): spk = self.delay.at('I') @@ -198,30 +198,30 @@ def __init__(self, scale, delay=None): V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPost2(pre=self.E, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), - syn=bp.dyn.Expon(size=ne, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPost2(pre=self.E, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), - syn=bp.dyn.Expon(size=ni, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPost2(pre=self.I, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), - syn=bp.dyn.Expon(size=ne, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPost2(pre=self.I, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), - syn=bp.dyn.Expon(size=ni, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.I) + self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ne, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ni, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ne, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ni, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.I) def update(self, inp): self.E2E() @@ -292,30 +292,30 @@ def __init__(self, scale=1., delay=None): V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) + self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) def update(self, inp): self.E2E() @@ -350,30 +350,30 @@ def __init__(self, scale=1., delay=None): V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=delay, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=delay, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=delay, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=delay, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) + self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=delay, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=delay, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=delay, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=delay, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) def update(self, inp): self.E2E() diff --git a/brainpy/_src/dyn/projections/tests/test_delta.py b/brainpy/_src/dyn/projections/tests/test_delta.py new file mode 100644 index 000000000..8e16a128a --- /dev/null +++ b/brainpy/_src/dyn/projections/tests/test_delta.py @@ -0,0 +1,51 @@ +import matplotlib.pyplot as plt + +import brainpy as bp +import brainpy.math as bm + + +class NetForHalfProj(bp.DynamicalSystem): + def __init__(self): + super().__init__() + + self.pre = bp.dyn.PoissonGroup(10, 100.) + self.post = bp.dyn.LifRef(1) + self.syn = bp.dyn.HalfProjDelta(bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) + + def update(self): + self.syn(self.pre()) + self.post() + return self.post.V.value + + +def test1(): + net = NetForHalfProj() + indices = bm.arange(1000).to_numpy() + vs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices, vs, show=True) + plt.close('all') + + +class NetForFullProj(bp.DynamicalSystem): + def __init__(self): + super().__init__() + + self.pre = bp.dyn.PoissonGroup(10, 100.) + self.post = bp.dyn.LifRef(1) + self.syn = bp.dyn.FullProjDelta(self.pre, 0., bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) + + def update(self): + self.syn() + self.pre() + self.post() + return self.post.V.value + + +def test2(): + net = NetForFullProj() + indices = bm.arange(1000).to_numpy() + vs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices, vs, show=True) + plt.close('all') + + diff --git a/brainpy/_src/dyn/projections/vanilla.py b/brainpy/_src/dyn/projections/vanilla.py new file mode 100644 index 000000000..15773d231 --- /dev/null +++ b/brainpy/_src/dyn/projections/vanilla.py @@ -0,0 +1,83 @@ +from typing import Optional + +from brainpy import math as bm, check +from brainpy._src.dynsys import DynamicalSystem, Projection +from brainpy._src.mixin import (JointType, BindCondData) + +__all__ = [ + 'VanillaProj', +] + + +class VanillaProj(Projection): + """Synaptic projection which defines the synaptic computation with the dimension of pre-synaptic neuron group. + + **Code Examples** + + To simulate an E/I balanced network model: + + .. code-block:: + + class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.syn1 = bp.dyn.Expon(size=3200, tau=5.) + self.syn2 = bp.dyn.Expon(size=800, tau=10.) + self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(self.syn1(spk[:3200])) + self.I(self.syn2(spk[3200:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(1000) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + Args: + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + + def __init__( + self, + comm: DynamicalSystem, + out: JointType[DynamicalSystem, BindCondData], + post: DynamicalSystem, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + check.is_instance(comm, DynamicalSystem) + check.is_instance(out, JointType[DynamicalSystem, BindCondData]) + check.is_instance(post, DynamicalSystem) + self.comm = comm + + # output initialization + post.add_inp_fun(self.name, out) + + # references + self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` + self.refs['comm'] = comm # unify the access + + def update(self, x): + current = self.comm(x) + self.refs['out'].bind_cond(current) + return current diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py index 4a6b9ddb6..5fad9482d 100644 --- a/brainpy/_src/dyn/synapses/abstract_models.py +++ b/brainpy/_src/dyn/synapses/abstract_models.py @@ -10,7 +10,6 @@ from brainpy.types import ArrayType __all__ = [ - 'Delta', 'Expon', 'DualExpon', 'DualExponV2', @@ -21,69 +20,6 @@ ] -class Delta(SynDyn, AlignPost): - r"""Delta synapse model. - - **Model Descriptions** - - The single exponential decay synapse model assumes the release of neurotransmitter, - its diffusion across the cleft, the receptor binding, and channel opening all happen - very quickly, so that the channels instantaneously jump from the closed to the open state. - Therefore, its expression is given by - - .. math:: - - g_{\mathrm{syn}}(t)=g_{\mathrm{max}} e^{-\left(t-t_{0}\right) / \tau} - - where :math:`\tau_{delay}` is the time constant of the synaptic state decay, - :math:`t_0` is the time of the pre-synaptic spike, - :math:`g_{\mathrm{max}}` is the maximal conductance. - - Accordingly, the differential form of the exponential synapse is given by - - .. math:: - - \begin{aligned} - & \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}). - \end{aligned} - - .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. - "The Synapse." Principles of Computational Modelling in Neuroscience. - Cambridge: Cambridge UP, 2011. 172-95. Print. - - """ - - def __init__( - self, - size: Union[int, Sequence[int]], - keep_size: bool = False, - sharding: Optional[Sequence[str]] = None, - name: Optional[str] = None, - mode: Optional[bm.Mode] = None, - ): - super().__init__(name=name, - mode=mode, - size=size, - keep_size=keep_size, - sharding=sharding) - - self.reset_state(self.mode) - - def reset_state(self, batch_or_mode=None, **kwargs): - self.g = self.init_variable(bm.zeros, batch_or_mode) - - def update(self, x=None): - if x is not None: - self.g.value += x - return self.g.value - - def add_current(self, x): - self.g.value += x - - def return_info(self): - return self.g - - class Expon(SynDyn, AlignPost): r"""Exponential decay synapse model. @@ -1030,4 +966,4 @@ def return_info(self): lambda shape: self.u * self.x) -STP.__doc__ = STP.__doc__ % (pneu_doc,) \ No newline at end of file +STP.__doc__ = STP.__doc__ % (pneu_doc,) diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py index a2bc1bdd5..55bac7111 100644 --- a/brainpy/_src/dynold/synapses/base.py +++ b/brainpy/_src/dynold/synapses/base.py @@ -6,7 +6,7 @@ from brainpy import math as bm from brainpy._src.connect import TwoEndConnector, One2One, All2All from brainpy._src.dnn import linear -from brainpy._src.dyn import projections +from brainpy._src.dyn.projections.conn import SynConn from brainpy._src.dyn.base import NeuDyn from brainpy._src.dynsys import DynamicalSystem from brainpy._src.initialize import parameter @@ -29,7 +29,7 @@ class _SynapseComponent(DynamicalSystem): synaptic long-term plasticity, and others. """ '''Master of this component.''' - master: projections.SynConn + master: SynConn def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -50,9 +50,9 @@ def isregistered(self, val: bool): def reset_state(self, batch_size=None): pass - def register_master(self, master: projections.SynConn): - if not isinstance(master, projections.SynConn): - raise TypeError(f'master must be instance of {projections.SynConn.__name__}, but we got {type(master)}') + def register_master(self, master: SynConn): + if not isinstance(master, SynConn): + raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}') if self.isregistered: raise ValueError(f'master has been registered, but we got another master going to be registered.') if hasattr(self, 'master') and self.master != master: @@ -90,7 +90,7 @@ def __init__( f'But we got {type(target_var)}') self.target_var: Optional[bm.Variable] = target_var - def register_master(self, master: projections.SynConn): + def register_master(self, master: SynConn): super().register_master(master) # initialize target variable to output @@ -125,7 +125,7 @@ def clone(self): return _NullSynOut() -class TwoEndConn(projections.SynConn): +class TwoEndConn(SynConn): """Base class to model synaptic connections. Parameters diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index ee1fb2b8f..a070a295a 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -91,7 +91,8 @@ def __init__( # Attribute for "SupportInputProj" # each instance of "SupportInputProj" should have a "cur_inputs" attribute - self.cur_inputs = bm.node_dict() + self.current_inputs = bm.node_dict() + self.delta_inputs = bm.node_dict() # the before- / after-updates used for computing # added after the version of 2.4.3 diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index 6ac7f3a3d..323fe872c 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -21,7 +21,6 @@ DynamicalSystem = None delay_identifier, init_delay_by_return = None, None - __all__ = [ 'MixIn', 'ParamDesc', @@ -53,7 +52,6 @@ def _get_dynsys(): return DynamicalSystem - class MixIn(object): """Base MixIn object. @@ -378,55 +376,119 @@ def get_delay_var(self, name): class SupportInputProj(MixIn): """The :py:class:`~.MixIn` that receives the input projections. - Note that the subclass should define a ``cur_inputs`` attribute. + Note that the subclass should define a ``cur_inputs`` attribute. Otherwise, + the input function utilities cannot be used. """ - cur_inputs: bm.node_dict + current_inputs: bm.node_dict + delta_inputs: bm.node_dict - def add_inp_fun(self, key: Any, fun: Callable): + def add_inp_fun(self, key: str, fun: Callable, label: Optional[str] = None, category: str = 'current'): """Add an input function. Args: - key: The dict key. - fun: The function to generate inputs. + key: str. The dict key. + fun: Callable. The function to generate inputs. + label: str. The input label. + category: str. The input category, should be ``current`` (the current) or + ``delta`` (the delta synapse, indicating the delta function). """ if not callable(fun): raise TypeError('Must be a function.') - if key in self.cur_inputs: - raise ValueError(f'Key "{key}" has been defined and used.') - self.cur_inputs[key] = fun - def get_inp_fun(self, key): + key = self._input_label_repr(key, label) + if category == 'current': + if key in self.current_inputs: + raise ValueError(f'Key "{key}" has been defined and used.') + self.current_inputs[key] = fun + elif category == 'delta': + if key in self.delta_inputs: + raise ValueError(f'Key "{key}" has been defined and used.') + self.delta_inputs[key] = fun + else: + raise NotImplementedError(f'Unknown category: {category}. Only support "current" and "delta".') + + def get_inp_fun(self, key: str): """Get the input function. Args: - key: The key. + key: str. The key. Returns: The input function which generates currents. """ - return self.cur_inputs.get(key) + if key in self.current_inputs: + return self.current_inputs[key] + elif key in self.delta_inputs: + return self.delta_inputs[key] + else: + raise ValueError(f'Unknown key: {key}') + + def sum_current_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): + """Summarize all current inputs by the defined input functions ``.current_inputs``. + + Args: + *args: The arguments for input functions. + init: The initial input data. + label: str. The input label. + **kwargs: The arguments for input functions. + + Returns: + The total currents. + """ + if label is None: + for key, out in self.current_inputs.items(): + init = init + out(*args, **kwargs) + else: + label_repr = self._input_label_start(label) + for key, out in self.current_inputs.items(): + if key.startswith(label_repr): + init = init + out(*args, **kwargs) + return init - def sum_inputs(self, *args, init=0., label=None, **kwargs): - """Summarize all inputs by the defined input functions ``.cur_inputs``. + def sum_delta_inputs(self, *args, init: Any = 0., label: Optional[str] = None, **kwargs): + """Summarize all delta inputs by the defined input functions ``.delta_inputs``. Args: *args: The arguments for input functions. init: The initial input data. + label: str. The input label. **kwargs: The arguments for input functions. Returns: The total currents. """ if label is None: - for key, out in self.cur_inputs.items(): + for key, out in self.delta_inputs.items(): init = init + out(*args, **kwargs) else: - for key, out in self.cur_inputs.items(): - if key.startswith(label + ' // '): + label_repr = self._input_label_start(label) + for key, out in self.delta_inputs.items(): + if key.startswith(label_repr): init = init + out(*args, **kwargs) return init + @classmethod + def _input_label_start(cls, label: str): + # unify the input label repr. + return f'{label} // ' + + @classmethod + def _input_label_repr(cls, name: str, label: Optional[str] = None): + # unify the input label repr. + return name if label is None else (cls._input_label_start(label) + str(name)) + + # deprecated # + # ---------- # + + @property + def cur_inputs(self): + return self.current_inputs + + def sum_inputs(self, *args, **kwargs): + warnings.warn('Please use ".sum_current_inputs()" instead. ".sum_inputs()" will be removed.', UserWarning) + return self.sum_current_inputs(*args, **kwargs) + class SupportReturnInfo(MixIn): """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`.""" diff --git a/brainpy/dyn/projections.py b/brainpy/dyn/projections.py index b2f4c5304..23e1a7485 100644 --- a/brainpy/dyn/projections.py +++ b/brainpy/dyn/projections.py @@ -1,24 +1,24 @@ - -from brainpy._src.dyn.projections.aligns import ( - VanillaProj, - ProjAlignPostMg1, - ProjAlignPostMg2, - ProjAlignPost1, - ProjAlignPost2, - ProjAlignPreMg1, - ProjAlignPreMg2, - ProjAlignPre1, - ProjAlignPre2, +from brainpy._src.dyn.projections.vanilla import VanillaProj +from brainpy._src.dyn.projections.delta import ( + HalfProjDelta, + FullProjDelta, +) +from brainpy._src.dyn.projections.align_post import ( + HalfProjAlignPostMg, + FullProjAlignPostMg, + HalfProjAlignPost, + FullProjAlignPost, +) +from brainpy._src.dyn.projections.align_pre import ( + FullProjAlignPreSDMg, + FullProjAlignPreDSMg, + FullProjAlignPreSD, + FullProjAlignPreDS, ) - from brainpy._src.dyn.projections.conn import ( SynConn as SynConn, ) - -from brainpy._src.dyn.projections.others import ( - PoissonInput as PoissonInput, -) - from brainpy._src.dyn.projections.inputs import ( InputVar, + PoissonInput, ) diff --git a/brainpy/dyn/synapses.py b/brainpy/dyn/synapses.py index 68be31944..9a097be1a 100644 --- a/brainpy/dyn/synapses.py +++ b/brainpy/dyn/synapses.py @@ -1,6 +1,5 @@ from brainpy._src.dyn.synapses.abstract_models import ( - Delta, Expon, Alpha, DualExpon, diff --git a/docs/apis/brainpy.dyn.projections.rst b/docs/apis/brainpy.dyn.projections.rst index c1f8c1070..0587dcbb8 100644 --- a/docs/apis/brainpy.dyn.projections.rst +++ b/docs/apis/brainpy.dyn.projections.rst @@ -14,14 +14,14 @@ Reduced Projections :nosignatures: :template: classtemplate.rst - ProjAlignPostMg1 - ProjAlignPostMg2 - ProjAlignPost1 - ProjAlignPost2 - ProjAlignPreMg1 - ProjAlignPreMg2 - ProjAlignPre1 - ProjAlignPre2 + HalfProjAlignPostMg + FullProjAlignPostMg + HalfProjAlignPost + FullProjAlignPost + FullProjAlignPreSDMg + FullProjAlignPreDSMg + FullProjAlignPreSD + FullProjAlignPreDS @@ -33,6 +33,8 @@ Projections :nosignatures: :template: classtemplate.rst + HalfProjDelta + FullProjDelta VanillaProj SynConn diff --git a/docs/apis/brainpy.dyn.synapses.rst b/docs/apis/brainpy.dyn.synapses.rst index ea4313c69..bea61ab87 100644 --- a/docs/apis/brainpy.dyn.synapses.rst +++ b/docs/apis/brainpy.dyn.synapses.rst @@ -42,7 +42,6 @@ Phenomenological synapse models :nosignatures: :template: classtemplate.rst - Delta Expon Alpha DualExpon diff --git a/docs/apis/losses.rst b/docs/apis/losses.rst index 8f50c487f..4f4a3d167 100644 --- a/docs/apis/losses.rst +++ b/docs/apis/losses.rst @@ -33,6 +33,14 @@ Comparison log_cosh_loss ctc_loss_with_forward_probs ctc_loss + multi_margin_loss + + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + CrossEntropyLoss NLLLoss L1Loss diff --git a/examples/dynamics_simulation/COBA.py b/examples/dynamics_simulation/COBA.py index af7511e19..60b325657 100644 --- a/examples/dynamics_simulation/COBA.py +++ b/examples/dynamics_simulation/COBA.py @@ -13,7 +13,7 @@ def __init__(self, num_exc, num_inh, inp=20.): self.E = bp.dyn.LifRefLTC(num_exc, **neu_pars) self.I = bp.dyn.LifRefLTC(num_inh, **neu_pars) - self.E2I = bp.dyn.ProjAlignPreMg1( + self.E2I = bp.dyn.FullProjAlignPreSDMg( pre=self.E, syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), delay=None, @@ -21,7 +21,7 @@ def __init__(self, num_exc, num_inh, inp=20.): out=bp.dyn.COBA(E=0.), post=self.I, ) - self.E2E = bp.dyn.ProjAlignPreMg1( + self.E2E = bp.dyn.FullProjAlignPreSDMg( pre=self.E, syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), delay=None, @@ -29,7 +29,7 @@ def __init__(self, num_exc, num_inh, inp=20.): out=bp.dyn.COBA(E=0.), post=self.E, ) - self.I2E = bp.dyn.ProjAlignPreMg1( + self.I2E = bp.dyn.FullProjAlignPreSDMg( pre=self.I, syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), delay=None, @@ -37,7 +37,7 @@ def __init__(self, num_exc, num_inh, inp=20.): out=bp.dyn.COBA(E=-80.), post=self.E, ) - self.I2I = bp.dyn.ProjAlignPreMg1( + self.I2I = bp.dyn.FullProjAlignPreSDMg( pre=self.I, syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), delay=0., @@ -67,7 +67,7 @@ def __init__(self, num_exc, num_inh, inp=20., ltc=True): self.E = bp.dyn.LifRef(num_exc, **neu_pars) self.I = bp.dyn.LifRef(num_inh, **neu_pars) - self.E2E = bp.dyn.ProjAlignPostMg2( + self.E2E = bp.dyn.FullProjAlignPostMg( pre=self.E, delay=None, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.E.num, post=self.E.num), 0.6), @@ -75,7 +75,7 @@ def __init__(self, num_exc, num_inh, inp=20., ltc=True): out=bp.dyn.COBA.desc(E=0.), post=self.E, ) - self.E2I = bp.dyn.ProjAlignPostMg2( + self.E2I = bp.dyn.FullProjAlignPostMg( pre=self.E, delay=None, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.E.num, post=self.I.num), 0.6), @@ -83,7 +83,7 @@ def __init__(self, num_exc, num_inh, inp=20., ltc=True): out=bp.dyn.COBA.desc(E=0.), post=self.I, ) - self.I2E = bp.dyn.ProjAlignPostMg2( + self.I2E = bp.dyn.FullProjAlignPostMg( pre=self.I, delay=None, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.I.num, post=self.E.num), 6.7), @@ -91,7 +91,7 @@ def __init__(self, num_exc, num_inh, inp=20., ltc=True): out=bp.dyn.COBA.desc(E=-80.), post=self.E, ) - self.I2I = bp.dyn.ProjAlignPostMg2( + self.I2I = bp.dyn.FullProjAlignPostMg( pre=self.I, delay=None, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=self.I.num, post=self.I.num), 6.7), diff --git a/examples/dynamics_simulation/COBA_parallel.py b/examples/dynamics_simulation/COBA_parallel.py index 45cf81953..954b01734 100644 --- a/examples/dynamics_simulation/COBA_parallel.py +++ b/examples/dynamics_simulation/COBA_parallel.py @@ -11,7 +11,7 @@ class ExpJIT(bp.Projection): def __init__(self, pre_num, post, prob, g_max, tau=5., E=0.): super().__init__() - self.proj = bp.dyn.ProjAlignPostMg1( + self.proj = bp.dyn.HalfProjAlignPostMg( comm=bp.dnn.EventJitFPHomoLinear(pre_num, post.num, prob=prob, weight=g_max), syn=bp.dyn.Expon.desc(size=post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]), out=bp.dyn.COBA.desc(E=E), @@ -40,7 +40,7 @@ def update(self, input): class ExpMasked(bp.Projection): def __init__(self, pre_num, post, prob, g_max, tau=5., E=0.): super().__init__() - self.proj = bp.dyn.ProjAlignPostMg1( + self.proj = bp.dyn.HalfProjAlignPostMg( comm=bp.dnn.MaskedLinear(bp.conn.FixedProb(prob, pre=pre_num, post=post.num), weight=g_max, sharding=[None, bm.sharding.NEU_AXIS]), syn=bp.dyn.Expon.desc(size=post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]), @@ -111,7 +111,7 @@ def _f(self, indices, indptr, x): class ExpMasked2(bp.Projection): def __init__(self, pre_num, post, prob, g_max, tau=5., E=0.): super().__init__() - self.proj = bp.dyn.ProjAlignPostMg1( + self.proj = bp.dyn.HalfProjAlignPostMg( comm=PCSR(bp.conn.FixedProb(prob, pre=pre_num, post=post.num), weight=g_max, num_shard=4), syn=bp.dyn.Expon.desc(size=post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]), out=bp.dyn.COBA.desc(E=E), diff --git a/examples/dynamics_simulation/decision_making_network.py b/examples/dynamics_simulation/decision_making_network.py index 5351680e6..334f99712 100644 --- a/examples/dynamics_simulation/decision_making_network.py +++ b/examples/dynamics_simulation/decision_making_network.py @@ -18,7 +18,7 @@ def __init__(self, pre, post, conn, delay, g_max, tau, E): raise ValueError syn = bp.dyn.Expon.desc(post.num, tau=tau) out = bp.dyn.COBA.desc(E=E) - self.proj = bp.dyn.ProjAlignPostMg2( + self.proj = bp.dyn.FullProjAlignPostMg( pre=pre, delay=delay, comm=comm, syn=syn, out=out, post=post ) @@ -35,7 +35,7 @@ def __init__(self, pre, post, conn, delay, g_max): raise ValueError syn = bp.dyn.NMDA.desc(pre.num, a=0.5, tau_decay=100., tau_rise=2.) out = bp.dyn.MgBlock(E=0., cc_Mg=1.0) - self.proj = bp.dyn.ProjAlignPreMg2( + self.proj = bp.dyn.FullProjAlignPreDSMg( pre=pre, delay=delay, syn=syn, comm=comm, out=out, post=post ) diff --git a/examples/dynamics_simulation/ei_nets.py b/examples/dynamics_simulation/ei_nets.py index 2243a9ca1..f98527458 100644 --- a/examples/dynamics_simulation/ei_nets.py +++ b/examples/dynamics_simulation/ei_nets.py @@ -9,14 +9,14 @@ def __init__(self): self.N = bp.dyn.LifRefLTC(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=4000, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=4000, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.N) + self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=4000, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=4000, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N) def update(self, input): spk = self.delay.at('I') @@ -40,30 +40,30 @@ def __init__(self): V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPost2(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=ne, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPost2(pre=self.E, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - syn=bp.dyn.Expon(size=ni, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPost2(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=ne, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPost2(pre=self.I, - delay=0.1, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - syn=bp.dyn.Expon(size=ni, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.I) + self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=ne, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + syn=bp.dyn.Expon(size=ni, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=ne, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + syn=bp.dyn.Expon(size=ni, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.I) def update(self, inp): self.E2E() @@ -118,30 +118,30 @@ def __init__(self): V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) + self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) def update(self, inp): self.E2E() @@ -167,30 +167,30 @@ def __init__(self): V_initializer=bp.init.Normal(-55., 2.)) self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) + self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=0.02, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=0.02, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) def update(self, inp): self.E2E() From a11806f95449824f656b554cf44715247fce057f Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 28 Dec 2023 19:54:40 +0800 Subject: [PATCH 2/4] [doc] update doc --- brainpy/_src/dyn/neurons/hh.py | 2 +- docs/tutorial_FAQs/brainpy_ecosystem.ipynb | 29 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py index fca13e8e1..f9145a94b 100644 --- a/brainpy/_src/dyn/neurons/hh.py +++ b/brainpy/_src/dyn/neurons/hh.py @@ -61,7 +61,7 @@ class CondNeuGroupLTC(HHTypedNeuron, Container, TreeNode): where :math:`\alpha_{x}` and :math:`\beta_{x}` are rate constants. .. versionadded:: 2.1.9 - Model the conductance-based neuron model. + Modeling the conductance-based neuron model. Parameters ---------- diff --git a/docs/tutorial_FAQs/brainpy_ecosystem.ipynb b/docs/tutorial_FAQs/brainpy_ecosystem.ipynb index ed88c9596..4b28375b5 100644 --- a/docs/tutorial_FAQs/brainpy_ecosystem.ipynb +++ b/docs/tutorial_FAQs/brainpy_ecosystem.ipynb @@ -51,6 +51,35 @@ "\n", "[brainpy-largescale](https://github.com/NH-NCL/brainpy-largescale) provides one solution for large-scale modeling. It enables multi-device running for BrainPy models.\n" ] + }, + { + "cell_type": "markdown", + "source": [ + "## 《神经计算建模实战》\n", + "\n", + "[《神经计算建模实战》 (Neural Modeling in Action)](https://github.com/c-xy17/NeuralModeling) is a book for brain dynamics modeling based on BrainPy. It introduces the basic concepts and methods of brain dynamics modeling, and provides comprehensive examples for brain dynamics modeling with BrainPy. \n" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "## 神经计算建模与编程培训班\n", + "\n", + "There is a series of training courses for brain dynamics modeling based on BrainPy. \n", + "\n", + "- [第一届神经计算建模与编程培训班 (First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course) \n", + "\n", + "- [第二届神经计算建模与编程培训班 (Second Training Course on Neural Modeling and Programming)](https://github.com/brainpy/2nd-neural-modeling-and-programming-course)\n", + "\n", + "This course is based on the textbook [《神经计算建模实战》 (Neural Modeling in Action)](https://github.com/c-xy17/NeuralModeling), supplemented by BrainPy, and based on the theory of \"theory+practice\" combination of teaching and learning. Through this course, students will master the basic concepts, methods and techniques of neural computation modelling, as well as how to use Python programming language to achieve convenient modelling and efficient simulation of neural systems, laying a solid foundation for future research in the field of neural computation or in the field of brain-like intelligence.\n", + "\n" + ], + "metadata": { + "collapsed": false + } } ], "metadata": { From 117e99731018ae73c330ab44d0880af1f2b41676 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 28 Dec 2023 20:31:10 +0800 Subject: [PATCH 3/4] [fix] fix bug --- brainpy/_src/dyn/neurons/lif.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py index d4599ebca..11934d9dc 100644 --- a/brainpy/_src/dyn/neurons/lif.py +++ b/brainpy/_src/dyn/neurons/lif.py @@ -5,12 +5,12 @@ import brainpy.math as bm from brainpy._src.context import share +from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc, if_doc +from brainpy._src.dyn.neurons.base import GradNeuDyn from brainpy._src.initialize import ZeroInit, OneInit from brainpy._src.integrators import odeint, JointEq from brainpy.check import is_initializer from brainpy.types import Shape, ArrayType, Sharding -from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc, if_doc -from brainpy._src.dyn.neurons.base import GradNeuDyn __all__ = [ 'IF', @@ -994,6 +994,7 @@ class ExpIFRefLTC(ExpIFLTC): %s """ + def __init__( self, size: Shape, @@ -1221,6 +1222,7 @@ class ExpIFRef(ExpIFRefLTC): %s %s """ + def derivative(self, V, t, I): exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau @@ -1424,7 +1426,8 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V, w = self.integral(self.V.value, self.w.value, t, x, dt) + self.sum_delta_inputs() + V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V += self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -1756,7 +1759,8 @@ def update(self, x=None): x = 0. if x is None else x # integrate membrane potential - V, w = self.integral(self.V.value, self.w.value, t, x, dt) + self.sum_delta_inputs() + V, w = self.integral(self.V.value, self.w.value, t, x, dt) + V += self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -2444,7 +2448,6 @@ class QuaIFRef(QuaIFRefLTC): %s """ - def derivative(self, V, t, I): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau return dVdt @@ -2633,7 +2636,7 @@ def update(self, x=None): # integrate membrane potential V, w = self.integral(self.V.value, self.w.value, t, x, dt) - V = V + self.sum_delta_inputs() + V += self.sum_delta_inputs() # spike, spiking time, and membrane potential reset if isinstance(self.mode, bm.TrainingMode): @@ -2940,7 +2943,7 @@ def update(self, x=None): # integrate membrane potential V, w = self.integral(self.V.value, self.w.value, t, x, dt) - V += self.sum_delta_inputs() + V += self.sum_delta_inputs() # refractory refractory = (t - self.t_last_spike) <= self.tau_ref @@ -3576,7 +3579,6 @@ class GifRefLTC(GifLTC): %s """ - def __init__( self, size: Shape, @@ -3844,7 +3846,6 @@ class GifRef(GifRefLTC): %s """ - def dV(self, V, t, I1, I2, I): return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau @@ -4495,7 +4496,7 @@ def update(self, x=None): return super().update(x) -Izhikevich.__doc__ = Izhikevich.__doc__ %(pneu_doc, dpneu_doc) -IzhikevichRefLTC.__doc__ = IzhikevichRefLTC.__doc__ %(pneu_doc, dpneu_doc, ref_doc) -IzhikevichRef.__doc__ = IzhikevichRef.__doc__ %(pneu_doc, dpneu_doc, ref_doc) -IzhikevichLTC.__doc__ = IzhikevichLTC.__doc__ %() +Izhikevich.__doc__ = Izhikevich.__doc__ % (pneu_doc, dpneu_doc) +IzhikevichRefLTC.__doc__ = IzhikevichRefLTC.__doc__ % (pneu_doc, dpneu_doc, ref_doc) +IzhikevichRef.__doc__ = IzhikevichRef.__doc__ % (pneu_doc, dpneu_doc, ref_doc) +IzhikevichLTC.__doc__ = IzhikevichLTC.__doc__ % () From 57ec38739e39716da72ab9a3b657fbdf75b232ac Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 28 Dec 2023 21:09:52 +0800 Subject: [PATCH 4/4] [doc] upgrade the documentation of synaptic projections --- brainpy/_src/dyn/projections/align_post.py | 56 +++++++++++++-- brainpy/_src/dyn/projections/align_pre.py | 69 +++++++++++++++++-- brainpy/_src/dyn/projections/delta.py | 63 +++++++++-------- brainpy/_src/dyn/projections/plasticity.py | 2 +- .../_src/dyn/projections/tests/test_delta.py | 4 +- .../dyn/projections/{base.py => utils.py} | 0 docs/apis/brainpy.dyn.projections.rst | 36 ++++++++-- 7 files changed, 183 insertions(+), 47 deletions(-) rename brainpy/_src/dyn/projections/{base.py => utils.py} (100%) diff --git a/brainpy/_src/dyn/projections/align_post.py b/brainpy/_src/dyn/projections/align_post.py index 217045032..b5679dc7d 100644 --- a/brainpy/_src/dyn/projections/align_post.py +++ b/brainpy/_src/dyn/projections/align_post.py @@ -48,7 +48,19 @@ def reset_state(self, *args, **kwargs): class HalfProjAlignPostMg(Projection): - r"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. + r"""Defining the half part of synaptic projection with the align-post reduction and the automatic synapse merging. + + The ``half-part`` means that the model only needs to provide half information needed for a projection, + including ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs + the manual providing of the spiking input. + + The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. + + The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same + parameters (such like time constants) will also share the same synaptic variables. + + All align-post projection models prefer to use the event-driven computation mode. This means that the + ``comm`` model should be the event-driven model. **Code Examples** @@ -131,7 +143,22 @@ def update(self, x): class FullProjAlignPostMg(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. + """Full-chain synaptic projection with the align-post reduction and the automatic synapse merging. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``. + + The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. + + The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same + parameters (such like time constants) will also share the same synaptic variables. + + All align-post projection models prefer to use the event-driven computation mode. This means that the + ``comm`` model should be the event-driven model. + + Moreover, it's worth noting that ``FullProjAlignPostMg`` has a different updating order with all align-pre + projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``. + While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``. **Code Examples** @@ -245,7 +272,16 @@ def update(self): class HalfProjAlignPost(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. + """Defining the half-part of synaptic projection with the align-post reduction. + + The ``half-part`` means that the model only needs to provide half information needed for a projection, + including ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs + the manual providing of the spiking input. + + The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. + + All align-post projection models prefer to use the event-driven computation mode. This means that the + ``comm`` model should be the event-driven model. To simulate an E/I balanced network: @@ -329,7 +365,19 @@ def update(self, x): class FullProjAlignPost(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. + """Full-chain synaptic projection with the align-post reduction. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``. + + The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group. + + All align-post projection models prefer to use the event-driven computation mode. This means that the + ``comm`` model should be the event-driven model. + + Moreover, it's worth noting that ``FullProjAlignPost`` has a different updating order with all align-pre + projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``. + While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``. To simulate and define an E/I balanced network model: diff --git a/brainpy/_src/dyn/projections/align_pre.py b/brainpy/_src/dyn/projections/align_pre.py index 2b609322c..356de0a6d 100644 --- a/brainpy/_src/dyn/projections/align_pre.py +++ b/brainpy/_src/dyn/projections/align_pre.py @@ -4,7 +4,7 @@ from brainpy._src.delay import (Delay, DelayAccess, init_delay_by_return, register_delay_by_return) from brainpy._src.dynsys import DynamicalSystem, Projection from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData) -from .base import _get_return +from .utils import _get_return __all__ = [ 'FullProjAlignPreSDMg', 'FullProjAlignPreDSMg', @@ -68,7 +68,22 @@ def reset_state(self, *args, **kwargs): class FullProjAlignPreSDMg(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. + """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating and merging. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. + + The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. + + The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the + synapse states to the delay model, and finally computes the synaptic current. + + The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same + parameters (such like time constants) will also share the same synaptic variables. + + Neither ``FullProjAlignPreSDMg`` nor ``FullProjAlignPreDSMg``facilitates the event-driven computation. + This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather + than the spiking. To facilitate the event-driven computation, please use align post projections. To simulate an E/I balanced network model: @@ -182,7 +197,24 @@ def update(self, x=None): class FullProjAlignPreDSMg(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. + """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``delay`` -> ``syn`` -> ``comm`` -> ``out`` -> ``post``. + Note here, compared to ``FullProjAlignPreSDMg``, the ``delay`` and ``syn`` are exchanged. + + The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. + + The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the + spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current. + + The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same + parameters (such like time constants) will also share the same synaptic variables. + + Neither ``FullProjAlignPreDSMg`` nor ``FullProjAlignPreSDMg`` facilitates the event-driven computation. + This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather + than the spiking. To facilitate the event-driven computation, please use align post projections. + To simulate an E/I balanced network model: @@ -296,7 +328,20 @@ def update(self): class FullProjAlignPreSD(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. + """Full-chain synaptic projection with the align-pre reduction and synapse+delay updating. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. + + The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. + + The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the + synapse states to the delay model, and finally computes the synaptic current. + + Neither ``FullProjAlignPreSD`` nor ``FullProjAlignPreDS``facilitates the event-driven computation. + This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather + than the spiking. To facilitate the event-driven computation, please use align post projections. + To simulate an E/I balanced network model: @@ -411,7 +456,21 @@ def update(self, x=None): class FullProjAlignPreDS(Projection): - """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. + """Full-chain synaptic projection with the align-pre reduction and delay+synapse updating. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``. + Note here, compared to ``FullProjAlignPreSD``, the ``delay`` and ``syn`` are exchanged. + + The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group. + + The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the + spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current. + + Neither ``FullProjAlignPreDS`` nor ``FullProjAlignPreSD`` facilitates the event-driven computation. + This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather + than the spiking. To facilitate the event-driven computation, please use align post projections. + To simulate an E/I balanced network model: diff --git a/brainpy/_src/dyn/projections/delta.py b/brainpy/_src/dyn/projections/delta.py index 616f83df6..19e4938cb 100644 --- a/brainpy/_src/dyn/projections/delta.py +++ b/brainpy/_src/dyn/projections/delta.py @@ -23,7 +23,13 @@ def __call__(self, *args, **kwargs): class HalfProjDelta(Projection): - """Delta synaptic projection. + """Defining the half-part of the synaptic projection for the Delta synapse model. + + The synaptic projection requires the input is the spiking data, otherwise + the synapse is not the Delta synapse model. + + The ``half-part`` means that the model only includes ``comm`` -> ``syn`` -> ``out`` -> ``post``. + Therefore, the model's ``update`` function needs the manual providing of the spiking input. **Model Descriptions** @@ -103,7 +109,13 @@ def update(self, x): class FullProjDelta(Projection): - """Delta synaptic projection. + """Full-chain of the synaptic projection for the Delta synapse model. + + The synaptic projection requires the input is the spiking data, otherwise + the synapse is not the Delta synapse model. + + The ``full-chain`` means that the model needs to provide all information needed for a projection, + including ``pre`` -> ``delay`` -> ``comm`` -> ``post``. **Model Descriptions** @@ -121,36 +133,31 @@ class FullProjDelta(Projection): **Code Examples** - To simulate an E/I balanced network model: - .. code-block:: - class EINet(bp.DynSysGroup): + import brainpy as bp + import brainpy.math as bm + + + class Net(bp.DynamicalSystem): def __init__(self): super().__init__() - self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.syn1 = bp.dyn.Expon(size=3200, tau=5.) - self.syn2 = bp.dyn.Expon(size=800, tau=10.) - self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(self.syn1(spk[:3200])) - self.I(self.syn2(spk[3200:])) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(1000) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) + + self.pre = bp.dyn.PoissonGroup(10, 100.) + self.post = bp.dyn.LifRef(1) + self.syn = bp.dyn.FullProjDelta(self.pre, 0., bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post) + + def update(self): + self.syn() + self.pre() + self.post() + return self.post.V.value + + + net = Net() + indices = bm.arange(1000).to_numpy() + vs = bm.for_loop(net.step_run, indices, progress_bar=True) + bp.visualize.line_plot(indices, vs, show=True) Args: diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index 598a7496f..d36074b9c 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -9,7 +9,7 @@ from brainpy.types import ArrayType from .align_post import (align_post_add_bef_update, ) from .align_pre import (align_pre2_add_bef_update, ) -from .base import (_get_return, ) +from .utils import (_get_return, ) __all__ = [ 'STDP_Song2000', diff --git a/brainpy/_src/dyn/projections/tests/test_delta.py b/brainpy/_src/dyn/projections/tests/test_delta.py index 8e16a128a..f4d21b643 100644 --- a/brainpy/_src/dyn/projections/tests/test_delta.py +++ b/brainpy/_src/dyn/projections/tests/test_delta.py @@ -22,7 +22,7 @@ def test1(): net = NetForHalfProj() indices = bm.arange(1000).to_numpy() vs = bm.for_loop(net.step_run, indices, progress_bar=True) - bp.visualize.line_plot(indices, vs, show=True) + bp.visualize.line_plot(indices, vs, show=False) plt.close('all') @@ -45,7 +45,7 @@ def test2(): net = NetForFullProj() indices = bm.arange(1000).to_numpy() vs = bm.for_loop(net.step_run, indices, progress_bar=True) - bp.visualize.line_plot(indices, vs, show=True) + bp.visualize.line_plot(indices, vs, show=False) plt.close('all') diff --git a/brainpy/_src/dyn/projections/base.py b/brainpy/_src/dyn/projections/utils.py similarity index 100% rename from brainpy/_src/dyn/projections/base.py rename to brainpy/_src/dyn/projections/utils.py diff --git a/docs/apis/brainpy.dyn.projections.rst b/docs/apis/brainpy.dyn.projections.rst index 0587dcbb8..5549e6394 100644 --- a/docs/apis/brainpy.dyn.projections.rst +++ b/docs/apis/brainpy.dyn.projections.rst @@ -6,8 +6,8 @@ Synaptic Projections -Reduced Projections -------------------- +Projections for Align-Post Reduction +------------------------------------ .. autosummary:: :toctree: generated/ @@ -18,6 +18,18 @@ Reduced Projections FullProjAlignPostMg HalfProjAlignPost FullProjAlignPost + + + +Projections for Align-Pre Reduction +------------------------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + VanillaProj FullProjAlignPreSDMg FullProjAlignPreDSMg FullProjAlignPreSD @@ -25,8 +37,8 @@ Reduced Projections -Projections ------------ +Projections for Delta synapses +------------------------------ .. autosummary:: :toctree: generated/ @@ -35,8 +47,6 @@ Projections HalfProjDelta FullProjDelta - VanillaProj - SynConn @@ -48,6 +58,18 @@ Inputs :nosignatures: :template: classtemplate.rst - PoissonInput InputVar + + + +Others +------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + SynConn +