Skip to content

Conversation

@jonghewk
Copy link
Contributor

Issue

Relay for depthwise nn.conv3d causes error.

Cause

depthwise is not fully implemented for conv3d.
The pytorch frontend takes care of the depthwise case and reshapes the weights.
(https://github.com/apache/tvm/blob/main/python/tvm/relay/frontend/pytorch.py#L1247).
However, this is not taken care in Conv3DRel.

Note that Conv2DRel takes care of this case.
(https://github.com/apache/tvm/blob/main/src/relay/op/nn/convolution.cc#L273)

How to reproduce

class Conv3d(nn.Module):
    def __init__(self):
        super(Conv3d, self).__init__()
        self.conv = nn.Conv3d(16, 32, kernel_size=1, groups=16)
    
    def forward(self, x):
        x = self.conv(x)
        return x

from tvm import relay
s = [1, 16, 10, 10, 10]
net = Conv3d()
input_data = torch.randn(s)
scripted_model = torch.jit.trace(net, input_data).eval()
mod = relay.frontend.from_pytorch(scripted_model, [("x", s)]) 

@jonghewk
Copy link
Contributor Author

@masahi Any chance of reviewing this?

@masahi
Copy link
Member

masahi commented Nov 22, 2023

Can you add your PT test as a test case?

@jonghewk
Copy link
Contributor Author

Can you add your PT test as a test case?

Done.

@masahi masahi merged commit 1a2cc18 into apache:main Nov 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants