Skip to content

Conversation

@locnd182644
Copy link
Contributor

Summary:

  • Supported Median operator: Add relax.median & Apply median op into exported_program_translator
  • Input: Tensor, Axis, KeepDim
  • Output: (Values, Indices)

Expected:

1. Axis = None, KeepDim = False

class MedianWithoutDim(nn.Module):
    def forward(self, x):
        return torch.median(x)
class Module:
    def main(x: R.Tensor((2, 3, 4), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")):
        with R.dataflow():
            lv: R.Tensor((), dtype="float32") = R.median(x, axis=None, keepdims=False)
            gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
            R.output(gv)
        return gv

2. Axis = 0, KeepDim = False

class MedianDim(nn.Module):
    def forward(self, x):
        return torch.median(x, dim=0)
class Module:
    def main(x: R.Tensor((2, 3, 4), dtype="float32")) -> R.Tuple(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 4), dtype="int64")):
        with R.dataflow():
            lv: R.Tuple(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 4), dtype="int64")) = R.median(x, axis=[0], keepdims=False)
            lv1: R.Tensor((3, 4), dtype="float32") = lv[0]
            lv2: R.Tensor((3, 4), dtype="int64") = lv[1]
            gv: R.Tuple(R.Tensor((3, 4), dtype="float32"), R.Tensor((3, 4), dtype="int64")) = lv1, lv2
            R.output(gv)
        return gv

3. Axis = -1, KeepDim = True

class MedianKeepDim(nn.Module):
    def forward(self, x):
        return torch.median(x, dim=-1, keepdim=True)
class Module:
    def main(x: R.Tensor((2, 3, 4), dtype="float32")) -> R.Tuple(R.Tensor((2, 3, 1), dtype="float32"), R.Tensor((2, 3, 1), dtype="int64")):
        with R.dataflow():
            lv: R.Tuple(R.Tensor((2, 3, 1), dtype="float32"), R.Tensor((2, 3, 1), dtype="int64")) = R.median(x, axis=[-1], keepdims=True)
            lv1: R.Tensor((2, 3, 1), dtype="float32") = lv[0]
            lv2: R.Tensor((2, 3, 1), dtype="int64") = lv[1]
            gv: R.Tuple(R.Tensor((2, 3, 1), dtype="float32"), R.Tensor((2, 3, 1), dtype="int64")) = lv1, lv2
            R.output(gv)
        return gv

 - Add relax.median
 - Apply median op into exported_program_translator
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @locnd182644, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the TVM Relax framework by adding support for the median statistical operator. This addition improves compatibility with PyTorch models that utilize torch.median, ensuring that such operations can be seamlessly translated and optimized within Relax. The implementation covers the operator's definition, shape inference, and legalization into low-level Tensor Expression, alongside comprehensive testing to validate its correctness and integration.

Highlights

  • Median Operator Support: Introduced the relax.median operator, enabling the computation of median values across tensor elements, with support for specifying an axis and retaining dimensions.
  • PyTorch Integration: Integrated the new median operator into the exported_program_translator, allowing PyTorch models using torch.median to be correctly translated and executed within the Relax framework.
  • Comprehensive Implementation: Provided a full implementation including Python frontend, C++ backend with struct info inference, Tensor Expression (TE) legalization, and TVMScript parser support for the median operator.
  • Output Format: The median operator now returns a tuple containing both the median values and their corresponding indices, aligning with common statistical library behaviors.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the relax.median operator, including its frontend integration for PyTorch, operator definition, legalization to TE, and associated tests. The implementation is comprehensive, but I've identified a critical correctness issue in the struct info inference logic for dynamic shapes, which could lead to incorrect program compilation. I've also found some issues in the tests that validate this buggy behavior, along with a couple of minor opportunities for code simplification and type hint improvements. Addressing these points will ensure the new operator is robust and maintainable.

@locnd182644
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the median operator in Relax, including frontend converters for PyTorch, the operator definition, struct info inference, and legalization to TE. The changes are well-structured and follow existing patterns in the codebase. I've found one correctness issue in the legalization logic for median when axis is None and keepdims is True, for which I've provided a suggestion. The rest of the implementation and the new tests look solid.

Comment on lines +61 to +80
shape_prod = _compute_shape_prod(x, axis)
mid_index = (shape_prod - 1) // 2

if axis is None or len(axis) == 0:
x = topi.reshape(x, [shape_prod.value])
ax = -1
else:
ax = axis[0].value
index_sorted = topi.argsort(x, axis=ax, is_ascend=True, dtype="int64")
x_sorted = topi.gather(x, axis=ax, indices=index_sorted)

new_shape = list(x.shape)
new_shape[ax] = 1
indices = topi.full(new_shape, fill_value=mid_index, dtype="int64")

median_val = topi.gather(x_sorted, axis=ax, indices=indices)
median_idx = topi.gather(index_sorted, axis=ax, indices=indices)

if axis is None or len(axis) == 0:
return median_val if keepdims else topi.squeeze(median_val, axis=axis)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When axis is None and keepdims is True, the output tensor should have the same rank as the input, with all dimensions of size 1. The current implementation returns a tensor of shape (1,) because the original rank of x is lost after it's reshaped.

To fix this, we should store the original rank of x before any modifications and use it to reshape median_val when keepdims is true and axis is None.

    orig_ndim = len(x.shape)
    shape_prod = _compute_shape_prod(x, axis)
    mid_index = (shape_prod - 1) // 2

    if axis is None or len(axis) == 0:
        x = topi.reshape(x, [shape_prod.value])
        ax = -1
    else:
        ax = axis[0].value
    index_sorted = topi.argsort(x, axis=ax, is_ascend=True, dtype="int64")
    x_sorted = topi.gather(x, axis=ax, indices=index_sorted)

    new_shape = list(x.shape)
    new_shape[ax] = 1
    indices = topi.full(new_shape, fill_value=mid_index, dtype="int64")

    median_val = topi.gather(x_sorted, axis=ax, indices=indices)
    median_idx = topi.gather(index_sorted, axis=ax, indices=indices)

    if axis is None or len(axis) == 0:
        if keepdims:
            return topi.reshape(median_val, [1] * orig_ndim)
        return topi.squeeze(median_val, axis=axis)

Copy link
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@tlopex tlopex merged commit 899556d into apache:main Jan 2, 2026
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants