Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions dev/provision.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,23 @@
);
"""
)

spark.sql(
"""
CREATE TABLE default.test_table_sanitized_character (
`letter/abc` string
)
USING iceberg
TBLPROPERTIES (
'format-version'='1'
);
"""
)

spark.sql(
f"""
INSERT INTO default.test_table_sanitized_character
VALUES
('123')
"""
)
3 changes: 2 additions & 1 deletion pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
pre_order_visit,
promote,
prune_columns,
sanitize_column_names,
visit,
visit_with_partner,
)
Expand Down Expand Up @@ -830,7 +831,7 @@ def _task_to_table(
bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive)
pyarrow_filter = expression_to_pyarrow(bound_file_filter)

file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False)
file_project_schema = sanitize_column_names(prune_columns(file_schema, projected_field_ids, select_full_types=False))

if file_schema is None:
raise ValueError(f"Missing Iceberg schema in Metadata for file: {path}")
Expand Down
94 changes: 94 additions & 0 deletions pyiceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,100 @@ def primitive(self, primitive: PrimitiveType) -> PrimitiveType:
return primitive


# Implementation copied from Apache Iceberg repo.
def make_compatible_name(name: str) -> str:
if not _valid_avro_name(name):
return _sanitize_name(name)
return name


def _valid_avro_name(name: str) -> bool:
length = len(name)
assert length > 0, ValueError("Can not validate empty avro name")
first = name[0]
if not (first.isalpha() or first == '_'):
return False

for character in name[1:]:
if not (character.isalnum() or character == '_'):
return False
return True


def _sanitize_name(name: str) -> str:
sb = []
first = name[0]
if not (first.isalpha() or first == '_'):
sb.append(_sanitize_char(first))
else:
sb.append(first)

for character in name[1:]:
if not (character.isalnum() or character == '_'):
sb.append(_sanitize_char(character))
else:
sb.append(character)
return ''.join(sb)


def _sanitize_char(character: str) -> str:
return "_" + character if character.isdigit() else "_x" + hex(ord(character))[2:].upper()


def sanitize_column_names(schema: Schema) -> Schema:
"""Sanitize column names to make them compatible with Avro.

The column name should be starting with '_' or digit followed by a string only contains '_', digit or alphabet,
otherwise it will be sanitized to conform the avro naming convention.

Args:
schema: The schema to be sanitized.

Returns:
The sanitized schema.
"""
result = visit(schema.as_struct(), _SanitizeColumnsVisitor())
return Schema(
*(result or StructType()).fields,
schema_id=schema.schema_id,
identifier_field_ids=schema.identifier_field_ids,
)


class _SanitizeColumnsVisitor(SchemaVisitor[Optional[IcebergType]]):
def schema(self, schema: Schema, struct_result: Optional[IcebergType]) -> Optional[IcebergType]:
return struct_result

def field(self, field: NestedField, field_result: Optional[IcebergType]) -> Optional[IcebergType]:
return NestedField(
field_id=field.field_id,
name=make_compatible_name(field.name),
field_type=field_result,
doc=field.doc,
required=field.required,
)

def struct(self, struct: StructType, field_results: List[Optional[IcebergType]]) -> Optional[IcebergType]:
return StructType(*[field for field in field_results if field is not None])

def list(self, list_type: ListType, element_result: Optional[IcebergType]) -> Optional[IcebergType]:
return ListType(element_id=list_type.element_id, element_type=element_result, element_required=list_type.element_required)

def map(
self, map_type: MapType, key_result: Optional[IcebergType], value_result: Optional[IcebergType]
) -> Optional[IcebergType]:
return MapType(
key_id=map_type.key_id,
value_id=map_type.value_id,
key_type=key_result,
value_type=value_result,
value_required=map_type.value_required,
)

def primitive(self, primitive: PrimitiveType) -> Optional[IcebergType]:
return primitive


def prune_columns(schema: Schema, selected: Set[int], select_full_types: bool = True) -> Schema:
"""Prunes a column by only selecting a set of field-ids.

