from __future__ import absolute_import, print_function, unicode_literals
from builtins import str
from collections import defaultdict, OrderedDict
from MySQLdb import escape_string
from ..refs import get_join_for
from .serializer import Serializer
from .util import parse_date, parse_datetime
int_op_map = {'equals': 'equals',
'contains': 'equals',
'startswith': 'gte',
'endswith': 'lte'}
str_op_map = {'gt': 'startswith',
'gte': 'startswith',
'lt': 'endswith',
'lte': 'endswith'}
date_op_map = {'contains': 'equals',
'startswith': 'gte',
'endswith': 'lte'}
operator_map = {
'equals': "='{}'",
'iequals': " LIKE '{}'",
'lt': "<'{}'",
'lte': "<='{}'",
'gt': ">'{}'",
'gte': ">='{}'",
'startswith': " LIKE '{}%%'",
'endswith': " LIKE '%%{}'",
'contains': " LIKE '%%{}%%'",
'regex': " REGEXP '%%{}%%'"}
def _mkcol(tbl, name, alias):
return tbl+'.'+name+' '+tbl+'_'+alias
def _mk_condition(db_column, operator, data):
op = operator_map.get(operator)
if not op:
raise Exception('unsupported operator:' + str(operator))
# I would prefer to use a prepared statement, but collecting arguments
# and passing them back along the string everywhere would be awful design.
# (Also, I didn't find any API from Django to generate a prepared statement
# without already executing it, e.g. django.db.connection.execute())
if isinstance(data, int):
return db_column+op.format(data)
return db_column+op.format(escape_string(data).decode('utf-8'))
[docs]class SQLSerializer(Serializer):
[docs] def __init__(self, model, mode='SELECT', fullnames=None, extra_model=None):
modes = 'SELECT', 'WHERE'
if mode not in modes:
raise AttributeError('invalid mode: {}. Must be one of {}'.format(mode, modes))
Serializer.__init__(self)
self.model = model
self.mode = mode
self.fullnames = fullnames
self.extra_model = extra_model
def _create_db_column_list(self, dom):
fullnames = self.fullnames if self.fullnames else dom.get_term_names()
result = []
for fullname in fullnames:
model, alias = self.model.get_class_from_fullname(fullname)
selector = model.get_selector_from_alias(alias)
target_model, field = model.get_field_from_selector(selector)
result.append((target_model, target_model._meta.db_table, field.column))
return result
def _create_select(self, fields):
# Create the "SELECT DISTINCT table1.col1, table2.col2, ..."
# part of the SQL.
col_numbers = defaultdict(int)
fullfields = []
for field in fields:
table, column = field[1:3]
key = "%s.%s" % (table, column)
col_number = col_numbers[key]
col_numbers[key] += 1
if len(field) == 3:
field = (field[0], table, column, column if col_number == 0 else "%s__%d" % (column, col_number))
fullfields.append(field)
select = 'SELECT DISTINCT '+_mkcol(fullfields[0][1], fullfields[0][2], fullfields[0][3])
for target_model, table, column, alias in fullfields[1:]:
select += ', '+_mkcol(table, column, alias)
# Find the best way to join the tables.
target_models = [r[0] for r in fullfields]
if self.extra_model:
target_models.append(self.extra_model)
vector = self.model.get_object_vector_for(target_models)
join_path = get_join_for(vector)
# Create the "table1 LEFT JOIN table2 ON table1.col1=table2.col1"
# part of the SQL.
select += ' FROM '+join_path[0][0]
for table, left, right in join_path[1:]:
select += ' LEFT JOIN {} ON {}={}'.format(table,
table+'.'+left,
right)
return select
[docs] def logical_root_group(self, root_group, terms):
fields = self._create_db_column_list(root_group)
# Create the SELECT part of the query.
if self.mode == 'SELECT':
select = self._create_select(fields)+' WHERE '
else:
select = ''
where = (' AND '.join(terms) if terms else '1')
if where.startswith('(') and where.endswith(')'):
select += where
else:
select += '('+where+')'
return select, []
[docs] def logical_group(self, terms):
terms = [t for t in terms if t]
if not terms:
return ''
return ' AND '.join(terms)
[docs] def logical_and(self, terms):
terms = [t for t in terms if t]
if not terms:
return '()'
return '(' + self.logical_group(terms) + ')'
[docs] def logical_or(self, terms):
terms = [t for t in terms if t]
if not terms:
return ''
return '(' + ' OR '.join(terms) + ')'
[docs] def logical_not(self, terms):
if not terms:
return ''
if len(terms) == 1:
return 'NOT(' + terms[0] + ')'
return 'NOT ' + self.logical_and(terms)
[docs] def boolean_term(self, db_column, operator, data):
value = 'TRUE' if data.lower() == 'true' else 'FALSE'
return _mk_condition(db_column, operator, value)
[docs] def int_term(self, db_column, operator, data):
try:
value = int(data)
except ValueError:
return '1'
operator = int_op_map.get(operator, operator)
return _mk_condition(db_column, operator, value)
[docs] def str_term(self, db_column, operator, data):
operator = str_op_map.get(operator, operator)
return _mk_condition(db_column, operator, data)
[docs] def lcstr_term(self, db_column, operator, data):
operator = str_op_map.get(operator, operator)
if operator == 'equals':
operator = 'iequals'
return _mk_condition(db_column, operator, data.lower())
[docs] def date_datetime_common(self, db_column, operator, thedatetime):
if not thedatetime:
return ''
operator = date_op_map.get(operator, operator)
return _mk_condition(db_column, operator, thedatetime.isoformat())
[docs] def date_term(self, db_column, operator, data):
thedate = parse_date(data)
return self.date_datetime_common(db_column, operator, thedate)
[docs] def datetime_term(self, db_column, operator, data):
thedatetime = parse_datetime(data)
return self.date_datetime_common(db_column, operator, thedatetime)
[docs] def term(self, term_name, operator, data):
if operator == 'any':
return '1'
model, alias = self.model.get_class_from_fullname(term_name)
selector = model.get_selector_from_alias(alias)
target_model, field = model.get_field_from_selector(selector)
db_column = target_model._meta.db_table + '.' + field.column
handler = model.get_field_handler_from_alias(alias)
type_map = {'BOOL': self.boolean_term,
'INT': self.int_term,
'STR': self.str_term,
'LCSTR': self.lcstr_term,
'DATE': self.date_term,
'DATETIME': self.datetime_term}
func = type_map.get(handler.db_type)
if not func:
raise TypeError('unsupported field type: '+repr(field_type))
return func(db_column, operator, handler.prepare(data))