#!/usr/bin/env python

# Copyright (c) 2011-2014. All Right Reserved, https://chartio.com/
#
# THIS CODE AND INFORMATION ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY
# KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A
# PARTICULAR PURPOSE.

'''Establish and maintain an SSH connection providing a remote tunnel to a database.

    Has limited support for restarting the SSH process.

    Basically, runs
        ssh -R 3603:localhost:12345

    Designed for use on POSIX systems only (signals, forking, etc.).

'''

import ConfigParser
import logging
import logging.handlers
import optparse
import os
import signal
import subprocess
import sys
import time

from version import __version__

# Messages intended for the user
console_log = None
# Daemon-generated messages
daemon_log = None
# Global config instance
CONFIG_DATA = None
# Default install location
PREFIX_DEFAULT = '~/.chartio.d'
# Distributed version
VERSION = __version__
# SSH listen port
SSH_LISTEN_PORT = 22

def config_data_set(prefix):
    global CONFIG_DATA
    CONFIG_DATA = ConfigData(prefix)

class ConfigData(object):
    '''Needs to be kept in sync with chartio_setup to limit
    PYTHONPATH dependencies.
    '''
    def __init__(self, prefix=None):
        # Directories
        self.PREFIX = prefix or os.path.expanduser(PREFIX_DEFAULT)
        self.LOG_DIRECTORY = os.path.join(self.PREFIX, 'logs')
        self.RUN_DIRECTORY = os.path.join(self.PREFIX, 'run')
        self.SSH_DIRECTORY = os.path.join(self.PREFIX, 'sshkey')
        # Files
        self.CONFIG_FILE = os.path.join(self.PREFIX, 'chartio.cfg')
        self.LOG_FILE = os.path.join(self.LOG_DIRECTORY, 'chartio_connect.log')
        self.PID_FILE = os.path.join(self.RUN_DIRECTORY, 'chartio_connect.pid')
        self.SSH_KEY = os.path.join(self.SSH_DIRECTORY, 'id_rsa')
        self.SSH_KNOWNHOSTS = os.path.join(self.SSH_DIRECTORY, 'known_hosts')
        # Config file sections
        self.SSHTUNNEL_SECTION = 'SSHTunnel'
        # Ensure required directories exist
        map(self.dir_exists,
            (self.PREFIX, self.LOG_DIRECTORY, self.RUN_DIRECTORY, self.SSH_DIRECTORY))

    def dir_exists(self, dir):
        '''Determine whether a directory exists. Exit gracefuly if it does not.

        Arguments
            dir -- the directory path to check

        '''
        if not os.path.isdir(dir):
            sys.stderr.write('The directory %r does not exist.\n'
                             'Please run chartio_setup or rerun chartio_connect and specify --prefix\n'
                             'Exiting\n' % (dir))
            sys.exit(1)


class PidFile(object):
    '''PID file manager.'''
    def __init__(self, path):
        self.path = path

    def clear(self):
        '''Truncate, i.e., empty, the pid file'''
        try:
            open(self.path, 'w').close()
        except IOError:
            pass

    def read(self):
        '''Extract the PID from the file.

        Return
            int | None -- PID value on success or None if the file is empty

        Raise
            ValueError -- if the file contains a non-integer value

        '''
        pid_str = ''
        try:
            f = open(self.path)
        except IOError:
            pass
        else:
            pid_str = f.read()
            f.close()
        if pid_str.strip():
            pid = int(pid_str)
        else:
            pid = None
        return pid

    def write(self, pid):
        '''Write a PID value into the file.

        Raise
            ValueError -- non-integer-looking pid argument
            IOError -- unable to open file

        Arguments
            pid -- the PID value to write. note that it is converted via int().

        '''
        pid_str = str(int(pid))
        f = open(self.path, 'w')
        try:
            pid_str = f.write(pid_str)
        except Exception, exc:
            f.close()
            raise exc
        else:
            f.close()


