/*------------------------------------------------------
Maximum likelihood estimation
of migration rate  and effectice population size
using a Metropolis-Hastings Monte Carlo algorithm
-------------------------------------------------------
 AIC model test   R O U T I N E S
 
 Peter Beerli 2001, Seattle
 beerli@gs.washington.edu
 
 Copyright 2001-2002 Peter Beerli and Joseph Felsenstein
 
 This software is distributed free of charge for non-commercial use
 and is copyrighted. Of course, we do not guarantee that the software
 works and are not responsible for any damage you may cause or have.
 
$Id: aic.c,v 1.15 2002/06/20 06:25:33 beerli Exp $ 
-------------------------------------------------------*/
#include "migration.h"
#include "tools.h"
#include "broyden.h"
#include "combroyden.h"
#include "options.h"
#include "sort.h"
#include "aic.h"

#ifdef DMALLOC_FUNC_CHECK
#include <dmalloc.h>
#endif


void print_progressheader_aic(boolean progress, FILE *file, double mleaic, double numparam);

void aic_score (aic_fmt ** aicvec, long *aicnum, nr_fmt * nr,
                long zero, long which, char *temppattern, double *param0,
                char migtype);

boolean legal_pattern (char *matrix, long numpop);

void fast_aic_score (aic_fmt ** aicvec, long *aicnum, nr_fmt * nr,
                     long zero, long which, char *temppattern,
                     double *param0, char migtype);

void add_aiclist (char migtype, long numparam, long freeparam, long remainnum, long zero, long m,
                  double *param0, char *temppattern, char *custm2,
                  long *aicnum, aic_fmt ** aicvec, nr_fmt * nr);

void add_aicvec(double aic, aic_fmt **aicvec, long *aicnum, nr_fmt *nr, long numparam, long remainnum);

void aic_print_driver(aic_fmt *aicvec, double mleaic, long aicnum, nr_fmt *nr, world_fmt *world);

void aic_score_driver(aic_struct *aic, double mle, double mleaic, nr_fmt *nr, world_fmt *world, char *temppattern);

void akaike_information (world_fmt * world, long *Gmax);

void print_header_aic (nr_fmt * nr, double mleaic);

void print_aicfile(aic_fmt **aicvec, long *aicnum, nr_fmt *nr);

void print_progress_aic(boolean add
                            , aic_fmt **aicvec, long *aicnum, nr_fmt *nr, long numparam, long freeparam);

// changes a linear matrix of ddd mm mm mm
// to a diagonal matrix dmm mdm mmd
// it destroys pattern
char *
reshuffle (char *pattern, char *origpattern, long numpop)
{
    long space = 0;
    long i, j, z = numpop;

    for (i = 0; i < numpop; i++)
    {
        for (j = 0; j < numpop; j++)
        {
            if (i == j)
                pattern[i * numpop + j + space] = 'x';
            else
                pattern[i * numpop + j + space] = origpattern[z++];
        }
        pattern[i * numpop + j + space] = ' ';
        space++;
    }
    return pattern;
}

/* run time reporter */
void
print_progressheader_aic(boolean progress, FILE *file, double mleaic, double numparam)
{
    if (progress)
    {
        fprintf (file, "\n\n");
        fprintf (file, "           Selecting the best migration model for this run,\n");
        fprintf (file, "           This may take a while!\n");
        fprintf (file, "           Checking all parameter combinations\n");
        fprintf (file, "           with  (AIC= -2 Log(L(Param)+n_param)<%f+%f\n",
                 mleaic, numparam);
    }
}

/* prints the header of the AIC table
     and runtime information
*/
void
print_header_aic (nr_fmt * nr, double mleaic)
{
    world_fmt *world = nr->world;

    print_progressheader_aic(world->options->progress, stdout,mleaic, world->options->aicmod*world->numpop2);
    print_progressheader_aic(world->options->writelog, stdout,mleaic, world->options->aicmod*world->numpop2);

    fprintf (world->outfile, "\n\n\n Akaike's Information Criterion  (AIC)\n");
    fprintf (world->outfile, "=========================================\n\n");
    fprintf (world->outfile, "[Linearized migration matrix, x=diagonal]\n");
    fprintf (world->outfile,
             "%-*.*s            AIC     #param   Log(L)   LRT      Prob     Probc\n",
             (int) MAX (18, nr->partsize + nr->numpop),
             (int) (nr->partsize + nr->numpop), "Pattern");
}

