diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 738cd77bfd..06d03e21e1 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1772,7 +1772,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT ) def write_parquet(task: WriteTask) -> DataFile: - file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}' + file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}' fo = io.new_output(file_path) with fo.create(overwrite=True) as fos: with pq.ParquetWriter(fos, schema=arrow_file_schema, **parquet_writer_kwargs) as writer: @@ -1787,7 +1787,7 @@ def write_parquet(task: WriteTask) -> DataFile: content=DataFileContent.DATA, file_path=file_path, file_format=FileFormat.PARQUET, - partition=Record(), + partition=task.partition_key.partition if task.partition_key else Record(), file_size_in_bytes=len(fo), # After this has been fixed: # https://github.com/apache/iceberg-python/issues/271 diff --git a/pyiceberg/manifest.py b/pyiceberg/manifest.py index 5277eed9e6..3b8138b61a 100644 --- a/pyiceberg/manifest.py +++ b/pyiceberg/manifest.py @@ -19,7 +19,6 @@ import math from abc import ABC, abstractmethod from enum import Enum -from functools import singledispatch from types import TracebackType from typing import ( Any, @@ -41,8 +40,6 @@ from pyiceberg.types import ( BinaryType, BooleanType, - DateType, - IcebergType, IntegerType, ListType, LongType, @@ -51,9 +48,6 @@ PrimitiveType, StringType, StructType, - TimestampType, - TimestamptzType, - TimeType, ) UNASSIGNED_SEQ = -1 @@ -283,31 +277,12 @@ def __repr__(self) -> str: } -@singledispatch -def partition_field_to_data_file_partition_field(partition_field_type: IcebergType) -> PrimitiveType: - raise TypeError(f"Unsupported partition field type: {partition_field_type}") - - -@partition_field_to_data_file_partition_field.register(LongType) -@partition_field_to_data_file_partition_field.register(DateType) -@partition_field_to_data_file_partition_field.register(TimeType) -@partition_field_to_data_file_partition_field.register(TimestampType) -@partition_field_to_data_file_partition_field.register(TimestamptzType) -def _(partition_field_type: PrimitiveType) -> IntegerType: - return IntegerType() - - -@partition_field_to_data_file_partition_field.register(PrimitiveType) -def _(partition_field_type: PrimitiveType) -> PrimitiveType: - return partition_field_type - - def data_file_with_partition(partition_type: StructType, format_version: TableVersion) -> StructType: data_file_partition_type = StructType(*[ NestedField( field_id=field.field_id, name=field.name, - field_type=partition_field_to_data_file_partition_field(field.field_type), + field_type=field.field_type, required=field.required, ) for field in partition_type.fields diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index 16f158828d..a3cf255341 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -19,7 +19,7 @@ import uuid from abc import ABC, abstractmethod from dataclasses import dataclass -from datetime import date, datetime +from datetime import date, datetime, time from functools import cached_property, singledispatch from typing import ( Any, @@ -62,9 +62,10 @@ StructType, TimestampType, TimestamptzType, + TimeType, UUIDType, ) -from pyiceberg.utils.datetime import date_to_days, datetime_to_micros +from pyiceberg.utils.datetime import date_to_days, datetime_to_micros, time_to_micros INITIAL_PARTITION_SPEC_ID = 0 PARTITION_FIELD_ID_START: int = 1000 @@ -431,6 +432,11 @@ def _(type: IcebergType, value: Optional[date]) -> Optional[int]: return date_to_days(value) if value is not None else None +@_to_partition_representation.register(TimeType) +def _(type: IcebergType, value: Optional[time]) -> Optional[int]: + return time_to_micros(value) if value is not None else None + + @_to_partition_representation.register(UUIDType) def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]: return str(value) if value is not None else None diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index e183d82775..2dbc32d893 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -16,13 +16,13 @@ # under the License. from __future__ import annotations -import datetime import itertools import uuid import warnings from abc import ABC, abstractmethod from copy import copy from dataclasses import dataclass +from datetime import datetime from enum import Enum from functools import cached_property, singledispatch from itertools import chain @@ -79,6 +79,8 @@ PARTITION_FIELD_ID_START, UNPARTITIONED_PARTITION_SPEC, PartitionField, + PartitionFieldValue, + PartitionKey, PartitionSpec, _PartitionNameGenerator, _visit_partition_field, @@ -373,8 +375,11 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - if len(self._table.spec().fields) > 0: - raise ValueError("Cannot write to partitioned tables") + supported_transforms = {IdentityTransform} + if not all(type(field.transform) in supported_transforms for field in self.table_metadata.spec().fields): + raise ValueError( + f"All transforms are not supported, expected: {supported_transforms}, but get: {[str(field) for field in self.table_metadata.spec().fields if field.transform not in supported_transforms]}." + ) _check_schema_compatible(self._table.schema(), other_schema=df.schema) # cast if the two schemas are compatible but not equal @@ -897,7 +902,7 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl if update.ref_name == MAIN_BRANCH: metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id if "last_updated_ms" not in metadata_updates: - metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.datetime.now().astimezone()) + metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.now().astimezone()) metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [ SnapshotLogEntry( @@ -2646,16 +2651,23 @@ def _add_and_move_fields( class WriteTask: write_uuid: uuid.UUID task_id: int + schema: Schema record_batches: List[pa.RecordBatch] sort_order_id: Optional[int] = None - - # Later to be extended with partition information + partition_key: Optional[PartitionKey] = None def generate_data_file_filename(self, extension: str) -> str: # Mimics the behavior in the Java API: # https://github.com/apache/iceberg/blob/a582968975dd30ff4917fbbe999f1be903efac02/core/src/main/java/org/apache/iceberg/io/OutputFileFactory.java#L92-L101 return f"00000-{self.task_id}-{self.write_uuid}.{extension}" + def generate_data_file_path(self, extension: str) -> str: + if self.partition_key: + file_path = f"{self.partition_key.to_path()}/{self.generate_data_file_filename(extension)}" + return file_path + else: + return self.generate_data_file_filename(extension) + @dataclass(frozen=True) class AddFileTask: @@ -2683,25 +2695,40 @@ def _dataframe_to_data_files( """ from pyiceberg.io.pyarrow import bin_pack_arrow_table, write_file - if len([spec for spec in table_metadata.partition_specs if spec.spec_id != 0]) > 0: - raise ValueError("Cannot write to partitioned tables") - counter = itertools.count(0) write_uuid = write_uuid or uuid.uuid4() - - target_file_size = PropertyUtil.property_as_int( + target_file_size: int = PropertyUtil.property_as_int( # type: ignore # The property is set with non-None value. properties=table_metadata.properties, property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES, default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT, ) - # This is an iter, so we don't have to materialize everything every time - # This will be more relevant when we start doing partitioned writes - yield from write_file( - io=io, - table_metadata=table_metadata, - tasks=iter([WriteTask(write_uuid, next(counter), batches) for batches in bin_pack_arrow_table(df, target_file_size)]), # type: ignore - ) + if len(table_metadata.spec().fields) > 0: + partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df) + yield from write_file( + io=io, + table_metadata=table_metadata, + tasks=iter([ + WriteTask( + write_uuid=write_uuid, + task_id=next(counter), + record_batches=batches, + partition_key=partition.partition_key, + schema=table_metadata.schema(), + ) + for partition in partitions + for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size) + ]), + ) + else: + yield from write_file( + io=io, + table_metadata=table_metadata, + tasks=iter([ + WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema()) + for batches in bin_pack_arrow_table(df, target_file_size) + ]), + ) def _parquet_files_to_data_files(table_metadata: TableMetadata, file_paths: List[str], io: FileIO) -> Iterable[DataFile]: @@ -3253,7 +3280,7 @@ def snapshots(self) -> "pa.Table": additional_properties = None snapshots.append({ - 'committed_at': datetime.datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0), + 'committed_at': datetime.utcfromtimestamp(snapshot.timestamp_ms / 1000.0), 'snapshot_id': snapshot.snapshot_id, 'parent_id': snapshot.parent_snapshot_id, 'operation': str(operation), @@ -3388,3 +3415,112 @@ def _readable_metrics_struct(bound_type: PrimitiveType) -> pa.StructType: entries, schema=entries_schema, ) + + +@dataclass(frozen=True) +class TablePartition: + partition_key: PartitionKey + arrow_table_partition: pa.Table + + +def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]: + order = 'ascending' if not reverse else 'descending' + null_placement = 'at_start' if reverse else 'at_end' + return {'sort_keys': [(column_name, order) for column_name in partition_columns], 'null_placement': null_placement} + + +def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table: + """Given a table, sort it by current partition scheme.""" + # only works for identity for now + sort_options = _get_partition_sort_order(partition_columns, reverse=False) + sorted_arrow_table = arrow_table.sort_by(sorting=sort_options['sort_keys'], null_placement=sort_options['null_placement']) + return sorted_arrow_table + + +def get_partition_columns( + spec: PartitionSpec, + schema: Schema, +) -> list[str]: + partition_cols = [] + for partition_field in spec.fields: + column_name = schema.find_column_name(partition_field.source_id) + if not column_name: + raise ValueError(f"{partition_field=} could not be found in {schema}.") + partition_cols.append(column_name) + return partition_cols + + +def _get_table_partitions( + arrow_table: pa.Table, + partition_spec: PartitionSpec, + schema: Schema, + slice_instructions: list[dict[str, Any]], +) -> list[TablePartition]: + sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x['offset']) + + partition_fields = partition_spec.fields + + offsets = [inst["offset"] for inst in sorted_slice_instructions] + projected_and_filtered = { + partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name] + .take(offsets) + .to_pylist() + for partition_field in partition_fields + } + + table_partitions = [] + for idx, inst in enumerate(sorted_slice_instructions): + partition_slice = arrow_table.slice(**inst) + fieldvalues = [ + PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx]) + for partition_field in partition_fields + ] + partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema) + table_partitions.append(TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice)) + return table_partitions + + +def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[TablePartition]: + """Based on the iceberg table partition spec, slice the arrow table into partitions with their keys. + + Example: + Input: + An arrow table with partition key of ['n_legs', 'year'] and with data of + {'year': [2020, 2022, 2022, 2021, 2022, 2022, 2022, 2019, 2021], + 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100], + 'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}. + The algrithm: + Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')] + and null_placement of "at_end". + This gives the same table as raw input. + Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')] + and null_placement : "at_start". + This gives: + [8, 7, 4, 5, 6, 3, 1, 2, 0] + Based on this we get partition groups of indices: + [{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}] + We then retrieve the partition keys by offsets. + And slice the arrow table by offsets and lengths of each partition. + """ + import pyarrow as pa + + partition_columns = get_partition_columns(spec=spec, schema=schema) + arrow_table = group_by_partition_scheme(arrow_table, partition_columns) + + reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True) + reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist() + + slice_instructions: list[dict[str, Any]] = [] + last = len(reversed_indices) + reversed_indices_size = len(reversed_indices) + ptr = 0 + while ptr < reversed_indices_size: + group_size = last - reversed_indices[ptr] + offset = reversed_indices[ptr] + slice_instructions.append({"offset": offset, "length": group_size}) + last = reversed_indices[ptr] + ptr = ptr + group_size + + table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) + + return table_partitions diff --git a/pyiceberg/typedef.py b/pyiceberg/typedef.py index 4bed386c77..6ccf9526ba 100644 --- a/pyiceberg/typedef.py +++ b/pyiceberg/typedef.py @@ -202,5 +202,9 @@ def record_fields(self) -> List[str]: """Return values of all the fields of the Record class except those specified in skip_fields.""" return [self.__getattribute__(v) if hasattr(self, v) else None for v in self._position_to_field_name] + def __hash__(self) -> int: + """Return hash value of the Record class.""" + return hash(str(self)) + TableVersion: TypeAlias = Literal[1, 2] diff --git a/tests/conftest.py b/tests/conftest.py index aa09517b6a..4a820fedec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ import socket import string import uuid -from datetime import date, datetime +from datetime import date, datetime, timezone from pathlib import Path from random import choice from tempfile import TemporaryDirectory @@ -1999,8 +1999,13 @@ def spark() -> "SparkSession": 'long': [1, None, 9], 'float': [0.0, None, 0.9], 'double': [0.0, None, 0.9], + # 'time': [1_000_000, None, 3_000_000], # Example times: 1s, none, and 3s past midnight #Spark does not support time fields 'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - 'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], + 'timestamptz': [ + datetime(2023, 1, 1, 19, 25, 00, tzinfo=timezone.utc), + None, + datetime(2023, 3, 1, 19, 25, 00, tzinfo=timezone.utc), + ], 'date': [date(2023, 1, 1), None, date(2023, 3, 1)], # Not supported by Spark # 'time': [time(1, 22, 0), None, time(19, 25, 0)], @@ -2027,6 +2032,8 @@ def pa_schema() -> "pa.Schema": ("long", pa.int64()), ("float", pa.float32()), ("double", pa.float64()), + # Not supported by Spark + # ("time", pa.time64('us')), ("timestamp", pa.timestamp(unit="us")), ("timestamptz", pa.timestamp(unit="us", tz="UTC")), ("date", pa.date32()), @@ -2041,7 +2048,23 @@ def pa_schema() -> "pa.Schema": @pytest.fixture(scope="session") def arrow_table_with_null(pa_schema: "pa.Schema") -> "pa.Table": + """Pyarrow table with all kinds of columns.""" import pyarrow as pa - """Pyarrow table with all kinds of columns.""" return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema) + + +@pytest.fixture(scope="session") +def arrow_table_without_data(pa_schema: "pa.Schema") -> "pa.Table": + """Pyarrow table without data.""" + import pyarrow as pa + + return pa.Table.from_pylist([], schema=pa_schema) + + +@pytest.fixture(scope="session") +def arrow_table_with_only_nulls(pa_schema: "pa.Schema") -> "pa.Table": + """Pyarrow table with only null values.""" + import pyarrow as pa + + return pa.Table.from_pylist([{}, {}], schema=pa_schema) diff --git a/tests/integration/test_partitioning_key.py b/tests/integration/test_partitioning_key.py index 12056bac1e..d89ecaf202 100644 --- a/tests/integration/test_partitioning_key.py +++ b/tests/integration/test_partitioning_key.py @@ -749,7 +749,7 @@ def test_partition_key( # key.to_path() generates the hive partitioning part of the to-write parquet file path assert key.to_path() == expected_hive_partition_path_slice - # Justify expected values are not made up but conform to spark behaviors + # Justify expected values are not made up but conforming to spark behaviors if spark_create_table_sql_for_justification is not None and spark_data_insert_sql_for_justification is not None: try: spark.sql(f"drop table {identifier}") diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py new file mode 100644 index 0000000000..d84b9745a7 --- /dev/null +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -0,0 +1,386 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name + +import pyarrow as pa +import pytest +from pyspark.sql import SparkSession + +from pyiceberg.catalog import Catalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.transforms import ( + BucketTransform, + DayTransform, + HourTransform, + IdentityTransform, + MonthTransform, + TruncateTransform, + YearTransform, +) +from tests.conftest import TEST_DATA_WITH_NULL +from utils import TABLE_SCHEMA, _create_table + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz', 'binary'] +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_filter_null_partitioned( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int +) -> None: + # Given + identifier = f"default.arrow_table_v{format_version}_with_null_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], + partition_spec=partition_spec, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in TEST_DATA_WITH_NULL.keys(): + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz', 'binary'] +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_filter_without_data_partitioned( + session_catalog: Catalog, spark: SparkSession, arrow_table_without_data: pa.Table, part_col: str, format_version: int +) -> None: + # Given + identifier = f"default.arrow_table_v{format_version}_without_data_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_without_data], + partition_spec=partition_spec, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + for col in TEST_DATA_WITH_NULL.keys(): + assert df.where(f"{col} is null").count() == 0, f"Expected 0 row for {col}" + assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row for {col}" + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", 'timestamp', 'timestamptz', 'binary'] +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_filter_only_nulls_partitioned( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_only_nulls: pa.Table, part_col: str, format_version: int +) -> None: + # Given + identifier = f"default.arrow_table_v{format_version}_with_only_nulls_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_only_nulls], + partition_spec=partition_spec, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + for col in TEST_DATA_WITH_NULL.keys(): + assert df.where(f"{col} is null").count() == 2, f"Expected 2 row for {col}" + assert df.where(f"{col} is not null").count() == 0, f"Expected 0 rows for {col}" + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_filter_appended_null_partitioned( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str, format_version: int +) -> None: + # Given + identifier = f"default.arrow_table_v{format_version}_appended_with_null_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[], + partition_spec=partition_spec, + ) + # Append with arrow_table_1 with lines [A,B,C] and then arrow_table_2 with lines[A,B,C,A,B,C] + tbl.append(arrow_table_with_null) + tbl.append(pa.concat_tables([arrow_table_with_null, arrow_table_with_null])) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + for col in TEST_DATA_WITH_NULL.keys(): + df = spark.table(identifier) + assert df.where(f"{col} is not null").count() == 6, f"Expected 6 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 3, f"Expected 3 null rows for {col}" + # expecting 6 files: first append with [A], [B], [C], second append with [A, A], [B, B], [C, C] + rows = spark.sql(f"select partition from {identifier}.files").collect() + assert len(rows) == 6 + + +@pytest.mark.integration +@pytest.mark.parametrize( + "part_col", ['int', 'bool', 'string', "string_long", "long", "float", "double", "date", "timestamptz", "timestamp", "binary"] +) +def test_query_filter_v1_v2_append_null( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, part_col: str +) -> None: + # Given + identifier = f"default.arrow_table_v1_v2_appended_with_null_partitioned_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=IdentityTransform(), name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": "1"}, + data=[], + partition_spec=partition_spec, + ) + tbl.append(arrow_table_with_null) + + # Then + assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}" + + # When + with tbl.transaction() as tx: + tx.upgrade_table_version(format_version=2) + + tbl.append(arrow_table_with_null) + + # Then + assert tbl.format_version == 2, f"Expected v2, got: v{tbl.format_version}" + for col in TEST_DATA_WITH_NULL.keys(): # type: ignore + df = spark.table(identifier) + assert df.where(f"{col} is not null").count() == 4, f"Expected 4 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 2, f"Expected 2 null rows for {col}" + + +@pytest.mark.integration +def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None: + identifier = "default.arrow_table_summaries" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="int")), + properties={'format-version': '2'}, + ) + + tbl.append(arrow_table_with_null) + tbl.append(arrow_table_with_null) + + rows = spark.sql( + f""" + SELECT operation, summary + FROM {identifier}.snapshots + ORDER BY committed_at ASC + """ + ).collect() + + operations = [row.operation for row in rows] + assert operations == ['append', 'append'] + + summaries = [row.summary for row in rows] + assert summaries[0] == { + 'changed-partition-count': '3', + 'added-data-files': '3', + 'added-files-size': '15029', + 'added-records': '3', + 'total-data-files': '3', + 'total-delete-files': '0', + 'total-equality-deletes': '0', + 'total-files-size': '15029', + 'total-position-deletes': '0', + 'total-records': '3', + } + + assert summaries[1] == { + 'changed-partition-count': '3', + 'added-data-files': '3', + 'added-files-size': '15029', + 'added-records': '3', + 'total-data-files': '6', + 'total-delete-files': '0', + 'total-equality-deletes': '0', + 'total-files-size': '30058', + 'total-position-deletes': '0', + 'total-records': '6', + } + + +@pytest.mark.integration +def test_data_files_with_table_partitioned_with_null( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table +) -> None: + identifier = "default.arrow_data_files" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="int")), + properties={'format-version': '1'}, + ) + + tbl.append(arrow_table_with_null) + tbl.append(arrow_table_with_null) + + # added_data_files_count, existing_data_files_count, deleted_data_files_count + rows = spark.sql( + f""" + SELECT added_data_files_count, existing_data_files_count, deleted_data_files_count + FROM {identifier}.all_manifests + """ + ).collect() + + assert [row.added_data_files_count for row in rows] == [3, 3, 3] + assert [row.existing_data_files_count for row in rows] == [ + 0, + 0, + 0, + ] + assert [row.deleted_data_files_count for row in rows] == [0, 0, 0] + + +@pytest.mark.integration +def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> None: + identifier = "default.arrow_data_files" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="int")), + properties={'format-version': '1'}, + ) + + with pytest.raises(ValueError, match="Expected PyArrow table, got: not a df"): + tbl.append("not a df") + + +@pytest.mark.integration +@pytest.mark.parametrize( + "spec", + [ + # mixed with non-identity is not supported + ( + PartitionSpec( + PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"), + PartitionField(source_id=1, field_id=1002, transform=IdentityTransform(), name="bool"), + ) + ), + # none of non-identity is supported + (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=BucketTransform(2), name="int_bucket"))), + (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=BucketTransform(2), name="long_bucket"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=BucketTransform(2), name="date_bucket"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=BucketTransform(2), name="timestamp_bucket"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=BucketTransform(2), name="timestamptz_bucket"))), + (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=BucketTransform(2), name="string_bucket"))), + (PartitionSpec(PartitionField(source_id=12, field_id=1001, transform=BucketTransform(2), name="fixed_bucket"))), + (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=BucketTransform(2), name="binary_bucket"))), + (PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=TruncateTransform(2), name="int_trunc"))), + (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))), + (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))), + (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=YearTransform(), name="timestamp_year"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=YearTransform(), name="timestamptz_year"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=YearTransform(), name="date_year"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=MonthTransform(), name="timestamp_month"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=MonthTransform(), name="timestamptz_month"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=MonthTransform(), name="date_month"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=DayTransform(), name="timestamp_day"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=DayTransform(), name="timestamptz_day"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=DayTransform(), name="date_day"))), + (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=HourTransform(), name="timestamp_hour"))), + (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=HourTransform(), name="timestamptz_hour"))), + (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=HourTransform(), name="date_hour"))), + ], +) +def test_unsupported_transform( + spec: PartitionSpec, spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table +) -> None: + identifier = "default.unsupported_transform" + + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + tbl = session_catalog.create_table( + identifier=identifier, + schema=TABLE_SCHEMA, + partition_spec=spec, + properties={'format-version': '1'}, + ) + + with pytest.raises(ValueError, match="All transforms are not supported.*"): + tbl.append(arrow_table_with_null) diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes/test_writes.py similarity index 88% rename from tests/integration/test_writes.py rename to tests/integration/test_writes/test_writes.py index e950fb43b1..62d3bb1172 100644 --- a/tests/integration/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -18,10 +18,9 @@ import math import os import time -import uuid from datetime import date, datetime from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict from urllib.parse import urlparse import pyarrow as pa @@ -36,93 +35,9 @@ from pyiceberg.catalog import Catalog from pyiceberg.catalog.sql import SqlCatalog from pyiceberg.exceptions import NoSuchTableError -from pyiceberg.schema import Schema -from pyiceberg.table import Table, TableProperties, _dataframe_to_data_files -from pyiceberg.typedef import Properties -from pyiceberg.types import ( - BinaryType, - BooleanType, - DateType, - DoubleType, - FixedType, - FloatType, - IntegerType, - LongType, - NestedField, - StringType, - TimestampType, - TimestamptzType, -) - -TEST_DATA_WITH_NULL = { - 'bool': [False, None, True], - 'string': ['a', None, 'z'], - # Go over the 16 bytes to kick in truncation - 'string_long': ['a' * 22, None, 'z' * 22], - 'int': [1, None, 9], - 'long': [1, None, 9], - 'float': [0.0, None, 0.9], - 'double': [0.0, None, 0.9], - 'timestamp': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - 'timestamptz': [datetime(2023, 1, 1, 19, 25, 00), None, datetime(2023, 3, 1, 19, 25, 00)], - 'date': [date(2023, 1, 1), None, date(2023, 3, 1)], - # Not supported by Spark - # 'time': [time(1, 22, 0), None, time(19, 25, 0)], - # Not natively supported by Arrow - # 'uuid': [uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, None, uuid.UUID('11111111-1111-1111-1111-111111111111').bytes], - 'binary': [b'\01', None, b'\22'], - 'fixed': [ - uuid.UUID('00000000-0000-0000-0000-000000000000').bytes, - None, - uuid.UUID('11111111-1111-1111-1111-111111111111').bytes, - ], -} - -TABLE_SCHEMA = Schema( - NestedField(field_id=1, name="bool", field_type=BooleanType(), required=False), - NestedField(field_id=2, name="string", field_type=StringType(), required=False), - NestedField(field_id=3, name="string_long", field_type=StringType(), required=False), - NestedField(field_id=4, name="int", field_type=IntegerType(), required=False), - NestedField(field_id=5, name="long", field_type=LongType(), required=False), - NestedField(field_id=6, name="float", field_type=FloatType(), required=False), - NestedField(field_id=7, name="double", field_type=DoubleType(), required=False), - NestedField(field_id=8, name="timestamp", field_type=TimestampType(), required=False), - NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(), required=False), - NestedField(field_id=10, name="date", field_type=DateType(), required=False), - # NestedField(field_id=11, name="time", field_type=TimeType(), required=False), - # NestedField(field_id=12, name="uuid", field_type=UuidType(), required=False), - NestedField(field_id=12, name="binary", field_type=BinaryType(), required=False), - NestedField(field_id=13, name="fixed", field_type=FixedType(16), required=False), -) - - -@pytest.fixture(scope="session") -def arrow_table_without_data(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns""" - return pa.Table.from_pylist([], schema=pa_schema) - - -@pytest.fixture(scope="session") -def arrow_table_with_only_nulls(pa_schema: pa.Schema) -> pa.Table: - """PyArrow table with all kinds of columns""" - return pa.Table.from_pylist([{}, {}], schema=pa_schema) - - -def _create_table( - session_catalog: Catalog, identifier: str, properties: Properties, data: Optional[List[pa.Table]] = None -) -> Table: - try: - session_catalog.drop_table(identifier=identifier) - except NoSuchTableError: - pass - - tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties=properties) - - if data: - for d in data: - tbl.append(d) - - return tbl +from pyiceberg.table import TableProperties, _dataframe_to_data_files +from tests.conftest import TEST_DATA_WITH_NULL +from utils import _create_table @pytest.fixture(scope="session", autouse=True) @@ -219,7 +134,7 @@ def test_query_filter_without_data(spark: SparkSession, col: str, format_version identifier = f"default.arrow_table_v{format_version}_without_data" df = spark.table(identifier) assert df.where(f"{col} is null").count() == 0, f"Expected 0 row for {col}" - assert df.where(f"{col} is not null").count() == 0, f"Expected 0 rows for {col}" + assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row for {col}" @pytest.mark.integration @@ -228,8 +143,8 @@ def test_query_filter_without_data(spark: SparkSession, col: str, format_version def test_query_filter_only_nulls(spark: SparkSession, col: str, format_version: int) -> None: identifier = f"default.arrow_table_v{format_version}_with_only_nulls" df = spark.table(identifier) - assert df.where(f"{col} is null").count() == 2, f"Expected 2 row for {col}" - assert df.where(f"{col} is not null").count() == 0, f"Expected 0 rows for {col}" + assert df.where(f"{col} is null").count() == 2, f"Expected 2 rows for {col}" + assert df.where(f"{col} is not null").count() == 0, f"Expected 0 row for {col}" @pytest.mark.integration diff --git a/tests/integration/test_writes/utils.py b/tests/integration/test_writes/utils.py new file mode 100644 index 0000000000..792e25185d --- /dev/null +++ b/tests/integration/test_writes/utils.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name +from typing import List, Optional + +import pyarrow as pa + +from pyiceberg.catalog import Catalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.partitioning import PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.table import Table +from pyiceberg.typedef import Properties +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DoubleType, + FixedType, + FloatType, + IntegerType, + LongType, + NestedField, + StringType, + TimestampType, + TimestamptzType, +) + +TABLE_SCHEMA = Schema( + NestedField(field_id=1, name="bool", field_type=BooleanType(), required=False), + NestedField(field_id=2, name="string", field_type=StringType(), required=False), + NestedField(field_id=3, name="string_long", field_type=StringType(), required=False), + NestedField(field_id=4, name="int", field_type=IntegerType(), required=False), + NestedField(field_id=5, name="long", field_type=LongType(), required=False), + NestedField(field_id=6, name="float", field_type=FloatType(), required=False), + NestedField(field_id=7, name="double", field_type=DoubleType(), required=False), + # NestedField(field_id=8, name="time", field_type=TimeType(), required=False), # Spark does not support time fields + NestedField(field_id=8, name="timestamp", field_type=TimestampType(), required=False), + NestedField(field_id=9, name="timestamptz", field_type=TimestamptzType(), required=False), + NestedField(field_id=10, name="date", field_type=DateType(), required=False), + # NestedField(field_id=11, name="time", field_type=TimeType(), required=False), + # NestedField(field_id=12, name="uuid", field_type=UuidType(), required=False), + NestedField(field_id=11, name="binary", field_type=BinaryType(), required=False), + NestedField(field_id=12, name="fixed", field_type=FixedType(16), required=False), +) + + +def _create_table( + session_catalog: Catalog, + identifier: str, + properties: Properties, + data: Optional[List[pa.Table]] = None, + partition_spec: Optional[PartitionSpec] = None, +) -> Table: + try: + session_catalog.drop_table(identifier=identifier) + except NoSuchTableError: + pass + + if partition_spec: + tbl = session_catalog.create_table( + identifier=identifier, schema=TABLE_SCHEMA, properties=properties, partition_spec=partition_spec + ) + else: + tbl = session_catalog.create_table(identifier=identifier, schema=TABLE_SCHEMA, properties=properties) + + if data: + for d in data: + tbl.append(d) + + return tbl diff --git a/tests/table/test_init.py b/tests/table/test_init.py index f1191295f3..2bc78f3197 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -64,6 +64,7 @@ UpdateSchema, _apply_table_update, _check_schema_compatible, + _determine_partitions, _match_deletes_to_data_file, _TableMetadataUpdateContext, update_table_metadata, @@ -82,7 +83,11 @@ SortField, SortOrder, ) -from pyiceberg.transforms import BucketTransform, IdentityTransform +from pyiceberg.transforms import ( + BucketTransform, + IdentityTransform, +) +from pyiceberg.typedef import Record from pyiceberg.types import ( BinaryType, BooleanType, @@ -1139,3 +1144,85 @@ def test_serialize_commit_table_request() -> None: deserialized_request = CommitTableRequest.model_validate_json(request.model_dump_json()) assert request == deserialized_request + + +def test_partition_for_demo() -> None: + import pyarrow as pa + + test_pa_schema = pa.schema([('year', pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) + test_schema = Schema( + NestedField(field_id=1, name='year', field_type=StringType(), required=False), + NestedField(field_id=2, name='n_legs', field_type=IntegerType(), required=True), + NestedField(field_id=3, name='animal', field_type=StringType(), required=False), + schema_id=1, + ) + test_data = { + 'year': [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021], + 'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100], + 'animal': ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"], + } + arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), + ) + result = _determine_partitions(partition_spec, test_schema, arrow_table) + assert {table_partition.partition_key.partition for table_partition in result} == { + Record(n_legs_identity=2, year_identity=2020), + Record(n_legs_identity=100, year_identity=2021), + Record(n_legs_identity=4, year_identity=2021), + Record(n_legs_identity=4, year_identity=2022), + Record(n_legs_identity=2, year_identity=2022), + Record(n_legs_identity=5, year_identity=2019), + } + assert ( + pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows + ) + + +def test_identity_partition_on_multi_columns() -> None: + import pyarrow as pa + + test_pa_schema = pa.schema([('born_year', pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) + test_schema = Schema( + NestedField(field_id=1, name='born_year', field_type=StringType(), required=False), + NestedField(field_id=2, name='n_legs', field_type=IntegerType(), required=True), + NestedField(field_id=3, name='animal', field_type=StringType(), required=False), + schema_id=1, + ) + # 5 partitions, 6 unique row values, 12 rows + test_rows = [ + (2021, 4, "Dog"), + (2022, 4, "Horse"), + (2022, 4, "Another Horse"), + (2021, 100, "Centipede"), + (None, 4, "Kirin"), + (2021, None, "Fish"), + ] * 2 + expected = {Record(n_legs_identity=test_rows[i][1], year_identity=test_rows[i][0]) for i in range(len(test_rows))} + partition_spec = PartitionSpec( + PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), + ) + import random + + # there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all + for _ in range(1000): + random.shuffle(test_rows) + test_data = { + 'born_year': [row[0] for row in test_rows], + 'n_legs': [row[1] for row in test_rows], + 'animal': [row[2] for row in test_rows], + } + arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) + + result = _determine_partitions(partition_spec, test_schema, arrow_table) + + assert {table_partition.partition_key.partition for table_partition in result} == expected + concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]) + assert concatenated_arrow_table.num_rows == arrow_table.num_rows + assert concatenated_arrow_table.sort_by([ + ('born_year', 'ascending'), + ('n_legs', 'ascending'), + ('animal', 'ascending'), + ]) == arrow_table.sort_by([('born_year', 'ascending'), ('n_legs', 'ascending'), ('animal', 'ascending')])