/*      Copyright (C) 2001, 2002, 2003, 2004 Stijn van Dongen
 *
 * This file is part of MCL.  You can redistribute and/or modify MCL under the
 * terms of the GNU General Public License; either version 2 of the License or
 * (at your option) any later version.  You should have received a copy of the
 * GPL along with MCL, in the file COPYING.
*/


#include <string.h>
#include <stdio.h>
#include <math.h>
#include <float.h>

#include "impala/compose.h"
#include "impala/matrix.h"
#include "impala/vector.h"
#include "impala/pval.h"
#include "impala/io.h"
#include "impala/iface.h"

#include "util/io.h"
#include "util/err.h"
#include "util/minmax.h"
#include "util/opt.h"
#include "util/types.h"

const char* me = "mcxarray";

const char* usagelines[] =
{  "Usage: mcxarray [options] <array data matrix>"
,  ""
,  "Options:"
,  "-o <fname>       write to file fname"
,  "-co <num>        remove inproduct (output) values smaller than num"
,  "-gq <num>        ignore data (input) values smaller than num"
,  "-lq <num>        ignore data (input) values larger than num"
,  "-t               work with the transpose"
,  "-ctr <num>       add center value (for graph-type input)"
,  "--ctr            add center with default value (for graph-type input)"
,  "--01             remap to [0,1] interval"
,  "--cosine         work with the cosine"
,  "--pearson        work with Pearson correlation score [default]"
,  "-tear <num>      inflate the input columns"
,  "-teartp <num>    inflate the tranposed columns"
,  "-pi <num>        inflate the result"
,  NULL
}  ;


