Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions protarrow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ProtarrowConfig:
string_type: pa.DataType = pa.string()
binary_type: pa.DataType = pa.binary()
list_array_type: type = pa.ListArray
skip_recursive_messages: bool = False

def __post_init__(self):
assert self.enum_type in SUPPORTED_ENUM_TYPES
Expand Down
87 changes: 74 additions & 13 deletions protarrow/proto_to_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,17 @@ def __len__(self) -> int:
return sum(len(i) for i in self.scalar_map if i)


def _raise_recursion_error(trace: Tuple[Descriptor, ...]):
trace_names = (d.full_name for d in trace)

raise TypeError(
"Recursive structure detected in the protobuf message. "
f"Full trace: ({', '.join(trace_names)})."
" Consider setting 'skip_recursive_messages=True'"
"in ProtarrowConfig."
)


def is_map(field_descriptor: FieldDescriptor) -> bool:
return (
field_descriptor.type == FieldDescriptor.TYPE_MESSAGE
Expand Down Expand Up @@ -249,11 +260,14 @@ def converter(x: int) -> bytes:
def field_descriptor_to_field(
field_descriptor: FieldDescriptor,
config: ProtarrowConfig,
descriptor_trace: Tuple[Descriptor, ...] = (),
) -> pa.Field:
if is_map(field_descriptor):
key_field, value_field = get_map_descriptors(field_descriptor)
key_type = field_descriptor_to_data_type(key_field, config)
value_type = field_descriptor_to_data_type(value_field, config)
key_type = field_descriptor_to_data_type(key_field, config, descriptor_trace)
value_type = field_descriptor_to_data_type(
value_field, config, descriptor_trace
)
return pa.field(
field_descriptor.name,
pa.map_(
Expand All @@ -266,14 +280,18 @@ def field_descriptor_to_field(
elif field_descriptor.label == FieldDescriptor.LABEL_REPEATED:
return pa.field(
field_descriptor.name,
config.list_(field_descriptor_to_data_type(field_descriptor, config)),
config.list_(
field_descriptor_to_data_type(
field_descriptor, config, descriptor_trace
)
),
nullable=config.list_nullable,
metadata=config.field_metadata(field_descriptor.number),
)
else:
return pa.field(
field_descriptor.name,
field_descriptor_to_data_type(field_descriptor, config),
field_descriptor_to_data_type(field_descriptor, config, descriptor_trace),
nullable=field_descriptor.has_presence,
metadata=config.field_metadata(field_descriptor.number),
)
Expand All @@ -282,6 +300,7 @@ def field_descriptor_to_field(
def _message_field_to_data_type(
field_descriptor: FieldDescriptor,
config: ProtarrowConfig,
descriptor_trace: Tuple[Descriptor, ...] = (),
) -> pa.DataType:
try:
return _PROTO_DESCRIPTOR_TO_PYARROW[field_descriptor.message_type]
Expand All @@ -291,9 +310,19 @@ def _message_field_to_data_type(
elif field_descriptor.message_type == StringValue.DESCRIPTOR:
return config.string_type
else:
descriptor = field_descriptor.message_type

if descriptor in descriptor_trace:
if config.skip_recursive_messages:
return pa.struct([])
else:
_raise_recursion_error(descriptor_trace + (descriptor,))

return pa.struct(
[
field_descriptor_to_field(child_field, config)
field_descriptor_to_field(
child_field, config, descriptor_trace + (descriptor,)
)
for child_field in field_descriptor.message_type.fields
]
)
Expand All @@ -302,6 +331,7 @@ def _message_field_to_data_type(
def field_descriptor_to_data_type(
field_descriptor: FieldDescriptor,
config: ProtarrowConfig,
descriptor_trace: Tuple[Descriptor, ...] = (),
) -> pa.DataType:
if field_descriptor.message_type == Timestamp.DESCRIPTOR:
return config.timestamp_type
Expand All @@ -310,7 +340,7 @@ def field_descriptor_to_data_type(
elif field_descriptor.message_type == Duration.DESCRIPTOR:
return config.duration_type
elif field_descriptor.type == FieldDescriptorProto.TYPE_MESSAGE:
return _message_field_to_data_type(field_descriptor, config)
return _message_field_to_data_type(field_descriptor, config, descriptor_trace)
elif field_descriptor.type == FieldDescriptorProto.TYPE_ENUM:
return config.enum_type
elif field_descriptor.type == FieldDescriptorProto.TYPE_STRING:
Expand Down Expand Up @@ -363,6 +393,7 @@ def _proto_field_to_array(
field_descriptor: FieldDescriptor,
validity_mask: Optional[Sequence[bool]],
config: ProtarrowConfig,
descriptor_trace: Tuple[Descriptor, ...] = (),
) -> pa.Array:
converter = _get_converter(field_descriptor, config)

Expand Down Expand Up @@ -392,6 +423,7 @@ def _proto_field_to_array(
field_descriptor.message_type,
validity_mask=validity_mask,
config=config,
descriptor_trace=descriptor_trace,
)


Expand All @@ -414,14 +446,19 @@ def _repeated_proto_to_array(
repeated_values: Iterable[RepeatedScalarFieldContainer],
field_descriptor: FieldDescriptor,
config: ProtarrowConfig,
descriptor_trace: Tuple[Descriptor, ...] = (),
) -> pa.ListArray:
"""
Convert Protobuf embedded lists to a 1-dimensional PyArrow ListArray with offsets
See PyArrow Layout format documentation on how to calculate offsets.
"""
offsets = _get_offsets(repeated_values)
array = _proto_field_to_array(
FlattenedIterable(repeated_values), field_descriptor, None, config
FlattenedIterable(repeated_values),
field_descriptor,
None,
config,
descriptor_trace,
)
return config.list_array_type.from_arrays(
offsets,
Expand All @@ -434,6 +471,7 @@ def _proto_map_to_array(
maps: Iterable[MessageMap],
field_descriptor: FieldDescriptor,
config: ProtarrowConfig = ProtarrowConfig(),
descriptor_trace: Tuple[Descriptor, ...] = (),
) -> pa.MapArray:
"""
Convert Protobuf maps to a 1-dimensional PyArrow MapArray with offsets
Expand All @@ -453,6 +491,7 @@ def _proto_map_to_array(
value_field,
validity_mask=None,
config=config,
descriptor_trace=descriptor_trace,
)
return pa.MapArray.from_arrays(offsets, keys, values).cast(
pa.map_(
Expand Down Expand Up @@ -495,6 +534,7 @@ def _messages_to_array(
descriptor: Descriptor,
validity_mask: Optional[Sequence[bool]],
config: ProtarrowConfig,
descriptor_trace: Tuple[Descriptor, ...] = (),
) -> pa.StructArray:
arrays = []
fields = []
Expand All @@ -511,14 +551,30 @@ def _messages_to_array(
field_values = NestedIterable(
messages, operator.attrgetter(field_descriptor.name)
)

this_trace = descriptor_trace + (descriptor,)
if descriptor in descriptor_trace:
if config.skip_recursive_messages:
continue
else:
_raise_recursion_error(this_trace)

if is_map(field_descriptor):
array = _proto_map_to_array(field_values, field_descriptor, config)
array = _proto_map_to_array(
field_values, field_descriptor, config, this_trace
)
elif field_descriptor.label == FieldDescriptorProto.LABEL_REPEATED:
array = _repeated_proto_to_array(field_values, field_descriptor, config)
array = _repeated_proto_to_array(
field_values, field_descriptor, config, this_trace
)
else:
mask = _proto_field_validity_mask(messages, field_descriptor)
array = _proto_field_to_array(
field_values, field_descriptor, validity_mask=mask, config=config
field_values,
field_descriptor,
validity_mask=mask,
config=config,
descriptor_trace=this_trace,
)

arrays.append(array)
Expand Down Expand Up @@ -574,20 +630,25 @@ def message_type_to_schema(
message_type: Type[Message],
config: ProtarrowConfig = ProtarrowConfig(),
) -> pa.Schema:
descriptor_trace = (message_type.DESCRIPTOR,)

return pa.schema(
[
field_descriptor_to_field(field_descriptor, config)
field_descriptor_to_field(field_descriptor, config, descriptor_trace)
for field_descriptor in message_type.DESCRIPTOR.fields
]
)


def message_type_to_struct_type(
message_type: Type[Message], config: ProtarrowConfig = ProtarrowConfig()
message_type: Type[Message],
config: ProtarrowConfig = ProtarrowConfig(),
) -> pa.StructType:
descriptor_trace = (message_type.DESCRIPTOR,)

return pa.struct(
[
field_descriptor_to_field(field_descriptor, config)
field_descriptor_to_field(field_descriptor, config, descriptor_trace)
for field_descriptor in message_type.DESCRIPTOR.fields
]
)
33 changes: 33 additions & 0 deletions protos/bench.proto
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,36 @@ message SuperNestedExampleMessage {
map<string, ExampleMessage> nested_example_message_string_map = 5;

}

// Recursive message tests: self-reference
message RecursiveSelfReferentialMessage {
RecursiveSelfReferentialMessage next = 1;
int32 depth = 2;
}

// Recursive message tests: nested
message RecursiveNestedMessageLevel1 {
string name = 1;
RecursiveNestedMessageLevel2 next = 2;
}
message RecursiveNestedMessageLevel2 {
string name = 1;
RecursiveNestedMessageLevel3 next = 2;
}
message RecursiveNestedMessageLevel3 {
string name = 1;
RecursiveNestedMessageLevel1 next = 2;
}


// Recursive message tests: repeated self-reference
message RecursiveSelfReferentialRepeatedMessage {
int32 depth = 1;
repeated RecursiveSelfReferentialRepeatedMessage children = 2;
}

// Recursive message tests: map self-reference
message RecursiveSelfReferentialMapMessage {
string name = 1;
map<string, RecursiveSelfReferentialMapMessage> children_map = 2;
}
6 changes: 5 additions & 1 deletion scripts/generate_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,15 @@ def can_be_optional(self) -> bool:

MAP_KEYS = ["int32", "string"]

RECURSION_DEPTH = 3 # Level at which recursion occurs on a nested recursive message


def generate():
env = Environment(loader=FileSystemLoader(DIR.as_posix()), autoescape=True)
template = env.get_template("template.proto.in")
generated = template.render(types=TYPES, map_keys=MAP_KEYS)
generated = template.render(
types=TYPES, map_keys=MAP_KEYS, recursion_depth=RECURSION_DEPTH
)
with (DIR.parent / "protos" / "bench.proto").open("w") as fp:
fp.write(generated)

Expand Down
27 changes: 27 additions & 0 deletions scripts/template.proto.in
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,30 @@ message SuperNestedExampleMessage {
map<{{map_key}}, ExampleMessage> nested_example_message_{{map_key}}_map = {{ 3 + loop.index}};
{% endfor %}
}

// Recursive message tests: self-reference
message RecursiveSelfReferentialMessage {
RecursiveSelfReferentialMessage next = 1;
int32 depth = 2;
}

// Recursive message tests: nested
{% for i in range(1, recursion_depth + 1) -%}
message RecursiveNestedMessageLevel{{ i }} {
{% set next_level = i + 1 if i < recursion_depth else 1 -%}
string name = 1;
RecursiveNestedMessageLevel{{ next_level }} next = 2;
}
{% endfor %}

// Recursive message tests: repeated self-reference
message RecursiveSelfReferentialRepeatedMessage {
int32 depth = 1;
repeated RecursiveSelfReferentialRepeatedMessage children = 2;
}

// Recursive message tests: map self-reference
message RecursiveSelfReferentialMapMessage {
string name = 1;
map<string, RecursiveSelfReferentialMapMessage> children_map = 2;
}
3 changes: 3 additions & 0 deletions tests/data/RecursiveNestedMessageLevel1.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"name": "M1_L1", "next": {"name": "M1_L2", "next": {"name": "M1_L3", "next": {"name": "M1_L1_CYCLE"}}}}
{"name": "M2_L1", "next": {"name": "M2_L2", "next": {"name": "M2_L3", "next": {"name": "M2_L1_CYCLE"}}}}
{"name": "M3_L1", "next": {"name": "M3_L2", "next": {"name": "M3_L3", "next": {"name": "M3_L1_CYCLE"}}}}
3 changes: 3 additions & 0 deletions tests/data/RecursiveSelfReferentialMapMessage.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"name": "M1_L1", "children_map": {"A": {"name": "M1_L2", "children_map": {"B": {"name": "M1_L3", "children_map": {"C": {"name": "M1_L4"}}}}}}}
{"name": "M2_L1", "children_map": {"D": {"name": "M2_L2", "children_map": {}}, "E": {"name": "M2_L2b", "children_map": {}}}}
{"name": "M3_L1", "children_map": {}}
3 changes: 3 additions & 0 deletions tests/data/RecursiveSelfReferentialMessage.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"depth": 1, "next": {"depth": 2, "next": {"depth": 3, "next": {}}}}
{"depth": 11, "next": {"depth": 12, "next": {"depth": 13, "next": {}}}}
{"depth": 21, "next": {"depth": 22, "next": {"depth": 23, "next": {}}}}
3 changes: 3 additions & 0 deletions tests/data/RecursiveSelfReferentialRepeatedMessage.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{"depth": 1, "children": [{"depth": 2, "children": [{"depth": 3, "children": []}]}]}
{"depth": 11, "children": [{"depth": 12, "children": [{"depth": 13, "children": []}]}, {"depth": 14, "children": []}]}
{"depth": 21, "children": []}
Loading