/*
    Copyright (C) 2000 Steve Brown
    Copyright (C) 2000,2001,2002 Guillaume Morin, Alcve

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA


    $Id: lib.c,v 1.62 2002/09/06 09:50:09 gmorin Exp $
    $Source: /cvsroot/nss-mysql/nss-mysql/src/lib.c,v $
    $Date: 2002/09/06 09:50:09 $
    $Author: gmorin $
*/

/* needed for vasprintf ... */
#define _GNU_SOURCE 1

#include <stdio.h>
#include <stdarg.h>
#include <stdlib.h>
#include <string.h>
#include <syslog.h>
#include <errno.h>
#include <ctype.h>
#include <unistd.h>

#ifdef HAVE_CONFIG_H
        #include "config.h"
#endif

#include "lib.h"

static char * parse_host (char * host, char ** unix_socket, int * port);
static void create_pthread_key(void);

/*EOP*/

static pthread_once_t key_once = PTHREAD_ONCE_INIT;
static pthread_key_t tsd_key;

/* Creates the key for the thread specific data */
void create_pthread_key(void) {
        pthread_key_create(&tsd_key,NULL);
}

/* _nss_mysql_escape_string:
 * Returns a pointer to the escaped string
 * We try to save first in a static buffer, because it is faster than
 * calling malloc() and then free().
 * 
 * string: string to be escaped
 * mysql_auth: connection information
 * buffer: a static buffer
 * buflen: sizeof  buffer
 * call_free: if true, free must be called for the returned pointer
 */

char * _nss_mysql_escape_string(const char * string, struct mysql_auth * m, 
                char * buffer, size_t len, int * call_free) {
        char * e_string;
        unsigned s_len = strlen(string)*2+1;

        if (buffer == NULL || s_len > len) {
                e_string = malloc(s_len);
                *call_free = 1;
        } else {
                e_string = buffer;
                *call_free = 0;
        }
        
        if (! e_string) {
                if (DEBUG)
                        _nss_mysql_log(LOG_ERR,"_nss_mysql_escape_string "
                                        "not enough memory for escaping "
                                        "the buffer");
                return NULL;
        }

        if (DEBUG)
                _nss_mysql_log(LOG_ERR,"_nss_mysql_escape_string: "
                                "escaping \"%s\" (*call_free == %d)",
                                string,*call_free);
#ifndef HAVE_MYSQL_REAL_ESCAPE_STRING
        mysql_escape_string(e_string,string,strlen(string));
#else
        mysql_real_escape_string(m->mysql,e_string,string,strlen(string));
#endif
        return e_string;
}        


/* _nss_mysql_send_query
 * Send a query to the MySQL server and then returns the result
 * 
 * mysql_auth: connection information
 * query: query to execute
 * result: MYSQL_RES returned by the query
 * ernnop: pointer to errno
 */

NSS_STATUS _nss_mysql_send_query(struct mysql_auth * auth,char * query,
                MYSQL_RES ** result,int * errnop) {
        
        if (DEBUG) _nss_mysql_log(LOG_ERR,"_nss_mysql_send_query:"
                        "called. MYSQL * %p, mutex %p, SQL statement: %s",
                        auth->mysql,auth->mutex,query);

        if (mysql_query(auth->mysql, query)) {
                _nss_mysql_log(LOG_ERR,"_nss_mysql_send_query: "
                                "mysql_query failed: %s",
                                mysql_error(auth->mysql));
                if (auth->mutex)
                        pthread_mutex_unlock(auth->mutex);
                goto out_unavail;
        }
        
        *result = mysql_store_result(auth->mysql);
        if (! *result) {
                _nss_mysql_log(LOG_ERR,"_nss_mysql_send_query: "
                                "mysql_store_result failed: %s",
                                mysql_error(auth->mysql));
                if (auth->mutex)
                        pthread_mutex_unlock(auth->mutex);
                goto out_unavail;
        }

        if (auth->mutex)
                pthread_mutex_unlock(auth->mutex);

        *errnop = 0;
        return NSS_STATUS_SUCCESS;

out_unavail:
        *errnop = ENOENT;
        return NSS_STATUS_UNAVAIL;
}

