Source code for runic.orm.core.models

"""Node and Edge base classes for graph entities."""

from __future__ import annotations

import logging
import typing
from collections.abc import Callable
from datetime import datetime
from enum import Enum
from typing import Any, ClassVar, dataclass_transform

from runic.orm.core.descriptors import Field, FieldDescriptor, FieldInfo
from runic.orm.core.types import (
    DatetimeConverter,
    EnumConverter,
    GeoLocation,
    GeoLocationConverter,
    Vector,
    VectorConverter,
)

log = logging.getLogger(__name__)

# Marker stored on the base classes so _collect_fields can skip them.
_BASE_CLASS_MARKER = "_is_orm_base"


def _is_collection_annotation(ann: Any) -> bool:
    """Return True if *ann* represents a list/collection type."""
    if isinstance(ann, str):
        return ann.startswith(("list[", "List["))
    return getattr(ann, "__origin__", None) is list


_ABSENT = object()


def _synthesize_plain_annotations(cls: type) -> None:
    """Install Field descriptors for bare ``name: type`` annotations on *cls*.

    Called before ``_collect_fields`` so plain annotations participate in
    dirty-tracking and ``__init__`` generation identically to explicit
    ``Field()`` declarations.  Only annotations declared directly on *cls*
    (not inherited ones) are processed; ancestors handle their own during
    their own ``__init_subclass__`` call.
    """
    for name, ann in cls.__annotations__.items():
        if name.startswith("_"):
            continue
        ann_str = ann if isinstance(ann, str) else repr(ann)
        if "ClassVar" in ann_str:
            continue
        if isinstance(cls.__dict__.get(name), FieldDescriptor):
            continue
        val = cls.__dict__.get(name, _ABSENT)
        if val is _ABSENT:
            is_optional = (
                "| None" in ann_str or "None |" in ann_str or "Optional[" in ann_str
            )
            field = Field(default=None) if is_optional else Field()
        else:
            field = Field(default=val)
        field.__set_name__(cls, name)
        setattr(cls, name, field)


def _collect_fields(cls: type, stop_at: type) -> list[FieldInfo]:
    """Collect Field descriptors from *cls* and its ancestors.

    Traverses the MRO in reverse (most-base first) so that parent fields come
    before child fields in the returned list.  ``stop_at`` (Node or Edge) and
    ``object`` are excluded.  If a name appears in multiple bases the first
    occurrence wins (i.e. parent definition is kept; child cannot shadow).

    ``FieldInfo.is_collection`` is set from the type annotation — ``True`` when
    the annotation is ``list[...]`` or ``List[...]``.
    """
    # Merge annotations across MRO so child overrides are handled correctly.
    annotations: dict[str, Any] = {}
    for base in reversed(cls.__mro__):
        if base is object or base is stop_at:
            continue
        annotations.update(getattr(base, "__annotations__", {}))

    seen: set[str] = set()
    result: list[FieldInfo] = []
    for base in reversed(cls.__mro__):
        if base is object or base is stop_at:
            continue
        for name, val in base.__dict__.items():
            if isinstance(val, FieldDescriptor) and name not in seen:
                ann = annotations.get(name)
                is_coll = _is_collection_annotation(ann)
                result.append(FieldInfo(name=name, field=val, is_collection=is_coll))
                seen.add(name)
    return result


def _make_init(field_infos: list[FieldInfo]) -> Callable[..., None]:
    """Return a ``__init__`` function that accepts all *field_infos* as kwargs.

    Values are stored directly into ``instance.__dict__`` to bypass
    ``Field.__set__``, preventing the initial construction from marking
    the entity dirty.
    """
    init_fields = [fi for fi in field_infos if fi.field.init]
    non_init_fields = [fi for fi in field_infos if not fi.field.init]
    required_names = frozenset(
        fi.name for fi in init_fields if not fi.field.has_default
    )
    init_names = frozenset(fi.name for fi in init_fields)

    def _generated_init(instance: Any, **kwargs: Any) -> None:
        # Reject unknown keyword arguments.
        for key in kwargs:
            if key not in init_names:
                raise TypeError(f"__init__() got unexpected keyword argument '{key!r}'")

        # Enforce required fields.
        for name in required_names:
            if name not in kwargs:
                raise TypeError(
                    f"__init__() missing required keyword argument: '{name}'"
                )

        # Object-state flags — written directly to bypass Field.__set__.
        instance.__dict__["_new"] = True
        instance.__dict__["_dirty"] = False

        # Initialise declared init fields.
        for fi in init_fields:
            if fi.name in kwargs:
                instance.__dict__[fi.name] = kwargs[fi.name]
            else:
                instance.__dict__[fi.name] = fi.field.get_default()

        # Initialise non-init fields from their defaults.
        for fi in non_init_fields:
            if fi.field.has_default:
                instance.__dict__[fi.name] = fi.field.get_default()

    return _generated_init


