/*------------------------------------------------------
 Maximum likelihood estimation 
 of migration rate  and effectice population size
 using a Metropolis-Hastings Monte Carlo algorithm                            
 ------------------------------------------------------- 
 AIC model test   R O U T I N E S 

 Peter Beerli 2001, Seattle
 beerli@genetics.washington.edu
Copyright 2001 Peter Beerli and Joseph Felsenstein

$Id: aic.c,v 1.9 2001/07/25 19:27:23 beerli Exp $

-------------------------------------------------------*/
#define SICK_VALUE    -1
#include "migration.h"
#include "tools.h"
#include "broyden.h"
#include "combroyden.h"
#include "options.h"
#include "sort.h"
#include "aic.h"

#ifdef DMALLOC_FUNC_CHECK
#include <dmalloc.h>
#endif

void aic_score (aic_fmt ** aicvec, long *aicnum, nr_fmt * nr,
		long zero, long which, char *tempattern, double *param0,
		char migtype);

void aic_score_minus (aic_fmt ** aicvec, long *aicnum, nr_fmt * nr,
		      long zero, long which, char *tempattern,
		      double *param0, char migtype);

boolean legal_pattern (char *matrix, long numpop);

void fast_aic_score (aic_fmt ** aicvec, long *aicnum, nr_fmt * nr,
		     long zero, long which, char *tempattern,
		     double *param0, char migtype);


void add_aiclist (char migtype, long numparam,
		  double aic, char *tempattern, char *custm2,
		  long *aicnum, aic_fmt ** aicvec, nr_fmt * nr);

// changes a linear matrix of ddd mm mm mm
// to a diagonal matrix dmm mdm mmd
// it destroys pattern
char *
reshuffle (char *pattern, char *origpattern, long numpop)
{
  long space = 0;
  long i, j, z = numpop;

  for (i = 0; i < numpop; i++)
    {
      for (j = 0; j < numpop; j++)
	{
	  if (i == j)
	    pattern[i * numpop + j + space] = 'x';
	  else
	    pattern[i * numpop + j + space] = origpattern[z++];
	}
      pattern[i * numpop + j + space] = ' ';
      space++;
    }
  return pattern;
}


void
print_header_aic (nr_fmt * nr, double mleaic)
{
  world_fmt *world = nr->world;

  if (world->options->progress)
    fprintf (stdout,
	     "\n\n           Selecting the best migration model for this run,\n           This may take a while!\n           All parameter \
combinations\n           with  (AIC= -2 Log(L(Param)+n_param)<%f+%f\n",
	     mleaic, world->options->aicmod * world->numpop2);
  if (world->options->writelog)
    fprintf (world->options->logfile,
	     "\n\n           Selecting the best migration model for this run,\n           This may take a while!\n           All parameter \
combinations\n           with  (AIC= -2 Log(L(Param)+n_param)<%f+%f\n",
	     mleaic, world->options->aicmod * world->numpop2);
  fprintf (world->outfile, "\n\n\n Akaike's Information Criterion  (AIC)\n");
  fprintf (world->outfile, "=========================================\n\n");
  fprintf (world->outfile, "[Linearized migration matrix, x=diagonal]\n");
  fprintf (world->outfile,
	   "%-*.*s            AIC     #param   Log(L)   LRT     Prob   Probc\n",
	   (int) MAX (18, nr->partsize + nr->numpop),
	   (int) (nr->partsize + nr->numpop), "Pattern");
}

