diff --git a/dev/provision.py b/dev/provision.py index b75030f8a3..ca6e5aa6aa 100644 --- a/dev/provision.py +++ b/dev/provision.py @@ -301,3 +301,23 @@ ); """ ) + +spark.sql( + """ +CREATE TABLE default.test_table_sanitized_character ( + `letter/abc` string +) +USING iceberg +TBLPROPERTIES ( + 'format-version'='1' +); +""" +) + +spark.sql( + f""" +INSERT INTO default.test_table_sanitized_character +VALUES + ('123') +""" +) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 29d4a4b170..a3d1869cf8 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -113,6 +113,7 @@ pre_order_visit, promote, prune_columns, + sanitize_column_names, visit, visit_with_partner, ) @@ -830,7 +831,7 @@ def _task_to_table( bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) pyarrow_filter = expression_to_pyarrow(bound_file_filter) - file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) + file_project_schema = sanitize_column_names(prune_columns(file_schema, projected_field_ids, select_full_types=False)) if file_schema is None: raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}") diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index 28101809c7..7fdd717858 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1273,6 +1273,100 @@ def primitive(self, primitive: PrimitiveType) -> PrimitiveType: return primitive +# Implementation copied from Apache Iceberg repo. +def make_compatible_name(name: str) -> str: + if not _valid_avro_name(name): + return _sanitize_name(name) + return name + + +def _valid_avro_name(name: str) -> bool: + length = len(name) + assert length > 0, ValueError("Can not validate empty avro name") + first = name[0] + if not (first.isalpha() or first == '_'): + return False + + for character in name[1:]: + if not (character.isalnum() or character == '_'): + return False + return True + + +def _sanitize_name(name: str) -> str: + sb = [] + first = name[0] + if not (first.isalpha() or first == '_'): + sb.append(_sanitize_char(first)) + else: + sb.append(first) + + for character in name[1:]: + if not (character.isalnum() or character == '_'): + sb.append(_sanitize_char(character)) + else: + sb.append(character) + return ''.join(sb) + + +def _sanitize_char(character: str) -> str: + return "_" + character if character.isdigit() else "_x" + hex(ord(character))[2:].upper() + + +def sanitize_column_names(schema: Schema) -> Schema: + """Sanitize column names to make them compatible with Avro. + + The column name should be starting with '_' or digit followed by a string only contains '_', digit or alphabet, + otherwise it will be sanitized to conform the avro naming convention. + + Args: + schema: The schema to be sanitized. + + Returns: + The sanitized schema. + """ + result = visit(schema.as_struct(), _SanitizeColumnsVisitor()) + return Schema( + *(result or StructType()).fields, + schema_id=schema.schema_id, + identifier_field_ids=schema.identifier_field_ids, + ) + + +class _SanitizeColumnsVisitor(SchemaVisitor[Optional[IcebergType]]): + def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]: + return struct_result + + def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]: + return NestedField( + field_id=field.field_id, + name=make_compatible_name(field.name), + field_type=field_result, + doc=field.doc, + required=field.required, + ) + + def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]: + return StructType(*[field for field in field_results if field is not None]) + + def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]: + return ListType(element_id=list_type.element_id, element_type=element_result, element_required=list_type.element_required) + + def map( + self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType] + ) -> Optional[IcebergType]: + return MapType( + key_id=map_type.key_id, + value_id=map_type.value_id, + key_type=key_result, + value_type=value_result, + value_required=map_type.value_required, + ) + + def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: + return primitive + + def prune_columns(schema: Schema, selected: Set[int], select_full_types: bool = True) -> Schema: """Prunes a column by only selecting a set of field-ids. diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 8b62212593..7174c91425 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -66,7 +66,7 @@ ) from pyiceberg.manifest import DataFile, DataFileContent, FileFormat from pyiceberg.partitioning import PartitionSpec -from pyiceberg.schema import Schema, visit +from pyiceberg.schema import Schema, make_compatible_name, visit from pyiceberg.table import FileScanTask, Table from pyiceberg.table.metadata import TableMetadataV2 from pyiceberg.types import ( @@ -1542,3 +1542,8 @@ def check_results(location: str, expected_schema: str, expected_netloc: str, exp check_results("/root/foo.txt", "file", "", "/root/foo.txt") check_results("/root/tmp/foo.txt", "file", "", "/root/tmp/foo.txt") + + +def test_make_compatible_name() -> None: + assert make_compatible_name("label/abc") == "label_x2Fabc" + assert make_compatible_name("label?abc") == "label_x3Fabc" diff --git a/tests/test_integration.py b/tests/test_integration.py index 6e874b68fa..4683aa8853 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -88,6 +88,11 @@ def table_test_table_version(catalog: Catalog) -> Table: return catalog.load_table("default.test_table_version") +@pytest.fixture() +def table_test_table_sanitized_character(catalog: Catalog) -> Table: + return catalog.load_table("default.test_table_sanitized_character") + + TABLE_NAME = ("default", "t1") @@ -396,3 +401,11 @@ def test_upgrade_table_version(table_test_table_version: Table) -> None: with table_test_table_version.transaction() as transaction: transaction.upgrade_table_version(format_version=3) assert "Unsupported table format version: 3" in str(e.value) + + +@pytest.mark.integration +def test_reproduce_issue(table_test_table_sanitized_character: Table) -> None: + arrow_table = table_test_table_sanitized_character.scan().to_arrow() + assert len(arrow_table.schema.names), 1 + assert len(table_test_table_sanitized_character.schema().fields), 1 + assert arrow_table.schema.names[0] == table_test_table_sanitized_character.schema().fields[0].name diff --git a/tests/test_schema.py b/tests/test_schema.py index 610298b84a..c12910a4b4 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -28,6 +28,7 @@ build_position_accessors, promote, prune_columns, + sanitize_column_names, ) from pyiceberg.typedef import EMPTY_DICT, StructProtocol from pyiceberg.types import ( @@ -431,6 +432,132 @@ def test_deserialize_schema(table_schema_simple: Schema) -> None: assert actual == expected +def test_sanitize() -> None: + before_sanitized = Schema( + NestedField(field_id=1, name="foo_field/bar", field_type=StringType(), required=True), + NestedField( + field_id=2, + name="foo_list/bar", + field_type=ListType(element_id=3, element_type=StringType(), element_required=True), + required=True, + ), + NestedField( + field_id=4, + name="foo_map/bar", + field_type=MapType( + key_id=5, + key_type=StringType(), + value_id=6, + value_type=MapType(key_id=7, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=8, + name="foo_struct/bar", + field_type=StructType( + NestedField(field_id=9, name="foo_struct_1/bar", field_type=StringType(), required=False), + NestedField(field_id=10, name="foo_struct_2/bar", field_type=IntegerType(), required=True), + ), + required=False, + ), + NestedField( + field_id=11, + name="foo_list_2/bar", + field_type=ListType( + element_id=12, + element_type=StructType( + NestedField(field_id=13, name="foo_list_2_1/bar", field_type=LongType(), required=True), + NestedField(field_id=14, name="foo_list_2_2/bar", field_type=LongType(), required=True), + ), + element_required=False, + ), + required=False, + ), + NestedField( + field_id=15, + name="foo_map_2/bar", + field_type=MapType( + key_id=16, + value_id=17, + key_type=StructType( + NestedField(field_id=18, name="foo_map_2_1/bar", field_type=StringType(), required=True), + ), + value_type=StructType( + NestedField(field_id=19, name="foo_map_2_2/bar", field_type=FloatType(), required=True), + ), + value_required=True, + ), + required=True, + ), + schema_id=1, + identifier_field_ids=[1], + ) + expected_schema = Schema( + NestedField(field_id=1, name="foo_field_x2Fbar", field_type=StringType(), required=True), + NestedField( + field_id=2, + name="foo_list_x2Fbar", + field_type=ListType(element_id=3, element_type=StringType(), element_required=True), + required=True, + ), + NestedField( + field_id=4, + name="foo_map_x2Fbar", + field_type=MapType( + key_id=5, + key_type=StringType(), + value_id=6, + value_type=MapType(key_id=7, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=8, + name="foo_struct_x2Fbar", + field_type=StructType( + NestedField(field_id=9, name="foo_struct_1_x2Fbar", field_type=StringType(), required=False), + NestedField(field_id=10, name="foo_struct_2_x2Fbar", field_type=IntegerType(), required=True), + ), + required=False, + ), + NestedField( + field_id=11, + name="foo_list_2_x2Fbar", + field_type=ListType( + element_id=12, + element_type=StructType( + NestedField(field_id=13, name="foo_list_2_1_x2Fbar", field_type=LongType(), required=True), + NestedField(field_id=14, name="foo_list_2_2_x2Fbar", field_type=LongType(), required=True), + ), + element_required=False, + ), + required=False, + ), + NestedField( + field_id=15, + name="foo_map_2_x2Fbar", + field_type=MapType( + key_id=16, + value_id=17, + key_type=StructType( + NestedField(field_id=18, name="foo_map_2_1_x2Fbar", field_type=StringType(), required=True), + ), + value_type=StructType( + NestedField(field_id=19, name="foo_map_2_2_x2Fbar", field_type=FloatType(), required=True), + ), + value_required=True, + ), + required=True, + ), + schema_id=1, + identifier_field_ids=[1], + ) + assert sanitize_column_names(before_sanitized) == expected_schema + + def test_prune_columns_string(table_schema_nested_with_struct_key_map: Schema) -> None: assert prune_columns(table_schema_nested_with_struct_key_map, {1}, False) == Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=True), schema_id=1, identifier_field_ids=[1]