/*------------------------------------------------------
 Maximum likelihood estimation 
 of migration rate  and effectice population size
 using a Metropolis-Hastings Monte Carlo algorithm                            
 -------------------------------------------------------                        
 P A R A M E T E R E S T I M A T I O N   R O U T I N E S 

 estimates parameter for each locus
 using a Newton-Rapshon maximization
 

 Peter Beerli 1996, Seattle
 beerli@genetics.washington.edu
 $Id: parameter.c,v 1.1.1.1 1998/06/06 06:09:51 beerli Exp $
-------------------------------------------------------*/

#include "migration.h"
#include "world.h"

#ifdef DMALLOC_FUNC_CHECK
#include "dmalloc.h"
#endif

/* prototypes ------------------------------------------- */
void estimateParameter (timearchive_fmt * tyme, long G,
       world_fmt * world, double **dd, long chain, char type, char **plane);
double probG (double *param, tarchive_fmt * tl, long numpop);
double calc_like (nr_fmt * nr, tarchive_fmt * tyme, long G);
void derivatives_to_logderivatives (nr_fmt * nr);
void calc_cov (double **dd, double *d, double *param, long n);
/* private functions */
void derivatives (long trials, nr_fmt * nr, tarchive_fmt * tl, long G,
		  double *param, boolean forloci);
void solveParameters (tarchive_fmt * tyme, long G,
       world_fmt * world, double **dd, long chain, char type, char **plane);
void reset_nr (nr_fmt * nr);
void free_nr (nr_fmt * nr);
boolean is_singular (double **dd, long n);
void param_adjust (double *value, double oldval, double min, double max);
void param_all_adjust (nr_fmt * nr, double *param, long gamma_param);
/* calculate and adjust the new parameterset */
void calc_param (nr_fmt * nr, double *param, double lamda);

/* finds the biggest value in the vector */
double vector_max (double *v, long size);
/* calculate the norm sqrt(sum(v*v)) */
double norm (double *d, long size);
void print_contribution (nr_fmt * nr, tarchive_fmt * tyme, long G);


/*=======================================================*/

void 
estimateParameter (timearchive_fmt * tyme, long G,
	world_fmt * world, double **dd, long chain, char type, char **plane)
{
  switch (world->numpop)
    {				/* estimate new theta and m values */
    case 1:
      error ("The 1 world case with migration is not yet implemented!\n");
      break;
    case 2:
      solveParameters (tyme[0].tl, G, world, dd, chain, type, plane);
      break;
    default:
      error ("Multiworld estimators are not implemented yet! But will come!\n");
      break;
    }
}

/*=======================================================*/


/* parameter calculation with a specified
   interval sampling, not only the changed trees
   are sampled.

   1       Prob(G|Param)
   L(Param) = -  Sum[---------------, G]
   G       Prob(G|Param0)

   The parameter are found with a damped Newton-Raphson procedure.
 */
