Source code for

Support for example tables wrapping data stored on a PostgreSQL server.
import contextlib
import functools
import logging
import threading
import warnings
from contextlib import contextmanager
from itertools import islice
from time import strftime

import numpy as np
from import (
    Table, Domain, Value, Instance, filter)
from import filter as sql_filter
from import Backend
from import TableDesc, BackendError

LARGE_TABLE = 100000
sql_log = logging.getLogger('sql_log')
sql_log.debug("Logging started: {}".format(strftime("%Y-%m-%d %H:%M:%S")))

[docs] class SqlTable(Table): table_name = None domain = None row_filters = ()
[docs] def __new__(cls, *args, **kwargs): # We do not (yet) need the magic of the Table.__new__, so we call it # with no parameters. return super().__new__(cls)
[docs] def __init__( self, connection_params, table_or_sql, backend=None, type_hints=None, inspect_values=False): """ Create a new proxy for sql table. To create a new SqlTable, specify the connection parameters for psycopg2 and the name of the table/sql query used to fetch the data. table = SqlTable('database_name', 'table_name') table = SqlTable('database_name', 'SELECT * FROM table') For complex configurations, dictionary of connection parameters can be used instead of the database name. For documentation about connection parameters, see: Data domain is inferred from the columns of the table/query. The (very quick) default setting is to treat all numeric columns as continuous variables and everything else as strings and placed among meta attributes. If inspect_values parameter is set to True, all column values are inspected and int/string columns with less than 21 values are intepreted as discrete features. Domains can be constructed by the caller and passed in type_hints parameter. Variables from the domain are used for the columns with the matching names; for columns without the matching name in the domain, types are inferred as described above. """ if isinstance(connection_params, str): connection_params = dict(database=connection_params) if backend is None: for backend in Backend.available_backends(): try: self.backend = backend(connection_params) break except BackendError: pass else: raise ValueError("No backend could connect to server") else: self.backend = backend(connection_params) if table_or_sql is not None: if isinstance(table_or_sql, TableDesc): table = table_or_sql.sql elif "select" in table_or_sql.lower(): table = "(%s) as my_table" % table_or_sql.strip("; ") else: table = self.backend.quote_identifier(table_or_sql) self.table_name = table self.domain = self.get_domain(type_hints, inspect_values) = table
@property def connection_params(self): warnings.warn("Use backend.connection_params", DeprecationWarning) return self.backend.connection_params def get_domain(self, type_hints=None, inspect_values=False): table_name = self.table_name if type_hints is None: type_hints = Domain([]) inspect_table = table_name if inspect_values else None attrs, class_vars, metas = [], [], [] for field_name, *field_metadata in self.backend.get_fields(table_name): var = self.backend.create_variable(field_name, field_metadata, type_hints, inspect_table) if var.is_string: metas.append(var) else: if var in type_hints.class_vars: class_vars.append(var) elif var in type_hints.metas: metas.append(var) else: attrs.append(var) return Domain(attrs, class_vars, metas)
[docs] def __getitem__(self, key): """ Indexing of SqlTable is performed in the following way: If a single row is requested, it is fetched from the database and returned as a SqlRowInstance. A new SqlTable with appropriate filters is constructed and returned otherwise. """ if isinstance(key, int): # one row return self._fetch_row(key) if not isinstance(key, tuple): # row filter key = (key, Ellipsis) if len(key) != 2: raise IndexError("Table indices must be one- or two-dimensional") row_idx, col_idx = key if isinstance(row_idx, int): try: col_idx = self.domain.index(col_idx) var = self.domain[col_idx] return Value( var, next(self._query([var], rows=[row_idx]))[0] ) except TypeError: pass elif not (row_idx is Ellipsis or row_idx == slice(None)): # TODO if row_idx specify multiple rows, one of the following must # happen # - the new table remembers which rows are selected (implement # table.limit_rows and whatever else is necessary) # - return an ordinary (non-SQL) Table # - raise an exception raise NotImplementedError("Row indices must be integers.") # multiple rows OR single row but multiple columns: # construct a new table table = self.copy() table.domain = self.domain.select_columns(col_idx) # table.limit_rows(row_idx) return table
@functools.lru_cache(maxsize=128) def _fetch_row(self, row_index): attributes = self.domain.variables + self.domain.metas rows = [row_index] values = list(self._query(attributes, rows=rows)) if not values: raise IndexError('Could not retrieve row {} from table {}'.format( row_index, return Instance(self.domain, values[0])
[docs] def __iter__(self): """ Iterating through the rows executes the query using a cursor and then yields resulting rows as SqlRowInstances as they are requested. """ attributes = self.domain.variables + self.domain.metas for row in self._query(attributes): yield Instance(self.domain, row)
def _query(self, attributes=None, filters=(), rows=None): if attributes is not None: fields = [] for attr in attributes: assert hasattr(attr, 'to_sql'), \ "Cannot use ordinary attributes with sql backend" field_str = '(%s) AS "%s"' % (attr.to_sql(), fields.append(field_str) if not fields: raise ValueError("No fields selected.") else: fields = ["*"] filters = [f.to_sql() for f in filters] offset = limit = None if rows is not None: if isinstance(rows, slice): offset = rows.start or 0 if rows.stop is not None: limit = rows.stop - offset else: rows = list(rows) offset, stop = min(rows), max(rows) limit = stop - offset + 1 # TODO: this returns all rows between min(rows) and max(rows): fix! query = self._sql_query(fields, filters, offset=offset, limit=limit) with self.backend.execute_sql_query(query) as cur: while True: row = cur.fetchone() if row is None: break yield row
[docs] def copy(self): """Return a copy of the SqlTable""" table = SqlTable.__new__(SqlTable) table.backend = self.backend table.domain = self.domain table.row_filters = self.row_filters table.table_name = self.table_name = return table
[docs] def __bool__(self): """Return True if the SqlTable is not empty.""" query = self._sql_query(["1"], limit=1) with self.backend.execute_sql_query(query) as cur: return cur.fetchone() is not None
_cached__len__ = None
[docs] def __len__(self): """ Return number of rows in the table. The value is cached so it is computed only the first time the length is requested. """ if self._cached__len__ is None: return self._count_rows() return self._cached__len__
def _count_rows(self): query = self._sql_query(["COUNT(*)"]) with self.backend.execute_sql_query(query) as cur: self._cached__len__ = cur.fetchone()[0] return self._cached__len__ def approx_len(self, get_exact=False): if self._cached__len__ is not None: return self._cached__len__ approx_len = None try: query = self._sql_query(["*"]) approx_len = self.backend.count_approx(query) if get_exact: threading.Thread(target=len, args=(self,)).start() except NotImplementedError: pass if approx_len is None: approx_len = len(self) return approx_len _X = None _Y = None _metas = None _W = None _ids = None
[docs] def download_data(self, limit=None, partial=False): """Download SQL data and store it in memory as numpy matrices.""" if limit and not partial and self.approx_len() > limit: raise ValueError("Too many rows to download the data into memory.") X = [np.empty((0, len(self.domain.attributes)))] Y = [np.empty((0, len(self.domain.class_vars)))] metas = [np.empty((0, len(self.domain.metas)))] for row in islice(self, limit): X.append(row._x) Y.append(row._y) metas.append(row._metas) self._X = np.vstack(X).astype(np.float64) self._Y = np.vstack(Y).astype(np.float64) self._metas = np.vstack(metas).astype(object) self._W = np.empty((self._X.shape[0], 0)) self._init_ids(self) if not partial or limit and self._X.shape[0] < limit: self._cached__len__ = self._X.shape[0]
@property def X(self): """Numpy array with attribute values.""" if self._X is None: self.download_data(AUTO_DL_LIMIT) return self._X @property def Y(self): """Numpy array with class values.""" if self._Y is None: self.download_data(AUTO_DL_LIMIT) return self._Y @property def metas(self): """Numpy array with class values.""" if self._metas is None: self.download_data(AUTO_DL_LIMIT) return self._metas @property def W(self): """Numpy array with class values.""" if self._W is None: self.download_data(AUTO_DL_LIMIT) return self._W @property def ids(self): """Numpy array with class values.""" if self._ids is None: self.download_data(AUTO_DL_LIMIT) return self._ids @ids.setter def ids(self, value): self._ids = value @ids.deleter def ids(self): del self._ids
[docs] def has_weights(self): return False
def _compute_basic_stats(self, columns=None, include_metas=False, compute_variance=False): if self.approx_len() > LARGE_TABLE: self = self.sample_time(DEFAULT_SAMPLE_TIME) if columns is not None: columns = [self.domain[col] for col in columns] else: columns = self.domain.variables if include_metas: columns += self.domain.metas return self._get_stats(columns) def _get_stats(self, columns): columns = [(c.to_sql(), c.is_continuous) for c in columns] sql_fields = [] for field_name, continuous in columns: stats = self.CONTINUOUS_STATS if continuous else self.DISCRETE_STATS sql_fields.append(stats % dict(field_name=field_name)) query = self._sql_query(sql_fields) with self.backend.execute_sql_query(query) as cur: results = cur.fetchone() stats = [] i = 0 for ci, (field_name, continuous) in enumerate(columns): if continuous: stats.append(results[i:i+6]) i += 6 else: stats.append((None,) * 4 + results[i:i+2]) i += 2 return stats def _compute_distributions(self, columns=None): if self.approx_len() > LARGE_TABLE: self = self.sample_time(DEFAULT_SAMPLE_TIME) if columns is not None: columns = [self.domain[col] for col in columns] else: columns = self.domain.variables return self._get_distributions(columns) def _get_distributions(self, columns): dists = [] for col in columns: field_name = col.to_sql() fields = field_name, "COUNT(%s)" % field_name query = self._sql_query(fields, filters=['%s IS NOT NULL' % field_name], group_by=[field_name], order_by=[field_name]) with self.backend.execute_sql_query(query) as cur: dist = np.array(cur.fetchall()) if col.is_continuous: dists.append((dist.T, [])) else: dists.append((dist[:, 1].T, [])) return dists def _compute_contingency(self, col_vars=None, row_var=None): if self.approx_len() > LARGE_TABLE: self = self.sample_time(DEFAULT_SAMPLE_TIME) if col_vars is None: col_vars = range(len(self.domain.variables)) if len(col_vars) != 1: raise NotImplementedError("Contingency for multiple columns " "has not yet been implemented.") if row_var is None: raise NotImplementedError("Defaults have not been implemented yet") row = self.domain[row_var] if not row.is_discrete: raise TypeError("Row variable must be discrete") columns = [self.domain[var] for var in col_vars] if any(not (var.is_continuous or var.is_discrete) for var in columns): raise ValueError("contingency can be computed only for discrete " "and continuous values") row_field = row.to_sql() all_contingencies = [None] * len(columns) for i, column in enumerate(columns): column_field = column.to_sql() fields = [row_field, column_field, "COUNT(%s)" % column_field] group_by = [row_field, column_field] order_by = [column_field] filters = ['%s IS NOT NULL' % f for f in (row_field, column_field)] query = self._sql_query(fields, filters=filters, group_by=group_by, order_by=order_by) with self.backend.execute_sql_query(query) as cur: data = list(cur.fetchall()) if column.is_continuous: all_contingencies[i] = \ (self._continuous_contingencies(data, row), [], [], 0) else: all_contingencies[i] =\ (self._discrete_contingencies(data, row, column), [], [], 0) return all_contingencies def _continuous_contingencies(self, data, row): values = np.zeros(len(data)) counts = np.zeros((len(row.values), len(data))) last = None i = -1 for row_value, column_value, count in data: if column_value == last: counts[row.to_val(row_value), i] += count else: i += 1 last = column_value values[i] = column_value counts[row.to_val(row_value), i] += count return (values, counts) def _discrete_contingencies(self, data, row, column): conts = np.zeros((len(row.values), len(column.values))) for row_value, col_value, count in data: row_index = row.to_val(row_value) col_index = column.to_val(col_value) conts[row_index, col_index] = count return conts def X_density(self): return self.DENSE def Y_density(self): return self.DENSE def metas_density(self): return self.DENSE # Filters def _filter_is_defined(self, columns=None, negate=False): if columns is None: columns = range(len(self.domain.variables)) columns = [self.domain[i].to_sql() for i in columns] t2 = self.copy() t2.row_filters += (sql_filter.IsDefinedSql(columns, negate),) return t2 def _filter_has_class(self, negate=False): columns = [c.to_sql() for c in self.domain.class_vars] t2 = self.copy() t2.row_filters += (sql_filter.IsDefinedSql(columns, negate),) return t2 def _filter_same_value(self, column, value, negate=False): var = self.domain[column] if value is None: pass elif var.is_discrete: value = var.to_val(value) value = "'%s'" % var.repr_val(value) else: pass t2 = self.copy() t2.row_filters += \ (sql_filter.SameValueSql(var.to_sql(), value, negate),) return t2 def _filter_values(self, f): conditions = [] for cond in f.conditions: var = self.domain[cond.column] if isinstance(cond, filter.FilterDiscrete): if cond.values is None: values = None else: values = ["'%s'" % var.repr_val(var.to_val(v)) for v in cond.values] new_condition = sql_filter.FilterDiscreteSql( column=var.to_sql(), values=values) elif isinstance(cond, filter.FilterContinuous): new_condition = sql_filter.FilterContinuousSql( position=var.to_sql(), oper=cond.oper, ref=cond.ref, max=cond.max) elif isinstance(cond, filter.FilterString): new_condition = sql_filter.FilterString( var.to_sql(), oper=cond.oper, ref=cond.ref, max=cond.max, case_sensitive=cond.case_sensitive, ) elif isinstance(cond, filter.FilterStringList): new_condition = sql_filter.FilterStringList( column=var.to_sql(), values=cond.values, case_sensitive=cond.case_sensitive) else: raise ValueError('Invalid condition %s' % type(cond)) conditions.append(new_condition) t2 = self.copy() t2.row_filters += (sql_filter.ValuesSql(conditions=conditions, conjunction=f.conjunction, negate=f.negate),) return t2
[docs] @classmethod def from_table(cls, domain, source, row_indices=...): # pylint: disable=unused-argument assert row_indices is ... table = source.copy() table.domain = domain return table
# sql queries def _sql_query(self, fields, filters=(), group_by=None, order_by=None, offset=None, limit=None, use_time_sample=None): row_filters = [f.to_sql() for f in self.row_filters] row_filters.extend(filters) return self.backend.create_sql_query( self.table_name, fields, row_filters, group_by, order_by, offset, limit, use_time_sample) DISCRETE_STATS = "SUM(CASE TRUE WHEN %(field_name)s IS NULL THEN 1 " \ "ELSE 0 END), " \ "SUM(CASE TRUE WHEN %(field_name)s IS NULL THEN 0 " \ "ELSE 1 END)" CONTINUOUS_STATS = "MIN(%(field_name)s)::double precision, " \ "MAX(%(field_name)s)::double precision, " \ "AVG(%(field_name)s)::double precision, " \ "STDDEV(%(field_name)s)::double precision, " \ + DISCRETE_STATS def sample_percentage(self, percentage, no_cache=False): if percentage >= 100: return self return self._sample('system', percentage, no_cache=no_cache) def sample_time(self, time_in_seconds, no_cache=False): return self._sample('system_time', int(time_in_seconds * 1000), no_cache=no_cache) def _sample(self, method, parameter, no_cache=False): # the module is optional, but this function is not called if it's not installed # pylint: disable=import-error import psycopg2 if "," in self.table_name: raise NotImplementedError("Sampling of complex queries is not supported") parameter = str(parameter) if "." in self.table_name: schema, name = self.table_name.split(".") sample_name = '__%s_%s_%s' % ( self.backend.unquote_identifier(name), method, parameter.replace('.', '_').replace('-', '_')) sample_table_q = ".".join([schema, self.backend.quote_identifier(sample_name)]) else: sample_table = '__%s_%s_%s' % ( self.backend.unquote_identifier(self.table_name), method, parameter.replace('.', '_').replace('-', '_')) sample_table_q = self.backend.quote_identifier(sample_table) create = False try: query = "SELECT * FROM " + sample_table_q + " LIMIT 0;" with self.backend.execute_sql_query(query): pass if no_cache: query = "DROP TABLE " + sample_table_q with self.backend.execute_sql_query(query): pass create = True except BackendError: create = True if create: with self.backend.execute_sql_query( " ".join(["CREATE TABLE", sample_table_q, "AS", "SELECT * FROM", self.table_name, "TABLESAMPLE", method, "(", parameter, ")"])): pass sampled_table = self.copy() sampled_table.table_name = sample_table_q with sampled_table.backend.execute_sql_query('ANALYZE' + sample_table_q): pass return sampled_table @contextmanager def _execute_sql_query(self, query, param=None): warnings.warn("Use backend.execute_sql_query", DeprecationWarning) with self.backend.execute_sql_query(query, param) as cur: yield cur
[docs] def checksum(self, include_metas=True): return np.nan
def __get_nan_frequency(self, columns): try: query = self._sql_query([" + ".join([f"COUNT(*) - COUNT({col.to_sql()})" for col in columns])]) with self.backend.execute_sql_query(query) as cur: return cur.fetchone()[0] / (len(self) * len(columns)) except BackendError: return None def get_nan_frequency_attribute(self): return self.__get_nan_frequency(self.domain.attributes) def get_nan_frequency_class(self): return self.__get_nan_frequency(self.domain.class_vars) def __getstate__(self): # avoids locking magic in Table.__getstate__ return self.__dict__ def __setstate__(self, state): # avoid locking magic in Table.__setstate__ self.__dict__.update(state) # if X is defined then it was already downloaded # thus ids exist to, rewrite them if self._X is not None: self._init_ids(self) # pylint: disable=unused-argument def _update_locks(self, *args, **kwargs): # avoid locking inherited from Table return # pylint: disable=unused-argument
[docs] def unlocked(self, *parts): # avoid locking inherited from Table return contextlib.nullcontext()