1from __future__ import annotations
2
3import builtins
4import copy
5from collections.abc import Callable, Iterable, Iterator, Mapping
6from typing import Any, TypeVar
7
8_KT = TypeVar("_KT")
9_VT = TypeVar("_VT")
10
11
12class OrderedSet:
13 """
14 A set which keeps the ordering of the inserted items.
15 """
16
17 def __init__(self, iterable: Iterable[Any] | None = None) -> None:
18 self.dict: dict[Any, None] = dict.fromkeys(iterable or ())
19
20 def add(self, item: Any) -> None:
21 self.dict[item] = None
22
23 def remove(self, item: Any) -> None:
24 del self.dict[item]
25
26 def discard(self, item: Any) -> None:
27 try:
28 self.remove(item)
29 except KeyError:
30 pass
31
32 def __iter__(self) -> Iterator[Any]:
33 return iter(self.dict)
34
35 def __reversed__(self) -> Iterator[Any]:
36 return reversed(self.dict)
37
38 def __contains__(self, item: Any) -> bool:
39 return item in self.dict
40
41 def __bool__(self) -> bool:
42 return bool(self.dict)
43
44 def __len__(self) -> int:
45 return len(self.dict)
46
47 def __repr__(self) -> str:
48 data = repr(list(self.dict)) if self.dict else ""
49 return f"{self.__class__.__qualname__}({data})"
50
51
52class MultiValueDictKeyError(KeyError):
53 pass
54
55
56class MultiValueDict(dict[str, list[Any]]):
57 """
58 A subclass of dictionary customized to handle multiple values for the
59 same key.
60
61 >>> d = MultiValueDict({'name': ['Adrian', 'Simon'], 'position': ['Developer']})
62 >>> d['name']
63 'Simon'
64 >>> d.getlist('name')
65 ['Adrian', 'Simon']
66 >>> d.getlist('doesnotexist')
67 []
68 >>> d.getlist('doesnotexist', ['Adrian', 'Simon'])
69 ['Adrian', 'Simon']
70 >>> d.get('lastname', 'nonexistent')
71 'nonexistent'
72 >>> d.setlist('lastname', ['Holovaty', 'Willison'])
73
74 This class exists to solve the irritating problem raised by cgi.parse_qs,
75 which returns a list for every key, even though most web forms submit
76 single name-value pairs.
77 """
78
79 def __init__(
80 self,
81 key_to_list_mapping: Mapping[str, list[Any]]
82 | Iterable[tuple[str, list[Any]]] = (),
83 ) -> None:
84 super().__init__(key_to_list_mapping)
85
86 def __repr__(self) -> str:
87 return f"<{self.__class__.__name__}: {super().__repr__()}>"
88
89 def __getitem__(self, key: str) -> Any:
90 """
91 Return the last data value for this key, or [] if it's an empty list;
92 raise KeyError if not found.
93 """
94 try:
95 list_ = super().__getitem__(key)
96 except KeyError:
97 raise MultiValueDictKeyError(key)
98 try:
99 return list_[-1]
100 except IndexError:
101 return []
102
103 def __setitem__(self, key: str, value: Any) -> None:
104 super().__setitem__(key, [value])
105
106 def __copy__(self) -> MultiValueDict:
107 return self.__class__([(k, v[:]) for k, v in self.lists()])
108
109 def __deepcopy__(self, memo: builtins.dict[int, Any]) -> MultiValueDict:
110 result = self.__class__()
111 memo[id(self)] = result
112 for key, value in dict.items(self):
113 dict.__setitem__(
114 result, copy.deepcopy(key, memo), copy.deepcopy(value, memo)
115 )
116 return result
117
118 def __getstate__(self) -> builtins.dict[str, Any]:
119 return {**self.__dict__, "_data": {k: self._getlist(k) for k in self}}
120
121 def __setstate__(self, obj_dict: builtins.dict[str, Any]) -> None:
122 data = obj_dict.pop("_data", {})
123 for k, v in data.items():
124 self.setlist(k, v)
125 self.__dict__.update(obj_dict)
126
127 def get(self, key: str, default: Any = None) -> Any:
128 """
129 Return the last data value for the passed key. If key doesn't exist
130 or value is an empty list, return `default`.
131 """
132 try:
133 val = self[key]
134 except KeyError:
135 return default
136 if val == []:
137 return default
138 return val
139
140 def _getlist(
141 self, key: str, default: list[Any] | None = None, force_list: bool = False
142 ) -> list[Any] | None:
143 """
144 Return a list of values for the key.
145
146 Used internally to manipulate values list. If force_list is True,
147 return a new copy of values.
148 """
149 try:
150 values = super().__getitem__(key)
151 except KeyError:
152 if default is None:
153 return []
154 return default
155 else:
156 if force_list:
157 values = list(values) if values is not None else None
158 return values
159
160 def getlist(self, key: str, default: list[Any] | None = None) -> list[Any]:
161 """
162 Return the list of values for the key. If key doesn't exist, return a
163 default value.
164 """
165 return self._getlist(key, default, force_list=True) # type: ignore
166
167 def setlist(self, key: str, list_: list[Any]) -> None:
168 super().__setitem__(key, list_)
169
170 def setdefault(self, key: str, default: Any = None) -> Any:
171 if key not in self:
172 self[key] = default
173 # Do not return default here because __setitem__() may store
174 # another value -- QueryDict.__setitem__() does. Look it up.
175 return self[key]
176
177 def setlistdefault(
178 self, key: str, default_list: list[Any] | None = None
179 ) -> list[Any]:
180 if key not in self:
181 if default_list is None:
182 default_list = []
183 self.setlist(key, default_list)
184 # Do not return default_list here because setlist() may store
185 # another value -- QueryDict.setlist() does. Look it up.
186 result = self._getlist(key)
187 assert result is not None
188 return result
189
190 def appendlist(self, key: str, value: Any) -> None:
191 """Append an item to the internal list associated with key."""
192 self.setlistdefault(key).append(value)
193
194 def items(self) -> Iterator[tuple[str, Any]]: # ty: ignore[invalid-method-override]
195 """
196 Yield (key, value) pairs, where value is the last item in the list
197 associated with the key.
198 """
199 for key in self:
200 yield key, self[key]
201
202 def lists(self) -> Iterator[tuple[str, list[Any]]]:
203 """Yield (key, list) pairs."""
204 return iter(super().items())
205
206 def values(self) -> Iterator[Any]: # ty: ignore[invalid-method-override]
207 """Yield the last value on every key list."""
208 for key in self:
209 yield self[key]
210
211 def copy(self) -> MultiValueDict:
212 """Return a shallow copy of this object."""
213 return copy.copy(self)
214
215 def update(self, *args: Any, **kwargs: Any) -> None:
216 """Extend rather than replace existing key lists."""
217 if len(args) > 1:
218 raise TypeError(f"update expected at most 1 argument, got {len(args)}")
219 if args:
220 arg = args[0]
221 if isinstance(arg, MultiValueDict):
222 for key, value_list in arg.lists():
223 self.setlistdefault(key).extend(value_list)
224 else:
225 if isinstance(arg, Mapping):
226 arg = arg.items()
227 for key, value in arg:
228 self.setlistdefault(key).append(value)
229 for key, value in kwargs.items():
230 self.setlistdefault(key).append(value)
231
232 def dict(self) -> builtins.dict[str, Any]:
233 """Return current object as a dict with singular values."""
234 return {key: self[key] for key in self}
235
236
237class ImmutableList(tuple):
238 """
239 A tuple-like object that raises useful errors when it is asked to mutate.
240
241 Example::
242
243 >>> a = ImmutableList(range(5), warning="You cannot mutate this.")
244 >>> a[3] = '4'
245 Traceback (most recent call last):
246 ...
247 AttributeError: You cannot mutate this.
248 """
249
250 warning: str # Set in __new__
251
252 def __new__(
253 cls,
254 *args: Any,
255 warning: str = "ImmutableList object is immutable.",
256 **kwargs: Any,
257 ) -> ImmutableList:
258 self = tuple.__new__(cls, *args, **kwargs)
259 self.warning = warning
260 return self
261
262 def complain(self, *args: Any, **kwargs: Any) -> None:
263 raise AttributeError(self.warning)
264
265 # All list mutation functions complain.
266 __delitem__ = complain
267 __delslice__ = complain
268 __iadd__ = complain
269 __imul__ = complain
270 __setitem__ = complain
271 __setslice__ = complain
272 append = complain
273 extend = complain
274 insert = complain
275 pop = complain
276 remove = complain
277 sort = complain
278 reverse = complain
279
280
281class DictWrapper(dict[str, Any]):
282 """
283 Wrap accesses to a dictionary so that certain values (those starting with
284 the specified prefix) are passed through a function before being returned.
285 The prefix is removed before looking up the real value.
286
287 Used by the SQL construction code to ensure that values are correctly
288 quoted before being used.
289 """
290
291 def __init__(
292 self, data: dict[str, Any], func: Callable[[Any], Any], prefix: str
293 ) -> None:
294 super().__init__(data)
295 self.func = func
296 self.prefix = prefix
297
298 def __getitem__(self, key: str) -> Any:
299 """
300 Retrieve the real value after stripping the prefix string (if
301 present). If the prefix is present, pass the value through self.func
302 before returning, otherwise return the raw value.
303 """
304 use_func = key.startswith(self.prefix)
305 key = key.removeprefix(self.prefix)
306 value = super().__getitem__(key)
307 if use_func:
308 return self.func(value)
309 return value
310
311
312class CaseInsensitiveMapping(Mapping[str, Any]):
313 """
314 Mapping allowing case-insensitive key lookups. Original case of keys is
315 preserved for iteration and string representation.
316
317 Example::
318
319 >>> ci_map = CaseInsensitiveMapping({'name': 'Jane'})
320 >>> ci_map['Name']
321 Jane
322 >>> ci_map['NAME']
323 Jane
324 >>> ci_map['name']
325 Jane
326 >>> ci_map # original case preserved
327 {'name': 'Jane'}
328 """
329
330 def __init__(self, data: Mapping[str, Any] | Iterable[tuple[str, Any]]) -> None:
331 self._store: dict[str, tuple[str, Any]] = {
332 k.lower(): (k, v) for k, v in self._unpack_items(data)
333 }
334
335 def __getitem__(self, key: str) -> Any:
336 return self._store[key.lower()][1]
337
338 def __len__(self) -> int:
339 return len(self._store)
340
341 def __eq__(self, other: object) -> bool:
342 if not isinstance(other, Mapping):
343 return False
344 return {k.lower(): v for k, v in self.items()} == {
345 k.lower(): v for k, v in other.items() if isinstance(k, str)
346 }
347
348 def __iter__(self) -> Iterator[str]:
349 return (original_key for original_key, value in self._store.values())
350
351 def __repr__(self) -> str:
352 return repr(dict(self._store.values()))
353
354 def copy(self) -> CaseInsensitiveMapping:
355 return self
356
357 @staticmethod
358 def _unpack_items(
359 data: Mapping[str, Any] | Iterable[tuple[str, Any]],
360 ) -> Iterator[tuple[str, Any]]:
361 # Explicitly test for dict first as the common case for performance,
362 # avoiding abc's __instancecheck__ and _abc_instancecheck for the
363 # general Mapping case.
364 if isinstance(data, dict):
365 yield from data.items() # ty: ignore[invalid-yield]
366 return
367 if isinstance(data, Mapping):
368 yield from data.items() # ty: ignore[invalid-yield]
369 return
370 for i, elem in enumerate(data):
371 if len(elem) != 2:
372 raise ValueError(
373 f"dictionary update sequence element #{i} has length {len(elem)}; "
374 "2 is required."
375 )
376 if not isinstance(elem[0], str):
377 raise ValueError(
378 f"Element key {elem[0]!r} invalid, only strings are allowed"
379 )
380 yield elem