{-
    Kaya - My favourite toy language.
    Copyright (C) 2004, 2005 Edwin Brady

    This file is distributed under the terms of the GNU General
    Public Licence. See COPYING for licence.
-}

module CodegenCPP where

import Options
import TAC
import Language
import IO
import Debug.Trace
import Lib

data Output = RawOutput String
	    | FNOutput (Name, Type, String, [(ArgType, Var)], String)
	    | ExternOutput (Name, Type)
            | ExcOutput Name Bool
	    | GlobOutput Int
  deriving Show

{-
writeCf :: InputType -> [FilePath] -> [CompileResult] -> FilePath -> IO ()
writeCf domain lds xs out = do mprog <- getmain lds domain
			       header <- getheader lds domain
			       let str = ((writeout header mprog).writecpp) xs
			       writeFile out str
-}

writeC :: Name -> -- Module name
	  InputType -> [FilePath] -> Context -> 
	  [CompileResult] -> Handle -> Options -> IO ()
writeC mod domain lds ctxt xs out copts
    = do mprog <- getmain mod lds ctxt domain copts
	 header <- getheader mod lds ctxt domain copts
	 let str = ((writeout header mprog).writecpp) xs
	 hPutStr out str
	 hClose out

writeout :: String -> String -> [Output] -> String
writeout header mprog xs = header ++
		writedecls xs ++
		writeout' xs ++ mprog

writeout' [] = ""
writeout' ((RawOutput str):xs) = str ++ "\n" ++ writeout' xs
writeout' ((ExternOutput (f, ty)):xs) 
    = "void " ++ (show f) ++ mangling ty ++ "(VMState* vm);\n" 
				  ++ writeout' xs
writeout' ((ExcOutput nm defhere):xs)
    = (if defhere then "" else "extern ") ++
         "char* " ++ show nm ++ 
      (if defhere then " = \"" ++ showuser nm ++"\"" else "") ++
         ";\n" ++ writeout' xs
writeout' ((FNOutput (f, ty, pop, args, def)):rest) 
    = "void " ++ (show f) ++ mangling ty ++ "(VMState* vm){\n" ++ pop ++ "\n" ++ def ++ "}\n\n"
      ++ writeout' rest 
writeout' ((GlobOutput n):rest)
    = "DECLGLOBAL("++show n++");\n"
      ++ writeout' rest

writedecls [] = ""
writedecls ((FNOutput (f, ty, _, _, _)):rest) 
    = "void " ++ (show f) ++ mangling ty ++ "(VMState* vm);\n"
      ++ writedecls rest 
writedecls (_:xs) = writedecls xs

getmain,getheader :: Name -> [FilePath] -> Context -> InputType -> Options -> IO String
getmain mod lds ctxt (Program _) copts = 
   do str <- findFile lds "startup.vcc"
      str' <- replaceDefs mod ctxt ["__start","__panic"] str copts
      return str'
getmain mod lds ctxt Shebang copts = 
   do str <- findFile lds "startup.vcc"
      str' <- replaceDefs mod ctxt ["__start","__panic"] str copts
      return str'
getmain mod lds ctxt Module copts = return ""
{-
getmain mod lds ctxt Webapp = 
   do str <- findFile lds "startup.vcc"
      str' <- replaceDefs mod ctxt ["__start","__panic"] str 
      return str'
getmain mod lds ctxt Webprog = 
   do str <- findFile lds "startup.vcc"
      str' <- replaceDefs mod ctxt ["__start","__panic"] str 
      return str'
-}

-- getmain mod lds ctxt Program = 
--    do str <- findFile lds "program.vcc"
--       str' <- replaceDefs mod ctxt ["main"] str 
--       return str'
-- getmain mod lds ctxt Shebang = 
--    do str <- findFile lds "program.vcc"
--       str' <- replaceDefs mod ctxt ["main"] str 
--       return str'
-- getmain mod lds ctxt Module = return ""
-- getmain mod lds ctxt Webapp = 
--     do str <- findFile lds "webapp.vcc"
--        str' <- replaceDefs mod ctxt ["PreContent","Default","PostContent",
-- 				   "initWebApp","flush","IllegalHandler"] str 
--        return str'

