Skip to content

Conversation

@Lunderberg
Copy link
Contributor

Prior to this commit, calling FuseOpsByPattern with annotate_codegen=True would cause an error when encountering a lambda function. This was caused by the CompositeFunctionAnnotator asserting that all relax::Function encountered must have the kComposite attribute. While this is true for all lambda functions produced by FuseOpsByPattern, the user may have defined other lambda functions as well.

This commit updates CompositeFunctionAnnotator to ignore lambda functions that do not have a kComposite attribute.

Prior to this commit, calling `FuseOpsByPattern` with
`annotate_codegen=True` would cause an error when encountering a
lambda function.  This was caused by the `CompositeFunctionAnnotator`
asserting that all `relax::Function` encountered must have the
`kComposite` attribute.  While this is true for all lambda functions
produced by `FuseOpsByPattern`, the user may have defined other lambda
functions as well.

This commit updates `CompositeFunctionAnnotator` to ignore lambda
functions that do not have a `kComposite` attribute.
Function f_inner = Downcast<Function>(ExprMutator::VisitExpr_(func_node));
auto composite_name = func_node->GetAttr<String>(attr::kComposite);

if (!func_node->GetAttr<String>(attr::kComposite)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are non-composite functions visited?

auto new_func = Downcast<Function>(VisitExpr(func));
here we only visit composite functions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PatternBasedPartitioner only visits non-composite functions, produces a composite function for each pattern match, and updates the non-composite function to call the newly-generated composite function. Afterwards, the call to CompositeFunctionAnnotator is called. This visits only non-composite functions, finds any relax-to-relax function calls, and asserts that the callee is composite.

The callee will be composite for every function call generated by PatternBasedPartitioner, but that doesn't guarantee that all relax-to-relax function calls have a composite callee. If the IRModule contains a relax-to-relax call prior to PatternBasedPartitioner, that callee may be non-composite. This IRModule would be entirely legal, but would trigger the assert in CompositeFunctionAnnotator.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, the problem isn't with calls to inner functions as on line 1224, but with calls to other functions within the IRModule.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem is if the callee is not a global var, the callee function will still be visited, so the fix makes sense to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whoops, you're right on that one. It's if there is a inner function in the input IRModule. (Apologies, trying to track too many PRs at one time.)

@vinx13 vinx13 merged commit d91fe45 into apache:main Feb 21, 2024
@Lunderberg Lunderberg deleted the transform_ignore_non_composite_in_fuse_ops branch February 21, 2024 04:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants