1from __future__ import annotations
2
3import datetime
4import re
5from collections.abc import Generator
6from typing import TYPE_CHECKING, Any, NamedTuple
7
8from plain.postgres.fields.related import (
9 RECURSIVE_RELATIONSHIP_CONSTANT,
10 RelatedField,
11)
12
13if TYPE_CHECKING:
14 from plain.postgres.fields import Field
15 from plain.postgres.fields.reverse_related import ForeignObjectRel
16
17
18class FieldReference(NamedTuple):
19 """Reference to a field in migrations, tracking direct and through relationships."""
20
21 to: tuple[ForeignObjectRel, list[str]] | None
22 through: tuple[ForeignObjectRel, tuple[str, ...] | None] | None
23
24
25COMPILED_REGEX_TYPE = type(re.compile(""))
26
27
28class RegexObject:
29 def __init__(self, obj: Any) -> None:
30 self.pattern = obj.pattern
31 self.flags = obj.flags
32
33 def __eq__(self, other: Any) -> bool:
34 if not isinstance(other, RegexObject):
35 return NotImplemented
36 return self.pattern == other.pattern and self.flags == other.flags
37
38
39def get_migration_name_timestamp() -> str:
40 return datetime.datetime.now().strftime("%Y%m%d_%H%M")
41
42
43def resolve_relation(
44 model: str | Any, package_label: str | None = None, model_name: str | None = None
45) -> tuple[str, str]:
46 """
47 Turn a model class or model reference string and return a model tuple.
48
49 package_label and model_name are used to resolve the scope of recursive and
50 unscoped model relationship.
51 """
52 if isinstance(model, str):
53 if model == RECURSIVE_RELATIONSHIP_CONSTANT:
54 if package_label is None or model_name is None:
55 raise TypeError(
56 "package_label and model_name must be provided to resolve "
57 "recursive relationships."
58 )
59 return package_label, model_name
60 if "." in model:
61 package_label, model_name = model.split(".", 1)
62 return package_label, model_name.lower()
63 if package_label is None:
64 raise TypeError(
65 "package_label must be provided to resolve unscoped model relationships."
66 )
67 return package_label, model.lower()
68 return model.model_options.package_label, model.model_options.model_name
69
70
71def field_references(
72 model_tuple: tuple[str, str],
73 field: Field,
74 reference_model_tuple: tuple[str, str],
75 reference_field_name: str | None = None,
76 reference_field: Field | None = None,
77) -> FieldReference | bool:
78 """
79 Return either False or a FieldReference if `field` references provided
80 context.
81
82 False positives can be returned if `reference_field_name` is provided
83 without `reference_field` because of the introspection limitation it
84 incurs. This should not be an issue when this function is used to determine
85 whether or not an optimization can take place.
86 """
87 # Only RelatedFields have remote_field attribute
88 if not isinstance(field, RelatedField):
89 return False
90 remote_field = field.remote_field
91 if not remote_field:
92 return False
93 references_to = None
94 references_through = None
95 if resolve_relation(remote_field.model, *model_tuple) == reference_model_tuple:
96 # ForeignObject always references 'id'
97 if (
98 reference_field_name is None
99 or reference_field_name == "id"
100 or (reference_field is None or reference_field.primary_key)
101 ):
102 references_to = (remote_field, ["id"])
103 through = getattr(remote_field, "through", None)
104 if through and resolve_relation(through, *model_tuple) == reference_model_tuple:
105 through_fields = getattr(remote_field, "through_fields", None)
106 if (
107 reference_field_name is None
108 or
109 # Unspecified through_fields.
110 through_fields is None
111 or
112 # Reference to field.
113 reference_field_name in through_fields
114 ):
115 references_through = (remote_field, through_fields)
116 if not (references_to or references_through):
117 return False
118 return FieldReference(references_to, references_through)
119
120
121def get_references(
122 state: Any, model_tuple: tuple[str, str], field_tuple: tuple[Any, ...] = ()
123) -> Generator[tuple[Any, str, Field, FieldReference]]:
124 """
125 Generator of (model_state, name, field, reference) referencing
126 provided context.
127
128 If field_tuple is provided only references to this particular field of
129 model_tuple will be generated.
130 """
131 for state_model_tuple, model_state in state.models.items():
132 for name, field in model_state.fields.items():
133 reference = field_references(
134 state_model_tuple,
135 field,
136 model_tuple,
137 *field_tuple,
138 )
139 if reference:
140 yield model_state, name, field, reference
141
142
143def field_is_referenced(
144 state: Any, model_tuple: tuple[str, str], field_tuple: tuple[Any, ...]
145) -> bool:
146 """Return whether `field_tuple` is referenced by any state models."""
147 return next(get_references(state, model_tuple, field_tuple), None) is not None