diff --git a/setup.cfg b/setup.cfg index 6921460..7450f41 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,7 +50,7 @@ python_requires = >=3.10 install_requires = importlib-metadata; python_version<"3.8" pydantic>=2.0.0 - datamodel-code-generator>=0.51.0 + datamodel-code-generator>=0.51.0,<0.55.0 typing_extensions pyld rdflib diff --git a/src/oold/backend/document_store.py b/src/oold/backend/document_store.py index 6982f0a..b2d78d0 100644 --- a/src/oold/backend/document_store.py +++ b/src/oold/backend/document_store.py @@ -1,15 +1,27 @@ import json import sqlite3 from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Set, Union -from oold.backend.interface import Backend, StoreResult +from oold.backend.interface import ( + Backend, + Condition, + LinkedDataFormat, + Query, + QueryParam, + ResolveParam, + ResolveResult, + StoreResult, + apply_operator, +) class SimpleDictDocumentStore(Backend): _store: Optional[Dict[str, dict]] = None + format: LinkedDataFormat = LinkedDataFormat.JSON - def __init__(self): + def __init__(self, **kwargs): + super().__init__(**kwargs) self._store = {} def resolve_iris(self, iris: List[str]) -> Dict[str, Dict]: @@ -18,13 +30,65 @@ def resolve_iris(self, iris: List[str]) -> Dict[str, Dict]: jsonld_dicts[iri] = self._store.get(iri, None) return jsonld_dicts - def store_jsonld_dicts(self, jsonld_dicts: Dict[str, Dict]) -> StoreResult: - for iri, jsonld_dict in jsonld_dicts.items(): - self._store[iri] = jsonld_dict + def store_json_dicts(self, json_dicts: Dict[str, Dict]) -> StoreResult: + for iri, json_dict in json_dicts.items(): + self._store[iri] = json_dict return StoreResult(success=True) - def query(): - pass + def _filter( + self, + key: str, + operator: str, + value: Any, + context: Optional[Dict[str, Dict]] = None, + data: Optional[Dict[str, Dict]] = None, + ) -> Set[str]: + if data is None: + data = self._store + # retrieve property mapping from context + # ToDo: use a jsonld expand here + # if context is not None and key in context: + # key = context[key] + matched_entities = set() + for iri, jsonld_dict in data.items(): + if key in jsonld_dict: + if apply_operator(operator, jsonld_dict[key], value): + matched_entities.add(iri) + return matched_entities + + def _query( + self, + query: Union[Query, Condition], + context: Dict = None, + data: Optional[Dict[str, Dict]] = None, + ) -> Set[str]: + print("QUERY", query) + if data is None: + data = self._store + if isinstance(query, Condition): + return self._filter(query.field, query.operator, query.value, context, data) + elif isinstance(query, Query): + c1_res = self._query(query.op1, context, data) + c2_res = self._query(query.op2, context, data) + if query.operator == "and": + # intersect the results + return c1_res & c2_res + elif query.operator == "or": + # union the results + return c1_res | c2_res + else: + raise NotImplementedError(f"Operator {query.operator} not implemented") + else: + raise ValueError("Invalid query type") + + def query(self, param: QueryParam) -> ResolveResult: + context = None + # if param.model_cls is not None: + # context = _get_schema(param.model_cls).get("@context", None) + # elif self.model_cls is not None: + # context = _get_schema(self.model_cls).get("@context", None) + iris = self._query(param.query, context) + return self.resolve(ResolveParam(iris=list(iris), model_cls=param.model_cls)) class SqliteDocumentStore(Backend): diff --git a/src/oold/backend/interface.py b/src/oold/backend/interface.py index 26c7424..6861861 100644 --- a/src/oold/backend/interface.py +++ b/src/oold/backend/interface.py @@ -1,4 +1,6 @@ +import operator as _op from abc import abstractmethod +from enum import Enum from typing import Dict, List, Optional, Type, Union from pydantic import BaseModel @@ -6,6 +8,29 @@ from oold.static import GenericLinkedBaseModel +class ComparisonOperator(str, Enum): + EQ = "eq" + NE = "ne" + LT = "lt" + LE = "le" + GT = "gt" + GE = "ge" + + +_COMPARISON_FNS = { + ComparisonOperator.EQ: _op.eq, + ComparisonOperator.NE: _op.ne, + ComparisonOperator.LT: _op.lt, + ComparisonOperator.LE: _op.le, + ComparisonOperator.GT: _op.gt, + ComparisonOperator.GE: _op.ge, +} + + +def apply_operator(operator: ComparisonOperator, a, b) -> bool: + return _COMPARISON_FNS[operator](a, b) + + class SetResolverParam(BaseModel): iri: str resolver: "Resolver" @@ -31,8 +56,45 @@ class ResolveResult(BaseModel): nodes: Dict[str, Union[None, GenericLinkedBaseModel]] +class Query(BaseModel): + op1: Union["Query", "Condition"] + operator: str + op2: Union["Query", "Condition"] + + # override the & operator + def __and__(self, other): + return Query(op1=self, operator="and", op2=other) + + +class Condition(BaseModel): + field: str + operator: Optional[ComparisonOperator] = None + value: Optional[Union[str, int, float]] = None + + # override the == operator + def __eq__(self, other): + self.operator = "eq" + self.value = other + return self + + # override the & operator + def __and__(self, other): + return Query(op1=self, operator="and", op2=other) + + +class QueryParam(BaseModel): + query: Union[Query, Condition] + model_cls: Optional[Type[GenericLinkedBaseModel]] = None + + +class LinkedDataFormat(str, Enum): + JSON_LD = "JSON-LD" + JSON = "JSON" + + class Resolver(BaseModel): model_cls: Optional[Type[GenericLinkedBaseModel]] = None + format: Optional[LinkedDataFormat] = LinkedDataFormat.JSON_LD @abstractmethod def resolve_iris(self, iris: List[str]) -> Dict[str, Dict]: @@ -53,11 +115,20 @@ def resolve(self, request: ResolveParam): if jsonld_dict is None: nodes[iri] = None else: - node = model_cls.from_jsonld(jsonld_dict) + if self.format == LinkedDataFormat.JSON_LD: + node = model_cls.from_jsonld(jsonld_dict) + elif self.format == LinkedDataFormat.JSON: + node = model_cls.from_json(jsonld_dict) + else: + raise ValueError(f"Unsupported format {self.format}") nodes[iri] = node return ResolveResult(nodes=nodes) + def query(self, param: QueryParam) -> ResolveResult: + """Query the backend and return a ResolveResult.""" + raise NotImplementedError("Query method not implemented in Resolver subclass") + global _resolvers _resolvers = {} @@ -100,10 +171,6 @@ class StoreResult(BaseModel): success: bool -class Query(BaseModel): - pass - - class Backend(Resolver): def store(self, param: StoreParam) -> StoreResult: jsonld_dicts = {} @@ -111,17 +178,26 @@ def store(self, param: StoreParam) -> StoreResult: if node is None: jsonld_dicts[iri] = None else: - jsonld_dicts[iri] = node.to_jsonld() - return self.store_jsonld_dicts(jsonld_dicts) + if self.format == LinkedDataFormat.JSON_LD: + jsonld_dicts[iri] = node.to_jsonld() + elif self.format == LinkedDataFormat.JSON: + jsonld_dicts[iri] = node.to_json() + else: + raise ValueError(f"Unsupported format {self.format}") + if self.format == LinkedDataFormat.JSON: + return self.store_json_dicts(jsonld_dicts) + else: + return self.store_jsonld_dicts(jsonld_dicts) - @abstractmethod def store_jsonld_dicts(self, jsonld_dicts: Dict[str, Dict]) -> StoreResult: - pass + raise NotImplementedError( + "store_jsonld_dicts method not implemented in Backend subclass" + ) - @abstractmethod - def query(self, query: Query) -> ResolveResult: - """Query the backend and return a ResolveResult.""" - pass + def store_json_dicts(self, json_dicts: Dict[str, Dict]) -> StoreResult: + raise NotImplementedError( + "store_json_dicts method not implemented in Backend subclass" + ) global _backends @@ -129,6 +205,7 @@ def query(self, query: Query) -> ResolveResult: def set_backend(param: SetBackendParam) -> None: + _resolvers[param.iri] = param.backend _backends[param.iri] = param.backend diff --git a/src/oold/model/__init__.py b/src/oold/model/__init__.py index d32228a..bcca8a8 100644 --- a/src/oold/model/__init__.py +++ b/src/oold/model/__init__.py @@ -1,15 +1,37 @@ import json -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generic, + List, + Literal, + Optional, + TypeVar, + Union, + overload, +) import pydantic -from pydantic import BaseModel -from typing_extensions import Self +# monkey patching pydantic FieldInfo +import pydantic.fields +from pydantic import BaseModel, GetCoreSchemaHandler +from pydantic.fields import FieldInfo +from pydantic_core import core_schema +from typing_extensions import Self, get_args + +from oold.backend import interface from oold.backend.interface import ( + Condition, GetBackendParam, GetResolverParam, + Query, + QueryParam, ResolveParam, + Resolver, StoreParam, + apply_operator, get_backend, get_resolver, ) @@ -20,14 +42,61 @@ import_jsonld, ) + +class OOFieldInfo(FieldInfo): + """Extension of pydantic FieldInfo to support query + construction via operators like ==, <, >, etc.""" + + name: Optional[str] = None + parent: Optional["LinkedBaseModel"] = None + + def __init__(self, *args, **kwargs): + # print("OOFieldInfo init") + super().__init__(*args, **kwargs) + + def __eq__(self, other): + return Condition(field=self.name, operator="eq", value=other) + + def __ne__(self, other): + return Condition(field=self.name, operator="ne", value=other) + + def __lt__(self, other): + return Condition(field=self.name, operator="lt", value=other) + + def __le__(self, other): + return Condition(field=self.name, operator="le", value=other) + + def __gt__(self, other): + return Condition(field=self.name, operator="gt", value=other) + + def __ge__(self, other): + return Condition(field=self.name, operator="ge", value=other) + + +pydantic.fields.FieldInfo = OOFieldInfo + # pydantic v2 _types: Dict[str, pydantic.main._model_construction.ModelMetaclass] = {} +M = TypeVar("M", bound="LinkedBaseModel") + + # pydantic v2 class LinkedBaseModelMetaClass(pydantic.main._model_construction.ModelMetaclass): + _constructing: bool = False + """Guards against __getattribute__ intercepting field access during class + construction. Pydantic checks ``getattr(base, field_name, None)`` in its + metaclass __new__ to detect shadowed BaseModel attributes. Without this + flag our __getattribute__ override would return a truthy FieldInfo instead + of the default None, causing false-positive field-name collision errors.""" + def __new__(mcs, name, bases, namespace): - cls = super().__new__(mcs, name, bases, namespace) + LinkedBaseModelMetaClass._constructing = True + try: + cls = super().__new__(mcs, name, bases, namespace) + finally: + LinkedBaseModelMetaClass._constructing = False if hasattr(cls, "get_cls_iri"): iri = cls.get_cls_iri() @@ -41,37 +110,250 @@ def __new__(mcs, name, bases, namespace): # override operators, see https://docs.python.org/3/library/operator.html + if not TYPE_CHECKING: + + def __getattribute__(self, name): + # print(f"Accessing attribute {name}") + if type(self)._constructing: + return super().__getattribute__(name) + if name not in [ + "__bases__", + "model_fields", + "__pydantic_fields__", + "__dict__", + ]: + # if name not in ["model_fields"]: + # check if attribute is in fields + if ( + name not in self.__dict__ # prevent shadowing if default value + and hasattr(self, "model_fields") + and name in self.model_fields + ): + # private_attributes = self.__dict__.get('__private_attributes__') + # if private_attributes and name in private_attributes: + # return super().__getattribute__(name) + # print(f"Attribute {name} is in model fields") + # ToDo: lookup the fields property if available + # return Condition(field=name) + field_info = self.model_fields[name] + field_info.name = name + field_info.parent = self + return field_info + # f = super().__getattribute__(name) + # return f + else: + return super().__getattribute__(name) + return super().__getattribute__(name) + @overload - def __getitem__(cls: "LinkedBaseModel", item: str) -> Self: + def __getitem__(cls: type[M], item: str) -> M: + """Get a class instance by its IRI.""" ... @overload - def __getitem__(cls: "LinkedBaseModel", item: List[str]) -> List[Self]: + def __getitem__( + cls: type[M], item: List[str] + ) -> Union[M, "LinkedBaseModelList[M]"]: + """Get multiple class instances by their IRIs.""" + # note: type M is to blend in M attributes + # in the signature of LinkedBaseModelList[M] ... + @overload def __getitem__( - cls: "LinkedBaseModel", item: Union[str, List[str]] - ) -> Union[Self, List[Self]]: - """Allow access to the class by its IRI.""" - result = cls._resolve(item if isinstance(item, list) else [item]) - return result[item] if isinstance(item, str) else [result[i] for i in item] + cls: type[M], item: Union[Query, Condition, bool] + ) -> Union[M, "LinkedBaseModelList[M]"]: + """Get class instances matching a query.""" + # note: (Entity.name == "test") is interpreted as bool + ... + def __getitem__( + cls: type[M], item: Union[str, List[str], Query, Condition, bool] + ) -> Union[M, "LinkedBaseModelList[M]", Optional["LinkedBaseModelList[M]"]]: + """Select instances of the class by IRI or by query.""" + return cls.oold_query(item) -# the following switch ensures that autocomplete works in IDEs like VSCode -if TYPE_CHECKING: + def __setitem__(cls: type[M], key, value: type[M]): + value._store() - class _LinkedBaseModel(BaseModel, GenericLinkedBaseModel): - pass -else: +T = TypeVar("T") - class _LinkedBaseModel( - BaseModel, GenericLinkedBaseModel, metaclass=LinkedBaseModelMetaClass + +class LinkedBaseModelList(Generic[T], List[Optional[T]]): + """Extension of list that tracks changes to the list. + by syncing every modification with the __iri__ field of the parent model.""" + + def __init__( + self, *args: Optional[T], _synced_iri_list: Optional[List[str]] = None ): - pass + super().__init__(*args) + self._synced_iri_list = ( + _synced_iri_list # if _synced_iri_list is not None else [] + ) + # self._synced_iri_list.extend( + # item.get_iri() for item in self if item is not None + # ) + # initialize the synced_iri_list with the IRIs of the initial items in the list + if self._synced_iri_list is not None: + self._synced_list = args[0] + self._synced_iri_list.extend( + item.get_iri() + for item in self + if item is not None and item.get_iri() not in self._synced_iri_list + ) + + # def _set_synced_iri_list(self, iri_list: List[str]) -> None: + # """Set the list of IRIs that are synced with the linked data store.""" + # self._synced_iri_list = iri_list + + @classmethod + def __get_pydantic_core_schema__( + cls, source: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + instance_schema = core_schema.is_instance_schema(cls) + + args = get_args(source) + if args: + # replace the type and rely on Pydantic to generate the right schema + # for `Sequence` + sequence_t_schema = handler.generate_schema(List[args[0]]) + else: + sequence_t_schema = handler.generate_schema(List) + + non_instance_schema = core_schema.no_info_after_validator_function( + LinkedBaseModelList, sequence_t_schema + ) + return core_schema.union_schema([instance_schema, non_instance_schema]) + + def append(self, item: Optional[T]) -> None: + if self._synced_iri_list is not None: + self._synced_iri_list.append(item.get_iri()) + self._synced_list.append(item) + super().append(item) + + def remove(self, item: Optional[T]) -> None: + if self._synced_iri_list is not None: + self._synced_iri_list.remove(item.get_iri()) + self._synced_list.remove(item) + super().remove(item) + + def extend(self, iterable): + if self._synced_iri_list is not None: + self._synced_iri_list.extend( + item.get_iri() for item in iterable if item is not None + ) + self._synced_list.extend(iterable) + return super().extend(iterable) + + def get_item_type(self): + # Returns the actual type argument, e.g. Entity + if hasattr(self, "__orig_class__"): + return get_args(self.__orig_class__)[0] + return None + + # override [] operator to also support string indices + + @overload + def __getitem__(self, index: str) -> Optional[Union[T, "LinkedBaseModelList[T]"]]: + ... + + @overload + def __getitem__(self, index: bool) -> Optional[Union[T, "LinkedBaseModelList[T]"]]: + ... + + @overload + def __getitem__(self, index: int) -> Optional[T]: + ... + + # allow pandas-style queries, e.g. l[Entity.name=='John'] + @overload + def __getitem__( + self, index: Union[Query, Condition, bool] + ) -> Optional[Union[T, "LinkedBaseModelList[T]"]]: + ... + + def __getitem__(self, index): + if isinstance(index, str): + if index.startswith("@"): + # query, e.g. "@name=='John'" + key = index[1:].split("==")[0].strip() + value = index.split("==")[1].strip("'\"") + return LinkedBaseModelList[self.get_item_type()]( + [ + item + for item in self + if item and getattr(item, key, None) == value + ], + _synced_iri_list=self._synced_iri_list, + ) + + else: + # IRI lookup + for item in self: + if item and item.get_iri() == index: + return item + raise KeyError(f"No item with IRI {index} found") + elif isinstance(index, Condition): + key = index.field + op = index.operator + value = index.value + return LinkedBaseModelList[self.get_item_type()]( + [ + item + for item in self + if item and apply_operator(op, getattr(item, key, None), value) + ], + _synced_iri_list=self._synced_iri_list, + ) + elif isinstance(index, Query): + raise NotImplementedError("Query-based indexing not implemented yet") + else: + return super().__getitem__(index) + + def __getattribute__(self, name): + if not name == "__orig_class__": + # if name == "links": + # print(typing.get_args(self)) + if self is not None and hasattr(self, "__orig_class__"): + _type = get_args(self.__orig_class__)[0] + if name in _type.model_fields.keys(): + # print(f"Attribute {name} is in type {_type}") + # build a new LinkedBaseModelList with all + # the values of this attribute + # if attribute is List + result_list = LinkedBaseModelList[_type]([], _synced_iri_list=None) + for item in self: + if item is not None and hasattr(item, name): + value = getattr(item, name) + if isinstance(value, list): + result_list.extend(value) + else: + result_list.append(value) + return result_list + + # else: + return super().__getattribute__(name) + + +# the following switch ensures that autocomplete works in IDEs like VSCode +# if TYPE_CHECKING: + +# class _LinkedBaseModel(BaseModel, GenericLinkedBaseModel): +# pass + +# else: +# class _LinkedBaseModel( +# BaseModel, GenericLinkedBaseModel, metaclass=LinkedBaseModelMetaClass +# ): +# pass -class LinkedBaseModel(_LinkedBaseModel): + +# class LinkedBaseModel(_LinkedBaseModel): +class LinkedBaseModel( + BaseModel, GenericLinkedBaseModel, metaclass=LinkedBaseModelMetaClass +): """LinkedBaseModel for pydantic v2""" __iris__: Optional[Dict[str, Union[str, List[str]]]] = {} @@ -310,6 +592,9 @@ def __getattribute__(self, name): if name in ["__dict__", "__pydantic_private__", "__iris__"]: return BaseModel.__getattribute__(self, name) # prevent loop + if name == "model_fields": + return type(self).model_fields + else: if hasattr(self, "__iris__"): if name in self.__iris__ and len(self.__iris__[name]) > 0: @@ -334,7 +619,12 @@ def __getattribute__(self, name): if node: self.__setattr__(name, node, True) - return BaseModel.__getattribute__(self, name) + result = BaseModel.__getattribute__(self, name) + if isinstance(result, list) and name in self.__iris__: + result = LinkedBaseModelList[type(self)]( + result, _synced_iri_list=self.__iris__[name] + ) + return result def model_dump(self, **kwargs): # extent BaseClass export function # print("dict") @@ -364,6 +654,81 @@ def store_jsonld(self): """Store the model instance in a backend matching its IRI.""" self._store() + @classmethod + def _oold_query( + cls, query: Union[str, List[str], Query, Condition] + ) -> "LinkedBaseModelList[Self]": + # get all resolvers + # ToDo: filter resolvers that support this class + resolvers: List[Resolver] = interface._resolvers.values() + node_list = [] + for r in resolvers: + try: + if isinstance(query, (str, list)): + _node_list = r.resolve( + ResolveParam( + iris=[query] if isinstance(query, str) else query, + model_cls=cls, + ) + ).nodes.values() + else: + _node_list = r.query( + QueryParam(query=query, model_cls=cls) + ).nodes.values() + node_list.extend(_node_list) + except NotImplementedError: + # resolver does not support query + continue + + if isinstance(query, str): + return node_list[0] if len(node_list) > 0 else None + else: + return ( + LinkedBaseModelList[Self](node_list, _synced_iri_list=None) + if len(node_list) > 0 + else None + ) + + @overload + @classmethod + def oold_query(cls, item: str) -> Self: + ... + + @overload + @classmethod + def oold_query(cls, item: List[str]) -> "LinkedBaseModelList[Self]": + ... + + # note: (Entity.name == "test") is interpreted as bool + @overload + @classmethod + def oold_query( + cls, item: Union[Query, Condition, bool] + ) -> Optional["LinkedBaseModelList[Self]"]: + ... + + @classmethod + def oold_query( + cls, item: Union[str, List[str], Query, bool] + ) -> Union[ + Self, "LinkedBaseModelList[Self]", Optional["LinkedBaseModelList[Self]"] + ]: + """Allow access to the class by its IRI.""" + return cls._oold_query(item) + # if isinstance(item, Query): + # # resolve all instances of this class + # #print(f"Select all {cls.__name__} that match {index}") + # #return cls._oold_query(item) + # return cls(id="ex:test", name="test") + # else: + # result = cls._resolve(item if isinstance(item, list) else [item]) + # return ( + # result[item] if isinstance(item, str) + # else LinkedBaseModelList[Self]( + # [result[i] for i in item] + # ) + # ) + # pydantic v2 def model_dump_json( self, @@ -448,3 +813,13 @@ def to_json(self) -> Dict: def from_json(cls, data: Dict) -> "LinkedBaseModel": """Constructs a model instance from a JSON representation.""" return import_json(BaseModel, LinkedBaseModel, cls, data, _types) + + # @classmethod + # def model_json_schema( + # cls, by_alias=True, ref_template=..., + # schema_generator=..., mode='validation', + # ) -> dict[str, Any]: + # # return super().model_json_schema( + # # by_alias, ref_template, schema_generator, mode + # # ) + # return cls.export_schema(cls) diff --git a/src/oold/model/v1/__init__.py b/src/oold/model/v1/__init__.py index dd27f4c..08fc27f 100644 --- a/src/oold/model/v1/__init__.py +++ b/src/oold/model/v1/__init__.py @@ -1,15 +1,32 @@ import json -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + List, + Optional, + TypeVar, + Union, + overload, +) import pydantic from pydantic.v1 import BaseModel, PrivateAttr -from typing_extensions import Self +from typing_extensions import Self, get_args +from oold.backend import interface from oold.backend.interface import ( + Condition, GetBackendParam, GetResolverParam, + Query, + QueryParam, ResolveParam, + Resolver, StoreParam, + apply_operator, get_backend, get_resolver, ) @@ -23,15 +40,66 @@ if TYPE_CHECKING: from pydantic.v1.typing import AbstractSetIntStr, MappingIntStrAny +# monkey patching pydantic v1 FieldInfo +import pydantic.v1.fields +from pydantic.v1.fields import FieldInfo + + +class OOFieldInfo(FieldInfo): + """Extension of pydantic v1 FieldInfo to support query + construction via operators like ==, <, >, etc.""" + + __slots__ = ("name", "parent") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.name = None + self.parent = None + + def __eq__(self, other): + return Condition(field=self.name, operator="eq", value=other) + + def __ne__(self, other): + return Condition(field=self.name, operator="ne", value=other) + + def __lt__(self, other): + return Condition(field=self.name, operator="lt", value=other) + + def __le__(self, other): + return Condition(field=self.name, operator="le", value=other) + + def __gt__(self, other): + return Condition(field=self.name, operator="gt", value=other) + + def __ge__(self, other): + return Condition(field=self.name, operator="ge", value=other) + + +pydantic.v1.fields.FieldInfo = OOFieldInfo + # pydantic v1 _types: Dict[str, pydantic.v1.main.ModelMetaclass] = {} +M = TypeVar("M", bound="LinkedBaseModel") + + # pydantic v1 class LinkedBaseModelMetaClass(pydantic.v1.main.ModelMetaclass): + _constructing: bool = False + """Guards against __getattribute__ intercepting field access during class + construction. Pydantic checks ``getattr(base, field_name, None)`` in its + metaclass __new__ to detect shadowed BaseModel attributes. Without this + flag our __getattribute__ override would return a truthy FieldInfo instead + of the default None, causing false-positive field-name collision errors.""" + def __new__(mcs, name, bases, namespace): - cls = super().__new__(mcs, name, bases, namespace) + LinkedBaseModelMetaClass._constructing = True + try: + cls = super().__new__(mcs, name, bases, namespace) + finally: + LinkedBaseModelMetaClass._constructing = False if hasattr(cls, "get_cls_iri"): iri = cls.get_cls_iri() @@ -45,20 +113,172 @@ def __new__(mcs, name, bases, namespace): # override operators, see https://docs.python.org/3/library/operator.html + if not TYPE_CHECKING: + + def __getattribute__(self, name): + if type(self)._constructing: + return super().__getattribute__(name) + if name not in ["__bases__", "__fields__", "__dict__"]: + if ( + name not in self.__dict__ + and hasattr(self, "__fields__") + and name in self.__fields__ + ): + field_info = self.__fields__[name].field_info + field_info.name = name + field_info.parent = self + return field_info + else: + return super().__getattribute__(name) + return super().__getattribute__(name) + + @overload + def __getitem__(cls: type[M], item: str) -> M: + """Get a class instance by its IRI.""" + ... + @overload - def __getitem__(cls: "LinkedBaseModel", item: str) -> Self: + def __getitem__( + cls: type[M], item: List[str] + ) -> Union[M, "LinkedBaseModelList[M]"]: + """Get multiple class instances by their IRIs.""" ... @overload - def __getitem__(cls: "LinkedBaseModel", item: List[str]) -> List[Self]: + def __getitem__( + cls: type[M], item: Union[Query, Condition, bool] + ) -> Union[M, "LinkedBaseModelList[M]"]: + """Get class instances matching a query.""" ... def __getitem__( - cls: "LinkedBaseModel", item: Union[str, List[str]] - ) -> Union[Self, List[Self]]: - """Allow access to the class by its IRI.""" - result = cls._resolve(item if isinstance(item, list) else [item]) - return result[item] if isinstance(item, str) else [result[i] for i in item] + cls: type[M], item: Union[str, List[str], Query, Condition, bool] + ) -> Union[M, "LinkedBaseModelList[M]", Optional["LinkedBaseModelList[M]"]]: + """Select instances of the class by IRI or by query.""" + return cls.oold_query(item) + + def __setitem__(cls: type[M], key, value: type[M]): + value._store() + + +T = TypeVar("T") + + +class LinkedBaseModelList(Generic[T], List[Optional[T]]): + """Extension of list that tracks changes to the list + by syncing every modification with the __iri__ field of the parent model.""" + + def __init__( + self, *args: Optional[T], _synced_iri_list: Optional[List[str]] = None + ): + super().__init__(*args) + self._synced_iri_list = _synced_iri_list + if self._synced_iri_list is not None: + self._synced_list = args[0] + self._synced_iri_list.extend( + item.get_iri() + for item in self + if item is not None and item.get_iri() not in self._synced_iri_list + ) + + def append(self, item: Optional[T]) -> None: + if self._synced_iri_list is not None: + self._synced_iri_list.append(item.get_iri()) + self._synced_list.append(item) + super().append(item) + + def remove(self, item: Optional[T]) -> None: + if self._synced_iri_list is not None: + self._synced_iri_list.remove(item.get_iri()) + self._synced_list.remove(item) + super().remove(item) + + def extend(self, iterable): + if self._synced_iri_list is not None: + self._synced_iri_list.extend( + item.get_iri() for item in iterable if item is not None + ) + self._synced_list.extend(iterable) + return super().extend(iterable) + + def get_item_type(self): + if hasattr(self, "__orig_class__"): + return get_args(self.__orig_class__)[0] + return None + + # override [] operator to also support string indices + + @overload + def __getitem__(self, index: str) -> Optional[Union[T, "LinkedBaseModelList[T]"]]: + ... + + @overload + def __getitem__(self, index: bool) -> Optional[Union[T, "LinkedBaseModelList[T]"]]: + ... + + @overload + def __getitem__(self, index: int) -> Optional[T]: + ... + + # allow pandas-style queries, e.g. l[Entity.name=='John'] + @overload + def __getitem__( + self, index: Union[Query, Condition, bool] + ) -> Optional[Union[T, "LinkedBaseModelList[T]"]]: + ... + + def __getitem__(self, index): + if isinstance(index, str): + if index.startswith("@"): + # query, e.g. "@name=='John'" + key = index[1:].split("==")[0].strip() + value = index.split("==")[1].strip("'\"") + return LinkedBaseModelList[self.get_item_type()]( + [ + item + for item in self + if item and getattr(item, key, None) == value + ], + _synced_iri_list=self._synced_iri_list, + ) + else: + # IRI lookup + for item in self: + if item and item.get_iri() == index: + return item + raise KeyError(f"No item with IRI {index} found") + elif isinstance(index, Condition): + key = index.field + op = index.operator + value = index.value + return LinkedBaseModelList[self.get_item_type()]( + [ + item + for item in self + if item and apply_operator(op, getattr(item, key, None), value) + ], + _synced_iri_list=self._synced_iri_list, + ) + elif isinstance(index, Query): + raise NotImplementedError("Query-based indexing not implemented yet") + else: + return super().__getitem__(index) + + def __getattribute__(self, name): + if not name == "__orig_class__": + if self is not None and hasattr(self, "__orig_class__"): + _type = get_args(self.__orig_class__)[0] + if name in _type.__fields__.keys(): + result_list = LinkedBaseModelList[_type]([], _synced_iri_list=None) + for item in self: + if item is not None and hasattr(item, name): + value = getattr(item, name) + if isinstance(value, list): + result_list.extend(value) + else: + result_list.append(value) + return result_list + return super().__getattribute__(name) # the following switch ensures that autocomplete works in IDEs like VSCode @@ -315,7 +535,12 @@ def __getattribute__(self, name): if node: self.__setattr__(name, node, True) - return BaseModel.__getattribute__(self, name) + result = BaseModel.__getattribute__(self, name) + if isinstance(result, list) and name in self.__iris__: + result = LinkedBaseModelList[type(self)]( + result, _synced_iri_list=self.__iris__[name] + ) + return result @staticmethod def _resolve(iris): @@ -333,6 +558,65 @@ def store_jsonld(self): """Store the model instance in a backend matching its IRI.""" self._store() + @classmethod + def _oold_query( + cls, query: Union[str, List[str], Query, Condition] + ) -> "LinkedBaseModelList[Self]": + # get all resolvers + resolvers: List[Resolver] = interface._resolvers.values() + node_list = [] + for r in resolvers: + try: + if isinstance(query, (str, list)): + _node_list = r.resolve( + ResolveParam( + iris=[query] if isinstance(query, str) else query, + model_cls=cls, + ) + ).nodes.values() + else: + _node_list = r.query( + QueryParam(query=query, model_cls=cls) + ).nodes.values() + node_list.extend(_node_list) + except NotImplementedError: + continue + + if isinstance(query, str): + return node_list[0] if len(node_list) > 0 else None + else: + return ( + LinkedBaseModelList[Self](node_list, _synced_iri_list=None) + if len(node_list) > 0 + else None + ) + + @overload + @classmethod + def oold_query(cls, item: str) -> Self: + ... + + @overload + @classmethod + def oold_query(cls, item: List[str]) -> "LinkedBaseModelList[Self]": + ... + + @overload + @classmethod + def oold_query( + cls, item: Union[Query, Condition, bool] + ) -> Optional["LinkedBaseModelList[Self]"]: + ... + + @classmethod + def oold_query( + cls, item: Union[str, List[str], Query, bool] + ) -> Union[ + Self, "LinkedBaseModelList[Self]", Optional["LinkedBaseModelList[Self]"] + ]: + """Allow access to the class by its IRI.""" + return cls._oold_query(item) + # pydantic v1 def json( self, diff --git a/tests/query_test.py b/tests/query_test.py new file mode 100644 index 0000000..8f2b51a --- /dev/null +++ b/tests/query_test.py @@ -0,0 +1,321 @@ +import time +from typing import List, Optional + +import pytest + +from oold.backend.document_store import SimpleDictDocumentStore +from oold.backend.interface import ( + ComparisonOperator, + Condition, + Query, + SetBackendParam, + set_backend, +) + + +def _define_entity(pydantic_version): + """Define the Entity class for the given pydantic version.""" + if pydantic_version == "v1": + from pydantic.v1 import Field + + from oold.model.v1 import LinkedBaseModel, LinkedBaseModelList + + class Entity(LinkedBaseModel): + """A simple Entity schema""" + + class Config: + schema_extra = { + "@context": { + "ex": "http://example.org/", + "id": "@id", + "type": "@type", + "name": "ex:name", + "links": { + "@id": "ex:links", + "@type": "@id", + "@container": "@set", + }, + }, + "iri": "ex:Entity", + } + + id: str + """The IRI of the entity.""" + name: str + """The name of the entity.""" + type: Optional[str] = "ex:Entity" + """The type of the entity.""" + links: Optional[List["Entity"]] = Field( + None, + range="ex:Entity", + ) + """links to other entities""" + + else: + from pydantic import Field + + from oold.model import LinkedBaseModel, LinkedBaseModelList + + class Entity(LinkedBaseModel): + """A simple Entity schema""" + + model_config = { + "json_schema_extra": { + "@context": { + "ex": "http://example.org/", + "id": "@id", + "type": "@type", + "name": "ex:name", + "links": { + "@id": "ex:links", + "@type": "@id", + "@container": "@set", + }, + }, + "iri": "ex:Entity", + } + } + + id: str + """The IRI of the entity.""" + name: str + """The name of the entity.""" + type: Optional[str] = "ex:Entity" + """The type of the entity.""" + links: Optional[List["Entity"]] = Field( + None, + json_schema_extra={ + "range": "ex:Entity", + }, + ) + """links to other entities""" + + return Entity, LinkedBaseModelList + + +def _run_queries(pydantic_version): + Entity, LinkedBaseModelList = _define_entity(pydantic_version) + + backend = SimpleDictDocumentStore() + set_backend(SetBackendParam(iri="ex", backend=backend)) + + e1 = Entity(id="ex:e1", name="Entity 1") + e2 = Entity(id="ex:e2", name="Entity 1") + e1.store_jsonld() + e2.store_jsonld() + + q = (Entity.name == "test") & (Entity.id == "ex:e1") + assert type(q) is Query + + r1 = Entity[Entity.name == "Entity 1"] + assert len(r1) == 2 + + r2 = Entity[Entity.id == "ex:e1"] + assert len(r2) == 1 + assert r2[0].id == "ex:e1" + + r3 = Entity[(Entity.name == "Entity 1") & (Entity.id == "ex:e2")] + assert len(r3) == 1 + assert r3[0].id == "ex:e2" + + +def _run_linked_base_model_list(pydantic_version): + Entity, LinkedBaseModelList = _define_entity(pydantic_version) + + # test LinkedBaseModelList IRI synchronization + synced_iri_list = [] + el = LinkedBaseModelList[Entity]( + [Entity(id="ex:e1", name="Entity 1"), Entity(id="ex:e2", name="Entity 2")], + _synced_iri_list=synced_iri_list, + ) + assert synced_iri_list == ["ex:e1", "ex:e2"] + el.append(Entity(id="ex:e3", name="Entity 3")) + assert synced_iri_list == ["ex:e1", "ex:e2", "ex:e3"] + el.remove(Entity(id="ex:e2", name="Entity 2")) + assert synced_iri_list == ["ex:e1", "ex:e3"] + el.extend( + [Entity(id="ex:e4", name="Entity 4"), Entity(id="ex:e5", name="Entity 5")] + ) + assert synced_iri_list == ["ex:e1", "ex:e3", "ex:e4", "ex:e5"] + + assert el[0].id == "ex:e1" + assert el["ex:e3"].name == "Entity 3" + + # test string queries + result = el["@name=='Entity 3'"] + assert result[0].id == "ex:e3" + + # test Condition-based queries + assert el[Entity.name == "Entity 3"][0].id == "ex:e3" + + # test linked entities with IRI sync on attribute access + e1 = Entity(name="Entity 1", id="ex:e1") + e2 = Entity(name="Entity 2", id="ex:e2", links=[e1]) + e3 = Entity(name="Entity 3", id="ex:e3", links=[e1, e2]) + + assert e2.__iris__["links"] == ["ex:e1"] + e2.links.append(e3) + assert e2.links == [e1, e3] + assert e2.__iris__["links"] == ["ex:e1", "ex:e3"] + e2.links.remove(e1) + assert e2.__iris__["links"] == ["ex:e3"] + e2.links.extend([e1]) + assert e2.__iris__["links"] == ["ex:e3", "ex:e1"] + + assert e3.links[0].id == "ex:e1" + assert e3.links["@name=='Entity 2'"][0].id == "ex:e2" + assert e3.links[(Entity.name == "Entity 1")] is not None + assert e3.links[Entity.name == "Entity 1"][0].id == "ex:e1" + + assert [e for e in e3.links if e.name == "Entity 1"][0].id == "ex:e1" + + # test multi chain + assert ( + e3.links[Entity.name == "Entity 2"][0].links[Entity.name == "Entity 1"][0].id + == "ex:e1" + ) + assert ( + e3.links[Entity.name == "Entity 2"].links[Entity.name == "Entity 1"][0].id + == "ex:e1" + ) + + e3.links[Entity.name == "Entity 2"].links[Entity.name == "Entity 1"].name + + res = e3.links[Entity.name == "Entity 2"].links[Entity.name == "Entity 1"] + assert res[0].id == "ex:e1" + + +def _run_operators(pydantic_version): + Entity, LinkedBaseModelList = _define_entity(pydantic_version) + + el = LinkedBaseModelList[Entity]( + [ + Entity(id="ex:a", name="A"), + Entity(id="ex:b", name="B"), + Entity(id="ex:c", name="C"), + Entity(id="ex:d", name="D"), + ] + ) + + # test that OOFieldInfo dunder methods produce correct Condition objects + cond = Entity.name == "B" + assert isinstance(cond, Condition) and cond.operator == ComparisonOperator.EQ + + cond = Entity.name != "B" + assert isinstance(cond, Condition) and cond.operator == ComparisonOperator.NE + + cond = Entity.name < "B" + assert isinstance(cond, Condition) and cond.operator == ComparisonOperator.LT + + cond = Entity.name <= "B" + assert isinstance(cond, Condition) and cond.operator == ComparisonOperator.LE + + cond = Entity.name > "B" + assert isinstance(cond, Condition) and cond.operator == ComparisonOperator.GT + + cond = Entity.name >= "B" + assert isinstance(cond, Condition) and cond.operator == ComparisonOperator.GE + + # test eq: names equal to "B" + r = el[Entity.name == "B"] + assert len(r) == 1 and r[0].id == "ex:b" + + # test ne: names not equal to "B" + r = el[Entity.name != "B"] + assert len(r) == 3 + assert {e.id for e in r} == {"ex:a", "ex:c", "ex:d"} + + # test lt: names less than "C" (lexicographic: A, B) + r = el[Entity.name < "C"] + assert len(r) == 2 + assert {e.id for e in r} == {"ex:a", "ex:b"} + + # test le: names less than or equal to "B" + r = el[Entity.name <= "B"] + assert len(r) == 2 + assert {e.id for e in r} == {"ex:a", "ex:b"} + + # test gt: names greater than "B" + r = el[Entity.name > "B"] + assert len(r) == 2 + assert {e.id for e in r} == {"ex:c", "ex:d"} + + # test ge: names greater than or equal to "C" + r = el[Entity.name >= "C"] + assert len(r) == 2 + assert {e.id for e in r} == {"ex:c", "ex:d"} + + +def _run_performance(pydantic_version): + Entity, LinkedBaseModelList = _define_entity(pydantic_version) + + # create 3 layers of entities with 333 entities each + # connect each node on a layer with all nodes on the next layer + layers = 3 + entities_per_layer = 333 + all_entities = [] + + start_time = time.time() + for layer in range(layers): + layer_entities = [] + for i in range(entities_per_layer): + e = Entity(name=f"Entity {layer}-{i}", id=f"ex:e{i}") + layer_entities.append(e) + all_entities.append(layer_entities) + if layer > 0: + for parent in all_entities[layer - 1]: + parent.links = layer_entities + end_time = time.time() + total_links = sum( + len(e.links) if e.links else 0 for layer in all_entities for e in layer + ) + print( + f"[{pydantic_version}] Created" + f" {layers * entities_per_layer} entities with" + f" {total_links} links" + f" in {end_time - start_time:.2f} seconds" + ) + + layer1 = LinkedBaseModelList[Entity](all_entities[0]) + + start_time = time.time() + res = ( + layer1[Entity.name == "Entity 0-50"] + .links[Entity.name == "Entity 1-50"] + .links[Entity.name == "Entity 2-50"] + ) + end_time = time.time() + assert res[0].name == "Entity 2-50" + elapsed = end_time - start_time + print(f"[{pydantic_version}] Accessed a specific link" f" in {elapsed:.6f} seconds") + + +@pytest.mark.parametrize("pydantic_version", ["v1", "v2"]) +def test_queries(pydantic_version): + _run_queries(pydantic_version) + + +@pytest.mark.parametrize("pydantic_version", ["v1", "v2"]) +def test_linked_base_model_list(pydantic_version): + _run_linked_base_model_list(pydantic_version) + + +@pytest.mark.parametrize("pydantic_version", ["v1", "v2"]) +def test_operators(pydantic_version): + _run_operators(pydantic_version) + + +@pytest.mark.parametrize("pydantic_version", ["v1", "v2"]) +def test_performance_large_linked_structure(pydantic_version): + _run_performance(pydantic_version) + + +if __name__ == "__main__": + _run_queries("v1") + _run_queries("v2") + _run_linked_base_model_list("v1") + _run_linked_base_model_list("v2") + _run_operators("v1") + _run_operators("v2") + _run_performance("v1") + _run_performance("v2")