def _apply_auto_converters(cls: type, fields: list[FieldInfo]) -> None:
    """Assign TypeConverters automatically for well-known annotation types.

    Handles ``datetime``, ``Enum`` subclasses, ``Vector``, and ``GeoLocation``
    on any field that has no explicit converter and is not a relationship.
    Falls back silently when annotations cannot be resolved (e.g. forward refs).
    """
    try:
        hints = typing.get_type_hints(cls)
    except Exception:
        return

    for fi in fields:
        if fi.field.relationship is not None:
            continue
        if fi.field.converter is not None:
            continue

        ann = hints.get(fi.name)
        if ann is None:
            continue

        # Unwrap Optional[X] / X | None → X
        origin = typing.get_origin(ann)
        args = typing.get_args(ann)
        if origin is typing.Union and len(args) == 2 and type(None) in args:
            ann = next(a for a in args if a is not type(None))

        if ann is datetime:
            fi.field.converter = DatetimeConverter()
        elif ann is Vector:
            fi.field.converter = VectorConverter()
        elif ann is GeoLocation:
            fi.field.converter = GeoLocationConverter()
        elif isinstance(ann, type) and issubclass(ann, Enum) and ann is not Enum:
            fi.field.converter = EnumConverter(ann)


# ---------------------------------------------------------------------------
# Node
# ---------------------------------------------------------------------------


[docs] @dataclass_transform(field_specifiers=(Field,), kw_only_default=True) class Node: """Base class for graph nodes. Subclass with ``labels`` and optional ``primary_label`` class keywords:: class Country( Location, labels=["Location", "Country"], primary_label="Location" ): iso_code: str = Field(unique=True) ``__init__`` is generated automatically from declared ``Field`` descriptors. Setting any field on an existing instance marks it dirty (``_dirty = True``), which the Mapper interprets as a MERGE/SET on the next flush. """ _is_orm_base: ClassVar[bool] = True # Class-level attributes set by __init_subclass__: _labels: ClassVar[list[str]] _primary_label: ClassVar[str] _fields: ClassVar[list[FieldInfo]] # Instance-level state flags (written by _make_init / Mapper): _new: bool _dirty: bool def __init_subclass__( cls, labels: list[str] | None = None, primary_label: str | None = None, **kwargs: Any, ) -> None: super().__init_subclass__(**kwargs) effective_labels = labels if labels is not None else [cls.__name__] effective_primary = ( primary_label if primary_label is not None else effective_labels[0] ) cls._labels = effective_labels cls._primary_label = effective_primary _synthesize_plain_annotations(cls) cls._fields = _collect_fields(cls, Node) _apply_auto_converters(cls, cls._fields) cls.__init__ = _make_init(cls._fields) # ty:ignore[invalid-assignment] from runic.orm.core.metadata import metadata metadata.register_node(cls) log.debug("Node subclass registered: %s", cls.__name__) def __repr__(self) -> str: pk_field = next( ( fi for fi in self.__class__._fields if fi.field.primary_key or fi.name == "id" ), None, ) if pk_field: pk_val = self.__dict__.get(pk_field.name, "?") return f"{type(self).__name__}({pk_field.name}={pk_val!r})" return f"{type(self).__name__}()"
# --------------------------------------------------------------------------- # Edge # ---------------------------------------------------------------------------
[docs] @dataclass_transform(field_specifiers=(Field,), kw_only_default=True) class Edge: """Base class for graph edge property models. Subclass with a ``type`` class keyword:: class InvitationEdge(Edge, type="INVITED_TO"): role: str = Field(required=True) Edge instances carry the same ``_new``/``_dirty`` lifecycle flags as Node. """ _is_orm_base: ClassVar[bool] = True # Class-level attributes set by __init_subclass__: _edge_type: ClassVar[str] _fields: ClassVar[list[FieldInfo]] # Instance-level state flags: _new: bool _dirty: bool def __init_subclass__( cls, type: str | None = None, # noqa: A002 **kwargs: Any, ) -> None: super().__init_subclass__(**kwargs) cls._edge_type = type if type is not None else cls.__name__ _synthesize_plain_annotations(cls) cls._fields = _collect_fields(cls, Edge) _apply_auto_converters(cls, cls._fields) cls.__init__ = _make_init(cls._fields) # ty:ignore[invalid-assignment] from runic.orm.core.metadata import metadata metadata.register_edge(cls) log.debug("Edge subclass registered: %s", cls.__name__) def __repr__(self) -> str: return f"{type(self).__name__}(type={self._edge_type!r})"