1from __future__ import annotations
2
3from datetime import datetime
4from typing import TYPE_CHECKING, Any
5
6from plain.models.expressions import Func
7from plain.models.fields import (
8 DateField,
9 DateTimeField,
10 DurationField,
11 Field,
12 IntegerField,
13 TimeField,
14)
15from plain.models.lookups import (
16 Transform,
17 YearExact,
18 YearGt,
19 YearGte,
20 YearLt,
21 YearLte,
22)
23from plain.utils import timezone
24
25if TYPE_CHECKING:
26 from plain.models.backends.base.base import BaseDatabaseWrapper
27 from plain.models.sql.compiler import SQLCompiler
28
29
30class TimezoneMixin(Transform):
31 tzinfo = None
32
33 def get_tzname(self) -> str | None:
34 # Timezone conversions must happen to the input datetime *before*
35 # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
36 # database as 2016-01-01 01:00:00 +00:00. Any results should be
37 # based on the input datetime not the stored datetime.
38 if self.tzinfo is None:
39 return timezone.get_current_timezone_name()
40 else:
41 return timezone._get_timezone_name(self.tzinfo)
42
43
44class Extract(TimezoneMixin, Transform):
45 lookup_name: str | None = None
46 output_field = IntegerField()
47
48 def __init__(
49 self,
50 expression: Any,
51 lookup_name: str | None = None,
52 tzinfo: Any = None,
53 **extra: Any,
54 ) -> None:
55 if self.lookup_name is None:
56 self.lookup_name = lookup_name
57 if self.lookup_name is None:
58 raise ValueError("lookup_name must be provided")
59 self.tzinfo = tzinfo
60 super().__init__(expression, **extra)
61
62 def as_sql(
63 self,
64 compiler: SQLCompiler,
65 connection: BaseDatabaseWrapper,
66 function: str | None = None,
67 template: str | None = None,
68 arg_joiner: str | None = None,
69 **extra_context: Any,
70 ) -> tuple[str, list[Any]]:
71 # lookup_name is guaranteed to be str after __init__ validation
72 assert self.lookup_name is not None
73 sql, params = compiler.compile(self.lhs)
74 lhs_output_field = self.lhs.output_field
75 if isinstance(lhs_output_field, DateTimeField):
76 tzname = self.get_tzname()
77 sql, params = connection.ops.datetime_extract_sql(
78 self.lookup_name, sql, tuple(params), tzname
79 )
80 elif self.tzinfo is not None:
81 raise ValueError("tzinfo can only be used with DateTimeField.")
82 elif isinstance(lhs_output_field, DateField):
83 sql, params = connection.ops.date_extract_sql(
84 self.lookup_name, sql, tuple(params)
85 )
86 elif isinstance(lhs_output_field, TimeField):
87 sql, params = connection.ops.time_extract_sql(
88 self.lookup_name, sql, tuple(params)
89 )
90 elif isinstance(lhs_output_field, DurationField):
91 if not connection.features.has_native_duration_field:
92 raise ValueError(
93 "Extract requires native DurationField database support."
94 )
95 sql, params = connection.ops.time_extract_sql(
96 self.lookup_name, sql, tuple(params)
97 )
98 else:
99 # resolve_expression has already validated the output_field so this
100 # assert should never be hit.
101 raise ValueError("Tried to Extract from an invalid type.")
102 return sql, list(params)
103
104 def resolve_expression(
105 self,
106 query: Any = None,
107 allow_joins: bool = True,
108 reuse: Any = None,
109 summarize: bool = False,
110 for_save: bool = False,
111 ) -> Extract:
112 copy = super().resolve_expression(
113 query, allow_joins, reuse, summarize, for_save
114 )
115 field = getattr(copy.lhs, "output_field", None)
116 if field is None:
117 return copy
118 if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
119 raise ValueError(
120 "Extract input expression must be DateField, DateTimeField, "
121 "TimeField, or DurationField."
122 )
123 # Passing dates to functions expecting datetimes is most likely a mistake.
124 if type(field) == DateField and copy.lookup_name in ( # noqa: E721
125 "hour",
126 "minute",
127 "second",
128 ):
129 raise ValueError(
130 f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'."
131 )
132 if isinstance(field, DurationField) and copy.lookup_name in (
133 "year",
134 "iso_year",
135 "month",
136 "week",
137 "week_day",
138 "iso_week_day",
139 "quarter",
140 ):
141 raise ValueError(
142 f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'."
143 )
144 return copy
145
146
147class ExtractYear(Extract):
148 lookup_name = "year"
149
150
151class ExtractIsoYear(Extract):
152 """Return the ISO-8601 week-numbering year."""
153
154 lookup_name = "iso_year"
155
156
157class ExtractMonth(Extract):
158 lookup_name = "month"
159
160
161class ExtractDay(Extract):
162 lookup_name = "day"
163
164
165class ExtractWeek(Extract):
166 """
167 Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
168 week.
169 """
170
171 lookup_name = "week"
172
173
174class ExtractWeekDay(Extract):
175 """
176 Return Sunday=1 through Saturday=7.
177
178 To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
179 """
180
181 lookup_name = "week_day"
182
183
184class ExtractIsoWeekDay(Extract):
185 """Return Monday=1 through Sunday=7, based on ISO-8601."""
186
187 lookup_name = "iso_week_day"
188
189
190class ExtractQuarter(Extract):
191 lookup_name = "quarter"
192
193
194class ExtractHour(Extract):
195 lookup_name = "hour"
196
197
198class ExtractMinute(Extract):
199 lookup_name = "minute"
200
201
202class ExtractSecond(Extract):
203 lookup_name = "second"
204
205
206DateField.register_lookup(ExtractYear)
207DateField.register_lookup(ExtractMonth)
208DateField.register_lookup(ExtractDay)
209DateField.register_lookup(ExtractWeekDay)
210DateField.register_lookup(ExtractIsoWeekDay)
211DateField.register_lookup(ExtractWeek)
212DateField.register_lookup(ExtractIsoYear)
213DateField.register_lookup(ExtractQuarter)
214
215TimeField.register_lookup(ExtractHour)
216TimeField.register_lookup(ExtractMinute)
217TimeField.register_lookup(ExtractSecond)
218
219DateTimeField.register_lookup(ExtractHour)
220DateTimeField.register_lookup(ExtractMinute)
221DateTimeField.register_lookup(ExtractSecond)
222
223ExtractYear.register_lookup(YearExact)
224ExtractYear.register_lookup(YearGt)
225ExtractYear.register_lookup(YearGte)
226ExtractYear.register_lookup(YearLt)
227ExtractYear.register_lookup(YearLte)
228
229ExtractIsoYear.register_lookup(YearExact)
230ExtractIsoYear.register_lookup(YearGt)
231ExtractIsoYear.register_lookup(YearGte)
232ExtractIsoYear.register_lookup(YearLt)
233ExtractIsoYear.register_lookup(YearLte)
234
235
236class Now(Func):
237 template = "CURRENT_TIMESTAMP"
238 output_field = DateTimeField()
239
240 def as_postgresql(
241 self,
242 compiler: SQLCompiler,
243 connection: BaseDatabaseWrapper,
244 **extra_context: Any,
245 ) -> tuple[str, list[Any]]:
246 # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
247 # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
248 # other databases.
249 return self.as_sql(
250 compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
251 )
252
253 def as_mysql(
254 self,
255 compiler: SQLCompiler,
256 connection: BaseDatabaseWrapper,
257 **extra_context: Any,
258 ) -> tuple[str, list[Any]]:
259 return self.as_sql(
260 compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
261 )
262
263 def as_sqlite(
264 self,
265 compiler: SQLCompiler,
266 connection: BaseDatabaseWrapper,
267 **extra_context: Any,
268 ) -> tuple[str, list[Any]]:
269 return self.as_sql(
270 compiler,
271 connection,
272 template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
273 **extra_context,
274 )
275
276
277class TruncBase(TimezoneMixin, Transform):
278 kind: str | None = None
279 tzinfo = None
280
281 def __init__(
282 self,
283 expression: Any,
284 output_field: Field | None = None,
285 tzinfo: Any = None,
286 **extra: Any,
287 ) -> None:
288 self.tzinfo = tzinfo
289 super().__init__(expression, output_field=output_field, **extra)
290
291 def as_sql(
292 self,
293 compiler: SQLCompiler,
294 connection: BaseDatabaseWrapper,
295 function: str | None = None,
296 template: str | None = None,
297 arg_joiner: str | None = None,
298 **extra_context: Any,
299 ) -> tuple[str, list[Any]]:
300 # kind is guaranteed to be str in subclasses
301 assert self.kind is not None
302 sql, params = compiler.compile(self.lhs)
303 tzname = None
304 if isinstance(self.lhs.output_field, DateTimeField):
305 tzname = self.get_tzname()
306 elif self.tzinfo is not None:
307 raise ValueError("tzinfo can only be used with DateTimeField.")
308 if isinstance(self.output_field, DateTimeField):
309 sql, params = connection.ops.datetime_trunc_sql(
310 self.kind, sql, tuple(params), tzname
311 )
312 elif isinstance(self.output_field, DateField):
313 sql, params = connection.ops.date_trunc_sql(
314 self.kind, sql, tuple(params), tzname
315 )
316 elif isinstance(self.output_field, TimeField):
317 sql, params = connection.ops.time_trunc_sql(
318 self.kind, sql, tuple(params), tzname
319 )
320 else:
321 raise ValueError(
322 "Trunc only valid on DateField, TimeField, or DateTimeField."
323 )
324 return sql, list(params)
325
326 def resolve_expression(
327 self,
328 query: Any = None,
329 allow_joins: bool = True,
330 reuse: Any = None,
331 summarize: bool = False,
332 for_save: bool = False,
333 ) -> TruncBase:
334 copy = super().resolve_expression(
335 query, allow_joins, reuse, summarize, for_save
336 )
337 field = copy.lhs.output_field
338 # DateTimeField is a subclass of DateField so this works for both.
339 if not isinstance(field, DateField | TimeField):
340 raise TypeError(
341 f"{field.name!r} isn't a DateField, TimeField, or DateTimeField."
342 )
343 # If self.output_field was None, then accessing the field will trigger
344 # the resolver to assign it to self.lhs.output_field.
345 if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
346 raise ValueError(
347 "output_field must be either DateField, TimeField, or DateTimeField"
348 )
349 # Passing dates or times to functions expecting datetimes is most
350 # likely a mistake.
351 class_output_field = (
352 self.__class__.output_field
353 if isinstance(self.__class__.output_field, Field)
354 else None
355 )
356 output_field = class_output_field or copy.output_field
357 has_explicit_output_field = (
358 class_output_field or field.__class__ is not copy.output_field.__class__
359 )
360 if type(field) == DateField and ( # noqa: E721
361 isinstance(output_field, DateTimeField)
362 or copy.kind in ("hour", "minute", "second", "time")
363 ):
364 raise ValueError(
365 "Cannot truncate DateField '{}' to {}.".format(
366 field.name,
367 output_field.__class__.__name__
368 if has_explicit_output_field
369 else "DateTimeField",
370 )
371 )
372 elif isinstance(field, TimeField) and (
373 isinstance(output_field, DateTimeField)
374 or copy.kind in ("year", "quarter", "month", "week", "day", "date")
375 ):
376 raise ValueError(
377 "Cannot truncate TimeField '{}' to {}.".format(
378 field.name,
379 output_field.__class__.__name__
380 if has_explicit_output_field
381 else "DateTimeField",
382 )
383 )
384 return copy
385
386 def convert_value(
387 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
388 ) -> Any:
389 if isinstance(self.output_field, DateTimeField):
390 if value is not None:
391 value = value.replace(tzinfo=None)
392 value = timezone.make_aware(value, self.tzinfo)
393 elif not connection.features.has_zoneinfo_database:
394 raise ValueError(
395 "Database returned an invalid datetime value. Are time "
396 "zone definitions for your database installed?"
397 )
398 elif isinstance(value, datetime):
399 if value is None:
400 pass
401 elif isinstance(self.output_field, DateField):
402 value = value.date()
403 elif isinstance(self.output_field, TimeField):
404 value = value.time()
405 return value
406
407
408class Trunc(TruncBase):
409 def __init__(
410 self,
411 expression: Any,
412 kind: str,
413 output_field: Field | None = None,
414 tzinfo: Any = None,
415 **extra: Any,
416 ) -> None:
417 self.kind = kind
418 super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
419
420
421class TruncYear(TruncBase):
422 kind = "year"
423
424
425class TruncQuarter(TruncBase):
426 kind = "quarter"
427
428
429class TruncMonth(TruncBase):
430 kind = "month"
431
432
433class TruncWeek(TruncBase):
434 """Truncate to midnight on the Monday of the week."""
435
436 kind = "week"
437
438
439class TruncDay(TruncBase):
440 kind = "day"
441
442
443class TruncDate(TruncBase):
444 kind = "date"
445 lookup_name = "date"
446 output_field = DateField()
447
448 def as_sql(
449 self,
450 compiler: SQLCompiler,
451 connection: BaseDatabaseWrapper,
452 function: str | None = None,
453 template: str | None = None,
454 arg_joiner: str | None = None,
455 **extra_context: Any,
456 ) -> tuple[str, list[Any]]:
457 # Cast to date rather than truncate to date.
458 sql, params = compiler.compile(self.lhs)
459 tzname = self.get_tzname()
460 sql, params = connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
461 return sql, list(params)
462
463
464class TruncTime(TruncBase):
465 kind = "time"
466 lookup_name = "time"
467 output_field = TimeField()
468
469 def as_sql(
470 self,
471 compiler: SQLCompiler,
472 connection: BaseDatabaseWrapper,
473 function: str | None = None,
474 template: str | None = None,
475 arg_joiner: str | None = None,
476 **extra_context: Any,
477 ) -> tuple[str, list[Any]]:
478 # Cast to time rather than truncate to time.
479 sql, params = compiler.compile(self.lhs)
480 tzname = self.get_tzname()
481 sql, params = connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
482 return sql, list(params)
483
484
485class TruncHour(TruncBase):
486 kind = "hour"
487
488
489class TruncMinute(TruncBase):
490 kind = "minute"
491
492
493class TruncSecond(TruncBase):
494 kind = "second"
495
496
497DateTimeField.register_lookup(TruncDate)
498DateTimeField.register_lookup(TruncTime)