getheader mod lds ctxt _ copts = 
   do hf <- findFile lds "header.vcc"
      return $ hf ++ "\nValue** globaltable"++show mod++"=NULL;\n\n"

getStartup :: InputType -> [FilePath] -> IO String
getStartup (Program s) lds = do contents <- findFile lds (s++".ks")
                                return $ "%line 0 %file \"" ++ s ++ ".ks\"\n" 
                                         ++ contents
getStartup Shebang lds = do contents <- findFile lds "program.ks"
                            return $ "%line 0 %file \"program.ks\"\n" 
                                         ++ contents
-- getStartup Webapp lds = findFile lds "webapp.ks"
-- getStartup Webprog lds = findFile lds "webprog.ks"
getStartup _ _ = return ""

replaceDefs :: Name -> Context -> [String] -> String -> Options -> IO String
replaceDefs mod ctxt [] x copts = return x
replaceDefs mod ctxt (x:xs) str copts = 
    do newstr <- replaceDefs mod ctxt xs str copts
       (fname,_) <- ctxtlookup mod (UN x) ctxt Nothing copts
       let str' = replace ("%"++x) (show fname) newstr
       return str'

replace :: String -> String -> String -> String
replace _ _ "" = ""
replace old new xs 
   | take (length old) xs == old 
       = new ++ replace old new (drop (length old) xs)
replace old new (x:xs) = x:(replace old new xs)

findFile :: [FilePath] -> FilePath -> IO String
findFile [] path
  = fail $ "Can't find " ++ path
findFile (x:xs) path 
  = catch
         (do --putStrLn $ "Trying " ++ x ++ path
	     f <- readFile (x++path)
	     return f)
         (\e -> findFile xs path)

writecpp :: [CompileResult] -> [Output]
writecpp [] = []
writecpp ((RawCode str):xs) = (RawOutput str):(writecpp xs)
writecpp ((ByteCode (n, ty, (Code pop args def))):xs) = {- trace (show n ++ show def) $-} ((FNOutput (n, ty, cpp pop, args, cpp def)):writecpp xs)
writecpp ((ExternDef (n,ty)):xs) = (ExternOutput (n, ty)):writecpp xs
writecpp ((ExcCode nm here):xs) = (ExcOutput nm here):(writecpp xs)
writecpp ((GlobCode n):xs) = (GlobOutput n):writecpp xs

cpp :: [TAC] -> String
cpp [] = ""
cpp (x:xs) = "\t" ++ instr x ++ ";\n" ++ cpp xs

printOp Plus = "+"
printOp Minus = "-"
printOp Times = "*"
printOp Divide = "/"
printOp Modulo = "%"
printOp OpLT = "<"
printOp OpGT = ">"
printOp OpLE = "<="
printOp OpGE = ">="
printOp Equal = "=="
printOp NEqual = "!="
printOp OpAnd = "&"
printOp OpOr = "|"
printOp OpAndBool = "&&"
printOp OpOrBool = "||"
printOp OpXOR = "^"
printOp OpShLeft = "<<"
printOp OpShRight = ">>"
printOp BAnd = "&&"
printOp BOr = "||"

printUnOp Not = "!"
printUnOp Neg = "-"

