From 2dc4dbb1d066627969452bb28bc9272dc1c0ecdd Mon Sep 17 00:00:00 2001 From: Jens Geyer Date: Thu, 19 Mar 2026 20:24:46 +0100 Subject: [PATCH] Regression test for THRIFT-4002: Immutable exception deserialization Client: py Patch: Jens Geyer Generated-by: Opencode big-pickle This test verifies that immutable structs (including exceptions, which are immutable by default since Thrift 0.14.0) can be properly deserialized without triggering the __setattr__ TypeError. The bug manifests when: 1. A struct class is marked immutable (has __setattr__ that raises TypeError) 2. Thrift's deserialization tries to set attributes via setattr instead of using the kwargs constructor Test coverage: - Immutable exception creation and hashability - Immutable exception blocks modification/deletion - Round-trip serialization/deserialization with TBinaryProtocol - Round-trip serialization/deserialization with TCompactProtocol - Accelerated protocol tests (C extension) when available Related: THRIFT-4002, THRIFT-5715 --- lib/py/CMakeLists.txt | 1 + lib/py/test/test_immutable_exception.py | 243 ++++++++++++++++++++++++ 2 files changed, 244 insertions(+) create mode 100644 lib/py/test/test_immutable_exception.py diff --git a/lib/py/CMakeLists.txt b/lib/py/CMakeLists.txt index 7c0e3818a6a..aa98818acc9 100644 --- a/lib/py/CMakeLists.txt +++ b/lib/py/CMakeLists.txt @@ -34,4 +34,5 @@ if(BUILD_TESTING) add_test(NAME PythonThriftTZlibTransport COMMAND Python3::Interpreter ${CMAKE_CURRENT_SOURCE_DIR}/test/thrift_TZlibTransport.py) add_test(NAME PythonThriftProtocol COMMAND Python3::Interpreter ${CMAKE_CURRENT_SOURCE_DIR}/test/thrift_TCompactProtocol.py) add_test(NAME PythonThriftTNonblockingServer COMMAND Python3::Interpreter ${CMAKE_CURRENT_SOURCE_DIR}/test/thrift_TNonblockingServer.py) + add_test(NAME PythonImmutableException COMMAND Python3::Interpreter ${CMAKE_CURRENT_SOURCE_DIR}/test/test_immutable_exception.py) endif() diff --git a/lib/py/test/test_immutable_exception.py b/lib/py/test/test_immutable_exception.py new file mode 100644 index 00000000000..827d3caa2bf --- /dev/null +++ b/lib/py/test/test_immutable_exception.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +""" +Test cases for THRIFT-4002: Immutable exception deserialization. + +This test verifies that immutable structs (including exceptions, which are immutable +by default since Thrift 0.14.0) can be properly deserialized without triggering +the __setattr__ TypeError. + +The bug manifests when: +1. A struct class is marked immutable (has __setattr__ that raises TypeError) +2. Thrift's deserialization tries to set attributes via setattr instead of + using the kwargs constructor + +This test ensures that all deserialization paths (C extension, pure Python, +all protocols) correctly handle immutable structs. +""" + +import unittest +from collections.abc import Hashable + +import glob +import os +import sys + +SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(SCRIPT_DIR))) + +for libpath in glob.glob(os.path.join(ROOT_DIR, 'lib', 'py', 'build', 'lib.*')): + for pattern in ('-%d.%d', '-%d%d'): + postfix = pattern % (sys.version_info[0], sys.version_info[1]) + if libpath.endswith(postfix): + sys.path.insert(0, libpath) + break +else: + src_path = os.path.join(ROOT_DIR, 'lib', 'py', 'src') + if os.path.exists(src_path): + sys.path.insert(0, src_path) +from thrift.Thrift import TException +from thrift.transport import TTransport +from thrift.protocol import TBinaryProtocol, TCompactProtocol + + +class ImmutableException(TException): + """Test exception that mimics generated immutable exception behavior.""" + + thrift_spec = ( + None, # 0 + (1, 11, 'message', 'UTF8', None, ), # 1: string + ) + + def __init__(self, message=None): + super(ImmutableException, self).__init__(message) + + def __setattr__(self, *args): + raise TypeError("can't modify immutable instance") + + def __delattr__(self, *args): + raise TypeError("can't modify immutable instance") + + def __hash__(self): + return hash(self.__class__) ^ hash((self.message,)) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.message == other.message + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('ImmutableException') + if self.message is not None: + oprot.writeFieldBegin('message', 11, 1) + oprot.writeString(self.message) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + @classmethod + def read(cls, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and cls.thrift_spec is not None: + return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec]) + return iprot.readStruct(cls, cls.thrift_spec, True) + + +class MutableException(TException): + """Test exception that mimics generated mutable exception behavior.""" + + thrift_spec = ( + None, # 0 + (1, 11, 'message', 'UTF8', None, ), # 1: string + ) + + def __init__(self, message=None): + super(MutableException, self).__init__(message) + + def write(self, oprot): + if oprot._fast_encode is not None and self.thrift_spec is not None: + oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec])) + return + oprot.writeStructBegin('MutableException') + if self.message is not None: + oprot.writeFieldBegin('message', 11, 1) + oprot.writeString(self.message) + oprot.writeFieldEnd() + oprot.writeFieldStop() + oprot.writeStructEnd() + + @classmethod + def read(cls, iprot): + if iprot._fast_decode is not None and isinstance(iprot.trans, TTransport.CReadableTransport) and cls.thrift_spec is not None: + return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec]) + return iprot.readStruct(cls, cls.thrift_spec, False) + + +class TestImmutableExceptionDeserialization(unittest.TestCase): + """Test that immutable exceptions can be properly deserialized.""" + + def _roundtrip(self, exc, protocol_class): + """Serialize and deserialize an exception.""" + otrans = TTransport.TMemoryBuffer() + oproto = protocol_class.getProtocol(otrans) + exc.write(oproto) + itrans = TTransport.TMemoryBuffer(otrans.getvalue()) + iproto = protocol_class.getProtocol(itrans) + return exc.__class__.read(iproto) + + def test_immutable_exception_is_hashable(self): + """Verify that immutable exceptions are hashable (required for caching/logging).""" + exc = ImmutableException(message="test") + self.assertTrue(isinstance(exc, Hashable)) + self.assertEqual(hash(exc), hash(ImmutableException(message="test"))) + + def test_immutable_exception_blocks_modification(self): + """Verify that immutable exceptions raise TypeError on attribute modification.""" + exc = ImmutableException(message="test") + with self.assertRaises(TypeError) as cm: + exc.message = "modified" + self.assertIn("immutable", str(cm.exception)) + + def test_immutable_exception_blocks_deletion(self): + """Verify that immutable exceptions raise TypeError on attribute deletion.""" + exc = ImmutableException(message="test") + with self.assertRaises(TypeError) as cm: + del exc.message + self.assertIn("immutable", str(cm.exception)) + + def test_immutable_exception_binary_protocol(self): + """Test immutable exception deserialization with TBinaryProtocol.""" + exc = ImmutableException(message="test error") + deserialized = self._roundtrip(exc, TBinaryProtocol.TBinaryProtocolFactory()) + self.assertEqual(exc.message, deserialized.message) + self.assertEqual(exc, deserialized) + + def test_immutable_exception_compact_protocol(self): + """Test immutable exception deserialization with TCompactProtocol.""" + exc = ImmutableException(message="test error") + deserialized = self._roundtrip(exc, TCompactProtocol.TCompactProtocolFactory()) + self.assertEqual(exc.message, deserialized.message) + self.assertEqual(exc, deserialized) + + def test_mutable_exception_can_be_modified(self): + """Verify that mutable exceptions can be modified (control test).""" + exc = MutableException(message="original") + exc.message = "modified" + self.assertEqual(exc.message, "modified") + + +class TestImmutableExceptionAccelerated(unittest.TestCase): + """Test immutable exception deserialization with accelerated protocols (C extension).""" + + def setUp(self): + try: + # The import is intentionally unused - it only checks if the C extension + # is available by catching ImportError. The noqa comment documents this. + from thrift.protocol import fastbinary # noqa: F401 + self._has_c_extension = True + except ImportError: + self._has_c_extension = False + + def _roundtrip(self, exc, protocol_class): + """Serialize and deserialize an exception.""" + otrans = TTransport.TMemoryBuffer() + oproto = protocol_class.getProtocol(otrans) + exc.write(oproto) + itrans = TTransport.TMemoryBuffer(otrans.getvalue()) + iproto = protocol_class.getProtocol(itrans) + return exc.__class__.read(iproto) + + def test_immutable_exception_binary_accelerated(self): + """Test immutable exception with TBinaryProtocolAccelerated.""" + if not self._has_c_extension: + self.skipTest("C extension not available") + exc = ImmutableException(message="test error") + deserialized = self._roundtrip( + exc, + TBinaryProtocol.TBinaryProtocolAcceleratedFactory(fallback=False) + ) + self.assertEqual(exc.message, deserialized.message) + self.assertEqual(exc, deserialized) + + def test_immutable_exception_compact_accelerated(self): + """Test immutable exception with TCompactProtocolAccelerated.""" + if not self._has_c_extension: + self.skipTest("C extension not available") + exc = ImmutableException(message="test error") + deserialized = self._roundtrip( + exc, + TCompactProtocol.TCompactProtocolAcceleratedFactory(fallback=False) + ) + self.assertEqual(exc.message, deserialized.message) + self.assertEqual(exc, deserialized) + + +def suite(): + suite = unittest.TestSuite() + loader = unittest.TestLoader() + suite.addTest(loader.loadTestsFromTestCase(TestImmutableExceptionDeserialization)) + suite.addTest(loader.loadTestsFromTestCase(TestImmutableExceptionAccelerated)) + return suite + + +if __name__ == "__main__": + unittest.main(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))