Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions brainpy/_src/math/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
# '''Default complex data type.'''
complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64

# register brainpy object as pytree
bp_object_as_pytree = False

if ti is not None:
# '''Default integer data type in Taichi.'''
ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32
Expand All @@ -40,3 +43,4 @@
else:
ti_int = None
ti_float = None

25 changes: 22 additions & 3 deletions brainpy/_src/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
float_: type = None,
int_: type = None,
bool_: type = None,
bp_object_as_pytree: bool = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -203,6 +204,10 @@ def __init__(
assert isinstance(complex_, type), '"complex_" must a type.'
self.old_complex = get_complex()

if bp_object_as_pytree is not None:
assert isinstance(bp_object_as_pytree, bool), '"bp_object_as_pytree" must be a bool.'
self.old_bp_object_as_pytree = defaults.bp_object_as_pytree

self.dt = dt
self.mode = mode
self.membrane_scaling = membrane_scaling
Expand All @@ -211,6 +216,7 @@ def __init__(
self.float_ = float_
self.int_ = int_
self.bool_ = bool_
self.bp_object_as_pytree = bp_object_as_pytree

def __enter__(self) -> 'environment':
if self.dt is not None: set_dt(self.dt)
Expand All @@ -221,6 +227,7 @@ def __enter__(self) -> 'environment':
if self.int_ is not None: set_int(self.int_)
if self.complex_ is not None: set_complex(self.complex_)
if self.bool_ is not None: set_bool(self.bool_)
if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.bp_object_as_pytree
return self

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
Expand All @@ -232,6 +239,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.float_ is not None: set_float(self.old_float)
if self.complex_ is not None: set_complex(self.old_complex)
if self.bool_ is not None: set_bool(self.old_bool)
if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.old_bp_object_as_pytree

def clone(self):
return self.__class__(dt=self.dt,
Expand All @@ -241,7 +249,8 @@ def clone(self):
bool_=self.bool_,
complex_=self.complex_,
float_=self.float_,
int_=self.int_)
int_=self.int_,
bp_object_as_pytree=self.bp_object_as_pytree)

def __eq__(self, other):
return id(self) == id(other)
Expand Down Expand Up @@ -269,6 +278,7 @@ def __init__(
bool_: type = None,
batch_size: int = 1,
membrane_scaling: scales.Scaling = None,
bp_object_as_pytree: bool = None,
):
super().__init__(dt=dt,
x64=x64,
Expand All @@ -277,7 +287,8 @@ def __init__(
int_=int_,
bool_=bool_,
membrane_scaling=membrane_scaling,
mode=modes.TrainingMode(batch_size))
mode=modes.TrainingMode(batch_size),
bp_object_as_pytree=bp_object_as_pytree)


class batching_environment(environment):
Expand All @@ -303,6 +314,7 @@ def __init__(
bool_: type = None,
batch_size: int = 1,
membrane_scaling: scales.Scaling = None,
bp_object_as_pytree: bool = None,
):
super().__init__(dt=dt,
x64=x64,
Expand All @@ -311,7 +323,8 @@ def __init__(
int_=int_,
bool_=bool_,
mode=modes.BatchingMode(batch_size),
membrane_scaling=membrane_scaling)
membrane_scaling=membrane_scaling,
bp_object_as_pytree=bp_object_as_pytree)


def set(
Expand All @@ -323,6 +336,7 @@ def set(
float_: type = None,
int_: type = None,
bool_: type = None,
bp_object_as_pytree: bool = None,
):
"""Set the default computation environment.

Expand All @@ -344,6 +358,8 @@ def set(
The integer data type.
bool_
The bool data type.
bp_object_as_pytree: bool
Whether to register brainpy object as pytree.
"""
if dt is not None:
assert isinstance(dt, float), '"dt" must a float.'
Expand Down Expand Up @@ -377,6 +393,9 @@ def set(
assert isinstance(complex_, type), '"complex_" must a type.'
set_complex(complex_)

if bp_object_as_pytree is not None:
defaults.__dict__['bp_object_as_pytree'] = bp_object_as_pytree


set_environment = set

Expand Down
15 changes: 10 additions & 5 deletions brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,24 @@
"""

import numbers
import os
import warnings
from collections import namedtuple
from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional

import jax
import numpy as np
from jax._src.tree_util import _registry
from jax.tree_util import register_pytree_node_class

from brainpy import errors
from brainpy._src.math.modes import Mode
from brainpy._src.math.ndarray import (Array, )
from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector)
from brainpy._src.math.object_transform.naming import (get_unique_name,
check_name_uniqueness)
from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar,
VarList, VarDict)
from brainpy._src.math.modes import Mode
from brainpy._src.math.sharding import BATCH_AXIS
from brainpy._src.math import defaults

variable_ = None
StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
Expand Down Expand Up @@ -89,6 +90,10 @@ class BrainPyObject(object):
def __init__(self, name=None):
super().__init__()

if defaults.bp_object_as_pytree:
if self.__class__ not in _registry:
register_pytree_node_class(self.__class__)

# check whether the object has a unique name.
self._name = None
self._name = self.unique_name(name=name)
Expand Down Expand Up @@ -217,8 +222,8 @@ def tree_flatten(self):
static_names = []
static_values = []
for k, v in self.__dict__.items():
# if isinstance(v, (BrainPyObject, Variable, NodeList, NodeDict, VarList, VarDict)):
if isinstance(v, (BrainPyObject, Variable)):
if isinstance(v, (BrainPyObject, Variable, NodeList, NodeDict, VarList, VarDict)):
# if isinstance(v, (BrainPyObject, Variable)):
dynamic_names.append(k)
dynamic_values.append(v)
else:
Expand Down
19 changes: 19 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,22 @@ def f1():
self.assertTrue(obj.vs['b'] == 12.)
self.assertTrue(bm.allclose(obj.vs['c'], bm.ones(10) * 11.))


class TestRegisterBPObjectAsPyTree(unittest.TestCase):
def test1(self):
bm.set(bp_object_as_pytree=True)

hh = bp.dyn.HH(1)
hh.reset()

tree = jax.tree_structure(hh)
leaves = jax.tree_leaves(hh)

print(tree)
print(leaves)
print(jax.tree_unflatten(tree, leaves))
print()