class SSHTunnel(object):
    # Number of consecutive times to retry before exiting
    RETRIES = 15
    # Initial number of seconds between retries. Doubles for backoff.
    RETRY_DELAY = 10
    # Seconds between pings to keep tunnel alive
    PING_DELAY = 10

    def __init__(self, remote_user, remote_host, remote_port, database_host, local_port):
        '''Constructor

        Arguments
            remote_user -- remote user name
            remote_host -- abc.chartio.com host
            remote_port -- remote tunnel port
            database_host -- database server ip or hostname
            local_port -- database listening port. (3306, 5432, etc.)

        Typical values:
                host: amz123.chartio.com
                remote_port: 12345 (some big number)
                database_host: 127.0.0.1 or mydb.server.com
                local_port: 3306 (mysql/pg port)

        '''
        self.host = remote_host
        self.user = remote_user
        self.remote_port = int(remote_port)
        self.database_host = database_host
        self.local_port = int(local_port)
        self.pid_file = PidFile(CONFIG_DATA.PID_FILE)

        self._pid = None
        self._ssh_process = None
        self._retries = self.RETRIES
        self._shutdown = False
        self._retry_delay = self.RETRY_DELAY

    # Store PID in a file. Get, set, and delete as property.
    # This lets a parent process set the PID and the child process load it in.
    def _get_pid(self):
        if self._pid is None:
            daemon_log.info('Reading pid file')
            self._pid = self.pid_file.read()
        return self._pid

    def _set_pid(self, pid):
        daemon_log.info('Writing pid %s', pid)
        self.pid_file.write(pid)
        self._pid = pid

    def _del_pid(self):
        daemon_log.info('Clearing pid file')
        self.pid_file.clear()

    pid = property(_get_pid, _set_pid, _del_pid)

    def _sig_chld(self, signum, frame):
        '''Handle subprocess dying'''
        if self._shutdown:
            return
        daemon_log.warning('Subprocess died. Attempting to restart.')
        if self._retries == 0:
            daemon_log.error('Reached retry limit (%d). Exiting.' % (self.RETRIES))
            sys.exit(1)
        daemon_log.info('Waiting %d seconds before connection attempt' % (self._retry_delay))
        signal.signal(signal.SIGALRM, self._sig_alarm)
        signal.alarm(self._retry_delay + 120)
        time.sleep(self._retry_delay)
        self._retries -= 1
        self._retry_delay *= 2
        self._make_connection()

    def _sig_alarm(self, signum, frame):
        '''SIGALRM handler. Reaching this means the SSH connection probably succeeded.'''
        logging.info('Connection appears to be working. Resetting retry information.')
        self._retries = self.RETRIES
        self._retry_delay = self.RETRY_DELAY

    def _make_connection(self, debug=False):
        '''Create SSH connection subprocess'''
        forward_only = '-N'
        remote_tunnel = '%d:%s:%d' % (self.remote_port, self.database_host, self.local_port)
        user_host = '%s@%s' % (self.user, self.host)
        permit_remote_connect = '-g'
        keep_alive_opt = 'ServerAliveInterval=%s' % (self.PING_DELAY)
        exit_if_keep_alive_fails_opt = 'ServerAliveCountMax=1'
        known_hosts_opt = 'UserKnownHostsFile=%s' % (CONFIG_DATA.SSH_KNOWNHOSTS)
        args = ['ssh',
                forward_only,
                '-R', remote_tunnel,
                user_host,
                permit_remote_connect,
                '-i', CONFIG_DATA.SSH_KEY,
                '-p', str(SSH_LISTEN_PORT),
                '-o', keep_alive_opt,
                '-o', exit_if_keep_alive_fails_opt,
                '-o', known_hosts_opt,
                ]
        if debug:
            # !!! Early return
            args.insert(1, '-vvv')
            return ' '.join(args)
        daemon_log.info('Making SSH tunnel connection as %r to %r.', self.user, self.host)
        signal.signal(signal.SIGCHLD, self._sig_chld)
        # ??? write stdout?
        self._ssh_process = subprocess.Popen(args,
                                             stdout=None,
                                             stderr=subprocess.STDOUT)

    def daemonize(self):
        '''Fork into a daemon process and exit.

        Return
            True for child process

        Exit
            0 -- successfully spawned child process
            1 -- failed to spawn child process

        '''
        pid = os.fork()
        if 0 < pid:
            # Parent (calling) process
            console.info('chartio_connect daemonized as process %d', pid)
            daemon_log.info('Daemon process running: %d', pid)
            # Implicitly updates the PID file
            self.pid = pid
            sys.exit(0)
        elif 0 == pid:
            # Child (daemon) process
            # Set session id to disassociate from tty
            os.setsid()
            # Close fds 0, 1, 2
            for fd in range(3):
                try:
                    os.close(fd)
                except OSError:
                    pass
            return True
        else:
            err_str = os.strerror(pid)
            daemon_log.error('fork() failed: %s' % (err_str))
            sys.stderr.write('Failed to daemonize. %s\nExiting.\n' % (err_str))
            # !!! Early exit
            sys.exit(1)
        # Not reached

    def cleanup(self, signum, frame):
        '''Signal handler called on daemon exit (we hope).'''
        daemon_log.info('Running cleanup')
        # Disable special SIGCHLD handling
        self._shutdown = True
        daemon_log.debug('Killing process %s' % (self.pid))
        os.kill(self._ssh_process.pid, signal.SIGKILL)
        # Permit SIGKILL to propagate
        time.sleep(1)
        del self.pid
        logging.shutdown()
        # !!! Expected exit
        sys.exit(0)

    @classmethod
    def kill(cls, silent=False):
        '''Attempt to stop the daemon process.'''
        old_level = console.getEffectiveLevel()
        if silent:
            console.setLevel(logging.WARNING)
        pid_file = PidFile(CONFIG_DATA.PID_FILE)
        if os.path.exists(pid_file.path):
            console.info('Trying to stop previous running instance.')
            pid = pid_file.read()
            if pid is None:
                console.info('Daemon does not seem to be running.')
            else:
                console.info('Stopping daemon at %s', pid)
                try:
                    os.kill(pid, signal.SIGINT)
                except OSError:
                    console.info('Previous instance probably is not running anymore.')
                else:
                    time.sleep(1)
        else:
            console.info('No PID file found. Unable to kill daemon.')
        if silent:
            console.setLevel(old_level)

    def main(self, daemonize=True):
        '''Entry point for initiating an SSH connection and tunnel.'''
        mode = (daemonize and 'daemon mode') or ('non-daemon mode')
        daemon_log.info('Starting SSH connection and tunnel in %s' % (mode))
        if daemonize:
            self.daemonize()
        else:
            console.info('Starting in %s. Press CTRL-C to stop.', mode)
        self._make_connection()
        signal.signal(signal.SIGTERM, self.cleanup)
        signal.signal(signal.SIGINT, self.cleanup)
        # Wait for next SIGALRM or other terminating signal
        while True:
            signal.pause()
        # Should not be reached
        sys.exit(2)


