/* ----------------------------------------------------------------------------

Copyright (C) 2005, 2006, 2008, 2011.

A. Ronald Gallant
Post Office Box 659
Chapel Hill NC 27514-0659
USA

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.

-----------------------------------------------------------------------------*/

#include "libgsm.h"
using namespace std;  
using namespace scl;
using namespace libgsm;

namespace {

  void do_nothing(const std::string s) { return; }
  LIB_WARN_HANDLER_PTR nowarn = &do_nothing;
}

namespace libgsm {

  bool near_eq(REAL lhs, REAL rhs, REAL tol)
  {
    tol = tol > 0 ? 0.5*tol : sqrt(EPS);
    if ( fabs(lhs-rhs) > tol*(fabs(lhs)+fabs(rhs)+tol) ) return false;
    return true;
  }
  
  bool near_eq(den_val lhs, den_val rhs, REAL tol)
  {
    if (lhs.positive != rhs.positive) return false;
    return near_eq(lhs.log_den, rhs.log_den, tol);
  }
    
  bool near_eq(const realmat& lhs, const realmat& rhs, REAL tol)
  {
    INTEGER rows = lhs.get_rows();
    if (rows != rhs.get_rows()) return false;
    INTEGER cols = lhs.get_cols();
    if (cols != rhs.get_cols()) return false;
    for (INTEGER i=1; i<=rows*cols; ++i) {
      if (!near_eq(rhs[i],lhs[i],tol)) return false;
    }
    return true;
  }

  bool near_eq(const map_val& lhs, const map_val& rhs, REAL tol) 
  {
    if(lhs.sci_mod_support != rhs.sci_mod_support) return false;
    if(lhs.sci_mod_simulate != rhs.sci_mod_simulate) return false;
    if(!near_eq(lhs.sci_mod_func, rhs.sci_mod_func, tol)) return false;
    if(!near_eq(lhs.sci_mod_prior, rhs.sci_mod_prior, tol)) return false;
    if(!near_eq(lhs.stat_mod_parm, rhs.stat_mod_parm, tol)) return false;
    if(!near_eq(lhs.stat_mod_sub_logl,rhs.stat_mod_sub_logl,tol)) return false;
    if(!near_eq(lhs.stat_mod_logl, rhs.stat_mod_logl, tol)) return false;
    if(lhs.stat_mod_support != rhs.stat_mod_support) return false;
    if(lhs.stat_mod_simulate != rhs.stat_mod_simulate) return false;
    if(!near_eq(lhs.stat_mod_func, rhs.stat_mod_func, tol)) return false;
    if(!near_eq(lhs.stat_mod_prior, rhs.stat_mod_prior, tol)) return false;
    return true;
  }

  bool operator==(const den_val& lhs, const den_val& rhs)
  {
    if (lhs.positive != rhs.positive) return false;
    if (lhs.log_den  != rhs.log_den)  return false;
    return true;
  }
  
  bool operator!=(const den_val& lhs, const den_val& rhs)
  {
    return !(lhs == rhs);
  }
  
  ostream& operator<<(ostream& os, const den_val& dv)
  {
    os << INTEGER(dv.positive) << ' ' << fmt('e',26,17,dv.log_den);
    return os;
  }
  
  istream& operator>>(istream& is, den_val& dv)
  {
    is >> dv.positive;
    is >> dv.log_den;
    return is;
  }

  ostream& operator<<(ostream& os, const sci_val& scv)
  {
    vecwrite(os,scv.sci_mod_parm);
    os << INTEGER(scv.sci_mod_support) << '\n';
    os << INTEGER(scv.sci_mod_simulate) << '\n';
    vecwrite(os,scv.sci_mod_func);
    os << scv.sci_mod_prior << '\n';
    vecwrite(os,scv.stat_mod_parm);
    os << scv.stat_mod_sub_logl << '\n';
    os << scv.stat_mod_logl << '\n';
    os << INTEGER(scv.stat_mod_support) << '\n';
    os << INTEGER(scv.stat_mod_simulate) << '\n';
    vecwrite(os,scv.stat_mod_func);
    os << scv.stat_mod_prior << '\n';
    return os;
  }