instr :: TAC -> String
instr (DECLARE v) = "DECLARE("++show (fst v)++")"
instr (DECLAREARG v) = "DECLAREARG("++show (fst v)++")"
instr (DECLAREQUICK v) = "DECLAREQUICK("++show (fst v)++")"
instr (HEAPVAL v) = "HEAPVAL("++ show (fst v) ++ ")"
instr (USETMP v) = "DECLAREARG("++show (fst v)++")"
instr (TMPINT i) = "TMPINT(t"++show i ++")"
instr (TMPREAL i) = "TMPREAL(t"++show i ++")"
instr (SET var idx val) = "SET("++show (fst var)++","++show idx ++ ","++show (fst val)++")"
instr TOINDEX = "TOINDEX"
instr SETTOP = "SETTOP"
instr ADDTOP = "ADDTOP"
instr SUBTOP = "SUBTOP"
instr MULTOP = "MULTOP"
instr DIVTOP = "DIVTOP"
instr APPENDTOP = "APPENDTOP"
instr APPENDTOPINT = "APPENDTOPINT"
instr (MKARRAY i) = "MKARRAY("++show i++")"
instr (TMPSET tmp val) = "t"++show tmp++"="++show val
instr (RTMPSET tmp val) = "t"++show tmp++"="++show val
instr (CALL v) = "CALLFUN("++show (fst v)++")"
instr (CALLNAME f) = "CALL("++f++")"
instr CALLTOP = "CALLTOP"
instr (FASTCALL nm args) = nm ++ "(vm, " ++ showlist args ++ ")"
    where showlist [] = ""
	  showlist [x] = show x
	  showlist (x:xs) = show x ++ "," ++ show xs
instr (TAILCALL v) = "TAILCALLFUN("++show (fst v)++")"
instr (TAILCALLNAME f) = "TAILCALL("++f++")"
instr TAILCALLTOP = "TAILCALLTOP"
instr (CLOSURE n i arity) 
    = "CLOSURE("++ n ++","++show i++"," ++ show arity ++")"
instr (CLOSURELOC v i) = "CLOSURELOC("++ show (fst v) ++","++show i++")"
instr (CLOSURETOP i) = "CLOSURETOP("++show i++")"
instr (FOREIGNCALL n lib ty args) = mkfcall n lib ty args 
instr (MKCON t i) = "MKCON("++show t ++ "," ++ show i ++ ")"
instr (MKCONZERO t) = "MKCONZERO("++show t ++ ")"
instr MKEXCEPT = "MKEXCEPT"
instr (MKNEWEXCEPT n a) = "MKNEWEXCEPT(" ++ show n ++ ", " ++ show a++")"
instr EQEXCEPT = "EQEXCEPT"
instr NEEXCEPT = "NEEXCEPT"
instr EQSTRING = "EQSTRING"
instr NESTRING = "NESTRING"
instr (EQSTRINGW v) = "EQSTRINGW(L"++show v++")"
instr (NESTRINGW v) = "NESTRINGW(L"++show v++")"
instr (JEQSTRING l) = "JEQSTRING("++show l++")"
instr (JNESTRING l) = "JNESTRING("++show l++")"
instr (JEQSTRINGW l v) = "JEQSTRINGW("++show l++",L"++show v++")"
instr (JNESTRINGW l v) = "JNESTRINGW("++show l++",L"++show v++")"
instr (JEXNE ex l) = "JEXNE("++show ex++","++show l++")"
instr (GETVAL v) = "GETVAL(t"++show v++")"
instr (GETRVAL v) = "GETRVAL(t"++show v++")"
instr GETINDEX = "GETINDEX"
instr (PROJARG a t) = "PROJARG("++ show a ++ "," ++ show t ++ ")"
instr (EXPROJARG a t) = "EXPROJARG("++ show (fst a) ++ "," ++ show t ++ ")"
instr (INFIX t Divide x y) = "INTDIV(" ++ tmp t ++ "," ++ 
			 tmp x ++ "," ++ tmp y ++ ")"
instr (INFIX t op x y) = "INTINFIX(" ++ tmp t ++ "," ++ printOp op ++ "," ++ 
			 tmp x ++ "," ++ tmp y ++ ")"
