diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py index 6f2dbd4f2..01f77dbca 100644 --- a/brainpy/_src/math/op_register/__init__.py +++ b/brainpy/_src/math/op_register/__init__.py @@ -2,5 +2,6 @@ from .numba_approach import (CustomOpByNumba, register_op_with_numba, compile_cpu_signature_with_numba) +from .taichi_aot_based import clean_caches, check_kernels_count from .base import XLACustomOp from .utils import register_general_batching diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index cb05ece81..bc5f4c15a 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -14,7 +14,8 @@ # from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_cpu_translation_rule, - register_taichi_gpu_translation_rule,) + register_taichi_gpu_translation_rule, + clean_caches) from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp @@ -138,6 +139,7 @@ def __init__( if transpose_translation is not None: ad.primitive_transposes[self.primitive] = transpose_translation + def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs): if outs is None: outs = self.outs diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index ab7b98011..878b205cf 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -4,6 +4,7 @@ import pathlib import platform import re +import shutil from functools import partial, reduce from typing import Any, Sequence @@ -36,6 +37,34 @@ def encode_md5(source: str) -> str: return md5.hexdigest() +# check kernels count +def check_kernels_count() -> int: + if not os.path.exists(kernels_aot_path): + return 0 + kernels_count = 0 + dir1 = os.listdir(kernels_aot_path) + for i in dir1: + dir2 = os.listdir(os.path.join(kernels_aot_path, i)) + kernels_count += len(dir2) + return kernels_count + +# clean caches +def clean_caches(kernels_name: list[str]=None): + if kernels_name is None: + if not os.path.exists(kernels_aot_path): + raise FileNotFoundError("The kernels cache folder does not exist. \ + Please define a kernel using `taichi.kernel` \ + and customize the operator using `bm.XLACustomOp` \ + before calling the operator.") + shutil.rmtree(kernels_aot_path) + print('Clean all kernel\'s cache successfully') + return + for kernel_name in kernels_name: + try: + shutil.rmtree(os.path.join(kernels_aot_path, kernel_name)) + except FileNotFoundError: + raise FileNotFoundError(f'Kernel {kernel_name} does not exist.') + print('Clean kernel\'s cache successfully') # TODO # not a very good way @@ -151,6 +180,9 @@ def _build_kernel( if ti.lang.impl.current_cfg().arch != arch: raise RuntimeError(f"Arch {arch} is not available") + # get kernel name + kernel_name = kernel.__name__ + # replace the name of the func kernel.__name__ = f'taichi_kernel_{device}' @@ -170,6 +202,9 @@ def _build_kernel( mod.add_kernel(kernel, template_args=template_args_dict) mod.save(kernel_path) + # rename kernel name + kernel.__name__ = kernel_name + ### KERNEL CALL PREPROCESS ### @@ -246,7 +281,7 @@ def _preprocess_kernel_call_cpu( return in_out_info -def preprocess_kernel_call_gpu( +def _preprocess_kernel_call_gpu( source_md5_encode: str, ins: dict, outs: dict, @@ -312,7 +347,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): # kernel to code codes = _kernel_to_code(kernel, abs_ins, abs_outs, platform) - source_md5_encode = encode_md5(codes) + source_md5_encode = kernel.__name__ + '/' + encode_md5(codes) # create ins, outs dict from kernel's args in_num = len(ins) @@ -332,7 +367,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): # returns if platform in ['gpu', 'cuda']: import_brainpylib_gpu_ops() - opaque = preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict) + opaque = _preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict) return opaque elif platform == 'cpu': import_brainpylib_cpu_ops() diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py new file mode 100644 index 000000000..1bebcdafe --- /dev/null +++ b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py @@ -0,0 +1,54 @@ +import brainpy.math as bm +import jax +import jax.numpy as jnp +import platform +import pytest +import taichi + +if not platform.platform().startswith('Windows'): + pytest.skip(allow_module_level=True) + +@taichi.func +def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32: + return weight[0] + + +@taichi.func +def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32): + out[index] += weight_val + +@taichi.kernel +def event_ell_cpu(indices: taichi.types.ndarray(ndim=2), + vector: taichi.types.ndarray(ndim=1), + weight: taichi.types.ndarray(ndim=1), + out: taichi.types.ndarray(ndim=1)): + weight_val = get_weight(weight) + num_rows, num_cols = indices.shape + taichi.loop_config(serialize=True) + for i in range(num_rows): + if vector[i]: + for j in range(num_cols): + update_output(out, indices[i, j], weight_val) + +prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) + +def test_taichi_clean_cache(): + s = 1000 + indices = bm.random.randint(0, s, (s, 1000)) + vector = bm.random.rand(s) < 0.1 + weight = bm.array([1.0]) + + out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + + out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + + print(out) + bm.clear_buffer_memory() + + print('kernels: ', bm.check_kernels_count()) + + bm.clean_caches() + + print('kernels: ', bm.check_kernels_count()) + +# test_taichi_clean_cache() \ No newline at end of file diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py index 014a54e6f..a48268ef4 100644 --- a/brainpy/math/op_register.py +++ b/brainpy/math/op_register.py @@ -4,6 +4,8 @@ from brainpy._src.math.op_register import ( CustomOpByNumba, compile_cpu_signature_with_numba, + clean_caches, + check_kernels_count, ) from brainpy._src.math.op_register.base import XLACustomOp diff --git a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb index 0443aed9d..c08cfdb2b 100644 --- a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb @@ -9,12 +9,12 @@ }, { "cell_type": "markdown", - "source": [ - "This functionality is only available for ``brainpylib>=0.2.0``. " - ], "metadata": { "collapsed": false - } + }, + "source": [ + "This functionality is only available for ``brainpylib>=0.2.0``. " + ] }, { "cell_type": "markdown", @@ -182,26 +182,6 @@ " # If the kernel is run on the CUDA backend, each block will have 16 threads.\n", " for i in range(n):\n", " val[i] = i\n", - "```\n", - "\n", - "#### `ti.grouped`\n", - "Groups the indices in the iterator returned by ndrange() into a 1-D vector.\n", - "This is often used when you want to iterate over all indices returned by ndrange() in one for loop and a single index.\n", - "\n", - "Example:\n", - "\n", - "```python\n", - "# without ti.grouped\n", - "for I in ti.ndrange(2, 3):\n", - " print(I)\n", - "prints 0, 1, 2, 3, 4, 5\n", - "```\n", - "\n", - "```python\n", - "# with ti.grouped\n", - "for I in ti.grouped(ndrange(2, 3)):\n", - " print(I)\n", - "prints [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]\n", "```" ] }, @@ -251,11 +231,12 @@ " vector: ti.types.ndarray(ndim=1), \n", " weight: ti.types.ndarray(ndim=1), \n", " out: ti.types.ndarray(ndim=1)):\n", - " weight_0 = weight[0]\n", - " ti.loop_config(block_dim=64)\n", - " for ij in ti.grouped(indices):\n", - " if vector[ij[0]]:\n", - " out[ij[1]] += weight_0\n", + " weight_val = get_weight(weight)\n", + " num_rows, num_cols = indices.shape\n", + " for i in range(num_rows):\n", + " if vector[i]:\n", + " for j in range(num_cols):\n", + " update_output(out, indices[i, j], weight_val)\n", "\n", "prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)\n", "\n", @@ -276,6 +257,32 @@ "```" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### More Examples\n", + "For more examples, please refer to: \n", + "- [event/_csr_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/event/_csr_matvec_taichi.py)\n", + "- [sparse/_csr_mv_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/sparse/_csr_mv_taichi.py)\n", + "- [jitconn/_event_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_event_matvec_taichi.py)\n", + "- [jitconn/_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_matvec_taichi.py)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Clean the cache of taichi kernels\n", + "Because brainpy fuse taichi and JAX using taichi AOT method, the taichi kernels will be cached in the system. If you want to clean the cache, you can use the following code:\n", + "\n", + "```python\n", + "import brainpy.math as bm\n", + "\n", + "bm.clean_caches()\n", + "```" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -442,28 +449,7 @@ " # If the kernel is run on the CUDA backend, each block will have 16 threads.\n", " for i in range(n):\n", " val[i] = i\n", - "```\n", - "\n", - "#### `ti.grouped`\n", - "\n", - "将由`ndrange()`返回的迭代器中的索引组合成一个一维向量。\n", - "这通常在你想要在一个 for 循环中迭代 ndrange() 返回的所有索引时使用,并且只使用一个索引。\n", - "\n", - "示例:\n", - "\n", - "```python\n", - "# without ti.grouped\n", - "for I in ti.ndrange(2, 3):\n", - " print(I)\n", - "prints 0, 1, 2, 3, 4, 5\n", - "```\n", - "\n", - "```python\n", - "# with ti.grouped\n", - "for I in ti.grouped(ndrange(2, 3)):\n", - " print(I)\n", - "prints [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]\n", - "```" + "```\n" ] }, { @@ -536,6 +522,32 @@ "test_taichi_op_register()\n", "```" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 更多示例\n", + "对于更多示例, 请参考: \n", + "- [event/_csr_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/event/_csr_matvec_taichi.py)\n", + "- [sparse/_csr_mv_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/sparse/_csr_mv_taichi.py)\n", + "- [jitconn/_event_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_event_matvec_taichi.py)\n", + "- [jitconn/_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_matvec_taichi.py)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 清除Taichi kernel的缓存\n", + "因为brainpy使用taichi的AOT方法来融合taichi和JAX,所以taichi的kernel会被缓存到系统中。如果你想清除缓存,可以使用以下代码:\n", + "\n", + "```python\n", + "import brainpy.math as bm\n", + "\n", + "bm.clean_caches()\n", + "```" + ] } ], "metadata": { @@ -554,7 +566,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", - "version": "2.7.6" + "version": "3.10.13" } }, "nbformat": 4,