/**************************************************************************/
/* File:   cg.cc                                                          */
/* Author: Joachim Schoeberl                                              */
/* Date:   5. Jul. 96                                                     */
/**************************************************************************/

/* 

  Conjugate Gradient Soler
  
*/ 

#include <la.hpp>
//#include "../florian/typecast.hh"

namespace ngla
{
  using namespace ngla;


  inline double Abs (double v)
  {
    return fabs (v);
  }

  inline double Abs (Complex v)
  {
    return std::abs (v);
  }





  KrylovSpaceSolver :: KrylovSpaceSolver ()
  {
    //      SetSymmetric();
    
    a = 0;  
    c = 0;
    SetPrecision (1e-10);
    SetMaxSteps (200); 
    SetInitialize (1);
    printrates = 0;
  }
  

  KrylovSpaceSolver :: KrylovSpaceSolver (const BaseMatrix & aa)
  {
    //  SetSymmetric();
    
    SetMatrix (aa);
    c = NULL;
    SetPrecision (1e-10);
    SetMaxSteps (200);
    SetInitialize (1);
    printrates = 0;
  }



  KrylovSpaceSolver :: KrylovSpaceSolver (const BaseMatrix & aa, const BaseMatrix & ac)
  {
    //  SetSymmetric();
    
    SetMatrix (aa);
    SetPrecond (ac);
    SetPrecision (1e-8);
    SetMaxSteps (200);
    SetInitialize (1);
    printrates = 0;
  }

  
  BaseVector * KrylovSpaceSolver :: CreateVector () const
  {
    return 0; // return a->CreateVector();
  }




