
/*------------------------------------------------------
 Maximum likelihood estimation 
 of migration rate  and effectice population size
 using a Metropolis-Hastings Monte Carlo algorithm                            
 -------------------------------------------------------                        
 F S T   R O U T I N E S 

 calculates FST

 Peter Beerli 1996, Seattle
 beerli@genetics.washington.edu
 $Id: fst.c,v 1.1.1.1 1998/06/06 06:09:51 beerli Exp $
-------------------------------------------------------*/

#include "migration.h"
#include "tools.h"

#ifdef DMALLOC_FUNC_CHECK
#include "dmalloc.h"
#endif

/* prototypes ------------------------------------------- */
void fst_type (char type);
void calc_fst (world_fmt * world);
/* private functions */
void frequencies (double ***f, double **f2, char *****data, long numpop, long **numind, long loci);
void calc_fw (double ***f, long numpop, long locus, double *fw);
void calc_fb (double ***f, long numpop, long locus, double *fb);
void calc_seq_fw (data_fmt * data, long numpop, long locus, double *fw);
void calc_seq_fb (data_fmt * data, long numpop, long locus, double *fb);
void solveNm_vartheta (double *fw, double *fb, long numpop, double *params);
void solveNm_varm (double *fw, double *fb, long numpop, double *params);

/* global variable SOLVENM points to function solveNm_varxx() */
static void (*solveNm) (double *, double *, long, double *);

/*=======================================================*/
void
fst_type (char type)
{
  if (type == 'T')
    solveNm = (void (*)(double *, double *, long, double *)) solveNm_vartheta;
  else
    solveNm = (void (*)(double *, double *, long, double *)) solveNm_varm;
}

void
calc_fst (world_fmt * world)
{
  long pop, locus;
  long connections = world->numpop * (world->numpop - 1) / 2;
  double ***fstfreq = NULL, **fstfreq2 = NULL, *fw, *fb, *sumfb, *sumfw;
  fw = (double *) calloc (1, sizeof (double) * world->numpop);
  fb = (double *) calloc (1, sizeof (double) * connections);
  sumfw = (double *) calloc (1, sizeof (double) * world->numpop);
  sumfb = (double *) calloc (1, sizeof (double) * connections);
  if (world->options->datatype != 's')
    {
      fstfreq = (double ***) calloc (1, sizeof (double **) * world->numpop);
      fstfreq2 = (double **) calloc (1, sizeof (double *) * world->loci);
      for (locus = 0; locus < world->loci; locus++)
	fstfreq2[locus] = (double *) calloc (1, sizeof (double) * 255);
      for (pop = 0; pop < world->numpop; pop++)
	{
	  fstfreq[pop] = (double **) calloc (1, sizeof (double *) * world->loci);
	  for (locus = 0; locus < world->loci; locus++)
	    fstfreq[pop][locus] = (double *) calloc (1, sizeof (double) * 255);
	}
      frequencies (fstfreq, fstfreq2, world->data->yy, world->numpop,
		   world->data->numind, world->loci);
    }
  for (locus = 0; locus < world->loci; locus++)
    {
      if (world->options->datatype == 's')
	{
	  calc_seq_fw (world->data, world->numpop, locus, fw);
	  calc_seq_fb (world->data, world->numpop, locus, fb);
	}
      else
	{
	  calc_fw (fstfreq, world->numpop, locus, fw);
	  calc_fb (fstfreq, world->numpop, locus, fb);
	}

      (*solveNm) (fw, fb, world->numpop, world->fstparam[locus]);

      for (pop = 0; pop < world->numpop; pop++)
	{
	  sumfw[pop] += fw[pop];
	}
      for (pop = 0; pop < connections; pop++)
	{
	  sumfb[pop] += fb[pop];
	}
    }
  for (pop = 0; pop < world->numpop; pop++)
    {
      sumfw[pop] /= world->loci;
    }
  for (pop = 0; pop < connections; pop++)
    {
      sumfb[pop] /= world->loci;
    }

  (*solveNm) (sumfw, sumfb, world->numpop, world->fstparam[world->loci]);
  if (world->options->datatype != 's')
    {
      for (pop = 0; pop < world->numpop; pop++)
	{
	  for (locus = 0; locus < world->loci; locus++)
	    free (fstfreq[pop][locus]);
	  free (fstfreq[pop]);
	}
      for (locus = 0; locus < world->loci; locus++)
	free (fstfreq2[locus]);
      free (fstfreq);
      free (fstfreq2);
    }
  free (fw);
  free (fb);
  free (sumfw);
  free (sumfb);
}