instr (ADDINPLACE v x) = "ADDINPLACE(" ++ pushitem v ++ "," ++ show x ++ ")"
instr (SUBINPLACE v x) = "SUBINPLACE(" ++ pushitem v ++ "," ++ show x ++ ")"
instr (INTPOWER t x y) = "INTPOWER(" ++ tmp t ++ "," ++ 
			 tmp x ++ "," ++ tmp y ++ ")"
instr (REALINFIX t Divide x y) = "REALDIV(" ++ tmp t ++ "," ++
			     tmp x ++ "," ++ tmp y ++ ")"
instr (REALINFIX t op x y) = "REALINFIX(" ++ tmp t ++ "," ++
			     printOp op ++ "," ++ 
			     tmp x ++ "," ++ tmp y ++ ")"
instr (REALINFIXBOOL op x y) = "REALINFIXBOOL(" ++ printOp op ++ "," ++ 
			       tmp x ++ "," ++ tmp y ++ ")"
instr (INFIXJFALSE op x y l) = "INFIXJFALSE(" ++ printOp op ++ "," ++
                                   tmp x ++ "," ++ tmp y ++ "," ++ show l ++")"
instr (INFIXJTRUE op x y l) = "INFIXJTRUE(" ++ printOp op ++ "," ++
                                  tmp x ++ "," ++ tmp y ++ "," ++ show l ++")"
instr (UNARYJFALSE op x l) = "UNARYJFALSE(" ++ printUnOp op ++ "," ++
                                   tmp x ++ "," ++ show l ++")"
instr (UNARYJTRUE op x l) = "UNARYJTRUE(" ++ printUnOp op ++ "," ++
                                  tmp x ++ "," ++ show l ++")"
instr (REALPOWER t x y) = "REALPOWER(" ++ tmp t ++ "," ++ 
			 tmp x ++ "," ++ tmp y ++ ")"
instr (UNARY t op x) = "INTUNARY("++tmp t++","++
		       printUnOp op ++ "," ++ tmp x ++ ")"
instr (REALUNARY t op x) = "REALUNARY("++tmp t++","++
			   printUnOp op ++ "," ++ tmp x ++ ")"
instr APPEND = "APPEND"
instr (APPENDCHAR i) = "APPENDINT("++show i++")"
instr (APPENDTMP t) = "APPENDINT(t"++show t++")"
instr (APPENDSTR str) = "APPENDSTR(L"++show str++")"
instr PRINTINT = "PRINTINT"
instr PRINTSTR = "PRINTSTR"
instr PRINTEXC = "PRINTEXC"
instr NEWLINE = "NEWLINE"
instr (LABEL l) = "LABEL("++show l++")"
instr (JUMP l) = "JUMP("++show l++")"
instr (JFALSE l) = "JFALSE("++show l++")"
instr (JTRUE l) = "JTRUE("++show l++")"
instr (JTFALSE t l) = "JTFALSE("++tmp t++","++show l++")"
instr (JTTRUE t l) = "JTTRUE("++tmp t++","++show l++")"
instr (TRY n) = "TRY("++ show n ++ ")"
instr TRIED = "TRIED"
instr THROW = "THROW"
instr RESTORE = "RESTORE"
instr (PUSH i) = "PUSH("++pushitem i++")"
instr (PUSH2 x y) = "PUSH2("++pushitem x++","++pushitem y++")"
instr (PUSH3 x y z) = "PUSH3("++pushitem x++","++pushitem y++","++
		                pushitem z++")"
instr (PUSH4 x y z w) = "PUSH4("++pushitem x ++ "," ++ pushitem y ++ "," ++
		                pushitem z ++ "," ++ pushitem w ++ ")"