/* _nss_mysql_set_fork_handler 
 *
 * Set the pthread fork handler
 *
 * isset: if true the fork_handler is already set
 * mutex: mutex which protects isset
 * prepare,parent,child: see pthread_atfork manpage
 */

void _nss_mysql_set_fork_handler(int * isset,pthread_mutex_t * mutex,
                void (*prepare)(void),
                void (*parent)(void),
                void (*child)(void)) {
        if (DEBUG)
                _nss_mysql_log(LOG_ERR,"_nss_mysql_set_fork_handler: called");
      
        /* we check the value before the mutex because in most cases
         * *isset will be 0, so we just avoid the mutex lock overhead
         */
        if (*isset == 0) {
                pthread_mutex_lock(mutex);
                if (*isset == 0) {
                        if (DEBUG)
                                _nss_mysql_log(LOG_ERR,"_nss_mysql_set_fork_"
                                                "handler: setting the fork "
                                                "handler");
                        if (pthread_atfork(prepare,parent,child) != 0)
                                _nss_mysql_log(LOG_ERR,"_nss_mysql_set_fork_"
                                                "handler: pthread_atfork "
                                                "failed: %s",strerror(errno));
                        *isset = 1;
                }
                pthread_mutex_unlock(mutex);
        }
        if (DEBUG)
               _nss_mysql_log(LOG_ERR,"_nss_mysql_set_fork_handler: finished");
}

/* _nss_mysql_copy_to_buffer 
 * copy a string to the buffer given as arguments
 * returns a pointer to the address in the buffer
 */

char * _nss_mysql_copy_to_buffer(char ** buffer,size_t * buflen, 
                const char * string) {
        size_t len = strlen(string) + 1;
        char * ptr;

        if (DEBUG) {
                if (buflen)
                        _nss_mysql_log(LOG_ERR,"_nss_mysql_copy_to_buffer: "
                                        "called for %s to %p(%u)",string,
                                        *buffer,*buflen);
                else
                        _nss_mysql_log(LOG_ERR,"_nss_mysql_copy_to_buffer: "
                                        "called for %s to %p",string,
                                        *buffer);
        }
                

        if (buflen && len > *buflen) {
                if (DEBUG)
                        _nss_mysql_log(LOG_ERR,"_nss_mysql_copy_to_buffer: "
                                       "Not enough space in the buffer.");
                return NULL;
        }
        memcpy(*buffer,string,len);
        if (buflen) 
                *buflen -= len;
        ptr = *buffer;
        (*buffer) += len;
        return ptr;
}

/* _nss_mysql_isempty
 * checks if a string only contains spaces
 * Returns:
 * 0, string is not empty
 * 1, string is empty
 */

int _nss_mysql_isempty(char * str) {
        if (!str) return 1;
        while(*str != '\0') 
                if (!isspace((unsigned char)*(str++))) return 0;
        return 1;
}

/* _nss_mysql_strtol
 * nss-MySQL strtol version
 * Converts ascii into long
 * str: string to convert
 * fallback: fallback to this value if strtol is not happy
 * error: if (*error), an error has occured, we have fallback.
 */

