/*
 * sync.go
 *
 * Synchronize data between SQL capable databases.
 *
 * Author:
 *   Mark Taylor (mtaylor@taylor-hq.com)
 *
 * Version:
 *   July 2014
 *
 * See presentation:
 *   http://www.taylor-hq.com/~mtaylor/presentations/dbsync/DBSync.html
 *
 * Color-code this script:
 *   pygmentize -f html -o ~/public_html/presentations/dbsync/db_sync-lib.go.html -O full sync.go
 *
 * NOTES:
 *   Using goroutines (gophers) to parallelize loading data into Postgres:
 *     http://www.acloudtree.com/how-to-shove-data-into-postgres-using-goroutinesgophers-and-golang/
 *
 * TODO:
 *   - Query paramters: "?" versus "$n" ? (get MySQL database/sql driver to work)
 *   - Use channels to communicate between a SELECT goroutine to several INSERT/UPDATE goroutines.
 *
 */

package sync


import (
    "fmt"
    "strings"
    "errors"
    "database/sql"
)


type TableSyncParams struct {
    Table_Name        string
    Unique_ID         string
    Timestamp_Column  string
    Insert_Only       bool
}

type TableSyncResults struct {
    Failed   bool
    Inserts  int
    Updates  int
}

type tableSyncSQLStuff struct {
    columns          []string
    insert_sql       string
    update_sql       string
    transaction      *sql.Tx
    insert_prepared  *sql.Stmt
    update_prepared  *sql.Stmt
}

type tableSyncState struct {
    params     TableSyncParams
    results    TableSyncResults
    sql_stuff  tableSyncSQLStuff
    verbosity  int
    db_source  *sql.DB
    db_dest    *sql.DB
}


func update_row(state *tableSyncState, upd_items []interface{}) (success bool, err error) {
    if state.params.Insert_Only {
        return
    }

    res, err := state.sql_stuff.update_prepared.Exec(upd_items...)
    if err != nil {
        fmt.Printf("ERROR: Table %s: UPDATE Exec() failure: %s\n", state.params.Table_Name, err)
        return
    }
    n, err := res.RowsAffected()
    switch n {
        case 1:
            // One row updated:
            state.results.Updates++
            success = true
        case 0:
            // No rows updated:
            break
        default:
            // More than one row updated! Fail
            e := fmt.Sprintf("ERROR: Table %s: UPDATE modified %d rows!\n", state.params.Table_Name, n)
            err = errors.New(e)
    }

    return
}


func insert_row(state *tableSyncState, ins_items []interface{}) (success bool, err error) {
    res, err := state.sql_stuff.insert_prepared.Exec(ins_items...)
    if err != nil {
        fmt.Printf("ERROR: Table %s: INSERT Exec() failure: %s\n", state.params.Table_Name, err)
        return
    }
    n, err := res.RowsAffected()
    switch n {
        case 1:
            // One row inserted:
            state.results.Inserts++
            success = true
        default:
            e := fmt.Sprintf("ERROR: Table %s: INSERT modified %d rows!\n", state.params.Table_Name, n)
            err = errors.New(e)
    }

    return
}


func sync_row(state *tableSyncState, row map[string]interface{}) (err error) {

    if state.verbosity > 1 {
        fmt.Printf("%q\n", row)
    }

    // For substitution, get the keys in expected order (for $1, $2, ... replacements):
    // Note: Would be nice to be able to use template and a dictionary instead.
    upd := make([]interface{}, 0, len(state.sql_stuff.columns))
    ins := make([]interface{}, 0, len(state.sql_stuff.columns))
    for _, k := range state.sql_stuff.columns {
        ins = append(ins, row[k])
        if ! state.params.Insert_Only && k != state.params.Unique_ID {
            upd = append(upd, row[k])
        }
    }
    if ! state.params.Insert_Only {
        upd = append(upd, row[state.params.Unique_ID])
        if state.verbosity > 2 {
            fmt.Printf("UPDATE: %s\n", upd)
        }
    }
    if state.verbosity > 2 {
        fmt.Printf("INSERT: %s\n", ins)
    }

    var success bool
    // UPDATE:
    success, err = update_row(state, upd)
    if success || err != nil {
        return
    }

    // INSERT:
    success, err = insert_row(state, ins)
    if success || err != nil {
        return
    }

    return
}


// Determine the maximum value of a given table.column in a given database.
func ColumnMax(table_name, table_column string, db *sql.DB) (mx interface{}, err error) {
    if len(table_name) == 0 || len(table_column) == 0 {
        return
    }

    // Find the largest value of db.(table_name).(table_column)
    qmax := fmt.Sprintf("SELECT MAX(%s) FROM %s", table_column, table_name)
    err = db.QueryRow(qmax).Scan(&mx)
    if err != nil {
        fmt.Printf("ERROR: Table %s: QUERY-MAX (%s): %s\n", table_name, qmax, err)
    }

    return
}


