diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 2f94df77a..84aa028e3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -156,7 +156,7 @@ jobs: # strategy: # fail-fast: false # matrix: -# python-version: ["3.8", "3.9", "3.10", "3.11"] +# python-version: ["3.9", "3.10", "3.11"] # # steps: # - uses: actions/checkout@v4 @@ -168,11 +168,7 @@ jobs: # run: | # python -m pip install --upgrade pip # python -m pip install flake8 pytest -# python -m pip install numpy>=1.21.0 -# python -m pip install "jaxlib==0.4.11" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver -# python -m pip install jax==0.4.11 # python -m pip install -r requirements-dev.txt -# python -m pip install tqdm brainpylib # pip uninstall brainpy -y # python setup.py install # - name: Lint with flake8 diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index b635d21f1..539214d3b 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -1,1275 +1,1423 @@ -# -*- coding: utf-8 -*- - - -import numbers -from typing import Dict, Optional, Union, Callable - -import jax -import jax.numpy as jnp -import numba -import numpy as np - -from brainpy import math as bm -from brainpy._src import connect, initialize as init -from brainpy._src.context import share -from brainpy._src.dnn.base import Layer -from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP -from brainpy.check import is_initializer -from brainpy.connect import csr2csc -from brainpy.errors import MathError -from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter -from brainpy.types import ArrayType, Sharding - -__all__ = [ - 'Dense', 'Linear', - 'Identity', - 'AllToAll', - 'OneToOne', - 'MaskedLinear', - 'CSRLinear', 'EventCSRLinear', - 'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear', - 'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear', -] - - -class Dense(Layer, SupportSTDP, SupportOnline, SupportOffline): - r"""A linear transformation applied over the last dimension of the input. - - Mathematically, this node can be defined as: - - .. math:: - - y = x \cdot weight + b - - Parameters - ---------- - num_in: int - The number of the input feature. A positive integer. - num_out: int - The number of the output features. A positive integer. - W_initializer: optional, Initializer - The weight initialization. - b_initializer: optional, Initializer - The bias initialization. - mode: Mode - Enable training this node or not. (default True) - """ - - def __init__( - self, - num_in: int, - num_out: int, - W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(), - b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(), - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super(Dense, self).__init__(mode=mode, name=name) - - # shape - self.num_in = num_in - self.num_out = num_out - if num_in < 0: - raise ValueError(f'Received an invalid value for `num_out`, expected ' - f'a positive integer. Received: num_in={num_in}') - if num_out < 0: - raise ValueError(f'Received an invalid value for `num_out`, expected ' - f'a positive integer. Received: num_out={num_out}') - - # weight initializer - self.W_initializer = W_initializer - self.bias_initializer = b_initializer - is_initializer(W_initializer, 'weight_initializer') - is_initializer(b_initializer, 'bias_initializer', allow_none=True) - - # parameter initialization - W = parameter(self.W_initializer, (num_in, self.num_out)) - b = parameter(self.bias_initializer, (self.num_out,)) - if isinstance(self.mode, bm.TrainingMode): - W = bm.TrainVar(W) - b = None if (b is None) else bm.TrainVar(b) - self.W = W - self.b = b - - # fitting parameters - self.online_fit_by = None # support online training - self.offline_fit_by = None # support offline training - self.fit_record = dict() - - def __repr__(self): - return (f'{self.__class__.__name__}(name={self.name}, ' - f'num_in={self.num_in}, ' - f'num_out={self.num_out}, ' - f'mode={self.mode})') - - def update(self, x): - x = bm.as_jax(x) - res = x @ self.W - if self.b is not None: - res += self.b - - # online fitting data - if share.load('fit', False) and self.online_fit_by is not None: - self.fit_record['input'] = x - self.fit_record['output'] = res - - # offline fitting data - if share.load('fit', False) and self.offline_fit_by is not None: - self.fit_record['input'] = x - self.fit_record['output'] = res - return res - - def online_init(self): - if self.b is None: - num_input = self.num_in - else: - num_input = self.num_in + 1 - self.online_fit_by.register_target(feature_in=num_input, identifier=self.name) - - def online_fit(self, - target: ArrayType, - fit_record: Dict[str, ArrayType]): - if not isinstance(target, (bm.ndarray, jnp.ndarray)): - raise MathError(f'"target" must be a tensor, but got {type(target)}') - x = fit_record['input'] - y = fit_record['output'] - if x.ndim != 2: - raise ValueError(f'"ff" must be a 2D tensor with shape of (num_sample, ' - f'num_feature), but we got {x.shape}') - if target.ndim != 2: - raise ValueError(f'"target" must be a 2D tensor with shape of (num_sample, ' - f'num_feature), but we got {target.shape}') - if x.shape[0] != target.shape[0]: - raise ValueError(f'Batch size of the input and target data should be ' - f'the same, while we got {x.shape[0]} != {target.shape[0]}.') - if target.shape[1] != y.shape[1]: - raise MathError(f'The output dimension of output and target data should be ' - f'the same, while we got {target.shape[1]} != {y.shape[1]}') - - # data - if self.b is not None: - x = jnp.concatenate([jnp.ones((x.shape[0], 1)), x], axis=-1) - - # fitting - dW = self.online_fit_by.call(target=target, input=x, output=y, identifier=self.name) - - # assign trained weights - if self.b is None: - self.W += dW - else: - db, dW = jnp.split(dW, [1]) - self.b += db[0] - self.W += dW - - def offline_fit(self, - target: ArrayType, - fit_record: Dict[str, ArrayType]): - """The offline training interface for the Dense node.""" - # data checking - if not isinstance(target, (bm.ndarray, jnp.ndarray)): - raise MathError(f'"targets" must be a tensor, but got {type(target)}') - xs = fit_record['input'] - ys = fit_record['output'] - if xs.ndim != 3: - raise ValueError(f'"ffs" must be a 3D tensor with shape of (num_sample, num_time, ' - f'num_feature), but we got {xs.shape}') - if target.ndim != 3: - raise ValueError(f'"targets" must be a 3D tensor with shape of (num_sample, num_time, ' - f'num_feature), but we got {target.shape}') - if ys.shape != target.shape: - raise ValueError(f'The shapes of output and target data should be ' - f'the same, while we got {ys.shape} != {target.shape}.') - if xs.shape[0] != target.shape[0]: - raise ValueError(f'Batch size of the input and target data should be ' - f'the same, while we got {xs.shape[0]} != {target.shape[0]}.') - if xs.shape[1] != target.shape[1]: - raise MathError(f'The time dimension of input and target data should be ' - f'the same, while we got {xs.shape[1]} != {target.shape[1]}') - - # get input and target training data - if self.b is not None: - xs = jnp.concatenate([jnp.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input) - - # solve weights by offline training methods - weights = self.offline_fit_by(target, xs, ys) - - # assign trained weights - if self.b is None: - self.W.value = weights - else: - bias, Wff = jnp.split(weights, [1]) - self.W.value = Wff - self.b.value = bias[0] - - def stdp_update( - self, - on_pre: Dict = None, - on_post: Dict = None, - w_min: numbers.Number = None, - w_max: numbers.Number = None - ): - if isinstance(self.W, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(self.W, bm.Variable): - self.tracing_variable('W', self.W, self.W.shape) - if on_pre is not None: - spike = on_pre['spike'] - trace = on_pre['trace'] - self.W.value = dense_on_pre(self.W.value, spike, trace, w_min, w_max) - if on_post is not None: - spike = on_post['spike'] - trace = on_post['trace'] - self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max) - - -Linear = Dense - - -class Identity(Layer): - r"""A placeholder identity operator that is argument-insensitive. - """ - - def __init__(self, *args, **kwargs) -> None: - super(Identity, self).__init__(*args, **kwargs) - - def update(self, x): - return x - - -@numba.njit(nogil=True, fastmath=True, parallel=False) -def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): - out_w[:] = weight - for i in numba.prange(spike.shape[0]): - if spike[i]: - out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) - - -dense_on_pre_prim = bm.XLACustomOp(_cpu_dense_on_pre) - - -def dense_on_pre(weight, spike, trace, w_min, w_max): - if w_min is None: - w_min = -np.inf - if w_max is None: - w_max = np.inf - return dense_on_pre_prim(weight, spike, trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] - - -@numba.njit(nogil=True, fastmath=True, parallel=False) -def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): - out_w[:] = weight - for i in numba.prange(spike.shape[0]): - if spike[i]: - out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) - - -dense_on_post_prim = bm.XLACustomOp(_cpu_dense_on_post) - - -def dense_on_post(weight, spike, trace, w_min, w_max): - if w_min is None: - w_min = -np.inf - if w_max is None: - w_max = np.inf - return dense_on_post_prim(weight, spike, trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] - - -class AllToAll(Layer, SupportSTDP): - """Synaptic matrix multiplication with All2All connections. - - Args: - num_pre: int. The number of neurons in the presynaptic neuron group. - num_post: int. The number of neurons in the postsynaptic neuron group. - weight: The synaptic weights. - sharding: The sharding strategy. - include_self: bool. Whether connect the neuron with at the same position. - mode: Mode. The computing mode. - name: str. The object name. - """ - - def __init__( - self, - num_pre: int, - num_post: int, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - include_self: bool = True, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(mode=mode, name=name) - - self.num_pre = num_pre - self.num_post = num_post - self.include_self = include_self - self.sharding = sharding - - weight = init.parameter(weight, (self.num_pre, self.num_post), sharding=sharding) - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def update(self, pre_val): - if bm.ndim(self.weight) == 0: # weight is a scalar - if isinstance(self.mode, bm.BatchingMode): - assert pre_val.ndim == 2, 'Under the batching mode, the input should be a 2D array.' - post_val = bm.sum(pre_val, keepdims=True, axis=1) - else: - assert pre_val.ndim == 1, 'Under the NonBatching mode, the input should be a 1D array.' - post_val = bm.sum(pre_val) - if not self.include_self: - if self.num_pre == self.num_post: - post_val = post_val - pre_val - elif self.num_pre > self.num_post: - val = pre_val[:self.num_post] - post_val = post_val - val - else: - val = bm.concatenate([pre_val, bm.zeros(self.num_post - self.num_pre)]) - post_val = post_val - val - post_val = self.weight * post_val - - else: # weight is a matrix - assert self.weight.ndim == 2, '"weight" must be a 2D matrix.' - if not self.include_self: - post_val = pre_val @ bm.fill_diagonal(self.weight, 0., inplace=False) - else: - post_val = pre_val @ self.weight - return post_val - - def stdp_update( - self, - on_pre: Dict = None, - on_post: Dict = None, - w_min: numbers.Number = None, - w_max: numbers.Number = None - ): - if isinstance(self.weight, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - if on_pre is not None: - spike = on_pre['spike'] - trace = on_pre['trace'] - self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) - if on_post is not None: - spike = on_post['spike'] - trace = on_post['trace'] - self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max) - - -class OneToOne(Layer, SupportSTDP): - """Synaptic matrix multiplication with One2One connection. - - Args: - num: int. The number of neurons. - weight: The synaptic weight. - sharding: The sharding strategy. - mode: The computing mode. - name: The object name. - - """ - - def __init__( - self, - num: int, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(mode=mode, name=name) - - self.num = num - self.sharding = sharding - - weight = init.parameter(weight, (self.num,), sharding=sharding) - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def update(self, pre_val): - return pre_val * self.weight - - def stdp_update( - self, - on_pre: Dict = None, - on_post: Dict = None, - w_min: numbers.Number = None, - w_max: numbers.Number = None - ): - if isinstance(self.weight, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - if on_pre is not None: - spike = on_pre['spike'] - trace = on_pre['trace'] - self.weight.value += spike * trace - if on_post is not None: - spike = on_post['spike'] - trace = on_post['trace'] - self.weight.value += spike * trace - - -class MaskedLinear(Layer, SupportSTDP): - r"""Synaptic matrix multiplication with masked dense computation. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, - :math:`M` the synaptic weight using a dense matrix. - - >>> import brainpy as bp - >>> l = bp.dnn.MaskedLinear(bp.conn.FixedProb(0.1, pre=100, post=100), - >>> weight=0.1) - - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - mask_fun: Masking function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - mask_fun: Callable = Identity(), - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - assert isinstance(conn, connect.TwoEndConnector) - self.conn = conn - self.sharding = sharding - self.mask_fun = mask_fun - - # weight - weight = init.parameter(weight, (conn.pre_num, conn.post_num), sharding=sharding) - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - # connection - self.mask = bm.sharding.partition(self.conn.require('conn_mat'), sharding=sharding) - - def update(self, x): - return x @ self.mask_fun(self.weight * self.mask) - - def stdp_update( - self, - on_pre: Dict = None, - on_post: Dict = None, - w_min: numbers.Number = None, - w_max: numbers.Number = None - ): - if isinstance(self.weight, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - if on_pre is not None: - spike = on_pre['spike'] - trace = on_pre['trace'] - self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) - if on_post is not None: - spike = on_post['spike'] - trace = on_post['trace'] - self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max) - - -class _CSRLayer(Layer, SupportSTDP): - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = True, - ): - super().__init__(name=name, mode=mode) - - assert isinstance(conn, connect.TwoEndConnector) - assert sharding is None, 'Currently this model does not support sharding.' - self.conn = conn - self.sharding = sharding - self.transpose = transpose - - # connection - self.indices, self.indptr = self.conn.require('csr') - - # weight - weight = init.parameter(weight, (self.indices.size,)) - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def stdp_update( - self, - on_pre: Dict = None, - on_post: Dict = None, - w_min: numbers.Number = None, - w_max: numbers.Number = None - ): - if bm.isscalar(self.weight): - raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.') - if self.weight.shape != self.indices.shape: - raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - if on_pre is not None: # update on presynaptic spike - spike = on_pre['spike'] - trace = on_pre['trace'] - self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max) - if on_post is not None: # update on postsynaptic spike - if not hasattr(self, '_pre_ids'): - with jax.ensure_compile_time_eval(): - self._pre_ids, self._post_indptr, self.w_indices = csr2csc( - [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size) - ) - spike = on_post['spike'] - trace = on_post['trace'] - self.weight.value = csc_on_post_update(self.weight.value, self._pre_ids, self._post_indptr, - self.w_indices, spike, trace, w_min, w_max) - - -class CSRLinear(_CSRLayer): - r"""Synaptic matrix multiplication with CSR sparse computation. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, - :math:`M` the synaptic weight using a CSR sparse matrix. - - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - method: str = None, - transpose: bool = True, - ): - super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) - self.method = method - - def update(self, x): - if x.ndim == 1: - return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - method=self.method, transpose=self.transpose) - elif x.ndim > 1: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_csrmv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_csrmv(self, x): - return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - method=self.method, transpose=self.transpose) - -class EventCSRLinear(_CSRLayer): - r"""Synaptic matrix multiplication with event CSR sparse computation. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weight using a CSR sparse matrix. - - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = True, - ): - super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) - - def update(self, x): - if x.ndim == 1: - return bm.event.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose) - elif x.ndim > 1: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_csrmv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_csrmv(self, x): - return bm.event.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose) - -@numba.njit(nogil=True, fastmath=True, parallel=False) -def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w): - out_w[:] = w - w_min = w_min[()] - w_max = w_max[()] - for i in numba.prange(spike.shape[0]): # pre id - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): # synapse id - j = indices[k] # post id - # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max) - out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) - -csr_on_pre_update_prim = bm.XLACustomOp(_cpu_csr_on_pre_update) - - -def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): - if w_min is None: - w_min = -np.inf - if w_max is None: - w_max = np.inf - return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] - -@numba.njit(nogil=True, fastmath=True, parallel=False) -def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): - out_w[:] = w - w_min = w_min[()] - w_max = w_max[()] - for i in numba.prange(spike.shape[0]): # post id - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = post_ids[k] # pre id - l = w_ids[k] # syn id - out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) - - -csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update) - - -def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None): - if w_min is None: - w_min = -np.inf - if w_max is None: - w_max = np.inf - return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] - - - -class CSCLinear(Layer): - r"""Synaptic matrix multiplication with CSC sparse computation. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, - :math:`M` the synaptic weight using a CSC sparse matrix. - - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - assert isinstance(conn, connect.TwoEndConnector) - self.conn = conn - self.sharding = sharding - - -class BcsrMM(Layer): - r"""Synaptic matrix multiplication with BCSR sparse computation. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, - :math:`M` the synaptic weight using a BCSR sparse matrix. - - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - assert isinstance(conn, connect.TwoEndConnector) - self.conn = conn - self.sharding = sharding - - -class BcscMM(Layer): - r"""Synaptic matrix multiplication with BCSC sparse computation. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, - :math:`M` the synaptic weight using a BCSC sparse matrix. - - Args: - conn: TwoEndConnector. The connection. - weight: Synaptic weights. Can be a scalar, array, or callable function. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - conn: connect.TwoEndConnector, - weight: Union[float, ArrayType, Callable], - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - assert isinstance(conn, connect.TwoEndConnector) - self.conn = conn - self.sharding = sharding - - -class JitFPHomoLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is the same :math:`weight`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - weight: float. The synaptic value at each position. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - weight: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = False, - atomic: bool = False, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - - -class JitFPUniformLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_low: float. The lowest value of the uniform distribution. - w_high: float. The highest value of the uniform distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - w_low: float, - w_high: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = False, - atomic: bool = False, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - self.w_low = w_low - self.w_high = w_high - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - - -class JitFPNormalLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_mu: float. The center of the normal distribution. - w_sigma: float. The standard variance of the normal distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - w_mu: float, - w_sigma: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - transpose: bool = False, - atomic: bool = False, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - self.w_mu = w_mu - self.w_sigma = w_sigma - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - - -class EventJitFPHomoLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is the same :math:`weight`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - weight: float. The synaptic value at each position. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - weight: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = False, - atomic: bool = True, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 1000000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - if isinstance(self.mode, bm.TrainingMode): - weight = bm.TrainVar(weight) - self.weight = weight - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - - -class EventJitFPUniformLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_low: float. The lowest value of the uniform distribution. - w_high: float. The highest value of the uniform distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - w_low: float, - w_high: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - transpose: bool = False, - atomic: bool = True, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - self.w_low = w_low - self.w_high = w_high - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - - -class EventJitFPNormalLinear(Layer): - r"""Synaptic matrix multiplication with the just-in-time connectivity. - - It performs the computation of: - - .. math:: - - y = x @ M - - where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, - :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. - Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, - and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. - - Args: - num_in: int. The number of the input feature. A positive integer. - num_out: int. The number of the input feature. A positive integer. - prob: float. The connectivity probability. - w_mu: float. The center of the normal distribution. - w_sigma: float. The standard variance of the normal distribution. - seed: int. The random seed used to keep the reproducibility of the connectivity. - transpose: bool. Transpose the JIT matrix or not. Default False. - atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. - May be changed in the future. - sharding: The sharding strategy. - mode: The synaptic computing mode. - name: The synapse model name. - """ - - def __init__( - self, - num_in: int, - num_out: int, - prob: float, - w_mu: float, - w_sigma: float, - seed: Optional[int] = None, - sharding: Optional[Sharding] = None, - transpose: bool = False, - atomic: bool = True, - mode: Optional[bm.Mode] = None, - name: Optional[str] = None, - ): - super().__init__(name=name, mode=mode) - - self.prob = prob - self.sharding = sharding - self.transpose = transpose - self.seed = np.random.randint(0, 100000) if seed is None else seed - self.atomic = atomic - self.num_in = num_in - self.num_out = num_out - - # weight - self.w_mu = w_mu - self.w_sigma = w_sigma - - def update(self, x): - if x.ndim == 1: - return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) - elif x.ndim == 2: - return jax.vmap(self._batch_mv)(x) - elif x.ndim > 2: - shapes = x.shape[:-1] - x = bm.flatten(x, end_dim=-2) - y = jax.vmap(self._batch_mv)(x) - return bm.reshape(y, shapes + (y.shape[-1],)) - else: - raise ValueError - - def _batch_mv(self, x): - return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, - shape=(self.num_out, self.num_in), - transpose=self.transpose, - outdim_parallel=not self.atomic) +# -*- coding: utf-8 -*- + + +import numbers +from typing import Dict, Optional, Union, Callable + +import jax +import jax.numpy as jnp +import numba +import numpy as np + +from brainpy import math as bm +from brainpy._src import connect, initialize as init +from brainpy._src.context import share +from brainpy._src.dnn.base import Layer +from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP +from brainpy._src.dependency_check import import_taichi +from brainpy.check import is_initializer +from brainpy.connect import csr2csc +from brainpy.errors import MathError +from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter +from brainpy.types import ArrayType, Sharding + +ti = import_taichi() + +__all__ = [ + 'Dense', 'Linear', + 'Identity', + 'AllToAll', + 'OneToOne', + 'MaskedLinear', + 'CSRLinear', 'EventCSRLinear', + 'JitFPHomoLinear', 'JitFPUniformLinear', 'JitFPNormalLinear', + 'EventJitFPHomoLinear', 'EventJitFPNormalLinear', 'EventJitFPUniformLinear', +] + + +class Dense(Layer, SupportSTDP, SupportOnline, SupportOffline): + r"""A linear transformation applied over the last dimension of the input. + + Mathematically, this node can be defined as: + + .. math:: + + y = x \cdot weight + b + + Parameters + ---------- + num_in: int + The number of the input feature. A positive integer. + num_out: int + The number of the output features. A positive integer. + W_initializer: optional, Initializer + The weight initialization. + b_initializer: optional, Initializer + The bias initialization. + mode: Mode + Enable training this node or not. (default True) + """ + + def __init__( + self, + num_in: int, + num_out: int, + W_initializer: Union[Initializer, Callable, ArrayType] = XavierNormal(), + b_initializer: Optional[Union[Initializer, Callable, ArrayType]] = ZeroInit(), + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super(Dense, self).__init__(mode=mode, name=name) + + # shape + self.num_in = num_in + self.num_out = num_out + if num_in < 0: + raise ValueError(f'Received an invalid value for `num_out`, expected ' + f'a positive integer. Received: num_in={num_in}') + if num_out < 0: + raise ValueError(f'Received an invalid value for `num_out`, expected ' + f'a positive integer. Received: num_out={num_out}') + + # weight initializer + self.W_initializer = W_initializer + self.bias_initializer = b_initializer + is_initializer(W_initializer, 'weight_initializer') + is_initializer(b_initializer, 'bias_initializer', allow_none=True) + + # parameter initialization + W = parameter(self.W_initializer, (num_in, self.num_out)) + b = parameter(self.bias_initializer, (self.num_out,)) + if isinstance(self.mode, bm.TrainingMode): + W = bm.TrainVar(W) + b = None if (b is None) else bm.TrainVar(b) + self.W = W + self.b = b + + # fitting parameters + self.online_fit_by = None # support online training + self.offline_fit_by = None # support offline training + self.fit_record = dict() + + def __repr__(self): + return (f'{self.__class__.__name__}(name={self.name}, ' + f'num_in={self.num_in}, ' + f'num_out={self.num_out}, ' + f'mode={self.mode})') + + def update(self, x): + x = bm.as_jax(x) + res = x @ self.W + if self.b is not None: + res += self.b + + # online fitting data + if share.load('fit', False) and self.online_fit_by is not None: + self.fit_record['input'] = x + self.fit_record['output'] = res + + # offline fitting data + if share.load('fit', False) and self.offline_fit_by is not None: + self.fit_record['input'] = x + self.fit_record['output'] = res + return res + + def online_init(self): + if self.b is None: + num_input = self.num_in + else: + num_input = self.num_in + 1 + self.online_fit_by.register_target(feature_in=num_input, identifier=self.name) + + def online_fit(self, + target: ArrayType, + fit_record: Dict[str, ArrayType]): + if not isinstance(target, (bm.ndarray, jnp.ndarray)): + raise MathError(f'"target" must be a tensor, but got {type(target)}') + x = fit_record['input'] + y = fit_record['output'] + if x.ndim != 2: + raise ValueError(f'"ff" must be a 2D tensor with shape of (num_sample, ' + f'num_feature), but we got {x.shape}') + if target.ndim != 2: + raise ValueError(f'"target" must be a 2D tensor with shape of (num_sample, ' + f'num_feature), but we got {target.shape}') + if x.shape[0] != target.shape[0]: + raise ValueError(f'Batch size of the input and target data should be ' + f'the same, while we got {x.shape[0]} != {target.shape[0]}.') + if target.shape[1] != y.shape[1]: + raise MathError(f'The output dimension of output and target data should be ' + f'the same, while we got {target.shape[1]} != {y.shape[1]}') + + # data + if self.b is not None: + x = jnp.concatenate([jnp.ones((x.shape[0], 1)), x], axis=-1) + + # fitting + dW = self.online_fit_by.call(target=target, input=x, output=y, identifier=self.name) + + # assign trained weights + if self.b is None: + self.W += dW + else: + db, dW = jnp.split(dW, [1]) + self.b += db[0] + self.W += dW + + def offline_fit(self, + target: ArrayType, + fit_record: Dict[str, ArrayType]): + """The offline training interface for the Dense node.""" + # data checking + if not isinstance(target, (bm.ndarray, jnp.ndarray)): + raise MathError(f'"targets" must be a tensor, but got {type(target)}') + xs = fit_record['input'] + ys = fit_record['output'] + if xs.ndim != 3: + raise ValueError(f'"ffs" must be a 3D tensor with shape of (num_sample, num_time, ' + f'num_feature), but we got {xs.shape}') + if target.ndim != 3: + raise ValueError(f'"targets" must be a 3D tensor with shape of (num_sample, num_time, ' + f'num_feature), but we got {target.shape}') + if ys.shape != target.shape: + raise ValueError(f'The shapes of output and target data should be ' + f'the same, while we got {ys.shape} != {target.shape}.') + if xs.shape[0] != target.shape[0]: + raise ValueError(f'Batch size of the input and target data should be ' + f'the same, while we got {xs.shape[0]} != {target.shape[0]}.') + if xs.shape[1] != target.shape[1]: + raise MathError(f'The time dimension of input and target data should be ' + f'the same, while we got {xs.shape[1]} != {target.shape[1]}') + + # get input and target training data + if self.b is not None: + xs = jnp.concatenate([jnp.ones(xs.shape[:2] + (1,)), xs], axis=-1) # (..., 1 + num_ff_input) + + # solve weights by offline training methods + weights = self.offline_fit_by(target, xs, ys) + + # assign trained weights + if self.b is None: + self.W.value = weights + else: + bias, Wff = jnp.split(weights, [1]) + self.W.value = Wff + self.b.value = bias[0] + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if isinstance(self.W, float): + raise ValueError(f'Cannot update the weight of a constant node.') + if not isinstance(self.W, bm.Variable): + self.tracing_variable('W', self.W, self.W.shape) + if on_pre is not None: + spike = on_pre['spike'] + trace = on_pre['trace'] + self.W.value = dense_on_pre(self.W.value, spike, trace, w_min, w_max) + if on_post is not None: + spike = on_post['spike'] + trace = on_post['trace'] + self.W.value = dense_on_post(self.W.value, spike, trace, w_min, w_max) + + +Linear = Dense + + +class Identity(Layer): + r"""A placeholder identity operator that is argument-insensitive. + """ + + def __init__(self, *args, **kwargs) -> None: + super(Identity, self).__init__(*args, **kwargs) + + def update(self, x): + return x + + +# @numba.njit(nogil=True, fastmath=True, parallel=False) +# def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): +# out_w[:] = weight +# for i in numba.prange(spike.shape[0]): +# if spike[i]: +# out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) + +@ti.kernel +def _cpu_dense_on_pre(weight: ti.types.ndarray(ndim=2), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2)): + trace0 = trace[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): + out_w[i, j] = weight[i, j] + for i in range(spike.shape[0]): + if spike[i]: + for j in range(out_w.shape[1]): + new_value = out_w[i, j] + trace0 + if new_value < w_min0: + out_w[i, j] = w_min0 + elif new_value > w_max0: + out_w[i, j] = w_max0 + else: + out_w[i, j] = new_value + + +@ti.kernel +def _gpu_dense_on_pre(weight: ti.types.ndarray(ndim=1), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=1)): + trace0 = trace[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): + out_w[i, j] = weight[i, j] + for i in range(spike.shape[0]): + if spike[i]: + for j in range(out_w.shape[1]): + new_value = out_w[i, j] + trace0 + if new_value < w_min0: + out_w[i, j] = w_min0 + elif new_value > w_max0: + out_w[i, j] = w_max0 + else: + out_w[i, j] = new_value + + +dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_pre, + gpu_kernel=_gpu_dense_on_pre) + + +def dense_on_pre(weight, spike, trace, w_min, w_max): + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + trace = jnp.atleast_1d(trace) + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + return dense_on_pre_prim(weight, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] + + +# @numba.njit(nogil=True, fastmath=True, parallel=False) +# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): +# out_w[:] = weight +# for i in numba.prange(spike.shape[0]): +# if spike[i]: +# out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) + +@ti.kernel +def _cpu_dense_on_post(weight: ti.types.ndarray(ndim=2), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2)): + trace0 = trace[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): + out_w[i, j] = weight[i, j] + for i in range(spike.shape[0]): + if spike[i]: + for j in range(out_w.shape[0]): + new_value = out_w[j, i] + trace0 + if new_value < w_min0: + out_w[j, i] = w_min0 + elif new_value > w_max0: + out_w[j, i] = w_max0 + else: + out_w[j, i] = new_value + +@ti.kernel +def _gpu_dense_on_post(weight: ti.types.ndarray(ndim=2), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2)): + trace0 = trace[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): + out_w[i, j] = weight[i, j] + for i in range(spike.shape[0]): + if spike[i]: + for j in range(out_w.shape[0]): + new_value = out_w[j, i] + trace0 + if new_value < w_min0: + out_w[j, i] = w_min0 + elif new_value > w_max0: + out_w[j, i] = w_max0 + else: + out_w[j, i] = new_value + +dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_post, + gpu_kernel=_gpu_dense_on_post) + + +def dense_on_post(weight, spike, trace, w_min, w_max): + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + trace = jnp.atleast_1d(trace) + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + return dense_on_post_prim(weight, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] + + +class AllToAll(Layer, SupportSTDP): + """Synaptic matrix multiplication with All2All connections. + + Args: + num_pre: int. The number of neurons in the presynaptic neuron group. + num_post: int. The number of neurons in the postsynaptic neuron group. + weight: The synaptic weights. + sharding: The sharding strategy. + include_self: bool. Whether connect the neuron with at the same position. + mode: Mode. The computing mode. + name: str. The object name. + """ + + def __init__( + self, + num_pre: int, + num_post: int, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + include_self: bool = True, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(mode=mode, name=name) + + self.num_pre = num_pre + self.num_post = num_post + self.include_self = include_self + self.sharding = sharding + + weight = init.parameter(weight, (self.num_pre, self.num_post), sharding=sharding) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, pre_val): + if bm.ndim(self.weight) == 0: # weight is a scalar + if isinstance(self.mode, bm.BatchingMode): + assert pre_val.ndim == 2, 'Under the batching mode, the input should be a 2D array.' + post_val = bm.sum(pre_val, keepdims=True, axis=1) + else: + assert pre_val.ndim == 1, 'Under the NonBatching mode, the input should be a 1D array.' + post_val = bm.sum(pre_val) + if not self.include_self: + if self.num_pre == self.num_post: + post_val = post_val - pre_val + elif self.num_pre > self.num_post: + val = pre_val[:self.num_post] + post_val = post_val - val + else: + val = bm.concatenate([pre_val, bm.zeros(self.num_post - self.num_pre)]) + post_val = post_val - val + post_val = self.weight * post_val + + else: # weight is a matrix + assert self.weight.ndim == 2, '"weight" must be a 2D matrix.' + if not self.include_self: + post_val = pre_val @ bm.fill_diagonal(self.weight, 0., inplace=False) + else: + post_val = pre_val @ self.weight + return post_val + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if isinstance(self.weight, float): + raise ValueError(f'Cannot update the weight of a constant node.') + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) + if on_post is not None: + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max) + + +class OneToOne(Layer, SupportSTDP): + """Synaptic matrix multiplication with One2One connection. + + Args: + num: int. The number of neurons. + weight: The synaptic weight. + sharding: The sharding strategy. + mode: The computing mode. + name: The object name. + + """ + + def __init__( + self, + num: int, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(mode=mode, name=name) + + self.num = num + self.sharding = sharding + + weight = init.parameter(weight, (self.num,), sharding=sharding) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, pre_val): + return pre_val * self.weight + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if isinstance(self.weight, float): + raise ValueError(f'Cannot update the weight of a constant node.') + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value += spike * trace + if on_post is not None: + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value += spike * trace + + +class MaskedLinear(Layer, SupportSTDP): + r"""Synaptic matrix multiplication with masked dense computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a dense matrix. + + >>> import brainpy as bp + >>> l = bp.dnn.MaskedLinear(bp.conn.FixedProb(0.1, pre=100, post=100), + >>> weight=0.1) + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + mask_fun: Masking function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + mask_fun: Callable = Identity(), + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + self.conn = conn + self.sharding = sharding + self.mask_fun = mask_fun + + # weight + weight = init.parameter(weight, (conn.pre_num, conn.post_num), sharding=sharding) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + # connection + self.mask = bm.sharding.partition(self.conn.require('conn_mat'), sharding=sharding) + + def update(self, x): + return x @ self.mask_fun(self.weight * self.mask) + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if isinstance(self.weight, float): + raise ValueError(f'Cannot update the weight of a constant node.') + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value = dense_on_pre(self.weight.value, spike, trace, w_min, w_max) + if on_post is not None: + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value = dense_on_post(self.weight.value, spike, trace, w_min, w_max) + + +class _CSRLayer(Layer, SupportSTDP): + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + assert sharding is None, 'Currently this model does not support sharding.' + self.conn = conn + self.sharding = sharding + self.transpose = transpose + + # connection + self.indices, self.indptr = self.conn.require('csr') + + # weight + weight = init.parameter(weight, (self.indices.size,)) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def stdp_update( + self, + on_pre: Dict = None, + on_post: Dict = None, + w_min: numbers.Number = None, + w_max: numbers.Number = None + ): + if bm.isscalar(self.weight): + raise ValueError(f'When using STDP to update synaptic weights, the weight cannot be a scalar.') + if self.weight.shape != self.indices.shape: + raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + if on_pre is not None: # update on presynaptic spike + spike = on_pre['spike'] + trace = on_pre['trace'] + self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max) + if on_post is not None: # update on postsynaptic spike + if not hasattr(self, '_pre_ids'): + with jax.ensure_compile_time_eval(): + self._pre_ids, self._post_indptr, self.w_indices = csr2csc( + [self.indices, self.indptr], self.conn.post_num, data=np.arange(self.weight.size) + ) + spike = on_post['spike'] + trace = on_post['trace'] + self.weight.value = csc_on_post_update(self.weight.value, self._pre_ids, self._post_indptr, + self.w_indices, spike, trace, w_min, w_max) + + +class CSRLinear(_CSRLayer): + r"""Synaptic matrix multiplication with CSR sparse computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a CSR sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + method: str = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) + self.method = method + + def update(self, x): + if x.ndim == 1: + return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + method=self.method, transpose=self.transpose) + elif x.ndim > 1: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_csrmv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_csrmv(self, x): + return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + method=self.method, transpose=self.transpose) + +class EventCSRLinear(_CSRLayer): + r"""Synaptic matrix multiplication with event CSR sparse computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weight using a CSR sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = True, + ): + super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) + + def update(self, x): + if x.ndim == 1: + return bm.event.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose) + elif x.ndim > 1: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_csrmv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_csrmv(self, x): + return bm.event.csrmv(self.weight, self.indices, self.indptr, x, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=self.transpose) + +# @numba.njit(nogil=True, fastmath=True, parallel=False) +# def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w): +# out_w[:] = w +# w_min = w_min[()] +# w_max = w_max[()] +# for i in numba.prange(spike.shape[0]): # pre id +# if spike[i]: +# for k in range(indptr[i], indptr[i + 1]): # synapse id +# j = indices[k] # post id +# # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max) +# out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) + + +@ti.kernel +def _cpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=1)): + trace0 = trace[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + for i in range(out_w.shape[0]): + out_w[i] = w[i] + for i in range(spike.shape[0]): + if spike[i]: + for k in range(indptr[i], indptr[i + 1]): + j = indices[k] + out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) +@ti.kernel +def _gpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=1)): + trace0 = trace[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + for i in range(out_w.shape[0]): + out_w[i] = w[i] + for i in range(spike.shape[0]): + if spike[i]: + for k in range(indptr[i], indptr[i + 1]): + j = indices[k] + out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) + + +csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_cpu_csr_on_pre_update, + gpu_kernel=_gpu_csr_on_pre_update) + + +def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + trace = jnp.atleast_1d(trace) + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + +@numba.njit(nogil=True, fastmath=True, parallel=False) +def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): + out_w[:] = w + w_min = w_min[()] + w_max = w_max[()] + for i in numba.prange(spike.shape[0]): # post id + if spike[i]: + for k in range(indptr[i], indptr[i + 1]): + j = post_ids[k] # pre id + l = w_ids[k] # syn id + out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) + + +csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update) + + +def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None): + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + + + +class CSCLinear(Layer): + r"""Synaptic matrix multiplication with CSC sparse computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a CSC sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + self.conn = conn + self.sharding = sharding + + +class BcsrMM(Layer): + r"""Synaptic matrix multiplication with BCSR sparse computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a BCSR sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + self.conn = conn + self.sharding = sharding + + +class BcscMM(Layer): + r"""Synaptic matrix multiplication with BCSC sparse computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a BCSC sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + self.conn = conn + self.sharding = sharding + + +class JitFPHomoLinear(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is the same :math:`weight`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + weight: float. The synaptic value at each position. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + weight: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.mv_prob_homo(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class JitFPUniformLinear(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_low: float. The lowest value of the uniform distribution. + w_high: float. The highest value of the uniform distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_low: float, + w_high: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = False, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_low = w_low + self.w_high = w_high + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class JitFPNormalLinear(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic variable, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_mu: float. The center of the normal distribution. + w_sigma: float. The standard variance of the normal distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_mu: float, + w_sigma: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + transpose: bool = False, + atomic: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_mu = w_mu + self.w_sigma = w_sigma + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPHomoLinear(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is the same :math:`weight`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + weight: float. The synaptic value at each position. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + weight: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = True, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 1000000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_homo(x, self.weight, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPUniformLinear(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a uniform distribution :math:`U(w_{low}, w_{high})`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_low: float. The lowest value of the uniform distribution. + w_high: float. The highest value of the uniform distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_low: float, + w_high: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + transpose: bool = False, + atomic: bool = True, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_low = w_low + self.w_high = w_high + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_uniform(x, self.w_low, self.w_high, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + + +class EventJitFPNormalLinear(Layer): + r"""Synaptic matrix multiplication with the just-in-time connectivity. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic spikes, + :math:`M` the synaptic weights which has the fixed sparse connectivity and weights. + Particularly, the connectivity in :math:`M` is sampled from a fixed probability :math:`prob`, + and at each connection, the synaptic value is sample from a normal distribution :math:`N(\mu, \sigma)`. + + Args: + num_in: int. The number of the input feature. A positive integer. + num_out: int. The number of the input feature. A positive integer. + prob: float. The connectivity probability. + w_mu: float. The center of the normal distribution. + w_sigma: float. The standard variance of the normal distribution. + seed: int. The random seed used to keep the reproducibility of the connectivity. + transpose: bool. Transpose the JIT matrix or not. Default False. + atomic: bool. Compute the post-synaptic value with the atomic summation. Default False. + May be changed in the future. + sharding: The sharding strategy. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + num_in: int, + num_out: int, + prob: float, + w_mu: float, + w_sigma: float, + seed: Optional[int] = None, + sharding: Optional[Sharding] = None, + transpose: bool = False, + atomic: bool = True, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + self.prob = prob + self.sharding = sharding + self.transpose = transpose + self.seed = np.random.randint(0, 100000) if seed is None else seed + self.atomic = atomic + self.num_in = num_in + self.num_out = num_out + + # weight + self.w_mu = w_mu + self.w_sigma = w_sigma + + def update(self, x): + if x.ndim == 1: + return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) + elif x.ndim == 2: + return jax.vmap(self._batch_mv)(x) + elif x.ndim > 2: + shapes = x.shape[:-1] + x = bm.flatten(x, end_dim=-2) + y = jax.vmap(self._batch_mv)(x) + return bm.reshape(y, shapes + (y.shape[-1],)) + else: + raise ValueError + + def _batch_mv(self, x): + return bm.jitconn.event_mv_prob_normal(x, self.w_mu, self.w_sigma, self.prob, self.seed, + shape=(self.num_out, self.num_in), + transpose=self.transpose, + outdim_parallel=not self.atomic) diff --git a/brainpy/_src/math/event/_info_collection.py b/brainpy/_src/math/event/_info_collection.py index 5f6acbb09..7bb043e3e 100644 --- a/brainpy/_src/math/event/_info_collection.py +++ b/brainpy/_src/math/event/_info_collection.py @@ -140,6 +140,7 @@ def _event_info_gpu_translation(c, events): batch_event_info_p = XLACustomOp( name='batched_event_info', cpu_kernel=_batch_event_info_taichi, + gpu_kernel=_batch_event_info_taichi, outs=_batch_event_info_abstract, ) batch_event_info_p.def_batching_rule(_batch_event_info_batching_rule) @@ -154,7 +155,7 @@ def _event_info_abstract(events, **kwargs): # TODO: first parallel evaluate the sub-sections, then serially event the sub-results. -@numba.jit(fastmath=True) +@numba.njit(fastmath=True) def _event_info(outs, ins): event_ids, event_num = outs event_num.fill(0) @@ -190,6 +191,7 @@ def _event_info_batching_rule(args, axes): event_info_p = XLACustomOp( name='event_info', cpu_kernel=_event_info_taichi, + gpu_kernel=_event_info_taichi, outs=_event_info_abstract, # gpu_func_translation=_event_info_gpu_translation, ) diff --git a/brainpy/_src/math/op_register/tests/test_ad_support.py b/brainpy/_src/math/op_register/tests/test_ad_support.py index 5a9343642..24f010a12 100644 --- a/brainpy/_src/math/op_register/tests/test_ad_support.py +++ b/brainpy/_src/math/op_register/tests/test_ad_support.py @@ -9,6 +9,8 @@ import brainpy as bp import brainpy.math as bm +bm.set_platform('cpu') + def csrmv(data, indices, indptr, vector, *, shape: Tuple[int, int], transpose: bool = False, ): data = jnp.atleast_1d(bm.as_jax(data)) diff --git a/brainpy/_src/math/op_register/tests/test_numba_based.py b/brainpy/_src/math/op_register/tests/test_numba_based.py index dd0a38dbf..968155ef9 100644 --- a/brainpy/_src/math/op_register/tests/test_numba_based.py +++ b/brainpy/_src/math/op_register/tests/test_numba_based.py @@ -1,31 +1,32 @@ -import jax.core -import brainpy.math as bm -import numba - - -@numba.njit(fastmath=True) -def numba_event_csrmv(weight, indices, vector, outs): - outs.fill(0) - weight = weight[()] # 0d - for row_i in range(vector.shape[0]): - if vector[row_i]: - for j in indices[row_i]: - outs[j] += weight - - -prim = bm.XLACustomOp(numba_event_csrmv) - - -def call(s=100): - indices = bm.random.randint(0, s, (s, 80)) - vector = bm.random.rand(s) < 0.1 - out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)]) - print(out[0].shape) - - -def test_event_ELL(): - call(1000) - call(100) - bm.clear_buffer_memory() - - +import jax.core +import brainpy.math as bm +import numba + +bm.set_platform('cpu') + +@numba.njit(fastmath=True) +def numba_event_csrmv(weight, indices, vector, outs): + outs.fill(0) + weight = weight[()] # 0d + for row_i in range(vector.shape[0]): + if vector[row_i]: + for j in indices[row_i]: + outs[j] += weight + + +prim = bm.XLACustomOp(numba_event_csrmv) + + +def call(s=100): + indices = bm.random.randint(0, s, (s, 80)) + vector = bm.random.rand(s) < 0.1 + out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)]) + print(out[0].shape) + + +def test_event_ELL(): + call(1000) + call(100) + bm.clear_buffer_memory() + + diff --git a/brainpy/_src/measure/tests/test_correlation.py b/brainpy/_src/measure/tests/test_correlation.py index 950dbce1f..dd19ca8aa 100644 --- a/brainpy/_src/measure/tests/test_correlation.py +++ b/brainpy/_src/measure/tests/test_correlation.py @@ -1,110 +1,111 @@ -# -*- coding: utf-8 -*- - - -import unittest -from functools import partial - -from jax import jit - -import brainpy as bp -import brainpy.math as bm - - -class TestCrossCorrelation(unittest.TestCase): - def test_c(self): - bm.random.seed() - spikes = bm.asarray([[1, 0, 1, 0, 1, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0]]).T - cc1 = bp.measure.cross_correlation(spikes, 1., dt=1.) - f_cc = jit(partial(bp.measure.cross_correlation, numpy=False, bin=1, dt=1.)) - cc2 = f_cc(spikes) - print(cc1, cc2) - self.assertTrue(cc1 == cc2) - bm.clear_buffer_memory() - - def test_cc(self): - bm.random.seed() - spikes = bm.ones((1000, 10)) - cc1 = bp.measure.cross_correlation(spikes, 1.) - self.assertTrue(cc1 == 1.) - - spikes = bm.zeros((1000, 10)) - cc2 = bp.measure.cross_correlation(spikes, 1.) - self.assertTrue(cc2 == 0.) - - bm.clear_buffer_memory() - - def test_cc2(self): - bm.random.seed() - spikes = bm.random.randint(0, 2, (1000, 10)) - print(bp.measure.cross_correlation(spikes, 1.)) - print(bp.measure.cross_correlation(spikes, 0.5)) - bm.clear_buffer_memory() - - def test_cc3(self): - bm.random.seed() - spikes = bm.random.random((1000, 100)) < 0.8 - print(bp.measure.cross_correlation(spikes, 1.)) - print(bp.measure.cross_correlation(spikes, 0.5)) - bm.clear_buffer_memory() - - def test_cc4(self): - bm.random.seed() - spikes = bm.random.random((1000, 100)) < 0.2 - print(bp.measure.cross_correlation(spikes, 1.)) - print(bp.measure.cross_correlation(spikes, 0.5)) - bm.clear_buffer_memory() - - def test_cc5(self): - bm.random.seed() - spikes = bm.random.random((1000, 100)) < 0.05 - print(bp.measure.cross_correlation(spikes, 1.)) - print(bp.measure.cross_correlation(spikes, 0.5)) - bm.clear_buffer_memory() - - -class TestVoltageFluctuation(unittest.TestCase): - def test_vf1(self): - bm.random.seed() - voltages = bm.random.normal(0, 10, size=(100, 10)) - print(bp.measure.voltage_fluctuation(voltages)) - - bm.enable_x64() - voltages = bm.ones((100, 10)) - r1 = bp.measure.voltage_fluctuation(voltages) - - jit_f = jit(partial(bp.measure.voltage_fluctuation, numpy=False)) - jit_f = jit(lambda a: bp.measure.voltage_fluctuation(a, numpy=False)) - r2 = jit_f(voltages) - print(r1, r2) # TODO: JIT results are different? - # self.assertTrue(r1 == r2) - - bm.disable_x64() - bm.clear_buffer_memory() - - -class TestFunctionalConnectivity(unittest.TestCase): - def test_cf1(self): - bm.random.seed() - act = bm.random.random((10000, 3)) - r1 = bp.measure.functional_connectivity(act) - - jit_f = jit(partial(bp.measure.functional_connectivity, numpy=False)) - r2 = jit_f(act) - - self.assertTrue(bm.allclose(r1, r2)) - bm.clear_buffer_memory() - - -class TestMatrixCorrelation(unittest.TestCase): - def test_mc(self): - bm.random.seed() - A = bm.random.random((100, 100)) - B = bm.random.random((100, 100)) - r1 = (bp.measure.matrix_correlation(A, B)) - - jit_f = jit(partial(bp.measure.matrix_correlation, numpy=False)) - r2 = jit_f(A, B) - self.assertTrue(bm.allclose(r1, r2)) - bm.clear_buffer_memory() - - +# -*- coding: utf-8 -*- + + +import unittest +from functools import partial + +from jax import jit + +import brainpy as bp +import brainpy.math as bm + +bm.set_platform('cpu') + +class TestCrossCorrelation(unittest.TestCase): + def test_c(self): + bm.random.seed() + spikes = bm.asarray([[1, 0, 1, 0, 1, 0, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 0, 0]]).T + cc1 = bp.measure.cross_correlation(spikes, 1., dt=1.) + f_cc = jit(partial(bp.measure.cross_correlation, numpy=False, bin=1, dt=1.)) + cc2 = f_cc(spikes) + print(cc1, cc2) + self.assertTrue(cc1 == cc2) + bm.clear_buffer_memory() + + def test_cc(self): + bm.random.seed() + spikes = bm.ones((1000, 10)) + cc1 = bp.measure.cross_correlation(spikes, 1.) + self.assertTrue(cc1 == 1.) + + spikes = bm.zeros((1000, 10)) + cc2 = bp.measure.cross_correlation(spikes, 1.) + self.assertTrue(cc2 == 0.) + + bm.clear_buffer_memory() + + def test_cc2(self): + bm.random.seed() + spikes = bm.random.randint(0, 2, (1000, 10)) + print(bp.measure.cross_correlation(spikes, 1.)) + print(bp.measure.cross_correlation(spikes, 0.5)) + bm.clear_buffer_memory() + + def test_cc3(self): + bm.random.seed() + spikes = bm.random.random((1000, 100)) < 0.8 + print(bp.measure.cross_correlation(spikes, 1.)) + print(bp.measure.cross_correlation(spikes, 0.5)) + bm.clear_buffer_memory() + + def test_cc4(self): + bm.random.seed() + spikes = bm.random.random((1000, 100)) < 0.2 + print(bp.measure.cross_correlation(spikes, 1.)) + print(bp.measure.cross_correlation(spikes, 0.5)) + bm.clear_buffer_memory() + + def test_cc5(self): + bm.random.seed() + spikes = bm.random.random((1000, 100)) < 0.05 + print(bp.measure.cross_correlation(spikes, 1.)) + print(bp.measure.cross_correlation(spikes, 0.5)) + bm.clear_buffer_memory() + + +class TestVoltageFluctuation(unittest.TestCase): + def test_vf1(self): + bm.random.seed() + voltages = bm.random.normal(0, 10, size=(100, 10)) + print(bp.measure.voltage_fluctuation(voltages)) + + bm.enable_x64() + voltages = bm.ones((100, 10)) + r1 = bp.measure.voltage_fluctuation(voltages) + + jit_f = jit(partial(bp.measure.voltage_fluctuation, numpy=False)) + jit_f = jit(lambda a: bp.measure.voltage_fluctuation(a, numpy=False)) + r2 = jit_f(voltages) + print(r1, r2) # TODO: JIT results are different? + # self.assertTrue(r1 == r2) + + bm.disable_x64() + bm.clear_buffer_memory() + + +class TestFunctionalConnectivity(unittest.TestCase): + def test_cf1(self): + bm.random.seed() + act = bm.random.random((10000, 3)) + r1 = bp.measure.functional_connectivity(act) + + jit_f = jit(partial(bp.measure.functional_connectivity, numpy=False)) + r2 = jit_f(act) + + self.assertTrue(bm.allclose(r1, r2)) + bm.clear_buffer_memory() + + +class TestMatrixCorrelation(unittest.TestCase): + def test_mc(self): + bm.random.seed() + A = bm.random.random((100, 100)) + B = bm.random.random((100, 100)) + r1 = (bp.measure.matrix_correlation(A, B)) + + jit_f = jit(partial(bp.measure.matrix_correlation, numpy=False)) + r2 = jit_f(A, B) + self.assertTrue(bm.allclose(r1, r2)) + bm.clear_buffer_memory() + + diff --git a/brainpy/_src/optimizers/tests/test_ModifyLr.py b/brainpy/_src/optimizers/tests/test_ModifyLr.py index 6e3cbf8c0..01e51016e 100644 --- a/brainpy/_src/optimizers/tests/test_ModifyLr.py +++ b/brainpy/_src/optimizers/tests/test_ModifyLr.py @@ -1,7 +1,8 @@ +from absl.testing import absltest +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm -from absl.testing import parameterized -from absl.testing import absltest dt = 0.04 num_step = int(1.0 / dt) @@ -33,15 +34,10 @@ def __init__(self, num_in, num_hidden): def update(self, x): return self.out(self.rnn(x)) - -with bm.training_environment(): - model = RNN(1, 100) - - -def loss(predictions, targets, l2_reg=2e-4): - mse = bp.losses.mean_squared_error(predictions, targets) - l2 = l2_reg * bp.losses.l2_norm(model.train_vars().unique().dict()) ** 2 - return mse + l2 + def loss(self, predictions, targets, l2_reg=2e-4): + mse = bp.losses.mean_squared_error(predictions, targets) + l2 = l2_reg * bp.losses.l2_norm(self.train_vars().unique().dict()) ** 2 + return mse + l2 class test_ModifyLr(parameterized.TestCase): @@ -54,22 +50,28 @@ class test_ModifyLr(parameterized.TestCase): ] ) def test_NewScheduler(self, LearningRate): + with bm.training_environment(): + model = RNN(1, 100) + opt = bp.optim.Adam(lr=LearningRate, eps=1e-1) - trainer = bp.BPTT(model, loss_fun=loss, optimizer=opt) + trainer = bp.BPTT(model, loss_fun=model.loss, optimizer=opt) bm.clear_buffer_memory() def test_modifylr(self): + with bm.training_environment(): + model = RNN(1, 100) + Scheduler_lr = bp.optim.ExponentialDecayLR(lr=0.025, decay_steps=1, decay_rate=0.99975) opt1 = bp.optim.Adam(lr=Scheduler_lr, eps=1e-1) opt1.lr.lr = 0.01 - trainer1 = bp.BPTT(model, loss_fun=loss, optimizer=opt1) + trainer1 = bp.BPTT(model, loss_fun=model.loss, optimizer=opt1) bm.clear_buffer_memory() opt2 = bp.optim.SGD(lr=Scheduler_lr) opt2.lr.set_value(0.01) - trainer2 = bp.BPTT(model, loss_fun=loss, optimizer=opt2) + trainer2 = bp.BPTT(model, loss_fun=model.loss, optimizer=opt2) bm.clear_buffer_memory() diff --git a/requirements-dev.txt b/requirements-dev.txt index fc074a0fd..0e475e83d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,8 +1,8 @@ numpy numba brainpylib -jax < 0.4.24 -jaxlib < 0.4.24 +jax +jaxlib matplotlib msgpack tqdm diff --git a/requirements-doc.txt b/requirements-doc.txt index a4d94bdc4..8b0a5a6a4 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -1,6 +1,6 @@ tqdm -jax<0.4.24 -jaxlib<0.4.24 +jax +jaxlib matplotlib numpy scipy diff --git a/requirements.txt b/requirements.txt index 6af11954c..02fdebe83 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ numpy -jax < 0.4.24 +jax tqdm numba +taichi==1.7.0 diff --git a/setup.py b/setup.py index 21b2f713c..d7fd45e38 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.8', - install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'numba'], + install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'numba', 'taichi==1.7.0'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues", @@ -69,7 +69,7 @@ ], extras_require={ 'cpu': ['jaxlib>=0.4.13', 'brainpylib'], - 'cuda': ['jax[cuda]', 'brainpylib-cu11x'], + 'cuda': ['jax[cuda]', 'brainpylib-cu12x'], 'cuda11': ['jax[cuda11_local]', 'brainpylib-cu11x'], 'cuda12': ['jax[cuda12_local]', 'brainpylib-cu12x'], 'tpu': ['jax[tpu]'],