diff --git a/mkdocs/docs/configuration.md b/mkdocs/docs/configuration.md index 93b198c328..1ca071f009 100644 --- a/mkdocs/docs/configuration.md +++ b/mkdocs/docs/configuration.md @@ -232,6 +232,15 @@ catalog: s3.secret-access-key: password ``` +When using Hive 2.x, make sure to set the compatibility flag: + +```yaml +catalog: + default: +... + hive.hive2-compatible: true +``` + ## Glue Catalog Your AWS credentials can be passed directly through the Python API. diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py index e7532677aa..c3c2fdafc6 100644 --- a/pyiceberg/catalog/glue.py +++ b/pyiceberg/catalog/glue.py @@ -65,6 +65,7 @@ from pyiceberg.table import ( CommitTableRequest, CommitTableResponse, + PropertyUtil, Table, update_table_metadata, ) @@ -162,7 +163,7 @@ def primitive(self, primitive: PrimitiveType) -> str: if isinstance(primitive, DecimalType): return f"decimal({primitive.precision},{primitive.scale})" if (primitive_type := type(primitive)) not in GLUE_PRIMITIVE_TYPES: - return str(primitive_type.root) + return str(primitive) return GLUE_PRIMITIVE_TYPES[primitive_type] @@ -344,7 +345,7 @@ def _update_glue_table(self, database_name: str, table_name: str, table_input: T self.glue.update_table( DatabaseName=database_name, TableInput=table_input, - SkipArchive=self.properties.get(GLUE_SKIP_ARCHIVE, GLUE_SKIP_ARCHIVE_DEFAULT), + SkipArchive=PropertyUtil.property_as_bool(self.properties, GLUE_SKIP_ARCHIVE, GLUE_SKIP_ARCHIVE_DEFAULT), VersionId=version_id, ) except self.glue.exceptions.EntityNotFoundException as e: diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py index 359bdef595..b504da9a73 100644 --- a/pyiceberg/catalog/hive.py +++ b/pyiceberg/catalog/hive.py @@ -74,7 +74,7 @@ from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec from pyiceberg.schema import Schema, SchemaVisitor, visit from pyiceberg.serializers import FromInputFile -from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, TableProperties, update_table_metadata +from pyiceberg.table import CommitTableRequest, CommitTableResponse, PropertyUtil, Table, TableProperties, update_table_metadata from pyiceberg.table.metadata import new_table_metadata from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties @@ -95,6 +95,7 @@ StringType, StructType, TimestampType, + TimestamptzType, TimeType, UUIDType, ) @@ -103,25 +104,13 @@ import pyarrow as pa -# Replace by visitor -hive_types = { - BooleanType: "boolean", - IntegerType: "int", - LongType: "bigint", - FloatType: "float", - DoubleType: "double", - DateType: "date", - TimeType: "string", - TimestampType: "timestamp", - StringType: "string", - UUIDType: "string", - BinaryType: "binary", - FixedType: "binary", -} - COMMENT = "comment" OWNER = "owner" +# If set to true, HiveCatalog will operate in Hive2 compatibility mode +HIVE2_COMPATIBLE = "hive.hive2-compatible" +HIVE2_COMPATIBLE_DEFAULT = False + class _HiveClient: """Helper class to nicely open and close the transport.""" @@ -151,10 +140,15 @@ def __exit__( self._transport.close() -def _construct_hive_storage_descriptor(schema: Schema, location: Optional[str]) -> StorageDescriptor: +def _construct_hive_storage_descriptor( + schema: Schema, location: Optional[str], hive2_compatible: bool = False +) -> StorageDescriptor: ser_de_info = SerDeInfo(serializationLib="org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") return StorageDescriptor( - [FieldSchema(field.name, visit(field.field_type, SchemaToHiveConverter()), field.doc) for field in schema.fields], + [ + FieldSchema(field.name, visit(field.field_type, SchemaToHiveConverter(hive2_compatible)), field.doc) + for field in schema.fields + ], location, "org.apache.hadoop.mapred.FileInputFormat", "org.apache.hadoop.mapred.FileOutputFormat", @@ -199,6 +193,7 @@ def _annotate_namespace(database: HiveDatabase, properties: Properties) -> HiveD DateType: "date", TimeType: "string", TimestampType: "timestamp", + TimestamptzType: "timestamp with local time zone", StringType: "string", UUIDType: "string", BinaryType: "binary", @@ -207,6 +202,11 @@ def _annotate_namespace(database: HiveDatabase, properties: Properties) -> HiveD class SchemaToHiveConverter(SchemaVisitor[str]): + hive2_compatible: bool + + def __init__(self, hive2_compatible: bool): + self.hive2_compatible = hive2_compatible + def schema(self, schema: Schema, struct_result: str) -> str: return struct_result @@ -226,6 +226,9 @@ def map(self, map_type: MapType, key_result: str, value_result: str) -> str: def primitive(self, primitive: PrimitiveType) -> str: if isinstance(primitive, DecimalType): return f"decimal({primitive.precision},{primitive.scale})" + elif self.hive2_compatible and isinstance(primitive, TimestamptzType): + # Hive2 doesn't support timestamp with local time zone + return "timestamp" else: return HIVE_PRIMITIVE_TYPES[type(primitive)] @@ -314,7 +317,9 @@ def create_table( owner=properties[OWNER] if properties and OWNER in properties else getpass.getuser(), createTime=current_time_millis // 1000, lastAccessTime=current_time_millis // 1000, - sd=_construct_hive_storage_descriptor(schema, location), + sd=_construct_hive_storage_descriptor( + schema, location, PropertyUtil.property_as_bool(self.properties, HIVE2_COMPATIBLE, HIVE2_COMPATIBLE_DEFAULT) + ), tableType=EXTERNAL_TABLE, parameters=_construct_parameters(metadata_location), ) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2dbc32d893..ac19c1a538 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -251,6 +251,12 @@ def property_as_int(properties: Dict[str, str], property_name: str, default: Opt else: return default + @staticmethod + def property_as_bool(properties: Dict[str, str], property_name: str, default: bool) -> bool: + if value := properties.get(property_name): + return value.lower() == "true" + return default + class Transaction: _table: Table diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py index e59b7599bc..a8c904d646 100644 --- a/tests/catalog/test_hive.py +++ b/tests/catalog/test_hive.py @@ -61,11 +61,24 @@ from pyiceberg.transforms import BucketTransform, IdentityTransform from pyiceberg.typedef import UTF8 from pyiceberg.types import ( + BinaryType, BooleanType, + DateType, + DecimalType, + DoubleType, + FixedType, + FloatType, IntegerType, + ListType, LongType, + MapType, NestedField, StringType, + StructType, + TimestampType, + TimestamptzType, + TimeType, + UUIDType, ) HIVE_CATALOG_NAME = "hive" @@ -181,15 +194,20 @@ def test_check_number_of_namespaces(table_schema_simple: Schema) -> None: catalog.create_table("table", schema=table_schema_simple) +@pytest.mark.parametrize("hive2_compatible", [True, False]) @patch("time.time", MagicMock(return_value=12345)) -def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None: +def test_create_table( + table_schema_with_all_types: Schema, hive_database: HiveDatabase, hive_table: HiveTable, hive2_compatible: bool +) -> None: catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL) + if hive2_compatible: + catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL, **{"hive.hive2-compatible": "true"}) catalog._client = MagicMock() catalog._client.__enter__().create_table.return_value = None catalog._client.__enter__().get_table.return_value = hive_table catalog._client.__enter__().get_database.return_value = hive_database - catalog.create_table(("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg"}) + catalog.create_table(("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"}) called_hive_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0] # This one is generated within the function itself, so we need to extract @@ -207,9 +225,27 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase, retention=None, sd=StorageDescriptor( cols=[ - FieldSchema(name="foo", type="string", comment=None), - FieldSchema(name="bar", type="int", comment=None), - FieldSchema(name="baz", type="boolean", comment=None), + FieldSchema(name='boolean', type='boolean', comment=None), + FieldSchema(name='integer', type='int', comment=None), + FieldSchema(name='long', type='bigint', comment=None), + FieldSchema(name='float', type='float', comment=None), + FieldSchema(name='double', type='double', comment=None), + FieldSchema(name='decimal', type='decimal(32,3)', comment=None), + FieldSchema(name='date', type='date', comment=None), + FieldSchema(name='time', type='string', comment=None), + FieldSchema(name='timestamp', type='timestamp', comment=None), + FieldSchema( + name='timestamptz', + type='timestamp' if hive2_compatible else 'timestamp with local time zone', + comment=None, + ), + FieldSchema(name='string', type='string', comment=None), + FieldSchema(name='uuid', type='string', comment=None), + FieldSchema(name='fixed', type='binary', comment=None), + FieldSchema(name='binary', type='binary', comment=None), + FieldSchema(name='list', type='array', comment=None), + FieldSchema(name='map', type='map', comment=None), + FieldSchema(name='struct', type='struct', comment=None), ], location=f"{hive_database.locationUri}/table", inputFormat="org.apache.hadoop.mapred.FileInputFormat", @@ -266,12 +302,46 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase, location=metadata.location, table_uuid=metadata.table_uuid, last_updated_ms=metadata.last_updated_ms, - last_column_id=3, + last_column_id=22, schemas=[ Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=False), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + NestedField(field_id=1, name='boolean', field_type=BooleanType(), required=True), + NestedField(field_id=2, name='integer', field_type=IntegerType(), required=True), + NestedField(field_id=3, name='long', field_type=LongType(), required=True), + NestedField(field_id=4, name='float', field_type=FloatType(), required=True), + NestedField(field_id=5, name='double', field_type=DoubleType(), required=True), + NestedField(field_id=6, name='decimal', field_type=DecimalType(precision=32, scale=3), required=True), + NestedField(field_id=7, name='date', field_type=DateType(), required=True), + NestedField(field_id=8, name='time', field_type=TimeType(), required=True), + NestedField(field_id=9, name='timestamp', field_type=TimestampType(), required=True), + NestedField(field_id=10, name='timestamptz', field_type=TimestamptzType(), required=True), + NestedField(field_id=11, name='string', field_type=StringType(), required=True), + NestedField(field_id=12, name='uuid', field_type=UUIDType(), required=True), + NestedField(field_id=13, name='fixed', field_type=FixedType(length=12), required=True), + NestedField(field_id=14, name='binary', field_type=BinaryType(), required=True), + NestedField( + field_id=15, + name='list', + field_type=ListType(type='list', element_id=18, element_type=StringType(), element_required=True), + required=True, + ), + NestedField( + field_id=16, + name='map', + field_type=MapType( + type='map', key_id=19, key_type=StringType(), value_id=20, value_type=IntegerType(), value_required=True + ), + required=True, + ), + NestedField( + field_id=17, + name='struct', + field_type=StructType( + NestedField(field_id=21, name='inner_string', field_type=StringType(), required=False), + NestedField(field_id=22, name='inner_int', field_type=IntegerType(), required=True), + ), + required=True, + ), schema_id=0, identifier_field_ids=[2], ) diff --git a/tests/conftest.py b/tests/conftest.py index 4a820fedec..7da0a0a85a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -69,7 +69,9 @@ BinaryType, BooleanType, DateType, + DecimalType, DoubleType, + FixedType, FloatType, IntegerType, ListType, @@ -78,6 +80,9 @@ NestedField, StringType, StructType, + TimestampType, + TimestamptzType, + TimeType, UUIDType, ) from pyiceberg.utils.datetime import datetime_to_millis @@ -266,6 +271,54 @@ def table_schema_nested_with_struct_key_map() -> Schema: ) +@pytest.fixture(scope="session") +def table_schema_with_all_types() -> Schema: + return schema.Schema( + NestedField(field_id=1, name="boolean", field_type=BooleanType(), required=True), + NestedField(field_id=2, name="integer", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="long", field_type=LongType(), required=True), + NestedField(field_id=4, name="float", field_type=FloatType(), required=True), + NestedField(field_id=5, name="double", field_type=DoubleType(), required=True), + NestedField(field_id=6, name="decimal", field_type=DecimalType(32, 3), required=True), + NestedField(field_id=7, name="date", field_type=DateType(), required=True), + NestedField(field_id=8, name="time", field_type=TimeType(), required=True), + NestedField(field_id=9, name="timestamp", field_type=TimestampType(), required=True), + NestedField(field_id=10, name="timestamptz", field_type=TimestamptzType(), required=True), + NestedField(field_id=11, name="string", field_type=StringType(), required=True), + NestedField(field_id=12, name="uuid", field_type=UUIDType(), required=True), + NestedField(field_id=14, name="fixed", field_type=FixedType(12), required=True), + NestedField(field_id=13, name="binary", field_type=BinaryType(), required=True), + NestedField( + field_id=15, + name="list", + field_type=ListType(element_id=16, element_type=StringType(), element_required=True), + required=True, + ), + NestedField( + field_id=17, + name="map", + field_type=MapType( + key_id=18, + key_type=StringType(), + value_id=19, + value_type=IntegerType(), + value_required=True, + ), + required=True, + ), + NestedField( + field_id=20, + name="struct", + field_type=StructType( + NestedField(field_id=21, name="inner_string", field_type=StringType(), required=False), + NestedField(field_id=22, name="inner_int", field_type=IntegerType(), required=True), + ), + ), + schema_id=1, + identifier_field_ids=[2], + ) + + @pytest.fixture(scope="session") def pyarrow_schema_simple_without_ids() -> "pa.Schema": import pyarrow as pa @@ -1953,6 +2006,20 @@ def session_catalog() -> Catalog: ) +@pytest.fixture(scope="session") +def session_catalog_hive() -> Catalog: + return load_catalog( + "local", + **{ + "type": "hive", + "uri": "http://localhost:9083", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) + + @pytest.fixture(scope="session") def spark() -> "SparkSession": import importlib.metadata @@ -1984,6 +2051,13 @@ def spark() -> "SparkSession": .config("spark.sql.catalog.integration.s3.endpoint", "http://localhost:9000") .config("spark.sql.catalog.integration.s3.path-style-access", "true") .config("spark.sql.defaultCatalog", "integration") + .config("spark.sql.catalog.hive", "org.apache.iceberg.spark.SparkCatalog") + .config("spark.sql.catalog.hive.type", "hive") + .config("spark.sql.catalog.hive.uri", "http://localhost:9083") + .config("spark.sql.catalog.hive.io-impl", "org.apache.iceberg.aws.s3.S3FileIO") + .config("spark.sql.catalog.hive.warehouse", "s3://warehouse/hive/") + .config("spark.sql.catalog.hive.s3.endpoint", "http://localhost:9000") + .config("spark.sql.catalog.hive.s3.path-style-access", "true") .getOrCreate() ) diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 62d3bb1172..e1526d2a5e 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -33,6 +33,7 @@ from pytest_mock.plugin import MockerFixture from pyiceberg.catalog import Catalog +from pyiceberg.catalog.hive import HiveCatalog from pyiceberg.catalog.sql import SqlCatalog from pyiceberg.exceptions import NoSuchTableError from pyiceberg.table import TableProperties, _dataframe_to_data_files @@ -747,3 +748,22 @@ def get_metadata_entries_count(identifier: str) -> int: tbl.transaction().set_properties({"test": "2"}).commit_transaction() tbl.append(arrow_table_with_null) assert get_metadata_entries_count(identifier) == 4 + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_hive_catalog_storage_descriptor( + session_catalog_hive: HiveCatalog, + pa_schema: pa.Schema, + arrow_table_with_null: pa.Table, + spark: SparkSession, + format_version: int, +) -> None: + tbl = _create_table( + session_catalog_hive, "default.test_storage_descriptor", {"format-version": format_version}, [arrow_table_with_null] + ) + + # check if pyiceberg can read the table + assert len(tbl.scan().to_arrow()) == 3 + # check if spark can read the table + assert spark.sql("SELECT * FROM hive.default.test_storage_descriptor").count() == 3