diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 6d880ab90dc2..4c9480b58748 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -927,6 +927,40 @@ def _mean(self, node: fx.Node) -> relax.Var: keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) return self.block_builder.emit(relax.op.mean(x, dim, keepdims=keepdim)) + def _norm(self, node: fx.Node) -> relax.Var: + data = self.env[node.args[0]] + dtype = data.struct_info.dtype + order = node.args[1] if len(node.args) > 1 else node.kwargs.get("p", 2) + axis = node.args[2] if len(node.args) > 2 else None + keepdims = node.args[3] if len(node.args) > 3 else False + + if order == float("inf"): + return self.block_builder.emit( + relax.op.max(relax.op.abs(data), axis=axis, keepdims=keepdims) + ) + elif order == float("-inf"): + return self.block_builder.emit( + relax.op.min(relax.op.abs(data), axis=axis, keepdims=keepdims) + ) + # frobenius_norm + elif order == "fro": + return self.block_builder.emit( + relax.op.sqrt( + relax.op.sum(relax.op.multiply(data, data), axis=axis, keepdims=keepdims), + ) + ) + else: + reci_order = relax.const(1 / order, dtype=dtype) + order = relax.const(order, dtype=dtype) + return self.block_builder.emit( + relax.op.power( + relax.op.sum( + relax.op.power(relax.op.abs(data), order), axis=axis, keepdims=keepdims + ), + reci_order, + ) + ) + def _prod(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) x = args[0] diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 594344fef89f..297529e8bf29 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -728,6 +728,7 @@ def create_convert_map( "lerp": self._lerp, # statistical "mean": self._mean, + "norm": self._norm, "prod": self._prod, "std": self._std, "sum": self._sum, diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index ee5a5c78c74a..a962de8a3237 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -4513,5 +4513,134 @@ def main( verify_model(Narrow(), [([5, 3], "float32")], {}, Expected) +def test_norm(): + + input_info = [([1, 3, 5, 3], "float32")] + + class Norm(Module): + def __init__(self, p, dim=None, keepdim=False): + super().__init__() + self.p = p + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return torch.norm(x, p=self.p, dim=self.dim, keepdim=self.keepdim) + + @tvm.script.ir_module + class Expected1: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.max(R.abs(inp_0), axis=None, keepdims=False) + gv: R.Tensor((), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected2: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.min(R.abs(inp_0), axis=None, keepdims=False) + gv: R.Tensor((), dtype="float32") = lv + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) + lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(2, "float32")) + lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False) + lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(0.5, "float32")) + gv: R.Tensor((), dtype="float32") = lv3 + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected4: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) + lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(1.0, "float32")) + lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False) + lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(1.0, "float32")) + gv: R.Tensor((), dtype="float32") = lv3 + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected5: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) + lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(-4, "float32")) + lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False) + lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(-0.25, "float32")) + gv: R.Tensor((), dtype="float32") = lv3 + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected6: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.abs(inp_0) + lv1: R.Tensor((1, 3, 5, 3), dtype="float32") = R.power(lv, R.const(0.5, "float32")) + lv2: R.Tensor((), dtype="float32") = R.sum(lv1, axis=None, keepdims=False) + lv3: R.Tensor((), dtype="float32") = R.power(lv2, R.const(2, "float32")) + gv: R.Tensor((), dtype="float32") = lv3 + R.output(gv) + return gv + + @tvm.script.ir_module + class Expected7: + @R.function + def main( + inp_0: R.Tensor((1, 3, 5, 3), dtype="float32"), + ) -> R.Tensor((), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 5, 3), dtype="float32") = R.multiply(inp_0, inp_0) + lv1: R.Tensor((), dtype="float32") = R.sum(lv, axis=None, keepdims=False) + lv2: R.Tensor((), dtype="float32") = R.sqrt(lv1) + gv: R.Tensor((), dtype="float32") = lv2 + R.output(gv) + return gv + + norms = [ + (float("inf"), None, False), + (float("-inf"), None, False), + (float(2), None, False), + (float(1.0), None, False), + (float(-4), None, True), + (float(0.5), None, True), + ("fro", None, False), + ] + + for norm, expected in zip( + norms, [Expected1, Expected2, Expected3, Expected4, Expected5, Expected6, Expected7] + ): + p, dim, keepdim = norm + verify_model(Norm(p, dim=dim, keepdim=keepdim), input_info, {}, expected) + + if __name__ == "__main__": tvm.testing.main()