From d08360ba5e81f57bfed1aaf1940730325d91b956 Mon Sep 17 00:00:00 2001 From: David Sierra-Gonzalez Date: Tue, 21 Oct 2025 15:30:40 +0000 Subject: [PATCH 1/5] add: Add detection and pruning of cyclical plain messages. --- protarrow/common.py | 1 + protarrow/proto_to_arrow.py | 60 ++++++- protos/bench.proto | 33 ++++ pyproject.toml | 6 + scripts/template.proto.in | 28 ++++ tests/data/CyclicalDirectMessage.jsonl | 1 + .../data/CyclicalIndirectMessageLevel1.jsonl | 1 + tests/data/CyclicalMapMessage.jsonl | 1 + tests/data/CyclicalRepeatedMessage.jsonl | 1 + tests/test_conversion_cyclical.py | 149 ++++++++++++++++++ 10 files changed, 276 insertions(+), 5 deletions(-) create mode 100644 tests/data/CyclicalDirectMessage.jsonl create mode 100644 tests/data/CyclicalIndirectMessageLevel1.jsonl create mode 100644 tests/data/CyclicalMapMessage.jsonl create mode 100644 tests/data/CyclicalRepeatedMessage.jsonl create mode 100644 tests/test_conversion_cyclical.py diff --git a/protarrow/common.py b/protarrow/common.py index 50433b1..a9cfce9 100644 --- a/protarrow/common.py +++ b/protarrow/common.py @@ -39,6 +39,7 @@ class ProtarrowConfig: string_type: pa.DataType = pa.string() binary_type: pa.DataType = pa.binary() list_array_type: type = pa.ListArray + purge_cyclical_messages: bool = False def __post_init__(self): assert self.enum_type in SUPPORTED_ENUM_TYPES diff --git a/protarrow/proto_to_arrow.py b/protarrow/proto_to_arrow.py index fbd933c..26b0f73 100644 --- a/protarrow/proto_to_arrow.py +++ b/protarrow/proto_to_arrow.py @@ -67,6 +67,12 @@ } +class ProtarrowCycleError(Exception): + """Raised when a cycle is found and cannot be safely processed.""" + + pass + + def _time_of_day_to_nanos(time_of_day: TimeOfDay) -> int: return ( (time_of_day.hours * 60 + time_of_day.minutes) * 60 + time_of_day.seconds @@ -363,6 +369,7 @@ def _proto_field_to_array( field_descriptor: FieldDescriptor, validity_mask: Optional[Sequence[bool]], config: ProtarrowConfig, + descriptor_trace: Optional[List[str]] = None, ) -> pa.Array: converter = _get_converter(field_descriptor, config) @@ -392,6 +399,7 @@ def _proto_field_to_array( field_descriptor.message_type, validity_mask=validity_mask, config=config, + descriptor_trace=descriptor_trace, ) @@ -414,6 +422,7 @@ def _repeated_proto_to_array( repeated_values: Iterable[RepeatedScalarFieldContainer], field_descriptor: FieldDescriptor, config: ProtarrowConfig, + descriptor_trace: Optional[List[str]] = None, ) -> pa.ListArray: """ Convert Protobuf embedded lists to a 1-dimensional PyArrow ListArray with offsets @@ -421,7 +430,11 @@ def _repeated_proto_to_array( """ offsets = _get_offsets(repeated_values) array = _proto_field_to_array( - FlattenedIterable(repeated_values), field_descriptor, None, config + FlattenedIterable(repeated_values), + field_descriptor, + None, + config, + descriptor_trace, ) return config.list_array_type.from_arrays( offsets, @@ -434,6 +447,7 @@ def _proto_map_to_array( maps: Iterable[MessageMap], field_descriptor: FieldDescriptor, config: ProtarrowConfig = ProtarrowConfig(), + descriptor_trace: Optional[List[str]] = None, ) -> pa.MapArray: """ Convert Protobuf maps to a 1-dimensional PyArrow MapArray with offsets @@ -453,6 +467,7 @@ def _proto_map_to_array( value_field, validity_mask=None, config=config, + descriptor_trace=descriptor_trace, ) return pa.MapArray.from_arrays(offsets, keys, values).cast( pa.map_( @@ -490,15 +505,28 @@ def _proto_field_validity_mask( return mask +def _raise_recursion_error(descriptor_name: str, trace: List[str]): + raise ProtarrowCycleError( + "Cyclical structure detected in protobuf message " + f"{descriptor_name}, with trace: [{', '.join(trace)}]." + " Consider setting 'purge_cyclical_messages=True'" + "in ProtarrowConfig." + ) + + def _messages_to_array( messages: Iterable[Message], descriptor: Descriptor, validity_mask: Optional[Sequence[bool]], config: ProtarrowConfig, + descriptor_trace: Optional[List[str]] = None, ) -> pa.StructArray: arrays = [] fields = [] + if descriptor_trace is None: + descriptor_trace = [] + for field_descriptor in descriptor.fields: if ( field_descriptor.type == FieldDescriptor.TYPE_MESSAGE @@ -511,14 +539,36 @@ def _messages_to_array( field_values = NestedIterable( messages, operator.attrgetter(field_descriptor.name) ) + + is_cycle = descriptor.name in descriptor_trace + is_repeated = field_descriptor.label == FieldDescriptorProto.LABEL_REPEATED + this_trace = descriptor_trace + [descriptor.name] + + if is_cycle and (is_map(field_descriptor) or is_repeated): + _raise_recursion_error(descriptor.name, this_trace) + if is_map(field_descriptor): - array = _proto_map_to_array(field_values, field_descriptor, config) - elif field_descriptor.label == FieldDescriptorProto.LABEL_REPEATED: - array = _repeated_proto_to_array(field_values, field_descriptor, config) + array = _proto_map_to_array( + field_values, field_descriptor, config, this_trace + ) + elif is_repeated: + array = _repeated_proto_to_array( + field_values, field_descriptor, config, this_trace + ) else: + if is_cycle: + if config.purge_cyclical_messages: + continue + else: + _raise_recursion_error(descriptor.name, this_trace) + mask = _proto_field_validity_mask(messages, field_descriptor) array = _proto_field_to_array( - field_values, field_descriptor, validity_mask=mask, config=config + field_values, + field_descriptor, + validity_mask=mask, + config=config, + descriptor_trace=this_trace, ) arrays.append(array) diff --git a/protos/bench.proto b/protos/bench.proto index 5c98502..4677647 100644 --- a/protos/bench.proto +++ b/protos/bench.proto @@ -187,3 +187,36 @@ message SuperNestedExampleMessage { map nested_example_message_string_map = 5; } + +// Recursion tests: self-reference +message CyclicalDirectMessage { + CyclicalDirectMessage next = 1; + int32 depth = 2; +} + +// Recursion tests: indirect cyclical +message CyclicalIndirectMessageLevel1 { + CyclicalIndirectMessageLevel2 next = 1; + string name = 2; +} +message CyclicalIndirectMessageLevel2 { + CyclicalIndirectMessageLevel3 next = 1; + string name = 2; +} +message CyclicalIndirectMessageLevel3 { + CyclicalIndirectMessageLevel1 next = 1; + string name = 2; +} + + +// Recursion tests: repeated self-reference +message CyclicalRepeatedMessage { + repeated CyclicalRepeatedMessage children = 1; + int32 depth = 2; +} + +// Recursion tests: map self-reference +message CyclicalMapMessage { + map children_map = 1; + string name = 2; +} diff --git a/pyproject.toml b/pyproject.toml index 1086293..b1121e8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,9 @@ files = ["*/__init__.py"] folders = [{path = "protarrow"}] [tool.ruff] +ignore = [ + "I001" # suppresses noisy "Import block is un-sorted or un-formatted" +] line-length = 88 [tool.ruff.lint] @@ -85,3 +88,6 @@ select = ["E", "F", "C", "I", "PERF"] [tool.ruff.lint.isort] known-first-party = ["protarrow", "protarrow_protos"] + +[tool.ruff.mccabe] +max-complexity = 11 diff --git a/scripts/template.proto.in b/scripts/template.proto.in index abefd88..0b2f40f 100644 --- a/scripts/template.proto.in +++ b/scripts/template.proto.in @@ -70,3 +70,31 @@ message SuperNestedExampleMessage { map<{{map_key}}, ExampleMessage> nested_example_message_{{map_key}}_map = {{ 3 + loop.index}}; {% endfor %} } + +// Recursion tests: self-reference +message CyclicalDirectMessage { + CyclicalDirectMessage next = 1; + int32 depth = 2; +} + +// Recursion tests: indirect cyclical +{% set RECURSION_DEPTH = 3 -%} +{% for i in range(1, RECURSION_DEPTH + 1) -%} +message CyclicalIndirectMessageLevel{{ i }} { + {% set next_level = i + 1 if i < RECURSION_DEPTH else 1 -%} + CyclicalIndirectMessageLevel{{ next_level }} next = 1; + string name = 2; +} +{% endfor %} + +// Recursion tests: repeated self-reference +message CyclicalRepeatedMessage { + repeated CyclicalRepeatedMessage children = 1; + int32 depth = 2; +} + +// Recursion tests: map self-reference +message CyclicalMapMessage { + map children_map = 1; + string name = 2; +} diff --git a/tests/data/CyclicalDirectMessage.jsonl b/tests/data/CyclicalDirectMessage.jsonl new file mode 100644 index 0000000..4023708 --- /dev/null +++ b/tests/data/CyclicalDirectMessage.jsonl @@ -0,0 +1 @@ +{"depth": 1, "next": {"depth": 2, "next": {"depth": 3, "next": {"depth": 4, "next": {"depth": 5, "next": {"depth": 6, "next": {}}}}}}} diff --git a/tests/data/CyclicalIndirectMessageLevel1.jsonl b/tests/data/CyclicalIndirectMessageLevel1.jsonl new file mode 100644 index 0000000..e3e7ffc --- /dev/null +++ b/tests/data/CyclicalIndirectMessageLevel1.jsonl @@ -0,0 +1 @@ +{"name": "L1", "next": {"name": "L2", "next": {"name": "L3", "next": {"name": "L4_CYCLE"}}}} diff --git a/tests/data/CyclicalMapMessage.jsonl b/tests/data/CyclicalMapMessage.jsonl new file mode 100644 index 0000000..114ff16 --- /dev/null +++ b/tests/data/CyclicalMapMessage.jsonl @@ -0,0 +1 @@ +{"name": "L1", "children_map": {"A": {"name": "L2", "children_map": {"B": {"name": "L3", "children_map": {"C": {"name": "L4"}}}}}}} diff --git a/tests/data/CyclicalRepeatedMessage.jsonl b/tests/data/CyclicalRepeatedMessage.jsonl new file mode 100644 index 0000000..820c8b6 --- /dev/null +++ b/tests/data/CyclicalRepeatedMessage.jsonl @@ -0,0 +1 @@ +{"depth":1,"children":[{"depth":2,"children":[{"depth":3,"children":[]}]}]} diff --git a/tests/test_conversion_cyclical.py b/tests/test_conversion_cyclical.py new file mode 100644 index 0000000..b0f5c0f --- /dev/null +++ b/tests/test_conversion_cyclical.py @@ -0,0 +1,149 @@ +# Imports sorted alphabetically +import pathlib +import pytest + +# 'from' imports, sorted alphabetically by module +from google.protobuf.json_format import Parse +from google.protobuf.message import Message +from protarrow.common import M, ProtarrowConfig +from protarrow.proto_to_arrow import ( + messages_to_record_batch, + ProtarrowCycleError, + messages_to_table, +) +from protarrow_protos.bench_pb2 import ( + CyclicalDirectMessage, + CyclicalIndirectMessageLevel1, + CyclicalMapMessage, + CyclicalRepeatedMessage, +) +from typing import List, Type + +CONFIGS = [ + ProtarrowConfig(purge_cyclical_messages=False), + ProtarrowConfig(purge_cyclical_messages=True), +] +DIR = pathlib.Path(__file__).parent + + +def read_proto_jsonl(path: pathlib.Path, message_type: Type[M]) -> List[M]: + with path.open() as fp: + return [ + Parse(line.strip(), message_type()) + for line in fp + if line.strip() and not line.startswith("#") + ] + + +def _load_data(filename: str, message_type: Type[Message]) -> List[Message]: + """Loads messages from the specific test data file.""" + source_file = DIR / "data" / filename + source_messages = read_proto_jsonl(source_file, message_type) + if not source_messages: + pytest.skip(f"Test data file {filename} is empty or missing.") + return source_messages + + +# ==================================================================== +# DIRECT SELF-REFERENCE +# X X +# A - Y => A +# A +# ==================================================================== +@pytest.mark.parametrize("config", CONFIGS) +def test_cyclical_direct_message_handling(config: ProtarrowConfig): + messages = _load_data("CyclicalDirectMessage.jsonl", CyclicalDirectMessage) + + if not config.purge_cyclical_messages: + with pytest.raises(ProtarrowCycleError): + messages_to_record_batch(messages, CyclicalDirectMessage, config) + + with pytest.raises(ProtarrowCycleError): + messages_to_table(messages, CyclicalDirectMessage, config) + + else: + rb = messages_to_record_batch(messages, CyclicalDirectMessage, config) + assert len(rb) == len(messages) + assert rb.num_columns == 2 + assert rb["next"].type.num_fields == 0 + + +# ==================================================================== +# INDIRECT CYCLE +# L1 L1 +# A - L2 => A - L2 +# B - L3 B - L3 +# C - L4 C +# A +# ==================================================================== +@pytest.mark.parametrize("config", CONFIGS) +def test_cyclical_indirect_message_handling(config: ProtarrowConfig): + messages = _load_data( + "CyclicalIndirectMessageLevel1.jsonl", CyclicalIndirectMessageLevel1 + ) + + if not config.purge_cyclical_messages: + with pytest.raises(ProtarrowCycleError): + messages_to_record_batch(messages, CyclicalIndirectMessageLevel1, config) + + with pytest.raises(ProtarrowCycleError): + messages_to_table(messages, CyclicalIndirectMessageLevel1, config) + + else: + rb = messages_to_record_batch(messages, CyclicalIndirectMessageLevel1, config) + assert len(rb) == len(messages) + assert rb.num_columns == 2 + assert rb.schema.names == ["next", "name"] + + datadict = rb.to_pylist()[0] + # Levels 1 to 3 + for i, level_name in enumerate(["L1", "L2", "L3"]): + assert datadict["name"] == level_name + datadict = datadict["next"] + + # Level 4 should have been pruned due to its type being + assert not datadict + + +# ==================================================================== +# CYCLICAL REPEATED MESSAGE +# L1 L1 +# - - +# A A +# A - L2 => A +# A - A +# - A - +# A +# A +# - +# ==================================================================== +# We only support cycle detection and exception raising here +@pytest.mark.parametrize("config", CONFIGS) +def test_cyclical_repeated_message_handling(config: ProtarrowConfig): + messages = _load_data("CyclicalRepeatedMessage.jsonl", CyclicalRepeatedMessage) + + with pytest.raises(ProtarrowCycleError): + messages_to_record_batch(messages, CyclicalRepeatedMessage, config) + + with pytest.raises(ProtarrowCycleError): + messages_to_table(messages, CyclicalRepeatedMessage, config) + + +# ==================================================================== +# CYCLICAL MAP MESSAGE +# L1 k1 L1 k1 +# | +# {L2 k2} => +# | +# {L3 k3} +# ==================================================================== +# We only support cycle detection and exception raising here +@pytest.mark.parametrize("config", CONFIGS) +def test_cyclical_map_message_handling(config: ProtarrowConfig): + messages = _load_data("CyclicalMapMessage.jsonl", CyclicalMapMessage) + + with pytest.raises(ProtarrowCycleError): + messages_to_record_batch(messages, CyclicalMapMessage, config) + + with pytest.raises(ProtarrowCycleError): + messages_to_table(messages, CyclicalMapMessage, config) From ab792a88e1a9f57f3ea6e7083f3e25848f7f6c46 Mon Sep 17 00:00:00 2001 From: David Sierra-Gonzalez Date: Fri, 24 Oct 2025 13:41:44 +0000 Subject: [PATCH 2/5] Addresses review comments and revamps conversion tests for recursive messages. --- protarrow/common.py | 2 +- protarrow/proto_to_arrow.py | 47 +-- protos/bench.proto | 42 +-- pyproject.toml | 8 - scripts/generate_proto.py | 6 +- scripts/template.proto.in | 35 +- tests/data/CyclicalDirectMessage.jsonl | 1 - .../data/CyclicalIndirectMessageLevel1.jsonl | 1 - tests/data/CyclicalMapMessage.jsonl | 1 - tests/data/CyclicalRepeatedMessage.jsonl | 1 - tests/data/RecursiveNestedMessageLevel1.jsonl | 3 + .../RecursiveSelfReferentialMapMessage.jsonl | 3 + .../RecursiveSelfReferentialMessage.jsonl | 3 + ...ursiveSelfReferentialRepeatedMessage.jsonl | 3 + tests/test_conversion_cyclical.py | 149 -------- tests/test_conversion_recursive_messages.py | 321 ++++++++++++++++++ 16 files changed, 394 insertions(+), 232 deletions(-) delete mode 100644 tests/data/CyclicalDirectMessage.jsonl delete mode 100644 tests/data/CyclicalIndirectMessageLevel1.jsonl delete mode 100644 tests/data/CyclicalMapMessage.jsonl delete mode 100644 tests/data/CyclicalRepeatedMessage.jsonl create mode 100644 tests/data/RecursiveNestedMessageLevel1.jsonl create mode 100644 tests/data/RecursiveSelfReferentialMapMessage.jsonl create mode 100644 tests/data/RecursiveSelfReferentialMessage.jsonl create mode 100644 tests/data/RecursiveSelfReferentialRepeatedMessage.jsonl delete mode 100644 tests/test_conversion_cyclical.py create mode 100644 tests/test_conversion_recursive_messages.py diff --git a/protarrow/common.py b/protarrow/common.py index a9cfce9..7c5d463 100644 --- a/protarrow/common.py +++ b/protarrow/common.py @@ -39,7 +39,7 @@ class ProtarrowConfig: string_type: pa.DataType = pa.string() binary_type: pa.DataType = pa.binary() list_array_type: type = pa.ListArray - purge_cyclical_messages: bool = False + skip_recursive_messages: bool = False def __post_init__(self): assert self.enum_type in SUPPORTED_ENUM_TYPES diff --git a/protarrow/proto_to_arrow.py b/protarrow/proto_to_arrow.py index 26b0f73..5c093f2 100644 --- a/protarrow/proto_to_arrow.py +++ b/protarrow/proto_to_arrow.py @@ -67,12 +67,6 @@ } -class ProtarrowCycleError(Exception): - """Raised when a cycle is found and cannot be safely processed.""" - - pass - - def _time_of_day_to_nanos(time_of_day: TimeOfDay) -> int: return ( (time_of_day.hours * 60 + time_of_day.minutes) * 60 + time_of_day.seconds @@ -422,7 +416,7 @@ def _repeated_proto_to_array( repeated_values: Iterable[RepeatedScalarFieldContainer], field_descriptor: FieldDescriptor, config: ProtarrowConfig, - descriptor_trace: Optional[List[str]] = None, + descriptor_trace: Tuple[Descriptor, ...] = (), ) -> pa.ListArray: """ Convert Protobuf embedded lists to a 1-dimensional PyArrow ListArray with offsets @@ -447,7 +441,7 @@ def _proto_map_to_array( maps: Iterable[MessageMap], field_descriptor: FieldDescriptor, config: ProtarrowConfig = ProtarrowConfig(), - descriptor_trace: Optional[List[str]] = None, + descriptor_trace: Tuple[Descriptor, ...] = (), ) -> pa.MapArray: """ Convert Protobuf maps to a 1-dimensional PyArrow MapArray with offsets @@ -505,11 +499,13 @@ def _proto_field_validity_mask( return mask -def _raise_recursion_error(descriptor_name: str, trace: List[str]): - raise ProtarrowCycleError( - "Cyclical structure detected in protobuf message " - f"{descriptor_name}, with trace: [{', '.join(trace)}]." - " Consider setting 'purge_cyclical_messages=True'" +def _raise_recursion_error(trace: Tuple[Descriptor, ...]): + trace_names = (d.full_name for d in trace) + + raise TypeError( + "Cyclical structure detected in the protobuf message. " + f"Full trace: ({', '.join(trace_names)})." + " Consider setting 'skip_recursive_messages=True'" "in ProtarrowConfig." ) @@ -519,14 +515,11 @@ def _messages_to_array( descriptor: Descriptor, validity_mask: Optional[Sequence[bool]], config: ProtarrowConfig, - descriptor_trace: Optional[List[str]] = None, + descriptor_trace: Tuple[Descriptor, ...] = (), ) -> pa.StructArray: arrays = [] fields = [] - if descriptor_trace is None: - descriptor_trace = [] - for field_descriptor in descriptor.fields: if ( field_descriptor.type == FieldDescriptor.TYPE_MESSAGE @@ -540,28 +533,22 @@ def _messages_to_array( messages, operator.attrgetter(field_descriptor.name) ) - is_cycle = descriptor.name in descriptor_trace - is_repeated = field_descriptor.label == FieldDescriptorProto.LABEL_REPEATED - this_trace = descriptor_trace + [descriptor.name] - - if is_cycle and (is_map(field_descriptor) or is_repeated): - _raise_recursion_error(descriptor.name, this_trace) + this_trace = descriptor_trace + (descriptor,) + if descriptor in descriptor_trace: + if config.skip_recursive_messages: + continue + else: + _raise_recursion_error(this_trace) if is_map(field_descriptor): array = _proto_map_to_array( field_values, field_descriptor, config, this_trace ) - elif is_repeated: + elif field_descriptor.label == FieldDescriptorProto.LABEL_REPEATED: array = _repeated_proto_to_array( field_values, field_descriptor, config, this_trace ) else: - if is_cycle: - if config.purge_cyclical_messages: - continue - else: - _raise_recursion_error(descriptor.name, this_trace) - mask = _proto_field_validity_mask(messages, field_descriptor) array = _proto_field_to_array( field_values, diff --git a/protos/bench.proto b/protos/bench.proto index 4677647..8de65f3 100644 --- a/protos/bench.proto +++ b/protos/bench.proto @@ -188,35 +188,35 @@ message SuperNestedExampleMessage { } -// Recursion tests: self-reference -message CyclicalDirectMessage { - CyclicalDirectMessage next = 1; +// Recursive message tests: self-reference +message RecursiveSelfReferentialMessage { + RecursiveSelfReferentialMessage next = 1; int32 depth = 2; } -// Recursion tests: indirect cyclical -message CyclicalIndirectMessageLevel1 { - CyclicalIndirectMessageLevel2 next = 1; - string name = 2; +// Recursive message tests: nested +message RecursiveNestedMessageLevel1 { + string name = 1; + RecursiveNestedMessageLevel2 next = 2; } -message CyclicalIndirectMessageLevel2 { - CyclicalIndirectMessageLevel3 next = 1; - string name = 2; +message RecursiveNestedMessageLevel2 { + string name = 1; + RecursiveNestedMessageLevel3 next = 2; } -message CyclicalIndirectMessageLevel3 { - CyclicalIndirectMessageLevel1 next = 1; - string name = 2; +message RecursiveNestedMessageLevel3 { + string name = 1; + RecursiveNestedMessageLevel1 next = 2; } -// Recursion tests: repeated self-reference -message CyclicalRepeatedMessage { - repeated CyclicalRepeatedMessage children = 1; - int32 depth = 2; +// Recursive message tests: repeated self-reference +message RecursiveSelfReferentialRepeatedMessage { + int32 depth = 1; + repeated RecursiveSelfReferentialRepeatedMessage children = 2; } -// Recursion tests: map self-reference -message CyclicalMapMessage { - map children_map = 1; - string name = 2; +// Recursive message tests: map self-reference +message RecursiveSelfReferentialMapMessage { + string name = 1; + map children_map = 2; } diff --git a/pyproject.toml b/pyproject.toml index b1121e8..74bd200 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,11 +76,6 @@ enable = true [tool.poetry-dynamic-versioning.substitution] files = ["*/__init__.py"] folders = [{path = "protarrow"}] - -[tool.ruff] -ignore = [ - "I001" # suppresses noisy "Import block is un-sorted or un-formatted" -] line-length = 88 [tool.ruff.lint] @@ -88,6 +83,3 @@ select = ["E", "F", "C", "I", "PERF"] [tool.ruff.lint.isort] known-first-party = ["protarrow", "protarrow_protos"] - -[tool.ruff.mccabe] -max-complexity = 11 diff --git a/scripts/generate_proto.py b/scripts/generate_proto.py index a8ea8d6..fb6fcbf 100644 --- a/scripts/generate_proto.py +++ b/scripts/generate_proto.py @@ -71,11 +71,15 @@ def can_be_optional(self) -> bool: MAP_KEYS = ["int32", "string"] +RECURSION_DEPTH = 3 # Level at which recursion occurs on a nested recursive message + def generate(): env = Environment(loader=FileSystemLoader(DIR.as_posix()), autoescape=True) template = env.get_template("template.proto.in") - generated = template.render(types=TYPES, map_keys=MAP_KEYS) + generated = template.render( + types=TYPES, map_keys=MAP_KEYS, recursion_depth=RECURSION_DEPTH + ) with (DIR.parent / "protos" / "bench.proto").open("w") as fp: fp.write(generated) diff --git a/scripts/template.proto.in b/scripts/template.proto.in index 0b2f40f..6b84f3e 100644 --- a/scripts/template.proto.in +++ b/scripts/template.proto.in @@ -71,30 +71,29 @@ message SuperNestedExampleMessage { {% endfor %} } -// Recursion tests: self-reference -message CyclicalDirectMessage { - CyclicalDirectMessage next = 1; +// Recursive message tests: self-reference +message RecursiveSelfReferentialMessage { + RecursiveSelfReferentialMessage next = 1; int32 depth = 2; } -// Recursion tests: indirect cyclical -{% set RECURSION_DEPTH = 3 -%} -{% for i in range(1, RECURSION_DEPTH + 1) -%} -message CyclicalIndirectMessageLevel{{ i }} { - {% set next_level = i + 1 if i < RECURSION_DEPTH else 1 -%} - CyclicalIndirectMessageLevel{{ next_level }} next = 1; - string name = 2; +// Recursive message tests: nested +{% for i in range(1, recursion_depth + 1) -%} +message RecursiveNestedMessageLevel{{ i }} { + {% set next_level = i + 1 if i < recursion_depth else 1 -%} + string name = 1; + RecursiveNestedMessageLevel{{ next_level }} next = 2; } {% endfor %} -// Recursion tests: repeated self-reference -message CyclicalRepeatedMessage { - repeated CyclicalRepeatedMessage children = 1; - int32 depth = 2; +// Recursive message tests: repeated self-reference +message RecursiveSelfReferentialRepeatedMessage { + int32 depth = 1; + repeated RecursiveSelfReferentialRepeatedMessage children = 2; } -// Recursion tests: map self-reference -message CyclicalMapMessage { - map children_map = 1; - string name = 2; +// Recursive message tests: map self-reference +message RecursiveSelfReferentialMapMessage { + string name = 1; + map children_map = 2; } diff --git a/tests/data/CyclicalDirectMessage.jsonl b/tests/data/CyclicalDirectMessage.jsonl deleted file mode 100644 index 4023708..0000000 --- a/tests/data/CyclicalDirectMessage.jsonl +++ /dev/null @@ -1 +0,0 @@ -{"depth": 1, "next": {"depth": 2, "next": {"depth": 3, "next": {"depth": 4, "next": {"depth": 5, "next": {"depth": 6, "next": {}}}}}}} diff --git a/tests/data/CyclicalIndirectMessageLevel1.jsonl b/tests/data/CyclicalIndirectMessageLevel1.jsonl deleted file mode 100644 index e3e7ffc..0000000 --- a/tests/data/CyclicalIndirectMessageLevel1.jsonl +++ /dev/null @@ -1 +0,0 @@ -{"name": "L1", "next": {"name": "L2", "next": {"name": "L3", "next": {"name": "L4_CYCLE"}}}} diff --git a/tests/data/CyclicalMapMessage.jsonl b/tests/data/CyclicalMapMessage.jsonl deleted file mode 100644 index 114ff16..0000000 --- a/tests/data/CyclicalMapMessage.jsonl +++ /dev/null @@ -1 +0,0 @@ -{"name": "L1", "children_map": {"A": {"name": "L2", "children_map": {"B": {"name": "L3", "children_map": {"C": {"name": "L4"}}}}}}} diff --git a/tests/data/CyclicalRepeatedMessage.jsonl b/tests/data/CyclicalRepeatedMessage.jsonl deleted file mode 100644 index 820c8b6..0000000 --- a/tests/data/CyclicalRepeatedMessage.jsonl +++ /dev/null @@ -1 +0,0 @@ -{"depth":1,"children":[{"depth":2,"children":[{"depth":3,"children":[]}]}]} diff --git a/tests/data/RecursiveNestedMessageLevel1.jsonl b/tests/data/RecursiveNestedMessageLevel1.jsonl new file mode 100644 index 0000000..e98b941 --- /dev/null +++ b/tests/data/RecursiveNestedMessageLevel1.jsonl @@ -0,0 +1,3 @@ +{"name": "M1_L1", "next": {"name": "M1_L2", "next": {"name": "M1_L3", "next": {"name": "M1_L1_CYCLE"}}}} +{"name": "M2_L1", "next": {"name": "M2_L2", "next": {"name": "M2_L3", "next": {"name": "M2_L1_CYCLE"}}}} +{"name": "M3_L1", "next": {"name": "M3_L2", "next": {"name": "M3_L3", "next": {"name": "M3_L1_CYCLE"}}}} diff --git a/tests/data/RecursiveSelfReferentialMapMessage.jsonl b/tests/data/RecursiveSelfReferentialMapMessage.jsonl new file mode 100644 index 0000000..63ca79b --- /dev/null +++ b/tests/data/RecursiveSelfReferentialMapMessage.jsonl @@ -0,0 +1,3 @@ +{"name": "M1_L1", "children_map": {"A": {"name": "M1_L2", "children_map": {"B": {"name": "M1_L3", "children_map": {"C": {"name": "M1_L4"}}}}}}} +{"name": "M2_L1", "children_map": {"D": {"name": "M2_L2", "children_map": {}}, "E": {"name": "M2_L2b", "children_map": {}}}} +{"name": "M3_L1", "children_map": {}} diff --git a/tests/data/RecursiveSelfReferentialMessage.jsonl b/tests/data/RecursiveSelfReferentialMessage.jsonl new file mode 100644 index 0000000..26e2e00 --- /dev/null +++ b/tests/data/RecursiveSelfReferentialMessage.jsonl @@ -0,0 +1,3 @@ +{"depth": 1, "next": {"depth": 2, "next": {"depth": 3, "next": {}}}} +{"depth": 11, "next": {"depth": 12, "next": {"depth": 13, "next": {}}}} +{"depth": 21, "next": {"depth": 22, "next": {"depth": 23, "next": {}}}} diff --git a/tests/data/RecursiveSelfReferentialRepeatedMessage.jsonl b/tests/data/RecursiveSelfReferentialRepeatedMessage.jsonl new file mode 100644 index 0000000..43ccc2f --- /dev/null +++ b/tests/data/RecursiveSelfReferentialRepeatedMessage.jsonl @@ -0,0 +1,3 @@ +{"depth": 1, "children": [{"depth": 2, "children": [{"depth": 3, "children": []}]}]} +{"depth": 11, "children": [{"depth": 12, "children": [{"depth": 13, "children": []}]}, {"depth": 14, "children": []}]} +{"depth": 21, "children": []} diff --git a/tests/test_conversion_cyclical.py b/tests/test_conversion_cyclical.py deleted file mode 100644 index b0f5c0f..0000000 --- a/tests/test_conversion_cyclical.py +++ /dev/null @@ -1,149 +0,0 @@ -# Imports sorted alphabetically -import pathlib -import pytest - -# 'from' imports, sorted alphabetically by module -from google.protobuf.json_format import Parse -from google.protobuf.message import Message -from protarrow.common import M, ProtarrowConfig -from protarrow.proto_to_arrow import ( - messages_to_record_batch, - ProtarrowCycleError, - messages_to_table, -) -from protarrow_protos.bench_pb2 import ( - CyclicalDirectMessage, - CyclicalIndirectMessageLevel1, - CyclicalMapMessage, - CyclicalRepeatedMessage, -) -from typing import List, Type - -CONFIGS = [ - ProtarrowConfig(purge_cyclical_messages=False), - ProtarrowConfig(purge_cyclical_messages=True), -] -DIR = pathlib.Path(__file__).parent - - -def read_proto_jsonl(path: pathlib.Path, message_type: Type[M]) -> List[M]: - with path.open() as fp: - return [ - Parse(line.strip(), message_type()) - for line in fp - if line.strip() and not line.startswith("#") - ] - - -def _load_data(filename: str, message_type: Type[Message]) -> List[Message]: - """Loads messages from the specific test data file.""" - source_file = DIR / "data" / filename - source_messages = read_proto_jsonl(source_file, message_type) - if not source_messages: - pytest.skip(f"Test data file {filename} is empty or missing.") - return source_messages - - -# ==================================================================== -# DIRECT SELF-REFERENCE -# X X -# A - Y => A -# A -# ==================================================================== -@pytest.mark.parametrize("config", CONFIGS) -def test_cyclical_direct_message_handling(config: ProtarrowConfig): - messages = _load_data("CyclicalDirectMessage.jsonl", CyclicalDirectMessage) - - if not config.purge_cyclical_messages: - with pytest.raises(ProtarrowCycleError): - messages_to_record_batch(messages, CyclicalDirectMessage, config) - - with pytest.raises(ProtarrowCycleError): - messages_to_table(messages, CyclicalDirectMessage, config) - - else: - rb = messages_to_record_batch(messages, CyclicalDirectMessage, config) - assert len(rb) == len(messages) - assert rb.num_columns == 2 - assert rb["next"].type.num_fields == 0 - - -# ==================================================================== -# INDIRECT CYCLE -# L1 L1 -# A - L2 => A - L2 -# B - L3 B - L3 -# C - L4 C -# A -# ==================================================================== -@pytest.mark.parametrize("config", CONFIGS) -def test_cyclical_indirect_message_handling(config: ProtarrowConfig): - messages = _load_data( - "CyclicalIndirectMessageLevel1.jsonl", CyclicalIndirectMessageLevel1 - ) - - if not config.purge_cyclical_messages: - with pytest.raises(ProtarrowCycleError): - messages_to_record_batch(messages, CyclicalIndirectMessageLevel1, config) - - with pytest.raises(ProtarrowCycleError): - messages_to_table(messages, CyclicalIndirectMessageLevel1, config) - - else: - rb = messages_to_record_batch(messages, CyclicalIndirectMessageLevel1, config) - assert len(rb) == len(messages) - assert rb.num_columns == 2 - assert rb.schema.names == ["next", "name"] - - datadict = rb.to_pylist()[0] - # Levels 1 to 3 - for i, level_name in enumerate(["L1", "L2", "L3"]): - assert datadict["name"] == level_name - datadict = datadict["next"] - - # Level 4 should have been pruned due to its type being - assert not datadict - - -# ==================================================================== -# CYCLICAL REPEATED MESSAGE -# L1 L1 -# - - -# A A -# A - L2 => A -# A - A -# - A - -# A -# A -# - -# ==================================================================== -# We only support cycle detection and exception raising here -@pytest.mark.parametrize("config", CONFIGS) -def test_cyclical_repeated_message_handling(config: ProtarrowConfig): - messages = _load_data("CyclicalRepeatedMessage.jsonl", CyclicalRepeatedMessage) - - with pytest.raises(ProtarrowCycleError): - messages_to_record_batch(messages, CyclicalRepeatedMessage, config) - - with pytest.raises(ProtarrowCycleError): - messages_to_table(messages, CyclicalRepeatedMessage, config) - - -# ==================================================================== -# CYCLICAL MAP MESSAGE -# L1 k1 L1 k1 -# | -# {L2 k2} => -# | -# {L3 k3} -# ==================================================================== -# We only support cycle detection and exception raising here -@pytest.mark.parametrize("config", CONFIGS) -def test_cyclical_map_message_handling(config: ProtarrowConfig): - messages = _load_data("CyclicalMapMessage.jsonl", CyclicalMapMessage) - - with pytest.raises(ProtarrowCycleError): - messages_to_record_batch(messages, CyclicalMapMessage, config) - - with pytest.raises(ProtarrowCycleError): - messages_to_table(messages, CyclicalMapMessage, config) diff --git a/tests/test_conversion_recursive_messages.py b/tests/test_conversion_recursive_messages.py new file mode 100644 index 0000000..8e30308 --- /dev/null +++ b/tests/test_conversion_recursive_messages.py @@ -0,0 +1,321 @@ +import pathlib +from typing import List, Type + +import pyarrow as pa +import pytest +from google.protobuf.message import Message + +from protarrow.common import ProtarrowConfig +from protarrow.proto_to_arrow import ( + messages_to_record_batch, + messages_to_table, +) +from protarrow_protos.bench_pb2 import ( + RecursiveNestedMessageLevel1, + RecursiveSelfReferentialMapMessage, + RecursiveSelfReferentialMessage, + RecursiveSelfReferentialRepeatedMessage, +) + +from .test_conversion import read_proto_jsonl + +CONFIGS = [ + ProtarrowConfig(skip_recursive_messages=False), + ProtarrowConfig(skip_recursive_messages=True), +] +DIR = pathlib.Path(__file__).parent + + +def _load_data(filename: str, message_type: Type[Message]) -> List[Message]: + """Loads messages from the specific test data file.""" + source_file = DIR / "data" / filename + source_messages = read_proto_jsonl(source_file, message_type) + if not source_messages: + raise ValueError(f"Found empty test file: {source_file}") + return source_messages + + +# ==================================================================== +# DIRECT SELF-REFERENCE +# mes A: mes A: +# mes A: => (ES) +# +# (ES): empty struct +# ==================================================================== +@pytest.mark.parametrize("config", CONFIGS) +def test_cyclical_direct_message_handling(config: ProtarrowConfig): + messages = _load_data( + "RecursiveSelfReferentialMessage.jsonl", RecursiveSelfReferentialMessage + ) + + if not config.skip_recursive_messages: + fqn = "protarrow.protos.RecursiveSelfReferentialMessage".replace(".", r"\.") + regex_pattern = r"(.*" + f"{fqn}, {fqn}" + r".*)" + + with pytest.raises(TypeError, match=regex_pattern): + messages_to_record_batch(messages, RecursiveSelfReferentialMessage, config) + + with pytest.raises(TypeError, match=regex_pattern): + messages_to_table(messages, RecursiveSelfReferentialMessage, config) + + else: + rb = messages_to_record_batch(messages, RecursiveSelfReferentialMessage, config) + + # Check schema + expected_schema = pa.schema( + [ + pa.field("next", pa.struct([])), + pa.field("depth", pa.int32(), nullable=False), + ] + ) + assert rb.schema == expected_schema + + # Check values + expected_depth_array = pa.array([1, 11, 21], type=pa.int32()) + expected_next_array = pa.StructArray.from_arrays( + arrays=[], + fields=[], + mask=pa.array([False] * len(expected_depth_array), pa.bool_()), + ) + expected_table = pa.Table.from_arrays( + [expected_next_array, expected_depth_array], schema=expected_schema + ) + expected_table = pa.Table.from_arrays( + [expected_next_array, expected_depth_array], schema=expected_schema + ) + + actual_table = pa.Table.from_batches([rb]) + assert actual_table.equals(expected_table) + + +# ==================================================================== +# INDIRECT RECURSIVE CYCLE +# mes A: mes A: +# mes B: => mes B: +# mes A: (ES) +# +# (ES): empty struct +# ==================================================================== +@pytest.mark.parametrize("config", CONFIGS) +def test_cyclical_indirect_message_handling(config: ProtarrowConfig): + messages = _load_data( + "RecursiveNestedMessageLevel1.jsonl", RecursiveNestedMessageLevel1 + ) + + if not config.skip_recursive_messages: + fqn1 = "protarrow.protos.RecursiveNestedMessageLevel1".replace(".", r"\.") + fqn2 = "protarrow.protos.RecursiveNestedMessageLevel2".replace(".", r"\.") + fqn3 = "protarrow.protos.RecursiveNestedMessageLevel3".replace(".", r"\.") + expected_trace_string = f"{fqn1}, {fqn2}, {fqn3}, {fqn1}" + regex_pattern = r"(.*" + f"{expected_trace_string}" + r".*)" + + with pytest.raises(TypeError, match=regex_pattern): + messages_to_record_batch(messages, RecursiveNestedMessageLevel1, config) + + with pytest.raises(TypeError, match=regex_pattern): + messages_to_table(messages, RecursiveNestedMessageLevel1, config) + + else: + rb = messages_to_record_batch(messages, RecursiveNestedMessageLevel1, config) + + # Check schema + pruned_struct = pa.struct([]) + level3_struct = pa.struct( + [ + pa.field("name", pa.string(), nullable=False), + pa.field("next", pruned_struct), + ] + ) + level2_struct = pa.struct( + [ + pa.field("name", pa.string(), nullable=False), + pa.field("next", level3_struct), + ] + ) + expected_schema = pa.schema( + [ + pa.field("name", pa.string(), nullable=False), + pa.field("next", level2_struct), + ] + ) + assert rb.schema == expected_schema + + # Check values + num_rows = 3 + + level3_name_array = pa.array( + [f"M{i}_L3" for i in range(1, num_rows + 1)], pa.string() + ) + level3_pruned_array = pa.StructArray.from_arrays( + arrays=[], + fields=[], + mask=pa.array([False] * len(level3_name_array), pa.bool_()), + ) + level3_array = pa.StructArray.from_arrays( + arrays=[level3_name_array, level3_pruned_array], + fields=[level3_struct.field("name"), level3_struct.field("next")], + ) + + level2_name_array = pa.array( + [f"M{i}_L2" for i in range(1, num_rows + 1)], pa.string() + ) + level2_array = pa.StructArray.from_arrays( + arrays=[level2_name_array, level3_array], + fields=[level2_struct.field("name"), level2_struct.field("next")], + ) + + level1_name_array = pa.array( + [f"M{i}_L1" for i in range(1, num_rows + 1)], pa.string() + ) + expected_table = pa.Table.from_arrays( + [level1_name_array, level2_array], + schema=expected_schema, + ) + + actual_table = pa.Table.from_batches([rb]) + assert actual_table.equals(expected_table) + + +# ==================================================================== +# CYCLICAL REPEATED MESSAGE +# mes A: mes A: +# [mes A, mes A, mes A] => [(ES), (ES), (ES)] +# +# (ES): empty struct +# ==================================================================== +@pytest.mark.parametrize("config", CONFIGS) +def test_cyclical_repeated_message_handling(config: ProtarrowConfig): + messages = _load_data( + "RecursiveSelfReferentialRepeatedMessage.jsonl", + RecursiveSelfReferentialRepeatedMessage, + ) + + fqn = "protarrow.protos.RecursiveSelfReferentialRepeatedMessage".replace(".", r"\.") + # The expected trace is the recursive field name repeated twice + regex_pattern = r".*" + f"({fqn}, {fqn})" + r".*" + + if not config.skip_recursive_messages: + with pytest.raises(TypeError, match=regex_pattern): + messages_to_record_batch( + messages, RecursiveSelfReferentialRepeatedMessage, config + ) + with pytest.raises(TypeError, match=regex_pattern): + messages_to_table(messages, RecursiveSelfReferentialRepeatedMessage, config) + + else: + rb = messages_to_record_batch( + messages, RecursiveSelfReferentialRepeatedMessage, config + ) + + # Check schema + pruned_item_field = pa.field( + name=config.list_value_name, + type=pa.struct([]), + nullable=config.list_value_nullable, + ) + expected_children_list_type = pa.list_(pruned_item_field) + + expected_schema = pa.schema( + [ + pa.field("depth", pa.int32(), nullable=False), + pa.field( + "children", + expected_children_list_type, + nullable=config.list_value_nullable, + ), + ] + ) + assert rb.schema == expected_schema + + # Check values + expected_depth_array = pa.array([1, 11, 21], pa.int32()) + child_struct_array = pa.StructArray.from_arrays( + arrays=[], + fields=[], + mask=pa.array([False] * len(expected_depth_array), pa.bool_()), + ) + + list_offsets = pa.array([0, 1, 3, 3], pa.int32()) + expected_children_list_array = pa.ListArray.from_arrays( + offsets=list_offsets, + values=child_struct_array, + type=expected_children_list_type, + ) + + expected_table = pa.Table.from_arrays( + [expected_depth_array, expected_children_list_array], schema=expected_schema + ) + actual_table = pa.Table.from_batches([rb]) + + assert actual_table.equals(expected_table) + + +# ==================================================================== +# CYCLICAL MAP MESSAGE +# mes A: mes A: +# map<*, mes A> => map<*, (ES)> +# ==================================================================== +@pytest.mark.parametrize("config", CONFIGS) +def test_cyclical_map_message_handling(config: ProtarrowConfig): + messages = _load_data( + "RecursiveSelfReferentialMapMessage.jsonl", RecursiveSelfReferentialMapMessage + ) + + fqn = "protarrow.protos.RecursiveSelfReferentialMapMessage".replace(".", r"\.") + regex_pattern = r".*" + f"({fqn}, {fqn})" + r".*" + + if not config.skip_recursive_messages: + with pytest.raises(TypeError, match=regex_pattern): + messages_to_record_batch( + messages, RecursiveSelfReferentialMapMessage, config + ) + + with pytest.raises(TypeError, match=regex_pattern): + messages_to_table(messages, RecursiveSelfReferentialMapMessage, config) + + else: + rb = messages_to_record_batch( + messages, RecursiveSelfReferentialMapMessage, config + ) + + # Check schema + pruned_value_struct = pa.struct([]) + key_type = pa.string() + value_field = pa.field( + name=config.map_value_name, + type=pruned_value_struct, + nullable=config.map_value_nullable, + ) + children_map_type = pa.map_(key_type, value_field) + + expected_schema = pa.schema( + [ + pa.field("name", pa.string(), nullable=False), + pa.field("children_map", children_map_type, nullable=False), + ] + ) + assert rb.schema == expected_schema, ( + "Schema mismatch for map self-reference pruning." + ) + + # Check values + key_array = pa.array(["A", "D", "E"], pa.string()) + pruned_value_array = pa.StructArray.from_arrays( + arrays=[], fields=[], mask=pa.array([False] * len(key_array)) + ) + expected_name_array = pa.array(["M1_L1", "M2_L1", "M3_L1"], pa.string()) + # Offsets for 1 item (M1), 2 items (M2), 0 items (M3) + list_offsets = pa.array([0, 1, 3, 3], pa.int32()) + + expected_children_map_array = pa.MapArray.from_arrays( + offsets=list_offsets, + keys=key_array, + items=pruned_value_array, + type=children_map_type, + ) + expected_table = pa.Table.from_arrays( + [expected_name_array, expected_children_map_array], schema=expected_schema + ) + + actual_table = pa.Table.from_batches([rb]) + assert actual_table.equals(expected_table) From 2d35f2750aa8ac54925c203cb7026b8f8abcbf58 Mon Sep 17 00:00:00 2001 From: David Sierra-Gonzalez Date: Fri, 24 Oct 2025 18:32:10 +0000 Subject: [PATCH 3/5] Allow all public functions in proto_to_arrow to handle recursive messages. --- protarrow/proto_to_arrow.py | 64 +++++++++++------ tests/test_conversion_recursive_messages.py | 80 ++++++++++++++++++--- 2 files changed, 116 insertions(+), 28 deletions(-) diff --git a/protarrow/proto_to_arrow.py b/protarrow/proto_to_arrow.py index 5c093f2..3f4938a 100644 --- a/protarrow/proto_to_arrow.py +++ b/protarrow/proto_to_arrow.py @@ -198,6 +198,17 @@ def __len__(self) -> int: return sum(len(i) for i in self.scalar_map if i) +def _raise_recursion_error(trace: Tuple[Descriptor, ...]): + trace_names = (d.full_name for d in trace) + + raise TypeError( + "Cyclical structure detected in the protobuf message. " + f"Full trace: ({', '.join(trace_names)})." + " Consider setting 'skip_recursive_messages=True'" + "in ProtarrowConfig." + ) + + def is_map(field_descriptor: FieldDescriptor) -> bool: return ( field_descriptor.type == FieldDescriptor.TYPE_MESSAGE @@ -249,11 +260,14 @@ def converter(x: int) -> bytes: def field_descriptor_to_field( field_descriptor: FieldDescriptor, config: ProtarrowConfig, + descriptor_trace: Tuple[Descriptor, ...] = (), ) -> pa.Field: if is_map(field_descriptor): key_field, value_field = get_map_descriptors(field_descriptor) - key_type = field_descriptor_to_data_type(key_field, config) - value_type = field_descriptor_to_data_type(value_field, config) + key_type = field_descriptor_to_data_type(key_field, config, descriptor_trace) + value_type = field_descriptor_to_data_type( + value_field, config, descriptor_trace + ) return pa.field( field_descriptor.name, pa.map_( @@ -266,14 +280,18 @@ def field_descriptor_to_field( elif field_descriptor.label == FieldDescriptor.LABEL_REPEATED: return pa.field( field_descriptor.name, - config.list_(field_descriptor_to_data_type(field_descriptor, config)), + config.list_( + field_descriptor_to_data_type( + field_descriptor, config, descriptor_trace + ) + ), nullable=config.list_nullable, metadata=config.field_metadata(field_descriptor.number), ) else: return pa.field( field_descriptor.name, - field_descriptor_to_data_type(field_descriptor, config), + field_descriptor_to_data_type(field_descriptor, config, descriptor_trace), nullable=field_descriptor.has_presence, metadata=config.field_metadata(field_descriptor.number), ) @@ -282,6 +300,7 @@ def field_descriptor_to_field( def _message_field_to_data_type( field_descriptor: FieldDescriptor, config: ProtarrowConfig, + descriptor_trace: Tuple[Descriptor, ...] = (), ) -> pa.DataType: try: return _PROTO_DESCRIPTOR_TO_PYARROW[field_descriptor.message_type] @@ -291,9 +310,19 @@ def _message_field_to_data_type( elif field_descriptor.message_type == StringValue.DESCRIPTOR: return config.string_type else: + descriptor = field_descriptor.message_type + + if descriptor in descriptor_trace: + if config.skip_recursive_messages: + return pa.struct([]) + else: + _raise_recursion_error(descriptor_trace + (descriptor,)) + return pa.struct( [ - field_descriptor_to_field(child_field, config) + field_descriptor_to_field( + child_field, config, descriptor_trace + (descriptor,) + ) for child_field in field_descriptor.message_type.fields ] ) @@ -302,6 +331,7 @@ def _message_field_to_data_type( def field_descriptor_to_data_type( field_descriptor: FieldDescriptor, config: ProtarrowConfig, + descriptor_trace: Tuple[Descriptor, ...] = (), ) -> pa.DataType: if field_descriptor.message_type == Timestamp.DESCRIPTOR: return config.timestamp_type @@ -310,7 +340,7 @@ def field_descriptor_to_data_type( elif field_descriptor.message_type == Duration.DESCRIPTOR: return config.duration_type elif field_descriptor.type == FieldDescriptorProto.TYPE_MESSAGE: - return _message_field_to_data_type(field_descriptor, config) + return _message_field_to_data_type(field_descriptor, config, descriptor_trace) elif field_descriptor.type == FieldDescriptorProto.TYPE_ENUM: return config.enum_type elif field_descriptor.type == FieldDescriptorProto.TYPE_STRING: @@ -499,17 +529,6 @@ def _proto_field_validity_mask( return mask -def _raise_recursion_error(trace: Tuple[Descriptor, ...]): - trace_names = (d.full_name for d in trace) - - raise TypeError( - "Cyclical structure detected in the protobuf message. " - f"Full trace: ({', '.join(trace_names)})." - " Consider setting 'skip_recursive_messages=True'" - "in ProtarrowConfig." - ) - - def _messages_to_array( messages: Iterable[Message], descriptor: Descriptor, @@ -611,20 +630,25 @@ def message_type_to_schema( message_type: Type[Message], config: ProtarrowConfig = ProtarrowConfig(), ) -> pa.Schema: + descriptor_trace = (message_type.DESCRIPTOR,) + return pa.schema( [ - field_descriptor_to_field(field_descriptor, config) + field_descriptor_to_field(field_descriptor, config, descriptor_trace) for field_descriptor in message_type.DESCRIPTOR.fields ] ) def message_type_to_struct_type( - message_type: Type[Message], config: ProtarrowConfig = ProtarrowConfig() + message_type: Type[Message], + config: ProtarrowConfig = ProtarrowConfig(), ) -> pa.StructType: + descriptor_trace = (message_type.DESCRIPTOR,) + return pa.struct( [ - field_descriptor_to_field(field_descriptor, config) + field_descriptor_to_field(field_descriptor, config, descriptor_trace) for field_descriptor in message_type.DESCRIPTOR.fields ] ) diff --git a/tests/test_conversion_recursive_messages.py b/tests/test_conversion_recursive_messages.py index 8e30308..e5e8c1f 100644 --- a/tests/test_conversion_recursive_messages.py +++ b/tests/test_conversion_recursive_messages.py @@ -7,6 +7,8 @@ from protarrow.common import ProtarrowConfig from protarrow.proto_to_arrow import ( + message_type_to_schema, + message_type_to_struct_type, messages_to_record_batch, messages_to_table, ) @@ -36,14 +38,14 @@ def _load_data(filename: str, message_type: Type[Message]) -> List[Message]: # ==================================================================== -# DIRECT SELF-REFERENCE +# RECURSIVE SELF-REFERENTIAL MESSAGES: # mes A: mes A: # mes A: => (ES) # # (ES): empty struct # ==================================================================== @pytest.mark.parametrize("config", CONFIGS) -def test_cyclical_direct_message_handling(config: ProtarrowConfig): +def test_recursive_self_referential_message_handling(config: ProtarrowConfig): messages = _load_data( "RecursiveSelfReferentialMessage.jsonl", RecursiveSelfReferentialMessage ) @@ -58,8 +60,20 @@ def test_cyclical_direct_message_handling(config: ProtarrowConfig): with pytest.raises(TypeError, match=regex_pattern): messages_to_table(messages, RecursiveSelfReferentialMessage, config) + with pytest.raises(TypeError, match=regex_pattern): + message_type_to_schema(RecursiveSelfReferentialMessage, config) + + with pytest.raises(TypeError, match=regex_pattern): + message_type_to_struct_type(RecursiveSelfReferentialMessage, config) + else: rb = messages_to_record_batch(messages, RecursiveSelfReferentialMessage, config) + inferred_schema = message_type_to_schema( + RecursiveSelfReferentialMessage, config + ) + inferred_type = message_type_to_struct_type( + RecursiveSelfReferentialMessage, config + ) # Check schema expected_schema = pa.schema( @@ -68,7 +82,11 @@ def test_cyclical_direct_message_handling(config: ProtarrowConfig): pa.field("depth", pa.int32(), nullable=False), ] ) + expected_type = pa.struct(expected_schema) + assert rb.schema == expected_schema + assert inferred_schema == expected_schema + assert inferred_type == expected_type # Check values expected_depth_array = pa.array([1, 11, 21], type=pa.int32()) @@ -89,7 +107,7 @@ def test_cyclical_direct_message_handling(config: ProtarrowConfig): # ==================================================================== -# INDIRECT RECURSIVE CYCLE +# NESTED RECURSIVE MESSAGES: # mes A: mes A: # mes B: => mes B: # mes A: (ES) @@ -97,7 +115,7 @@ def test_cyclical_direct_message_handling(config: ProtarrowConfig): # (ES): empty struct # ==================================================================== @pytest.mark.parametrize("config", CONFIGS) -def test_cyclical_indirect_message_handling(config: ProtarrowConfig): +def test_recursive_nested_message_handling(config: ProtarrowConfig): messages = _load_data( "RecursiveNestedMessageLevel1.jsonl", RecursiveNestedMessageLevel1 ) @@ -115,8 +133,18 @@ def test_cyclical_indirect_message_handling(config: ProtarrowConfig): with pytest.raises(TypeError, match=regex_pattern): messages_to_table(messages, RecursiveNestedMessageLevel1, config) + with pytest.raises(TypeError, match=regex_pattern): + message_type_to_schema(RecursiveNestedMessageLevel1, config) + + with pytest.raises(TypeError, match=regex_pattern): + message_type_to_struct_type(RecursiveNestedMessageLevel1, config) + else: rb = messages_to_record_batch(messages, RecursiveNestedMessageLevel1, config) + inferred_schema = message_type_to_schema(RecursiveNestedMessageLevel1, config) + inferred_type = message_type_to_struct_type( + RecursiveNestedMessageLevel1, config + ) # Check schema pruned_struct = pa.struct([]) @@ -138,7 +166,11 @@ def test_cyclical_indirect_message_handling(config: ProtarrowConfig): pa.field("next", level2_struct), ] ) + expected_type = pa.struct(expected_schema) + assert rb.schema == expected_schema + assert inferred_schema == expected_schema + assert inferred_type == expected_type # Check values num_rows = 3 @@ -177,14 +209,14 @@ def test_cyclical_indirect_message_handling(config: ProtarrowConfig): # ==================================================================== -# CYCLICAL REPEATED MESSAGE +# RECURSIVE SELF-REFERENTIAL REPEATED MESSAGES # mes A: mes A: # [mes A, mes A, mes A] => [(ES), (ES), (ES)] # # (ES): empty struct # ==================================================================== @pytest.mark.parametrize("config", CONFIGS) -def test_cyclical_repeated_message_handling(config: ProtarrowConfig): +def test_recursive_self_referential_repeated_message_handling(config: ProtarrowConfig): messages = _load_data( "RecursiveSelfReferentialRepeatedMessage.jsonl", RecursiveSelfReferentialRepeatedMessage, @@ -202,10 +234,22 @@ def test_cyclical_repeated_message_handling(config: ProtarrowConfig): with pytest.raises(TypeError, match=regex_pattern): messages_to_table(messages, RecursiveSelfReferentialRepeatedMessage, config) + with pytest.raises(TypeError, match=regex_pattern): + message_type_to_schema(RecursiveSelfReferentialRepeatedMessage, config) + + with pytest.raises(TypeError, match=regex_pattern): + message_type_to_struct_type(RecursiveSelfReferentialRepeatedMessage, config) + else: rb = messages_to_record_batch( messages, RecursiveSelfReferentialRepeatedMessage, config ) + inferred_schema = message_type_to_schema( + RecursiveSelfReferentialRepeatedMessage, config + ) + inferred_type = message_type_to_struct_type( + RecursiveSelfReferentialRepeatedMessage, config + ) # Check schema pruned_item_field = pa.field( @@ -225,7 +269,11 @@ def test_cyclical_repeated_message_handling(config: ProtarrowConfig): ), ] ) + expected_type = pa.struct(expected_schema) + assert rb.schema == expected_schema + assert inferred_schema == expected_schema + assert inferred_type == expected_type # Check values expected_depth_array = pa.array([1, 11, 21], pa.int32()) @@ -251,12 +299,12 @@ def test_cyclical_repeated_message_handling(config: ProtarrowConfig): # ==================================================================== -# CYCLICAL MAP MESSAGE +# RECURSIVE SELF-REFERENTIAL MAP MESSAGES # mes A: mes A: # map<*, mes A> => map<*, (ES)> # ==================================================================== @pytest.mark.parametrize("config", CONFIGS) -def test_cyclical_map_message_handling(config: ProtarrowConfig): +def test_recursive_self_referential_map_message_handling(config: ProtarrowConfig): messages = _load_data( "RecursiveSelfReferentialMapMessage.jsonl", RecursiveSelfReferentialMapMessage ) @@ -273,10 +321,22 @@ def test_cyclical_map_message_handling(config: ProtarrowConfig): with pytest.raises(TypeError, match=regex_pattern): messages_to_table(messages, RecursiveSelfReferentialMapMessage, config) + with pytest.raises(TypeError, match=regex_pattern): + message_type_to_schema(RecursiveSelfReferentialMapMessage, config) + + with pytest.raises(TypeError, match=regex_pattern): + message_type_to_struct_type(RecursiveSelfReferentialMapMessage, config) + else: rb = messages_to_record_batch( messages, RecursiveSelfReferentialMapMessage, config ) + inferred_schema = message_type_to_schema( + RecursiveSelfReferentialMapMessage, config + ) + inferred_type = message_type_to_struct_type( + RecursiveSelfReferentialMapMessage, config + ) # Check schema pruned_value_struct = pa.struct([]) @@ -294,9 +354,13 @@ def test_cyclical_map_message_handling(config: ProtarrowConfig): pa.field("children_map", children_map_type, nullable=False), ] ) + expected_type = pa.struct(expected_schema) + assert rb.schema == expected_schema, ( "Schema mismatch for map self-reference pruning." ) + assert inferred_schema == expected_schema + assert inferred_type == expected_type # Check values key_array = pa.array(["A", "D", "E"], pa.string()) From 068067c94a10a0fe57a60242f8700085b8fc5835 Mon Sep 17 00:00:00 2001 From: David Sierra-Gonzalez Date: Mon, 27 Oct 2025 12:10:03 +0000 Subject: [PATCH 4/5] Restores pyproject.toml and fixes typos. --- protarrow/proto_to_arrow.py | 4 ++-- pyproject.toml | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/protarrow/proto_to_arrow.py b/protarrow/proto_to_arrow.py index 3f4938a..e0c2362 100644 --- a/protarrow/proto_to_arrow.py +++ b/protarrow/proto_to_arrow.py @@ -202,7 +202,7 @@ def _raise_recursion_error(trace: Tuple[Descriptor, ...]): trace_names = (d.full_name for d in trace) raise TypeError( - "Cyclical structure detected in the protobuf message. " + "Recursive structure detected in the protobuf message. " f"Full trace: ({', '.join(trace_names)})." " Consider setting 'skip_recursive_messages=True'" "in ProtarrowConfig." @@ -393,7 +393,7 @@ def _proto_field_to_array( field_descriptor: FieldDescriptor, validity_mask: Optional[Sequence[bool]], config: ProtarrowConfig, - descriptor_trace: Optional[List[str]] = None, + descriptor_trace: Tuple[Descriptor, ...] = (), ) -> pa.Array: converter = _get_converter(field_descriptor, config) diff --git a/pyproject.toml b/pyproject.toml index 74bd200..1086293 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,8 @@ enable = true [tool.poetry-dynamic-versioning.substitution] files = ["*/__init__.py"] folders = [{path = "protarrow"}] + +[tool.ruff] line-length = 88 [tool.ruff.lint] From ddacc279a7dfb562ca59677da9c302b53b355da8 Mon Sep 17 00:00:00 2001 From: David Sierra-Gonzalez Date: Mon, 27 Oct 2025 13:54:41 +0000 Subject: [PATCH 5/5] Fixes flaky test caused by the unsorted nature of map/dict. --- tests/test_conversion_recursive_messages.py | 28 +++++++++++++++------ 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/tests/test_conversion_recursive_messages.py b/tests/test_conversion_recursive_messages.py index e5e8c1f..a7615da 100644 --- a/tests/test_conversion_recursive_messages.py +++ b/tests/test_conversion_recursive_messages.py @@ -362,7 +362,7 @@ def test_recursive_self_referential_map_message_handling(config: ProtarrowConfig assert inferred_schema == expected_schema assert inferred_type == expected_type - # Check values + # Check values: map key order is arbitrary, so don't compare at the table level key_array = pa.array(["A", "D", "E"], pa.string()) pruned_value_array = pa.StructArray.from_arrays( arrays=[], fields=[], mask=pa.array([False] * len(key_array)) @@ -370,16 +370,30 @@ def test_recursive_self_referential_map_message_handling(config: ProtarrowConfig expected_name_array = pa.array(["M1_L1", "M2_L1", "M3_L1"], pa.string()) # Offsets for 1 item (M1), 2 items (M2), 0 items (M3) list_offsets = pa.array([0, 1, 3, 3], pa.int32()) - expected_children_map_array = pa.MapArray.from_arrays( offsets=list_offsets, keys=key_array, items=pruned_value_array, type=children_map_type, ) - expected_table = pa.Table.from_arrays( - [expected_name_array, expected_children_map_array], schema=expected_schema - ) - actual_table = pa.Table.from_batches([rb]) - assert actual_table.equals(expected_table) + inferred_table = pa.Table.from_batches([rb]) + assert inferred_table.column("name").num_chunks == 1 + inferred_name_array = inferred_table.column("name").chunk(0) + assert inferred_name_array.equals(expected_name_array) + + expected_children_list = expected_children_map_array.to_pylist() + inferred_children_list = inferred_table.column("children_map").to_pylist() + + sorted_expected_children_list = [ + sorted(row, key=lambda x: x[0]) for row in expected_children_list + ] + sorted_inferred_children_list = [ + sorted(row, key=lambda x: x[0]) for row in inferred_children_list + ] + assert sorted_expected_children_list == sorted_inferred_children_list + + # Finally, assert all nested messages in the recursive map were skipped + for row in sorted_inferred_children_list: + for key, value in row: + assert not value