import atexit
import inspect
import sys
import os
import threading
import types
try:
    from urlparse import urlparse, parse_qsl
    from urllib import unquote, quote, urlencode
except ImportError:
    from urllib.parse import urlparse, parse_qsl, unquote, quote, urlencode
import warnings
import weakref
from . import classregistry
from . import col
from . import sqlbuilder
from .cache import CacheSet
from .compat import PY2, string_type, unicode_type
from .converters import sqlrepr
from .events import send, CommitSignal, RollbackSignal
from .util.threadinglocal import local as threading_local
warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
def _closeConnection(ref):
    conn = ref()
    if conn is not None:
        conn.close()
[docs]class ConsoleWriter:
    def __init__(self, connection, loglevel):
        # loglevel: None or empty string for stdout; or 'stderr'
        self.loglevel = loglevel or "stdout"
        self.dbEncoding = getattr(connection, "dbEncoding", None) or "ascii"
[docs]    def write(self, text):
        logfile = getattr(sys, self.loglevel)
        if PY2 and isinstance(text, unicode_type):
            try:
                text = text.encode(self.dbEncoding)
            except UnicodeEncodeError:
                text = repr(text)[2:-1]  # Remove u'...' from the repr
        logfile.write(text + '\n')  
[docs]class LogWriter:
    def __init__(self, connection, logger, loglevel):
        self.logger = logger
        self.loglevel = loglevel
        self.logmethod = getattr(logger, loglevel)
[docs]    def write(self, text):
        self.logmethod(text)  
[docs]def makeDebugWriter(connection, loggerName, loglevel):
    if not loggerName:
        return ConsoleWriter(connection, loglevel)
    import logging
    logger = logging.getLogger(loggerName)
    return LogWriter(connection, logger, loglevel) 
[docs]class Boolean(object):
    """A bool class that also understands some special string keywords
    Understands: yes/no, true/false, on/off, 1/0, case ignored.
    """
    _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
                 '0': False, 'no': False, 'false': False, 'off': False}
    def __new__(cls, value):
        try:
            return Boolean._keywords[value.lower()]
        except (AttributeError, KeyError):
            return bool(value) 
[docs]class DBConnection:
    def __init__(self, name=None, debug=False, debugOutput=False,
                 cache=True, style=None, autoCommit=True,
                 debugThreading=False, registry=None,
                 logger=None, loglevel=None):
        self.name = name
        self.debug = Boolean(debug)
        self.debugOutput = Boolean(debugOutput)
        self.debugThreading = Boolean(debugThreading)
        self.debugWriter = makeDebugWriter(self, logger, loglevel)
        self.doCache = Boolean(cache)
        self.cache = CacheSet(cache=self.doCache)
        self.style = style
        self._connectionNumbers = {}
        self._connectionCount = 1
        self.autoCommit = Boolean(autoCommit)
        self.registry = registry or None
        classregistry.registry(self.registry).addCallback(self.soClassAdded)
        registerConnectionInstance(self)
        atexit.register(_closeConnection, weakref.ref(self))
[docs]    def oldUri(self):
        auth = getattr(self, 'user', '') or ''
        if auth:
            if self.password:
                auth += ':' + self.password
            auth += '@'
        else:
            assert not getattr(self, 'password', None), (
                'URIs cannot express passwords without usernames')
        uri = '%s://%s' % (self.dbName, auth)
        if self.host:
            uri += self.host
            if self.port:
                uri += ':%d' % self.port
        uri += '/'
        db = self.db
        if db.startswith('/'):
            db = db[1:]
        return uri + db 
[docs]    def uri(self):
        auth = getattr(self, 'user', '') or ''
        if auth:
            auth = quote(auth)
            if self.password:
                auth += ':' + quote(self.password)
            auth += '@'
        else:
            assert not getattr(self, 'password', None), (
                'URIs cannot express passwords without usernames')
        uri = '%s://%s' % (self.dbName, auth)
        if self.host:
            uri += self.host
            if self.port:
                uri += ':%d' % self.port
        uri += '/'
        db = self.db
        if db.startswith('/'):
            db = db[1:]
        return uri + quote(db) 
[docs]    @classmethod
    def connectionFromOldURI(cls, uri):
        return cls._connectionFromParams(*cls._parseOldURI(uri)) 
