/*------------------------------------------------------
Maximum likelihood estimation
of migration rate  and effectice population size
using a Metropolis-Hastings Monte Carlo algorithm
-------------------------------------------------------
    variation over time routines   R O U T I N E S

    Peter Beerli 2006, Tallahassee
    beerli@scs.fsu.edu

    Copyright 2006 Peter Beerli, Tallahassee

    This software is distributed free of charge for non-commercial use
    and is copyrighted. Of course, we do not guarantee that the software
    works and are not responsible for any damage you may cause or have.


 $Id$
    */
/*! \file skyline.c 

this file contains functions that calculate the expected parameters based on 
individual time intervals. Results will be printed as a histogram over time (similar to the 
skyline plots of Rambaut, Strimmer etc) and also will print a table/file with values for each parameter

*/
#include <stdlib.h>

#include "skyline.h"
#include "random.h"
#include "tools.h"
#include "sighandler.h"
#include "world.h"
#ifdef PRETTY
#include "pretty.h"
#endif
#ifdef MPI
#include "migrate_mpi.h"
#else
extern int myID;
#endif

///
/// calculate the expected Theta for a coalescence event and returns the expected value
MYREAL waiting_theta(long pop, 
		     MYREAL *param, 
		     MYREAL * mig0list, 
		     long *lineages, 
		     MYREAL interval, 
		     MYREAL inv_mu_rate,
		     long numpop)
{
  long i;
  long kpop;
  MYREAL expected = 0.0;
  
  for(i=0;i<numpop; i++)
    {
      expected += (lineages[i] * (lineages[i]-1) * (1.0/ param[i]) +  mig0list[i] * lineages[i]) ;
    }
  kpop = lineages[pop] * ( lineages[pop] -1);
  expected = 1./interval - (expected - kpop * 1.0 / param[pop])*inv_mu_rate;
  // 
  if(expected > EPSILON)
    return kpop / expected;
  else
    return -90.;
}

///
/// calculate the expected M for a coalescence event and returns the expected value
MYREAL waiting_M(long pop,long to, 
		 MYREAL *param, 
		 MYREAL * mig0list, 
		 long *lineages, 
		 MYREAL interval, 
		 MYREAL inv_mu_rate,
		 long numpop)
{
  long i;
  long kpop;
  MYREAL expected = 0.0;
  
  for(i=0;i<numpop; i++)
    {
      expected += (lineages[i] * (lineages[i]-1) * (1.0/ param[i]) +  mig0list[i] * lineages[i]);
    }
  kpop = lineages[to];
  expected = 1./interval - (expected - kpop * param[pop]) * inv_mu_rate ;
  if(expected > EPSILON)
    return expected / kpop;
  else
    return -9.;
}

