#ifndef FILE_NGS_BASEMATRIX
#define FILE_NGS_BASEMATRIX


/*********************************************************************/
/* File:   basematrix.hpp                                            */
/* Author: Joachim Schoeberl                                         */
/* Date:   25. Mar. 2000                                             */
/*********************************************************************/

/**
   The base for all matrices in the linalg.
*/
class BaseMatrix
{
public:
  /// constructor
  BaseMatrix ();
  /// destructor
  virtual ~BaseMatrix ();
  
  /// virtual function must be overloaded
  virtual int VHeight() const;

  /// virtual function must be overloaded
  virtual int VWidth() const;

  /// inline function VHeight
  int Height() const
  {
    return VHeight();
  }
  
  /// inline function VWidth
  int Width() const
  {
    return VWidth();
  }

  /// scalar assignment
  BaseMatrix & operator= (double s)
  {
    AsVector().SetScalar(s);
    return *this;
  }

  /// linear access of matrix memory
  virtual BaseVector & AsVector();
  /// linear access of matrix memory
  virtual const BaseVector & AsVector() const;
  
  virtual ostream & Print (ostream & ost) const;
  virtual void MemoryUsage (ARRAY<MemoryUsageStruct*> & mu) const;

  // virtual const void * Data() const;
  // virtual void * Data();

  /// creates matrix of same type
  virtual BaseMatrix * CreateMatrix () const;
  /// creates a compativle vector, size = width
  virtual BaseVector * CreateRowVector () const;
  /// creates a compativle vector, size = height
  virtual BaseVector * CreateColVector () const;
  /// creates a fitting vector (for square matrices)
  virtual BaseVector * CreateVector () const;

  /// y = matrix * x. Multadd should be implemented, instead
  virtual void Mult (const BaseVector & x, BaseVector & y) const;
  /// y += s matrix * x
  virtual void MultAdd (double s, const BaseVector & x, BaseVector & y) const;
  /// y += s matrix * x
  virtual void MultAdd (Complex s, const BaseVector & x, BaseVector & y) const;
  
  /// y += s Trans(matrix) * x
  virtual void MultTransAdd (double s, const BaseVector & x, BaseVector & y) const;
  /// y += s Trans(matrix) * x
  virtual void MultTransAdd (Complex s, const BaseVector & x, BaseVector & y) const;
};



/// Fixes the scalar type.
template <typename SCAL>
class S_BaseMatrix :  virtual public BaseMatrix
{
public:
  ///
  S_BaseMatrix ();
  ///
  virtual ~S_BaseMatrix ();
};


/// Fixes the scalar type Complex.
template <>
class S_BaseMatrix<Complex> : virtual public BaseMatrix
{
public:
  ///
  S_BaseMatrix ();
  ///
  virtual ~S_BaseMatrix ();

  /// calls MultAdd (Complex s);
  virtual void MultAdd (double s, const BaseVector & x, BaseVector & y) const;
  /// must be overloaded
  virtual void MultAdd (Complex s, const BaseVector & x, BaseVector & y) const;
  
  /// calls MultTransAdd (Complex s);
  virtual void MultTransAdd (double s, const BaseVector & x, BaseVector & y) const;
  /// should be overloaded
  virtual void MultTransAdd (Complex s, const BaseVector & x, BaseVector & y) const;
};







/* *************************** Matrix * Vector ******************** */


/// 
class VMatVecExpr
{
  const BaseMatrix & m;
  const BaseVector & x;
  
public:
  VMatVecExpr (const BaseMatrix & am, const BaseVector & ax) : m(am), x(ax) { ; }

  template <class TS>
  void AssignTo (TS s, BaseVector & v) const
  { 
    m.Mult (x, v);
    v *= s;
  }

  template <class TS>
  void AddTo (TS s, BaseVector & v) const
  { 
    m.MultAdd (s, x, v);
  }
};


/// BaseMatrix times Vector - expression template
inline VVecExpr<VMatVecExpr>
operator* (const BaseMatrix & a, const BaseVector & b)
{
  return VMatVecExpr (a, b);
}


/* ************************** Transpose ************************* */

/**
   The Transpose of a BaseMatrix.
 */
class Transpose : public BaseMatrix
{
  const BaseMatrix & bm;
public:
  ///
  Transpose (const BaseMatrix & abm) : bm(abm) { ; }
  ///
  virtual void MultAdd (double s, const BaseVector & x, BaseVector & y) const
  {
    bm.MultTransAdd (s, x, y);
  }
  ///
  virtual void MultAdd (Complex s, const BaseVector & x, BaseVector & y) const 
  {
    bm.MultTransAdd (s, x, y);
  }
  ///
  virtual void MultTransAdd (double s, const BaseVector & x, BaseVector & y) const
  {
    bm.MultAdd (s, x, y);
  }
  ///
  virtual void MultTransAdd (Complex s, const BaseVector & x, BaseVector & y) const
  {
    bm.MultAdd (s, x, y);
  }  
};

/* *********************** operator<< ********************** */

/// output operator for matrices
inline ostream & operator<< (ostream & ost, const BaseMatrix & m)
{
  return m.Print(ost);
}


#endif