// calculates akaike information score
void aic_score_driver(aic_struct *aic, double mle, double mleaic, nr_fmt *nr, world_fmt *world, char *temppattern)
{
    char * savecustm;
    char *savecustm2;

    savecustm = calloc (world->numpop2, sizeof (char));
    savecustm2 = calloc (world->numpop2, sizeof (char));

    memcpy (savecustm, world->options->custm2, sizeof (char) * nr->numpop2);
    memcpy (savecustm2, world->options->custm, sizeof (char) * nr->numpop2);

    aic->aicnum = 1;
    aic->aicvec = (aic_fmt *) calloc (aic->aicnum, sizeof (aic_fmt));
    aic->aicvec[0].mle = mle;
    aic->aicvec[0].aic = mleaic;
    aic->aicvec[0].lrt = 0.0;
    aic->aicvec[0].prob = 1.0;
    aic->aicvec[0].probcorr = 1.0;
    aic->aicvec[0].numparam = nr->partsize;
    aic->aicvec[0].pattern = (char *) calloc (nr->partsize + 1, sizeof (char));
    memcpy (aic->aicvec[0].pattern, world->options->custm2, sizeof (char) * nr->partsize);
    if (world->options->fast_aic)
    {
        fast_aic_score (&aic->aicvec, &aic->aicnum, nr, 0, world->numpop,
                        temppattern, aic->param0, world->options->aictype[0]);
        memcpy (world->options->custm2, savecustm, sizeof (char) * nr->numpop2);
        if(world->options->aictype[1]!='\0')
            fast_aic_score (&aic->aicvec, &aic->aicnum, nr, 0, world->numpop,
                            temppattern, aic->param0, world->options->aictype[1]);
    }
    else
    {
        //find aic scores in a branch-and-bound fashion
        // with some parameters set to zero, this needs more
        // investigation because of boundary problems
        aic_score (&aic->aicvec, &aic->aicnum, nr, 0, world->numpop,
                   temppattern, aic->param0, world->options->aictype[0]);
        memcpy (world->options->custm2, savecustm, sizeof (char) * nr->numpop2);
        // aic scores based on averaging M (not 4Nm)
        if(world->options->aictype[1]!='\0')
        {
            aic_score (&aic->aicvec, &aic->aicnum, nr, 0, world->numpop, temppattern,
                       aic->param0, world->options->aictype[1]);
            memcpy (world->options->custm2, savecustm, sizeof (char) * nr->numpop2);
        }
    }
    memcpy (world->options->custm2, savecustm, sizeof (char) * nr->numpop2);
    memcpy (world->options->custm, savecustm2, sizeof (char) * nr->numpop2);
    free(savecustm);
    free(savecustm2);
}


// print aic results, assume they are already ordered by qsort()
// printing is:
// pattern {migration matrix}
// aic value, number of parameters, MLE of pattern, Prob of standard LRT, Prob of weighted chisquare
void aic_print_driver(aic_fmt *aicvec, double mleaic, long aicnum, nr_fmt *nr, world_fmt *world)
{
    long i;
    char *temppattern;
    boolean mldone=FALSE;
    temppattern = calloc (world->numpop2 + 1 + world->numpop, sizeof (char));
    for (i = 0; i < aicnum; i++)
    {
        fprintf (world->outfile, "%-*.*s %20.5f %4li % f % f %f %f\n",
                 (int) MAX (18, nr->partsize + nr->numpop),
                 (int) (nr->partsize + nr->numpop),
                 reshuffle (temppattern, aicvec[i].pattern, nr->numpop),
                 aicvec[i].aic, aicvec[i].numparam, aicvec[i].mle,
                 aicvec[i].lrt, aicvec[i].prob, aicvec[i].probcorr);
        if (aicvec[i].aic == mleaic && !mldone)
        {
            mldone = TRUE;
            fprintf (world->outfile, "%-*.*s%21.21s-----\n",
                     (int) MAX (18, nr->partsize + nr->numpop),
                     (int) (nr->partsize + nr->numpop),
                     "--------------------------------------------------------------------------------------------------------------------------------------------------------------",
                     "---------------------");
        }
        free (aicvec[i].pattern);
    }
    free(temppattern);
}