  istream& operator>>(std::istream& is, sci_val& scv)
  {
    LIB_WARN_HANDLER_PTR previous = set_lib_warn_handler(nowarn);
    vecread(is,scv.sci_mod_parm);
    if (!is.good()) { set_lib_warn_handler(previous); return is; }
    is >> scv.sci_mod_support;
    is >> scv.sci_mod_simulate;
    vecread(is,scv.sci_mod_func);
    is >> scv.sci_mod_prior;
    vecread(is,scv.stat_mod_parm);
    is >> scv.stat_mod_sub_logl;
    is >> scv.stat_mod_logl;
    is >> scv.stat_mod_support;
    is >> scv.stat_mod_simulate;
    vecread(is,scv.stat_mod_func);
    is >> scv.stat_mod_prior;
    set_lib_warn_handler(previous);
    return is;
  } 

  realmat pack_sci_val(const sci_val& sv)
  {
    INTEGER rows = 0;
    rows += 1 + sv.sci_mod_parm.size();    // sci_mod_parm
    rows += 1;                             // sci_mod_support
    rows += 1;                             // sci_mod_simulate
    rows += 1 + sv.sci_mod_func.size();    // sci_mod_func
    rows += 2;                             // sci_mod_prior
    rows += 1 + sv.stat_mod_parm.size();   // stat_mod_parm
    rows += 2;                             // stat_mod_sub_logl
    rows += 2;                             // stat_mod_logl
    rows += 1;                             // stat_mod_support
    rows += 1;                             // stat_mod_simulate
    rows += 1 + sv.sci_mod_func.size();    // stat_mod_func
    rows += 2;                             // stat_mod_prior
    
    realmat rm(rows,1);
    
    INTEGER k = 0;
    rm[++k] = sv.sci_mod_parm.size();                    // sci_mod_parm  
    for (INTEGER i=1; i<=sv.sci_mod_parm.size(); ++i) 
      rm[++k] = sv.sci_mod_parm[i];                      
    rm[++k] = sv.sci_mod_support;                        // sci_mod_support
    rm[++k] = sv.sci_mod_simulate;                       // sci_mod_simulate 
    rm[++k] = sv.sci_mod_func.size();                    // sci_mod_func  
    for (INTEGER i=1; i<=sv.sci_mod_func.size(); ++i) 
      rm[++k] = sv.sci_mod_func[i];
    rm[++k] = sv.sci_mod_prior.positive;                 // sci_mod_prior
    rm[++k] = sv.sci_mod_prior.log_den;
    rm[++k] = sv.stat_mod_parm.size();                   // stat_mod_parm  
    for (INTEGER i=1; i<=sv.stat_mod_parm.size(); ++i) 
      rm[++k] = sv.stat_mod_parm[i];                      
    rm[++k] = sv.stat_mod_sub_logl.positive;             // stat_mod_sub_logl
    rm[++k] = sv.stat_mod_sub_logl.log_den;
    rm[++k] = sv.stat_mod_logl.positive;                 // stat_mod_logl 
    rm[++k] = sv.stat_mod_logl.log_den;
    rm[++k] = sv.stat_mod_support;                       // stat_mod_support
    rm[++k] = sv.stat_mod_simulate;                      // stat_mod_simulate 
    rm[++k] = sv.stat_mod_func.size();                   // stat_mod_func  
    for (INTEGER i=1; i<=sv.stat_mod_func.size(); ++i) 
      rm[++k] = sv.stat_mod_func[i];
    rm[++k] = sv.stat_mod_prior.positive;                // stat_mod_prior
    rm[++k] = sv.stat_mod_prior.log_den;

    return rm;
  }

