Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions python/tvm/relax/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
# pylint: disable=invalid-name
"""Commons for Relax frontend."""
from typing import Dict, List, Tuple
import numpy as _np

import tvm
from tvm import topi


def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.nd.NDArray]]]:
Expand Down Expand Up @@ -53,3 +55,75 @@ def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.n
else:
detached_mod[gv] = func
return detached_mod, params_dict


def autopad(
bb,
data,
strides,
kernel_shape,
dilations=(1, 1),
pad_type="constant",
deconv=False,
mode="SAME_UPPER",
pad_value=0.0,
):
"""
Perform autopadding with dynamic input shapes
"""
# get attributes as constants
strides = _np.array(strides)
dilated_kernel_shape = _np.array(
[(kernel - 1) * dilation + 1 for kernel, dilation in zip(kernel_shape, dilations)]
)
# get input shape
ndim = data.struct_info.ndim
data_shape = list(data.struct_info.shape)
shape = data_shape[2:ndim]

# set up integer constants
zero = 0
one = 1
two = 2

# Calculate total padding
mod = shape % strides

left = _np.maximum(dilated_kernel_shape - strides, zero)
right = _np.maximum(dilated_kernel_shape - mod, zero)

total_pad = _np.where(_np.equal(mod, zero), left, right)
if deconv:
total_pad = _np.array(kernel_shape) - one - total_pad

# split total padding into before and after
pad_before = _np.floor_divide(total_pad, two)
pad_after = total_pad - pad_before

# combine
if "LOWER" in mode:
pad = _np.concatenate(
[_np.reshape(pad_after, [-1, 1]), _np.reshape(pad_before, [-1, 1])], axis=1
)
else:
pad = _np.concatenate(
[_np.reshape(pad_before, [-1, 1]), _np.reshape(pad_after, [-1, 1])], axis=1
)

# pad N and C with zeros
pad = _np.concatenate([_np.zeros([2, 2], dtype="int64"), pad], axis=0)

if pad_type not in ["constant", "edge", "reflect"]:
raise tvm.error.OpAttributeInvalid(
"Value " + pad_type + ' in attribute "mode" is invalid for operator Pad.'
)

if pad_type == "constant":
return bb.emit_te(topi.nn.pad, data, pad[:, 0].tolist(), pad[:, 1].tolist(), pad_value)
elif pad_type == "reflect":
return bb.emit_te(
topi.nn.mirror_pad, data, pad[:, 0].tolist(), pad[:, 1].tolist(), "REFLECT"
)
else:
# TODO(gigiblender) Support edge mode.
raise NotImplementedError("Pad mode {} not implemented".format(pad_type))
32 changes: 31 additions & 1 deletion python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from tvm.ir.supply import NameSupply
from tvm.tir.generic import cast

from ..common import autopad


def get_type(elem_type: Union[str, int]) -> str:
"""Converts onnx integer datatype to numpy datatype"""
Expand Down Expand Up @@ -1208,11 +1210,15 @@ class Conv(OnnxOpConverter):

@classmethod
def _impl_v11(cls, bb, inputs, attr, params):
data = inputs[0]
if hasattr(inputs[0].struct_info, "ndim"):
ndim = inputs[0].struct_info.ndim
else:
ndim = len(inputs[0].struct_info.shape)

if "kernel_shape" not in attr:
attr["kernel_shape"] = inputs[1].struct_info.shape.values[2:]

if ndim == 3:
op = relax.op.nn.conv1d
data_layout = "NCW"
Expand All @@ -1228,9 +1234,33 @@ def _impl_v11(cls, bb, inputs, attr, params):
else:
raise NotImplementedError("Ndim > 5 not supported for convolution.")

if "auto_pad" in attr:
attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
data = autopad(
bb,
inputs[0],
attr.get("strides", [1] * (ndim - 2)),
attr["kernel_shape"],
attr.get("dilations", [1] * (ndim - 2)),
mode=attr["auto_pad"],
deconv=False,
)
elif attr["auto_pad"] == "VALID":
attr["pads"] = [0 for _ in range(ndim - 2)]
elif attr["auto_pad"] == "NOTSET":
pass
else:
msg = (
f'Value {attr["auto_pad"]} in attribute "auto_pad" of operator Conv '
f"is invalid."
)
raise tvm.error.OpAttributeInvalid(msg)
attr.pop("auto_pad")

conv_out = bb.normalize(
op(
data=inputs[0],
data=data,
weight=inputs[1],
strides=attr.get("strides", 1),
padding=attr.get("pads", 0),
Expand Down
64 changes: 49 additions & 15 deletions tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,23 +980,57 @@ def test_shrink():
@pytest.mark.parametrize("dilation", [1, 2])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("pad", [0, 2])
def test_conv(stride: int, dilation: int, pad: int, bias: bool):
@pytest.mark.parametrize("auto_pad", ["SAME_UPPER", "SAME_LOWER", "VALID"])
def test_conv(stride: int, dilation: int, pad: int, bias: bool, auto_pad: str):
def _verify_conv(input_shape, weight_shape):
nd = len(weight_shape) - 2
output_shape = [input_shape[0], weight_shape[0]] + [
(input_shape[i] + 2 * pad - dilation * (weight_shape[i] - 1) - 1) // stride + 1
for i in range(2, len(input_shape))
]
bias_shape = [output_shape[1]]
conv_node = helper.make_node(
"Conv",
inputs=["x", "w"] + (["b"] if bias else []),
outputs=["y"],
strides=[stride] * nd,
dilations=[dilation] * nd,
pads=[pad] * nd * 2,
group=input_shape[1] // weight_shape[1],
)
if auto_pad == "VALID":
output_shape = [input_shape[0], weight_shape[0]] + [
(input_shape[i] - dilation * (weight_shape[i] - 1) - 1) // stride + 1
for i in range(2, len(input_shape))
]
bias_shape = [output_shape[1]]
conv_node = helper.make_node(
"Conv",
inputs=["x", "w"] + (["b"] if bias else []),
outputs=["y"],
strides=[stride] * nd,
dilations=[dilation] * nd,
auto_pad=auto_pad,
group=input_shape[1] // weight_shape[1],
)
elif auto_pad in ("SAME_UPPER", "SAME_LOWER"):
if dilation == 2:
# auto_pad = "SAME" and dilation = 2 is not supported in ONNX
return
output_shape = [input_shape[0], weight_shape[0]] + [
(input_shape[i] + stride - 1) // stride for i in range(2, len(input_shape))
]
bias_shape = [output_shape[1]]
conv_node = helper.make_node(
"Conv",
inputs=["x", "w"] + (["b"] if bias else []),
outputs=["y"],
strides=[stride] * nd,
dilations=[dilation] * nd,
auto_pad=auto_pad,
group=input_shape[1] // weight_shape[1],
)
else:
output_shape = [input_shape[0], weight_shape[0]] + [
(input_shape[i] + 2 * pad - dilation * (weight_shape[i] - 1) - 1) // stride + 1
for i in range(2, len(input_shape))
]
bias_shape = [output_shape[1]]
conv_node = helper.make_node(
"Conv",
inputs=["x", "w"] + (["b"] if bias else []),
outputs=["y"],
strides=[stride] * nd,
dilations=[dilation] * nd,
pads=[pad] * nd * 2,
group=input_shape[1] // weight_shape[1],
)
graph = helper.make_graph(
[conv_node],
"conv_test",
Expand Down
Loading