## ------------------------------------------------------------------
## Parameter specification
## ------------------------------------------------------------------
lambda <- 0.069
h_max  <- 12               ## forecast `h_max` step ahead
m      <- 17               ## no maturities
f      <- 3                ## #'s DNS factors
r      <- 2                ## #'s macro factors
options(warn = -1)

## -----------------------------------------------------------------------------
## Load external functions, load data, and models
## -----------------------------------------------------------------------------
source("01-helper.R")
source("02-load_data.R")
source("03-setup_model.R")

## -----------------------------------------------------------------------------
## Start clusters
## -----------------------------------------------------------------------------
library(foreach)
library(doParallel) 
cl <- makeCluster(2)
registerDoParallel(cl)

## -----------------------------------------------------------------------------
## Loop 1: Estimate parameters of dns models
## -----------------------------------------------------------------------------
em <- list()
pb <- progress_bar$new(total = length((oos_end):(oos_stop)), 
                       format = "Estimation [:bar] :percent eta: :eta",
                       clear = FALSE, width = 60)

for (j in (oos_end):(oos_stop)) {
  pb$tick()
  count <- j - oos_end + 1
  YY <- yield_data[1:j]
  XXmf <- Xmf[1:j,]
  yield_eff <- yield_data[(j + 1):min(j + h_max, nrow(yield_data)), ]
  irow <- 1:nrow(yield_eff)
  idx <- 1:j

  if (count == 1) {
      ## First sample -- use `initvalues`
      init <- list(m1 = initvalues(idx, typemodel = "yields only"),
                   m2 = initvalues(idx, typemodel = "macro augmented", X = Xmf))
  } else {
    ## Set initial value to estimated pars of previous iteration
    init <- MARSSpar2init(em)
  }
  
  datas <-  list(t(YY), 
                 t(cbind(YY, XXmf)))
                 
  
  models <-  list(dns_model[[1]]$model,
                  dns_model[[2]]$model)

  em <- foreach(i = 1:2, .packages = "MARSS")  %dopar%  MARSS(y = datas[[i]],
                                     model = models[[i]],
                                     init = init[[i]],
                                     control = list(maxit = 2000,
                                     allow.degen = TRUE,
                                     conv.test.slope.tol = 1000),
                                     silent = TRUE)

  kf <- lapply(em, FUN = function(u) MARSSkfas(u, return.kfas.model = TRUE))

  ## Update 
  for (kk in 1:length(dns_model)) {
    dns_model[[kk]]$est_model[[count]] <- em[[kk]]  
    out <- SSMpredict(kf[[kk]]$kfas.model)
    Yhat <- out$Yhat
    Yhat_var <- out$Vhat
    
    dns_model[[kk]]$actual_yield[count, irow, ] <- yield_eff
    
    dns_model[[kk]]$fcast[count, irow, ] <- out$Yhat[irow, 1:m]
    dns_model[[kk]]$error[count, irow, ] <- yield_eff - out$Yhat[irow, 1:m]
    
    rwf <- apply(yield_data[idx], 2, function(u) rwf(u, h = 12)$mean)[irow,]
    dns_model[[kk]]$fcast_rw[count, irow, ] <- rwf
    dns_model[[kk]]$error_rw[count, irow, ]  <- yield_eff - rwf
    
    for (hh in irow) {
      dns_model[[kk]]$Sigma[count, hh, , ] <- Yhat_var[1:m, 1:m, hh]
    }

    dns_model[[kk]]$AIC[count] <- em[[kk]]$AIC
    dns_model[[kk]]$loglik[count] <- em[[kk]]$logLik
    ## To get the dates back as.yearmon(dns_model[[kk]]$oos_dates[count,])
    dns_model[[kk]]$oos_dates[count,irow] <- index(yield_eff)
  }
}

## -----------------------------------------------------------------------------
## Loop 2: Tilt out-of-sample forecast
## -----------------------------------------------------------------------------
off <- 181
for (kk in 1:2) {
  for (hh in c(3, 6, 9, 12)) {
    tmp <- na.omit(dns_model[[kk]]$error_rw[, hh, 1])
    for (jj in 1:(length(tmp))) {
      bcidx <- jj + off - 1
      if (bcidx <= nrow(bc_3m_release)) {
        mu <- bc_3m_release[bcidx, hh / 3]
        dns_model[[kk]]$fcast_tilt[jj, hh, ]  <-
          tilt(dns_model[[kk]]$fcast[jj, hh,], dns_model[[kk]]$Sigma[jj, hh, ,], mu)
        dns_model[[kk]]$error_tilt[jj, hh, ]  <-
          yield_data[off + jj + hh - 1 - 1, ] - dns_model[[kk]]$fcast_tilt[jj, hh, ]
      } else {
        dns_model[[kk]]$fcast_tilt[jj, hh, ]  <- NA
        dns_model[[kk]]$error_tilt[jj, hh, ]  <- NA
      }
    }
  }
}


bc_6m_release <- bc_set[[2]][[1]]
dns_model[[1]]$fcast_tilt_2 <- array(0, dim(dns_model[[1]]$fcast_tilt))
dns_model[[2]]$fcast_tilt_2 <- array(0, dim(dns_model[[2]]$fcast_tilt))

dns_model[[1]]$error_tilt_2 <- array(0, dim(dns_model[[1]]$fcast_tilt))
dns_model[[2]]$error_tilt_2 <- array(0, dim(dns_model[[2]]$fcast_tilt))




