Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import copy
  4from collections.abc import Callable, Iterable, Iterator, Mapping
  5from typing import Any, TypeVar
  6
  7_KT = TypeVar("_KT")
  8_VT = TypeVar("_VT")
  9
 10
 11class OrderedSet:
 12    """
 13    A set which keeps the ordering of the inserted items.
 14    """
 15
 16    def __init__(self, iterable: Iterable[Any] | None = None) -> None:
 17        self.dict: dict[Any, None] = dict.fromkeys(iterable or ())
 18
 19    def add(self, item: Any) -> None:
 20        self.dict[item] = None
 21
 22    def remove(self, item: Any) -> None:
 23        del self.dict[item]
 24
 25    def discard(self, item: Any) -> None:
 26        try:
 27            self.remove(item)
 28        except KeyError:
 29            pass
 30
 31    def __iter__(self) -> Iterator[Any]:
 32        return iter(self.dict)
 33
 34    def __reversed__(self) -> Iterator[Any]:
 35        return reversed(self.dict)
 36
 37    def __contains__(self, item: Any) -> bool:
 38        return item in self.dict
 39
 40    def __bool__(self) -> bool:
 41        return bool(self.dict)
 42
 43    def __len__(self) -> int:
 44        return len(self.dict)
 45
 46    def __repr__(self) -> str:
 47        data = repr(list(self.dict)) if self.dict else ""
 48        return f"{self.__class__.__qualname__}({data})"
 49
 50
 51class MultiValueDictKeyError(KeyError):
 52    pass
 53
 54
 55class MultiValueDict(dict[str, list[Any]]):
 56    """
 57    A subclass of dictionary customized to handle multiple values for the
 58    same key.
 59
 60    >>> d = MultiValueDict({'name': ['Adrian', 'Simon'], 'position': ['Developer']})
 61    >>> d['name']
 62    'Simon'
 63    >>> d.getlist('name')
 64    ['Adrian', 'Simon']
 65    >>> d.getlist('doesnotexist')
 66    []
 67    >>> d.getlist('doesnotexist', ['Adrian', 'Simon'])
 68    ['Adrian', 'Simon']
 69    >>> d.get('lastname', 'nonexistent')
 70    'nonexistent'
 71    >>> d.setlist('lastname', ['Holovaty', 'Willison'])
 72
 73    This class exists to solve the irritating problem raised by cgi.parse_qs,
 74    which returns a list for every key, even though most web forms submit
 75    single name-value pairs.
 76    """
 77
 78    def __init__(
 79        self,
 80        key_to_list_mapping: Mapping[str, list[Any]]
 81        | Iterable[tuple[str, list[Any]]] = (),
 82    ) -> None:
 83        super().__init__(key_to_list_mapping)
 84
 85    def __repr__(self) -> str:
 86        return f"<{self.__class__.__name__}: {super().__repr__()}>"
 87
 88    def __getitem__(self, key: str) -> Any:
 89        """
 90        Return the last data value for this key, or [] if it's an empty list;
 91        raise KeyError if not found.
 92        """
 93        try:
 94            list_ = super().__getitem__(key)
 95        except KeyError:
 96            raise MultiValueDictKeyError(key)
 97        try:
 98            return list_[-1]
 99        except IndexError:
