1from __future__ import annotations
2
3from datetime import datetime
4from typing import TYPE_CHECKING, Any
5
6from plain.models.expressions import Func
7from plain.models.fields import (
8 DateField,
9 DateTimeField,
10 DurationField,
11 Field,
12 IntegerField,
13 TimeField,
14)
15from plain.models.lookups import (
16 Transform,
17 YearExact,
18 YearGt,
19 YearGte,
20 YearLt,
21 YearLte,
22)
23from plain.models.postgres.sql import (
24 date_extract_sql,
25 date_trunc_sql,
26 datetime_cast_date_sql,
27 datetime_cast_time_sql,
28 datetime_extract_sql,
29 datetime_trunc_sql,
30 time_extract_sql,
31 time_trunc_sql,
32)
33from plain.utils import timezone
34
35if TYPE_CHECKING:
36 from plain.models.postgres.wrapper import DatabaseWrapper
37 from plain.models.sql.compiler import SQLCompiler
38
39
40class TimezoneMixin(Transform):
41 tzinfo = None
42
43 def get_tzname(self) -> str | None:
44 # Timezone conversions must happen to the input datetime *before*
45 # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
46 # database as 2016-01-01 01:00:00 +00:00. Any results should be
47 # based on the input datetime not the stored datetime.
48 if self.tzinfo is None:
49 return timezone.get_current_timezone_name()
50 else:
51 return timezone._get_timezone_name(self.tzinfo)
52
53
54class Extract(TimezoneMixin, Transform):
55 lookup_name: str | None = None
56 output_field = IntegerField()
57
58 def __init__(
59 self,
60 expression: Any,
61 lookup_name: str | None = None,
62 tzinfo: Any = None,
63 **extra: Any,
64 ) -> None:
65 if self.lookup_name is None:
66 self.lookup_name = lookup_name
67 if self.lookup_name is None:
68 raise ValueError("lookup_name must be provided")
69 self.tzinfo = tzinfo
70 super().__init__(expression, **extra)
71
72 def as_sql(
73 self,
74 compiler: SQLCompiler,
75 connection: DatabaseWrapper,
76 function: str | None = None,
77 template: str | None = None,
78 arg_joiner: str | None = None,
79 **extra_context: Any,
80 ) -> tuple[str, list[Any]]:
81 # lookup_name is guaranteed to be str after __init__ validation
82 assert self.lookup_name is not None
83 sql, params = compiler.compile(self.lhs)
84 lhs_output_field = self.lhs.output_field
85 if isinstance(lhs_output_field, DateTimeField):
86 tzname = self.get_tzname()
87 sql, params = datetime_extract_sql(
88 self.lookup_name, sql, tuple(params), tzname
89 )
90 elif self.tzinfo is not None:
91 raise ValueError("tzinfo can only be used with DateTimeField.")
92 elif isinstance(lhs_output_field, DateField):
93 sql, params = date_extract_sql(self.lookup_name, sql, tuple(params))
94 elif isinstance(lhs_output_field, TimeField):
95 sql, params = time_extract_sql(self.lookup_name, sql, tuple(params))
96 elif isinstance(lhs_output_field, DurationField):
97 # PostgreSQL has native duration (interval) type
98 sql, params = time_extract_sql(self.lookup_name, sql, tuple(params))
99 else:
100 # resolve_expression has already validated the output_field so this
101 # assert should never be hit.
102 raise ValueError("Tried to Extract from an invalid type.")
103 return sql, list(params)
104
105 def resolve_expression(
106 self,
107 query: Any = None,
108 allow_joins: bool = True,
109 reuse: Any = None,
110 summarize: bool = False,
111 for_save: bool = False,
112 ) -> Extract:
113 copy = super().resolve_expression(
114 query, allow_joins, reuse, summarize, for_save
115 )
116 field = getattr(copy.lhs, "output_field", None)
117 if field is None:
118 return copy
119 if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
120 raise ValueError(
121 "Extract input expression must be DateField, DateTimeField, "
122 "TimeField, or DurationField."
123 )
124 # Passing dates to functions expecting datetimes is most likely a mistake.
125 if type(field) == DateField and copy.lookup_name in ( # noqa: E721
126 "hour",
127 "minute",
128 "second",
129 ):
130 raise ValueError(
131 f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'."
132 )
133 if isinstance(field, DurationField) and copy.lookup_name in (
134 "year",
135 "iso_year",
136 "month",
137 "week",
138 "week_day",
139 "iso_week_day",
140 "quarter",
141 ):
142 raise ValueError(
143 f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'."
144 )
145 return copy
146
147
148class ExtractYear(Extract):
149 lookup_name = "year"
150
151
152class ExtractIsoYear(Extract):
153 """Return the ISO-8601 week-numbering year."""
154
155 lookup_name = "iso_year"
156
157
158class ExtractMonth(Extract):
159 lookup_name = "month"
160
161
162class ExtractDay(Extract):
163 lookup_name = "day"
164
165
166class ExtractWeek(Extract):
167 """
168 Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
169 week.
170 """
171
172 lookup_name = "week"
173
174
175class ExtractWeekDay(Extract):
176 """
177 Return Sunday=1 through Saturday=7.
178
179 To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
180 """
181
182 lookup_name = "week_day"
183
184
185class ExtractIsoWeekDay(Extract):
186 """Return Monday=1 through Sunday=7, based on ISO-8601."""
187
188 lookup_name = "iso_week_day"
189
190
191class ExtractQuarter(Extract):
192 lookup_name = "quarter"
193
194
195class ExtractHour(Extract):
196 lookup_name = "hour"
197
198
199class ExtractMinute(Extract):
200 lookup_name = "minute"
201
202
203class ExtractSecond(Extract):
204 lookup_name = "second"
205
206
207DateField.register_lookup(ExtractYear)
208DateField.register_lookup(ExtractMonth)
209DateField.register_lookup(ExtractDay)
210DateField.register_lookup(ExtractWeekDay)
211DateField.register_lookup(ExtractIsoWeekDay)
212DateField.register_lookup(ExtractWeek)
213DateField.register_lookup(ExtractIsoYear)
214DateField.register_lookup(ExtractQuarter)
215
216TimeField.register_lookup(ExtractHour)
217TimeField.register_lookup(ExtractMinute)
218TimeField.register_lookup(ExtractSecond)
219
220DateTimeField.register_lookup(ExtractHour)
221DateTimeField.register_lookup(ExtractMinute)
222DateTimeField.register_lookup(ExtractSecond)
223
224ExtractYear.register_lookup(YearExact)
225ExtractYear.register_lookup(YearGt)
226ExtractYear.register_lookup(YearGte)
227ExtractYear.register_lookup(YearLt)
228ExtractYear.register_lookup(YearLte)
229
230ExtractIsoYear.register_lookup(YearExact)
231ExtractIsoYear.register_lookup(YearGt)
232ExtractIsoYear.register_lookup(YearGte)
233ExtractIsoYear.register_lookup(YearLt)
234ExtractIsoYear.register_lookup(YearLte)
235
236
237class Now(Func):
238 # STATEMENT_TIMESTAMP() returns the time at the start of the current statement,
239 # as opposed to CURRENT_TIMESTAMP which returns the time at the start of the
240 # transaction.
241 template = "STATEMENT_TIMESTAMP()"
242 output_field = DateTimeField()
243
244
245class TruncBase(TimezoneMixin, Transform):
246 kind: str | None = None
247
248 def __init__(
249 self,
250 expression: Any,
251 output_field: Field | None = None,
252 tzinfo: Any = None,
253 **extra: Any,
254 ) -> None:
255 self.tzinfo = tzinfo
256 super().__init__(expression, output_field=output_field, **extra)
257
258 def as_sql(
259 self,
260 compiler: SQLCompiler,
261 connection: DatabaseWrapper,
262 function: str | None = None,
263 template: str | None = None,
264 arg_joiner: str | None = None,
265 **extra_context: Any,
266 ) -> tuple[str, list[Any]]:
267 # kind is guaranteed to be str in subclasses
268 assert self.kind is not None
269 sql, params = compiler.compile(self.lhs)
270 tzname = None
271 if isinstance(self.lhs.output_field, DateTimeField):
272 tzname = self.get_tzname()
273 elif self.tzinfo is not None:
274 raise ValueError("tzinfo can only be used with DateTimeField.")
275 if isinstance(self.output_field, DateTimeField):
276 sql, params = datetime_trunc_sql(self.kind, sql, tuple(params), tzname)
277 elif isinstance(self.output_field, DateField):
278 sql, params = date_trunc_sql(self.kind, sql, tuple(params), tzname)
279 elif isinstance(self.output_field, TimeField):
280 sql, params = time_trunc_sql(self.kind, sql, tuple(params), tzname)
281 else:
282 raise ValueError(
283 "Trunc only valid on DateField, TimeField, or DateTimeField."
284 )
285 return sql, list(params)
286
287 def resolve_expression(
288 self,
289 query: Any = None,
290 allow_joins: bool = True,
291 reuse: Any = None,
292 summarize: bool = False,
293 for_save: bool = False,
294 ) -> TruncBase:
295 copy = super().resolve_expression(
296 query, allow_joins, reuse, summarize, for_save
297 )
298 field = copy.lhs.output_field
299 # DateTimeField is a subclass of DateField so this works for both.
300 if not isinstance(field, DateField | TimeField):
301 raise TypeError(
302 f"{field.name!r} isn't a DateField, TimeField, or DateTimeField."
303 )
304 # If self.output_field was None, then accessing the field will trigger
305 # the resolver to assign it to self.lhs.output_field.
306 if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
307 raise ValueError(
308 "output_field must be either DateField, TimeField, or DateTimeField"
309 )
310 # Passing dates or times to functions expecting datetimes is most
311 # likely a mistake.
312 class_output_field = (
313 self.__class__.output_field
314 if isinstance(self.__class__.output_field, Field)
315 else None
316 )
317 output_field = class_output_field or copy.output_field
318 has_explicit_output_field = (
319 class_output_field or field.__class__ is not copy.output_field.__class__
320 )
321 if type(field) == DateField and ( # noqa: E721
322 isinstance(output_field, DateTimeField)
323 or copy.kind in ("hour", "minute", "second", "time")
324 ):
325 raise ValueError(
326 "Cannot truncate DateField '{}' to {}.".format(
327 field.name,
328 output_field.__class__.__name__
329 if has_explicit_output_field
330 else "DateTimeField",
331 )
332 )
333 elif isinstance(field, TimeField) and (
334 isinstance(output_field, DateTimeField)
335 or copy.kind in ("year", "quarter", "month", "week", "day", "date")
336 ):
337 raise ValueError(
338 "Cannot truncate TimeField '{}' to {}.".format(
339 field.name,
340 output_field.__class__.__name__
341 if has_explicit_output_field
342 else "DateTimeField",
343 )
344 )
345 return copy
346
347 def convert_value(
348 self, value: Any, expression: Any, connection: DatabaseWrapper
349 ) -> Any:
350 if isinstance(self.output_field, DateTimeField):
351 if value is not None:
352 value = value.replace(tzinfo=None)
353 value = timezone.make_aware(value, self.tzinfo)
354 elif isinstance(value, datetime):
355 if value is None:
356 pass
357 elif isinstance(self.output_field, DateField):
358 value = value.date()
359 elif isinstance(self.output_field, TimeField):
360 value = value.time()
361 return value
362
363
364class Trunc(TruncBase):
365 def __init__(
366 self,
367 expression: Any,
368 kind: str,
369 output_field: Field | None = None,
370 tzinfo: Any = None,
371 **extra: Any,
372 ) -> None:
373 self.kind = kind
374 super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
375
376
377class TruncYear(TruncBase):
378 kind = "year"
379
380
381class TruncQuarter(TruncBase):
382 kind = "quarter"
383
384
385class TruncMonth(TruncBase):
386 kind = "month"
387
388
389class TruncWeek(TruncBase):
390 """Truncate to midnight on the Monday of the week."""
391
392 kind = "week"
393
394
395class TruncDay(TruncBase):
396 kind = "day"
397
398
399class TruncDate(TruncBase):
400 kind = "date"
401 lookup_name = "date"
402 output_field = DateField()
403
404 def as_sql(
405 self,
406 compiler: SQLCompiler,
407 connection: DatabaseWrapper,
408 function: str | None = None,
409 template: str | None = None,
410 arg_joiner: str | None = None,
411 **extra_context: Any,
412 ) -> tuple[str, list[Any]]:
413 # Cast to date rather than truncate to date.
414 sql, params = compiler.compile(self.lhs)
415 tzname = self.get_tzname()
416 sql, params = datetime_cast_date_sql(sql, tuple(params), tzname)
417 return sql, list(params)
418
419
420class TruncTime(TruncBase):
421 kind = "time"
422 lookup_name = "time"
423 output_field = TimeField()
424
425 def as_sql(
426 self,
427 compiler: SQLCompiler,
428 connection: DatabaseWrapper,
429 function: str | None = None,
430 template: str | None = None,
431 arg_joiner: str | None = None,
432 **extra_context: Any,
433 ) -> tuple[str, list[Any]]:
434 # Cast to time rather than truncate to time.
435 sql, params = compiler.compile(self.lhs)
436 tzname = self.get_tzname()
437 sql, params = datetime_cast_time_sql(sql, tuple(params), tzname)
438 return sql, list(params)
439
440
441class TruncHour(TruncBase):
442 kind = "hour"
443
444
445class TruncMinute(TruncBase):
446 kind = "minute"
447
448
449class TruncSecond(TruncBase):
450 kind = "second"
451
452
453DateTimeField.register_lookup(TruncDate)
454DateTimeField.register_lookup(TruncTime)