Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import builtins
  4import copy
  5from collections.abc import Callable, Iterable, Iterator, Mapping
  6from typing import Any, TypeVar
  7
  8_KT = TypeVar("_KT")
  9_VT = TypeVar("_VT")
 10
 11
 12class OrderedSet:
 13    """
 14    A set which keeps the ordering of the inserted items.
 15    """
 16
 17    def __init__(self, iterable: Iterable[Any] | None = None) -> None:
 18        self.dict: dict[Any, None] = dict.fromkeys(iterable or ())
 19
 20    def add(self, item: Any) -> None:
 21        self.dict[item] = None
 22
 23    def remove(self, item: Any) -> None:
 24        del self.dict[item]
 25
 26    def discard(self, item: Any) -> None:
 27        try:
 28            self.remove(item)
 29        except KeyError:
 30            pass
 31
 32    def __iter__(self) -> Iterator[Any]:
 33        return iter(self.dict)
 34
 35    def __reversed__(self) -> Iterator[Any]:
 36        return reversed(self.dict)
 37
 38    def __contains__(self, item: Any) -> bool:
 39        return item in self.dict
 40
 41    def __bool__(self) -> bool:
 42        return bool(self.dict)
 43
 44    def __len__(self) -> int:
 45        return len(self.dict)
 46
 47    def __repr__(self) -> str:
 48        data = repr(list(self.dict)) if self.dict else ""
 49        return f"{self.__class__.__qualname__}({data})"
 50
 51
 52class MultiValueDictKeyError(KeyError):
 53    pass
 54
 55
 56class MultiValueDict(dict[str, list[Any]]):
 57    """
 58    A subclass of dictionary customized to handle multiple values for the
 59    same key.
 60
 61    >>> d = MultiValueDict({'name': ['Adrian', 'Simon'], 'position': ['Developer']})
 62    >>> d['name']
 63    'Simon'
 64    >>> d.getlist('name')
 65    ['Adrian', 'Simon']
 66    >>> d.getlist('doesnotexist')
 67    []
 68    >>> d.getlist('doesnotexist', ['Adrian', 'Simon'])
 69    ['Adrian', 'Simon']
 70    >>> d.get('lastname', 'nonexistent')
 71    'nonexistent'
 72    >>> d.setlist('lastname', ['Holovaty', 'Willison'])
 73
 74    This class exists to solve the irritating problem raised by cgi.parse_qs,
 75    which returns a list for every key, even though most web forms submit
 76    single name-value pairs.
 77    """
 78
 79    def __init__(
 80        self,
 81        key_to_list_mapping: Mapping[str, list[Any]]
 82        | Iterable[tuple[str, list[Any]]] = (),
 83    ) -> None:
 84        super().__init__(key_to_list_mapping)
 85
 86    def __repr__(self) -> str:
 87        return f"<{self.__class__.__name__}: {super().__repr__()}>"
 88
 89    def __getitem__(self, key: str) -> Any:
 90        """
 91        Return the last data value for this key, or [] if it's an empty list;
 92        raise KeyError if not found.
 93        """
 94        try:
 95            list_ = super().__getitem__(key)
 96        except KeyError:
 97            raise MultiValueDictKeyError(key)
 98        try:
 99            return list_[-1]