void
akaike_information (world_fmt * world, long *Gmax)
{
  long i;

  boolean mldone = FALSE;
  nr_fmt *nr;
  char *testpat;
  char *pattern;
  char *temppattern;
  char *savecustm, *savecustm2;
  long kind = world->loci > 1 ? MULTILOCUS : SINGLELOCUS;
  long repstop;
  long repstart;
  double *param0;
  aic_fmt *aicvec;

  long aicnum;
  double mleaic;
  double mle;
  boolean multilocus;
  prepare_broyden (kind, world, &multilocus);
  world->options->migration_model = MATRIX_ARBITRARY;
  savecustm = calloc (world->numpop2, sizeof (char));
  savecustm2 = calloc (world->numpop2, sizeof (char));
  testpat = calloc (world->numpop2, sizeof (char));
  memset (testpat, 0x31, sizeof (char) * world->numpop);
  pattern = calloc (world->numpop2 + 1, sizeof (char));
  temppattern = calloc (world->numpop2 + 1 + world->numpop, sizeof (char));
  param0 = calloc (world->numpop2 + 1, sizeof (double));

  set_replicates (world, world->repkind, world->rep, &repstart, &repstop);

  if (kind == MULTILOCUS)
    {
      mle = world->atl[0][world->loci].param_like;
      mleaic = -2. * mle + 2. * world->numpop2;
    }
  else
    {
      mle = world->atl[repstop == 1 ? 0 : repstop][0].param_like;
      mleaic = -2. * mle + 2. * world->numpop2;
    }
  nr = (nr_fmt *) calloc (1, sizeof (nr_fmt));

  create_nr (nr, world, *Gmax, 0, world->loci, world->repkind, world->rep);

  setup_parameter0 (world, nr, world->repkind, repstart, repstop,
		    world->loci, kind, multilocus);

  memcpy (savecustm, world->options->custm2, sizeof (char) * nr->numpop2);
  memcpy (savecustm2, world->options->custm, sizeof (char) * nr->numpop2);
  print_header_aic (nr, mleaic);

  if (kind == MULTILOCUS)
    memcpy (param0, nr->world->atl[0][nr->world->loci].param,
	    sizeof (double) * nr->numpop2);
  else
    memcpy (param0,
	    nr->world->atl[repstop ==
			   1 ? 0 : repstop][nr->world->locus].param,
	    sizeof (double) * nr->numpop2);

  // calculates akaike information score
  aicnum = 1;
  aicvec = (aic_fmt *) calloc (aicnum, sizeof (aic_fmt));
  aicvec[0].mle = mle;
  aicvec[0].aic = mleaic;
  aicvec[0].lrt = 0.0;
  aicvec[0].prob = 1.0;
  aicvec[0].probcorr = 1.0;
  aicvec[0].numparam = nr->partsize;
  aicvec[0].pattern = (char *) calloc (nr->partsize + 1, sizeof (char));
  memcpy (aicvec[0].pattern, world->options->custm2,
	  sizeof (char) * nr->partsize);
  if (world->options->fast_aic)
    {
      fast_aic_score (&aicvec, &aicnum, nr, 0, world->numpop,
		      temppattern, param0, '0');
      memcpy (world->options->custm2, savecustm, sizeof (char) * nr->numpop2);
      fast_aic_score (&aicvec, &aicnum, nr, 0, world->numpop,
		      temppattern, param0, 'm');
    }
  else
    {
      //find aic scores in a branch-and-bound fashion
      // with some parameters set to zero, this needs more
      // investigation because of boundary problems
      aic_score (&aicvec, &aicnum, nr, 0, world->numpop,
		 temppattern, param0, '0');
      // aic scores based on averaging M (not 4Nm)
///      memcpy (world->options->custm2, savecustm, sizeof (char) * nr->numpop2);
      //aic_score (&aicvec, &aicnum, nr, 0, world->numpop, temppattern, param0, 'm');
      memcpy (world->options->custm2, savecustm, sizeof (char) * nr->numpop2);
      //  aic_score_minus(&aicvec, &aicnum, nr, 0 , world->numpop2-1, 
      //          temppattern, param0);
    }
  qsort ((void *) aicvec, aicnum, sizeof (aic_fmt), aiccmp);
  for (i = 0; i < aicnum; i++)
    {
      fprintf (world->outfile, "%-*.*s %20.5f %4li %f %f %f %f\n",
	       (int) MAX (18, nr->partsize + nr->numpop),
	       (int) (nr->partsize + nr->numpop),
	       reshuffle (temppattern, aicvec[i].pattern, nr->numpop),
	       aicvec[i].aic, aicvec[i].numparam, aicvec[i].mle,
	       aicvec[i].lrt, aicvec[i].prob, aicvec[i].probcorr);
      if (aicvec[i].aic == mleaic && !mldone)
	{
	  mldone = TRUE;
	  fprintf (world->outfile, "%-*.*s%21.21s-----\n",
		   (int) MAX (18, nr->partsize + nr->numpop),
		   (int) (nr->partsize + nr->numpop),
		   "--------------------------------------------------------------------------------------------------------------------------------------------------------------",
		   "---------------------");
	}
      free (aicvec[i].pattern);
    }
  memcpy (world->options->custm2, savecustm, sizeof (char) * nr->numpop2);
  memcpy (world->options->custm, savecustm2, sizeof (char) * nr->numpop2);
  free (aicvec);
  fflush (world->outfile);
  free (param0);
  free (pattern);
  free (temppattern);
  free (testpat);
  free (savecustm);
  destroy_nr (nr, world);
}

