1from datetime import datetime
2
3from plain.models.expressions import Func
4from plain.models.fields import (
5 DateField,
6 DateTimeField,
7 DurationField,
8 Field,
9 IntegerField,
10 TimeField,
11)
12from plain.models.lookups import (
13 Transform,
14 YearExact,
15 YearGt,
16 YearGte,
17 YearLt,
18 YearLte,
19)
20from plain.utils import timezone
21
22
23class TimezoneMixin:
24 tzinfo = None
25
26 def get_tzname(self):
27 # Timezone conversions must happen to the input datetime *before*
28 # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
29 # database as 2016-01-01 01:00:00 +00:00. Any results should be
30 # based on the input datetime not the stored datetime.
31 if self.tzinfo is None:
32 return timezone.get_current_timezone_name()
33 else:
34 return timezone._get_timezone_name(self.tzinfo)
35
36
37class Extract(TimezoneMixin, Transform):
38 lookup_name = None
39 output_field = IntegerField()
40
41 def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
42 if self.lookup_name is None:
43 self.lookup_name = lookup_name
44 if self.lookup_name is None:
45 raise ValueError("lookup_name must be provided")
46 self.tzinfo = tzinfo
47 super().__init__(expression, **extra)
48
49 def as_sql(self, compiler, connection):
50 sql, params = compiler.compile(self.lhs)
51 lhs_output_field = self.lhs.output_field
52 if isinstance(lhs_output_field, DateTimeField):
53 tzname = self.get_tzname()
54 sql, params = connection.ops.datetime_extract_sql(
55 self.lookup_name, sql, tuple(params), tzname
56 )
57 elif self.tzinfo is not None:
58 raise ValueError("tzinfo can only be used with DateTimeField.")
59 elif isinstance(lhs_output_field, DateField):
60 sql, params = connection.ops.date_extract_sql(
61 self.lookup_name, sql, tuple(params)
62 )
63 elif isinstance(lhs_output_field, TimeField):
64 sql, params = connection.ops.time_extract_sql(
65 self.lookup_name, sql, tuple(params)
66 )
67 elif isinstance(lhs_output_field, DurationField):
68 if not connection.features.has_native_duration_field:
69 raise ValueError(
70 "Extract requires native DurationField database support."
71 )
72 sql, params = connection.ops.time_extract_sql(
73 self.lookup_name, sql, tuple(params)
74 )
75 else:
76 # resolve_expression has already validated the output_field so this
77 # assert should never be hit.
78 raise ValueError("Tried to Extract from an invalid type.")
79 return sql, params
80
81 def resolve_expression(
82 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
83 ):
84 copy = super().resolve_expression(
85 query, allow_joins, reuse, summarize, for_save
86 )
87 field = getattr(copy.lhs, "output_field", None)
88 if field is None:
89 return copy
90 if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
91 raise ValueError(
92 "Extract input expression must be DateField, DateTimeField, "
93 "TimeField, or DurationField."
94 )
95 # Passing dates to functions expecting datetimes is most likely a mistake.
96 if type(field) == DateField and copy.lookup_name in ( # noqa: E721
97 "hour",
98 "minute",
99 "second",
100 ):
101 raise ValueError(
102 f"Cannot extract time component '{copy.lookup_name}' from DateField '{field.name}'."
103 )
104 if isinstance(field, DurationField) and copy.lookup_name in (
105 "year",
106 "iso_year",
107 "month",
108 "week",
109 "week_day",
110 "iso_week_day",
111 "quarter",
112 ):
113 raise ValueError(
114 f"Cannot extract component '{copy.lookup_name}' from DurationField '{field.name}'."
115 )
116 return copy
117
118
119class ExtractYear(Extract):
120 lookup_name = "year"
121
122
123class ExtractIsoYear(Extract):
124 """Return the ISO-8601 week-numbering year."""
125
126 lookup_name = "iso_year"
127
128
129class ExtractMonth(Extract):
130 lookup_name = "month"
131
132
133class ExtractDay(Extract):
134 lookup_name = "day"
135
136
137class ExtractWeek(Extract):
138 """
139 Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
140 week.
141 """
142
143 lookup_name = "week"
144
145
146class ExtractWeekDay(Extract):
147 """
148 Return Sunday=1 through Saturday=7.
149
150 To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
151 """
152
153 lookup_name = "week_day"
154
155
156class ExtractIsoWeekDay(Extract):
157 """Return Monday=1 through Sunday=7, based on ISO-8601."""
158
159 lookup_name = "iso_week_day"
160
161
162class ExtractQuarter(Extract):
163 lookup_name = "quarter"
164
165
166class ExtractHour(Extract):
167 lookup_name = "hour"
168
169
170class ExtractMinute(Extract):
171 lookup_name = "minute"
172
173
174class ExtractSecond(Extract):
175 lookup_name = "second"
176
177
178DateField.register_lookup(ExtractYear)
179DateField.register_lookup(ExtractMonth)
180DateField.register_lookup(ExtractDay)
181DateField.register_lookup(ExtractWeekDay)
182DateField.register_lookup(ExtractIsoWeekDay)
183DateField.register_lookup(ExtractWeek)
184DateField.register_lookup(ExtractIsoYear)
185DateField.register_lookup(ExtractQuarter)
186
187TimeField.register_lookup(ExtractHour)
188TimeField.register_lookup(ExtractMinute)
189TimeField.register_lookup(ExtractSecond)
190
191DateTimeField.register_lookup(ExtractHour)
192DateTimeField.register_lookup(ExtractMinute)
193DateTimeField.register_lookup(ExtractSecond)
194
195ExtractYear.register_lookup(YearExact)
196ExtractYear.register_lookup(YearGt)
197ExtractYear.register_lookup(YearGte)
198ExtractYear.register_lookup(YearLt)
199ExtractYear.register_lookup(YearLte)
200
201ExtractIsoYear.register_lookup(YearExact)
202ExtractIsoYear.register_lookup(YearGt)
203ExtractIsoYear.register_lookup(YearGte)
204ExtractIsoYear.register_lookup(YearLt)
205ExtractIsoYear.register_lookup(YearLte)
206
207
208class Now(Func):
209 template = "CURRENT_TIMESTAMP"
210 output_field = DateTimeField()
211
212 def as_postgresql(self, compiler, connection, **extra_context):
213 # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
214 # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
215 # other databases.
216 return self.as_sql(
217 compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
218 )
219
220 def as_mysql(self, compiler, connection, **extra_context):
221 return self.as_sql(
222 compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
223 )
224
225 def as_sqlite(self, compiler, connection, **extra_context):
226 return self.as_sql(
227 compiler,
228 connection,
229 template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
230 **extra_context,
231 )
232
233
234class TruncBase(TimezoneMixin, Transform):
235 kind = None
236 tzinfo = None
237
238 def __init__(
239 self,
240 expression,
241 output_field=None,
242 tzinfo=None,
243 **extra,
244 ):
245 self.tzinfo = tzinfo
246 super().__init__(expression, output_field=output_field, **extra)
247
248 def as_sql(self, compiler, connection):
249 sql, params = compiler.compile(self.lhs)
250 tzname = None
251 if isinstance(self.lhs.output_field, DateTimeField):
252 tzname = self.get_tzname()
253 elif self.tzinfo is not None:
254 raise ValueError("tzinfo can only be used with DateTimeField.")
255 if isinstance(self.output_field, DateTimeField):
256 sql, params = connection.ops.datetime_trunc_sql(
257 self.kind, sql, tuple(params), tzname
258 )
259 elif isinstance(self.output_field, DateField):
260 sql, params = connection.ops.date_trunc_sql(
261 self.kind, sql, tuple(params), tzname
262 )
263 elif isinstance(self.output_field, TimeField):
264 sql, params = connection.ops.time_trunc_sql(
265 self.kind, sql, tuple(params), tzname
266 )
267 else:
268 raise ValueError(
269 "Trunc only valid on DateField, TimeField, or DateTimeField."
270 )
271 return sql, params
272
273 def resolve_expression(
274 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
275 ):
276 copy = super().resolve_expression(
277 query, allow_joins, reuse, summarize, for_save
278 )
279 field = copy.lhs.output_field
280 # DateTimeField is a subclass of DateField so this works for both.
281 if not isinstance(field, DateField | TimeField):
282 raise TypeError(
283 f"{field.name!r} isn't a DateField, TimeField, or DateTimeField."
284 )
285 # If self.output_field was None, then accessing the field will trigger
286 # the resolver to assign it to self.lhs.output_field.
287 if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
288 raise ValueError(
289 "output_field must be either DateField, TimeField, or DateTimeField"
290 )
291 # Passing dates or times to functions expecting datetimes is most
292 # likely a mistake.
293 class_output_field = (
294 self.__class__.output_field
295 if isinstance(self.__class__.output_field, Field)
296 else None
297 )
298 output_field = class_output_field or copy.output_field
299 has_explicit_output_field = (
300 class_output_field or field.__class__ is not copy.output_field.__class__
301 )
302 if type(field) == DateField and ( # noqa: E721
303 isinstance(output_field, DateTimeField)
304 or copy.kind in ("hour", "minute", "second", "time")
305 ):
306 raise ValueError(
307 "Cannot truncate DateField '{}' to {}.".format(
308 field.name,
309 output_field.__class__.__name__
310 if has_explicit_output_field
311 else "DateTimeField",
312 )
313 )
314 elif isinstance(field, TimeField) and (
315 isinstance(output_field, DateTimeField)
316 or copy.kind in ("year", "quarter", "month", "week", "day", "date")
317 ):
318 raise ValueError(
319 "Cannot truncate TimeField '{}' to {}.".format(
320 field.name,
321 output_field.__class__.__name__
322 if has_explicit_output_field
323 else "DateTimeField",
324 )
325 )
326 return copy
327
328 def convert_value(self, value, expression, connection):
329 if isinstance(self.output_field, DateTimeField):
330 if value is not None:
331 value = value.replace(tzinfo=None)
332 value = timezone.make_aware(value, self.tzinfo)
333 elif not connection.features.has_zoneinfo_database:
334 raise ValueError(
335 "Database returned an invalid datetime value. Are time "
336 "zone definitions for your database installed?"
337 )
338 elif isinstance(value, datetime):
339 if value is None:
340 pass
341 elif isinstance(self.output_field, DateField):
342 value = value.date()
343 elif isinstance(self.output_field, TimeField):
344 value = value.time()
345 return value
346
347
348class Trunc(TruncBase):
349 def __init__(
350 self,
351 expression,
352 kind,
353 output_field=None,
354 tzinfo=None,
355 **extra,
356 ):
357 self.kind = kind
358 super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
359
360
361class TruncYear(TruncBase):
362 kind = "year"
363
364
365class TruncQuarter(TruncBase):
366 kind = "quarter"
367
368
369class TruncMonth(TruncBase):
370 kind = "month"
371
372
373class TruncWeek(TruncBase):
374 """Truncate to midnight on the Monday of the week."""
375
376 kind = "week"
377
378
379class TruncDay(TruncBase):
380 kind = "day"
381
382
383class TruncDate(TruncBase):
384 kind = "date"
385 lookup_name = "date"
386 output_field = DateField()
387
388 def as_sql(self, compiler, connection):
389 # Cast to date rather than truncate to date.
390 sql, params = compiler.compile(self.lhs)
391 tzname = self.get_tzname()
392 return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
393
394
395class TruncTime(TruncBase):
396 kind = "time"
397 lookup_name = "time"
398 output_field = TimeField()
399
400 def as_sql(self, compiler, connection):
401 # Cast to time rather than truncate to time.
402 sql, params = compiler.compile(self.lhs)
403 tzname = self.get_tzname()
404 return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
405
406
407class TruncHour(TruncBase):
408 kind = "hour"
409
410
411class TruncMinute(TruncBase):
412 kind = "minute"
413
414
415class TruncSecond(TruncBase):
416 kind = "second"
417
418
419DateTimeField.register_lookup(TruncDate)
420DateTimeField.register_lookup(TruncTime)