"""Apache AGE driver, dialect, and result wrappers.
Apache AGE is a PostgreSQL extension that adds graph database capabilities
and supports openCypher queries. Cypher is executed via the ``cypher()``
SQL function:
SELECT * FROM cypher('graph', $$ CYPHER $$ [, params::agtype])
AS (col0 agtype, ...);
This driver uses ``psycopg`` (psycopg3) for the PostgreSQL connection.
Parameters are serialised as a JSON / agtype map and passed as the third
argument to ``cypher()``, which makes them available inside the Cypher
query as ``$param_name`` references (identical to the runic ORM's ``$p0``
convention).
"""
from __future__ import annotations
import json
import logging
import re
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from runic.orm.core.descriptors import FieldInfo
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Agtype data structures
# ---------------------------------------------------------------------------
class _AGEVertexData:
"""Internal parsed agtype vertex."""
__slots__ = ("id", "label", "properties")
def __init__(self, id: Any, label: str, properties: dict[str, Any]) -> None: # noqa: A002
self.id = id
self.label = label
self.properties = properties
class _AGEEdgeData:
"""Internal parsed agtype edge."""
__slots__ = ("end_id", "id", "label", "properties", "start_id")
def __init__(
self,
id: Any, # noqa: A002
label: str,
start_id: Any,
end_id: Any,
properties: dict[str, Any],
) -> None:
self.id = id
self.label = label
self.start_id = start_id
self.end_id = end_id
self.properties = properties
def _parse_agtype(text: str) -> Any:
"""Parse an AGE agtype string representation into a Python value.
Vertices arrive as ``{...}::vertex``, edges as ``{...}::edge``;
arrays as ``[elem1, elem2, ...]`` (elements may carry their own ``::type``);
plain scalars/maps are valid JSON.
"""
text = text.strip()
# AGE arrays: [elem1::vertex, elem2::vertex, ...]
if text.startswith("["):
inner = text[1:-1] if text.endswith("]") else text[1:]
inner = inner.strip()
if not inner:
return []
return [
_parse_agtype(e.strip())
for e in _split_agtype_array_elements(inner)
if e.strip()
]
if "::" in text:
json_part, _, type_tag = text.rpartition("::")
json_part = json_part.strip()
if type_tag == "vertex":
data: dict[str, Any] = json.loads(json_part)
return _AGEVertexData(
id=data.get("id"),
label=data.get("label", ""),
properties=data.get("properties") or {},
)
if type_tag == "edge":
data = json.loads(json_part)
return _AGEEdgeData(
id=data.get("id"),
label=data.get("label", ""),
start_id=data.get("start_id"),
end_id=data.get("end_id"),
properties=data.get("properties") or {},
)
# Unknown typed literal — return decoded JSON body
return json.loads(json_part)
return json.loads(text)
def _serialize_param(val: Any) -> Any:
"""JSON serialiser for types that json.dumps() does not handle natively."""
from datetime import datetime
from enum import Enum
if isinstance(val, datetime):
return val.isoformat()
if isinstance(val, Enum):
return val.value
raise TypeError(f"Cannot serialise {type(val).__name__!r} to agtype")
# ---------------------------------------------------------------------------
# GraphNode / GraphEdge wrappers
# ---------------------------------------------------------------------------
class AGENode:
"""Wraps an :class:`_AGEVertexData` to conform to ``GraphNode``."""
__slots__ = ("_raw",)
def __init__(self, raw: _AGEVertexData) -> None:
self._raw = raw
@property
def element_id(self) -> Any:
return self._raw.id
@property
def labels(self) -> list[str]:
stored = self._raw.properties.get("_labels")
return stored if isinstance(stored, list) else [self._raw.label]
@property
def properties(self) -> dict[str, Any]:
return dict(self._raw.properties)
class AGEEdge:
"""Wraps an :class:`_AGEEdgeData` to conform to ``GraphEdge``."""
__slots__ = ("_raw",)
def __init__(self, raw: _AGEEdgeData) -> None:
self._raw = raw
@property
def type(self) -> str:
return self._raw.label
@property
def properties(self) -> dict[str, Any]:
return dict(self._raw.properties)
# ---------------------------------------------------------------------------
# GraphResult wrapper
# ---------------------------------------------------------------------------
class AGEResult:
"""Eagerly-collected AGE query result conforming to ``GraphResult``."""
__slots__ = ("_columns", "_rows")
def __init__(self, rows: list[list[Any]], columns: list[str]) -> None:
self._rows = rows
self._columns = columns
@property
def rows(self) -> list[list[Any]]:
return self._rows
@property
def columns(self) -> list[str]:
return self._columns
# ---------------------------------------------------------------------------
# RETURN-clause parser (builds the AGE AS (...) column list)
# ---------------------------------------------------------------------------
def _split_at_top_level_commas(expr: str) -> list[str]:
"""Split *expr* by commas, ignoring commas inside parentheses."""
parts: list[str] = []
depth = 0
current: list[str] = []
for ch in expr:
if ch == "(":
depth += 1
current.append(ch)
elif ch == ")":
depth -= 1
current.append(ch)
elif ch == "," and depth == 0:
parts.append("".join(current))
current = []
else:
current.append(ch)
if current:
parts.append("".join(current))
return parts
def _split_agtype_array_elements(text: str) -> list[str]:
"""Split agtype array content by top-level commas.
Unlike :func:`_split_at_top_level_commas`, this function tracks all bracket
types (``{}``, ``[]``, ``()``) and quoted strings so it correctly splits
agtype arrays of vertices/edges that contain nested JSON objects.
"""
parts: list[str] = []
depth = 0
in_string = False
escape_next = False
current: list[str] = []
for ch in text:
if escape_next:
current.append(ch)
escape_next = False
elif in_string:
if ch == "\\":
current.append(ch)
escape_next = True
elif ch == '"':
current.append(ch)
in_string = False
else:
current.append(ch)
elif ch == '"':
current.append(ch)
in_string = True
elif ch in "{[(":
depth += 1
current.append(ch)
elif ch in "}])":
depth -= 1
current.append(ch)
elif ch == "," and depth == 0:
parts.append("".join(current))
current = []
else:
current.append(ch)
if current:
parts.append("".join(current))
return parts
def _parse_return_columns(cypher: str) -> list[str]:
"""Extract SQL column names from the Cypher RETURN clause.
Used to build ``AS (col0 agtype, ...)`` for the AGE ``cypher()`` call.
Handles: simple alias (``RETURN n``), multi-alias (``RETURN n, m``),
property projections (``RETURN n.name``), aggregation AS aliases
(``RETURN count(*) AS cnt``), DISTINCT, and inline RETURN on the same
line as the preceding clause.
"""
# Search each line in reverse for the last RETURN keyword.
# Using re.search (not re.match) so that inline "... RETURN n" is found
# even when RETURN is not at the start of the line.
lines = cypher.splitlines()
return_expr = ""
for line in reversed(lines):
stripped = line.strip()
m = re.search(r"\bRETURN\s+(.*)", stripped, re.IGNORECASE)
if m:
return_expr = m.group(1).strip()
break
if not return_expr:
return ["result"]
# Strip trailing ORDER BY / SKIP / LIMIT / UNION clauses before parsing columns.
return_expr = re.split(
r"\bORDER\s+BY\b|\bSKIP\b|\bLIMIT\b|\bUNION\b", return_expr, flags=re.IGNORECASE
)[0].rstrip()
# Strip DISTINCT keyword
return_expr = re.sub(r"^DISTINCT\s+", "", return_expr, flags=re.IGNORECASE)
cols: list[str] = []
for i, item in enumerate(_split_at_top_level_commas(return_expr)):
item = item.strip()
# Explicit AS alias: "expr AS alias"
as_m = re.search(r"\bAS\s+(\w+)\s*$", item, re.IGNORECASE)
if as_m:
cols.append(as_m.group(1))
continue
# Property access: "n.prop"
dot_m = re.match(r"^\w+\.(\w+)$", item)
if dot_m:
cols.append(dot_m.group(1))
continue
# Simple identifier: "n"
id_m = re.match(r"^(\w+)$", item)
if id_m:
cols.append(id_m.group(1))
continue
# Fallback: positional name
cols.append(f"col{i}")
return cols or ["result"]
# ---------------------------------------------------------------------------
# AGE connection setup helpers
# ---------------------------------------------------------------------------
def _setup_age_connection(conn: Any, graph_name: str) -> None:
"""Load AGE, configure search_path, and register the agtype type adapter."""
from psycopg.adapt import Loader
class _AgtypeLoader(Loader):
def load(self, data: bytes | bytearray | memoryview) -> Any:
text = (
bytes(data).decode("utf-8")
if isinstance(data, memoryview)
else data.decode("utf-8")
)
return _parse_agtype(text)
with conn.cursor() as cur:
cur.execute("LOAD 'age'")
cur.execute('SET search_path = ag_catalog, "$user", public')
# Fetch the agtype OID and register a loader so psycopg decodes it.
cur.execute("SELECT oid FROM pg_type WHERE typname = 'agtype'")
row = cur.fetchone()
if row:
agtype_oid: int = row[0]
conn.adapters.register_loader(agtype_oid, _AgtypeLoader)
# Ensure the graph exists (AGE raises if it does not).
cur.execute(
"SELECT count(*) FROM ag_graph WHERE name = %s",
(graph_name,),
)
result = cur.fetchone()
if result and result[0] == 0:
cur.execute(
"SELECT * FROM create_graph(%s)",
(graph_name,),
)
log.info("AGEDriver: created graph %r", graph_name)
conn.commit()
# ---------------------------------------------------------------------------
# Dialect
# ---------------------------------------------------------------------------
[docs]
class AGEDialect:
"""Strategy for Apache AGE-specific Cypher generation.
Key differences from FalkorDB:
- No ``toInteger()`` cast for ``id()``-based lookups
- No ``vecf32()`` or ``intern()`` wrappers (raw Python values stored as-is)
- Fulltext search: not supported natively (raises ``NotImplementedError``)
- Vector KNN: not supported natively (raises ``NotImplementedError``)
- Multi-label emulation: extra labels stored as ``_labels`` property array
"""
def generated_id_where(self, alias: str, param: str) -> str:
return f"WHERE id({alias}) = ${param}"
[docs]
def labels_clause(self, labels: list[str]) -> str:
"""AGE only supports one label per vertex — use the primary label."""
return labels[0]
[docs]
def subtype_where(self, alias: str, labels: list[str]) -> str | None:
"""Return a WHERE condition filtering by emulated subtype labels."""
if len(labels) > 1:
return " AND ".join(f'"{lbl}" IN {alias}._labels' for lbl in labels[1:])
return None
[docs]
def needs_labels_property(self) -> bool:
"""Signal to the mapper to inject ``_labels`` on CREATE for subtypes."""
return True
def cypher_fn_for_field(self, fi: FieldInfo) -> str | None: # noqa: ARG002
return None
def fulltext_call(self, label: str, alias: str, query_param: str) -> str: # noqa: ARG002
raise NotImplementedError(
"Apache AGE does not support native Cypher fulltext search. "
"Use PostgreSQL full-text search on the underlying tables instead."
)
def vector_knn_start(
self,
alias: str, # noqa: ARG002
labels_str: str, # noqa: ARG002
type_name: str, # noqa: ARG002
field_name: str, # noqa: ARG002
) -> str:
raise NotImplementedError(
"Apache AGE does not support native Cypher vector KNN search. "
"Use pgvector on the underlying PostgreSQL tables instead."
)
def vector_knn_score_expr(self, alias: str, field_name: str) -> str: # noqa: ARG002
raise NotImplementedError(
"Apache AGE does not support native Cypher vector KNN search."
)
def wrap_node(self, raw: Any) -> AGENode:
return AGENode(raw)
def wrap_edge(self, raw: Any) -> AGEEdge:
return AGEEdge(raw)
_AGE_DIALECT = AGEDialect()
# ---------------------------------------------------------------------------
# Driver
# ---------------------------------------------------------------------------
[docs]
class AGEDriver:
"""Sync driver for Apache AGE (PostgreSQL graph extension).
Cypher queries are wrapped in the AGE ``cypher()`` SQL function and
executed via a ``psycopg`` (psycopg3) connection. Parameters are
serialised as an agtype JSON map and passed as the third argument to
``cypher()``, making them accessible inside Cypher as ``$param_name``.
Supports explicit ACID transactions via
:class:`~runic.orm.driver.TransactionalGraphDriver`. psycopg3 starts an
implicit ``BEGIN`` on the first statement after each commit/rollback
(``autocommit=False`` default); this driver's ``commit()`` / ``rollback()``
map directly to ``conn.commit()`` / ``conn.rollback()``. ``begin()`` is a
documented no-op because psycopg3 manages the implicit transaction start
automatically.
The ORM :class:`~runic.orm.session.session.Session` drives this lifecycle:
the first query in a Session opens a transaction implicitly; ``commit()`` /
``rollback()`` finalise it.
AGE stores each vertex label as a separate PostgreSQL table; a vertex
belongs to exactly one label fixed at creation time. Multi-label
operations (``SET n:New REMOVE n:Old``) are therefore not supported.
Example
-------
::
driver = create_age_driver(
host="localhost",
port=5432,
database="postgres",
graph="my_graph",
username="postgres",
password="secret",
)
with Session(driver) as session:
...
"""
supports_multi_label: bool = False
def __init__(self, conn: Any, graph_name: str) -> None:
self._conn = conn
self._graph_name = graph_name
@property
def dialect(self) -> AGEDialect:
return _AGE_DIALECT
# ------------------------------------------------------------------
# Transaction support (TransactionalGraphDriver)
# ------------------------------------------------------------------
[docs]
def begin(self) -> None:
"""No-op: psycopg3 starts an implicit BEGIN on the first statement.
Exists to satisfy the :class:`~runic.orm.driver.TransactionalGraphDriver`
protocol so the ORM Session can detect transaction support via
``isinstance`` checks.
"""
[docs]
def commit(self) -> None:
"""Commit the active PostgreSQL transaction."""
self._conn.commit()
log.debug("AGEDriver: transaction committed on graph %r", self._graph_name)
[docs]
def rollback(self) -> None:
"""Roll back the active PostgreSQL transaction."""
self._conn.rollback()
log.debug("AGEDriver: transaction rolled back on graph %r", self._graph_name)
# ------------------------------------------------------------------
# Execute
# ------------------------------------------------------------------
def execute(self, cypher: str, params: dict[str, Any]) -> AGEResult:
cols = _parse_return_columns(cypher)
as_clause = ", ".join(f"{c} agtype" for c in cols)
with self._conn.cursor() as cur:
if params:
params_json = json.dumps(params, default=_serialize_param)
sql = (
f"SELECT * FROM cypher('{self._graph_name}', " # noqa: S608
f"$age_q$ {cypher} $age_q$, "
f"%s::agtype) AS ({as_clause})"
)
cur.execute(sql, (params_json,))
else:
sql = (
f"SELECT * FROM cypher('{self._graph_name}', " # noqa: S608
f"$age_q$ {cypher} $age_q$) AS ({as_clause})"
)
cur.execute(sql)
columns = [d.name for d in (cur.description or [])]
rows = [list(r) for r in cur.fetchall()]
log.debug(
"AGEDriver executed Cypher on %r; %d row(s)", self._graph_name, len(rows)
)
return AGEResult(rows, columns)
def close(self) -> None:
self._conn.close()
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
[docs]
def create_age_driver(
host: str,
port: int,
database: str,
graph: str,
username: str,
password: str,
) -> AGEDriver:
"""Create an :class:`AGEDriver` connected to a PostgreSQL+AGE instance.
Parameters
----------
host:
PostgreSQL host name or IP address.
port:
PostgreSQL port (default is 5432).
database:
PostgreSQL database name.
graph:
AGE graph name within the database.
username:
PostgreSQL user name.
password:
PostgreSQL password.
"""
import psycopg
conn = psycopg.connect(
host=host,
port=port,
dbname=database,
user=username,
password=password,
)
_setup_age_connection(conn, graph)
return AGEDriver(conn, graph)