  sci_val unpack_sci_val(const realmat& rm)
  {
    sci_val sv;

    INTEGER k = 0;

    if (rm[++k] > 0.0) {                                 // sci_mod_parm  
      sv.sci_mod_parm.resize(INTEGER(rm[k]),1); 
      for (INTEGER i=1; i<=sv.sci_mod_parm.size(); ++i)
        sv.sci_mod_parm[i] = rm[++k];                    
    }
    sv.sci_mod_support = bool(rm[++k]);                  // sci_mod_support
    sv.sci_mod_simulate = bool(rm[++k]);                 // sci_mod_simulate 
    if (rm[++k] > 0.0) {                                 // sci_mod_func  
      sv.sci_mod_func.resize(INTEGER(rm[k]),1);
      for (INTEGER i=1; i<=sv.sci_mod_func.size(); ++i) 
        sv.sci_mod_func[i] = rm[++k];
    }
    sv.sci_mod_prior.positive = bool(rm[++k]);           // sci_mod_prior
    sv.sci_mod_prior.log_den = rm[++k];
    if (rm[++k] > 0.0) {                                 // stat_mod_parm  
      sv.stat_mod_parm.resize(INTEGER(rm[k]),1); 
      for (INTEGER i=1; i<=sv.stat_mod_parm.size(); ++i) 
         sv.stat_mod_parm[i] = rm[++k];                      
    }
    sv.stat_mod_sub_logl.positive = bool(rm[++k]);       // stat_mod_sub_logl
    sv.stat_mod_sub_logl.log_den = rm[++k];
    sv.stat_mod_logl.positive = bool(rm[++k]);           // stat_mod_logl 
    sv.stat_mod_logl.log_den = rm[++k];
    sv.stat_mod_support = bool(rm[++k]);                 // stat_mod_support
    sv.stat_mod_simulate = bool(rm[++k]);                // stat_mod_simulate 
    if (rm[++k] > 0.0) {                                 // stat_mod_func  
      sv.stat_mod_func.resize(INTEGER(rm[k]),1);
      for (INTEGER i=1; i<=sv.stat_mod_func.size(); ++i) 
        sv.stat_mod_func[i] = rm[++k];
    }
    sv.stat_mod_prior.positive = bool(rm[++k]);          //stat_mod_prior
    sv.stat_mod_prior.log_den = rm[++k];

    return sv;
  }

  ostream& operator<<(ostream& os, const stat_val& stv)
  {
    vecwrite(os,stv.stat_mod_parm);
    os << stv.stat_mod_logl << '\n';
    os << INTEGER(stv.stat_mod_support) << '\n';
    os << INTEGER(stv.stat_mod_simulate) << '\n';
    vecwrite(os,stv.stat_mod_func);
    os << stv.stat_mod_prior << '\n';
    os << stv.assess_prior << '\n';
    return os;
  }

  istream& operator>>(std::istream& is, stat_val& stv)
  {
    LIB_WARN_HANDLER_PTR previous = set_lib_warn_handler(nowarn);
    vecread(is,stv.stat_mod_parm);
    if (!is.good()) { set_lib_warn_handler(previous); return is; }
    is >> stv.stat_mod_logl;
    is >> stv.stat_mod_support;
    is >> stv.stat_mod_simulate;
    vecread(is,stv.stat_mod_func);
    is >> stv.stat_mod_prior;
    is >> stv.assess_prior;
    set_lib_warn_handler(previous);
    return is;
  } 

  bool operator==(const map_val& lhs, const map_val& rhs) 
  {
    if (lhs.sci_mod_support != rhs.sci_mod_support) return false;
    if (lhs.sci_mod_simulate != rhs.sci_mod_simulate) return false;
    if (lhs.sci_mod_func != rhs.sci_mod_func) return false;
    if (lhs.sci_mod_prior != rhs.sci_mod_prior) return false;
    if (lhs.stat_mod_parm != rhs.stat_mod_parm) return false;
    if (lhs.stat_mod_sub_logl != rhs.stat_mod_sub_logl) return false;
    if (lhs.stat_mod_logl != rhs.stat_mod_logl) return false;
    if (lhs.stat_mod_support != rhs.stat_mod_support) return false;
    if (lhs.stat_mod_simulate != rhs.stat_mod_simulate) return false;
    if (lhs.stat_mod_func != rhs.stat_mod_func) return false;
    if (lhs.stat_mod_prior != rhs.stat_mod_prior) return false;
    return true;
  }
  
  bool operator!=(const map_val& lhs, const map_val& rhs)
  {
    return !(lhs == rhs);
  }
  
  
  sci_val::sci_val()
  : sci_mod_parm(),
    sci_mod_support(),
    sci_mod_simulate(),
    sci_mod_func(),
    sci_mod_prior(),
    stat_mod_parm(),
    stat_mod_sub_logl(),
    stat_mod_logl(),
    stat_mod_support(),
    stat_mod_simulate(),
    stat_mod_func(),
    stat_mod_prior()
  { }
  
  sci_val::sci_val(const realmat& parm)
  : sci_mod_parm(parm),
    sci_mod_support(false),
    sci_mod_simulate(false),
    sci_mod_func(),
    sci_mod_prior(false,-REAL_MAX),
    stat_mod_parm(),
    stat_mod_sub_logl(false,-REAL_MAX),
    stat_mod_logl(false,-REAL_MAX),
    stat_mod_support(false),
    stat_mod_simulate(false),
    stat_mod_func(),
    stat_mod_prior(false,-REAL_MAX)
  { }

