#!/usr/bin/env python

# Copyright (c) 2011-2014. All Right Reserved, https://chartio.com/
#
# THIS CODE AND INFORMATION ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY
# KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A
# PARTICULAR PURPOSE.

import ConfigParser
import datetime
import errno
import getpass
import optparse
import os
import random
import signal
import socket
import string
import subprocess
import sys
import tempfile
import time
import urllib
import urllib2

try:
    import json
except ImportError:
    try:
        import simplejson as json
    except ImportError:
        print 'Please install simplejson module'
        sys.exit(1)

from version import __version__

# Chartio URL
BASE_URL = os.environ.get('CHARTIO_BASE_URL', 'https://chartio.com').rstrip('/')
# Global configuration data
CONFIG_DATA = None
# Default install location
PREFIX_DEFAULT = '~/.chartio.d'
# Distributed version
VERSION = __version__


class ConfigData(object):
    '''
    Needs to be kept in sync with chartio_connect to limit
    PYTHONPATH dependencies.
    '''
    def __init__(self, prefix=None):
        # Directories
        self.PREFIX = prefix or os.path.expanduser(PREFIX_DEFAULT)
        self.LOG_DIRECTORY = os.path.join(self.PREFIX, 'logs')
        self.RUN_DIRECTORY = os.path.join(self.PREFIX, 'run')
        self.SSH_DIRECTORY = os.path.join(self.PREFIX, 'sshkey')
        # Files
        self.CONFIG_FILE = os.path.join(self.PREFIX, 'chartio.cfg')
        self.SSH_KEY = os.path.join(self.SSH_DIRECTORY, 'id_rsa')
        self.SSH_KNOWNHOSTS = os.path.join(self.SSH_DIRECTORY, 'known_hosts')
        # Config file sections
        self.SSHTUNNEL_SECTION = 'SSHTunnel'
        # Ensure the directories exist. Exit if they do not (or cannot be created).
        self.directory_create(self.PREFIX, 0755)
        self.directory_create(self.LOG_DIRECTORY, 0755)
        self.directory_create(self.RUN_DIRECTORY, 0755)
        self.directory_create(self.SSH_DIRECTORY, 0700)

    def directory_create(self, path, mode):
        '''Create a directory if it does not exist.

        Exit if it exists and is not a directory (or symlink) or unable to create it.

        Arguments
            path -- directory path
            mode -- mode to set on directory

        '''
        if os.path.exists(path):
            if os.path.isdir(path):
                try:
                    os.chmod(path, mode)
                except Exception, exc:
                    TermColor.print_error('Failed to change mode of %r. Exiting.' % (path))
                    TermColor.print_error(str(exc))
                    sys.exit(1)
            else:
                TermColor.print_error('The path %r is not a directory. Exiting.' % (path))
                sys.exit(1)
        else:
            try:
                os.makedirs(path, mode)
            except Exception, exc:
                TermColor.print_error('Failed to create %r. Exiting.' % (path))
                TermColor.print_error(str(exc))
                sys.exit(1)


def config_data_set(prefix):
    '''Update the global config data.

    Exits when unable to create configuration directories.

    Arguments
        prefix -- The install prefix

    '''
    global CONFIG_DATA
    CONFIG_DATA = ConfigData(prefix)
    print 'Installing into %r' % (CONFIG_DATA.PREFIX)


class TermColor(object):
    '''Print colored text on a neutral background to the terminal'''
    CLRS = {
        'white': '\033[37m',
        'green': '\033[92m',
        'red': '\033[91m',
        'yellow': '\033[33m',
        'bg': '\033[40m\033m'
    }

    END = '\033[0m'

    @classmethod
    def print_clr(cls, color, txt, newline=True):
        sys.stdout.write(cls.CLRS.get(color, '')
                         + cls.CLRS['bg']
                         + txt
                         + cls.END)
        if newline:
            sys.stdout.write('\n')

    @classmethod
    def print_cmd(cls, txt, newline=True):
        cls.print_clr('yellow', txt, newline)

    @classmethod
    def print_header(cls, txt, newline=True):
        cls.print_clr('white', txt, newline)

    @classmethod
    def print_ok(cls, txt, newline=True):
        cls.print_clr('green', txt, newline)

    @classmethod
    def print_error(cls, txt, newline=True):
        cls.print_clr('red', 'Error: ' + txt, newline)

    @classmethod
    def print_delay(cls, txt, newline=True):
        cls.print_clr('red', '==> ', False)
        cls.print_ok(txt, newline)


class StateTracker(object):
    '''Pseudo-state machine'''
    def __init__(self, *states):
        '''Constructor

        Arguments
        states -- ordered sequence of names or (name, phrase) pairs. if the former,
            the phrases are generated by lowercasing the names and converting underscores
            to spaces. the unmodified names become instance attributes with index values.

        '''
        self.state = 0
        self.names = []
        self.phrases = []
        for idx, state in enumerate(states):
            if isinstance(state, basestring):
                name = state
                phrase = name.lower().replace('_', ' ')
            elif isinstance(state, (tuple, list)):
                (name, phrase) = state
            else:
                raise RuntimeError('Invalid item in state sequence: %r' % (state))
            setattr(self, name, idx)
            self.names.append(name)
            self.phrases.append(phrase)

    def advance(self):
        '''Advance the internal state.

        Prints the current state afterward.

        Raises
        RuntimeError -- if the state is already at the end

        '''
        if self.state == len(self.names):
            raise RuntimeError('Unable to advance past final state')
        else:
            self.state += 1
        TermColor.print_delay(str(self))

    def assign(self, name):
        '''Explicitly set the state.

        Prints the current state afterward.

        Return
        integer -- the previous state value

        '''
        value = getattr(self, name)
        retval, self.state = self.state, value
        TermColor.print_delay(str(self))
        return retval

    def reset(self):
        '''Reset the internal state

        Prints the current state afterward.

        Return
        integer -- the previous state value

        '''
        retval, self.state = self.state, 0
        TermColor.print_delay(str(self))
        return retval

    def __str__(self):
        '''Stringify current state value'''
        return self.phrases[self.state]

    def __cmp__(self, other):
        if isinstance(other, basestring):
            value = getattr(self, other)
        else:
            value = other.state
        return cmp(self.state, value)


def get_choice(question, choices, default=None):
    '''Prompt for a response from a selection of choices.

    Return
        The selected item from choices (value, not an index)

    Raises
        ValueError -- if default is not in choices

    Arguments
        question -- the prompt
        choices -- possible answers
        default -- optional default value

    Example
        choice = get_choice('What fruit do you want?', ['apples', 'oranges'], 'apples')
        print 'choice', choice

    '''
    if default is None:
        default_idx = None
    else:
        try:
            default_idx = choices.index(default)
        except ValueError:
            raise ValueError('Default choice %r is not in choice collection' % (default))

    enum_choices = list(enumerate(choices))
    prompt = ((default_idx is not None and ('[%d]: ' % (default_idx + 1)))
              or ': ')
    while True:
        TermColor.print_header(question)
        for idx, item in enum_choices:
            TermColor.print_ok('    %d.' % ((idx + 1)), newline=False)
            print ' %s' % (item)
        input_raw = raw_input(prompt)
        if is_integer(input_raw):
            choice_idx = int(input_raw) - 1
        elif not input_raw.strip() and default_idx is not None:
            choice_idx = default_idx
        else:
            choice_idx = None

        if choice_idx is None:
            TermColor.print_error('invalid choice value: %s' % (input_raw))
        elif choice_idx < 0 or (len(choices) <= choice_idx):
            TermColor.print_error('choice out of range: %s' % (choice_idx))
        else:
            # !!! Exit loop
            break
    return choices[choice_idx]


