Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from datetime import datetime
  4from typing import TYPE_CHECKING, Any
  5
  6from plain.models.expressions import Func
  7from plain.models.fields import (
  8    DateField,
  9    DateTimeField,
 10    DurationField,
 11    Field,
 12    IntegerField,
 13    TimeField,
 14)
 15from plain.models.lookups import (
 16    Transform,
 17    YearExact,
 18    YearGt,
 19    YearGte,
 20    YearLt,
 21    YearLte,
 22)
 23from plain.utils import timezone
 24
 25if TYPE_CHECKING:
 26    from plain.models.backends.base.base import BaseDatabaseWrapper
 27    from plain.models.sql.compiler import SQLCompiler
 28
 29
 30class TimezoneMixin:
 31    tzinfo = None
 32
 33    def get_tzname(self) -> str | None:
 34        # Timezone conversions must happen to the input datetime *before*
 35        # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
 36        # database as 2016-01-01 01:00:00 +00:00. Any results should be
 37        # based on the input datetime not the stored datetime.
 38        if self.tzinfo is None:
 39            return timezone.get_current_timezone_name()
 40        else:
 41            return timezone._get_timezone_name(self.tzinfo)
 42
 43
 44class Extract(TimezoneMixin, Transform):
 45    lookup_name = None
 46    output_field = IntegerField()
 47
 48    def __init__(
 49        self,
 50        expression: Any,
 51        lookup_name: str | None = None,
 52        tzinfo: Any = None,
 53        **extra: Any,
 54    ) -> None:
 55        if self.lookup_name is None:
 56            self.lookup_name = lookup_name
 57        if self.lookup_name is None:
 58            raise ValueError("lookup_name must be provided")
 59        self.tzinfo = tzinfo
 60        super().__init__(expression, **extra)
 61
 62    def as_sql(
 63        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
 64    ) -> tuple[str, tuple[Any, ...]]:
 65        sql, params = compiler.compile(self.lhs)
 66        lhs_output_field = self.lhs.output_field
 67        if isinstance(lhs_output_field, DateTimeField):
 68            tzname = self.get_tzname()
 69            sql, params = connection.ops.datetime_extract_sql(
 70                self.lookup_name, sql, tuple(params), tzname
 71            )
 72        elif self.tzinfo is not None:
 73            raise ValueError("tzinfo can only be used with DateTimeField.")
 74        elif isinstance(lhs_output_field, DateField):
 75            sql, params = connection.ops.date_extract_sql(
 76                self.lookup_name, sql, tuple(params)
 77            )
 78        elif isinstance(lhs_output_field, TimeField):
 79            sql, params = connection.ops.time_extract_sql(
 80                self.lookup_name, sql, tuple(params)
 81            )
 82        elif isinstance(lhs_output_field, DurationField):
 83            if not connection.features.has_native_duration_field:
 84                raise ValueError(
 85                    "Extract requires native DurationField database support."
 86                )
 87            sql, params = connection.ops.time_extract_sql(
 88                self.lookup_name, sql, tuple(params)
 89            )
 90        else:
 91            # resolve_expression has already validated the output_field so this
 92            # assert should never be hit.
 93            raise ValueError("Tried to Extract from an invalid type.")
 94        return sql, tuple(params) if isinstance(params, list) else params
 95
 96    def resolve_expression(
 97        self,
 98        query: Any = None,
 99        allow_joins: bool = True,
100        reuse: Any = None,
101        summarize: bool = False,
102        for_save: bool = False,
103    ) -> Extract:
104        copy = super().resolve_expression(
105            query, allow_joins, reuse, summarize, for_save
106        )
107        field = getattr(copy.lhs, "output_field", None)
108        if field is None:
109            return copy
110        if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
111            raise ValueError(
112                "Extract input expression must be DateField, DateTimeField, "
113                "TimeField, or DurationField."
114            )
115        # Passing dates to functions expecting datetimes is most likely a mistake.
116        if type(field) == DateField and copy.lookup_name in (  # noqa: E721
117            "hour",
118            "minute",
119            "second",
120        ):
121            raise ValueError(
122                f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'."
123            )
124        if isinstance(field, DurationField) and copy.lookup_name in (
125            "year",
126            "iso_year",
127            "month",
128            "week",
129            "week_day",
130            "iso_week_day",
131            "quarter",
132        ):
133            raise ValueError(
134                f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'."
135            )
136        return copy
137
138
139class ExtractYear(Extract):
140    lookup_name = "year"
141
142
143class ExtractIsoYear(Extract):
144    """Return the ISO-8601 week-numbering year."""
145
146    lookup_name = "iso_year"
147
148
149class ExtractMonth(Extract):
150    lookup_name = "month"
151
152
153class ExtractDay(Extract):
154    lookup_name = "day"
155
156
157class ExtractWeek(Extract):
158    """
159    Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
160    week.
161    """
162
163    lookup_name = "week"
164
165
166class ExtractWeekDay(Extract):
167    """
168    Return Sunday=1 through Saturday=7.
169
170    To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
171    """
172
173    lookup_name = "week_day"
174
175
176class ExtractIsoWeekDay(Extract):
177    """Return Monday=1 through Sunday=7, based on ISO-8601."""
178
179    lookup_name = "iso_week_day"
180
181
182class ExtractQuarter(Extract):
183    lookup_name = "quarter"
184
185
186class ExtractHour(Extract):
187    lookup_name = "hour"
188
189
190class ExtractMinute(Extract):
191    lookup_name = "minute"
192
193
194class ExtractSecond(Extract):
195    lookup_name = "second"
196
197
198DateField.register_lookup(ExtractYear)
199DateField.register_lookup(ExtractMonth)
200DateField.register_lookup(ExtractDay)
201DateField.register_lookup(ExtractWeekDay)
202DateField.register_lookup(ExtractIsoWeekDay)
203DateField.register_lookup(ExtractWeek)
204DateField.register_lookup(ExtractIsoYear)
205DateField.register_lookup(ExtractQuarter)
206
207TimeField.register_lookup(ExtractHour)
208TimeField.register_lookup(ExtractMinute)
209TimeField.register_lookup(ExtractSecond)
210
211DateTimeField.register_lookup(ExtractHour)
212DateTimeField.register_lookup(ExtractMinute)
213DateTimeField.register_lookup(ExtractSecond)
214
215ExtractYear.register_lookup(YearExact)
216ExtractYear.register_lookup(YearGt)
217ExtractYear.register_lookup(YearGte)
218ExtractYear.register_lookup(YearLt)
219ExtractYear.register_lookup(YearLte)
220
221ExtractIsoYear.register_lookup(YearExact)
222ExtractIsoYear.register_lookup(YearGt)
223ExtractIsoYear.register_lookup(YearGte)
224ExtractIsoYear.register_lookup(YearLt)
225ExtractIsoYear.register_lookup(YearLte)
226
227
228class Now(Func):
229    template = "CURRENT_TIMESTAMP"
230    output_field = DateTimeField()
231
232    def as_postgresql(
233        self,
234        compiler: SQLCompiler,
235        connection: BaseDatabaseWrapper,
236        **extra_context: Any,
237    ) -> tuple[str, tuple[Any, ...]]:
238        # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
239        # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
240        # other databases.
241        return self.as_sql(
242            compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
243        )
244
245    def as_mysql(
246        self,
247        compiler: SQLCompiler,
248        connection: BaseDatabaseWrapper,
249        **extra_context: Any,
250    ) -> tuple[str, tuple[Any, ...]]:
251        return self.as_sql(
252            compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
253        )
254
255    def as_sqlite(
256        self,
257        compiler: SQLCompiler,
258        connection: BaseDatabaseWrapper,
259        **extra_context: Any,
260    ) -> tuple[str, tuple[Any, ...]]:
261        return self.as_sql(
262            compiler,
263            connection,
264            template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
265            **extra_context,
266        )
267
268
269class TruncBase(TimezoneMixin, Transform):
270    kind = None
271    tzinfo = None
272
273    def __init__(
274        self,
275        expression: Any,
276        output_field: Field | None = None,
277        tzinfo: Any = None,
278        **extra: Any,
279    ) -> None:
280        self.tzinfo = tzinfo
281        super().__init__(expression, output_field=output_field, **extra)
282
283    def as_sql(
284        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
285    ) -> tuple[str, tuple[Any, ...]]:
286        sql, params = compiler.compile(self.lhs)
287        tzname = None
288        if isinstance(self.lhs.output_field, DateTimeField):
289            tzname = self.get_tzname()
290        elif self.tzinfo is not None:
291            raise ValueError("tzinfo can only be used with DateTimeField.")
292        if isinstance(self.output_field, DateTimeField):
293            sql, params = connection.ops.datetime_trunc_sql(
294                self.kind, sql, tuple(params), tzname
295            )
296        elif isinstance(self.output_field, DateField):
297            sql, params = connection.ops.date_trunc_sql(
298                self.kind, sql, tuple(params), tzname
299            )
300        elif isinstance(self.output_field, TimeField):
301            sql, params = connection.ops.time_trunc_sql(
302                self.kind, sql, tuple(params), tzname
303            )
304        else:
305            raise ValueError(
306                "Trunc only valid on DateField, TimeField, or DateTimeField."
307            )
308        return sql, tuple(params) if isinstance(params, list) else params
309
310    def resolve_expression(
311        self,
312        query: Any = None,
313        allow_joins: bool = True,
314        reuse: Any = None,
315        summarize: bool = False,
316        for_save: bool = False,
317    ) -> TruncBase:
318        copy = super().resolve_expression(
319            query, allow_joins, reuse, summarize, for_save
320        )
321        field = copy.lhs.output_field
322        # DateTimeField is a subclass of DateField so this works for both.
323        if not isinstance(field, DateField | TimeField):
324            raise TypeError(
325                f"{field.name!r} isn't a DateField, TimeField, or DateTimeField."
326            )
327        # If self.output_field was None, then accessing the field will trigger
328        # the resolver to assign it to self.lhs.output_field.
329        if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
330            raise ValueError(
331                "output_field must be either DateField, TimeField, or DateTimeField"
332            )
333        # Passing dates or times to functions expecting datetimes is most
334        # likely a mistake.
335        class_output_field = (
336            self.__class__.output_field
337            if isinstance(self.__class__.output_field, Field)
338            else None
339        )
340        output_field = class_output_field or copy.output_field
341        has_explicit_output_field = (
342            class_output_field or field.__class__ is not copy.output_field.__class__
343        )
344        if type(field) == DateField and (  # noqa: E721
345            isinstance(output_field, DateTimeField)
346            or copy.kind in ("hour", "minute", "second", "time")
347        ):
348            raise ValueError(
349                "Cannot truncate DateField '{}' to {}.".format(
350                    field.name,
351                    output_field.__class__.__name__
352                    if has_explicit_output_field
353                    else "DateTimeField",
354                )
355            )
356        elif isinstance(field, TimeField) and (
357            isinstance(output_field, DateTimeField)
358            or copy.kind in ("year", "quarter", "month", "week", "day", "date")
359        ):
360            raise ValueError(
361                "Cannot truncate TimeField '{}' to {}.".format(
362                    field.name,
363                    output_field.__class__.__name__
364                    if has_explicit_output_field
365                    else "DateTimeField",
366                )
367            )
368        return copy
369
370    def convert_value(
371        self, value: Any, expression: Any, connection: BaseDatabaseWrapper
372    ) -> Any:
373        if isinstance(self.output_field, DateTimeField):
374            if value is not None:
375                value = value.replace(tzinfo=None)
376                value = timezone.make_aware(value, self.tzinfo)
377            elif not connection.features.has_zoneinfo_database:
378                raise ValueError(
379                    "Database returned an invalid datetime value. Are time "
380                    "zone definitions for your database installed?"
381                )
382        elif isinstance(value, datetime):
383            if value is None:
384                pass
385            elif isinstance(self.output_field, DateField):
386                value = value.date()
387            elif isinstance(self.output_field, TimeField):
388                value = value.time()
389        return value
390
391
392class Trunc(TruncBase):
393    def __init__(
394        self,
395        expression: Any,
396        kind: str,
397        output_field: Field | None = None,
398        tzinfo: Any = None,
399        **extra: Any,
400    ) -> None:
401        self.kind = kind
402        super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
403
404
405class TruncYear(TruncBase):
406    kind = "year"
407
408
409class TruncQuarter(TruncBase):
410    kind = "quarter"
411
412
413class TruncMonth(TruncBase):
414    kind = "month"
415
416
417class TruncWeek(TruncBase):
418    """Truncate to midnight on the Monday of the week."""
419
420    kind = "week"
421
422
423class TruncDay(TruncBase):
424    kind = "day"
425
426
427class TruncDate(TruncBase):
428    kind = "date"
429    lookup_name = "date"
430    output_field = DateField()
431
432    def as_sql(
433        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
434    ) -> tuple[str, tuple[Any, ...]]:
435        # Cast to date rather than truncate to date.
436        sql, params = compiler.compile(self.lhs)
437        tzname = self.get_tzname()
438        sql, params = connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
439        return sql, tuple(params) if isinstance(params, list) else params
440
441
442class TruncTime(TruncBase):
443    kind = "time"
444    lookup_name = "time"
445    output_field = TimeField()
446
447    def as_sql(
448        self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
449    ) -> tuple[str, tuple[Any, ...]]:
450        # Cast to time rather than truncate to time.
451        sql, params = compiler.compile(self.lhs)
452        tzname = self.get_tzname()
453        sql, params = connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
454        return sql, tuple(params) if isinstance(params, list) else params
455
456
457class TruncHour(TruncBase):
458    kind = "hour"
459
460
461class TruncMinute(TruncBase):
462    kind = "minute"
463
464
465class TruncSecond(TruncBase):
466    kind = "second"
467
468
469DateTimeField.register_lookup(TruncDate)
470DateTimeField.register_lookup(TruncTime)