[docs]    @classmethod
    def connectionFromURI(cls, uri):
        return cls._connectionFromParams(*cls._parseURI(uri)) 
    @staticmethod
    def _parseOldURI(uri):
        schema, rest = uri.split(':', 1)
        assert rest.startswith('/'), \
            
"URIs must start with scheme:/ -- " \
            
"you did not include a / (in %r)" % rest
        if rest.startswith('/') and not rest.startswith('//'):
            host = None
            rest = rest[1:]
        elif rest.startswith('///'):
            host = None
            rest = rest[3:]
        else:
            rest = rest[2:]
            if rest.find('/') == -1:
                host = rest
                rest = ''
            else:
                host, rest = rest.split('/', 1)
        if host and host.find('@') != -1:
            user, host = host.rsplit('@', 1)
            if user.find(':') != -1:
                user, password = user.split(':', 1)
            else:
                password = None
        else:
            user = password = None
        if host and host.find(':') != -1:
            _host, port = host.split(':')
            try:
                port = int(port)
            except ValueError:
                raise ValueError("port must be integer, "
                                 "got '%s' instead" % port)
            if not (1 <= port <= 65535):
                raise ValueError("port must be integer in the range 1-65535, "
                                 "got '%d' instead" % port)
            host = _host
        else:
            port = None
        path = '/' + rest
        if os.name == 'nt':
            if (len(rest) > 1) and (rest[1] == '|'):
                path = "%s:%s" % (rest[0], rest[2:])
        args = {}
        if path.find('?') != -1:
            path, arglist = path.split('?', 1)
            arglist = arglist.split('&')
            for single in arglist:
                argname, argvalue = single.split('=', 1)
                argvalue = unquote(argvalue)
                args[argname] = argvalue
        return user, password, host, port, path, args
    @staticmethod
    def _parseURI(uri):
        parsed = urlparse(uri)
        host, path = parsed.hostname, parsed.path
        user, password, port = None, None, None
        if parsed.username:
            user = unquote(parsed.username)
        if parsed.password:
            password = unquote(parsed.password)
        if parsed.port:
            port = int(parsed.port)
        path = unquote(path)
        if (os.name == 'nt') and (len(path) > 2):
            # Preserve backward compatibility with URIs like /C|/path;
            # replace '|' by ':'
            if path[2] == '|':
                path = "%s:%s" % (path[0:2], path[3:])
            # Remove leading slash
            if (path[0] == '/') and (path[2] == ':'):
                path = path[1:]
        query = parsed.query
        # hash-tag / fragment is ignored
        args = {}
        if query:
            for name, value in parse_qsl(query):
                args[name] = value
        return user, password, host, port, path, args
[docs]    def soClassAdded(self, soClass):
        """
        This is called for each new class; we use this opportunity
        to create an instance method that is bound to the class
        and this connection.
        """
        name = soClass.__name__
        assert not hasattr(self, name), (
            "Connection %r already has an attribute with the name "
            "%r (and you just created the conflicting class %r)"
            % (self, name, soClass))
        setattr(self, name, ConnWrapper(soClass, self)) 
[docs]    def expireAll(self):
        """
        Expire all instances of objects for this connection.
        """
        cache_set = self.cache
        cache_set.weakrefAll()
        for item in cache_set.getAll():
            item.expire()  
[docs]class ConnWrapper(object):
    """
    This represents a SQLObject class that is bound to a specific
    connection (instances have a connection instance variable, but
    classes are global, so this is binds the connection variable
    lazily when a class method is accessed)
    """
    # @@: methods that take connection arguments should be explicitly
    # marked up instead of the implicit use of a connection argument
    # and inspect.getargspec()
    def __init__(self, soClass, connection):
        self._soClass = soClass
        self._connection = connection
    def __call__(self, *args, **kw):
        kw['connection'] = self._connection
        return self._soClass(*args, **kw)
    def __getattr__(self, attr):
        meth = getattr(self._soClass, attr)
        if not isinstance(meth, types.MethodType):
            # We don't need to wrap non-methods
            return meth
        try:
            takes_conn = meth.takes_connection
        except AttributeError:
            args, varargs, varkw, defaults = inspect.getargspec(meth)
            assert not varkw and not varargs, (
                "I cannot tell whether I must wrap this method, "
                "because it takes **kw: %r"
                % meth)
            takes_conn = 'connection' in args
            meth.__func__.takes_connection = takes_conn
        if not takes_conn:
            return meth
        return ConnMethodWrapper(meth, self._connection) 
[docs]class ConnMethodWrapper(object):
    def __init__(self, method, connection):
        self._method = method
        self._connection = connection
    def __getattr__(self, attr):
        return getattr(self._method, attr)
    def __call__(self, *args, **kw):
        kw['connection'] = self._connection
        return self._method(*args, **kw)
    def __repr__(self):
        return '<Wrapped %r with connection %r>' % (
            self._method, self._connection) 
[docs]class DBAPI(DBConnection):
    """
    Subclass must define a `makeConnection()` method, which
    returns a newly-created connection object.
    ``queryInsertID`` must also be defined.
    """
    dbName = None
    def __init__(self, **kw):
        self._pool = []
        self._poolLock = threading.Lock()
        DBConnection.__init__(self, **kw)
        self._binaryType = type(self.module.Binary(b''))
    def _runWithConnection(self, meth, *args):
        conn = self.getConnection()
        try:
            val = meth(conn, *args)
        finally:
            self.releaseConnection(conn)
        return val
[docs]    def getConnection(self):
        self._poolLock.acquire()
        try:
            if not self._pool:
                conn = self.makeConnection()
                self._connectionNumbers[id(conn)] = self._connectionCount
                self._connectionCount += 1
            else:
                conn = self._pool.pop()
            if self.debug:
                s = 'ACQUIRE'
                if self._pool is not None:
                    s += ' pool=[%s]' % ', '.join(
                        [str(self._connectionNumbers[id(v)])
                            for v in self._pool])
                self.printDebug(conn, s, 'Pool')
            return conn
        finally:
            self._poolLock.release() 
[docs]    def releaseConnection(self, conn, explicit=False):
        if self.debug:
            if explicit:
                s = 'RELEASE (explicit)'
            else:
                s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
            if self._pool is None:
                s += ' no pooling'
            else:
                s += ' pool=[%s]' % ', '.join(
                    [str(self._connectionNumbers[id(v)]) for v in self._pool])
            self.printDebug(conn, s, 'Pool')
        if self.supportTransactions and not explicit:
            if self.autoCommit == 'exception':
                if self.debug:
                    self.printDebug(conn, 'auto/exception', 'ROLLBACK')
                conn.rollback()
                raise Exception('Object used outside of a transaction; '
                                'implicit COMMIT or ROLLBACK not allowed')
            elif self.autoCommit:
                if self.debug:
                    self.printDebug(conn, 'auto', 'COMMIT')
                if not getattr(conn, 'autocommit', False):
                    conn.commit()
            else:
                if self.debug:
                    self.printDebug(conn, 'auto', 'ROLLBACK')
                conn.rollback()
        if self._pool is not None:
            if conn not in self._pool:
                # @@: We can get duplicate releasing of connections with
                # the __del__ in Iteration (unfortunately, not sure why
                # it happens)
                self._pool.insert(0, conn)
        else:
            conn.close() 
[docs]    def printDebug(self, conn, s, name, type='query'):
        if name == 'Pool' and self.debug != 'Pool':
            return
        if type == 'query':
            sep = ': '
        else:
            sep = '->'
            s = repr(s)
        n = self._connectionNumbers[id(conn)]
        spaces = ' ' * (8 - len(name))
        if self.debugThreading:
            threadName = threading.currentThread().getName()
            threadName = (':' + threadName + ' ' * (8 - len(threadName)))
        else:
            threadName = ''
        msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
        self.debugWriter.write(msg) 
    def _executeRetry(self, conn, cursor, query):
        if self.debug:
            self.printDebug(conn, query, 'QueryR')
        return cursor.execute(query)
    def _query(self, conn, s):
        if self.debug:
            self.printDebug(conn, s, 'Query')
        c = conn.cursor()
        self._executeRetry(conn, c, s)
        c.close()
[docs]    def query(self, s):
        return self._runWithConnection(self._query, s) 
    def _queryAll(self, conn, s):
        if self.debug:
            self.printDebug(conn, s, 'QueryAll')
        c = conn.cursor()
        self._executeRetry(conn, c, s)
        value = c.fetchall()
        c.close()
        if self.debugOutput:
            self.printDebug(conn, value, 'QueryAll', 'result')
        return value
[docs]    def queryAll(self, s):
        return self._runWithConnection(self._queryAll, s) 
    def _queryAllDescription(self, conn, s):
        """
        Like queryAll, but returns (description, rows), where the
        description is cursor.description (which gives row types)
        """
        if self.debug:
            self.printDebug(conn, s, 'QueryAllDesc')
        c = conn.cursor()
        self._executeRetry(conn, c, s)
        value = c.fetchall()
        c.close()
        if self.debugOutput:
            self.printDebug(conn, value, 'QueryAll', 'result')
        return c.description, value
[docs]    def queryAllDescription(self, s):
        return self._runWithConnection(self._queryAllDescription, s) 
    def _queryOne(self, conn, s):
        if self.debug:
            self.printDebug(conn, s, 'QueryOne')
        c = conn.cursor()
        self._executeRetry(conn, c, s)
        value = c.fetchone()
        c.close()
        if self.debugOutput:
            self.printDebug(conn, value, 'QueryOne', 'result')
        return value
[docs]    def queryOne(self, s):
        return self._runWithConnection(self._queryOne, s) 
    def _insertSQL(self, table, names, values):
        return ("INSERT INTO %s (%s) VALUES (%s)" %
                (table, ', '.join(names),
                 ', '.join([self.sqlrepr(v) for v in values])))
[docs]    def transaction(self):
        return Transaction(self) 
[docs]    def queryInsertID(self, soInstance, id, names, values):
        return self._runWithConnection(self._queryInsertID, soInstance, id,
                                       names, values) 
[docs]    def iterSelect(self, select):
        return select.IterationClass(self, self.getConnection(),
                                     select, keepConnection=False) 
[docs]    def accumulateSelect(self, select, *expressions):
        """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
            to the select object.
        """
        q = select.queryForSelect().newItems(expressions).\
            
unlimited().orderBy(None)
        q = self.sqlrepr(q)
        val = self.queryOne(q)
        if len(expressions) == 1:
            val = val[0]
        return val 
[docs]    def queryForSelect(self, select):
        return self.sqlrepr(select.queryForSelect()) 
    def _SO_createJoinTable(self, join):
        self.query(self._SO_createJoinTableSQL(join))
    def _SO_createJoinTableSQL(self, join):
        return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
                (join.intermediateTable,
                 join.joinColumn,
                 self.joinSQLType(join),
                 join.otherColumn,
                 self.joinSQLType(join)))
    def _SO_dropJoinTable(self, join):
        self.query("DROP TABLE %s" % join.intermediateTable)
    def _SO_createIndex(self, soClass, index):
        self.query(self.createIndexSQL(soClass, index))
[docs]    def createIndexSQL(self, soClass, index):
        assert 0, 'Implement in subclasses' 
[docs]    def createTable(self, soClass):
        createSql, constraints = self.createTableSQL(soClass)
        self.query(createSql)
        return constraints 
[docs]    def createReferenceConstraints(self, soClass):
        refConstraints = [self.createReferenceConstraint(soClass, column)
                          for column in soClass.sqlmeta.columnList
                          if isinstance(column, col.SOForeignKey)]
        refConstraintDefs = [constraint for constraint in refConstraints
                             if constraint]
        return refConstraintDefs 
[docs]    def createSQL(self, soClass):
        tableCreateSQLs = getattr(soClass.sqlmeta, 'createSQL', None)
        if tableCreateSQLs:
            assert isinstance(tableCreateSQLs, (str, list, dict, tuple)), (
                '%s.sqlmeta.createSQL must be a str, list, dict or tuple.' %
                (soClass.__name__))
            if isinstance(tableCreateSQLs, dict):
                tableCreateSQLs = tableCreateSQLs.get(
                    soClass._connection.dbName, [])
            if isinstance(tableCreateSQLs, str):
                tableCreateSQLs = [tableCreateSQLs]
            if isinstance(tableCreateSQLs, tuple):
                tableCreateSQLs = list(tableCreateSQLs)
            assert isinstance(tableCreateSQLs, list), (
                'Unable to create a list from %s.sqlmeta.createSQL' %
                (soClass.__name__))
        return tableCreateSQLs or [] 
[docs]    def createTableSQL(self, soClass):
        constraints = self.createReferenceConstraints(soClass)
        extraSQL = self.createSQL(soClass)
        createSql = ('CREATE TABLE %s (\n%s\n)' %
                     (soClass.sqlmeta.table, self.createColumns(soClass)))
        return createSql, constraints + extraSQL 
[docs]    def createColumns(self, soClass):
        columnDefs = [self.createIDColumn(soClass)] \
            
+ [self.createColumn(soClass, col)
               for col in soClass.sqlmeta.columnList]
        return ",\n".join(["    %s" % c for c in columnDefs]) 
[docs]    def createReferenceConstraint(self, soClass, col):
        assert 0, "Implement in subclasses" 
[docs]    def createColumn(self, soClass, col):
        assert 0, "Implement in subclasses" 
[docs]    def dropTable(self, tableName, cascade=False):
        self.query("DROP TABLE %s" % tableName) 
[docs]    def clearTable(self, tableName):
        # 3-03 @@: Should this have a WHERE 1 = 1 or similar
        # clause?  In some configurations without the WHERE clause
        # the query won't go through, but maybe we shouldn't override
        # that.
        self.query("DELETE FROM %s" % tableName) 
[docs]    def createBinary(self, value):
        """
        Create a binary object wrapper for the given database.
        """
        # Default is Binary() function from the connection driver.
        return self.module.Binary(value) 
    # The _SO_* series of methods are sorts of "friend" methods
    # with SQLObject.  They grab values from the SQLObject instances
    # or classes freely, but keep the SQLObject class from accessing
    # the database directly.  This way no SQL is actually created
    # in the SQLObject class.
    def _SO_update(self, so, values):
        self.query("UPDATE %s SET %s WHERE %s = (%s)" %
                   (so.sqlmeta.table,
                    ", ".join(["%s = (%s)" % (dbName, self.sqlrepr(value))
                               for dbName, value in values]),
                    so.sqlmeta.idName,
                    self.sqlrepr(so.id)))
    def _SO_selectOne(self, so, columnNames):
        return self._SO_selectOneAlt(so, columnNames, so.q.id == so.id)
    def _SO_selectOneAlt(self, so, columnNames, condition):
        if columnNames:
            columns = [isinstance(x, string_type)
                       and sqlbuilder.SQLConstant(x)
                       or x for x in columnNames]
        else:
            columns = None
        return self.queryOne(self.sqlrepr(sqlbuilder.Select(
            columns, staticTables=[so.sqlmeta.table], clause=condition)))
    def _SO_delete(self, so):
        self.query("DELETE FROM %s WHERE %s = (%s)" %
                   (so.sqlmeta.table,
                    so.sqlmeta.idName,
                    self.sqlrepr(so.id)))
    def _SO_selectJoin(self, soClass, column, value):
        return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
                             (soClass.sqlmeta.idName,
                              soClass.sqlmeta.table,
                              column,
                              self.sqlrepr(value)))
    def _SO_intermediateJoin(self, table, getColumn, joinColumn, value):
        return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
                             (getColumn,
                              table,
                              joinColumn,
                              self.sqlrepr(value)))
    def _SO_intermediateDelete(self, table, firstColumn, firstValue,
                               secondColumn, secondValue):
        self.query("DELETE FROM %s WHERE %s = (%s) AND %s = (%s)" %
                   (table,
                    firstColumn,
                    self.sqlrepr(firstValue),
                    secondColumn,
                    self.sqlrepr(secondValue)))
    def _SO_intermediateInsert(self, table, firstColumn, firstValue,
                               secondColumn, secondValue):
        self.query("INSERT INTO %s (%s, %s) VALUES (%s, %s)" %
                   (table,
                    firstColumn,
                    secondColumn,
                    self.sqlrepr(firstValue),
                    self.sqlrepr(secondValue)))
    def _SO_columnClause(self, soClass, kw):
        from . import main
        data = []
        if 'id' in kw:
            data.append((soClass.sqlmeta.idName, kw.pop('id')))
        for soColumn in soClass.sqlmeta.columnList:
            key = soColumn.name
            if key in kw:
                val = kw.pop(key)
                if soColumn.from_python:
                    val = soColumn.from_python(
                        val,
                        sqlbuilder.SQLObjectState(soClass, connection=self))
                data.append((soColumn.dbName, val))
            elif soColumn.foreignName in kw:
                obj = kw.pop(soColumn.foreignName)
                if isinstance(obj, main.SQLObject):
                    data.append((soColumn.dbName, obj.id))
                else:
                    data.append((soColumn.dbName, obj))
        if kw:
            # pick the first key from kw to use to raise the error,
            raise TypeError("got an unexpected keyword argument(s): "
                            "%r" % kw.keys())
        if not data:
            return None
        return ' AND '.join(
            ['%s %s %s' %
             (dbName, "IS" if value is None else "=", self.sqlrepr(value))
             for dbName, value
             in data])
[docs]    def sqlrepr(self, v):
        return sqlrepr(v, self.dbName) 
    def __del__(self):
        self.close()
[docs]    def close(self):
        if not hasattr(self, '_pool'):
            # Probably there was an exception while creating this
            # instance, so it is incomplete.
            return
        if not self._pool:
            return
        self._poolLock.acquire()
        try:
            if not self._pool:  # _pool could be filled in a different thread
                return
            conns = self._pool[:]
            self._pool[:] = []
            for conn in conns:
                try:
                    conn.close()
                except self.module.Error:
                    pass
            del conn
            del conns
        finally:
            self._poolLock.release() 
[docs]    def createEmptyDatabase(self):
        """
        Create an empty database.
        """
        raise NotImplementedError 
[docs]    def make_odbc_conn_str(self, odb_source, db, host=None, port=None,
                           user=None, password=None):
        odbc_conn_parts = ['Driver={%s}' % odb_source]
        for odbc_keyword, value in \
                
zip(self.odbc_keywords, (host, port, user, password, db)):
            if value is not None:
                odbc_conn_parts.append('%s=%s' % (odbc_keyword, value))
        self.odbc_conn_str = ';'.join(odbc_conn_parts)  
[docs]class Iteration(object):
    def __init__(self, dbconn, rawconn, select, keepConnection=False):
        self.dbconn = dbconn
        self.rawconn = rawconn
        self.select = select
        self.keepConnection = keepConnection
        self.cursor = rawconn.cursor()
        self.query = self.dbconn.queryForSelect(select)
        if dbconn.debug:
            dbconn.printDebug(rawconn, self.query, 'Select')
        self.dbconn._executeRetry(self.rawconn, self.cursor, self.query)
    def __iter__(self):
        return self
    def __next__(self):
        return self.next()
[docs]    def next(self):
        result = self.cursor.fetchone()
        if result is None:
            self._cleanup()
            raise StopIteration
        if result[0] is None:
            return None
        if self.select.ops.get('lazyColumns', 0):
            obj = self.select.sourceClass.get(result[0],
                                              connection=self.dbconn)
            return obj
        else:
            obj = self.select.sourceClass.get(result[0],
                                              selectResults=result[1:],
                                              connection=self.dbconn)
            return obj 
    def _cleanup(self):
        if getattr(self, 'query', None) is None:
            # already cleaned up
            return
        self.cursor.close()
        if not self.keepConnection:
            self.dbconn.releaseConnection(self.rawconn)
        self.query = self.dbconn = self.rawconn = \
            
self.select = self.cursor = None
    def __del__(self):
        self._cleanup() 
[docs]class Transaction(object):
    def __init__(self, dbConnection):
        # this is to skip __del__ in case of an exception in this __init__
        self._obsolete = True
        self._dbConnection = dbConnection
        self._connection = dbConnection.getConnection()
        self._dbConnection._setAutoCommit(self._connection, 0)
        self.cache = CacheSet(cache=dbConnection.doCache)
        self._deletedCache = {}
        self._obsolete = False
[docs]    def assertActive(self):
        assert not self._obsolete, \
            
"This transaction has already gone through ROLLBACK; " \
            
"begin another transaction" 
[docs]    def query(self, s):
        self.assertActive()
        return self._dbConnection._query(self._connection, s) 
[docs]    def queryAll(self, s):
        self.assertActive()
        return self._dbConnection._queryAll(self._connection, s) 
[docs]    def queryOne(self, s):
        self.assertActive()
        return self._dbConnection._queryOne(self._connection, s) 
[docs]    def queryInsertID(self, soInstance, id, names, values):
        self.assertActive()
        return self._dbConnection._queryInsertID(
            self._connection, soInstance, id, names, values) 
[docs]    def iterSelect(self, select):
        self.assertActive()
        # We can't keep the cursor open with results in a transaction,
        # because we might want to use the connection while we're
        # still iterating through the results.
        # @@: But would it be okay for psycopg, with threadsafety
        # level 2?
        return iter(list(select.IterationClass(self, self._connection,
                                               select, keepConnection=True))) 
    def _SO_delete(self, inst):
        cls = inst.__class__.__name__
        if cls not in self._deletedCache:
            self._deletedCache[cls] = []
        self._deletedCache[cls].append(inst.id)
        if PY2:
            meth = types.MethodType(self._dbConnection._SO_delete.__func__,
                                    self, self.__class__)
        else:
            meth = types.MethodType(self._dbConnection._SO_delete.__func__,
                                    self)
        return meth(inst)
[docs]    def commit(self, close=False):
        if self._obsolete:
            # @@: is it okay to get extraneous commits?
            return
        if self._dbConnection.debug:
            self._dbConnection.printDebug(self._connection, '', 'COMMIT')
        self._send_event(CommitSignal)
        self._connection.commit()
        subCaches = [(sub[0], sub[1].allIDs())
                     for sub in self.cache.allSubCachesByClassNames().items()]
        subCaches.extend([(x[0], x[1]) for x in self._deletedCache.items()])
        for cls, ids in subCaches:
            for id in list(ids):
                inst = self._dbConnection.cache.tryGetByName(id, cls)
                if inst is not None:
                    inst.expire()
        if close:
            self._makeObsolete() 
[docs]    def rollback(self):
        if self._obsolete:
            # @@: is it okay to get extraneous rollbacks?
            return
        if self._dbConnection.debug:
            self._dbConnection.printDebug(self._connection, '', 'ROLLBACK')
        subCaches = [(sub, sub.allIDs()) for sub in self.cache.allSubCaches()]
        self._send_event(RollbackSignal)
        self._connection.rollback()
        for subCache, ids in subCaches:
            for id in list(ids):
                inst = subCache.tryGet(id)
                if inst is not None:
                    inst.expire()
        self._makeObsolete() 
    def _send_event(self, signal):
        """
        Pushes a list of class_names and related ids in cache.
        :param signal: Type of event signal to use
        """
        cached_classes_and_ids = [
            (class_name, cache.allIDs()) for class_name, cache in
            self.cache.allSubCachesByClassNames().items()
        ]
        if cached_classes_and_ids:
            from .main import sqlmeta  # Import here to avoid circular import
            send(signal, sqlmeta, cached_classes_and_ids)
    def __getattr__(self, attr):
        """
        If nothing else works, let the parent connection handle it.
        Except with this transaction as 'self'.  Poor man's
        acquisition?  Bad programming?  Okay, maybe.
        """
        self.assertActive()
        attr = getattr(self._dbConnection, attr)
        try:
            func = attr.__func__
        except AttributeError:
            if isinstance(attr, ConnWrapper):
                return ConnWrapper(attr._soClass, self)
            else:
                return attr
        else:
            if PY2:
                meth = types.MethodType(func, self, self.__class__)
            else:
                meth = types.MethodType(func, self)
            return meth
    def _makeObsolete(self):
        self._obsolete = True
        if self._dbConnection.autoCommit:
            self._dbConnection._setAutoCommit(self._connection, 1)
        self._dbConnection.releaseConnection(self._connection,
                                             explicit=True)
        self._connection = None
        self._deletedCache = {}
[docs]    def begin(self):
        # @@: Should we do this, or should begin() be a no-op when we're
        # not already obsolete?
        assert self._obsolete, \
            
"You cannot begin a new transaction session " \
            
"without rolling back this one"
        self._obsolete = False
        self._connection = self._dbConnection.getConnection()
        self._dbConnection._setAutoCommit(self._connection, 0) 
    def __del__(self):
        if self._obsolete:
            return
        self.rollback()
[docs]    def close(self):
        raise TypeError('You cannot just close transaction - '
                        'you should either call rollback(), commit() '
                        'or commit(close=True) '
                        'to close the underlying connection.')  
[docs]class ConnectionHub(object):
    """
    This object serves as a hub for connections, so that you can pass
    in a ConnectionHub to a SQLObject subclass as though it was a
    connection, but actually bind a real database connection later.
    You can also bind connections on a per-thread basis.
    You must hang onto the original ConnectionHub instance, as you
    cannot retrieve it again from the class or instance.
    To use the hub, do something like::
        hub = ConnectionHub()
        class MyClass(SQLObject):
            _connection = hub
        hub.threadConnection = connectionFromURI('...')
    """
    def __init__(self):
        self.threadingLocal = threading_local()
    def __get__(self, obj, type=None):
        # I'm a little surprised we have to do this, but apparently
        # the object's private dictionary of attributes doesn't
        # override this descriptor.
        if (obj is not None) and '_connection' in obj.__dict__:
            return obj.__dict__['_connection']
        return self.getConnection()
    def __set__(self, obj, value):
        obj.__dict__['_connection'] = value
