Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,13 @@ def _mean(self, node: fx.Node) -> relax.Var:
keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim))

def _median(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
x = args[0]
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.median(x, dim, keepdims=keepdim))

def _norm(self, node: fx.Node) -> relax.Var:
data = self.env[node.args[0]]
dtype = data.struct_info.dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,8 @@ def create_convert_map(
"sum.dim_IntList": self._sum,
"var.correction": self._var,
"max.dim": self._max_dim,
"median.dim": self._median,
"median.default": self._median,
# search
"argmax.default": self._argmax_argmin(relax.op.argmax),
"argmin.default": self._argmax_argmin(relax.op.argmin),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
from .search import argmax, argmin, where, bucketize
from .set import nonzero, unique
from .sorting import argsort, sort, topk
from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance
from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance, median
from .ternary import ewise_fma
from .unary import (
abs,
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relax/op/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,30 @@ def variance(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bo
if isinstance(axis, int):
axis = [axis]
return _ffi_api.variance(x, axis, keepdims) # type: ignore


def median(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr:
"""Computes the median of tensor elements over given axes.

Parameters
----------
x : relax.Expr
The input data tensor

axis : Optional[Union[int, List[int]]]
Axis along which the median is computed. The default (None) is to compute
the median of the entire flattened tensor.

keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input tensor.

Returns
-------
result : relax.Expr
The computed result.
"""
if isinstance(axis, int):
axis = [axis]
return _ffi_api.median(x, axis, keepdims) # type: ignore
47 changes: 46 additions & 1 deletion python/tvm/relax/transform/legalize_ops/statistical.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Default legalization function for statistical operators."""
from typing import List
from typing import List, Union, Tuple
from tvm import topi, tir, te
from ...block_builder import BlockBuilder
from ...expr import Call, Expr
Expand Down Expand Up @@ -53,6 +53,40 @@ def _te_variance(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Ten
# return _te_mean(x * x, axis, keepdims) - mean * mean


def _te_median(
x: te.Tensor, axis: List[tir.IntImm], keepdims: bool
) -> Union[te.Tensor, Tuple[te.Tensor, te.Tensor]]:
# currently only supports one axis or no axis ~ same pytorch
# todo: support multiple axis ~ same numpy
shape_prod = _compute_shape_prod(x, axis)
mid_index = (shape_prod - 1) // 2

if axis is None or len(axis) == 0:
x = topi.reshape(x, [shape_prod.value])
ax = -1
else:
ax = axis[0].value
index_sorted = topi.argsort(x, axis=ax, is_ascend=True, dtype="int64")
x_sorted = topi.gather(x, axis=ax, indices=index_sorted)

new_shape = list(x.shape)
new_shape[ax] = 1
indices = topi.full(new_shape, fill_value=mid_index, dtype="int64")

median_val = topi.gather(x_sorted, axis=ax, indices=indices)
median_idx = topi.gather(index_sorted, axis=ax, indices=indices)

if axis is None or len(axis) == 0:
return median_val if keepdims else topi.squeeze(median_val, axis=axis)
Comment on lines +61 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When axis is None and keepdims is True, the output tensor should have the same rank as the input, with all dimensions of size 1. The current implementation returns a tensor of shape (1,) because the original rank of x is lost after it's reshaped.

To fix this, we should store the original rank of x before any modifications and use it to reshape median_val when keepdims is true and axis is None.

    orig_ndim = len(x.shape)
    shape_prod = _compute_shape_prod(x, axis)
    mid_index = (shape_prod - 1) // 2

    if axis is None or len(axis) == 0:
        x = topi.reshape(x, [shape_prod.value])
        ax = -1
    else:
        ax = axis[0].value
    index_sorted = topi.argsort(x, axis=ax, is_ascend=True, dtype="int64")
    x_sorted = topi.gather(x, axis=ax, indices=index_sorted)

    new_shape = list(x.shape)
    new_shape[ax] = 1
    indices = topi.full(new_shape, fill_value=mid_index, dtype="int64")

    median_val = topi.gather(x_sorted, axis=ax, indices=indices)
    median_idx = topi.gather(index_sorted, axis=ax, indices=indices)

    if axis is None or len(axis) == 0:
        if keepdims:
            return topi.reshape(median_val, [1] * orig_ndim)
        return topi.squeeze(median_val, axis=axis)


val = median_val
idx = median_idx
if not keepdims:
val = topi.squeeze(val, axis=axis)
idx = topi.squeeze(idx, axis=axis)
return val, idx


@register_legalize("relax.mean")
def _mean(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
Expand Down Expand Up @@ -81,6 +115,17 @@ def _variance(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.median")
def _median(bb: BlockBuilder, call: Call) -> Expr:
return bb.call_te(
_te_median,
call.args[0],
call.attrs.axis,
call.attrs.keepdims,
primfunc_name_hint="median",
)


register_legalize("relax.max", _statistical(topi.max))
register_legalize("relax.min", _statistical(topi.min))
register_legalize("relax.prod", _statistical(topi.prod))
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
max,
maximum,
mean,
median,
memory,
meshgrid,
min,
Expand Down Expand Up @@ -874,6 +875,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr:
"max",
"maximum",
"mean",
"median",
"memory",
"meshgrid",
"metal",
Expand Down
82 changes: 82 additions & 0 deletions src/relax/op/tensor/statistical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,68 @@ StructInfo InferStructInfoScan(const Call& call, const BlockBuilder& ctx) {
}
}

StructInfo InferStructInfoStatisticalExtension(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
const auto* attrs = call->attrs.as<StatisticalAttrs>();

std::vector<int> axes;
if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) {
axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value());
}

int out_ndim;
if (attrs->keepdims) {
out_ndim = data_sinfo->ndim;
} else if (!attrs->axis.defined()) {
out_ndim = 0;
} else if (data_sinfo->IsUnknownNdim()) {
out_ndim = kUnknownNDim;
} else {
out_ndim = data_sinfo->ndim - axes.size();
ICHECK_GE(out_ndim, 0);
}

// The inference rule for median operator output shapes:
// - axes is None || len(axes) > 1, keepdims is false -> return the zero-rank shape;
// - axes is None || len(axes) > 1, keepdims is true -> return the shape whose ndim
// is the same as input and every value is 1.
// - len(axes) == 1, keepdims is false -> the returned shape does not contain the input axis.
// - len(axes) == 1, keepdims is true -> the returned shape has value 1 at the positions of the
// input axis
const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
if (data_shape == nullptr) {
if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) {
return TensorStructInfo(
ShapeExpr(ffi::Array<PrimExpr>(out_ndim, IntImm(DataType::Int(64), /*value=*/1))),
data_sinfo->dtype, data_sinfo->vdevice);
}
if (out_ndim == 0) {
return TensorStructInfo(ShapeExpr(ffi::Array<PrimExpr>()), data_sinfo->dtype,
data_sinfo->vdevice);
}
return TupleStructInfo({TensorStructInfo(data_sinfo->dtype, out_ndim, data_sinfo->vdevice),
TensorStructInfo(DataType::Int(64), out_ndim, data_sinfo->vdevice)});
}

ffi::Array<PrimExpr> out_shape;
out_shape.reserve(out_ndim);
for (int i = 0; i < data_sinfo->ndim; ++i) {
if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) {
out_shape.push_back(data_shape->values[i]);
} else if (attrs->keepdims) {
out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1));
}
}
ICHECK_EQ(static_cast<int>(out_shape.size()), out_ndim);

if (!attrs->axis.defined() || axes.size() > 1)
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice);
else
return TupleStructInfo(
{TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice),
TensorStructInfo(ShapeExpr(out_shape), DataType::Int(64), data_sinfo->vdevice)});
}

/* relax.cumprod */
Expr cumprod(Expr data, ffi::Optional<int64_t> axis, ffi::Optional<DataType> dtype,
Bool exclusive) {
Expand Down Expand Up @@ -227,6 +289,26 @@ TVM_REGISTER_OP("relax.cumsum")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScan)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.median */
Expr median(Expr data, ffi::Optional<ffi::Array<Integer>> axis, bool keepdims) {
ObjectPtr<StatisticalAttrs> attrs = ffi::make_object<StatisticalAttrs>();
attrs->axis = std::move(axis);
attrs->keepdims = keepdims;
static const Op& op = Op::Get("relax.median");
return Call(op, {std::move(data)}, Attrs{attrs}, {});
}

TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("relax.op.median", median);
}

TVM_REGISTER_OP("relax.median")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoStatisticalExtension)
.set_attr<Bool>("FPurity", Bool(true));

RELAX_REGISTER_STATISTICAL_OP_INTERFACE(max);
RELAX_REGISTER_STATISTICAL_OP_INTERFACE(mean);
RELAX_REGISTER_STATISTICAL_OP_INTERFACE(min);
Expand Down
3 changes: 3 additions & 0 deletions src/relax/op/tensor/statistical.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ Expr cumsum(Expr data, ffi::Optional<int64_t> axis = std::nullopt,
/*! \brief Computes the variance of tensor elements over given axes. */
Expr variance(Expr x, ffi::Optional<ffi::Array<Integer>> axis, bool keepdims);

/*! \brief Computes the median of tensor elements over given axes. */
Expr median(Expr x, ffi::Optional<ffi::Array<Integer>> axis, bool keepdims);

} // namespace relax
} // namespace tvm

Expand Down
68 changes: 68 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -4957,6 +4957,74 @@ def main(
verify_model(MeanWithoutDim(), example_args, {}, Expected3)


def test_median():
class Median(Module):
def forward(self, input):
return input.median(-1)

class MedianKeepDim(Module):
def forward(self, input):
return input.median(-1, keepdim=True)

class MedianWithoutDim(Module):
def forward(self, input):
return input.median()

@I.ir_module
class Expected1:
@R.function
def main(
inp_0: R.Tensor((256, 256), dtype="float32")
) -> R.Tuple(R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="int64")):
with R.dataflow():
lv: R.Tuple(
R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="int64")
) = R.median(inp_0, axis=[-1], keepdims=False)
lv1: R.Tensor((256,), dtype="float32") = lv[0]
lv2: R.Tensor((256,), dtype="int64") = lv[1]
gv: R.Tuple(R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="int64")) = (
lv1,
lv2,
)
R.output(gv)
return gv

@I.ir_module
class Expected2:
@R.function
def main(
inp_0: R.Tensor((256, 256), dtype="float32")
) -> R.Tuple(R.Tensor((256, 1), dtype="float32"), R.Tensor((256, 1), dtype="int64")):
with R.dataflow():
lv: R.Tuple(
R.Tensor((256, 1), dtype="float32"), R.Tensor((256, 1), dtype="int64")
) = R.median(inp_0, axis=[-1], keepdims=True)
lv1: R.Tensor((256, 1), dtype="float32") = lv[0]
lv2: R.Tensor((256, 1), dtype="int64") = lv[1]
gv: R.Tuple(
R.Tensor((256, 1), dtype="float32"), R.Tensor((256, 1), dtype="int64")
) = (lv1, lv2)
R.output(gv)
return gv

@I.ir_module
class Expected3:
@R.function
def main(
inp_0: R.Tensor((256, 256), dtype="float32")
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.median(inp_0, axis=None, keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(256, 256, dtype=torch.float32),)
verify_model(Median(), example_args, {}, Expected1)
verify_model(MedianKeepDim(), example_args, {}, Expected2)
verify_model(MedianWithoutDim(), example_args, {}, Expected3)


def test_sum():
class Sum(Module):
def forward(self, x):
Expand Down
Loading