func Sync(params TableSyncParams, db_src *sql.DB, db_dest *sql.DB, verbosity int) (TableSyncResults, error) {
    var rows *sql.Rows
    var err  error

    state := tableSyncState{ db_source:db_src, db_dest:db_dest, params:params, verbosity:verbosity }
    state.results.Failed = false

    timestamp_max, err := ColumnMax(state.params.Table_Name, state.params.Timestamp_Column, state.db_dest)
    if err != nil {
        return state.results, err
    }
    if state.verbosity > 0 {
        fmt.Printf("Table %s: MAX(src:%s): %q\n", state.params.Table_Name, state.params.Timestamp_Column, timestamp_max)
    }

    // Create SQL to query from source database:
    query := fmt.Sprintf("SELECT * FROM %s", state.params.Table_Name)
    if timestamp_max != nil {
        query = fmt.Sprintf("%s WHERE %s > $1", query, state.params.Timestamp_Column)
        rows, err = state.db_source.Query(query, timestamp_max)
    } else {
        rows, err = state.db_source.Query(query)
    }
    if state.verbosity > 0 {
        fmt.Printf("Table %s: SOURCE-QUERY: %s\n", state.params.Table_Name, query)
    }
    if err != nil {
        fmt.Printf("ERROR: Table %s: Failed source QUERY (%s): %s\n", state.params.Table_Name, query, err)
        state.results.Failed = true
        return state.results, err
    }

    // Determine column names in query results:
    cols, err := rows.Columns()
    if err != nil {
        fmt.Printf("ERROR: Table %s: determining source query columns: %s\n", state.params.Table_Name, err)
        state.results.Failed = true
        return state.results, err
    }
    state.sql_stuff.columns = cols
    if state.verbosity > 0 {
        fmt.Printf("Columns(%s): %v\n", state.params.Table_Name, strings.Join(cols, ", "))
    }
    // Create a list of values:
    vals   := make([]interface{}, len(cols))
    // Make a list of pointers to vals[], for rows.Scan() to use:
    vals_p := make([]interface{}, len(cols))
    for i, _ := range vals {
        vals_p[i] = &vals[i]
    }

    if state.verbosity > 0 {
        fmt.Println("Table:", state.params.Table_Name)
        fmt.Printf("Columns: %v\n", state.sql_stuff.columns)
    }
    cols_k  := make([]string, 0, len(state.sql_stuff.columns))
    cols_kv := make([]string, 0, len(state.sql_stuff.columns)-1)
    i_xy := 0
    for i, k := range state.sql_stuff.columns {
        cols_k = append(cols_k, fmt.Sprintf("$%d", i+1))
        if k != state.params.Unique_ID {
            cols_kv = append(cols_kv, fmt.Sprintf("%s=$%d", k, i_xy+1))
            i_xy++
        }
    }
    // Create UPDATE statement:
    if ! state.params.Insert_Only {
        state.sql_stuff.update_sql = fmt.Sprintf("UPDATE %s SET %s WHERE %s = $%d",
                                  state.params.Table_Name,
                                  strings.Join(cols_kv, ","),
                                  state.params.Unique_ID, len(state.sql_stuff.columns))
    }
    // Create INSERT statement:
    state.sql_stuff.insert_sql = fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
                              state.params.Table_Name,
                              strings.Join(state.sql_stuff.columns, ","),
                              strings.Join(cols_k, ","))
    if state.verbosity > 0 {
        fmt.Println("UPDATE: ", state.sql_stuff.update_sql)
        fmt.Println("INSERT: ", state.sql_stuff.insert_sql)
    }

    // Create a destination database transaction:
    tx, err := state.db_dest.Begin()
    if err != nil {
        fmt.Printf("Create TRANSACTION failed: %s\n", err)
        state.results.Failed = true
        return state.results, err
    }
    state.sql_stuff.transaction = tx

    // Prepare() the UPDATE statement within the transaction:
    if ! state.params.Insert_Only {
        update_prepared, err := state.sql_stuff.transaction.Prepare(state.sql_stuff.update_sql)
        if err != nil {
            fmt.Printf("Prepare UPDATE failed (%s): %s\n", state.sql_stuff.update_sql, err)
            state.results.Failed = true
            return state.results, err
        }
        state.sql_stuff.update_prepared = update_prepared
    }
    // Prepare() the INSERT statement within the transaction:
    insert_prepared, err := state.sql_stuff.transaction.Prepare(state.sql_stuff.insert_sql)
    if err != nil {
        fmt.Printf("Prepare INSERT failed (%s): %s\n", state.sql_stuff.insert_sql, err)
        state.results.Failed = true
        return state.results, err
    }
    state.sql_stuff.insert_prepared = insert_prepared

    // Dictionary of current result row:
    rowDict := make(map[string]interface{}, len(cols))
    // Iterate over source-query result rows, calling the per-row callback
    //  with map "rowDict", until done:
    for rows.Next() {
        // Parse current row into values:
        err = rows.Scan(vals_p...)
        if err != nil {
            fmt.Printf("ERROR: Table %s: Row Scan() error: %s\n", state.params.Table_Name, err)
            state.results.Failed = true
            break
        }
        // Put into "dictionary" variable rowDict:
        for i, col := range cols {
            rowDict[col] = vals[i]
        }

        // Call per-row callback:
        err = sync_row(&state, rowDict)
        // Finish if the per-row function returns an error:
        if err != nil {
            state.results.Failed = true
            break
        }
    }
    rows.Close()

    // Commit() or Rollback() transaction:
    var op string
    if ! state.results.Failed {
        op = "Commit"
    } else {
        op = "Rollback"
    }
    if state.verbosity > 1 {
        fmt.Printf("DONE: Table %s: Calling %s()\n", state.params.Table_Name, op)
    }
    var err_final error
    switch state.results.Failed {
        case false:
            err_final = state.sql_stuff.transaction.Commit()
        case true:
            err_final = state.sql_stuff.transaction.Rollback()
    }
    if err_final != nil {
        fmt.Printf("ERROR: Table %s: Final %s() error: %s\n", state.params.Table_Name, op, err)
        state.results.Failed = true
        return state.results, err_final
    }

    return state.results, err
}


// Attach "Sync()" method (aka "sync.Sync()" function) to TableSyncParams type.
func (tsp TableSyncParams) Sync(db_src *sql.DB, db_dest *sql.DB, verbosity int) (TableSyncResults, error) {
    return Sync(tsp, db_src, db_dest, verbosity)
}