1from __future__ import annotations
2
3from datetime import datetime
4from typing import TYPE_CHECKING, Any
5
6from plain.models.expressions import Func
7from plain.models.fields import (
8 DateField,
9 DateTimeField,
10 DurationField,
11 Field,
12 IntegerField,
13 TimeField,
14)
15from plain.models.lookups import (
16 Transform,
17 YearExact,
18 YearGt,
19 YearGte,
20 YearLt,
21 YearLte,
22)
23from plain.utils import timezone
24
25if TYPE_CHECKING:
26 from plain.models.backends.base.base import BaseDatabaseWrapper
27 from plain.models.sql.compiler import SQLCompiler
28
29
30class TimezoneMixin:
31 tzinfo = None
32
33 def get_tzname(self) -> str | None:
34 # Timezone conversions must happen to the input datetime *before*
35 # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
36 # database as 2016-01-01 01:00:00 +00:00. Any results should be
37 # based on the input datetime not the stored datetime.
38 if self.tzinfo is None:
39 return timezone.get_current_timezone_name()
40 else:
41 return timezone._get_timezone_name(self.tzinfo)
42
43
44class Extract(TimezoneMixin, Transform):
45 lookup_name = None
46 output_field = IntegerField()
47
48 def __init__(
49 self,
50 expression: Any,
51 lookup_name: str | None = None,
52 tzinfo: Any = None,
53 **extra: Any,
54 ) -> None:
55 if self.lookup_name is None:
56 self.lookup_name = lookup_name
57 if self.lookup_name is None:
58 raise ValueError("lookup_name must be provided")
59 self.tzinfo = tzinfo
60 super().__init__(expression, **extra)
61
62 def as_sql(
63 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
64 ) -> tuple[str, tuple[Any, ...]]:
65 sql, params = compiler.compile(self.lhs)
66 lhs_output_field = self.lhs.output_field
67 if isinstance(lhs_output_field, DateTimeField):
68 tzname = self.get_tzname()
69 sql, params = connection.ops.datetime_extract_sql(
70 self.lookup_name, sql, tuple(params), tzname
71 )
72 elif self.tzinfo is not None:
73 raise ValueError("tzinfo can only be used with DateTimeField.")
74 elif isinstance(lhs_output_field, DateField):
75 sql, params = connection.ops.date_extract_sql(
76 self.lookup_name, sql, tuple(params)
77 )
78 elif isinstance(lhs_output_field, TimeField):
79 sql, params = connection.ops.time_extract_sql(
80 self.lookup_name, sql, tuple(params)
81 )
82 elif isinstance(lhs_output_field, DurationField):
83 if not connection.features.has_native_duration_field:
84 raise ValueError(
85 "Extract requires native DurationField database support."
86 )
87 sql, params = connection.ops.time_extract_sql(
88 self.lookup_name, sql, tuple(params)
89 )
90 else:
91 # resolve_expression has already validated the output_field so this
92 # assert should never be hit.
93 raise ValueError("Tried to Extract from an invalid type.")
94 return sql, tuple(params) if isinstance(params, list) else params
95
96 def resolve_expression(
97 self,
98 query: Any = None,
99 allow_joins: bool = True,
100 reuse: Any = None,
101 summarize: bool = False,
102 for_save: bool = False,
103 ) -> Extract:
104 copy = super().resolve_expression(
105 query, allow_joins, reuse, summarize, for_save
106 )
107 field = getattr(copy.lhs, "output_field", None)
108 if field is None:
109 return copy
110 if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
111 raise ValueError(
112 "Extract input expression must be DateField, DateTimeField, "
113 "TimeField, or DurationField."
114 )
115 # Passing dates to functions expecting datetimes is most likely a mistake.
116 if type(field) == DateField and copy.lookup_name in ( # noqa: E721
117 "hour",
118 "minute",
119 "second",
120 ):
121 raise ValueError(
122 f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'."
123 )
124 if isinstance(field, DurationField) and copy.lookup_name in (
125 "year",
126 "iso_year",
127 "month",
128 "week",
129 "week_day",
130 "iso_week_day",
131 "quarter",
132 ):
133 raise ValueError(
134 f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'."
135 )
136 return copy
137
138
139class ExtractYear(Extract):
140 lookup_name = "year"
141
142
143class ExtractIsoYear(Extract):
144 """Return the ISO-8601 week-numbering year."""
145
146 lookup_name = "iso_year"
147
148
149class ExtractMonth(Extract):
150 lookup_name = "month"
151
152
153class ExtractDay(Extract):
154 lookup_name = "day"
155
156
157class ExtractWeek(Extract):
158 """
159 Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
160 week.
161 """
162
163 lookup_name = "week"
164
165
166class ExtractWeekDay(Extract):
167 """
168 Return Sunday=1 through Saturday=7.
169
170 To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
171 """
172
173 lookup_name = "week_day"
174
175
176class ExtractIsoWeekDay(Extract):
177 """Return Monday=1 through Sunday=7, based on ISO-8601."""
178
179 lookup_name = "iso_week_day"
180
181
182class ExtractQuarter(Extract):
183 lookup_name = "quarter"
184
185
186class ExtractHour(Extract):
187 lookup_name = "hour"
188
189
190class ExtractMinute(Extract):
191 lookup_name = "minute"
192
193
194class ExtractSecond(Extract):
195 lookup_name = "second"
196
197
198DateField.register_lookup(ExtractYear)
199DateField.register_lookup(ExtractMonth)
200DateField.register_lookup(ExtractDay)
201DateField.register_lookup(ExtractWeekDay)
202DateField.register_lookup(ExtractIsoWeekDay)
203DateField.register_lookup(ExtractWeek)
204DateField.register_lookup(ExtractIsoYear)
205DateField.register_lookup(ExtractQuarter)
206
207TimeField.register_lookup(ExtractHour)
208TimeField.register_lookup(ExtractMinute)
209TimeField.register_lookup(ExtractSecond)
210
211DateTimeField.register_lookup(ExtractHour)
212DateTimeField.register_lookup(ExtractMinute)
213DateTimeField.register_lookup(ExtractSecond)
214
215ExtractYear.register_lookup(YearExact)
216ExtractYear.register_lookup(YearGt)
217ExtractYear.register_lookup(YearGte)
218ExtractYear.register_lookup(YearLt)
219ExtractYear.register_lookup(YearLte)
220
221ExtractIsoYear.register_lookup(YearExact)
222ExtractIsoYear.register_lookup(YearGt)
223ExtractIsoYear.register_lookup(YearGte)
224ExtractIsoYear.register_lookup(YearLt)
225ExtractIsoYear.register_lookup(YearLte)
226
227
228class Now(Func):
229 template = "CURRENT_TIMESTAMP"
230 output_field = DateTimeField()
231
232 def as_postgresql(
233 self,
234 compiler: SQLCompiler,
235 connection: BaseDatabaseWrapper,
236 **extra_context: Any,
237 ) -> tuple[str, tuple[Any, ...]]:
238 # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
239 # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
240 # other databases.
241 return self.as_sql(
242 compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
243 )
244
245 def as_mysql(
246 self,
247 compiler: SQLCompiler,
248 connection: BaseDatabaseWrapper,
249 **extra_context: Any,
250 ) -> tuple[str, tuple[Any, ...]]:
251 return self.as_sql(
252 compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
253 )
254
255 def as_sqlite(
256 self,
257 compiler: SQLCompiler,
258 connection: BaseDatabaseWrapper,
259 **extra_context: Any,
260 ) -> tuple[str, tuple[Any, ...]]:
261 return self.as_sql(
262 compiler,
263 connection,
264 template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
265 **extra_context,
266 )
267
268
269class TruncBase(TimezoneMixin, Transform):
270 kind = None
271 tzinfo = None
272
273 def __init__(
274 self,
275 expression: Any,
276 output_field: Field | None = None,
277 tzinfo: Any = None,
278 **extra: Any,
279 ) -> None:
280 self.tzinfo = tzinfo
281 super().__init__(expression, output_field=output_field, **extra)
282
283 def as_sql(
284 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
285 ) -> tuple[str, tuple[Any, ...]]:
286 sql, params = compiler.compile(self.lhs)
287 tzname = None
288 if isinstance(self.lhs.output_field, DateTimeField):
289 tzname = self.get_tzname()
290 elif self.tzinfo is not None:
291 raise ValueError("tzinfo can only be used with DateTimeField.")
292 if isinstance(self.output_field, DateTimeField):
293 sql, params = connection.ops.datetime_trunc_sql(
294 self.kind, sql, tuple(params), tzname
295 )
296 elif isinstance(self.output_field, DateField):
297 sql, params = connection.ops.date_trunc_sql(
298 self.kind, sql, tuple(params), tzname
299 )
300 elif isinstance(self.output_field, TimeField):
301 sql, params = connection.ops.time_trunc_sql(
302 self.kind, sql, tuple(params), tzname
303 )
304 else:
305 raise ValueError(
306 "Trunc only valid on DateField, TimeField, or DateTimeField."
307 )
308 return sql, tuple(params) if isinstance(params, list) else params
309
310 def resolve_expression(
311 self,
312 query: Any = None,
313 allow_joins: bool = True,
314 reuse: Any = None,
315 summarize: bool = False,
316 for_save: bool = False,
317 ) -> TruncBase:
318 copy = super().resolve_expression(
319 query, allow_joins, reuse, summarize, for_save
320 )
321 field = copy.lhs.output_field
322 # DateTimeField is a subclass of DateField so this works for both.
323 if not isinstance(field, DateField | TimeField):
324 raise TypeError(
325 f"{field.name!r} isn't a DateField, TimeField, or DateTimeField."
326 )
327 # If self.output_field was None, then accessing the field will trigger
328 # the resolver to assign it to self.lhs.output_field.
329 if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
330 raise ValueError(
331 "output_field must be either DateField, TimeField, or DateTimeField"
332 )
333 # Passing dates or times to functions expecting datetimes is most
334 # likely a mistake.
335 class_output_field = (
336 self.__class__.output_field
337 if isinstance(self.__class__.output_field, Field)
338 else None
339 )
340 output_field = class_output_field or copy.output_field
341 has_explicit_output_field = (
342 class_output_field or field.__class__ is not copy.output_field.__class__
343 )
344 if type(field) == DateField and ( # noqa: E721
345 isinstance(output_field, DateTimeField)
346 or copy.kind in ("hour", "minute", "second", "time")
347 ):
348 raise ValueError(
349 "Cannot truncate DateField '{}' to {}.".format(
350 field.name,
351 output_field.__class__.__name__
352 if has_explicit_output_field
353 else "DateTimeField",
354 )
355 )
356 elif isinstance(field, TimeField) and (
357 isinstance(output_field, DateTimeField)
358 or copy.kind in ("year", "quarter", "month", "week", "day", "date")
359 ):
360 raise ValueError(
361 "Cannot truncate TimeField '{}' to {}.".format(
362 field.name,
363 output_field.__class__.__name__
364 if has_explicit_output_field
365 else "DateTimeField",
366 )
367 )
368 return copy
369
370 def convert_value(
371 self, value: Any, expression: Any, connection: BaseDatabaseWrapper
372 ) -> Any:
373 if isinstance(self.output_field, DateTimeField):
374 if value is not None:
375 value = value.replace(tzinfo=None)
376 value = timezone.make_aware(value, self.tzinfo)
377 elif not connection.features.has_zoneinfo_database:
378 raise ValueError(
379 "Database returned an invalid datetime value. Are time "
380 "zone definitions for your database installed?"
381 )
382 elif isinstance(value, datetime):
383 if value is None:
384 pass
385 elif isinstance(self.output_field, DateField):
386 value = value.date()
387 elif isinstance(self.output_field, TimeField):
388 value = value.time()
389 return value
390
391
392class Trunc(TruncBase):
393 def __init__(
394 self,
395 expression: Any,
396 kind: str,
397 output_field: Field | None = None,
398 tzinfo: Any = None,
399 **extra: Any,
400 ) -> None:
401 self.kind = kind
402 super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
403
404
405class TruncYear(TruncBase):
406 kind = "year"
407
408
409class TruncQuarter(TruncBase):
410 kind = "quarter"
411
412
413class TruncMonth(TruncBase):
414 kind = "month"
415
416
417class TruncWeek(TruncBase):
418 """Truncate to midnight on the Monday of the week."""
419
420 kind = "week"
421
422
423class TruncDay(TruncBase):
424 kind = "day"
425
426
427class TruncDate(TruncBase):
428 kind = "date"
429 lookup_name = "date"
430 output_field = DateField()
431
432 def as_sql(
433 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
434 ) -> tuple[str, tuple[Any, ...]]:
435 # Cast to date rather than truncate to date.
436 sql, params = compiler.compile(self.lhs)
437 tzname = self.get_tzname()
438 sql, params = connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
439 return sql, tuple(params) if isinstance(params, list) else params
440
441
442class TruncTime(TruncBase):
443 kind = "time"
444 lookup_name = "time"
445 output_field = TimeField()
446
447 def as_sql(
448 self, compiler: SQLCompiler, connection: BaseDatabaseWrapper
449 ) -> tuple[str, tuple[Any, ...]]:
450 # Cast to time rather than truncate to time.
451 sql, params = compiler.compile(self.lhs)
452 tzname = self.get_tzname()
453 sql, params = connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
454 return sql, tuple(params) if isinstance(params, list) else params
455
456
457class TruncHour(TruncBase):
458 kind = "hour"
459
460
461class TruncMinute(TruncBase):
462 kind = "minute"
463
464
465class TruncSecond(TruncBase):
466 kind = "second"
467
468
469DateTimeField.register_lookup(TruncDate)
470DateTimeField.register_lookup(TruncTime)