Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion brainpy/_src/dnn/interoperation_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def initialize_carry(self, rng, batch_dims, size=None, init_fn=None):
raise NotImplementedError

_state_vars = self.model.vars().unique().not_subset(bm.TrainVar)
self.model.reset_state(batch_size=batch_dims)
self.model.reset(batch_size=batch_dims)
return [_state_vars.dict(), 0, 0.]

def setup(self):
Expand Down
17 changes: 17 additions & 0 deletions brainpy/_src/dyn/projections/align_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def update(self, x):
self.refs['syn'].add_current(current) # synapse post current
return current

syn = property(lambda self: self.refs['syn'])
out = property(lambda self: self.refs['out'])
post = property(lambda self: self.refs['post'])


class FullProjAlignPostMg(Projection):
"""Full-chain synaptic projection with the align-post reduction and the automatic synapse merging.
Expand Down Expand Up @@ -270,6 +274,12 @@ def update(self):
self.refs['syn'].add_current(current) # synapse post current
return current

syn = property(lambda self: self.refs['syn'])
out = property(lambda self: self.refs['out'])
delay = property(lambda self: self.refs['delay'])
pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])


class HalfProjAlignPost(Projection):
"""Defining the half-part of synaptic projection with the align-post reduction.
Expand Down Expand Up @@ -363,6 +373,8 @@ def update(self, x):
self.refs['out'].bind_cond(g) # synapse post current
return current

post = property(lambda self: self.refs['post'])


class FullProjAlignPost(Projection):
"""Full-chain synaptic projection with the align-post reduction.
Expand Down Expand Up @@ -488,3 +500,8 @@ def update(self):
g = self.syn(self.comm(x))
self.refs['out'].bind_cond(g) # synapse post current
return g

delay = property(lambda self: self.refs['delay'])
pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
out = property(lambda self: self.refs['out'])
23 changes: 23 additions & 0 deletions brainpy/_src/dyn/projections/align_pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ def update(self, x=None):
self.refs['out'].bind_cond(current)
return current

pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
syn = property(lambda self: self.refs['syn'])
delay = property(lambda self: self.refs['delay'])
out = property(lambda self: self.refs['out'])


class FullProjAlignPreDSMg(Projection):
"""Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging.
Expand Down Expand Up @@ -326,6 +332,11 @@ def update(self):
self.refs['out'].bind_cond(current)
return current

pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
syn = property(lambda self: self.refs['syn'])
out = property(lambda self: self.refs['out'])


class FullProjAlignPreSD(Projection):
"""Full-chain synaptic projection with the align-pre reduction and synapse+delay updating.
Expand Down Expand Up @@ -454,6 +465,12 @@ def update(self, x=None):
self.refs['out'].bind_cond(current)
return current

pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
syn = property(lambda self: self.refs['syn'])
delay = property(lambda self: self.refs['delay'])
out = property(lambda self: self.refs['out'])


class FullProjAlignPreDS(Projection):
"""Full-chain synaptic projection with the align-pre reduction and delay+synapse updating.
Expand Down Expand Up @@ -581,3 +598,9 @@ def update(self):
g = self.comm(self.syn(spk))
self.refs['out'].bind_cond(g)
return g

pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
delay = property(lambda self: self.refs['delay'])
out = property(lambda self: self.refs['out'])

6 changes: 6 additions & 0 deletions brainpy/_src/dyn/projections/plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ def __init__(
self.A1 = A1
self.A2 = A2

pre = property(lambda self: self.refs['pre'])
post = property(lambda self: self.refs['post'])
syn = property(lambda self: self.refs['syn'])
delay = property(lambda self: self.refs['delay'])
out = property(lambda self: self.refs['out'])

def update(self):
# pre-synaptic spikes
pre_spike = self.refs['delay'].at(self.name) # spike
Expand Down
Loading