from django.core.exceptions import FieldDoesNotExist from django.db.models.fields import NOT_PROVIDED from django.utils.functional import cached_property from .base import Operation from .utils import is_referenced_by_foreign_key class FieldOperation(Operation): def __init__(self, model_name, name): self.model_name = model_name self.name = name @cached_property def model_name_lower(self): return self.model_name.lower() @cached_property def name_lower(self): return self.name.lower() def is_same_model_operation(self, operation): return self.model_name_lower == operation.model_name_lower def is_same_field_operation(self, operation): return self.is_same_model_operation(operation) and self.name_lower == operation.name_lower def references_model(self, name, app_label=None): return name.lower() == self.model_name_lower def references_field(self, model_name, name, app_label=None): return self.references_model(model_name) and name.lower() == self.name_lower def reduce(self, operation, in_between, app_label=None): return ( super().reduce(operation, in_between, app_label=app_label) or not operation.references_field(self.model_name, self.name, app_label) ) class AddField(FieldOperation): """Add a field to a model.""" def __init__(self, model_name, name, field, preserve_default=True): self.field = field self.preserve_default = preserve_default super().__init__(model_name, name) def deconstruct(self): kwargs = { 'model_name': self.model_name, 'name': self.name, 'field': self.field, } if self.preserve_default is not True: kwargs['preserve_default'] = self.preserve_default return ( self.__class__.__name__, [], kwargs ) def state_forwards(self, app_label, state): # If preserve default is off, don't use the default for future state if not self.preserve_default: field = self.field.clone() field.default = NOT_PROVIDED else: field = self.field state.models[app_label, self.model_name_lower].fields.append((self.name, field)) # Delay rendering of relationships if it's not a relational field delay = not field.is_relation state.reload_model(app_label, self.model_name_lower, delay=delay) def database_forwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) if self.allow_migrate_model(schema_editor.connection.alias, to_model): from_model = from_state.apps.get_model(app_label, self.model_name) field = to_model._meta.get_field(self.name) if not self.preserve_default: field.default = self.field.default schema_editor.add_field( from_model, field, ) if not self.preserve_default: field.default = NOT_PROVIDED def database_backwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.apps.get_model(app_label, self.model_name) if self.allow_migrate_model(schema_editor.connection.alias, from_model): schema_editor.remove_field(from_model, from_model._meta.get_field(self.name)) def describe(self): return "Add field %s to %s" % (self.name, self.model_name) def reduce(self, operation, in_between, app_label=None): if isinstance(operation, FieldOperation) and self.is_same_field_operation(operation): if isinstance(operation, AlterField): return [ AddField( model_name=self.model_name, name=operation.name, field=operation.field, ), ] elif isinstance(operation, RemoveField): return [] elif isinstance(operation, RenameField): return [ AddField( model_name=self.model_name, name=operation.new_name, field=self.field, ), ] return super().reduce(operation, in_between, app_label=app_label) class RemoveField(FieldOperation): """Remove a field from a model.""" def deconstruct(self): kwargs = { 'model_name': self.model_name, 'name': self.name, } return ( self.__class__.__name__, [], kwargs ) def state_forwards(self, app_label, state): new_fields = [] old_field = None for name, instance in state.models[app_label, self.model_name_lower].fields: if name != self.name: new_fields.append((name, instance)) else: old_field = instance state.models[app_label, self.model_name_lower].fields = new_fields # Delay rendering of relationships if it's not a relational field delay = not old_field.is_relation state.reload_model(app_label, self.model_name_lower, delay=delay) def database_forwards(self, app_label, schema_editor, from_state, to_state): from_model = from_state.apps.get_model(app_label, self.model_name) if self.allow_migrate_model(schema_editor.connection.alias, from_model): schema_editor.remove_field(from_model, from_model._meta.get_field(self.name)) def database_backwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) if self.allow_migrate_model(schema_editor.connection.alias, to_model): from_model = from_state.apps.get_model(app_label, self.model_name) schema_editor.add_field(from_model, to_model._meta.get_field(self.name)) def describe(self): return "Remove field %s from %s" % (self.name, self.model_name) class AlterField(FieldOperation): """ Alter a field's database column (e.g. null, max_length) to the provided new field. """ def __init__(self, model_name, name, field, preserve_default=True): self.field = field self.preserve_default = preserve_default super().__init__(model_name, name) def deconstruct(self): kwargs = { 'model_name': self.model_name, 'name': self.name, 'field': self.field, } if self.preserve_default is not True: kwargs['preserve_default'] = self.preserve_default return ( self.__class__.__name__, [], kwargs ) def state_forwards(self, app_label, state): if not self.preserve_default: field = self.field.clone() field.default = NOT_PROVIDED else: field = self.field state.models[app_label, self.model_name_lower].fields = [ (n, field if n == self.name else f) for n, f in state.models[app_label, self.model_name_lower].fields ] # TODO: investigate if old relational fields must be reloaded or if it's # sufficient if the new field is (#27737). # Delay rendering of relationships if it's not a relational field and # not referenced by a foreign key. delay = ( not field.is_relation and not is_referenced_by_foreign_key(state, self.model_name_lower, self.field, self.name) ) state.reload_model(app_label, self.model_name_lower, delay=delay) def database_forwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) if self.allow_migrate_model(schema_editor.connection.alias, to_model): from_model = from_state.apps.get_model(app_label, self.model_name) from_field = from_model._meta.get_field(self.name) to_field = to_model._meta.get_field(self.name) if not self.preserve_default: to_field.default = self.field.default schema_editor.alter_field(from_model, from_field, to_field) if not self.preserve_default: to_field.default = NOT_PROVIDED def database_backwards(self, app_label, schema_editor, from_state, to_state): self.database_forwards(app_label, schema_editor, from_state, to_state) def describe(self): return "Alter field %s on %s" % (self.name, self.model_name) def reduce(self, operation, in_between, app_label=None): if isinstance(operation, RemoveField) and self.is_same_field_operation(operation): return [operation] elif isinstance(operation, RenameField) and self.is_same_field_operation(operation): return [ operation, AlterField( model_name=self.model_name, name=operation.new_name, field=self.field, ), ] return super().reduce(operation, in_between, app_label=app_label) class RenameField(FieldOperation): """Rename a field on the model. Might affect db_column too.""" def __init__(self, model_name, old_name, new_name): self.old_name = old_name self.new_name = new_name super().__init__(model_name, old_name) @cached_property def old_name_lower(self): return self.old_name.lower() @cached_property def new_name_lower(self): return self.new_name.lower() def deconstruct(self): kwargs = { 'model_name': self.model_name, 'old_name': self.old_name, 'new_name': self.new_name, } return ( self.__class__.__name__, [], kwargs ) def state_forwards(self, app_label, state): model_state = state.models[app_label, self.model_name_lower] # Rename the field fields = model_state.fields found = False delay = True for index, (name, field) in enumerate(fields): if not found and name == self.old_name: fields[index] = (self.new_name, field) found = True # Fix from_fields to refer to the new field. from_fields = getattr(field, 'from_fields', None) if from_fields: field.from_fields = tuple([ self.new_name if from_field_name == self.old_name else from_field_name for from_field_name in from_fields ]) # Delay rendering of relationships if it's not a relational # field and not referenced by a foreign key. delay = delay and ( not field.is_relation and not is_referenced_by_foreign_key(state, self.model_name_lower, field, self.name) ) if not found: raise FieldDoesNotExist( "%s.%s has no field named '%s'" % (app_label, self.model_name, self.old_name) ) # Fix index/unique_together to refer to the new field options = model_state.options for option in ('index_together', 'unique_together'): if option in options: options[option] = [ [self.new_name if n == self.old_name else n for n in together] for together in options[option] ] # Fix to_fields to refer to the new field. model_tuple = app_label, self.model_name_lower for (model_app_label, model_name), model_state in state.models.items(): for index, (name, field) in enumerate(model_state.fields): remote_field = field.remote_field if remote_field: remote_model_tuple = self._get_model_tuple( remote_field.model, model_app_label, model_name ) if remote_model_tuple == model_tuple: if getattr(remote_field, 'field_name', None) == self.old_name: remote_field.field_name = self.new_name to_fields = getattr(field, 'to_fields', None) if to_fields: field.to_fields = tuple([ self.new_name if to_field_name == self.old_name else to_field_name for to_field_name in to_fields ]) state.reload_model(app_label, self.model_name_lower, delay=delay) def database_forwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) if self.allow_migrate_model(schema_editor.connection.alias, to_model): from_model = from_state.apps.get_model(app_label, self.model_name) schema_editor.alter_field( from_model, from_model._meta.get_field(self.old_name), to_model._meta.get_field(self.new_name), ) def database_backwards(self, app_label, schema_editor, from_state, to_state): to_model = to_state.apps.get_model(app_label, self.model_name) if self.allow_migrate_model(schema_editor.connection.alias, to_model): from_model = from_state.apps.get_model(app_label, self.model_name) schema_editor.alter_field( from_model, from_model._meta.get_field(self.new_name), to_model._meta.get_field(self.old_name), ) def describe(self): return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name) def references_field(self, model_name, name, app_label=None): return self.references_model(model_name) and ( name.lower() == self.old_name_lower or name.lower() == self.new_name_lower ) def reduce(self, operation, in_between, app_label=None): if (isinstance(operation, RenameField) and self.is_same_model_operation(operation) and self.new_name_lower == operation.old_name_lower): return [ RenameField( self.model_name, self.old_name, operation.new_name, ), ] # Skip `FieldOperation.reduce` as we want to run `references_field` # against self.new_name. return ( super(FieldOperation, self).reduce(operation, in_between, app_label=app_label) or not operation.references_field(self.model_name, self.new_name, app_label) )