diff --git a/python/pyarrow/serialization.py b/python/pyarrow/serialization.py index 8669e824d5a..e398e9da1cb 100644 --- a/python/pyarrow/serialization.py +++ b/python/pyarrow/serialization.py @@ -136,7 +136,7 @@ def register_torch_serialization_handlers(serialization_context): import torch def _serialize_torch_tensor(obj): - return obj.numpy() + return obj.detach().numpy() def _deserialize_torch_tensor(data): return torch.from_numpy(data) diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py index e484ebb3aff..6cc391af4e5 100644 --- a/python/pyarrow/tests/test_serialization.py +++ b/python/pyarrow/tests/test_serialization.py @@ -364,6 +364,10 @@ def test_torch_serialization(large_buffer): serialization_roundtrip(obj, large_buffer, context=serialization_context) + tensor_requiring_grad = torch.randn(10, 10, requires_grad=True) + serialization_roundtrip(tensor_requiring_grad, large_buffer, + context=serialization_context) + def test_numpy_immutable(large_buffer): obj = np.zeros([10])