Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,20 @@ def _rsub(self, node: fx.Node) -> relax.Var:

return self.block_builder.emit(relax.op.subtract(rhs, lhs))

def _isin(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
elements = args[0]
test_elements = args[1]

expanded_elements = relax.op.expand_dims(elements, axis=-1)
flattened_test_elements = relax.op.reshape(test_elements, (-1,))

comparison = relax.op.equal(expanded_elements, flattened_test_elements)
summed = relax.op.sum(comparison, axis=-1)
result = relax.op.greater(summed, relax.const(0, dtype=elements.struct_info.dtype))

return self.block_builder.emit(result)

########## Neural Network ##########

def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def create_convert_map(
"hardtanh_.default": self._hardtanh,
"isfinite.default": self._unary_op(relax.op.isfinite),
"isinf.default": self._unary_op(relax.op.isinf),
"isin.Tensor_Tensor": self._isin,
"isnan.default": self._unary_op(relax.op.isnan),
"leaky_relu.default": self._leakyrelu,
"leaky_relu_.default": self._leakyrelu,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def create_convert_map(
"hardtanh": self._hardtanh,
"isfinite": self._unary_op(relax.op.isfinite),
"isinf": self._unary_op(relax.op.isinf),
"isin": self._isin,
"isnan": self._unary_op(relax.op.isnan),
"leaky_relu": self._leakyrelu,
"log": self._unary_op(relax.op.log),
Expand Down
31 changes: 31 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,37 @@ def main(
verify_model(RSub2(), example_args2, {}, expected_rsub2)


# IsIn


def test_isin():
class IsInModel(torch.nn.Module):
def forward(self, x, test_elements):
return torch.isin(x, test_elements)

@tvm.script.ir_module
class expected:
@R.function
def main(
x: R.Tensor((10, 10), dtype="float32"), test_elements: R.Tensor((8,), dtype="float32")
) -> R.Tuple(R.Tensor((10, 10), dtype="bool")):
with R.dataflow():
lv: R.Tensor((10, 10, 1), dtype="float32") = R.expand_dims(x, axis=[-1])
lv1: R.Tensor((8,), dtype="float32") = R.reshape(test_elements, R.shape([8]))
lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1)
lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1], keepdims=False)
lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3, R.const(0.0, "float32"))
gv: R.Tuple(R.Tensor((10, 10), dtype="bool")) = (lv4,)
R.output(gv)
return gv

example_args = (
torch.randn(10, 10, dtype=torch.float32),
torch.randn(8, dtype=torch.float32),
)
verify_model(IsInModel(), example_args, {}, expected)


def test_batchnorm2d():
class BatchNorm2d(Module):
def __init__(self):
Expand Down
29 changes: 29 additions & 0 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,6 +1868,35 @@ def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="fl
verify_model(RSub2(), input_info2, {}, expected_rsub2)


# IsIn


def test_isin():
input_info = [([10, 10], "float32"), ([8], "float32")]

class IsInModel(torch.nn.Module):
def forward(self, x, test_elements):
return torch.isin(x, test_elements)

@tvm.script.ir_module
class expected:
@R.function
def main(
inp_0: R.Tensor((10, 10), dtype="float32"), inp_1: R.Tensor((8,), dtype="float32")
) -> R.Tensor((10, 10), dtype="bool"):
with R.dataflow():
lv: R.Tensor((10, 10, 1), dtype="float32") = R.expand_dims(inp_0, axis=[-1])
lv1: R.Tensor((8,), dtype="float32") = R.reshape(inp_1, R.shape([8]))
lv2: R.Tensor((10, 10, 8), dtype="bool") = R.equal(lv, lv1)
lv3: R.Tensor((10, 10), dtype="bool") = R.sum(lv2, axis=[-1], keepdims=False)
lv4: R.Tensor((10, 10), dtype="bool") = R.greater(lv3, R.const(0.0, "float32"))
gv: R.Tensor((10, 10), dtype="bool") = lv4
R.output(gv)
return gv

verify_model(IsInModel(), input_info, {}, expected)


def test_size():
input_info = [([1, 3, 10, 10], "float32")]

Expand Down