/*=======================================================*/

void
frequencies (double ***f, double **f2, char *****data, long numpop,
	     long **numind, long loci)
{
  long **buckets, *buckets2;
  long *total, pop, locus, ind, a;
  buckets = (long **) malloc (sizeof (long *) * numpop);
  buckets2 = (long *) calloc (1, sizeof (long) * 255);
  total = (long *) calloc (1, sizeof (long) * numpop);
  for (pop = 0; pop < numpop; pop++)
    {
      buckets[pop] = (long *) calloc (1, sizeof (long) * 255);
      for (locus = 0; locus < loci; locus++)
	{
	  memset (buckets[pop], 0, sizeof (long) * 255);
	  total[pop] = 0;
	  for (ind = 0; ind < numind[pop][FLOC]; ind++)
	    {
	      if (data[pop][ind][locus][0][0] != '?')
		{
		  buckets[pop][data[pop][ind][locus][0][0] - '!'] += 1;
		  buckets2[data[pop][ind][locus][0][0] - '!'] += 1;
		  total[pop] += 1;
		}
	      if (data[pop][ind][locus][1][0] != '?')
		{
		  buckets[pop][data[pop][ind][locus][1][0] - '!'] += 1;
		  buckets2[data[pop][ind][locus][1][0] - '!'] += 1;
		  total[pop] += 1;
		}
	      for (a = 0; a < 255; a++)
		{
		  if (total[pop] > 0)
		    f[pop][locus][a] = (double) buckets[pop][a] / (double) total[pop];
		  if (total[0] + total[1] > 0)
		    f2[locus][a] = (double) buckets2[a] / ((double) total[0] + total[1]);
		}
	    }
	}
      free (buckets[pop]);
    }
  free (total);
  free (buckets);
  free (buckets2);
}



void
calc_fw (double ***f, long numpop, long locus, double *fw)
{
  long pop, i;
  for (pop = 0; pop < numpop; pop++)
    {
      fw[pop] = 0;
      for (i = 0; i < 255; i++)
	{
	  fw[pop] += f[pop][locus][i] * f[pop][locus][i];
	}
    }
}
void
calc_fb (double ***f, long numpop, long locus, double *fb)
{
  long i, p1, p2, zz = 0;
  for (p1 = 0; p1 < numpop; p1++)
    {
      for (p2 = p1 + 1; p2 < numpop; p2++)
	{
	  fb[zz] = 0.0;
	  for (i = 0; i < 255; i++)
	    {
	      fb[zz] += f[p1][locus][i] * f[p2][locus][i];
	    }
	  zz++;
	}
    }
}

void
calc_seq_fw (data_fmt * data, long numpop, long locus, double *fw)
{
  long pop, i, k, j;
  double nn;
  double diff;
  for (pop = 0; pop < numpop; pop++)
    {
      fw[pop] = 0;
      nn = data->seq->sites[locus] * (data->numind[pop][locus] *
		  data->numind[pop][locus] - data->numind[pop][locus]) / 2.;
      for (i = 0; i < data->numind[pop][locus]; i++)
	{
	  for (k = i + 1; k < data->numind[pop][locus]; k++)
	    {
	      diff = 0.;
	      for (j = 0; j < data->seq->sites[locus]; j++)
		{
		  diff +=
		    (data->yy[pop][i][locus][0][j] !=
		     data->yy[pop][k][locus][0][j]);
		}
	      if (nn > 0)
		fw[pop] += diff / nn;
	    }
	}
      fw[pop] = 1. - fw[pop];
    }
}

