/*
 * Wrapper functions for calling lower-level mySQL routines.
 *
 * Copyright (c) 2009-2011 Centro Svizzero di Calcolo Scientifico (CSCS)
 * Licensed under the GPLv2.
 */
#include "basil_mysql.h"

/**
 * validate_stmt_column_count - Validate column count of prepared statement
 * @stmt:	 prepared statement
 * @query:	 query text
 * @expect_cols: expected number of columns
 * Return true if ok.
 */
static bool validate_stmt_column_count(MYSQL_STMT *stmt, const char *query,
				       unsigned long expect_cols)
{
	unsigned long	column_count;
	MYSQL_RES	*result_metadata = mysql_stmt_result_metadata(stmt);

	/* Fetch result-set meta information */
	if (!result_metadata) {
		error("can not obtain statement meta information for \"%s\": %s",
		      query, mysql_stmt_error(stmt));
		return false;
	}

	/* Check total column count of query */
	column_count = mysql_num_fields(result_metadata);
	if (column_count != expect_cols) {
		error("expected %lu columns for \"%s\", but got %lu",
		      expect_cols, query, column_count);
		mysql_free_result(result_metadata);
		return false;
	}

	/* Free the prepared result metadata */
	mysql_free_result(result_metadata);

	return true;
}

/**
 * prepare_stmt - Initialize and prepare a query statement.
 * @handle:	connected handle
 * @query:	query statement string to execute
 * @bind_parm:  values for unbound variables (parameters) in @query
 * @nparams:	length of @bind_parms
 * @bind_col:	typed array to contain the column results
 *		==> non-NULL 'is_null'/'error' fields are taken to mean
 *		    that NULL values/errors are not acceptable
 * @ncols:	number of expected columns (length of @bind_col)
 * Return prepared statement handle on success, NULL on error.
 */
MYSQL_STMT *prepare_stmt(MYSQL *handle, const char *query,
			 MYSQL_BIND bind_parm[], unsigned long nparams,
			 MYSQL_BIND bind_col[], unsigned long ncols)
{
	MYSQL_STMT	*stmt;
	unsigned long	param_count;

	if (query == NULL || *query == '\0')
		return NULL;

	/* Initialize statement (fails only if out of memory). */
	stmt = mysql_stmt_init(handle);
	if (stmt == NULL) {
		error("can not allocate handle for \"%s\"", query);
		return NULL;
	}

	if (mysql_stmt_prepare(stmt, query, strlen(query))) {
		error("can not prepare statement \"%s\": %s",
		      query, mysql_stmt_error(stmt));
		goto prepare_failed;
	}

	/* Verify the parameter count */
	param_count = mysql_stmt_param_count(stmt);
	if (nparams != nparams) {
		error("expected %lu parameters for \"%s\" but got %lu",
		      nparams, query, param_count);
		goto prepare_failed;
	}

	if (!validate_stmt_column_count(stmt, query, ncols))
		goto prepare_failed;

	if (nparams && mysql_stmt_bind_param(stmt, bind_parm)) {
		error("can not bind parameter buffers for \"%s\": %s",
		      query, mysql_stmt_error(stmt));
		goto prepare_failed;
	}

	if (mysql_stmt_bind_result(stmt, bind_col)) {
		error("can not bind output buffers for \"%s\": %s", query,
		      mysql_stmt_error(stmt));
		goto prepare_failed;
	}

	return stmt;

prepare_failed:
	(void)mysql_stmt_close(stmt);
	return NULL;
}

/**
 * store_stmt_results - Buffer all results of a query on the client
 * @args: passed from run_stmt()
 * Returns -1 on error, number_of_rows >= 0 if ok.
 */
static int store_stmt_results(MYSQL_STMT *stmt, const char *query,
			      MYSQL_BIND bind_col[], unsigned long ncols)
{
	my_ulonglong nrows;
	int i;

	if (stmt == NULL || ncols == 0)
		return -1;

	if (mysql_stmt_store_result(stmt)) {
		error("can not store query result for \"%s\": %s",
		      query, mysql_stmt_error(stmt));
		return -1;
	}

	nrows = mysql_stmt_affected_rows(stmt);
	if (nrows == (my_ulonglong)-1) {
		error("query \"%s\" returned an error: %s",
		      query, mysql_stmt_error(stmt));
		return -1;
	}

	while (mysql_stmt_fetch(stmt) == 0)
		for (i = 0; i < ncols; i++) {
			if (bind_col[i].error && *bind_col[i].error)  {
				error("result value in column %d truncated: %s",
				      i, mysql_stmt_error(stmt));
				return -1;
			}
		}

	/* Seek back to begin of data set */
	mysql_stmt_data_seek(stmt, 0);

	return nrows;
}

/**
 * run_stmt - Execute and validate a prepared statement
 * @stmt:	prepared statement
 * @query:	query text
 * @bind_col:	as in prepare_stmt()
 * @ncols:	as in prepare_stmt()
 * @do_store:	whether to store the results on the client
 */
bool run_stmt(MYSQL_STMT *stmt, const char *query,
	      MYSQL_BIND bind_col[], unsigned long ncols, bool do_store)
{
	if (mysql_stmt_execute(stmt)) {
		error("failed to execute \"%s\": %s",
		      query, mysql_stmt_error(stmt));
		return false;
	}

	if (do_store && store_stmt_results(stmt, query, bind_col, ncols) < 0)
		return false;

	return true;
}

/**
 * exec_stmt - Execute, store and validate a prepared statement
 * @bind_col:	as in run_stmt()
 * @query:	as in run_stmt()
 * @bind_col:	as in run_stmt()
 * @ncols:	as in run_stmt()
 * Returns -1 on error, number_of_rows >= 0 if ok.
 */
int exec_stmt(MYSQL_STMT *stmt, const char *query,
	      MYSQL_BIND bind_col[], unsigned long ncols)
{
	if (!run_stmt(stmt, query, bind_col, ncols, false))
		return -1;
	return store_stmt_results(stmt, query, bind_col, ncols);
}

/**
 * exec_query - Prepare, execute, and validate parameter-less query statement
 * @handle:	connected handle
 * @query:	query statement to execute
 * @columns:	typed array for the column results
 * @ncols:	length of @columns
 * Return ready-to-fetch statement handle on success, NULL on error.
 */
MYSQL_STMT *exec_query(MYSQL *handle, const char *query,
		       MYSQL_BIND columns[], unsigned long ncols)
{
	MYSQL_STMT *stmt = prepare_stmt(handle, query, NULL, 0, columns, ncols);

	if (stmt && exec_stmt(stmt, query, columns, ncols) < 0) {
		mysql_stmt_close(stmt);
		return NULL;
	}
	return stmt;
}

/**
 * exec_boolean_query - Prepare and execute parameter-less Boolean query
 * @handle:	connected handle
 * @query:	statement (question) to execute
 * Returns -1 on error, 1 if true, 0 if false.
 */
int exec_boolean_query(MYSQL *handle, const char *query)
{
	MYSQL_BIND	result[1];
	signed char	answer;
	my_bool		is_null;
	my_bool		is_error;
	MYSQL_STMT	*stmt;

	memset(result, 0, sizeof(result));
	result[0].buffer_type	= MYSQL_TYPE_TINY;
	result[0].buffer	= (char *)&answer;
	result[0].is_null	= &is_null;
	result[0].error		= &is_error;

	stmt = exec_query(handle, query, result, 1);
	if (stmt == NULL)
		return -1;
	mysql_stmt_close(stmt);
	return answer;
}