100        except IndexError:
101            return []
102
103    def __setitem__(self, key: str, value: Any) -> None:
104        super().__setitem__(key, [value])
105
106    def __copy__(self) -> MultiValueDict:
107        return self.__class__([(k, v[:]) for k, v in self.lists()])
108
109    def __deepcopy__(self, memo: builtins.dict[int, Any]) -> MultiValueDict:
110        result = self.__class__()
111        memo[id(self)] = result
112        for key, value in dict.items(self):
113            dict.__setitem__(
114                result, copy.deepcopy(key, memo), copy.deepcopy(value, memo)
115            )
116        return result
117
118    def __getstate__(self) -> builtins.dict[str, Any]:
119        return {**self.__dict__, "_data": {k: self._getlist(k) for k in self}}
120
121    def __setstate__(self, obj_dict: builtins.dict[str, Any]) -> None:
122        data = obj_dict.pop("_data", {})
123        for k, v in data.items():
124            self.setlist(k, v)
125        self.__dict__.update(obj_dict)
126
127    def get(self, key: str, default: Any = None) -> Any:
128        """
129        Return the last data value for the passed key. If key doesn't exist
130        or value is an empty list, return `default`.
131        """
132        try:
133            val = self[key]
134        except KeyError:
135            return default
136        if val == []:
137            return default
138        return val
139
140    def _getlist(
141        self, key: str, default: list[Any] | None = None, force_list: bool = False
142    ) -> list[Any] | None:
143        """
144        Return a list of values for the key.
145
146        Used internally to manipulate values list. If force_list is True,
147        return a new copy of values.
148        """
149        try:
150            values = super().__getitem__(key)
151        except KeyError:
152            if default is None:
153                return []
154            return default
155        else:
156            if force_list:
157                values = list(values) if values is not None else None
158            return values
159
160    def getlist(self, key: str, default: list[Any] | None = None) -> list[Any]:
161        """
162        Return the list of values for the key. If key doesn't exist, return a
163        default value.
164        """
165        return self._getlist(key, default, force_list=True)  # type: ignore
166
167    def setlist(self, key: str, list_: list[Any]) -> None:
168        super().__setitem__(key, list_)
169
170    def setdefault(self, key: str, default: Any = None) -> Any:
171        if key not in self:
172            self[key] = default
173            # Do not return default here because __setitem__() may store
174            # another value -- QueryDict.__setitem__() does. Look it up.
175        return self[key]
176
177    def setlistdefault(
178        self, key: str, default_list: list[Any] | None = None
179    ) -> list[Any]:
180        if key not in self:
181            if default_list is None:
182                default_list = []
183            self.setlist(key, default_list)
184            # Do not return default_list here because setlist() may store
185            # another value -- QueryDict.setlist() does. Look it up.
186        return self._getlist(key)  # type: ignore[return-value]
187
188    def appendlist(self, key: str, value: Any) -> None:
189        """Append an item to the internal list associated with key."""
190        self.setlistdefault(key).append(value)
191
192    def items(self) -> Iterator[tuple[str, Any]]:  # type: ignore[override]
193        """
194        Yield (key, value) pairs, where value is the last item in the list
195        associated with the key.
196        """
197        for key in self:
198            yield key, self[key]
199
200    def lists(self) -> Iterator[tuple[str, list[Any]]]:
201        """Yield (key, list) pairs."""
202        return iter(super().items())
203
204    def values(self) -> Iterator[Any]:  # type: ignore[override]
205        """Yield the last value on every key list."""
206        for key in self:
207            yield self[key]
208
209    def copy(self) -> MultiValueDict:
210        """Return a shallow copy of this object."""
211        return copy.copy(self)
212
213    def update(self, *args: Any, **kwargs: Any) -> None:
214        """Extend rather than replace existing key lists."""
215        if len(args) > 1:
216            raise TypeError(f"update expected at most 1 argument, got {len(args)}")
217        if args:
218            arg = args[0]
219            if isinstance(arg, MultiValueDict):
220                for key, value_list in arg.lists():
221                    self.setlistdefault(key).extend(value_list)
222            else:
223                if isinstance(arg, Mapping):
224                    arg = arg.items()
225                for key, value in arg:
226                    self.setlistdefault(key).append(value)
227        for key, value in kwargs.items():
228            self.setlistdefault(key).append(value)
229
230    def dict(self) -> builtins.dict[str, Any]:
231        """Return current object as a dict with singular values."""
232        return {key: self[key] for key in self}
233
234
235class ImmutableList(tuple):
236    """
237    A tuple-like object that raises useful errors when it is asked to mutate.
238
239    Example::
240
241        >>> a = ImmutableList(range(5), warning="You cannot mutate this.")
242        >>> a[3] = '4'
243        Traceback (most recent call last):
244            ...
245        AttributeError: You cannot mutate this.
246    """
247
248    warning: str  # Set in __new__
249
250    def __new__(
251        cls,
252        *args: Any,
253        warning: str = "ImmutableList object is immutable.",
254        **kwargs: Any,
255    ) -> ImmutableList:
256        self = tuple.__new__(cls, *args, **kwargs)
257        self.warning = warning
258        return self
259
260    def complain(self, *args: Any, **kwargs: Any) -> None:
261        raise AttributeError(self.warning)
262
263    # All list mutation functions complain.
264    __delitem__ = complain
265    __delslice__ = complain
266    __iadd__ = complain
267    __imul__ = complain
268    __setitem__ = complain
269    __setslice__ = complain
270    append = complain
271    extend = complain
272    insert = complain
273    pop = complain
274    remove = complain
275    sort = complain
276    reverse = complain
277
278
279class DictWrapper(dict[str, Any]):
280    """
281    Wrap accesses to a dictionary so that certain values (those starting with
282    the specified prefix) are passed through a function before being returned.
283    The prefix is removed before looking up the real value.
284
285    Used by the SQL construction code to ensure that values are correctly
286    quoted before being used.
287    """
288
289    def __init__(
290        self, data: dict[str, Any], func: Callable[[Any], Any], prefix: str
291    ) -> None:
292        super().__init__(data)
293        self.func = func
294        self.prefix = prefix
295
296    def __getitem__(self, key: str) -> Any:
297        """
298        Retrieve the real value after stripping the prefix string (if
299        present). If the prefix is present, pass the value through self.func
300        before returning, otherwise return the raw value.
301        """
302        use_func = key.startswith(self.prefix)
303        key = key.removeprefix(self.prefix)
304        value = super().__getitem__(key)
305        if use_func:
306            return self.func(value)
307        return value
308
309
310class CaseInsensitiveMapping(Mapping[str, Any]):
311    """
312    Mapping allowing case-insensitive key lookups. Original case of keys is
313    preserved for iteration and string representation.
314
315    Example::
316
317        >>> ci_map = CaseInsensitiveMapping({'name': 'Jane'})
318        >>> ci_map['Name']
319        Jane
320        >>> ci_map['NAME']
321        Jane
322        >>> ci_map['name']
323        Jane
324        >>> ci_map  # original case preserved
325        {'name': 'Jane'}
326    """
327
328    def __init__(self, data: Mapping[str, Any] | Iterable[tuple[str, Any]]) -> None:
329        self._store: dict[str, tuple[str, Any]] = {
330            k.lower(): (k, v) for k, v in self._unpack_items(data)
331        }
332
333    def __getitem__(self, key: str) -> Any:
334        return self._store[key.lower()][1]
335
336    def __len__(self) -> int:
337        return len(self._store)
338
339    def __eq__(self, other: object) -> bool:
340        if not isinstance(other, Mapping):
341            return False
342        return {k.lower(): v for k, v in self.items()} == {
343            k.lower(): v for k, v in other.items() if isinstance(k, str)
344        }
345
346    def __iter__(self) -> Iterator[str]:
347        return (original_key for original_key, value in self._store.values())
348
349    def __repr__(self) -> str:
350        return repr(dict(self._store.values()))
351
352    def copy(self) -> CaseInsensitiveMapping:
353        return self
354
355    @staticmethod
356    def _unpack_items(
357        data: Mapping[str, Any] | Iterable[tuple[str, Any]],
358    ) -> Iterator[tuple[str, Any]]:
359        # Explicitly test for dict first as the common case for performance,
360        # avoiding abc's __instancecheck__ and _abc_instancecheck for the
361        # general Mapping case.
362        if isinstance(data, dict):
363            yield from data.items()
364            return
365        if isinstance(data, Mapping):
366            yield from data.items()
367            return
368        for i, elem in enumerate(data):
369            if len(elem) != 2:
370                raise ValueError(
371                    f"dictionary update sequence element #{i} has length {len(elem)}; "
372                    "2 is required."
373                )
374            if not isinstance(elem[0], str):
375                raise ValueError(
376                    f"Element key {elem[0]!r} invalid, only strings are allowed"
377                )
378            yield elem