def get_value(name, default=None, validate=None, validate_explanation=None,
              is_password=False):
    '''Prompt for and read a value from the terminal.

    Return
        string -- the read value

    Arguments
        name -- value name
        default -- [optional] default value
        validate -- [optional] callable to validate the input. Defaults to
            ensuring something was entered.
        validate_explanation -- [optional] error message if validate() fails
        is_password -- [optional] if True, do not echo the response. Defaults to False.

    '''
    prompt_default = (default and ' [%s]' % (default)) or ''
    prompt = '%s%s: ' % (name, prompt_default)
    if is_password:
        input_fn = lambda: getpass.getpass('')
    else:
        input_fn = raw_input
    validate_fn = validate or (lambda x: bool(x))
    error_msg = validate_explanation or ('Invalid Input.  Please try again.')
    while True:
        TermColor.print_header(prompt, newline=False)
        input_raw = input_fn().strip()
        input_value = input_raw or default or ''
        if validate_fn(input_value):
            # !!! Loop exit
            break
        else:
            TermColor.print_error(error_msg)
    return input_value


def is_integer(value):
    '''Determine whether value is convertible to an integer'''
    try:
        int(value)
        rc = True
    except (TypeError, ValueError):
        rc = False
    return rc


def can_connect_on_host_port(dbhost, port):
    '''Determine whether a service is listening on dbhost:value

    Return
    bool -- True iff value converts to an integer and a TCP service
        is listening on that port; False otherwise.

    '''
    if is_integer(port):
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        sock.settimeout(5)
        try:
            sock.connect((dbhost, int(port)))
            rc = True
        except socket.error, e:
            rc = False
        else:
            sock.close()
    else:
        rc = False
    return rc


def name_generate(db_name, max_length):
    '''Generate a database user name which does not exceed a maximum length

    If the name would exceed the maximum length, the name is shortened
    and some random digits appended.

    Return
        string -- the generated name

    Raises
        RuntimeError -- if database name and maximum length values do not
            permit generation of a useful name.

    Arguments
        db_name -- database name
        max_length -- the inclusive size limit of the generated name

    '''
    PREFIX = 'chartio_'
    RANDOM_SUFFIX_LENGTH = 3
    MAX_REQUIRED = len(PREFIX) + RANDOM_SUFFIX_LENGTH
    if max_length < MAX_REQUIRED:
        raise RuntimeError('Unable to generate a name with fewer than %d characters (%d specified).'
                           % (MAX_REQUIRED, max_length))
    name_full = ('%s%s%s' % (PREFIX, db_name.strip(), RANDOM_SUFFIX_LENGTH * 'X'))[:max_length]
    chars = string.digits
    name = (name_full[:-RANDOM_SUFFIX_LENGTH]
            + ''.join([random.choice(chars) for i in range(RANDOM_SUFFIX_LENGTH)]))
    return name


class ResultTrue(object):
    def __init__(self, info):
        self.info = info

    def __nonzero__(self):
        return True

    def __repr__(self):
        return 'ResultTrue(%r)' % (self.info)


class ResultFalse(object):
    def __init__(self, info):
        self.info = info

    def __nonzero__(self):
        return False

    def __repr__(self):
        return 'ResultFalse(%r)' % (self.info)


def empty_is_ok(_value):
    '''Validator for accepting an empty string'''
    return True


def is_yes_no(value):
    '''Determine whether a string begins with y/n'''
    lowered = value.strip().lower()
    retval = lowered.startswith('y') or lowered.startswith('n')
    return retval


def bool_from_yes_no(value):
    '''Convert a y/n string to a bool'''
    lowered = value.strip().lower()
    retval = lowered.startswith('y')
    return retval


def random_password_generate(length=24):
    '''Generate a random password of a specified length

    '''
    valid_chars = string.letters + string.digits
    password = ''.join(map(lambda x: random.SystemRandom().choice(valid_chars), range(length)))
    return password


def database_choices(chartio_api):
    db_settings = {'MySQL': {'default_host': '127.0.0.1', 'default_port': 3306, 'user_name_limit': 16},
                   'Oracle': {'default_host': '127.0.0.1', 'default_port': 1521, 'user_name_limit': 31},
                   'PostgreSQL': {'default_host': '127.0.0.1', 'default_port': 5432, 'user_name_limit': 63},
                   'AmazonRedshift': {'default_host': '127.0.0.1', 'default_port': 5439, 'user_name_limit': 63},
                   'Presto': {'default_host': '127.0.0.1', 'default_port': 8080}}
    json_response = chartio_api.post('/connectionclient/databasetypes/')
    response = json.loads(json_response)
    databases = response.get('databasetypes', [])
    db_map = {}
    for db in databases:
        name = db['name']
        db_map[name] = {'name': name, 'id': db['id']}
        settings = db_settings.get(name)
        if settings:
            db_map[name].update(settings)
    return db_map


class Settings(object):
    '''Structure class'''
    def __init__(self):
        self.database_id = None
        self.database_host = None
        self.database_name = None
        self.schema = None
        self.database_port = None
        self.readonly_user = None
        self.readonly_password = None
        self.readonly_host = None
        # For dev
        self.db = None


class Poster(object):
    '''Class to POST information to Chartio.'''

    def __init__(self):
        self.opener = urllib2.build_opener(urllib2.HTTPCookieProcessor())
        urllib2.install_opener(self.opener)
        self.history = []

    def post(self, url, data_param=None, offer_post=True):
        '''A simple POST request wrapper'''
        if data_param is None:
            encoded_args = []
        else:
            encoded_args = [urllib.urlencode(data_param)]
        if not url.startswith('http'):
            url = BASE_URL + url
        if isinstance(data_param, dict):
            for pw_key in ('passwd', 'password'):
                if pw_key in data_param:
                    data_param[pw_key] = 'XXXXXX'
        log_record = [datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S'),
                      url,
                      data_param]
        try:
            response = self.opener.open(url, *encoded_args)
        except urllib2.URLError, exc:
            if hasattr(exc, 'read'):
                error_msg = exc.read()
            else:
                error_msg = ''
            log_record.extend([str(exc), error_msg])
            self.history.append(log_record)
            TermColor.print_error('Issue communicating with the Chartio service:'
                                  '\n    URL: %s'
                                  '\n         %s'
                                  '\n         %s'
                                  % (url, exc, error_msg))
            if offer_post:
                self.history_post()
            # !!! Early exit
            sys.exit(1)
        retval = response.read()
        log_record.extend(['ok', retval])
        self.history.append(log_record)
        return retval

    def history_content(self):
        return '\n'.join([' '.join(map(str, record)) for record in self.history])

    def history_post(self):
        post_history_to_service = get_value('Would you like to post the error log to Chartio?\n'
                                            'This will open a Chartio support ticket.',
                                            default='yes',
                                            validate=is_yes_no,
                                            validate_explanation='Please enter yes or no')
        if bool_from_yes_no(post_history_to_service):
            self.post('/connectionclient/support/',
                      [('history', self.history_content())],
                      offer_post=False)
            TermColor.print_ok('Successfully posted error log to Chartio support.\n'
                               'Thank you,\n'
                               'The Chartio Team <support@chartio.com>')