// this drives the aic calculation and is called from main.c
//
void
akaike_information (world_fmt * world, long *Gmax)
{
    aic_struct aic;
    nr_fmt *nr;
    long kind = world->loci > 1 ? MULTILOCUS : SINGLELOCUS;
    long repstop;
    long repstart;

    double mleaic;
    double mle;
    boolean multilocus;

    char *temppattern;

    prepare_broyden (kind, world, &multilocus);
    world->options->migration_model = MATRIX_ARBITRARY;

    temppattern = calloc (world->numpop2 + 1 + world->numpop, sizeof (char));
    aic.param0 = calloc (world->numpop2 + 1, sizeof (double));

    set_replicates (world, world->repkind, world->rep, &repstart, &repstop);

    if (kind == MULTILOCUS)
    {
        mle = world->atl[0][world->loci].param_like;
        mleaic = -2. * mle + 2. * world->numpop2;
    }
    else
    {
        mle = world->atl[repstop == 1 ? 0 : repstop][0].param_like;
        mleaic = -2. * mle + 2. * world->numpop2;
    }
    nr = (nr_fmt *) calloc (1, sizeof (nr_fmt));

    create_nr (nr, world, *Gmax, 0, world->loci, world->repkind, world->rep);

    SETUPPARAM0 (world, nr, world->repkind, repstart, repstop,
                 world->loci, kind, multilocus);

    print_header_aic (nr, mleaic);

    if (kind == MULTILOCUS)
        memcpy (aic.param0, nr->world->atl[0][nr->world->loci].param, sizeof (double) * nr->numpop2);
    else
        memcpy (aic.param0, nr->world->atl[repstop ==1 ? 0 : repstop][nr->world->locus].param, sizeof (double) * nr->numpop2);

    aic_score_driver(&aic, mle, mleaic, nr, world, temppattern);

    qsort ((void *) aic.aicvec, aic.aicnum, sizeof (aic_fmt), aiccmp);

    aic_print_driver(aic.aicvec, mleaic, aic.aicnum, nr, world);

    free (aic.aicvec);
    fflush (world->outfile);
    free (aic.param0);
    free (temppattern);
    destroy_nr (nr, world);
}

// does check all enumeration on one level and then picks best
// this is different to the "branch-bound" algorithm
//
void
fast_aic_score (aic_fmt ** aicvec, long *aicnum, nr_fmt * nr,
                long zero, long which, char *temppattern,
                double *param0, char migtype)
{
    long m, ii;
    double likes = 0;
    double normd = 0;
    double aic;
    double borderaic = (*aicvec)[0].aic + nr->world->options->aicmod * nr->numpop2;
    char savecustm2;
    long remainnum = 0;
    boolean legal;
    char *custm2 = nr->world->options->custm2;
    char *scustm2;
    long numparam;
    long freeparam;
    aic_fmt *best;

    scustm2 = (char *) calloc (nr->partsize, sizeof (char));
    memcpy (scustm2, custm2, sizeof (char) * nr->partsize);
    if (migtype == 'm')
        remainnum = 1;
    numparam = zero;
    freeparam = (nr->numpop2 - numparam - 1 + remainnum);
    best = (aic_fmt *) calloc (nr->partsize, sizeof (aic_fmt));
    for (m = nr->numpop; m < nr->numpop2; m++)
    {
        best[m].aic = DBL_MAX;
        best[m].numparam = m;
        if (scustm2[m] == migtype)
            continue;
        savecustm2 = custm2[m];
        custm2[m] = migtype;
        memcpy (nr->world->param0, param0, sizeof (double) * nr->numpop2);
        resynchronize_param (nr->world);
        if ((legal = legal_pattern (nr->world->options->custm2, nr->numpop)))
        {
            do_profiles (nr->world, nr, &likes, &normd, PROFILE,
                         nr->world->rep, nr->world->repkind);
            aic = -2. * nr->llike + 2. * freeparam;
            best[m].aic = aic;
            best[m].numparam = m;
            add_aiclist(migtype, numparam, freeparam, remainnum, zero, m,
                        param0, temppattern, custm2, aicnum, aicvec, nr);
        }
        else
        {   // illegal combination of parameters
            if (nr->world->options->progress)
                fprintf (stdout, "           F   %s %20s\n",
                         reshuffle (temppattern, custm2, nr->numpop), "-----");
            fflush (stdout);
            if (nr->world->options->writelog)
                fprintf (nr->world->options->logfile, "           F   %s %20s\n",
                         reshuffle (temppattern, custm2, nr->numpop), "-----");
        }
        custm2[m] = savecustm2;
    }
    for (ii = nr->numpop; ii < nr->partsize; ii++)
    {
        if (best[ii].aic < borderaic && custm2[best[ii].numparam] != migtype)
        {
            custm2[best[ii].numparam] = migtype;
            fast_aic_score (aicvec, aicnum, nr, zero + 1, best[ii].numparam,
                            temppattern, param0, migtype);
        }
    }
    free (best);
}

