1"""
2Implementations of SQL functions for SQLite.
3"""
4import functools
5import random
6import statistics
7import zoneinfo
8from datetime import timedelta
9from hashlib import md5, sha1, sha224, sha256, sha384, sha512
10from math import (
11 acos,
12 asin,
13 atan,
14 atan2,
15 ceil,
16 cos,
17 degrees,
18 exp,
19 floor,
20 fmod,
21 log,
22 pi,
23 radians,
24 sin,
25 sqrt,
26 tan,
27)
28from re import search as re_search
29
30from plain.models.backends.utils import (
31 split_tzname_delta,
32 typecast_time,
33 typecast_timestamp,
34)
35from plain.utils import timezone
36from plain.utils.duration import duration_microseconds
37
38
39def register(connection):
40 create_deterministic_function = functools.partial(
41 connection.create_function,
42 deterministic=True,
43 )
44 create_deterministic_function("plain_date_extract", 2, _sqlite_datetime_extract)
45 create_deterministic_function("plain_date_trunc", 4, _sqlite_date_trunc)
46 create_deterministic_function(
47 "plain_datetime_cast_date", 3, _sqlite_datetime_cast_date
48 )
49 create_deterministic_function(
50 "plain_datetime_cast_time", 3, _sqlite_datetime_cast_time
51 )
52 create_deterministic_function("plain_datetime_extract", 4, _sqlite_datetime_extract)
53 create_deterministic_function("plain_datetime_trunc", 4, _sqlite_datetime_trunc)
54 create_deterministic_function("plain_time_extract", 2, _sqlite_time_extract)
55 create_deterministic_function("plain_time_trunc", 4, _sqlite_time_trunc)
56 create_deterministic_function("plain_time_diff", 2, _sqlite_time_diff)
57 create_deterministic_function("plain_timestamp_diff", 2, _sqlite_timestamp_diff)
58 create_deterministic_function("plain_format_dtdelta", 3, _sqlite_format_dtdelta)
59 create_deterministic_function("regexp", 2, _sqlite_regexp)
60 create_deterministic_function("BITXOR", 2, _sqlite_bitxor)
61 create_deterministic_function("COT", 1, _sqlite_cot)
62 create_deterministic_function("LPAD", 3, _sqlite_lpad)
63 create_deterministic_function("MD5", 1, _sqlite_md5)
64 create_deterministic_function("REPEAT", 2, _sqlite_repeat)
65 create_deterministic_function("REVERSE", 1, _sqlite_reverse)
66 create_deterministic_function("RPAD", 3, _sqlite_rpad)
67 create_deterministic_function("SHA1", 1, _sqlite_sha1)
68 create_deterministic_function("SHA224", 1, _sqlite_sha224)
69 create_deterministic_function("SHA256", 1, _sqlite_sha256)
70 create_deterministic_function("SHA384", 1, _sqlite_sha384)
71 create_deterministic_function("SHA512", 1, _sqlite_sha512)
72 create_deterministic_function("SIGN", 1, _sqlite_sign)
73 # Don't use the built-in RANDOM() function because it returns a value
74 # in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
75 connection.create_function("RAND", 0, random.random)
76 connection.create_aggregate("STDDEV_POP", 1, StdDevPop)
77 connection.create_aggregate("STDDEV_SAMP", 1, StdDevSamp)
78 connection.create_aggregate("VAR_POP", 1, VarPop)
79 connection.create_aggregate("VAR_SAMP", 1, VarSamp)
80 # Some math functions are enabled by default in SQLite 3.35+.
81 sql = "select sqlite_compileoption_used('ENABLE_MATH_FUNCTIONS')"
82 if not connection.execute(sql).fetchone()[0]:
83 create_deterministic_function("ACOS", 1, _sqlite_acos)
84 create_deterministic_function("ASIN", 1, _sqlite_asin)
85 create_deterministic_function("ATAN", 1, _sqlite_atan)
86 create_deterministic_function("ATAN2", 2, _sqlite_atan2)
87 create_deterministic_function("CEILING", 1, _sqlite_ceiling)
88 create_deterministic_function("COS", 1, _sqlite_cos)
89 create_deterministic_function("DEGREES", 1, _sqlite_degrees)
90 create_deterministic_function("EXP", 1, _sqlite_exp)
91 create_deterministic_function("FLOOR", 1, _sqlite_floor)
92 create_deterministic_function("LN", 1, _sqlite_ln)
93 create_deterministic_function("LOG", 2, _sqlite_log)
94 create_deterministic_function("MOD", 2, _sqlite_mod)
95 create_deterministic_function("PI", 0, _sqlite_pi)
96 create_deterministic_function("POWER", 2, _sqlite_power)
97 create_deterministic_function("RADIANS", 1, _sqlite_radians)
98 create_deterministic_function("SIN", 1, _sqlite_sin)
99 create_deterministic_function("SQRT", 1, _sqlite_sqrt)
100 create_deterministic_function("TAN", 1, _sqlite_tan)
101
102
103def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
104 if dt is None:
105 return None
106 try:
107 dt = typecast_timestamp(dt)
108 except (TypeError, ValueError):
109 return None
110 if conn_tzname:
111 dt = dt.replace(tzinfo=zoneinfo.ZoneInfo(conn_tzname))
112 if tzname is not None and tzname != conn_tzname:
113 tzname, sign, offset = split_tzname_delta(tzname)
114 if offset:
115 hours, minutes = offset.split(":")
116 offset_delta = timedelta(hours=int(hours), minutes=int(minutes))
117 dt += offset_delta if sign == "+" else -offset_delta
118 dt = timezone.localtime(dt, zoneinfo.ZoneInfo(tzname))
119 return dt
120
121
122def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
123 dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
124 if dt is None:
125 return None
126 if lookup_type == "year":
127 return f"{dt.year:04d}-01-01"
128 elif lookup_type == "quarter":
129 month_in_quarter = dt.month - (dt.month - 1) % 3
130 return f"{dt.year:04d}-{month_in_quarter:02d}-01"
131 elif lookup_type == "month":
132 return f"{dt.year:04d}-{dt.month:02d}-01"
133 elif lookup_type == "week":
134 dt -= timedelta(days=dt.weekday())
135 return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
136 elif lookup_type == "day":
137 return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
138 raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
139
140
141def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
142 if dt is None:
143 return None
144 dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
145 if dt_parsed is None:
146 try:
147 dt = typecast_time(dt)
148 except (ValueError, TypeError):
149 return None
150 else:
151 dt = dt_parsed
152 if lookup_type == "hour":
153 return f"{dt.hour:02d}:00:00"
154 elif lookup_type == "minute":
155 return f"{dt.hour:02d}:{dt.minute:02d}:00"
156 elif lookup_type == "second":
157 return f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
158 raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
159
160
161def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
162 dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
163 if dt is None:
164 return None
165 return dt.date().isoformat()
166
167
168def _sqlite_datetime_cast_time(dt, tzname, conn_tzname):
169 dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
170 if dt is None:
171 return None
172 return dt.time().isoformat()
173
174
175def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
176 dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
177 if dt is None:
178 return None
179 if lookup_type == "week_day":
180 return (dt.isoweekday() % 7) + 1
181 elif lookup_type == "iso_week_day":
182 return dt.isoweekday()
183 elif lookup_type == "week":
184 return dt.isocalendar().week
185 elif lookup_type == "quarter":
186 return ceil(dt.month / 3)
187 elif lookup_type == "iso_year":
188 return dt.isocalendar().year
189 else:
190 return getattr(dt, lookup_type)
191
192
193def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
194 dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
195 if dt is None:
196 return None
197 if lookup_type == "year":
198 return f"{dt.year:04d}-01-01 00:00:00"
199 elif lookup_type == "quarter":
200 month_in_quarter = dt.month - (dt.month - 1) % 3
201 return f"{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00"
202 elif lookup_type == "month":
203 return f"{dt.year:04d}-{dt.month:02d}-01 00:00:00"
204 elif lookup_type == "week":
205 dt -= timedelta(days=dt.weekday())
206 return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
207 elif lookup_type == "day":
208 return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
209 elif lookup_type == "hour":
210 return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00"
211 elif lookup_type == "minute":
212 return (
213 f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
214 f"{dt.hour:02d}:{dt.minute:02d}:00"
215 )
216 elif lookup_type == "second":
217 return (
218 f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
219 f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
220 )
221 raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
222
223
224def _sqlite_time_extract(lookup_type, dt):
225 if dt is None:
226 return None
227 try:
228 dt = typecast_time(dt)
229 except (ValueError, TypeError):
230 return None
231 return getattr(dt, lookup_type)
232
233
234def _sqlite_prepare_dtdelta_param(conn, param):
235 if conn in ["+", "-"]:
236 if isinstance(param, int):
237 return timedelta(0, 0, param)
238 else:
239 return typecast_timestamp(param)
240 return param
241
242
243def _sqlite_format_dtdelta(connector, lhs, rhs):
244 """
245 LHS and RHS can be either:
246 - An integer number of microseconds
247 - A string representing a datetime
248 - A scalar value, e.g. float
249 """
250 if connector is None or lhs is None or rhs is None:
251 return None
252 connector = connector.strip()
253 try:
254 real_lhs = _sqlite_prepare_dtdelta_param(connector, lhs)
255 real_rhs = _sqlite_prepare_dtdelta_param(connector, rhs)
256 except (ValueError, TypeError):
257 return None
258 if connector == "+":
259 # typecast_timestamp() returns a date or a datetime without timezone.
260 # It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
261 out = str(real_lhs + real_rhs)
262 elif connector == "-":
263 out = str(real_lhs - real_rhs)
264 elif connector == "*":
265 out = real_lhs * real_rhs
266 else:
267 out = real_lhs / real_rhs
268 return out
269
270
271def _sqlite_time_diff(lhs, rhs):
272 if lhs is None or rhs is None:
273 return None
274 left = typecast_time(lhs)
275 right = typecast_time(rhs)
276 return (
277 (left.hour * 60 * 60 * 1000000)
278 + (left.minute * 60 * 1000000)
279 + (left.second * 1000000)
280 + (left.microsecond)
281 - (right.hour * 60 * 60 * 1000000)
282 - (right.minute * 60 * 1000000)
283 - (right.second * 1000000)
284 - (right.microsecond)
285 )
286
287
288def _sqlite_timestamp_diff(lhs, rhs):
289 if lhs is None or rhs is None:
290 return None
291 left = typecast_timestamp(lhs)
292 right = typecast_timestamp(rhs)
293 return duration_microseconds(left - right)
294
295
296def _sqlite_regexp(pattern, string):
297 if pattern is None or string is None:
298 return None
299 if not isinstance(string, str):
300 string = str(string)
301 return bool(re_search(pattern, string))
302
303
304def _sqlite_acos(x):
305 if x is None:
306 return None
307 return acos(x)
308
309
310def _sqlite_asin(x):
311 if x is None:
312 return None
313 return asin(x)
314
315
316def _sqlite_atan(x):
317 if x is None:
318 return None
319 return atan(x)
320
321
322def _sqlite_atan2(y, x):
323 if y is None or x is None:
324 return None
325 return atan2(y, x)
326
327
328def _sqlite_bitxor(x, y):
329 if x is None or y is None:
330 return None
331 return x ^ y
332
333
334def _sqlite_ceiling(x):
335 if x is None:
336 return None
337 return ceil(x)
338
339
340def _sqlite_cos(x):
341 if x is None:
342 return None
343 return cos(x)
344
345
346def _sqlite_cot(x):
347 if x is None:
348 return None
349 return 1 / tan(x)
350
351
352def _sqlite_degrees(x):
353 if x is None:
354 return None
355 return degrees(x)
356
357
358def _sqlite_exp(x):
359 if x is None:
360 return None
361 return exp(x)
362
363
364def _sqlite_floor(x):
365 if x is None:
366 return None
367 return floor(x)
368
369
370def _sqlite_ln(x):
371 if x is None:
372 return None
373 return log(x)
374
375
376def _sqlite_log(base, x):
377 if base is None or x is None:
378 return None
379 # Arguments reversed to match SQL standard.
380 return log(x, base)
381
382
383def _sqlite_lpad(text, length, fill_text):
384 if text is None or length is None or fill_text is None:
385 return None
386 delta = length - len(text)
387 if delta <= 0:
388 return text[:length]
389 return (fill_text * length)[:delta] + text
390
391
392def _sqlite_md5(text):
393 if text is None:
394 return None
395 return md5(text.encode()).hexdigest()
396
397
398def _sqlite_mod(x, y):
399 if x is None or y is None:
400 return None
401 return fmod(x, y)
402
403
404def _sqlite_pi():
405 return pi
406
407
408def _sqlite_power(x, y):
409 if x is None or y is None:
410 return None
411 return x**y
412
413
414def _sqlite_radians(x):
415 if x is None:
416 return None
417 return radians(x)
418
419
420def _sqlite_repeat(text, count):
421 if text is None or count is None:
422 return None
423 return text * count
424
425
426def _sqlite_reverse(text):
427 if text is None:
428 return None
429 return text[::-1]
430
431
432def _sqlite_rpad(text, length, fill_text):
433 if text is None or length is None or fill_text is None:
434 return None
435 return (text + fill_text * length)[:length]
436
437
438def _sqlite_sha1(text):
439 if text is None:
440 return None
441 return sha1(text.encode()).hexdigest()
442
443
444def _sqlite_sha224(text):
445 if text is None:
446 return None
447 return sha224(text.encode()).hexdigest()
448
449
450def _sqlite_sha256(text):
451 if text is None:
452 return None
453 return sha256(text.encode()).hexdigest()
454
455
456def _sqlite_sha384(text):
457 if text is None:
458 return None
459 return sha384(text.encode()).hexdigest()
460
461
462def _sqlite_sha512(text):
463 if text is None:
464 return None
465 return sha512(text.encode()).hexdigest()
466
467
468def _sqlite_sign(x):
469 if x is None:
470 return None
471 return (x > 0) - (x < 0)
472
473
474def _sqlite_sin(x):
475 if x is None:
476 return None
477 return sin(x)
478
479
480def _sqlite_sqrt(x):
481 if x is None:
482 return None
483 return sqrt(x)
484
485
486def _sqlite_tan(x):
487 if x is None:
488 return None
489 return tan(x)
490
491
492class ListAggregate(list):
493 step = list.append
494
495
496class StdDevPop(ListAggregate):
497 finalize = statistics.pstdev
498
499
500class StdDevSamp(ListAggregate):
501 finalize = statistics.stdev
502
503
504class VarPop(ListAggregate):
505 finalize = statistics.pvariance
506
507
508class VarSamp(ListAggregate):
509 finalize = statistics.variance