/*
 * mig-histogram code
 * 
 * 
 * $Id$
 * 
 */

#include "mig-histogram.h"
void setup_mighist (world_fmt * world, option_fmt * options);
void print_mighist (world_fmt * world);
void
setup_plotfield (plotfield_fmt * plotfield, char thisplotype,
		 long xsize, long ysize, char xlabel[], char ylabel[],
		 char yflabel[], char title[], boolean print);

long calc_migtable (double **migtable, histogram_fmt * histogram,
		    mighistloci_fmt * aa, world_fmt * world, long loci);
double average (double *vec, long size, long *weight, double *se, long *n);
double quantile (double *vec, long size, long *weight, double quantile);

#define NBINS 30
void
print_histogram_ascii (FILE * out, histogram_fmt ** histogram,
		       plotfield_fmt ** plotfield, long loci, long nmigs,
		       long bins, long *sum, double ***migtable);

void
prepare_hist (histogram_fmt * hist, double *time, long count, long *weight);

void minmax (histogram_fmt * hist, double *tempmin, double *tempmax);


void
print_mighist_file (FILE * mighist, world_fmt * world)
{
  mighistloci_fmt *aa;
  // long copies;
  long l, j, i;
  for (l = 0; l < world->loci; l++)
    {
      aa = &world->mighistloci[l];
      for (j = 0; j < aa->mighistnum; j++)
	{
	  //  copies = aa->mighist[j].copies;
	  for (i = 0; i < aa->mighist[j].migeventsize; i++)
	    {
	      fprintf (mighist, "%f %f %f\n",
		       aa->mighist[j].migevents[i][0],
		       aa->mighist[j].migevents[i][1],
		       aa->mighist[j].migevents[i][2]);
	    }
	}
    }
}

void
calc_mighistvalues (world_fmt * world, double ***migtable,
		    histogram_fmt ** histogram, long *total)
{
  mighistloci_fmt *aa;
  long l, p1;
  long loci1 = world->loci == 1 ? 1 : world->loci + 1;
  for (l = 0; l < loci1; l++)
    {
      for (p1 = 0; p1 < world->numpop2; p1++)
	migtable[l][p1][2] = NOAVERAGE;
      aa = &world->mighistloci[l];
      if (l == world->loci)
	{
	  if (world->loci != 1)
	    calc_migtable (migtable[l], histogram[l], world->mighistloci,
			   world, world->loci);
	  else
	    break;
	}
      else
	total[l] = calc_migtable (migtable[l], histogram[l], aa, world, 1);
    }
}

void
print_mighist_output (FILE * out, world_fmt * world, double ***migtable,
		      long *total)
{
  long l, p1;
  long loci1 = world->loci == 1 ? 1 : world->loci + 1;
  fprintf (out, "\n\nSummary of Migration Events\n");
  fprintf (out, "===============================\n\n");
  for (l = 0; l < world->loci; l++)
    total[world->loci] += total[l];
  for (l = 0; l < loci1; l++)
    {				/* Each locus + Summary */
      if (l != world->loci)
	fprintf (out, "Locus %li\n", l + 1);
      else
	fprintf (out, "Over all loci\n");
      fprintf (out,
	       "---------------------------------------------------------\n");
      fprintf (out,
	       "Population   Time                             Frequency\n");
      fprintf (out, "             -----------------------------\n");
      fprintf (out, "From    To   Average    Median     SE\n");
      fprintf (out,
	       "---------------------------------------------------------\n");
      for (p1 = 0; p1 < world->numpop2; p1++)
	{
	  if (migtable[l][p1][0] != migtable[l][p1][1])
	    {
	      if (migtable[l][p1][2] == NOAVERAGE)
		fprintf (out, "%4li %4li    No migration event encountered\n",
			 (long) migtable[l][p1][0] + 1,
			 (long) migtable[l][p1][1] + 1);
	      else
		{
		  fprintf (out,
			   "%4li %4li    %3.5f    %3.5f    %3.5f    %3.5f\n",
			   (long) migtable[l][p1][0] + 1,
			   (long) migtable[l][p1][1] + 1, migtable[l][p1][2],
			   migtable[l][p1][3], migtable[l][p1][4],
			   migtable[l][p1][5] / total[l]);
		}
	    }
	}
      fprintf (out,
	       "---------------------------------------------------------\n");
      fprintf (out, "\n");
    }
}

