/*-----------------------------------------------------------------------------

Copyright (C) 2014.

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA   

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

// This code is from gsm_develop/svfx.  
// Errors discovered here should be fixed there also.

#include <cerrno>
#include "snp_stat_mod.h"

using namespace scl;
using namespace libsnp;
using namespace snp;
using namespace std;

namespace {
  
  bool good_sim(const realmat& y)
  {
    #if defined GNU_GPP_COMPILER
      for (INTEGER i=1; i<=y.size(); ++i) {
        //if (!finite(y[i])) return false;  // finite is deprecated
        if (!isfinite(y[i])) return false;
      }
    #else
      const REAL quarter_max = REAL_MAX/4.0;
      for (INTEGER i=1; i<=y.size(); ++i) {
        if (fabs(y[i])>quarter_max) return false;
      }
    #endif
    return true;
  };

}

mle::snp_stat_mod::snp_stat_mod
  (const realmat* dat_ptr, const vector<string>& pfvec, 
   const vector<string>& alvec, ostream& detail)
{
  
  if (!pf.set_parms(pfvec, detail)) {
     detail.flush();
     string msg("Error, snp_stat_mod, cannot read parmfile.  Make sure\n");
     msg += "the typedef in gsmusr.h is bound to the correct class.";
     error(msg);
  }

  Y = *dat_ptr;

  if (Y.get_rows()!=pf.get_datparms().M || Y.get_cols()!=pf.get_datparms().n) {
    error("Error, gsm_stat_mod, dim data disagrees with (M,n) in pfvec");
  }

  if (pf.get_optparms().print) {
    detail << starbox("/First 12 observations//");
    detail << Y("",seq(1,12));
    detail << starbox("/Last 12 observations//");
    detail << Y("",seq(pf.get_datparms().n-11, pf.get_datparms().n));
  }

  trnfrm tr_tmp(pf.get_datparms(), pf.get_tranparms(), Y);
  tr = tr_tmp;

  pf.set_tranparms(tr.get_tranparms());

  if (pf.get_optparms().print) {
    detail << starbox("/Mean and variance of data//");
    if (pf.get_tranparms().diag) {
      detail << "(Variance has been diagonalized.)\n";
    }
    detail << pf.get_tranparms().mean << pf.get_tranparms().variance;
  }

  X = Y;

  tr.normalize(Y);
  tr.normalize(X);

  if (pf.get_tranparms().squash == 1) {
    tr.spline(X);
  } else if (pf.get_tranparms().squash == 2) {
    tr.logistic(X);
  }

  if (pf.get_optparms().print) {
    detail << starbox("/First 12 normalized observations//");
    detail << Y("",seq(1,12));
    detail << starbox("/Last 12 normalized observations//");
    detail << Y("",seq(pf.get_datparms().n-11, pf.get_datparms().n));
    if (pf.get_tranparms().squash > 0) {
      detail << starbox("/First 12 transformed observations//");
      detail << Y("",seq(1,12));
      detail << starbox("/Last 12 transformed observations//");
      detail << Y("",seq(pf.get_datparms().n-11, pf.get_datparms().n));
    }
  }

  Ylags.resize(pf.get_datparms().M, pf.get_datparms().drop);
  Xlags.resize(pf.get_datparms().M, pf.get_datparms().drop);

  for (INTEGER t=1; t<=pf.get_datparms().drop; ++t) {
    for (INTEGER i=1; i<=pf.get_datparms().M; ++i) {
      Ylags(i,t) = Y(i,t);
      Xlags(i,t) = X(i,t);
    }
  }

  snpll ll_tmp(pf.get_datparms(),
    pf.get_snpden(), pf.get_afunc(), pf.get_ufunc(),
    pf.get_rfunc(), pf.get_afunc_mask(), pf.get_ufunc_mask(),
    pf.get_rfunc_mask());

  ll = ll_tmp;
  ll.set_XY(&X,&Y);

  vector<string>::const_iterator al_ptr=alvec.begin();

  lsim = atoi((++al_ptr)->substr(0,12).c_str());
  spin = atoi((++al_ptr)->substr(0,12).c_str());

  lrho = ll.get_rho().size();
  INTEGER M = pf.get_datparms().M;
  lstats = M + M*M;

  if (pf.get_optparms().print) {
    detail << starbox("/Additional settings//");
    detail << "\t lsim = " << lsim << '\n';
    detail << "\t spin = " << spin << '\n';
    detail << "\t lrho = " << lrho << '\n';
    detail << "\t lstats = " << lstats << '\n';
  }
 
  if (pf.get_optparms().print) detail.flush();
}


bool mle::snp_stat_mod::snp_simulate(INT_32BIT& simseed, realmat& sim)
{

  datparms dpm = ll.get_datparms(); 
  tranparms tpm = tr.get_tranparms();

  snpden f = ll.get_snpden();
  afunc af = ll.get_afunc();
  ufunc uf = ll.get_ufunc();
  rfunc rf = ll.get_rfunc();

  if (sim.get_rows() != dpm.M || sim.get_cols() != lsim) {
    sim.resize(dpm.M,lsim);
  }

  realmat y(dpm.M,1);
  realmat x(dpm.M,1);
  realmat u(dpm.M,1);
  realmat y_lag(dpm.M,1);
  realmat x_lag(dpm.M,1);
  realmat u_lag(dpm.M,1);

  af.initialize_state();
  uf.initialize_state();
  rf.initialize_state();

  for (INTEGER i=1; i<=dpm.M; ++i) {
    x[i] = Xlags(i,1);
  }

  u = uf(x);

  for (INTEGER t=3; t<=dpm.drop; ++t) {

    for (INTEGER i=1; i<=dpm.M; ++i) {
      y[i] = Ylags(i,t);
      x[i] = Xlags(i,t);
      y_lag[i] = Ylags(i,t-1);
      x_lag[i] = Xlags(i,t-1);
    }

    u_lag = u;

    u = uf(x_lag);

    f.set_R(rf(y_lag,u_lag,x_lag));
    f.set_a(af(x_lag));
    f.set_u(u);
  }

  INTEGER min_spin = dpm.drop + dpm.M;
  INTEGER spin0 = min_spin > spin ? min_spin : spin;

  for (INTEGER t=dpm.drop+1; t<=spin0; ++t) {

    for (INTEGER i=1; i<=dpm.M; ++i) {
      y_lag[i] = y[i];
      x_lag[i] = x[i];
    }

    u_lag = u;

    u = uf(x_lag);

    f.set_R(rf(y_lag,u_lag,x_lag));
    f.set_a(af(x_lag));
    f.set_u(u);

    y = f.sampy(simseed);
    if (!good_sim(y)) return false;
    x = y;
    if (tpm.squash == 1) {
      tr.spline(x);
    }
    else if (tpm.squash == 2) {
      tr.logistic(x);
    }
  }

  INTEGER min_lsim = dpm.drop + dpm.M*(dpm.M+1);
  INTEGER lsim0 = min_lsim > lsim ? min_lsim : lsim;
  
  for (INTEGER t=1; t<=lsim0; ++t) {

    for (INTEGER i=1; i<=dpm.M; ++i) {
      y_lag[i] = y[i];
      x_lag[i] = x[i];
    }

    u_lag = u;

    u = uf(x_lag);

    f.set_R(rf(y_lag,u_lag,x_lag));
    f.set_a(af(x_lag));
    f.set_u(u);

    y = f.sampy(simseed);
    if (!good_sim(y)) return false;
    x = y;
    if (tpm.squash == 1) {
      tr.spline(x);
    }
    else if (tpm.squash == 2) {
      tr.logistic(x);
    }

    realmat v = y;
    tr.unnormalize(v);

    for (INTEGER i=1; i<=v.size(); ++i) {
      sim(i,t) = v[i];
    }
  }

  return true;
}

bool mle::snp_stat_mod::gen_sim(realmat& sim, realmat& stats) 
{ 
  stats.resize(lstats, 1, 0.0);

  INT_32BIT fixed_seed = 770116;

  if (!snp_simulate(fixed_seed,sim)) return false;

  INTEGER M = sim.get_rows();
  INTEGER N = sim.get_cols();

  if (lstats != M + M*M) {
    error("Error, snp_stat_mod constructor and gen_sim lstats values differ");
  }

  realmat mean(M,1,0.0);
  intvec null;
  for (INTEGER t=1; t<=N; ++t) {
    mean += sim(null,t);
  }

  mean = mean/N;
  
  realmat var(M,M,0.0);
  for (INTEGER t=1; t<=N; ++t) {
    var += (sim(null,t)-mean)*T(sim(null,t)-mean);
  }

  var = var/N;
  
  for (INTEGER i=1; i<=M; ++i)   stats[i] = mean[i];
  for (INTEGER i=1; i<=M*M; ++i) stats[M+i] = var[i];

  return true;
}

bool mle::snp_stat_mod::write_parmfile(const char* filename)
{
  string ctrl = "Generated by snp_stat_mod";
  pf.set_afunc(ll.get_afunc());
  pf.set_ufunc(ll.get_ufunc());
  pf.set_rfunc(ll.get_rfunc());
  return pf.write_parms(filename, ctrl, ll);
}

void mle::snp_stat_mod::get_parm(realmat& rho) 
{ 
  rho = ll.get_rho();
}

void mle::snp_stat_mod::set_parm(const realmat& rho) 
{ 
  ll.set_rho(rho);
}

bool mle::snp_stat_mod::support(const realmat& rho) 
{
  realmat test = ll.get_theta()(ll.get_srvec(),1);
  for (INTEGER i=1; i<=rho.size(); ++i) if (!isfinite(rho[i])) return false;
  for (INTEGER i=1; i<=test.size(); ++i) if (test[i] <= 0.0) return false;
  if (ll.get_uf_stability() >= 1.0) return false;
  if (ll.get_rf_stability() >= 1.0) return false;
  return true; 
}

den_val mle::snp_stat_mod::prior(const realmat& rho, const realmat& stats) 
{
  return den_val(true,0.0); 
}
 

void mle::snp_stat_mod::set_data_ptr(const realmat* dat_ptr)
{
  Y = *dat_ptr;

  tr.normalize(Y);

  if (tr.get_tranparms().squash == 1) {
    X = Y;
    tr.spline(X);
    ll.set_XY(&X,&Y);
  } else if (tr.get_tranparms().squash == 2) {
    X = Y;
    tr.logistic(X);
    ll.set_XY(&X,&Y);
  }
  else {
    ll.set_XY(&Y,&Y);
  }

  ll.set_n(Y.get_cols());
}

den_val mle::snp_stat_mod::loglikelihood()
{
  return den_val(true,ll.log_likehood());
}  

den_val mle::snp_stat_mod::loglikelihood(realmat& dlogl)
{
  return den_val(true,ll.log_likehood(dlogl));
}  

den_val mle::snp_stat_mod::loglikelihood(realmat& dlogl,realmat& infmat)
{
  return den_val(true,ll.log_likehood(dlogl,infmat));
}  

den_val mle::snp_stat_mod::loglikelihood
  (realmat& dlogl,realmat& infmat, realmat& scores)
{
  REAL rv = ll.log_likehood(dlogl,infmat,scores);
  return den_val(true,rv);
}  

REAL mle::snp_stat_mod::penalty(realmat& dpenalty) 
{
  smomax smooth;
  intvec ivec = ll.get_srvec();
  realmat theta = ll.get_theta();
  REAL sum = 0.0;
  realmat d_sum_wrt_theta(1,theta.size(),0.0);
  for (INTEGER i=1; i<=ivec.size(); ++i) {
    sum += smooth.max0(-theta[ivec[i]]);
    d_sum_wrt_theta[ivec[i]] += smooth.max1(-theta[ivec[i]]);
  }
  dpenalty.resize(1,len_parm());
  vector< pair<INTEGER,INTEGER> > rt = ll.get_rt(); 
  vector< pair<INTEGER,INTEGER> >::const_iterator rtptr;
  for (rtptr = rt.begin(); rtptr != rt.end(); ++rtptr) {
    dpenalty[rtptr->first] = d_sum_wrt_theta[rtptr->second];
  }
  return sum; 
}
