Plain is headed towards 1.0! Subscribe for development updates →

  1import ipaddress
  2from functools import lru_cache
  3
  4try:
  5    from psycopg import ClientCursor, IsolationLevel, adapt, adapters, errors, sql
  6    from psycopg.postgres import types
  7    from psycopg.types.datetime import TimestamptzLoader
  8    from psycopg.types.json import Jsonb
  9    from psycopg.types.range import Range, RangeDumper
 10    from psycopg.types.string import TextLoader
 11
 12    Inet = ipaddress.ip_address
 13
 14    DateRange = DateTimeRange = DateTimeTZRange = NumericRange = Range
 15    RANGE_TYPES = (Range,)
 16
 17    TSRANGE_OID = types["tsrange"].oid
 18    TSTZRANGE_OID = types["tstzrange"].oid
 19
 20    def mogrify(sql, params, connection):
 21        return ClientCursor(connection.connection).mogrify(sql, params)
 22
 23    # Adapters.
 24    class BaseTzLoader(TimestamptzLoader):
 25        """
 26        Load a PostgreSQL timestamptz using the a specific timezone.
 27        The timezone can be None too, in which case it will be chopped.
 28        """
 29
 30        timezone = None
 31
 32        def load(self, data):
 33            res = super().load(data)
 34            return res.replace(tzinfo=self.timezone)
 35
 36    def register_tzloader(tz, context):
 37        class SpecificTzLoader(BaseTzLoader):
 38            timezone = tz
 39
 40        context.adapters.register_loader("timestamptz", SpecificTzLoader)
 41
 42    class PlainRangeDumper(RangeDumper):
 43        """A Range dumper customized for Plain."""
 44
 45        def upgrade(self, obj, format):
 46            # Dump ranges containing naive datetimes as tstzrange, because
 47            # Plain doesn't use tz-aware ones.
 48            dumper = super().upgrade(obj, format)
 49            if dumper is not self and dumper.oid == TSRANGE_OID:
 50                dumper.oid = TSTZRANGE_OID
 51            return dumper
 52
 53    @lru_cache
 54    def get_adapters_template(use_tz, timezone):
 55        # Create at adapters map extending the base one.
 56        ctx = adapt.AdaptersMap(adapters)
 57        # Register a no-op dumper to avoid a round trip from psycopg version 3
 58        # decode to json.dumps() to json.loads(), when using a custom decoder
 59        # in JSONField.
 60        ctx.register_loader("jsonb", TextLoader)
 61        # Don't convert automatically from PostgreSQL network types to Python
 62        # ipaddress.
 63        ctx.register_loader("inet", TextLoader)
 64        ctx.register_loader("cidr", TextLoader)
 65        ctx.register_dumper(Range, PlainRangeDumper)
 66        # Register a timestamptz loader configured on self.timezone.
 67        # This, however, can be overridden by create_cursor.
 68        register_tzloader(timezone, ctx)
 69        return ctx
 70
 71    is_psycopg3 = True
 72
 73except ImportError:
 74    from enum import IntEnum
 75
 76    from psycopg2 import errors, extensions, sql  # NOQA
 77    from psycopg2.extras import (  # NOQA  # NOQA
 78        DateRange,
 79        DateTimeRange,
 80        DateTimeTZRange,
 81        Inet,
 82        NumericRange,
 83        Range,
 84    )
 85    from psycopg2.extras import Json as Jsonb  # NOQA
 86
 87    RANGE_TYPES = (DateRange, DateTimeRange, DateTimeTZRange, NumericRange)
 88
 89    class IsolationLevel(IntEnum):
 90        READ_UNCOMMITTED = extensions.ISOLATION_LEVEL_READ_UNCOMMITTED
 91        READ_COMMITTED = extensions.ISOLATION_LEVEL_READ_COMMITTED
 92        REPEATABLE_READ = extensions.ISOLATION_LEVEL_REPEATABLE_READ
 93        SERIALIZABLE = extensions.ISOLATION_LEVEL_SERIALIZABLE
 94
 95    def _quote(value, connection=None):
 96        adapted = extensions.adapt(value)
 97        if hasattr(adapted, "encoding"):
 98            adapted.encoding = "utf8"
 99        # getquoted() returns a quoted bytestring of the adapted value.
100        return adapted.getquoted().decode()
101
102    sql.quote = _quote
103
104    def mogrify(sql, params, connection):
105        with connection.cursor() as cursor:
106            return cursor.mogrify(sql, params).decode()
107
108    is_psycopg3 = False