  sci_val::sci_val(const realmat& parm, const map_val& mv)
  : sci_mod_parm(parm),
    sci_mod_support(mv.sci_mod_support),
    sci_mod_simulate(mv.sci_mod_simulate),
    sci_mod_func(mv.sci_mod_func),
    sci_mod_prior(mv.sci_mod_prior),
    stat_mod_parm(mv.stat_mod_parm),
    stat_mod_sub_logl(mv.stat_mod_sub_logl),
    stat_mod_logl(mv.stat_mod_logl),
    stat_mod_support(mv.stat_mod_support),
    stat_mod_simulate(mv.stat_mod_simulate),
    stat_mod_func(mv.stat_mod_func),
    stat_mod_prior(mv.stat_mod_prior)
  { }

  string sci_val::annotated_sci_val() const
  {
    sci_val scv = *this;
    stringstream os;
    os << "sci_mod_parm (rows, cols, values) " << '\n'; 
    vecwrite(os,scv.sci_mod_parm);
    os << "sci_mod_support\n" << INTEGER(scv.sci_mod_support) << '\n';
    os << "sci_mod_simulate\n" << INTEGER(scv.sci_mod_simulate) << '\n';
    os << "sci_mod_func (rows, cols, values)\n"; 
    vecwrite(os,scv.sci_mod_func);
    os << "sci_mod_prior (den_val)\n" << scv.sci_mod_prior << '\n';
    os << "stat_mod_parm (rows, cols, values)\n"; 
    vecwrite(os,scv.stat_mod_parm);
    os << "stat_mod_sub_logl (den_val)\n" << scv.stat_mod_sub_logl << '\n';
    os << "stat_mod_logl (den_val)\n" << scv.stat_mod_logl << '\n';
    os << "stat_mod_support\n" << INTEGER(scv.stat_mod_support) << '\n';
    os << "stat_mod_simulate\n" << INTEGER(scv.stat_mod_simulate) << '\n';
    os << "stat_mod_func (rows, cols, values)\n"; 
    vecwrite(os,scv.stat_mod_func);
    os << "stat_mod_prior (den_val)\n" << scv.stat_mod_prior;
    return os.str();
  }

  map_val::map_val()
  : sci_mod_support(),
    sci_mod_simulate(),
    sci_mod_func(),
    sci_mod_prior(),
    stat_mod_parm(),
    stat_mod_sub_logl(),
    stat_mod_logl(),
    stat_mod_support(),
    stat_mod_simulate(),
    stat_mod_func(),
    stat_mod_prior()
  { }
  
  map_val::map_val(const sci_val& sv)
  : sci_mod_support(sv.sci_mod_support),
    sci_mod_simulate(sv.sci_mod_simulate),
    sci_mod_func(sv.sci_mod_func),
    sci_mod_prior(sv.sci_mod_prior),
    stat_mod_parm(sv.stat_mod_parm),
    stat_mod_sub_logl(sv.stat_mod_sub_logl),
    stat_mod_logl(sv.stat_mod_logl),
    stat_mod_support(sv.stat_mod_support),
    stat_mod_simulate(sv.stat_mod_simulate),
    stat_mod_func(sv.stat_mod_func),
    stat_mod_prior(sv.stat_mod_prior)
  { }

  INTEGER stat_mod_base::num_obs()
  {
    string msg = "Error, stat_mod_base, num_polish_iter postive but\n";
    msg += "inherited method num_obs() not coded";
    error(msg);
    return 0;
  }

  den_val stat_mod_base::loglikelihood(realmat& dlogl)
  {
    string msg = "Error, stat_mod_base, num_polish_iter postive but\n";
    msg += "inherited method den_val logl(realmat& dlogl) not coded";
    error(msg);
    return den_val();
  }

  REAL stat_mod_base::penalty(realmat& dpenalty)
  { 
    dpenalty.resize(1,len_parm(),0.0);
    string msg = "Error, stat_mod_base, num_polish_iter postive but\n";
    msg += "inherited method penalty(realmat& dpenalty) not coded";
    error(msg);
    return 0.0;
  }
    
