"""Specialised QueryBuilder subclasses: async, fulltext, and vector variants.
These extend :class:`~runic.ogm.query.builder.QueryBuilder` with either an
async execution model (:class:`AsyncQueryBuilder`) or a specialised root
clause that replaces the standard ``MATCH`` (fulltext ``CALL`` and vector KNN).
"""
from __future__ import annotations
import logging
from typing import Any, TypeVar
from runic.ogm.core.descriptors import FieldDescriptor
from runic.ogm.query.builder import QueryBuilder
from runic.ogm.query.expressions import CompoundExpr
log = logging.getLogger(__name__)
T = TypeVar("T")
# ---------------------------------------------------------------------------
# AsyncQueryBuilder
# ---------------------------------------------------------------------------
[docs]
class AsyncQueryBuilder(QueryBuilder[T]): # noqa: UP046
"""Async variant of :class:`QueryBuilder` for use with
:class:`~runic.ogm.session.async_session.AsyncSession`.
All intermediate (chainable) methods are identical to the sync version.
Only the **terminal** methods are replaced with ``async def`` equivalents.
Example
-------
.. code-block:: python
async with AsyncSession(graph) as session:
users = await (
session.query(User)
.where(User.active == True)
.order_by(User.name)
.limit(50)
.all()
)
"""
[docs]
async def all(self) -> list[T]: # type: ignore[override] # ty: ignore[invalid-method-override]
"""Async version of :meth:`~QueryBuilder.all`."""
cypher, params = self.build()
log.debug("AsyncQueryBuilder.all: %s", cypher)
result = await self._session.execute(cypher, params)
return self._decode_node_result(result)
[docs]
async def one(self) -> T | None: # type: ignore[override] # ty: ignore[invalid-method-override]
"""Async version of :meth:`~QueryBuilder.one`."""
self.limit(1)
items = await self.all()
return items[0] if items else None
[docs]
async def all_with_edges(self) -> list[tuple[Any, ...]]: # type: ignore[override] # ty: ignore[invalid-method-override]
"""Async version of :meth:`~QueryBuilder.all_with_edges`."""
cypher, params = self.build()
log.debug("AsyncQueryBuilder.all_with_edges: %s", cypher)
result = await self._session.execute(cypher, params)
return self._decode_edge_result(result)
[docs]
async def all_rows(self) -> list[dict[str, Any]]: # type: ignore[override] # ty: ignore[invalid-method-override]
"""Async version of :meth:`~QueryBuilder.all_rows`."""
cypher, params = self.build()
log.debug("AsyncQueryBuilder.all_rows: %s", cypher)
result = await self._session.execute(cypher, params)
return self._decode_rows_as_dicts(result)
[docs]
async def count(self) -> int: # type: ignore[override] # ty: ignore[invalid-method-override]
"""Async version of :meth:`~QueryBuilder.count`."""
saved_agg = self._agg_exprs
saved_group = self._group_by_alias
saved_return = self._return_aliases
saved_project = self._project_fields
from runic.ogm.query.expressions import count as _count_fn
self._agg_exprs = [_count_fn("*").as_("_count")]
self._group_by_alias = None
self._return_aliases = None
self._project_fields = []
cypher, params = self.build()
log.debug("AsyncQueryBuilder.count: %s", cypher)
result = await self._session.execute(cypher, params)
self._agg_exprs = saved_agg
self._group_by_alias = saved_group
self._return_aliases = saved_return
self._project_fields = saved_project
if result.rows:
return int(result.rows[0][0])
return 0
[docs]
async def scalar(self) -> Any: # type: ignore[override]
"""Async version of :meth:`~QueryBuilder.scalar`."""
result = await self._session.execute(*self.build())
if result.rows and result.rows[0]:
return result.rows[0][0]
return None
[docs]
async def scalars(self) -> list[Any]: # type: ignore[override] # ty: ignore[invalid-method-override]
"""Async version of :meth:`~QueryBuilder.scalars`."""
result = await self._session.execute(*self.build())
return [row[0] for row in result.rows]
# ---------------------------------------------------------------------------
# FulltextQueryBuilder
# ---------------------------------------------------------------------------
[docs]
class FulltextQueryBuilder(QueryBuilder[T]): # noqa: UP046
"""QueryBuilder variant for FalkorDB fulltext search queries.
Constructed via :meth:`~runic.ogm.session.session.Session.fulltext_search`.
The root MATCH is replaced with a ``CALL db.idx.fulltext.queryNodes(...)``
invocation that uses the declared fulltext index.
The fulltext index must have been created for the node's label, e.g.::
class Post(Node, labels=["Post"]):
title: str = Field(index_type="FULLTEXT")
Example
-------
.. code-block:: python
posts = (
session.fulltext_search(Post, query="graph databases", fields=["title"])
.where(Post.published == True)
.order_by(Post.created_at, desc=True)
.limit(20)
.all()
)
Cypher emitted::
CALL db.idx.fulltext.queryNodes('Post', $__fts_query) YIELD node AS n
WHERE n.published = $p0
RETURN n
ORDER BY n.created_at DESC
LIMIT 20
"""
def __init__(
self,
session: Any,
root_cls: type[T],
query: str,
fields: list[str] | None = None,
) -> None:
super().__init__(session, root_cls)
self._fts_query = query
self._fts_fields = fields
[docs]
def build(self) -> tuple[str, dict[str, Any]]:
"""Compile to Cypher, replacing MATCH with CALL fulltext procedure."""
self._param_counter = 0
self._params = {"__fts_query": self._fts_query}
root_meta = self._meta.get_node_meta(self._root_cls)
if root_meta is None:
raise ValueError(
f"Class {self._root_cls.__name__!r} is not a registered Node subclass"
)
alias = self._root_alias
label = root_meta.primary_label
parts: list[str] = [self._dialect.fulltext_call(label, alias, "__fts_query")]
# Extra OPTIONAL MATCHes for traversals (root WHERE + WITH go first)
if self._where_exprs and self._match_clauses:
root_exprs, post_exprs = self._split_where_exprs()
else:
root_exprs = []
post_exprs = self._where_exprs
if root_exprs:
cond = self._compile_expr(
root_exprs[0]
if len(root_exprs) == 1
else CompoundExpr(op="AND", operands=root_exprs)
)
parts.append(f"WHERE {cond}")
if self._with_vars:
parts.append(f"WITH {', '.join(self._with_vars)}")
parts.extend(mc.to_cypher() for mc in self._match_clauses)
if post_exprs:
cond = self._compile_expr(
post_exprs[0]
if len(post_exprs) == 1
else CompoundExpr(op="AND", operands=post_exprs)
)
parts.append(f"WHERE {cond}")
parts.append(self._compile_return())
if self._order:
parts.append(f"ORDER BY {', '.join(o.to_cypher() for o in self._order)}")
if self._skip_val is not None:
parts.append(f"SKIP {self._skip_val}")
if self._limit_val is not None:
parts.append(f"LIMIT {self._limit_val}")
return "\n".join(parts), dict(self._params)
# ---------------------------------------------------------------------------
# VectorQueryBuilder
# ---------------------------------------------------------------------------
[docs]
class VectorQueryBuilder(QueryBuilder[T]): # noqa: UP046
"""QueryBuilder variant for vector KNN search.
Constructed via :meth:`~runic.ogm.session.session.Session.vector_search`.
Appends a KNN distance expression to the ORDER BY and RETURN clauses.
The field must have ``index_type="VECTOR"`` and an HNSW vector index
must be created via :meth:`~runic.migrate.schema.SchemaManager`::
class Document(Node, labels=["Document"]):
embedding: Vector = Field(index_type="VECTOR")
Example
-------
.. code-block:: python
similar = (
session.vector_search(
Document,
field=Document.embedding,
vector=[0.1, 0.2, 0.3],
k=10,
)
.where(Document.active == True)
.all()
)
Cypher emitted (FalkorDB KNN syntax)::
MATCH (n:Document)
WHERE n.active = $p0
RETURN n, vecf32(n.embedding) <-> vecf32($__knn_vec) AS __score
ORDER BY __score ASC
LIMIT 10
"""
def __init__(
self,
session: Any,
root_cls: type[T],
field: FieldDescriptor,
vector: list[float],
k: int,
) -> None:
super().__init__(session, root_cls)
self._knn_field = field
self._knn_vector = vector
self._knn_k = k
[docs]
def build(self) -> tuple[str, dict[str, Any]]:
"""Compile to Cypher with KNN ORDER BY."""
self._param_counter = 0
self._params = {"__knn_vec": list(self._knn_vector)}
root_meta = self._meta.get_node_meta(self._root_cls)
if root_meta is None:
raise ValueError(
f"Class {self._root_cls.__name__!r} is not a registered Node subclass"
)
_lc = getattr(self._dialect, "labels_clause", None)
labels_str = _lc(root_meta.labels) if _lc else ":".join(root_meta.labels)
alias = self._root_alias
field_alias = (
self._alias_for_cls(self._knn_field.owner)
if self._knn_field.owner
else self._root_alias
)
field_name = self._knn_field.field_name
type_name = root_meta.primary_label
self._params["__knn_k"] = self._knn_k
parts: list[str] = [
self._dialect.vector_knn_start(alias, labels_str, type_name, field_name)
]
if self._where_exprs and self._match_clauses:
root_exprs, post_exprs = self._split_where_exprs()
else:
root_exprs = []
post_exprs = self._where_exprs
if root_exprs:
cond = self._compile_expr(
root_exprs[0]
if len(root_exprs) == 1
else CompoundExpr(op="AND", operands=root_exprs)
)
parts.append(f"WHERE {cond}")
if self._with_vars:
parts.append(f"WITH {', '.join(self._with_vars)}")
parts.extend(mc.to_cypher() for mc in self._match_clauses)
if post_exprs:
cond = self._compile_expr(
post_exprs[0]
if len(post_exprs) == 1
else CompoundExpr(op="AND", operands=post_exprs)
)
parts.append(f"WHERE {cond}")
# KNN return includes the distance score
return_part = self._compile_return()
if "RETURN" in return_part and "__score" not in return_part:
score_expr = self._dialect.vector_knn_score_expr(field_alias, field_name)
return_part = return_part + f", {score_expr}"
parts.append(return_part)
# KNN ordering: always by score ASC, then any user orders
knn_order = "ORDER BY __score ASC"
if self._order:
user_order = ", ".join(o.to_cypher() for o in self._order)
parts.append(f"{knn_order}, {user_order}")
else:
parts.append(knn_order)
if self._skip_val is not None:
parts.append(f"SKIP {self._skip_val}")
# k overrides limit if no explicit limit was set
effective_limit = (
self._limit_val if self._limit_val is not None else self._knn_k
)
parts.append(f"LIMIT {effective_limit}")
return "\n".join(parts), dict(self._params)