Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from datetime import datetime
  4from typing import TYPE_CHECKING, Any
  5
  6from plain.models.expressions import Func
  7from plain.models.fields import (
  8    DateField,
  9    DateTimeField,
 10    DurationField,
 11    Field,
 12    IntegerField,
 13    TimeField,
 14)
 15from plain.models.lookups import (
 16    Transform,
 17    YearExact,
 18    YearGt,
 19    YearGte,
 20    YearLt,
 21    YearLte,
 22)
 23from plain.utils import timezone
 24
 25if TYPE_CHECKING:
 26    from plain.models.backends.base.base import BaseDatabaseWrapper
 27    from plain.models.sql.compiler import SQLCompiler
 28
 29
 30class TimezoneMixin(Transform):
 31    tzinfo = None
 32
 33    def get_tzname(self) -> str | None:
 34        # Timezone conversions must happen to the input datetime *before*
 35        # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
 36        # database as 2016-01-01 01:00:00 +00:00. Any results should be
 37        # based on the input datetime not the stored datetime.
 38        if self.tzinfo is None:
 39            return timezone.get_current_timezone_name()
 40        else:
 41            return timezone._get_timezone_name(self.tzinfo)
 42
 43
 44class Extract(TimezoneMixin, Transform):
 45    lookup_name: str | None = None
 46    output_field = IntegerField()
 47
 48    def __init__(
 49        self,
 50        expression: Any,
 51        lookup_name: str | None = None,
 52        tzinfo: Any = None,
 53        **extra: Any,
 54    ) -> None:
 55        if self.lookup_name is None:
 56            self.lookup_name = lookup_name
 57        if self.lookup_name is None:
 58            raise ValueError("lookup_name must be provided")
 59        self.tzinfo = tzinfo
 60        super().__init__(expression, **extra)
 61
 62    def as_sql(
 63        self,
 64        compiler: SQLCompiler,
 65        connection: BaseDatabaseWrapper,
 66        function: str | None = None,
 67        template: str | None = None,
 68        arg_joiner: str | None = None,
 69        **extra_context: Any,
 70    ) -> tuple[str, list[Any]]:
 71        # lookup_name is guaranteed to be str after __init__ validation
 72        assert self.lookup_name is not None
 73        sql, params = compiler.compile(self.lhs)
 74        lhs_output_field = self.lhs.output_field
 75        if isinstance(lhs_output_field, DateTimeField):
 76            tzname = self.get_tzname()
 77            sql, params = connection.ops.datetime_extract_sql(
 78                self.lookup_name, sql, tuple(params), tzname
 79            )
 80        elif self.tzinfo is not None:
 81            raise ValueError("tzinfo can only be used with DateTimeField.")
 82        elif isinstance(lhs_output_field, DateField):
 83            sql, params = connection.ops.date_extract_sql(
 84                self.lookup_name, sql, tuple(params)
 85            )
 86        elif isinstance(lhs_output_field, TimeField):
 87            sql, params = connection.ops.time_extract_sql(
 88                self.lookup_name, sql, tuple(params)
 89            )
 90        elif isinstance(lhs_output_field, DurationField):
 91            if not connection.features.has_native_duration_field:
 92                raise ValueError(
 93                    "Extract requires native DurationField database support."
 94                )
 95            sql, params = connection.ops.time_extract_sql(
 96                self.lookup_name, sql, tuple(params)
 97            )
 98        else:
 99            # resolve_expression has already validated the output_field so this
