Plain is headed towards 1.0! Subscribe for development updates →

  1"""
  2Helpers to manipulate deferred DDL statements that might need to be adjusted or
  3discarded within when executing a migration.
  4"""
  5from copy import deepcopy
  6
  7
  8class Reference:
  9    """Base class that defines the reference interface."""
 10
 11    def references_table(self, table):
 12        """
 13        Return whether or not this instance references the specified table.
 14        """
 15        return False
 16
 17    def references_column(self, table, column):
 18        """
 19        Return whether or not this instance references the specified column.
 20        """
 21        return False
 22
 23    def rename_table_references(self, old_table, new_table):
 24        """
 25        Rename all references to the old_name to the new_table.
 26        """
 27        pass
 28
 29    def rename_column_references(self, table, old_column, new_column):
 30        """
 31        Rename all references to the old_column to the new_column.
 32        """
 33        pass
 34
 35    def __repr__(self):
 36        return f"<{self.__class__.__name__} {str(self)!r}>"
 37
 38    def __str__(self):
 39        raise NotImplementedError(
 40            "Subclasses must define how they should be converted to string."
 41        )
 42
 43
 44class Table(Reference):
 45    """Hold a reference to a table."""
 46
 47    def __init__(self, table, quote_name):
 48        self.table = table
 49        self.quote_name = quote_name
 50
 51    def references_table(self, table):
 52        return self.table == table
 53
 54    def rename_table_references(self, old_table, new_table):
 55        if self.table == old_table:
 56            self.table = new_table
 57
 58    def __str__(self):
 59        return self.quote_name(self.table)
 60
 61
 62class TableColumns(Table):
 63    """Base class for references to multiple columns of a table."""
 64
 65    def __init__(self, table, columns):
 66        self.table = table
 67        self.columns = columns
 68
 69    def references_column(self, table, column):
 70        return self.table == table and column in self.columns
 71
 72    def rename_column_references(self, table, old_column, new_column):
 73        if self.table == table:
 74            for index, column in enumerate(self.columns):
 75                if column == old_column:
 76                    self.columns[index] = new_column
 77
 78
 79class Columns(TableColumns):
 80    """Hold a reference to one or many columns."""
 81
 82    def __init__(self, table, columns, quote_name, col_suffixes=()):
 83        self.quote_name = quote_name
 84        self.col_suffixes = col_suffixes
 85        super().__init__(table, columns)
 86
 87    def __str__(self):
 88        def col_str(column, idx):
 89            col = self.quote_name(column)
 90            try:
 91                suffix = self.col_suffixes[idx]
 92                if suffix:
 93                    col = f"{col} {suffix}"
 94            except IndexError:
 95                pass
 96            return col
 97
 98        return ", ".join(
 99            col_str(column, idx) for idx, column in enumerate(self.columns)
100        )
101
102
103class IndexName(TableColumns):
104    """Hold a reference to an index name."""
105
106    def __init__(self, table, columns, suffix, create_index_name):
107        self.suffix = suffix
108        self.create_index_name = create_index_name
109        super().__init__(table, columns)
110
111    def __str__(self):
112        return self.create_index_name(self.table, self.columns, self.suffix)
113
114
115class IndexColumns(Columns):
116    def __init__(self, table, columns, quote_name, col_suffixes=(), opclasses=()):
117        self.opclasses = opclasses
118        super().__init__(table, columns, quote_name, col_suffixes)
119
120    def __str__(self):
121        def col_str(column, idx):
122            # Index.__init__() guarantees that self.opclasses is the same
123            # length as self.columns.
124            col = f"{self.quote_name(column)} {self.opclasses[idx]}"
125            try:
126                suffix = self.col_suffixes[idx]
127                if suffix:
128                    col = f"{col} {suffix}"
129            except IndexError:
130                pass
131            return col
132
133        return ", ".join(
134            col_str(column, idx) for idx, column in enumerate(self.columns)
135        )
136
137
138class ForeignKeyName(TableColumns):
139    """Hold a reference to a foreign key name."""
140
141    def __init__(
142        self,
143        from_table,
144        from_columns,
145        to_table,
146        to_columns,
147        suffix_template,
148        create_fk_name,
149    ):
150        self.to_reference = TableColumns(to_table, to_columns)
151        self.suffix_template = suffix_template
152        self.create_fk_name = create_fk_name
153        super().__init__(
154            from_table,
155            from_columns,
156        )
157
158    def references_table(self, table):
159        return super().references_table(table) or self.to_reference.references_table(
160            table
161        )
162
163    def references_column(self, table, column):
164        return super().references_column(
165            table, column
166        ) or self.to_reference.references_column(table, column)
167
168    def rename_table_references(self, old_table, new_table):
169        super().rename_table_references(old_table, new_table)
170        self.to_reference.rename_table_references(old_table, new_table)
171
172    def rename_column_references(self, table, old_column, new_column):
173        super().rename_column_references(table, old_column, new_column)
174        self.to_reference.rename_column_references(table, old_column, new_column)
175
176    def __str__(self):
177        suffix = self.suffix_template % {
178            "to_table": self.to_reference.table,
179            "to_column": self.to_reference.columns[0],
180        }
181        return self.create_fk_name(self.table, self.columns, suffix)
182
183
184class Statement(Reference):
185    """
186    Statement template and formatting parameters container.
187
188    Allows keeping a reference to a statement without interpolating identifiers
189    that might have to be adjusted if they're referencing a table or column
190    that is removed
191    """
192
193    def __init__(self, template, **parts):
194        self.template = template
195        self.parts = parts
196
197    def references_table(self, table):
198        return any(
199            hasattr(part, "references_table") and part.references_table(table)
200            for part in self.parts.values()
201        )
202
203    def references_column(self, table, column):
204        return any(
205            hasattr(part, "references_column") and part.references_column(table, column)
206            for part in self.parts.values()
207        )
208
209    def rename_table_references(self, old_table, new_table):
210        for part in self.parts.values():
211            if hasattr(part, "rename_table_references"):
212                part.rename_table_references(old_table, new_table)
213
214    def rename_column_references(self, table, old_column, new_column):
215        for part in self.parts.values():
216            if hasattr(part, "rename_column_references"):
217                part.rename_column_references(table, old_column, new_column)
218
219    def __str__(self):
220        return self.template % self.parts
221
222
223class Expressions(TableColumns):
224    def __init__(self, table, expressions, compiler, quote_value):
225        self.compiler = compiler
226        self.expressions = expressions
227        self.quote_value = quote_value
228        columns = [
229            col.target.column
230            for col in self.compiler.query._gen_cols([self.expressions])
231        ]
232        super().__init__(table, columns)
233
234    def rename_table_references(self, old_table, new_table):
235        if self.table != old_table:
236            return
237        self.expressions = self.expressions.relabeled_clone({old_table: new_table})
238        super().rename_table_references(old_table, new_table)
239
240    def rename_column_references(self, table, old_column, new_column):
241        if self.table != table:
242            return
243        expressions = deepcopy(self.expressions)
244        self.columns = []
245        for col in self.compiler.query._gen_cols([expressions]):
246            if col.target.column == old_column:
247                col.target.column = new_column
248            self.columns.append(col.target.column)
249        self.expressions = expressions
250
251    def __str__(self):
252        sql, params = self.compiler.compile(self.expressions)
253        params = map(self.quote_value, params)
254        return sql % tuple(params)