void
fast_aic_score (aic_fmt ** aicvec, long *aicnum, nr_fmt * nr,
		long zero, long which, char *tempattern,
		double *param0, char migtype)
{
  long m, ii;
  double likes = 0;
  double normd = 0;
  double aic;
  double borderaic =
    (*aicvec)[0].aic + nr->world->options->aicmod * nr->numpop2;
  char savecustm2;
  long remainnum = 0;
  boolean legal;
  char *custm2 = nr->world->options->custm2;
  char *scustm2;
  long numparam;
  long freeparam;
  aic_fmt *best;

  scustm2 = (char *) calloc (nr->partsize, sizeof (char));
  memcpy (scustm2, custm2, sizeof (char) * nr->partsize);
  if (migtype == 'm')
    remainnum = 1;
  numparam = zero;
  freeparam = (nr->numpop2 - numparam - 1 + remainnum);
  best = (aic_fmt *) calloc (nr->partsize, sizeof (aic_fmt));
  for (m = nr->numpop; m < nr->numpop2; m++)
    {
      best[m].aic = DBL_MAX;
      best[m].numparam = m;
      if (scustm2[m] == migtype)
	continue;
      savecustm2 = custm2[m];
      custm2[m] = migtype;
      memcpy (nr->world->param0, param0, sizeof (double) * nr->numpop2);
      resynchronize_param (nr->world);
      if ((legal = legal_pattern (nr->world->options->custm2, nr->numpop)))
	{
	  do_profiles (nr->world, nr, &likes, &normd, PROFILE,
		       nr->world->rep, nr->world->repkind);
	  aic = -2. * nr->llike + 2. * freeparam;
	  best[m].aic = aic;
	  best[m].numparam = m;
	  add_aiclist (migtype, numparam + 1 - remainnum,
		       aic, tempattern, custm2, aicnum, aicvec, nr);
	}
      else
	{			// illegal combination of parameters
	  if (nr->world->options->progress)
	    fprintf (stdout, "           F   %s %20s\n",
		     reshuffle (tempattern, custm2, nr->numpop), "-----");
	  fflush (stdout);
	  if (nr->world->options->writelog)
	    fprintf (nr->world->options->logfile, "           F   %s %20s\n",
		     reshuffle (tempattern, custm2, nr->numpop), "-----");

	}
      custm2[m] = savecustm2;
    }
  for (ii = nr->numpop; ii < nr->partsize; ii++)
    {
      if (best[ii].aic < borderaic && custm2[best[ii].numparam] != migtype)
	{
	  custm2[best[ii].numparam] = migtype;
	  fast_aic_score (aicvec, aicnum, nr, zero + 1, best[ii].numparam,
			  tempattern, param0, migtype);
	}
    }
  free (best);
}