100            # assert should never be hit.
101            raise ValueError("Tried to Extract from an invalid type.")
102        return sql, list(params)
103
104    def resolve_expression(
105        self,
106        query: Any = None,
107        allow_joins: bool = True,
108        reuse: Any = None,
109        summarize: bool = False,
110        for_save: bool = False,
111    ) -> Extract:
112        copy = super().resolve_expression(
113            query, allow_joins, reuse, summarize, for_save
114        )
115        field = getattr(copy.lhs, "output_field", None)
116        if field is None:
117            return copy
118        if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
119            raise ValueError(
120                "Extract input expression must be DateField, DateTimeField, "
121                "TimeField, or DurationField."
122            )
123        # Passing dates to functions expecting datetimes is most likely a mistake.
124        if type(field) == DateField and copy.lookup_name in (  # noqa: E721
125            "hour",
126            "minute",
127            "second",
128        ):
129            raise ValueError(
130                f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'."
131            )
132        if isinstance(field, DurationField) and copy.lookup_name in (
133            "year",
134            "iso_year",
135            "month",
136            "week",
137            "week_day",
138            "iso_week_day",
139            "quarter",
140        ):
141            raise ValueError(
142                f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'."
143            )
144        return copy
145
146
147class ExtractYear(Extract):
148    lookup_name = "year"
149
150
151class ExtractIsoYear(Extract):
152    """Return the ISO-8601 week-numbering year."""
153
154    lookup_name = "iso_year"
155
156
157class ExtractMonth(Extract):
158    lookup_name = "month"
159
160
161class ExtractDay(Extract):
162    lookup_name = "day"
163
164
165class ExtractWeek(Extract):
166    """
167    Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
168    week.
169    """
170
171    lookup_name = "week"
172
173
174class ExtractWeekDay(Extract):
175    """
176    Return Sunday=1 through Saturday=7.
177
178    To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
179    """
180
181    lookup_name = "week_day"
182
183
184class ExtractIsoWeekDay(Extract):
185    """Return Monday=1 through Sunday=7, based on ISO-8601."""
186
187    lookup_name = "iso_week_day"
188
189
190class ExtractQuarter(Extract):
191    lookup_name = "quarter"
192
193
194class ExtractHour(Extract):
195    lookup_name = "hour"
196
197
198class ExtractMinute(Extract):
199    lookup_name = "minute"
200
201
202class ExtractSecond(Extract):
203    lookup_name = "second"
204
205
206DateField.register_lookup(ExtractYear)
207DateField.register_lookup(ExtractMonth)
208DateField.register_lookup(ExtractDay)
209DateField.register_lookup(ExtractWeekDay)
210DateField.register_lookup(ExtractIsoWeekDay)
211DateField.register_lookup(ExtractWeek)
212DateField.register_lookup(ExtractIsoYear)
213DateField.register_lookup(ExtractQuarter)
214
215TimeField.register_lookup(ExtractHour)
216TimeField.register_lookup(ExtractMinute)
217TimeField.register_lookup(ExtractSecond)
218
219DateTimeField.register_lookup(ExtractHour)
220DateTimeField.register_lookup(ExtractMinute)
221DateTimeField.register_lookup(ExtractSecond)
222
223ExtractYear.register_lookup(YearExact)
224ExtractYear.register_lookup(YearGt)
225ExtractYear.register_lookup(YearGte)
226ExtractYear.register_lookup(YearLt)
227ExtractYear.register_lookup(YearLte)
228
229ExtractIsoYear.register_lookup(YearExact)
230ExtractIsoYear.register_lookup(YearGt)
231ExtractIsoYear.register_lookup(YearGte)
232ExtractIsoYear.register_lookup(YearLt)
233ExtractIsoYear.register_lookup(YearLte)
234
235
236class Now(Func):
237    template = "CURRENT_TIMESTAMP"
238    output_field = DateTimeField()
239
240    def as_postgresql(
241        self,
242        compiler: SQLCompiler,
243        connection: BaseDatabaseWrapper,
244        **extra_context: Any,
245    ) -> tuple[str, list[Any]]:
246        # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
247        # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
248        # other databases.
249        return self.as_sql(
250            compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
251        )
252
253    def as_mysql(
254        self,
255        compiler: SQLCompiler,
256        connection: BaseDatabaseWrapper,
257        **extra_context: Any,
258    ) -> tuple[str, list[Any]]:
259        return self.as_sql(
260            compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
261        )
262
263    def as_sqlite(
264        self,
265        compiler: SQLCompiler,
266        connection: BaseDatabaseWrapper,
267        **extra_context: Any,
268    ) -> tuple[str, list[Any]]:
269        return self.as_sql(
270            compiler,
271            connection,
272            template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
273            **extra_context,
274        )
275
276
277class TruncBase(TimezoneMixin, Transform):
278    kind: str | None = None
279    tzinfo = None
280
281    def __init__(
282        self,
283        expression: Any,
284        output_field: Field | None = None,
285        tzinfo: Any = None,
286        **extra: Any,
287    ) -> None:
288        self.tzinfo = tzinfo
289        super().__init__(expression, output_field=output_field, **extra)
290
291    def as_sql(
292        self,
293        compiler: SQLCompiler,
294        connection: BaseDatabaseWrapper,
295        function: str | None = None,
296        template: str | None = None,
297        arg_joiner: str | None = None,
298        **extra_context: Any,
299    ) -> tuple[str, list[Any]]:
300        # kind is guaranteed to be str in subclasses
301        assert self.kind is not None
302        sql, params = compiler.compile(self.lhs)
303        tzname = None
304        if isinstance(self.lhs.output_field, DateTimeField):
305            tzname = self.get_tzname()
306        elif self.tzinfo is not None:
307            raise ValueError("tzinfo can only be used with DateTimeField.")
308        if isinstance(self.output_field, DateTimeField):
309            sql, params = connection.ops.datetime_trunc_sql(
310                self.kind, sql, tuple(params), tzname
311            )
312        elif isinstance(self.output_field, DateField):
313            sql, params = connection.ops.date_trunc_sql(
314                self.kind, sql, tuple(params), tzname
315            )
316        elif isinstance(self.output_field, TimeField):
317            sql, params = connection.ops.time_trunc_sql(
318                self.kind, sql, tuple(params), tzname
319            )
320        else:
321            raise ValueError(
322                "Trunc only valid on DateField, TimeField, or DateTimeField."
323            )
324        return sql, list(params)
325
326    def resolve_expression(
327        self,
328        query: Any = None,
329        allow_joins: bool = True,
330        reuse: Any = None,
331        summarize: bool = False,
332        for_save: bool = False,
333    ) -> TruncBase:
334        copy = super().resolve_expression(
335            query, allow_joins, reuse, summarize, for_save
336        )
337        field = copy.lhs.output_field
338        # DateTimeField is a subclass of DateField so this works for both.
339        if not isinstance(field, DateField | TimeField):
340            raise TypeError(
341                f"{field.name!r} isn't a DateField, TimeField, or DateTimeField."
342            )
343        # If self.output_field was None, then accessing the field will trigger
344        # the resolver to assign it to self.lhs.output_field.
345        if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
346            raise ValueError(
347                "output_field must be either DateField, TimeField, or DateTimeField"
348            )
349        # Passing dates or times to functions expecting datetimes is most
350        # likely a mistake.
351        class_output_field = (
352            self.__class__.output_field
353            if isinstance(self.__class__.output_field, Field)
354            else None
355        )
356        output_field = class_output_field or copy.output_field
357        has_explicit_output_field = (
358            class_output_field or field.__class__ is not copy.output_field.__class__
359        )
360        if type(field) == DateField and (  # noqa: E721
361            isinstance(output_field, DateTimeField)
362            or copy.kind in ("hour", "minute", "second", "time")
363        ):
364            raise ValueError(
365                "Cannot truncate DateField '{}' to {}.".format(
366                    field.name,
367                    output_field.__class__.__name__
368                    if has_explicit_output_field
369                    else "DateTimeField",
370                )
371            )
372        elif isinstance(field, TimeField) and (
373            isinstance(output_field, DateTimeField)
374            or copy.kind in ("year", "quarter", "month", "week", "day", "date")
375        ):
376            raise ValueError(
377                "Cannot truncate TimeField '{}' to {}.".format(
378                    field.name,
379                    output_field.__class__.__name__
380                    if has_explicit_output_field
381                    else "DateTimeField",
382                )
383            )
384        return copy
385
386    def convert_value(
387        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
388    ) -> Any:
389        if isinstance(self.output_field, DateTimeField):
390            if value is not None:
391                value = value.replace(tzinfo=None)
392                value = timezone.make_aware(value, self.tzinfo)
393            elif not connection.features.has_zoneinfo_database:
394                raise ValueError(
395                    "Database returned an invalid datetime value. Are time "
396                    "zone definitions for your database installed?"
397                )
398        elif isinstance(value, datetime):
399            if value is None:
400                pass
401            elif isinstance(self.output_field, DateField):
402                value = value.date()
403            elif isinstance(self.output_field, TimeField):
404                value = value.time()
405        return value
406
407
408class Trunc(TruncBase):
409    def __init__(
410        self,
411        expression: Any,
412        kind: str,
413        output_field: Field | None = None,
414        tzinfo: Any = None,
415        **extra: Any,
416    ) -> None:
417        self.kind = kind
418        super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
419
420
421class TruncYear(TruncBase):
422    kind = "year"
423
424
425class TruncQuarter(TruncBase):
426    kind = "quarter"
427
428
429class TruncMonth(TruncBase):
430    kind = "month"
431
432
433class TruncWeek(TruncBase):
434    """Truncate to midnight on the Monday of the week."""
435
436    kind = "week"
437
438
439class TruncDay(TruncBase):
440    kind = "day"
441
442
443class TruncDate(TruncBase):
444    kind = "date"
445    lookup_name = "date"
446    output_field = DateField()
447
448    def as_sql(
449        self,
450        compiler: SQLCompiler,
451        connection: BaseDatabaseWrapper,
452        function: str | None = None,
453        template: str | None = None,
454        arg_joiner: str | None = None,
455        **extra_context: Any,
456    ) -> tuple[str, list[Any]]:
457        # Cast to date rather than truncate to date.
458        sql, params = compiler.compile(self.lhs)
459        tzname = self.get_tzname()
460        sql, params = connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
461        return sql, list(params)
462
463
464class TruncTime(TruncBase):
465    kind = "time"
466    lookup_name = "time"
467    output_field = TimeField()
468
469    def as_sql(
470        self,
471        compiler: SQLCompiler,
472        connection: BaseDatabaseWrapper,
473        function: str | None = None,
474        template: str | None = None,
475        arg_joiner: str | None = None,
476        **extra_context: Any,
477    ) -> tuple[str, list[Any]]:
478        # Cast to time rather than truncate to time.
479        sql, params = compiler.compile(self.lhs)
480        tzname = self.get_tzname()
481        sql, params = connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
482        return sql, list(params)
483
484
485class TruncHour(TruncBase):
486    kind = "hour"
487
488
489class TruncMinute(TruncBase):
490    kind = "minute"
491
492
493class TruncSecond(TruncBase):
494    kind = "second"
495
496
497DateTimeField.register_lookup(TruncDate)
498DateTimeField.register_lookup(TruncTime)