1
0
forked from s434650/CatOrNot
CatOrNot/venv/lib/python3.6/site-packages/wtforms/ext/sqlalchemy/orm.py
2018-12-11 00:32:28 +01:00

305 lines
10 KiB
Python

"""
Tools for generating forms based on SQLAlchemy models.
"""
from __future__ import unicode_literals
import inspect
from wtforms import fields as f
from wtforms import validators
from wtforms.form import Form
from .fields import QuerySelectField, QuerySelectMultipleField
__all__ = (
'model_fields', 'model_form',
)
def converts(*args):
def _inner(func):
func._converter_for = frozenset(args)
return func
return _inner
class ModelConversionError(Exception):
def __init__(self, message):
Exception.__init__(self, message)
class ModelConverterBase(object):
def __init__(self, converters, use_mro=True):
self.use_mro = use_mro
if not converters:
converters = {}
for name in dir(self):
obj = getattr(self, name)
if hasattr(obj, '_converter_for'):
for classname in obj._converter_for:
converters[classname] = obj
self.converters = converters
def convert(self, model, mapper, prop, field_args, db_session=None):
if not hasattr(prop, 'columns') and not hasattr(prop, 'direction'):
return
elif not hasattr(prop, 'direction') and len(prop.columns) != 1:
raise TypeError(
'Do not know how to convert multiple-column properties currently'
)
kwargs = {
'validators': [],
'filters': [],
'default': None,
}
converter = None
column = None
types = None
if not hasattr(prop, 'direction'):
column = prop.columns[0]
# Support sqlalchemy.schema.ColumnDefault, so users can benefit
# from setting defaults for fields, e.g.:
# field = Column(DateTimeField, default=datetime.utcnow)
default = getattr(column, 'default', None)
if default is not None:
# Only actually change default if it has an attribute named
# 'arg' that's callable.
callable_default = getattr(default, 'arg', None)
if callable_default is not None:
# ColumnDefault(val).arg can be also a plain value
default = callable_default(None) if callable(callable_default) else callable_default
kwargs['default'] = default
if column.nullable:
kwargs['validators'].append(validators.Optional())
else:
kwargs['validators'].append(validators.Required())
if self.use_mro:
types = inspect.getmro(type(column.type))
else:
types = [type(column.type)]
for col_type in types:
type_string = '%s.%s' % (col_type.__module__, col_type.__name__)
if type_string.startswith('sqlalchemy'):
type_string = type_string[11:]
if type_string in self.converters:
converter = self.converters[type_string]
break
else:
for col_type in types:
if col_type.__name__ in self.converters:
converter = self.converters[col_type.__name__]
break
else:
raise ModelConversionError('Could not find field converter for %s (%r).' % (prop.key, types[0]))
else:
# We have a property with a direction.
if not db_session:
raise ModelConversionError("Cannot convert field %s, need DB session." % prop.key)
foreign_model = prop.mapper.class_
nullable = True
for pair in prop.local_remote_pairs:
if not pair[0].nullable:
nullable = False
kwargs.update({
'allow_blank': nullable,
'query_factory': lambda: db_session.query(foreign_model).all()
})
converter = self.converters[prop.direction.name]
if field_args:
kwargs.update(field_args)
return converter(
model=model,
mapper=mapper,
prop=prop,
column=column,
field_args=kwargs
)
class ModelConverter(ModelConverterBase):
def __init__(self, extra_converters=None, use_mro=True):
super(ModelConverter, self).__init__(extra_converters, use_mro=use_mro)
@classmethod
def _string_common(cls, column, field_args, **extra):
if column.type.length:
field_args['validators'].append(validators.Length(max=column.type.length))
@converts('String', 'Unicode')
def conv_String(self, field_args, **extra):
self._string_common(field_args=field_args, **extra)
return f.TextField(**field_args)
@converts('types.Text', 'UnicodeText', 'types.LargeBinary', 'types.Binary', 'sql.sqltypes.Text')
def conv_Text(self, field_args, **extra):
self._string_common(field_args=field_args, **extra)
return f.TextAreaField(**field_args)
@converts('Boolean')
def conv_Boolean(self, field_args, **extra):
return f.BooleanField(**field_args)
@converts('Date')
def conv_Date(self, field_args, **extra):
return f.DateField(**field_args)
@converts('DateTime')
def conv_DateTime(self, field_args, **extra):
return f.DateTimeField(**field_args)
@converts('Enum')
def conv_Enum(self, column, field_args, **extra):
if 'choices' not in field_args:
field_args['choices'] = [(e, e) for e in column.type.enums]
return f.SelectField(**field_args)
@converts('Integer', 'SmallInteger')
def handle_integer_types(self, column, field_args, **extra):
unsigned = getattr(column.type, 'unsigned', False)
if unsigned:
field_args['validators'].append(validators.NumberRange(min=0))
return f.IntegerField(**field_args)
@converts('Numeric', 'Float')
def handle_decimal_types(self, column, field_args, **extra):
places = getattr(column.type, 'scale', 2)
if places is not None:
field_args['places'] = places
return f.DecimalField(**field_args)
@converts('databases.mysql.MSYear', 'dialects.mysql.base.YEAR')
def conv_MSYear(self, field_args, **extra):
field_args['validators'].append(validators.NumberRange(min=1901, max=2155))
return f.TextField(**field_args)
@converts('databases.postgres.PGInet', 'dialects.postgresql.base.INET')
def conv_PGInet(self, field_args, **extra):
field_args.setdefault('label', 'IP Address')
field_args['validators'].append(validators.IPAddress())
return f.TextField(**field_args)
@converts('dialects.postgresql.base.MACADDR')
def conv_PGMacaddr(self, field_args, **extra):
field_args.setdefault('label', 'MAC Address')
field_args['validators'].append(validators.MacAddress())
return f.TextField(**field_args)
@converts('dialects.postgresql.base.UUID')
def conv_PGUuid(self, field_args, **extra):
field_args.setdefault('label', 'UUID')
field_args['validators'].append(validators.UUID())
return f.TextField(**field_args)
@converts('MANYTOONE')
def conv_ManyToOne(self, field_args, **extra):
return QuerySelectField(**field_args)
@converts('MANYTOMANY', 'ONETOMANY')
def conv_ManyToMany(self, field_args, **extra):
return QuerySelectMultipleField(**field_args)
def model_fields(model, db_session=None, only=None, exclude=None,
field_args=None, converter=None, exclude_pk=False,
exclude_fk=False):
"""
Generate a dictionary of fields for a given SQLAlchemy model.
See `model_form` docstring for description of parameters.
"""
mapper = model._sa_class_manager.mapper
converter = converter or ModelConverter()
field_args = field_args or {}
properties = []
for prop in mapper.iterate_properties:
if getattr(prop, 'columns', None):
if exclude_fk and prop.columns[0].foreign_keys:
continue
elif exclude_pk and prop.columns[0].primary_key:
continue
properties.append((prop.key, prop))
# ((p.key, p) for p in mapper.iterate_properties)
if only:
properties = (x for x in properties if x[0] in only)
elif exclude:
properties = (x for x in properties if x[0] not in exclude)
field_dict = {}
for name, prop in properties:
field = converter.convert(
model, mapper, prop,
field_args.get(name), db_session
)
if field is not None:
field_dict[name] = field
return field_dict
def model_form(model, db_session=None, base_class=Form, only=None,
exclude=None, field_args=None, converter=None, exclude_pk=True,
exclude_fk=True, type_name=None):
"""
Create a wtforms Form for a given SQLAlchemy model class::
from wtforms.ext.sqlalchemy.orm import model_form
from myapp.models import User
UserForm = model_form(User)
:param model:
A SQLAlchemy mapped model class.
:param db_session:
An optional SQLAlchemy Session.
:param base_class:
Base form class to extend from. Must be a ``wtforms.Form`` subclass.
:param only:
An optional iterable with the property names that should be included in
the form. Only these properties will have fields.
:param exclude:
An optional iterable with the property names that should be excluded
from the form. All other properties will have fields.
:param field_args:
An optional dictionary of field names mapping to keyword arguments used
to construct each field object.
:param converter:
A converter to generate the fields based on the model properties. If
not set, ``ModelConverter`` is used.
:param exclude_pk:
An optional boolean to force primary key exclusion.
:param exclude_fk:
An optional boolean to force foreign keys exclusion.
:param type_name:
An optional string to set returned type name.
"""
if not hasattr(model, '_sa_class_manager'):
raise TypeError('model must be a sqlalchemy mapped model')
type_name = type_name or str(model.__name__ + 'Form')
field_dict = model_fields(
model, db_session, only, exclude, field_args, converter,
exclude_pk=exclude_pk, exclude_fk=exclude_fk
)
return type(type_name, (base_class, ), field_dict)