Plain is headed towards 1.0! Subscribe for development updates →

  1from datetime import datetime
  2
  3from plain.models.expressions import Func
  4from plain.models.fields import (
  5    DateField,
  6    DateTimeField,
  7    DurationField,
  8    Field,
  9    IntegerField,
 10    TimeField,
 11)
 12from plain.models.lookups import (
 13    Transform,
 14    YearExact,
 15    YearGt,
 16    YearGte,
 17    YearLt,
 18    YearLte,
 19)
 20from plain.utils import timezone
 21
 22
 23class TimezoneMixin:
 24    tzinfo = None
 25
 26    def get_tzname(self):
 27        # Timezone conversions must happen to the input datetime *before*
 28        # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
 29        # database as 2016-01-01 01:00:00 +00:00. Any results should be
 30        # based on the input datetime not the stored datetime.
 31        if self.tzinfo is None:
 32            return timezone.get_current_timezone_name()
 33        else:
 34            return timezone._get_timezone_name(self.tzinfo)
 35
 36
 37class Extract(TimezoneMixin, Transform):
 38    lookup_name = None
 39    output_field = IntegerField()
 40
 41    def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
 42        if self.lookup_name is None:
 43            self.lookup_name = lookup_name
 44        if self.lookup_name is None:
 45            raise ValueError("lookup_name must be provided")
 46        self.tzinfo = tzinfo
 47        super().__init__(expression, **extra)
 48
 49    def as_sql(self, compiler, connection):
 50        sql, params = compiler.compile(self.lhs)
 51        lhs_output_field = self.lhs.output_field
 52        if isinstance(lhs_output_field, DateTimeField):
 53            tzname = self.get_tzname()
 54            sql, params = connection.ops.datetime_extract_sql(
 55                self.lookup_name, sql, tuple(params), tzname
 56            )
 57        elif self.tzinfo is not None:
 58            raise ValueError("tzinfo can only be used with DateTimeField.")
 59        elif isinstance(lhs_output_field, DateField):
 60            sql, params = connection.ops.date_extract_sql(
 61                self.lookup_name, sql, tuple(params)
 62            )
 63        elif isinstance(lhs_output_field, TimeField):
 64            sql, params = connection.ops.time_extract_sql(
 65                self.lookup_name, sql, tuple(params)
 66            )
 67        elif isinstance(lhs_output_field, DurationField):
 68            if not connection.features.has_native_duration_field:
 69                raise ValueError(
 70                    "Extract requires native DurationField database support."
 71                )
 72            sql, params = connection.ops.time_extract_sql(
 73                self.lookup_name, sql, tuple(params)
 74            )
 75        else:
 76            # resolve_expression has already validated the output_field so this
 77            # assert should never be hit.
 78            raise ValueError("Tried to Extract from an invalid type.")
 79        return sql, params
 80
 81    def resolve_expression(
 82        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
 83    ):
 84        copy = super().resolve_expression(
 85            query, allow_joins, reuse, summarize, for_save
 86        )
 87        field = getattr(copy.lhs, "output_field", None)
 88        if field is None:
 89            return copy
 90        if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
 91            raise ValueError(
 92                "Extract input expression must be DateField, DateTimeField, "
 93                "TimeField, or DurationField."
 94            )
 95        # Passing dates to functions expecting datetimes is most likely a mistake.
 96        if type(field) == DateField and copy.lookup_name in (  # noqa: E721
 97            "hour",
 98            "minute",
 99            "second",
100        ):
101            raise ValueError(
102                f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'."
103            )
104        if isinstance(field, DurationField) and copy.lookup_name in (
105            "year",
106            "iso_year",
107            "month",
108            "week",
109            "week_day",
110            "iso_week_day",
111            "quarter",
112        ):
113            raise ValueError(
114                f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'."
115            )
116        return copy
117
118
119class ExtractYear(Extract):
120    lookup_name = "year"
121
122
123class ExtractIsoYear(Extract):
124    """Return the ISO-8601 week-numbering year."""
125
126    lookup_name = "iso_year"
127
128
129class ExtractMonth(Extract):
130    lookup_name = "month"
131
132
133class ExtractDay(Extract):
134    lookup_name = "day"
135
136
137class ExtractWeek(Extract):
138    """
139    Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
140    week.
141    """
142
143    lookup_name = "week"
144
145
146class ExtractWeekDay(Extract):
147    """
148    Return Sunday=1 through Saturday=7.
149
150    To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
151    """
152
153    lookup_name = "week_day"
154
155
156class ExtractIsoWeekDay(Extract):
157    """Return Monday=1 through Sunday=7, based on ISO-8601."""
158
159    lookup_name = "iso_week_day"
160
161
162class ExtractQuarter(Extract):
163    lookup_name = "quarter"
164
165
166class ExtractHour(Extract):
167    lookup_name = "hour"
168
169
170class ExtractMinute(Extract):
171    lookup_name = "minute"
172
173
174class ExtractSecond(Extract):
175    lookup_name = "second"
176
177
178DateField.register_lookup(ExtractYear)
179DateField.register_lookup(ExtractMonth)
180DateField.register_lookup(ExtractDay)
181DateField.register_lookup(ExtractWeekDay)
182DateField.register_lookup(ExtractIsoWeekDay)
183DateField.register_lookup(ExtractWeek)
184DateField.register_lookup(ExtractIsoYear)
185DateField.register_lookup(ExtractQuarter)
186
187TimeField.register_lookup(ExtractHour)
188TimeField.register_lookup(ExtractMinute)
189TimeField.register_lookup(ExtractSecond)
190
191DateTimeField.register_lookup(ExtractHour)
192DateTimeField.register_lookup(ExtractMinute)
193DateTimeField.register_lookup(ExtractSecond)
194
195ExtractYear.register_lookup(YearExact)
196ExtractYear.register_lookup(YearGt)
197ExtractYear.register_lookup(YearGte)
198ExtractYear.register_lookup(YearLt)
199ExtractYear.register_lookup(YearLte)
200
201ExtractIsoYear.register_lookup(YearExact)
202ExtractIsoYear.register_lookup(YearGt)
203ExtractIsoYear.register_lookup(YearGte)
204ExtractIsoYear.register_lookup(YearLt)
205ExtractIsoYear.register_lookup(YearLte)
206
207
208class Now(Func):
209    template = "CURRENT_TIMESTAMP"
210    output_field = DateTimeField()
211
212    def as_postgresql(self, compiler, connection, **extra_context):
213        # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
214        # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
215        # other databases.
216        return self.as_sql(
217            compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
218        )
219
220    def as_mysql(self, compiler, connection, **extra_context):
221        return self.as_sql(
222            compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
223        )
224
225    def as_sqlite(self, compiler, connection, **extra_context):
226        return self.as_sql(
227            compiler,
228            connection,
229            template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
230            **extra_context,
231        )
232
233
234class TruncBase(TimezoneMixin, Transform):
235    kind = None
236    tzinfo = None
237
238    def __init__(
239        self,
240        expression,
241        output_field=None,
242        tzinfo=None,
243        **extra,
244    ):
245        self.tzinfo = tzinfo
246        super().__init__(expression, output_field=output_field, **extra)
247
248    def as_sql(self, compiler, connection):
249        sql, params = compiler.compile(self.lhs)
250        tzname = None
251        if isinstance(self.lhs.output_field, DateTimeField):
252            tzname = self.get_tzname()
253        elif self.tzinfo is not None:
254            raise ValueError("tzinfo can only be used with DateTimeField.")
255        if isinstance(self.output_field, DateTimeField):
256            sql, params = connection.ops.datetime_trunc_sql(
257                self.kind, sql, tuple(params), tzname
258            )
259        elif isinstance(self.output_field, DateField):
260            sql, params = connection.ops.date_trunc_sql(
261                self.kind, sql, tuple(params), tzname
262            )
263        elif isinstance(self.output_field, TimeField):
264            sql, params = connection.ops.time_trunc_sql(
265                self.kind, sql, tuple(params), tzname
266            )
267        else:
268            raise ValueError(
269                "Trunc only valid on DateField, TimeField, or DateTimeField."
270            )
271        return sql, params
272
273    def resolve_expression(
274        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
275    ):
276        copy = super().resolve_expression(
277            query, allow_joins, reuse, summarize, for_save
278        )
279        field = copy.lhs.output_field
280        # DateTimeField is a subclass of DateField so this works for both.
281        if not isinstance(field, DateField | TimeField):
282            raise TypeError(
283                f"{field.name!r} isn't a DateField, TimeField, or DateTimeField."
284            )
285        # If self.output_field was None, then accessing the field will trigger
286        # the resolver to assign it to self.lhs.output_field.
287        if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
288            raise ValueError(
289                "output_field must be either DateField, TimeField, or DateTimeField"
290            )
291        # Passing dates or times to functions expecting datetimes is most
292        # likely a mistake.
293        class_output_field = (
294            self.__class__.output_field
295            if isinstance(self.__class__.output_field, Field)
296            else None
297        )
298        output_field = class_output_field or copy.output_field
299        has_explicit_output_field = (
300            class_output_field or field.__class__ is not copy.output_field.__class__
301        )
302        if type(field) == DateField and (  # noqa: E721
303            isinstance(output_field, DateTimeField)
304            or copy.kind in ("hour", "minute", "second", "time")
305        ):
306            raise ValueError(
307                "Cannot truncate DateField '{}' to {}.".format(
308                    field.name,
309                    output_field.__class__.__name__
310                    if has_explicit_output_field
311                    else "DateTimeField",
312                )
313            )
314        elif isinstance(field, TimeField) and (
315            isinstance(output_field, DateTimeField)
316            or copy.kind in ("year", "quarter", "month", "week", "day", "date")
317        ):
318            raise ValueError(
319                "Cannot truncate TimeField '{}' to {}.".format(
320                    field.name,
321                    output_field.__class__.__name__
322                    if has_explicit_output_field
323                    else "DateTimeField",
324                )
325            )
326        return copy
327
328    def convert_value(self, value, expression, connection):
329        if isinstance(self.output_field, DateTimeField):
330            if value is not None:
331                value = value.replace(tzinfo=None)
332                value = timezone.make_aware(value, self.tzinfo)
333            elif not connection.features.has_zoneinfo_database:
334                raise ValueError(
335                    "Database returned an invalid datetime value. Are time "
336                    "zone definitions for your database installed?"
337                )
338        elif isinstance(value, datetime):
339            if value is None:
340                pass
341            elif isinstance(self.output_field, DateField):
342                value = value.date()
343            elif isinstance(self.output_field, TimeField):
344                value = value.time()
345        return value
346
347
348class Trunc(TruncBase):
349    def __init__(
350        self,
351        expression,
352        kind,
353        output_field=None,
354        tzinfo=None,
355        **extra,
356    ):
357        self.kind = kind
358        super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
359
360
361class TruncYear(TruncBase):
362    kind = "year"
363
364
365class TruncQuarter(TruncBase):
366    kind = "quarter"
367
368
369class TruncMonth(TruncBase):
370    kind = "month"
371
372
373class TruncWeek(TruncBase):
374    """Truncate to midnight on the Monday of the week."""
375
376    kind = "week"
377
378
379class TruncDay(TruncBase):
380    kind = "day"
381
382
383class TruncDate(TruncBase):
384    kind = "date"
385    lookup_name = "date"
386    output_field = DateField()
387
388    def as_sql(self, compiler, connection):
389        # Cast to date rather than truncate to date.
390        sql, params = compiler.compile(self.lhs)
391        tzname = self.get_tzname()
392        return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
393
394
395class TruncTime(TruncBase):
396    kind = "time"
397    lookup_name = "time"
398    output_field = TimeField()
399
400    def as_sql(self, compiler, connection):
401        # Cast to time rather than truncate to time.
402        sql, params = compiler.compile(self.lhs)
403        tzname = self.get_tzname()
404        return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
405
406
407class TruncHour(TruncBase):
408    kind = "hour"
409
410
411class TruncMinute(TruncBase):
412    kind = "minute"
413
414
415class TruncSecond(TruncBase):
416    kind = "second"
417
418
419DateTimeField.register_lookup(TruncDate)
420DateTimeField.register_lookup(TruncTime)