Hotel/env/lib/python3.5/site-packages/django/contrib/postgres/search.py

220 lines
8.1 KiB
Python
Raw Permalink Normal View History

from django.db.models import Field, FloatField
from django.db.models.expressions import CombinedExpression, Func, Value
from django.db.models.functions import Coalesce
from django.db.models.lookups import Lookup
class SearchVectorExact(Lookup):
lookup_name = 'exact'
def process_rhs(self, qn, connection):
if not hasattr(self.rhs, 'resolve_expression'):
config = getattr(self.lhs, 'config', None)
self.rhs = SearchQuery(self.rhs, config=config)
rhs, rhs_params = super().process_rhs(qn, connection)
return rhs, rhs_params
def as_sql(self, qn, connection):
lhs, lhs_params = self.process_lhs(qn, connection)
rhs, rhs_params = self.process_rhs(qn, connection)
params = lhs_params + rhs_params
return '%s @@ %s = true' % (lhs, rhs), params
class SearchVectorField(Field):
def db_type(self, connection):
return 'tsvector'
class SearchQueryField(Field):
def db_type(self, connection):
return 'tsquery'
class SearchVectorCombinable:
ADD = '||'
def _combine(self, other, connector, reversed):
if not isinstance(other, SearchVectorCombinable) or not self.config == other.config:
raise TypeError('SearchVector can only be combined with other SearchVectors')
if reversed:
return CombinedSearchVector(other, connector, self, self.config)
return CombinedSearchVector(self, connector, other, self.config)
class SearchVector(SearchVectorCombinable, Func):
function = 'to_tsvector'
arg_joiner = " || ' ' || "
output_field = SearchVectorField()
config = None
def __init__(self, *expressions, **extra):
super().__init__(*expressions, **extra)
self.source_expressions = [
Coalesce(expression, Value('')) for expression in self.source_expressions
]
self.config = self.extra.get('config', self.config)
weight = self.extra.get('weight')
if weight is not None and not hasattr(weight, 'resolve_expression'):
weight = Value(weight)
self.weight = weight
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
if self.config:
if not hasattr(self.config, 'resolve_expression'):
resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
else:
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return resolved
def as_sql(self, compiler, connection, function=None, template=None):
config_params = []
if template is None:
if self.config:
config_sql, config_params = compiler.compile(self.config)
template = "%(function)s({}::regconfig, %(expressions)s)".format(config_sql.replace('%', '%%'))
else:
template = self.template
sql, params = super().as_sql(compiler, connection, function=function, template=template)
extra_params = []
if self.weight:
weight_sql, extra_params = compiler.compile(self.weight)
sql = 'setweight({}, {})'.format(sql, weight_sql)
return sql, config_params + params + extra_params
class CombinedSearchVector(SearchVectorCombinable, CombinedExpression):
def __init__(self, lhs, connector, rhs, config, output_field=None):
self.config = config
super().__init__(lhs, connector, rhs, output_field)
class SearchQueryCombinable:
BITAND = '&&'
BITOR = '||'
def _combine(self, other, connector, reversed):
if not isinstance(other, SearchQueryCombinable):
raise TypeError(
'SearchQuery can only be combined with other SearchQuerys, '
'got {}.'.format(type(other))
)
if not self.config == other.config:
raise TypeError("SearchQuery configs don't match.")
if reversed:
return CombinedSearchQuery(other, connector, self, self.config)
return CombinedSearchQuery(self, connector, other, self.config)
# On Combinable, these are not implemented to reduce confusion with Q. In
# this case we are actually (ab)using them to do logical combination so
# it's consistent with other usage in Django.
def __or__(self, other):
return self._combine(other, self.BITOR, False)
def __ror__(self, other):
return self._combine(other, self.BITOR, True)
def __and__(self, other):
return self._combine(other, self.BITAND, False)
def __rand__(self, other):
return self._combine(other, self.BITAND, True)
class SearchQuery(SearchQueryCombinable, Value):
output_field = SearchQueryField()
def __init__(self, value, output_field=None, *, config=None, invert=False):
self.config = config
self.invert = invert
super().__init__(value, output_field=output_field)
def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
resolved = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)
if self.config:
if not hasattr(self.config, 'resolve_expression'):
resolved.config = Value(self.config).resolve_expression(query, allow_joins, reuse, summarize, for_save)
else:
resolved.config = self.config.resolve_expression(query, allow_joins, reuse, summarize, for_save)
return resolved
def as_sql(self, compiler, connection):
params = [self.value]
if self.config:
config_sql, config_params = compiler.compile(self.config)
template = 'plainto_tsquery({}::regconfig, %s)'.format(config_sql)
params = config_params + [self.value]
else:
template = 'plainto_tsquery(%s)'
if self.invert:
template = '!!({})'.format(template)
return template, params
def _combine(self, other, connector, reversed):
combined = super()._combine(other, connector, reversed)
combined.output_field = SearchQueryField()
return combined
def __invert__(self):
return type(self)(self.value, config=self.config, invert=not self.invert)
class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
def __init__(self, lhs, connector, rhs, config, output_field=None):
self.config = config
super().__init__(lhs, connector, rhs, output_field)
class SearchRank(Func):
function = 'ts_rank'
output_field = FloatField()
def __init__(self, vector, query, **extra):
if not hasattr(vector, 'resolve_expression'):
vector = SearchVector(vector)
if not hasattr(query, 'resolve_expression'):
query = SearchQuery(query)
weights = extra.get('weights')
if weights is not None and not hasattr(weights, 'resolve_expression'):
weights = Value(weights)
self.weights = weights
super().__init__(vector, query, **extra)
def as_sql(self, compiler, connection, function=None, template=None):
extra_params = []
extra_context = {}
if template is None and self.extra.get('weights'):
if self.weights:
template = '%(function)s(%(weights)s, %(expressions)s)'
weight_sql, extra_params = compiler.compile(self.weights)
extra_context['weights'] = weight_sql
sql, params = super().as_sql(
compiler, connection,
function=function, template=template, **extra_context
)
return sql, extra_params + params
SearchVectorField.register_lookup(SearchVectorExact)
class TrigramBase(Func):
output_field = FloatField()
def __init__(self, expression, string, **extra):
if not hasattr(string, 'resolve_expression'):
string = Value(string)
super().__init__(expression, string, **extra)
class TrigramSimilarity(TrigramBase):
function = 'SIMILARITY'
class TrigramDistance(TrigramBase):
function = ''
arg_joiner = ' <-> '