Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions brainpy/_src/math/op_register/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .numba_approach import (CustomOpByNumba,
register_op_with_numba,
compile_cpu_signature_with_numba)
from .taichi_aot_based import clean_caches, check_kernels_count
from .base import XLACustomOp
from .utils import register_general_batching
4 changes: 3 additions & 1 deletion brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_cpu_translation_rule,
register_taichi_gpu_translation_rule,)
register_taichi_gpu_translation_rule,
clean_caches)
from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp

Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
if transpose_translation is not None:
ad.primitive_transposes[self.primitive] = transpose_translation


def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs):
if outs is None:
outs = self.outs
Expand Down
41 changes: 38 additions & 3 deletions brainpy/_src/math/op_register/taichi_aot_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pathlib
import platform
import re
import shutil
from functools import partial, reduce
from typing import Any, Sequence

Expand Down Expand Up @@ -36,6 +37,34 @@ def encode_md5(source: str) -> str:

return md5.hexdigest()

# check kernels count
def check_kernels_count() -> int:
if not os.path.exists(kernels_aot_path):
return 0
kernels_count = 0
dir1 = os.listdir(kernels_aot_path)
for i in dir1:
dir2 = os.listdir(os.path.join(kernels_aot_path, i))
kernels_count += len(dir2)
return kernels_count

# clean caches
def clean_caches(kernels_name: list[str]=None):
if kernels_name is None:
if not os.path.exists(kernels_aot_path):
raise FileNotFoundError("The kernels cache folder does not exist. \
Please define a kernel using `taichi.kernel` \
and customize the operator using `bm.XLACustomOp` \
before calling the operator.")
shutil.rmtree(kernels_aot_path)
print('Clean all kernel\'s cache successfully')
return
for kernel_name in kernels_name:
try:
shutil.rmtree(os.path.join(kernels_aot_path, kernel_name))
except FileNotFoundError:
raise FileNotFoundError(f'Kernel {kernel_name} does not exist.')
print('Clean kernel\'s cache successfully')

# TODO
# not a very good way
Expand Down Expand Up @@ -151,6 +180,9 @@ def _build_kernel(
if ti.lang.impl.current_cfg().arch != arch:
raise RuntimeError(f"Arch {arch} is not available")

# get kernel name
kernel_name = kernel.__name__

# replace the name of the func
kernel.__name__ = f'taichi_kernel_{device}'

Expand All @@ -170,6 +202,9 @@ def _build_kernel(
mod.add_kernel(kernel, template_args=template_args_dict)
mod.save(kernel_path)

# rename kernel name
kernel.__name__ = kernel_name


### KERNEL CALL PREPROCESS ###

Expand Down Expand Up @@ -246,7 +281,7 @@ def _preprocess_kernel_call_cpu(
return in_out_info


def preprocess_kernel_call_gpu(
def _preprocess_kernel_call_gpu(
source_md5_encode: str,
ins: dict,
outs: dict,
Expand Down Expand Up @@ -312,7 +347,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs):

# kernel to code
codes = _kernel_to_code(kernel, abs_ins, abs_outs, platform)
source_md5_encode = encode_md5(codes)
source_md5_encode = kernel.__name__ + '/' + encode_md5(codes)

# create ins, outs dict from kernel's args
in_num = len(ins)
Expand All @@ -332,7 +367,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs):
# returns
if platform in ['gpu', 'cuda']:
import_brainpylib_gpu_ops()
opaque = preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict)
opaque = _preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict)
return opaque
elif platform == 'cpu':
import_brainpylib_cpu_ops()
Expand Down
54 changes: 54 additions & 0 deletions brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import brainpy.math as bm
import jax
import jax.numpy as jnp
import platform
import pytest
import taichi

if not platform.platform().startswith('Windows'):
pytest.skip(allow_module_level=True)

@taichi.func
def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32:
return weight[0]


@taichi.func
def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32):
out[index] += weight_val

@taichi.kernel
def event_ell_cpu(indices: taichi.types.ndarray(ndim=2),
vector: taichi.types.ndarray(ndim=1),
weight: taichi.types.ndarray(ndim=1),
out: taichi.types.ndarray(ndim=1)):
weight_val = get_weight(weight)
num_rows, num_cols = indices.shape
taichi.loop_config(serialize=True)
for i in range(num_rows):
if vector[i]:
for j in range(num_cols):
update_output(out, indices[i, j], weight_val)

prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu)

def test_taichi_clean_cache():
s = 1000
indices = bm.random.randint(0, s, (s, 1000))
vector = bm.random.rand(s) < 0.1
weight = bm.array([1.0])

out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

print(out)
bm.clear_buffer_memory()

print('kernels: ', bm.check_kernels_count())

bm.clean_caches()

print('kernels: ', bm.check_kernels_count())

