Source code for runic.orm.query.builder

"""Fluent query builder for the runic ORM.

See :doc:`/query_builder` for the full API reference and examples.

:class:`QueryBuilder` is the core builder; specialised subclasses
(:class:`~runic.orm.query.specialised.AsyncQueryBuilder`,
:class:`~runic.orm.query.specialised.FulltextQueryBuilder`,
:class:`~runic.orm.query.specialised.VectorQueryBuilder`) live in
:mod:`runic.orm.query.specialised`.
"""

from __future__ import annotations

import logging
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Generic, TypeVar

from runic.orm.core.descriptors import FieldDescriptor, FieldInfo
from runic.orm.core.metadata import metadata as _global_metadata
from runic.orm.query.expressions import (
    AggExpr,
    CompoundExpr,
    Expr,
    FilterExpr,
    NegatedExpr,
    OrderExpr,
)
from runic.orm.query.traversal import TraversalStep

log = logging.getLogger(__name__)

T = TypeVar("T")


# ---------------------------------------------------------------------------
# Internal: compiled match clause
# ---------------------------------------------------------------------------


class _MatchClause:
    """One MATCH or OPTIONAL MATCH clause, plus any WITH pipelining."""

    def __init__(
        self,
        pattern: str,
        *,
        optional: bool = True,
        is_call: bool = False,
    ) -> None:
        self.pattern = pattern
        self.optional = optional
        self.is_call = is_call

    def to_cypher(self) -> str:
        if self.is_call:
            return self.pattern
        prefix = "OPTIONAL MATCH" if self.optional else "MATCH"
        return f"{prefix} {self.pattern}"


# ---------------------------------------------------------------------------
# QueryBuilder
# ---------------------------------------------------------------------------