def multiple_attempts(attempts, fn):
    for attempt in range(attempts):
        if fn():
            return True
    return False


def admin_user_password_dec(wrapped_fn):
    '''Decorator handling prompting of admin user/password as needed

    Requirements
        self.settings.admin_user data attribute

    References
        self.have_admin to determine whether admin is already configured
        self.admin_user_password() to set/validate/check rights

    '''
    def fn(self, *args, **kwargs):
        def admin_user_prompt_and_check():
            admin_user = get_value('Database administrator name')
            TermColor.print_header('Please enter the password for the database'
                                   ' administrator (leave empty for none)')
            # Validate any value to permit a blank password
            admin_password = get_value('Database administrator password',
                                       is_password=True,
                                       validate=empty_is_ok)
            rc = self.admin_user_password(admin_user, admin_password)
            return rc
        if self.have_admin:
            run_wrapped_fn = True
        else:
            TermColor.print_header('\nThis step requires administrator access.\n'
                                   'The credentials will be used during setup only for\n'
                                   ' - determining the names of local databases [optional]\n'
                                   ' - creating read-only user account [optional]\n'
                                   ' - granting read-only database access to the user account [optional]\n'
                                   'You may instead choose to do any or all of the above steps prior to\n'
                                   ' running chartio_setup. The script will prompt you to enter the\n'
                                   ' relevant information directly as needed.\n')
            run_wrapped_fn = multiple_attempts(3, admin_user_prompt_and_check)
        if run_wrapped_fn:
            retval = wrapped_fn(self, *args, **kwargs)
        else:
            retval = None
        return retval
    return fn


class DbAdmin(object):
    def __init__(self, chartio_poster, db_accessors):
        self.db_accessor = None
        self.db_accessors = db_accessors
        self.db_choice = None
        self.have_admin = False
        self.chartio_poster = chartio_poster
        self.settings = Settings()

    def admin_user_password(self, user, password):
        '''Means of validating and storing admin user and password'''
        self.have_admin = self.db_accessor.admin_user_password(user, password)
        return self.have_admin

    def readonly_user_password(self, database, user, password):
        '''Means of validating and storing readonly user and password'''
        rc = self.db_accessor.readonly_user_password(database, user, password)
        if rc:
            self.settings.readonly_user = user
            self.settings.readonly_password = password
        return rc

    def start_step(self):
        '''Beginning of interaction state machine'''
        return self.database_type_step

    def database_type_step(self):
        '''Get database type'''
        db_types = database_choices(self.chartio_poster)
        # Connecting Oracle or Presto requires CLI-only mode
        db_types.pop('Oracle', None)
        db_types.pop('Presto', None)
        db_type_name = get_choice('What type of database are you connecting?',
                                  sorted(db_types))
        self.db_choice = db_types[db_type_name]
        self.settings.database_id = db_types[db_type_name]['id']
        self.db_accessor = self.db_accessors[db_type_name.lower()]()
        return self.database_host_step

    def database_host_step(self):
        '''Database server'''
        TermColor.print_header('Enter hostname or IP address of your database server')
        default = self.db_choice['default_host']
        dbhost = get_value('Database server', default=default)
        self.db_accessor.dbhost = dbhost
        self.settings.database_host = dbhost
        return self.database_listen_port_step(dbhost)

    def database_listen_port_step(self, dbhost):
        '''Database listen port'''
        TermColor.print_header('Enter database listen port')
        default = self.db_choice['default_port']
        port = get_value('Database listen port',
                         default=default,
                         validate=is_integer,
                         validate_explanation='Port must be a valid integer')
        if can_connect_on_host_port(self.db_accessor.dbhost, port):
            self.db_accessor.port = port
            self.settings.database_port = port
        else:
            TermColor.print_header('Database is not listening on the specified database server and port.'
                                   '\nPlease confirm your database connection information and try again.\n')
            return self.database_host_step
        return self.database_name_step

    def database_name_step(self):
        '''Get database name'''
        if self.db_choice['name'] == 'AmazonRedshift':
            TermColor.print_header('Enter ther database name to connect to Chartio')
            database = get_value('Database name')
        else:
            TermColor.print_header('Enter the database name to connect to Chartio'
                                   '\n    [Leave blank to list available databases]')
            database = (get_value('Database name', validate=empty_is_ok)
                        or self.database_name_from_choice())
        if database is None:
            retval = False
        else:
            self.settings.database_name = database
            if isinstance(self.db_accessor, MysqlAccessor):
                # MySQL, no schema
                retval = self.readonly_user_step
            else:
                # Postgres, get schema
                retval = self.schema_step
        return retval

    def schema_step(self):
        TermColor.print_header('Enter the schema name to use when connecting'
                               '\n    [Leave blank for "public"]')
        schema = (get_value('Schema', validate=empty_is_ok) or 'public')
        self.db_accessor.schema = schema
        self.settings.schema = schema
        return self.readonly_user_step

    @admin_user_password_dec
    def database_name_from_choice(self):
        '''Select database from automatically generated list (requires admin)'''
        databases = self.db_accessor.databases_get()
        if databases is None:
            retval = False
        else:
            TermColor.print_header('\nSelect which database to connect:')
            retval = get_choice('Database name', sorted(databases))
        return retval

    def readonly_user_step(self):
        '''Get/create read-only role and password and grant database access'''
        database = self.settings.database_name
        is_remote = self.db_accessor.dbhost not in ['127.0.0.1', 'localhost']
        is_mysql = isinstance(self.db_accessor, MysqlAccessor)
        if is_mysql and is_remote:
            TermColor.print_header('Enter an existing read-only role for Chartio to use')
            user = get_value('Read-only role name')
            user_password = self.readonly_user_password_get(database, user=user)
            while not user_password:
                user_password = self.readonly_user_password_get(database, user=None)
            retval = (user_password and self.datasource_register_step) or False
        else:
            if self.db_choice['name'] == 'AmazonRedshift':
                TermColor.print_header('Enter an existing read-only role for Chartio to user')
                user = get_value('Read-only role name')
            else:
                TermColor.print_header('Enter an existing read-only role for Chartio to use'
                                       '\n    [Leave blank to create a new role automatically]')
                user = get_value('Read-only role name', validate=empty_is_ok)
            if user:
                user_password = self.readonly_user_password_get(database, user=user)
                while not user_password:
                    user_password = self.readonly_user_password_get(database, user=None)
            else:
                user_password = self.readonly_user_create_and_grant(database)
            retval = (user_password and self.datasource_register_step) or False
        return retval

    def readonly_user_password_get(self, database, user=None):
        '''Prompt for existing read-only database role

        Return
        (user, password) | False -- a tuple of the user and password iff apparently successful

        '''
        if user is None:
            user = get_value('Read-only role name')
        password = get_value('Read-only role password', is_password=True, validate=empty_is_ok)
        if self.readonly_user_password(database, user, password):
            retval = (user, password)
        else:
            retval = False
        return retval

    @admin_user_password_dec
    def readonly_user_create_and_grant(self, database):
        '''Create a read-only user

        Return
        (user, password) | False -- a tuple of the user and password iff apparently successful

        '''
        user = name_generate(database.replace(' ', ''),
                             self.db_choice['user_name_limit'])
        password = random_password_generate()
        TermColor.print_delay('Creating read-only user %r' % (user))
        if (self.db_accessor.readonly_user_create_and_grant(database, user, password)
            and self.readonly_user_password(database, user, password)):
            retval = (user, password)
        else:
            retval = False
        return retval

    def datasource_register_step(self):
        return None


