/*------------------------------------------------------
Maximum likelihood estimation
of migration rate  and effectice population size
using a Metropolis-Hastings Monte Carlo algorithm
-------------------------------------------------------
 Bayesian   R O U T I N E S
 
 Peter Beerli 2003, Seattle
 beerli@csit.fsu.edu
 
 Copyright 2001-2003 Peter Beerli
 
 This software is distributed free of charge for non-commercial use
 and is copyrighted. Of course, we do not guarantee that the software
 works and are not responsible for any damage you may cause or have.
 
 
 $Id: bayes.c,v 1.18 2003/11/04 16:12:08 beerli Exp $
*/

#include "bayes.h"
#include "random.h"
#ifdef BAYESUPDATE
extern void precalc_world(world_fmt *world);

boolean bayes_accept (double newval, double oldval, double heat);
void bayes_print_accept(FILE * file,  world_fmt *world);


double propose_uni_newparam (double param,
                             double delta, double minparam, double maxparam);
double propose_exp_newparam (double param,
                             double delta, double minparam, double maxparam);

long bayes_update (world_fmt * world);

double probWait(vtlist *tlp, world_fmt *world, long numpop);

// calculate prob(g|param) from world->treetimes
//
double probg_treetimes(world_fmt *world)
{
    long i;
    vtlist *tl = world->treetimes->tl;
    double deltatime = tl[0].age;
    double sumprob = 0.;
    double eventprob=0.;
    eventprob = (tl[0].from==tl[0].to) ? (LOG2 - log(world->param0[tl[0].from]))
                : log(world->param0[mm2m(tl[0].from,tl[0].to, world->numpop)]);
    sumprob = -deltatime * probWait(&tl[0], world, world->numpop) + eventprob;

    for(i=1; i<world->treetimes->T;i++)
    {
        deltatime = (tl[i].age - tl[i-1].age);
        eventprob = (tl[i].from==tl[i].to) ? (LOG2 - log(world->param0[tl[i].from]))
                    : log(world->param0[mm2m(tl[i].from,tl[i].to, world->numpop)]);
        sumprob += -deltatime * probWait(&tl[i], world, world->numpop) + eventprob;
    }
    return sumprob;
}

double probWait(vtlist *tlp, world_fmt *world, long numpop)
{
    long j, z;
    double msum;
    double divisor;
    double probm=0., probth=0.;
    double line;
    for(j=0; j < numpop; j++)
    {
        line =tlp->lineages[j];
        msum = 0.0;
        for(z=world->mstart[j]; z < world->mend[j]; z++)
            msum += world->param0[z];
        probm += line * msum;
        divisor = 1. / world->param0[j];
        probth += line *(line-1.) * divisor;
    }
    return probth + probm;
}

// do we accept parameter update
boolean
bayes_accept (double newval, double oldval, double heat)
{
  double diff = (newval - oldval) * heat;// heat = 1/temperature
  if (diff > 0.0)
    return TRUE;
  if (log (RANDUM ()) < diff)
    return TRUE;
  else
    return FALSE;
}

//============================================================
// uniform flat prior, coding based on Rasmus Nielsen, veryfied with mathematica
// the correlation among values is dependent on delta, what is a good delta?
double
propose_uni_newparam (double param,
                  double delta, double minparam, double maxparam)
{
    double np;
    double r = 2. * RANDUM ();
    double cdelta = delta;//debug > param ? param : delta; 
    if(r>1.0)
      {
        np = param + (r - 1.) * cdelta;
        if (np > maxparam)
            return 2. * maxparam - np;
      }
    else
      {
        np = param -  r * cdelta;
        if (np < minparam)
            return 2. * minparam - np;
      }        
    return np;
}

double
propose_exp_newparam (double param,
                  double delta, double minparam, double maxparam)
{
    double np;
    double mean = (maxparam - minparam)/2.;
    np = -log(RANDUM()) * mean;
    return np;
}