[docs] class QueryBuilder(Generic[T]): # noqa: UP046 """Fluent Cypher query builder for a single root Node class. Construct via :meth:`Session.query`:: q = session.query(User) All non-terminal methods return ``self`` so calls can be chained:: users = session.query(User).where(User.active == True).limit(10).all() Parameters ---------- session: The :class:`~runic.orm.session.session.Session` (or :class:`~runic.orm.session.async_session.AsyncSession`) this builder is bound to. root_cls: The root Node subclass to query. """ def __init__(self, session: Any | None, root_cls: type[T]) -> None: from runic.orm.core.metadata import MetaData self._session: Any = session # None when unbound (created via select()) self._root_cls: type[T] = root_cls _mapper = getattr(session, "mapper", None) self._meta: MetaData = getattr(_mapper, "meta", _global_metadata) # Alias tracking ------------------------------------------------- # alias → ORM class (Node or Edge) self._alias_map: dict[str, type] = {} # ORM class → list of aliases (inverse lookup) self._cls_aliases: dict[type, list[str]] = {} # The most recently registered target alias (default RETURN target) self._last_alias: str = "n" # The root node alias self._root_alias: str = "n" # Register root self._set_alias("n", root_cls) # Query parts ---------------------------------------------------- self._match_clauses: list[_MatchClause] = [] self._with_vars: list[str] | None = None self._where_exprs: list[Expr] = [] self._order: list[OrderExpr] = [] self._distinct: bool = False self._limit_val: int | None = None self._skip_val: int | None = None # Return specification ------------------------------------------- # None → auto (last alias or root alias) # list of str → explicit aliases / Cypher expressions to return self._return_aliases: list[str] | None = None # Edge alias to include in .all_with_edges() output self._edge_alias_for_result: str | None = None # Aggregation specs self._agg_exprs: list[AggExpr] = [] self._group_by_alias: str | None = None # Scalar projection (for .project()) self._project_fields: list[FieldDescriptor | str] = [] # Parameter counter ---------------------------------------------- self._param_counter: int = 0 self._params: dict[str, Any] = {} # ------------------------------------------------------------------ # Dialect access # ------------------------------------------------------------------ @property def _dialect(self) -> Any: if self._session is None: return None return self._session.mapper.dialect # ------------------------------------------------------------------ # Unbound-statement guard # ------------------------------------------------------------------ def _check_bound(self) -> None: if self._session is None: raise RuntimeError( "This statement is not bound to a session. " "Use session.scalars(stmt), session.scalar(stmt), " "session.all_rows(stmt), session.all_with_edges(stmt), " "or session.count(stmt) to execute it." ) @contextmanager def _bound_to(self, session: Any) -> Generator[QueryBuilder[T]]: """Temporarily bind this statement to *session* for execution. Used by :class:`~runic.orm.session.session.Session` execution methods so that :meth:`build` has access to the dialect and the identity map is populated correctly. The binding is restored after the ``with`` block, leaving the statement reusable. """ old = self._session self._session = session try: yield self finally: self._session = old # ------------------------------------------------------------------ # Alias management # ------------------------------------------------------------------
[docs] def alias(self, name: str) -> QueryBuilder[T]: """Set the Cypher variable for the root (most recent) node. Call immediately after :meth:`Session.query` to name the root variable, or after :meth:`TraversalStep.alias` has already been called to rename the last registered target. Example:: session.query(User).alias("u").where(User.active == True, on="u") """ old_alias = self._last_alias old_cls = self._alias_map.get(old_alias) if old_cls is not None: # Remove old mapping self._alias_map.pop(old_alias, None) if old_cls in self._cls_aliases and old_alias in self._cls_aliases[old_cls]: self._cls_aliases[old_cls].remove(old_alias) self._set_alias(name, old_cls or self._root_cls) self._last_alias = name # Update root alias if renaming the root if old_alias == self._root_alias: self._root_alias = name return self
# ------------------------------------------------------------------ # Filtering # ------------------------------------------------------------------
[docs] def where( self, expr: Expr, *, on: str | None = None, ) -> QueryBuilder[T]: """Add a WHERE predicate. Parameters ---------- expr: A :class:`~runic.orm.query.expressions.FilterExpr`, :class:`~runic.orm.query.expressions.CompoundExpr`, or :class:`~runic.orm.query.expressions.NegatedExpr`. Created via field descriptor operators:: User.name == "Alice" (User.age > 18) & (User.active == True) on: Override the Cypher variable for this predicate. Useful when the same Node class appears under multiple aliases, or when filtering on edge properties:: .where(Rated.score > 4.0, on="r") Notes ----- Multiple ``.where()`` calls are combined with ``AND``. To express ``OR``, use the ``|`` operator on the expressions before passing:: .where((User.role == "admin") | (User.role == "mod")) """ if on is not None and isinstance(expr, FilterExpr): expr = expr.with_alias(on) self._where_exprs.append(expr) return self
# ------------------------------------------------------------------ # Traversal # ------------------------------------------------------------------
[docs] def traverse( self, relation_field: FieldDescriptor, *, edge_alias: str | None = None, optional: bool = True, ) -> TraversalStep: """Traverse a declared :func:`~runic.orm.core.descriptors.Relation` field. Returns a :class:`~runic.orm.query.traversal.TraversalStep`; call ``.alias("f")`` on it to name the target node and return to the builder. Parameters ---------- relation_field: The ``Relation``-backed field descriptor accessed at class level:: User.friends # list[User] = Relation(...) User.rated # list[Movie] = Relation(edge_model=Rated) edge_alias: When given, a named relationship variable is emitted in the pattern:: (u)-[r:RATED]->(m) This enables filtering on edge properties via ``.where(Rated.score > 4, on="r")`` and retrieving edge instances via ``.all_with_edges()``. optional: ``True`` (default) → ``OPTIONAL MATCH`` (left-join; keeps source nodes that have no such relationship). ``False`` → ``MATCH`` (inner join; drops source nodes without a matching relationship). Returns ------- TraversalStep Call ``.alias("name")`` on the return value to complete the step. Examples -------- .. code-block:: python # Basic traversal q = session.query(User).alias("u") q = q.traverse(User.friends).alias("f") # Traversal with edge properties q = session.query(User).alias("u") q = q.traverse(User.rated, edge_alias="r").alias("m") q = q.where(Rated.score >= 4.0, on="r") """ return TraversalStep( builder=self, field_descriptor=relation_field, source_alias=self._last_alias, optional=optional, edge_alias=edge_alias, min_hops=1, max_hops=1, )
[docs] def repeat( self, relation_field: FieldDescriptor, *, min_hops: int = 1, max_hops: int | None = None, optional: bool = False, ) -> TraversalStep: """Traverse a relation with variable-length path quantifier ``*min..max``. Generates a Cypher pattern like:: (p)-[:PARENT*1..5]->(ancestor:Person) Parameters ---------- relation_field: The ``Relation`` field to traverse repeatedly. min_hops: Minimum number of hops (default ``1``). max_hops: Maximum number of hops. ``None`` means unbounded (``*min..``). optional: ``False`` (default for repeat) — required traversal. ``True`` → ``OPTIONAL MATCH``. Returns ------- TraversalStep Call ``.alias("name")`` to complete the step. Examples -------- .. code-block:: python # All ancestors up to depth 5 ancestors = ( session.query(Person) .alias("p") .where(Person.id == start_id) .repeat(Person.parent, min_hops=1, max_hops=5) .alias("anc") .all() ) # All reachable nodes (unbounded) reachable = ( session.query(Node) .alias("s") .repeat(Node.connected_to) .alias("t") .all() ) """ return TraversalStep( builder=self, field_descriptor=relation_field, source_alias=self._last_alias, optional=optional, edge_alias=None, min_hops=min_hops, max_hops=max_hops, )
# ------------------------------------------------------------------ # WITH (multi-stage pipelining) # ------------------------------------------------------------------
[docs] def with_(self, *aliases: str) -> QueryBuilder[T]: """Insert a ``WITH`` clause to pipeline results between query stages. Use when you want to filter/aggregate in one stage before continuing a traversal in the next:: ( session.query(User) .alias("u") .where(User.active == True) .with_("u") # WITH u .traverse(User.posts) .alias("p") .return_target("p") .all() ) Parameters ---------- *aliases: Cypher variable names to carry forward (e.g. ``"u"``, ``"f"``). """ self._with_vars = list(aliases) return self
# ------------------------------------------------------------------ # Ordering / pagination # ------------------------------------------------------------------
[docs] def order_by( self, field: FieldDescriptor | str, *, desc: bool = False, ) -> QueryBuilder[T]: """Add an ``ORDER BY`` term. Parameters ---------- field: A field descriptor (``User.name``) or a raw Cypher expression string (``"n.created_at DESC"``). desc: ``True`` for descending order (default ``False``). Examples -------- .. code-block:: python q.order_by(User.age) # ORDER BY n.age ASC q.order_by(User.created_at, desc=True) # ORDER BY n.created_at DESC q.order_by("score ASC") # raw string """ if isinstance(field, FieldDescriptor): alias = ( self._alias_for_cls(field.owner) if field.owner else self._root_alias ) self._order.append(OrderExpr(alias=alias, prop=field.field_name, desc=desc)) else: self._order.append( OrderExpr(alias=None, prop=None, raw=str(field), desc=desc) ) return self
[docs] def limit(self, n: int) -> QueryBuilder[T]: """Set ``LIMIT n`` on the query.""" self._limit_val = n return self
[docs] def skip(self, n: int) -> QueryBuilder[T]: """Set ``SKIP n`` (offset) on the query.""" self._skip_val = n return self
[docs] def distinct(self) -> QueryBuilder[T]: """Add ``DISTINCT`` to the ``RETURN`` clause.""" self._distinct = True return self
# ------------------------------------------------------------------ # Return specification # ------------------------------------------------------------------
[docs] def return_target(self, alias: str) -> QueryBuilder[T]: """Set the single alias to return decoded Node instances from. When a traversal is involved, this selects which alias's nodes constitute the result of ``.all()``:: q.return_target("f") # returns f-nodes as list[FriendType] """ self._return_aliases = [alias] return self
[docs] def return_nodes(self, *aliases: str) -> QueryBuilder[T]: """Declare multiple node aliases to include in the ``RETURN`` clause. Used with :meth:`return_edge` and :meth:`all_with_edges` to return structured tuples:: q.return_nodes("u", "m").return_edge("r").all_with_edges() """ self._return_aliases = list(aliases) return self
[docs] def return_edge(self, alias: str) -> QueryBuilder[T]: """Declare an edge alias to include in the ``RETURN`` clause. Requires that the traversal was created with ``edge_alias=alias``. The edge is decoded via :meth:`~runic.orm.mapper.mapper.Mapper.decode_edge` and included as the middle element of tuples returned by :meth:`all_with_edges`. """ self._edge_alias_for_result = alias return self
[docs] def project(self, *fields: FieldDescriptor | str) -> QueryBuilder[T]: """Return only specific property values (scalar projection). Terminal method ``.scalars()`` returns the projected values as a flat list; ``.all_rows()`` returns a list of dicts:: # Scalar list names = session.query(User).project(User.name).scalars() # Dict list rows = session.query(User).project(User.name, User.age).all_rows() """ self._project_fields = list(fields) return self
# ------------------------------------------------------------------ # Aggregation # ------------------------------------------------------------------
[docs] def aggregate( self, *agg_exprs: AggExpr, group_by: str | None = None, ) -> QueryBuilder[T]: """Add aggregation expressions to the ``RETURN`` clause. Parameters ---------- *agg_exprs: One or more :class:`~runic.orm.query.expressions.AggExpr` instances created by the helper functions :func:`~runic.orm.query.expressions.count`, :func:`~runic.orm.query.expressions.avg`, etc. group_by: Alias to keep in the ``RETURN`` clause alongside the aggregations (Cypher grouping is implicit — any non-aggregated return term acts as a GROUP BY key):: .aggregate(count("*").as_("friend_count"), group_by="u") # RETURN u, count(*) AS friend_count Examples -------- .. code-block:: python from runic.orm.query import count, avg result = ( session.query(User) .alias("u") .traverse(User.friends) .aggregate(count("*").as_("friend_count"), group_by="u") .all_rows() # list[dict] with {"u": ..., "friend_count": int} ) avg_age = ( session.query(User).aggregate(avg(User.age).as_("average_age")).scalar() ) """ self._agg_exprs = list(agg_exprs) self._group_by_alias = group_by return self
# ------------------------------------------------------------------ # Build (compile to Cypher) # ------------------------------------------------------------------
[docs] def build(self) -> tuple[str, dict[str, Any]]: """Compile the accumulated builder state to a ``(cypher, params)`` pair. This is the core compilation step; all terminal methods call it internally. You can also call it directly for debugging or to integrate with custom execution logic:: cypher, params = session.query(User).where(User.active == True).build() print(cypher) # MATCH (n:User) # WHERE n.active = $p0 # RETURN n Returns ------- tuple[str, dict[str, Any]] A ``(cypher_string, params_dict)`` pair ready to pass to :meth:`~runic.orm.session.session.Session.execute`. """ # Reset params for each build call so multiple .all() calls are clean. self._param_counter = 0 self._params = {} parts: list[str] = [] # ── Root MATCH ────────────────────────────────────────────────── root_meta = self._meta.get_node_meta(self._root_cls) if root_meta is None: raise ValueError( f"Class {self._root_cls.__name__!r} is not a registered Node subclass" ) _lc = getattr(self._dialect, "labels_clause", None) labels_str = _lc(root_meta.labels) if _lc else ":".join(root_meta.labels) _sw = getattr(self._dialect, "subtype_where", None) subtype_filter = _sw(self._root_alias, root_meta.labels) if _sw else None parts.append(f"MATCH ({self._root_alias}:{labels_str})") if subtype_filter: parts.append(f"WHERE {subtype_filter}") # ── WHERE (root conditions) + WITH + Traversal + WHERE (post) # # Correct Cypher ordering when traversals are present: # MATCH (root) # WHERE <root conditions> ← must precede OPTIONAL MATCH # [WITH ...] ← pipelining, precedes traversal # OPTIONAL MATCH ... # WHERE <traversal-target conditions> # # Without this split, WHERE would apply to the OPTIONAL MATCH clause # and turn root filters into null-producing predicates for non-matching # root nodes (FalkorDB applies WHERE to the preceding clause). # ───────────────────────────────────────────────────────────────── if self._where_exprs and self._match_clauses: root_exprs, post_exprs = self._split_where_exprs() else: root_exprs = [] post_exprs = self._where_exprs if root_exprs: cond = self._compile_expr( root_exprs[0] if len(root_exprs) == 1 else CompoundExpr(op="AND", operands=root_exprs) ) parts.append(f"WHERE {cond}") # ── WITH (pipeline — emitted before traversals) ────────────────── if self._with_vars: parts.append(f"WITH {', '.join(self._with_vars)}") # ── Traversal clauses ──────────────────────────────────────────── parts.extend(mc.to_cypher() for mc in self._match_clauses) # ── WHERE (post-traversal conditions on traversal targets / edges) if post_exprs: cond = self._compile_expr( post_exprs[0] if len(post_exprs) == 1 else CompoundExpr(op="AND", operands=post_exprs) ) parts.append(f"WHERE {cond}") # ── RETURN ──────────────────────────────────────────────────────── parts.append(self._compile_return()) # ── ORDER BY ───────────────────────────────────────────────────── if self._order: order_str = ", ".join(o.to_cypher() for o in self._order) parts.append(f"ORDER BY {order_str}") # ── SKIP / LIMIT ───────────────────────────────────────────────── if self._skip_val is not None: parts.append(f"SKIP {self._skip_val}") if self._limit_val is not None: parts.append(f"LIMIT {self._limit_val}") cypher = "\n".join(parts) return cypher, dict(self._params)
# ------------------------------------------------------------------ # Terminal methods (sync) # ------------------------------------------------------------------
[docs] def all(self) -> list[T]: """Execute and return all matching Node instances. The return type is the root class (or the alias set by :meth:`return_target`). Results are decoded and registered in the session identity map. Returns ------- list[T] Decoded Node instances of the root type (or target type when ``return_target()`` was called). """ self._check_bound() cypher, params = self.build() log.debug("QueryBuilder.all: %s", cypher) result = self._session.execute(cypher, params) return self._decode_node_result(result)
[docs] def one(self) -> T | None: """Execute and return the first matching Node instance, or ``None``. Internally calls ``.limit(1).all()`` and returns the first element. """ self.limit(1) items = self.all() return items[0] if items else None
[docs] def all_with_edges(self) -> list[tuple[Any, ...]]: """Execute and return tuples of ``(NodeA, EdgeModel, NodeB)``. Requires :meth:`return_nodes` to specify node aliases and :meth:`return_edge` to specify the edge alias. The edge is decoded via :meth:`~runic.orm.mapper.mapper.Mapper.decode_edge`. Returns ------- list[tuple] Each element is a tuple whose order matches the aliases given to ``return_nodes()`` with the edge inserted at its position in ``return_edge()``. Example ------- .. code-block:: python rows = ( session.query(User) .alias("u") .traverse(User.rated, edge_alias="r") .alias("m") .return_nodes("u", "m") .return_edge("r") .all_with_edges() ) for user, rated_edge, movie in rows: print(f"{user.name} rated {movie.title} with {rated_edge.score}") """ self._check_bound() cypher, params = self.build() log.debug("QueryBuilder.all_with_edges: %s", cypher) result = self._session.execute(cypher, params) return self._decode_edge_result(result)
[docs] def all_rows(self) -> list[dict[str, Any]]: """Execute and return raw column-keyed dicts. Useful for multi-alias returns, aggregations, or scalar projections where mixed types are in the result set:: rows = q.aggregate(count("*").as_("n"), group_by="u").all_rows() # [{"u": <User>, "n": 5}, ...] """ self._check_bound() cypher, params = self.build() log.debug("QueryBuilder.all_rows: %s", cypher) result = self._session.execute(cypher, params) return self._decode_rows_as_dicts(result)
[docs] def count(self) -> int: """Execute a ``count(*)`` variant and return the integer count. Overrides any existing RETURN spec to emit ``RETURN count(*)``. Ignores :meth:`limit` and :meth:`skip`. """ self._check_bound() saved_agg = self._agg_exprs saved_group = self._group_by_alias saved_return = self._return_aliases saved_project = self._project_fields from runic.orm.query.expressions import count as _count_fn self._agg_exprs = [_count_fn("*").as_("_count")] self._group_by_alias = None self._return_aliases = None self._project_fields = [] cypher, params = self.build() log.debug("QueryBuilder.count: %s", cypher) result = self._session.execute(cypher, params) # Restore self._agg_exprs = saved_agg self._group_by_alias = saved_group self._return_aliases = saved_return self._project_fields = saved_project if result.rows: return int(result.rows[0][0]) return 0
[docs] def scalar(self) -> Any: """Execute and return the first column of the first row, or ``None``.""" self._check_bound() result = self._session.execute(*self.build()) if result.rows and result.rows[0]: return result.rows[0][0] return None
[docs] def scalars(self) -> list[Any]: """Execute and return the first column of every row as a flat list.""" self._check_bound() result = self._session.execute(*self.build()) return [row[0] for row in result.rows]
# ------------------------------------------------------------------ # Internal: traversal registration (called by TraversalStep.alias) # ------------------------------------------------------------------
[docs] def register_traversal( self, fd: FieldDescriptor, source_alias: str, target_alias: str, *, optional: bool, edge_alias: str | None, min_hops: int, max_hops: int | None, ) -> QueryBuilder[T]: """Append a MATCH clause for one traversal step and register aliases. Called by :meth:`TraversalStep.alias` to complete a traversal step. """ # Resolve target class and label raw_target = fd.target target_cls = ( self._meta.resolve_target(raw_target) if isinstance(raw_target, str) else raw_target ) if target_cls is None: target_label = str(raw_target) if raw_target else "Node" else: node_meta = self._meta.get_node_meta(target_cls) target_label = node_meta.primary_label if node_meta else target_cls.__name__ # Build the relationship part of the pattern rel_type = fd.relationship or "REL" direction = fd.direction or "OUTGOING" if min_hops == 1 and max_hops == 1: hop_str = "" elif max_hops is None: hop_str = f"*{min_hops}.." else: hop_str = f"*{min_hops}..{max_hops}" if edge_alias: rel_part = f"[{edge_alias}:{rel_type}{hop_str}]" else: rel_part = f"[:{rel_type}{hop_str}]" target_part = f"({target_alias}:{target_label})" if direction == "OUTGOING": pattern = f"({source_alias})-{rel_part}->{target_part}" elif direction == "INCOMING": pattern = f"({source_alias})<-{rel_part}-{target_part}" else: pattern = f"({source_alias})-{rel_part}-{target_part}" self._match_clauses.append(_MatchClause(pattern, optional=optional)) # Register target node alias if target_cls is not None: self._set_alias(target_alias, target_cls) # Register edge alias if edge_alias is not None: edge_cls = fd.edge_model if isinstance(edge_cls, str): edge_cls = self._meta.resolve_target(edge_cls) if edge_cls is not None: self._set_alias(edge_alias, edge_cls) self._last_alias = target_alias return self
# ------------------------------------------------------------------ # Internal: Cypher expression compilation # ------------------------------------------------------------------ def _compile_expr(self, expr: Expr) -> str: """Recursively compile an Expr tree to a Cypher predicate string.""" if isinstance(expr, FilterExpr): return self._compile_filter(expr) if isinstance(expr, CompoundExpr): parts = [f"({self._compile_expr(op)})" for op in expr.operands] return f" {expr.op} ".join(parts) if isinstance(expr, NegatedExpr): return f"NOT ({self._compile_expr(expr.operand)})" raise TypeError(f"Unsupported expression type: {type(expr)!r}") def _compile_filter(self, expr: FilterExpr) -> str: """Compile a single FilterExpr to a Cypher condition string.""" alias = expr.alias or self._alias_for_cls(expr.cls) # Null checks have no parameter if expr.op == "IS NULL": return f"{alias}.{expr.prop} IS NULL" if expr.op == "IS NOT NULL": return f"{alias}.{expr.prop} IS NOT NULL" # Look up converter for this field fi = self._find_field_info(expr.cls, expr.prop) converter = fi.field.converter if fi is not None else None # Convert value to graph representation param_value = expr.value if converter is not None and param_value is not None: param_value = converter.to_graph(param_value) param_name = self._next_param(param_value) # Wrap param ref with cypher_fn if needed (dialect-aware) _d = self._dialect cypher_fn = ( _d.cypher_fn_for_field(fi) if (fi is not None and _d is not None) else None ) param_ref = f"{cypher_fn}(${param_name})" if cypher_fn else f"${param_name}" if expr.op in ("IN", "NOT IN"): prefix = "NOT " if (expr.negate or expr.op == "NOT IN") else "" return f"{prefix}{alias}.{expr.prop} IN ${param_name}" if expr.negate: return f"NOT ({alias}.{expr.prop} {expr.op} {param_ref})" return f"{alias}.{expr.prop} {expr.op} {param_ref}" def _compile_return(self) -> str: """Compile the RETURN clause.""" distinct_kw = "DISTINCT " if self._distinct else "" # Aggregation mode if self._agg_exprs: cls_to_alias: dict[type, str] = { cls: aliases[0] for cls, aliases in self._cls_aliases.items() if aliases } agg_parts = [e.to_cypher(cls_to_alias) for e in self._agg_exprs] if self._group_by_alias: return f"RETURN {distinct_kw}{self._group_by_alias}, {', '.join(agg_parts)}" return f"RETURN {distinct_kw}{', '.join(agg_parts)}" # Scalar projection if self._project_fields: proj_parts: list[str] = [] for f in self._project_fields: if isinstance(f, FieldDescriptor): alias = ( self._alias_for_cls(f.owner) if f.owner else self._root_alias ) proj_parts.append(f"{alias}.{f.field_name}") else: proj_parts.append(str(f)) return f"RETURN {distinct_kw}{', '.join(proj_parts)}" # Explicit return aliases if self._return_aliases: extra = [] if self._edge_alias_for_result: extra.append(self._edge_alias_for_result) all_parts = list(self._return_aliases) if ( self._edge_alias_for_result and self._edge_alias_for_result not in all_parts ): # Insert edge between the two node aliases all_parts.insert(1, self._edge_alias_for_result) return f"RETURN {distinct_kw}{', '.join(all_parts)}" # Default: return last alias return f"RETURN {distinct_kw}{self._last_alias}" # ------------------------------------------------------------------ # Internal: result decoding # ------------------------------------------------------------------ def _decode_node_result(self, result: Any) -> list[T]: """Decode a single-column node result into ORM entities.""" mapper = self._session.mapper register = self._session.register_or_get # Determine the target class for decoding return_alias = ( self._return_aliases[0] if self._return_aliases else self._last_alias ) target_cls = self._alias_map.get(return_alias, self._root_cls) entities: list[T] = [] for row in result.rows: val = row[0] if val is None: continue decoded = mapper.decode_node(val, target_cls) entities.append(register(decoded)) return entities def _decode_edge_result(self, result: Any) -> list[tuple[Any, ...]]: """Decode multi-column result into (NodeA, EdgeModel, NodeB) tuples.""" mapper = self._session.mapper register = self._session.register_or_get # Column order: return_aliases[0], edge_alias, return_aliases[1] edge_alias = self._edge_alias_for_result # Build ordered column list matching the RETURN clause columns: list[tuple[str, bool]] = [] if self._return_aliases: for i, a in enumerate(self._return_aliases): if i == 1 and edge_alias and edge_alias not in self._return_aliases: columns.append((edge_alias, True)) columns.append((a, False)) if not columns: columns = [(self._last_alias, False)] tuples: list[tuple[Any, ...]] = [] for row in result.rows: decoded_row: list[Any] = [] for col_idx, (col_alias, is_edge) in enumerate(columns): val = row[col_idx] if col_idx < len(row) else None if val is None: decoded_row.append(None) continue if is_edge: edge_cls = self._alias_map.get(col_alias) decoded_row.append(mapper.decode_edge(val, edge_cls)) else: node_cls = self._alias_map.get(col_alias, self._root_cls) decoded = mapper.decode_node(val, node_cls) decoded_row.append(register(decoded)) tuples.append(tuple(decoded_row)) return tuples def _decode_rows_as_dicts(self, result: Any) -> list[dict[str, Any]]: """Decode a multi-column result into column-keyed dicts.""" mapper = self._session.mapper register = self._session.register_or_get header = result.columns rows: list[dict[str, Any]] = [] for row in result.rows: d: dict[str, Any] = {} for i, val in enumerate(row): col_name = header[i] if i < len(header) else str(i) alias = col_name cls = self._alias_map.get(alias) if cls is not None and val is not None: # Check if this is a Node class (has NodeMeta) node_meta = self._meta.get_node_meta(cls) edge_meta = self._meta.get_edge_meta(cls) if node_meta is not None: val = register(mapper.decode_node(val, cls)) elif edge_meta is not None: val = mapper.decode_edge(val, cls) d[col_name] = val rows.append(d) return rows # ------------------------------------------------------------------ # Internal: helpers # ------------------------------------------------------------------ def _set_alias(self, alias: str, cls: type) -> None: self._alias_map[alias] = cls self._cls_aliases.setdefault(cls, []) if alias not in self._cls_aliases[cls]: self._cls_aliases[cls].append(alias) def _alias_for_cls(self, cls: type) -> str: """Return the first registered Cypher alias for *cls*, or root alias.""" aliases = self._cls_aliases.get(cls) if aliases: return aliases[0] # Fallback: if cls is the root, return root alias if cls is self._root_cls: return self._root_alias return self._last_alias def _next_param(self, value: Any) -> str: """Allocate a new positional parameter, store value, return name.""" name = f"p{self._param_counter}" self._param_counter += 1 self._params[name] = value return name def _find_field_info(self, cls: type, prop: str) -> FieldInfo | None: """Look up a FieldInfo by class and property name.""" node_meta = self._meta.get_node_meta(cls) if node_meta: return next((fi for fi in node_meta.fields if fi.name == prop), None) edge_meta = self._meta.get_edge_meta(cls) if edge_meta: return next((fi for fi in edge_meta.fields if fi.name == prop), None) return None def _split_where_exprs(self) -> tuple[list[Expr], list[Expr]]: """Split WHERE expressions into root-targeting and post-traversal groups. Root expressions reference only the root alias (or its class) and are safe to emit between the root MATCH and any OPTIONAL MATCH clauses. Post-traversal expressions reference traversal targets or edges and must come after all MATCH/OPTIONAL MATCH clauses. """ root: list[Expr] = [] post: list[Expr] = [] for expr in self._where_exprs: if self._expr_targets_root_only(expr): root.append(expr) else: post.append(expr) return root, post def _expr_targets_root_only(self, expr: Expr) -> bool: """Return True if *expr* references only the root Cypher alias.""" if isinstance(expr, FilterExpr): if expr.alias is not None: return expr.alias == self._root_alias # No explicit alias: resolve via class lookup resolved = self._alias_for_cls(expr.cls) return resolved == self._root_alias if isinstance(expr, CompoundExpr): return all(self._expr_targets_root_only(op) for op in expr.operands) if isinstance(expr, NegatedExpr): return self._expr_targets_root_only(expr.operand) return False