Plain is headed towards 1.0! Subscribe for development updates →

  1from datetime import datetime
  2
  3from plain.models.expressions import Func
  4from plain.models.fields import (
  5    DateField,
  6    DateTimeField,
  7    DurationField,
  8    Field,
  9    IntegerField,
 10    TimeField,
 11)
 12from plain.models.lookups import (
 13    Transform,
 14    YearExact,
 15    YearGt,
 16    YearGte,
 17    YearLt,
 18    YearLte,
 19)
 20from plain.runtime import settings
 21from plain.utils import timezone
 22
 23
 24class TimezoneMixin:
 25    tzinfo = None
 26
 27    def get_tzname(self):
 28        # Timezone conversions must happen to the input datetime *before*
 29        # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
 30        # database as 2016-01-01 01:00:00 +00:00. Any results should be
 31        # based on the input datetime not the stored datetime.
 32        tzname = None
 33        if settings.USE_TZ:
 34            if self.tzinfo is None:
 35                tzname = timezone.get_current_timezone_name()
 36            else:
 37                tzname = timezone._get_timezone_name(self.tzinfo)
 38        return tzname
 39
 40
 41class Extract(TimezoneMixin, Transform):
 42    lookup_name = None
 43    output_field = IntegerField()
 44
 45    def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
 46        if self.lookup_name is None:
 47            self.lookup_name = lookup_name
 48        if self.lookup_name is None:
 49            raise ValueError("lookup_name must be provided")
 50        self.tzinfo = tzinfo
 51        super().__init__(expression, **extra)
 52
 53    def as_sql(self, compiler, connection):
 54        sql, params = compiler.compile(self.lhs)
 55        lhs_output_field = self.lhs.output_field
 56        if isinstance(lhs_output_field, DateTimeField):
 57            tzname = self.get_tzname()
 58            sql, params = connection.ops.datetime_extract_sql(
 59                self.lookup_name, sql, tuple(params), tzname
 60            )
 61        elif self.tzinfo is not None:
 62            raise ValueError("tzinfo can only be used with DateTimeField.")
 63        elif isinstance(lhs_output_field, DateField):
 64            sql, params = connection.ops.date_extract_sql(
 65                self.lookup_name, sql, tuple(params)
 66            )
 67        elif isinstance(lhs_output_field, TimeField):
 68            sql, params = connection.ops.time_extract_sql(
 69                self.lookup_name, sql, tuple(params)
 70            )
 71        elif isinstance(lhs_output_field, DurationField):
 72            if not connection.features.has_native_duration_field:
 73                raise ValueError(
 74                    "Extract requires native DurationField database support."
 75                )
 76            sql, params = connection.ops.time_extract_sql(
 77                self.lookup_name, sql, tuple(params)
 78            )
 79        else:
 80            # resolve_expression has already validated the output_field so this
 81            # assert should never be hit.
 82            raise ValueError("Tried to Extract from an invalid type.")
 83        return sql, params
 84
 85    def resolve_expression(
 86        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
 87    ):
 88        copy = super().resolve_expression(
 89            query, allow_joins, reuse, summarize, for_save
 90        )
 91        field = getattr(copy.lhs, "output_field", None)
 92        if field is None:
 93            return copy
 94        if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
 95            raise ValueError(
 96                "Extract input expression must be DateField, DateTimeField, "
 97                "TimeField, or DurationField."
 98            )
 99        # Passing dates to functions expecting datetimes is most likely a mistake.
