(*
  FreeTDS interface for OCamlDBI.
  Copyright (C) 2004 Kenneth Knowles
  
  This library is free software; you can redistribute it and/or modify
  it under the terms of the GNU Lesser General Public License as published by
  the Free Software Foundation; either version 2 of the License, or
  (at your option) any later version.
  
  This library is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU Lesser General Public License for more details.
  
  You should have received a copy of the GNU Lesser General Public License
  along with this library; if not, write to the Free Software
  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*)

open Printf


let dbi_raise_messages conn =
    let clientmsgs = 
        List.map (fun (sev, err) -> 
                      sprintf "Client: %s" err)
            (Ct.get_messages conn [`Client])
    in

    let servermsgs = 
        List.map (fun (sev, err) -> 
                      sprintf "Server: %s" err)
            (Ct.get_messages conn [`Server])
    in
    raise (Dbi.SQL_error (String.concat "\n" (clientmsgs @ servermsgs)))
        
exception Got_rowset

(* TDS allows only one query per connection at a time, so I've decided fetching them
   all at once is better than having a complex deallocatios scheme *)
type recordset = {
    rs_names : string array;
    rs_rows : Dbi.sql_t array array
}

let dbi_of_ct v =
    match v with
        | `Int i32 -> `Int (Int32.to_int i32)
              
        | `Tinyint i
        | `Smallint i -> `Int i
              
        | `Bit b -> `Bool b
              
        | `Text s
        | `String s -> `String s
              
        | `Binary s -> `Binary s
              
        (* TODO: parse the date/time with some library... *)
        | `Datetime s -> `String s
              
        | `Decimal s -> `Decimal (Dbi.Decimal.of_string s)
              
        | `Float f -> `Float f

        | `Null -> `Null
          

let rec get_rows cmd cols li =
    if 
        try ignore (Ct.fetch cmd); false
        with Ct.End_data -> true
    then
        Array.of_list (List.rev li)
    else
        let row = Array.map 
                      (fun col -> dbi_of_ct (Ct.buffer_contents col.Ct.col_buffer))
                      cols 
        in
        get_rows cmd cols  (row :: li)
;;

let rec get_one_recordset cmd =
    match Ct.results cmd with
        | `Status | `Cmd_succeed | `Cmd_done -> get_one_recordset cmd
        | `Cmd_fail -> raise Ct.Cmd_fail
        | `Param -> failwith "Don't handle param results"
        | `Row ->
              (* Columns are a list because DBI likes lists of sql_t for results *)
              let cols = Array.init 
                             (Ct.res_info cmd `Numdata)
                             (fun i -> Ct.bind cmd (i+1))
              in
              {
                  rs_names = Array.map (fun c -> c.Ct.col_name) cols;
                  rs_rows = get_rows cmd cols []
              }

class statement
    dbh
    cmd
    =
object(self)
    inherit Dbi.statement (dbh :> Dbi.connection)
        
    val mutable recordsets = []
    val mutable curr_row = -1
    val mutable curr_rs = None

    method private recordset =
        match curr_rs with
            | None -> raise Not_found
            | Some s -> s

    method connection = (dbh :> Dbi.connection)

    (* An analogous set of methods for recordsets as for rows *)
    method rs_fetch_nth i = List.nth recordsets i
    method rs_fetch_all = recordsets
    method rs_count = List.length recordsets

    method execute args =
        (* Send the commands and fetch all the results *)
        ( try
              Ct.send cmd;
              if dbh # debug_level >= 2 then printf "Sent command\n";
              while true do
                let rs = get_one_recordset cmd in
                  recordsets <- recordsets @ [rs];
                  curr_rs <- Some rs;
                  curr_row <- 0
              done
          with
              | Ct.End_results -> () 
              | Ct.Cmd_fail -> printf "Command failed\n%!"; dbi_raise_messages dbh # raw_connection
              | Failure s -> printf "Failure: %s\n%!" s; dbi_raise_messages dbh # raw_connection
        )
            
    method fetch1 () =
        try
            curr_row <- curr_row + 1;
            Array.to_list (self # recordset).rs_rows.(curr_row);
        with
            | Invalid_argument _ -> raise Not_found
                  
    method names = Array.to_list (self # recordset).rs_names

    method serial = failwith "Not implemented."
                        (*       raise Not_found *)
            (* TDS has a serial, but you can just fetch it your damn self. *)
end


and connection ?host ?port ?user ?password database = 
    let context = Ct.ctx_create () in
    let conn = Ct.con_alloc context in
    let maybe_setstring conn var value =
        match value with
            | Some s -> Ct.con_setstring conn var s
            | None -> ()
    in
object(self)
    inherit Dbi.connection ?host ?port ?user ?password database as super

    method raw_connection = conn

    (* Debug levels (sorry about making them ints, but I don't have a better idea):
       0 - no debugging
       1 - connection / command creation debugging
       2 - debug each result
       3 - debug each fetch *)
    val mutable debug_level = 0
                                  
    method set_debug b = 
        debug_level <- (if b 
                        then ( printf "Dbi_ct.connection # set_debug %B\n" b; 1 )
                        else 0)
            
    method set_debug_level i =
        debug_level <- i;
        if i > 0 then printf "Dbi_ct.connection # set_debug_level %i\n" i

    method debug = debug_level > 0
    method debug_level = debug_level
        
    initializer
        maybe_setstring conn `Username user;
        maybe_setstring conn `Password password;
        (match host with
            | Some h -> Ct.connect conn h
            | None -> Ct.connect conn "localhost");

        if database <> "" then
            let cmd = Ct.cmd_alloc conn in
            try
                Ct.command cmd `Lang ("USE " ^ database);
                Ct.send cmd;
                while true do ignore (Ct.results cmd) done
            with
                | Ct.End_results -> ()
                | Failure _
                | Ct.Cmd_fail -> dbi_raise_messages conn 
                      (* failwith ("Could not connect to database: " ^ database) *)
    
    method host = host
    method port = port
    method user = user        
    method password = password
    method database = database

    method database_type = "tds"
                               
    method prepare query =
        if self # debug then
            eprintf "Dbi_ct.connection %d # prepare %S\n%!" self#id query;
        
        if self # closed then
            failwith "Dbi_ct: prepare called on closed database handle.";
        
        let cmd = Ct.cmd_alloc conn in
        Ct.command cmd `Lang query;
        ( new statement
              (self :> connection)
              cmd
          :> Dbi.statement)

    method ex_multi sql (params : Dbi.sql_t list) =
        let cmd = Ct.cmd_alloc conn in
        Ct.command cmd `Lang sql;
        let sth = new statement
                      (self :> connection)
                      cmd
        in
        sth # execute params;
        sth

    method close =
        Ct.close conn;
        super # close
end