long
bayes_update (world_fmt * world)
{
    long ba=0;
    long i = world->bayes->paramnum;
    double oldparam;
	double oldval;
    double newval;
    double newparam;
	if(((1+world->bayes->count++) % ((long)world->options->updateratio)) != 0)
		return -1L;
    if(!strchr("c0",world->options->custm2[world->bayes->paramnum]))
    {
        newparam = propose_uni_newparam (world->param0[i],
                                     world->bayes->delta[i],
                                     world->bayes->minparam[i],
                                     world->bayes->maxparam[i]);
        precalc_world(world);
        oldval =probg_treetimes(world);
        oldparam = world->param0[i];
 /*       for(di=-3; di<3;di += 0.1)
        {
            world->param0[i]=pow(10.,di);
            precalc_world(world);
            newval =probg_treetimes(world);
            printf ("                                    [%f %f %f %f] (%f) %f %f  -- %f\n",
                    world->param0[0],world->param0[1],world->param0[2],world->param0[3],
                    newval-oldval, newval,oldval, di);
            
        } */
        
        world->param0[i]=newparam;
        precalc_world(world);
        newval =probg_treetimes(world);
        if(bayes_accept(newval, oldval,world->heat))
        {
	      	if(myID==MASTER && world->heat == 1.0)
                                                printf ("***[%f %f %f %f] (%f) %f %f %f %li\n",
			world->param0[0],world->param0[1],world->param0[2],world->param0[3],
			newval-oldval, newval,oldval, world->likelihood[world->G], i);
            world->bayes->oldval = newval;
            precalc_world(world);
            ba = 1;
	    world->bayes->accept[i] += (double) ba;
        }
        else
        {
            if(myID==MASTER && world->heat == 1.0)
	  printf ("   [%f %f %f %f] (%f) %f %f %f %li\n",
			world->param0[0],world->param0[1],world->param0[2],world->param0[3],
		  newval-oldval,  newval,oldval, world->likelihood[world->G], i);
            world->param0[i] = oldparam;
            world->bayes->oldval = oldval;
            ba = 0;
        }
    }
    world->bayes->paramnum++;
    world->bayes->trials[i] += 1;
    if (world->bayes->paramnum >= world->numpop2)//does not work with gamma deviated mutation rates yet
        world->bayes->paramnum = 0; //reset the parameter choosing cycle
    return ba;
}

void bayes_save(world_fmt *world)
{
    long i;
    long pnum = world->bayes->numparams;
    long allocparams = world->bayes->allocparams;
    world->bayes->params[pnum][0] = world->bayes->oldval + //probg_treetimes(world) +
		world->likelihood[world->G];
    memcpy(&world->bayes->params[pnum][1], world->param0,sizeof(double)*world->numpop2);
    //debug if(world->options->verbose)
    //  {
    //    FPRINTF(stdout,"%8li> %10.2f ",world->bayes->count, world->bayes->params[pnum][0]);
        
    //    FPRINTF(stdout,"%8li> %10.2f ",world->bayes->count, world->bayes->params[pnum][0]);
    //    for(i=1; i<world->numpop + 1;i++)
    //    FPRINTF(stdout,"%6.4f ",world->bayes->params[pnum][i]);
    //    for(i= world->numpop + 1; i<world->numpop2+1;i++)
    //    FPRINTF(stdout,"%6.1f ",world->bayes->params[pnum][i]);
    //    FPRINTF(stdout,"\n");
    //  }
    pnum++;
    if(pnum>=allocparams)
    {
        allocparams += 1000;
        world->bayes->params = (double **) realloc(world->bayes->params,sizeof(double*)*allocparams);
        for(i=pnum;i<allocparams;i++)
            world->bayes->params[i] = (double *) calloc(world->numpop2+1,sizeof(double));
    }
    world->bayes->numparams = pnum;
    world->bayes->allocparams = allocparams;
}

