1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5from plain import postgres
  6from plain.postgres.db import DatabaseError
  7from plain.postgres.meta import Meta
  8from plain.postgres.registry import ModelsRegistry
  9from plain.utils.functional import classproperty
 10from plain.utils.timezone import now
 11
 12from .exceptions import MigrationSchemaMissing
 13
 14MIGRATION_TABLE_NAME = "plainmigrations"
 15
 16if TYPE_CHECKING:
 17    from plain.postgres.connection import DatabaseConnection
 18
 19
 20class MigrationRecorder:
 21    """
 22    Deal with storing migration records in the database.
 23
 24    Because this table is actually itself used for dealing with model
 25    creation, it's the one thing we can't do normally via migrations.
 26    We manually handle table creation/schema updating (using schema backend)
 27    and then have a floating model to do queries with.
 28
 29    If a migration is unapplied its row is removed from the table. Having
 30    a row in the table always means a migration is applied.
 31    """
 32
 33    _migration_class: type[postgres.Model] | None = None
 34
 35    @classproperty  # type: ignore[invalid-argument-type]
 36    def Migration(cls) -> type[postgres.Model]:
 37        """
 38        Lazy load to avoid PackageRegistryNotReady if installed packages import
 39        MigrationRecorder.
 40        """
 41        if cls._migration_class is None:
 42            _models_registry = ModelsRegistry()
 43            _models_registry.ready = True
 44
 45            class Migration(postgres.Model):
 46                app = postgres.CharField(max_length=255)
 47                name = postgres.CharField(max_length=255)
 48                applied = postgres.DateTimeField(default=now)
 49
 50                # Use isolated models registry for migrations
 51                _model_meta = Meta(models_registry=_models_registry)
 52
 53                model_options = postgres.Options(
 54                    package_label="migrations",
 55                    db_table=MIGRATION_TABLE_NAME,
 56                )
 57
 58                def __str__(self) -> str:
 59                    return f"Migration {self.name} for {self.app}"
 60
 61            cls._migration_class = Migration
 62        return cls._migration_class
 63
 64    def __init__(self, connection: DatabaseConnection) -> None:
 65        self.connection = connection
 66
 67    @property
 68    def migration_qs(self) -> Any:
 69        return self.Migration.query.all()
 70
 71    def has_table(self) -> bool:
 72        """Return True if the plainmigrations table exists."""
 73        with self.connection.cursor() as cursor:
 74            tables = self.connection.table_names(cursor)
 75        return self.Migration.model_options.db_table in tables
 76
 77    def ensure_schema(self) -> None:
 78        """Ensure the table exists and has the correct schema."""
 79        # If the table's there, that's fine - we've never changed its schema
 80        # in the codebase.
 81        if self.has_table():
 82            return
 83        # Make the table
 84        try:
 85            with self.connection.schema_editor() as editor:
 86                editor.create_model(self.Migration)
 87        except DatabaseError as exc:
 88            raise MigrationSchemaMissing(
 89                f"Unable to create the plainmigrations table ({exc})"
 90            )
 91
 92    def applied_migrations(self) -> dict[tuple[str, str], Any]:
 93        """
 94        Return a dict mapping (package_name, migration_name) to Migration instances
 95        for all applied migrations.
 96        """
 97        if self.has_table():
 98            return {
 99                (migration.app, migration.name): migration
100                for migration in self.migration_qs
101            }
102        else:
103            # If the plainmigrations table doesn't exist, then no migrations
104            # are applied.
105            return {}
106
107    def record_applied(self, app: str, name: str) -> None:
108        """Record that a migration was applied."""
109        self.ensure_schema()
110        self.migration_qs.create(app=app, name=name)
111
112    def record_unapplied(self, app: str, name: str) -> None:
113        """Record that a migration was unapplied."""
114        self.ensure_schema()
115        self.migration_qs.filter(app=app, name=name).delete()
116
117    def flush(self) -> None:
118        """Delete all migration records. Useful for testing migrations."""
119        self.migration_qs.all().delete()