void 
solveParameters (tarchive_fmt * tyme, long G,
	world_fmt * world, double **dd, long chain, char type, char **plane)
{

  boolean notinverse;
  char *strllike, kind[20];
  long pop = -99, g, trials = -1, savety_belt;
  double **idd, *nld, lam1, lam2, lamda = 1., nld2, llike = -DBL_MAX, normd;
  nr_fmt *nr;

  nr = (nr_fmt *) calloc (1, sizeof (nr_fmt) * 1);
  nr->numpop = world->numpop;
  nr->numpop2 = world->numpop * 2;
  nr->partsize = (4 + nr->numpop2 * nr->numpop2);

  strllike = (char *) calloc (1, sizeof (char) * 128);
  nr->parts = (double *) calloc (1, nr->partsize * sizeof (double));
  nr->d = (double *) malloc (nr->numpop2 * sizeof (double));
  nr->param = (double *) malloc (nr->numpop2 * sizeof (double));
  nr->oparam = (double *) malloc (nr->numpop2 * sizeof (double));
  nr->datalike = (double *) malloc (G * sizeof (double));
  nr->apg0 = (double *) malloc (G * sizeof (double));
  nr->apg = (double *) malloc (G * sizeof (double));
  nld = (double *) malloc ((1 + NTRIALS) * sizeof (double));
  nr->dd = dd;
  idd = (double **) calloc (1, sizeof (double *) * nr->numpop2);
  idd[0] = (double *) calloc (1, sizeof (double) * nr->numpop2 * nr->numpop2);
  for (pop = 1; pop < nr->numpop2; pop++)
    {
      idd[pop] = idd[0] + pop * nr->numpop2;
    }
  memcpy (nr->param, world->param0, sizeof (double) * nr->numpop2);
  memcpy (nr->datalike, world->likelihood, sizeof (double) * G);

  /* Prob(G|Param0) */
  for (g = 0; g < G; g++)
    {
      nr->apg0[g] = probG (world->param0, &tyme[g], nr->numpop);
    }

  /* Newton Raphson loop */
  while (trials++ < NTRIALS)
    {
      reset_nr (nr);
      if (trials == 0)
	calc_like (nr, tyme, G);

      derivatives (trials, nr, tyme, G, world->param0, 0);
      derivatives_to_logderivatives (nr);
      normd = norm (nr->d, nr->numpop2);
      if (normd == 0.0)
	break;
      notinverse = is_singular (nr->dd, nr->numpop2);
      if (!notinverse)
	{
	  for (pop = 0; pop < nr->numpop2; pop++)
	    memcpy (idd[pop], nr->dd[pop], sizeof (double) * nr->numpop2);
	  invert_matrix (idd, 4);
	  if (nrcheck (nr->dd, idd, nr->d, 4, &lam1, &lam2, TRUE))
	    {
	      /* nrcheck calculates as a sideeffect 
	         the change in the newton case: 
	         d is already idd * gradient */
	      strcpy (kind, "NEWTON: ");
	      lamda = 1.;
	      nld[trials] = 1.0;
	    }
	  else
	    notinverse = TRUE;
	}
      else
	{
	  lam1 = lam2 = 1.0;
	}
      if (notinverse)
	{
	  if (lam2 > 0)
	    {
	      lamda = lam1 / lam2;
	      if (lamda >= 1.0)
		nld[trials] = normd;
	      else
		nld[trials] = normd * (lamda);
	    }
	  else
	    {
	      if (trials == 0)
		{
		  lamda = 1.0;
		  nld[0] = normd;
		}
	      else
		{
		  nld2 = normd;
		  if (nld2 == 0.0)
		    {
		      nld[trials] = 0.0;
		      lamda = 0.0;
		      fprintf (stderr, "norm(d) is 0.0 we should stop!");
		    }
		  else
		    {
		      lamda = vector_max (nld, trials) / nld2;
		      if (lamda >= 1.0)
			nld[trials] = normd;
		      else
			nld[trials] = nld2 * (lamda);
		    }
		}
	    }
	  strcpy (kind, "SLOW:   ");
	}
      calc_param (nr, world->param0, lamda);
      nr->ollike = nr->llike;
      calc_like (nr, tyme, G);
      savety_belt = 0;
      memcpy (nr->oparam, nr->param, nr->numpop2 * sizeof (double));
      if (nr->ollike > nr->llike)
	{			/* halfing if the new likelihood is worse than the old */
	  while (nr->llike - nr->ollike < -EPSILON && savety_belt++ < 50)
	    {
	      memcpy (nr->oparam, world->param0, nr->numpop2 * sizeof (double));
	      lamda /= 2.;
	      calc_param (nr, world->param0, lamda);
	      calc_like (nr, tyme, G);
	    }
	  if (savety_belt > 50)
	    {
	      fprintf (stderr, "halfing limit reached!\n");
	    }
	  memcpy (world->param0, nr->param, nr->numpop2 * sizeof (double));

	}
      else
	{			/* doubling if the new likelihood is better than the old */
	  while (nr->llike - nr->ollike > EPSILON && savety_belt++ < 50)
	    {
	      memcpy (nr->oparam, nr->param, nr->numpop2 * sizeof (double));
	      lamda *= 2.;
	      calc_param (nr, world->param0, lamda);
	      llike = nr->ollike;
	      nr->ollike = nr->llike;
	      nr->oPGC = nr->PGC;
	      calc_like (nr, tyme, G);
	    }
	  if (savety_belt > 0)
	    {
	      memcpy (world->param0, nr->oparam, nr->numpop2 * sizeof (double));
	      memcpy (nr->param, nr->oparam, nr->numpop2 * sizeof (double));
	      nr->llike = nr->ollike;
	      nr->ollike = llike;
	      nr->PGC = nr->oPGC;
	    }
	  else
	    {
	      calc_param (nr, world->param0, lamda);
	      memcpy (world->param0, nr->param, nr->numpop2 * sizeof (double));
	    }
	}
      if (!((((normd > 0.001) && (trials < NTRIALS))) || trials == 0))
	{
	  break;
	}
    }
  memcpy (world->param0, nr->param, nr->numpop2 * sizeof (double));
  llike = nr->llike;
  world->param_like = nr->llike;
  if (world->options->progress)
    {
      print_menu_chain (type, chain, G, world);
      if (world->options->verbose)
	{
	  print_contribution (nr, tyme, G);
	  fprintf (stdout, "           Maximization steps needed:   %li\n", trials);
	}
    }
  if (world->param_like < world->options->lcepsilon &&
      world->options->plotnow && !world->options->simulation)
    create_locus_plot (world, plane, tyme, nr, G);
  if (!world->options->simulation)
    calc_cov (nr->dd, nr->d, world->param0, 4);
  free (strllike);
  free_nr (nr);
  free (nld);
  free (idd[0]);
  free (idd);
}