def opt_args_gather(usage):
    '''Parse options and arguments.

    Return
     (optparse options object, argument list)

    '''
    parser = optparse.OptionParser(usage, version='%prog @VERSION@')
    parser.add_option('-d', '--daemonize', dest='daemonize', action='store_true', default=False,
                      help='Disassociate from the terminal and run in the background')
    parser.add_option('--prefix',
                      help=('Prefix argument (if any) used during chartio_setup configuration.'
                            ' Specifies where to find/place run-time data.'
                            ' Defaults to %r' % (PREFIX_DEFAULT)))
    opt_args = parser.parse_args()
    return opt_args


def logging_init():
    '''Configure and begin logging.

    Updates the daemon_log and console global variables.

    '''
    global daemon_log
    daemon_log = logging.getLogger('ssh_tunnel')
    # Rotate log at 1 MB
    file_handler = logging.handlers.RotatingFileHandler(CONFIG_DATA.LOG_FILE, maxBytes=(2**20))
    verbose_fmt = logging.Formatter('%(levelname)s @ %(asctime)s %(filename)s:%(lineno)d %(message)s',
                                    datefmt='%Y-%m-%d %H:%M:%S')
    file_handler.setFormatter(verbose_fmt)
    daemon_log.addHandler(file_handler)
    daemon_log.setLevel(logging.DEBUG)

    global console
    console = logging.getLogger('user')
    stdout_handler = logging.StreamHandler(sys.stdout)
    plain_fmt = logging.Formatter('%(message)s')
    stdout_handler.setFormatter(plain_fmt)
    console.addHandler(stdout_handler)
    console.setLevel(logging.DEBUG)


def main():
    cmd_args = ('start', 'restart', 'stop', 'debug')
    usage = 'Usage: %%prog [options] [%s]' % (' | '.join(cmd_args))
    options, args = opt_args_gather(usage)
    config_data_set(options.prefix)

    logging_init()

    # Get configuration
    config_file = CONFIG_DATA.CONFIG_FILE
    if not os.path.exists(config_file) or not os.path.isfile(config_file):
        sys.stderr.write('Config file not found: %s.\n' % (config_file))
        sys.stderr.write('Please re-run installation of chartio or specify a different prefix.\n')
        sys.stderr.write('Exiting.\n')
        sys.exit(1)
    config = ConfigParser.ConfigParser()
    config.read(config_file)
    try:
        conf = dict(config.items(CONFIG_DATA.SSHTUNNEL_SECTION))
    except ConfigParser.NoSectionError:
        sys.stderr.write('Config file appears empty. Please run chartio_setup')
        sys.exit(1)
    dbhost = conf.get('dbhost', '127.0.0.1')

    tunnel = SSHTunnel(conf['remoteuser'], conf['remotehost'],
                       conf['remoteport'], dbhost, conf['localport'])
    action = 'start'
    if args:
        action = args[0]
    if action in ('start', 'restart'):
        tunnel.kill(silent=True)
        tunnel.main(options.daemonize)
    elif action == 'debug':
        console.info('This program would have run the following command:')
        console.info(tunnel._make_connection(debug=True))
    elif action == 'stop':
        tunnel.kill(silent=False)
    elif action == 'help':
        # Useful kluge
        sys.argv.append('--help')
        opt_args_gather(usage)
    else:
        sys.stderr.write('Invalid action: %r.\n'
                         'If specified, it must be one of %s\n' % (action, ', '.join(cmd_args)))
        sys.exit(1)

if __name__ == '__main__':
    main()