void bayes_init(bayes_fmt *bayes, long size)
{
//    bayes->oldval = -DBL_MAX;
    bayes->allocparams = 1;
    bayes->numparams = 0;
    bayes->paramnum = 0;
    bayes->delta = calloc(size,sizeof(double));
    bayes->minparam = calloc(size,sizeof(double));
    bayes->maxparam = calloc(size,sizeof(double));
    bayes->params = calloc(1,sizeof(double *));
    bayes->params[0] = calloc(size+1,sizeof(double));

    bayes->accept = calloc(size+1,sizeof(double));
    bayes->trials = calloc(size+1,sizeof(double));
}

void bayes_fill(world_fmt *world, option_fmt *options)
{
    long i;
    for(i=0; i< world->numpop;i++)
    {
        world->bayes->delta[i] = options->thetag[i]; //debug was 0.01;
        world->bayes->minparam[i] = 0.00000001 ; // debug was SMALLEST_THETA;
        world->bayes->maxparam[i] =  3. * options->thetag[i]; // debug was BIGGEST_THETA;
    }
    for(i=world->numpop; i< world->numpop2;i++)
    {
        world->bayes->delta[i] = options->mg[i-world->numpop];
        world->bayes->minparam[i] = 0.0001; //debug was SMALLEST_MIGRATION;
        world->bayes->maxparam[i] = 3. * options->mg[i-world->numpop]; // debug was BIGGEST_MIGRATION;
    }
    // memcpy(world->bayes->delta, options->bayes->delta,sizeof(double)*world->numpop2);
    // memcpy(world->bayes->minparam, options->bayes->minparam,sizeof(double)*world->numpop2);
    // memcpy(world->bayes->maxparam, options->bayes->maxparam,sizeof(double)*world->numpop2);
}

void bayes_free(world_fmt *world)
{
    long i;
    free(world->bayes->delta);
    free(world->bayes->accept);
    free(world->bayes->trials);
    free(world->bayes->minparam);
    free(world->bayes->maxparam);
    for(i=world->bayes->allocparams; i > 0; i--)
        free(world->bayes->params[i]);
    free(world->bayes->params);
    //   free(world->bayes);

}

void bayes_stat(world_fmt *world)
{
  long frompop=0;
  long topop=0;
  long i, j;
  bayes_fmt *bayes = world->bayes;
  long size;
  long numpop2 = world->numpop2;
  long mode;
  //long weight;
#ifdef LONGSUM
  long addon = (world->fluctuate ? world->numpop * 3 : 0);
#else
  long addon = 0;
#endif
  double temp;
  double **rawstat;
  double *averages;
  double *std;
  char stemp[9];
  addon += world->options->gamma ? 1 : 0;
  size = numpop2 + addon;
  averages = calloc(size,sizeof(double));
  std = calloc(size,sizeof(double));
//  waverages = calloc(size,sizeof(double));
//  wstd = calloc(size,sizeof(double));
  rawstat = calloc(size,sizeof(double *));
  rawstat[0] = calloc(size * bayes->numparams,sizeof(double));
  
  for(j=1; j< size; j++)
    rawstat[j] = rawstat[0] + j * bayes->numparams;
  mode = 0;
  // find the mode so that we can (1) have the mode and
  // obsolete? (2) know the max weight for the
  // weighted average.
  for(i=0; i< bayes->numparams; i++)
    {
      mode =  (bayes->params[i][0] > bayes->params[mode][0] ? i : mode);
    }
  for(i=0; i< bayes->numparams; i++)
    {
      FPRINTF(world->bayesfile, "%f ", bayes->params[i][0]);
      //weight =  exp(bayes->params[i][0] -  bayes->params[mode][0]);
      for(j = 0; j < size; j++)
        {
	  temp = bayes->params[i][j+1];
	  FPRINTF(world->bayesfile, "%f ", temp);
	  averages[j] += temp;
	  std[j] += temp * temp;
	  rawstat[j][i] = temp;
        }
      FPRINTF(world->bayesfile, "\n");
    }
  FPRINTF(world->outfile,"\n\n\nBayesian estimates\n");
  FPRINTF(world->outfile,"==================\n\n");
  FPRINTF(world->outfile,"Parameter  2.5%%     median    97.5%%     mode     mean      std\n");
  FPRINTF(world->outfile,"--------------------------------------------------------------\n");
  for(j=0; j< size; j++)
    {
      qsort(rawstat[j],bayes->numparams,sizeof(double),numcmp);
      if(j < world->numpop)
	{
	  FPRINTF(world->outfile,"Theta_%-3li",j+1);
	  FPRINTF(world->outfile, "%8.5f %8.5f %8.5f %8.5f %8.5f %8.5f\n",
		  rawstat[j][(long)(bayes->numparams * 0.025)],
		  rawstat[j][(long)(bayes->numparams * 0.5)],
		  rawstat[j][(long)(bayes->numparams * 0.975)],
		  rawstat[j][mode],
		  averages[j]/bayes->numparams,
		  sqrt(std[j]/(bayes->numparams-1)));
	}
      else
	{
	  topop = (long) (j-world->numpop)/(world->numpop-1);
	  frompop = j - world->numpop - topop * (world->numpop - 1);
	  if(frompop>=topop)
	    frompop += 1;
	  sprintf(stemp,"M_%li->%li", frompop+1, topop+1);
	  FPRINTF(world->outfile, "%-9.9s%8.1f %8.1f %8.1f %8.1f %8.1f %8.1f\n",
		  stemp, rawstat[j][(long)(bayes->numparams * 0.025)],
		  rawstat[j][(long)(bayes->numparams * 0.5)],
		  rawstat[j][(long)(bayes->numparams * 0.975)],
		  rawstat[j][mode],
		  averages[j]/bayes->numparams,
		  std[j]/(bayes->numparams-1));
	}     
    }
  // print out the acceptance ratios for every parameter and the tree
  if(world->options->progress)
    {
      bayes_print_accept(stdout,world);
    }
  bayes_print_accept(world->outfile,world);


  free(rawstat[0]);
  free(rawstat);
  free(averages);
  free(std);
 // free(waverages);
 // free(wstd);
}


