diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f7d54a6216a7..d04dfbb6c35a 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1572,6 +1572,13 @@ def _mean(self, node: fx.Node) -> relax.Var: keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) + def _median(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(relax.op.median(x, dim, keepdims=keepdim)) + def _norm(self, node: fx.Node) -> relax.Var: data = self.env[node.args[0]] dtype = data.struct_info.dtype diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index b6b9723c131f..0a97614eb576 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1384,6 +1384,8 @@ def create_convert_map( "sum.dim_IntList": self._sum, "var.correction": self._var, "max.dim": self._max_dim, + "median.dim": self._median, + "median.default": self._median, # search "argmax.default": self._argmax_argmin(relax.op.argmax), "argmin.default": self._argmax_argmin(relax.op.argmin), diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 19096decd932..c6504d79c9a5 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -119,7 +119,7 @@ from .search import argmax, argmin, where, bucketize from .set import nonzero, unique from .sorting import argsort, sort, topk -from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance +from .statistical import cumprod, cumsum, max, mean, min, prod, std, sum, variance, median from .ternary import ewise_fma from .unary import ( abs, diff --git a/python/tvm/relax/op/statistical.py b/python/tvm/relax/op/statistical.py index 502d058ffdf6..f11d31604a05 100644 --- a/python/tvm/relax/op/statistical.py +++ b/python/tvm/relax/op/statistical.py @@ -341,3 +341,30 @@ def variance(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bo if isinstance(axis, int): axis = [axis] return _ffi_api.variance(x, axis, keepdims) # type: ignore + + +def median(x: Expr, axis: Optional[Union[int, List[int]]] = None, keepdims: bool = False) -> Expr: + """Computes the median of tensor elements over given axes. + + Parameters + ---------- + x : relax.Expr + The input data tensor + + axis : Optional[Union[int, List[int]]] + Axis along which the median is computed. The default (None) is to compute + the median of the entire flattened tensor. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + if isinstance(axis, int): + axis = [axis] + return _ffi_api.median(x, axis, keepdims) # type: ignore diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py index bdb79126f012..0c140187db9a 100644 --- a/python/tvm/relax/transform/legalize_ops/statistical.py +++ b/python/tvm/relax/transform/legalize_ops/statistical.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name """Default legalization function for statistical operators.""" -from typing import List +from typing import List, Union, Tuple from tvm import topi, tir, te from ...block_builder import BlockBuilder from ...expr import Call, Expr @@ -53,6 +53,40 @@ def _te_variance(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Ten # return _te_mean(x * x, axis, keepdims) - mean * mean +def _te_median( + x: te.Tensor, axis: List[tir.IntImm], keepdims: bool +) -> Union[te.Tensor, Tuple[te.Tensor, te.Tensor]]: + # currently only supports one axis or no axis ~ same pytorch + # todo: support multiple axis ~ same numpy + shape_prod = _compute_shape_prod(x, axis) + mid_index = (shape_prod - 1) // 2 + + if axis is None or len(axis) == 0: + x = topi.reshape(x, [shape_prod.value]) + ax = -1 + else: + ax = axis[0].value + index_sorted = topi.argsort(x, axis=ax, is_ascend=True, dtype="int64") + x_sorted = topi.gather(x, axis=ax, indices=index_sorted) + + new_shape = list(x.shape) + new_shape[ax] = 1 + indices = topi.full(new_shape, fill_value=mid_index, dtype="int64") + + median_val = topi.gather(x_sorted, axis=ax, indices=indices) + median_idx = topi.gather(index_sorted, axis=ax, indices=indices) + + if axis is None or len(axis) == 0: + return median_val if keepdims else topi.squeeze(median_val, axis=axis) + + val = median_val + idx = median_idx + if not keepdims: + val = topi.squeeze(val, axis=axis) + idx = topi.squeeze(idx, axis=axis) + return val, idx + + @register_legalize("relax.mean") def _mean(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( @@ -81,6 +115,17 @@ def _variance(bb: BlockBuilder, call: Call) -> Expr: ) +@register_legalize("relax.median") +def _median(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + _te_median, + call.args[0], + call.attrs.axis, + call.attrs.keepdims, + primfunc_name_hint="median", + ) + + register_legalize("relax.max", _statistical(topi.max)) register_legalize("relax.min", _statistical(topi.min)) register_legalize("relax.prod", _statistical(topi.prod)) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 141361a729c4..354a4d77bac6 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -128,6 +128,7 @@ max, maximum, mean, + median, memory, meshgrid, min, @@ -874,6 +875,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "max", "maximum", "mean", + "median", "memory", "meshgrid", "metal", diff --git a/src/relax/op/tensor/statistical.cc b/src/relax/op/tensor/statistical.cc index 621c23d36310..771f6ffb133f 100644 --- a/src/relax/op/tensor/statistical.cc +++ b/src/relax/op/tensor/statistical.cc @@ -180,6 +180,68 @@ StructInfo InferStructInfoScan(const Call& call, const BlockBuilder& ctx) { } } +StructInfo InferStructInfoStatisticalExtension(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + + std::vector axes; + if (!data_sinfo->IsUnknownNdim() && attrs->axis.defined()) { + axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axis.value()); + } + + int out_ndim; + if (attrs->keepdims) { + out_ndim = data_sinfo->ndim; + } else if (!attrs->axis.defined()) { + out_ndim = 0; + } else if (data_sinfo->IsUnknownNdim()) { + out_ndim = kUnknownNDim; + } else { + out_ndim = data_sinfo->ndim - axes.size(); + ICHECK_GE(out_ndim, 0); + } + + // The inference rule for median operator output shapes: + // - axes is None || len(axes) > 1, keepdims is false -> return the zero-rank shape; + // - axes is None || len(axes) > 1, keepdims is true -> return the shape whose ndim + // is the same as input and every value is 1. + // - len(axes) == 1, keepdims is false -> the returned shape does not contain the input axis. + // - len(axes) == 1, keepdims is true -> the returned shape has value 1 at the positions of the + // input axis + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + if (!attrs->axis.defined() && attrs->keepdims && out_ndim != kUnknownNDim) { + return TensorStructInfo( + ShapeExpr(ffi::Array(out_ndim, IntImm(DataType::Int(64), /*value=*/1))), + data_sinfo->dtype, data_sinfo->vdevice); + } + if (out_ndim == 0) { + return TensorStructInfo(ShapeExpr(ffi::Array()), data_sinfo->dtype, + data_sinfo->vdevice); + } + return TupleStructInfo({TensorStructInfo(data_sinfo->dtype, out_ndim, data_sinfo->vdevice), + TensorStructInfo(DataType::Int(64), out_ndim, data_sinfo->vdevice)}); + } + + ffi::Array out_shape; + out_shape.reserve(out_ndim); + for (int i = 0; i < data_sinfo->ndim; ++i) { + if (attrs->axis.defined() && std::find(axes.begin(), axes.end(), i) == axes.end()) { + out_shape.push_back(data_shape->values[i]); + } else if (attrs->keepdims) { + out_shape.push_back(IntImm(DataType::Int(64), /*value=*/1)); + } + } + ICHECK_EQ(static_cast(out_shape.size()), out_ndim); + + if (!attrs->axis.defined() || axes.size() > 1) + return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice); + else + return TupleStructInfo( + {TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, data_sinfo->vdevice), + TensorStructInfo(ShapeExpr(out_shape), DataType::Int(64), data_sinfo->vdevice)}); +} + /* relax.cumprod */ Expr cumprod(Expr data, ffi::Optional axis, ffi::Optional dtype, Bool exclusive) { @@ -227,6 +289,26 @@ TVM_REGISTER_OP("relax.cumsum") .set_attr("FInferStructInfo", InferStructInfoScan) .set_attr("FPurity", Bool(true)); +/* relax.median */ +Expr median(Expr data, ffi::Optional> axis, bool keepdims) { + ObjectPtr attrs = ffi::make_object(); + attrs->axis = std::move(axis); + attrs->keepdims = keepdims; + static const Op& op = Op::Get("relax.median"); + return Call(op, {std::move(data)}, Attrs{attrs}, {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.median", median); +} + +TVM_REGISTER_OP("relax.median") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoStatisticalExtension) + .set_attr("FPurity", Bool(true)); + RELAX_REGISTER_STATISTICAL_OP_INTERFACE(max); RELAX_REGISTER_STATISTICAL_OP_INTERFACE(mean); RELAX_REGISTER_STATISTICAL_OP_INTERFACE(min); diff --git a/src/relax/op/tensor/statistical.h b/src/relax/op/tensor/statistical.h index a80ef728683a..0a4f83687d56 100644 --- a/src/relax/op/tensor/statistical.h +++ b/src/relax/op/tensor/statistical.h @@ -119,6 +119,9 @@ Expr cumsum(Expr data, ffi::Optional axis = std::nullopt, /*! \brief Computes the variance of tensor elements over given axes. */ Expr variance(Expr x, ffi::Optional> axis, bool keepdims); +/*! \brief Computes the median of tensor elements over given axes. */ +Expr median(Expr x, ffi::Optional> axis, bool keepdims); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 9f8842ddcb69..01a24ada1fa0 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4957,6 +4957,74 @@ def main( verify_model(MeanWithoutDim(), example_args, {}, Expected3) +def test_median(): + class Median(Module): + def forward(self, input): + return input.median(-1) + + class MedianKeepDim(Module): + def forward(self, input): + return input.median(-1, keepdim=True) + + class MedianWithoutDim(Module): + def forward(self, input): + return input.median() + + @I.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="int64")): + with R.dataflow(): + lv: R.Tuple( + R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="int64") + ) = R.median(inp_0, axis=[-1], keepdims=False) + lv1: R.Tensor((256,), dtype="float32") = lv[0] + lv2: R.Tensor((256,), dtype="int64") = lv[1] + gv: R.Tuple(R.Tensor((256,), dtype="float32"), R.Tensor((256,), dtype="int64")) = ( + lv1, + lv2, + ) + R.output(gv) + return gv + + @I.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((256, 1), dtype="float32"), R.Tensor((256, 1), dtype="int64")): + with R.dataflow(): + lv: R.Tuple( + R.Tensor((256, 1), dtype="float32"), R.Tensor((256, 1), dtype="int64") + ) = R.median(inp_0, axis=[-1], keepdims=True) + lv1: R.Tensor((256, 1), dtype="float32") = lv[0] + lv2: R.Tensor((256, 1), dtype="int64") = lv[1] + gv: R.Tuple( + R.Tensor((256, 1), dtype="float32"), R.Tensor((256, 1), dtype="int64") + ) = (lv1, lv2) + R.output(gv) + return gv + + @I.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.median(inp_0, axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(256, 256, dtype=torch.float32),) + verify_model(Median(), example_args, {}, Expected1) + verify_model(MedianKeepDim(), example_args, {}, Expected2) + verify_model(MedianWithoutDim(), example_args, {}, Expected3) + + def test_sum(): class Sum(Module): def forward(self, x): diff --git a/tests/python/relax/test_op_statistical.py b/tests/python/relax/test_op_statistical.py index a0cfc81e55f0..5dccbb33cc90 100644 --- a/tests/python/relax/test_op_statistical.py +++ b/tests/python/relax/test_op_statistical.py @@ -33,6 +33,7 @@ def test_op_correctness(): assert relax.op.std(x).op == Op.get("relax.std") assert relax.op.sum(x).op == Op.get("relax.sum") assert relax.op.variance(x).op == Op.get("relax.variance") + assert relax.op.median(x).op == Op.get("relax.median") def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): @@ -275,5 +276,230 @@ def test_scan_opinfer_struct_info_wrong_input_type(scan_op: Callable): bb.normalize(scan_op(x1, axis=1)) +def test_statistical_ext_infer_struct_info(): + bb = relax.BlockBuilder() + vdev0 = VDevice("llvm") + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4, 5))) + x4 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0)) + + _check_inference( + bb, + relax.op.median(x0, axis=[1]), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 5), "float32"), + relax.TensorStructInfo((2, 4, 5), "int64"), + ] + ), + ) + _check_inference( + bb, + relax.op.median(x0, axis=[1], keepdims=True), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 1, 4, 5), "float32"), + relax.TensorStructInfo((2, 1, 4, 5), "int64"), + ] + ), + ) + _check_inference( + bb, + relax.op.median(x1, axis=[1]), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.median(x1, axis=[1], keepdims=True), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo(dtype="int64", ndim=4), + ] + ), + ) + _check_inference( + bb, + relax.op.median(x1, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "float32"), + ) + _check_inference( + bb, + relax.op.median(x2, axis=[1]), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64"), + ] + ), + ) + _check_inference( + bb, + relax.op.median(x2, axis=[1], keepdims=True), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64"), + ] + ), + ) + _check_inference(bb, relax.op.median(x2, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.median(x3, axis=[1], keepdims=True), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 1, 4, 5), dtype=""), + relax.TensorStructInfo((2, 1, 4, 5), dtype="int64"), + ] + ), + ) + _check_inference(bb, relax.op.median(x3, axis=None), relax.TensorStructInfo((), dtype="")) + _check_inference( + bb, + relax.op.median(x3, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), dtype=""), + ) + _check_inference( + bb, + relax.op.median(x4, axis=[1]), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 4, 5), "float32", vdev0), + relax.TensorStructInfo((2, 4, 5), "int64", vdev0), + ] + ), + ) + + +def test_statistical_ext_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + x = relax.Var("x", R.Tensor((a, b, c, d), "float32")) + + _check_inference( + bb, + relax.op.median(x, axis=[1]), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((a, c, d), "float32"), + relax.TensorStructInfo((a, c, d), "int64"), + ] + ), + ) + _check_inference( + bb, + relax.op.median(x, axis=[1], keepdims=True), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((a, 1, c, d), "float32"), + relax.TensorStructInfo((a, 1, c, d), "int64"), + ] + ), + ) + _check_inference(bb, relax.op.median(x, axis=None), relax.TensorStructInfo((), "float32")) + _check_inference( + bb, + relax.op.median(x, axis=None, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), "float32"), + ) + + +def test_statistical_ext_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, relax.op.median(x0), relax.TensorStructInfo((), dtype="float32")) + _check_inference( + bb, + relax.op.median(x0, keepdims=True), + relax.TensorStructInfo((1, 1, 1, 1), dtype="float32"), + ) + _check_inference( + bb, + relax.op.median(x0, axis=[2]), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=3), + relax.TensorStructInfo(dtype="int64", ndim=3), + ] + ), + ) + _check_inference( + bb, + relax.op.median(x0, axis=[2], keepdims=True), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32", ndim=4), + relax.TensorStructInfo(dtype="int64", ndim=4), + ] + ), + ) + _check_inference(bb, relax.op.median(x1), relax.TensorStructInfo((), dtype="float32")) + _check_inference( + bb, + relax.op.median(x1, keepdims=True), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64"), + ] + ), + ) + _check_inference( + bb, + relax.op.median(x1, axis=[2]), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64"), + ] + ), + ) + _check_inference( + bb, + relax.op.median(x1, axis=[2], keepdims=True), + relax.TupleStructInfo( + [ + relax.TensorStructInfo(dtype="float32"), + relax.TensorStructInfo(dtype="int64"), + ] + ), + ) + + +def test_statistical_ext_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8")) + + _check_inference(bb, relax.op.median(x0), relax.TensorStructInfo((), "float16")) + _check_inference(bb, relax.op.median(x1), relax.TensorStructInfo((), "int8")) + + +def test_statistical_ext_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4, 5), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.median(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.median(x1)) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_legalize_ops_search_statistical.py b/tests/python/relax/test_transform_legalize_ops_search_statistical.py index 7edfff3dfc43..b28451da1b18 100644 --- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py +++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py @@ -684,6 +684,84 @@ def mean(var_rxplaceholder: T.handle, var_T_divide: T.handle): tvm.ir.assert_structural_equal(mod, Expected) +def test_median(): + # fmt: off + @tvm.script.ir_module + class Median: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tuple(R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4, 5), dtype="int64")): + gv: R.Tuple(R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4, 5), dtype="int64")) = R.median(x, axis=[0], keepdims=False) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tuple(R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4, 5), dtype="int64")): + gv = R.call_tir(Expected.median, (x,), out_sinfo=[R.Tensor((3, 4, 5), dtype="float32"), R.Tensor((3, 4, 5), dtype="int64")]) + return gv + + @T.prim_func(private=True) + def median(var_x: T.handle, T_squeeze: T.Buffer((T.int64(3), T.int64(4), T.int64(5)), "float32"), T_squeeze_1: T.Buffer((T.int64(3), T.int64(4), T.int64(5)), "int64")): + T.func_attr({"tir.noalias": True}) + data_buf = T.match_buffer(var_x, (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), align=8) + # with T.block("root"): + T_full = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(5)), "int64") + out_buf = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "int64", align=8) + T_gather = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5))) + T_gather_1 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(5))) + T_gather_2 = T.alloc_buffer((T.int64(1), T.int64(3), T.int64(4), T.int64(5)), "int64") + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_full"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads() + T.writes(T_full[v_ax0, v_ax1, v_ax2, v_ax3]) + T_full[v_ax0, v_ax1, v_ax2, v_ax3] = 0 + with T.block("argsort_cpu"): + T.reads() + T.writes() + T.call_packed("tvm.contrib.sort.argsort", T.tvm_stack_make_array(data_buf.data, + T.tvm_stack_make_shape(T.int64(2), T.int64(3), T.int64(4), T.int64(5)), + 0, 4, T.float32(0.0), T.int64(0)), + T.tvm_stack_make_array(out_buf.data, + T.tvm_stack_make_shape(T.int64(2), T.int64(3), T.int64(4), T.int64(5)), + 0, 4, T.int64(0), T.int64(0)), + 0, T.bool(True)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_gather"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(data_buf[out_buf[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3], out_buf[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_gather[v_ax0, v_ax1, v_ax2, v_ax3]) + T_gather[v_ax0, v_ax1, v_ax2, v_ax3] = data_buf[out_buf[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_gather_1"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(T_gather[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3], T_full[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_gather_1[v_ax0, v_ax1, v_ax2, v_ax3]) + T_gather_1[v_ax0, v_ax1, v_ax2, v_ax3] = T_gather[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3] + for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_squeeze"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_gather_1[T.int64(0), v_ax0, v_ax1, v_ax2]) + T.writes(T_squeeze[v_ax0, v_ax1, v_ax2]) + T_squeeze[v_ax0, v_ax1, v_ax2] = T_gather_1[T.int64(0), v_ax0, v_ax1, v_ax2] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_gather_2"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(out_buf[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3], T_full[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(T_gather_2[v_ax0, v_ax1, v_ax2, v_ax3]) + T_gather_2[v_ax0, v_ax1, v_ax2, v_ax3] = out_buf[T_full[v_ax0, v_ax1, v_ax2, v_ax3], v_ax1, v_ax2, v_ax3] + for ax0, ax1, ax2 in T.grid(T.int64(3), T.int64(4), T.int64(5)): + with T.block("T_squeeze_1"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(T_gather_2[T.int64(0), v_ax0, v_ax1, v_ax2]) + T.writes(T_squeeze_1[v_ax0, v_ax1, v_ax2]) + T_squeeze_1[v_ax0, v_ax1, v_ax2] = T_gather_2[T.int64(0), v_ax0, v_ax1, v_ax2] + # fmt: on + + mod = LegalizeOps()(Median) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_std(): # fmt: off @tvm.script.ir_module diff --git a/tests/python/relax/test_tvmscript_parser_op_statistical.py b/tests/python/relax/test_tvmscript_parser_op_statistical.py index 910c08bf1e3a..6ba90c56516a 100644 --- a/tests/python/relax/test_tvmscript_parser_op_statistical.py +++ b/tests/python/relax/test_tvmscript_parser_op_statistical.py @@ -95,6 +95,25 @@ def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1, 3), "float32"): _check(foo, bb.get()["foo"]) +def test_median(): + @R.function + def foo( + x: R.Tensor((1, 2, 3, 4), "float32") + ) -> R.Tuple(R.Tensor((1, 3, 4), "float32"), R.Tensor((1, 3, 4), "int64")): + gv: R.Tuple(R.Tensor((1, 3, 4), "float32"), R.Tensor((1, 3, 4), "int64")) = R.median( + x, axis=[1] + ) + return gv + + x = relax.Var("x", R.Tensor((1, 2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.median(x, axis=[1])) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + def test_variance(): @R.function def foo(x: R.Tensor((1, 2, 3, 4), "float32")) -> R.Tensor((1,), "float32"):