  template <class SCAL>
  void CGSolver<SCAL> :: Mult (const BaseVector & f, BaseVector & u) const
  {
    try
      {
	// Solve A u = f

	BaseVector & d = *f.CreateVector();
	BaseVector & w = *f.CreateVector();
	BaseVector & s = *f.CreateVector();

	int n = 0;
	SCAL al, be, wd, wdn, kss;
	double err;
	
	if (initialize)
	  {
	    u = 0.0;
	    d = f;
	  }
	else
	  d = f - (*a) * u;

	if (c)
	  w = (*c) * d;
	else
	  w = d;

	s = w;

	wdn = S_InnerProduct<SCAL> (w,d);
	if (printrates) cout << "0 " << sqrt(Abs(wdn)) << endl;
	if (wdn == 0.0) wdn = 1;
	
	if(stop_absolute)
	  err = prec * prec;
	else
	  err = prec * prec * Abs (wdn);
	


	while (n++ < maxsteps && Abs(wdn) > err)
	  {
	    w = (*a) * s;

	    wd = wdn;

	    kss = S_InnerProduct<SCAL> (w, s);
	    if (kss == 0.0) break;
	    
	    al = wd / kss;
	    
	    u += al * s;
	    d -= al * w;

	    if (c)
	      w = (*c) * d;
	    else
	      w = d;
	    

	    wdn = S_InnerProduct<SCAL> (w, d);
	    be = wdn / wd;
	    
	    
	    s *= be;
	    s += w;
	    
	    if (printrates ) cout << n << " " << sqrt (Abs (wdn)) << endl;
	  } 
	
	const_cast<int&> (steps) = n;
	
	delete &d;
	delete &w;
	delete &s;
      }

    catch (exception & e)
      {
	throw Exception(e.what() +
			string ("\ncaught in CGSolver::Mult\n"));
      }
    catch (Exception & e)
      {
	e.Append ("in caught in CGSolver::Mult\n");
	throw;
      }
  }















//*****************************************************************
// Iterative template routine -- QMR
//
// QMR.h solves the unsymmetric linear system Ax = b using the
// Quasi-Minimal Residual method following the algorithm as described
// on p. 24 in the SIAM Templates book.
//
//   -------------------------------------------------------------
//   return value     indicates
//   ------------     ---------------------
//        0           convergence within max_iter iterations
//        1           no convergence after max_iter iterations
//                    breakdown in:
//        2             rho
//        3             beta
//        4             gamma
//        5             delta
//        6             ep
//        7             xi
//   -------------------------------------------------------------
//   
// Upon successful return, output arguments have the following values:
//
//        x  --  approximate solution to Ax=b
// max_iter  --  the number of iterations performed before the
//               tolerance was reached
//      tol  --  the residual after the final iteration
//
//*****************************************************************



template <class SCAL>
void QMRSolver<SCAL> :: Mult (const BaseVector & b, BaseVector & x) const
{
  try
    {
      cout << "QMR called" << endl;
      double resid;
      SCAL rho, rho_1, xi, gamma, gamma_1, theta, theta_1, eta, delta, ep, beta;
      

      BaseVector & r = *b.CreateVector();
      BaseVector & v_tld = *b.CreateVector();
      BaseVector & y = *b.CreateVector();
      BaseVector & w_tld = *b.CreateVector();
      BaseVector & z = *b.CreateVector();
      BaseVector & v = *b.CreateVector();
      BaseVector & w = *b.CreateVector();
      BaseVector & y_tld = *b.CreateVector();
      BaseVector & z_tld = *b.CreateVector();
      BaseVector & p = *b.CreateVector();
      BaseVector & q = *b.CreateVector();
      BaseVector & p_tld = *b.CreateVector();
      BaseVector & d = *b.CreateVector();
      BaseVector & s = *b.CreateVector();

      double normb = b.L2Norm();


      if (initialize)
	x = 0;


      r = b - (*a) * x;

      if (normb == 0.0)
	normb = 1;
      
      cout.precision(12);
      
      // 
      double tol = prec;
      int max_iter = maxsteps;
      
      if ((resid = r.L2Norm() / normb) <= tol) {
	tol = resid;
	max_iter = 0;
	((int&)status) = 0;
	return;
      }
  
      v_tld = r;

      // use preconditioner c1
      if (c)
	y = (*c) * v_tld;
      else
	y = v_tld;

      rho = y.L2Norm();
      
      w_tld = r;

      if (c2) 
	z = Transpose (*c2) * w_tld; 
      // z = (*c2) * w_tld; 
      else
	z = w_tld;
      
      xi = z.L2Norm();

      gamma = 1.0;
      eta = -1.0;
      theta = 0.0;
      ((int&)steps) = 0;


      for (int i = 1; i <= max_iter; i++) 
	{
	  //	  (*testout) << "qmr, it = " << i << endl;

	  ((int&)steps) = i;  
	  
	  if (rho == 0.0)
	    {
	      (*testout) << "QMR: breakdown in rho" << endl;
	      ((int&)status) = 2;
	      return;                        // return on breakdown
	    }
	  
	  if (xi == 0.0)
	    {
	      (*testout) << "QMR: breakdown in xi" << endl;
	      ((int&)status) = 7;
	      return;                        // return on breakdown
	    }

	  v = (1.0/rho) * v_tld;
	  y /= rho;

	  w = (1.0/xi) * w_tld;
	  z /= xi;


	  delta = S_InnerProduct<SCAL> (z, y);
	  if (delta == 0.0)
	    {
	      (*testout) << "QMR: breakdown in delta" << endl;
	      ((int&)status) = 5;
	      return;                        // return on breakdown
	    }

	  
	  if (c2) 
	    y_tld = (*c2) * y;
	  else
	    y_tld = y;

	  
	  if (c)
	    z_tld = Transpose (*c) * z;
	  // z_tld = (*c) * z;
	  else
	    z_tld = z;

	  if (i > 1) 
	    {
	      //  p = y_tld - (xi(0) * delta(0) / ep(0)) * p;
	      //  q = z_tld - (rho(0) * delta(0) / ep(0)) * q;
	      p *= (-xi * delta / ep);
	      p += y_tld;
	      q *= (-rho * delta / ep);
	      q += z_tld;
	    } 
	  else 
	    {
	      p = y_tld;
	      q = z_tld;
	    }
	  
	  p_tld = (*a) * p;
	  ep = S_InnerProduct<SCAL> (q, p_tld);

	  if (ep == 0.0)
	    {
	      (*testout) << "QMR: breakdown in ep" << endl;
	      ((int&)status) = 6;
	      return;                        // return on breakdown
	    }

	  beta = ep / delta;
	  if (beta == 0.0)
	    {
	      (*testout) << "QMR: breakdown in beta" << endl;
	      ((int&)status) = 3;
	      return;                        // return on breakdown
	    }

	  v_tld = p_tld;
	  v_tld -= beta * v;

	  if (c)
	    y = (*c) * v_tld;
	  else
	    y = v_tld;


	  rho_1 = rho;
	  rho = y.L2Norm();

	  w_tld = Transpose(*a) * q;
	  w_tld -= beta * w;
	  
	  if (c2) 
	    z = Transpose (*c2) * w_tld;
	  // z = (*c2) * w_tld;
	  else
	    z = w_tld;
	  
	  xi = z.L2Norm();
	  
	  gamma_1 = gamma;
	  theta_1 = theta;
	  
	  theta = rho / (gamma_1 * Abs(beta));    // abs (beta) ???
	  gamma = 1.0 / sqrt(1.0 + theta * theta);
	  
	  if (gamma == 0.0)
	    {
	      (*testout) << "QMR: breakdown in gamma" << endl;
	      ((int&)status) = 4;
	      return;                        // return on breakdown
	    }
	  
	  eta = -eta * rho_1 * gamma * gamma / 
	    (beta * gamma_1 * gamma_1);

	  if (i > 1) 
	    {
	      // d = eta(0) * p + (theta_1(0) * theta_1(0) * gamma(0) * gamma(0)) * d;
	      // s = eta(0) * p_tld + (theta_1(0) * theta_1(0) * gamma(0) * gamma(0)) * s;
	      d *= (theta_1 * theta_1 * gamma * gamma);
	      d += eta * p;
	      s *= (theta_1 * theta_1 * gamma * gamma);
	      s += eta * p_tld;
	    } 
	  else 
	    {
	      d = eta * p;
	      s = eta * p_tld;
	    }
	  
	  x += d;
	  r -= s;

	  if ( printrates )   (cout) << i << " " << r.L2Norm() << endl;
	  
	  if ((resid = r.L2Norm() / normb) <= tol) {
	    tol = resid;
	    max_iter = i;
	    ((int&)status) = 0;
	    return;
	  }
	}
      
      /*
      (*testout) << "no convergence" << endl;

      (*testout) << "res = " << endl << r << endl;
      (*testout) << "x = " << endl << x << endl;
      (*testout) << "b = " << endl << b << endl;
      */
      tol = resid;
      ((int&)status) = 1;
      return;                            // no convergence
    }

  
  catch (exception & e)
    {
      throw Exception(e.what() +
		      string ("\ncaught in QMRSolver::Mult\n"));
    }

  catch (Exception & e)
    {
      e.Append ("in caught in QMRSolver::Mult\n");
      throw;
    }
}
  

  
  template class CGSolver<double>;
  template class CGSolver<Complex>;
  template class QMRSolver<double>;
  template class QMRSolver<Complex>;


}