instr (STACKINT var v) = "STACKINT("++show (fst var)++","++show v++")";
instr (STACKTMP var v) = "STACKINT("++show (fst var)++",t"++show v++")";
instr (TMPSETTOP t) = "TMPSETTOP(t"++show t++")"
instr (PUSHSETTOP x) = "PUSHSETTOP("++pushitem x++")"
instr (PUSHGETVAL x t) = "PUSHGETVAL("++pushitem x++",t"++show t++")"
instr (PUSHGETRVAL x t) = "PUSHGETRVAL("++pushitem x++",t"++show t++")"
instr (PUSHGETINDEX x) = "PUSHGETINDEX("++pushitem x++")"

instr (PUSHTOINDEX x) = "PUSHTOINDEX("++pushitem x++")"

instr (PUSHGLOBAL x i) = "PUSHGLOBAL(globaltable"++x++"," ++show i ++ ")"
instr (CREATEGLOBAL x i) = "CREATEGLOBAL("++show x ++ "," ++ show i++")"
instr RETURN = "RETURN"
instr (SETVAL v x) = "SETVAL("++show (fst v)++","++show x++")"
instr (PUSHSETINT v x) = "SETINT("++pushitem v++","++tmp x++")"
instr (SETINT v x) = "SETINT("++show (fst v)++","++tmp x++")"
instr (SETVAR v x) = "SETVAR("++show (fst v)++","++show (fst x)++")"
instr (GETLENGTH) = "GETLENGTH"
instr (POP v) = "POP("++show (fst v)++")"
instr (POPARG v) = "POPARG("++show (fst v)++")"
instr (POPANDCOPYARG v) = "POPANDCOPYARG("++show (fst v)++")"
instr (POPINDEX v) = "POPINDEX("++show (fst v)++")"
instr (NOTEVAR v) = "NOTEVAR("++show (fst v)++")"
instr (COPYARG v) = "COPYARG("++show (fst v)++")"
instr (REMEMBER v) = "REMEMBER("++show (fst v)++")"
instr (ARRAY v) = "ARRAY("++show (fst v)++")"
instr (PROJ v i) = "PROJ("++show (fst v) ++ ","++show i++")"
instr DISCARD = "DISCARD"
instr (CASE as) = "switch(TAG) {\n" ++ (instrCases 0 as) ++ "\t}"
instr (CONSTCASE ty as def)
   | ty == Number || ty == Character || ty == Boolean
     = "switch(TOPINT) {\n" ++ intCases as def ++ "\t}"
   | ty == StringType = stringCases True as def
--   | ty == Exception = constCases EQEXCEPT as def
instr STR2INT = "STR2INT"
instr INT2STR = "INT2STR"
instr REAL2STR = "REAL2STR"
instr BOOL2STR = "BOOL2STR"
instr STR2REAL = "STR2REAL"
instr CHR2STR = "CHR2STR"
instr INT2REAL = "INT2REAL"
instr REAL2INT = "REAL2INT"
instr VMPTR = "VMPTR"
{- CIM 12/7/05 changed to KERROR to avoid clashes with MinGW -}
instr ERROR = "KERROR"
instr (LINENO f l) = "LINENO(L\""++f++"\","++show l++")"
instr (PUSHBT fn f l) = "PUSHBT(L\""++fn++"\",L\""++f++"\","++show l++")"
instr (INLAM f) = "INLAM(\""++f++"\")"
instr POPBT = "POPBT"
instr (CHECKCACHE n) | n<5 = "CHECKCACHE("++show n++")"
		     | otherwise = ""
instr (STORECACHE xs) 
   | length xs < 5 = "STORECACHE"++show (length xs)++"("++showlist (map fst xs)++")"
   | otherwise = ""
  where showlist [] = ""
	showlist [x] = show x
	showlist (x:xs) = show x ++ "," ++ show xs
instr _ = "NOP"

