Skip to content

Conversation

@CharlieFRuan
Copy link
Member

@CharlieFRuan CharlieFRuan commented Dec 27, 2023

This PR allows dynamic shape in some modules of nn.Module (Linear and Embedding), motivated by dynamic vocab size as discussed in mlc-ai/mlc-llm#1417.

UX-wise
We allow nn.Parameter("vocab_size", 4096), equivalent to nn.Parameter(tir.Var("vocab_size", "int64"), 4096).

Symbolic relationship across Parameters
We make sure the same symbolic shape for different nn.Parameter do not appear twice within the same method. For instance, the following would not happen:

@R.function
def prefill(...):
  vocab_size = T.int64()   # for `embed_tokens`
  vocab_size1 = T.int64()  # for `lm_head`

This is prevented by str2var_params in Exporter, so that all embed_tokens.shape[0] and lm_head.shape[0] share the same tir.Var.

Symbolic vars shared across functions
We also prevent two relax.Function sharing the same tir.Var for a symbolic shape, e.g. TIR variable A and B share the same physical address in memory below.

@I.ir_module
class Module:

  @R.function
  def prefill(...):
    vocab_size = T.int64() <== A
    ...

  @R.function
  def decode(...):
    vocab_size = T.int64() <== B
    ...

We prevent this by reinitializing params = _params() and the mapping before each call to _emit_method(), so that the packed_params is different for each function.

@CharlieFRuan CharlieFRuan force-pushed the pr-1226-symbolic-shape-param branch from 9614856 to 33f7dd7 Compare January 8, 2024 12:48
@tqchen tqchen merged commit 07d8e02 into apache:unity Jan 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants