#!/usr/bin/env python
from __future__ import print_function
import fnmatch
import optparse
import os
import re
import sys
import textwrap
import time
import warnings
try:
    from paste.deploy import appconfig
except ImportError:
    appconfig = None
import sqlobject
from sqlobject import col
from sqlobject.classregistry import findClass
from sqlobject.declarative import DeclarativeMeta
from sqlobject.util import moduleloader
from sqlobject.compat import PY2, with_metaclass, string_type
# It's not very unsafe to use tempnam like we are doing:
warnings.filterwarnings(
    'ignore', 'tempnam is a potential security risk.*',
    RuntimeWarning, '.*command', 28)
if PY2:
    # noqa for flake8 and python 3
    input = raw_input  # noqa
[docs]def nowarning_tempnam(*args, **kw):
    return os.tempnam(*args, **kw) 
[docs]class SQLObjectVersionTable(sqlobject.SQLObject):
    """
    This table is used to store information about the database and
    its version (used with record and update commands).
    """
    class sqlmeta:
        table = 'sqlobject_db_version'
    version = col.StringCol()
    updated = col.DateTimeCol(default=col.DateTimeCol.now) 
[docs]def db_differences(soClass, conn):
    """
    Returns the differences between a class and the table in a
    connection.  Returns [] if no differences are found.  This
    function does the best it can; it can miss many differences.
    """
    # @@: Repeats a lot from CommandStatus.command, but it's hard
    # to actually factor out the display logic.  Or I'm too lazy
    # to do so.
    diffs = []
    if not conn.tableExists(soClass.sqlmeta.table):
        if soClass.sqlmeta.columns:
            diffs.append('Does not exist in database')
    else:
        try:
            columns = conn.columnsFromSchema(soClass.sqlmeta.table,
                                             soClass)
        except AttributeError:
            # Database does not support reading columns
            pass
        else:
            existing = {}
            for _col in columns:
                _col = _col.withClass(soClass)
                existing[_col.dbName] = _col
            missing = {}
            for _col in soClass.sqlmeta.columnList:
                if _col.dbName in existing:
                    del existing[_col.dbName]
                else:
                    missing[_col.dbName] = _col
            for _col in existing.values():
                diffs.append('Database has extra column: %s' % _col.dbName)
            for _col in missing.values():
                diffs.append('Database missing column: %s' % _col.dbName)
    return diffs 
[docs]class CommandRunner(object):
    def __init__(self):
        self.commands = {}
        self.command_aliases = {}
[docs]    def run(self, argv):
        invoked_as = argv[0]
        args = argv[1:]
        for i in range(len(args)):
            if not args[i].startswith('-'):
                # this must be a command
                command = args[i].lower()
                del args[i]
                break
        else:
            # no command found
            self.invalid('No COMMAND given (try "%s help")'
                         % os.path.basename(invoked_as))
        real_command = self.command_aliases.get(command, command)
        if real_command not in self.commands.keys():
            self.invalid('COMMAND %s unknown' % command)
        runner = self.commands[real_command](
            invoked_as, command, args, self)
        runner.run() 
[docs]    def register(self, command):
        name = command.name
        self.commands[name] = command
        for alias in command.aliases:
            self.command_aliases[alias] = name 
[docs]    def invalid(self, msg, code=2):
        print(msg)
        sys.exit(code)  
the_runner = CommandRunner()
register = the_runner.register
[docs]def standard_parser(connection=True, simulate=True,
                    interactive=False, find_modules=True):
    parser = optparse.OptionParser()
    parser.add_option('-v', '--verbose',
                      help='Be verbose (multiple times for more verbosity)',
                      action='count',
                      dest='verbose',
                      default=0)
    if simulate:
        parser.add_option('-n', '--simulate',
                          help="Don't actually do anything (implies -v)",
                          action='store_true',
                          dest='simulate')
    if connection:
        parser.add_option('-c', '--connection',
                          help="The database connection URI",
                          metavar='URI',
                          dest='connection_uri')
    parser.add_option('-f', '--config-file',
                      help="The Paste config file "
                      "that contains the database URI (in the database key)",
                      metavar="FILE",
                      dest="config_file")
    if find_modules:
        parser.add_option('-m', '--module',
                          help="Module in which to find SQLObject classes",
                          action='append',
                          metavar='MODULE',
                          dest='modules',
                          default=[])
        parser.add_option('-p', '--package',
                          help="Package to search for SQLObject classes",
                          action="append",
                          metavar="PACKAGE",
                          dest="packages",
                          default=[])
        parser.add_option('--class',
                          help="Select only named classes (wildcards allowed)",
                          action="append",
                          metavar="NAME",
                          dest="class_matchers",
                          default=[])
    if interactive:
        parser.add_option('-i', '--interactive',
                          help="Ask before doing anything "
                          "(use twice to be more careful)",
                          action="count",
                          dest="interactive",
                          default=0)
    parser.add_option('--egg',
                      help="Select modules from the given Egg, "
                      "using sqlobject.txt",
                      action="append",
                      metavar="EGG_SPEC",
                      dest="eggs",
                      default=[])
    return parser 
[docs]class Command(with_metaclass(DeclarativeMeta, object)):
    min_args = 0
    min_args_error = 'You must provide at least %(min_args)s arguments'
    max_args = 0
    max_args_error = 'You must provide no more than %(max_args)s arguments'
    aliases = ()
    required_args = []
    description = None
    help = ''
[docs]    def orderClassesByDependencyLevel(self, classes):
        """
        Return classes ordered by their depth in the class dependency
        tree (this is *not* the inheritance tree), from the
        top level (independant) classes to the deepest level.
        The dependency tree is defined by the foreign key relations.
        """
        # @@: written as a self-contained function for now, to prevent
        # having to modify any core SQLObject component and namespace
        # contamination.
        # yemartin - 2006-08-08
        class SQLObjectCircularReferenceError(Exception):
            pass
        def findReverseDependencies(cls):
            """
            Return a list of classes that cls depends on. Note that
            "depends on" here mean "has a foreign key pointing to".
            """
            depended = []
            for _col in cls.sqlmeta.columnList:
                if _col.foreignKey:
                    other = findClass(_col.foreignKey,
                                      _col.soClass.sqlmeta.registry)
                    if (other is not cls) and (other not in depended):
                        depended.append(other)
            return depended
        # Cache to save already calculated dependency levels.
        dependency_levels = {}
        def calculateDependencyLevel(cls, dependency_stack=[]):
            """
            Recursively calculate the dependency level of cls, while
            using the dependency_stack to detect any circular reference.
            """
            # Return value from the cache if already calculated
            if cls in dependency_levels:
                return dependency_levels[cls]
            # Check for circular references
            if cls in dependency_stack:
                dependency_stack.append(cls)
                raise SQLObjectCircularReferenceError(
                    "Found a circular reference: %s " %
                    (' --> '.join([x.__name__ for x in dependency_stack])))
            dependency_stack.append(cls)
            # Recursively inspect dependent classes.
            depended = findReverseDependencies(cls)
            if depended:
                level = max([calculateDependencyLevel(x, dependency_stack)
                             for x in depended]) + 1
            else:
                level = 0
            dependency_levels[cls] = level
            return level
        # Now simply calculate and sort by dependency levels:
        try:
            sorter = []
            for cls in classes:
                level = calculateDependencyLevel(cls)
                sorter.append((level, cls))
            sorter.sort(key=lambda x: x[0])
            ordered_classes = [cls for _, cls in sorter]
        except SQLObjectCircularReferenceError as msg:
            # Failsafe: return the classes as-is if a circular reference
            # prevented the dependency levels to be calculated.
            print("Warning: a circular reference was detected in the "
                  "model. Unable to sort the classes by dependency: they "
                  "will be treated in alphabetic order. This may or may "
                  "not work depending on your database backend. "
                  "The error was:\n%s" % msg)
            return classes
        return ordered_classes 
    def __classinit__(cls, new_args):
        if cls.__bases__ == (object,):
            # This abstract base class
            return
        register(cls)
    def __init__(self, invoked_as, command_name, args, runner):
        self.invoked_as = invoked_as
        self.command_name = command_name
        self.raw_args = args
        self.runner = runner
[docs]    def run(self):
        self.parser.usage = "%%prog [options]\n%s" % self.summary
        if self.help:
            help = textwrap.fill(
                self.help, int(os.environ.get('COLUMNS', 80)) - 4)
            self.parser.usage += '\n' + help
        self.parser.prog = '%s %s' % (
            os.path.basename(self.invoked_as),
            self.command_name)
        if self.description:
            self.parser.description = self.description
        self.options, self.args = self.parser.parse_args(self.raw_args)
        if (getattr(self.options, 'simulate', False)
                and not self.options.verbose):
            self.options.verbose = 1
        if self.min_args is not None and len(self.args) < self.min_args:
            self.runner.invalid(
                self.min_args_error % {'min_args': self.min_args,
                                       'actual_args': len(self.args)})
        if self.max_args is not None and len(self.args) > self.max_args:
            self.runner.invalid(
                self.max_args_error % {'max_args': self.max_args,
                                       'actual_args': len(self.args)})
        for var_name, option_name in self.required_args:
            if not getattr(self.options, var_name, None):
                self.runner.invalid(
                    'You must provide the option %s' % option_name)
        conf = self.config()
        if conf and conf.get('sys_path'):
            update_sys_path(conf['sys_path'], self.options.verbose)
        if conf and conf.get('database'):
            conn = sqlobject.connectionForURI(conf['database'])
            sqlobject.sqlhub.processConnection = conn
        for egg_spec in getattr(self.options, 'eggs', []):
            self.load_options_from_egg(egg_spec)
        self.command() 
[docs]    def classes(self, require_connection=True,
                require_some=False):
        all = []
        for module_name in self.options.modules:
            all.extend(self.classes_from_module(
                moduleloader.load_module(module_name)))
        for package_name in self.options.packages:
            all.extend(self.classes_from_package(package_name))
        for egg_spec in self.options.eggs:
            all.extend(self.classes_from_egg(egg_spec))
        if self.options.class_matchers:
            filtered = []
            for soClass in all:
                name = soClass.__name__
                for matcher in self.options.class_matchers:
                    if fnmatch.fnmatch(name, matcher):
                        filtered.append(soClass)
                        break
            all = filtered
        conn = self.connection()
        if conn:
            for soClass in all:
                soClass._connection = conn
        else:
            missing = []
            for soClass in all:
                try:
                    if not soClass._connection:
                        missing.append(soClass)
                except AttributeError:
                    missing.append(soClass)
            if missing and require_connection:
                self.runner.invalid(
                    'These classes do not have connections set:\n  * %s\n'
                    'You must indicate --connection=URI'
                    % '\n  * '.join([soClass.__name__
                                     for soClass in missing]))
        if require_some and not all:
            print('No classes found!')
            if self.options.modules:
                print('Looked in modules: %s' %
                      ', '.join(self.options.modules))
            else:
                print('No modules specified')
            if self.options.packages:
                print('Looked in packages: %s' %
                      ', '.join(self.options.packages))
            else:
                print('No packages specified')
            if self.options.class_matchers:
                print('Matching class pattern: %s' %
                      self.options.class_matches)
            if self.options.eggs:
                print('Looked in eggs: %s' % ', '.join(self.options.eggs))
            else:
                print('No eggs specified')
            sys.exit(1)
        return self.orderClassesByDependencyLevel(all) 
[docs]    def classes_from_module(self, module):
        all = []
        if hasattr(module, 'soClasses'):
            for name_or_class in module.soClasses:
                if isinstance(name_or_class, str):
                    name_or_class = getattr(module, name_or_class)
                all.append(name_or_class)
        else:
            for name in dir(module):
                value = getattr(module, name)
                if (isinstance(value, type)
                    and issubclass(value, sqlobject.SQLObject)
                        and value.__module__ == module.__name__):
                    all.append(value)
        return all 
[docs]    def connection(self):
        config = self.config()
        if config is not None:
            assert config.get('database'), (
                "No database variable found in config file %s"
                % self.options.config_file)
            return sqlobject.connectionForURI(config['database'])
        elif getattr(self.options, 'connection_uri', None):
            return sqlobject.connectionForURI(self.options.connection_uri)
        else:
            return None 
[docs]    def config(self):
        if not getattr(self.options, 'config_file', None):
            return None
        config_file = self.options.config_file
        if appconfig:
            if (not config_file.startswith('egg:')
                    and not config_file.startswith('config:')):
                config_file = 'config:' + config_file
            return appconfig(config_file,
                             relative_to=os.getcwd())
        else:
            return self.ini_config(config_file) 
[docs]    def ini_config(self, conf_fn):
        conf_section = 'main'
        if '#' in conf_fn:
            conf_fn, conf_section = conf_fn.split('#', 1)
        try:
            from ConfigParser import ConfigParser
        except ImportError:
            from configparser import ConfigParser
        p = ConfigParser()
        # Case-sensitive:
        p.optionxform = str
        if not os.path.exists(conf_fn):
            # Stupid RawConfigParser doesn't give an error for
            # non-existant files:
            raise OSError(
                "Config file %s does not exist" % self.options.config_file)
        p.read([conf_fn])
        p._defaults.setdefault(
            'here', os.path.dirname(os.path.abspath(conf_fn)))
        possible_sections = []
        for section in p.sections():
            name = section.strip().lower()
            if (conf_section == name
                or (conf_section == name.split(':')[-1]
                    and name.split(':')[0] in ('app', 'application'))):
                possible_sections.append(section)
        if not possible_sections:
            raise OSError(
                "Config file %s does not have a section [%s] or [*:%s]"
                % (conf_fn, conf_section, conf_section))
        if len(possible_sections) > 1:
            raise OSError(
                "Config file %s has multiple sections matching %s: %s"
                % (conf_fn, conf_section, ', '.join(possible_sections)))
        config = {}
        for op in p.options(possible_sections[0]):
            config[op] = p.get(possible_sections[0], op)
        return config 
[docs]    def classes_from_package(self, package_name):
        all = []
        package = moduleloader.load_module(package_name)
        package_dir = os.path.dirname(package.__file__)
        def find_classes_in_file(arg, dir_name, filenames):
            if dir_name.startswith('.svn'):
                return
            filenames = filter(
                lambda fname: fname.endswith('.py') and fname != '__init__.py',
                filenames)
            for fname in filenames:
                module_name = os.path.join(dir_name, fname)
                module_name = module_name[module_name.find(package_name):]
                module_name = module_name.replace(os.path.sep, '.')[:-3]
                try:
                    module = moduleloader.load_module(module_name)
                except ImportError as err:
                    if self.options.verbose:
                        print('Could not import module "%s". '
                              'Error was : "%s"' % (module_name, err))
                    continue
                except Exception as exc:
                    if self.options.verbose:
                        print('Unknown exception while processing module '
                              '"%s" : "%s"' % (module_name, exc))
                    continue
                classes = self.classes_from_module(module)
                all.extend(classes)
        for dirpath, dirnames, filenames in os.walk(package_dir):
            find_classes_in_file(None, dirpath, dirnames + filenames)
        return all 