def argv_run(argv, input_str=None):
    '''Execute an argv with optional input

    Return
    ResultFalse|ResultTrue

    '''
    if input_str is None:
        stdin = None
    else:
        stdin = subprocess.PIPE
    try:
        proc = subprocess.Popen(argv,
                                stdin=stdin,
                                stderr=subprocess.PIPE,
                                stdout=subprocess.PIPE)
        info, stderr = proc.communicate(input_str)
        proc_exited_cleanly = (0 == proc.returncode)
    except IOError:
        info = None
        proc_exited_cleanly = False
    except OSError, e:
        if e.errno == errno.ENOENT:
            command = argv[0]
            result = subprocess.call(['which', command])
            if result == 0:
                raise e

            error_msg = "'%s' could not be found in your path." % command
            TermColor.print_error(error_msg)
            sys.exit(1)
        else:
            raise e

    if stdin is not None:
        proc.stdin.close()
    proc.stderr.close()
    if not proc_exited_cleanly and stderr:
        if info:
            info = '\n'.join([info, stderr])
        else:
            info = stderr

    if proc_exited_cleanly:
        return ResultTrue(info)
    else:
        return ResultFalse(info)


def argv_check(what, argv_result, check_fn):
    if argv_result:
        if isinstance(argv_result.info, basestring):
            info = argv_result.info.strip()
        else:
            info = argv_result.info
        rc = check_fn(info)
    else:
        rc = False
    if not rc:
        TermColor.print_error('Failed to ' + what + '.')
        if argv_result.info:
            TermColor.print_error(argv_result.info)
    return rc


def replacer_make(source, dest):
    '''A means of creating a replacement function which operates on strings.

    The use here is for escaping special characters from SQL arguments.

    '''
    def replacer(**kwargs):
        def item_replacer(item):
            key, value = item
            if isinstance(value, basestring):
                retval = (key, value.replace(source, dest))
            else:
                retval = item
            return retval
        retval = dict(map(item_replacer, kwargs.items()))
        return retval
    return replacer


class MysqlAccessor(object):
    '''Intermediary for accessing a MySQL database'''

    squote_escape = staticmethod(replacer_make("'", "''"))
    backtick_escape = staticmethod(replacer_make("`", "``"))

    def __init__(self):
        self.dbhost = None
        self.port = None
        self.admin_user = None
        self.admin_password = None

    def _sql_cmd_argv(self, user, password, database=None):
        argv = ['mysql',
                '--silent',
                '-h%s' % self.dbhost,
                '-u%s' % user,
                '--port=%s' % self.port,
                '--protocol=tcp']

        if password:
            argv.extend(['-p%s' % (password)])
        if database is not None:
            argv.append(database)
        return argv

    def sql_commands_print(self):
        cmds = ('''SHOW_DATABASES;''',
                '''SELECT 'SUCCESS';''',
                '''SELECT 'SUCCESS' FROM mysql.user WHERE user = '<USER>' AND host='127.0.0.1';''',
                ('''GRANT SELECT, SHOW VIEW ON <DATABASE>.* TO `<USER>`@`127.0.0.1`'''
                 ''' IDENTIFIED BY '<PASSWORD>';'''))
        for cmd in cmds:
            TermColor.print_cmd(cmd)

    def admin_user_password(self, user, password):
        '''Confirm and set admin user/password values'''
        rc = self.user_password_is_valid(user, password, database='mysql')
        if rc:
            self.admin_user = user
            self.admin_password = password
        return rc

    def readonly_user_password(self, database, user, password):
        '''Confirm and set readonly user/password values'''
        rc = self.user_password_is_valid(user, password, database)
        if rc:
            self.ro_user = user
            self.ro_password = password
        return rc

    def databases_get(self):
        '''Fetch a sequence of databases'''
        argv = self._sql_cmd_argv(self.admin_user, self.admin_password, database=None)
        sql = 'SHOW DATABASES;'
        sql_retval = argv_run(argv, sql)
        sql_rc = argv_check('list available databases', sql_retval, empty_is_ok)
        if sql_rc:
            databases = [db.strip() for db in sql_retval.info.split()]
        else:
            databases = None
        return databases

    def user_password_is_valid(self, user, password, database):
        '''Determine whether a user/password combination is valid for access to a given database'''
        argv = self._sql_cmd_argv(user, password, database)
        sql = '''SELECT 'SUCCESS';'''
        sql_retval = argv_run(argv, sql)
        rc = argv_check('validate user/password for %r' % (user),
                        sql_retval,
                        lambda x: 'SUCCESS' == x)
        return rc

    def user_has_role_create_access(self, user, password):
        argv = self._sql_cmd_argv(user, password, database='mysql')
        sql = ('''SELECT COUNT(1) FROM user WHERE User='%(user)s' AND grant_priv='y';'''
               % self.squote_escape(user=user))
        sql_retval = argv_run(argv, sql)
        rc = argv_check('verify user %r has role create access' % (user),
                        sql_retval,
                        lambda x: 'SUCCESS' == x)
        return rc

    def readonly_user_create_and_grant(self, database, ro_user, ro_password):
        admin_user, admin_password = self.admin_user, self.admin_password
        user_exists = self.readonly_user_exists(admin_user, admin_password, database, ro_user)
        if user_exists or (user_exists is None):
            rc = False
        else:
            rc = self.readonly_user_access_grant(admin_user, admin_password, database, ro_user, ro_password)
        return rc

    def readonly_user_exists(self, user, password, database, ro_user):
        argv = self._sql_cmd_argv(user, password, database)
        sql = ('''SELECT 'SUCCESS' FROM mysql.user WHERE User='%(user)s' AND Host='127.0.0.1';'''
               % (self.squote_escape(user=ro_user)))
        sql_retval = argv_run(argv, sql)
        if sql_retval:
            rc = ('SUCCESS' == sql_retval.info.strip())
            if rc:
                TermColor.print_error('Role %r already exists.' % (ro_user))
                TermColor.print_header('''You may reset the password with\n'''
                                       '''    SET PASSWORD FOR '%s'@'127.0.0.1' = 'New-Password';\n''' % (ro_user))
                TermColor.print_header('''Or re-run chartio_setup to generate a different role\n''')
        else:
            TermColor.print_error('Error checking whether role %r exists.' % (ro_user))
            TermColor.print_error(sql_retval.info)
            rc = None
        return rc

    def readonly_user_access_grant(self, user, password, database, ro_user, ro_password):
        argv = self._sql_cmd_argv(user, password, database)
        params = self.squote_escape(role=ro_user, password=ro_password)
        params.update(self.backtick_escape(database=database))
        sql = ('''GRANT SELECT, SHOW VIEW ON `%(database)s`.* TO '%(role)s'@`127.0.0.1`'''
               ''' IDENTIFIED BY "%(password)s";'''
               % params)
        sql_retval = argv_run(argv, sql)
        rc = argv_check('create/grant access to database %r for role %r' % (database, user),
                        sql_retval,
                        lambda x: '' == x)
        return rc


