/*------------------------------------------------------
 Maximum likelihood estimation 
 of migration rate  and effectice population size
 using a Metropolis-Hastings Monte Carlo algorithm                            
 -------------------------------------------------------                        
 R e p o r t e r   R O U T I N E S 

 reports things progress=True or verbose

 stuff from world.c comes here, transfer not yet complete.
                                                                                                               
 Peter Beerli 1999, Seattle
 beerli@genetics.washington.edu
 $Id: reporter.c,v 1.10 2000/04/04 18:39:44 beerli Exp $

-------------------------------------------------------*/

#include "migration.h"
#include "mcmc.h"

#include "fst.h"
#include "random.h"
#include "tools.h"
#include "broyden.h"
#include "combroyden.h"
#include "options.h"

#ifdef DMALLOC_FUNC_CHECK
#include <dmalloc.h>
#endif


void both_chain_means (double *mc, double *lc, double *tc, long len, long lastn,
		       long n);
void calc_gelmanb (double *gelmanb, double *mc, double *tc, double *lc,
		   long len, long lastn, long n);
void calc_gelmanw (double *gelmanw, world_fmt * world, double *mc, double *tc,
		   long len, long lastn, long n);
void calc_gelmanr (double *gelmanr, double *gelmanw, double *gelmanb, long len,
		   long lastn, long n);
void calc_average_biggest_gelmanr (double *gelmanr, long len, double *meanR,
				   double *bigR);
void print_gelmanr (double average, double biggest);
double calc_s (long this, double *tc, world_fmt * world);
void chain_means (double *thischainmeansm, world_fmt * world);


//public function
void
convergence_check (world_fmt * world, boolean progress)
{
  static double *lastchainmeans, *chainmeans, *thischainmeans;
  static double *gelmanw, *gelmanb, *gelmanr;

  static boolean done = FALSE;
  static long len = 0;
  static long lastn = 0;
  static boolean first = TRUE;
  long n = 0;
  if (world->chains < 2 || world->repkind!=SINGLECHAIN)
    return;

  if (world->start)
    first = TRUE;
  if (progress || world->options->gelman)
    {
      if (!done)
	{
	  done = TRUE;
	  // len defines the length of arrays that
	  // have to hold all km, kt, p, and mindex means
	  len = world->numpop2 + world->numpop * 3;
	  lastchainmeans = (double *) calloc (len, sizeof (double));
	  thischainmeans = (double *) calloc (len, sizeof (double));
	  chainmeans = (double *) calloc (len, sizeof (double));
	  gelmanw = (double *) calloc (len, sizeof (double));
	  gelmanb = (double *) calloc (len, sizeof (double));
	  gelmanr = (double *) calloc (len, sizeof (double));
	}
      n = world->atl[world->rep][world->locus].T;
      memset (thischainmeans, 0, sizeof (double) * len);
      if (first)
	{
	  first = FALSE;
	  chain_means (lastchainmeans, world);
	  return;
	}
      else
	{
	  chain_means (thischainmeans, world);
	  both_chain_means (chainmeans, lastchainmeans, thischainmeans, len,
			    lastn, n);
	  calc_gelmanb (gelmanb, chainmeans, thischainmeans, lastchainmeans,
			len, lastn, n);
	  calc_gelmanw (gelmanw, world, thischainmeans, lastchainmeans, len,
			lastn, n);
	  calc_gelmanr (gelmanr, gelmanw, gelmanb, len, lastn, n);
	  calc_average_biggest_gelmanr (gelmanr, len, &world->gelmanmeanR,
					&world->gelmanmaxR);
	  memcpy (lastchainmeans, thischainmeans, sizeof (double) * len);
	  lastn = n;
	}
    }
}

void
both_chain_means (double *mc, double *lc, double *tc, long len, long lastn,
		  long n)
{
  long i;

  for (i = 0; i < len; i++)
    {
      mc[i] = (lc[i] * lastn + tc[i] * n) / (n + lastn);
    }
}


void
calc_gelmanb (double *gelmanb, double *mc, double *tc, double *lc, long len,
	      long lastn, long n)
{
  long i;
  double nn = (n + lastn) / 2.;

  for (i = 0; i < len; i++)
    {
      gelmanb[i] =
	nn * (pow ((lc[i] - mc[i]), 2.) + (pow ((tc[i] - mc[i]), 2.)));
    }

}

void
calc_gelmanw (double *gelmanw, world_fmt * world, double *mc, double *tc,
	      long len, long lastn, long n)
{
  long i;
  double s1, s2;

  for (i = 0; i < len; i++)
    {
      s1 = calc_s (i, tc, world);
      s2 = calc_s (i, mc, world);
      gelmanw[i] = 0.5 * (s1 + s2);

    }
}


