/*------------------------------------------------------
Maximum likelihood estimation
of migration rate  and effectice population size
using a Metropolis-Hastings Monte Carlo algorithm
-------------------------------------------------------
 Bayesian   R O U T I N E S
 
 Peter Beerli 2001, Seattle
 beerli@gs.washington.edu
 
 Copyright 2001-2002 Peter Beerli
 
 This software is distributed free of charge for non-commercial use
 and is copyrighted. Of course, we do not guarantee that the software
 works and are not responsible for any damage you may cause or have.
 
 
 $Id: bayes.c,v 1.14 2002/11/27 16:17:13 beerli Exp $
*/

#include "bayes.h"
#include "random.h"
#ifdef BAYESUPDATE
extern void precalc_world(world_fmt *world);

boolean bayes_accept (double newval, double oldval, double heat);

double propose_newparam (double param,
                         double delta, double minparam, double maxparam);

long bayes_update (world_fmt * world);

double probWait(vtlist *tlp, world_fmt *world, long numpop);

// calculate prob(g|param) from world->treetimes
//
double probg_treetimes(world_fmt *world)
{
    long i;
    vtlist *tl = world->treetimes->tl;
    double deltatime = tl[0].age;
    double sumprob = 0.;
    double eventprob=0.;
    eventprob = (tl[0].from==tl[0].to) ? (LOG2 - log(world->param0[tl[0].from]))
                : log(world->param0[mm2m(tl[0].from,tl[0].to, world->numpop)]);
    sumprob = -deltatime * probWait(&tl[0], world, world->numpop) + eventprob;

    for(i=1; i<world->treetimes->T;i++)
    {
        deltatime = (tl[i].age - tl[i-1].age);
        eventprob = (tl[i].from==tl[i].to) ? (LOG2 - log(world->param0[tl[i].from]))
                    : log(world->param0[mm2m(tl[i].from,tl[i].to, world->numpop)]);
        sumprob += -deltatime * probWait(&tl[i], world, world->numpop) + eventprob;
    }
    return sumprob;
}

double probWait(vtlist *tlp, world_fmt *world, long numpop)
{
    long j, z;
    double msum;
    double probm=0., probth=0.;
    double line;
    for(j=0; j < numpop; j++)
    {
        line =tlp->lineages[j];
        msum = 0.0;
        for(z=world->mstart[j]; z < world->mend[j]; z++)
            msum += world->param0[z];
        probm += line * msum;
        probth += line *(line-1.) / world->param0[j];
    }
    return probth + probm;
}

// do we accept parameter update
boolean
bayes_accept (double newval, double oldval, double heat)
{
	double diff = (newval - oldval) * heat;// heat = 1/temperature
    if (diff > 0.0)
        return TRUE;
    if (log (RANDUM ()) < diff)
        return TRUE;
    else
        return FALSE;
	 
}

double
propose_newparam (double param,
                  double delta, double minparam, double maxparam)
{
    double np;
    double r = RANDUM ();
    if(r>0.5)
      {
        np = param + (2. * r - 1.) * delta;
        if (np > maxparam)
            return 2. * maxparam - np;
      }
    else
      {
        np = param - 2. * r * delta;
        if (np < minparam)
            return 2. * minparam - np;
      }        
    return np;
}