100            return []
101
102    def __setitem__(self, key: str, value: Any) -> None:
103        super().__setitem__(key, [value])
104
105    def __copy__(self) -> MultiValueDict:
106        return self.__class__([(k, v[:]) for k, v in self.lists()])
107
108    def __deepcopy__(self, memo: dict[int, Any]) -> MultiValueDict:
109        result = self.__class__()
110        memo[id(self)] = result
111        for key, value in dict.items(self):
112            dict.__setitem__(
113                result, copy.deepcopy(key, memo), copy.deepcopy(value, memo)
114            )
115        return result
116
117    def __getstate__(self) -> dict[str, Any]:
118        return {**self.__dict__, "_data": {k: self._getlist(k) for k in self}}
119
120    def __setstate__(self, obj_dict: dict[str, Any]) -> None:
121        data = obj_dict.pop("_data", {})
122        for k, v in data.items():
123            self.setlist(k, v)
124        self.__dict__.update(obj_dict)
125
126    def get(self, key: str, default: Any = None) -> Any:
127        """
128        Return the last data value for the passed key. If key doesn't exist
129        or value is an empty list, return `default`.
130        """
131        try:
132            val = self[key]
133        except KeyError:
134            return default
135        if val == []:
136            return default
137        return val
138
139    def _getlist(
140        self, key: str, default: list[Any] | None = None, force_list: bool = False
141    ) -> list[Any] | None:
142        """
143        Return a list of values for the key.
144
145        Used internally to manipulate values list. If force_list is True,
146        return a new copy of values.
147        """
148        try:
149            values = super().__getitem__(key)
150        except KeyError:
151            if default is None:
152                return []
153            return default
154        else:
155            if force_list:
156                values = list(values) if values is not None else None
157            return values
158
159    def getlist(self, key: str, default: list[Any] | None = None) -> list[Any]:
160        """
161        Return the list of values for the key. If key doesn't exist, return a
162        default value.
163        """
164        return self._getlist(key, default, force_list=True)
165
166    def setlist(self, key: str, list_: list[Any]) -> None:
167        super().__setitem__(key, list_)
168
169    def setdefault(self, key: str, default: Any = None) -> Any:
170        if key not in self:
171            self[key] = default
172            # Do not return default here because __setitem__() may store
173            # another value -- QueryDict.__setitem__() does. Look it up.
174        return self[key]
175
176    def setlistdefault(
177        self, key: str, default_list: list[Any] | None = None
178    ) -> list[Any]:
179        if key not in self:
180            if default_list is None:
181                default_list = []
182            self.setlist(key, default_list)
183            # Do not return default_list here because setlist() may store
184            # another value -- QueryDict.setlist() does. Look it up.
185        return self._getlist(key)
186
187    def appendlist(self, key: str, value: Any) -> None:
188        """Append an item to the internal list associated with key."""
189        self.setlistdefault(key).append(value)
190
191    def items(self) -> Iterator[tuple[str, Any]]:
192        """
193        Yield (key, value) pairs, where value is the last item in the list
194        associated with the key.
195        """
196        for key in self:
197            yield key, self[key]
198
199    def lists(self) -> Iterator[tuple[str, list[Any]]]:
200        """Yield (key, list) pairs."""
201        return iter(super().items())
202
203    def values(self) -> Iterator[Any]:
204        """Yield the last value on every key list."""
205        for key in self:
206            yield self[key]
207
208    def copy(self) -> MultiValueDict:
209        """Return a shallow copy of this object."""
210        return copy.copy(self)
211
212    def update(self, *args: Any, **kwargs: Any) -> None:
213        """Extend rather than replace existing key lists."""
214        if len(args) > 1:
215            raise TypeError(f"update expected at most 1 argument, got {len(args)}")
216        if args:
217            arg = args[0]
218            if isinstance(arg, MultiValueDict):
219                for key, value_list in arg.lists():
220                    self.setlistdefault(key).extend(value_list)
221            else:
222                if isinstance(arg, Mapping):
223                    arg = arg.items()
224                for key, value in arg:
225                    self.setlistdefault(key).append(value)
226        for key, value in kwargs.items():
227            self.setlistdefault(key).append(value)
228
229    def dict(self) -> dict[str, Any]:
230        """Return current object as a dict with singular values."""
231        return {key: self[key] for key in self}
232
233
234class ImmutableList(tuple):
235    """
236    A tuple-like object that raises useful errors when it is asked to mutate.
237
238    Example::
239
240        >>> a = ImmutableList(range(5), warning="You cannot mutate this.")
241        >>> a[3] = '4'
242        Traceback (most recent call last):
243            ...
244        AttributeError: You cannot mutate this.
245    """
246
247    def __new__(
248        cls,
249        *args: Any,
250        warning: str = "ImmutableList object is immutable.",
251        **kwargs: Any,
252    ) -> ImmutableList:
253        self = tuple.__new__(cls, *args, **kwargs)
254        self.warning = warning
255        return self
256
257    def complain(self, *args: Any, **kwargs: Any) -> None:
258        raise AttributeError(self.warning)
259
260    # All list mutation functions complain.
261    __delitem__ = complain
262    __delslice__ = complain
263    __iadd__ = complain
264    __imul__ = complain
265    __setitem__ = complain
266    __setslice__ = complain
267    append = complain
268    extend = complain
269    insert = complain
270    pop = complain
271    remove = complain
272    sort = complain
273    reverse = complain
274
275
276class DictWrapper(dict[str, Any]):
277    """
278    Wrap accesses to a dictionary so that certain values (those starting with
279    the specified prefix) are passed through a function before being returned.
280    The prefix is removed before looking up the real value.
281
282    Used by the SQL construction code to ensure that values are correctly
283    quoted before being used.
284    """
285
286    def __init__(
287        self, data: dict[str, Any], func: Callable[[Any], Any], prefix: str
288    ) -> None:
289        super().__init__(data)
290        self.func = func
291        self.prefix = prefix
292
293    def __getitem__(self, key: str) -> Any:
294        """
295        Retrieve the real value after stripping the prefix string (if
296        present). If the prefix is present, pass the value through self.func
297        before returning, otherwise return the raw value.
298        """
299        use_func = key.startswith(self.prefix)
300        key = key.removeprefix(self.prefix)
301        value = super().__getitem__(key)
302        if use_func:
303            return self.func(value)
304        return value
305
306
307class CaseInsensitiveMapping(Mapping[str, Any]):
308    """
309    Mapping allowing case-insensitive key lookups. Original case of keys is
310    preserved for iteration and string representation.
311
312    Example::
313
314        >>> ci_map = CaseInsensitiveMapping({'name': 'Jane'})
315        >>> ci_map['Name']
316        Jane
317        >>> ci_map['NAME']
318        Jane
319        >>> ci_map['name']
320        Jane
321        >>> ci_map  # original case preserved
322        {'name': 'Jane'}
323    """
324
325    def __init__(self, data: Mapping[str, Any] | Iterable[tuple[str, Any]]) -> None:
326        self._store: dict[str, tuple[str, Any]] = {
327            k.lower(): (k, v) for k, v in self._unpack_items(data)
328        }
329
330    def __getitem__(self, key: str) -> Any:
331        return self._store[key.lower()][1]
332
333    def __len__(self) -> int:
334        return len(self._store)
335
336    def __eq__(self, other: object) -> bool:
337        if not isinstance(other, Mapping):
338            return False
339        return {k.lower(): v for k, v in self.items()} == {
340            k.lower(): v for k, v in other.items() if isinstance(k, str)
341        }
342
343    def __iter__(self) -> Iterator[str]:
344        return (original_key for original_key, value in self._store.values())
345
346    def __repr__(self) -> str:
347        return repr(dict(self._store.values()))
348
349    def copy(self) -> CaseInsensitiveMapping:
350        return self
351
352    @staticmethod
353    def _unpack_items(
354        data: Mapping[str, Any] | Iterable[tuple[str, Any]],
355    ) -> Iterator[tuple[str, Any]]:
356        # Explicitly test for dict first as the common case for performance,
357        # avoiding abc's __instancecheck__ and _abc_instancecheck for the
358        # general Mapping case.
359        if isinstance(data, dict):
360            yield from data.items()
361            return
362        if isinstance(data, Mapping):
363            yield from data.items()
364            return
365        for i, elem in enumerate(data):
366            if len(elem) != 2:
367                raise ValueError(
368                    f"dictionary update sequence element #{i} has length {len(elem)}; "
369                    "2 is required."
370                )
371            if not isinstance(elem[0], str):
372                raise ValueError(
373                    f"Element key {elem[0]!r} invalid, only strings are allowed"
374                )
375            yield elem