void
add_aiclist (char migtype, long numparam,
	     double aic, char *tempattern, char *custm2,
	     long *aicnum, aic_fmt ** aicvec, nr_fmt * nr)
{
  double lrt = -DBL_MAX;
  double prob;
  double probcorr;
  if (aic < (*aicvec)[0].aic + nr->world->options->aicmod * nr->numpop2)
    {
      if (migtype != 'm' || (migtype == 'm' && nr->world->options->mmn > 1))
	{
	  *aicvec = (aic_fmt *)
	    realloc (*aicvec, sizeof (aic_fmt) * (*aicnum + 1));
	  (*aicvec)[*aicnum].pattern = (char *)
	    calloc (nr->partsize + 1, sizeof (char));
	  (*aicvec)[*aicnum].aic = aic;
	  (*aicvec)[*aicnum].mle = nr->llike;
	  (*aicvec)[*aicnum].numparam = nr->numpop2 - numparam;
	  memcpy ((*aicvec)[*aicnum].pattern, custm2,
		  sizeof (char) * nr->partsize);

	  (*aicvec)[*aicnum].lrt = -2. * (nr->llike - (*aicvec)[0].mle);
#ifdef AICTEST
	  (*aicvec)[*aicnum].prob =
	    probchi (numparam, (*aicvec)[*aicnum].lrt);
	  (*aicvec)[*aicnum].probcorr =
	    probchiboundary ((*aicvec)[*aicnum].lrt, numparam, numparam);
#endif

	  if (nr->world->options->progress)
	    fprintf (stdout, "           +   %s %20.5f %f\n",
		     reshuffle (tempattern, custm2, nr->numpop), aic,
		     (*aicvec)[*aicnum].lrt);
	  if (nr->world->options->writelog)
	    fprintf (nr->world->options->logfile,
		     "           +   %s %20.5f %f\n",
		     reshuffle (tempattern, custm2, nr->numpop), aic,
		     (*aicvec)[*aicnum].lrt);
	  fflush (stdout);
	  (*aicnum)++;
	}
    }
  else
    {
      if (nr->world->options->progress)
	{
	  lrt = -2. * (nr->llike - (*aicvec)[0].mle);

	  prob = probchi (numparam, (*aicvec)[*aicnum].lrt);
	  probcorr = probchiboundary ((*aicvec)[*aicnum].lrt,
				      numparam, numparam);


	  fprintf (stdout, "           -   %s %20.5f %f\n",
		   reshuffle (tempattern, custm2, nr->numpop), aic,
		   (*aicvec)[*aicnum].lrt);
	  fflush (stdout);
	}
      if (nr->world->options->writelog)
	{
	  if (!nr->world->options->progress)
	    {
	      lrt = -2. * (nr->llike - (*aicvec)[0].mle);

	      prob = probchi (numparam, (*aicvec)[*aicnum].lrt);
	      probcorr = probchiboundary ((*aicvec)[*aicnum].lrt,
					  numparam, numparam);
	      //nr->world->options->zeron,
	      //nr->numpop2)
	    }
	  fprintf (nr->world->options->logfile,
		   "           -   %s %20.5f %f\n",
		   reshuffle (tempattern, custm2, nr->numpop), aic, lrt);
	}
    }
}


