# -*- coding:utf-8 -*-
'''
Django DB backend providing wrapper for two connections: master connection
for read-write operations and slave connection for read-only operations.
'''
from django.db.backends import util

from django.db.backends.mysql.base import DatabaseWrapper as MySQLDatabaseWrapper, django_conversions, \
     Database, DatabaseClient, DatabaseCreation, DatabaseError, DatabaseFeatures, DatabaseIntrospection, \
     DatabaseOperations, DatabaseValidation, CursorWrapper, django_conversions, FIELD_TYPE, FLAG, CLIENT, \
     SafeString, SafeUnicode

from django.db.backends.signals import connection_created
from django.core import signals as core_signals
from django.db.utils import ConnectionDoesNotExist
from django.db.transaction import TransactionManagementError

from exceptions import KeyError

from collections import OrderedDict
import functools
import time

import codecs
codecs.register(lambda name: codecs.lookup('utf8') if name == 'utf8mb4' else None)

class FixedOrderedDict(OrderedDict):
    """
    Fixes OrderedDict to not print error messages when global bindings fail in destructor. Yes, that is safe to do!
    """
    def __del__(self):
        try:
            super(FixedOrderedDict, self).__del__()
        except:
            pass

def pre_save_callback(**kwargs):
    """
    Callback for multi db save
    """
    from django.db import connection
    # Must use master connection for all saves
    connection.use_master()


def request_started_callback(**kwargs):
    """
    Default to slave on request start
    """
    from django.db import connection
    # Reset to use slave on request start
    connection.use_slave()

# Add request start handler which defaults connection to slave
# This will in turn call connection.use_slave(), and on any param access
# it will hook up the pre_save handler.  So we are guaranteed to have
# the pre_save callback hooked up prior to any model operation
# Note: We cannot connect the pre_save callback first, because the pre_save
#       signal is not available until AFTER the connection class is initiated
#       Attempting to load the pre_save Signal prior here will result in a
#       infinite recursion loop since Signal will try to load connection
core_signals.request_started.connect(request_started_callback)

from django.conf import settings
import random

