open Options
open Format
module L = List
module T = Term
module S = Symbols
module HS= Hstring
module ST = T.Set
module SA = Literal.Set
type elt = ST.t * SA.t
  
module Make (X : Sig.X) = struct

  (*= MODULE UTIL ============================================================*)
  module Util = struct
    
    let hr = "  -------------------------------------- "

    let pr_leaf fmt lf = 
      try fprintf fmt "%a" T.print (X.term_of_leaf lf)
      with _ -> assert false 

    let pr_leafl fmt = function
        []   -> ()
      | a::l ->
	  fprintf fmt "%a" pr_leaf a;
	  L.iter (fun r -> fprintf fmt ", %a" pr_leaf r) l

    let inter_tpl (x1,y1) (x2,y2) = ST.inter x1 x2, SA.inter y1 y2

    let union_tpl (x1,y1) (x2,y2) = ST.union x1 x2, SA.union y1 y2
      
    let leaves r = 
      let one = X.make (T.make (Symbols.name "@bottom") [] Ty.Tint) in
      match X.leaves r with [] -> [one] | l -> l
	
    let make_list l = L.map X.make l
      
    (* Pre-condition : S.is_ac sym = true *)
    let ac_hs sym =
      match sym with 
	  S.Name (hs,true) -> hs 
	| _ -> assert false
	    
    let is_a_leaf r = 
      try let _ = X.term_of_leaf r in true
      with Exception.Not_a_leaf    -> false
	
    let ac_leaf rx = 
      try 
	match (T.view (X.term_of_leaf rx)).T.f with
	    S.Name (hs,true) -> true 
	  | _ -> false
      with Exception.Not_a_leaf -> false

    let sm_ac_info rx = 
      try 
	let {T.f = f; xs = xs } = T.view (X.term_of_leaf rx) in
	match f with 
	    S.Name (hs,true) -> hs,xs 
	  | _ -> assert false
      with Exception.Not_a_leaf -> assert false

    (* l1 et l2 doivent etre triee *)
    let complement l1 l2 = 
      let rec f_aux acc = function
	  [], r -> (L.rev acc)@r
	| _ ,[] -> raise Exit
	| a::r1, b::r2 ->
	    let c = X.compare a b in
	    if c = 0 then f_aux acc (r1,r2)
	    else if c > 0 then f_aux (b::acc) (a::r1,r2) 
	    else raise Exit 
      in f_aux [] (l1,l2)

    let sort_lr = L.fast_sort X.compare
	   
  end

  open Util
    
  type pc = {ops : X.r list; du  : X.r list; dv  : X.r list}

  module Spc = 
    Set.Make 
      (struct
	 type t = pc

	 exception NE of int
	   
	 let compare pc1 pc2 = 
	   let {ops = o1; du = u1; dv = v1} = pc1 in
	   let {ops = o2; du = u2; dv = v2} = pc2 in
	   assert (o1 = sort_lr o1);
	   assert (o2 = sort_lr o2);
	   assert (u1 = sort_lr u1);
	   assert (u2 = sort_lr u2);
	   assert (v1 = sort_lr v1);
	   assert (v2 = sort_lr v2);
	   
	   let cmpl l1 l2 = 
	     let n1 = L.length l1 in
	     let n2 = L.length l2 in
	     if n1 <> n2 then n1 - n2
	     else 
	       try L.iter2 
		 (fun a b ->let c = X.compare a b in if c<>0 then raise (NE c))
		 l1 l2;
		 0
	       with NE n -> n 
	   in 
	   let c = cmpl o1 o2 in 
	   if c <> 0 then c 
	   else let c = cmpl u1 u2 in
	   if c <> 0 then c 
	   else cmpl u1 u2
       end)

  module Sr = Set.Make(struct type t = X.r let compare = X.compare end)

  module Mhs = Map.Make(
    struct type t = HS.t let compare = HS.compare end)
    
  module Mr = Map.Make(struct type t = X.r include X end)

  (*= MODULE G ===============================================================*)
  module G = struct
    include Mr
    
    let find k m = try find k m with Not_found -> (ST.empty,SA.empty)
    
    type ty = elt t
	
    let add_term k t mp =
      let g_t,g_a = find k mp in add k (ST.add t g_t,g_a) mp
				   
    let up_add g t rt lvs = 
      let g = if mem rt g then g else add rt (ST.empty, SA.empty) g in
      L.fold_left (fun g x -> add_term x t g) g lvs 
	   	
    let congr_add_aux fc g lvs = 
      match lvs with
	  []    -> ST.empty
	| x::ls -> 
	    L.fold_left (fun acc y -> fc (fst(find y g)) acc)
	      (fst(find x g)) ls
	      
    let congr_add g lvs = congr_add_aux ST.inter g lvs

    let ac_congr_add g lvs = congr_add_aux ST.union g lvs
      
    let up_close_up g p lvs = 
      let g_p = find p g in
      L.fold_left (fun gg q -> add q (union_tpl g_p (find q g)) gg) g lvs
	
    let congr_close_up g p touched =
      let inter = function 
	  [] -> (ST.empty, SA.empty)
	| rx::l -> 
	    L.fold_left (fun acc x ->inter_tpl acc (find x g))(find rx g) l
      in 
      L.fold_left (fun (st,sa) (_,v) -> union_tpl (st,sa)(inter (leaves v)))
	(find p g) touched 


    (*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*)
    let concat_leaves xs = make_list xs
	
    let pc_aux lvs1 lvs2 =
      let olps, diff1 = L.partition (fun e -> L.mem e lvs2) lvs1 in
      let diff2 = L.filter (fun e -> not (L.mem e olps)) lvs2 in
      { ops = sort_lr olps ; du = sort_lr diff2 ; dv = sort_lr diff1 }
	
    let compute_pc g t lvs hs tmp_pc =
      L.fold_left 
	(fun acc x -> 
	   ST.fold 
	     (fun t2 ac ->
		let {T.f = sym2; xs = xs2 } = T.view t2 in
		if S.is_ac sym2 && HS.equal hs (ac_hs sym2) then 
		  Spc.add (pc_aux lvs (concat_leaves xs2)) ac
		else ac
	     ) (fst(find x g)) acc
	) tmp_pc lvs
        
    (*++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++*)
	
    let print g = 
      let sterms fmt = ST.iter (fprintf fmt "%a " T.print) in
      let satoms fmt = SA.iter (fprintf fmt "%a " Literal.print) in
      fprintf fmt "@{<C.Bold>[use]@} gamma :\n";
      iter 
	(fun t (st,sa) -> 
	   fprintf fmt "%a is used by {%a} and {%a}\n"  
	     X.print t sterms st satoms sa
	)g
	
  end
    
  (*= MODULE UF =============================================================*)
  module UF = struct
    
    type leader = L of X.r 
      
    let content (L r) = r
    let cmp_lead (L u) (L v) = X.compare u v

    module Repr = Map.Make(struct type t = X.r let compare = X.compare end)
    module Elts = Map.Make(struct type t = leader let compare = cmp_lead end)
    
    type t = { 
      (* associe a chaque element son representant (leader)*)
      repr: leader Repr.t; 
      
      (* associe a chaque representant (leader sa classe *)
      elts: Sr.t Elts.t }
	
    let empty = { repr = Repr.empty; elts = Elts.empty }
      
    let leader uf sv = 
      (* verification de l'invariant *)
      let _ = X.term_of_leaf sv in
      
      try Repr.find sv uf.repr with Not_found -> L sv
      
    let union uf ld1 ld2 =
      let L r1 = ld1 in
      let L r2 = ld2 in
      
      (* verification de l'invariant *)
      let _ = X.term_of_leaf r1 in
      let _ = X.term_of_leaf r2 in
      
      if r1 = r2 then uf
      else
	let cls1=try Elts.find ld1 uf.elts with Not_found -> Sr.singleton r1 in
	let cls2=try Elts.find ld2 uf.elts with Not_found -> Sr.singleton r2 in
	if Sr.cardinal cls1 < Sr.cardinal cls2 then
	  {repr = Sr.fold (fun k rep -> Repr.add k (L r2) rep) cls1 uf.repr;
	   elts = Elts.add ld2 (Sr.union cls1 cls2) (Elts.remove ld1 uf.elts)}
	else
	  {repr = Sr.fold (fun k r -> Repr.add k (L r1) r) cls2 uf.repr;
	   elts = Elts.add ld1 (Sr.union cls1 cls2) (Elts.remove ld2 uf.elts)}
	    
    (* retourner la classe du representant donne *)
    let class_of ld uf  = 
      (* verification de l'invariant *)
      let _ = X.term_of_leaf (content ld) in
      try Elts.find ld uf.elts 
      with Not_found -> let L r = ld in Sr.singleton r

    (* printing uf *)
    let pr_set fmt st = pr_leafl fmt (Sr.elements st)

    let pr_classes fmt mp = 
      fprintf fmt "    Classes map:\n";
      Elts.iter 
	(fun (L k) s -> fprintf fmt "\t%a -> { %a }\n" pr_leaf k pr_set s) mp
	
    let pr_repr fmt mp = 
      fprintf fmt "    Representatives map:\n";
      Repr.iter(fun e (L r)-> fprintf fmt "\t%a -> %a\n" pr_leaf e pr_leaf r) mp

    let print uf = printf "%a\n%a\n" pr_repr uf.repr  pr_classes uf.elts
 
  end
    
  (*= MODULE Huf =============================================================*)
  module Huf = struct
    include Mhs
    
    let find k m = try find k m with Not_found -> UF.empty
      
    type ty = UF.t t
      
    let uf_chain uf lvs = 
      match lvs with
	  [] | [_] -> uf
	| x::y::l  -> 
	    let lx = UF.leader uf x in
	    let uf, _ = 
	      L.fold_left
		(fun (uf, la) b ->
		   let lb  = UF.leader uf b in
		   let nuf = UF.union uf la lb in
		   (nuf, UF.leader nuf b)
		)(uf, lx) (y::l)
	    in uf

    let up_uf muf (hs,xs) v = 
      let uf = uf_chain (find hs muf) (v::(make_list xs)) in add hs uf muf
	

    (* 
       QUE FAIRE ?
       - f(a,b) = v 
       + mettre a,b,v dans la meme classe ou bien
       + mettre a,b, leaves(v) dans la meme classe
    *)
    let up_close_up huf p v = 
      if is_a_leaf v then
	match ac_leaf p, ac_leaf v with
	    (* c |-> d / both hd(c) and hd(d) are not AC *)
	  | false, false -> 
	      fold 
		(fun hs ufs acc -> add hs (uf_chain ufs [p;v]) acc) huf empty
	      
	  (* f(a,b) |-> c / f is AC but not hd(c)*)
	  | true , false -> up_uf huf (sm_ac_info p) v 
	      
	  (* c |-> f(a,b) / f is AC but not hd(c)*)
	  | false , true -> up_uf huf (sm_ac_info v) p 
	      
	  (* f(a,b) |-> g(c,d) / both f and g are AC *)
	  | true , true  -> huf
      else huf

    let print mp = 
      fprintf fmt "\n@{<C.Bold>[use]@} AC-UF :\n";
      iter
	(fun hs uf ->
	   fprintf fmt "  %s symbol :\n%s\n" (HS.view hs) hr;
	   UF.print uf;
	)mp

    let classes muf hs ls = 
      let uf = find hs muf in
      L.fold_left (fun acc e -> Sr.union acc (UF.class_of (UF.leader uf e) uf))
	Sr.empty ls

  end

   
  (*= MODULE PC =============================================================*)
  module PC = struct
    
    include Mhs

    let find k m = try find k m with Not_found -> Spc.empty,Sr.empty
      
    type ty = (Spc.t * Sr.t) t

    let up_add mp npc hs = 
      let st_pc,st_r = find hs mp in
      let st_pc = Spc.union npc st_pc in
      let st_r  = 
	Spc.fold
	  (fun {ops=ops} acc -> L.fold_left(fun ac e-> Sr.add e ac) acc ops)
	  npc st_r
      in add hs (st_pc,st_r) mp

    let critical_lvs mp hs = snd (find hs mp)

    let apply_pc pc xs1 xs2 = 
      (complement pc.du xs1, pc.ops@pc.dv, pc.du),
      (complement pc.dv xs2, pc.ops@pc.du,pc.dv)
      
    let applicable_pc mp hs xs1 xs2 = 
      Spc.fold
	(fun pc acc ->
	   try (apply_pc pc xs1 xs2)::acc
	   with Exit ->
	     try (apply_pc pc xs2 xs1)::acc
	     with Exit -> acc
	) (fst(find hs mp)) []

    let print mp = 
      fprintf fmt "\n";
      iter
	(fun hs st ->
	   let s = HS.view hs in
	   Spc.iter
	     (fun pc ->
		fprintf fmt "[ %s(%a)=Y ; %s(%a)=X ] -> %s(%a,X) = %s(%a,Y)\n" 
		  s pr_leafl (pc.ops@pc.du) s pr_leafl (pc.ops@pc.dv)
		  s pr_leafl pc.du s pr_leafl pc.dv
	     )(fst st);
	)mp;
				   
  end    
    
  (*= CE MODULE ==============================================================*)
  type t = { g : G.ty ; pc : PC.ty ; uf : Huf.ty ; tmp_pc : Spc.t }
      
  let empty = {g=G.empty ; pc=PC.empty ; uf=Huf.empty ; tmp_pc=Spc.empty} 
    
  let find k use   = G.find k use.g
  
  let add  k v use = { use with g = G.add k v use.g }
  
  let mem  k use   = G.mem k use.g
    
  let print use = 
    printf "\n============== use table : begin ==============@.";
    G.print use.g;
    PC.print use.pc;
    Huf.print use.uf;
    printf "============== use table : end ================@.\n"

  (** **)
  let ac_leaves use hs lvs = 
    let s1 = Huf.classes use.uf hs lvs in
    let s2 = PC.critical_lvs use.pc hs in
    Sr.elements (Sr.union s1 s2)
    
  (** **) 
  let up_add use t rt sym lvs = 
    if S.is_ac sym then
      let hs = ac_hs sym in
      (* calcul des paires critiques *)
      let tmp_pc = G.compute_pc use.g t lvs hs use.tmp_pc in
      
      (* maj normale et maj modulo AC *)
      let ng = G.up_add use.g t rt lvs in
      let ng = G.up_add ng t rt (ac_leaves use hs lvs) in
      {use with g = ng ; tmp_pc = tmp_pc}
    else {use with g = G.up_add use.g t rt lvs}
	
  (** **) 
  let congr_add use sym lvs = 
    if S.is_ac sym then 
      let hs = ac_hs sym in
      let acl = ac_leaves use hs lvs in
      if debug_tmp then
	begin
	  fprintf fmt "leaves :";
	  L.iter (fprintf fmt "%a - " pr_leaf) lvs;
	  fprintf fmt "\n";
	  fprintf fmt "AC-leaves :";
	  L.iter (fprintf fmt "%a - " pr_leaf)acl;
	  fprintf fmt "\n@.";
	end;
      G.ac_congr_add use.g acl
    else G.congr_add use.g lvs

  (** **) 
  let up_close_up use p v touched = 
    let use = {use with g = G.up_close_up use.g p (leaves v)} in
    L.fold_left 
      (fun us (p,v) ->
	 if is_a_leaf p then {us with uf = Huf.up_close_up us.uf p v}
	 else us
      )use touched

  (** **) 
  let congr_close_up use p touched = 
    G.congr_close_up use.g p touched

  (** **)
  let update_pc use t = 
    let {T.f=f} = T.view t in
    if S.is_ac f then 
      let pc = PC.up_add use.pc use.tmp_pc (ac_hs f) in
      {use with pc = pc ; tmp_pc = Spc.empty}
    else use

  (** **)
  let rewrite use hs xsr1 xsr2 = 
    let pcl = PC.applicable_pc use.pc hs xsr1 xsr2 in
    let s = HS.view hs in
    List.iter
      (fun  (ct1,ct2) ->
	 let (u, v, w) = ct1 in
	 let (x, y, z) = ct2 in
	 if debug_tmp then
	   fprintf fmt "\n%s(%a) = %s(%a) AND %s(%a) = %s(%a) -> %s(%a) = %s(%a)" 
	     s pr_leafl u s pr_leafl v s pr_leafl x s pr_leafl y
	     s pr_leafl (v@w) s pr_leafl (y@z);
      )pcl;
    fprintf fmt "\n";
    pcl

end