pushitem (NAME n i) = "MKFUN("++n++","++ show i ++ ")"
pushitem (VAL x) = "MKINT("++show x++")"
pushitem (RVAL x) = "MKREAL("++show x++")"
pushitem (STR x) = "MKSTR(L"++show x++")"
pushitem (INT t) = "MKINT("++tmp t++")"
pushitem (REAL t) = "MKREAL("++tmp t++")"
pushitem (VAR v) = show (fst v)
pushitem EMPTYSTR = "EMPTYSTR"

mkfcall n lib ty args = popvals 0 (length args) ++
			(conv ty) ++ "("++n++"("++stackconv 0 args++")))"
 where
    popvals n 0 = ""
    popvals n (a+1) = 
		    -- "HEAPVAL("++show (tmpval n)++"); " ++ 
                    "POPARG(" ++ show (tmpval n) ++ "); " ++
                    popvals (n+1) a

{- CIM 12/7/05 changed to KVOID to avoid clashes with MinGW -}
    conv (Prim Void) = "KVOID("
    conv (Prim Number) = "PUSH(MKINT"
    conv (Prim RealNum) = "PUSH(MKREAL"
    conv (Prim Boolean) = "PUSH(MKINT"
    conv (Prim Character) = "PUSH(MKCHAR"
    conv (Prim StringType) = "PUSH(MKSTR"
    conv (Prim File) = "PUSH(MKINT"
    conv (Prim Pointer) = "PUSH(MKINT"
    conv (TyVar _) = "PUSH(" -- Enough rope to hang yourself with!
    conv (Array _) = "PUSH(MKARRAYVAL"
    conv t = "PUSH(" -- error $ "Can't deal with that type in foreign calls" ++ show t

    stackconv n [] = ""
    stackconv n [x] = stackconv' n x
    stackconv n (x:xs) = stackconv' n x ++ "," ++ stackconv (n+1) xs
    stackconv' n (Prim Number) = show (tmpval n) ++ "->getInt()"
    stackconv' n (Prim RealNum) = show (tmpval n) ++ "->getReal()"
    stackconv' n (Prim Boolean) = show (tmpval n) ++ "->getInt()"
    stackconv' n (Prim Character) = show (tmpval n) ++ "->getInt()"
    stackconv' n (Prim StringType) 
	= show (tmpval n) ++ "->getString()->getVal()"
    stackconv' n (Prim File) 
	= "(FILE*)("++show (tmpval n) ++ "->getRaw())"
    stackconv' n (Prim Pointer) = show (tmpval n) ++ "->getRaw()"
    stackconv' n (Array _) = show (tmpval n) ++ "->getArray()"
    stackconv' n (TyVar _) = show (tmpval n)
    stackconv' n t = show (tmpval n) -- error $ "Can't deal with that type (" ++ show t ++ ") in foreign calls"


instrCases :: Int -> [[TAC]] -> String
instrCases v [] = ""
instrCases v (x:xs) = "\tcase " ++ show v ++ ":\n" ++ cpp x ++ "\tbreak;\n"
		      ++ instrCases (v+1) xs

intCases :: [(Const,[TAC])] -> [TAC] -> String
intCases [] def = "\tdefault:\n" ++ cpp def ++ "\tbreak;\n"
intCases ((c,code):xs) def = "\tcase " ++ showcase c++":\n" ++ cpp code ++
			     "\tbreak;\n" ++ intCases xs def
								  
showcase (Num x) = show x
showcase (Ch c) = show (fromEnum c)
showcase (Bo True) = "1"
showcase (Bo False) = "0"
showcase (Str str) = show str

stringCases :: Bool -> [(Const,[TAC])] -> [TAC] -> String
stringCases isfst ((c,code):xs) def =
    needelse isfst ++
    "if (TOPSTREQ(L" ++ showcase c ++ ")) {\n " ++ cpp code ++ "\t}" ++
    stringCases False xs def
  where needelse True = ""
	needelse False = "else "
stringCases isfst [] def =
    needelse isfst ++ "{\n " ++ cpp def ++ "\t}"
  where needelse True = ""
	needelse False = "else "
