1from __future__ import annotations
  2
  3import datetime
  4import re
  5from collections.abc import Generator
  6from typing import TYPE_CHECKING, Any, NamedTuple
  7
  8from plain.postgres.fields.related import (
  9    RECURSIVE_RELATIONSHIP_CONSTANT,
 10    RelatedField,
 11)
 12
 13if TYPE_CHECKING:
 14    from plain.postgres.fields import Field
 15    from plain.postgres.fields.reverse_related import ForeignObjectRel
 16
 17
 18class FieldReference(NamedTuple):
 19    """Reference to a field in migrations, tracking direct and through relationships."""
 20
 21    to: tuple[ForeignObjectRel, list[str]] | None
 22    through: tuple[ForeignObjectRel, tuple[str, ...] | None] | None
 23
 24
 25COMPILED_REGEX_TYPE = type(re.compile(""))
 26
 27
 28class RegexObject:
 29    def __init__(self, obj: Any) -> None:
 30        self.pattern = obj.pattern
 31        self.flags = obj.flags
 32
 33    def __eq__(self, other: Any) -> bool:
 34        if not isinstance(other, RegexObject):
 35            return NotImplemented
 36        return self.pattern == other.pattern and self.flags == other.flags
 37
 38
 39def get_migration_name_timestamp() -> str:
 40    return datetime.datetime.now().strftime("%Y%m%d_%H%M")
 41
 42
 43def resolve_relation(
 44    model: str | Any, package_label: str | None = None, model_name: str | None = None
 45) -> tuple[str, str]:
 46    """
 47    Turn a model class or model reference string and return a model tuple.
 48
 49    package_label and model_name are used to resolve the scope of recursive and
 50    unscoped model relationship.
 51    """
 52    if isinstance(model, str):
 53        if model == RECURSIVE_RELATIONSHIP_CONSTANT:
 54            if package_label is None or model_name is None:
 55                raise TypeError(
 56                    "package_label and model_name must be provided to resolve "
 57                    "recursive relationships."
 58                )
 59            return package_label, model_name
 60        if "." in model:
 61            package_label, model_name = model.split(".", 1)
 62            return package_label, model_name.lower()
 63        if package_label is None:
 64            raise TypeError(
 65                "package_label must be provided to resolve unscoped model relationships."
 66            )
 67        return package_label, model.lower()
 68    return model.model_options.package_label, model.model_options.model_name
 69
 70
 71def field_references(
 72    model_tuple: tuple[str, str],
 73    field: Field,
 74    reference_model_tuple: tuple[str, str],
 75    reference_field_name: str | None = None,
 76    reference_field: Field | None = None,
 77) -> FieldReference | bool:
 78    """
 79    Return either False or a FieldReference if `field` references provided
 80    context.
 81
 82    False positives can be returned if `reference_field_name` is provided
 83    without `reference_field` because of the introspection limitation it
 84    incurs. This should not be an issue when this function is used to determine
 85    whether or not an optimization can take place.
 86    """
 87    # Only RelatedFields have remote_field attribute
 88    if not isinstance(field, RelatedField):
 89        return False
 90    remote_field = field.remote_field
 91    if not remote_field:
 92        return False
 93    references_to = None
 94    references_through = None
 95    if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
 96        # ForeignObject always references 'id'
 97        if (
 98            reference_field_name is None
 99            or reference_field_name == "id"
100            or (reference_field is None or reference_field.primary_key)
101        ):
102            references_to = (remote_field, ["id"])
103    through = getattr(remote_field, "through", None)
104    if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
105        through_fields = getattr(remote_field, "through_fields", None)
106        if (
107            reference_field_name is None
108            or
109            # Unspecified through_fields.
110            through_fields is None
111            or
112            # Reference to field.
113            reference_field_name in through_fields
114        ):
115            references_through = (remote_field, through_fields)
116    if not (references_to or references_through):
117        return False
118    return FieldReference(references_to, references_through)
119
120
121def get_references(
122    state: Any, model_tuple: tuple[str, str], field_tuple: tuple[Any, ...] = ()
123) -> Generator[tuple[Any, str, Field, FieldReference]]:
124    """
125    Generator of (model_state, name, field, reference) referencing
126    provided context.
127
128    If field_tuple is provided only references to this particular field of
129    model_tuple will be generated.
130    """
131    for state_model_tuple, model_state in state.models.items():
132        for name, field in model_state.fields.items():
133            reference = field_references(
134                state_model_tuple,
135                field,
136                model_tuple,
137                *field_tuple,
138            )
139            if reference:
140                yield model_state, name, field, reference
141
142
143def field_is_referenced(
144    state: Any, model_tuple: tuple[str, str], field_tuple: tuple[Any, ...]
145) -> bool:
146    """Return whether `field_tuple` is referenced by any state models."""
147    return next(get_references(state, model_tuple, field_tuple), None) is not None