Skip to content

Standardizing and generalizing object-oriented transformations#628

Merged
chaoming0625 merged 16 commits intomasterfrom
oo-transform
Feb 22, 2024
Merged

Standardizing and generalizing object-oriented transformations#628
chaoming0625 merged 16 commits intomasterfrom
oo-transform

Conversation

@chaoming0625
Copy link
Copy Markdown
Member

@chaoming0625 chaoming0625 commented Feb 21, 2024

This PR standardizes the customization of object-oriented transformations. The key is using brainpy.math.VariableStack and brainpy.math.eval_shape.

One OO transformation involves two steps. The first step is using brainpy.math.eval_shape to evaluate all Variables used in the target function. The second step is the actual compilation phase, to compile the model on the given target device.

For example, to customize an object-oriented JIT compilation interface, we can use:

import jax
import brainpy.math as bm


def jit(fun):
  stack: bm.VariableStack = None
  jit_fun = None

  def new_fun(vars, *args, **kwargs):
    for k, v in vars.items():
        stack[k].value = v
    ret = fun(*args, **kwargs)
    new_vars = stack.dict_data()
    return ret, new_vars

  def wrapper(*args, **kwargs):
    global stack, jit_fun

    # [first step]: find all the variables
    if stack is None:
      with bm.VariableStack() as stack:
        ret = bm.eval_shape(fun, *args, **kwargs)
        jit_fun = jax.jit(new_fun)
      if not stack.is_first_stack():
        return ret

    # [second step]: jit compilation
    ret, new_vars = jit_fun(stack.dict_data(), *args, **kwargs)
    stack.assign(new_vars)
    return ret

  return wrapper


@chaoming0625 chaoming0625 marked this pull request as ready for review February 22, 2024 05:06
@chaoming0625 chaoming0625 merged commit 4d74816 into master Feb 22, 2024
@chaoming0625 chaoming0625 deleted the oo-transform branch February 22, 2024 05:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant