diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 715c50ba7..19603f94c 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -828,7 +828,9 @@ def truncated_normal(self, lower, upper, size=None, loc=0., scale=1., dtype=floa # Uniformly fill tensor with values from [l, u], then translate to # [2l-1, 2u-1]. key = self.split_key() if key is None else _formalize_key(key) - out = jr.uniform(key, size, dtype, minval=2 * l - 1, maxval=2 * u - 1) + out = jr.uniform(key, size, dtype, + minval=lax.nextafter(2 * l - 1, np.array(np.inf, dtype=dtype)), + maxval=lax.nextafter(2 * u - 1, np.array(-np.inf, dtype=dtype))) # Use inverse cdf transform for normal distribution to get truncated # standard normal