1from __future__ import annotations
  2
  3import subprocess
  4import sys
  5import time
  6from collections import defaultdict
  7
  8import click
  9
 10from plain.cli import register_cli
 11
 12from ..backups.cli import cli as backups_cli
 13from ..db import OperationalError, get_connection
 14from ..dialect import quote_name
 15from ..migrations.recorder import MIGRATION_TABLE_NAME
 16
 17
 18@register_cli("db")
 19@click.group()
 20def cli() -> None:
 21    """Database operations"""
 22
 23
 24cli.add_command(backups_cli)
 25
 26
 27@cli.command()
 28@click.argument("parameters", nargs=-1)
 29def shell(parameters: tuple[str, ...]) -> None:
 30    """Open an interactive database shell"""
 31    conn = get_connection()
 32    try:
 33        conn.runshell(list(parameters))
 34    except FileNotFoundError:
 35        # Note that we're assuming the FileNotFoundError relates to the
 36        # command missing. It could be raised for some other reason, in
 37        # which case this error message would be inaccurate. Still, this
 38        # message catches the common case.
 39        click.secho(
 40            f"You appear not to have the {conn.executable_name!r} program installed or on your path.",
 41            fg="red",
 42            err=True,
 43        )
 44        sys.exit(1)
 45    except subprocess.CalledProcessError as e:
 46        click.secho(
 47            '"{}" returned non-zero exit status {}.'.format(
 48                " ".join(e.cmd),
 49                e.returncode,
 50            ),
 51            fg="red",
 52            err=True,
 53        )
 54        sys.exit(e.returncode)
 55
 56
 57@cli.command("drop-unknown-tables")
 58@click.option(
 59    "--yes",
 60    is_flag=True,
 61    help="Skip confirmation prompt (for non-interactive use).",
 62)
 63def drop_unknown_tables(yes: bool) -> None:
 64    """Drop all tables not associated with a Plain model"""
 65    conn = get_connection()
 66    db_tables = set(conn.table_names())
 67    model_tables = set(conn.plain_table_names())
 68    unknown_tables = sorted(db_tables - model_tables - {MIGRATION_TABLE_NAME})
 69
 70    if not unknown_tables:
 71        click.echo("No unknown tables found.")
 72        return
 73
 74    unknown_set = set(unknown_tables)
 75    table_count = len(unknown_tables)
 76    tables_label = f"{table_count} table{'s' if table_count != 1 else ''}"
 77
 78    # Find foreign key constraints from kept tables that reference unknown tables
 79    cascade_warnings: defaultdict[str, list[tuple[str, str]]] = defaultdict(list)
 80    with conn.cursor() as cursor:
 81        for table in unknown_tables:
 82            cursor.execute(
 83                """
 84                SELECT conname, conrelid::regclass
 85                FROM pg_constraint
 86                WHERE confrelid = %s::regclass AND contype = 'f'
 87                """,
 88                [table],
 89            )
 90            for constraint_name, referencing_table in cursor.fetchall():
 91                if str(referencing_table) not in unknown_set:
 92                    cascade_warnings[table].append(
 93                        (constraint_name, str(referencing_table))
 94                    )
 95
 96    click.secho("Unknown tables:", fg="yellow", bold=True)
 97    for table in unknown_tables:
 98        click.echo(f"  - {table}")
 99        for constraint_name, referencing_table in cascade_warnings[table]:
100            click.secho(
101                f"    ⚠ CASCADE will drop constraint {constraint_name} on {referencing_table}",
102                fg="red",
103            )
104    click.echo()
105
106    if not yes:
107        if not click.confirm(f"Drop {tables_label} (CASCADE)? This cannot be undone."):
108            return
109
110    with conn.cursor() as cursor:
111        for table in unknown_tables:
112            click.echo(f"  Dropping {table}...", nl=False)
113            cursor.execute(f"DROP TABLE IF EXISTS {quote_name(table)} CASCADE")
114            click.echo(" OK")
115
116    click.secho(f"✓ Dropped {tables_label}.", fg="green")
117
118
119@cli.command()
120def wait() -> None:
121    """Wait for the database to be ready"""
122    attempts = 0
123    while True:
124        attempts += 1
125        waiting_for = False
126
127        try:
128            get_connection().ensure_connection()
129        except OperationalError:
130            waiting_for = True
131
132        if waiting_for:
133            if attempts > 1:
134                # After the first attempt, start printing them
135                click.secho(
136                    f"Waiting for database (attempt {attempts})",
137                    fg="yellow",
138                )
139            time.sleep(1.5)
140        else:
141            click.secho("✔ Database ready", fg="green")
142            break