Skip to content

Commit 54a08f3

Browse files
authored
Replace type() calls with isinstance() (#188)
* WIP * fix bad ifs that snuck in * missing this one * couple more cases * almost all of them * undo change * add a comment and explain why we are doing the explicit type(call) * standardize naming * lint
1 parent 56ce5d8 commit 54a08f3

File tree

5 files changed

+47
-62
lines changed

5 files changed

+47
-62
lines changed

pyiceberg/expressions/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def as_bound(self) -> Type[BoundNotNull[L]]:
459459
class BoundIsNaN(BoundUnaryPredicate[L]):
460460
def __new__(cls, term: BoundTerm[L]) -> BooleanExpression: # type: ignore # pylint: disable=W0221
461461
bound_type = term.ref().field.field_type
462-
if type(bound_type) in {FloatType, DoubleType}:
462+
if isinstance(bound_type, (FloatType, DoubleType)):
463463
return super().__new__(cls)
464464
return AlwaysFalse()
465465

@@ -475,7 +475,7 @@ def as_unbound(self) -> Type[IsNaN]:
475475
class BoundNotNaN(BoundUnaryPredicate[L]):
476476
def __new__(cls, term: BoundTerm[L]) -> BooleanExpression: # type: ignore # pylint: disable=W0221
477477
bound_type = term.ref().field.field_type
478-
if type(bound_type) in {FloatType, DoubleType}:
478+
if isinstance(bound_type, (FloatType, DoubleType)):
479479
return super().__new__(cls)
480480
return AlwaysTrue()
481481

pyiceberg/expressions/visitors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ def visit_not_null(self, term: BoundTerm[L]) -> bool:
620620
# lowerBound is null if all partition values are null
621621
all_null = self.partition_fields[pos].contains_null is True and self.partition_fields[pos].lower_bound is None
622622

623-
if all_null and type(term.ref().field.field_type) in {DoubleType, FloatType}:
623+
if all_null and isinstance(term.ref().field.field_type, (DoubleType, FloatType)):
624624
# floating point types may include NaN values, which we check separately.
625625
# In case bounds don't include NaN value, contains_nan needs to be checked against.
626626
all_null = self.partition_fields[pos].contains_nan is False

pyiceberg/table/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,9 @@ def _append_updates(self, *new_updates: TableUpdate) -> Transaction:
150150
Transaction object with the new updates appended.
151151
"""
152152
for new_update in new_updates:
153+
# explicitly get type of new_update as new_update is an instantiated class
153154
type_new_update = type(new_update)
154-
if any(type(update) == type_new_update for update in self._updates):
155+
if any(isinstance(update, type_new_update) for update in self._updates):
155156
raise ValueError(f"Updates in a single commit need to be unique, duplicate: {type_new_update}")
156157
self._updates = self._updates + new_updates
157158
return self
@@ -168,9 +169,10 @@ def _append_requirements(self, *new_requirements: TableRequirement) -> Transacti
168169
Returns:
169170
Transaction object with the new requirements appended.
170171
"""
171-
for requirement in new_requirements:
172-
type_new_requirement = type(requirement)
173-
if any(type(requirement) == type_new_requirement for update in self._requirements):
172+
for new_requirement in new_requirements:
173+
# explicitly get type of new_update as requirement is an instantiated class
174+
type_new_requirement = type(new_requirement)
175+
if any(isinstance(requirement, type_new_requirement) for requirement in self._requirements):
174176
raise ValueError(f"Requirements in a single commit need to be unique, duplicate: {type_new_requirement}")
175177
self._requirements = self._requirements + new_requirements
176178
return self

pyiceberg/table/sorting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def set_null_order(cls, values: Dict[str, Any]) -> Dict[str, Any]:
114114

115115
def __str__(self) -> str:
116116
"""Return the string representation of the SortField class."""
117-
if type(self.transform) == IdentityTransform:
117+
if isinstance(self.transform, IdentityTransform):
118118
# In the case of an identity transform, we can omit the transform
119119
return f"{self.source_id} {self.direction} {self.null_order}"
120120
else:

pyiceberg/transforms.py

Lines changed: 37 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -220,38 +220,40 @@ def project(self, name: str, pred: BoundPredicate[L]) -> Optional[UnboundPredica
220220
return None
221221

222222
def can_transform(self, source: IcebergType) -> bool:
223-
return type(source) in {
224-
IntegerType,
225-
DateType,
226-
LongType,
227-
TimeType,
228-
TimestampType,
229-
TimestamptzType,
230-
DecimalType,
231-
StringType,
232-
FixedType,
233-
BinaryType,
234-
UUIDType,
235-
}
223+
return isinstance(
224+
source,
225+
(
226+
IntegerType,
227+
DateType,
228+
LongType,
229+
TimeType,
230+
TimestampType,
231+
TimestamptzType,
232+
DecimalType,
233+
StringType,
234+
FixedType,
235+
BinaryType,
236+
UUIDType,
237+
),
238+
)
236239

237240
def transform(self, source: IcebergType, bucket: bool = True) -> Callable[[Optional[Any]], Optional[int]]:
238-
source_type = type(source)
239-
if source_type in {IntegerType, LongType, DateType, TimeType, TimestampType, TimestamptzType}:
241+
if isinstance(source, (IntegerType, LongType, DateType, TimeType, TimestampType, TimestamptzType)):
240242

241243
def hash_func(v: Any) -> int:
242244
return mmh3.hash(struct.pack("<q", v))
243245

244-
elif source_type == DecimalType:
246+
elif isinstance(source, DecimalType):
245247

246248
def hash_func(v: Any) -> int:
247249
return mmh3.hash(decimal_to_bytes(v))
248250

249-
elif source_type in {StringType, FixedType, BinaryType}:
251+
elif isinstance(source, (StringType, FixedType, BinaryType)):
250252

251253
def hash_func(v: Any) -> int:
252254
return mmh3.hash(v)
253255

254-
elif source_type == UUIDType:
256+
elif isinstance(source, UUIDType):
255257

256258
def hash_func(v: Any) -> int:
257259
if isinstance(v, UUID):
@@ -330,13 +332,12 @@ class YearTransform(TimeTransform[S]):
330332
root: LiteralType["year"] = Field(default="year") # noqa: F821
331333

332334
def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[int]]:
333-
source_type = type(source)
334-
if source_type == DateType:
335+
if isinstance(source, DateType):
335336

336337
def year_func(v: Any) -> int:
337338
return datetime.days_to_years(v)
338339

339-
elif source_type in {TimestampType, TimestamptzType}:
340+
elif isinstance(source, (TimestampType, TimestamptzType)):
340341

341342
def year_func(v: Any) -> int:
342343
return datetime.micros_to_years(v)
@@ -347,11 +348,7 @@ def year_func(v: Any) -> int:
347348
return lambda v: year_func(v) if v is not None else None
348349

349350
def can_transform(self, source: IcebergType) -> bool:
350-
return type(source) in {
351-
DateType,
352-
TimestampType,
353-
TimestamptzType,
354-
}
351+
return isinstance(source, (DateType, TimestampType, TimestamptzType))
355352

356353
@property
357354
def granularity(self) -> TimeResolution:
@@ -377,13 +374,12 @@ class MonthTransform(TimeTransform[S]):
377374
root: LiteralType["month"] = Field(default="month") # noqa: F821
378375

379376
def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[int]]:
380-
source_type = type(source)
381-
if source_type == DateType:
377+
if isinstance(source, DateType):
382378