///
/// Calculate the expected parameters given the event
void calculate_expected_values(tetra **eventbins, 
			       long *eventbinnum, 
			       MYREAL eventinterval, 
			       MYREAL interval, 
			       MYREAL age, 
			       long from, 
			       long to, 
			       long * lineages, 
			       long numpop, 
			       world_fmt *world)
{
  //  static MYREAL maxlweight = -HUGE;
  //  MYREAL oldlweight = -HUGE;
  boolean change;
  //  long p;

  float lweight;
  //  float unweight;
  long i;
  long pop;
  long diff;
  long allocsize=1;

  float val;
  float se1;
  //  float se2;

  float weight1 = 0.0F ;
  float weight2 = 0.0F ;

  MYREAL mu_rate = world->options->mu_rates[world->locus];
  MYREAL inv_mu_rate = 1./ mu_rate;
  MYREAL inv_eventinterval = 1. / eventinterval;
  long firstbin = (long) ((age - interval) * inv_eventinterval);
  long bin     = (long) (age * inv_eventinterval);
  float binage  = eventinterval * bin; //this is on the left side of the bin
  MYREAL firstbinage        = eventinterval * firstbin; //this is the age at the left side of first bin
  tetra *bins=NULL;

  // after discussion with Mike P. 09/28/07 DEBUG
  //if (maxlweight >= world->likelihood[world->G])
  //  {
  //    lweight = EXP(world->likelihood[world->G] - maxlweight);
  //    change = FALSE;
  //    unweight = 1.0;
  //  }
  //else
  //  {
  //    oldlweight = maxlweight;
  //    maxlweight = world->likelihood[world->G]; 
  //    unweight = EXP( -maxlweight + oldlweight);
  //    lweight = 1.0;
  //    change = TRUE;
  //  }
  // enable this to disable likelihood weights
  change =FALSE;
  lweight = 1.0;
  // do we need to report probles when bin number is above 10000?
  // 
  if(age > 1000.0 || bin > 10000)
    return;

  if(firstbin < 0) 
    firstbin=0;


  if(bin < 0) 
    error("bin is smaller than zero");

  if(from == to)
    {
      diff = bin - firstbin;
      pop  = to;
      val = waiting_theta(pop,world->param0,world->mig0list, lineages,interval, inv_mu_rate, numpop); 
      if(val < 0.0)
	return;
      //      printf("%i> k(k-1)=%li val=%f pop=%li\n",myID,k,val, pop);
      se1  = val * val;
      //se2  = se1;
    } 
  else
    {
      diff = 0;
      pop = mm2m(from, to, numpop);
      //val = lineages[to] / interval ;
      val = waiting_M(pop, to, world->param0,world->mig0list, lineages,interval, inv_mu_rate, numpop); 
      if(val < 0.0)
	return;
      se1 = val * val;
      //se2 = 1.0;
    }
  if(bin >= eventbinnum[pop])
    {
      allocsize      = bin+10;
      eventbins[pop] = myrealloc(eventbins[pop], allocsize * sizeof(tetra));
      for(i=eventbinnum[pop];i < allocsize; i++)
	{
	  eventbins[pop][i][0] = 0.;
	  eventbins[pop][i][1] = 0.;
	  eventbins[pop][i][2] = 0.;
	  eventbins[pop][i][3] = 0.;
	  eventbins[pop][i][4] = 0.;
	}
      eventbinnum[pop] = allocsize;
    }

  // weights are the contribution of the estimate to the bin
  if(diff < 0)
      error("time difference is negative -- not possible");

  /*================================ complicated weighting*/
  //if(change)
  //  {
  //    for(p=0; p < world->numpop2; p++)
  //	{
  //	  bins = eventbins[p];
  //	  for(i=0; i < eventbinnum[p]; i++)
  //	    {
  //	      if(bins[i][0]>0.0)
  //		{
  //		  bins[i][0] *= unweight; 
  //		  bins[i][1] *= unweight; 
  //		  bins[i][2] *= unweight; 
  //		  bins[i][3] *= unweight; 
  //		}
  //	    }
  //	}
  //  }
  bins = eventbins[pop];
  switch(diff)
    {
      //start and stop bin are the same the whole interval does not extend over bin boundaries
    case 0: 
      weight1 = interval * inv_eventinterval; // or =lweight=1.0
      bins[bin][0] +=  val * weight1; // parameter of interest
      bins[bin][1] += weight1;        // its weight per bin
      bins[bin][2] += se1 * weight1;  // std of parameter 
      bins[bin][3] += weight1 * weight1; // std of weight
      bins[bin][4] += 1.0;
      break;
      // the interval crosses a single bin boundary
    case 1: 
      // older bin
      weight1 = (age - binage) * inv_eventinterval; 
      bins[bin][0] +=  val * weight1;
      bins[bin][1] += weight1;
      bins[bin][2] += se1 * weight1;
      bins[bin][3] += weight1 * weight1;
      bins[bin][4] += 1.0;
      // newer bin
      weight2 = (binage - age + interval) * inv_eventinterval; 
      bins[firstbin][0] +=  val * weight2;
      bins[firstbin][1] += weight2;
      bins[firstbin][2] += se1 * weight2;
      bins[firstbin][3] += weight2 * weight2;
      bins[firstbin][4] += 1.0;
      break;
      // start and stop are in different bins and more than one boundary apart
    default:
      //this is old [081406] weight1 = (age-interval+firstbinage+eventinterval) * inv_eventinterval;
      weight1 = (-age + interval + firstbinage + eventinterval) * inv_eventinterval;
      bins[firstbin][0] +=  val * weight1;
      bins[firstbin][1] += weight1;
      bins[firstbin][2] += se1 * weight1;
      bins[firstbin][3] += weight1 * weight1;
      bins[firstbin][4] += 1.0;
      for(i=firstbin+1; i < bin; i++)
	{
	  bins[i][0] += val;
	  bins[i][1] += lweight; // = 1.0
	  bins[i][2] += se1;
	  bins[i][3] += lweight * lweight; /* 1 x 1*/
	  bins[i][4] += 1.0;
	}
      weight2 = (age - binage) * inv_eventinterval; 
      bins[bin][0] += weight2 * val;
      bins[bin][1] += weight2;
      bins[bin][2] += se1 * weight2;
      bins[bin][3] += weight2 * weight2;
      bins[bin][4] += 1.0;
    }
  //  printf("@%c %li %li :%li:  %f %f %f %f %f %f %f\n", from==to ? 'c' : 'm', from, to, bin, val, age, bins[bin][0],bins[bin][1], lweight, weight1, weight2); 
}


