From 68081491641b0d7bada13a18b98ded3e08e127a6 Mon Sep 17 00:00:00 2001 From: Pucheng Yang Date: Mon, 16 Oct 2023 17:41:02 -0700 Subject: [PATCH 1/5] issue --- dev/provision.py | 20 ++++++++++++++++++++ tests/test_integration.py | 10 ++++++++++ 2 files changed, 30 insertions(+) diff --git a/dev/provision.py b/dev/provision.py index b75030f8a3..ca6e5aa6aa 100644 --- a/dev/provision.py +++ b/dev/provision.py @@ -301,3 +301,23 @@ ); """ ) + +spark.sql( + """ +CREATE TABLE default.test_table_sanitized_character ( + `letter/abc` string +) +USING iceberg +TBLPROPERTIES ( + 'format-version'='1' +); +""" +) + +spark.sql( + f""" +INSERT INTO default.test_table_sanitized_character +VALUES + ('123') +""" +) diff --git a/tests/test_integration.py b/tests/test_integration.py index 6e874b68fa..f839878468 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -396,3 +396,13 @@ def test_upgrade_table_version(table_test_table_version: Table) -> None: with table_test_table_version.transaction() as transaction: transaction.upgrade_table_version(format_version=3) assert "Unsupported table format version: 3" in str(e.value) + + +@pytest.fixture() +def table_test_table_sanitized_character(catalog: Catalog) -> Table: + return catalog.load_table("default.test_table_sanitized_character") + + +@pytest.mark.integration +def test_reproduce_issue(table_test_table_sanitized_character: Table) -> None: + table = table_test_table_sanitized_character.scan().to_arrow() \ No newline at end of file From 57891c783416f68d5af1f72ede346f61cbe69261 Mon Sep 17 00:00:00 2001 From: Pucheng Yang Date: Tue, 17 Oct 2023 13:20:11 -0700 Subject: [PATCH 2/5] update --- pyiceberg/io/pyarrow.py | 3 +- pyiceberg/schema.py | 98 +++++++++++++++++++++++++++++ tests/io/test_pyarrow.py | 7 ++- tests/test_integration.py | 15 +++-- tests/test_schema.py | 129 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 244 insertions(+), 8 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 29d4a4b170..00ace88e82 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -113,6 +113,7 @@ pre_order_visit, promote, prune_columns, + sanitize_columns, visit, visit_with_partner, ) @@ -830,7 +831,7 @@ def _task_to_table( bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) pyarrow_filter = expression_to_pyarrow(bound_file_filter) - file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) + file_project_schema = sanitize_columns(prune_columns(file_schema, projected_field_ids, select_full_types=False)) if file_schema is None: raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}") diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index 28101809c7..bcecf970a6 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1273,6 +1273,104 @@ def primitive(self, primitive: PrimitiveType) -> PrimitiveType: return primitive +def make_compatible_name(name: str) -> str: + if not _valid_avro_name(name): + return _sanitize_name(name) + return name + + +def _valid_avro_name(name: str) -> bool: + length = len(name) + assert length > 0, "Empty name" + first = name[0] + if not (first.isalpha() or first == '_'): + return False + + for i in range(1, length): + character = name[i] + if not (character.isalnum() or character == '_'): + return False + return True + + +def _sanitize_name(name: str) -> str: + length = len(name) + sb = [] + first = name[0] + if not (first.isalpha() or first == '_'): + sb.append(_sanitize_char(first)) + else: + sb.append(first) + + for i in range(1, length): + character = name[i] + if not (character.isalnum() or character == '_'): + sb.append(_sanitize_char(character)) + else: + sb.append(character) + return ''.join(sb) + + +def _sanitize_char(character: str) -> str: + if character.isdigit(): + return "_" + character + return "_x" + hex(ord(character))[2:].upper() + + +class _SanitizeColumnsVisitor(SchemaVisitor[Optional[IcebergType]]): + def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]: + return struct_result + + def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]: + return NestedField( + field_id=field.field_id, + name=make_compatible_name(field.name), + field_type=field_result, + doc=field.doc, + required=field.required, + ) + + def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]: + # field_results = [field for field in field_results if field.get() is not None] + return StructType(*field_results) + + def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]: + return ListType(element_id=list_type.element_id, element_type=element_result, element_required=list_type.element_required) + + def map( + self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType] + ) -> Optional[IcebergType]: + return MapType( + key_id=map_type.key_id, + value_id=map_type.value_id, + key_type=key_result, + value_type=value_result, + value_required=map_type.value_required, + ) + + def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: + return primitive + + +def sanitize_columns(schema: Schema) -> Schema: + """Prunes a column by only selecting a set of field-ids. + + Args: + schema: The schema to be pruned. + selected: The field-ids to be included. + select_full_types: Return the full struct when a subset is recorded + + Returns: + The pruned schema. + """ + result = visit(schema.as_struct(), _SanitizeColumnsVisitor()) + return Schema( + *(result or StructType()).fields, + schema_id=schema.schema_id, + identifier_field_ids=schema.identifier_field_ids, + ) + + def prune_columns(schema: Schema, selected: Set[int], select_full_types: bool = True) -> Schema: """Prunes a column by only selecting a set of field-ids. diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 8b62212593..7174c91425 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -66,7 +66,7 @@ ) from pyiceberg.manifest import DataFile, DataFileContent, FileFormat from pyiceberg.partitioning import PartitionSpec -from pyiceberg.schema import Schema, visit +from pyiceberg.schema import Schema, make_compatible_name, visit from pyiceberg.table import FileScanTask, Table from pyiceberg.table.metadata import TableMetadataV2 from pyiceberg.types import ( @@ -1542,3 +1542,8 @@ def check_results(location: str, expected_schema: str, expected_netloc: str, exp check_results("/root/foo.txt", "file", "", "/root/foo.txt") check_results("/root/tmp/foo.txt", "file", "", "/root/tmp/foo.txt") + + +def test_make_compatible_name() -> None: + assert make_compatible_name("label/abc") == "label_x2Fabc" + assert make_compatible_name("label?abc") == "label_x3Fabc" diff --git a/tests/test_integration.py b/tests/test_integration.py index f839878468..59db10de58 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -87,6 +87,11 @@ def table_test_all_types(catalog: Catalog) -> Table: def table_test_table_version(catalog: Catalog) -> Table: return catalog.load_table("default.test_table_version") +@pytest.fixture() +def table_test_table_sanitized_character(catalog: Catalog) -> Table: + return catalog.load_table("default.test_table_sanitized_character") + + TABLE_NAME = ("default", "t1") @@ -398,11 +403,9 @@ def test_upgrade_table_version(table_test_table_version: Table) -> None: assert "Unsupported table format version: 3" in str(e.value) -@pytest.fixture() -def table_test_table_sanitized_character(catalog: Catalog) -> Table: - return catalog.load_table("default.test_table_sanitized_character") - - @pytest.mark.integration def test_reproduce_issue(table_test_table_sanitized_character: Table) -> None: - table = table_test_table_sanitized_character.scan().to_arrow() \ No newline at end of file + arrow_table = table_test_table_sanitized_character.scan().to_arrow() + assert len(arrow_table.schema.names), 1 + assert len(table_test_table_sanitized_character.schema().fields), 1 + assert arrow_table.schema.names[0] == table_test_table_sanitized_character.schema().fields[0].name diff --git a/tests/test_schema.py b/tests/test_schema.py index 610298b84a..2f20ba0910 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -28,6 +28,7 @@ build_position_accessors, promote, prune_columns, + sanitize_columns, ) from pyiceberg.typedef import EMPTY_DICT, StructProtocol from pyiceberg.types import ( @@ -431,6 +432,134 @@ def test_deserialize_schema(table_schema_simple: Schema) -> None: assert actual == expected +def test_sanitize() -> None: + before_sanitized = Schema( + NestedField(field_id=1, name="foo_field/bar", field_type=StringType(), required=True), + NestedField( + field_id=2, + name="foo_list/bar", + field_type=ListType(element_id=3, element_type=StringType(), element_required=True), + required=True, + ), + NestedField( + field_id=4, + name="foo_map/bar", + field_type=MapType( + key_id=5, + key_type=StringType(), + value_id=6, + value_type=MapType(key_id=7, key_type=StringType(), value_id=10, value_type=IntegerType(), + value_required=True), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=8, + name="foo_struct/bar", + field_type=StructType( + NestedField(field_id=9, name="foo_struct_1/bar", field_type=StringType(), required=False), + NestedField(field_id=10, name="foo_struct_2/bar", field_type=IntegerType(), required=True), + ), + required=False, + ), + NestedField( + field_id=11, + name="foo_list_2/bar", + field_type=ListType( + element_id=12, + element_type=StructType( + NestedField(field_id=13, name="foo_list_2_1/bar", field_type=LongType(), required=True), + NestedField(field_id=14, name="foo_list_2_2/bar", field_type=LongType(), required=True), + ), + element_required=False, + ), + required=False, + ), + NestedField( + field_id=15, + name="foo_map_2/bar", + field_type=MapType( + key_id=16, + value_id=17, + key_type=StructType( + NestedField(field_id=18, name="foo_map_2_1/bar", field_type=StringType(), required=True), + ), + value_type=StructType( + NestedField(field_id=19, name="foo_map_2_2/bar", field_type=FloatType(), required=True), + ), + value_required=True, + ), + required=True, + ), + schema_id=1, + identifier_field_ids=[1], + ) + expected_schema = Schema( + NestedField(field_id=1, name="foo_field_x2Fbar", field_type=StringType(), required=True), + NestedField( + field_id=2, + name="foo_list_x2Fbar", + field_type=ListType(element_id=3, element_type=StringType(), element_required=True), + required=True, + ), + NestedField( + field_id=4, + name="foo_map_x2Fbar", + field_type=MapType( + key_id=5, + key_type=StringType(), + value_id=6, + value_type=MapType(key_id=7, key_type=StringType(), value_id=10, value_type=IntegerType(), + value_required=True), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=8, + name="foo_struct_x2Fbar", + field_type=StructType( + NestedField(field_id=9, name="foo_struct_1_x2Fbar", field_type=StringType(), required=False), + NestedField(field_id=10, name="foo_struct_2_x2Fbar", field_type=IntegerType(), required=True), + ), + required=False, + ), + NestedField( + field_id=11, + name="foo_list_2_x2Fbar", + field_type=ListType( + element_id=12, + element_type=StructType( + NestedField(field_id=13, name="foo_list_2_1_x2Fbar", field_type=LongType(), required=True), + NestedField(field_id=14, name="foo_list_2_2_x2Fbar", field_type=LongType(), required=True), + ), + element_required=False, + ), + required=False, + ), + NestedField( + field_id=15, + name="foo_map_2_x2Fbar", + field_type=MapType( + key_id=16, + value_id=17, + key_type=StructType( + NestedField(field_id=18, name="foo_map_2_1_x2Fbar", field_type=StringType(), required=True), + ), + value_type=StructType( + NestedField(field_id=19, name="foo_map_2_2_x2Fbar", field_type=FloatType(), required=True), + ), + value_required=True, + ), + required=True, + ), + schema_id=1, + identifier_field_ids=[1], + ) + assert sanitize_columns(before_sanitized) == expected_schema + + def test_prune_columns_string(table_schema_nested_with_struct_key_map: Schema) -> None: assert prune_columns(table_schema_nested_with_struct_key_map, {1}, False) == Schema( NestedField(field_id=1, name="foo", field_type=StringType(), required=True), schema_id=1, identifier_field_ids=[1] From aa03dd410800d7cbbed6bd90f7b7187806725001 Mon Sep 17 00:00:00 2001 From: Pucheng Yang Date: Tue, 17 Oct 2023 13:33:36 -0700 Subject: [PATCH 3/5] update --- pyiceberg/io/pyarrow.py | 4 ++-- pyiceberg/schema.py | 40 +++++++++++++++++++-------------------- tests/test_integration.py | 2 +- tests/test_schema.py | 10 ++++------ 4 files changed, 26 insertions(+), 30 deletions(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 00ace88e82..a3d1869cf8 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -113,7 +113,7 @@ pre_order_visit, promote, prune_columns, - sanitize_columns, + sanitize_column_names, visit, visit_with_partner, ) @@ -831,7 +831,7 @@ def _task_to_table( bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) pyarrow_filter = expression_to_pyarrow(bound_file_filter) - file_project_schema = sanitize_columns(prune_columns(file_schema, projected_field_ids, select_full_types=False)) + file_project_schema = sanitize_column_names(prune_columns(file_schema, projected_field_ids, select_full_types=False)) if file_schema is None: raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}") diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index bcecf970a6..a6e513f969 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1273,6 +1273,7 @@ def primitive(self, primitive: PrimitiveType) -> PrimitiveType: return primitive +# Implementation copied from Apache Iceberg repo. def make_compatible_name(name: str) -> str: if not _valid_avro_name(name): return _sanitize_name(name) @@ -1317,6 +1318,23 @@ def _sanitize_char(character: str) -> str: return "_x" + hex(ord(character))[2:].upper() +def sanitize_column_names(schema: Schema) -> Schema: + """Sanitize column names to make them compatible with Avro. + + Args: + schema: The schema to be sanitized. + + Returns: + The sanitized schema. + """ + result = visit(schema.as_struct(), _SanitizeColumnsVisitor()) + return Schema( + *(result or StructType()).fields, + schema_id=schema.schema_id, + identifier_field_ids=schema.identifier_field_ids, + ) + + class _SanitizeColumnsVisitor(SchemaVisitor[Optional[IcebergType]]): def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]: return struct_result @@ -1331,8 +1349,7 @@ def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Opti ) def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]: - # field_results = [field for field in field_results if field.get() is not None] - return StructType(*field_results) + return StructType(*[field.get() for field in field_results if field is not None]) def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]: return ListType(element_id=list_type.element_id, element_type=element_result, element_required=list_type.element_required) @@ -1352,25 +1369,6 @@ def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]: return primitive -def sanitize_columns(schema: Schema) -> Schema: - """Prunes a column by only selecting a set of field-ids. - - Args: - schema: The schema to be pruned. - selected: The field-ids to be included. - select_full_types: Return the full struct when a subset is recorded - - Returns: - The pruned schema. - """ - result = visit(schema.as_struct(), _SanitizeColumnsVisitor()) - return Schema( - *(result or StructType()).fields, - schema_id=schema.schema_id, - identifier_field_ids=schema.identifier_field_ids, - ) - - def prune_columns(schema: Schema, selected: Set[int], select_full_types: bool = True) -> Schema: """Prunes a column by only selecting a set of field-ids. diff --git a/tests/test_integration.py b/tests/test_integration.py index 59db10de58..4683aa8853 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -87,12 +87,12 @@ def table_test_all_types(catalog: Catalog) -> Table: def table_test_table_version(catalog: Catalog) -> Table: return catalog.load_table("default.test_table_version") + @pytest.fixture() def table_test_table_sanitized_character(catalog: Catalog) -> Table: return catalog.load_table("default.test_table_sanitized_character") - TABLE_NAME = ("default", "t1") diff --git a/tests/test_schema.py b/tests/test_schema.py index 2f20ba0910..c12910a4b4 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -28,7 +28,7 @@ build_position_accessors, promote, prune_columns, - sanitize_columns, + sanitize_column_names, ) from pyiceberg.typedef import EMPTY_DICT, StructProtocol from pyiceberg.types import ( @@ -448,8 +448,7 @@ def test_sanitize() -> None: key_id=5, key_type=StringType(), value_id=6, - value_type=MapType(key_id=7, key_type=StringType(), value_id=10, value_type=IntegerType(), - value_required=True), + value_type=MapType(key_id=7, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True), value_required=True, ), required=True, @@ -510,8 +509,7 @@ def test_sanitize() -> None: key_id=5, key_type=StringType(), value_id=6, - value_type=MapType(key_id=7, key_type=StringType(), value_id=10, value_type=IntegerType(), - value_required=True), + value_type=MapType(key_id=7, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True), value_required=True, ), required=True, @@ -557,7 +555,7 @@ def test_sanitize() -> None: schema_id=1, identifier_field_ids=[1], ) - assert sanitize_columns(before_sanitized) == expected_schema + assert sanitize_column_names(before_sanitized) == expected_schema def test_prune_columns_string(table_schema_nested_with_struct_key_map: Schema) -> None: From d17e823bf7981eee9c8cceff31a4379e101203b6 Mon Sep 17 00:00:00 2001 From: Pucheng Yang Date: Tue, 17 Oct 2023 14:02:43 -0700 Subject: [PATCH 4/5] ut fix --- pyiceberg/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index a6e513f969..1528826b19 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1349,7 +1349,7 @@ def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Opti ) def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]: - return StructType(*[field.get() for field in field_results if field is not None]) + return StructType(*[field for field in field_results if field is not None]) def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]: return ListType(element_id=list_type.element_id, element_type=element_result, element_required=list_type.element_required) From 012533c4a2e78cbe659e29824ab33e39db95e1e8 Mon Sep 17 00:00:00 2001 From: Pucheng Yang Date: Thu, 19 Oct 2023 10:24:14 -0700 Subject: [PATCH 5/5] address comments --- pyiceberg/schema.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index 1528826b19..7fdd717858 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1282,20 +1282,18 @@ def make_compatible_name(name: str) -> str: def _valid_avro_name(name: str) -> bool: length = len(name) - assert length > 0, "Empty name" + assert length > 0, ValueError("Can not validate empty avro name") first = name[0] if not (first.isalpha() or first == '_'): return False - for i in range(1, length): - character = name[i] + for character in name[1:]: if not (character.isalnum() or character == '_'): return False return True def _sanitize_name(name: str) -> str: - length = len(name) sb = [] first = name[0] if not (first.isalpha() or first == '_'): @@ -1303,8 +1301,7 @@ def _sanitize_name(name: str) -> str: else: sb.append(first) - for i in range(1, length): - character = name[i] + for character in name[1:]: if not (character.isalnum() or character == '_'): sb.append(_sanitize_char(character)) else: @@ -1313,14 +1310,15 @@ def _sanitize_name(name: str) -> str: def _sanitize_char(character: str) -> str: - if character.isdigit(): - return "_" + character - return "_x" + hex(ord(character))[2:].upper() + return "_" + character if character.isdigit() else "_x" + hex(ord(character))[2:].upper() def sanitize_column_names(schema: Schema) -> Schema: """Sanitize column names to make them compatible with Avro. + The column name should be starting with '_' or digit followed by a string only contains '_', digit or alphabet, + otherwise it will be sanitized to conform the avro naming convention. + Args: schema: The schema to be sanitized.