383379
def month_func(v: Any) -> int:
384380
return datetime.days_to_months(v)
385381

386-
elif source_type in {TimestampType, TimestamptzType}:
382+
elif isinstance(source, (TimestampType, TimestamptzType)):
387383

388384
def month_func(v: Any) -> int:
389385
return datetime.micros_to_months(v)
@@ -394,11 +390,7 @@ def month_func(v: Any) -> int:
394390
return lambda v: month_func(v) if v else None
395391

396392
def can_transform(self, source: IcebergType) -> bool:
397-
return type(source) in {
398-
DateType,
399-
TimestampType,
400-
TimestamptzType,
401-
}
393+
return isinstance(source, (DateType, TimestampType, TimestamptzType))
402394

403395
@property
404396
def granularity(self) -> TimeResolution:
@@ -424,13 +416,12 @@ class DayTransform(TimeTransform[S]):
424416
root: LiteralType["day"] = Field(default="day") # noqa: F821
425417

426418
def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[int]]:
427-
source_type = type(source)
428-
if source_type == DateType:
419+
if isinstance(source, DateType):
429420

430421
def day_func(v: Any) -> int:
431422
return v
432423

433-
elif source_type in {TimestampType, TimestamptzType}:
424+
elif isinstance(source, (TimestampType, TimestamptzType)):
434425

