Source code for runic.orm.session.session

"""Session: unit-of-work manager for graph writes."""

from __future__ import annotations

import logging
import weakref
from types import TracebackType
from typing import TYPE_CHECKING, Any, TypeVar

from runic.orm.core.descriptors import _NOT_LOADED, FieldDescriptor
from runic.orm.exceptions import EntityNotFoundError
from runic.orm.mapper.mapper import Mapper
from runic.orm.session._base import _SessionBase

if TYPE_CHECKING:
    from runic.orm.driver import GraphDriver, GraphResult
    from runic.orm.query.builder import QueryBuilder

log = logging.getLogger(__name__)

_T = TypeVar("_T")


[docs] class Session(_SessionBase): """Sync unit-of-work manager. Owns all mutations (``add``, ``delete``), single-entity lookup (``get``), identity map, and flush/commit lifecycle. Repositories hold a session reference and delegate writes and PK lookups to it. The backend-agnostic bookkeeping lives in :class:`~runic.orm.session._base._SessionBase`. **Transaction model** — determined by the injected driver: - **FalkorDB** (no native multi-query transactions): each ``GRAPH.QUERY`` is individually atomic. ``commit()`` flushes pending writes; ``rollback()`` discards un-flushed state only — it cannot undo writes already sent to the graph. - **Bolt-based drivers** (Neo4j, Memgraph, ArcadeDB): full ACID transactions via the Bolt protocol. The first query lazily opens a Bolt transaction; ``commit()`` / ``rollback()`` commit or discard all changes as a single atomic unit. - **Apache AGE** (psycopg3): full PostgreSQL ACID transactions. psycopg3 starts an implicit ``BEGIN`` on the first SQL statement; ``commit()`` / ``rollback()`` map to ``conn.commit()`` / ``conn.rollback()``. Drivers that support explicit transactions implement the :class:`~runic.orm.driver.TransactionalGraphDriver` protocol. The Session detects this via ``isinstance`` and wires commit/rollback accordingly. """ def __init__( self, driver: GraphDriver, mapper: Mapper | None = None, *, log_cypher: bool = False, ) -> None: from runic.orm.driver import TransactionalGraphDriver self._init_state(driver, mapper, log_cypher=log_cypher) # True when a driver-level transaction is open (lazy-begin on first query) self._in_transaction: bool = False self._is_transactional: bool = isinstance(driver, TransactionalGraphDriver) # ------------------------------------------------------------------ # Internal query runner # ------------------------------------------------------------------ def _run_query(self, cypher: str, params: dict[str, Any]) -> GraphResult: if self._log_cypher: log.debug("Cypher: %s | params: %s", cypher, params) if self._is_transactional and not self._in_transaction: self._driver.begin() self._in_transaction = True return self._driver.execute(cypher, params) # ------------------------------------------------------------------ # Lookup # ------------------------------------------------------------------
[docs] def get(self, cls: type, pk: Any, fetch: list[str] | None = None) -> Any | None: """Return entity from identity map or query graph; ``None`` if not found. Pass ``fetch=["rel_name", ...]`` to eager-load relationship fields in the same Cypher query using ``OPTIONAL MATCH``. """ key = (cls, pk) if key in self._identity_map: entity = self._identity_map[key] if entity.__dict__.get("_expired"): self._reload(entity, cls, pk) return entity if fetch: return self._get_with_fetch(cls, pk, fetch) cypher, params = self._mapper.build_get_query(cls, pk) result = self._run_query(cypher, params) if not result.rows: return None raw_node = result.rows[0][0] entity = self._mapper.decode_node(raw_node, cls) actual_pk = self._mapper.get_pk_value(entity) self._register_entity(entity, cls, actual_pk) log.debug("Loaded %s pk=%r from graph", cls.__name__, actual_pk) return entity
[docs] def load_relationship(self, entity: Any, field_name: str) -> Any: """Load a lazy relationship field and cache the result on the entity. Called by ``Field._trigger_lazy_load`` when a ``_NOT_LOADED`` sentinel is accessed on an entity that is attached to this session. Writes directly to ``entity.__dict__`` to bypass the dirty-tracking descriptor. """ cls = type(entity) node_meta = self._mapper.require_node_meta(cls) fi = next((f for f in node_meta.fields if f.name == field_name), None) if fi is None or fi.field.relationship is None: return None cypher, params = self._rel_loader.build_lazy_load_query(entity, fi) result = self._run_query(cypher, params) decoded = self._rel_loader.decode_lazy_result(result, fi) entity.__dict__[field_name] = decoded self._inject_session_into(decoded) log.debug("Lazy-loaded %r.%s", entity, field_name) return decoded
# ------------------------------------------------------------------ # Flush / Commit / Rollback # ------------------------------------------------------------------
[docs] def flush(self) -> None: """Execute all pending/dirty/deleted entities against the graph. Does **not** clear the identity map. Each entity write is a separate ``graph.query()`` call. Entities with ``generated=True`` IDs are handled individually so the returned ID can be assigned before continuing. """ self._flush_pending() self._flush_dirty() self._flush_deleted()
[docs] def commit(self) -> None: """``flush()`` then clear the pending/deleted tracking sets. For transactional drivers (Bolt, AGE), also commits the active database transaction so all flushed writes become durable and visible. """ self.flush() self._pending.clear() self._deleted.clear() if self._in_transaction: self._driver.commit() self._in_transaction = False log.debug("Session committed")
[docs] def rollback(self) -> None: """Discard un-flushed pending/deleted sets; expire all persistent entities. For transactional drivers (Bolt, AGE), also rolls back the active database transaction — writes already flushed but not yet committed are discarded atomically. For FalkorDB (no native transactions), only un-flushed in-memory state is cleared; writes already sent to the graph cannot be undone. """ self._pending.clear() self._deleted.clear() for entity in self._identity_map.values(): entity.__dict__["_expired"] = True entity.__dict__["_dirty"] = False if self._in_transaction: self._driver.rollback() self._in_transaction = False log.debug( "Session rolled back (pending/deleted cleared; persistent entities expired)" )
# ------------------------------------------------------------------ # Refresh # ------------------------------------------------------------------
[docs] def refresh(self, entity: Any) -> None: """Immediately re-query the entity from the graph and update in-place.""" cls = type(entity) pk = self._mapper.get_pk_value(entity) self._reload(entity, cls, pk)
# ------------------------------------------------------------------ # Relationship mutations # ------------------------------------------------------------------
[docs] def relate( self, source: Any, field_name: str | FieldDescriptor, target: Any, edge: Any | None = None, ) -> None: """Create or update a relationship between *source* and *target*. Uses ``MERGE`` semantics: if the relationship already exists its edge properties are updated; if not, it is created. Pass an ``Edge`` model instance as *edge* to write properties on the relationship itself. *field_name* may be a plain string **or** the class-level descriptor attribute (e.g. ``User.invited_trips``) for type-safe call sites. The cached value of the relation field on *source* is invalidated after the write so the next access re-fetches fresh data from the graph. """ fi = self._resolve_relation_fi(source, field_name) cypher, params = self._rel_writer.build_relate_query(source, fi, target, edge) self._run_query(cypher, params) source.__dict__[fi.name] = _NOT_LOADED log.debug("Related %r -[%s]-> %r", source, fi.field.relationship, target)
[docs] def unrelate( self, source: Any, field_name: str | FieldDescriptor, target: Any, ) -> None: """Delete the relationship between *source* and *target*. *field_name* may be a plain string **or** the class-level descriptor attribute (e.g. ``User.invited_trips``) for type-safe call sites. The cached value of the relation field on *source* is invalidated after the write so the next access re-fetches fresh data from the graph. """ fi = self._resolve_relation_fi(source, field_name) cypher, params = self._rel_writer.build_unrelate_query(source, fi, target) self._run_query(cypher, params) source.__dict__[fi.name] = _NOT_LOADED log.debug("Unrelated %r -[%s]-x %r", source, fi.field.relationship, target)
# ------------------------------------------------------------------ # Raw Cypher # ------------------------------------------------------------------
[docs] def execute( self, cypher: str, params: dict[str, Any] | None = None, write: bool = False, # noqa: ARG002 (reserved for future routing) ) -> Any: """Execute raw Cypher; returns ``QueryResult``; no entity mapping.""" return self._run_query(cypher, params or {})
# ------------------------------------------------------------------ # Statement-based execution (select() pattern) # ------------------------------------------------------------------
[docs] def scalars(self, stmt: QueryBuilder[_T]) -> list[_T]: """Execute a :func:`~runic.orm.query.select` statement; return decoded entities. Type-safe: ``session.scalars(select(User).where(...))`` infers ``list[User]``. Parameters ---------- stmt: An unbound :class:`~runic.orm.query.builder.QueryBuilder` created via :func:`~runic.orm.query.select`. """ self._require_query_builder(stmt, "scalars") with stmt._bound_to(self) as bound: # noqa: SLF001 cypher, params = bound.build() result = self._run_query(cypher, params) return bound._decode_node_result(result) # type: ignore[return-value] # noqa: SLF001
[docs] def scalar(self, stmt: QueryBuilder[_T]) -> _T | None: """Execute a :func:`~runic.orm.query.select` statement; return first entity or ``None``. Adds ``LIMIT 1`` internally without permanently modifying the statement. Type-safe: ``session.scalar(select(User).where(...))`` infers ``User | None``. Parameters ---------- stmt: An unbound :class:`~runic.orm.query.builder.QueryBuilder` created via :func:`~runic.orm.query.select`. """ self._require_query_builder(stmt, "scalar") old_limit = stmt._limit_val # noqa: SLF001 stmt._limit_val = 1 # noqa: SLF001 try: with stmt._bound_to(self) as bound: # noqa: SLF001 cypher, params = bound.build() result = self._run_query(cypher, params) entities = bound._decode_node_result(result) # noqa: SLF001 return entities[0] if entities else None # type: ignore[return-value] finally: stmt._limit_val = old_limit # noqa: SLF001
[docs] def all_rows(self, stmt: QueryBuilder[Any]) -> list[dict[str, Any]]: """Execute a :func:`~runic.orm.query.select` statement; return column-keyed dicts. Parameters ---------- stmt: An unbound :class:`~runic.orm.query.builder.QueryBuilder`. """ self._require_query_builder(stmt, "all_rows") with stmt._bound_to(self) as bound: # noqa: SLF001 cypher, params = bound.build() result = self._run_query(cypher, params) return bound._decode_rows_as_dicts(result) # noqa: SLF001
[docs] def all_with_edges(self, stmt: QueryBuilder[Any]) -> list[tuple[Any, ...]]: """Execute a :func:`~runic.orm.query.select` statement; return ``(NodeA, Edge, NodeB)`` tuples. Parameters ---------- stmt: An unbound :class:`~runic.orm.query.builder.QueryBuilder` with ``return_nodes()`` and ``return_edge()`` configured. """ self._require_query_builder(stmt, "all_with_edges") with stmt._bound_to(self) as bound: # noqa: SLF001 cypher, params = bound.build() result = self._run_query(cypher, params) return bound._decode_edge_result(result) # noqa: SLF001
[docs] def count(self, stmt: QueryBuilder[Any]) -> int: """Execute a :func:`~runic.orm.query.select` statement; return the row count. Parameters ---------- stmt: An unbound :class:`~runic.orm.query.builder.QueryBuilder`. """ self._require_query_builder(stmt, "count") with stmt._bound_to(self) as bound: # noqa: SLF001 return bound.count()
# ------------------------------------------------------------------ # Query builder entry points # ------------------------------------------------------------------
[docs] def query(self, cls: type[Any]) -> Any: """Return a :class:`~runic.orm.query.builder.QueryBuilder` for *cls*. This is the primary entry point for the fluent query builder API:: users = ( session.query(User) .where(User.active == True) .order_by(User.name) .limit(20) .all() ) Parameters ---------- cls: A registered :class:`~runic.orm.core.models.Node` subclass. Returns ------- QueryBuilder[cls] """ from runic.orm.query.builder import QueryBuilder return QueryBuilder(self, cls)
# ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------
[docs] def close(self) -> None: """Expunge all tracked entities; roll back any orphaned transaction. If ``close()`` is called without a prior ``commit()`` or ``rollback()`` (e.g. the session was not used as a context manager), any active driver-level transaction is rolled back to release the connection cleanly. """ if self._in_transaction: try: self._driver.rollback() except Exception: log.warning( "Session.close(): driver rollback failed; connection may leak" ) self._in_transaction = False self.expunge_all()
def __enter__(self) -> Session: return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: if exc_type is None: self.commit() else: self.rollback() self.close() # ------------------------------------------------------------------ # Private helpers # ------------------------------------------------------------------ def _get_with_fetch(self, cls: type, pk: Any, fetch: list[str]) -> Any | None: """Load entity and eager-fetch named relationship fields in one Cypher query.""" cypher, params, fetch_meta = self._rel_loader.build_get_with_fetch_query( cls, pk, fetch ) result = self._run_query(cypher, params) if not result.rows: return None row = result.rows[0] raw_node = row[0] entity = self._mapper.decode_node(raw_node, cls) related = self._rel_loader.decode_eager_columns(row, entity, fetch_meta) actual_pk = self._mapper.get_pk_value(entity) self._register_entity(entity, cls, actual_pk) self._inject_session_into(related) log.debug("Loaded %s pk=%r with fetch=%r", cls.__name__, actual_pk, fetch) return entity def _flush_pending(self) -> None: """CREATE all entities in the pending list.""" for entity in list(self._pending): cypher, params = self._mapper.build_create_query(entity) result = self._run_query(cypher, params) raw_node = result.rows[0][0] if result.rows else None if raw_node is not None: self._mapper.update_entity_from_node(entity, raw_node) entity.__dict__["_new"] = False entity.__dict__["_dirty"] = False pk = self._mapper.get_pk_value(entity) entity.__dict__["_session"] = weakref.ref(self) self._identity_map[(type(entity), pk)] = entity log.debug("Created %r pk=%r", entity, pk) self._pending.clear() def _flush_dirty(self) -> None: """MERGE/SET all dirty persistent entities.""" for (_cls, _pk), entity in list(self._identity_map.items()): if not entity.__dict__.get("_dirty", False): continue if entity.__dict__.get("_new", False): continue cypher, params = self._mapper.build_update_query(entity) if not cypher: entity.__dict__["_dirty"] = False continue result = self._run_query(cypher, params) if result.rows: self._mapper.update_entity_from_node(entity, result.rows[0][0]) else: entity.__dict__["_dirty"] = False log.debug("Updated %s", type(entity).__name__) def _flush_deleted(self) -> None: """DETACH DELETE all entities in the deleted list.""" for entity in list(self._deleted): cypher, params = self._mapper.build_delete_query(entity) self._run_query(cypher, params) cls = type(entity) pk = self._mapper.get_pk_value(entity) self._identity_map.pop((cls, pk), None) entity.__dict__.pop("_session", None) log.debug("Deleted %s pk=%r", cls.__name__, pk) self._deleted.clear() def _reload(self, entity: Any, cls: type, pk: Any) -> None: """Re-query a single entity from the graph and update it in-place.""" cypher, params = self._mapper.build_get_query(cls, pk) result = self._run_query(cypher, params) if not result.rows: raise EntityNotFoundError( f"{cls.__name__} pk={pk!r} no longer exists in the graph" ) self._mapper.update_entity_from_node(entity, result.rows[0][0])