class PostgresqlAccessor(object):
    '''Intermediary for accessing a PostgreSQL database'''

    dquote_escape = staticmethod(replacer_make('"', '""'))
    squote_escape = staticmethod(replacer_make("'", "''"))

    def __init__(self):
        self.port = None
        self.database = None
        self.schema = None
        self.admin_user = None
        self.admin_password = None
        self.ro_user = None
        self.ro_password = None

    def _sql_cmd_argv(self, user, password, database=None):
        argv = []
        if password:
            argv.extend(['env', 'PGPASSWORD=%s' % (password)])
        argv.extend(['psql',
                     '-t',
                     '-U', user,
                     '-v', 'ON_ERROR_STOP=1'])
        if self.port is not None:
            argv.extend(['-h', self.dbhost,
                         '-p', str(self.port)])
        if database is None:
            # If a database is not supplied, postgress will default to the current user.
            database = 'postgres'
        argv.append(database)
        return argv

    def sql_commands_print(self):
        assert self.schema is not None
        cmds = ('''SELECT 'SUCCESS';''',
                ('''SELECT 'SUCCESS' FROM pg_roles WHERE rolname='<ADMIN-USER>' '''
                 '''AND rolcreaterole='t';'''),
                '''SELECT 'SUCCESS' FROM pg_roles WHERE rolname='<READONLY-USER>';''',
                ('''CREATE USER "<READONLY-USER>" PASSWORD '<PASSWORD>' '''
                 '''NOSUPERUSER NOCREATEDB NOCREATEROLE NOINHERIT;'''),
                '''SELECT datname FROM pg_database;''',
                ('''SELECT relname'''
                 ''' FROM pg_class JOIN pg_namespace ON pg_namespace.oid = pg_class.relnamespace'''
                 ''' WHERE nspname = '%(schema)s' AND relkind IN ('r','v')'''
                 ''' ORDER BY relname ASC;''' % {'schema': self.schema}),
                '''GRANT SELECT ON TABLE "<TABLE>" TO "<READONLY-USER>";''')
        for cmd in cmds:
            TermColor.print_cmd(cmd)

    def admin_user_password(self, user, password):
        '''Confirm and set admin user/password values'''
        rc = (self.user_password_is_valid(user, password, database=None)
              and self.user_has_role_create_access(user, password))
        if rc:
            self.admin_user = user
            self.admin_password = password
        return rc

    def readonly_user_password(self, database, user, password):
        '''Confirm and set readonly user/password values'''
        rc = (self.user_password_is_valid(user, password, database)
              and self.user_has_readonly_access(user, password, database))
        if rc:
            self.ro_user = user
            self.ro_password = password
        return rc

    def user_password_is_valid(self, user, password, database):
        '''Return bool'''
        argv = self._sql_cmd_argv(user, password, database)
        sql = '''SELECT 'SUCCESS';'''
        sql_retval = argv_run(argv, sql)
        rc = argv_check('validate user/password for %r' % (user),
                        sql_retval,
                        lambda x: 'SUCCESS' == x)
        return rc

    def user_has_role_create_access(self, user, password):
        argv = self._sql_cmd_argv(user, password, database=None)
        sql = ('''SELECT 'SUCCESS' FROM pg_roles WHERE rolname='%(rolname)s' '''
               ''' AND rolcreaterole='t';'''
               % self.dquote_escape(rolname=user))
        sql_retval = argv_run(argv, sql)
        rc = argv_check('verify user %r has role create access' % (user),
                        sql_retval,
                        lambda x: 'SUCCESS' == x)
        return rc

    def user_has_readonly_access(self, user, password, database):
        argv = self._sql_cmd_argv(user, password, database)
        sql = '''SELECT 'SUCCESS';'''
        sql_retval = argv_run(argv, sql)
        rc = argv_check('verify role %r has readonly access to %r' % (user, database),
                        sql_retval,
                        lambda x: 'SUCCESS' == x)
        return rc

    def readonly_user_create_and_grant(self, database, ro_user, ro_password):
        admin_user, admin_password = self.admin_user, self.admin_password
        user_exists = self.readonly_user_exists(admin_user, admin_password, database, ro_user)
        if user_exists or (user_exists is None):
            rc = False
        else:
            rc = (self.readonly_user_create(admin_user, admin_password,
                                            database,
                                            ro_user, ro_password)
                  and self.readonly_user_access_grant(admin_user, admin_password, database, ro_user))
        return rc

    def readonly_user_exists(self, user, password, database, ro_user):
        argv = self._sql_cmd_argv(user, password, database)
        sql = ('''SELECT 'SUCCESS' FROM pg_roles WHERE rolname='%(rolname)s';'''
               % self.squote_escape(rolname=ro_user))
        sql_retval = argv_run(argv, sql)
        if sql_retval:
            rc = ('SUCCESS' == sql_retval.info.strip())
            if rc:
                TermColor.print_error('Role %r already exists.' % (ro_user))
                TermColor.print_header('''You may reset the password with\n'''
                                       '''    ALTER ROLE '%s' PASSWORD 'New-Password'\n''' % (ro_user))
        else:
            TermColor.print_error('Error checking whether role %r exists.' % (ro_user))
            TermColor.print_error(sql_retval.info)
            rc = None
        return rc

    def readonly_user_create(self, user, password, database, ro_user, ro_password):
        '''Create a database-specific Postgresql read-only user'''
        argv = self._sql_cmd_argv(user, password, database)
        sql = ('''CREATE USER "%(role)s" PASSWORD '%(password)s' '''
               ''' NOSUPERUSER NOCREATEDB NOCREATEROLE NOINHERIT;'''
               % self.dquote_escape(role=ro_user, password=ro_password))
        sql_retval = argv_run(argv, sql)
        argv_check('create read-only user %r' % (ro_user),
                   sql_retval,
                   lambda x: 'CREATE ROLE' == x)
        return sql_retval

    def databases_get(self):
        '''List available databases

        Return
        seq | None -- a sequence of database names iff successful; None otherwise.

        '''
        argv = self._sql_cmd_argv(self.admin_user, self.admin_password, database=None)
        sql = '''SELECT datname FROM pg_database ORDER BY datname ASC;'''
        sql_retval = argv_run(argv, sql)
        sql_rc = argv_check('list available databases', sql_retval, empty_is_ok)
        if sql_rc:
            databases = [db.strip() for db in sql_retval.info.split()]
        else:
            databases = None
        return databases

    def readonly_user_access_grant(self, user, password, database, ro_user):
        tables_retval = self.tables_get(user, password, database)
        if tables_retval:
            tables = tables_retval.info.strip().split()
            for table in tables:
                if not self.readonly_user_table_access_grant(user, password, database, ro_user, table):
                    rc = False
                    break
            else:
                rc = True
        else:
            rc = False
        # If version 9 and greater, apply to future tables as well
        argv = self._sql_cmd_argv(user, password, database)
        version_sql = 'SELECT version();'
        sql_retval = argv_run(argv, version_sql)
        if sql_retval and sql_retval.info.strip().startswith('PostgreSQL 9.'):
            priv_sql = '''ALTER DEFAULT PRIVILEGES GRANT SELECT ON TABLES TO "%s"''' % ro_user
            sql_retval = argv_run(argv, priv_sql)
            if sql_retval:
                TermColor.print_delay('Granting %r read-only access to future tables' % ro_user)
        return rc

    def tables_get(self, user, password, database):
        argv = self._sql_cmd_argv(user, password, database)
        assert self.schema is not None
        sql = ('''SELECT relname'''
               ''' FROM pg_class JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace'''
               ''' WHERE nspname='%(schema)s' AND relkind IN ('r','v')'''
               ''' ORDER BY relname ASC;''' % {'schema': self.schema})
        sql_retval = argv_run(argv, sql)
        argv_check('determine tables in %r' % self.database,
                   sql_retval,
                   lambda x: True)
        return sql_retval

    def readonly_user_table_access_grant(self, user, password, database, ro_user, table):
        argv = self._sql_cmd_argv(user, password, database)
        sql = ('''GRANT SELECT ON TABLE "%(table)s" TO "%(role)s";'''
               % self.dquote_escape(table=table, role=ro_user))
        TermColor.print_delay('Granting %r read-only access to table %r' % (ro_user, table))
        sql_retval = argv_run(argv, sql)
        argv_check('grant table select access on %s to "%s"' % (table, ro_user),
                   sql_retval,
                   lambda x: 'GRANT' == x)
        return sql_retval