void
aic_score (aic_fmt ** aicvec, long *aicnum, nr_fmt * nr,
	   long zero, long which, char *tempattern, double *param0,
	   char migtype)
{
  long m;
  long i;
  //  long kind = nr->world->loci>1 ? MULTILOCUS : SINGLELOCUS;
  double likes = 0;
  double normd = 0;
  double aic;
  char savecustm2;
  long remainnum = 0;
  boolean legal;
  char *custm2 = nr->world->options->custm2;
  long numparam = 0;
  long freeparam;
  double lrt;
  double prob = 1.0;
  double probc = 1.0;
  switch (migtype)
    {
    case '0':
      numparam = nr->world->options->zeron;
      remainnum = 0;
      break;
    case 'm':
      numparam = nr->world->options->mmn;
      remainnum = 1;
      if (nr->world->options->custm2[which] == 'm')
	return;
      break;
    }
  freeparam = (nr->numpop2 - numparam - 1 + remainnum);
  for (m = which; m < nr->numpop2; m++)
    {
      savecustm2 = custm2[m];
      custm2[m] = migtype;
      memcpy (nr->world->param0, param0, sizeof (double) * nr->numpop2);
      resynchronize_param (nr->world);

      if ((legal = legal_pattern (nr->world->options->custm2, nr->numpop)))
	{
	  do_profiles (nr->world, nr, &likes, &normd, PROFILE,
		       nr->world->rep, nr->world->repkind);
	  aic = -2. * nr->llike + 2. * freeparam;
	  lrt = -2. * (nr->llike - (*aicvec)[0].mle);
	  if (aic < (*aicvec)[0].aic + nr->world->options->aicmod * freeparam)
	    //memo        df = (nr->numpop2 - numparam - 1 + remainnum);
	    {
	      *aicvec =
		(aic_fmt *) realloc (*aicvec,
				     sizeof (aic_fmt) * (*aicnum + 1));
	      (*aicvec)[*aicnum].pattern =
		(char *) calloc (nr->partsize + 1, sizeof (char));
	      (*aicvec)[*aicnum].aic = aic;
	      (*aicvec)[*aicnum].mle = nr->llike;
	      (*aicvec)[*aicnum].numparam =
		nr->numpop2 - numparam - 1 + remainnum;
	      memcpy ((*aicvec)[*aicnum].pattern, custm2,
		      sizeof (char) * nr->partsize);
	      (*aicvec)[*aicnum].lrt = -2. * (nr->llike - (*aicvec)[0].mle);

	      (*aicvec)[*aicnum].prob =
		probchi (numparam, (*aicvec)[*aicnum].lrt);
	      (*aicvec)[*aicnum].probcorr =
		probchiboundary ((*aicvec)[*aicnum].lrt, numparam, numparam);

	      if (nr->world->options->aicfile)
		{
		  fprintf (nr->world->options->aicfile, "%f %f %li %f  %f ",
			   aic, (*aicvec)[*aicnum].lrt,
			   (*aicvec)[*aicnum].numparam,
			   (*aicvec)[*aicnum].prob,
			   (*aicvec)[*aicnum].probcorr);

		  for (i = 0; i < nr->partsize; i++)
		    fprintf (nr->world->options->aicfile, "%f ",
			     nr->world->param0[i]);
		  fprintf (nr->world->options->aicfile, "\n");
		}


	      if (nr->world->options->progress)
		fprintf (stdout,
			 "           +   %s %20.5f %3li %8.4f %8.4f %6.4f %6.4f\n",
			 reshuffle (tempattern, custm2, nr->numpop), aic,
			 freeparam, (*aicvec)[*aicnum].mle,
			 (*aicvec)[*aicnum].lrt, (*aicvec)[*aicnum].prob,
			 (*aicvec)[*aicnum].probcorr);
	      if (nr->world->options->writelog)
		fprintf (nr->world->options->logfile,
			 "           +   %s %20.5f %3li %8.4f %8.4f %6.4f %6.4f\n",
			 reshuffle (tempattern, custm2, nr->numpop), aic,
			 freeparam, (*aicvec)[*aicnum].mle,
			 (*aicvec)[*aicnum].lrt, (*aicvec)[*aicnum].prob,
			 (*aicvec)[*aicnum].probcorr);
	      fflush (stdout);

	      (*aicnum)++;

	      aic_score (aicvec, aicnum, nr, zero + 1, m + 1,
			 tempattern, param0, migtype);
	    }
	  else
	    {
	      if (nr->world->options->progress)
		{
		  prob = probchi (numparam, lrt);
		  probc = probchiboundary (lrt, numparam, numparam);
		  fprintf (stdout,
			   "           -   %s %20.5f %3li %8.4f %8.4f %6.4f %6.4f\n",
			   reshuffle (tempattern, custm2, nr->numpop), aic,
			   freeparam, nr->llike, lrt, prob, probc);
		}
	      if (nr->world->options->writelog)
		{
		  if (!nr->world->options->progress)
		    {
		      prob = probchi (numparam, lrt);
		      probc = probchiboundary (lrt, numparam, numparam);
		    }
		  fprintf (nr->world->options->logfile,
			   "           -   %s %20.5f %3li %8.4f %8.4f %6.4f %6.4f\n",
			   reshuffle (tempattern, custm2, nr->numpop), aic,
			   freeparam, nr->llike, lrt, prob, probc);

		  fflush (stdout);
		}
	    }
	}
      else
	{
	  if (nr->world->options->progress)
	    {
	      fprintf (stdout, "           F   %s %20s\n",
		       reshuffle (tempattern, custm2, nr->numpop), "-----");
	      fflush (stdout);
	    }
	  if (nr->world->options->writelog)
	    fprintf (nr->world->options->logfile, "           F   %s %20s\n",
		     reshuffle (tempattern, custm2, nr->numpop), "-----");
	}
      custm2[m] = savecustm2;
    }
}

