1from __future__ import annotations
2
3from typing import Any
4
5from plain.models.expressions import Func
6from plain.models.fields import Field, FloatField, IntegerField
7
8__all__ = [
9 "CumeDist",
10 "DenseRank",
11 "FirstValue",
12 "Lag",
13 "LastValue",
14 "Lead",
15 "NthValue",
16 "Ntile",
17 "PercentRank",
18 "Rank",
19 "RowNumber",
20]
21
22
23class CumeDist(Func):
24 function = "CUME_DIST"
25 output_field = FloatField()
26 window_compatible = True
27
28
29class DenseRank(Func):
30 function = "DENSE_RANK"
31 output_field = IntegerField()
32 window_compatible = True
33
34
35class FirstValue(Func):
36 arity = 1
37 function = "FIRST_VALUE"
38 window_compatible = True
39
40
41class LagLeadFunction(Func):
42 window_compatible = True
43
44 def __init__(
45 self, expression: Any, offset: int = 1, default: Any = None, **extra: Any
46 ) -> None:
47 if expression is None:
48 raise ValueError(
49 f"{self.__class__.__name__} requires a non-null source expression."
50 )
51 if offset is None or offset <= 0:
52 raise ValueError(
53 f"{self.__class__.__name__} requires a positive integer for the offset."
54 )
55 args = (expression, offset)
56 if default is not None:
57 args += (default,)
58 super().__init__(*args, **extra)
59
60 def _resolve_output_field(self) -> Field:
61 sources = self.get_source_expressions()
62 return sources[0].output_field
63
64
65class Lag(LagLeadFunction):
66 function = "LAG"
67
68
69class LastValue(Func):
70 arity = 1
71 function = "LAST_VALUE"
72 window_compatible = True
73
74
75class Lead(LagLeadFunction):
76 function = "LEAD"
77
78
79class NthValue(Func):
80 function = "NTH_VALUE"
81 window_compatible = True
82
83 def __init__(self, expression: Any, nth: int = 1, **extra: Any) -> None:
84 if expression is None:
85 raise ValueError(
86 f"{self.__class__.__name__} requires a non-null source expression."
87 )
88 if nth is None or nth <= 0:
89 raise ValueError(
90 f"{self.__class__.__name__} requires a positive integer as for nth."
91 )
92 super().__init__(expression, nth, **extra)
93
94 def _resolve_output_field(self) -> Field:
95 sources = self.get_source_expressions()
96 return sources[0].output_field
97
98
99class Ntile(Func):
100 function = "NTILE"
101 output_field = IntegerField()
102 window_compatible = True
103
104 def __init__(self, num_buckets: int = 1, **extra: Any) -> None:
105 if num_buckets <= 0:
106 raise ValueError("num_buckets must be greater than 0.")
107 super().__init__(num_buckets, **extra)
108
109
110class PercentRank(Func):
111 function = "PERCENT_RANK"
112 output_field = FloatField()
113 window_compatible = True
114
115
116class Rank(Func):
117 function = "RANK"
118 output_field = IntegerField()
119 window_compatible = True
120
121
122class RowNumber(Func):
123 function = "ROW_NUMBER"
124 output_field = IntegerField()
125 window_compatible = True