Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Implementations of SQL functions for SQLite.
  3"""
  4import functools
  5import random
  6import statistics
  7import zoneinfo
  8from datetime import timedelta
  9from hashlib import md5, sha1, sha224, sha256, sha384, sha512
 10from math import (
 11    acos,
 12    asin,
 13    atan,
 14    atan2,
 15    ceil,
 16    cos,
 17    degrees,
 18    exp,
 19    floor,
 20    fmod,
 21    log,
 22    pi,
 23    radians,
 24    sin,
 25    sqrt,
 26    tan,
 27)
 28from re import search as re_search
 29
 30from plain.models.backends.utils import (
 31    split_tzname_delta,
 32    typecast_time,
 33    typecast_timestamp,
 34)
 35from plain.utils import timezone
 36from plain.utils.duration import duration_microseconds
 37
 38
 39def register(connection):
 40    create_deterministic_function = functools.partial(
 41        connection.create_function,
 42        deterministic=True,
 43    )
 44    create_deterministic_function("plain_date_extract", 2, _sqlite_datetime_extract)
 45    create_deterministic_function("plain_date_trunc", 4, _sqlite_date_trunc)
 46    create_deterministic_function(
 47        "plain_datetime_cast_date", 3, _sqlite_datetime_cast_date
 48    )
 49    create_deterministic_function(
 50        "plain_datetime_cast_time", 3, _sqlite_datetime_cast_time
 51    )
 52    create_deterministic_function("plain_datetime_extract", 4, _sqlite_datetime_extract)
 53    create_deterministic_function("plain_datetime_trunc", 4, _sqlite_datetime_trunc)
 54    create_deterministic_function("plain_time_extract", 2, _sqlite_time_extract)
 55    create_deterministic_function("plain_time_trunc", 4, _sqlite_time_trunc)
 56    create_deterministic_function("plain_time_diff", 2, _sqlite_time_diff)
 57    create_deterministic_function("plain_timestamp_diff", 2, _sqlite_timestamp_diff)
 58    create_deterministic_function("plain_format_dtdelta", 3, _sqlite_format_dtdelta)
 59    create_deterministic_function("regexp", 2, _sqlite_regexp)
 60    create_deterministic_function("BITXOR", 2, _sqlite_bitxor)
 61    create_deterministic_function("COT", 1, _sqlite_cot)
 62    create_deterministic_function("LPAD", 3, _sqlite_lpad)
 63    create_deterministic_function("MD5", 1, _sqlite_md5)
 64    create_deterministic_function("REPEAT", 2, _sqlite_repeat)
 65    create_deterministic_function("REVERSE", 1, _sqlite_reverse)
 66    create_deterministic_function("RPAD", 3, _sqlite_rpad)
 67    create_deterministic_function("SHA1", 1, _sqlite_sha1)
 68    create_deterministic_function("SHA224", 1, _sqlite_sha224)
 69    create_deterministic_function("SHA256", 1, _sqlite_sha256)
 70    create_deterministic_function("SHA384", 1, _sqlite_sha384)
 71    create_deterministic_function("SHA512", 1, _sqlite_sha512)
 72    create_deterministic_function("SIGN", 1, _sqlite_sign)
 73    # Don't use the built-in RANDOM() function because it returns a value
 74    # in the range [-1 * 2^63, 2^63 - 1] instead of [0, 1).
 75    connection.create_function("RAND", 0, random.random)
 76    connection.create_aggregate("STDDEV_POP", 1, StdDevPop)
 77    connection.create_aggregate("STDDEV_SAMP", 1, StdDevSamp)
 78    connection.create_aggregate("VAR_POP", 1, VarPop)
 79    connection.create_aggregate("VAR_SAMP", 1, VarSamp)
 80    # Some math functions are enabled by default in SQLite 3.35+.
 81    sql = "select sqlite_compileoption_used('ENABLE_MATH_FUNCTIONS')"
 82    if not connection.execute(sql).fetchone()[0]:
 83        create_deterministic_function("ACOS", 1, _sqlite_acos)
 84        create_deterministic_function("ASIN", 1, _sqlite_asin)
 85        create_deterministic_function("ATAN", 1, _sqlite_atan)
 86        create_deterministic_function("ATAN2", 2, _sqlite_atan2)
 87        create_deterministic_function("CEILING", 1, _sqlite_ceiling)
 88        create_deterministic_function("COS", 1, _sqlite_cos)
 89        create_deterministic_function("DEGREES", 1, _sqlite_degrees)
 90        create_deterministic_function("EXP", 1, _sqlite_exp)
 91        create_deterministic_function("FLOOR", 1, _sqlite_floor)
 92        create_deterministic_function("LN", 1, _sqlite_ln)
 93        create_deterministic_function("LOG", 2, _sqlite_log)
 94        create_deterministic_function("MOD", 2, _sqlite_mod)
 95        create_deterministic_function("PI", 0, _sqlite_pi)
 96        create_deterministic_function("POWER", 2, _sqlite_power)
 97        create_deterministic_function("RADIANS", 1, _sqlite_radians)
 98        create_deterministic_function("SIN", 1, _sqlite_sin)
 99        create_deterministic_function("SQRT", 1, _sqlite_sqrt)
100        create_deterministic_function("TAN", 1, _sqlite_tan)
101
102
103def _sqlite_datetime_parse(dt, tzname=None, conn_tzname=None):
104    if dt is None:
105        return None
106    try:
107        dt = typecast_timestamp(dt)
108    except (TypeError, ValueError):
109        return None
110    if conn_tzname:
111        dt = dt.replace(tzinfo=zoneinfo.ZoneInfo(conn_tzname))
112    if tzname is not None and tzname != conn_tzname:
113        tzname, sign, offset = split_tzname_delta(tzname)
114        if offset:
115            hours, minutes = offset.split(":")
116            offset_delta = timedelta(hours=int(hours), minutes=int(minutes))
117            dt += offset_delta if sign == "+" else -offset_delta
118        dt = timezone.localtime(dt, zoneinfo.ZoneInfo(tzname))
119    return dt
120
121
122def _sqlite_date_trunc(lookup_type, dt, tzname, conn_tzname):
123    dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
124    if dt is None:
125        return None
126    if lookup_type == "year":
127        return f"{dt.year:04d}-01-01"
128    elif lookup_type == "quarter":
129        month_in_quarter = dt.month - (dt.month - 1) % 3
130        return f"{dt.year:04d}-{month_in_quarter:02d}-01"
131    elif lookup_type == "month":
132        return f"{dt.year:04d}-{dt.month:02d}-01"
133    elif lookup_type == "week":
134        dt -= timedelta(days=dt.weekday())
135        return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
136    elif lookup_type == "day":
137        return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d}"
138    raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
139
140
141def _sqlite_time_trunc(lookup_type, dt, tzname, conn_tzname):
142    if dt is None:
143        return None
144    dt_parsed = _sqlite_datetime_parse(dt, tzname, conn_tzname)
145    if dt_parsed is None:
146        try:
147            dt = typecast_time(dt)
148        except (ValueError, TypeError):
149            return None
150    else:
151        dt = dt_parsed
152    if lookup_type == "hour":
153        return f"{dt.hour:02d}:00:00"
154    elif lookup_type == "minute":
155        return f"{dt.hour:02d}:{dt.minute:02d}:00"
156    elif lookup_type == "second":
157        return f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
158    raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
159
160
161def _sqlite_datetime_cast_date(dt, tzname, conn_tzname):
162    dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
163    if dt is None:
164        return None
165    return dt.date().isoformat()
166
167
168def _sqlite_datetime_cast_time(dt, tzname, conn_tzname):
169    dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
170    if dt is None:
171        return None
172    return dt.time().isoformat()
173
174
175def _sqlite_datetime_extract(lookup_type, dt, tzname=None, conn_tzname=None):
176    dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
177    if dt is None:
178        return None
179    if lookup_type == "week_day":
180        return (dt.isoweekday() % 7) + 1
181    elif lookup_type == "iso_week_day":
182        return dt.isoweekday()
183    elif lookup_type == "week":
184        return dt.isocalendar().week
185    elif lookup_type == "quarter":
186        return ceil(dt.month / 3)
187    elif lookup_type == "iso_year":
188        return dt.isocalendar().year
189    else:
190        return getattr(dt, lookup_type)
191
192
193def _sqlite_datetime_trunc(lookup_type, dt, tzname, conn_tzname):
194    dt = _sqlite_datetime_parse(dt, tzname, conn_tzname)
195    if dt is None:
196        return None
197    if lookup_type == "year":
198        return f"{dt.year:04d}-01-01 00:00:00"
199    elif lookup_type == "quarter":
200        month_in_quarter = dt.month - (dt.month - 1) % 3
201        return f"{dt.year:04d}-{month_in_quarter:02d}-01 00:00:00"
202    elif lookup_type == "month":
203        return f"{dt.year:04d}-{dt.month:02d}-01 00:00:00"
204    elif lookup_type == "week":
205        dt -= timedelta(days=dt.weekday())
206        return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
207    elif lookup_type == "day":
208        return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} 00:00:00"
209    elif lookup_type == "hour":
210        return f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} {dt.hour:02d}:00:00"
211    elif lookup_type == "minute":
212        return (
213            f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
214            f"{dt.hour:02d}:{dt.minute:02d}:00"
215        )
216    elif lookup_type == "second":
217        return (
218            f"{dt.year:04d}-{dt.month:02d}-{dt.day:02d} "
219            f"{dt.hour:02d}:{dt.minute:02d}:{dt.second:02d}"
220        )
221    raise ValueError(f"Unsupported lookup type: {lookup_type!r}")
222
223
224def _sqlite_time_extract(lookup_type, dt):
225    if dt is None:
226        return None
227    try:
228        dt = typecast_time(dt)
229    except (ValueError, TypeError):
230        return None
231    return getattr(dt, lookup_type)
232
233
234def _sqlite_prepare_dtdelta_param(conn, param):
235    if conn in ["+", "-"]:
236        if isinstance(param, int):
237            return timedelta(0, 0, param)
238        else:
239            return typecast_timestamp(param)
240    return param
241
242
243def _sqlite_format_dtdelta(connector, lhs, rhs):
244    """
245    LHS and RHS can be either:
246    - An integer number of microseconds
247    - A string representing a datetime
248    - A scalar value, e.g. float
249    """
250    if connector is None or lhs is None or rhs is None:
251        return None
252    connector = connector.strip()
253    try:
254        real_lhs = _sqlite_prepare_dtdelta_param(connector, lhs)
255        real_rhs = _sqlite_prepare_dtdelta_param(connector, rhs)
256    except (ValueError, TypeError):
257        return None
258    if connector == "+":
259        # typecast_timestamp() returns a date or a datetime without timezone.
260        # It will be formatted as "%Y-%m-%d" or "%Y-%m-%d %H:%M:%S[.%f]"
261        out = str(real_lhs + real_rhs)
262    elif connector == "-":
263        out = str(real_lhs - real_rhs)
264    elif connector == "*":
265        out = real_lhs * real_rhs
266    else:
267        out = real_lhs / real_rhs
268    return out
269
270
271def _sqlite_time_diff(lhs, rhs):
272    if lhs is None or rhs is None:
273        return None
274    left = typecast_time(lhs)
275    right = typecast_time(rhs)
276    return (
277        (left.hour * 60 * 60 * 1000000)
278        + (left.minute * 60 * 1000000)
279        + (left.second * 1000000)
280        + (left.microsecond)
281        - (right.hour * 60 * 60 * 1000000)
282        - (right.minute * 60 * 1000000)
283        - (right.second * 1000000)
284        - (right.microsecond)
285    )
286
287
288def _sqlite_timestamp_diff(lhs, rhs):
289    if lhs is None or rhs is None:
290        return None
291    left = typecast_timestamp(lhs)
292    right = typecast_timestamp(rhs)
293    return duration_microseconds(left - right)
294
295
296def _sqlite_regexp(pattern, string):
297    if pattern is None or string is None:
298        return None
299    if not isinstance(string, str):
300        string = str(string)
301    return bool(re_search(pattern, string))
302
303
304def _sqlite_acos(x):
305    if x is None:
306        return None
307    return acos(x)
308
309
310def _sqlite_asin(x):
311    if x is None:
312        return None
313    return asin(x)
314
315
316def _sqlite_atan(x):
317    if x is None:
318        return None
319    return atan(x)
320
321
322def _sqlite_atan2(y, x):
323    if y is None or x is None:
324        return None
325    return atan2(y, x)
326
327
328def _sqlite_bitxor(x, y):
329    if x is None or y is None:
330        return None
331    return x ^ y
332
333
334def _sqlite_ceiling(x):
335    if x is None:
336        return None
337    return ceil(x)
338
339
340def _sqlite_cos(x):
341    if x is None:
342        return None
343    return cos(x)
344
345
346def _sqlite_cot(x):
347    if x is None:
348        return None
349    return 1 / tan(x)
350
351
352def _sqlite_degrees(x):
353    if x is None:
354        return None
355    return degrees(x)
356
357
358def _sqlite_exp(x):
359    if x is None:
360        return None
361    return exp(x)
362
363
364def _sqlite_floor(x):
365    if x is None:
366        return None
367    return floor(x)
368
369
370def _sqlite_ln(x):
371    if x is None:
372        return None
373    return log(x)
374
375
376def _sqlite_log(base, x):
377    if base is None or x is None:
378        return None
379    # Arguments reversed to match SQL standard.
380    return log(x, base)
381
382
383def _sqlite_lpad(text, length, fill_text):
384    if text is None or length is None or fill_text is None:
385        return None
386    delta = length - len(text)
387    if delta <= 0:
388        return text[:length]
389    return (fill_text * length)[:delta] + text
390
391
392def _sqlite_md5(text):
393    if text is None:
394        return None
395    return md5(text.encode()).hexdigest()
396
397
398def _sqlite_mod(x, y):
399    if x is None or y is None:
400        return None
401    return fmod(x, y)
402
403
404def _sqlite_pi():
405    return pi
406
407
408def _sqlite_power(x, y):
409    if x is None or y is None:
410        return None
411    return x**y
412
413
414def _sqlite_radians(x):
415    if x is None:
416        return None
417    return radians(x)
418
419
420def _sqlite_repeat(text, count):
421    if text is None or count is None:
422        return None
423    return text * count
424
425
426def _sqlite_reverse(text):
427    if text is None:
428        return None
429    return text[::-1]
430
431
432def _sqlite_rpad(text, length, fill_text):
433    if text is None or length is None or fill_text is None:
434        return None
435    return (text + fill_text * length)[:length]
436
437
438def _sqlite_sha1(text):
439    if text is None:
440        return None
441    return sha1(text.encode()).hexdigest()
442
443
444def _sqlite_sha224(text):
445    if text is None:
446        return None
447    return sha224(text.encode()).hexdigest()
448
449
450def _sqlite_sha256(text):
451    if text is None:
452        return None
453    return sha256(text.encode()).hexdigest()
454
455
456def _sqlite_sha384(text):
457    if text is None:
458        return None
459    return sha384(text.encode()).hexdigest()
460
461
462def _sqlite_sha512(text):
463    if text is None:
464        return None
465    return sha512(text.encode()).hexdigest()
466
467
468def _sqlite_sign(x):
469    if x is None:
470        return None
471    return (x > 0) - (x < 0)
472
473
474def _sqlite_sin(x):
475    if x is None:
476        return None
477    return sin(x)
478
479
480def _sqlite_sqrt(x):
481    if x is None:
482        return None
483    return sqrt(x)
484
485
486def _sqlite_tan(x):
487    if x is None:
488        return None
489    return tan(x)
490
491
492class ListAggregate(list):
493    step = list.append
494
495
496class StdDevPop(ListAggregate):
497    finalize = statistics.pstdev
498
499
500class StdDevSamp(ListAggregate):
501    finalize = statistics.stdev
502
503
504class VarPop(ListAggregate):
505    finalize = statistics.pvariance
506
507
508class VarSamp(ListAggregate):
509    finalize = statistics.variance