# test_taichi_clean_cache()
2 changes: 2 additions & 0 deletions brainpy/math/op_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from brainpy._src.math.op_register import (
CustomOpByNumba,
compile_cpu_signature_with_numba,
clean_caches,
check_kernels_count,
)

from brainpy._src.math.op_register.base import XLACustomOp
Expand Down
116 changes: 64 additions & 52 deletions docs/tutorial_advanced/operator_custom_with_taichi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
},
{
"cell_type": "markdown",
"source": [
"This functionality is only available for ``brainpylib>=0.2.0``. "
],
"metadata": {
"collapsed": false
}
},
"source": [
"This functionality is only available for ``brainpylib>=0.2.0``. "
]
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -182,26 +182,6 @@
" # If the kernel is run on the CUDA backend, each block will have 16 threads.\n",
" for i in range(n):\n",
" val[i] = i\n",
"```\n",
"\n",
"#### `ti.grouped`\n",
"Groups the indices in the iterator returned by ndrange() into a 1-D vector.\n",
"This is often used when you want to iterate over all indices returned by ndrange() in one for loop and a single index.\n",
"\n",
"Example:\n",
"\n",
"```python\n",
"# without ti.grouped\n",
"for I in ti.ndrange(2, 3):\n",
" print(I)\n",
"prints 0, 1, 2, 3, 4, 5\n",
"```\n",
"\n",
"```python\n",
"# with ti.grouped\n",
"for I in ti.grouped(ndrange(2, 3)):\n",
" print(I)\n",
"prints [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]\n",
"```"
]
},
Expand Down Expand Up @@ -251,11 +231,12 @@
" vector: ti.types.ndarray(ndim=1), \n",
" weight: ti.types.ndarray(ndim=1), \n",
" out: ti.types.ndarray(ndim=1)):\n",
" weight_0 = weight[0]\n",
" ti.loop_config(block_dim=64)\n",
" for ij in ti.grouped(indices):\n",
" if vector[ij[0]]:\n",
" out[ij[1]] += weight_0\n",
" weight_val = get_weight(weight)\n",
" num_rows, num_cols = indices.shape\n",
" for i in range(num_rows):\n",
" if vector[i]:\n",
" for j in range(num_cols):\n",
" update_output(out, indices[i, j], weight_val)\n",
"\n",
"prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)\n",
"\n",
Expand All @@ -276,6 +257,32 @@
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### More Examples\n",
"For more examples, please refer to: \n",
"- [event/_csr_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/event/_csr_matvec_taichi.py)\n",
"- [sparse/_csr_mv_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/sparse/_csr_mv_taichi.py)\n",
"- [jitconn/_event_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_event_matvec_taichi.py)\n",
"- [jitconn/_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_matvec_taichi.py)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Clean the cache of taichi kernels\n",
"Because brainpy fuse taichi and JAX using taichi AOT method, the taichi kernels will be cached in the system. If you want to clean the cache, you can use the following code:\n",
"\n",
"```python\n",
"import brainpy.math as bm\n",
"\n",
"bm.clean_caches()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -442,28 +449,7 @@
" # If the kernel is run on the CUDA backend, each block will have 16 threads.\n",
" for i in range(n):\n",
" val[i] = i\n",
"```\n",
"\n",
"#### `ti.grouped`\n",
"\n",
"将由`ndrange()`返回的迭代器中的索引组合成一个一维向量。\n",
"这通常在你想要在一个 for 循环中迭代 ndrange() 返回的所有索引时使用,并且只使用一个索引。\n",
"\n",
"示例:\n",
"\n",
"```python\n",
"# without ti.grouped\n",
"for I in ti.ndrange(2, 3):\n",
" print(I)\n",
"prints 0, 1, 2, 3, 4, 5\n",
"```\n",
"\n",
"```python\n",
"# with ti.grouped\n",
"for I in ti.grouped(ndrange(2, 3)):\n",
" print(I)\n",
"prints [0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2]\n",
"```"
"```\n"
]
},
{
Expand Down Expand Up @@ -536,6 +522,32 @@
"test_taichi_op_register()\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 更多示例\n",
"对于更多示例, 请参考: \n",
"- [event/_csr_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/event/_csr_matvec_taichi.py)\n",
"- [sparse/_csr_mv_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/sparse/_csr_mv_taichi.py)\n",
"- [jitconn/_event_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_event_matvec_taichi.py)\n",
"- [jitconn/_matvec_taichi.py](https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/math/jitconn/_matvec_taichi.py)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 清除Taichi kernel的缓存\n",
"因为brainpy使用taichi的AOT方法来融合taichi和JAX,所以taichi的kernel会被缓存到系统中。如果你想清除缓存,可以使用以下代码:\n",
"\n",
"```python\n",
"import brainpy.math as bm\n",
"\n",
"bm.clean_caches()\n",
"```"
]
}
],
"metadata": {
Expand All @@ -554,7 +566,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down