Expand Down
7 changes: 6 additions & 1 deletion tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
)
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionSpec
from pyiceberg.schema import Schema, visit
from pyiceberg.schema import Schema, make_compatible_name, visit
from pyiceberg.table import FileScanTask, Table
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.types import (
Expand Down Expand Up @@ -1542,3 +1542,8 @@ def check_results(location: str, expected_schema: str, expected_netloc: str, exp

check_results("/root/foo.txt", "file", "", "/root/foo.txt")
check_results("/root/tmp/foo.txt", "file", "", "/root/tmp/foo.txt")


def test_make_compatible_name() -> None:
assert make_compatible_name("label/abc") == "label_x2Fabc"
assert make_compatible_name("label?abc") == "label_x3Fabc"
13 changes: 13 additions & 0 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def table_test_table_version(catalog: Catalog) -> Table:
return catalog.load_table("default.test_table_version")


@pytest.fixture()
def table_test_table_sanitized_character(catalog: Catalog) -> Table:
return catalog.load_table("default.test_table_sanitized_character")


TABLE_NAME = ("default", "t1")


Expand Down Expand Up @@ -396,3 +401,11 @@ def test_upgrade_table_version(table_test_table_version: Table) -> None:
with table_test_table_version.transaction() as transaction:
transaction.upgrade_table_version(format_version=3)
assert "Unsupported table format version: 3" in str(e.value)


@pytest.mark.integration
def test_reproduce_issue(table_test_table_sanitized_character: Table) -> None:
arrow_table = table_test_table_sanitized_character.scan().to_arrow()
assert len(arrow_table.schema.names), 1
assert len(table_test_table_sanitized_character.schema().fields), 1
assert arrow_table.schema.names[0] == table_test_table_sanitized_character.schema().fields[0].name
127 changes: 127 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
build_position_accessors,
promote,
prune_columns,
sanitize_column_names,
)
from pyiceberg.typedef import EMPTY_DICT, StructProtocol
from pyiceberg.types import (
Expand Down Expand Up @@ -431,6 +432,132 @@ def test_deserialize_schema(table_schema_simple: Schema) -> None:
assert actual == expected


def test_sanitize() -> None:
before_sanitized = Schema(
NestedField(field_id=1, name="foo_field/bar", field_type=StringType(), required=True),
NestedField(
field_id=2,
name="foo_list/bar",
field_type=ListType(element_id=3, element_type=StringType(), element_required=True),
required=True,
),
NestedField(
field_id=4,
name="foo_map/bar",
field_type=MapType(
key_id=5,
key_type=StringType(),
value_id=6,
value_type=MapType(key_id=7, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True),
value_required=True,
),
required=True,
),
NestedField(
field_id=8,
name="foo_struct/bar",
field_type=StructType(
NestedField(field_id=9, name="foo_struct_1/bar", field_type=StringType(), required=False),
NestedField(field_id=10, name="foo_struct_2/bar", field_type=IntegerType(), required=True),
),
required=False,
),
NestedField(
field_id=11,
name="foo_list_2/bar",
field_type=ListType(
element_id=12,
element_type=StructType(
NestedField(field_id=13, name="foo_list_2_1/bar", field_type=LongType(), required=True),
NestedField(field_id=14, name="foo_list_2_2/bar", field_type=LongType(), required=True),
),
element_required=False,
),
required=False,
),
NestedField(
field_id=15,
name="foo_map_2/bar",
field_type=MapType(
key_id=16,
value_id=17,
key_type=StructType(
NestedField(field_id=18, name="foo_map_2_1/bar", field_type=StringType(), required=True),
),
value_type=StructType(
NestedField(field_id=19, name="foo_map_2_2/bar", field_type=FloatType(), required=True),
),
value_required=True,
),
required=True,
),
schema_id=1,
identifier_field_ids=[1],
)
expected_schema = Schema(
NestedField(field_id=1, name="foo_field_x2Fbar", field_type=StringType(), required=True),
NestedField(
field_id=2,
name="foo_list_x2Fbar",
field_type=ListType(element_id=3, element_type=StringType(), element_required=True),
required=True,
),
NestedField(
field_id=4,
name="foo_map_x2Fbar",
field_type=MapType(
key_id=5,
key_type=StringType(),
value_id=6,
value_type=MapType(key_id=7, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=True),
value_required=True,
),
required=True,
),
NestedField(
field_id=8,
name="foo_struct_x2Fbar",
field_type=StructType(
NestedField(field_id=9, name="foo_struct_1_x2Fbar", field_type=StringType(), required=False),
NestedField(field_id=10, name="foo_struct_2_x2Fbar", field_type=IntegerType(), required=True),
),
required=False,
),
NestedField(
field_id=11,
name="foo_list_2_x2Fbar",
field_type=ListType(
element_id=12,
element_type=StructType(
NestedField(field_id=13, name="foo_list_2_1_x2Fbar", field_type=LongType(), required=True),
NestedField(field_id=14, name="foo_list_2_2_x2Fbar", field_type=LongType(), required=True),
),
element_required=False,
),
required=False,
),
NestedField(
field_id=15,
name="foo_map_2_x2Fbar",
field_type=MapType(
key_id=16,
value_id=17,
key_type=StructType(
NestedField(field_id=18, name="foo_map_2_1_x2Fbar", field_type=StringType(), required=True),
),
value_type=StructType(
NestedField(field_id=19, name="foo_map_2_2_x2Fbar", field_type=FloatType(), required=True),
),
value_required=True,
),
required=True,
),
schema_id=1,
identifier_field_ids=[1],
)
assert sanitize_column_names(before_sanitized) == expected_schema


def test_prune_columns_string(table_schema_nested_with_struct_key_map: Schema) -> None:
assert prune_columns(table_schema_nested_with_struct_key_map, {1}, False) == Schema(
NestedField(field_id=1, name="foo", field_type=StringType(), required=True), schema_id=1, identifier_field_ids=[1]
Expand Down