long
bayes_update (world_fmt * world)
{
    long ba=0;
    long i = world->bayes->paramnum;
    double oldparam;
	double oldval;
    double newval;
    double newparam;
	if(((1+world->bayes->count++) % ((long)world->options->updateratio)) != 0)
		return 0;
    if(!strchr("c0",world->options->custm2[world->bayes->paramnum]))
    {
        newparam = propose_newparam (world->param0[i],
                                     world->bayes->delta[i],
                                     world->bayes->minparam[i],
                                     world->bayes->maxparam[i]);
        oldval =probg_treetimes(world);
        oldparam = world->param0[i];
        world->param0[i]=newparam;
        newval =probg_treetimes(world);
        if(bayes_accept(newval, oldval,world->heat))
        {
		//	printf ("***%f (%f) %f %f %f %li\n",newparam, oldparam, newval,oldval, world->likelihood[world->G], i);
            world->bayes->oldval = newval;
            precalc_world(world);
            ba = 1;
        }
        else
        {
       // printf ("   %f (%f) %f %f %f %li\n",newparam, oldparam, newval,oldval, world->likelihood[world->G], i);
            world->param0[i] = oldparam;
            world->bayes->oldval = oldval;
            ba = 0;
        }
    }
    world->bayes->paramnum++;
    if (world->bayes->paramnum >= world->numpop2)//does not work with gamma deviated mutation rates yet
        world->bayes->paramnum = 0; //reset the parameter choosing cycle
    return ba;
}

void bayes_save(world_fmt *world)
{
    long i;
    long pnum = world->bayes->numparams;
    long allocparams = world->bayes->allocparams;
    world->bayes->params[pnum][0] = world->bayes->oldval + //probg_treetimes(world) +
		world->likelihood[world->G];
    memcpy(&world->bayes->params[pnum][1], world->param0,sizeof(double)*world->numpop2);
    if(world->options->verbose)
      {
        FPRINTF(stdout,"%8li> %10.2f ",world->bayes->count, world->bayes->params[pnum][0]);
        
        FPRINTF(stdout,"%8li> %10.2f ",world->bayes->count, world->bayes->params[pnum][0]);
        for(i=1; i<world->numpop + 1;i++)
        FPRINTF(stdout,"%6.4f ",world->bayes->params[pnum][i]);
        for(i= world->numpop + 1; i<world->numpop2+1;i++)
        FPRINTF(stdout,"%6.1f ",world->bayes->params[pnum][i]);
        FPRINTF(stdout,"\n");
      }
    pnum++;
    if(pnum>=allocparams)
    {
        allocparams += 1000;
        world->bayes->params = (double **) realloc(world->bayes->params,sizeof(double*)*allocparams);
        for(i=pnum;i<allocparams;i++)
            world->bayes->params[i] = (double *) calloc(world->numpop2+1,sizeof(double));
    }
    world->bayes->numparams = pnum;
    world->bayes->allocparams = allocparams;
}

void bayes_init(bayes_fmt *bayes, long size)
{
//    bayes->oldval = -DBL_MAX;
    bayes->allocparams = 1;
    bayes->numparams = 0;
    bayes->paramnum = 0;
    bayes->delta = calloc(size,sizeof(double));
    bayes->minparam = calloc(size,sizeof(double));
    bayes->maxparam = calloc(size,sizeof(double));
    bayes->params = calloc(1,sizeof(double *));
    bayes->params[0] = calloc(size+1,sizeof(double));
}

void bayes_fill(world_fmt *world, option_fmt *options)
{
    long i;
    for(i=0; i< world->numpop;i++)
    {
        world->bayes->delta[i] = 0.01;
        world->bayes->minparam[i] = SMALLEST_THETA;
        world->bayes->maxparam[i] = BIGGEST_THETA;
    }
    for(i=world->numpop; i< world->numpop2;i++)
    {
        world->bayes->delta[i] = 10.;
        world->bayes->minparam[i] = SMALLEST_MIGRATION;
        world->bayes->maxparam[i] = BIGGEST_MIGRATION;
    }
    // memcpy(world->bayes->delta, options->bayes->delta,sizeof(double)*world->numpop2);
    // memcpy(world->bayes->minparam, options->bayes->minparam,sizeof(double)*world->numpop2);
    // memcpy(world->bayes->maxparam, options->bayes->maxparam,sizeof(double)*world->numpop2);
}

void bayes_free(world_fmt *world)
{
    long i;
    free(world->bayes->delta);
    free(world->bayes->minparam);
    free(world->bayes->maxparam);
    for(i=world->bayes->allocparams; i > 0; i--)
        free(world->bayes->params[i]);
    free(world->bayes->params);
    //   free(world->bayes);
}

