1from __future__ import annotations
 2
 3from base64 import b64decode
 4from collections.abc import Callable, Sequence
 5from typing import TYPE_CHECKING, Any
 6
 7import psycopg
 8
 9from plain.validators import MaxLengthValidator
10
11from .base import ColumnField
12
13if TYPE_CHECKING:
14    from plain.postgres.connection import DatabaseConnection
15    from plain.postgres.sql.compiler import SQLCompiler
16
17
18class BinaryField(ColumnField[bytes | memoryview]):
19    db_type_sql = "bytea"
20    empty_values = [None, b""]
21    _default_empty_value = b""
22
23    def __init__(
24        self,
25        *,
26        max_length: int | None = None,
27        required: bool = True,
28        allow_null: bool = False,
29        validators: Sequence[Callable[..., Any]] = (),
30    ):
31        # `default` is intentionally not accepted: a str default on a bytes
32        # field is a type mismatch.
33        self.max_length = max_length
34        super().__init__(
35            required=required,
36            allow_null=allow_null,
37            validators=validators,
38        )
39        if self.max_length is not None:
40            self.validators.append(MaxLengthValidator(self.max_length))
41
42    def deconstruct(self) -> tuple[str | None, str, list[Any], dict[str, Any]]:
43        name, path, args, kwargs = super().deconstruct()
44        if self.max_length is not None:
45            kwargs["max_length"] = self.max_length
46        return name, path, args, kwargs
47
48    def get_placeholder(
49        self, value: Any, compiler: SQLCompiler, connection: DatabaseConnection
50    ) -> Any:
51        return "%s"
52
53    def get_db_prep_value(
54        self, value: Any, connection: DatabaseConnection, prepared: bool = False
55    ) -> Any:
56        value = super().get_db_prep_value(value, connection, prepared)
57        if value is not None:
58            return psycopg.Binary(value)
59        return value
60
61    def to_python(self, value: Any) -> bytes | memoryview | None:
62        # If it's a string, it should be base64-encoded data
63        if isinstance(value, str):
64            return memoryview(b64decode(value.encode("ascii")))
65        return value