[docs]    def getConnection(self):
        try:
            return self.threadingLocal.connection
        except AttributeError:
            try:
                return self.processConnection
            except AttributeError:
                raise AttributeError(
                    "No connection has been defined for this thread "
                    "or process") 
[docs]    def doInTransaction(self, func, *args, **kw):
        """
        This routine can be used to run a function in a transaction,
        rolling the transaction back if any exception is raised from
        that function, and committing otherwise.
        Use like::
            sqlhub.doInTransaction(process_request, os.environ)
        This will run ``process_request(os.environ)``.  The return
        value will be preserved.
        """
        # @@: In Python 2.5, something usable with with: should also
        # be added.
        try:
            old_conn = self.threadingLocal.connection
            old_conn_is_threading = True
        except AttributeError:
            old_conn = self.processConnection
            old_conn_is_threading = False
        conn = old_conn.transaction()
        if old_conn_is_threading:
            self.threadConnection = conn
        else:
            self.processConnection = conn
        try:
            try:
                value = func(*args, **kw)
            except Exception:
                conn.rollback()
                raise
            else:
                conn.commit(close=True)
                return value
        finally:
            if old_conn_is_threading:
                self.threadConnection = old_conn
            else:
                self.processConnection = old_conn 
    def _set_threadConnection(self, value):
        self.threadingLocal.connection = value
    def _get_threadConnection(self):
        return self.threadingLocal.connection
    def _del_threadConnection(self):
        del self.threadingLocal.connection
    threadConnection = property(_get_threadConnection,
                                _set_threadConnection,
                                _del_threadConnection) 
[docs]class ConnectionURIOpener(object):
    def __init__(self):
        self.schemeBuilders = {}
        self.instanceNames = {}
        self.cachedURIs = {}
[docs]    def registerConnection(self, schemes, builder):
        for uriScheme in schemes:
            assert uriScheme not in self.schemeBuilders \
                
or self.schemeBuilders[uriScheme] is builder, \
                
"A driver has already been registered " \
                
"for the URI scheme %s" % uriScheme
            self.schemeBuilders[uriScheme] = builder 
[docs]    def registerConnectionInstance(self, inst):
        if inst.name:
            assert (inst.name not in self.instanceNames
                    or self.instanceNames[inst.name] is cls  # noqa
                    ), ("A instance has already been registered "
                        "with the name %s" % inst.name)
            assert inst.name.find(':') == -1, \
                
"You cannot include ':' " \
                
"in your class names (%r)" % cls.name  # noqa
            self.instanceNames[inst.name] = inst 
[docs]    def connectionForURI(self, uri, oldUri=False, **args):
        if args:
            if '?' not in uri:
                uri += '?' + urlencode(args)
            else:
                uri += '&' + urlencode(args)
        if uri in self.cachedURIs:
            return self.cachedURIs[uri]
        if uri.find(':') != -1:
            scheme, rest = uri.split(':', 1)
            connCls = self.dbConnectionForScheme(scheme)
            if oldUri:
                conn = connCls.connectionFromOldURI(uri)
            else:
                conn = connCls.connectionFromURI(uri)
        else:
            # We just have a name, not a URI
            assert uri in self.instanceNames, \
                
"No SQLObject driver exists under the name %s" % uri
            conn = self.instanceNames[uri]
        # @@: Do we care if we clobber another connection?
        self.cachedURIs[uri] = conn
        return conn 
[docs]    def dbConnectionForScheme(self, scheme):
        assert scheme in self.schemeBuilders, (
            "No SQLObject driver exists for %s (only %s)" % (
                scheme,
                ', '.join(self.schemeBuilders.keys())))
        return self.schemeBuilders[scheme]()  
TheURIOpener = ConnectionURIOpener()
registerConnection = TheURIOpener.registerConnection
registerConnectionInstance = TheURIOpener.registerConnectionInstance
connectionForURI = TheURIOpener.connectionForURI
dbConnectionForScheme = TheURIOpener.dbConnectionForScheme
# Register DB URI schemas -- do import for side effects
# noqa is a directive for flake8 to ignore seemingly unused imports
from . import firebird  # noqa
from . import maxdb  # noqa
from . import mssql  # noqa
from . import mysql  # noqa
from . import postgres  # noqa
from . import sqlite  # noqa
from . import sybase  # noqa