1from __future__ import annotations
2
3from base64 import b64decode
4from collections.abc import Callable, Sequence
5from typing import TYPE_CHECKING, Any
6
7import psycopg
8
9from plain.validators import MaxLengthValidator
10
11from .base import ColumnField
12
13if TYPE_CHECKING:
14 from plain.postgres.connection import DatabaseConnection
15 from plain.postgres.sql.compiler import SQLCompiler
16
17
18class BinaryField(ColumnField[bytes | memoryview]):
19 db_type_sql = "bytea"
20 empty_values = [None, b""]
21 _default_empty_value = b""
22
23 def __init__(
24 self,
25 *,
26 max_length: int | None = None,
27 required: bool = True,
28 allow_null: bool = False,
29 validators: Sequence[Callable[..., Any]] = (),
30 ):
31 # `default` is intentionally not accepted: a str default on a bytes
32 # field is a type mismatch.
33 self.max_length = max_length
34 super().__init__(
35 required=required,
36 allow_null=allow_null,
37 validators=validators,
38 )
39 if self.max_length is not None:
40 self.validators.append(MaxLengthValidator(self.max_length))
41
42 def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
43 name, path, args, kwargs = super().deconstruct()
44 if self.max_length is not None:
45 kwargs["max_length"] = self.max_length
46 return name, path, args, kwargs
47
48 def get_placeholder(
49 self, value: Any, compiler: SQLCompiler, connection: DatabaseConnection
50 ) -> Any:
51 return "%s"
52
53 def get_db_prep_value(
54 self, value: Any, connection: DatabaseConnection, prepared: bool = False
55 ) -> Any:
56 value = super().get_db_prep_value(value, connection, prepared)
57 if value is not None:
58 return psycopg.Binary(value)
59 return value
60
61 def to_python(self, value: Any) -> bytes | memoryview | None:
62 # If it's a string, it should be base64-encoded data
63 if isinstance(value, str):
64 return memoryview(b64decode(value.encode("ascii")))
65 return value