/*
 * print_mighist() prints a table with the frequency of migrations events
 * from and to per timeinterval that is 1/100 of the full time  that goes
 * from zero to the Maximum time in the record.
 * 
 * PopFrom PopTo  Average-Time Median-Time SE "Probability"
 * 
 */
void
print_mighist (world_fmt * world)
{
  long l, i, z;
  char plotype;
  char xlabel[255];
  char ylabel[255];
  char yflabel[255];
  char title[255];
  long xsize;
  long ysize;
  long to;
  long from;
  FILE *out = world->outfile;
  long loci1 = world->loci == 1 ? 1 : world->loci + 1;
  long *total;
  double ***migtable;		//loci x numpop2 x {mean, median, se}
  plotfield_fmt **plotfield;
  histogram_fmt **histogram;	//only for overall loci
  if (world->options->mighist)
    {
      total = (long *) calloc (loci1, sizeof (long));
      plotfield = (plotfield_fmt **) calloc (loci1, sizeof (plotfield_fmt *));
      migtable = (double ***) calloc (loci1, sizeof (double **));
      histogram = (histogram_fmt **) calloc (loci1, sizeof (histogram_fmt *));
      for (l = 0; l < loci1; ++l)
	{
	  plotfield[l] =
	    (plotfield_fmt *) calloc (world->numpop2, sizeof (plotfield_fmt));
	  migtable[l] =
	    (double **) calloc (world->numpop2, sizeof (double *));
	  histogram[l] =
	    (histogram_fmt *) calloc (world->numpop2, sizeof (histogram_fmt));
	  for (i = 0; i < world->numpop2; ++i)
	    {
	      histogram[l][i].time = NULL;
	      histogram[l][i].weight = NULL;
	    }
	  for (i = 0; i < world->numpop2; ++i)
	    {
	      migtable[l][i] = (double *) calloc (6, sizeof (double));
	    }
	}

      //setup histogram
      plotype = 'a';
      xsize = NBINS;
      ysize = MIGHIST_YSIZE;
      strcpy (xlabel, "Time");
      strcpy (yflabel, "Frequency");
      strcpy (ylabel, "Count");
      for (l = 0; l < loci1; l++)
	{
	  z = 0;
	  for (to = 0; to < world->numpop; ++to)
	    {
	      for (from = 0; from < world->numpop; ++from)
		{
		  sprintf (title, "Migrations from population %li to %li",
			   from + 1, to + 1);
		  setup_plotfield (&plotfield[l][z], plotype, xsize, ysize,
				   xlabel, ylabel, yflabel, title,
				   from != to);
		  z++;
		}
	    }
	}
      print_mighist_file (world->mighistfile, world);
      calc_mighistvalues (world, migtable, histogram, total);
      print_mighist_output (world->outfile, world, migtable, total);

      print_histogram_ascii (out, histogram, plotfield, loci1, world->numpop2,
			     NBINS, total, migtable);
      fflush (out);
      free (total);
      for (l = 0; l < loci1; l++)
	{
	  for (i = 0; i < world->numpop2; ++i)
	    {
	      free (migtable[l][i]);
	      free (plotfield[l][i].data[0]);
	      free (plotfield[l][i].data);
	      free (plotfield[l][i].y);
	      free (plotfield[l][i].yfreq);
	    }
	  free (migtable[l]);
	  free (plotfield[l]);
	  free (histogram[l]);
	}
      free (migtable);
      free (plotfield);
      free (histogram);
    }
}