void
calc_gelmanr (double *gelmanr, double *gelmanw, double *gelmanb, long len,
	      long lastn, long n)
{
  long i;
  double nn = (n + lastn) / 2.;

  for (i = 0; i < len; i++)
    {
      gelmanr[i] =
	sqrt (((nn - 1.) / nn * gelmanw[i] + 1. / nn * gelmanb[i]) /
	      gelmanw[i]);
    }
}

void
calc_average_biggest_gelmanr (double *gelmanr, long len, double *meanR,
			      double *bigR)
{
  long i;
  double average = 0;
  double biggest = 0.;
  for (i = 0; i < len; i++)
    {
      if (biggest < gelmanr[i])
	biggest = gelmanr[i];
      average += gelmanr[i];
    }
  *meanR = average / len;
  *bigR = biggest;
}

void
print_gelmanr (double average, double biggest)
{
  fprintf (stdout, "           Average Gelman's R = %f\n", average);
  fprintf (stdout,
	   "           Largest Gelman's R = %f [Value < 1.2 show convergence]\n",
	   biggest);
}


double
calc_s (long this, double *tc, world_fmt * world)
{
  long i, j;
  double s = 0;
  long rep = world->rep;
  static long startp, startl, startkm;
  static boolean done = FALSE;

  if (!done)
    {
      startkm = world->numpop;
      startp = startkm + world->numpop;
      startl = startp + world->numpop;
    }
  if (this < startkm)
    {
      i = this;
      for (j = 0; j < world->atl[rep][world->locus].T; j++)
	s +=
	  (world->atl[rep][world->locus].tl[j].kt[i] - tc[i]) * (world->atl[rep][world->locus].tl[j].kt[i] -
						 tc[i]);
      s /= world->atl[rep][world->locus].T - 1.;
      return s;
    }
  else
    {
      if (this < startp)
	{
	  i = this - startkm;
	  for (j = 0; j < world->atl[rep][world->locus].T; j++)
	    s +=
	      (world->atl[rep][world->locus].tl[j].km[i] - tc[i]) * (world->atl[rep][world->locus].tl[j].km[i] -
						     tc[i]);
	  s /= world->atl[rep][world->locus].T - 1.;
	  return s;
	}
      else
	{
	  if (this < startl)
	    {
	      i = this - startp;
	      for (j = 0; j < world->atl[rep][world->locus].T; j++)
		s +=
		  (world->atl[rep][world->locus].tl[j].p[i] -
		   tc[i]) * (world->atl[rep][world->locus].tl[j].p[i] - tc[i]);
	      s /= world->atl[rep][world->locus].T - 1.;
	      return s;
	    }
	  else
	    {
	      i = this - startl;
	      for (j = 0; j < world->atl[rep][world->locus].T; j++)
		s +=
		  (world->atl[rep][world->locus].tl[j].mindex[i] -
		   tc[i]) * (world->atl[rep][world->locus].tl[j].mindex[i] - tc[i]);
	      s /= world->atl[rep][world->locus].T - 1.;
	      return s;
	    }
	}
    }
  return s;
}



void
chain_means (double *thischainmeans, world_fmt * world)
{
  static long startp, startl, startkm;
  static boolean done = FALSE;
  long i, j;

  if (!done)
    {
      done = TRUE;
      startkm = world->numpop;
      startp = startkm + world->numpop;
      startl = startp + world->numpop;
    }

  for (j = 0; j < world->atl[world->rep][world->locus].T; j++)
    {
      for (i = 0; i < world->numpop; i++)
	{
	  thischainmeans[i] += world->atl[world->rep][world->locus].tl[j].kt[i];
	  thischainmeans[i + startkm] += world->atl[world->rep][world->locus].tl[j].km[i];
	  thischainmeans[i + startp] += world->atl[world->rep][world->locus].tl[j].p[i];
	}
      for (i = 0; i < world->numpop2; i++)
	thischainmeans[i + startl] += world->atl[world->rep][world->locus].tl[j].mindex[i];
    }
  for (i = 0; i < world->numpop; i++)
    {
      thischainmeans[i] /= world->atl[world->rep][world->locus].T;
      thischainmeans[i + startkm] /= world->atl[world->rep][world->locus].T;
      thischainmeans[i + startp] /= world->atl[world->rep][world->locus].T;
    }
  for (i = startl; i < world->numpop2 + startl; i++)
    thischainmeans[i] /= world->atl[world->rep][world->locus].T;

}