[docs]    def classes_from_egg(self, egg_spec):
        modules = []
        dist, conf = self.config_from_egg(egg_spec, warn_no_sqlobject=True)
        for mod in conf.get('db_module', '').split(','):
            mod = mod.strip()
            if not mod:
                continue
            if self.options.verbose:
                print('Looking in module %s' % mod)
            modules.extend(self.classes_from_module(
                moduleloader.load_module(mod)))
        return modules 
[docs]    def load_options_from_egg(self, egg_spec):
        dist, conf = self.config_from_egg(egg_spec)
        if (hasattr(self.options, 'output_dir')
                and not self.options.output_dir and conf.get('history_dir')):
            dir = conf['history_dir']
            dir = dir.replace('$base', dist.location)
            self.options.output_dir = dir 
[docs]    def config_from_egg(self, egg_spec, warn_no_sqlobject=True):
        import pkg_resources
        dist = pkg_resources.get_distribution(egg_spec)
        if not dist.has_metadata('sqlobject.txt'):
            if warn_no_sqlobject:
                print('No sqlobject.txt in %s egg info' % egg_spec)
            return None, {}
        result = {}
        for line in dist.get_metadata_lines('sqlobject.txt'):
            line = line.strip()
            if not line or line.startswith('#'):
                continue
            name, value = line.split('=', 1)
            name = name.strip().lower()
            if name in result:
                print('Warning: %s appears more than once '
                      'in sqlobject.txt' % name)
            result[name.strip().lower()] = value.strip()
        return dist, result 
[docs]    def command(self):
        raise NotImplementedError 
    def _get_prog_name(self):
        return os.path.basename(self.invoked_as)
    prog_name = property(_get_prog_name)
[docs]    def ask(self, prompt, safe=False, default=True):
        if self.options.interactive >= 2:
            default = safe
        if default:
            prompt += ' [Y/n]? '
        else:
            prompt += ' [y/N]? '
        while 1:
            response = input(prompt).strip()
            if not response.strip():
                return default
            if response and response[0].lower() in ('y', 'n'):
                return response[0].lower() == 'y'
            print('Y or N please') 
[docs]    def shorten_filename(self, fn):
        """
        Shortens a filename to make it relative to the current
        directory (if it can).  For display purposes.
        """
        if fn.startswith(os.getcwd() + '/'):
            fn = fn[len(os.getcwd()) + 1:]
        return fn 
[docs]    def open_editor(self, pretext, breaker=None, extension='.txt'):
        """
        Open an editor with the given text.  Return the new text,
        or None if no edits were made.  If given, everything after
        `breaker` will be ignored.
        """
        fn = nowarning_tempnam() + extension
        f = open(fn, 'w')
        f.write(pretext)
        f.close()
        print('$EDITOR %s' % fn)
        os.system('$EDITOR %s' % fn)
        f = open(fn, 'r')
        content = f.read()
        f.close()
        if breaker:
            content = content.split(breaker)[0]
            pretext = pretext.split(breaker)[0]
        if content == pretext or not content.strip():
            return None
        return content  
[docs]class CommandSQL(Command):
    name = 'sql'
    summary = 'Show SQL CREATE statements'
    parser = standard_parser(simulate=False)
[docs]    def command(self):
        classes = self.classes()
        allConstraints = []
        for cls in classes:
            if self.options.verbose >= 1:
                print('-- %s from %s' % (
                      cls.__name__, cls.__module__))
            createSql, constraints = cls.createTableSQL()
            print(createSql.strip() + ';\n')
            allConstraints.append(constraints)
        for constraints in allConstraints:
            if constraints:
                for constraint in constraints:
                    if constraint:
                        print(constraint.strip() + ';\n')  
[docs]class CommandList(Command):
    name = 'list'
    summary = 'Show all SQLObject classes found'
    parser = standard_parser(simulate=False, connection=False)
[docs]    def command(self):
        if self.options.verbose >= 1:
            print('Classes found:')
        classes = self.classes(require_connection=False)
        for soClass in classes:
            print('%s.%s' % (soClass.__module__, soClass.__name__))
            if self.options.verbose >= 1:
                print('  Table: %s' % soClass.sqlmeta.table)  
[docs]class CommandCreate(Command):
    name = 'create'
    summary = 'Create tables'
    parser = standard_parser(interactive=True)
    parser.add_option('--create-db',
                      action='store_true',
                      dest='create_db',
                      help="Create the database")