long _nss_mysql_strtol(char * str, long fallback, int * error) {
        char * endptr;
        long toreturn;
        
        
        /* sanity checks */
        if (!str) {
                _nss_mysql_log(LOG_ERR,"_nss_mysql_strol: string pointer "
                                "is NULL.");
                *error = 1;
                return fallback;
        }

        if (*str == '\0') {
                _nss_mysql_log(LOG_ERR,"_nss_mysql_strtol: string is empty.");
                *error = 1;
                return fallback;
        }
        
        toreturn = strtol(str,&endptr,10);
        
        if (endptr == str) {
                _nss_mysql_log(LOG_ERR,"_nss_mysql_strtol: can't convert %s",
                                str);
                *error = 1;
                return fallback;
        }

        if (*endptr != '\0') {
                _nss_mysql_log(LOG_ERR,"_nss_mysql_strtol_: incomplete "
                                "conversion of %s to %ld. Falling back "
                                "to %ld.",str,toreturn,fallback);
                *error = 1;
                return fallback;
        }

        if (errno != ERANGE) {
                *error = 0;
                return toreturn;
        }

        _nss_mysql_log(LOG_ERR,"_nss_mysql_strol: overflow when converting %s. "
                        "Fix your database.",str);
        *error = 1;
        return toreturn;
}

/* _nss_mysql_log
 * write in syslog 
 * arguments: error level, printf type format and args
*/
void _nss_mysql_log(int err, const char *format, ...) {
        static int openlog_ac = 0;
        va_list args;
        
        va_start(args, format);
        /* this is not thread safe, but it does not matter here */
        if (! openlog_ac) {
                ++openlog_ac;
                openlog("nss-mysql", LOG_PID, LOG_AUTH);
        }
        
        vsyslog(err, format, args);
        va_end(args);

        /* according to its manpage calling closelog is optional */
        if (DEBUG) {
                closelog();
                openlog_ac--;
        } 
}

/* Close the database connexion
 * Arguments: ptr to the connexion you want to close
*/

void _nss_mysql_db_close (MYSQL ** mysql_auth) {
        if (*mysql_auth == NULL) {
                if (DEBUG) _nss_mysql_log(LOG_ERR,"_nss_mysql_db_close: called "
                                "with a NULL pointer");
                return; /* closed already */
        }
        mysql_close(*mysql_auth);
        /* Trust no one */
        *mysql_auth = NULL;
}

/* MySQL access functions */

/* _nss_mysql_db_connect
 * opens a MySQL connection
 * arguments: mysql_auth of the new connection, database host, database user,
 *            database password, database name
 * returns 1 on success, and 0 on failure
 */