435426
def day_func(v: Any) -> int:
436427
return datetime.micros_to_days(v)
@@ -441,11 +432,7 @@ def day_func(v: Any) -> int:
441432
return lambda v: day_func(v) if v else None
442433

443434
def can_transform(self, source: IcebergType) -> bool:
444-
return type(source) in {
445-
DateType,
446-
TimestampType,
447-
TimestamptzType,
448-
}
435+
return isinstance(source, (DateType, TimestampType, TimestamptzType))
449436

450437
def result_type(self, source: IcebergType) -> IcebergType:
451438
return DateType()
@@ -474,7 +461,7 @@ class HourTransform(TimeTransform[S]):
474461
root: LiteralType["hour"] = Field(default="hour") # noqa: F821
475462

476463
def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[int]]:
477-
if type(source) in {TimestampType, TimestamptzType}:
464+
if isinstance(source, (TimestampType, TimestamptzType)):
478465

479466
def hour_func(v: Any) -> int:
480467
return datetime.micros_to_hours(v)
@@ -485,10 +472,7 @@ def hour_func(v: Any) -> int:
485472
return lambda v: hour_func(v) if v else None
486473

487474
def can_transform(self, source: IcebergType) -> bool:
488-
return type(source) in {
489-
TimestampType,
490-
TimestamptzType,
491-
}
475+
return isinstance(source, (TimestampType, TimestamptzType))
492476

493477
@property
494478
def granularity(self) -> TimeResolution:
@@ -580,7 +564,7 @@ def __init__(self, width: int, **data: Any):
580564
self._width = width
581565

582566
def can_transform(self, source: IcebergType) -> bool:
583-
return type(source) in {IntegerType, LongType, StringType, BinaryType, DecimalType}
567+
return isinstance(source, (IntegerType, LongType, StringType, BinaryType, DecimalType))
584568

585569
def result_type(self, source: IcebergType) -> IcebergType:
586570
return source
@@ -616,18 +600,17 @@ def width(self) -> int:
616600
return self._width
617601

618602
def transform(self, source: IcebergType) -> Callable[[Optional[S]], Optional[S]]:
619-
source_type = type(source)
620-
if source_type in {IntegerType, LongType}:
603+
if isinstance(source, (IntegerType, LongType)):
621604

622605
def truncate_func(v: Any) -> Any:
623606
return v - v % self._width
624607

625-
elif source_type in {StringType, BinaryType}:
608+
elif isinstance(source, (StringType, BinaryType)):
626609

627610
def truncate_func(v: Any) -> Any:
628611
return v[0 : min(self._width, len(v))]
629612

630-
elif source_type == DecimalType:
613+
elif isinstance(source, DecimalType):
631614

632615
def truncate_func(v: Any) -> Any:
633616
return truncate_decimal(v, self._width)
@@ -788,9 +771,9 @@ def _truncate_array(
788771
) -> Optional[UnboundPredicate[Any]]:
789772
boundary = pred.literal
790773

791-
if type(pred) in {BoundLessThan, BoundLessThanOrEqual}:
774+
if isinstance(pred, (BoundLessThan, BoundLessThanOrEqual)):
792775
return LessThanOrEqual(Reference(name), _transform_literal(transform, boundary))
793-
elif type(pred) in {BoundGreaterThan, BoundGreaterThanOrEqual}:
776+
elif isinstance(pred, (BoundGreaterThan, BoundGreaterThanOrEqual)):
794777
return GreaterThanOrEqual(Reference(name), _transform_literal(transform, boundary))
795778
if isinstance(pred, BoundEqualTo):
796779
return EqualTo(Reference(name), _transform_literal(transform, boundary))

0 commit comments

Comments
 (0)