1from __future__ import annotations
2
3import builtins
4import copy
5from collections.abc import Callable, Iterable, Iterator, Mapping
6from typing import Any, TypeVar
7
8_KT = TypeVar("_KT")
9_VT = TypeVar("_VT")
10
11
12class OrderedSet:
13 """
14 A set which keeps the ordering of the inserted items.
15 """
16
17 def __init__(self, iterable: Iterable[Any] | None = None) -> None:
18 self.dict: dict[Any, None] = dict.fromkeys(iterable or ())
19
20 def add(self, item: Any) -> None:
21 self.dict[item] = None
22
23 def remove(self, item: Any) -> None:
24 del self.dict[item]
25
26 def discard(self, item: Any) -> None:
27 try:
28 self.remove(item)
29 except KeyError:
30 pass
31
32 def __iter__(self) -> Iterator[Any]:
33 return iter(self.dict)
34
35 def __reversed__(self) -> Iterator[Any]:
36 return reversed(self.dict)
37
38 def __contains__(self, item: Any) -> bool:
39 return item in self.dict
40
41 def __bool__(self) -> bool:
42 return bool(self.dict)
43
44 def __len__(self) -> int:
45 return len(self.dict)
46
47 def __repr__(self) -> str:
48 data = repr(list(self.dict)) if self.dict else ""
49 return f"{self.__class__.__qualname__}({data})"
50
51
52class MultiValueDictKeyError(KeyError):
53 pass
54
55
56class MultiValueDict(dict[str, list[Any]]):
57 """
58 A subclass of dictionary customized to handle multiple values for the
59 same key.
60
61 >>> d = MultiValueDict({'name': ['Adrian', 'Simon'], 'position': ['Developer']})
62 >>> d['name']
63 'Simon'
64 >>> d.getlist('name')
65 ['Adrian', 'Simon']
66 >>> d.getlist('doesnotexist')
67 []
68 >>> d.getlist('doesnotexist', ['Adrian', 'Simon'])
69 ['Adrian', 'Simon']
70 >>> d.get('lastname', 'nonexistent')
71 'nonexistent'
72 >>> d.setlist('lastname', ['Holovaty', 'Willison'])
73
74 This class exists to solve the irritating problem raised by cgi.parse_qs,
75 which returns a list for every key, even though most web forms submit
76 single name-value pairs.
77 """
78
79 def __init__(
80 self,
81 key_to_list_mapping: Mapping[str, list[Any]]
82 | Iterable[tuple[str, list[Any]]] = (),
83 ) -> None:
84 super().__init__(key_to_list_mapping)
85
86 def __repr__(self) -> str:
87 return f"<{self.__class__.__name__}: {super().__repr__()}>"
88
89 def __getitem__(self, key: str) -> Any:
90 """
91 Return the last data value for this key, or [] if it's an empty list;
92 raise KeyError if not found.
93 """
94 try:
95 list_ = super().__getitem__(key)
96 except KeyError:
97 raise MultiValueDictKeyError(key)
98 try:
99 return list_[-1]
100 except IndexError:
101 return []
102
103 def __setitem__(self, key: str, value: Any) -> None:
104 super().__setitem__(key, [value])
105
106 def __copy__(self) -> MultiValueDict:
107 return self.__class__([(k, v[:]) for k, v in self.lists()])
108
109 def __deepcopy__(self, memo: builtins.dict[int, Any]) -> MultiValueDict:
110 result = self.__class__()
111 memo[id(self)] = result
112 for key, value in dict.items(self):
113 dict.__setitem__(
114 result, copy.deepcopy(key, memo), copy.deepcopy(value, memo)
115 )
116 return result
117
118 def __getstate__(self) -> builtins.dict[str, Any]:
119 return {**self.__dict__, "_data": {k: self._getlist(k) for k in self}}
120
121 def __setstate__(self, obj_dict: builtins.dict[str, Any]) -> None:
122 data = obj_dict.pop("_data", {})
123 for k, v in data.items():
124 self.setlist(k, v)
125 self.__dict__.update(obj_dict)
126
127 def get(self, key: str, default: Any = None) -> Any:
128 """
129 Return the last data value for the passed key. If key doesn't exist
130 or value is an empty list, return `default`.
131 """
132 try:
133 val = self[key]
134 except KeyError:
135 return default
136 if val == []:
137 return default
138 return val
139
140 def _getlist(
141 self, key: str, default: list[Any] | None = None, force_list: bool = False
142 ) -> list[Any] | None:
143 """
144 Return a list of values for the key.
145
146 Used internally to manipulate values list. If force_list is True,
147 return a new copy of values.
148 """
149 try:
150 values = super().__getitem__(key)
151 except KeyError:
152 if default is None:
153 return []
154 return default
155 else:
156 if force_list:
157 values = list(values) if values is not None else None
158 return values
159
160 def getlist(self, key: str, default: list[Any] | None = None) -> list[Any]:
161 """
162 Return the list of values for the key. If key doesn't exist, return a
163 default value.
164 """
165 return self._getlist(key, default, force_list=True) # type: ignore
166
167 def setlist(self, key: str, list_: list[Any]) -> None:
168 super().__setitem__(key, list_)
169
170 def setdefault(self, key: str, default: Any = None) -> Any:
171 if key not in self:
172 self[key] = default
173 # Do not return default here because __setitem__() may store
174 # another value -- QueryDict.__setitem__() does. Look it up.
175 return self[key]
176
177 def setlistdefault(
178 self, key: str, default_list: list[Any] | None = None
179 ) -> list[Any]:
180 if key not in self:
181 if default_list is None:
182 default_list = []
183 self.setlist(key, default_list)
184 # Do not return default_list here because setlist() may store
185 # another value -- QueryDict.setlist() does. Look it up.
186 return self._getlist(key) # type: ignore[return-value]
187
188 def appendlist(self, key: str, value: Any) -> None:
189 """Append an item to the internal list associated with key."""
190 self.setlistdefault(key).append(value)
191
192 def items(self) -> Iterator[tuple[str, Any]]: # type: ignore[override]
193 """
194 Yield (key, value) pairs, where value is the last item in the list
195 associated with the key.
196 """
197 for key in self:
198 yield key, self[key]
199
200 def lists(self) -> Iterator[tuple[str, list[Any]]]:
201 """Yield (key, list) pairs."""
202 return iter(super().items())
203
204 def values(self) -> Iterator[Any]: # type: ignore[override]
205 """Yield the last value on every key list."""
206 for key in self:
207 yield self[key]
208
209 def copy(self) -> MultiValueDict:
210 """Return a shallow copy of this object."""
211 return copy.copy(self)
212
213 def update(self, *args: Any, **kwargs: Any) -> None:
214 """Extend rather than replace existing key lists."""
215 if len(args) > 1:
216 raise TypeError(f"update expected at most 1 argument, got {len(args)}")
217 if args:
218 arg = args[0]
219 if isinstance(arg, MultiValueDict):
220 for key, value_list in arg.lists():
221 self.setlistdefault(key).extend(value_list)
222 else:
223 if isinstance(arg, Mapping):
224 arg = arg.items()
225 for key, value in arg:
226 self.setlistdefault(key).append(value)
227 for key, value in kwargs.items():
228 self.setlistdefault(key).append(value)
229
230 def dict(self) -> builtins.dict[str, Any]:
231 """Return current object as a dict with singular values."""
232 return {key: self[key] for key in self}
233
234
235class ImmutableList(tuple):
236 """
237 A tuple-like object that raises useful errors when it is asked to mutate.
238
239 Example::
240
241 >>> a = ImmutableList(range(5), warning="You cannot mutate this.")
242 >>> a[3] = '4'
243 Traceback (most recent call last):
244 ...
245 AttributeError: You cannot mutate this.
246 """
247
248 warning: str # Set in __new__
249
250 def __new__(
251 cls,
252 *args: Any,
253 warning: str = "ImmutableList object is immutable.",
254 **kwargs: Any,
255 ) -> ImmutableList:
256 self = tuple.__new__(cls, *args, **kwargs)
257 self.warning = warning
258 return self
259
260 def complain(self, *args: Any, **kwargs: Any) -> None:
261 raise AttributeError(self.warning)
262
263 # All list mutation functions complain.
264 __delitem__ = complain
265 __delslice__ = complain
266 __iadd__ = complain
267 __imul__ = complain
268 __setitem__ = complain
269 __setslice__ = complain
270 append = complain
271 extend = complain
272 insert = complain
273 pop = complain
274 remove = complain
275 sort = complain
276 reverse = complain
277
278
279class DictWrapper(dict[str, Any]):
280 """
281 Wrap accesses to a dictionary so that certain values (those starting with
282 the specified prefix) are passed through a function before being returned.
283 The prefix is removed before looking up the real value.
284
285 Used by the SQL construction code to ensure that values are correctly
286 quoted before being used.
287 """
288
289 def __init__(
290 self, data: dict[str, Any], func: Callable[[Any], Any], prefix: str
291 ) -> None:
292 super().__init__(data)
293 self.func = func
294 self.prefix = prefix
295
296 def __getitem__(self, key: str) -> Any:
297 """
298 Retrieve the real value after stripping the prefix string (if
299 present). If the prefix is present, pass the value through self.func
300 before returning, otherwise return the raw value.
301 """
302 use_func = key.startswith(self.prefix)
303 key = key.removeprefix(self.prefix)
304 value = super().__getitem__(key)
305 if use_func:
306 return self.func(value)
307 return value
308
309
310class CaseInsensitiveMapping(Mapping[str, Any]):
311 """
312 Mapping allowing case-insensitive key lookups. Original case of keys is
313 preserved for iteration and string representation.
314
315 Example::
316
317 >>> ci_map = CaseInsensitiveMapping({'name': 'Jane'})
318 >>> ci_map['Name']
319 Jane
320 >>> ci_map['NAME']
321 Jane
322 >>> ci_map['name']
323 Jane
324 >>> ci_map # original case preserved
325 {'name': 'Jane'}
326 """
327
328 def __init__(self, data: Mapping[str, Any] | Iterable[tuple[str, Any]]) -> None:
329 self._store: dict[str, tuple[str, Any]] = {
330 k.lower(): (k, v) for k, v in self._unpack_items(data)
331 }
332
333 def __getitem__(self, key: str) -> Any:
334 return self._store[key.lower()][1]
335
336 def __len__(self) -> int:
337 return len(self._store)
338
339 def __eq__(self, other: object) -> bool:
340 if not isinstance(other, Mapping):
341 return False
342 return {k.lower(): v for k, v in self.items()} == {
343 k.lower(): v for k, v in other.items() if isinstance(k, str)
344 }
345
346 def __iter__(self) -> Iterator[str]:
347 return (original_key for original_key, value in self._store.values())
348
349 def __repr__(self) -> str:
350 return repr(dict(self._store.values()))
351
352 def copy(self) -> CaseInsensitiveMapping:
353 return self
354
355 @staticmethod
356 def _unpack_items(
357 data: Mapping[str, Any] | Iterable[tuple[str, Any]],
358 ) -> Iterator[tuple[str, Any]]:
359 # Explicitly test for dict first as the common case for performance,
360 # avoiding abc's __instancecheck__ and _abc_instancecheck for the
361 # general Mapping case.
362 if isinstance(data, dict):
363 yield from data.items()
364 return
365 if isinstance(data, Mapping):
366 yield from data.items()
367 return
368 for i, elem in enumerate(data):
369 if len(elem) != 2:
370 raise ValueError(
371 f"dictionary update sequence element #{i} has length {len(elem)}; "
372 "2 is required."
373 )
374 if not isinstance(elem[0], str):
375 raise ValueError(
376 f"Element key {elem[0]!r} invalid, only strings are allowed"
377 )
378 yield elem