void add_aicvec(double aic, aic_fmt **aicvec, long *aicnum, nr_fmt *nr, long numparam, long remainnum)
{
    double lrt = -2. * (nr->llike - (*aicvec)[0].mle);

    *aicvec = (aic_fmt *)
              realloc (*aicvec, sizeof (aic_fmt) * (*aicnum + 1));
    (*aicvec)[*aicnum].pattern = (char *)
                                 calloc (nr->partsize + 1, sizeof (char));
    (*aicvec)[*aicnum].aic = aic;
    (*aicvec)[*aicnum].mle = nr->llike;
    (*aicvec)[*aicnum].numparam = nr->numpop2 - numparam - 1 + remainnum;
    memcpy ((*aicvec)[*aicnum].pattern, nr->world->options->custm2, sizeof (char) * nr->partsize);

    (*aicvec)[*aicnum].lrt = lrt;
    (*aicvec)[*aicnum].prob = probchi (numparam, (*aicvec)[*aicnum].lrt);
    (*aicvec)[*aicnum].probcorr = probchiboundary ((*aicvec)[*aicnum].lrt, numparam, numparam);
}

void
print_aicfile(aic_fmt **aicvec, long *aicnum, nr_fmt *nr)
{
    long i;
    if (nr->world->options->aicfile)
    {
        fprintf (nr->world->options->aicfile, "%f %f %li %f  %f ",
                 (*aicvec)[*aicnum].aic, (*aicvec)[*aicnum].lrt,
                 (*aicvec)[*aicnum].numparam,
                 (*aicvec)[*aicnum].prob,
                 (*aicvec)[*aicnum].probcorr);

        for (i = 0; i < nr->partsize; i++)
            fprintf (nr->world->options->aicfile, "%f ", nr->world->param0[i]);
        fprintf (nr->world->options->aicfile, "\n");
    }
}

void
print_progress_aic(boolean add
                       , aic_fmt **aicvec, long *aicnum, nr_fmt *nr, long numparam, long freeparam)
{
    double mle, aic, lrt, prob, probcorr;
    char *temppattern;
    char *custm2 = nr->world->options->custm2;
    temppattern = calloc (nr->world->numpop2 + 1 + nr->world->numpop, sizeof (char));
    if(add
      )
    {
        if (nr->world->options->progress)
            fprintf (stdout,
                     "           +   %s %20.5f %3li %8.4f %8.4f %6.4f %6.4f\n",
                     reshuffle (temppattern, custm2, nr->numpop), (*aicvec)[*aicnum].aic,
                     freeparam, (*aicvec)[*aicnum].mle,
                     (*aicvec)[*aicnum].lrt, (*aicvec)[*aicnum].prob,
                     (*aicvec)[*aicnum].probcorr);
        if (nr->world->options->writelog)
            fprintf (nr->world->options->logfile,
                     "           +   %s %20.5f %3li %8.4f %8.4f %6.4f %6.4f\n",
                     reshuffle (temppattern, custm2, nr->numpop), (*aicvec)[*aicnum].aic,
                     freeparam, (*aicvec)[*aicnum].mle,
                     (*aicvec)[*aicnum].lrt, (*aicvec)[*aicnum].prob,
                     (*aicvec)[*aicnum].probcorr);
    }
    else
    {
        mle = nr->llike;
        aic =  -2. * mle + 2. * freeparam;
        lrt =  -2. * (mle - (*aicvec)[0].mle);
        prob =  probchi (numparam, lrt);
        probcorr =  probchiboundary (lrt, numparam, numparam);
        if (nr->world->options->progress)
            fprintf (stdout,
                     "           -   %s %20.5f %3li %8.4f %8.4f %6.4f %6.4f\n",
                     reshuffle (temppattern, custm2, nr->numpop), aic,
                     freeparam, mle, lrt,
                     prob,
                     probcorr);
        if (nr->world->options->writelog)
            fprintf (nr->world->options->logfile,
                     "           +   %s %20.5f %3li %8.4f %8.4f %6.4f %6.4f\n",
                     reshuffle (temppattern, custm2, nr->numpop), aic,
                     freeparam, mle, lrt,
                     prob,
                     probcorr);
    }
    fflush (stdout);
    free(temppattern);
}

