1from __future__ import annotations
  2
  3from datetime import datetime
  4from typing import TYPE_CHECKING, Any
  5
  6from plain.models.expressions import Func
  7from plain.models.fields import (
  8    DateField,
  9    DateTimeField,
 10    DurationField,
 11    Field,
 12    IntegerField,
 13    TimeField,
 14)
 15from plain.models.lookups import (
 16    Transform,
 17    YearExact,
 18    YearGt,
 19    YearGte,
 20    YearLt,
 21    YearLte,
 22)
 23from plain.models.postgres.sql import (
 24    date_extract_sql,
 25    date_trunc_sql,
 26    datetime_cast_date_sql,
 27    datetime_cast_time_sql,
 28    datetime_extract_sql,
 29    datetime_trunc_sql,
 30    time_extract_sql,
 31    time_trunc_sql,
 32)
 33from plain.utils import timezone
 34
 35if TYPE_CHECKING:
 36    from plain.models.postgres.wrapper import DatabaseWrapper
 37    from plain.models.sql.compiler import SQLCompiler
 38
 39
 40class TimezoneMixin(Transform):
 41    tzinfo = None
 42
 43    def get_tzname(self) -> str | None:
 44        # Timezone conversions must happen to the input datetime *before*
 45        # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
 46        # database as 2016-01-01 01:00:00 +00:00. Any results should be
 47        # based on the input datetime not the stored datetime.
 48        if self.tzinfo is None:
 49            return timezone.get_current_timezone_name()
 50        else:
 51            return timezone._get_timezone_name(self.tzinfo)
 52
 53
 54class Extract(TimezoneMixin, Transform):
 55    lookup_name: str | None = None
 56    output_field = IntegerField()
 57
 58    def __init__(
 59        self,
 60        expression: Any,
 61        lookup_name: str | None = None,
 62        tzinfo: Any = None,
 63        **extra: Any,
 64    ) -> None:
 65        if self.lookup_name is None:
 66            self.lookup_name = lookup_name
 67        if self.lookup_name is None:
 68            raise ValueError("lookup_name must be provided")
 69        self.tzinfo = tzinfo
 70        super().__init__(expression, **extra)
 71
 72    def as_sql(
 73        self,
 74        compiler: SQLCompiler,
 75        connection: DatabaseWrapper,
 76        function: str | None = None,
 77        template: str | None = None,
 78        arg_joiner: str | None = None,
 79        **extra_context: Any,
 80    ) -> tuple[str, list[Any]]:
 81        # lookup_name is guaranteed to be str after __init__ validation
 82        assert self.lookup_name is not None
 83        sql, params = compiler.compile(self.lhs)
 84        lhs_output_field = self.lhs.output_field
 85        if isinstance(lhs_output_field, DateTimeField):
 86            tzname = self.get_tzname()
 87            sql, params = datetime_extract_sql(
 88                self.lookup_name, sql, tuple(params), tzname
 89            )
 90        elif self.tzinfo is not None:
 91            raise ValueError("tzinfo can only be used with DateTimeField.")
 92        elif isinstance(lhs_output_field, DateField):
 93            sql, params = date_extract_sql(self.lookup_name, sql, tuple(params))
 94        elif isinstance(lhs_output_field, TimeField):
 95            sql, params = time_extract_sql(self.lookup_name, sql, tuple(params))
 96        elif isinstance(lhs_output_field, DurationField):
 97            # PostgreSQL has native duration (interval) type
 98            sql, params = time_extract_sql(self.lookup_name, sql, tuple(params))
 99        else:
