1from datetime import datetime
2
3from plain.models.expressions import Func
4from plain.models.fields import (
5 DateField,
6 DateTimeField,
7 DurationField,
8 Field,
9 IntegerField,
10 TimeField,
11)
12from plain.models.lookups import (
13 Transform,
14 YearExact,
15 YearGt,
16 YearGte,
17 YearLt,
18 YearLte,
19)
20from plain.runtime import settings
21from plain.utils import timezone
22
23
24class TimezoneMixin:
25 tzinfo = None
26
27 def get_tzname(self):
28 # Timezone conversions must happen to the input datetime *before*
29 # applying a function. 2015-12-31 23:00:00 -02:00 is stored in the
30 # database as 2016-01-01 01:00:00 +00:00. Any results should be
31 # based on the input datetime not the stored datetime.
32 tzname = None
33 if settings.USE_TZ:
34 if self.tzinfo is None:
35 tzname = timezone.get_current_timezone_name()
36 else:
37 tzname = timezone._get_timezone_name(self.tzinfo)
38 return tzname
39
40
41class Extract(TimezoneMixin, Transform):
42 lookup_name = None
43 output_field = IntegerField()
44
45 def __init__(self, expression, lookup_name=None, tzinfo=None, **extra):
46 if self.lookup_name is None:
47 self.lookup_name = lookup_name
48 if self.lookup_name is None:
49 raise ValueError("lookup_name must be provided")
50 self.tzinfo = tzinfo
51 super().__init__(expression, **extra)
52
53 def as_sql(self, compiler, connection):
54 sql, params = compiler.compile(self.lhs)
55 lhs_output_field = self.lhs.output_field
56 if isinstance(lhs_output_field, DateTimeField):
57 tzname = self.get_tzname()
58 sql, params = connection.ops.datetime_extract_sql(
59 self.lookup_name, sql, tuple(params), tzname
60 )
61 elif self.tzinfo is not None:
62 raise ValueError("tzinfo can only be used with DateTimeField.")
63 elif isinstance(lhs_output_field, DateField):
64 sql, params = connection.ops.date_extract_sql(
65 self.lookup_name, sql, tuple(params)
66 )
67 elif isinstance(lhs_output_field, TimeField):
68 sql, params = connection.ops.time_extract_sql(
69 self.lookup_name, sql, tuple(params)
70 )
71 elif isinstance(lhs_output_field, DurationField):
72 if not connection.features.has_native_duration_field:
73 raise ValueError(
74 "Extract requires native DurationField database support."
75 )
76 sql, params = connection.ops.time_extract_sql(
77 self.lookup_name, sql, tuple(params)
78 )
79 else:
80 # resolve_expression has already validated the output_field so this
81 # assert should never be hit.
82 raise ValueError("Tried to Extract from an invalid type.")
83 return sql, params
84
85 def resolve_expression(
86 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
87 ):
88 copy = super().resolve_expression(
89 query, allow_joins, reuse, summarize, for_save
90 )
91 field = getattr(copy.lhs, "output_field", None)
92 if field is None:
93 return copy
94 if not isinstance(field, DateField | DateTimeField | TimeField | DurationField):
95 raise ValueError(
96 "Extract input expression must be DateField, DateTimeField, "
97 "TimeField, or DurationField."
98 )
99 # Passing dates to functions expecting datetimes is most likely a mistake.
100 if type(field) == DateField and copy.lookup_name in (
101 "hour",
102 "minute",
103 "second",
104 ):
105 raise ValueError(
106 "Cannot extract time component '{}' from DateField '{}'.".format(
107 copy.lookup_name, field.name
108 )
109 )
110 if isinstance(field, DurationField) and copy.lookup_name in (
111 "year",
112 "iso_year",
113 "month",
114 "week",
115 "week_day",
116 "iso_week_day",
117 "quarter",
118 ):
119 raise ValueError(
120 "Cannot extract component '{}' from DurationField '{}'.".format(
121 copy.lookup_name, field.name
122 )
123 )
124 return copy
125
126
127class ExtractYear(Extract):
128 lookup_name = "year"
129
130
131class ExtractIsoYear(Extract):
132 """Return the ISO-8601 week-numbering year."""
133
134 lookup_name = "iso_year"
135
136
137class ExtractMonth(Extract):
138 lookup_name = "month"
139
140
141class ExtractDay(Extract):
142 lookup_name = "day"
143
144
145class ExtractWeek(Extract):
146 """
147 Return 1-52 or 53, based on ISO-8601, i.e., Monday is the first of the
148 week.
149 """
150
151 lookup_name = "week"
152
153
154class ExtractWeekDay(Extract):
155 """
156 Return Sunday=1 through Saturday=7.
157
158 To replicate this in Python: (mydatetime.isoweekday() % 7) + 1
159 """
160
161 lookup_name = "week_day"
162
163
164class ExtractIsoWeekDay(Extract):
165 """Return Monday=1 through Sunday=7, based on ISO-8601."""
166
167 lookup_name = "iso_week_day"
168
169
170class ExtractQuarter(Extract):
171 lookup_name = "quarter"
172
173
174class ExtractHour(Extract):
175 lookup_name = "hour"
176
177
178class ExtractMinute(Extract):
179 lookup_name = "minute"
180
181
182class ExtractSecond(Extract):
183 lookup_name = "second"
184
185
186DateField.register_lookup(ExtractYear)
187DateField.register_lookup(ExtractMonth)
188DateField.register_lookup(ExtractDay)
189DateField.register_lookup(ExtractWeekDay)
190DateField.register_lookup(ExtractIsoWeekDay)
191DateField.register_lookup(ExtractWeek)
192DateField.register_lookup(ExtractIsoYear)
193DateField.register_lookup(ExtractQuarter)
194
195TimeField.register_lookup(ExtractHour)
196TimeField.register_lookup(ExtractMinute)
197TimeField.register_lookup(ExtractSecond)
198
199DateTimeField.register_lookup(ExtractHour)
200DateTimeField.register_lookup(ExtractMinute)
201DateTimeField.register_lookup(ExtractSecond)
202
203ExtractYear.register_lookup(YearExact)
204ExtractYear.register_lookup(YearGt)
205ExtractYear.register_lookup(YearGte)
206ExtractYear.register_lookup(YearLt)
207ExtractYear.register_lookup(YearLte)
208
209ExtractIsoYear.register_lookup(YearExact)
210ExtractIsoYear.register_lookup(YearGt)
211ExtractIsoYear.register_lookup(YearGte)
212ExtractIsoYear.register_lookup(YearLt)
213ExtractIsoYear.register_lookup(YearLte)
214
215
216class Now(Func):
217 template = "CURRENT_TIMESTAMP"
218 output_field = DateTimeField()
219
220 def as_postgresql(self, compiler, connection, **extra_context):
221 # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
222 # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
223 # other databases.
224 return self.as_sql(
225 compiler, connection, template="STATEMENT_TIMESTAMP()", **extra_context
226 )
227
228 def as_mysql(self, compiler, connection, **extra_context):
229 return self.as_sql(
230 compiler, connection, template="CURRENT_TIMESTAMP(6)", **extra_context
231 )
232
233 def as_sqlite(self, compiler, connection, **extra_context):
234 return self.as_sql(
235 compiler,
236 connection,
237 template="STRFTIME('%%%%Y-%%%%m-%%%%d %%%%H:%%%%M:%%%%f', 'NOW')",
238 **extra_context,
239 )
240
241
242class TruncBase(TimezoneMixin, Transform):
243 kind = None
244 tzinfo = None
245
246 def __init__(
247 self,
248 expression,
249 output_field=None,
250 tzinfo=None,
251 **extra,
252 ):
253 self.tzinfo = tzinfo
254 super().__init__(expression, output_field=output_field, **extra)
255
256 def as_sql(self, compiler, connection):
257 sql, params = compiler.compile(self.lhs)
258 tzname = None
259 if isinstance(self.lhs.output_field, DateTimeField):
260 tzname = self.get_tzname()
261 elif self.tzinfo is not None:
262 raise ValueError("tzinfo can only be used with DateTimeField.")
263 if isinstance(self.output_field, DateTimeField):
264 sql, params = connection.ops.datetime_trunc_sql(
265 self.kind, sql, tuple(params), tzname
266 )
267 elif isinstance(self.output_field, DateField):
268 sql, params = connection.ops.date_trunc_sql(
269 self.kind, sql, tuple(params), tzname
270 )
271 elif isinstance(self.output_field, TimeField):
272 sql, params = connection.ops.time_trunc_sql(
273 self.kind, sql, tuple(params), tzname
274 )
275 else:
276 raise ValueError(
277 "Trunc only valid on DateField, TimeField, or DateTimeField."
278 )
279 return sql, params
280
281 def resolve_expression(
282 self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False
283 ):
284 copy = super().resolve_expression(
285 query, allow_joins, reuse, summarize, for_save
286 )
287 field = copy.lhs.output_field
288 # DateTimeField is a subclass of DateField so this works for both.
289 if not isinstance(field, DateField | TimeField):
290 raise TypeError(
291 "%r isn't a DateField, TimeField, or DateTimeField." % field.name
292 )
293 # If self.output_field was None, then accessing the field will trigger
294 # the resolver to assign it to self.lhs.output_field.
295 if not isinstance(copy.output_field, DateField | DateTimeField | TimeField):
296 raise ValueError(
297 "output_field must be either DateField, TimeField, or DateTimeField"
298 )
299 # Passing dates or times to functions expecting datetimes is most
300 # likely a mistake.
301 class_output_field = (
302 self.__class__.output_field
303 if isinstance(self.__class__.output_field, Field)
304 else None
305 )
306 output_field = class_output_field or copy.output_field
307 has_explicit_output_field = (
308 class_output_field or field.__class__ is not copy.output_field.__class__
309 )
310 if type(field) == DateField and (
311 isinstance(output_field, DateTimeField)
312 or copy.kind in ("hour", "minute", "second", "time")
313 ):
314 raise ValueError(
315 "Cannot truncate DateField '{}' to {}.".format(
316 field.name,
317 output_field.__class__.__name__
318 if has_explicit_output_field
319 else "DateTimeField",
320 )
321 )
322 elif isinstance(field, TimeField) and (
323 isinstance(output_field, DateTimeField)
324 or copy.kind in ("year", "quarter", "month", "week", "day", "date")
325 ):
326 raise ValueError(
327 "Cannot truncate TimeField '{}' to {}.".format(
328 field.name,
329 output_field.__class__.__name__
330 if has_explicit_output_field
331 else "DateTimeField",
332 )
333 )
334 return copy
335
336 def convert_value(self, value, expression, connection):
337 if isinstance(self.output_field, DateTimeField):
338 if not settings.USE_TZ:
339 pass
340 elif value is not None:
341 value = value.replace(tzinfo=None)
342 value = timezone.make_aware(value, self.tzinfo)
343 elif not connection.features.has_zoneinfo_database:
344 raise ValueError(
345 "Database returned an invalid datetime value. Are time "
346 "zone definitions for your database installed?"
347 )
348 elif isinstance(value, datetime):
349 if value is None:
350 pass
351 elif isinstance(self.output_field, DateField):
352 value = value.date()
353 elif isinstance(self.output_field, TimeField):
354 value = value.time()
355 return value
356
357
358class Trunc(TruncBase):
359 def __init__(
360 self,
361 expression,
362 kind,
363 output_field=None,
364 tzinfo=None,
365 **extra,
366 ):
367 self.kind = kind
368 super().__init__(expression, output_field=output_field, tzinfo=tzinfo, **extra)
369
370
371class TruncYear(TruncBase):
372 kind = "year"
373
374
375class TruncQuarter(TruncBase):
376 kind = "quarter"
377
378
379class TruncMonth(TruncBase):
380 kind = "month"
381
382
383class TruncWeek(TruncBase):
384 """Truncate to midnight on the Monday of the week."""
385
386 kind = "week"
387
388
389class TruncDay(TruncBase):
390 kind = "day"
391
392
393class TruncDate(TruncBase):
394 kind = "date"
395 lookup_name = "date"
396 output_field = DateField()
397
398 def as_sql(self, compiler, connection):
399 # Cast to date rather than truncate to date.
400 sql, params = compiler.compile(self.lhs)
401 tzname = self.get_tzname()
402 return connection.ops.datetime_cast_date_sql(sql, tuple(params), tzname)
403
404
405class TruncTime(TruncBase):
406 kind = "time"
407 lookup_name = "time"
408 output_field = TimeField()
409
410 def as_sql(self, compiler, connection):
411 # Cast to time rather than truncate to time.
412 sql, params = compiler.compile(self.lhs)
413 tzname = self.get_tzname()
414 return connection.ops.datetime_cast_time_sql(sql, tuple(params), tzname)
415
416
417class TruncHour(TruncBase):
418 kind = "hour"
419
420
421class TruncMinute(TruncBase):
422 kind = "minute"
423
424
425class TruncSecond(TruncBase):
426 kind = "second"
427
428
429DateTimeField.register_lookup(TruncDate)
430DateTimeField.register_lookup(TruncTime)