Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import Any
  4
  5from plain.models.expressions import Func
  6from plain.models.fields import Field, FloatField, IntegerField
  7
  8__all__ = [
  9    "CumeDist",
 10    "DenseRank",
 11    "FirstValue",
 12    "Lag",
 13    "LastValue",
 14    "Lead",
 15    "NthValue",
 16    "Ntile",
 17    "PercentRank",
 18    "Rank",
 19    "RowNumber",
 20]
 21
 22
 23class CumeDist(Func):
 24    function = "CUME_DIST"
 25    output_field = FloatField()
 26    window_compatible = True
 27
 28
 29class DenseRank(Func):
 30    function = "DENSE_RANK"
 31    output_field = IntegerField()
 32    window_compatible = True
 33
 34
 35class FirstValue(Func):
 36    arity = 1
 37    function = "FIRST_VALUE"
 38    window_compatible = True
 39
 40
 41class LagLeadFunction(Func):
 42    window_compatible = True
 43
 44    def __init__(
 45        self, expression: Any, offset: int = 1, default: Any = None, **extra: Any
 46    ) -> None:
 47        if expression is None:
 48            raise ValueError(
 49                f"{self.__class__.__name__} requires a non-null source expression."
 50            )
 51        if offset is None or offset <= 0:
 52            raise ValueError(
 53                f"{self.__class__.__name__} requires a positive integer for the offset."
 54            )
 55        args = (expression, offset)
 56        if default is not None:
 57            args += (default,)
 58        super().__init__(*args, **extra)
 59
 60    def _resolve_output_field(self) -> Field:
 61        sources = self.get_source_expressions()
 62        return sources[0].output_field
 63
 64
 65class Lag(LagLeadFunction):
 66    function = "LAG"
 67
 68
 69class LastValue(Func):
 70    arity = 1
 71    function = "LAST_VALUE"
 72    window_compatible = True
 73
 74
 75class Lead(LagLeadFunction):
 76    function = "LEAD"
 77
 78
 79class NthValue(Func):
 80    function = "NTH_VALUE"
 81    window_compatible = True
 82
 83    def __init__(self, expression: Any, nth: int = 1, **extra: Any) -> None:
 84        if expression is None:
 85            raise ValueError(
 86                f"{self.__class__.__name__} requires a non-null source expression."
 87            )
 88        if nth is None or nth <= 0:
 89            raise ValueError(
 90                f"{self.__class__.__name__} requires a positive integer as for nth."
 91            )
 92        super().__init__(expression, nth, **extra)
 93
 94    def _resolve_output_field(self) -> Field:
 95        sources = self.get_source_expressions()
 96        return sources[0].output_field
 97
 98
 99class Ntile(Func):
100    function = "NTILE"
101    output_field = IntegerField()
102    window_compatible = True
103
104    def __init__(self, num_buckets: int = 1, **extra: Any) -> None:
105        if num_buckets <= 0:
106            raise ValueError("num_buckets must be greater than 0.")
107        super().__init__(num_buckets, **extra)
108
109
110class PercentRank(Func):
111    function = "PERCENT_RANK"
112    output_field = FloatField()
113    window_compatible = True
114
115
116class Rank(Func):
117    function = "RANK"
118    output_field = IntegerField()
119    window_compatible = True
120
121
122class RowNumber(Func):
123    function = "ROW_NUMBER"
124    output_field = IntegerField()
125    window_compatible = True