import contextlib import copy from decimal import Decimal from django.apps.registry import Apps from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.backends.ddl_references import Statement from django.db.transaction import atomic from django.db.utils import NotSupportedError class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): sql_delete_table = "DROP TABLE %(table)s" sql_create_fk = None sql_create_inline_fk = "REFERENCES %(to_table)s (%(to_column)s) DEFERRABLE INITIALLY DEFERRED" sql_create_unique = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)" sql_delete_unique = "DROP INDEX %(name)s" def __enter__(self): # Some SQLite schema alterations need foreign key constraints to be # disabled. Enforce it here for the duration of the schema edition. if not self.connection.disable_constraint_checking(): raise NotSupportedError( 'SQLite schema editor cannot be used while foreign key ' 'constraint checks are enabled. Make sure to disable them ' 'before entering a transaction.atomic() context because ' 'SQLite3 does not support disabling them in the middle of ' 'a multi-statement transaction.' ) self.connection.cursor().execute('PRAGMA legacy_alter_table = ON') return super().__enter__() def __exit__(self, exc_type, exc_value, traceback): super().__exit__(exc_type, exc_value, traceback) self.connection.cursor().execute('PRAGMA legacy_alter_table = OFF') self.connection.enable_constraint_checking() def quote_value(self, value): # The backend "mostly works" without this function and there are use # cases for compiling Python without the sqlite3 libraries (e.g. # security hardening). try: import sqlite3 value = sqlite3.adapt(value) except ImportError: pass except sqlite3.ProgrammingError: pass # Manual emulation of SQLite parameter quoting if isinstance(value, bool): return str(int(value)) elif isinstance(value, (Decimal, float, int)): return str(value) elif isinstance(value, str): return "'%s'" % value.replace("\'", "\'\'") elif value is None: return "NULL" elif isinstance(value, (bytes, bytearray, memoryview)): # Bytes are only allowed for BLOB fields, encoded as string # literals containing hexadecimal data and preceded by a single "X" # character. return "X'%s'" % value.hex() else: raise ValueError("Cannot quote parameter value %r of type %s" % (value, type(value))) def _is_referenced_by_fk_constraint(self, table_name, column_name=None, ignore_self=False): """ Return whether or not the provided table name is referenced by another one. If `column_name` is specified, only references pointing to that column are considered. If `ignore_self` is True, self-referential constraints are ignored. """ with self.connection.cursor() as cursor: for other_table in self.connection.introspection.get_table_list(cursor): if ignore_self and other_table.name == table_name: continue constraints = self.connection.introspection._get_foreign_key_constraints(cursor, other_table.name) for constraint in constraints.values(): constraint_table, constraint_column = constraint['foreign_key'] if (constraint_table == table_name and (column_name is None or constraint_column == column_name)): return True return False def alter_db_table(self, model, old_db_table, new_db_table, disable_constraints=True): if disable_constraints and self._is_referenced_by_fk_constraint(old_db_table): if self.connection.in_atomic_block: raise NotSupportedError(( 'Renaming the %r table while in a transaction is not ' 'supported on SQLite because it would break referential ' 'integrity. Try adding `atomic = False` to the Migration class.' ) % old_db_table) self.connection.enable_constraint_checking() super().alter_db_table(model, old_db_table, new_db_table) self.connection.disable_constraint_checking() else: super().alter_db_table(model, old_db_table, new_db_table) def alter_field(self, model, old_field, new_field, strict=False): old_field_name = old_field.name table_name = model._meta.db_table _, old_column_name = old_field.get_attname_column() if (new_field.name != old_field_name and self._is_referenced_by_fk_constraint(table_name, old_column_name, ignore_self=True)): if self.connection.in_atomic_block: raise NotSupportedError(( 'Renaming the %r.%r column while in a transaction is not ' 'supported on SQLite because it would break referential ' 'integrity. Try adding `atomic = False` to the Migration class.' ) % (model._meta.db_table, old_field_name)) with atomic(self.connection.alias): super().alter_field(model, old_field, new_field, strict=strict) # Follow SQLite's documented procedure for performing changes # that don't affect the on-disk content. # https://sqlite.org/lang_altertable.html#otheralter with self.connection.cursor() as cursor: schema_version = cursor.execute('PRAGMA schema_version').fetchone()[0] cursor.execute('PRAGMA writable_schema = 1') references_template = ' REFERENCES "%s" ("%%s") ' % table_name new_column_name = new_field.get_attname_column()[1] search = references_template % old_column_name replacement = references_template % new_column_name cursor.execute('UPDATE sqlite_master SET sql = replace(sql, %s, %s)', (search, replacement)) cursor.execute('PRAGMA schema_version = %d' % (schema_version + 1)) cursor.execute('PRAGMA writable_schema = 0') # The integrity check will raise an exception and rollback # the transaction if the sqlite_master updates corrupt the # database. cursor.execute('PRAGMA integrity_check') # Perform a VACUUM to refresh the database representation from # the sqlite_master table. with self.connection.cursor() as cursor: cursor.execute('VACUUM') else: super().alter_field(model, old_field, new_field, strict=strict) def _remake_table(self, model, create_field=None, delete_field=None, alter_field=None): """ Shortcut to transform a model from old_model into new_model The essential steps are: 1. rename the model's existing table, e.g. "app_model" to "app_model__old" 2. create a table with the updated definition called "app_model" 3. copy the data from the old renamed table to the new table 4. delete the "app_model__old" table """ # Self-referential fields must be recreated rather than copied from # the old model to ensure their remote_field.field_name doesn't refer # to an altered field. def is_self_referential(f): return f.is_relation and f.remote_field.model is model # Work out the new fields dict / mapping body = { f.name: f.clone() if is_self_referential(f) else f for f in model._meta.local_concrete_fields } # Since mapping might mix column names and default values, # its values must be already quoted. mapping = {f.column: self.quote_name(f.column) for f in model._meta.local_concrete_fields} # This maps field names (not columns) for things like unique_together rename_mapping = {} # If any of the new or altered fields is introducing a new PK, # remove the old one restore_pk_field = None if getattr(create_field, 'primary_key', False) or ( alter_field and getattr(alter_field[1], 'primary_key', False)): for name, field in list(body.items()): if field.primary_key: field.primary_key = False restore_pk_field = field if field.auto_created: del body[name] del mapping[field.column] # Add in any created fields if create_field: body[create_field.name] = create_field # Choose a default and insert it into the copy map if not create_field.many_to_many and create_field.concrete: mapping[create_field.column] = self.quote_value( self.effective_default(create_field) ) # Add in any altered fields if alter_field: old_field, new_field = alter_field body.pop(old_field.name, None) mapping.pop(old_field.column, None) body[new_field.name] = new_field if old_field.null and not new_field.null: case_sql = "coalesce(%(col)s, %(default)s)" % { 'col': self.quote_name(old_field.column), 'default': self.quote_value(self.effective_default(new_field)) } mapping[new_field.column] = case_sql else: mapping[new_field.column] = self.quote_name(old_field.column) rename_mapping[old_field.name] = new_field.name # Remove any deleted fields if delete_field: del body[delete_field.name] del mapping[delete_field.column] # Remove any implicit M2M tables if delete_field.many_to_many and delete_field.remote_field.through._meta.auto_created: return self.delete_model(delete_field.remote_field.through) # Work inside a new app registry apps = Apps() # Provide isolated instances of the fields to the new model body so # that the existing model's internals aren't interfered with when # the dummy model is constructed. body = copy.deepcopy(body) # Work out the new value of unique_together, taking renames into # account unique_together = [ [rename_mapping.get(n, n) for n in unique] for unique in model._meta.unique_together ] # Work out the new value for index_together, taking renames into # account index_together = [ [rename_mapping.get(n, n) for n in index] for index in model._meta.index_together ] indexes = model._meta.indexes if delete_field: indexes = [ index for index in indexes if delete_field.name not in index.fields ] # Construct a new model for the new state meta_contents = { 'app_label': model._meta.app_label, 'db_table': model._meta.db_table, 'unique_together': unique_together, 'index_together': index_together, 'indexes': indexes, 'apps': apps, } meta = type("Meta", (), meta_contents) body['Meta'] = meta body['__module__'] = model.__module__ temp_model = type(model._meta.object_name, model.__bases__, body) # We need to modify model._meta.db_table, but everything explodes # if the change isn't reversed before the end of this method. This # context manager helps us avoid that situation. @contextlib.contextmanager def altered_table_name(model, temporary_table_name): original_table_name = model._meta.db_table model._meta.db_table = temporary_table_name yield model._meta.db_table = original_table_name with altered_table_name(model, model._meta.db_table + "__old"): # Rename the old table to make way for the new self.alter_db_table( model, temp_model._meta.db_table, model._meta.db_table, disable_constraints=False, ) # Create a new table with the updated schema. self.create_model(temp_model) # Copy data from the old table into the new table field_maps = list(mapping.items()) self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % ( self.quote_name(temp_model._meta.db_table), ', '.join(self.quote_name(x) for x, y in field_maps), ', '.join(y for x, y in field_maps), self.quote_name(model._meta.db_table), )) # Delete the old table self.delete_model(model, handle_autom2m=False) # Run deferred SQL on correct table for sql in self.deferred_sql: self.execute(sql) self.deferred_sql = [] # Fix any PK-removed field if restore_pk_field: restore_pk_field.primary_key = True def delete_model(self, model, handle_autom2m=True): if handle_autom2m: super().delete_model(model) else: # Delete the table (and only that) self.execute(self.sql_delete_table % { "table": self.quote_name(model._meta.db_table), }) # Remove all deferred statements referencing the deleted table. for sql in list(self.deferred_sql): if isinstance(sql, Statement) and sql.references_table(model._meta.db_table): self.deferred_sql.remove(sql) def add_field(self, model, field): """ Create a field on a model. Usually involves adding a column, but may involve adding a table instead (for M2M fields). """ # Special-case implicit M2M tables if field.many_to_many and field.remote_field.through._meta.auto_created: return self.create_model(field.remote_field.through) self._remake_table(model, create_field=field) def remove_field(self, model, field): """ Remove a field from a model. Usually involves deleting a column, but for M2Ms may involve deleting a table. """ # M2M fields are a special case if field.many_to_many: # For implicit M2M tables, delete the auto-created table if field.remote_field.through._meta.auto_created: self.delete_model(field.remote_field.through) # For explicit "through" M2M fields, do nothing # For everything else, remake. else: # It might not actually have a column behind it if field.db_parameters(connection=self.connection)['type'] is None: return self._remake_table(model, delete_field=field) def _alter_field(self, model, old_field, new_field, old_type, new_type, old_db_params, new_db_params, strict=False): """Perform a "physical" (non-ManyToMany) field update.""" # Alter by remaking table self._remake_table(model, alter_field=(old_field, new_field)) # Rebuild tables with FKs pointing to this field if the PK type changed. if old_field.primary_key and new_field.primary_key and old_type != new_type: for rel in new_field.model._meta.related_objects: if not rel.many_to_many: self._remake_table(rel.related_model) def _alter_many_to_many(self, model, old_field, new_field, strict): """Alter M2Ms to repoint their to= endpoints.""" if old_field.remote_field.through._meta.db_table == new_field.remote_field.through._meta.db_table: # The field name didn't change, but some options did; we have to propagate this altering. self._remake_table( old_field.remote_field.through, alter_field=( # We need the field that points to the target model, so we can tell alter_field to change it - # this is m2m_reverse_field_name() (as opposed to m2m_field_name, which points to our model) old_field.remote_field.through._meta.get_field(old_field.m2m_reverse_field_name()), new_field.remote_field.through._meta.get_field(new_field.m2m_reverse_field_name()), ), ) return # Make a new through table self.create_model(new_field.remote_field.through) # Copy the data across self.execute("INSERT INTO %s (%s) SELECT %s FROM %s" % ( self.quote_name(new_field.remote_field.through._meta.db_table), ', '.join([ "id", new_field.m2m_column_name(), new_field.m2m_reverse_name(), ]), ', '.join([ "id", old_field.m2m_column_name(), old_field.m2m_reverse_name(), ]), self.quote_name(old_field.remote_field.through._meta.db_table), )) # Delete the old through table self.delete_model(old_field.remote_field.through)