void
setup_plotfield (plotfield_fmt * plotfield, char thisplotype,
		 long xsize, long ysize, char xlabel[], char ylabel[],
		 char yflabel[], char title[], boolean print)
{
  long i;
  plotfield->print = print;
  plotfield->type = thisplotype;
  plotfield->xsize = xsize;
  plotfield->ysize = ysize;
  strncpy (plotfield->xaxis, xlabel, 254);
  strncpy (plotfield->yaxis, ylabel, 254);
  strncpy (plotfield->yfaxis, yflabel, 254);
  strncpy (plotfield->title, title, 254);
  plotfield->yfreq = (double *) calloc (ysize, sizeof (double));
  plotfield->y = (long *) calloc (ysize, sizeof (long));
  plotfield->data = (char **) malloc (sizeof (char *) * xsize);
  plotfield->data[0] = (char *) malloc (sizeof (char) * xsize * ysize);
  for (i = 1; i < xsize; i++)
    {
      plotfield->data[i] = plotfield->data[0] + i * ysize;
    }
}



long
calc_migtable (double **migtable, histogram_fmt * histogram,
	       mighistloci_fmt * aa, world_fmt * world, long loci)
{
  long p1, p2, pa, pb, i, j;
  long maxloci, locus;
  long copies;
  long n, total = 0, maxsize;
  double se;
  double ***migtime;
  long ***gencount;
  long **migcount;
  long **size;
  migtime = (double ***) calloc (world->numpop, sizeof (double **));
  gencount = (long ***) calloc (world->numpop, sizeof (long **));
  migcount = (long **) calloc (world->numpop, sizeof (long *));
  size = (long **) calloc (world->numpop, sizeof (long *));
  maxsize = 1;
  if (loci == 1)
    maxloci = 1;
  else
    maxloci = world->loci;
  for (locus = 0; locus < maxloci; locus++)
    {
      for (j = 0; j < aa[locus].mighistnum; j++)
	{
	  if (maxsize < aa[locus].mighist[j].migeventsize)
	    maxsize = aa[locus].mighist[j].migeventsize;
	}
    }
  for (p1 = 0; p1 < world->numpop; ++p1)
    {
      migtime[p1] = (double **) calloc (world->numpop, sizeof (double *));
      gencount[p1] = (long **) calloc (world->numpop, sizeof (long *));
      migcount[p1] = (long *) calloc (world->numpop, sizeof (long));
      size[p1] = (long *) calloc (world->numpop, sizeof (long));
      for (p2 = 0; p2 < world->numpop; ++p2)
	{
	  migtime[p1][p2] = (double *) calloc (maxsize, sizeof (double));
	  gencount[p1][p2] = (long *) calloc (maxsize, sizeof (long));
	  size[p1][p2] = maxsize;
	}
    }
  for (locus = 0; locus < maxloci; locus++)
    {
      for (j = 0; j < aa[locus].mighistnum; j++)
	{
	  copies = aa[locus].mighist[j].copies;
	  for (i = 0; i < aa[locus].mighist[j].migeventsize; i++)
	    {
	      p1 = aa[locus].mighist[j].migevents[i][1];
	      p2 = aa[locus].mighist[j].migevents[i][2];
	      if (migcount[p1][p2] >= size[p1][p2])
		{
		  size[p1][p2] += 10;
		  gencount[p1][p2] =
		    (long *) realloc (gencount[p1][p2],
				      sizeof (long) * size[p1][p2]);
		  memset (gencount[p1][p2] + migcount[p1][p2], 0,
			  sizeof (long) * 10);
		  migtime[p1][p2] =
		    (double *) realloc (migtime[p1][p2],
					sizeof (double) * size[p1][p2]);
		  memset (migtime[p1][p2] + migcount[p1][p2], 0,
			  sizeof (double) * 10);
		}
	      gencount[p1][p2][migcount[p1][p2]] += copies;
	      migtime[p1][p2][migcount[p1][p2]] +=
		aa[locus].mighist[j].migevents[i][0];
	      migcount[p1][p2] += 1;
	    }
	}
    }
  for (p1 = 0; p1 < world->numpop; p1++)
    {
      for (p2 = 0; p2 < p1; p2++)
	{
	  pa = p2 * world->numpop + p1;
	  pb = p1 * world->numpop + p2;
	  migtable[pa][0] = p1;
	  migtable[pb][0] = p2;
	  migtable[pa][1] = p2;
	  migtable[pb][1] = p1;
	  migtable[pa][2] = average (migtime[p1][p2],
				     migcount[p1][p2],
				     gencount[p1][p2], &se, &n);
	  migtable[pa][4] = se;
	  migtable[pa][5] = n;
	  migtable[pb][2] = average (migtime[p2][p1],
				     migcount[p2][p1],
				     gencount[p2][p1], &se, &n);
	  migtable[pb][4] = se;
	  migtable[pb][5] = n;
	  migtable[pa][3] = quantile (migtime[p1][p2],
				      migcount[p1][p2],
				      gencount[p1][p2], 0.5);
	  migtable[pb][3] = quantile (migtime[p2][p1],
				      migcount[p2][p1],
				      gencount[p2][p1], 0.5);
	  prepare_hist (&histogram[pa], migtime[p1][p2], migcount[p1][p2],
			gencount[p1][p2]);
	  prepare_hist (&histogram[pb], migtime[p2][p1], migcount[p2][p1],
			gencount[p2][p1]);

	}
    }
  for (p1 = 0; p1 < world->numpop2; p1++)
    {
      total += migtable[p1][5];
    }
  for (p1 = 0; p1 < world->numpop; ++p1)
    {
      free (migcount[p1]);
      free (size[p1]);
      for (p2 = 0; p2 < world->numpop; ++p2)
	{
	  free (migtime[p1][p2]);
	  free (gencount[p1][p2]);
	}
      free (migtime[p1]);
      free (gencount[p1]);
    }
  free (migtime);
  free (gencount);
  return total;
}

