diff --git a/brainpy/_src/math/surrogate/_one_input_new.py b/brainpy/_src/math/surrogate/_one_input_new.py index 64c7280d0..bfffd88f5 100644 --- a/brainpy/_src/math/surrogate/_one_input_new.py +++ b/brainpy/_src/math/surrogate/_one_input_new.py @@ -90,7 +90,30 @@ def _as_jax(x): class Surrogate(object): - """The base surrograte gradient function.""" + """The base surrograte gradient function. + + To customize a surrogate gradient function, you can inherit this class and + implement the `surrogate_fun` and `surrogate_grad` methods. + + Examples + -------- + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import jax.numpy as jnp + + >>> class MySurrogate(bm.Surrogate): + ... def __init__(self, alpha=1.): + ... super().__init__() + ... self.alpha = alpha + ... + ... def surrogate_fun(self, x): + ... return jnp.sin(x) * self.alpha + ... + ... def surrogate_grad(self, x): + ... return jnp.cos(x) * self.alpha + + """ def __call__(self, x): x = _as_jax(x) @@ -123,7 +146,7 @@ def __init__(self, alpha: float = 4.): self.alpha = alpha def surrogate_fun(self, x): - return sci.special.expit(x) + return sci.special.expit(self.alpha * x) def surrogate_grad(self, x): sgax = sci.special.expit(x * self.alpha) diff --git a/brainpy/math/surrogate.py b/brainpy/math/surrogate.py index 0121bddec..bf7897435 100644 --- a/brainpy/math/surrogate.py +++ b/brainpy/math/surrogate.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- - from brainpy._src.math.surrogate._one_input_new import ( + Surrogate, + Sigmoid, sigmoid as sigmoid,