[docs]    def command(self):
        v = self.options.verbose
        created = 0
        existing = 0
        dbs_created = []
        constraints = {}
        for soClass in self.classes(require_some=True):
            if (self.options.create_db
                    and soClass._connection not in dbs_created):
                if not self.options.simulate:
                    try:
                        soClass._connection.createEmptyDatabase()
                    except soClass._connection.module.ProgrammingError as e:
                        if str(e).find('already exists') != -1:
                            print('Database already exists')
                        else:
                            raise
                else:
                    print('(simulating; cannot create database)')
                dbs_created.append(soClass._connection)
            if soClass._connection not in constraints.keys():
                constraints[soClass._connection] = []
            exists = soClass._connection.tableExists(soClass.sqlmeta.table)
            if v >= 1:
                if exists:
                    existing += 1
                    print('%s already exists.' % soClass.__name__)
                else:
                    print('Creating %s' % soClass.__name__)
            if v >= 2:
                sql, extra = soClass.createTableSQL()
                print(sql)
            if (not self.options.simulate and not exists):
                if self.options.interactive:
                    if self.ask('Create %s' % soClass.__name__):
                        created += 1
                        tableConstraints = soClass.createTable(
                            applyConstraints=False)
                        if tableConstraints:
                            constraints[soClass._connection].append(
                                tableConstraints)
                    else:
                        print('Cancelled')
                else:
                    created += 1
                    tableConstraints = soClass.createTable(
                        applyConstraints=False)
                    if tableConstraints:
                        constraints[soClass._connection].append(
                            tableConstraints)
        for connection in constraints.keys():
            if v >= 2:
                print('Creating constraints')
            for constraintList in constraints[connection]:
                for constraint in constraintList:
                    if constraint:
                        connection.query(constraint)
        if v >= 1:
            print('%i tables created (%i already exist)' % (
                  created, existing))  
[docs]class CommandDrop(Command):
    name = 'drop'
    summary = 'Drop tables'
    parser = standard_parser(interactive=True)
[docs]    def command(self):
        v = self.options.verbose
        dropped = 0
        not_existing = 0
        for soClass in reversed(self.classes()):
            exists = soClass._connection.tableExists(soClass.sqlmeta.table)
            if v >= 1:
                if exists:
                    print('Dropping %s' % soClass.__name__)
                else:
                    not_existing += 1
                    print('%s does not exist.' % soClass.__name__)
            if (not self.options.simulate and exists):
                if self.options.interactive:
                    if self.ask('Drop %s' % soClass.__name__):
                        dropped += 1
                        soClass.dropTable()
                    else:
                        print('Cancelled')
                else:
                    dropped += 1
                    soClass.dropTable()
        if v >= 1:
            print('%i tables dropped (%i didn\'t exist)' % (
                  dropped, not_existing))  
[docs]class CommandStatus(Command):
    name = 'status'
    summary = 'Show status of classes vs. database'
    help = ('This command checks the SQLObject definition and checks if '
            'the tables in the database match.  It can always test for '
            'missing tables, and on some databases can test for the '
            'existance of other tables.  Column types are not currently '
            'checked.')
    parser = standard_parser(simulate=False)
[docs]    def print_class(self, soClass):
        if self.printed:
            return
        self.printed = True
        print('Checking %s...' % soClass.__name__) 
[docs]    def command(self):
        good = 0
        bad = 0
        missing_tables = 0
        columnsFromSchema_warning = False
        for soClass in self.classes(require_some=True):
            conn = soClass._connection
            self.printed = False
            if self.options.verbose:
                self.print_class(soClass)
            if not conn.tableExists(soClass.sqlmeta.table):
                self.print_class(soClass)
                print('  Does not exist in database')
                missing_tables += 1
                continue
            try:
                columns = conn.columnsFromSchema(soClass.sqlmeta.table,
                                                 soClass)
            except AttributeError:
                if not columnsFromSchema_warning:
                    print('Database does not support reading columns')
                    columnsFromSchema_warning = True
                good += 1
                continue
            except AssertionError as e:
                print('Cannot read db table %s: %s' % (
                    soClass.sqlmeta.table, e))
                continue
            existing = {}
            for _col in columns:
                _col = _col.withClass(soClass)
                existing[_col.dbName] = _col
            missing = {}
            for _col in soClass.sqlmeta.columnList:
                if _col.dbName in existing:
                    del existing[_col.dbName]
                else:
                    missing[_col.dbName] = _col
            if existing:
                self.print_class(soClass)
                for _col in existing.values():
                    print('  Database has extra column: %s' % _col.dbName)
            if missing:
                self.print_class(soClass)
                for _col in missing.values():
                    print('  Database missing column: %s' % _col.dbName)
            if existing or missing:
                bad += 1
            else:
                good += 1
        if self.options.verbose:
            print('%i in sync; %i out of sync; %i not in database' % (
                  good, bad, missing_tables))  
[docs]class CommandHelp(Command):
    name = 'help'
    summary = 'Show help'
    parser = optparse.OptionParser()
    max_args = 1
