diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 09bf2958d..b635d21f1 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -570,7 +570,7 @@ def __init__( sharding: Optional[Sharding] = None, mode: Optional[bm.Mode] = None, name: Optional[str] = None, - method: str = 'cusparse', + method: str = None, transpose: bool = True, ): super().__init__(name=name, mode=mode, conn=conn, weight=weight, sharding=sharding, transpose=transpose) @@ -580,8 +580,7 @@ def update(self, x): if x.ndim == 1: return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose, - method=self.method) + method=self.method, transpose=self.transpose) elif x.ndim > 1: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) @@ -593,9 +592,7 @@ def update(self, x): def _batch_csrmv(self, x): return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), - transpose=self.transpose, - method=self.method) - + method=self.method, transpose=self.transpose) class EventCSRLinear(_CSRLayer): r"""Synaptic matrix multiplication with event CSR sparse computation. @@ -646,7 +643,6 @@ def _batch_csrmv(self, x): shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) - @numba.njit(nogil=True, fastmath=True, parallel=False) def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w): out_w[:] = w @@ -659,7 +655,6 @@ def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max) out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) - csr_on_pre_update_prim = bm.XLACustomOp(_cpu_csr_on_pre_update) @@ -671,7 +666,6 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] - @numba.njit(nogil=True, fastmath=True, parallel=False) def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): out_w[:] = w @@ -697,6 +691,7 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_m outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + class CSCLinear(Layer): r"""Synaptic matrix multiplication with CSC sparse computation. @@ -1080,7 +1075,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, transpose: bool = False, - atomic: bool = False, + atomic: bool = True, ): super().__init__(name=name, mode=mode) @@ -1161,7 +1156,7 @@ def __init__( mode: Optional[bm.Mode] = None, name: Optional[str] = None, transpose: bool = False, - atomic: bool = False, + atomic: bool = True, ): super().__init__(name=name, mode=mode) @@ -1239,7 +1234,7 @@ def __init__( seed: Optional[int] = None, sharding: Optional[Sharding] = None, transpose: bool = False, - atomic: bool = False, + atomic: bool = True, mode: Optional[bm.Mode] = None, name: Optional[str] = None, ): diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index da49bdbfe..7fc89526c 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -213,6 +213,5 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): self.assertTrue(y2.shape == shape + (200,)) bm.clear_buffer_memory() - if __name__ == '__main__': absltest.main() diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index 865d682a0..631129558 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,5 +1,4 @@ from ._info_collection import * from ._csr_matvec import * -from ._csr_matvec_taichi import * diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 9da0cf524..2e7895334 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -10,7 +10,6 @@ """ - from functools import partial from typing import Union, Tuple @@ -22,20 +21,69 @@ from jax.interpreters import ad, xla from jax.lib import xla_client +from brainpy._src.dependency_check import (import_brainpylib_gpu_ops) +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching) -from brainpy._src.math.sparse._csr_mv import csrmv as normal_csrmv + register_general_batching, + XLACustomOp) +from brainpy._src.math.sparse._csr_mv import csrmv_brainpylib as normal_csrmv +from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi from brainpy._src.math.sparse._utils import csr_to_coo -from brainpy._src.dependency_check import (import_brainpylib_gpu_ops) from brainpy.errors import GPUOperatorNotFound __all__ = [ 'csrmv' ] +ti = import_taichi() + def csrmv( + data: Union[float, jax.Array], + indices: jax.Array, + indptr: jax.Array, + events: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, +) -> jax.Array: + """Product of a sparse CSR matrix and a dense event vector. + + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. + + Parameters + ---------- + data: ndarray, float + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: ndarray + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + events: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + If ``transpose=True``, the operator will compute based on the + event-driven property of the ``events`` vector. + + Returns + ------- + y : Array + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + return csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) + + +### BRAINPYLIB ### + +def csrmv_brainpylib( data: Union[float, jax.Array], indices: jax.Array, indptr: jax.Array, @@ -519,15 +567,15 @@ def _event_csr_matvec_batching_rule(args, axes, *, shape, transpose): return r, 0 -def _event_csr_matvec_jvp_values(values_dot, values, indices, indptr, events, *, shape, transpose): - return csrmv(values_dot, indices, indptr, events, shape=shape, transpose=transpose) +def _event_csr_matvec_jvp_values_brainpylib(values_dot, values, indices, indptr, events, *, shape, transpose): + return normal_csrmv(values_dot, indices, indptr, events, shape=shape, transpose=transpose) -def _event_csr_matvec_jvp_events(events_dot, values, indices, indptr, events, *, shape, transpose): +def _event_csr_matvec_jvp_events_brainpylib(events_dot, values, indices, indptr, events, *, shape, transpose): return normal_csrmv(values, indices, indptr, events_dot, shape=shape, transpose=transpose) -def _event_csr_matvec_transpose(ct, values, indices, indptr, events, *, shape, transpose): +def _event_csr_matvec_transpose_brainpylib(ct, values, indices, indptr, events, *, shape, transpose): if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): raise ValueError("Cannot transpose with respect to sparse indices.") if ad.is_undefined_primal(events): @@ -538,7 +586,7 @@ def _event_csr_matvec_transpose(ct, values, indices, indptr, events, *, shape, t ct_values = ad.Zero(values) else: if values.aval.shape[0] == 1: # scalar - ct_values = csrmv(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose) + ct_values = csrmv_brainpylib(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose) ct_values = jnp.inner(ct, ct_values) else: # heterogeneous values row, col = csr_to_coo(indices, indptr) @@ -551,7 +599,491 @@ def _event_csr_matvec_transpose(ct, values, indices, indptr, events, *, shape, t event_csr_matvec_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_p)) xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation -ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values, None, None, _event_csr_matvec_jvp_events) -ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose +ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None, + _event_csr_matvec_jvp_events_brainpylib) +ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose_brainpylib register_general_batching(event_csr_matvec_p) + + # batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule + + +### TAICHI ### + +def csrmv_taichi( + data: Union[float, jax.Array], + indices: jax.Array, + indptr: jax.Array, + events: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False +) -> jax.Array: + """Product of a sparse CSR matrix and a dense event vector. + + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. + + Parameters + ---------- + data: ndarray, float + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: ndarray + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + events: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + If ``transpose=True``, the operator will compute based on the + event-driven property of the ``events`` vector. + + Returns + ------- + y : Array + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + data = as_jax(data) + indices = as_jax(indices) + indptr = as_jax(indptr) + events = as_jax(events) + + # checking + data = jnp.atleast_1d(data) + if np.ndim(data) == 1: + if data.shape[0] not in [1, indices.shape[0]]: + raise ValueError('The size of data should be 1 or be consistent with indices.' + f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') + else: + raise ValueError('data should be a scalar or 1D vector. ' + f'But we got {np.ndim(data)}-D array.') + if np.ndim(indices) != 1: + raise ValueError('indices should be a 1D vector with integer type.') + if np.ndim(indptr) != 1: + raise ValueError('indptr should be a 1D vector with integer type.') + if indices.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: + raise ValueError( + 'indices should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') + if indptr.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: + raise ValueError( + 'indptr should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') + if np.ndim(events) != 1: + raise ValueError('events should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if transpose: + if events.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') + else: + if events.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') + + # if the shape of indices is (0,), then we return a zero vector + if indices.shape[0] == 0: + return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) + + return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] + + +# ------------- +# CPU operators +# ------------- + +# 1. The benchmarking shows that the performance of the following transpose +# kernels is maximized when using serialized mode +# 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable +# arguments, we have to define each kernel separately when the +# non-differentiable/non-jittable arguments are different. + + +@ti.kernel +def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + +@ti.kernel +def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + +@ti.kernel +def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i] != 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + +@ti.kernel +def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i] != 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + +@ti.kernel +def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += value + out[row_i] = r + + +@ti.kernel +def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += values[j] + out[row_i] = r + + +@ti.kernel +def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] != 0.: + r += value + out[row_i] = r + + +@ti.kernel +def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] != 0.: + r += values[j] + out[row_i] = r + + +# ------------- +# GPU operators +# ------------- + +# 1. GPU kernels are different from the CPU ones, since the GPU kernels need +# to use warp-level parallelism to achieve the best performance. + + +@ti.kernel +def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i]: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += value + j += 32 + + +@ti.kernel +def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i] != 0.: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += value + j += 32 + + +# TODO +# It is important to note that the following warp-based kernels +# should be improved, since the atomic_add for each thread is not +# very efficient. Instead, the warp-level reduction primitive +# should be used. +# see ``warp_reduce_sum()`` function in tifunc.py. +# However, currently Taichi does not support general warp-level primitives. + + +@ti.kernel +def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]]: + r += value + j += 32 + out[row_i] += r # TODO: warp-level primitive + + +@ti.kernel +def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]] != 0.: + r += value + j += 32 + out[row_i] += r # TODO: warp-level primitive + + +@ti.kernel +def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i]: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += values[j] + j += 32 + + +@ti.kernel +def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i] != 0.: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += values[j] + j += 32 + + +@ti.kernel +def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]]: + r += values[j] + j += 32 + out[row_i] += r # TODO: warp-level primitive + + +@ti.kernel +def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]] != 0.: + r += values[j] + j += 32 + out[row_i] += r # TODO: warp-level primitive + + +def raw_csrmv_taichi( + data: Union[float, jax.Array], + indices: jax.Array, + indptr: jax.Array, + events: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False +): + if transpose: + if events.dtype == jnp.bool_: + if data.shape[0] == 1: + prim = _event_csrmv_transpose_bool_homo_p + else: + prim = _event_csrmv_transpose_bool_heter_p + else: + if data.shape[0] == 1: + prim = _event_csrmv_transpose_homo_p + else: + prim = _event_csrmv_transpose_heter_p + else: + if events.dtype == jnp.bool_: + if data.shape[0] == 1: + prim = _event_csrmv_bool_homo_p + else: + prim = _event_csrmv_bool_heter_p + else: + if data.shape[0] == 1: + prim = _event_csrmv_homo_p + else: + prim = _event_csrmv_heter_p + + # computing + return prim(data, + indices, + indptr, + events, + outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)], + transpose=transpose, + shape=shape) + + +def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) + + +def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) + + +def _event_csr_matvec_transpose_taichi( + ct, values, indices, indptr, events, *, outs, transpose, shape +): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(events): + ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] + return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) + else: + if type(ct[0]) is ad.Zero: + ct_values = ad.Zero(values) + else: + if values.aval.shape[0] == 1: # scalar + ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] + ct_values = jnp.inner(ct[0], ct_values) + else: # heterogeneous values + row, col = csr_to_coo(indices, indptr) + ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] + return ct_values, indices, indptr, events + + +def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) + prim.def_transpose_rule(_event_csr_matvec_transpose_taichi) + return prim + + +# transpose bool homo +_event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, + _event_csr_matvec_transpose_bool_homo_gpu) + +# transpose homo +_event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, _event_csr_matvec_transpose_homo_gpu) + +# not transpose bool homo +_event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, _event_csr_matvec_bool_homo_gpu) + +# not transpose homo +_event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, _event_csr_matvec_homo_gpu) + +# transpose bool heter +_event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, + _event_csr_matvec_transpose_bool_heter_gpu) + +# transpose heter +_event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, + _event_csr_matvec_transpose_heter_gpu) + +# not transpose bool heter +_event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, _event_csr_matvec_bool_heter_gpu) + +# not transpose heter +_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, _event_csr_matvec_heter_gpu) diff --git a/brainpy/_src/math/event/_csr_matvec_taichi.py b/brainpy/_src/math/event/_csr_matvec_taichi.py deleted file mode 100644 index 9be9c49d9..000000000 --- a/brainpy/_src/math/event/_csr_matvec_taichi.py +++ /dev/null @@ -1,487 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Union, Tuple - -import jax -import jax.numpy as jnp -import numpy as np -from jax.interpreters import ad - -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.sparse._csr_mv_taichi import csrmv_taichi as normal_csrmv_taichi -from brainpy._src.math.sparse._utils import csr_to_coo - -ti = import_taichi() - -__all__ = [ - 'csrmv_taichi' -] - - -# ------------- -# CPU operators -# ------------- - -# 1. The benchmarking shows that the performance of the following transpose -# kernels is maximized when using serialized mode -# 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable -# arguments, we have to define each kernel separately when the -# non-differentiable/non-jittable arguments are different. - - -@ti.kernel -def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - -@ti.kernel -def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - -@ti.kernel -def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - -@ti.kernel -def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - -@ti.kernel -def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += value - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += values[j] - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += value - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += values[j] - out[row_i] = r - - -# ------------- -# GPU operators -# ------------- - -# 1. GPU kernels are different from the CPU ones, since the GPU kernels need -# to use warp-level parallelism to achieve the best performance. - - -@ti.kernel -def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - -@ti.kernel -def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - -# TODO -# It is important to note that the following warp-based kernels -# should be improved, since the atomic_add for each thread is not -# very efficient. Instead, the warp-level reduction primitive -# should be used. -# see ``warp_reduce_sum()`` function in tifunc.py. -# However, currently Taichi does not support general warp-level primitives. - - -@ti.kernel -def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - -@ti.kernel -def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - -@ti.kernel -def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -def _event_csr_matvec_jvp_values(val_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) - - -def _event_csr_matvec_jvp_events(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) - - -def _event_csr_matvec_transpose( - ct, values, indices, indptr, events, *, outs, transpose, shape -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] - return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) - else: - if type(ct[0]) is ad.Zero: - ct_values = ad.Zero(values) - else: - if values.aval.shape[0] == 1: # scalar - ct_values = csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] - ct_values = jnp.inner(ct[0], ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] - return ct_values, indices, indptr, events - - -def csrmv_taichi( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -) -> jax.Array: - """Product of a sparse CSR matrix and a dense event vector. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - If ``transpose=True``, the operator will compute based on the - event-driven property of the ``events`` vector. - - Returns - ------- - y : Array - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - data = as_jax(data) - indices = as_jax(indices) - indptr = as_jax(indptr) - events = as_jax(events) - - # checking - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - else: - raise ValueError('data should be a scalar or 1D vector. ' - f'But we got {np.ndim(data)}-D array.') - if np.ndim(indices) != 1: - raise ValueError('indices should be a 1D vector with integer type.') - if np.ndim(indptr) != 1: - raise ValueError('indptr should be a 1D vector with integer type.') - if indices.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: - raise ValueError( - 'indices should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') - if indptr.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: - raise ValueError( - 'indptr should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') - if np.ndim(events) != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - # if the shape of indices is (0,), then we return a zero vector - if indices.shape[0] == 0: - return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) - - if transpose: - if events.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csrmv_transpose_bool_homo_p - else: - prim = _event_csrmv_transpose_bool_heter_p - else: - if data.shape[0] == 1: - prim = _event_csrmv_transpose_homo_p - else: - prim = _event_csrmv_transpose_heter_p - else: - if events.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csrmv_bool_homo_p - else: - prim = _event_csrmv_bool_heter_p - else: - if data.shape[0] == 1: - prim = _event_csrmv_homo_p - else: - prim = _event_csrmv_heter_p - - # computing - return prim(data, - indices, - indptr, - events, - outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)], - transpose=transpose, - shape=shape) - - -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_csr_matvec_jvp_values, None, None, _event_csr_matvec_jvp_events) - prim.def_transpose_rule(_event_csr_matvec_transpose) - return prim - - -# transpose bool homo -_event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, - _event_csr_matvec_transpose_bool_homo_gpu) - -# transpose homo -_event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, _event_csr_matvec_transpose_homo_gpu) - -# not transpose bool homo -_event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, _event_csr_matvec_bool_homo_gpu) - -# not transpose homo -_event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, _event_csr_matvec_homo_gpu) - -# transpose bool heter -_event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, - _event_csr_matvec_transpose_bool_heter_gpu) - -# transpose heter -_event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, - _event_csr_matvec_transpose_heter_gpu) - -# not transpose bool heter -_event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, _event_csr_matvec_bool_heter_gpu) - -# not transpose heter -_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, _event_csr_matvec_heter_gpu) diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 3ca456b0b..e0f38490f 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -8,13 +8,8 @@ import brainpy as bp import brainpy.math as bm -import platform -import pytest - -is_manual_test = False -if platform.system() == 'Windows' and not is_manual_test: - pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) +seed = 1234 def sum_op(op): @@ -24,127 +19,92 @@ def func(*args, **kwargs): return func +taichi_csr_matvec = bm.event.csrmv -class Test_event_csr_matvec(parameterized.TestCase): +class Test_event_csr_matvec_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_csr_matvec, self).__init__(*args, **kwargs) - bm.set_platform(platform) + super(Test_event_csr_matvec_taichi, self).__init__(*args, **kwargs) + print() + bm.set_platform(platform) - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', - transpose=transpose, - shape=shape, - homo_data=homo_data, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (10000, 2)] - for homo_data in [-1., 0., 1.] + @parameterized.product( + transpose=[True, False], + shape=[(100, 200), + (200, 200), + (200, 100), + (10, 1000)], + homo_data=[-1., 0., 1.], ) - def test_homo(self, shape, transpose, homo_data): + def test_homo(self, transpose, shape, homo_data): print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState() + rng = bm.random.RandomState(seed=seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') events = rng.random(shape[0] if transpose else shape[1]) < 0.1 heter_data = bm.ones(indices.shape) * homo_data - r1 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = bm.event.csrmv(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - r3 = bm.event.csrmv(homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r3)) - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r4 = (events @ dense) if transpose else (dense @ events) - self.assertTrue(bm.allclose(r1, r4)) + r1 = (events @ dense) if transpose else (dense @ events) + r2 = taichi_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r5 = bm.event.csrmv(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r5)) + assert (bm.allclose(r1, r2)) bm.clear_buffer_memory() - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', - transpose=transpose, - shape=shape, - homo_data=homo_data, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - for homo_data in [-1., 0., 1.] + @parameterized.product( + transpose=[True, False], + shape=[(100, 200), + (200, 200), + (200, 100), + (10, 1000)], + homo_data=[-1., 0., 1.], ) def test_homo_vmap(self, shape, transpose, homo_data): print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState() + rng = bm.random.RandomState(seed=seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, + f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, + shape=shape, transpose=transpose)) + f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) - f2 = jax.vmap( - partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), - shape=shape, transpose=transpose)) vmap_data = bm.as_jax([homo_data] * 10) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) # vmap 'events' - f3 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr, + f3 = jax.vmap(partial(bm.sparse.csrmv, homo_data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), homo_data, indices, indptr, + f4 = jax.vmap(partial(taichi_csr_matvec, homo_data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) + self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose, - method='cusparse')) + f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + vmap_data1 = bm.as_jax([homo_data] * 10) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2.astype(float)))) + f6(vmap_data1, vmap_data2))) bm.clear_buffer_memory() - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose},shape={shape},homo_data={homo_data}', - homo_data=homo_data, - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - for homo_data in [-1., 0., 1.] + @parameterized.product( + transpose=[True, False], + shape=[(100, 200), + (200, 200), + (200, 100), + (10, 1000)], + homo_data=[-1., 0., 1.], ) def test_homo_grad(self, shape, transpose, homo_data): print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState() + rng = bm.random.RandomState(seed=seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -152,140 +112,102 @@ def test_homo_grad(self, shape, transpose, homo_data): dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) # grad 'data' - r1 = jax.grad(sum_op(bm.event.csrmv))( + r1 = jax.grad(sum_op(bm.sparse.csrmv))( + homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(taichi_csr_matvec))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) - r3 = jax.grad(sum_op(lambda a: (events @ (dense_conn * a) if transpose else - ((dense_conn * a) @ events))))(homo_data) - self.assertTrue(bm.allclose(r1, r3)) # grad 'events' - r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( + r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( + r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(lambda e: (e @ (dense_conn * homo_data) if transpose else - ((dense_conn * homo_data) @ e))))(events.astype(float)) - self.assertTrue(bm.allclose(r4, r5)) - self.assertTrue(bm.allclose(r4, r6)) + self.assertTrue(bm.allclose(r3, r4)) bm.clear_buffer_memory() - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}', - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (10000, 2)] + @parameterized.product( + transpose=[True, False], + shape=[(100, 200), + (200, 200), + (200, 100), + (10, 1000), ] ) def test_heter(self, shape, transpose): print(f'test_heter: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState() + rng = bm.random.RandomState(seed=seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 heter_data = bm.as_jax(rng.random(indices.shape)) - r1 = bm.event.csrmv(heter_data, indices, indptr, events, + r1 = bm.sparse.csrmv(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = partial(bm.sparse.csrmv, method='cusparse')(heter_data, indices, indptr, events.astype(float), - shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r3 = (events @ dense) if transpose else (dense @ events) - self.assertTrue(bm.allclose(r1, r3)) + r2 = taichi_csr_matvec(heter_data, indices, indptr, events, + shape=shape, transpose=transpose) - r4 = bm.event.csrmv(heter_data, indices, indptr, events.astype(float), - shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r4)) + assert (bm.allclose(r1, r2)) bm.clear_buffer_memory() - @parameterized.named_parameters( - dict( - testcase_name=f"transpose={transpose}, shape={shape}", - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] + @parameterized.product( + transpose=[True, False], + shape=[(100, 200), + (200, 200), + (200, 100), + (10, 1000)] ) def test_heter_vmap(self, shape, transpose): print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState() + rng = bm.random.RandomState(seed=seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, + f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, + shape=shape, transpose=transpose)) + f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) - f2 = jax.vmap( - partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), - shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) # vmap 'events' data = bm.as_jax(rng.random(indices.shape)) - f3 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr, + f3 = jax.vmap(partial(bm.sparse.csrmv, data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), data, indices, indptr, + f4 = jax.vmap(partial(taichi_csr_matvec, data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) + self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, + f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: partial(bm.sparse.csrmv, method='cusparse')(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, + shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2.astype(float)))) + f6(vmap_data1, vmap_data2))) bm.clear_buffer_memory() - @parameterized.named_parameters( - dict(testcase_name=f'transpose={transpose},shape={shape}', - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] + @parameterized.product( + transpose=[True, False], + shape=[(100, 200), + (200, 200), + (200, 100), + (10, 1000)] ) def test_heter_grad(self, shape, transpose): print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState() + rng = bm.random.RandomState(seed=seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -295,27 +217,24 @@ def test_heter_grad(self, shape, transpose): # grad 'data' data = bm.as_jax(rng.random(indices.shape)) - r1 = jax.grad(sum_op(bm.event.csrmv))( + r1 = jax.grad(sum_op(bm.sparse.csrmv))( + data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(taichi_csr_matvec))( data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) - dense_data = bm.sparse.csr_to_dense(data, indices, indptr, shape=shape) - r3 = jax.grad(sum_op(lambda a: ((events @ a) if transpose else - (a @ events))))(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - r3 = r3[rows, cols] - self.assertTrue(bm.allclose(r1, r3)) - # grad 'events' - r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( + r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( + data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( + data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r3, r4)) + + r5 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( + r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(lambda e: ((e @ dense_data) if transpose else - (dense_data @ e))))(events.astype(float)) - self.assertTrue(bm.allclose(r4, r5)) - self.assertTrue(bm.allclose(r4, r6)) + self.assertTrue(bm.allclose(r5[0], r6[0])) + self.assertTrue(bm.allclose(r5[1], r6[1])) bm.clear_buffer_memory() diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_gpu.py b/brainpy/_src/math/event/tests/test_event_csrmv_gpu.py deleted file mode 100644 index a5b8df152..000000000 --- a/brainpy/_src/math/event/tests/test_event_csrmv_gpu.py +++ /dev/null @@ -1,15 +0,0 @@ -# -*- coding: utf-8 -*- - - -import jax -import pytest - -import test_event_csrmv - -if jax.default_backend() != 'gpu': - pytest.skip("No gpu available.", allow_module_level=True) - - -class Test_event_csr_matvec_GPU(test_event_csrmv.Test_event_csr_matvec): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, platform='gpu') diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_old.py b/brainpy/_src/math/event/tests/test_event_csrmv_old.py new file mode 100644 index 000000000..31a6527a2 --- /dev/null +++ b/brainpy/_src/math/event/tests/test_event_csrmv_old.py @@ -0,0 +1,324 @@ +# -*- coding: utf-8 -*- + + +from functools import partial + +import jax +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm +import platform + +import pytest +pytest.skip('Old implementation.', allow_module_level=True) + +is_manual_test = False +# if platform.system() == 'Windows' and not is_manual_test: +# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) + +brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib') +taichi_csr_matvec = partial(bm.event.csrmv, method='taichi') + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + +class Test_event_csr_matvec(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_event_csr_matvec, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.named_parameters( + dict( + testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', + transpose=transpose, + shape=shape, + homo_data=homo_data, + ) + for transpose in [True, False] + for shape in [(100, 200), + (200, 200), + (200, 100), + (10, 1000), + (2, 10000), + (1000, 10), + (10000, 2)] + for homo_data in [-1., 0., 1.] + ) + def test_homo(self, shape, transpose, homo_data): + print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') + + rng = bm.random.RandomState() + indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + heter_data = bm.ones(indices.shape) * homo_data + + r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = brainpylib_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + + r3 = brainpylib_csr_matvec(homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r3)) + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r4 = (events @ dense) if transpose else (dense @ events) + self.assertTrue(bm.allclose(r1, r4)) + + r5 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r5)) + + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict( + testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', + transpose=transpose, + shape=shape, + homo_data=homo_data, + ) + for transpose in [True, False] + for shape in [(100, 200), + (200, 200), + (200, 100), + (10, 1000), + (2, 10000), + (1000, 10), + (100000, 2)] + for homo_data in [-1., 0., 1.] + ) + def test_homo_vmap(self, shape, transpose, homo_data): + print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') + + rng = bm.random.RandomState() + indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') + + # vmap 'data' + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, + shape=shape, transpose=transpose)) + f2 = jax.vmap( + partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), + shape=shape, transpose=transpose)) + vmap_data = bm.as_jax([homo_data] * 10) + self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) + + # vmap 'events' + f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr, + shape=shape, transpose=transpose)) + f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), homo_data, indices, indptr, + shape=shape, transpose=transpose)) + vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 + self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) + + # vmap 'data' and 'events' + f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose, + method='cusparse')) + vmap_data1 = bm.as_jax([homo_data] * 10) + vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 + self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), + f6(vmap_data1, vmap_data2.astype(float)))) + + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict( + testcase_name=f'transpose={transpose},shape={shape},homo_data={homo_data}', + homo_data=homo_data, + shape=shape, + transpose=transpose, + ) + for transpose in [True, False] + for shape in [(100, 200), + (200, 200), + (200, 100), + (10, 1000), + (2, 10000), + (1000, 10), + (100000, 2)] + for homo_data in [-1., 0., 1.] + ) + def test_homo_grad(self, shape, transpose, homo_data): + print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') + + rng = bm.random.RandomState() + indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) + + # grad 'data' + r1 = jax.grad(sum_op(brainpylib_csr_matvec))( + homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( + homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + r3 = jax.grad(sum_op(lambda a: (events @ (dense_conn * a) if transpose else + ((dense_conn * a) @ events))))(homo_data) + self.assertTrue(bm.allclose(r1, r3)) + + # grad 'events' + r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( + homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r6 = jax.grad(sum_op(lambda e: (e @ (dense_conn * homo_data) if transpose else + ((dense_conn * homo_data) @ e))))(events.astype(float)) + self.assertTrue(bm.allclose(r4, r5)) + self.assertTrue(bm.allclose(r4, r6)) + + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict( + testcase_name=f'transpose={transpose}, shape={shape}', + shape=shape, + transpose=transpose, + ) + for transpose in [True, False] + for shape in [(100, 200), + (200, 200), + (200, 100), + (10, 1000), + (2, 10000), + (1000, 10), + (10000, 2)] + ) + def test_heter(self, shape, transpose): + print(f'test_heter: shape = {shape}, transpose = {transpose}') + + rng = bm.random.RandomState() + indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + heter_data = bm.as_jax(rng.random(indices.shape)) + + r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events, + shape=shape, transpose=transpose) + r2 = partial(bm.sparse.csrmv, method='cusparse')(heter_data, indices, indptr, events.astype(float), + shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r3 = (events @ dense) if transpose else (dense @ events) + self.assertTrue(bm.allclose(r1, r3)) + + r4 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), + shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r4)) + + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict( + testcase_name=f"transpose={transpose}, shape={shape}", + shape=shape, + transpose=transpose, + ) + for transpose in [True, False] + for shape in [(100, 200), + (200, 200), + (200, 100), + (10, 1000), + (2, 10000), + (1000, 10), + (100000, 2)] + ) + def test_heter_vmap(self, shape, transpose): + print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') + + rng = bm.random.RandomState() + indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + + # vmap 'data' + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, + shape=shape, transpose=transpose)) + f2 = jax.vmap( + partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), + shape=shape, transpose=transpose)) + vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) + self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) + + # vmap 'events' + data = bm.as_jax(rng.random(indices.shape)) + f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr, + shape=shape, transpose=transpose)) + f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), data, indices, indptr, + shape=shape, transpose=transpose)) + vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 + self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) + + # vmap 'data' and 'events' + f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, + shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: partial(bm.sparse.csrmv, method='cusparse')(dd, indices, indptr, ee, + shape=shape, transpose=transpose)) + vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) + vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 + self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), + f6(vmap_data1, vmap_data2.astype(float)))) + + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=f'transpose={transpose},shape={shape}', + shape=shape, + transpose=transpose, + ) + for transpose in [True, False] + for shape in [(100, 200), + (200, 200), + (200, 100), + (10, 1000), + (2, 10000), + (1000, 10), + (100000, 2)] + ) + def test_heter_grad(self, shape, transpose): + print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') + + rng = bm.random.RandomState() + indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) + + # grad 'data' + data = bm.as_jax(rng.random(indices.shape)) + r1 = jax.grad(sum_op(brainpylib_csr_matvec))( + data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( + data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + + dense_data = bm.sparse.csr_to_dense(data, indices, indptr, shape=shape) + r3 = jax.grad(sum_op(lambda a: ((events @ a) if transpose else + (a @ events))))(dense_data) + rows, cols = bm.sparse.csr_to_coo(indices, indptr) + r3 = r3[rows, cols] + self.assertTrue(bm.allclose(r1, r3)) + + # grad 'events' + r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( + data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r6 = jax.grad(sum_op(lambda e: ((e @ dense_data) if transpose else + (dense_data @ e))))(events.astype(float)) + self.assertTrue(bm.allclose(r4, r5)) + self.assertTrue(bm.allclose(r4, r6)) + + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py b/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py deleted file mode 100644 index b759a4789..000000000 --- a/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py +++ /dev/null @@ -1,246 +0,0 @@ -# -*- coding: utf-8 -*- - - -from functools import partial - -import jax -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - -seed = 1234 - - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -def sum_op2(op): - def func(*args, **kwargs): - r = op(*args, **kwargs)[0] - return r.sum() - - return func - - -class Test_event_csr_matvec_taichi(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_csr_matvec_taichi, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], - ) - def test_homo(self, transpose, shape, homo_data): - print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - heter_data = bm.ones(indices.shape) * homo_data - - r1 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = bm.event.csrmv_taichi(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - - assert (bm.allclose(r1, r2[0])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], - ) - def test_homo_vmap(self, shape, transpose, homo_data): - print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState(seed=seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - - # vmap 'data' - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax([homo_data] * 10) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)[0])) - - # vmap 'events' - f3 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr, - shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(bm.event.csrmv_taichi, homo_data, indices, indptr, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)[0])) - - # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: bm.event.csrmv_taichi(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - - vmap_data1 = bm.as_jax([homo_data] * 10) - vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 - self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2)[0])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], - ) - def test_homo_grad(self, shape, transpose, homo_data): - print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState(seed=seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - - # grad 'data' - r1 = jax.grad(sum_op(bm.event.csrmv))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - # grad 'events' - r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000), ] - ) - def test_heter(self, shape, transpose): - print(f'test_heter: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - heter_data = bm.as_jax(rng.random(indices.shape)) - - r1 = bm.event.csrmv(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) - r2 = bm.event.csrmv_taichi(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) - - assert (bm.allclose(r1, r2[0])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)] - ) - def test_heter_vmap(self, shape, transpose): - print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState(seed=seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - # vmap 'data' - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)[0])) - - # vmap 'events' - data = bm.as_jax(rng.random(indices.shape)) - f3 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr, - shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(bm.event.csrmv_taichi, data, indices, indptr, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)[0])) - - # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: bm.event.csrmv_taichi(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) - vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) - vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 - self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2)[0])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)] - ) - def test_heter_grad(self, shape, transpose): - print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState(seed=seed) - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - - # grad 'data' - data = bm.as_jax(rng.random(indices.shape)) - r1 = jax.grad(sum_op(bm.event.csrmv))( - data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))( - data, indices, indptr, events, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - # grad 'events' - r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r3, r4)) - - r5 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=(0, 3))( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/__init__.py b/brainpy/_src/math/jitconn/__init__.py index 439324152..a79cdc982 100644 --- a/brainpy/_src/math/jitconn/__init__.py +++ b/brainpy/_src/math/jitconn/__init__.py @@ -1,5 +1,3 @@ from ._matvec import * -from ._matvec_taichi import * -from ._event_matvec import * -from ._event_matvec_taichi import * +from ._event_matvec import * \ No newline at end of file diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index d739919f7..7971b4a92 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -10,18 +10,29 @@ from jax.interpreters import xla, ad from jax.lib import xla_client -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.jitconn._matvec import (mv_prob_homo_p, mv_prob_uniform_p, mv_prob_normal_p, mv_prob_homo, mv_prob_uniform, - mv_prob_normal) + mv_prob_normal, + _general_checking, + raw_mv_prob_homo, + raw_mv_prob_uniform, + raw_mv_prob_normal, + _mv_prob_homo_transpose, + _mv_prob_uniform_transpose, + _mv_prob_normal_transpose, + _reverse) from brainpy._src.math.ndarray import _get_dtype -from brainpy._src.math.op_register import register_general_batching +from brainpy._src.math.op_register import register_general_batching, XLACustomOp +from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) from brainpy.errors import GPUOperatorNotFound +ti = import_taichi() + __all__ = [ 'event_mv_prob_homo', 'event_mv_prob_uniform', @@ -38,6 +49,58 @@ def event_mv_prob_homo( shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, +) -> jax.Array: + return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + +event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ + + +def event_mv_prob_uniform( + events: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + +event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ + + +def event_mv_prob_normal( + events: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + +### BRAINPYLIB ### + +def event_mv_prob_homo_brainpylib( + events: jax.Array, + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, ) -> jax.Array: events = as_jax(events) weight = jnp.atleast_1d(as_jax(weight)) @@ -57,10 +120,10 @@ def event_mv_prob_homo( return r -event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ +event_mv_prob_homo_brainpylib.__doc__ = mv_prob_homo.__doc__ -def event_mv_prob_uniform( +def event_mv_prob_uniform_brainpylib( events: jax.Array, w_low: float, w_high: float, @@ -90,10 +153,10 @@ def event_mv_prob_uniform( outdim_parallel=outdim_parallel)[0] -event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ +event_mv_prob_uniform_brainpylib.__doc__ = mv_prob_uniform.__doc__ -def event_mv_prob_normal( +def event_mv_prob_normal_brainpylib( events: jax.Array, w_mu: float, w_sigma: float, @@ -123,7 +186,7 @@ def event_mv_prob_normal( outdim_parallel=outdim_parallel)[0] -event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__ +event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__ def _event_matvec_prob_homo_abstract( @@ -665,3 +728,1261 @@ def _event_matvec_prob_normal_transpose( register_general_batching(event_mv_prob_normal_p) ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose + + +### TAICHI ### + +def event_mv_prob_homo_taichi( + events: jax.Array, + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + events: Array, ndarray + The events. + weight: float + The value of the random matrix. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + events = as_jax(events) + if isinstance(weight, float): weight = as_jax(weight) + weight = jnp.atleast_1d(as_jax(weight)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_event_mv_prob_homo(events, weight, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + + +def event_mv_prob_uniform_taichi( + events: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + events: Array, ndarray + The events. + w_low: float + Lower boundary of the output interval. + w_high: float + Upper boundary of the output interval. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + events = as_jax(events) + if isinstance(w_low, float): w_low = as_jax(w_low) + if isinstance(w_high, float): w_high = as_jax(w_high) + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_event_mv_prob_uniform(events, w_low, w_high, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + + +def event_mv_prob_normal_taichi( + events: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a normal distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + events: Array, ndarray + The events. + w_mu: float + Mean (centre) of the distribution. + w_sigma: float + Standard deviation (spread or “width”) of the distribution. Must be non-negative. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + events = as_jax(events) + if isinstance(w_mu, float): w_mu = as_jax(w_mu) + if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma) + w_mu = jnp.atleast_1d(as_jax(w_mu)) + w_sigma = jnp.atleast_1d(as_jax(w_sigma)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_event_mv_prob_normal(events, w_mu, w_sigma, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + + +# ------------- +# CPU function +# ------------- +# For each non-zero event value, it generates a random key using a +# function lfsr88_key and then uses this key to compute random integers +# and update the out array based on the computed indices and weight. +# +# The function is likely designed to be parallelized. + + +@ti.kernel +def _event_mv_prob_homo_bool_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_homo_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + if events[i_col]: + r += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + +# ------------- +# GPU function +# ------------- +# Contrary to the CPU functions, for each column, +# this function will 32 threads (one warp) to make +# the just-in-time random generation parallelized. + + +@ti.kernel +def _event_mv_prob_homo_bool_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_homo_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + r += weight0 * events[i_col] # TODO: speed comparison without if else + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +def _reverse(shape): + return shape[::-1] + + +# ------------- +# CPU function +# ------------- +# For each non-zero event value, it generates a random key using a +# function lfsr88_key and then uses this key to compute random integers +# and update the out array based on the computed indices and weight. +# +# The function is likely designed to be parallelized. + + +@ti.kernel +def _event_mv_prob_homo_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_homo_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + if events[i_col] != 0.: + r += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r # TODO: warp-level reduction + + +# ------------- +# GPU function +# ------------- +# Contrary to the CPU functions, for each column, +# this function will 32 threads (one warp) to make +# the just-in-time random generation parallelized. + + +@ti.kernel +def _event_mv_prob_homo_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_homo_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + r += weight0 * events[i_col] # TODO: speed comparison with if else + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +def _event_mv_prob_homo_jvp_events( + evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(evt_dot, weight, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _event_mv_prob_homo_jvp_weight( + w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(events, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + + +def raw_event_mv_prob_homo( + events: jax.Array, + weight: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_homo_outdim_parallel_bool_p + else: + prim = _event_mv_prob_homo_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_homo_bool_p + else: + prim = _event_mv_prob_homo_p + + return prim(events, + weight, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_homo_jvp_events, + _event_mv_prob_homo_jvp_weight, + None, + None) + prim.def_transpose_rule(_mv_prob_homo_transpose) + return prim + + +# outdim_parallel = True, events.dtype = jnp.bool_ +_event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu +) + +# outdim_parallel = False, events.dtype = jnp.bool_ +_event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_bool_cpu, + gpu_kernel=_event_mv_prob_homo_bool_gpu +) + +# outdim_parallel = True, events.dtype != jnp.bool_ +_event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu +) + +# outdim_parallel = False, events.dtype != jnp.bool_ +_event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_cpu, + gpu_kernel=_event_mv_prob_homo_gpu +) + + +@ti.kernel +def _event_mv_prob_uniform_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_uniform_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + if events[i_col]: + r += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + +@ti.kernel +def _event_mv_prob_uniform_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_uniform_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + r += row_v * events[i_col] # TODO: speed comparison without if else + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +@ti.kernel +def _event_mv_prob_uniform_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_uniform_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + if events[i_col] != 0.: + r += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r # TODO: warp-level reduction + + +@ti.kernel +def _event_mv_prob_uniform_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_uniform_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + r += row_v * events[i_col] # TODO: speed comparison with if else + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +def _event_mv_prob_uniform_jvp_events( + evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _event_mv_prob_uniform_jvp_w_low( + w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _event_mv_prob_uniform_jvp_w_high( + w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def raw_event_mv_prob_uniform( + events: jax.Array, + w_low: jax.Array, # vector with size 1 + w_high: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_uniform_outdim_parallel_bool_p + else: + prim = _event_mv_prob_uniform_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_uniform_bool_p + else: + prim = _event_mv_prob_uniform_p + + return prim(events, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_uniform_jvp_events, + _event_mv_prob_uniform_jvp_w_low, + _event_mv_prob_uniform_jvp_w_high, + None, + None) + prim.def_transpose_rule(_mv_prob_uniform_transpose) + return prim + + +# outdim_parallel = True, events.dtype = jnp.bool_ +_event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu +) + +# outdim_parallel = False, events.dtype = jnp.bool_ +_event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_bool_cpu, + gpu_kernel=_event_mv_prob_uniform_bool_gpu +) + +# outdim_parallel = True, events.dtype != jnp.bool_ +_event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu +) + +# outdim_parallel = False, events.dtype != jnp.bool_ +_event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_cpu, + gpu_kernel=_event_mv_prob_uniform_gpu +) + + +@ti.kernel +def _event_mv_prob_normal_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_normal_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + if events[i_col]: + r += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + +@ti.kernel +def _event_mv_prob_normal_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_normal_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += row_v * events[i_col] # TODO: speed comparison without if else + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +@ti.kernel +def _event_mv_prob_normal_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_normal_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + if events[i_col] != 0.: + r += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + +@ti.kernel +def _event_mv_prob_normal_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _event_mv_prob_normal_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += row_v * events[i_col] # TODO: speed comparison with if else + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +def _event_mv_prob_normal_jvp_events( + evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _event_mv_prob_normal_jvp_w_mu( + w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _event_mv_prob_normal_jvp_w_sigma( + w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def raw_event_mv_prob_normal( + events: jax.Array, + w_mu: jax.Array, # vector with size 1 + w_sigma: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_normal_outdim_parallel_bool_p + else: + prim = _event_mv_prob_normal_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_normal_bool_p + else: + prim = _event_mv_prob_normal_p + + return prim(events, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_normal_jvp_events, + _event_mv_prob_normal_jvp_w_mu, + _event_mv_prob_normal_jvp_w_sigma, + None, + None) + prim.def_transpose_rule(_mv_prob_normal_transpose) + return prim + + +# outdim_parallel = True, events.dtype = jnp.bool_ +_event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu +) + +# outdim_parallel = False, events.dtype = jnp.bool_ +_event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_bool_cpu, + gpu_kernel=_event_mv_prob_normal_bool_gpu +) + +# outdim_parallel = True, events.dtype != jnp.bool_ +_event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu +) + +# outdim_parallel = False, events.dtype != jnp.bool_ +_event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_cpu, + gpu_kernel=_event_mv_prob_normal_gpu +) diff --git a/brainpy/_src/math/jitconn/_event_matvec_taichi.py b/brainpy/_src/math/jitconn/_event_matvec_taichi.py deleted file mode 100644 index 8346607aa..000000000 --- a/brainpy/_src/math/jitconn/_event_matvec_taichi.py +++ /dev/null @@ -1,1277 +0,0 @@ -# -*- coding: utf-8 -*- - - -from typing import Tuple, Optional - -import jax -import numpy as np -from jax import numpy as jnp - -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import _get_dtype -from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_uniform, lfsr88_normal, lfsr88_random_integers) -from ._matvec_taichi import (_general_checking, raw_mv_prob_homo, raw_mv_prob_uniform, raw_mv_prob_normal, - _mv_prob_homo_transpose, _mv_prob_uniform_transpose, _mv_prob_normal_transpose, - _reverse) - -ti = import_taichi() - -__all__ = [ - 'event_mv_prob_homo_taichi', - 'event_mv_prob_uniform_taichi', - 'event_mv_prob_normal_taichi', -] - - -# ------------- -# CPU function -# ------------- -# For each non-zero event value, it generates a random key using a -# function lfsr88_key and then uses this key to compute random integers -# and update the out array based on the computed indices and weight. -# -# The function is likely designed to be parallelized. - - -@ti.kernel -def _event_mv_prob_homo_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - if events[i_col]: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -# ------------- -# GPU function -# ------------- -# Contrary to the CPU functions, for each column, -# this function will 32 threads (one warp) to make -# the just-in-time random generation parallelized. - - -@ti.kernel -def _event_mv_prob_homo_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -# ------------- -# CPU function -# ------------- -# For each non-zero event value, it generates a random key using a -# function lfsr88_key and then uses this key to compute random integers -# and update the out array based on the computed indices and weight. -# -# The function is likely designed to be parallelized. - - -@ti.kernel -def _event_mv_prob_homo_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - if events[i_col] != 0.: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - -# ------------- -# GPU function -# ------------- -# Contrary to the CPU functions, for each column, -# this function will 32 threads (one warp) to make -# the just-in-time random generation parallelized. - - -@ti.kernel -def _event_mv_prob_homo_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_homo_jvp_events( - evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(evt_dot, weight, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_homo_jvp_weight( - w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(events, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - -def raw_event_mv_prob_homo( - events: jax.Array, - weight: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_outdim_parallel_bool_p - else: - prim = _event_mv_prob_homo_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_bool_p - else: - prim = _event_mv_prob_homo_p - - return prim(events, - weight, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def event_mv_prob_homo_taichi( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - events = as_jax(events) - if isinstance(weight, float): weight = as_jax(weight) - weight = jnp.atleast_1d(as_jax(weight)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_homo(events, weight, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_homo_jvp_events, - _event_mv_prob_homo_jvp_weight, - None, - None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_bool_cpu, - gpu_kernel=_event_mv_prob_homo_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_cpu, - gpu_kernel=_event_mv_prob_homo_gpu -) - - -@ti.kernel -def _event_mv_prob_uniform_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_uniform_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_uniform_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_uniform_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_uniform_jvp_events( - evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_uniform_jvp_w_low( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_uniform_jvp_w_high( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def raw_event_mv_prob_uniform( - events: jax.Array, - w_low: jax.Array, # vector with size 1 - w_high: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_outdim_parallel_bool_p - else: - prim = _event_mv_prob_uniform_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_bool_p - else: - prim = _event_mv_prob_uniform_p - - return prim(events, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def event_mv_prob_uniform_taichi( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - events = as_jax(events) - if isinstance(w_low, float): w_low = as_jax(w_low) - if isinstance(w_high, float): w_high = as_jax(w_high) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_uniform(events, w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_uniform_jvp_events, - _event_mv_prob_uniform_jvp_w_low, - _event_mv_prob_uniform_jvp_w_high, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_cpu, - gpu_kernel=_event_mv_prob_uniform_gpu -) - - -@ti.kernel -def _event_mv_prob_normal_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_normal_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_normal_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_normal_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_normal_jvp_events( - evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_normal_jvp_w_mu( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_normal_jvp_w_sigma( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def raw_event_mv_prob_normal( - events: jax.Array, - w_mu: jax.Array, # vector with size 1 - w_sigma: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_outdim_parallel_bool_p - else: - prim = _event_mv_prob_normal_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_bool_p - else: - prim = _event_mv_prob_normal_p - - return prim(events, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def event_mv_prob_normal_taichi( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - events = as_jax(events) - if isinstance(w_mu, float): w_mu = as_jax(w_mu) - if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_normal(events, w_mu, w_sigma, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_normal_jvp_events, - _event_mv_prob_normal_jvp_w_mu, - _event_mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_bool_cpu, - gpu_kernel=_event_mv_prob_normal_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_cpu, - gpu_kernel=_event_mv_prob_normal_gpu -) diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index cad95924d..e33a0ab1e 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -11,12 +11,15 @@ from jax.interpreters import xla, ad from jax.lib import xla_client -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype -from brainpy._src.math.op_register import register_general_batching +from brainpy._src.math.op_register import register_general_batching, XLACustomOp +from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) from brainpy.errors import GPUOperatorNotFound +ti = import_taichi() + __all__ = [ 'mv_prob_homo', 'mv_prob_uniform', @@ -49,6 +52,200 @@ def mv_prob_homo( When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + weight: float + The value of the random matrix. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + +def mv_prob_uniform( + vector: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + w_low: float + Lower boundary of the output interval. + w_high: float + Upper boundary of the output interval. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + +def mv_prob_normal( + vector: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a normal distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + w_mu: float + Mean (centre) of the distribution. + w_sigma: float + Standard deviation (spread or “width”) of the distribution. Must be non-negative. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + +### BRAINYPLIB ### + +def mv_prob_homo_brainpylib( + vector: Union[Array, jax.Array], + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + .. note:: Note that the just-in-time generated :math:`M` (`transpose=False`) is @@ -100,7 +297,7 @@ def mv_prob_homo( )[0] -def mv_prob_uniform( +def mv_prob_uniform_brainpylib( vector: jax.Array, w_low: float, w_high: float, @@ -180,7 +377,7 @@ def mv_prob_uniform( outdim_parallel=outdim_parallel)[0] -def mv_prob_normal( +def mv_prob_normal_brainpylib( vector: jax.Array, w_mu: float, w_sigma: float, @@ -817,3 +1014,892 @@ def _matvec_prob_normal_transpose( register_general_batching(mv_prob_normal_p) ad.primitive_jvps[mv_prob_normal_p] = _matvec_prob_normal_jvp ad.primitive_transposes[mv_prob_normal_p] = _matvec_prob_normal_transpose + + +### TAICHI ### +def mv_prob_homo_taichi( + vector: Union[Array, jax.Array], + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Generally, the :math:`M` in ``f(outdim_parallel=True, transpose=False)`` is the same of + the :math:`M^T` used in ``f(outdim_parallel=False, transpose=True)``. + + Similarly, the :math:`M^T` in ``f(outdim_parallel=True, transpose=True)`` is the same + of the :math:`M` used in ``f(outdim_parallel=False, transpose=False)``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + weight: float + The value of the random matrix. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + vector = as_jax(vector) + if isinstance(weight, float): + weight = as_jax(weight, dtype=vector.dtype) + weight = jnp.atleast_1d(as_jax(weight)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.asarray(seed, dtype=jnp.uint32) + seed = jnp.atleast_1d(seed) + return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + + +def mv_prob_uniform_taichi( + vector: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + w_low: float + Lower boundary of the output interval. + w_high: float + Upper boundary of the output interval. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + vector = as_jax(vector) + if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) + if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + + +def mv_prob_normal_taichi( + vector: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a normal distribution for its value. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + vector: Array, ndarray + The vector. + w_mu: float + Mean (centre) of the distribution. + w_sigma: float + Standard deviation (spread or “width”) of the distribution. Must be non-negative. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ + vector = as_jax(vector) + if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) + if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype) + w_mu = jnp.atleast_1d(as_jax(w_mu)) + w_sigma = jnp.atleast_1d(as_jax(w_sigma)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] + + +def _reverse(shape): + return shape[::-1] + + +@ti.kernel +def _mv_prob_homo_cpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + v = vector[i_col] * weight0 + while i_row < num_row: + out[i_row] += v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _mv_prob_homo_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + r += vector[i_col] + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r * weight0 + + +@ti.kernel +def _mv_prob_homo_gpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + out[i_row] += weight0 * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _mv_prob_homo_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + r += vector[i_col] + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += weight0 * r # TODO: warp-level reduction + + +def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_homo_transpose( + ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), weight, clen, seed + else: + dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, weight, clen, seed + elif ad.is_undefined_primal(weight): + if type(ct) is ad.Zero: + return vector, ad.Zero(weight), clen, seed + else: + row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + dw = jnp.sum(row * vector, keepdims=True) + return vector, dw, clen, seed + else: + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + +def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + if vector.ndim != 1: + raise ValueError('vector should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('conn_prob must be a 1D scalar.') + + assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + + for weight in weights: + if weight.ndim != 1: + raise ValueError('weight must be a 1D scalar.') + assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' + + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be boolean value.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be boolean value.') + + if transpose: + out_shape = (shape[1],) + if vector.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') + shape = _reverse(shape) + else: + if vector.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') + out_shape = (shape[0],) + + return shape, out_shape + + +def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + + +def raw_mv_prob_homo( + vector: jax.Array, + weight: jax.Array, # vector with size 1 + clen: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) + + if outdim_parallel: + prim = _mv_prob_homo_outdim_parallel_p + else: + prim = _mv_prob_homo_p + + return prim(vector, + weight, + clen, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) + prim.def_transpose_rule(_mv_prob_homo_transpose) + return prim + + +# outdim_parallel = True +_mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, + gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) + +# outdim_parallel = False +_mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, + gpu_kernel=_mv_prob_homo_gpu) + + +@ti.kernel +def _mv_prob_uniform_cpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + col_v = vector[i_col] + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += col_v * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _mv_prob_uniform_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_uniform(key, w_min0, w_max0) + r += vector[i_col] * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + +@ti.kernel +def _mv_prob_uniform_gpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _mv_prob_uniform_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + r += vector[i_col] * row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_uniform_transpose( + ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), w_low, w_high, clen, seed + else: + dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, w_low, w_high, clen, seed + else: + assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.' + assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.' + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + +def raw_mv_prob_uniform( + vector: jax.Array, + w_low: jax.Array, + w_high: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) + + if outdim_parallel: + prim = _mv_prob_uniform_outdim_parallel_p + else: + prim = _mv_prob_uniform_p + + return prim(vector, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_uniform_jvp_vector, + _mv_prob_uniform_jvp_wlow, + _mv_prob_uniform_jvp_whigh, + None, + None) + prim.def_transpose_rule(_mv_prob_uniform_transpose) + return prim + + +# outdim_parallel = True +_mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( + cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, + gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu +) + +# outdim_parallel = False +_mv_prob_uniform_p = _define_mv_prob_uniform_prim( + cpu_kernel=_mv_prob_uniform_cpu, + gpu_kernel=_mv_prob_uniform_gpu +) + + +@ti.kernel +def _mv_prob_normal_cpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + col_v = vector[i_col] + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += col_v * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _mv_prob_normal_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += vector[i_col] * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + +@ti.kernel +def _mv_prob_normal_gpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + +@ti.kernel +def _mv_prob_normal_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) +): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += vector[i_col] * row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + +def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + +def _mv_prob_normal_transpose( + ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), w_mu, w_sigma, clen, seed + else: + dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, w_mu, w_sigma, clen, seed + else: + assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' + assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + +def raw_mv_prob_normal( + vector: jax.Array, + w_mu: jax.Array, + w_sigma: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + + if outdim_parallel: + prim = _mv_prob_normal_outdim_parallel_p + else: + prim = _mv_prob_normal_p + + return prim(vector, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_normal_jvp_vector, + _mv_prob_normal_jvp_w_mu, + _mv_prob_normal_jvp_w_sigma, + None, + None) + prim.def_transpose_rule(_mv_prob_normal_transpose) + return prim + + +# outdim_parallel = True +_mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( + cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, + gpu_kernel=_mv_prob_normal_outdim_parallel_gpu +) + +# outdim_parallel = False +_mv_prob_normal_p = _define_mv_prob_normal_prim( + cpu_kernel=_mv_prob_normal_cpu, + gpu_kernel=_mv_prob_normal_gpu +) diff --git a/brainpy/_src/math/jitconn/_matvec_taichi.py b/brainpy/_src/math/jitconn/_matvec_taichi.py deleted file mode 100644 index beaf2c383..000000000 --- a/brainpy/_src/math/jitconn/_matvec_taichi.py +++ /dev/null @@ -1,911 +0,0 @@ -# -*- coding: utf-8 -*- - - -from typing import Tuple, Optional, Union - -import jax -import numpy as np -from jax import numpy as jnp -from jax.interpreters import ad - -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import Array, _get_dtype -from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) - -ti = import_taichi() - -__all__ = [ - 'mv_prob_homo_taichi', - 'mv_prob_uniform_taichi', - 'mv_prob_normal_taichi', -] - - -def _reverse(shape): - return shape[::-1] - - -@ti.kernel -def _mv_prob_homo_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - v = vector[i_col] * weight0 - while i_row < num_row: - out[i_row] += v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_homo_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r * weight0 - - -@ti.kernel -def _mv_prob_homo_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_homo_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += weight0 * r # TODO: warp-level reduction - - -def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_homo_transpose( - ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), weight, clen, seed - else: - dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, weight, clen, seed - elif ad.is_undefined_primal(weight): - if type(ct) is ad.Zero: - return vector, ad.Zero(weight), clen, seed - else: - row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] - dw = jnp.sum(row * vector, keepdims=True) - return vector, dw, clen, seed - else: - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - - assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - - for weight in weights: - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - out_shape = (shape[1],) - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') - shape = _reverse(shape) - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out_shape = (shape[0],) - - return shape, out_shape - - -def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - -def raw_mv_prob_homo( - vector: jax.Array, - weight: jax.Array, # vector with size 1 - clen: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - prim = _mv_prob_homo_outdim_parallel_p - else: - prim = _mv_prob_homo_p - - return prim(vector, - weight, - clen, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def mv_prob_homo_taichi( - vector: Union[Array, jax.Array], - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Generally, the :math:`M` in ``f(outdim_parallel=True, transpose=False)`` is the same of - the :math:`M^T` used in ``f(outdim_parallel=False, transpose=True)``. - - Similarly, the :math:`M^T` in ``f(outdim_parallel=True, transpose=True)`` is the same - of the :math:`M` used in ``f(outdim_parallel=False, transpose=False)``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - if isinstance(weight, float): - weight = as_jax(weight, dtype=vector.dtype) - weight = jnp.atleast_1d(as_jax(weight)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.asarray(seed, dtype=jnp.uint32) - seed = jnp.atleast_1d(seed) - return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) - -# outdim_parallel = False -_mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, - gpu_kernel=_mv_prob_homo_gpu) - - -@ti.kernel -def _mv_prob_uniform_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_uniform_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _mv_prob_uniform_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_uniform_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_transpose( - ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_low, w_high, clen, seed - else: - dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_low, w_high, clen, seed - else: - assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.' - assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def raw_mv_prob_uniform( - vector: jax.Array, - w_low: jax.Array, - w_high: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - prim = _mv_prob_uniform_outdim_parallel_p - else: - prim = _mv_prob_uniform_p - - return prim(vector, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def mv_prob_uniform_taichi( - vector: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) - if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_uniform_jvp_vector, - _mv_prob_uniform_jvp_wlow, - _mv_prob_uniform_jvp_whigh, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu -) - -# outdim_parallel = False -_mv_prob_uniform_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_cpu, - gpu_kernel=_mv_prob_uniform_gpu -) - - -@ti.kernel -def _mv_prob_normal_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_normal_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _mv_prob_normal_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_normal_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_mu, w_sigma, clen, seed - else: - dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_mu, w_sigma, clen, seed - else: - assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' - assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def raw_mv_prob_normal( - vector: jax.Array, - w_mu: jax.Array, - w_sigma: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - - if outdim_parallel: - prim = _mv_prob_normal_outdim_parallel_p - else: - prim = _mv_prob_normal_p - - return prim(vector, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def mv_prob_normal_taichi( - vector: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) - if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_normal_jvp_vector, - _mv_prob_normal_jvp_w_mu, - _mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_mv_prob_normal_outdim_parallel_gpu -) - -# outdim_parallel = False -_mv_prob_normal_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_cpu, - gpu_kernel=_mv_prob_normal_gpu -) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index 556213e89..b10d55d21 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -1,557 +1,520 @@ # -*- coding: utf-8 -*- +from functools import partial import jax import jax.numpy as jnp from absl.testing import parameterized -import platform import brainpy.math as bm -import pytest +shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] +shapes = [(100, 200), (2, 1000), (1000, 2)] -is_manual_test = False -if platform.system() == 'Windows' and not is_manual_test: - pytest.skip('Under windows, brainpy.math package may need manual tests.', allow_module_level=True) - -shapes = [(100, 200), - # (10, 1000), - (2, 1000), - # (1000, 10), - (1000, 2)] +taichi_mv_prob_homo = bm.jitconn.event_mv_prob_homo +taichi_mv_prob_uniform = bm.jitconn.event_mv_prob_uniform +taichi_mv_prob_normal = bm.jitconn.event_mv_prob_normal class Test_event_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - homo_data=[-1., ], - bool_event=[True, False], - seed=[1234], - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=None, x64=False): - print(f'_test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - r1 = bm.jitconn.event_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = bm.jitconn.event_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - - r3 = bm.jitconn.event_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - r3 = jax.block_until_ready(r3) - self.assertTrue(jnp.allclose(r1, r3)) - - # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') - # indices = bm.as_jax(indices) - # indptr = bm.as_jax(indptr) - # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, - # shape=shape, transpose=transpose) - # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - bool_event=[True, False], - seed=[1234], - ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=None, x64=False): - print(f'_test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: bm.jitconn.event_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - transpose=transpose, outdim_parallel=outdim_parallel - ) + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.01, 0.1, 0.5], + homo_data=[-1., ], + bool_event=[True, False], + seed=[1234], ) - r1 = f1(events, weights) - r1 = jax.block_until_ready(r1) - r2 = f1(events, weights) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.5] - ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.5 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: bm.jitconn.event_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose - ).sum(), - argnums=0 + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False): + print(f'_test_homo: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + + r1 = taichi_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = taichi_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') + # indices = bm.as_jax(indices) + # indptr = bm.as_jax(indptr) + # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, + # shape=shape, transpose=transpose) + # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.01, 0.1, 0.5], + bool_event=[True, False], + seed=[1234], ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - - r3 = f1(events, 3.) - r3 = jax.block_until_ready(r3) - - self.assertTrue(jnp.allclose(r1 * 3., r3)) - self.assertTrue(jnp.allclose(r1 * 2., r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - bool_event=bool_event, - seed=1234, - x64=x64 - ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.4] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] - for bool_event in [True, False] - ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, - bool_event=True, seed=None, x64=False): - print(f'_test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = bm.jitconn.event_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = bm.jitconn.event_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - - r3 = bm.jitconn.event_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - r3 = jax.block_until_ready(r3) - self.assertTrue(jnp.allclose(r1, r3)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_uniform_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'bool_event={bool_event}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=None, x64=False): - print(f'_test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap( - lambda e: bm.jitconn.event_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False): + print(f'_test_homo_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + weights = bm.as_jax(rng.random(10)) + + f1 = jax.vmap( + lambda event, data: taichi_mv_prob_homo( + event, data, conn_prob=prob, shape=shape, seed=seed, + transpose=transpose, outdim_parallel=outdim_parallel + )[0] + ) + r1 = f1(events, weights) + r1 = jax.block_until_ready(r1) + r2 = f1(events, weights) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}', + shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, seed=1234, + x64=x64) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1, 0.5] ) - - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - testcase_name=f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_high: bm.jitconn.event_mv_prob_uniform( - e, - w_low=0., - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.5 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.grad( + lambda event, data: taichi_mv_prob_homo( + event, data, conn_prob=prob, shape=shape, seed=seed, + outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), + argnums=0 + ) + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + + r3 = f1(events, 3.) + r3 = jax.block_until_ready(r3) + + self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6)) + self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=f'test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}', + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + w_low=w_low, + w_high=w_high, + bool_event=bool_event, + seed=1234, + x64=x64 + ) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1, 0.4] + for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] + for bool_event in [True, False] ) - - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2., r2)) - # print(r1) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu={w_mu}, ' - f'w_sigma={w_sigma}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, ] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] - for bool_event in [True, False] - ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, - bool_event=True, seed=None, x64=False): - print(f'_test_normal: shape = {shape}, ' - f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' - f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = bm.jitconn.event_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = bm.jitconn.event_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - - r3 = bm.jitconn.event_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - r3 = jax.block_until_ready(r3) - self.assertTrue(jnp.allclose(r1, r3)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] - ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=None, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap(lambda e: bm.jitconn.event_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.jit( - jax.grad( - lambda e, w_sigma: bm.jitconn.event_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, + bool_event=True, seed=1234, x64=False): + print(f'_test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + r1 = taichi_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = taichi_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel, prob=prob, + bool_event=bool_event, + x64=x64, + seed=1234, + testcase_name=f'_test_uniform_vmap: ' + f'shape={shape}, ' + f'transpose={transpose}, ' + f'bool_event={bool_event}, ' + f'outdim_parallel={outdim_parallel}, ' + f'prob={prob}, ' + f'x64={x64}') + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + for bool_event in [True, False] + ) + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, + bool_event=True, seed=1234, x64=False): + print(f'_test_uniform_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + f1 = jax.vmap( + lambda e: taichi_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + ) + + r1 = f1(events) + r1 = jax.block_until_ready(r1) + r2 = f1(events) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234, + testcase_name=f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + ) + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.grad( + lambda e, w_high: taichi_mv_prob_uniform( + e, + w_low=0., + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose).sum() + ) + + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + # print(r1) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + w_mu=w_mu, + w_sigma=w_sigma, + bool_event=bool_event, + x64=x64, + seed=1234, + testcase_name=f'_test_normal: ' + f'shape={shape}, ' + f'transpose={transpose}, ' + f'outdim_parallel={outdim_parallel}, ' + f'prob={prob}, ' + f'w_mu={w_mu}, ' + f'w_sigma={w_sigma}, ' + f'bool_event={bool_event}, ' + f'x64={x64}') + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1, ] + for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] + for bool_event in [True, False] + ) + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, + bool_event=True, seed=1234, x64=False): + print(f'_test_normal: shape = {shape}, ' + f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' + f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + r1 = taichi_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = taichi_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + bool_event=bool_event, + x64=x64, + seed=1234, + testcase_name=f'_test_normal_vmap: ' + f'shape={shape}, ' + f'transpose={transpose}, ' + f'outdim_parallel={outdim_parallel}, ' + f'prob={prob}, ' + f'bool_event={bool_event}, ' + f'x64={x64}') + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + for bool_event in [True, False] + ) + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, + bool_event=True, seed=1234, x64=False): + print(f'_test_normal_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) + r1 = f1(events) + r1 = jax.block_until_ready(r1) + r2 = f1(events) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + x64=x64, + seed=1234, + testcase_name=f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2, r2)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.jit( + jax.grad( + lambda e, w_sigma: taichi_mv_prob_normal( + e, + w_mu=0., + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose).sum() + ) + ) + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec_gpu.py b/brainpy/_src/math/jitconn/tests/test_event_matvec_gpu.py deleted file mode 100644 index 778212547..000000000 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec_gpu.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax -import pytest - -import test_event_matvec - -if jax.default_backend() != 'gpu': - pytest.skip("No gpu available.", allow_module_level=True) - - -class Test_event_matvec_prob_conn_GPU(test_event_matvec.Test_event_matvec_prob_conn): - def __init__(self, *args, **kwargs): - super(Test_event_matvec_prob_conn_GPU, self).__init__(*args, **kwargs, platform='gpu') diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py b/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py similarity index 71% rename from brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py rename to brainpy/_src/math/jitconn/tests/test_event_matvec_old.py index e42434e95..b2fa77229 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec_taichi.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec_old.py @@ -1,15 +1,31 @@ # -*- coding: utf-8 -*- - +from functools import partial import jax import jax.numpy as jnp from absl.testing import parameterized +import platform import brainpy.math as bm -shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] -shapes = [(100, 200), (2, 1000), (1000, 2)] - +import pytest +pytest.skip('Old implementation.', allow_module_level=True) +is_manual_test = False +if platform.system() == 'Windows' and not is_manual_test: + pytest.skip('Under windows, brainpy.math package may need manual tests.', allow_module_level=True) + +shapes = [(100, 200), + # (10, 1000), + (2, 1000), + # (1000, 10), + (1000, 2)] + +brainpylib_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='brainpylib') +taichi_mv_prob_homo = partial(bm.jitconn.event_mv_prob_homo, method='taichi') +brainpylib_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='brainpylib') +taichi_mv_prob_uniform = partial(bm.jitconn.event_mv_prob_uniform, method='taichi') +brainpylib_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='brainpylib') +taichi_mv_prob_normal = partial(bm.jitconn.event_mv_prob_normal, method='taichi') class Test_event_matvec_prob_conn(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -44,32 +60,32 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_eve if not bool_event: events = events.astype(float) - r1 = bm.jitconn.event_mv_prob_homo_taichi(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = brainpylib_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = bm.jitconn.event_mv_prob_homo_taichi(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r2 = brainpylib_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = bm.jitconn.event_mv_prob_homo_taichi(events, - homo_data, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r3 = brainpylib_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) r3 = jax.block_until_ready(r3) self.assertTrue(jnp.allclose(r1, r3)) @@ -111,10 +127,10 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=Tru weights = bm.as_jax(rng.random(10)) f1 = jax.vmap( - lambda event, data: bm.jitconn.event_mv_prob_homo_taichi( + lambda event, data: brainpylib_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, transpose=transpose, outdim_parallel=outdim_parallel - )[0] + ) ) r1 = f1(events, weights) r1 = jax.block_until_ready(r1) @@ -155,9 +171,10 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 events = events.astype(float) f1 = jax.grad( - lambda event, data: bm.jitconn.event_mv_prob_homo_taichi( + lambda event, data: brainpylib_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), + outdim_parallel=outdim_parallel, transpose=transpose + ).sum(), argnums=0 ) r1 = f1(events, 1.) @@ -221,35 +238,35 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, if not bool_event: events = events.astype(float) - r1 = bm.jitconn.event_mv_prob_uniform_taichi(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = brainpylib_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = bm.jitconn.event_mv_prob_uniform_taichi(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r2 = brainpylib_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = bm.jitconn.event_mv_prob_uniform_taichi(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r3 = brainpylib_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) r3 = jax.block_until_ready(r3) self.assertTrue(jnp.allclose(r1, r3)) if x64: @@ -292,14 +309,14 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, events = events.astype(float) f1 = jax.vmap( - lambda e: bm.jitconn.event_mv_prob_uniform_taichi(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + lambda e: brainpylib_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) ) r1 = f1(events) @@ -342,7 +359,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, events = events.astype(float) f1 = jax.grad( - lambda e, w_high: bm.jitconn.event_mv_prob_uniform_taichi( + lambda e, w_high: brainpylib_mv_prob_uniform( e, w_low=0., w_high=w_high, @@ -403,35 +420,35 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, if not bool_event: events = events.astype(float) - r1 = bm.jitconn.event_mv_prob_normal_taichi(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = brainpylib_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r1 = jax.block_until_ready(r1) - r2 = bm.jitconn.event_mv_prob_normal_taichi(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r2 = brainpylib_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) r2 = jax.block_until_ready(r2) self.assertTrue(jnp.allclose(r1, r2)) - r3 = bm.jitconn.event_mv_prob_normal_taichi(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r3 = brainpylib_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) r3 = jax.block_until_ready(r3) self.assertTrue(jnp.allclose(r1, r3)) @@ -476,14 +493,14 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, if not bool_event: events = events.astype(float) - f1 = jax.vmap(lambda e: bm.jitconn.event_mv_prob_normal_taichi(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) + f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) r1 = f1(events) r1 = jax.block_until_ready(r1) r2 = f1(events) @@ -526,7 +543,7 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x f1 = jax.jit( jax.grad( - lambda e, w_sigma: bm.jitconn.event_mv_prob_normal_taichi( + lambda e, w_sigma: brainpylib_mv_prob_normal( e, w_mu=0., w_sigma=w_sigma, diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 91c48fc66..2e6e406cf 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -1,65 +1,61 @@ # -*- coding: utf-8 -*- +from functools import partial import jax import jax.numpy as jnp from absl.testing import parameterized import brainpy.math as bm -import platform -import pytest -is_manual_test = False -if platform.system() == 'Windows' and not is_manual_test: - pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) +shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] +shapes = [(100, 200), (2, 1000), (1000, 2)] -shapes = [(100, 200), - (10, 1000), - (2, 1000), - (1000, 10), - (1000, 2)] +taichi_mv_prob_homo = bm.jitconn.mv_prob_homo +taichi_mv_prob_uniform = bm.jitconn.mv_prob_uniform +taichi_mv_prob_normal = bm.jitconn.mv_prob_normal class Test_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - homo_data=homo_data, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for homo_data in [-1., 1.] - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=None, x64=False): - print(f'test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = bm.jitconn.mv_prob_homo(vector, + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.named_parameters( + dict(testcase_name=(f'test_homo, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}, ' + f'x64 = {x64}'), + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + homo_data=homo_data, + seed=1234) + for x64 in [True, False] + for transpose in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + for homo_data in [-1., 1.] + ) + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False): + print(f'test_homo: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = taichi_mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, @@ -67,163 +63,152 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non outdim_parallel=outdim_parallel, transpose=transpose) - r2 = bm.jitconn.mv_prob_homo(vector, + r2 = taichi_mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - self.assertTrue(jnp.allclose(r1, r2)) - - r2 = bm.jitconn.mv_prob_homo(vector, - homo_data, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - self.assertTrue(jnp.allclose(r1, r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: bm.jitconn.mv_prob_homo( - event, data, - conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose - ) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=(f'test_homo_vmap, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}'), + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234, + x64=x64) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] ) - r1 = f1(events, weights) - r2 = f1(events, weights) - self.assertTrue(jnp.allclose(r1, r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: bm.jitconn.mv_prob_homo( - event, data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - ).sum(), - argnums=0 + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'test_homo_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + weights = bm.as_jax(rng.random(10)) + + f1 = jax.vmap( + lambda event, data: taichi_mv_prob_homo( + event, data, + conn_prob=prob, shape=shape, seed=seed, + outdim_parallel=outdim_parallel, transpose=transpose + )[0] + ) + r1 = f1(events, weights) + r2 = f1(events, weights) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=(f'test_homo_grad, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}'), + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234, + x64=x64) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - - self.assertTrue(jnp.allclose(r1 * 2., r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - x64=x64, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] - ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=None, x64=False): - print(f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64 = {x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = bm.jitconn.mv_prob_uniform(events, + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 + events = events.astype(float) + + f1 = jax.grad( + lambda event, data: taichi_mv_prob_homo( + event, data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum(), + argnums=0 + ) + r1 = f1(events, 1.) + r2 = f1(events, 2.) + + self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=(f'test_uniform, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}' + f'x64 = {x64}'), + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + w_low=w_low, + w_high=w_high, + x64=x64, + seed=1234) + for x64 in [True, False] + for transpose in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] + ) + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False): + print(f'test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'x64 = {x64}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = taichi_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -232,7 +217,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s outdim_parallel=outdim_parallel, transpose=transpose) - r2 = bm.jitconn.mv_prob_uniform(events, + r2 = taichi_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -240,58 +225,45 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - r2 = bm.jitconn.mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_uniform(e, + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=f'test_uniform_vmap, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}', + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234, + x64=x64) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + ) + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'test_uniform_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + + f1 = jax.vmap(lambda e: taichi_mv_prob_uniform(e, w_low=0., w_high=1., conn_prob=prob, @@ -300,107 +272,107 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, outdim_parallel=outdim_parallel, transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - self.assertTrue(jnp.allclose(r1, r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - f1 = jax.grad( - lambda e, w_low, w_high: bm.jitconn.mv_prob_uniform( - e, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - ).sum() + r1 = f1(events) + r2 = f1(events) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=(f'test_uniform_grad, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'x64={x64}'), + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234, + x64=x64) + for x64 in [True, False] + for transpose in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] ) - - r1 = f1(events, 0., 1.) - r2 = f1(events, 0., 2.) - - self.assertTrue(bm.allclose(r1 * 2., r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=(f'test_normal, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma},' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - seed=1234 + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + f1 = jax.grad( + lambda e, w_low, w_high: taichi_mv_prob_uniform( + e, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum() + ) + + r1 = f1(events, 0., 1.) + r2 = f1(events, 0., 2.) + + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict( + testcase_name=(f'test_normal, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_mu = {w_mu}, ' + f'w_sigma = {w_sigma},' + f'x64={x64}'), + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + w_mu=w_mu, + w_sigma=w_sigma, + seed=1234 + ) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] - ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=None, x64=False): - print(f'_test_normal: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = bm.jitconn.mv_prob_normal(events, + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False): + print(f'_test_normal: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_mu = {w_mu}, ' + f'w_sigma = {w_sigma}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = taichi_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -409,7 +381,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se outdim_parallel=outdim_parallel, transpose=transpose) - r2 = bm.jitconn.mv_prob_normal(events, + r2 = taichi_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -417,59 +389,46 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(testcase_name=f'test_normal_vmap, shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'x64={x64}', + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234) + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] + ) + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') - r2 = bm.jitconn.mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_normal_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_normal(e, + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + + f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, w_mu=0., w_sigma=1., conn_prob=prob, @@ -477,65 +436,66 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - c = jnp.allclose(r1, r2) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64, - testcase_name=f'test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_sigma: bm.jitconn.mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - ).sum() + r1 = f1(events) + r2 = f1(events) + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + print(r1 - r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.named_parameters( + dict(shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel, + prob=prob, + seed=1234, + x64=x64, + testcase_name=f'test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'x64={x64}') + for transpose in [True, False] + for x64 in [True, False] + for outdim_parallel in [True, False] + for shape in shapes + for prob in [0.01, 0.1] ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - self.assertTrue(bm.allclose(r1 * 2., r2)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + events = events.astype(float) + + f1 = jax.grad( + lambda e, w_sigma: taichi_mv_prob_normal( + e, + w_mu=0., + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum() + ) + r1 = f1(events, 1.) + r2 = f1(events, 2.) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_matvec_gpu.py b/brainpy/_src/math/jitconn/tests/test_matvec_gpu.py deleted file mode 100644 index f227c0e6a..000000000 --- a/brainpy/_src/math/jitconn/tests/test_matvec_gpu.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax -import pytest - -import test_matvec - -if jax.default_backend() != 'gpu': - pytest.skip("No gpu available.", allow_module_level=True) - - -class Test_matvec_prob_conn_GPU(test_matvec.Test_matvec_prob_conn): - def __init__(self, *args, **kwargs): - super(Test_matvec_prob_conn_GPU, self).__init__(*args, **kwargs, platform='gpu') diff --git a/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py b/brainpy/_src/math/jitconn/tests/test_matvec_old.py similarity index 68% rename from brainpy/_src/math/jitconn/tests/test_matvec_taichi.py rename to brainpy/_src/math/jitconn/tests/test_matvec_old.py index 380db3cf5..360711e7b 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec_old.py @@ -1,15 +1,31 @@ # -*- coding: utf-8 -*- - +from functools import partial import jax import jax.numpy as jnp from absl.testing import parameterized import brainpy.math as bm - -shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] -shapes = [(100, 200), (2, 1000), (1000, 2)] - +import platform +import pytest + +pytest.skip('Old implementation.', allow_module_level=True) +is_manual_test = False +if platform.system() == 'Windows' and not is_manual_test: + pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) + +shapes = [(100, 200), + (10, 1000), + (2, 1000), + (1000, 10), + (1000, 2)] + +brainpylib_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='brainpylib') +taichi_mv_prob_homo = partial(bm.jitconn.mv_prob_homo, method='taichi') +brainpylib_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='brainpylib') +taichi_mv_prob_uniform = partial(bm.jitconn.mv_prob_uniform, method='taichi') +brainpylib_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='brainpylib') +taichi_mv_prob_normal = partial(bm.jitconn.mv_prob_normal, method='taichi') class Test_matvec_prob_conn(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -51,32 +67,34 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non rng = bm.random.RandomState() vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = bm.jitconn.mv_prob_homo_taichi(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = bm.jitconn.mv_prob_homo_taichi(vector, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = brainpylib_mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = brainpylib_mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) self.assertTrue(jnp.allclose(r1, r2)) - r2 = bm.jitconn.mv_prob_homo_taichi(vector, - homo_data, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r2 = brainpylib_mv_prob_homo(vector, + homo_data, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) self.assertTrue(jnp.allclose(r1, r2)) + if x64: + bm.disable_x64() bm.clear_buffer_memory() @parameterized.named_parameters( @@ -111,11 +129,11 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64 weights = bm.as_jax(rng.random(10)) f1 = jax.vmap( - lambda event, data: bm.jitconn.mv_prob_homo_taichi( + lambda event, data: brainpylib_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose - )[0] + ) ) r1 = f1(events, weights) r2 = f1(events, weights) @@ -156,14 +174,14 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 events = events.astype(float) f1 = jax.grad( - lambda event, data: bm.jitconn.mv_prob_homo_taichi( + lambda event, data: brainpylib_mv_prob_homo( event, data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose - )[0].sum(), + ).sum(), argnums=0 ) r1 = f1(events, 1.) @@ -213,36 +231,36 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s rng = bm.random.RandomState() events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = bm.jitconn.mv_prob_uniform_taichi(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = bm.jitconn.mv_prob_uniform_taichi(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = brainpylib_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = brainpylib_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) c = jnp.allclose(r1, r2) if not c: print(r1, r2) self.assertTrue(c) - r2 = bm.jitconn.mv_prob_uniform_taichi(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r2 = brainpylib_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) c = jnp.allclose(r1, r2) if not c: print(r1, r2) @@ -281,14 +299,14 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, rng = bm.random.RandomState() events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_uniform_taichi(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) + f1 = jax.vmap(lambda e: brainpylib_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) r1 = f1(events) r2 = f1(events) @@ -330,7 +348,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) f1 = jax.grad( - lambda e, w_low, w_high: bm.jitconn.mv_prob_uniform_taichi( + lambda e, w_low, w_high: brainpylib_mv_prob_uniform( e, w_low=w_low, w_high=w_high, @@ -339,7 +357,7 @@ def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=None, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose - )[0].sum() + ).sum() ) r1 = f1(events, 0., 1.) @@ -390,36 +408,36 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se rng = bm.random.RandomState() events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = bm.jitconn.mv_prob_normal_taichi(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - - r2 = bm.jitconn.mv_prob_normal_taichi(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) + r1 = brainpylib_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + + r2 = brainpylib_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) c = jnp.allclose(r1, r2) if not c: print(r1, r2) self.assertTrue(c) - r2 = bm.jitconn.mv_prob_normal_taichi(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=(shape[1], shape[0]), - seed=seed, - outdim_parallel=outdim_parallel, - transpose=not transpose) + r2 = brainpylib_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=(shape[1], shape[0]), + seed=seed, + outdim_parallel=outdim_parallel, + transpose=not transpose) c = jnp.allclose(r1, r2) if not c: print(r1, r2) @@ -459,20 +477,19 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x rng = bm.random.RandomState() events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_normal_taichi(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) + f1 = jax.vmap(lambda e: brainpylib_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) r1 = f1(events) r2 = f1(events) - c = jnp.allclose(r1, r2, atol=1e-6) + c = jnp.allclose(r1, r2) if not c: print(r1, r2) - print(r1 - r2) self.assertTrue(c) if x64: @@ -512,7 +529,7 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x events = events.astype(float) f1 = jax.grad( - lambda e, w_sigma: bm.jitconn.mv_prob_normal_taichi( + lambda e, w_sigma: brainpylib_mv_prob_normal( e, w_mu=0., w_sigma=w_sigma, @@ -521,10 +538,12 @@ def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x seed=seed, outdim_parallel=outdim_parallel, transpose=transpose - )[0].sum() + ).sum() ) r1 = f1(events, 1.) r2 = f1(events, 2.) + print('r1:', r1) + print('r2:', r2) self.assertTrue(bm.allclose(r1 * 2., r2)) if x64: diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 878b205cf..96ebabfa7 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -347,7 +347,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): # kernel to code codes = _kernel_to_code(kernel, abs_ins, abs_outs, platform) - source_md5_encode = kernel.__name__ + '/' + encode_md5(codes) + source_md5_encode = os.path.join(kernel.__name__, encode_md5(codes)) # create ins, outs dict from kernel's args in_num = len(ins) @@ -361,7 +361,10 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): try: _build_kernel(source_md5_encode, kernel, ins_dict, outs_dict, platform) except Exception as e: - os.removedirs(os.path.join(kernels_aot_path, source_md5_encode)) + try: + os.removedirs(os.path.join(kernels_aot_path, source_md5_encode)) + except Exception: + raise RuntimeError(f'Failed to preprocess info to build kernel:\n\n {codes}') from e raise RuntimeError(f'Failed to build kernel:\n\n {codes}') from e # returns diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index cd94d0621..d45f2c80b 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,7 +1,6 @@ from ._coo_mv import * from ._csr_mv import * -from ._csr_mv_taichi import * from ._utils import * from ._bsr_mv import * from ._bsr_mm import * diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index d874ad901..47704af04 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -13,20 +13,78 @@ from jax.lib import xla_client from jaxlib import gpu_sparse -from brainpy._src.dependency_check import import_brainpylib_gpu_ops +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching) + register_general_batching, + XLACustomOp) from brainpy._src.math.sparse._utils import csr_to_coo from brainpy.errors import GPUOperatorNotFound +ti = import_taichi() + __all__ = [ 'csrmv', ] def csrmv( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + vector: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, + method: str = None, +): + """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. + + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. + + Parameters + ---------- + data: ndarray, float + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: ndarray + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + vector: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple of int + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + method: str + The method used to compute Matrix-Vector Multiplication. Default is ``taichi``. + The candidate methods are: + + - ``None``: default using Taichi kernel. + - ``cusparse``: using cuSPARSE library. + - ``scalar``: + - ``vector``: + - ``adaptive``: + + Returns + ------- + y : ndarry + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + if method is None: + return csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose) + else: + return csrmv_brainpylib(data, indices, indptr, vector, shape=shape, transpose=transpose, method=method) + + +### BRAINPYLIB ### + +def csrmv_brainpylib( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], indptr: Union[jnp.ndarray, Array], @@ -466,3 +524,289 @@ def _csrmv_adaptive_transpose(ct, data, indices, indptr, vector, *, shape, trans partial(_csrmv_jvp_vec, _csrmv_adaptive_p), ) ad.primitive_transposes[_csrmv_adaptive_p] = _csrmv_adaptive_transpose register_general_batching(_csrmv_adaptive_p) + + +### TAICHI ### + +def csrmv_taichi( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + vector: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +) -> jax.Array: + """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. + + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. + + Parameters + ---------- + data: ndarray, float + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: ndarray + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + vector: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple of int + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + + Returns + ------- + y : ndarry + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + + data = jnp.atleast_1d(as_jax(data)) + indices = as_jax(indices) + indptr = as_jax(indptr) + vector = as_jax(vector) + + if vector.dtype == jnp.bool_: + vector = as_jax(vector, dtype=data.dtype) + + if data.dtype not in [jnp.float16, jnp.float32, jnp.float64]: + raise TypeError('Only support float16, float32 or float64 type. ' + f'But we got {data.dtype}.') + if data.dtype != vector.dtype: + raise TypeError('The types of data and vector should be the same. ' + f'But we got {data.dtype} != {vector.dtype}.') + assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 + if not jnp.issubdtype(indices.dtype, jnp.integer): + raise ValueError('indices should be a 1D vector with integer type.') + if not jnp.issubdtype(indptr.dtype, jnp.integer): + raise ValueError('indptr should be a 1D vector with integer type.') + + # if the shape of indices is (0,), then we return a zero vector + if indices.shape[0] == 0: + return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) + + return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] + + +# ------------- +# CPU operators +# ------------- + + +@ti.kernel +def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += value * vector[row_i] + + +@ti.kernel +def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += vector[row_i] * values[j] + + +@ti.kernel +def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += vector[col_indices[j]] + out[row_i] = r * value + + +@ti.kernel +def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * vector[col_indices[j]] + out[row_i] = r + + +# ------------- +# GPU operators +# ------------- + + +@ti.kernel +def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + out[col_indices[j]] += value * vector[row_i] + j += 32 + + +@ti.kernel +def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + r += vector[col_indices[j]] + j += 32 + out[row_i] += value * r + + +@ti.kernel +def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + out[col_indices[j]] += values[j] * vector[row_i] + j += 32 + + +@ti.kernel +def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + r += values[j] * vector[col_indices[j]] + j += 32 + out[row_i] += r # TODO: warp-level primitive + + +def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): + return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) + + +def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): + return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) + + +def _sparse_csr_matvec_transpose( + ct, data, indices, indptr, vector, *, outs, transpose, shape, +): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(vector): + ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] + return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) + + else: + if type(ct[0]) is ad.Zero: + ct_data = ad.Zero(data) + else: + if data.aval.shape[0] == 1: # scalar + ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] + ct_data = jnp.inner(ct[0], ct_data) + else: + row, col = csr_to_coo(indices, indptr) + ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] + + return ct_data, indices, indptr, vector + + +def raw_csrmv_taichi( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + vector: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +): + out_shape = shape[1] if transpose else shape[0] + if transpose: + if data.shape[0] == 1: + prim = _csr_matvec_transpose_homo_p + else: + prim = _csr_matvec_transpose_heter_p + else: + if data.shape[0] == 1: + prim = _csr_matvec_homo_p + else: + prim = _csr_matvec_heter_p + + return prim(data, + indices, + indptr, + vector, + outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)], + transpose=transpose, + shape=shape) + + +def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) + prim.def_transpose_rule(_sparse_csr_matvec_transpose) + return prim + + +# transpose homo +_csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, + gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) + +# no transpose homo +_csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, + gpu_kernel=_sparse_csr_matvec_homo_gpu) + +# transpose heter +_csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, + gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) + +# no transpose heter +_csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, + gpu_kernel=_sparse_csr_matvec_heter_gpu) diff --git a/brainpy/_src/math/sparse/_csr_mv_taichi.py b/brainpy/_src/math/sparse/_csr_mv_taichi.py deleted file mode 100644 index cd09af08e..000000000 --- a/brainpy/_src/math/sparse/_csr_mv_taichi.py +++ /dev/null @@ -1,288 +0,0 @@ -# -*- coding: utf-8 -*- - - -from typing import Union, Tuple - -import jax -from jax import numpy as jnp -from jax.interpreters import ad - -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.sparse._utils import csr_to_coo - -ti = import_taichi() - -__all__ = [ - 'csrmv_taichi', -] - - -# ------------- -# CPU operators -# ------------- - - -@ti.kernel -def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += value * vector[row_i] - - -@ti.kernel -def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += vector[row_i] * values[j] - - -@ti.kernel -def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += vector[col_indices[j]] - out[row_i] = r * value - - -@ti.kernel -def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * vector[col_indices[j]] - out[row_i] = r - - -# ------------- -# GPU operators -# ------------- - - -@ti.kernel -def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += value * vector[row_i] - j += 32 - - -@ti.kernel -def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += vector[col_indices[j]] - j += 32 - out[row_i] += value * r - - -@ti.kernel -def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += values[j] * vector[row_i] - j += 32 - - -@ti.kernel -def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += values[j] * vector[col_indices[j]] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) - - -def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) - - -def _sparse_csr_matvec_transpose( - ct, data, indices, indptr, vector, *, outs, transpose, shape, -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(vector): - ct_vector = csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) - - else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] - ct_data = jnp.inner(ct[0], ct_data) - else: - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] - - return ct_data, indices, indptr, vector - - -def csrmv_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - vector: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -) -> jax.Array: - """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - vector: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple of int - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - - Returns - ------- - y : ndarry - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - - data = jnp.atleast_1d(as_jax(data)) - indices = as_jax(indices) - indptr = as_jax(indptr) - vector = as_jax(vector) - - if vector.dtype == jnp.bool_: - vector = as_jax(vector, dtype=data.dtype) - - if data.dtype not in [jnp.float16, jnp.float32, jnp.float64]: - raise TypeError('Only support float16, float32 or float64 type. ' - f'But we got {data.dtype}.') - if data.dtype != vector.dtype: - raise TypeError('The types of data and vector should be the same. ' - f'But we got {data.dtype} != {vector.dtype}.') - assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - out_shape = shape[1] if transpose else shape[0] - - if transpose: - if data.shape[0] == 1: - prim = _csr_matvec_transpose_homo_p - else: - prim = _csr_matvec_transpose_heter_p - else: - if data.shape[0] == 1: - prim = _csr_matvec_homo_p - else: - prim = _csr_matvec_heter_p - - return prim(data, - indices, - indptr, - vector, - outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)], - transpose=transpose, - shape=shape) - - -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) - prim.def_transpose_rule(_sparse_csr_matvec_transpose) - return prim - - -# transpose homo -_csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) - -# no transpose homo -_csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, - gpu_kernel=_sparse_csr_matvec_homo_gpu) - -# transpose heter -_csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) - -# no transpose heter -_csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, - gpu_kernel=_sparse_csr_matvec_heter_gpu) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 16bf43a48..2c75f0901 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -3,24 +3,60 @@ from functools import partial import jax -import pytest from absl.testing import parameterized -import platform + import brainpy as bp import brainpy.math as bm -is_manual_test = False -if platform.system() == 'Windows' and not is_manual_test: - pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) +# bm.set_platform('gpu') + +seed = 1234 + + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + + +def compare_with_nan_tolerance(a, b, tol=1e-8): + """ + Compare two arrays with tolerance for NaN values. + + Parameters: + a (np.array): First array to compare. + b (np.array): Second array to compare. + tol (float): Tolerance for comparing non-NaN elements. + + Returns: + bool: True if arrays are similar within the tolerance, False otherwise. + """ + if a.shape != b.shape: + return False + + # Create masks for NaNs in both arrays + nan_mask_a = bm.isnan(a) + nan_mask_b = bm.isnan(b) + + # Check if NaN positions are the same in both arrays + if not bm.array_equal(nan_mask_a, nan_mask_b): + return False + + # Compare non-NaN elements + a_non_nan = a[~nan_mask_a] + b_non_nan = b[~nan_mask_b] -cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse') -scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar') -vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') + return bm.allclose(a_non_nan, b_non_nan, atol=tol) -class Test_cusparse_csrmv(parameterized.TestCase): +taichi_csr_matvec = bm.sparse.csrmv + +class Test_csrmv_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): - super(Test_cusparse_csrmv, self).__init__(*args, **kwargs) + super(Test_csrmv_taichi, self).__init__(*args, **kwargs) print() bm.set_platform(platform) @@ -31,35 +67,36 @@ def __init__(self, *args, platform='cpu', **kwargs): homo_data=[-1., 0., 1.] ) def test_homo(self, transpose, shape, homo_data): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) + print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') + conn = bp.conn.FixedProb(0.3) + # matrix indices, indptr = conn(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) - - heter_data = bm.ones(indices.shape).value * homo_data - + # vector + rng = bm.random.RandomState(seed=seed) vector = rng.random(shape[0] if transpose else shape[1]) vector = bm.as_jax(vector) - r1 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) + + heter_data = bm.ones(indices.shape).value * homo_data dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r3 = (vector @ dense) if transpose else (dense @ vector) - self.assertTrue(bm.allclose(r1, r3)) + r1 = (vector @ dense) if transpose else (dense @ vector) + r2 = taichi_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], + shape=[(200, 200), (200, 100), (100, 1000), (2, 2000)], v=[-1., 0., 1.] ) def test_homo_vmap(self, transpose, shape, v): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) + print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, v = {v}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) indices, indptr = conn(*shape).require('pre2post') indices = bm.as_jax(indices) @@ -71,17 +108,13 @@ def test_homo_vmap(self, transpose, shape, v): homo_data = bm.ones(10).value * v dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, + f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) + f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) - f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) - - r1 = jax.vmap(f1)(homo_data) - r2 = jax.vmap(f1)(heter_data) + r1 = jax.vmap(f1)(dense_data) + r2 = jax.vmap(f2)(homo_data) self.assertTrue(bm.allclose(r1, r2)) - r3 = jax.vmap(f2)(dense_data) - self.assertTrue(bm.allclose(r1, r3)) - bm.clear_buffer_memory() @parameterized.product( @@ -90,8 +123,9 @@ def test_homo_vmap(self, transpose, shape, v): homo_data=[-1., 0., 1.] ) def test_homo_grad(self, transpose, shape, homo_data): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) + print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) indices, indptr = conn(*shape).require('pre2post') indices = bm.as_jax(indices) @@ -103,37 +137,35 @@ def test_homo_grad(self, transpose, shape, homo_data): vector = rng.random(shape[0] if transpose else shape[1]) vector = bm.as_jax(vector) - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, - shape=shape, transpose=transpose).sum(), - argnums=0) + # print('grad data start') + # grad 'data' dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum() if transpose else ((dense * a) @ vector).sum()), argnums=0) + r1 = dense_f1(homo_data) + r2 = jax.grad(sum_op(taichi_csr_matvec))( + homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r1 = csr_f1(homo_data) - r2 = dense_f1(homo_data) self.assertTrue(bm.allclose(r1, r2)) - csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(homo_data, indices, indptr, v, - shape=shape, transpose=transpose).sum()) + # print('grad vector start') + # grad 'vector' dense_data = dense * homo_data dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) + r3 = dense_f2(vector) + r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( + homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r3 = csr_f2(vector) - r4 = dense_f2(vector) self.assertTrue(bm.allclose(r3, r4)) - csr_f3 = jax.grad(lambda a, v: cusparse_csr_matvec(a, indices, indptr, v, - shape=shape, transpose=transpose).sum(), - argnums=(0, 1)) dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() if transpose else ((dense * a) @ v).sum()), argnums=(0, 1)) - - r5 = csr_f3(homo_data, vector) - r6 = dense_f3(homo_data, vector) + r5 = dense_f3(homo_data, vector) + r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( + homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) @@ -141,26 +173,28 @@ def test_homo_grad(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], + shape=[(200, 200), (200, 100), (2, 2000)], ) def test_heter(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) + print(f'test_homo: transpose = {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) indices, indptr = conn(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) - heter_data = rng.random(indices.shape) + heter_data = bm.as_jax(rng.random(indices.shape)) heter_data = bm.as_jax(heter_data) vector = rng.random(shape[0] if transpose else shape[1]) vector = bm.as_jax(vector) - r1 = cusparse_csr_matvec(heter_data, indices, indptr, vector, - shape=shape, transpose=transpose) + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r2 = (vector @ dense) if transpose else (dense @ vector) - self.assertTrue(bm.allclose(r1, r2)) + r1 = (vector @ dense) if transpose else (dense @ vector) + r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) + + self.assertTrue(compare_with_nan_tolerance(r1, r2)) bm.clear_buffer_memory() @@ -169,8 +203,8 @@ def test_heter(self, transpose, shape): shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] ) def test_heter_vmap(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) indices, indptr = conn(*shape).require('pre2post') indices = bm.as_jax(indices) @@ -183,23 +217,20 @@ def test_heter_vmap(self, transpose, shape): dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, + f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) + f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) - f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) - - r1 = jax.vmap(f1)(heter_data) - r2 = jax.vmap(f2)(dense_data) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() + r1 = jax.vmap(f1)(dense_data) + r2 = jax.vmap(f2)(heter_data) + self.assertTrue(compare_with_nan_tolerance(r1, r2)) @parameterized.product( transpose=[True, False], shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] ) def test_heter_grad(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) indices, indptr = conn(*shape).require('pre2post') indices = bm.as_jax(indices) @@ -210,141 +241,29 @@ def test_heter_grad(self, transpose, shape): vector = rng.random(shape[0] if transpose else shape[1]) vector = bm.as_jax(vector) - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, + # grad 'data' + dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), + argnums=0) + csr_f1 = jax.grad(lambda a: taichi_csr_matvec(a, indices, indptr, vector, shape=shape, transpose=transpose).sum(), argnums=0) - dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), - argnums=0) - r1 = csr_f1(heter_data) r2 = dense_f1(dense_data) rows, cols = bm.sparse.csr_to_coo(indices, indptr) r2 = r2[rows, cols] + print(r1.shape, r2.shape) self.assertTrue(bm.allclose(r1, r2)) - csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, - shape=shape, - transpose=transpose).sum(), - argnums=0) + # grad 'vector' dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), argnums=0) - r3 = csr_f2(vector) - r4 = dense_f2(vector) + csr_f2 = jax.grad(lambda v: taichi_csr_matvec(heter_data, indices, indptr, v, + shape=shape, + transpose=transpose).sum(), + argnums=0) + r3 = dense_f2(vector) + r4 = csr_f2(vector) self.assertTrue(bm.allclose(r3, r4)) bm.clear_buffer_memory() - - -class Test_csrmv(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_csrmv, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - homo_data=[-1., 0., 0.1, 1.], - shape=[(100, 200), (10, 1000), (2, 2000)], - ) - def test_homo(self, shape, homo_data): - conn = bp.conn.FixedProb(0.1) - - # matrix - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # vector - rng = bm.random.RandomState(123) - vector = rng.random(shape[1]) - vector = bm.as_jax(vector) - - # csrmv - r1 = scalar_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - r2 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - r3 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - - heter_data = bm.ones(indices.shape).to_jax() * homo_data - r4 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r5 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r6 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - self.assertTrue(bm.allclose(r1, r4)) - self.assertTrue(bm.allclose(r1, r5)) - self.assertTrue(bm.allclose(r1, r6)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - rdense = dense @ vector - self.assertTrue(bm.allclose(r1, rdense)) - - bm.clear_buffer_memory() - - @parameterized.product( - shape=[(100, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter(self, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = bm.as_jax(rng.random(indices.shape)) - vector = bm.as_jax(rng.random(shape[1])) - - r1 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r3 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r4 = dense @ vector - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - self.assertTrue(bm.allclose(r1, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_grad(self, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - heter_data = rng.random(indices.shape) - dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - vector = rng.random(shape[1]) - - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - csr_f2 = jax.grad(lambda a: scalar_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - csr_f3 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - dense_f1 = jax.grad(lambda a: (a @ vector).sum()) - - r1 = csr_f1(heter_data) - r2 = csr_f2(heter_data) - r3 = csr_f3(heter_data) - - d1 = dense_f1(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - d1 = d1[rows, cols] - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - self.assertTrue(bm.allclose(r1, d1)) - - # csr_f4 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # csr_f5 = jax.grad(lambda v: scalar_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # csr_f6 = jax.grad(lambda v: vector_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # dense_f2 = jax.grad(lambda v: (dense_data @ v).sum()) - # r4 = csr_f4(vector) - # r5 = csr_f5(vector) - # r6 = csr_f6(vector) - # d2 = dense_f2(vector) - # self.assertTrue(bm.allclose(r4, r5)) - # self.assertTrue(bm.allclose(r4, r6)) - # self.assertTrue(bm.allclose(r4, d2)) - - bm.clear_buffer_memory() - - diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_gpu.py b/brainpy/_src/math/sparse/tests/test_csrmv_gpu.py deleted file mode 100644 index ccf090ec4..000000000 --- a/brainpy/_src/math/sparse/tests/test_csrmv_gpu.py +++ /dev/null @@ -1,21 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax -import pytest - -import test_csrmv - -if jax.default_backend() != 'gpu': - pytest.skip("No gpu available.", allow_module_level=True) - - -class Test_cusparse_csrmv_GPU(test_csrmv.Test_cusparse_csrmv): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, platform='gpu') - - -class Test__csrmv_GPU(test_csrmv.Test_csrmv): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs, platform='gpu') - - diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_old.py b/brainpy/_src/math/sparse/tests/test_csrmv_old.py new file mode 100644 index 000000000..b73217496 --- /dev/null +++ b/brainpy/_src/math/sparse/tests/test_csrmv_old.py @@ -0,0 +1,352 @@ +# -*- coding: utf-8 -*- + +from functools import partial + +import jax +import pytest +from absl.testing import parameterized +import platform +import brainpy as bp +import brainpy.math as bm + +pytest.skip('Old implementation.', allow_module_level=True) + +is_manual_test = False +# if platform.system() == 'Windows' and not is_manual_test: +# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) + +cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse') +scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar') +vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') + + +class Test_cusparse_csrmv(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_cusparse_csrmv, self).__init__(*args, **kwargs) + + print() + bm.set_platform(platform) + + @parameterized.product( + transpose=[True, False], + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], + homo_data=[-1., 0., 1.] + ) + def test_homo(self, transpose, shape, homo_data): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + + heter_data = bm.ones(indices.shape).value * homo_data + + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + r1 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r3 = (vector @ dense) if transpose else (dense @ vector) + self.assertTrue(bm.allclose(r1, r3)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], + v=[-1., 0., 1.] + ) + def test_homo_vmap(self, transpose, shape, v): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + + heter_data = bm.ones((10, indices.shape[0])).value * v + homo_data = bm.ones(10).value * v + dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) + + f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, + shape=shape, transpose=transpose) + f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) + + r1 = jax.vmap(f1)(homo_data) + r2 = jax.vmap(f1)(heter_data) + self.assertTrue(bm.allclose(r1, r2)) + + r3 = jax.vmap(f2)(dense_data) + self.assertTrue(bm.allclose(r1, r3)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], + homo_data=[-1., 0., 1.] + ) + def test_homo_grad(self, transpose, shape, homo_data): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + indices, + indptr, + shape=shape) + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + + csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, + shape=shape, transpose=transpose).sum(), + argnums=0) + dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum() + if transpose else + ((dense * a) @ vector).sum()), + argnums=0) + + r1 = csr_f1(homo_data) + r2 = dense_f1(homo_data) + self.assertTrue(bm.allclose(r1, r2)) + + csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(homo_data, indices, indptr, v, + shape=shape, transpose=transpose).sum()) + dense_data = dense * homo_data + dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) + + r3 = csr_f2(vector) + r4 = dense_f2(vector) + self.assertTrue(bm.allclose(r3, r4)) + + csr_f3 = jax.grad(lambda a, v: cusparse_csr_matvec(a, indices, indptr, v, + shape=shape, transpose=transpose).sum(), + argnums=(0, 1)) + dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() + if transpose else + ((dense * a) @ v).sum()), + argnums=(0, 1)) + + r5 = csr_f3(homo_data, vector) + r6 = dense_f3(homo_data, vector) + self.assertTrue(bm.allclose(r5[0], r6[0])) + self.assertTrue(bm.allclose(r5[1], r6[1])) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], + ) + def test_heter(self, transpose, shape): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + + heter_data = rng.random(indices.shape) + heter_data = bm.as_jax(heter_data) + + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + r1 = cusparse_csr_matvec(heter_data, indices, indptr, vector, + shape=shape, transpose=transpose) + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r2 = (vector @ dense) if transpose else (dense @ vector) + self.assertTrue(bm.allclose(r1, r2)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + ) + def test_heter_vmap(self, transpose, shape): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + + heter_data = rng.random((10, indices.shape[0])) + heter_data = bm.as_jax(heter_data) + dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, + shape=shape))(heter_data) + + f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, + shape=shape, transpose=transpose) + f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) + + r1 = jax.vmap(f1)(heter_data) + r2 = jax.vmap(f2)(dense_data) + self.assertTrue(bm.allclose(r1, r2)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + ) + def test_heter_grad(self, transpose, shape): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + heter_data = rng.random(indices.shape) + heter_data = bm.as_jax(heter_data) + dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + + csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, + shape=shape, + transpose=transpose).sum(), + argnums=0) + dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), + argnums=0) + + r1 = csr_f1(heter_data) + r2 = dense_f1(dense_data) + rows, cols = bm.sparse.csr_to_coo(indices, indptr) + r2 = r2[rows, cols] + self.assertTrue(bm.allclose(r1, r2)) + + csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, + shape=shape, + transpose=transpose).sum(), + argnums=0) + dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), + argnums=0) + r3 = csr_f2(vector) + r4 = dense_f2(vector) + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() + + +class Test_csrmv(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_csrmv, self).__init__(*args, **kwargs) + + print() + bm.set_platform(platform) + + @parameterized.product( + homo_data=[-1., 0., 0.1, 1.], + shape=[(100, 200), (10, 1000), (2, 2000)], + ) + def test_homo(self, shape, homo_data): + conn = bp.conn.FixedProb(0.1) + + # matrix + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # vector + rng = bm.random.RandomState(123) + vector = rng.random(shape[1]) + vector = bm.as_jax(vector) + + # csrmv + r1 = scalar_csr_matvec(homo_data, indices, indptr, vector, shape=shape) + r2 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape) + r3 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape) + self.assertTrue(bm.allclose(r1, r2)) + self.assertTrue(bm.allclose(r1, r3)) + + heter_data = bm.ones(indices.shape).to_jax() * homo_data + r4 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + r5 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + r6 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + self.assertTrue(bm.allclose(r1, r4)) + self.assertTrue(bm.allclose(r1, r5)) + self.assertTrue(bm.allclose(r1, r6)) + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + rdense = dense @ vector + self.assertTrue(bm.allclose(r1, rdense)) + + bm.clear_buffer_memory() + + @parameterized.product( + shape=[(100, 200), (200, 100), (10, 1000), (2, 2000)] + ) + def test_heter(self, shape): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + heter_data = bm.as_jax(rng.random(indices.shape)) + vector = bm.as_jax(rng.random(shape[1])) + + r1 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + r3 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r4 = dense @ vector + self.assertTrue(bm.allclose(r1, r2)) + self.assertTrue(bm.allclose(r1, r3)) + self.assertTrue(bm.allclose(r1, r4)) + + bm.clear_buffer_memory() + + @parameterized.product( + shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + ) + def test_heter_grad(self, shape): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + heter_data = rng.random(indices.shape) + dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + vector = rng.random(shape[1]) + + csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) + csr_f2 = jax.grad(lambda a: scalar_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) + csr_f3 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) + dense_f1 = jax.grad(lambda a: (a @ vector).sum()) + + r1 = csr_f1(heter_data) + r2 = csr_f2(heter_data) + r3 = csr_f3(heter_data) + + d1 = dense_f1(dense_data) + rows, cols = bm.sparse.csr_to_coo(indices, indptr) + d1 = d1[rows, cols] + self.assertTrue(bm.allclose(r1, r2)) + self.assertTrue(bm.allclose(r1, r3)) + self.assertTrue(bm.allclose(r1, d1)) + + # csr_f4 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) + # csr_f5 = jax.grad(lambda v: scalar_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) + # csr_f6 = jax.grad(lambda v: vector_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) + # dense_f2 = jax.grad(lambda v: (dense_data @ v).sum()) + # r4 = csr_f4(vector) + # r5 = csr_f5(vector) + # r6 = csr_f6(vector) + # d2 = dense_f2(vector) + # self.assertTrue(bm.allclose(r4, r5)) + # self.assertTrue(bm.allclose(r4, r6)) + # self.assertTrue(bm.allclose(r4, d2)) + + bm.clear_buffer_memory() + + diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py b/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py deleted file mode 100644 index 2b3d7b5b0..000000000 --- a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py +++ /dev/null @@ -1,488 +0,0 @@ -# -*- coding: utf-8 -*- - -from functools import partial - -import jax -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - -# bm.set_platform('gpu') - -seed = 1234 - - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -def sum_op2(op): - def func(*args, **kwargs): - r = op(*args, **kwargs)[0] - return r.sum() - - return func - - -def compare_with_nan_tolerance(a, b, tol=1e-8): - """ - Compare two arrays with tolerance for NaN values. - - Parameters: - a (np.array): First array to compare. - b (np.array): Second array to compare. - tol (float): Tolerance for comparing non-NaN elements. - - Returns: - bool: True if arrays are similar within the tolerance, False otherwise. - """ - if a.shape != b.shape: - return False - - # Create masks for NaNs in both arrays - nan_mask_a = bm.isnan(a) - nan_mask_b = bm.isnan(b) - - # Check if NaN positions are the same in both arrays - if not bm.array_equal(nan_mask_a, nan_mask_b): - return False - - # Compare non-NaN elements - a_non_nan = a[~nan_mask_a] - b_non_nan = b[~nan_mask_b] - - return bm.allclose(a_non_nan, b_non_nan, atol=tol) - - -vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') - - -### MANUAL TESTS ### -# transposes = [True, False] -# homo_datas = [-1., 0., 0.1, 1.] -# shapes = [(100, 200), (10, 1000), (2, 2000)] -# -# -# def test_homo(transpose, shape, homo_data): -# print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') -# conn = bp.conn.FixedProb(0.1) -# -# # matrix -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# # vector -# rng = bm.random.RandomState(123) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# r1 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) -# r2 = bm.sparse.csrmv_taichi(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) -# assert (bm.allclose(r1, r2[0])) -# -# bm.clear_buffer_memory() -# -# -# def test_homo_vmap(transpose, shape, homo_data): -# print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# heter_data = bm.ones((10, indices.shape[0])).value * homo_data -# homo_data = bm.ones(10).value * homo_data -# dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) -# -# f1 = partial(vector_csr_matvec, indices=indices, indptr=indptr, vector=vector, -# shape=shape, transpose=transpose) -# f2 = partial(bm.sparse.csrmv_taichi, indices=indices, indptr=indptr, vector=vector, -# shape=shape, transpose=transpose) -# r1 = jax.vmap(f1)(homo_data) -# r2 = jax.vmap(f1)(homo_data) -# assert (bm.allclose(r1, r2[0])) -# -# bm.clear_buffer_memory() -# -# -# def test_homo_grad(transpose, shape, homo_data): -# print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, -# indices, -# indptr, -# shape=shape) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# # print('grad data start') -# # grad 'data' -# r1 = jax.grad(sum_op(vector_csr_matvec))( -# homo_data, indices, indptr, vector, shape=shape, transpose=transpose) -# r2 = jax.grad(sum_op2(bm.sparse.csrmv_taichi))( -# homo_data, indices, indptr, vector, shape=shape, transpose=transpose) -# -# # csr_f1 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, -# # shape=shape, transpose=transpose).sum(), -# # argnums=0) -# # csr_f2 = jax.grad(lambda a: bm.sparse.csrmv_taichi(a, indices, indptr, vector, -# # shape=shape, transpose=transpose)[0].sum(), -# # argnums=0) -# # r1 = csr_f1(homo_data) -# # r2 = csr_f2(homo_data) -# assert (bm.allclose(r1, r2)) -# -# # print('grad vector start') -# # grad 'vector' -# r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( -# homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# r4 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( -# homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# # csr_f3 = jax.grad(lambda v: vector_csr_matvec(homo_data, indices, indptr, v, -# # shape=shape, transpose=transpose).sum()) -# # csr_f4 = jax.grad(lambda v: bm.sparse.csrmv_taichi(homo_data, indices, indptr, v, -# # shape=shape, transpose=transpose)[0].sum()) -# # r3 = csr_f3(vector) -# # r4 = csr_f4(vector) -# assert (bm.allclose(r3, r4)) -# -# # csr_f5 = jax.grad(lambda a, v: vector_csr_matvec(a, indices, indptr, v, -# # shape=shape, transpose=transpose).sum(), -# # argnums=(0, 1)) -# # csr_f6 = jax.grad(lambda a, v: bm.sparse.csrmv_taichi(a, indices, indptr, v, -# # shape=shape, transpose=transpose)[0].sum(), -# # argnums=(0, 1)) -# # r5 = csr_f5(homo_data, vector) -# # r6 = csr_f6(homo_data, vector) -# # assert(bm.allclose(r5[0], r6[0])) -# # assert(bm.allclose(r5[1], r6[1])) -# -# bm.clear_buffer_memory() -# -# -# def test_heter(transpose, shape): -# print(f'test_heter: transpose = {transpose} shape = {shape}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# heter_data = bm.as_jax(rng.random(indices.shape)) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# r1 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) -# r2 = bm.sparse.csrmv_taichi(heter_data, indices, indptr, vector, shape=shape) -# # bm.nan_to_num(r1) -# # bm.nan_to_num(r2[0]) -# # print(r1) -# # print(r1 - r2[0]) -# assert (compare_with_nan_tolerance(r1, r2[0])) -# -# bm.clear_buffer_memory() -# -# -# def test_heter_vmap(transpose, shape): -# print(f'test_heter_vmap: transpose = {transpose} shape = {shape}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# heter_data = rng.random((10, indices.shape[0])) -# heter_data = bm.as_jax(heter_data) -# dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, -# shape=shape))(heter_data) -# -# f1 = partial(vector_csr_matvec, indices=indices, indptr=indptr, vector=vector, -# shape=shape, transpose=transpose) -# f2 = partial(bm.sparse.csrmv_taichi, indices=indices, indptr=indptr, vector=vector, -# shape=shape, transpose=transpose) -# r1 = jax.vmap(f1)(heter_data) -# r2 = jax.vmap(f2)(heter_data) -# assert (bm.allclose(r1, r2[0])) -# -# -# def test_heter_grad(transpose, shape): -# print(f'test_heter_grad: transpose = {transpose} shape = {shape}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# heter_data = rng.random(indices.shape) -# heter_data = bm.as_jax(heter_data) -# dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# # grad 'data' -# r1 = jax.grad(sum_op(vector_csr_matvec))( -# heter_data, indices, indptr, vector, shape=shape, transpose=transpose) -# r2 = jax.grad(sum_op2(bm.sparse.csrmv_taichi))( -# heter_data, indices, indptr, vector, shape=shape, transpose=transpose) -# assert (bm.allclose(r1, r2)) -# -# # grad 'vector' -# r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( -# heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# r4 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( -# heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# assert (bm.allclose(r3, r4)) -# -# r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( -# heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# r6 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=(0, 3))( -# heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# assert (bm.allclose(r5[0], r6[0])) -# assert (bm.allclose(r5[1], r6[1])) -# -# bm.clear_buffer_memory() -# -# def test_all(): -# # for transpose in transposes: -# # for shape in shapes: -# # for homo_data in homo_datas: -# # test_homo(transpose, shape, homo_data) -# # test_homo_vmap(transpose, shape, homo_data) -# # test_homo_grad(transpose, shape, homo_data) -# -# for transpose in transposes: -# for shape in shapes: -# test_heter(transpose, shape) -# test_heter_vmap(transpose, shape) -# test_heter_grad(transpose, shape) -# test_all() - -# PYTEST -class Test_csrmv_taichi(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_csrmv_taichi, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] - ) - def test_homo(self, transpose, shape, homo_data): - print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') - conn = bp.conn.FixedProb(0.3) - - # matrix - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # vector - rng = bm.random.RandomState(seed=seed) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - r1 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = bm.sparse.csrmv_taichi(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2[0])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (100, 1000), (2, 2000)], - v=[-1., 0., 1.] - ) - def test_homo_vmap(self, transpose, shape, v): - print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, v = {v}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = bm.ones((10, indices.shape[0])).value * v - homo_data = bm.ones(10).value * v - dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - - f1 = partial(vector_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - f2 = partial(bm.sparse.csrmv_taichi, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - r1 = jax.vmap(f1)(homo_data) - r2 = jax.vmap(f1)(homo_data) - self.assertTrue(bm.allclose(r1, r2[0])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] - ) - def test_homo_grad(self, transpose, shape, homo_data): - print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, - indices, - indptr, - shape=shape) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - # print('grad data start') - # grad 'data' - r1 = jax.grad(sum_op(vector_csr_matvec))( - homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op2(bm.sparse.csrmv_taichi))( - homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - - # csr_f1 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, - # shape=shape, transpose=transpose).sum(), - # argnums=0) - # csr_f2 = jax.grad(lambda a: bm.sparse.csrmv_taichi(a, indices, indptr, vector, - # shape=shape, transpose=transpose)[0].sum(), - # argnums=0) - # r1 = csr_f1(homo_data) - # r2 = csr_f2(homo_data) - self.assertTrue(bm.allclose(r1, r2)) - - # print('grad vector start') - # grad 'vector' - r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - - self.assertTrue(bm.allclose(r3, r4)) - - r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=(0, 3))( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (2, 2000)], - ) - def test_heter(self, transpose, shape): - print(f'test_homo: transpose = {transpose} shape = {shape}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - heter_data = bm.as_jax(rng.random(indices.shape)) - heter_data = bm.as_jax(heter_data) - - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - r1 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r2 = bm.sparse.csrmv_taichi(heter_data, indices, indptr, vector, shape=shape) - - print(r1) - print(r2[0]) - - self.assertTrue(compare_with_nan_tolerance(r1, r2[0])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_vmap(self, transpose, shape): - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = rng.random((10, indices.shape[0])) - heter_data = bm.as_jax(heter_data) - dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, - shape=shape))(heter_data) - - f1 = partial(vector_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - f2 = partial(bm.sparse.csrmv_taichi, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - r1 = jax.vmap(f1)(heter_data) - r2 = jax.vmap(f2)(heter_data) - self.assertTrue(compare_with_nan_tolerance(r1, r2[0])) - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_grad(self, transpose, shape): - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = rng.random(indices.shape) - heter_data = bm.as_jax(heter_data) - dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - # grad 'data' - r1 = jax.grad(sum_op(vector_csr_matvec))( - heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op2(bm.sparse.csrmv_taichi))( - heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - # grad 'vector' - r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r3, r4)) - - r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=(0, 3))( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - - bm.clear_buffer_memory() diff --git a/brainpy/math/event.py b/brainpy/math/event.py index 2e9f38039..0a17cae7c 100644 --- a/brainpy/math/event.py +++ b/brainpy/math/event.py @@ -1,6 +1,5 @@ from brainpy._src.math.event import ( csrmv as csrmv, - csrmv_taichi as csrmv_taichi, info as info, ) diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index 0ade274e6..90a028b7e 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -6,13 +6,5 @@ mv_prob_homo as mv_prob_homo, mv_prob_uniform as mv_prob_uniform, mv_prob_normal as mv_prob_normal, - - event_mv_prob_homo_taichi as event_mv_prob_homo_taichi, - event_mv_prob_uniform_taichi as event_mv_prob_uniform_taichi, - event_mv_prob_normal_taichi as event_mv_prob_normal_taichi, - - mv_prob_homo_taichi as mv_prob_homo_taichi, - mv_prob_uniform_taichi as mv_prob_uniform_taichi, - mv_prob_normal_taichi as mv_prob_normal_taichi ) diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index 97c585746..1380a9e9c 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -1,6 +1,5 @@ from brainpy._src.math.sparse import ( csrmv, - csrmv_taichi, coomv, seg_matmul,