///
/// set up the skyline plot histogram containers, only when also the migration histograms are recorded
/// this function needs to be called AFTER the setup_mighist() function
void
setup_expected_events (world_fmt * world, option_fmt * options)
{
    long locus, i, j;
    long allocsize = 1;
    MYREAL  binsize = world->options->eventbinsize; // in mutational units (=time scale)
    if (world->options->mighist && world->options->skyline)
    {
      for (locus = 0; locus < world->loci; locus++)
        {
	  world->mighistloci[locus].eventbinsize = binsize;
	  world->mighistloci[locus].eventbins = (tetra **) mycalloc (world->numpop2, sizeof (tetra *));
	  world->mighistloci[locus].eventbinnum = (long *) mycalloc (world->numpop2, sizeof (long));
	  for (i = 0; i < world->numpop2; i++)
            {
	      world->mighistloci[locus].eventbinnum[i] = allocsize;
	      world->mighistloci[locus].eventbins[i] = (tetra *) mycalloc (allocsize, sizeof (tetra));
	      for(j=0;j<allocsize;j++)
		{
		  world->mighistloci[locus].eventbins[i][j][0] = 0.;
		  world->mighistloci[locus].eventbins[i][j][1] = 0.;
		  world->mighistloci[locus].eventbins[i][j][2] = 0.;
		  world->mighistloci[locus].eventbins[i][j][3] = 0.;
		  world->mighistloci[locus].eventbins[i][j][4] = 0.;
		}
            }
        }
    }
}

///
/// Destroy the skyline plot histogram container
void
destroy_expected_events (world_fmt * world)
{
    long locus, i;
    if (world->options->mighist && world->options->skyline)
    {
      for (locus = 0; locus < world->loci; locus++)
        {	  
	  for (i = 0; i < world->numpop2; i++)
            {
	      myfree(world->mighistloci[locus].eventbins[i]);
            }
	  myfree(world->mighistloci[locus].eventbins);
	  myfree(world->mighistloci[locus].eventbinnum);
        }
    }
}


void print_expected_values_list(FILE *file, long locus, tetra **eventbins, MYREAL eventbinsize, long *eventbinnum, long numpop)
{
  long i;
  long pop;
  long frompop;
  long topop;
  long numpop2 = numpop * numpop;
  MYREAL age;
  
  for(pop = 0; pop < numpop2; pop++)
    {  
      age = 0.;
      if(pop < numpop)
	{
	  fprintf(file,"\nLocus: %li   Parameter: %s_%li\n", locus+1, "Theta",pop+1);  
	}
      else
	{
	  m2mm(pop,numpop,&frompop,&topop);
	  fprintf(file,"\nLocus: %li   Parameter: %s_(%li,%li)\n", locus+1, "M", frompop+1, topop+1);  
	}
      fprintf(file,"Time        Parameter       Frequency of visit\n");
      fprintf(file,"----------------------------------------------\n");
      for(i = 0; i < eventbinnum[pop]; i++)
	{
	  age += eventbinsize;
	  if(eventbins[pop][i][1] < SMALL_VALUE)
	    continue;
	  if(eventbins[pop][i][0] < 0)
	    error("nono");
	  fprintf(file,"%10.10f  %10.10f     %10.10f\n", age, eventbins[pop][i][0], eventbins[pop][i][1]);
	}
    } 
}