def can_query_oracle(role, password, dbhost, port, database):
    '''Determine whether the specified role can query the Oracle database'''
    cmd = ['sqlplus',
           '-R', '3', '-S',
           ('%s/%s@%s:%s/%s' % (role, password, dbhost, port, database))]
    pipe = subprocess.PIPE
    popen = subprocess.Popen(cmd, stdin=pipe, stdout=pipe, stderr=pipe)
    out_err = popen.communicate('SELECT 1 AS OKAY FROM DUAL;')
    rc = ((0 == popen.returncode)
          and ('OKAY\n--' in out_err[0])
          and (not out_err[0].startswith('ERROR')))
    popen.wait()
    return rc


def create_ssh_conf():
    '''Attempt to create the config file.

    Complains and exits if the attempt fails. Good luck testing.

    '''
    try:
        conf_file = open(CONFIG_DATA.CONFIG_FILE, 'a')
    except IOError, exc:
        # !!! Early exit
        sys.stderr.write('Unable to write to config file %r\n' % (CONFIG_DATA.CONFIG_FILE))
        sys.stderr.write('    %s\n' % (exc))
        sys.stderr.write('Exiting.\n')
        sys.exit(1)
    else:
        conf_file.close()


def write_ssh_conf(key, value):
    '''Write a config key and value to a config file SSHTunnel section.

    Argumnents
    key -- the storage key
    value -- the storage value

    '''
    conf = ConfigParser.ConfigParser()
    if os.path.exists(CONFIG_DATA.CONFIG_FILE):
        conf.read(CONFIG_DATA.CONFIG_FILE)
    section = CONFIG_DATA.SSHTUNNEL_SECTION
    if section not in conf.sections():
        conf.add_section(section)
    conf.set(section, key, value)
    try:
        f = open(CONFIG_DATA.CONFIG_FILE, 'w')
    except IOError, exc:
        sys.stderr.write('Unable to open config file for writing: %s\n' % (exc))
        sys.stderr.write('Exiting.\n')
        sys.exit(1)
    conf.write(f)
    f.close()


def write_ssh_known_hosts(hostname, pubkey):
    null = open(os.devnull, 'w')
    # Since subprocess.check_output is only on Python 2.7+, check_output is implemented for <=2.6
    try:
        check_output = subprocess.check_output
    except AttributeError:
        def check_output(*popenargs, **kwargs):
            if 'stdout' in kwargs:
                raise ValueError('stdout argument not allowed, it will be overridden.')
            process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs)
            output, unused_err = process.communicate()
            retcode = process.poll()
            if retcode:
                cmd = kwargs.get("args")
                if cmd is None:
                    cmd = popenargs[0]
                raise subprocess.CalledProcessError(retcode, cmd, output=output)
            return output
    try:
        key = check_output([
            'ssh-keyscan',
            '-t', 'rsa',
            '%s,%s' % (hostname, socket.gethostbyname(hostname))
        ], stderr=null)
    finally:
        null.close()
    if pubkey not in key:
        TermColor.print_error('Server ssh key mismatch. Please contact support@chartio.com')
        sys.exit(1)
    try:
        f = open(CONFIG_DATA.SSH_KNOWNHOSTS, 'w')
    except IOError, exc:
        sys.stderr.write('Unable to open ssh known_hosts file for writing: %s\n' % (exc))
        sys.exit(1)
    f.write(key)
    f.close()


def get_ssh_conf_value(key):
    '''Retrieve the value of a key from a config file SSHTunnel section.

    Return
    string or None -- the associated value on success; None if the
        config file, section, or key was not found.

    Arguments
    key -- the lookup key

    '''
    conf = ConfigParser.ConfigParser()
    if os.path.exists(CONFIG_DATA.CONFIG_FILE):
        conf.read(CONFIG_DATA.CONFIG_FILE)
        try:
            retval = conf.get(CONFIG_DATA.SSHTUNNEL_SECTION, key)
        except (ConfigParser.NoOptionError, ConfigParser.NoSectionError):
            retval = None
    else:
        retval = None
    return retval