100            # resolve_expression has already validated the output_field so this
101            # assert should never be hit.
102            raise ValueError("Tried to Extract from an invalid type.")
103        return sql, list(params)
104
105    def resolve_expression(
106        self,
107        query: Any = None,
108        allow_joins: bool = True,
109        reuse: Any = None,
110        summarize: bool = False,
111        for_save: bool = False,
112    ) -> Extract:
113        copy = super().resolve_expression(
114            query, allow_joins, reuse, summarize, for_save
115        )
116        field = getattr(copy.lhs, "output_field", None)
117        if field is None:
118            return copy
119        if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
120            raise ValueError(
121                "Extract input expression must be DateField, DateTimeField, "
122                "TimeField, or DurationField."
123            )
124        # Passing dates to functions expecting datetimes is most likely a mistake.
125        if type(field) == DateField and copy.lookup_name in (  # noqa: E721
126            "hour",
127            "minute",
128            "second",
129        ):
130            raise ValueError(
131                f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'."
132            )
133        if isinstance(field, DurationField) and copy.lookup_name in (
134            "year",
135            "iso_year",
136            "month",
137            "week",
138            "week_day",
139            "iso_week_day",
140            "quarter",
141        ):
142            raise ValueError(
143                f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'."
144            )
145        return copy
146
147
148class ExtractYear(Extract):
149    lookup_name = "year"
150
151
152class ExtractIsoYear(Extract):
153    """Return the ISO-8601 week-numbering year."""
154
155    lookup_name = "iso_year"
156
157
158class ExtractMonth(Extract):
159    lookup_name = "month"
160
161
162class ExtractDay(Extract):
163    lookup_name = "day"
164
165
166class ExtractWeek(Extract):
167    """
168    Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
169    week.
170    """
171
172    lookup_name = "week"
173
174
175class ExtractWeekDay(Extract):
176    """
177    Return Sunday=1 through Saturday=7.
178
179    To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
180    """
181
182    lookup_name = "week_day"
183
184
185class ExtractIsoWeekDay(Extract):
186    """Return Monday=1 through Sunday=7, based on ISO-8601."""
187
188    lookup_name = "iso_week_day"
189
190
191class ExtractQuarter(Extract):
192    lookup_name = "quarter"
193
194
195class ExtractHour(Extract):
196    lookup_name = "hour"
197
198
199class ExtractMinute(Extract):
200    lookup_name = "minute"
201
202
203class ExtractSecond(Extract):
204    lookup_name = "second"
205
206
207DateField.register_lookup(ExtractYear)
208DateField.register_lookup(ExtractMonth)
209DateField.register_lookup(ExtractDay)
210DateField.register_lookup(ExtractWeekDay)
211DateField.register_lookup(ExtractIsoWeekDay)
212DateField.register_lookup(ExtractWeek)
213DateField.register_lookup(ExtractIsoYear)
214DateField.register_lookup(ExtractQuarter)
215
216TimeField.register_lookup(ExtractHour)
217TimeField.register_lookup(ExtractMinute)
218TimeField.register_lookup(ExtractSecond)
219
220DateTimeField.register_lookup(ExtractHour)
221DateTimeField.register_lookup(ExtractMinute)
222DateTimeField.register_lookup(ExtractSecond)
223
224ExtractYear.register_lookup(YearExact)
225ExtractYear.register_lookup(YearGt)
226ExtractYear.register_lookup(YearGte)
227ExtractYear.register_lookup(YearLt)
228ExtractYear.register_lookup(YearLte)
229
230ExtractIsoYear.register_lookup(YearExact)
231ExtractIsoYear.register_lookup(YearGt)
232ExtractIsoYear.register_lookup(YearGte)
233ExtractIsoYear.register_lookup(YearLt)
234ExtractIsoYear.register_lookup(YearLte)
235
236
237class Now(Func):
238    # STATEMENT_TIMESTAMP() returns the time at the start of the current statement,
239    # as opposed to CURRENT_TIMESTAMP which returns the time at the start of the
240    # transaction.
241    template = "STATEMENT_TIMESTAMP()"
242    output_field = DateTimeField()
243
244
245class TruncBase(TimezoneMixin, Transform):
246    kind: str | None = None
247
248    def __init__(
249        self,
250        expression: Any,
251        output_field: Field | None = None,
252        tzinfo: Any = None,
253        **extra: Any,
254    ) -> None:
255        self.tzinfo = tzinfo
256        super().__init__(expression, output_field=output_field, **extra)
257
258    def as_sql(
259        self,
260        compiler: SQLCompiler,
261        connection: DatabaseWrapper,
262        function: str | None = None,
263        template: str | None = None,
264        arg_joiner: str | None = None,
265        **extra_context: Any,
266    ) -> tuple[str, list[Any]]:
267        # kind is guaranteed to be str in subclasses
268        assert self.kind is not None
269        sql, params = compiler.compile(self.lhs)
270        tzname = None
271        if isinstance(self.lhs.output_field, DateTimeField):
272            tzname = self.get_tzname()
273        elif self.tzinfo is not None:
274            raise ValueError("tzinfo can only be used with DateTimeField.")
275        if isinstance(self.output_field, DateTimeField):
276            sql, params = datetime_trunc_sql(self.kind, sql, tuple(params), tzname)
277        elif isinstance(self.output_field, DateField):
278            sql, params = date_trunc_sql(self.kind, sql, tuple(params), tzname)
279        elif isinstance(self.output_field, TimeField):
280            sql, params = time_trunc_sql(self.kind, sql, tuple(params), tzname)
281        else:
282            raise ValueError(
283                "Trunc only valid on DateField, TimeField, or DateTimeField."
284            )
285        return sql, list(params)
286
287    def resolve_expression(
288        self,
289        query: Any = None,
290        allow_joins: bool = True,
291        reuse: Any = None,
292        summarize: bool = False,
293        for_save: bool = False,
294    ) -> TruncBase:
295        copy = super().resolve_expression(
296            query, allow_joins, reuse, summarize, for_save
297        )
298        field = copy.lhs.output_field
299        # DateTimeField is a subclass of DateField so this works for both.
300        if not isinstance(field, DateField | TimeField):
301            raise TypeError(
302                f"{field.name!r} isn't a DateField, TimeField, or DateTimeField."
303            )
304        # If self.output_field was None, then accessing the field will trigger
305        # the resolver to assign it to self.lhs.output_field.
306        if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
307            raise ValueError(
308                "output_field must be either DateField, TimeField, or DateTimeField"
309            )
310        # Passing dates or times to functions expecting datetimes is most
311        # likely a mistake.
312        class_output_field = (
313            self.__class__.output_field
314            if isinstance(self.__class__.output_field, Field)
315            else None
316        )
317        output_field = class_output_field or copy.output_field
318        has_explicit_output_field = (
319            class_output_field or field.__class__ is not copy.output_field.__class__
320        )
321        if type(field) == DateField and (  # noqa: E721
322            isinstance(output_field, DateTimeField)
323            or copy.kind in ("hour", "minute", "second", "time")
324        ):
325            raise ValueError(
326                "Cannot truncate DateField '{}' to {}.".format(
327                    field.name,
328                    output_field.__class__.__name__
329                    if has_explicit_output_field
330                    else "DateTimeField",
331                )
332            )
333        elif isinstance(field, TimeField) and (
334            isinstance(output_field, DateTimeField)
335            or copy.kind in ("year", "quarter", "month", "week", "day", "date")
336        ):
337            raise ValueError(
338                "Cannot truncate TimeField '{}' to {}.".format(
339                    field.name,
340                    output_field.__class__.__name__
341                    if has_explicit_output_field
342                    else "DateTimeField",
343                )
344            )
345        return copy
346
347    def convert_value(
348        self, value: Any, expression: Any, connection: DatabaseWrapper
349    ) -> Any:
350        if isinstance(self.output_field, DateTimeField):
351            if value is not None:
352                value = value.replace(tzinfo=None)
353                value = timezone.make_aware(value, self.tzinfo)
354        elif isinstance(value, datetime):
355            if value is None:
356                pass
357            elif isinstance(self.output_field, DateField):
358                value = value.date()
359            elif isinstance(self.output_field, TimeField):
360                value = value.time()
361        return value
362
363
364class Trunc(TruncBase):
365    def __init__(
366        self,
367        expression: Any,
368        kind: str,
369        output_field: Field | None = None,
370        tzinfo: Any = None,
371        **extra: Any,
372    ) -> None:
373        self.kind = kind
374        super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
375
376
377class TruncYear(TruncBase):
378    kind = "year"
379
380
381class TruncQuarter(TruncBase):
382    kind = "quarter"
383
384
385class TruncMonth(TruncBase):
386    kind = "month"
387
388
389class TruncWeek(TruncBase):
390    """Truncate to midnight on the Monday of the week."""
391
392    kind = "week"
393
394
395class TruncDay(TruncBase):
396    kind = "day"
397
398
399class TruncDate(TruncBase):
400    kind = "date"
401    lookup_name = "date"
402    output_field = DateField()
403
404    def as_sql(
405        self,
406        compiler: SQLCompiler,
407        connection: DatabaseWrapper,
408        function: str | None = None,
409        template: str | None = None,
410        arg_joiner: str | None = None,
411        **extra_context: Any,
412    ) -> tuple[str, list[Any]]:
413        # Cast to date rather than truncate to date.
414        sql, params = compiler.compile(self.lhs)
415        tzname = self.get_tzname()
416        sql, params = datetime_cast_date_sql(sql, tuple(params), tzname)
417        return sql, list(params)
418
419
420class TruncTime(TruncBase):
421    kind = "time"
422    lookup_name = "time"
423    output_field = TimeField()
424
425    def as_sql(
426        self,
427        compiler: SQLCompiler,
428        connection: DatabaseWrapper,
429        function: str | None = None,
430        template: str | None = None,
431        arg_joiner: str | None = None,
432        **extra_context: Any,
433    ) -> tuple[str, list[Any]]:
434        # Cast to time rather than truncate to time.
435        sql, params = compiler.compile(self.lhs)
436        tzname = self.get_tzname()
437        sql, params = datetime_cast_time_sql(sql, tuple(params), tzname)
438        return sql, list(params)
439
440
441class TruncHour(TruncBase):
442    kind = "hour"
443
444
445class TruncMinute(TruncBase):
446    kind = "minute"
447
448
449class TruncSecond(TruncBase):
450    kind = "second"
451
452
453DateTimeField.register_lookup(TruncDate)
454DateTimeField.register_lookup(TruncTime)