void print_expected_values_tofile(FILE *file,  world_fmt *world)
{
  long i;
  long pop;
  long frompop;
  long topop;
  long locus;
  long sumloc;
  long numpop = world->numpop;
  long numpop2 = world->numpop2;
  MYREAL age;
  tetra **eventbins;
  long *eventbinnum;
  MYREAL eventbinsize;

  fprintf(file,"# Raw record of the skyline histogram for all parameters and all loci\n");  
  fprintf(file,"# The time interval is set to %f\n", world->options->eventbinsize);  
  fprintf(file,"# produced by the program %s (http://popgen.csit.fsu.edu/migrate.hml)\n",
	  MIGRATEVERSION);  
  fprintf(file,"# written by Peter Beerli 2006, Tallahassee,\n");
  fprintf(file,"# if you have problems with this file please email to beerli@scs.fsu.edu\n");  
  fprintf(file,"#\n");
  fprintf(file,"# Order of the parameters:\n");
  fprintf(file,"# Parameter-number Parameter\n");
  for(pop=0;pop<numpop2;pop++)
    {
      if(pop < numpop)
	{
	  fprintf(file,"# %6li    %s_%li\n", pop+1, "Theta",pop+1);  
	}
      else
	{
	  m2mm(pop,numpop,&frompop,&topop);
	  fprintf(file,"# %6li    %s_(%li,%li)\n", pop+1, (world->options->usem ? "M" : "xNm"), frompop+1, topop+1);  
	}
    }
  fprintf(file,"#\n#----------------------------------------------------------------------------\n");
  fprintf(file,"# Locus Parameter-number Bin Age Parameter-value Parameter-Frequency \n");
  fprintf(file,"#        Standard-deviation Counts-per-bin\n");
  fprintf(file,"#----------------------------------------------------------------------------\n");
  fprintf(file,"# (*) values with -1 were NEVER visited\n");
  if(world->loci>1)
    {
      sumloc =1;
      fprintf(file,"# Locus %li is sum over all loci, when there are more than 1 locus\n", world->loci+1);
    }
  else
    {
      sumloc = 0;
    }

  for(locus=0; locus < world->loci + sumloc; locus++)
    {
      if(!world->data->skiploci[locus])
	{
	  eventbins =  world->mighistloci[locus].eventbins;
	  eventbinnum = world->mighistloci[locus].eventbinnum;
	  eventbinsize = world->mighistloci[locus].eventbinsize;
	  for(pop = 0; pop < numpop2; pop++)
	    {  
	      age = eventbinsize / 2.;
	      for(i = 0; i < eventbinnum[pop]; i++)
		{
		  fprintf(file,"%li %li %li %10.10f %10.10f %10.10f %10.10f %10.0f\n", 
			  locus+1, pop+1, i+1, age, eventbins[pop][i][0], eventbins[pop][i][1],eventbins[pop][i][2],eventbins[pop][i][4]);
		  age += eventbinsize;
		}
	    } 
	}
    }
}