def chartio_connect_start():
    '''Start chartio_connect daemon process.

    Exits if the return code is non-zero.

    '''
    # Launching chartio connect
    TermColor.print_delay('Launching chartio_connect')
    retcode = subprocess.call(['chartio_connect',
                               '-d',
                               '--prefix=%s' % (CONFIG_DATA.PREFIX)])
    if 0 == retcode:
        # Wait for connection to establish
        time.sleep(10)
        TermColor.print_delay('chartio_connect running')
    else:
        TermColor.print_error('Failed to launch chartio_connect. Exiting.')
        sys.exit(1)


def _exit(*args):
    print ''
    TermColor.print_ok('Exiting')
    sys.exit(0)


def opt_args_gather():
    parser = optparse.OptionParser(version='%prog @VERSION@')
    parser.add_option('--prefix',
                      help=('installation prefix for configuration'
                            'and runtime information. Defaults to %r' % (PREFIX_DEFAULT)))
    parser.add_option('-H', '--database-host',
                      type='string',
                      default=None,
                      help='name or IP address of database server')
    parser.add_option('-d', '--database-name',
                      type='string',
                      default=None,
                      help='name of local database (Oracle-only)')
    parser.add_option('-p', '--port',
                      type='int',
                      default=None,
                      help=('database listen port.'
                            ' Defaults based on database type.'
                            ' MySQL: 3306, Oracle: 1521, PostgreSQL: 5432'))
    parser.add_option('-r', '--role',
                      type='string',
                      default=None,
                      help='name of existing read-only role/user (Oracle-only)')
    parser.add_option('-s', '--schema',
                      type='string',
                      default='',
                      help='name of database schema (only for Oracle or Presto)')
    parser.add_option('-c', '--catalog',
                      type='string',
                      default='',
                      help='name of database catalog (Presto-only)')
    parser.add_option('--oracle',
                      action='store_true',
                      help='setup connection for an Oracle database')
    parser.add_option('--presto',
                      action='store_true',
                      help='setup connection for a Presto database')
    opt_args = parser.parse_args()
    if opt_args[0].oracle and opt_args[0].presto:
        TermColor.print_error('You may not select both Oracle and Presto.')
        TermColor.print_ok('Exiting')
        sys.exit(1)

    return opt_args


def crontab_update(crontab_entry):
    '''Ask and either add or tell how to add a crontab entry

    Arguments
    crontab_entry -- the entry to add

    '''
    # 2.5 always deletes NamedTemporaryFile, so just use it to find a temp directory
    tmpdir = tempfile.gettempdir()
    crontab_filename = os.path.join(tmpdir, 'chartio.cron')
    state = StateTracker('READING_CURRENT_CRONTAB',
                         'CHECKING_CURRENT_CRONTAB',
                         ('WRITING_TEMPORARY_FILE',
                          'writing temporary file as %s' % (crontab_filename)),
                         ('INSTALLING_NEW_CRONTAB',
                          'installing %s as new crontab' % (crontab_filename)),
                         ('REMOVING_TEMPORARY_FILE',
                          'removing %s temporary file' % (crontab_filename)))
    TermColor.print_header('Would you like a crontab entry added to reconnect'
                           ' to Chartio on reboot?')
    add_crontab = get_value('[y/n]', default=None,
                            validate=is_yes_no,
                            validate_explanation='Please enter yes or no')
    if bool_from_yes_no(add_crontab):
        try:
            proc = subprocess.Popen(['crontab', '-l'],
                                    stdout=subprocess.PIPE,
                                    stderr=subprocess.PIPE,
                                    close_fds=True)
            (contents, stderr) = proc.communicate()
            if stderr.startswith('no crontab for'):
                state.advance()
            else:
                if 0 != proc.returncode:
                    proc = None
                    raise RuntimeError(stderr)
                state.advance()
            if 0 <= contents.find(crontab_entry):
                raise RuntimeError('crontab appears to already run chartio_connect')
            state.advance()
            crontab = file(crontab_filename, 'w')
            crontab.write(contents)
            crontab.write(crontab_entry + '\n')
            crontab.close()
            state.advance()
            proc_returncode = subprocess.call(['crontab', crontab_filename])
            if 0 == proc_returncode:
                state.advance()
                os.unlink(crontab_filename)
        except Exception, e:
            TermColor.print_error('Failed while %s:\n'
                                  '    %s' % (str(state), str(e)))
            if state == 'INSTALLING_NEW_CRONTAB':
                TermColor.print_header('Modified crontab may be found in')
                TermColor.print_header('    %s' % (crontab_filename))
    if state < 'INSTALLING_NEW_CRONTAB':
        TermColor.print_ok('\nTo reconnect automatically to Chartio after a reboot,'
                           ' edit your crontab by typing')
        TermColor.print_header('  crontab -e')
        TermColor.print_ok('and adding this entry:')
        TermColor.print_header(crontab_entry)


