Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions brainpy/_src/math/delayvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from brainpy import check
from brainpy.check import is_float, is_integer, jit_error
from brainpy.errors import UnsupportedError
from .compat_numpy import vstack, broadcast_to
from .compat_numpy import broadcast_to, expand_dims, concatenate
from .environment import get_dt, get_float
from .interoperability import as_jax
from .ndarray import ndarray, Array
Expand Down Expand Up @@ -392,6 +392,7 @@ def reset(
dtype=delay_target.dtype),
batch_axis=batch_axis)
else:
self.data.value
self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape,
dtype=delay_target.dtype)

Expand Down Expand Up @@ -472,7 +473,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None):

elif self.update_method == CONCAT_UPDATE:
if self.num_delay_step >= 2:
self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]])
self.data.value = concatenate([expand_dims(value, 0), self.data[1:]], axis=0)
else:
self.data[:] = value

Expand Down