1from __future__ import annotations
2
3import copy
4from collections.abc import Callable, Iterable, Iterator, Mapping
5from typing import Any, TypeVar
6
7_KT = TypeVar("_KT")
8_VT = TypeVar("_VT")
9
10
11class OrderedSet:
12 """
13 A set which keeps the ordering of the inserted items.
14 """
15
16 def __init__(self, iterable: Iterable[Any] | None = None) -> None:
17 self.dict: dict[Any, None] = dict.fromkeys(iterable or ())
18
19 def add(self, item: Any) -> None:
20 self.dict[item] = None
21
22 def remove(self, item: Any) -> None:
23 del self.dict[item]
24
25 def discard(self, item: Any) -> None:
26 try:
27 self.remove(item)
28 except KeyError:
29 pass
30
31 def __iter__(self) -> Iterator[Any]:
32 return iter(self.dict)
33
34 def __reversed__(self) -> Iterator[Any]:
35 return reversed(self.dict)
36
37 def __contains__(self, item: Any) -> bool:
38 return item in self.dict
39
40 def __bool__(self) -> bool:
41 return bool(self.dict)
42
43 def __len__(self) -> int:
44 return len(self.dict)
45
46 def __repr__(self) -> str:
47 data = repr(list(self.dict)) if self.dict else ""
48 return f"{self.__class__.__qualname__}({data})"
49
50
51class MultiValueDictKeyError(KeyError):
52 pass
53
54
55class MultiValueDict(dict[str, list[Any]]):
56 """
57 A subclass of dictionary customized to handle multiple values for the
58 same key.
59
60 >>> d = MultiValueDict({'name': ['Adrian', 'Simon'], 'position': ['Developer']})
61 >>> d['name']
62 'Simon'
63 >>> d.getlist('name')
64 ['Adrian', 'Simon']
65 >>> d.getlist('doesnotexist')
66 []
67 >>> d.getlist('doesnotexist', ['Adrian', 'Simon'])
68 ['Adrian', 'Simon']
69 >>> d.get('lastname', 'nonexistent')
70 'nonexistent'
71 >>> d.setlist('lastname', ['Holovaty', 'Willison'])
72
73 This class exists to solve the irritating problem raised by cgi.parse_qs,
74 which returns a list for every key, even though most web forms submit
75 single name-value pairs.
76 """
77
78 def __init__(
79 self,
80 key_to_list_mapping: Mapping[str, list[Any]]
81 | Iterable[tuple[str, list[Any]]] = (),
82 ) -> None:
83 super().__init__(key_to_list_mapping)
84
85 def __repr__(self) -> str:
86 return f"<{self.__class__.__name__}: {super().__repr__()}>"
87
88 def __getitem__(self, key: str) -> Any:
89 """
90 Return the last data value for this key, or [] if it's an empty list;
91 raise KeyError if not found.
92 """
93 try:
94 list_ = super().__getitem__(key)
95 except KeyError:
96 raise MultiValueDictKeyError(key)
97 try:
98 return list_[-1]
99 except IndexError:
100 return []
101
102 def __setitem__(self, key: str, value: Any) -> None:
103 super().__setitem__(key, [value])
104
105 def __copy__(self) -> MultiValueDict:
106 return self.__class__([(k, v[:]) for k, v in self.lists()])
107
108 def __deepcopy__(self, memo: dict[int, Any]) -> MultiValueDict:
109 result = self.__class__()
110 memo[id(self)] = result
111 for key, value in dict.items(self):
112 dict.__setitem__(
113 result, copy.deepcopy(key, memo), copy.deepcopy(value, memo)
114 )
115 return result
116
117 def __getstate__(self) -> dict[str, Any]:
118 return {**self.__dict__, "_data": {k: self._getlist(k) for k in self}}
119
120 def __setstate__(self, obj_dict: dict[str, Any]) -> None:
121 data = obj_dict.pop("_data", {})
122 for k, v in data.items():
123 self.setlist(k, v)
124 self.__dict__.update(obj_dict)
125
126 def get(self, key: str, default: Any = None) -> Any:
127 """
128 Return the last data value for the passed key. If key doesn't exist
129 or value is an empty list, return `default`.
130 """
131 try:
132 val = self[key]
133 except KeyError:
134 return default
135 if val == []:
136 return default
137 return val
138
139 def _getlist(
140 self, key: str, default: list[Any] | None = None, force_list: bool = False
141 ) -> list[Any] | None:
142 """
143 Return a list of values for the key.
144
145 Used internally to manipulate values list. If force_list is True,
146 return a new copy of values.
147 """
148 try:
149 values = super().__getitem__(key)
150 except KeyError:
151 if default is None:
152 return []
153 return default
154 else:
155 if force_list:
156 values = list(values) if values is not None else None
157 return values
158
159 def getlist(self, key: str, default: list[Any] | None = None) -> list[Any]:
160 """
161 Return the list of values for the key. If key doesn't exist, return a
162 default value.
163 """
164 return self._getlist(key, default, force_list=True)
165
166 def setlist(self, key: str, list_: list[Any]) -> None:
167 super().__setitem__(key, list_)
168
169 def setdefault(self, key: str, default: Any = None) -> Any:
170 if key not in self:
171 self[key] = default
172 # Do not return default here because __setitem__() may store
173 # another value -- QueryDict.__setitem__() does. Look it up.
174 return self[key]
175
176 def setlistdefault(
177 self, key: str, default_list: list[Any] | None = None
178 ) -> list[Any]:
179 if key not in self:
180 if default_list is None:
181 default_list = []
182 self.setlist(key, default_list)
183 # Do not return default_list here because setlist() may store
184 # another value -- QueryDict.setlist() does. Look it up.
185 return self._getlist(key)
186
187 def appendlist(self, key: str, value: Any) -> None:
188 """Append an item to the internal list associated with key."""
189 self.setlistdefault(key).append(value)
190
191 def items(self) -> Iterator[tuple[str, Any]]:
192 """
193 Yield (key, value) pairs, where value is the last item in the list
194 associated with the key.
195 """
196 for key in self:
197 yield key, self[key]
198
199 def lists(self) -> Iterator[tuple[str, list[Any]]]:
200 """Yield (key, list) pairs."""
201 return iter(super().items())
202
203 def values(self) -> Iterator[Any]:
204 """Yield the last value on every key list."""
205 for key in self:
206 yield self[key]
207
208 def copy(self) -> MultiValueDict:
209 """Return a shallow copy of this object."""
210 return copy.copy(self)
211
212 def update(self, *args: Any, **kwargs: Any) -> None:
213 """Extend rather than replace existing key lists."""
214 if len(args) > 1:
215 raise TypeError(f"update expected at most 1 argument, got {len(args)}")
216 if args:
217 arg = args[0]
218 if isinstance(arg, MultiValueDict):
219 for key, value_list in arg.lists():
220 self.setlistdefault(key).extend(value_list)
221 else:
222 if isinstance(arg, Mapping):
223 arg = arg.items()
224 for key, value in arg:
225 self.setlistdefault(key).append(value)
226 for key, value in kwargs.items():
227 self.setlistdefault(key).append(value)
228
229 def dict(self) -> dict[str, Any]:
230 """Return current object as a dict with singular values."""
231 return {key: self[key] for key in self}
232
233
234class ImmutableList(tuple):
235 """
236 A tuple-like object that raises useful errors when it is asked to mutate.
237
238 Example::
239
240 >>> a = ImmutableList(range(5), warning="You cannot mutate this.")
241 >>> a[3] = '4'
242 Traceback (most recent call last):
243 ...
244 AttributeError: You cannot mutate this.
245 """
246
247 def __new__(
248 cls,
249 *args: Any,
250 warning: str = "ImmutableList object is immutable.",
251 **kwargs: Any,
252 ) -> ImmutableList:
253 self = tuple.__new__(cls, *args, **kwargs)
254 self.warning = warning
255 return self
256
257 def complain(self, *args: Any, **kwargs: Any) -> None:
258 raise AttributeError(self.warning)
259
260 # All list mutation functions complain.
261 __delitem__ = complain
262 __delslice__ = complain
263 __iadd__ = complain
264 __imul__ = complain
265 __setitem__ = complain
266 __setslice__ = complain
267 append = complain
268 extend = complain
269 insert = complain
270 pop = complain
271 remove = complain
272 sort = complain
273 reverse = complain
274
275
276class DictWrapper(dict[str, Any]):
277 """
278 Wrap accesses to a dictionary so that certain values (those starting with
279 the specified prefix) are passed through a function before being returned.
280 The prefix is removed before looking up the real value.
281
282 Used by the SQL construction code to ensure that values are correctly
283 quoted before being used.
284 """
285
286 def __init__(
287 self, data: dict[str, Any], func: Callable[[Any], Any], prefix: str
288 ) -> None:
289 super().__init__(data)
290 self.func = func
291 self.prefix = prefix
292
293 def __getitem__(self, key: str) -> Any:
294 """
295 Retrieve the real value after stripping the prefix string (if
296 present). If the prefix is present, pass the value through self.func
297 before returning, otherwise return the raw value.
298 """
299 use_func = key.startswith(self.prefix)
300 key = key.removeprefix(self.prefix)
301 value = super().__getitem__(key)
302 if use_func:
303 return self.func(value)
304 return value
305
306
307class CaseInsensitiveMapping(Mapping[str, Any]):
308 """
309 Mapping allowing case-insensitive key lookups. Original case of keys is
310 preserved for iteration and string representation.
311
312 Example::
313
314 >>> ci_map = CaseInsensitiveMapping({'name': 'Jane'})
315 >>> ci_map['Name']
316 Jane
317 >>> ci_map['NAME']
318 Jane
319 >>> ci_map['name']
320 Jane
321 >>> ci_map # original case preserved
322 {'name': 'Jane'}
323 """
324
325 def __init__(self, data: Mapping[str, Any] | Iterable[tuple[str, Any]]) -> None:
326 self._store: dict[str, tuple[str, Any]] = {
327 k.lower(): (k, v) for k, v in self._unpack_items(data)
328 }
329
330 def __getitem__(self, key: str) -> Any:
331 return self._store[key.lower()][1]
332
333 def __len__(self) -> int:
334 return len(self._store)
335
336 def __eq__(self, other: object) -> bool:
337 if not isinstance(other, Mapping):
338 return False
339 return {k.lower(): v for k, v in self.items()} == {
340 k.lower(): v for k, v in other.items() if isinstance(k, str)
341 }
342
343 def __iter__(self) -> Iterator[str]:
344 return (original_key for original_key, value in self._store.values())
345
346 def __repr__(self) -> str:
347 return repr(dict(self._store.values()))
348
349 def copy(self) -> CaseInsensitiveMapping:
350 return self
351
352 @staticmethod
353 def _unpack_items(
354 data: Mapping[str, Any] | Iterable[tuple[str, Any]],
355 ) -> Iterator[tuple[str, Any]]:
356 # Explicitly test for dict first as the common case for performance,
357 # avoiding abc's __instancecheck__ and _abc_instancecheck for the
358 # general Mapping case.
359 if isinstance(data, dict):
360 yield from data.items()
361 return
362 if isinstance(data, Mapping):
363 yield from data.items()
364 return
365 for i, elem in enumerate(data):
366 if len(elem) != 2:
367 raise ValueError(
368 f"dictionary update sequence element #{i} has length {len(elem)}; "
369 "2 is required."
370 )
371 if not isinstance(elem[0], str):
372 raise ValueError(
373 f"Element key {elem[0]!r} invalid, only strings are allowed"
374 )
375 yield elem