1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5from plain import postgres
6from plain.postgres.db import DatabaseError
7from plain.postgres.meta import Meta
8from plain.postgres.registry import ModelsRegistry
9from plain.utils.functional import classproperty
10from plain.utils.timezone import now
11
12from .exceptions import MigrationSchemaMissing
13
14MIGRATION_TABLE_NAME = "plainmigrations"
15
16if TYPE_CHECKING:
17 from plain.postgres.connection import DatabaseConnection
18
19
20class MigrationRecorder:
21 """
22 Deal with storing migration records in the database.
23
24 Because this table is actually itself used for dealing with model
25 creation, it's the one thing we can't do normally via migrations.
26 We manually handle table creation/schema updating (using schema backend)
27 and then have a floating model to do queries with.
28
29 If a migration is unapplied its row is removed from the table. Having
30 a row in the table always means a migration is applied.
31 """
32
33 _migration_class: type[postgres.Model] | None = None
34
35 @classproperty # type: ignore[invalid-argument-type]
36 def Migration(cls) -> type[postgres.Model]:
37 """
38 Lazy load to avoid PackageRegistryNotReady if installed packages import
39 MigrationRecorder.
40 """
41 if cls._migration_class is None:
42 _models_registry = ModelsRegistry()
43 _models_registry.ready = True
44
45 class Migration(postgres.Model):
46 app = postgres.CharField(max_length=255)
47 name = postgres.CharField(max_length=255)
48 applied = postgres.DateTimeField(default=now)
49
50 # Use isolated models registry for migrations
51 _model_meta = Meta(models_registry=_models_registry)
52
53 model_options = postgres.Options(
54 package_label="migrations",
55 db_table=MIGRATION_TABLE_NAME,
56 )
57
58 def __str__(self) -> str:
59 return f"Migration {self.name} for {self.app}"
60
61 cls._migration_class = Migration
62 return cls._migration_class
63
64 def __init__(self, connection: DatabaseConnection) -> None:
65 self.connection = connection
66
67 @property
68 def migration_qs(self) -> Any:
69 return self.Migration.query.all()
70
71 def has_table(self) -> bool:
72 """Return True if the plainmigrations table exists."""
73 with self.connection.cursor() as cursor:
74 tables = self.connection.table_names(cursor)
75 return self.Migration.model_options.db_table in tables
76
77 def ensure_schema(self) -> None:
78 """Ensure the table exists and has the correct schema."""
79 # If the table's there, that's fine - we've never changed its schema
80 # in the codebase.
81 if self.has_table():
82 return
83 # Make the table
84 try:
85 with self.connection.schema_editor() as editor:
86 editor.create_model(self.Migration)
87 except DatabaseError as exc:
88 raise MigrationSchemaMissing(
89 f"Unable to create the plainmigrations table ({exc})"
90 )
91
92 def applied_migrations(self) -> dict[tuple[str, str], Any]:
93 """
94 Return a dict mapping (package_name, migration_name) to Migration instances
95 for all applied migrations.
96 """
97 if self.has_table():
98 return {
99 (migration.app, migration.name): migration
100 for migration in self.migration_qs
101 }
102 else:
103 # If the plainmigrations table doesn't exist, then no migrations
104 # are applied.
105 return {}
106
107 def record_applied(self, app: str, name: str) -> None:
108 """Record that a migration was applied."""
109 self.ensure_schema()
110 self.migration_qs.create(app=app, name=name)
111
112 def record_unapplied(self, app: str, name: str) -> None:
113 """Record that a migration was unapplied."""
114 self.ensure_schema()
115 self.migration_qs.filter(app=app, name=name).delete()
116
117 def flush(self) -> None:
118 """Delete all migration records. Useful for testing migrations."""
119 self.migration_qs.all().delete()