# -*- coding: utf-8 -*-
#
# BSD licence
#
# Copyright (c) <2008-2011> Pierre Quentel (pierre.quentel@gmail.com)
# Copyright (c) <2014-2015> Bendik Rønning Opstad <bro.devel@gmail.com>.
#
"""
Main differences from :mod:`pydblite.pydblite`:
- pass the connection to the :class:`SQLite db <pydblite.sqlite.Database>` as argument to
:class:`Table <pydblite.sqlite.Table>`
- in :func:`create() <pydblite.sqlite.Table.create>` field definitions must specify a type.
- no `drop_field` (not supported by SQLite)
- the :class:`Table <pydblite.sqlite.Table>` instance has a
:attr:`cursor <pydblite.sqlite.Table.cursor>` attribute, so that raw SQL requests can
be executed.
"""
try:
import cStringIO as io
def to_str(val, encoding="utf-8"): # encode a Unicode string to a Python 2 str
return val.encode(encoding)
except ImportError:
import io
unicode = str # used in tests
def to_str(val): # leaves a Unicode unchanged
return val
import datetime
import re
import traceback
from .common import ExpressionGroup, Filter
# test if sqlite is installed or raise exception
try:
from sqlite3 import dbapi2 as sqlite
from sqlite3 import OperationalError
except ImportError:
try:
from pysqlite2 import dbapi2 as sqlite
from pysqlite2._sqlite import OperationalError
except ImportError:
print("SQLite is not installed")
raise
# compatibility with Python 2.3
try:
set([])
except NameError:
from sets import Set as set # NOQA
# classes for CURRENT_DATE, CURRENT_TIME, CURRENT_TIMESTAMP
class CurrentDate:
def __call__(self):
return datetime.date.today().strftime('%Y-%M-%D')
class CurrentTime:
def __call__(self):
return datetime.datetime.now().strftime('%h:%m:%s')
class CurrentTimestamp:
def __call__(self):
return datetime.datetime.now().strftime('%Y-%M-%D %h:%m:%s')
DEFAULT_CLASSES = [CurrentDate, CurrentTime, CurrentTimestamp]
# functions to convert a value returned by a SQLite SELECT
# CURRENT_TIME format is HH:MM:SS
# CURRENT_DATE : YYYY-MM-DD
# CURRENT_TIMESTAMP : YYYY-MM-DD HH:MM:SS
c_time_fmt = re.compile('^(\d{2}):(\d{2}):(\d{2})$')
c_date_fmt = re.compile('^(\d{4})-(\d{2})-(\d{2})$')
c_tmsp_fmt = re.compile('^(\d{4})-(\d{2})-(\d{2}) (\d{2}):(\d{2}):(\d{2})')
# DATE : convert YYYY-MM-DD to datetime.date instance
def to_date(date):
if date is None:
return None
mo = c_date_fmt.match(date)
if not mo:
raise ValueError("Bad value %s for DATE format" % date)
year, month, day = [int(x) for x in mo.groups()]
return datetime.date(year, month, day)
# TIME : convert HH-MM-SS to datetime.time instance
def to_time(_time):
if _time is None:
return None
mo = c_time_fmt.match(_time)
if not mo:
raise ValueError("Bad value %s for TIME format" % _time)
hour, minute, second = [int(x) for x in mo.groups()]
return datetime.time(hour, minute, second)
# DATETIME or TIMESTAMP : convert %YYYY-MM-DD HH:MM:SS
# to datetime.datetime instance
def to_datetime(timestamp):
if timestamp is None:
return None
if not isinstance(timestamp, unicode):
raise ValueError("Bad value %s for TIMESTAMP format" % timestamp)
mo = c_tmsp_fmt.match(timestamp)
if not mo:
raise ValueError("Bad value %s for TIMESTAMP format" % timestamp)
return datetime.datetime(*[int(x) for x in mo.groups()])
# if default value is CURRENT_DATE etc. SQLite doesn't
# give the information, default is the value of the
# variable as a string. We have to guess...
#
def guess_default_fmt(value):
mo = c_time_fmt.match(value)
if mo:
h, m, s = [int(x) for x in mo.groups()]
if (0 <= h <= 23) and (0 <= m <= 59) and (0 <= s <= 59):
return CurrentTime
mo = c_date_fmt.match(value)
if mo:
y, m, d = [int(x) for x in mo.groups()]
try:
datetime.date(y, m, d)
return CurrentDate
except:
pass
mo = c_tmsp_fmt.match(value)
if mo:
y, mth, d, h, mn, s = [int(x) for x in mo.groups()]
try:
datetime.datetime(y, mth, d, h, mn, s)
return CurrentTimestamp
except:
pass
return value
[docs]class SQLiteError(Exception):
"""SQLiteError"""
pass
[docs]class Database(dict):
[docs] def __init__(self, filename, **kw):
"""
To create an in-memory database provide ':memory:' as filename
Args:
- filename (str): The name of the database file, or ':memory:'
- kw (dict): Arguments forwarded to sqlite3.connect
"""
dict.__init__(self)
self.conn = sqlite.connect(filename, **kw)
"""The SQLite connection"""
self.cursor = self.conn.cursor()
"""The SQLite connections cursor"""
for table_name in self._tables():
self[table_name] = Table(table_name, self)
[docs] def _tables(self):
"""Return the list of table names in the database"""
tables = []
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
for table_info in self.cursor.fetchall():
if table_info[0] != 'sqlite_sequence':
tables.append(table_info[0])
return tables
def create(self, table_name, *fields, **kw):
self[table_name] = Table(table_name, self).create(*fields, **kw)
return self[table_name]
[docs] def commit(self):
"""Save any changes to the database"""
self.conn.commit()
[docs] def close(self):
"""Closes the database"""
self.conn.close()
def __delitem__(self, table):
# drop table
if isinstance(table, Table):
table = table.name
self.cursor.execute('DROP TABLE %s' % table)
dict.__delitem__(self, table)
# The instance can be used as a context manager, to make sure that it is
# closed even if an exception is raised during operations
[docs] def __enter__(self):
"""Enter 'with' statement"""
return self
[docs] def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit 'with' statement"""
self.conn.close()
return exc_type is None
[docs]class Table(object):
[docs] def __init__(self, table_name, db):
"""
Args:
- table_name (str): The name of the SQLite table.
- db (:class:`Database <pydblite.sqlite.Database>`): The database.
"""
self.name = table_name
self.db = db
self.cursor = db.cursor
"""The SQLite connections cursor"""
self.conv_func = {}
self.mode = "open"
self._get_table_info()
[docs] def create(self, *fields, **kw):
"""
Create a new table.
Args:
- fields (list of tuples): The fields names/types to create.
For each field, a 2-element tuple must be provided:
- the field name
- a string with additional information like field type +
other information using the SQLite syntax
eg ('name', 'TEXT NOT NULL'), ('date', 'BLOB DEFAULT CURRENT_DATE')
- mode (str): The mode used when creating the database.
mode is only used if a database file already exists.
- if mode = 'open' : open the existing base, ignore the fields
- if mode = 'override' : erase the existing base and create a
new one with the specified fields
Returns:
- the database (self).
"""
self.mode = mode = kw.get("mode", None)
if self._table_exists():
if mode == "override":
self.cursor.execute("DROP TABLE %s" % self.name)
elif mode == "open":
return self.open()
else:
raise IOError("Base '%s' already exists" % self.name)
sql = "CREATE TABLE %s (" % self.name
for field in fields:
sql += self._validate_field(field) + ','
sql = sql[:-1] + ')'
self.cursor.execute(sql)
self._get_table_info()
return self
[docs] def open(self):
"""Open an existing database."""
return self
[docs] def commit(self):
"""Save any changes to the database"""
self.db.commit()
def _table_exists(self):
return self.name in self.db
[docs] def _get_table_info(self):
"""Inspect the base to get field names."""
self.fields = []
self.field_info = {}
self.cursor.execute('PRAGMA table_info (%s)' % self.name)
for field_info in self.cursor.fetchall():
fname = to_str(field_info[1])
self.fields.append(fname)
ftype = to_str(field_info[2])
info = {'type': ftype}
# can be null ?
info['NOT NULL'] = field_info[3] != 0
# default value
default = field_info[4]
if isinstance(default, unicode):
default = guess_default_fmt(default)
info['DEFAULT'] = default
self.field_info[fname] = info
self.fields_with_id = ['__id__'] + self.fields
def info(self):
# returns information about the table
return [(field, self.field_info[field]) for field in self.fields]
def _validate_field(self, field):
if len(field) != 2 and len(field) != 3:
msg = "Error in field definition %s" % field
msg += ": should be a tuple with field_name, field_info, and optionally a default value"
raise SQLiteError(msg)
field_sql = '%s %s' % (field[0], field[1])
if len(field) == 3 and field[2] is not None:
field_sql += " DEFAULT {0}".format(field[2])
return field_sql
[docs] def conv(self, field_name, conv_func):
"""When a record is returned by a SELECT, ask conversion of
specified field value with the specified function."""
if field_name not in self.fields:
raise NameError("Unknown field %s" % field_name)
self.conv_func[field_name] = conv_func
[docs] def is_date(self, field_name):
"""Ask conversion of field to an instance of datetime.date"""
self.conv(field_name, to_date)
[docs] def is_time(self, field_name):
"""Ask conversion of field to an instance of datetime.date"""
self.conv(field_name, to_time)
[docs] def is_datetime(self, field_name):
"""Ask conversion of field to an instance of datetime.date"""
self.conv(field_name, to_datetime)
[docs] def insert(self, *args, **kw):
"""Insert a record in the database.
Parameters can be positional or keyword arguments. If positional
they must be in the same order as in the :func:`create` method.
Returns:
- The record identifier
"""
if args:
if isinstance(args[0], (list, tuple)):
return self._insert_many(args[0])
kw = dict([(f, arg) for f, arg in zip(self.fields, args)])
ks = kw.keys()
s1 = ",".join(ks)
qm = ','.join(['?'] * len(ks))
sql = "INSERT INTO %s (%s) VALUES (%s)" % (self.name, s1, qm)
self.cursor.execute(sql, list(kw.values()))
return self.cursor.lastrowid
[docs] def _insert_many(self, args):
"""Insert a list or tuple of records
Returns:
- The last row id
"""
sql = "INSERT INTO %s" % self.name
sql += "(%s) VALUES (%s)"
if isinstance(args[0], dict):
ks = args[0].keys()
sql = sql % (', '.join(ks), ','.join(['?' for k in ks]))
args = [[arg[k] for k in ks] for arg in args]
else:
sql = sql % (', '.join(self.fields),
','.join(['?' for f in self.fields]))
try:
self.cursor.executemany(sql, args)
except:
raise Exception(self._err_msg(sql, args))
# return last row id
return self.cursor.lastrowid
[docs] def delete(self, removed):
"""Remove a single record, or the records in an iterable.
Before starting deletion, test if all records are in the base
and don't have twice the same __id__.
Returns:
- int: the number of deleted items
"""
sql = "DELETE FROM %s " % self.name
if isinstance(removed, dict):
# remove a single record
_id = removed['__id__']
sql += "WHERE rowid = ?"
args = (_id,)
removed = [removed]
else:
# convert iterable into a list
removed = [r for r in removed]
if not removed:
return 0
args = [r['__id__'] for r in removed]
sql += "WHERE rowid IN (%s)" % (','.join(['?'] * len(args)))
self.cursor.execute(sql, args)
self.db.commit()
return len(removed)
[docs] def update(self, record, **kw):
"""Update the record with new keys and values."""
vals = self._make_sql_params(kw)
sql = "UPDATE %s SET %s WHERE rowid=?" % (self.name,
",".join(vals))
self.cursor.execute(sql, list(kw.values()) + [record['__id__']])
self.db.commit()
[docs] def _make_sql_params(self, kw):
"""Make a list of strings to pass to an SQL statement
from the dictionary kw with Python types."""
return ['%s=?' % k for k in kw.keys()]
[docs] def _make_record(self, row, fields=None):
"""Make a record dictionary from the result of a fetch"""
if fields is None:
fields = self.fields_with_id
res = dict(zip(fields, row))
for field_name in self.conv_func:
res[field_name] = self.conv_func[field_name](res[field_name])
return res
[docs] def add_field(self, name, column_type="TEXT", default=None):
"""Add a new column to the table.
Args:
- name (string): The name of the field
- column_type (string): The data type of the column (Defaults to TEXT)
- default (datatype): The default value for this field (if any)
"""
sql = "ALTER TABLE %s ADD " % self.name
sql += self._validate_field((name, column_type, default))
self.cursor.execute(sql)
self.db.commit()
self._get_table_info()
def drop_field(self, field):
raise SQLiteError("Dropping fields is not supported by SQLite")
[docs] def __call__(self, *args, **kw):
"""
Selection by field values.
db(key=value) returns the list of records where r[key] = value
Args:
- args (list): A field to filter on.
- kw (dict): pairs of field and value to filter on.
Returns:
- When args supplied, return a :class:`Filter <pydblite.common.Filter>`
object that filters on the specified field.
- When kw supplied, return all the records where field values matches
the key/values in kw.
"""
if args and kw:
raise SyntaxError("Can't specify positional AND keyword arguments")
use_expression = False
if args:
if len(args) > 1:
raise SyntaxError("Only one field can be specified")
if type(args[0]) is ExpressionGroup or type(args[0]) is Filter:
use_expression = True
elif args[0] not in self.fields:
raise ValueError("%s is not a field" % args[0])
else:
return self.filter(key=args[0])
if use_expression:
sql = "SELECT rowid,* FROM %s WHERE %s" % (self.name, args[0])
self.cursor.execute(sql)
return [self._make_record(row) for row in self.cursor.fetchall()]
else:
if kw:
undef = set(kw) - set(self.fields)
if undef:
raise ValueError("Fields %s not in the database" % undef)
vals = self._make_sql_params(kw)
sql = "SELECT rowid,* FROM %s WHERE %s" % (self.name, " AND ".join(vals))
self.cursor.execute(sql, list(kw.values()))
else:
self.cursor.execute("SELECT rowid,* FROM %s" % self.name)
records = self.cursor.fetchall()
return [self._make_record(row) for row in records]
[docs] def __getitem__(self, record_id):
"""Direct access by record id."""
sql = "SELECT rowid,* FROM %s WHERE rowid=%s" % (self.name, record_id)
self.cursor.execute(sql)
res = self.cursor.fetchone()
if res is None:
raise IndexError("No record at index %s" % record_id)
else:
return self._make_record(res)
def filter(self, key=None):
return Filter(self, key)
[docs] def _len(self, db_filter=None):
"""Return number of matching entries"""
if db_filter is not None and db_filter.is_filtered():
sql = "SELECT COUNT(*) AS count FROM %s WHERE %s" % (self.name, db_filter)
else:
sql = "SELECT COUNT(*) AS count FROM %s;" % self.name
self.cursor.execute(sql)
res = self.cursor.fetchone()
return res[0]
def __len__(self):
return self._len()
[docs] def __delitem__(self, record_id):
"""Delete by record id"""
self.delete(self[record_id])
[docs] def __iter__(self):
"""Iteration on the records"""
self.cursor.execute("SELECT rowid,* FROM %s" % self.name)
results = [self._make_record(r) for r in self.cursor.fetchall()]
return iter(results)
def _err_msg(self, sql, args=None):
msg = "Exception for table %s.%s\n" % (self.db, self.name)
msg += 'SQL request %s\n' % sql
if args:
import pprint
msg += 'Arguments : %s\n' % pprint.saferepr(args)
out = io.StringIO()
traceback.print_exc(file=out)
msg += out.getvalue()
return msg
[docs] def get_group_count(self, group_by, db_filter=None):
"""Return the grouped by count of the values of a column"""
if db_filter is not None and db_filter.is_filtered():
sql = "SELECT %s, COUNT(*) FROM %s WHERE %s GROUP BY %s " % (group_by, self.name,
db_filter, group_by)
else:
sql = "SELECT %s, COUNT(*) FROM %s GROUP BY %s;" % (group_by, self.name, group_by)
self.cursor.execute(sql)
return self.cursor.fetchall()
[docs] def get_unique_ids(self, unique_id, db_filter=None):
"""Return all the unique values of a column"""
sql = "SELECT rowid,%s FROM %s" % (unique_id, self.name)
if db_filter is not None and db_filter.is_filtered():
sql += " WHERE %s" % db_filter
self.cursor.execute(sql)
records = self.cursor.fetchall()
return set([row[1] for row in records])
def create_index(self, *index_columns):
for ic in index_columns:
sql = "CREATE INDEX index_%s on %s (%s);" % (ic, self.name, ic)
self.cursor.execute(sql)
self.db.commit()
def delete_index(self, *index_columns):
for ic in index_columns:
sql = "DROP INDEX index_%s;" % (ic)
self.cursor.execute(sql)
self.db.commit()
def get_indices(self):
indices = []
sql = "SELECT * FROM sqlite_master WHERE type = 'index';"
try:
self.cursor.execute(sql)
except OperationalError:
return indices
records = self.cursor.fetchall()
for r in records:
indices.append(r[1][len("index_"):])
return indices
Base = Table # compatibility with previous versions