void
bayes_print_accept(FILE * file,  world_fmt *world)
{
  long j, topop, frompop;
  char *stemp;
  long tc = world->numpop2; //this ignores alpha among multiple loci
  bayes_fmt *bayes = world->bayes;
  stemp = calloc(LINESIZE,sizeof(char));
  FPRINTF(file,"\n\n\nAcceptance ratios for all parameters and the genealogies\n");
  FPRINTF(file,"--------------------------------------------------------\n\n");
  FPRINTF(file,"Parameter           Accepted changes            Ratio\n");
  for(j=0; j < world->numpop2; j++)
    {
      if(j < world->numpop)
	{
	  FPRINTF(file,"Theta_%-3li",j+1);
	  FPRINTF(file, "            %8li/%-8li         %8.5f\n", bayes->accept[j],
		  world->bayes->trials[j], 
		  (double) bayes->accept[j]/world->bayes->trials[j]);
	}
      else
	{
	  topop = (long) (j-world->numpop)/(world->numpop-1);
	  frompop = j - world->numpop - topop * (world->numpop - 1);
	  if(frompop>=topop)
	    frompop += 1;
	  memset(stemp,0,sizeof(char)*LINESIZE);
	  sprintf(stemp,"M_%li->%li", frompop+1, topop+1);
	  FPRINTF(file, "%-9.9s            %8li/%-8li         %8.5f\n", stemp, bayes->accept[j], 
		  bayes->trials[j], 
		  (double) bayes->accept[j]/world->bayes->trials[j]);
	}

    }
      // accepted trees
      FPRINTF(file,"Genealogies");
      FPRINTF(file, "          %8li/%-8li         %8.5f\n", bayes->accept[tc], (long)
	      ((world->options->lincr * world->options->lsteps) * (1. - 1./world->options->updateratio)), 
	      (double) bayes->accept[tc]/((world->options->lincr * world->options->lsteps) * (1. - 1./world->options->updateratio)));
  		 free(stemp);
}

#endif /*bayesupdate*/

