Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mkdocs/docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ catalog:
s3.secret-access-key: password
```

When using Hive 2.x, make sure to set the compatibility flag:

```yaml
catalog:
default:
...
hive.hive2-compatible: true
```

## Glue Catalog

Your AWS credentials can be passed directly through the Python API.
Expand Down
5 changes: 3 additions & 2 deletions pyiceberg/catalog/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from pyiceberg.table import (
CommitTableRequest,
CommitTableResponse,
PropertyUtil,
Table,
update_table_metadata,
)
Expand Down Expand Up @@ -162,7 +163,7 @@ def primitive(self, primitive: PrimitiveType) -> str:
if isinstance(primitive, DecimalType):
return f"decimal({primitive.precision},{primitive.scale})"
if (primitive_type := type(primitive)) not in GLUE_PRIMITIVE_TYPES:
return str(primitive_type.root)
return str(primitive)
return GLUE_PRIMITIVE_TYPES[primitive_type]


Expand Down Expand Up @@ -344,7 +345,7 @@ def _update_glue_table(self, database_name: str, table_name: str, table_input: T
self.glue.update_table(
DatabaseName=database_name,
TableInput=table_input,
SkipArchive=self.properties.get(GLUE_SKIP_ARCHIVE, GLUE_SKIP_ARCHIVE_DEFAULT),
SkipArchive=PropertyUtil.property_as_bool(self.properties, GLUE_SKIP_ARCHIVE, GLUE_SKIP_ARCHIVE_DEFAULT),
VersionId=version_id,
)
except self.glue.exceptions.EntityNotFoundException as e:
Expand Down
45 changes: 25 additions & 20 deletions pyiceberg/catalog/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
from pyiceberg.schema import Schema, SchemaVisitor, visit
from pyiceberg.serializers import FromInputFile
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, TableProperties, update_table_metadata
from pyiceberg.table import CommitTableRequest, CommitTableResponse, PropertyUtil, Table, TableProperties, update_table_metadata
from pyiceberg.table.metadata import new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties
Expand All @@ -95,6 +95,7 @@
StringType,
StructType,
TimestampType,
TimestamptzType,
TimeType,
UUIDType,
)
Expand All @@ -103,25 +104,13 @@
import pyarrow as pa


# Replace by visitor
hive_types = {
BooleanType: "boolean",
IntegerType: "int",
LongType: "bigint",
FloatType: "float",
DoubleType: "double",
DateType: "date",
TimeType: "string",
TimestampType: "timestamp",
StringType: "string",
UUIDType: "string",
BinaryType: "binary",
FixedType: "binary",
}

COMMENT = "comment"
OWNER = "owner"

# If set to true, HiveCatalog will operate in Hive2 compatibility mode
HIVE2_COMPATIBLE = "hive.hive2-compatible"
HIVE2_COMPATIBLE_DEFAULT = False


class _HiveClient:
"""Helper class to nicely open and close the transport."""
Expand Down Expand Up @@ -151,10 +140,15 @@ def __exit__(
self._transport.close()


def _construct_hive_storage_descriptor(schema: Schema, location: Optional[str]) -> StorageDescriptor:
def _construct_hive_storage_descriptor(
schema: Schema, location: Optional[str], hive2_compatible: bool = False
) -> StorageDescriptor:
ser_de_info = SerDeInfo(serializationLib="org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")
return StorageDescriptor(
[FieldSchema(field.name, visit(field.field_type, SchemaToHiveConverter()), field.doc) for field in schema.fields],
[
FieldSchema(field.name, visit(field.field_type, SchemaToHiveConverter(hive2_compatible)), field.doc)
for field in schema.fields
],
location,
"org.apache.hadoop.mapred.FileInputFormat",
"org.apache.hadoop.mapred.FileOutputFormat",
Expand Down Expand Up @@ -199,6 +193,7 @@ def _annotate_namespace(database: HiveDatabase, properties: Properties) -> HiveD
DateType: "date",
TimeType: "string",
TimestampType: "timestamp",
TimestamptzType: "timestamp with local time zone",
StringType: "string",
UUIDType: "string",
BinaryType: "binary",
Expand All @@ -207,6 +202,11 @@ def _annotate_namespace(database: HiveDatabase, properties: Properties) -> HiveD


class SchemaToHiveConverter(SchemaVisitor[str]):
hive2_compatible: bool

def __init__(self, hive2_compatible: bool):
self.hive2_compatible = hive2_compatible

def schema(self, schema: Schema, struct_result: str) -> str:
return struct_result

Expand All @@ -226,6 +226,9 @@ def map(self, map_type: MapType, key_result: str, value_result: str) -> str:
def primitive(self, primitive: PrimitiveType) -> str:
if isinstance(primitive, DecimalType):
return f"decimal({primitive.precision},{primitive.scale})"
elif self.hive2_compatible and isinstance(primitive, TimestamptzType):
# Hive2 doesn't support timestamp with local time zone
return "timestamp"
else:
return HIVE_PRIMITIVE_TYPES[type(primitive)]

Expand Down Expand Up @@ -314,7 +317,9 @@ def create_table(
owner=properties[OWNER] if properties and OWNER in properties else getpass.getuser(),
createTime=current_time_millis // 1000,
lastAccessTime=current_time_millis // 1000,
sd=_construct_hive_storage_descriptor(schema, location),
sd=_construct_hive_storage_descriptor(
schema, location, PropertyUtil.property_as_bool(self.properties, HIVE2_COMPATIBLE, HIVE2_COMPATIBLE_DEFAULT)
),
tableType=EXTERNAL_TABLE,
parameters=_construct_parameters(metadata_location),
)
Expand Down
6 changes: 6 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,12 @@ def property_as_int(properties: Dict[str, str], property_name: str, default: Opt
else:
return default

@staticmethod
def property_as_bool(properties: Dict[str, str], property_name: str, default: bool) -> bool:
if value := properties.get(property_name):
return value.lower() == "true"
return default


class Transaction:
_table: Table
Expand Down
88 changes: 79 additions & 9 deletions tests/catalog/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,24 @@
from pyiceberg.transforms import BucketTransform, IdentityTransform
from pyiceberg.typedef import UTF8
from pyiceberg.types import (
BinaryType,
BooleanType,
DateType,
DecimalType,
DoubleType,
FixedType,
FloatType,
IntegerType,
ListType,
LongType,
MapType,
NestedField,
StringType,
StructType,
TimestampType,
TimestamptzType,
TimeType,
UUIDType,
)

HIVE_CATALOG_NAME = "hive"
Expand Down Expand Up @@ -181,15 +194,20 @@ def test_check_number_of_namespaces(table_schema_simple: Schema) -> None:
catalog.create_table("table", schema=table_schema_simple)


@pytest.mark.parametrize("hive2_compatible", [True, False])
@patch("time.time", MagicMock(return_value=12345))
def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None:
def test_create_table(
table_schema_with_all_types: Schema, hive_database: HiveDatabase, hive_table: HiveTable, hive2_compatible: bool
) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
if hive2_compatible:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL, **{"hive.hive2-compatible": "true"})

catalog._client = MagicMock()
catalog._client.__enter__().create_table.return_value = None
catalog._client.__enter__().get_table.return_value = hive_table
catalog._client.__enter__().get_database.return_value = hive_database
catalog.create_table(("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg"})
catalog.create_table(("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"})

called_hive_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0]
# This one is generated within the function itself, so we need to extract
Expand All @@ -207,9 +225,27 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase,
retention=None,
sd=StorageDescriptor(
cols=[
FieldSchema(name="foo", type="string", comment=None),
FieldSchema(name="bar", type="int", comment=None),
FieldSchema(name="baz", type="boolean", comment=None),
FieldSchema(name='boolean', type='boolean', comment=None),
FieldSchema(name='integer', type='int', comment=None),
FieldSchema(name='long', type='bigint', comment=None),
FieldSchema(name='float', type='float', comment=None),
FieldSchema(name='double', type='double', comment=None),
FieldSchema(name='decimal', type='decimal(32,3)', comment=None),
FieldSchema(name='date', type='date', comment=None),
FieldSchema(name='time', type='string', comment=None),
FieldSchema(name='timestamp', type='timestamp', comment=None),
FieldSchema(
name='timestamptz',
type='timestamp' if hive2_compatible else 'timestamp with local time zone',
comment=None,
),
FieldSchema(name='string', type='string', comment=None),
FieldSchema(name='uuid', type='string', comment=None),
FieldSchema(name='fixed', type='binary', comment=None),
FieldSchema(name='binary', type='binary', comment=None),
FieldSchema(name='list', type='array<string>', comment=None),
FieldSchema(name='map', type='map<string,int>', comment=None),
FieldSchema(name='struct', type='struct<inner_string:string,inner_int:int>', comment=None),
],
location=f"{hive_database.locationUri}/table",
inputFormat="org.apache.hadoop.mapred.FileInputFormat",
Expand Down Expand Up @@ -266,12 +302,46 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase,
location=metadata.location,
table_uuid=metadata.table_uuid,
last_updated_ms=metadata.last_updated_ms,
last_column_id=3,
last_column_id=22,
schemas=[
Schema(
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
NestedField(field_id=1, name='boolean', field_type=BooleanType(), required=True),
NestedField(field_id=2, name='integer', field_type=IntegerType(), required=True),
NestedField(field_id=3, name='long', field_type=LongType(), required=True),
NestedField(field_id=4, name='float', field_type=FloatType(), required=True),
NestedField(field_id=5, name='double', field_type=DoubleType(), required=True),
NestedField(field_id=6, name='decimal', field_type=DecimalType(precision=32, scale=3), required=True),
NestedField(field_id=7, name='date', field_type=DateType(), required=True),
NestedField(field_id=8, name='time', field_type=TimeType(), required=True),
NestedField(field_id=9, name='timestamp', field_type=TimestampType(), required=True),
NestedField(field_id=10, name='timestamptz', field_type=TimestamptzType(), required=True),
NestedField(field_id=11, name='string', field_type=StringType(), required=True),
NestedField(field_id=12, name='uuid', field_type=UUIDType(), required=True),
NestedField(field_id=13, name='fixed', field_type=FixedType(length=12), required=True),
NestedField(field_id=14, name='binary', field_type=BinaryType(), required=True),
NestedField(
field_id=15,
name='list',
field_type=ListType(type='list', element_id=18, element_type=StringType(), element_required=True),
required=True,
),
NestedField(
field_id=16,
name='map',
field_type=MapType(
type='map', key_id=19, key_type=StringType(), value_id=20, value_type=IntegerType(), value_required=True
),
required=True,
),
NestedField(
field_id=17,
name='struct',
field_type=StructType(
NestedField(field_id=21, name='inner_string', field_type=StringType(), required=False),
NestedField(field_id=22, name='inner_int', field_type=IntegerType(), required=True),
),
required=True,
),
schema_id=0,
identifier_field_ids=[2],
)
Expand Down
Loading