void
prepare_hist (histogram_fmt * hist, double *time, long count, long *weight)
{
  hist->count = count;
  if (hist->time == NULL)
    hist->time = (double *) calloc (count, sizeof (double));
  else
    hist->time = realloc (hist->time, sizeof (double) * count);
  if (hist->weight == NULL)
    hist->weight = (long *) calloc (count, sizeof (long));
  else
    hist->weight = realloc (hist->weight, sizeof (long) * count);
  memcpy (hist->time, time, sizeof (double) * count);
  memcpy (hist->weight, weight, sizeof (long) * count);
}

double
average (double *vec, long size, long *weight, double *se, long *n)
{
  long i;
  double mean, sum = 0., sum2 = 0.;
  long sumweight = 0;
  for (i = 0; i < size; ++i)
    sumweight += weight[i];
  for (i = 0; i < size; ++i)
    {
      sum += vec[i] * weight[i];
      sum2 += (vec[i] * weight[i]) * (vec[i] * weight[i]);
    }
  if (sumweight != 0)
    {
      mean = sum / sumweight;
      if (sumweight > 1)
	*se = sqrt (fabs (sum - sum2)) / (sumweight - 1.);
      else
	*se = DBL_MAX;
      *n = sumweight;
      return mean;
    }
  else
    {
      *n = 0;
      *se = DBL_MAX;
    }
  return NOAVERAGE;
}

double
quantile (double *vec, long size, long *weight, double quantile)
{
  long i, j, z = 0;
  double *tmp1;
  double val;
  long sumweight = 0;
  for (i = 0; i < size; ++i)
    sumweight += weight[i];

  tmp1 = (double *) calloc (sumweight, sizeof (double));

  for (i = 0; i < size; ++i)
    {
      for (j = 0; j < weight[i]; ++j)
	tmp1[z++] = vec[i];
    }
  qsort ((void *) tmp1, sumweight, sizeof (double), numcmp);
  val = tmp1[(long) (sumweight * quantile)];
  free (tmp1);
  return val;
}

void
setup_mighist (world_fmt * world, option_fmt * options)
{
  long locus, i;
  if (world->options->mighist)
    {
      world->mighistloci = (mighistloci_fmt *)
	calloc (world->loci, sizeof (mighistloci_fmt));
      world->mighistlocinum = 0;
      for (locus = 0; locus < world->loci; locus++)
	{
	  world->mighistloci[locus].mighist =
	    (mighist_fmt *) calloc (1 + options->lsteps,
				    sizeof (mighist_fmt));
	  world->mighistloci[locus].mighistnum = 0;
	  for (i = 0; i < options->lsteps; i++)
	    world->mighistloci[locus].mighist[i].migevents =
	      (migevent_fmt *) calloc (1, sizeof (migevent_fmt));
	  /* allocation will be in archive_timelist() */
	}
    }
}

