(**************************************************************************)
(*                                                                        *)
(*     The Alt-ergo theorem prover                                        *)
(*     Copyright (C) 2006-2008                                            *)
(*                                                                        *)
(*     Sylvain Conchon                                                    *)
(*     Evelyne Contejean                                                  *)
(*     CNRS-LRI-Universite Paris Sud                                      *)
(*                                                                        *)
(*   This file is distributed under the terms of the CeCILL-C licence     *)
(*                                                                        *)
(**************************************************************************)

open Format
open Options
open Exception

module type S = sig
  type t

  module R : Sig.X

  val empty :  t
  val add : t -> Term.t -> t
  val mem : t -> Term.t -> bool
  val find : t -> Term.t -> R.r
  val union : 
    t -> Term.t -> Term.t -> Explanation.t ->
    t * (R.r * ((R.r * R.r) list) * R.r) list

  val distinct : t -> Term.t -> Term.t -> Explanation.t -> t

  val equal : t -> Term.t -> Term.t -> bool
  val are_distinct : t -> Term.t -> Term.t -> bool
  val class_of : t -> Term.t -> Term.t list

  val explain : t -> Term.t -> Term.t -> Explanation.t
  val neq_explain : t -> Term.t -> Term.t -> Explanation.t

end

module Make ( R : Sig.X ) = struct

  module Ex = Explanation
  module R = R
  module S = Symbols
  module T = Term
  module F = Formula
  module MapT = Term.Map
  module SetT = Term.Set
  module SetF = Formula.Set
  module MapR = Map.Make(struct type t = R.r let compare = R.compare end)
  module SetR = Set.Make(struct type t = R.r let compare = R.compare end)

  type t = { 
    make : R.r MapT.t; (* term -> [t] *)
    repr : (R.r * Ex.t) MapR.t; (* representative table *)
    classes : SetT.t MapR.t;  (* r -> class (of terms) *)
    gamma : SetR.t MapR.t; 
    (* associates each value r with the set of semantical values whose
       representatives contains r *)
    neqs: Ex.t MapR.t MapR.t; (* the disequations map *)
  }
      
  let empty = { 
    make  = MapT.empty; 
    repr = MapR.empty;
    classes = MapR.empty; 
    gamma = MapR.empty;
    neqs = MapR.empty;
  }

  module Print = struct

    let rs_print fmt = SetR.iter (fprintf fmt "%a " R.print)
    let rm_print fmt = MapR.iter (fun k _ -> fprintf fmt "%a " R.print k)

    let t_print fmt = SetT.iter (fprintf fmt "%a " T.print)
      
    let pmake fmt m = 
      fprintf fmt "[.] map:\n";
      MapT.iter (fun t r -> fprintf fmt "%a -> %a\n" T.print t R.print r) m
	
    let prepr fmt m = 
      fprintf fmt "Representatives map:\n";
      MapR.iter (fun r (r',_) -> 
		   fprintf fmt "%a -> %a\n"  R.print r R.print r') m
	
    let pclasses fmt m = 
      fprintf fmt "Classes map:\n";
      MapR.iter 
	(fun k s -> fprintf fmt "%a -> %a\n" R.print k Term.print_list 
	   (SetT.elements s)) m

    let pgamma fmt m = 
      fprintf fmt "Gamma map:\n";
      MapR.iter (fun k s -> fprintf fmt "%a -> %a\n" R.print k rs_print s) m 
		
    let pneqs fmt m = 
      fprintf fmt "Disequations map:\n";
      MapR.iter (fun k s -> fprintf fmt "%a -> %a\n" R.print k rm_print s) m

    let all env = 
      printf "---------- UF environment --------@.";
      printf "%a\n%a\n%a\n%a\n" 
	pmake env.make prepr env.repr pneqs env.neqs pclasses env.classes;
      printf "---------- FIN UF environment --------@.";
  end

  let mem env t = MapT.mem t env.make
      
  let find env t = try
    MapR.find (MapT.find t env.make) env.repr
  with Not_found -> R.make t , Ex.empty
    

  module Env = struct

    let add_to_classes t r classes =  
      MapR.add r 
	(SetT.add t (try MapR.find r classes with Not_found -> SetT.empty))
	classes

    let update_classes c nc classes = 
      let s1 = try MapR.find c classes with Not_found -> SetT.empty in
      let s2 = try MapR.find nc classes with Not_found -> SetT.empty in
      MapR.add nc (SetT.union s1 s2) classes

    let add_to_gamma r c gamma = 
      List.fold_left
	(fun gamma x -> 
	   let s = try MapR.find x gamma with Not_found -> SetR.empty in
	   MapR.add x (SetR.add r s) gamma) gamma (R.leaves c)
	
    let merge r1 m1 r2 m2 dep neqs = 
      let m , neqs = 
	MapR.fold 
	  (fun k ex1 (m,neqs) -> 
	     if MapR.mem k m2 then
	       m , MapR.add k (MapR.remove r1 (MapR.find k neqs)) neqs
	     else
	       let ex = Ex.union ex1 dep in
	       let mk = MapR.add r2 ex (MapR.remove r1 (MapR.find k neqs)) in
	       MapR.add k ex m , MapR.add k mk neqs
	  )
	  m1 (m2,neqs)
      in
      MapR.add r2 m neqs


    let update_neqs r1 r2 dep env = 
      let m1 = MapR.find r1 env.neqs in
      let m2 = try MapR.find r2 env.neqs with Not_found -> MapR.empty in
      if MapR.mem r2 m1 or MapR.mem r1 m2 then raise Inconsistent;
      merge r1 m1 r2 m2 dep env.neqs

    let canon ((r,dep) as p) env =
      List.fold_left
	(fun (newr,dep) x -> 
	   let rx,ex = 
	     try MapR.find x env.repr with Not_found -> x , Ex.empty in
	   R.subst x rx newr , Ex.union dep ex) p (R.leaves r)

    let init env t r = 
      let rr , dep = canon (r,Ex.empty) env in
      { 
	make  = MapT.add t r env.make; 
	repr = 
	  if MapR.mem r env.repr then env.repr 
	  else MapR.add r (rr,dep) env.repr;

	classes = add_to_classes t rr env.classes;
	gamma = add_to_gamma r rr env.gamma;
	neqs = 
	  if MapR.mem rr env.neqs then env.neqs 
	  else MapR.add rr MapR.empty env.neqs
      }

    let update env p v dep = try
      SetR.fold 
	(fun r env -> (* the representative c of r contains the pivot p *)
           let c,ex = MapR.find r env.repr in
	   let nc = R.subst p v c in
	   { env with
	       repr = MapR.add r (nc,Ex.union ex dep) env.repr;
	       classes = update_classes c nc env.classes;
	       gamma = add_to_gamma r nc env.gamma ;
	       neqs = update_neqs c nc dep env;
	   }
	) 
	(* il faut faire le menage dans les maps *)
	(MapR.find p env.gamma) env
    with Not_found -> assert false
  end
    
  let add env t = if MapT.mem t env.make then env else Env.init env t (R.make t)

  let union env t1 t2 dep =
    if debug_uf then 
      printf "@{<C.Bold>[uf]@} union %a = %a@." T.print t1 T.print t2;
    let env = add (add env t1) t2 in
    let r1 , _ = find env t1 in
    let r2 , _ = find env t2 in 
    if R.compare r1 r2 = 0 then env , []
    else
      try
	if MapR.mem r2 (MapR.find r1 env.neqs) then raise Inconsistent;
	if MapR.mem r1 (MapR.find r2 env.neqs) then raise Inconsistent;
	let env,res = 
	  List.fold_left
	    (fun (env,res) (p,v) ->
	       if debug_uf then 
		 printf "@{<C.Bold>[uf]@} on pivote sur %a@." R.print p;
	       let env' = Env.update env p v dep in
	       let touched =
		 SetR.fold 
		   (fun r touched -> (r,fst (MapR.find r env'.repr)) :: touched) 
		   (MapR.find p env.gamma) []
	       in
	       env' , (p,touched,v)::res) 
	    (env,[]) (R.solve r1 r2)
	in
	if debug_uf then Print.all env;
	env , res
      with Unsolvable -> raise Inconsistent
	
  let make_distinct env r1 r2 dep = 
    let d1 = try MapR.find r1 env.neqs with Not_found -> MapR.empty in
    let d2 = try MapR.find r2 env.neqs with Not_found -> MapR.empty in
    let neqs = 
      if MapR.mem r2 d1 then env.neqs else 
	MapR.add r1 (MapR.add r2 dep d1) 
	  (MapR.add r2 (MapR.add r1 dep d2) env.neqs) 
    in
    { env with neqs = neqs}

  let rec distinct env t1 t2 dep = 
    if debug_uf then 
      printf "@{<C.Bold>[uf]@} distinct %a <> %a@." T.print t1 T.print t2;
    let env = add (add env t1) t2 in
    let r1 , ex1 = find env t1 in
    let r2 , ex2 = find env t2 in
    let dep' = Ex.union ex1 (Ex.union ex2 dep) in
    if R.compare r1 r2 = 0 then raise Inconsistent;
    let env = make_distinct env r1 r2 dep' in
    match Term.view t1,Term.view t2 with
      | {Term.f = f1; xs = [a]},{Term.f = f2; xs = [b]}
          when (Symbols.equal f1 f2 
                && R.compare (R.empty_embedding t1) r1 = 0 
              && R.compare (R.empty_embedding t2) r2 = 0) 
            -> distinct env a b dep
      | _,_ -> 
	  match (try R.solve r1 r2 with Unsolvable -> []) with
	      [a,b] -> make_distinct env a b dep'
            | _ -> env
      
  let equal env t1 t2 = 
    let r1 , _ = find env t1 in
    let r2 , _ = find env t2 in
    R.compare r1 r2 = 0

  let are_in_neqs env r1 r2 = 
    (try MapR.mem r1 (MapR.find r2 env.neqs) with Not_found -> false) ||
    (try MapR.mem r2 (MapR.find r1 env.neqs) with Not_found -> false)

  let are_distinct env t1 t2 = 
    let b= 
      let m = add (add env t1) t2 in
      let r1 , _ = find m t1 in
      let r2 , _ = find m t2 in
      if R.compare r1 r2 = 0 then false
      else
	are_in_neqs env r1 r2 ||
          try List.exists (fun (a,b) -> are_in_neqs env a b) (R.solve r1 r2)
            (* True because r1=r2 <-> /\_{(a,b)in(R.solve r1 r2)}  a=b *)
          with Unsolvable -> true
(*      try
	match T.view t1 , T.view t2 with
	    {T.f=S.Int n1} , {T.f=S.Int n2} -> Hstring.compare n1 n2 <> 0
	  | _ -> 
	      let nt1 = MapR.find (find m t1) m.neqs in
	      let nt2 = MapR.find (find m t2) m.neqs in
	      SetT.mem t1 nt2 || SetT.mem t2 nt1
      with Not_found -> false*)
    in     
    if debug_uf then
      printf "@{<C.Bold>[uf]@} are_distinct %a<>%a ? %b@." 
	T.print t1 T.print t2 b; 
    b

  let explain env t1 t2 = 
    if Term.equal t1 t2 then Ex.empty
    else
      let r1 ,ex1 = MapR.find (MapT.find t1 env.make) env.repr in
      let r2 ,ex2 = MapR.find (MapT.find t2 env.make) env.repr in
      if R.compare r1 r2 = 0 then Ex.union ex1 ex2 
      else raise NotCongruent

  let neq_explain env t1 t2 = 
    let r1 ,ex1 = find env t1 in
    let r2 ,ex2 = find env t2 in
    if R.compare r1 r2 <> 0 then Ex.union ex1 ex2 
    else raise NotCongruent
    
  let find env t = fst (find env t)

  let class_of env t = try 
    SetT.elements (MapR.find (find env t) env.classes)
  with Not_found -> [t]


end