#define get_param(param) \
        (! _nss_mysql_isempty(opt->backup_##param)) ? opt->backup_##param : opt->param

int _nss_mysql_db_connect (struct mysql_auth *m,
                struct connection_options * opt) {
        MYSQL * tmp;
        int i;
        char * user,*passwd,*database,*unix_socket,*hostname,*host;
        unsigned port;

        if (DEBUG)
                _nss_mysql_log(LOG_ERR,"_nss_mysql_db_connect: called");
        
        pthread_once(&key_once,create_pthread_key);
        
        if (m->mysql != NULL && DEBUG)
                _nss_mysql_log(LOG_ERR,"_nss_mysql_db_connect: called with a "
                                "non NULL MySQL connexion");

        if (! m->mysql) {
                if (DEBUG)
                        _nss_mysql_log(LOG_ERR,"_nss_mysql_db_connect: calling "
                                "mysql_init(NULL)");
                tmp = mysql_init(NULL);
        
                if (tmp == NULL) {
                        if (DEBUG) 
                                _nss_mysql_log(LOG_ERR,"_nss_mysql_db_connect: "
                                                "not enough memory to allocate "
                                                "a mysql object");
                        m->mysql = NULL;
                        return 0;
                }
        } else
                tmp = m->mysql;
        

        if (DEBUG)
                _nss_mysql_log(LOG_ERR,"_nss_mysql_db_connect: mysql_init() "
                                "suceeded");
        
        for (i = 0 ; i < 2 ; ++i) {
	        port = 3306;
                unix_socket = NULL;
                if (!i ) {
                        /* first time */
                        user = opt->dbuser;
                        passwd = opt->dbpasswd;
                        database = opt->database;
                        host = opt->host;
                } else {
                        user = get_param(dbuser);
                        passwd = get_param(dbpasswd);
                        database = get_param(database);
                        host = opt->backup_host;
                }
                if (_nss_mysql_isempty(host)) continue;

                hostname = parse_host(host,&unix_socket,&port);
                
                if (! (hostname || unix_socket)) {
                /* an error has occured in parse_host */
                        m->mysql = NULL;
                        return 0;
                }
        
                if (DEBUG && unix_socket) 
                        _nss_mysql_log(LOG_ERR,"_nss_mysql_db_connect: " 
                                        "connection with user=%s,passwd=%s,"
                                        "database=%s, unix_socket=%s",
                                        user,passwd,database,
                                        unix_socket);
                if (DEBUG && ! unix_socket) 
                        _nss_mysql_log(LOG_ERR,"_nss_mysql_db_connect: "
                                        "connection with host=%s,user=%s,"
                                        "passwd=%s,database=%s,port=%u",
                                        hostname,user,passwd,
                                        database,port);

                m->mysql = mysql_real_connect (tmp,
                                         hostname,
                                         user,
                                         passwd,
                                         database, port, unix_socket, 0);

        
                if (hostname) free(hostname);
                if (m->mysql) break;

                /* connection has failed */
                _nss_mysql_log(LOG_INFO, "_nss_mysql_db_connect: connection "
                                "failed: %s", mysql_error(tmp));
        } 

        if (m->mysql == NULL) {
                
                mysql_close(tmp);
                m->pid = 0;
                return 0;
        }
        
        m->pid = getpid();
        return 1;
}

/* _nss_mysql_sqlprintf: is called similar to printf. 
 * pointer to a malloc'ed string 
 * is returned. you have to free() it after use!
 */

#ifdef HAVE_VASPRINTF

char * _nss_mysql_sqlprintf(const char * format, ...) {
 
        va_list args;
        char * str = NULL;

        va_start(args, format);

        if (vasprintf(&str, format, args) == -1)
                /* in case of error, str value is undefined, 
	           but could be NULL */
	        str = NULL; 

        if (DEBUG && str == NULL) 
                _nss_mysql_log(LOG_ERR,"_nss_mysql_sqlprintf: not enough "
                                "memory");
        va_end(args);

        return str;
}

#else /* ! HAVE_VASPRINTF */

#define BUFFER_SIZE 1024

char * _nss_mysql_sqlprintf(const char * format, ...) {
 
        va_list args;
        char * str = NULL;

        int len;
        int buffersize = BUFFER_SIZE;
        char * tmp;

        va_start(args, format);

        str = malloc(buffersize);
        if (!str) {
                if (DEBUG) _nss_mysql_log(LOG_ERR,"_nss_mysql_sqlprintf: not "
                                "enough memory");
                return NULL;
        }
        
        len = vsnprintf(str, buffersize, format, args);
        while (len < 0 || len >= buffersize) {
                /* output has been truncated, we will allocate a bigger space */
                buffersize += BUFFER_SIZE;
                tmp = realloc(str, buffersize);
                if (! tmp) {
                        if (DEBUG) _nss_mysql_log(LOG_ERR,"_nss_mysql_"
                                        "sqlprintf: not enough memory");
                        free(str);
                        return NULL;
                } else {
                        str = tmp;
                }
                
                len = vsnprintf(str, buffersize, format, args);
        }

        if (DEBUG) _nss_mysql_log(LOG_ERR, "_nss_mysql_sqlprintf(): buffersize="
                        "%d, len=%d", buffersize, len);
        va_end(args);

        return str;
}

#endif /* ! HAVE_VASPRINTF */

/* _nss_mysql_check_connection
 * checks if a connection has been opened and if it is still alive
 * tries to open a connection if not
 */

int _nss_mysql_check_connection(struct mysql_auth * m, 
                struct connection_options * con) {
        int force_reconnect = 0;
        if (DEBUG)
                _nss_mysql_log(LOG_ERR,"_nss_mysql_check_connection called");

        /* Is the server still alive ? */
        pthread_mutex_lock(m->mutex);
        if (m->mysql != NULL) {
                int ping = 0;
                if (DEBUG)
                        _nss_mysql_log(LOG_ERR,"_nss_mysql_check_connection: "
                                        "found a saved MYSQL * (%p)",
                                        m->mysql);

                if (m->pid == 0) {
                        force_reconnect = 1;
                        if (DEBUG)
                                _nss_mysql_log(LOG_ERR,"_nss_mysql_check_"
                                                "connection: fork() has been "
                                                "called (threaded app). "
                                                "Reconnecting");
                } else if (m->pid != getpid()) {
                        if (pthread_getspecific(tsd_key) == (void *) 1) {
                                force_reconnect = 1;
                                if (DEBUG)
                                        _nss_mysql_log(LOG_ERR,"_nss_mysql_"
                                                  "check_connection: the found "
                                                  "MYSQL * is used by another "
                                                  "process (%d). Reconnecting",
                                                  m->pid);
                        } else {
                                ping = 1;
                                if (DEBUG)
                                        _nss_mysql_log(LOG_ERR,"_nss_mysql_"
                                                  "check_connection: the found "
                                                  "MYSQL * has been saved by "
                                                  "another thread (%d). "
                                                  "Pinging",m->pid);
                        }
                } else
                        ping = 1;

                if (ping) {
                        if (DEBUG)
                                _nss_mysql_log(LOG_ERR,"_nss_mysql_check_"
                                                "connection: pinging %p",
                                                m->mysql);
                        my_thread_init();
                        if (mysql_ping(m->mysql)) {
                                if (DEBUG) 
                                        _nss_mysql_log(LOG_ERR,
                                                "_nss_mysql_check_connection: "
                                                "can't sustain connection : %s",
                                                mysql_error(m->mysql));
                                _nss_mysql_db_close(&m->mysql);
                                m->mysql = NULL; 
                        }
                }
        }

        /* DB connection */
        if (m->mysql == NULL || force_reconnect) {
                pthread_setspecific(tsd_key,(void *) 1);
                if (DEBUG)
                        _nss_mysql_log(LOG_ERR,"_nss_mysql_check_connection: "
                                        "Initiating a new connection");
                if  (! _nss_mysql_db_connect(m,con)) {
                        pthread_mutex_unlock(m->mutex);
                        return 0;
                }
        }

        if (DEBUG)
                _nss_mysql_log(LOG_ERR,"_nss_mysql_check_connection: "
                                "sucessfully exiting");
        return 1;
}

char * parse_host (char * host, char ** unix_socket, int * port) {
        char * hostname=NULL,*p;

        if (strncmp(host, "unix:", 5) == 0) {
                /* we use an UNIX socket */
                *unix_socket = host + 5;
        } else {
                /* inet type connection, allowed syntaxes :
                 * inet:host:port
                 * host:port
                 * host
                 */
                if (strncmp(host, "inet:", 5) == 0)
                        host += 5;
                if ((p = strchr(host,':')) == 0 || *++p == '\0') {
                        /* no port */
                        if (DEBUG) _nss_mysql_log(LOG_ERR,"_nss_mysql_db_connect"
                                        ": using default port");
                        hostname = strdup(host);
                } else {
                        hostname = malloc((p - host) * sizeof(char));
			if (hostname != NULL) {
                               hostname[0] = '\0';
                               strncat(hostname, host, p - host - 1);
                               *port = strtol(p, NULL, 10);
                        }
                }

                if (hostname == NULL) {
                        if (DEBUG) _nss_mysql_log(LOG_ERR,"_nss_mysql_"
						  "db_connect not enough memory"
						  "to parse hostname");
                                return NULL;
                }               
        }
        return hostname;
}
