diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 95bd8eafd..d29b07ebc 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -11,6 +11,12 @@ on: branches: - '**' # matches every branch + +permissions: + contents: read # to fetch code + actions: write # to cancel previous workflows + + #on: # push: # branches: [ master ] @@ -27,6 +33,10 @@ jobs: python-version: [ "3.9", "3.10", "3.11"] steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 @@ -35,16 +45,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | cd brainpy @@ -82,40 +85,6 @@ jobs: pytest _src/ -# test_linux_py37: -# runs-on: ubuntu-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.7"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi -# pip install jax==0.3.25 -# pip install jaxlib==0.3.25 -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ -# - test_macos: runs-on: macos-latest strategy: @@ -124,6 +93,10 @@ jobs: python-version: ["3.9", "3.10", "3.11"] steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 @@ -132,16 +105,40 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install - - name: Lint with flake8 + - name: Test with pytest run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + cd brainpy + pytest -n auto --tb=short _src/ + + + test_windows: + strategy: + fail-fast: false + matrix: + os: [ win-2019-16core ] + arch: [ AMD64 ] + python-version: ["3.9", "3.10", "3.11"] + runs-on: ${{ matrix.os }} + + steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements-dev.txt + pip uninstall brainpy -y + python setup.py install - name: Test with pytest run: | cd brainpy @@ -178,104 +175,3 @@ jobs: cd brainpy pytest _src/ -# test_macos_py37: -# runs-on: macos-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: [ "3.7" ] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi -# pip install jax==0.3.25 -# pip install jaxlib==0.3.25 -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ -# - - -# test_windows: -# runs-on: windows-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.9", "3.10", "3.11"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# python -m pip install -r requirements-dev.txt -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd brainpy -# pytest _src/ - - -# test_windows_py37: -# runs-on: windows-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.7"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# python -m pip install numpy>=1.21.0 -# python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver -# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz -# python -m pip install -r requirements-dev.txt -# python -m pip install tqdm brainpylib -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ diff --git a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py index c3936f685..6db945ff2 100644 --- a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py @@ -1,130 +1,124 @@ -# -*- coding: utf-8 -*- - -import pytest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm -from brainpy._src.dynold.synapses import abstract_models -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) - - -class Test_Abstract_Synapse(parameterized.TestCase): - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_all2all_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(bp.synapses, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_one2one_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - comp_type=['sparse', 'dense'], - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_sparse_synapse(self, comp_type, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - post_ref_key=[None, 'refractory'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_delta_synapse(self, post_ref_key, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5, ref_var=True) - post_neu = bp.neurons.LIF(3, ref_var=True) - syn = bp.synapses.Delta(pre_neu, post_neu, - conn=bp.conn.All2All(), - post_ref_key=post_ref_key, - stp=stp, ) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - pre_expected_shape = (100, 5) - post_expected_shape = (100, 3) - if isinstance(mode, bm.BatchingMode): - pre_expected_shape = (mode.batch_size,) + pre_expected_shape - post_expected_shape = (mode.batch_size,) + post_expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) - bm.clear_buffer_memory() +# -*- coding: utf-8 -*- + + +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm +from brainpy._src.dynold.synapses import abstract_models + + +class Test_Abstract_Synapse(parameterized.TestCase): + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_all2all_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(bp.synapses, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, monitors=['pre.V', 'syn.g', 'post.V'], inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size, ) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_one2one_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size, ) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + comp_type=['sparse', 'dense'], + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_sparse_synapse(self, comp_type, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size, ) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + post_ref_key=[None, 'refractory'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_delta_synapse(self, post_ref_key, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5, ref_var=True) + post_neu = bp.neurons.LIF(3, ref_var=True) + syn = bp.synapses.Delta(pre_neu, post_neu, + conn=bp.conn.All2All(), + post_ref_key=post_ref_key, + stp=stp, ) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + pre_expected_shape = (100, 5) + post_expected_shape = (100, 3) + if isinstance(mode, bm.BatchingMode): + pre_expected_shape = (mode.batch_size,) + pre_expected_shape + post_expected_shape = (mode.batch_size,) + post_expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) + bm.clear_buffer_memory() \ No newline at end of file diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 668f837c0..7827dfed3 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -2,6 +2,7 @@ import functools +import gc import inspect import os import re @@ -16,6 +17,7 @@ from . import modes from . import scales from . import defaults +from .object_transform import naming from brainpy._src.dependency_check import import_taichi ti = import_taichi(error_if_not_found=False) @@ -681,7 +683,9 @@ def set_host_device_count(n): def clear_buffer_memory( platform: str = None, array: bool = True, - compilation: bool = False + transform: bool = True, + compilation: bool = False, + object_name: bool = False, ): """Clear all on-device buffers. @@ -698,9 +702,13 @@ def clear_buffer_memory( platform: str The device to clear its memory. array: bool - Clear all buffer array. + Clear all buffer array. Default is True. compilation: bool - Clear compilation cache. + Clear compilation cache. Default is False. + transform: bool + Clear transform cache. Default is True. + object_name: bool + Clear name cache. Default is True. """ if array: @@ -708,6 +716,11 @@ def clear_buffer_memory( buf.delete() if compilation: jax.clear_caches() + if transform: + naming.clear_stack_cache() + if object_name: + naming.clear_name_cache() + gc.collect() def disable_gpu_memory_preallocation(release_memory: bool = True): diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 976b72b96..ac62bbfaf 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -45,9 +45,747 @@ def event_mv_prob_homo( if ti is None: raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + +event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ + + +def event_mv_prob_uniform( + events: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + +event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ + + +def event_mv_prob_normal( + events: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + +### BRAINPYLIB ### + +def event_mv_prob_homo_brainpylib( + events: jax.Array, + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + events = as_jax(events) + weight = jnp.atleast_1d(jnp.asarray(weight)) + conn_prob = jnp.atleast_1d(jnp.asarray(conn_prob)) + clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) + with jax.ensure_compile_time_eval(): + if seed is None: + seed = int(np.random.randint(0, int(1e8))) + seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) + r = event_mv_prob_homo_p.bind(events, + weight, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + return r + + +event_mv_prob_homo_brainpylib.__doc__ = mv_prob_homo.__doc__ + + +def event_mv_prob_uniform_brainpylib( + events: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + events = as_jax(events) + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_prob = jnp.atleast_1d(as_jax(conn_prob)) + clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) + with jax.ensure_compile_time_eval(): + if seed is None: + seed = int(np.random.randint(0, int(1e8))) + seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) + return event_mv_prob_uniform_p.bind(events, + w_low, + w_high, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + + +event_mv_prob_uniform_brainpylib.__doc__ = mv_prob_uniform.__doc__ + + +def event_mv_prob_normal_brainpylib( + events: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + events = as_jax(events) + w_mu = jnp.atleast_1d(as_jax(w_mu)) + w_sigma = jnp.atleast_1d(as_jax(w_sigma)) + conn_prob = jnp.atleast_1d(as_jax(conn_prob)) + clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) + with jax.ensure_compile_time_eval(): + if seed is None: + seed = int(np.random.randint(0, int(1e8))) + seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) + return event_mv_prob_normal_p.bind(events, + w_mu, + w_sigma, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + + +event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__ + + +def _event_matvec_prob_homo_abstract( + events, weight, clen, seed, *, shape, transpose, outdim_parallel +): + assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] + assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + + if events.ndim != 1: + raise ValueError('events should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('conn_prob must be a 1D scalar.') + if weight.ndim != 1: + raise ValueError('weight must be a 1D scalar.') + + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be boolean value.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be boolean value.') + + if transpose: + if events.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') + else: + if events.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') + out = ShapedArray(dtype=weight.dtype, shape=(shape[1] if transpose else shape[0],)) + return [out] + + +def _event_matvec_prob_homo_cpu_translation( + c, events, weight, clen, seed, *, shape, transpose, outdim_parallel +): + import_brainpylib_cpu_ops() + n_row, n_col = (shape[1], shape[0]) if transpose else shape + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + if outdim_parallel: + fn = b'cpu_event_matvec_prob_homo' + type_name + event_type + else: + fn = b'cpu_event_matvec_atomic_prob_homo' + type_name + event_type + + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, + weight, + clen, + seed, + xla_client.ops.ConstantLiteral(c, n_row), + xla_client.ops.ConstantLiteral(c, n_col)), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(weight), + c.get_shape(clen), + c.get_shape(seed), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + ) + + +def _event_matvec_prob_homo_gpu_translation( + c, events, weight, clen, seed, *, shape, transpose, outdim_parallel +): + gpu_ops = import_brainpylib_gpu_ops() + if gpu_ops is None: + raise GPUOperatorNotFound(event_mv_prob_homo_p.name) + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], + shape[0] if transpose else shape[1], ) + + if outdim_parallel: + fn = b'gpu_jit_event_csrmv_prob_homo_v2' + type_name + event_type + else: + fn = b'gpu_jit_event_csrmv_atomic_prob_homo_v2' + type_name + event_type + + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, weight, clen, seed), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(weight), + c.get_shape(clen), + c.get_shape(seed)), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + opaque=opaque, + ) + + +def _event_matvec_prob_homo_jvp( + primals, tangents, *, shape, transpose, outdim_parallel +): + events, weight, clen, seed = primals + event_dot, weight_dot, clen_dot, seed_dot = tangents + r = event_mv_prob_homo_p.bind(events, + weight, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + assert type(weight_dot) is ad.Zero + assert type(clen_dot) is ad.Zero + assert type(seed_dot) is ad.Zero + if type(weight_dot) is ad.Zero: + if type(event_dot) is ad.Zero: + raise ValueError + dr = mv_prob_homo_p.bind(event_dot, + weight, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + elif type(event_dot) is ad.Zero: + dr = mv_prob_homo_p.bind(events, + weight_dot, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + else: + dr = mv_prob_homo_p.bind(event_dot, + weight_dot, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + return r, dr + + +def _event_matvec_prob_homo_transpose( + ct, events, weight, clen, seed, *, shape, transpose, outdim_parallel +): + assert type(events) is ad.UndefinedPrimal + assert type(weight) is not ad.UndefinedPrimal + assert type(clen) is not ad.UndefinedPrimal + assert type(seed) is not ad.UndefinedPrimal + + r = mv_prob_homo_p.bind(ct[0], + weight, + clen, + seed, + shape=shape, + transpose=not transpose, + outdim_parallel=not outdim_parallel)[0] + return r, weight, clen, seed + + +event_mv_prob_homo_p = Primitive('event_mv_prob_homo') +event_mv_prob_homo_p.multiple_results = True +event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract) +event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p)) +# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation +# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation +ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp +ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose +register_general_batching(event_mv_prob_homo_p) + + +def _event_matvec_prob_uniform_abstract( + events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel +): + assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] + _w_low_dtype = _get_dtype(w_low) + _w_high_dtype = _get_dtype(w_low) + assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' + assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' + assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + + if events.ndim != 1: + raise ValueError('events should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if w_low.ndim != 1: + raise ValueError('w_low must be a 1D scalar.') + if w_high.ndim != 1: + raise ValueError('w_high must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('clen must be a 1D scalar.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + + if not isinstance(transpose, bool): + raise ValueError('transpose must be a boolean value.') + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be a boolean value.') + assert w_low.dtype == w_high.dtype + + if transpose: + if events.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') + else: + if events.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') + + out = ShapedArray(dtype=w_low.dtype, shape=(shape[1] if transpose else shape[0],)) + return [out] + + +def _event_matvec_prob_uniform_cpu_translation( + c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel +): + import_brainpylib_cpu_ops() + n_row, n_col = (shape[1], shape[0]) if transpose else shape + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + if outdim_parallel: + fn = b'cpu_event_matvec_prob_uniform' + type_name + event_type + else: + fn = b'cpu_event_matvec_atomic_prob_uniform' + type_name + event_type + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, + w_low, + w_high, + clen, + seed, + xla_client.ops.ConstantLiteral(c, n_row), + xla_client.ops.ConstantLiteral(c, n_col)), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(w_low), + c.get_shape(w_high), + c.get_shape(clen), + c.get_shape(seed), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + ) + + +def _event_matvec_prob_uniform_gpu_translation( + c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel +): + gpu_ops = import_brainpylib_gpu_ops() + if gpu_ops is None: + raise GPUOperatorNotFound(event_mv_prob_uniform_p.name) + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], + shape[0] if transpose else shape[1]) + if outdim_parallel: + fn = b'gpu_jit_event_csrmv_prob_uniform_v2' + type_name + event_type + else: + fn = b'gpu_jit_event_csrmv_atomic_prob_uniform_v2' + type_name + event_type + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, w_low, w_high, clen, seed), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(w_low), + c.get_shape(w_high), + c.get_shape(clen), + c.get_shape(seed),), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + opaque=opaque, + ) + + +def _event_matvec_prob_uniform_jvp( + primals, tangents, *, shape, transpose, outdim_parallel +): + events, w_low, w_high, clen, seed = primals + events_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents + r = event_mv_prob_uniform_p.bind(events, + w_low, + w_high, + clen, + seed, + shape=shape, + outdim_parallel=outdim_parallel, + transpose=transpose) + assert type(w_low_dot) is ad.Zero + assert type(w_high_dot) is ad.Zero + assert type(clen_dot) is ad.Zero + assert type(seed_dot) is ad.Zero + r_dot = mv_prob_uniform_p.bind(events_dot, + w_low, + w_high, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + return r, r_dot + + +def _event_matvec_prob_uniform_transpose( + ct, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel +): + assert type(events) is ad.UndefinedPrimal + assert type(w_low) is not ad.UndefinedPrimal + assert type(w_high) is not ad.UndefinedPrimal + assert type(clen) is not ad.UndefinedPrimal + assert type(seed) is not ad.UndefinedPrimal + + r = mv_prob_uniform_p.bind(ct[0], + w_low, + w_high, + clen, + seed, + shape=shape, + transpose=not transpose, + outdim_parallel=not outdim_parallel)[0] + return r, w_low, w_high, clen, seed + + +event_mv_prob_uniform_p = Primitive('event_mv_prob_uniform') +event_mv_prob_uniform_p.multiple_results = True +event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract) +event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p)) +# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation +# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation +register_general_batching(event_mv_prob_uniform_p) +ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp +ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose + + +def _event_matvec_prob_normal_abstract( + events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel +): + assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] + _w_mu_dtype = _get_dtype(w_mu) + _w_sigma_dtype = _get_dtype(w_sigma) + assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' + assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + + if w_mu.ndim != 1: + raise ValueError('w_mu should be a 1D scalar.') + if w_sigma.ndim != 1: + raise ValueError('w_sigma should be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('clen should be a 1D scalar.') + if events.ndim != 1: + raise ValueError('events should be a 1D vector.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + assert w_mu.dtype == w_sigma.dtype + + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be a boolean value.') + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be a boolean value.') + + if transpose: + if events.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') + else: + if events.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') + + out = ShapedArray(dtype=w_mu.dtype, shape=(shape[1] if transpose else shape[0],)) + return [out] + + +def _get_types(event_shape): + event_type = event_shape.element_type() + if event_type == jnp.bool_: + event_type = b'_bool' + out_dtype = dtypes.canonicalize_dtype(float) + elif event_type == jnp.float32: + event_type = b'_float' + out_dtype = event_shape.element_type() + elif event_type == jnp.float64: + event_type = b'_double' + out_dtype = event_shape.element_type() + else: + raise TypeError + + if out_dtype == jnp.float32: + type_name = b'_float' + elif out_dtype == jnp.float64: + type_name = b'_double' + else: + raise TypeError + + return out_dtype, event_type, type_name + + +def _event_matvec_prob_normal_cpu_translation( + c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel +): + import_brainpylib_cpu_ops() + n_row, n_col = (shape[1], shape[0]) if transpose else shape + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + if outdim_parallel: + fn = b'cpu_event_matvec_prob_normal' + type_name + event_type + else: + fn = b'cpu_event_matvec_atomic_prob_normal' + type_name + event_type + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, + w_mu, + w_sigma, + clen, + seed, + xla_client.ops.ConstantLiteral(c, n_row), + xla_client.ops.ConstantLiteral(c, n_col)), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(w_mu), + c.get_shape(w_sigma), + c.get_shape(clen), + c.get_shape(seed), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + ) + + +def _event_matvec_prob_normal_gpu_translation( + c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel +): + gpu_ops = import_brainpylib_gpu_ops() + if gpu_ops is None: + raise GPUOperatorNotFound(event_mv_prob_normal_p.name) + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], + shape[0] if transpose else shape[1]) + if outdim_parallel: + fn = b'gpu_jit_event_csrmv_prob_normal_v2' + type_name + event_type + else: + fn = b'gpu_jit_event_csrmv_atomic_prob_normal_v2' + type_name + event_type + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, w_mu, w_sigma, clen, seed), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(w_mu), + c.get_shape(w_sigma), + c.get_shape(clen), + c.get_shape(seed)), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + opaque=opaque, + ) + + +def _event_matvec_prob_normal_jvp( + primals, tangents, *, shape, transpose, outdim_parallel +): + events, w_mu, w_sigma, clen, seed = primals + events_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents + r = event_mv_prob_normal_p.bind(events, + w_mu, + w_sigma, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + assert type(w_mu_dot) is ad.Zero + assert type(w_sigma_dot) is ad.Zero + assert type(clen_dot) is ad.Zero + assert type(seed_dot) is ad.Zero + r_dot = mv_prob_normal_p.bind(events_dot, + w_mu, + w_sigma, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + return r, r_dot + + +def _event_matvec_prob_normal_transpose( + ct, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel +): + assert type(events) is ad.UndefinedPrimal + assert type(w_mu) is not ad.UndefinedPrimal + assert type(w_sigma) is not ad.UndefinedPrimal + assert type(clen) is not ad.UndefinedPrimal + assert type(seed) is not ad.UndefinedPrimal + + r = mv_prob_normal_p.bind(ct[0], + w_mu, + w_sigma, + clen, + seed, + shape=shape, + transpose=not transpose, + outdim_parallel=not outdim_parallel)[0] + return r, w_mu, w_sigma, clen, seed + + +event_mv_prob_normal_p = Primitive('event_mv_prob_normal') +event_mv_prob_normal_p.multiple_results = True +event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract) +event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p)) +# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation +# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation +register_general_batching(event_mv_prob_normal_p) +ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp +ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose + + +### TAICHI ### + +def event_mv_prob_homo_taichi( + events: jax.Array, + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + events: Array, ndarray + The events. + weight: float + The value of the random matrix. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ events = as_jax(events) - if isinstance(weight, float): weight = as_jax(weight) - weight = jnp.atleast_1d(as_jax(weight)) + weight = as_jax(weight) + if jnp.ndim(weight) < 1: + weight = jnp.expand_dims(weight, axis=0) conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) if seed is None: diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py index 1c8ca6ef9..6326929c4 100644 --- a/brainpy/_src/math/object_transform/naming.py +++ b/brainpy/_src/math/object_transform/naming.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- - +import gc import warnings from brainpy import errors @@ -11,6 +11,7 @@ _name2id = dict() _typed_names = {} +_fun2stack = dict() def check_name_uniqueness(name, obj): @@ -49,9 +50,6 @@ def clear_name_cache(ignore_warn=False): warnings.warn(f'All named models and their ids are cleared.', UserWarning) -_fun2stack = dict() - - def cache_stack(func, stack): _fun2stack[func] = stack @@ -59,6 +57,7 @@ def cache_stack(func, stack): def clear_stack_cache(): for k in tuple(_fun2stack.keys()): del _fun2stack[k] + gc.collect() def get_stack_cache(func):