#!/usr/bin/python
#
# db_sync.py
#
# Synchronize data between SQL capable databases.
#
# Author:
#   Mark Taylor (mtaylor@taylor-hq.com)
#
# Version:
#   July 2014
#
# See presentation:
#   http://www.taylor-hq.com/~mtaylor/presentations/dbsync/DBSync.html
#
# Color-code this script:
#   pygmentize -f html -o ~/public_html/presentations/dbsync/db_sync.py.html -O full db_sync.py
#


import os, random

# Use Psycopg2 because it has server-side and dict cursors:
import psycopg2, psycopg2.extras


def get_dict_cursor(sql, db_conn):
    """Return a server-side dictionary-based cursor object."""
    # Generate a random name, to avoid cursor name collision:
    nm = 'dbsync-%d-%d' % (os.getpid(), random.randint(1,999999))
    cur = db_conn.cursor(name = nm,
                         cursor_factory = psycopg2.extras.DictCursor)
    return cur


class DBSyncTable(object):
    """Database Sync Table object."""
    def __init__(self, table, unique_id, timestamp, insert_only):
        # Parameter variables:
        self.table       = table
        self.unique_id   = unique_id
        self.timestamp   = timestamp
        self.insert_only = insert_only
        # Generated variables:
        self.slave_sql_insert = None
        self.slave_sql_update = None
        self.columns = []
        self.inserts, self.updates = 0, 0

    def _get_slave_max(self, db_slave):
        """Return the max-value from a db-connection of the defined 'timestamp'
        column for this object.
        """
        if self.timestamp is None:
            return None
        sql = 'SELECT MAX(%s) AS %s FROM %s' % (self.timestamp, self.timestamp, self.table)
        cur = get_dict_cursor(sql, db_slave)
        cur.execute(sql, {})
        timestamp_max = cur.fetchone()[self.timestamp]
        cur.close()
        return timestamp_max

    def _gen_master_sql(self, timestamp_max):
        """Return SQL to use against the master database to retrieve the data
        to sync."""
        sql = 'SELECT * FROM %s' % (self.table)
        if timestamp_max is not None:
            sql += ' WHERE %s > %%(timestamp_max)s' % (self.timestamp)
        return sql

    def _gen_insert_sql(self):
        """Return SQL to use against the slave database to INSERT a data row."""
        sql = 'INSERT INTO %s (%s) VALUES (' % (self.table, ','.join(self.columns))
        sql += ','.join(['%%(%s)s' % c for c in self.columns])
        sql += ')'
        return sql

    def _gen_update_sql(self):
        """Return SQL to use against the slave database to UPDATE a data row."""
        if self.unique_id is None:
            raise Exception, 'For non-insert_only (UPDATE), must specify unique_id field'
        sql = 'UPDATE %s SET ' % (self.table)
        sql += ','.join(['%s=%%(%s)s' % (c,c) for c in self.columns if c != self.unique_id])
        sql += ' WHERE %s = %%(%s)s' % (self.unique_id, self.unique_id)
        return sql

    def _extract_columns(self, row_dict, verbosity=0):
        """Populate the list of columns needed for syncing.
        Also populates the object's SQL for INSERT and UPDATE."""
        self.columns = row_dict.keys()
        if verbosity > 0:
            print self.columns
        if self.slave_sql_insert is None:
            self.slave_sql_insert = self._gen_insert_sql()
            if verbosity >= 0:
                print self.slave_sql_insert
        if self.slave_sql_update is None and not self.insert_only:
            self.slave_sql_update = self._gen_update_sql()
            if verbosity >= 0:
                print self.slave_sql_update

    def perform_sync(self, db_master, db_slave, verbosity=0):
        """Perform database synchronization between the specified master
        database connection and the specified slave database connection.
        Returns a list of inserts and updates."""
        # Get (if needed) the slave's latest timestamp:
        slave_timestamp_max = self._get_slave_max(db_slave)
        # Generate and execute the SQL to retrieve data from the master:
        master_sql = self._gen_master_sql(slave_timestamp_max)
        cur_master = get_dict_cursor(master_sql, db_master)
        cur_master.execute(master_sql, {'timestamp_max':slave_timestamp_max})
        cur_slave = db_slave.cursor()
        n = 0
        # Iterate over each row from the master database:
        for master_row in cur_master:
            # IFF first row, we need to figure the list of columns:
            if n == 0:
                self._extract_columns(master_row, verbosity)
            n += 1
            # IFF not insert-only, try UPDATE command:
            if not self.insert_only:
                if verbosity > 0:
                    print cur_slave.mogrify(self.slave_sql_update, master_row)
                cur_slave.execute(self.slave_sql_update, master_row)
                if cur_slave.rowcount == 1:
                    self.updates += 1
                    continue
                elif cur_slave.rowcount > 1:
                    db_slave.rollback()
                    raise Exception, 'UPDATE statement on %s modified more than one row' % (self.table)
            if verbosity > 0:
                print cur_slave.mogrify(self.slave_sql_insert, master_row)
            # UPDATE command didn't modify any rows, so try INSERT:
            try:
                cur_slave.execute(self.slave_sql_insert, master_row)
                self.inserts += 1
            except Exception, e:
                db_slave.rollback()
                raise Exception, e
        # All done processing master's rows: commit, close, and return:
        db_slave.commit()
        cur_master.close()
        cur_slave.close()
        return [self.inserts, self.updates]


if __name__ == '__main__':
    # Sample tables:
    sync_tables = [{'table':'gauges',
                        'unique_id':'id', 'timestamp':None,   'insert_only':False},
                   {'table':'gauges_daily',
                        'unique_id':None, 'timestamp':'date', 'insert_only':True },
                  ]

    # Sample databases:
    db_master_name = 'tva'
    db_master_dsn = 'dbname=%s user=postgres' % (db_master_name)
    db_slave_name = 'tva2'
    db_slave_dsn = 'dbname=%s user=postgres' % (db_slave_name)

    # Connect to databases:
    db_master_conn = psycopg2.connect(db_master_dsn)
    if not db_master_conn:
        raise Exception, 'Cannot connect to source database "%s"' % db_master_conn
    db_slave_conn = psycopg2.connect(db_slave_dsn)
    if not db_slave_conn:
        raise Exception, 'Cannot connect to destination database "%s"' % db_slave_conn

    # Walk the list of tables to sync:
    for st in sync_tables:
        #DBSyncTable(table, unique_id, timestamp, insert_only):
        dbst = DBSyncTable(**st)
        ins, upd = dbst.perform_sync(db_master_conn, db_slave_conn)
        print '%s: INSERTS: %d;  UPDATES: %d' % (st['table'], ins, upd)

    db_master_conn.close()
    db_slave_conn.close()