off <- 181
for (kk in 1:2) {
  for (hh in c(3, 6, 9, 12)) {
    tmp <- na.omit(dns_model[[kk]]$error_rw[, hh, 1])
    for (jj in 1:(length(tmp))) {
      bcidx <- jj + off - 1
      if (bcidx <= nrow(bc_3m_release)) {
        mu <- c(bc_3m_release[bcidx, hh / 3], bc_6m_release[bcidx, hh / 3])
        dns_model[[kk]]$fcast_tilt_2[jj, hh, ]  <-
          tilt(dns_model[[kk]]$fcast[jj, hh,], dns_model[[kk]]$Sigma[jj, hh, ,], mu)
        dns_model[[kk]]$error_tilt_2[jj, hh, ]  <-
          yield_data[off + jj + hh - 1 - 1, ] - dns_model[[kk]]$fcast_tilt_2[jj, hh, ]
      } else {
        dns_model[[kk]]$fcast_tilt_2[jj, hh, ]  <- NA
        dns_model[[kk]]$error_tilt_2[jj, hh, ]  <- NA
      }
    }
  }
}

# off <- 180
# for (kk in 1:2) {
#   for (hh in c(3, 6, 9, 12)) {
#     tmp <- na.omit(dns_model[[kk]]$error_rw[, hh, 1])
#     for (jj in 1:(length(tmp))) {
#       bcidx <- jj + off - 1
#       if (bcidx <= nrow(bc_3m_release)) {
#         mu <- bc_3m_target[[hh / 3]][off + jj,]
#         dns_model[[kk]]$fcast_tilt[jj, hh, ]  <-
#           tilt(dns_model[[kk]]$fcast[jj, hh,], dns_model[[kk]]$Sigma[jj, hh, ,], mu)
#         dns_model[[kk]]$error_tilt[jj, hh, ]  <-
#           yield_data[off + jj, ] - dns_model[[kk]]$fcast_tilt[jj, hh, ]
#       } else {
#         dns_model[[kk]]$fcast_tilt[jj, hh, ]  <- NA
#         dns_model[[kk]]$error_tilt[jj, hh, ]  <- NA
#       }
#     }
#   }
# }
# 
# 
# bc_6m_target <- bc_set[["m6"]][[2]]
# dns_model[[1]]$fcast_tilt_2 <- array(0, dim(dns_model[[1]]$fcast_tilt))
# dns_model[[2]]$fcast_tilt_2 <- array(0, dim(dns_model[[2]]$fcast_tilt))
# 
# dns_model[[1]]$error_tilt_2 <- array(0, dim(dns_model[[1]]$fcast_tilt))
# dns_model[[2]]$error_tilt_2 <- array(0, dim(dns_model[[2]]$fcast_tilt))
# 
# 
# 
# 
# off <- 180
# for (kk in 1:2) {
#   for (hh in c(3, 6, 9, 12)) {
#     tmp <- na.omit(dns_model[[kk]]$error_rw[, hh, 1])
#     for (jj in 1:(length(tmp))) {
#       bcidx <- jj + off - 1
#       if (bcidx <= nrow(bc_3m_release)) {
#         mu <- c(bc_3m_target[[hh / 3]][off + jj], bc_6m_target[[hh / 3]][off + jj])
#         dns_model[[kk]]$fcast_tilt_2[jj, hh, ]  <-
#           tilt(dns_model[[kk]]$fcast[jj, hh,], dns_model[[kk]]$Sigma[jj, hh, ,], mu)
#         dns_model[[kk]]$error_tilt_2[jj, hh, ]  <-
#           yield_data[off + jj, ] - dns_model[[kk]]$fcast_tilt_2[jj, hh, ]
#       } else {
#         dns_model[[kk]]$fcast_tilt_2[jj, hh, ]  <- NA
#         dns_model[[kk]]$error_tilt_2[jj, hh, ]  <- NA
#       }
#     }
#   }
# }






## -----------------------------------------------------------------------------
## Loop 3: Produce conditional forecasts
## -----------------------------------------------------------------------------

off <- which(index(bc_3m_release) == "Jan 2000")

pb <- progress_bar$new(total = (length((oos_end):(oos_stop - 1))), 
                       format = "Estimation [:bar] :percent eta: :eta",
                       clear = FALSE, width = 60)

cl <- makeCluster(6)
registerDoParallel(cl)

out <- list()
for (kk in 1:2) {
  out[[kk]] <- foreach (j = (oos_end):(oos_stop-1), .packages = c("MARSS", "KFAS")) %dopar% {
    joff <- j - (off - 2)
    irow <- ((j + 1):min(j + h_max, nrow(yield_data))) - j
    kf <- MARSSkfas(dns_model[[kk]]$est_model[[joff]], return.kfas.model = TRUE)
    Ycond <- array(NA, c(h_max, attr(kf$kfas.model, "p")))
    mu <- bc_3m_release[j +1, ]
        
    # for (ii in c(3, 6, 9, 12)) {
    #   mu <- bc_3m_target[[ii/3]][j + 1, ]
    #   Ycond[ii, 1] <- mu
    # }
    Ycond[c(3,6,9,12),1] <- as.numeric(mu)
    tmp <- DNScondfore_dist(kf$kfas.model, Ycond = Ycond, nsim = 100)
    #list(Yhat = matrix(unlist(tmp), nrow = NROW(tmp[[1]]), ncol = length(tmp), byrow = F))
    list(Yhat = tmp$Yhat)
  }
}

for (kk in 1:length(out)) {
  for (j in 1:length(out[[kk]])) {
    irow <- ((j + 1):min(j + h_max, nrow(yield_data))) - j
    Yhat <- out[[kk]][[j]][[1]]
    dns_model[[kk]]$fcast_cond[j, , ] <- Yhat[,1:m]
    dns_model[[kk]]$error_cond[j, , ] <-
      unclass(yield_data[180 + j - 1 + irow, ]) - Yhat[,1:m]
  }
}

save(dns_model, file = "results/dns_kalman.Rda")
options(warn = 0)