void
minmax (histogram_fmt * hist, double *tempmin, double *tempmax)
{
  long i;
  double tmp1, tmp2;
  double tmpmin = DBL_MAX;
  double tmpmax = -DBL_MAX;

  for (i = 0; i < hist->count; i++)
    {
      if ((tmp1 = hist->time[i]) < tmpmin)
	tmpmin = tmp1;
      if ((tmp2 = hist->time[i]) > tmpmax)
	tmpmax = tmp2;
    }
  *tempmax = tmpmax;
  *tempmin = tmpmin;
}


void
print_histogram_ascii (FILE * out, histogram_fmt ** histogram,
		       plotfield_fmt ** plotfield, long loci, long nmigs,
		       long bins, long *sum, double ***migtable)
{
  long l, i, j, z, zz;
  double biggest = 0.;
  double *binning;
  double *binvec;
  double tempmin = DBL_MAX;
  double tempmax = -DBL_MAX;
  double begin = DBL_MAX;
  double end = -DBL_MAX;
  double delta;
  //double sum = 0;
  double time;
  long weight;
  binning = (double *) calloc (bins, sizeof (double));
  binvec = (double *) calloc (bins, sizeof (double));

  for (l = 0; l < loci; l++)
    {
      for (i = 0; i < nmigs; i++)
	{
	  if (migtable[l][i][2] == NOAVERAGE)
	    {
	      plotfield[l][i].print = FALSE;
	      continue;		// no event for this migration i->j
	    }
	  minmax (&histogram[l][i], &tempmin, &tempmax);
	  if (tempmin < begin)
	    begin = tempmin;
	  if (tempmax > end)
	    end = tempmax;
	}
    }
  delta = (end - begin) / bins;
  binning[0] = begin + 0.5 * delta;
  for (i = 1; i < bins; i++)
    binning[i] = delta + binning[i - 1];
  for (l = 0; l < loci; l++)
    {
      for (i = 0; i < nmigs; i++)
	{
	  if (migtable[l][i][2] == NOAVERAGE)
	    continue;		// no event for this migration i->j
	  memset (binvec, 0, sizeof (double) * bins);
	  for (j = 0; j < histogram[l][i].count; j++)
	    {
	      time = histogram[l][i].time[j];
	      weight = histogram[l][i].weight[j];
	      z = 0;
	      while (time > binning[z] && z < bins)
		z++;
	      binvec[z] += weight;
	    }
	  biggest = 0.;
	  for (j = 0; j < bins; j++)
	    {
	      plotfield[l][i].y[j] = (long) binvec[j];
	      plotfield[l][i].yfreq[j] = binvec[j] = binvec[j] / sum[l];
	      if (biggest < binvec[j])
		biggest = binvec[j];
	    }
	  for (j = 0; j < bins; j++)
	    {
	      for (zz = 0;
		   zz < (long) (binvec[j] * plotfield[l][i].ysize / biggest);
		   zz++)
		plotfield[l][i].data[j][zz] = '+';
	      plotfield[l][i].data[j][zz] = '\0';
	    }
	}
    }
  for (l = 0; l < loci; l++)
    {
      if (l == loci - 1)
	{
	  if (loci > 1)
	    fprintf (out,
		     "\nOver all loci\n------------------------------------------------------------------\n");
	  else
	    fprintf (out,
		     "\nLocus %li\n------------------------------------------------------------------\n",
		     l + 1);
	}
      else
	fprintf (out,
		 "\nLocus %li\n------------------------------------------------------------------\n",
		 l + 1);

      for (i = 0; i < nmigs; i++)
	{
	  if (plotfield[l][i].print)
	    {
	      fprintf (out, "%s\n\n%10.10s %10.10s %10.10s\n",
		       plotfield[l][i].title, plotfield[l][i].xaxis,
		       plotfield[l][i].yaxis, plotfield[l][i].yfaxis);
	      for (j = 0; j < bins; j++)
		{
		  fprintf (out,
			   "%10.6f %10li %10.6f %s\n", binning[j],
			   plotfield[l][i].y[j], plotfield[l][i].yfreq[j],
			   plotfield[l][i].data[j]);
		}
	      fprintf (out, " \n");
	    }
	}
    }
}