100        if type(field) == DateField and copy.lookup_name in (
101            "hour",
102            "minute",
103            "second",
104        ):
105            raise ValueError(
106                "Cannot extract time component '{}' from DateField '{}'.".format(
107                    copy.lookup_name, field.name
108                )
109            )
110        if isinstance(field, DurationField) and copy.lookup_name in (
111            "year",
112            "iso_year",
113            "month",
114            "week",
115            "week_day",
116            "iso_week_day",
117            "quarter",
118        ):
119            raise ValueError(
120                "Cannot extract component '{}' from DurationField '{}'.".format(
121                    copy.lookup_name, field.name
122                )
123            )
124        return copy
125
126
127class ExtractYear(Extract):
128    lookup_name = "year"
129
130
131class ExtractIsoYear(Extract):
132    """Return the ISO-8601 week-numbering year."""
133
134    lookup_name = "iso_year"
135
136
137class ExtractMonth(Extract):
138    lookup_name = "month"
139
140
141class ExtractDay(Extract):
142    lookup_name = "day"
143
144
145class ExtractWeek(Extract):
146    """
147    Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
148    week.
149    """
150
151    lookup_name = "week"
152
153
154class ExtractWeekDay(Extract):
155    """
156    Return Sunday=1 through Saturday=7.
157
158    To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
159    """
160
161    lookup_name = "week_day"
162
163
164class ExtractIsoWeekDay(Extract):
165    """Return Monday=1 through Sunday=7, based on ISO-8601."""
166
167    lookup_name = "iso_week_day"
168
169
170class ExtractQuarter(Extract):
171    lookup_name = "quarter"
172
173
174class ExtractHour(Extract):
175    lookup_name = "hour"
176
177
178class ExtractMinute(Extract):
179    lookup_name = "minute"
180
181
182class ExtractSecond(Extract):
183    lookup_name = "second"
184
185
186DateField.register_lookup(ExtractYear)
187DateField.register_lookup(ExtractMonth)
188DateField.register_lookup(ExtractDay)
189DateField.register_lookup(ExtractWeekDay)
190DateField.register_lookup(ExtractIsoWeekDay)
191DateField.register_lookup(ExtractWeek)
192DateField.register_lookup(ExtractIsoYear)
193DateField.register_lookup(ExtractQuarter)
194
195TimeField.register_lookup(ExtractHour)
196TimeField.register_lookup(ExtractMinute)
197TimeField.register_lookup(ExtractSecond)
198
199DateTimeField.register_lookup(ExtractHour)
200DateTimeField.register_lookup(ExtractMinute)
201DateTimeField.register_lookup(ExtractSecond)
202
203ExtractYear.register_lookup(YearExact)
204ExtractYear.register_lookup(YearGt)
205ExtractYear.register_lookup(YearGte)
206ExtractYear.register_lookup(YearLt)
207ExtractYear.register_lookup(YearLte)
208
209ExtractIsoYear.register_lookup(YearExact)
210ExtractIsoYear.register_lookup(YearGt)
211ExtractIsoYear.register_lookup(YearGte)
212ExtractIsoYear.register_lookup(YearLt)
213ExtractIsoYear.register_lookup(YearLte)
214
215
216class Now(Func):
217    template = "CURRENT_TIMESTAMP"
218    output_field = DateTimeField()
219
220    def as_postgresql(self, compiler, connection, **extra_context):
221        # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
222        # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
223        # other databases.
224        return self.as_sql(
225            compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
226        )
227
228    def as_mysql(self, compiler, connection, **extra_context):
229        return self.as_sql(
230            compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
231        )
232
233    def as_sqlite(self, compiler, connection, **extra_context):
234        return self.as_sql(
235            compiler,
236            connection,
237            template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
238            **extra_context,
239        )
240
241
242class TruncBase(TimezoneMixin, Transform):
243    kind = None
244    tzinfo = None
245
246    def __init__(
247        self,
248        expression,
249        output_field=None,
250        tzinfo=None,
251        **extra,
252    ):
253        self.tzinfo = tzinfo
254        super().__init__(expression, output_field=output_field, **extra)
255
256    def as_sql(self, compiler, connection):
257        sql, params = compiler.compile(self.lhs)
258        tzname = None
259        if isinstance(self.lhs.output_field, DateTimeField):
260            tzname = self.get_tzname()
261        elif self.tzinfo is not None:
262            raise ValueError("tzinfo can only be used with DateTimeField.")
263        if isinstance(self.output_field, DateTimeField):
264            sql, params = connection.ops.datetime_trunc_sql(
265                self.kind, sql, tuple(params), tzname
266            )
267        elif isinstance(self.output_field, DateField):
268            sql, params = connection.ops.date_trunc_sql(
269                self.kind, sql, tuple(params), tzname
270            )
271        elif isinstance(self.output_field, TimeField):
272            sql, params = connection.ops.time_trunc_sql(
273                self.kind, sql, tuple(params), tzname
274            )
275        else:
276            raise ValueError(
277                "Trunc only valid on DateField, TimeField, or DateTimeField."
278            )
279        return sql, params
280
281    def resolve_expression(
282        self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
283    ):
284        copy = super().resolve_expression(
285            query, allow_joins, reuse, summarize, for_save
286        )
287        field = copy.lhs.output_field
288        # DateTimeField is a subclass of DateField so this works for both.
289        if not isinstance(field, DateField | TimeField):
290            raise TypeError(
291                "%r isn't a DateField, TimeField, or DateTimeField." % field.name
292            )
293        # If self.output_field was None, then accessing the field will trigger
294        # the resolver to assign it to self.lhs.output_field.
295        if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
296            raise ValueError(
297                "output_field must be either DateField, TimeField, or DateTimeField"
298            )
299        # Passing dates or times to functions expecting datetimes is most
300        # likely a mistake.
301        class_output_field = (
302            self.__class__.output_field
303            if isinstance(self.__class__.output_field, Field)
304            else None
305        )
306        output_field = class_output_field or copy.output_field
307        has_explicit_output_field = (
308            class_output_field or field.__class__ is not copy.output_field.__class__
309        )
310        if type(field) == DateField and (
311            isinstance(output_field, DateTimeField)
312            or copy.kind in ("hour", "minute", "second", "time")
313        ):
314            raise ValueError(
315                "Cannot truncate DateField '{}' to {}.".format(
316                    field.name,
317                    output_field.__class__.__name__
318                    if has_explicit_output_field
319                    else "DateTimeField",
320                )
321            )
322        elif isinstance(field, TimeField) and (
323            isinstance(output_field, DateTimeField)
324            or copy.kind in ("year", "quarter", "month", "week", "day", "date")
325        ):
326            raise ValueError(
327                "Cannot truncate TimeField '{}' to {}.".format(
328                    field.name,
329                    output_field.__class__.__name__
330                    if has_explicit_output_field
331                    else "DateTimeField",
332                )
333            )
334        return copy
335
336    def convert_value(self, value, expression, connection):
337        if isinstance(self.output_field, DateTimeField):
338            if not settings.USE_TZ:
339                pass
340            elif value is not None:
341                value = value.replace(tzinfo=None)
342                value = timezone.make_aware(value, self.tzinfo)
343            elif not connection.features.has_zoneinfo_database:
344                raise ValueError(
345                    "Database returned an invalid datetime value. Are time "
346                    "zone definitions for your database installed?"
347                )
348        elif isinstance(value, datetime):
349            if value is None:
350                pass
351            elif isinstance(self.output_field, DateField):
352                value = value.date()
353            elif isinstance(self.output_field, TimeField):
354                value = value.time()
355        return value
356
357
358class Trunc(TruncBase):
359    def __init__(
360        self,
361        expression,
362        kind,
363        output_field=None,
364        tzinfo=None,
365        **extra,
366    ):
367        self.kind = kind
368        super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
369
370
371class TruncYear(TruncBase):
372    kind = "year"
373
374
375class TruncQuarter(TruncBase):
376    kind = "quarter"
377
378
379class TruncMonth(TruncBase):
380    kind = "month"
381
382
383class TruncWeek(TruncBase):
384    """Truncate to midnight on the Monday of the week."""
385
386    kind = "week"
387
388
389class TruncDay(TruncBase):
390    kind = "day"
391
392
393class TruncDate(TruncBase):
394    kind = "date"
395    lookup_name = "date"
396    output_field = DateField()
397
398    def as_sql(self, compiler, connection):
399        # Cast to date rather than truncate to date.
400        sql, params = compiler.compile(self.lhs)
401        tzname = self.get_tzname()
402        return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
403
404
405class TruncTime(TruncBase):
406    kind = "time"
407    lookup_name = "time"
408    output_field = TimeField()
409
410    def as_sql(self, compiler, connection):
411        # Cast to time rather than truncate to time.
412        sql, params = compiler.compile(self.lhs)
413        tzname = self.get_tzname()
414        return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
415
416
417class TruncHour(TruncBase):
418    kind = "hour"
419
420
421class TruncMinute(TruncBase):
422    kind = "minute"
423
424
425class TruncSecond(TruncBase):
426    kind = "second"
427
428
429DateTimeField.register_lookup(TruncDate)
430DateTimeField.register_lookup(TruncTime)