class DatabaseWrapper(MySQLDatabaseWrapper):
    '''
    Main wrapper class used to access database for all database
    operations in Django.
    '''

    def __init__(self, *args, **kwargs):
        self.__use_master = False
        self.__is_master = False
        self.__signals_connected = False
        self.__callbacks = []
        self.__savepoints = FixedOrderedDict()
        self.__db_key = args[1]
        super(DatabaseWrapper, self).__init__(*args, **kwargs)

    def __getattribute__(self, name):
        # On all attribute access, make sure the pre_save callback has been initialized
        # See note above regarding callback connection on why it needs to occur here and
        # not as part of the file definition
        if name != '_DatabaseWrapper__signals_connected' and not self.__signals_connected:
            # This piece of code is re-entrant (signal connect is idempotent)
            from django.db.models import signals as models_signals
            models_signals.pre_save.connect(pre_save_callback)
            self.__signals_connected = True

        return super(DatabaseWrapper, self).__getattribute__(name)

    def switch_key(self, new_key):
        """
        Switches key used when creating a connection to a db based on its name
        new_key specifies a key from settings.DATABASES
        """
        if not new_key in settings.DATABASES:
            raise ConnectionDoesNotExist("The connection %s doesn't exist" % new_key)

        if self.connection is not None:
            self.connection.close()
            self.connection = None

        self.__use_master = False
        self.__is_master = False
        self.__db_key = new_key
        self.settings_dict = settings.DATABASES[new_key]

    def switch_key_from_parameters(self, **kwargs):
        """
        Function used with sharding
        Switches key used when creating a connection to a db based on
        required parameters specified in settings.LOOKUP_REQUIRED_PARAMETERS
        """
        lookup_dictionary = getattr(settings, 'LOOKUP_DICTIONARY')
        lookup_function = getattr(settings, 'LOOKUP_FUNCTION')

        lookup_value = lookup_function(**kwargs)
        db_key = lookup_dictionary.get(lookup_value, None)
        if not db_key:
            raise KeyError("Lookup function returned a value not specified in lookup dictionary: %s" % lookup_value)
        self.switch_key(db_key)

    def is_master(self):
        """
        Return whether or not we are using master connection
        """
        return self.__is_master

    def use_master(self):
        '''
        Use the connection to use the master DB
        Returns True if the connection type was changed
        '''
        if not self.__use_master:
            self.__use_master = True
            return True
        return False

    def use_slave(self):
        '''
        Use the connection to use the slave DB
        Returns True if the connection type was changed
        '''
        if self.__use_master:
            self.__use_master = False
            return True
        return False

    def _cursor(self):
        if self.__use_master:
            # If we are forcing master, set the existing connection to the master db
            self.__use_master_connection()
        else:
            # Otherwise, use one of the slaves
            self.__use_slave_connection()

        return CursorWrapper(self.connection.cursor())

    # Overriding the original leave_transaction management method because it is broken.
    # See https://code.djangoproject.com/ticket/2227
    def leave_transaction_management(self):
        """
        Leaves transaction management for a running thread. A dirty flag is carried
        over to the surrounding block, as a commit will commit all changes, even
        those from outside. (Commits are on connection level.)
        """
        self._leave_transaction_management(self.is_managed())
        if self.transaction_state:
            del self.transaction_state[-1]
        else:
            raise TransactionManagementError("This code isn't under transaction management")
        if not self.transaction_state or not self.is_managed():
            if self._dirty:
                self.rollback()
                raise TransactionManagementError("Transaction managed block ended with pending COMMIT/ROLLBACK")
            self._dirty = None

    def commit(self):
        """
        Commits the current transaction and notifies any pending callbacks
        """
        super(DatabaseWrapper, self).commit()
        self.__notify(commit=True)

    def rollback(self):
        """
        Rolls back the current transaction and notifies any pending callbacks
        """
        super(DatabaseWrapper, self).rollback()
        self.__notify(commit=False)

    def _commit(self):
        """
        Extend _commit to log debug information
        """
        if self.__debug():
            start = time.time()
        try:
            ret = super(DatabaseWrapper, self)._commit()
        finally:
            if self.__debug():
                end = time.time()
                self.__log(start, end, 'commit')
        return ret

    def _rollback(self):
        """
        Extend _rollback to log debug information
        """
        if self.__debug():
            start = time.time()
        try:
            ret = super(DatabaseWrapper, self)._rollback()
        finally:
            if self.__debug():
                end = time.time()
                self.__log(start, end, 'rollback')
        return ret

    def savepoint(self):
        """
        Creates a database savepoint and adds a matching savepoint to the pending callback queue
        """
        sid = super(DatabaseWrapper, self).savepoint()
        self.__savepoints[sid] = len(self.__callbacks)
        return sid

    def savepoint_commit(self, sid):
        """
        Commits the current transaction since the given savepoint and notifies any pending callbacks that were added since the savepoint
        """
        super(DatabaseWrapper, self).savepoint_commit(sid)
        self.__notify(commit=True, savepoint=sid)

    def savepoint_rollback(self, sid):
        """
        Rolls back the current transaction since the given savepoint and notifies any pending callbacks that were added since the savepoint
        """
        super(DatabaseWrapper, self).savepoint_rollback(sid)
        self.__notify(commit=False, savepoint=sid)

    def on_commit(self, func, *args, **kwargs):
        """
        Adds a callback to be notified when (and if) the current transaction is commited
        """
        self.__defer(func=func, args=args, kwargs=kwargs, on_commit=True)

    def on_rollback(self, func, *args, **kwargs):
        """
        Adds a callback to be notified when (and if) the current transaction is rolled back
        """
        self.__defer(func=func, args=args, kwargs=kwargs, on_commit=False)

    def reset_callbacks(self):
        """
        Removes all pending callbacks
        """
        del self.__callbacks[0:]
        self.__savepoints.clear()

    def __defer(self, func, args, kwargs, on_commit):
        """
        Adds a function callback if we are in managed mode and otherwise calls the function now
        """
        if not self.is_managed():
            if on_commit:
                func(*args, **kwargs)
            return
        item = (on_commit, func, args, kwargs)
        if item not in self.__callbacks:
            self.__callbacks.append(item)

    def __notify(self, commit, savepoint=None):
        """
        Notifies all interested callbacks of the commit or rollback
        """
        idx = self.__pop_index(savepoint)

        callbacks = reversed(self.__callbacks[idx:])
        self.__callbacks = self.__callbacks[:idx]

        for on_commit, func, args, kwargs in callbacks:
            if on_commit != commit:
                continue
            func(*args, **kwargs)

    def __pop_index(self, savepoint):
        """
        Removes the given savepoint and all later savepoints from the queue.

        returns the associated index into __callbacks
        """
        while len(self.__savepoints) > 0:
            key, value = self.__savepoints.popitem()
            if key == savepoint:
                return value
        return 0

    def __debug(self):
        """
        returns whether we should debug the current connection
        """
        return self.use_debug_cursor or (self.use_debug_cursor is None and settings.DEBUG)

    def __log(self, start, end, sql):
        """
        logs a sql query and the time it took to execute
        """
        duration = end - start
        self.queries.append({
            'time': "%.3f" % duration,
            'sql': sql
        })

    def __use_master_connection(self):
        """
        Helper function to set the connection to the master db connection
        """

        # Only switch the connection if it is not the master db or we have no connection
        if not self.__is_master or not self._valid_connection():

            # Close the old connection if it exists (i.e. we have a valid connection, but it's a slave connection)
            if self.connection is not None:
                self.connection.close()
                self.connection = None

            # Get the base db options common to master and slave
            kwargs = self.__get_base_db_options()

            # Check for misconfigured settings
            if not self.settings_dict.has_key('MASTER_DATABASE_HOST'):
                raise LookupError("MASTER_DATABASE_HOST is not defined")

            master_db_host = self.settings_dict['MASTER_DATABASE_HOST']
            if master_db_host.startswith('/'):
                kwargs['unix_socket'] = master_db_host
            elif master_db_host:
                kwargs['host'] = master_db_host

            self.connection = Database.connect(**kwargs)
            connection_created.send(sender=self.__class__, connection=self)
            self.__is_master = True

    def __use_slave_connection(self):
        """
        Helper function to set the connection to the slave db connection
        """
        # If we have no slaves, just use the master
        slave_db_hosts = self.settings_dict.get('SLAVE_DATABASE_HOSTS', None)
        if not slave_db_hosts:
            return self.__use_master_connection()

        # Only switch the connection if it is not the slave db or we have no connection
        if self.__is_master or not self._valid_connection():

            # Close the old connection if it exists (i.e. we have a valid connection, but it's a master connection)
            if self.connection is not None:
                self.connection.close()
                self.connection = None

            kwargs = self.__get_base_db_options()

            # Pick a slave database
            slave_database_host = slave_db_hosts[random.randint(0, len(slave_db_hosts)-1)]
            if slave_database_host.startswith('/'):
                kwargs['unix_socket'] = slave_database_host
            else:
                kwargs['host'] = slave_database_host

            self.connection = Database.connect(**kwargs)
            connection_created.send(sender=self.__class__, connection=self)
            self.__is_master = False

    def __get_base_db_options(self):
        """
        Get the base db options common to master and slave
        """
        kwargs = {
            'conv': django_conversions,
            'charset': 'utf8mb4',
            'use_unicode': False,
            }
        # If we are forcing slave db, use it
        if self.settings_dict['USER']:
            kwargs['user'] = self.settings_dict['USER']
        if self.settings_dict['NAME']:
            kwargs['db'] = self.settings_dict['NAME']
        if self.settings_dict['PASSWORD']:
            kwargs['passwd'] = self.settings_dict['PASSWORD']
        if self.settings_dict['PORT']:
            kwargs['port'] = self.settings_dict['PORT']
        if self.settings_dict['OPTIONS']:
            kwargs.update(self.settings_dict['OPTIONS'])

        return kwargs