void
add_aiclist (char migtype, long numparam, long freeparam, long remainnum, long zero, long m,
             double *param0, char *temppattern, char *custm2,
             long *aicnum, aic_fmt ** aicvec, nr_fmt * nr)
{
    double aic = -2. * nr->llike + 2. * freeparam;

    if (aic < (*aicvec)[0].aic + nr->world->options->aicmod * freeparam)
    {
        if (migtype != 'm' || (migtype == 'm' && nr->world->options->mmn > 1))
        {
            add_aicvec(aic, aicvec, aicnum, nr, numparam, remainnum);
            print_aicfile(aicvec,aicnum,nr);
            print_progress_aic(TRUE, aicvec,aicnum,nr, numparam, freeparam);
            (*aicnum)++;
            aic_score (aicvec, aicnum, nr, zero + 1, m + 1,
                       temppattern, param0, migtype);
        }
    }
    else
    {
        print_progress_aic(FALSE, aicvec,aicnum,nr, numparam, freeparam);
    }
}


void
aic_score (aic_fmt ** aicvec, long *aicnum, nr_fmt * nr,
           long zero, long which, char *temppattern, double *param0,
           char migtype)
{
    long m;
    double likes = 0;
    double normd = 0;
    char savecustm2;
    long remainnum = 0;
    boolean legal;
    char *custm2 = nr->world->options->custm2;
    long numparam = 0;
    long freeparam;

    switch (migtype)
    {
    case '0':
        numparam = nr->world->options->zeron;
        remainnum = 0;
        break;
    case 'm':
        numparam = nr->world->options->mmn;
        remainnum = 1;
        if (nr->world->options->custm2[which] == 'm')
            return;
        break;
    }
    freeparam = (nr->numpop2 - numparam - 1 + remainnum);
    for (m = which; m < nr->numpop2; m++)
    {
        savecustm2 = custm2[m];
        custm2[m] = migtype;
        memcpy (nr->world->param0, param0, sizeof (double) * nr->numpop2);
        resynchronize_param (nr->world);

        if ((legal = legal_pattern (nr->world->options->custm2, nr->numpop)))
        {
            do_profiles (nr->world, nr, &likes, &normd, PROFILE,
                         nr->world->rep, nr->world->repkind);

            add_aiclist(migtype, numparam, freeparam, remainnum, zero, m, param0,
                        temppattern, nr->world->options->custm2, aicnum, aicvec, nr);
        }
        else
        {
            if (nr->world->options->progress)
            {
                fprintf (stdout, "           F   %s %20s\n", reshuffle (temppattern, custm2, nr->numpop), "-----");
                fflush (stdout);
            }
            if (nr->world->options->writelog)
                fprintf (nr->world->options->logfile, "           F   %s %20s\n", reshuffle (temppattern, custm2, nr->numpop), "-----");
        }
        custm2[m] = savecustm2;
    }
}

boolean
check_numparam (long which, long migtype, worldoption_fmt * options,
                long *numparam, long *remainnum)
{
    boolean rc = FALSE;
    switch (migtype)
    {
    case '0':
        *numparam = options->zeron;
        *remainnum = 0;
        break;
    case 'm':
        *numparam = options->mmn;
        *remainnum = 1;
        if (options->custm2[which] == 'm')
            rc = TRUE;
        break;
    }
    return rc;
}

boolean
legal_pattern (char *matrix, long numpop)
{
    long from, to, i;
    double summ = -1;
    double oldto;
    for (i = 0; i < numpop; i++)
    {
        if (matrix[i] == '0')
            return FALSE;
    }
    oldto = -1;
    for (i = numpop; i < numpop * numpop; i++)
    {
        m2mm (i, numpop, &from, &to);
        if (oldto != to)
        {
            if (summ == 0)
                return FALSE;
            oldto = to;
            summ = 0;
        }
        summ += (matrix[i] != '0') + (matrix[mm2m (to, from, numpop)] != '0');
    }
    if (summ == 0)
        return FALSE;
    return TRUE;
}