void
calc_seq_fb (data_fmt * data, long numpop, long locus, double *fb)
{
  long i, k, j;
  double nn, temp;
  double diff;
  long p1, p2, zz = 0;
  for (p1 = 0; p1 < numpop; p1++)
    {
      for (p2 = p1 + 1; p2 < numpop; p2++)
	{
	  temp = 0;
	  nn = data->seq->sites[locus] * data->numind[p1][locus] *
	    data->numind[p2][locus];
	  for (i = 0; i < data->numind[p1][locus]; i++)
	    {
	      for (k = 0; k < data->numind[p2][locus]; k++)
		{
		  diff = 0.;
		  for (j = 0; j < data->seq->sites[locus]; j++)
		    {
		      diff +=
			(data->yy[p1][i][locus][0][j] !=
			 data->yy[p2][k][locus][0][j]);
		    }
		  if (nn > 0)
		    temp += diff / nn;
		}
	    }
	  fb[zz++] = 1. - temp;
	}
    }
}

void
solveNm_varm (double *fw, double *fb, long numpop, double *params)
{
  long i, p1 /*,p2 */ ;
  /*Version 2.0  double sumfw = sum(fw,numpop);
     double sumfb = sum(fb,numpop*(numpop-1)/2); */
  double first = (2. - fw[0] - fw[1]) / (2.*fb[0] + fw[0] + fw[1]);
  long offset2 = numpop + numpop * (numpop - 1);
  long offset = numpop;
  long numfb = 0;

  for (p1 = 0; p1 < numpop; p1++)
    {
      numfb += p1;
      params[p1] = first;
      params[offset2 + p1] = fw[p1];
    }
  params[offset] = (2.*fb[0]-fw[0]-2. * fb[0]*fw[0]+fw[1])/
    ((fb[0]-fw[0])*(-2. + fw[0] + fw[1]));
  params[offset + 1] = (2.*fb[0]-fw[1]-2. * fb[0]*fw[1]+fw[0])/
    ((fb[0]-fw[1])*(-2. + fw[0] + fw[1]));
    
  for (i = 0; i < offset2; i++)
    {
      if (params[i] < 0.)
	params[i] = -999;
    }
  for (i = 0; i < numfb; i++)
    {
      params[offset2 + numpop + i] = fb[i];
    }
}

void
solveNm_vartheta (double *fw, double *fb, long numpop, double *params)
{
  long i;
  double nom;
  double denom;
  /* Version 2.0  double sumfw = sum(fw,numpop);
     double sumfb = sum(fb,numpop*(numpop-1)/2); */

  long offset2 = numpop + numpop * (numpop - 1);
  long numfb = 0;
  nom = (-2. * fb[0] + fw[0] + fw[1]);
  for (i = 0; i < numpop; i++)
    {
      numfb += i;
    }
  denom = -2.*fb[0] * fb[0] + fw[0] * fw[1];
  for (i = 0; i < numpop; i++)
    {
      params[offset2 + i] = fw[i];
      params[i] = (nom * (1.-fw[i]))/(denom + fw[i] * fw[i]);
    }
  if(nom==0.0)
    params[numpop] = -999.;
  else
    params[numpop] = 2. * fb[0] / nom;
  for (i = 1; i < numpop * (numpop - 1); i++)
    {
      params[numpop + i] = params[numpop];
    }
  for (i = 0; i < offset2; i++)
    {
      if (params[i] < 0.)
	params[i] = -999;
    }
  for (i = 0; i < numfb; i++)
    {
      params[offset2 + numpop + i] = fb[i];
    }
}