void prepare_expected_values(world_fmt *world)
{
  MYREAL sum;
  long locus;
  long pop;
  long i;
  long numpop2 = world->numpop2;
  tetra **eventbins = NULL;
  tetra **eventbins_all;
  long *eventbinnum;
  long eventbinnum_allmax=0;
  float * suml;
  float *count;
  float weight_average;
  suml = (float *) mycalloc(world->numpop2,sizeof(float));
  count = (float *) mycalloc(world->numpop2,sizeof(float));

  for(locus=0; locus < world->loci; locus++)
    {
      if(!world->data->skiploci[locus])
	{
	  eventbins =  world->mighistloci[locus].eventbins;
	  eventbinnum = world->mighistloci[locus].eventbinnum;
	  for(pop = 0; pop < numpop2; pop++)
	    {  
	      sum = (MYREAL) 0.0;
	      
	      if(eventbinnum[pop] > eventbinnum_allmax)
		eventbinnum_allmax = eventbinnum[pop];
	      
	      for(i = 0; i < eventbinnum[pop]; i++)
		{
		  if(eventbins[pop][i][1] <= 0.0)
		    {
		      eventbins[pop][i][0] = 0.0;
		      eventbins[pop][i][1] = 0.0;
		      eventbins[pop][i][2] = 0.0;
		      eventbins[pop][i][3] = 0.0;
		      eventbins[pop][i][4] = 0.0;
		    }
		  else
		    {
		      eventbins[pop][i][0] /= eventbins[pop][i][1]; //average: sum(val * weight)/sum(weight)
		      // balancing the weights sum(val^2 weights)/ sum(weights)
		      eventbins[pop][i][2] /= eventbins[pop][i][1]; // ---> val^2
		      // Sum(val^2) - (E(val)^2)
		      eventbins[pop][i][2] -= (eventbins[pop][i][0] * eventbins[pop][i][0]);
		      // 1/n (sqrt(var)
		      //  eventbins[pop][i][2] = sqrt(fabs(eventbins[pop][i][2])/(eventbins[pop][i][4]-1.));
		      eventbins[pop][i][2] = sqrt(fabs(eventbins[pop][i][2]));
		      //		  eventbins[pop][i][3] = eventbins[pop][i][1];
		    }		  
		  sum += eventbins[pop][i][1];
		}
	      for(i = 0; i < eventbinnum[pop]; i++)
		{
		  if(eventbins[pop][i][1] <= 0.0)
		    continue;
		  eventbins[pop][i][1] /= sum; // calculate frequency: sum(weight_i)/sum(sum(weight_i)_j)
		  // sum(weight)/counts
		  weight_average = eventbins[pop][i][1] / eventbins[pop][i][4];
		  // sum(weight^2) - waverage^2
		  eventbins[pop][i][3] -= weight_average * weight_average;
		  eventbins[pop][i][3] = sqrt(fabs(eventbins[pop][i][3]/(eventbins[pop][i][4]-1.)));
		}
	    } 
	}
    }
  if(world->loci>1)
    {  
      if(world->mighistloci[world->loci].eventbins == NULL)
	world->mighistloci[world->loci].eventbins = (tetra **) mycalloc(world->numpop2,sizeof(tetra *));
      if(world->mighistloci[world->loci].eventbinnum == NULL)
      world->mighistloci[world->loci].eventbinnum = (long *) mycalloc(world->numpop2,sizeof(long));
      world->mighistloci[world->loci].eventbinsize = world->mighistloci[0].eventbinsize;
      for(pop=0; pop< world->numpop2 ; pop++)    
	{
	  world->mighistloci[world->loci].eventbins[pop] = (tetra *) mycalloc(eventbinnum_allmax,sizeof(tetra));
	  world->mighistloci[world->loci].eventbinnum[pop] = eventbinnum_allmax;  
	}
      eventbins_all = world->mighistloci[world->loci].eventbins;
      for (locus = 0; locus < world->loci; locus++)
	{
	  if(!world->data->skiploci[locus])
	    {
	      eventbinnum = world->mighistloci[locus].eventbinnum;
	      eventbins = world->mighistloci[locus].eventbins;
	      for(pop=0; pop< world->numpop2 ; pop++)
		{
		  for(i=0 ; i < eventbinnum[pop]; i++)
		    {
		      if(eventbins[pop][i][1] > 0.0)
			{
			  eventbins_all[pop][i][0] += eventbins[pop][i][0] * eventbins[pop][i][1];
			  eventbins_all[pop][i][1] += eventbins[pop][i][1];
			  eventbins_all[pop][i][2] += eventbins[pop][i][2] * eventbins[pop][i][1];
			  eventbins_all[pop][i][3] += eventbins[pop][i][3] * eventbins[pop][i][1];
			  eventbins_all[pop][i][4] += eventbins[pop][i][4];
			}
		      suml[pop] += eventbins[pop][i][1];
		      count[pop] += eventbins[pop][i][4];
		    }
		}
	    }
	}
      for(pop=0; pop< world->numpop2 ; pop++)
	{
	  for(i=0 ; i < eventbinnum_allmax; i++)
	    {
	      if(eventbins_all[pop][i][1] > 0.0)
		{
		  eventbins_all[pop][i][0] /= eventbins_all[pop][i][1];//average
		  eventbins_all[pop][i][2] /= eventbins_all[pop][i][1];//average standard deviation
		  eventbins_all[pop][i][3] /= eventbins_all[pop][i][1];//average weight
		      //eventbins_all[pop][i][0] /= suml[pop];
		      //eventbins_all[pop][i][1] /= suml[pop];
		      //eventbins_all[pop][i][2] /= suml[pop];
		  eventbins_all[pop][i][4] /= count[pop];
		}
	    }
	  //     eventbins_all[pop][i][0] /= suml[pop];
	}  
    } 
  myfree(suml);
  myfree(count);
}
  