[docs]    def command(self):
        if self.args:
            the_runner.run([self.invoked_as, self.args[0], '-h'])
        else:
            print('Available commands:')
            print('  (use "%s help COMMAND" or "%s COMMAND -h" ' % (
                  self.prog_name, self.prog_name))
            print('  for more information)')
            items = sorted(the_runner.commands.items())
            max_len = max([len(cn) for cn, c in items])
            for command_name, command in items:
                print('%s:%s %s' % (command_name,
                                    ' ' * (max_len - len(command_name)),
                                    command.summary))
                if command.aliases:
                    print('%s (Aliases: %s)' % (
                        ' ' * max_len, ', '.join(command.aliases)))  
[docs]class CommandExecute(Command):
    name = 'execute'
    summary = 'Execute SQL statements'
    help = ('Runs SQL statements directly in the database, with no '
            'intervention.  Useful when used with a configuration file.  '
            'Each argument is executed as an individual statement.')
    parser = standard_parser(find_modules=False)
    parser.add_option('--stdin',
                      help="Read SQL from stdin "
                      "(normally takes SQL from the command line)",
                      dest="use_stdin",
                      action="store_true")
    max_args = None
[docs]    def command(self):
        args = self.args
        if self.options.use_stdin:
            if self.options.verbose:
                print("Reading additional SQL from stdin "
                      "(Ctrl-D or Ctrl-Z to finish)...")
            args.append(sys.stdin.read())
        self.conn = self.connection().getConnection()
        self.cursor = self.conn.cursor()
        for sql in args:
            self.execute_sql(sql) 
[docs]    def execute_sql(self, sql):
        if self.options.verbose:
            print(sql)
        try:
            self.cursor.execute(sql)
        except Exception as e:
            if not self.options.verbose:
                print(sql)
            print("****Error:")
            print('    ', e)
            return
        desc = self.cursor.description
        rows = self.cursor.fetchall()
        if self.options.verbose:
            if not self.cursor.rowcount:
                print("No rows accessed")
            else:
                print("%i rows accessed" % self.cursor.rowcount)
        if desc:
            for (name, type_code, display_size, internal_size,
                    precision, scale, null_ok) in desc:
                sys.stdout.write("%s\t" % name)
            sys.stdout.write("\n")
        for row in rows:
            for _col in row:
                sys.stdout.write("%r\t" % _col)
            sys.stdout.write("\n")
        print()  
[docs]class CommandRecord(Command):
    name = 'record'
    summary = 'Record historical information about the database status'
    help = ('Record state of table definitions.  The state of each '
            'table is written out to a separate file in a directory, '
            'and that directory forms a "version".  A table is also '
            'added to your database (%s) that reflects the version the '
            'database is currently at.  Use the upgrade command to '
            'sync databases with code.'
            % SQLObjectVersionTable.sqlmeta.table)
    parser = standard_parser()
    parser.add_option('--output-dir',
                      help="Base directory for recorded definitions",
                      dest="output_dir",
                      metavar="DIR",
                      default=None)
    parser.add_option('--no-db-record',
                      help="Don't record version to database",
                      dest="db_record",
                      action="store_false",
                      default=True)
    parser.add_option('--force-create',
                      help="Create a new version even if appears to be "
                      "identical to the last version",
                      action="store_true",
                      dest="force_create")
    parser.add_option('--name',
                      help="The name to append to the version.  The "
                      "version should sort after previous versions (so "
                      "any versions from the same day should come "
                      "alphabetically before this version).",
                      dest="version_name",
                      metavar="NAME")
    parser.add_option('--force-db-version',
                      help="Update the database version, and include no "
                      "database information.  This is for databases that "
                      "were developed without any interaction with "
                      "this tool, to create a 'beginning' revision.",
                      metavar="VERSION_NAME",
                      dest="force_db_version")
    parser.add_option('--edit',
                      help="Open an editor for the upgrader in the last "
                      "version (using $EDITOR).",
                      action="store_true",
                      dest="open_editor")
    version_regex = re.compile(r'^\d\d\d\d-\d\d-\d\d')