  INTEGER stat_mod_base::mle(realmat& stat_parm_mle, realmat& stat_parm_V)
  { 
    stat_parm_mle.resize(len_parm(),1);
    stat_parm_V.resize(len_parm(),len_parm());
    string msg = "Error, stat_mod_base, analytic_mle == 1 but inherited\n";
    msg += "method mle(realmat& stat_parm, realmat& stat_parm_V) not coded\n"; 
    error(msg);
    return 0;  // Returns number of d.f. in stat_parm_V
  }
    
  bool stat_mod_eqns::get_f(const realmat& parm, realmat& f)
  {
    f.resize(1,1);
    stat_mod.set_parm(parm);
    if (!stat_mod.support(parm)) return false;
    den_val dv = stat_mod.loglikelihood();
    realmat dpenalty;
    f[1] = -dv.log_den/stat_mod.num_obs() + stat_mod.penalty(dpenalty);
    return dv.positive;
  }
  
  bool stat_mod_eqns::get_F(const realmat& parm, realmat& f, realmat& F)
  {
    f.resize(1,1);
    stat_mod.set_parm(parm);
    if (!stat_mod.support(parm)) return false;
    realmat dlogl;
    realmat dpenalty;
    den_val dv = stat_mod.loglikelihood(dlogl);
    f[1] = -dv.log_den/stat_mod.num_obs() + stat_mod.penalty(dpenalty);
    F = -dlogl/stat_mod.num_obs() + dpenalty;
    return dv.positive;
  }

  realmat pack_stat_val(const stat_val& sv)
  {
    INTEGER rows = 0;
    rows += 1 + sv.stat_mod_parm.size();   // stat_mod_parm
    rows += 2;                             // stat_mod_logl
    rows += 1;                             // stat_mod_support
    rows += 1;                             // stat_mod_simulate
    rows += 1 + sv.stat_mod_func.size();   // stat_mod_func
    rows += 2;                             // stat_mod_prior
    rows += 2;                             // assess_prior
    
    realmat rm(rows,1);
    
    INTEGER k = 0;
    rm[++k] = sv.stat_mod_parm.size();                   // stat_mod_parm  
    for (INTEGER i=1; i<=sv.stat_mod_parm.size(); ++i) 
      rm[++k] = sv.stat_mod_parm[i];                      
    rm[++k] = sv.stat_mod_logl.positive;                 // stat_mod_logl 
    rm[++k] = sv.stat_mod_logl.log_den;
    rm[++k] = sv.stat_mod_support;                       // stat_mod_support
    rm[++k] = sv.stat_mod_simulate;                      // stat_mod_simulate 
    rm[++k] = sv.stat_mod_func.size();                   // stat_mod_func  
    for (INTEGER i=1; i<=sv.stat_mod_func.size(); ++i) 
      rm[++k] = sv.stat_mod_func[i];
    rm[++k] = sv.stat_mod_prior.positive;                // stat_mod_prior
    rm[++k] = sv.stat_mod_prior.log_den;
    rm[++k] = sv.assess_prior.positive;                  // assess_prior
    rm[++k] = sv.assess_prior.log_den;

    return rm;
  }

  stat_val unpack_stat_val(const realmat& rm)
  {
    stat_val sv;

    INTEGER k = 0;

    if (rm[++k] > 0.0) {                                 // stat_mod_parm  
      sv.stat_mod_parm.resize(INTEGER(rm[k]),1); 
      for (INTEGER i=1; i<=sv.stat_mod_parm.size(); ++i)
        sv.stat_mod_parm[i] = rm[++k];                    
    }
    sv.stat_mod_logl.positive = bool(rm[++k]);           // stat_mod_logl 
    sv.stat_mod_logl.log_den = rm[++k];
    sv.stat_mod_support = bool(rm[++k]);                 // stat_mod_support
    sv.stat_mod_simulate = bool(rm[++k]);                // stat_mod_simulate 
    if (rm[++k] > 0.0) {                                 // stat_mod_func  
      sv.stat_mod_func.resize(INTEGER(rm[k]),1);
      for (INTEGER i=1; i<=sv.stat_mod_func.size(); ++i) 
        sv.stat_mod_func[i] = rm[++k];
    }
    sv.stat_mod_prior.positive = bool(rm[++k]);          // stat_mod_prior
    sv.stat_mod_prior.log_den = rm[++k];
    sv.assess_prior.positive = bool(rm[++k]);            //assess_prior
    sv.assess_prior.log_den = rm[++k];

    return sv;
  }
}