void bayes_stat(world_fmt *world)
{
  long frompop=0;
  long topop=0;
  long i, j;
  bayes_fmt *bayes = world->bayes;
  long size;
  long numpop2 = world->numpop2;
  long mode;
  long weight;
#ifdef LONGSUM
  long addon = (world->fluctuate ? world->numpop * 3 : 0);
#else
  long addon = 0;
#endif
  double temp;
  double **rawstat;
  double *averages;
  double *std;
  char stemp[9];
  addon += world->options->gamma ? 1 : 0;
  size = numpop2 + addon;
  averages = calloc(size,sizeof(double));
  std = calloc(size,sizeof(double));
//  waverages = calloc(size,sizeof(double));
//  wstd = calloc(size,sizeof(double));
  rawstat = calloc(size,sizeof(double *));
  rawstat[0] = calloc(size * bayes->numparams,sizeof(double));
  
  for(j=1; j< size; j++)
    rawstat[j] = rawstat[0] + j * bayes->numparams;
  mode = 0;
  // find the mode so that we can (1) have the mode and (2) know the max weight for the
  // weighted average.
  for(i=0; i< bayes->numparams; i++)
    {
      mode =  (bayes->params[i][0] > bayes->params[mode][0] ? i : mode);
    }
  for(i=0; i< bayes->numparams; i++)
    {
      FPRINTF(world->bayesfile, "%f ", bayes->params[i][0]);
      weight =  exp(bayes->params[i][0] -  bayes->params[mode][0]);
      for(j = 0; j < size; j++)
        {
	  temp = bayes->params[i][j+1];
	  FPRINTF(world->bayesfile, "%f ", temp);
	  averages[j] += temp;
	  std[j] += temp * temp;
	  rawstat[j][i] = temp;
        }
      FPRINTF(world->bayesfile, "\n");
    }
  FPRINTF(world->outfile,"\n\n\nBayesian estimates\n");
  FPRINTF(world->outfile,"==================\n\n");
  FPRINTF(world->outfile,"Parameter  2.5%%     median    97.5%%     mode     mean      std\n");
  FPRINTF(world->outfile,"--------------------------------------------------------------\n");
  for(j=0; j< size; j++)
    {
      qsort(rawstat[j],bayes->numparams,sizeof(double),numcmp);
      if(j < world->numpop)
	{
	  FPRINTF(world->outfile,"Theta_%-3li",j+1);
	  FPRINTF(world->outfile, "%8.5f %8.5f %8.5f %8.5f %8.5f %8.5f\n",
		  rawstat[j][(long)(bayes->numparams * 0.025)],
		  rawstat[j][(long)(bayes->numparams * 0.5)],
		  rawstat[j][(long)(bayes->numparams * 0.975)],
		  rawstat[j][mode],
		  averages[j]/bayes->numparams,
		  sqrt(std[j]/(bayes->numparams-1)));
	}
      else
	{
	  topop = (long) (j-world->numpop)/(world->numpop-1);
	  frompop = j - world->numpop - topop * (world->numpop - 1);
	  if(frompop>=topop)
	    frompop += 1;
	  sprintf(stemp,"M_%li->%li", frompop+1, topop+1);
	  FPRINTF(world->outfile, "%-9.9s%8.1f %8.1f %8.1f %8.1f %8.1f %8.1f\n",
		  stemp, rawstat[j][(long)(bayes->numparams * 0.025)],
		  rawstat[j][(long)(bayes->numparams * 0.5)],
		  rawstat[j][(long)(bayes->numparams * 0.975)],
		  rawstat[j][mode],
		  averages[j]/bayes->numparams,
		  std[j]/(bayes->numparams-1));
	}     
    }
  free(rawstat[0]);
  free(rawstat);
  free(averages);
  free(std);
 // free(waverages);
 // free(wstd);
}


#endif /*bayesupdate*/