void print_expected_values_title(FILE *file, boolean progress)
{
  if(progress)
    {      
      fprintf(file,"\n\nParameter changes over time\n");
      fprintf(file,"---------------------------\nSEE IN PDF FILE AND SKYLINE FILE\n");
    }
}

void print_expected_values(world_fmt * world)
{
  //  long locus;
  //  long sumloc = (world->loci > 1) ? 1 : 0;
  if(world->options->skyline)
    {
      // prepare skyline histogram for printing
      prepare_expected_values(world);
      
      // print skyline to file
      print_expected_values_tofile(world->skylinefile, world);
      // print title to screen
      print_expected_values_title(stdout, world->options->progress);
      // print title to ascii-outfile
      print_expected_values_title(world->outfile, TRUE);
      // print content to stdout and to ascii - outfile
      //for(locus=0; locus < world->loci+sumloc; locus++)
      //	{
	  //	  if(world->options->progress)
	    // print_expected_values_list(stdout, locus, world->mighistloci[locus].eventbins, 
	    //			       world->mighistloci[locus].eventbinsize, 
	    //			       world->mighistloci[locus].eventbinnum, world->numpop);
	    //print_expected_values_list(world->outfile, locus, world->mighistloci[locus].eventbins, 
	    //			     world->mighistloci[locus].eventbinsize, 
	    //			     world->mighistloci[locus].eventbinnum, world->numpop);
      //}
#ifdef PRETTY
	  pdf_skyline_histogram(world->loci, world->numpop2,  world, FALSE);
	  if(world->options->bayes_infer)
	    pdf_skyline_histogram(world->loci, world->numpop2,  world, TRUE);
#endif
    }
}


void debug_skyline(world_fmt *world, char text[])
{
  long i;
  long pop;
  long locus;
  tetra ** eventbins ;
  long *eventbinnum;
  MYREAL eventbinsize;
  MYREAL age;
  FILE *file = stdout;
  long numpop2 = world->numpop2;
  fprintf(file,"#%i -------- %s ----------\n",myID,text);
  fprintf(file,"#\n#----------------------------------------------------------------------------\n");
  fprintf(file,"# Locus Parameter-number Bin Age Parameter-value(*) Parameter-Frequency(*)\n");
  fprintf(file,"#----------------------------------------------------------------------------\n");
  fprintf(file,"# (*) values with -1 were NEVER visited\n");
  for(locus=0; locus < world->loci; locus++)
    {
      if(!world->data->skiploci[locus])
	{
	  eventbins =  world->mighistloci[locus].eventbins;
	  eventbinnum = world->mighistloci[locus].eventbinnum;
	  eventbinsize = world->mighistloci[locus].eventbinsize;
	  for(pop = 0; pop < numpop2; pop++)
	    {  
	      age = eventbinsize / 2.;
	      for(i = 0; i < eventbinnum[pop]; i++)
		{
		  fprintf(file,"%li %li %li %10.10f %10.10f %10.10f %10.10f\n", 
			  locus+1, pop+1, i+1, age, eventbins[pop][i][0], eventbins[pop][i][1],eventbins[pop][i][0]/ eventbins[pop][i][1]);
		  age += eventbinsize;
		}
	    } 
	}
    }
  fprintf(file,"#%i >>>>>>>>>>>>>>> %s <<<<<<<<<<<<<<<<<END\n",myID,text);
}
