diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 746538169..62687f218 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -940,6 +940,8 @@ def scan( ): """``scan`` control flow with :py:class:`~.Variable`. + Similar to ``jax.lax.scan``. + .. versionadded:: 2.4.7 All returns in body function will be gathered @@ -999,7 +1001,7 @@ def scan( rets = jax.eval_shape(transform, init, operands) cache_stack(body_fun, dyn_vars) # cache if current_transform_number(): - return rets[1] + return rets[0][1], rets[1] del rets transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll) diff --git a/brainpy/_src/math/object_transform/tests/test_controls.py b/brainpy/_src/math/object_transform/tests/test_controls.py index 658af8c6b..d8ff2282c 100644 --- a/brainpy/_src/math/object_transform/tests/test_controls.py +++ b/brainpy/_src/math/object_transform/tests/test_controls.py @@ -1,14 +1,11 @@ # -*- coding: utf-8 -*- -import sys import tempfile import unittest from functools import partial import jax -from jax import vmap - from absl.testing import parameterized -from jax._src import test_util as jtu +from jax import vmap import brainpy as bp import brainpy.math as bm @@ -147,6 +144,25 @@ def f(carray, x): expected = bm.expand_dims(expected, axis=-1) self.assertTrue(bm.allclose(outs, expected)) + def test2(self): + a = bm.Variable(1) + + def f(carray, x): + carray += x + a.value += 1. + return carray, a + + @bm.jit + def f_outer(carray, x): + carry, outs = bm.scan(f, carray, x, unroll=2) + return carry, outs + + carry, outs = f_outer(bm.zeros(2), bm.arange(10)) + self.assertTrue(bm.allclose(carry, 45.)) + expected = bm.arange(1, 11).astype(outs.dtype) + expected = bm.expand_dims(expected, axis=-1) + self.assertTrue(bm.allclose(outs, expected)) + class TestCond(unittest.TestCase): def test1(self): @@ -234,7 +250,6 @@ def F2(x): self.assertTrue(bm.grad(F2)(9.0) == 18.) self.assertTrue(bm.grad(F2)(11.0) == 1.) - def test_grad2(self): def F3(x): return bm.ifelse(conditions=(x >= 10, x >= 0), @@ -519,6 +534,3 @@ def body(a): file.seek(0) out6 = file.read().strip() self.assertTrue(out5 == out6) - - - diff --git a/docs/apis/brainpy.math.oo_transform.rst b/docs/apis/brainpy.math.oo_transform.rst index 5ee94c615..754e0d81d 100644 --- a/docs/apis/brainpy.math.oo_transform.rst +++ b/docs/apis/brainpy.math.oo_transform.rst @@ -60,6 +60,7 @@ Object-oriented Transformations ifelse for_loop while_loop + scan jit cls_jit to_object