[docs]    def command(self):
        if self.options.force_db_version:
            self.command_force_db_version()
            return
        v = self.options.verbose
        sim = self.options.simulate
        classes = self.classes()
        if not classes:
            print("No classes found!")
            return
        output_dir = self.find_output_dir()
        version = os.path.basename(output_dir)
        print("Creating version %s" % version)
        conns = []
        files = {}
        for cls in self.classes():
            dbName = cls._connection.dbName
            if cls._connection not in conns:
                conns.append(cls._connection)
            fn = os.path.join(cls.__name__ + '_' + dbName + '.sql')
            if sim:
                continue
            create, constraints = cls.createTableSQL()
            if constraints:
                constraints = '\n-- Constraints:\n%s\n' % (
                    '\n'.join(constraints))
            else:
                constraints = ''
            files[fn] = ''.join([
                '-- Exported definition from %s\n'
                % time.strftime('%Y-%m-%dT%H:%M:%S'),
                '-- Class %s.%s\n'
                % (cls.__module__, cls.__name__),
                '-- Database: %s\n'
                % dbName,
                create.strip(),
                '\n',
                constraints])
        last_version_dir = self.find_last_version()
        if last_version_dir and not self.options.force_create:
            if v > 1:
                print("Checking %s to see if it is current" % last_version_dir)
            files_copy = files.copy()
            for fn in os.listdir(last_version_dir):
                if not fn.endswith('.sql'):
                    continue
                if fn not in files_copy:
                    if v > 1:
                        print("Missing file %s" % fn)
                    break
                f = open(os.path.join(last_version_dir, fn), 'r')
                content = f.read()
                f.close()
                if (self.strip_comments(files_copy[fn])
                        != self.strip_comments(content)):
                    if v > 1:
                        print("Content does not match: %s" % fn)
                    break
                del files_copy[fn]
            else:
                # No differences so far
                if not files_copy:
                    # Used up all files
                    print("Current status matches version %s"
                          % os.path.basename(last_version_dir))
                    return
                if v > 1:
                    print("Extra files: %s" % ', '.join(files_copy.keys()))
            if v:
                print("Current state does not match %s"
                      % os.path.basename(last_version_dir))
        if v > 1 and not last_version_dir:
            print("No last version to check")
        if not sim:
            os.mkdir(output_dir)
        if v:
            print('Making directory %s' % self.shorten_filename(output_dir))
        files = sorted(files.items())
        for fn, content in files:
            if v:
                print('  Writing %s' % self.shorten_filename(fn))
            if not sim:
                f = open(os.path.join(output_dir, fn), 'w')
                f.write(content)
                f.close()
        if self.options.db_record:
            all_diffs = []
            for cls in self.classes():
                for conn in conns:
                    diffs = db_differences(cls, conn)
                    for diff in diffs:
                        if len(conns) > 1:
                            diff = '  (%s).%s: %s' % (
                                conn.uri(), cls.sqlmeta.table, diff)
                        else:
                            diff = '  %s: %s' % (cls.sqlmeta.table, diff)
                        all_diffs.append(diff)
            if all_diffs:
                print('Database does not match schema:')
                print('\n'.join(all_diffs))
                for conn in conns:
                    self.update_db(version, conn)
        else:
            all_diffs = []
        if self.options.open_editor:
            if not last_version_dir:
                print("Cannot edit upgrader because there is no "
                      "previous version")
            else:
                breaker = ('-' * 20 + ' lines below this will be ignored '
                           + '-' * 20)
                pre_text = breaker + '\n' + '\n'.join(all_diffs)
                text = self.open_editor('\n\n' + pre_text, breaker=breaker,
                                        extension='.sql')
                if text is not None:
                    fn = os.path.join(last_version_dir,
                                      'upgrade_%s_%s.sql' %
                                      (dbName, version))
                    f = open(fn, 'w')
                    f.write(text)
                    f.close()
                    print('Wrote to %s' % fn) 
[docs]    def update_db(self, version, conn):
        v = self.options.verbose
        if not conn.tableExists(SQLObjectVersionTable.sqlmeta.table):
            if v:
                print('Creating table %s'
                      % SQLObjectVersionTable.sqlmeta.table)
            sql = SQLObjectVersionTable.createTableSQL(connection=conn)
            if v > 1:
                print(sql)
            if not self.options.simulate:
                SQLObjectVersionTable.createTable(connection=conn)
        if not self.options.simulate:
            SQLObjectVersionTable.clearTable(connection=conn)
            SQLObjectVersionTable(
                version=version,
                connection=conn) 
[docs]    def base_dir(self):
        base = self.options.output_dir
        if base is None:
            config = self.config()
            if config is not None:
                base = config.get('sqlobject_history_dir', '.')
            else:
                base = '.'
        if not os.path.exists(base):
            print('Creating history directory %s' %
                  self.shorten_filename(base))
            if not self.options.simulate:
                os.makedirs(base)
        return base 
[docs]    def find_output_dir(self):
        today = time.strftime('%Y-%m-%d', time.localtime())
        if self.options.version_name:
            dir = os.path.join(self.base_dir(), today + '-'
                               + self.options.version_name)
            if os.path.exists(dir):
                print("Error, directory already exists: %s"
                      % dir)
                sys.exit(1)
            return dir
        extra = ''
        while 1:
            dir = os.path.join(self.base_dir(), today + extra)
            if not os.path.exists(dir):
                return dir
            if not extra:
                extra = 'a'
            else:
                extra = chr(ord(extra) + 1) 
[docs]    def find_last_version(self):
        names = []
        for fn in os.listdir(self.base_dir()):
            if not self.version_regex.search(fn):
                continue
            names.append(fn)
        if not names:
            return None
        names.sort()
        return os.path.join(self.base_dir(), names[-1]) 