/* calculates P(G | theta1,theta2,...,m1, m2,...)
   AND RETURNS a LOG(results)
 */

double 
probG (double *param, tarchive_fmt * tl, long numpop)
{
  long i;
  double result = 0;
  for (i = 0; i < numpop; i++)
    {
      result += tl->p[i] * (LOG2 - log (param[i])) + tl->l[i] * log (param[i + numpop]) - tl->km[i] * param[i + numpop] - tl->kt[i] / param[i];
    }
  return result;
}

boolean 
is_singular (double **dd, long n)
{
  long i, j;
  double temp;
  boolean singular = FALSE;
  for (i = 0; i < n; i++)
    {
      temp = 0.0;
      for (j = 0; j < n; j++)
	{
	  temp += dd[i][j];
	}
      if (temp == 0.0)
	{
	  singular = TRUE;
	  break;
	}
    }
  for (i = 0; i < n; i++)
    {
      temp = 0.0;
      for (j = 0; j < n; j++)
	{
	  temp += dd[i][j];
	}
      if (temp == 0.0)
	{
	  singular = TRUE;
	  break;
	}
    }
  return singular;
}


void 
param_all_adjust (nr_fmt * nr, double *param, long gamma_param)
{
  double ff, f = 1., denom;
  long i;
  boolean overboard = FALSE;

  double minima[5] =
  {SMALLEST_THETA, SMALLEST_THETA, SMALLEST_MIGRATION, SMALLEST_MIGRATION, SMALLEST_GAMMA};
  double maxima[5] =
  {BIGGEST_THETA, BIGGEST_THETA, BIGGEST_MIGRATION, BIGGEST_MIGRATION, BIGGEST_GAMMA};
  for (i = 0; i < 4 + gamma_param; i++)
    {
      if (nr->param[i] < minima[i] || nr->param[i] > maxima[i])
	{
	  overboard = TRUE;
	  break;
	}
    }
  if (overboard)
    {
      for (i = 0; i < 4 + gamma_param; i++)
	{
	  denom = nr->param[i] - param[i];
	  if (denom != 0)
	    {
	      ff = MIN (1., fabs ((minima[i] - param[i]) / denom));
	      ff = MIN (ff, fabs ((maxima[i] - param[i]) / denom));
	    }
	  else
	    ff = 1.;
	  if (ff < f)
	    f = ff;
	}
      if (f < 1.)
	{
	  for (i = 0; i < 4 + gamma_param; i++)
	    {
	      nr->param[i] = param[i] + f * (nr->param[i] - param[i]);
	    }
	}
    }
}


void 
calc_param (nr_fmt * nr, double *param, double lamda)
{
  long i;
  for (i = 0; i < nr->numpop2; i++)
    {
      nr->param[i] = param[i] * exp ((MAX (-100, MIN (-lamda * nr->d[i], 100))));
    }
  param_all_adjust (nr, param, 0);
}


double 
calc_like (nr_fmt * nr, tarchive_fmt * atl, long G)
{
  int g;
  double gsum = 0;
  nr->PGC = 0.0;
  nr->apg_max = -DBL_MAX;
  for (g = 0; g < G; g++)
    {
      nr->apg[g] = probG (nr->param, &atl[g], nr->numpop) - nr->apg0[g];
      if (nr->apg[g] > nr->apg_max)
	nr->apg_max = nr->apg[g];
    }
  for (g = 0; g < G; g++)
    {
      gsum += atl[g].copies;
      nr->apg[g] -= nr->apg_max;
      nr->PGC += atl[g].copies * exp (nr->apg[g]);
    }
  nr->llike = nr->apg_max + log (nr->PGC) - log (gsum);
  return nr->llike;
}

double 
norm (double *d, long size)
{
  int i;
  double sum = 0.;
  for (i = 0; i < (int) size; i++)
    {
      sum += d[i] * d[i];
    }
  return sqrt (sum);
}

double 
vector_max (double *v, long size)
{
  double maxval = -DBL_MAX;
  while (--size >= 0)
    if (v[size] > maxval)
      maxval = v[size];
  return maxval;
}

void 
reset_nr (nr_fmt * nr)
{
  long pop;
  memset (nr->d, 0, sizeof (double) * nr->numpop2);
  for (pop = 0; pop < nr->numpop2; pop++)
    memset (nr->dd[pop], 0, sizeof (double) * nr->numpop2);
  memset (nr->parts, 0, sizeof (double) * 2 * nr->numpop2);
}