def main():
    # Handle control-c
    signal.signal(signal.SIGINT, _exit)
    TermColor.print_ok('Welcome to the Chartio setup wizard.')
    # This exits on failure
    (options, _args) = opt_args_gather()
    # This exits on failure
    config_data_set(options.prefix)
    # This exits on failure
    create_ssh_conf()
    # Confirm things have been installed
    proc = subprocess.Popen(['which', 'chartio_connect'],
                            stderr=subprocess.STDOUT,
                            stdout=subprocess.PIPE)
    (conn_location, _which_err) = proc.communicate()
    if 0 != proc.returncode:
        TermColor.print_error('Chartio does not appear installed. Please run\n'
                              '  easy_install chartio\n'
                              'or\n'
                              '  python setup.py install')
    conn_location = os.path.abspath(conn_location).strip()

    # Instantiate API poster
    chartio_api = Poster()

    # Version check
    TermColor.print_header('Checking current Chartio version')
    latest_version = chartio_api.post('/connectionclient/version/')
    if latest_version.strip() != VERSION:
        TermColor.print_error('The current Chartio version is %s\n'
                              'This setup client is version %s.\n'
                              'Please consider upgrading chartio before continuing.'
                              % (latest_version.strip(), VERSION))
        TermColor.print_header('Would you like to halt setup in order to upgrade to the current version?')
        version_continue = get_value('[y/n]', default=None,
                                     validate=is_yes_no,
                                     validate_explanation='Please enter yes or no')
        if not bool_from_yes_no(version_continue):
            TermColor.print_header('Continuing despite version mismatch.')
        else:
            TermColor.print_ok('Exiting')
            sys.exit(0)

    if os.path.exists(CONFIG_DATA.SSH_KEY):
        TermColor.print_error('Existing configuration detected.\n'
            'Please delete %s directory contents or specify unique prefix by passing the --prefix argument' % CONFIG_DATA.PREFIX)
        sys.exit(0)

    # Login
    LOGIN_ATTEMPT_LIMIT = 3
    for _login_attempt in range(LOGIN_ATTEMPT_LIMIT):
        TermColor.print_ok('Enter your Chartio login information:')
        email = get_value('e-mail',
                          validate = lambda x: '@' in x,
                          validate_explanation = 'This is not a valid email')
        password = get_value('password', is_password=True)

        # Login user
        response = chartio_api.post('/connectionclient/login/',
                                    {'email': email,
                                     'password': password},
                                    offer_post=False)
        if response == 'success':
            TermColor.print_delay('Username and password confirmed')
            break
        else:
            TermColor.print_error(response)
    else:
        TermColor.print_error('Login tries exceeded.')
        sys.exit(1)

    TermColor.print_delay('Checking for existing SSH keys')
    if os.path.exists(CONFIG_DATA.SSH_KEY):
        TermColor.print_delay('SSH key found. Using the existing SSH key.')
    else:
        TermColor.print_delay('Generating keys for SSH tunneling')
        ret = subprocess.call([
            'ssh-keygen',
            '-q',  # shhh!
            '-N', '',  # No passphrase
            '-C', 'chartio.com ssh tunneling',
            '-t', 'rsa',
            '-f', CONFIG_DATA.SSH_KEY,
        ])

        if ret != 0:
            TermColor.print_error('Failed to generate SSH key. Please confirm you have'
                                  ' ssh-keygen installed.')
            sys.exit(1)
        TermColor.print_delay('Generated SSH keys.')

    # Get the host name of the remote server that will connect to the db
    TermColor.print_delay('Creating tunnel connection')
    response = json.loads(chartio_api.post('/connectionclient/remotehost/'))
    write_ssh_conf('remotehost', response['remotehost'])
    write_ssh_known_hosts(response['remotehost'], response['ssh_host_pub'])

    # Confirm we can connect on port 22 to remotehost
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
    sock.settimeout(10)
    try:
        sock.connect((get_ssh_conf_value('remotehost'), 22))
    except socket.error, exc:
        TermColor.print_error('Unable to connect to Chartio server %s on port 22.\n'
                              '    %s\n'
                              ' Cannot continue.'
                              % (exc, get_ssh_conf_value('remotehost')))
        sys.exit(1)
    else:
        sock.close()

    # Get the organizations
    organizations = json.loads(chartio_api.post('/connectionclient/organizations/')).get('organizations')
    if not organizations:
        print ('\nUnable to find organizations for your account. You must create or join an'
               ' organization through the Chartio web interface before running this.')
        sys.exit(1)
    elif len(organizations) == 1:
        org = organizations[0]
    else:
        org_map = dict([(p['name'], p) for p in organizations])
        org_name = get_choice('\nYou have multiple organizations.'
                              ' To which organization would you like to attach this database',
                              sorted(org_map.keys()))
        org = org_map[org_name]

    if options.oracle:
        # Cli-mode
        choices = database_choices(chartio_api)
        db_choice = choices.get('Oracle')
        if db_choice is None:
            print 'Chartio does not yet support connecting an Oracle datasource.'
            sys.exit(1)
        if not options.database_name:
            print 'Must provide a database name (--database-name) when connecting an Oracle database'
            sys.exit(1)
        if not options.role:
            print 'Must provide an existing read-only role (--role) when connecting an Oracle database'
            sys.exit(1)
        dbhost = (options.database_host
                  or db_choice['default_host'])
        local_port = (options.port
                      or db_choice['default_port'])
        if not can_connect_on_host_port(dbhost, local_port):
            print ('Database is not listening on %s and port %s' % (dbhost, local_port))
            sys.exit(1)
        password = get_value('Read-only role password', is_password=True, validate=empty_is_ok)
        if not can_query_oracle(options.role, password, local_port, options.database_name):
            print ('Unable to connect to Oracle database %r as %r'
                   % (options.database_name, options.role))
            sys.exit(1)
        ds_kwargs = {'db_type_id': db_choice['id'],
                     'organization_id': org['id'],
                     'remotehost': get_ssh_conf_value('remotehost'),
                     'dbhost': options.database_host,
                     'name': options.database_name,
                     'user': options.role,
                     'passwd': password,
                     'schema': options.schema,
                     }
        write_ssh_conf('dbhost', options.database_host)
        write_ssh_conf('localport', local_port)
    elif options.presto:
        # Cli-mode
        choices = database_choices(chartio_api)
        db_choice = choices.get('Presto')
        dbhost = (options.database_host
                  or db_choice['default_host'])
        local_port = (options.port
                      or db_choice['default_port'])
        if not can_connect_on_host_port(dbhost, local_port):
            print ('Database is not listening on %s and port %s' % (dbhost, local_port))
            sys.exit(1)
        ds_kwargs = {'db_type_id': db_choice['id'],
                     'organization_id': org['id'],
                     'remotehost': get_ssh_conf_value('remotehost'),
                     'dbhost': dbhost,
                     'port': local_port,
                     'catalog': options.catalog,
                     'schema': options.schema,
                     }
        write_ssh_conf('dbhost', dbhost)
        write_ssh_conf('localport', local_port)

    else:
        # Interactive setup
        db_accessors = {'mysql': MysqlAccessor,
                        'postgresql': PostgresqlAccessor,
                        'amazonredshift': PostgresqlAccessor}
        db_admin = DbAdmin(chartio_api, db_accessors)
        db_step = db_admin.start_step()
        while db_step:
            db_step = db_step()
        if db_step is None:
            TermColor.print_delay('Finished configuring database information')
        elif not db_step:
            TermColor.print_error('Exiting.')
            sys.exit(1)
        ds_kwargs = {'db_type_id': db_admin.settings.database_id,
                     'organization_id': org['id'],
                     'remotehost': get_ssh_conf_value('remotehost'),
                     'dbhost': db_admin.settings.database_host,
                     'name': db_admin.settings.database_name,
                     'user': db_admin.settings.readonly_user,
                     'schema': db_admin.settings.schema,
                     'passwd': db_admin.settings.readonly_password}
        write_ssh_conf('dbhost', db_admin.settings.database_host)
        write_ssh_conf('localport', db_admin.settings.database_port)

    # Create the datasource
    TermColor.print_delay('''Creating the datasource and the tunnel account on Chartio's server.'''
                          ''' This will take a moment.''')
    ssh_key = open('%s.pub' % CONFIG_DATA.SSH_KEY).read()
    ds_kwargs.update({'customer_ssh_pub_key': ssh_key,
                      'version': VERSION})
    response = chartio_api.post('/connectionclient/create/', ds_kwargs)
    response = json.loads(response)
    write_ssh_conf('remoteuser', response['connection']['server_username'])
    write_ssh_conf('remoteport', response['connection']['port'])
    datasource_id = response['datasource_id']
    ds_kwargs.update({'datasource_id': datasource_id, 'remoteport': response['connection']['port']})
    TermColor.print_delay('Datasource and Tunnel account created')

    chartio_connect_start()

    TermColor.print_delay('Datasource registered. chartio_connect is running.\n')
    crontab_entry = '@reboot %s -d --prefix=%s' % (conn_location, CONFIG_DATA.PREFIX)
    crontab_update(crontab_entry)
    TermColor.print_ok('\nYour datasource schema is currently being refreshed. '
                       '\nPlease check the status on the datasources settings '
                       'page on https://chartio.com')

if '__main__' == __name__:
    main()
