"""Repository: typed reads and custom Cypher helpers for one entity type."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from runic.orm.repository.cypher import _SCALAR_TYPES, map_cypher_result
from runic.orm.repository.pagination import Page, Pageable
from runic.orm.repository.protocol import RepositoryProtocol
if TYPE_CHECKING:
from runic.orm.session.session import Session
log = logging.getLogger(__name__)
[docs]
class Repository[T](RepositoryProtocol[T]):
"""Typed reads and explicit Cypher helpers for one entity type.
Mutations (``add``, ``delete``) and single-PK lookup (``get``) belong to
the :class:`Session`. All reads here register returned entities in the
session identity map.
Example::
with Session(graph) as session:
repo = Repository(session, Person)
all_people = repo.find_all()
page = repo.find_all_paginated(Pageable(page=0, size=25))
"""
def __init__(self, session: Session, entity_class: type[T]) -> None:
self._session = session
self._cls = entity_class
# ------------------------------------------------------------------
# Standard reads
# ------------------------------------------------------------------
[docs]
def find_all(self, fetch: list[str] | None = None) -> list[T]:
"""Return all entities of this type, with optional eager relationship loading."""
if fetch:
cypher, params, fetch_meta = (
self._session.rel_loader.build_find_all_with_fetch_query(
self._cls, fetch
)
)
result = self._session.execute(cypher, params)
return self._decode_rows_with_fetch(result, fetch_meta)
cypher, params = self._session.mapper.build_find_all_query(self._cls)
result = self._session.execute(cypher, params)
return self._decode_rows(result)
[docs]
def find_all_by_ids(
self, pks: list[Any], fetch: list[str] | None = None
) -> list[T]:
"""Return entities whose primary keys are in *pks*."""
if not pks:
return []
if fetch:
cypher, params, fetch_meta = (
self._session.rel_loader.build_find_all_by_ids_with_fetch_query(
self._cls, pks, fetch
)
)
result = self._session.execute(cypher, params)
return self._decode_rows_with_fetch(result, fetch_meta)
cypher, params = self._session.mapper.build_find_all_by_ids_query(
self._cls, pks
)
result = self._session.execute(cypher, params)
return self._decode_rows(result)
[docs]
def count(self) -> int:
"""Return the total number of entities of this type."""
cypher, params = self._session.mapper.build_count_query(self._cls)
result = self._session.execute(cypher, params)
if result.result_set:
return int(result.result_set[0][0])
return 0
[docs]
def exists(self, pk: Any) -> bool:
"""Return ``True`` if an entity with *pk* exists in the graph."""
cypher, params = self._session.mapper.build_exists_query(self._cls, pk)
result = self._session.execute(cypher, params)
if result.result_set:
return int(result.result_set[0][0]) > 0
return False
[docs]
def find_all_paginated(self, pageable: Pageable) -> Page[T]:
"""Return a single :class:`Page` of results for *pageable*."""
cypher, params = self._session.mapper.build_paginated_query(self._cls, pageable)
result = self._session.execute(cypher, params)
items = self._decode_rows(result)
count_cypher, count_params = self._session.mapper.build_count_query(self._cls)
count_result = self._session.execute(count_cypher, count_params)
total = int(count_result.result_set[0][0]) if count_result.result_set else 0
return Page(
items=items,
page_number=pageable.page,
size=pageable.size,
total_elements=total,
)
# ------------------------------------------------------------------
# Custom Cypher helpers
# ------------------------------------------------------------------
[docs]
def query(self) -> Any:
"""Return a :class:`~runic.orm.query.builder.QueryBuilder` for this repository's entity type.
Shorthand for ``session.query(self._cls)``::
repo = Repository(session, User)
# These are equivalent:
users = repo.query().where(User.active == True).all()
users = session.query(User).where(User.active == True).all()
Returns
-------
QueryBuilder[T]
"""
from runic.orm.query.builder import QueryBuilder
return QueryBuilder(self._session, self._cls)
[docs]
def cypher(
self,
query: str,
params: dict[str, Any] | None = None,
*,
returns: type | None = None,
write: bool = False,
) -> list[Any]:
"""Execute *query* and return a typed list.
``returns`` controls decoding: ``int``, ``str``, ``bool`` → scalar;
``dict`` → column-keyed dicts; any ``Node`` subclass → decoded entities
registered in the session identity map; ``None`` → empty list.
"""
result = self._session.execute(query, params or {}, write=write)
register_fn = (
self._session.register_or_get
if returns is not None
and returns not in _SCALAR_TYPES
and returns is not dict
else None
)
return map_cypher_result(result, returns, self._session.mapper, register_fn)
[docs]
def cypher_one(
self,
query: str,
params: dict[str, Any] | None = None,
*,
returns: type | None = None,
write: bool = False,
) -> Any | None:
"""Execute *query* and return the first mapped value, or ``None``."""
items = self.cypher(query, params, returns=returns, write=write)
return items[0] if items else None
[docs]
def cypher_raw(
self,
query: str,
params: dict[str, Any] | None = None,
*,
write: bool = False,
) -> Any:
"""Execute *query* and return the raw ``QueryResult`` without entity mapping."""
return self._session.execute(query, params or {}, write=write)
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _decode_rows(self, result: Any) -> list[T]:
"""Decode plain MATCH rows (single node per row) and register in identity map."""
entities: list[T] = []
for row in result.result_set:
decoded = self._session.mapper.decode_node(row[0], self._cls)
registered = self._session.register_or_get(decoded)
entities.append(registered)
return entities
def _decode_rows_with_fetch(
self,
result: Any,
fetch_meta: list[tuple[str, Any]],
) -> list[T]:
"""Decode rows that include eager-loaded relationship columns."""
entities: list[T] = []
for row in result.result_set:
decoded = self._session.mapper.decode_node(row[0], self._cls)
registered = self._session.register_or_get(decoded)
related = self._session.rel_loader.decode_eager_columns(
row, registered, fetch_meta
)
for rel_entity in related:
self._session.register_or_get(rel_entity)
entities.append(registered)
return entities