void 
derivatives (long trials, nr_fmt * nr, tarchive_fmt * tl, long G,
	     double *param, boolean forloci)
{

  long g, j, i;
  double tsq1, tsq2, /*ttr1, ttr2, */ expapg, *thetas, *m;


  thetas = param;
  m = param + nr->numpop;
  tsq1 = param[0] * param[0];
  tsq2 = param[1] * param[1];
/*   ttr1 = tsq1 * param[0]; */
/*   ttr2 = tsq2 * param[1]; */

  for (g = 0; g < G; g++)
    {
      if (nr->apg[g] > -100)
	{
	  nr->parts[0] = (-tl[g].p[0] + tl[g].kt[0] / thetas[0]) / (thetas[0]);
	  nr->parts[1] = (-tl[g].p[1] + tl[g].kt[1] / thetas[1]) / (thetas[1]);
	  nr->parts[2] = tl[g].l[0] / m[0] - tl[g].km[0];
	  nr->parts[3] = tl[g].l[1] / m[1] - tl[g].km[1];
	  /* 2nd derivatives for x_i, x_i */
	  nr->parts[4] = (tl[g].p[0] - 2. * tl[g].kt[0] / thetas[0]) / (tsq1);
	  nr->parts[5] = (tl[g].p[1] - 2. * tl[g].kt[1] / thetas[1]) / (tsq2);
	  nr->parts[6] = -tl[g].l[0] / (m[0] * m[0]);
	  nr->parts[7] = -tl[g].l[1] / (m[1] * m[1]);
	  expapg = tl[g].copies * exp (nr->apg[g]);
	  for (i = 0; i < 4; i++)
	    {
	      nr->d[i] += expapg * nr->parts[i];
	      nr->dd[i][i] += expapg * (nr->parts[i] * nr->parts[i] + nr->parts[4 + i]);
	      for (j = 0; j < i; j++)
		nr->dd[i][j] += expapg * (nr->parts[i] * nr->parts[j]);
	    }
	}
    }
  for (i = 0; i < 4; i++)
    nr->d[i] /= nr->PGC;
  for (i = 0; i < 4; i++)
    {
      for (j = 0; j < i; j++)
	{
	  nr->dd[i][j] = -(nr->dd[i][j]) / nr->PGC + nr->d[i] * nr->d[j];
	}
      nr->dd[i][i] = -(nr->dd[i][i]) / nr->PGC + nr->d[i] * nr->d[i];
    }
  for (i = 0; i < 4; i++)
    nr->d[i] = -nr->d[i];
}



void 
calc_cov (double **dd, double *d, double *param, long n)
{
  long i, j;
  for (i = 0; i < n; i++)
    {
      for (j = 0; j < i; j++)
	{
	  dd[i][j] /= (param[i] * param[j]);
	  dd[j][i] = dd[i][j];
	}
      dd[i][i] = (dd[i][i] - param[i] * d[i]) / (param[i] * param[i]);
    }
  if (!is_singular (dd, n))
    invert_matrix (dd, n);
}


void 
free_nr (nr_fmt * nr)
{
  free (nr->parts);
  free (nr->param);
  free (nr->oparam);
  free (nr->datalike);
  free (nr->apg0);
  free (nr->apg);
  free (nr->d);
  free (nr);

}

/* change of variables: from parameters to log(parameters) */
void 
derivatives_to_logderivatives (nr_fmt * nr)
{
  long i, j;
  for (i = 0; i < nr->numpop2; i++)
    {
      for (j = 0; j < i; j++)
	{
	  nr->dd[i][j] = nr->param[i] * nr->param[j] * nr->dd[i][j];
	  nr->dd[j][i] = nr->dd[i][j];
	}
      nr->dd[i][i] = nr->param[i] * nr->d[i] + nr->param[i] *
	nr->param[i] * nr->dd[i][i];
    }
  for (i = 0; i < nr->numpop2; i++)
    {
      nr->d[i] = nr->param[i] * nr->d[i];
    }
}

void 
print_contribution (nr_fmt * nr, tarchive_fmt * tyme, long G)
{
  long g;
  long contribution[11];
  for (g = 0; g < 11; g++)
    contribution[g] = 0;
  for (g = 0; g < G; g++)
    {
      if (nr->apg[g] > -20)
	{
	  contribution[9 - (long) (fabs (nr->apg[g]) / 2)] += tyme[g].copies;
	}
      contribution[10] += tyme[g].copies;
    }
  fprintf (stdout, "           log(P(g|Param))  -20 to ");
  for (g = -18; g <= 0; g += 2)
    {
      fprintf (stdout, "%4li ", g);
    }
  fprintf (stdout, "  All\n");
  fprintf (stdout, "           Counts                  ");
  for (g = 0; g < 10; g++)
    {
      fprintf (stdout, "%4li ", contribution[g]);
    }
  fprintf (stdout, "%5li\n", contribution[10]);
}