[docs]    def command_force_db_version(self):
        v = self.options.verbose
        sim = self.options.simulate
        version = self.options.force_db_version
        if not self.version_regex.search(version):
            print("Versions must be in the format YYYY-MM-DD...")
            print("You version %s does not fit this" % version)
            return
        version_dir = os.path.join(self.base_dir(), version)
        if not os.path.exists(version_dir):
            if v:
                print('Creating %s' % self.shorten_filename(version_dir))
            if not sim:
                os.mkdir(version_dir)
        elif v:
            print('Directory %s exists'
                  % self.shorten_filename(version_dir))
        if self.options.db_record:
            self.update_db(version, self.connection())  
[docs]class CommandUpgrade(CommandRecord):
    name = 'upgrade'
    summary = 'Update the database to a new version (as created by record)'
    help = ('This command runs scripts (that you write by hand) to '
            'upgrade a database.  The database\'s current version is in '
            'the sqlobject_version table (use record --force-db-version '
            'if a database does not have a sqlobject_version table), '
            'and upgrade scripts are in the version directory you are '
            'upgrading FROM, named upgrade_DBNAME_VERSION.sql, like '
            '"upgrade_mysql_2004-12-01b.sql".')
    parser = standard_parser(find_modules=False)
    parser.add_option('--upgrade-to',
                      help="Upgrade to the given version "
                      "(default: newest version)",
                      dest="upgrade_to",
                      metavar="VERSION")
    parser.add_option('--output-dir',
                      help="Base directory for recorded definitions",
                      dest="output_dir",
                      metavar="DIR",
                      default=None)
    upgrade_regex = re.compile(r'^upgrade_([a-z]*)_([^.]*)\.sql$', re.I)
[docs]    def command(self):
        v = self.options.verbose
        sim = self.options.simulate
        if self.options.upgrade_to:
            version_to = self.options.upgrade_to
        else:
            fname = self.find_last_version()
            if fname is None:
                print("No version exists, use 'record' command to create one")
                return
            version_to = os.path.basename(fname)
        current = self.current_version()
        if v:
            print('Current version: %s' % current)
        version_list = self.make_plan(current, version_to)
        if not version_list:
            print('Database up to date')
            return
        if v:
            print('Plan:')
            for next_version, upgrader in version_list:
                print('  Use %s to upgrade to %s' % (
                      self.shorten_filename(upgrader), next_version))
        conn = self.connection()
        for next_version, upgrader in version_list:
            f = open(upgrader)
            sql = f.read()
            f.close()
            if v:
                print("Running:")
                print(sql)
                print('-' * 60)
            if not sim:
                try:
                    conn.query(sql)
                except Exception:
                    print("Error in script: %s" % upgrader)
                    raise
            self.update_db(next_version, conn)
        print('Done.') 
[docs]    def current_version(self):
        conn = self.connection()
        if not conn.tableExists(SQLObjectVersionTable.sqlmeta.table):
            print('No sqlobject_version table!')
            sys.exit(1)
        versions = list(SQLObjectVersionTable.select(connection=conn))
        if not versions:
            print('No rows in sqlobject_version!')
            sys.exit(1)
        if len(versions) > 1:
            print('Ambiguous sqlobject_version_table')
            sys.exit(1)
        return versions[0].version 
[docs]    def make_plan(self, current, dest):
        if current == dest:
            return []
        dbname = self.connection().dbName
        next_version, upgrader = self.best_upgrade(current, dest, dbname)
        if not upgrader:
            print('No way to upgrade from %s to %s' % (current, dest))
            print('(you need a %s/upgrade_%s_%s.sql script)'
                  % (current, dbname, dest))
            sys.exit(1)
        plan = [(next_version, upgrader)]
        if next_version == dest:
            return plan
        else:
            return plan + self.make_plan(next_version, dest) 
[docs]    def best_upgrade(self, current, dest, target_dbname):
        current_dir = os.path.join(self.base_dir(), current)
        if self.options.verbose > 1:
            print('Looking in %s for upgraders'
                  % self.shorten_filename(current_dir))
        upgraders = []
        for fn in os.listdir(current_dir):
            match = self.upgrade_regex.search(fn)
            if not match:
                if self.options.verbose > 1:
                    print('Not an upgrade script: %s' % fn)
                continue
            dbname = match.group(1)
            version = match.group(2)
            if dbname != target_dbname:
                if self.options.verbose > 1:
                    print('Not for this database: %s (want %s)' % (
                          dbname, target_dbname))
                continue
            if version > dest:
                if self.options.verbose > 1:
                    print('Version too new: %s (only want %s)' % (
                          version, dest))
            upgraders.append((version, os.path.join(current_dir, fn)))
        if not upgraders:
            if self.options.verbose > 1:
                print('No upgraders found in %s' % current_dir)
            return None, None
        upgraders.sort()
        return upgraders[-1]  
[docs]def update_sys_path(paths, verbose):
    if isinstance(paths, string_type):
        paths = [paths]
    for path in paths:
        path = os.path.abspath(path)
        if path not in sys.path:
            if verbose > 1:
                print('Adding %s to path' % path)
            sys.path.insert(0, path) 
if __name__ == '__main__':
    the_runner.run(sys.argv)