Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5from plain import models
  6from plain.models.connections import DatabaseConnection
  7from plain.models.db import DatabaseError
  8from plain.models.meta import Meta
  9from plain.models.registry import ModelsRegistry
 10from plain.utils.functional import classproperty
 11from plain.utils.timezone import now
 12
 13from .exceptions import MigrationSchemaMissing
 14
 15if TYPE_CHECKING:
 16    from plain.models.backends.base.base import BaseDatabaseWrapper
 17
 18
 19class MigrationRecorder:
 20    """
 21    Deal with storing migration records in the database.
 22
 23    Because this table is actually itself used for dealing with model
 24    creation, it's the one thing we can't do normally via migrations.
 25    We manually handle table creation/schema updating (using schema backend)
 26    and then have a floating model to do queries with.
 27
 28    If a migration is unapplied its row is removed from the table. Having
 29    a row in the table always means a migration is applied.
 30    """
 31
 32    _migration_class: type[models.Model] | None = None
 33
 34    @classproperty
 35    def Migration(cls) -> type[models.Model]:  # type: ignore[misc]
 36        """
 37        Lazy load to avoid PackageRegistryNotReady if installed packages import
 38        MigrationRecorder.
 39        """
 40        if cls._migration_class is None:
 41            _models_registry = ModelsRegistry()
 42            _models_registry.ready = True
 43
 44            class Migration(models.Model):
 45                app = models.CharField(max_length=255)
 46                name = models.CharField(max_length=255)
 47                applied = models.DateTimeField(default=now)
 48
 49                # Use isolated models registry for migrations
 50                _model_meta = Meta(models_registry=_models_registry)
 51
 52                model_options = models.Options(
 53                    package_label="migrations",
 54                    db_table="plainmigrations",
 55                )
 56
 57                def __str__(self) -> str:
 58                    return f"Migration {self.name} for {self.app}"
 59
 60            cls._migration_class = Migration
 61        return cls._migration_class
 62
 63    def __init__(self, connection: BaseDatabaseWrapper | DatabaseConnection) -> None:
 64        self.connection = connection
 65
 66    @property
 67    def migration_qs(self) -> Any:
 68        return self.Migration.query.all()
 69
 70    def has_table(self) -> bool:
 71        """Return True if the plainmigrations table exists."""
 72        with self.connection.cursor() as cursor:
 73            tables = self.connection.introspection.table_names(cursor)
 74        return self.Migration.model_options.db_table in tables
 75
 76    def ensure_schema(self) -> None:
 77        """Ensure the table exists and has the correct schema."""
 78        # If the table's there, that's fine - we've never changed its schema
 79        # in the codebase.
 80        if self.has_table():
 81            return
 82        # Make the table
 83        try:
 84            with self.connection.schema_editor() as editor:
 85                editor.create_model(self.Migration)
 86        except DatabaseError as exc:
 87            raise MigrationSchemaMissing(
 88                f"Unable to create the plainmigrations table ({exc})"
 89            )
 90
 91    def applied_migrations(self) -> dict[tuple[str, str], Any]:
 92        """
 93        Return a dict mapping (package_name, migration_name) to Migration instances
 94        for all applied migrations.
 95        """
 96        if self.has_table():
 97            return {
 98                (migration.app, migration.name): migration
 99                for migration in self.migration_qs
100            }
101        else:
102            # If the plainmigrations table doesn't exist, then no migrations
103            # are applied.
104            return {}
105
106    def record_applied(self, app: str, name: str) -> None:
107        """Record that a migration was applied."""
108        self.ensure_schema()
109        self.migration_qs.create(app=app, name=name)
110
111    def record_unapplied(self, app: str, name: str) -> None:
112        """Record that a migration was unapplied."""
113        self.ensure_schema()
114        self.migration_qs.filter(app=app, name=name).delete()
115
116    def flush(self) -> None:
117        """Delete all migration records. Useful for testing migrations."""
118        self.migration_qs.all().delete()