int main
(  int                  argc
,  const char*          argv[]
)
   {  int a = 1, c, d
   ;  int digits = MCLXIO_VALUE_GETENV
   ;  double cutoff = -1.0, tear = 0.0, teartp = 0.0, pi = 0.0
   ;  double lq = DBL_MAX, gq = -DBL_MAX, ctr = 0.0
   ;  mclx* tbl, *res
   ;  mcxIO* xfin, *xfout
   ;  mclv* ssqs, *sums, *scratch
   ;  mcxbool transpose = FALSE, to01 = FALSE
   ;  const char* out = "out.array"
   ;  mcxbool mode = 'p'
   ;  int n_mod

#define ARGCOK if (a++ + 1 >= argc) goto arg_missing

   ;  while (a<argc)
      {  if (!strcmp(argv[a], "-h"))
         {  mcxUsage(stdout, me, usagelines)
         ;  return 0
      ;  }
         else if (!strcmp(argv[a], "-t"))
         transpose = TRUE
      ;  else if (!strcmp(argv[a], "--pearson"))
         mode = 'p'
      ;  else if (!strcmp(argv[a], "--cosine"))
         mode = 'c'
      ;  else if (!strcmp(argv[a], "--01"))
         to01 = TRUE
      ;  else if (!strcmp(argv[a], "--ctr"))
         ctr = 1.0
      ;  else if (!strcmp(argv[a], "-cutoff") || !strcmp(argv[a], "-co"))
         {  ARGCOK
         ;  cutoff = atof(argv[a])
      ;  }
         else if (!strcmp(argv[a], "-ctr"))
         {  ARGCOK
         ;  ctr = atof(argv[a])
      ;  }
         else if (!strcmp(argv[a], "-lq"))
         {  ARGCOK
         ;  lq = atof(argv[a])
      ;  }
         else if (!strcmp(argv[a], "-gq"))
         {  ARGCOK
         ;  gq = atof(argv[a])
      ;  }
         else if (!strcmp(argv[a], "-o"))
         {  ARGCOK
         ;  out = argv[a]
      ;  }
         else if (!strcmp(argv[a], "-pi"))
         {  ARGCOK
         ;  pi = atof(argv[a])
      ;  }
         else if (!strcmp(argv[a], "-tear"))
         {  ARGCOK
         ;  tear = atof(argv[a])
      ;  }
         else if (!strcmp(argv[a], "-teartp"))
         {  ARGCOK
         ;  teartp = atof(argv[a])
      ;  }
         else if (!strcmp(argv[a], "-digits"))
         {  ARGCOK
         ;  digits = strtol(argv[a], NULL, 10)
      ;  }
         else if (0)
         {  arg_missing
         :  mcxTell(me, "flag <%s> needs argument; see help (-h)", argv[argc-1])
         ;  mcxExit(1)
      ;  }
         else if (a == argc-1)
         break
      ;  else
         {  mcxErr(me, "not an option: <%s>", argv[a])
         ;  return 1
      ;  }
         a++
   ;  }

      xfin = mcxIOnew(argv[argc-1], "r")
   ;  xfout = mcxIOnew(out, "w")
   ;  mcxIOopen(xfin, EXIT_ON_FAIL)
   ;  mcxIOopen(xfout, EXIT_ON_FAIL)
   ;  tbl = mclxRead(xfin, EXIT_ON_FAIL)

   ;  if (lq < DBL_MAX)
      {  double mass = mclxMass(tbl)
      ;  double kept = mclxSelectValues(tbl, NULL, &lq, MCLX_LQ)
      ;  fprintf(stderr, "orig %.2f kept %.2f\n", mass, kept)
   ;  }

      if (gq > -DBL_MAX)
      {  double mass = mclxMass(tbl)
      ;  double kept = mclxSelectValues(tbl, &gq, NULL, MCLX_GQ)
      ;  fprintf(stderr, "orig %.2f kept %.2f\n", mass, kept)
   ;  }

      if (ctr && MCLD_EQUAL(tbl->dom_cols, tbl->dom_rows))
      mcxDie(1, me, "-ctr option disabled (mclxCenter needs inspection)")

   ;  if (tear)
      mclxInflate(tbl, tear)

   ;  if (transpose)
      {  mclx* tblt = mclxTranspose(tbl)
      ;  mclxFree(&tbl)
      ;  tbl = tblt
      ;  if (teartp)
         mclxInflate(tbl, teartp)
   ;  }

      ssqs = mclvCopy(NULL, tbl->dom_cols)
   ;  sums = mclvCopy(NULL, tbl->dom_cols)
   ;  scratch = mclvCopy(NULL, tbl->dom_cols)

   ;  for (c=0;c<N_COLS(tbl);c++)
      {  double sumsq = mclvPowSum(tbl->cols+c, 2.0)
      ;  double sum = mclvSum(tbl->cols+c)
      ;  ssqs->ivps[c].val = sumsq
      ;  sums->ivps[c].val = sum
   ;  }

      res   =
      mclxAllocZero
      (  mclvCopy(NULL, tbl->dom_cols)
      ,  mclvCopy(NULL, tbl->dom_cols)
      )

   ;  n_mod =  MAX(1+(N_COLS(tbl)-1)/40, 1)

   ;  {  double N  = MAX(N_ROWS(tbl), 1)

      ;  for (c=0;c<N_COLS(tbl);c++)
         {  mclvZeroValues(scratch)
         ;  for (d=c;d<N_COLS(tbl);d++)
            {  double ip = mclvIn(tbl->cols+c, tbl->cols+d)
            ;  double score = 0.0
            ;  if (mode == 'c')
               {  double nom = sqrt(ssqs->ivps[c].val  * ssqs->ivps[d].val)
               ;  score = nom ? ip / nom : 0.0
            ;  }
               else if (mode == 'p')
               {  double s1 = sums->ivps[c].val
               ;  double sq1= ssqs->ivps[c].val
               ;  double s2 = sums->ivps[d].val
               ;  double sq2= ssqs->ivps[d].val
               ;  double nom= sqrt((sq1 - s1*s1/N) * (sq2 - s2*s2/N))

               ;  double num= ip - s1*s2/N
               ;  double f1 = sq1 - s1*s1/N
               ;  double f2 = sq2 - s2*s2/N
               ;  score = nom ? (num / nom) : 0.0
;if (0) fprintf(stderr, "--%.2f %.2f %.2f %.2f %.2f %.2f %.2f %.2f\n", s1, sq1, s2, sq2, f1, f2, num, nom)
            ;  }
               if (score >= cutoff)
               scratch->ivps[d].val = score
         ;  }
            mclvAdd(scratch, res->cols+c, res->cols+c)
         ;  if ((c+1) % n_mod == 0)
               fputc('.', stderr)
            ,  fflush(NULL)
      ;  }
      }

      mclxAddTranspose(res, 0.5)

   ;  if (to01)
      {  mclx* halves
         =  mclxCartesian
            (  mclvCopy(NULL, res->dom_cols)
            ,  mclvCopy(NULL, res->dom_rows)
            ,  0.5
            )
      ;  mclxScale(res, 0.5)
      ;  mclxMerge(res, halves, fltAdd)
      ;  mclxFree(&halves)
   ;  }

      if (pi)
      mclxInflate(res, pi)

   ;  mclxWrite(res, xfout, digits, EXIT_ON_FAIL)
   ;  return 0
;  }