boolean
check_numparam (long which, long migtype, worldoption_fmt * options,
		long *numparam, long *remainnum)
{
  boolean rc = FALSE;
  switch (migtype)
    {
    case '0':
      *numparam = options->zeron;
      *remainnum = 0;
      break;
    case 'm':
      *numparam = options->mmn;
      *remainnum = 1;
      if (options->custm2[which] == 'm')
	rc = TRUE;
      break;
    }
  return rc;
}

/*
void  aic_score_minus(aic_fmt **aicvec, long *aicnum, nr_fmt *nr, 
		      long zero, long which, char *tempattern, 
		      double *param0,
		      char migtype)
{
  long m;
  long kind = nr->world->loci>1 ? MULTILOCUS : SINGLELOCUS;
  long repstop = !nr->world->options->replicate ? 0 : 
    (nr->world->options->replicatenum == 0 ? 
     nr->world->options->lchains : nr->world->options->replicatenum); 
  double likes=0;
  double normd=0;
  double aic;
  char savecustm2;
  boolean legal;
  long df;
not_set_yet
  char *custm2 = nr->world->options->custm2;
  for(m = which; m >= nr->world->numpop; m--)
    {
      savecustm2 = custm2[m];
      custm2[m] = '0';
      if(kind==MULTILOCUS)
	memcpy(nr->world->param0,nr->world->atl[0][nr->world->loci].param,
	       sizeof(double) * nr->numpop2);
      else
	memcpy(nr->world->param0,nr->world->atl[repstop][nr->world->locus].param,
	       sizeof(double) * nr->numpop2);
      resynchronize_param (nr->world);
      if((legal=legal_pattern(nr->world->options->custm2, 
				      nr->numpop)))
	{
	  do_profiles(nr->world, nr, &likes, &normd, PROFILE, 
		      nr->world->rep, nr->world->repkind);
	  aic = -2. * nr->llike + df;
	    //memo	  df = (nr->numpop2 - numparam - 1 + remainnum);
	  if(aic < (*aicvec)[0].aic +  nr->world->options->aicmod * nr->numpop2)
	    {
	      *aicvec = (aic_fmt *) realloc(*aicvec, sizeof(aic_fmt) * (*aicnum + 1));
	      (*aicvec)[*aicnum].pattern = (char *) calloc(nr->partsize+1,sizeof(char));
	      (*aicvec)[*aicnum].aic = aic;
	      memcpy((*aicvec)[*aicnum].pattern, custm2, sizeof(char)*nr->partsize);
	      if(nr->world->options->progress)
		fprintf(stdout, "           +   %s %20.5f %f\n", reshuffle(tempattern,custm2, nr->numpop),aic,(*aicvec)[*aicnum].lrt);  fflush(stdout);
	      (*aicnum)++;
	      
	      aic_score_minus(aicvec,aicnum, nr, zero + 1, m - 1, tempattern, param0, migtype);
	    }
	  else
	    {
	      if(nr->world->options->progress)
		fprintf(stdout,"           -   %s %20.5f %f\n", 
			reshuffle(tempattern, custm2, nr->numpop),
			aic, (*aicvec)[*aicnum].lrt); fflush(stdout);
	    }
	}
      else
	{
	  if(nr->world->options->progress)
	    fprintf(stdout,"           F   %s %s\n", reshuffle(tempattern, custm2, nr->numpop), "-----");fflush(stdout);
	}
      custm2[m] = savecustm2;
    }
}
*/

boolean
legal_pattern (char *matrix, long numpop)
{
  long from, to, i;
  double summ = -1;
  double oldto;
  for (i = 0; i < numpop; i++)
    {
      if (matrix[i] == '0')
	return FALSE;
    }
  oldto = -1;
  for (i = numpop; i < numpop * numpop; i++)
    {
      m2mm (i, numpop, &from, &to);
      if (oldto != to)
	{
	  if (summ == 0)
	    return FALSE;
	  oldto = to;
	  summ = 0;
	}
      summ += (matrix[i] != '0') + (matrix[mm2m (to, from, numpop)] != '0');
    }
  if (summ == 0)
    return FALSE;
  return TRUE;
}
