rm(list=ls())
library(Bolstad) 
library(AER) 
library(mvtnorm) 
source("TSIV_functions.R")

ptm <- proc.time()
n1=5000
n2=1000
S=10
alpha=0.05
beta = c(-2,0,2)
K=c(1,5,10)
CONC=c(1,4,16)
p=1
eta=rep(0,p)
nu=rep(0,p)
for (b in 1:length(beta)){
  ci_true_AR_hetero = matrix(NA,nrow=length(K),ncol=length(CONC))
  ci_true_CLR_hetero = matrix(NA,nrow=length(K),ncol=length(CONC))
  ci_true_K_hetero = matrix(NA,nrow=length(K),ncol=length(CONC))
  
  b_AR = matrix(NA,nrow=length(K),ncol=length(CONC))
  b_CLR = matrix(NA,nrow=length(K),ncol=length(CONC))
  b_K = matrix(NA,nrow=length(K),ncol=length(CONC))
  
  ci_length_AR = matrix(NA,nrow=length(K),ncol=length(CONC))
  ci_length_CLR = matrix(NA,nrow=length(K),ncol=length(CONC))
  ci_length_K = matrix(NA,nrow=length(K),ncol=length(CONC))

  ci_length_AR_median = matrix(NA,nrow=length(K),ncol=length(CONC))
  ci_length_CLR_median = matrix(NA,nrow=length(K),ncol=length(CONC))
  ci_length_K_median = matrix(NA,nrow=length(K),ncol=length(CONC))
  
  for (kind in 1:length(K)){
    k=K[kind]
    set.seed(123)  
    for (concind in 1:length(CONC)){
      conc=CONC[concind]
      print(c(k,conc))
      lambda = conc*k
      Pi=rep(sqrt(lambda / (n2*k)),k)
        
      AR_cset = rep("",S)
      CLR_cset = rep("",S)
      K_cset = rep("",S)
      K_cset0 = rep("",S)
      rnull1_K0 = rep("",S)
      rnull2_K0 = rep("",S)
      
      
      AR_r_midpoint = rep(NA,S) 
      CLR_r_midpoint = rep(NA,S)
      K_r_midpoint = rep(NA,S)
      
      bounded_AR = rep(NA,S)
      bounded_CLR = rep(NA,S)
      bounded_K = rep(NA,S)
 
      cilength_AR = rep(NA,S)
      cilength_CLR = rep(NA,S)
      cilength_K = rep(NA,S)
      
        
      for (s in 1:S){
        z1=rmvnorm(n1,sigma=diag(k))
        z2=rmvnorm(n2,sigma=diag(k))
        if (p==1){
          x1<-matrix(1,nrow=n1,ncol=p)
          x2<-matrix(1,nrow=n2,ncol=p)
        }
        else{
          x1<-cbind(rep(1,n1),rnorm(n1))
          x2<-cbind(rep(1,n2),rnorm(n2))		
        }
        v1=rnorm(n1)
        v2=rnorm(n2)
        w1=t(t(z1))%*%Pi+x1%*%eta+v1
        w2=t(t(z2))%*%Pi+x2%*%eta+v2
        e=rnorm(n1)*exp(-(t(t(z1))%*%Pi+x1%*%eta)^2)
        e=e/sd(e)
        u1=v1*0.1+e*sqrt(0.99)
        y1=beta[b]*w1+x1%*%nu+u1
        
        
        ## 95% Confidence Interval - Classical t-test
        Pihat<-lm(w2~z2+x2-1)$coefficients
        v2sigmasq<-sum((residuals(lm(w2~z2+x2-1)))^2)/(n2-k-p)
        w1hat<-cbind(z1,x1)%*%Pihat
        betahat<-lm(y1~w1hat+x1-1)$coefficients[1]  
        e1sigmasq<-sum((residuals(lm(y1~w1hat+x1-1)))^2)/(n2-1-p)
        vbetahat=(summary(lm(y1~w1hat+x1-1))$coefficients[1,2])^2*(1+n1/n2*betahat^2*(v2sigmasq)/(e1sigmasq))
        ci <- c(betahat-qt(0.975,n1-k-1)*sqrt(vbetahat), betahat+qt(0.975,n1-k-1)*sqrt(vbetahat)) 
          
        ## Create grid [-a,a]
        grid <- seq(-1000,-100,10)
        grid <- c(grid,seq(-99, -10, 1))
        grid <- c(grid,seq(-9.9, -1, 0.1))
        grid <- c(grid,seq(-0.99, 0.99, 0.01))
        grid <- c(grid,seq(1, 9.9, 0.1))
        grid <- c(grid,seq(10, 99, 1))
        grid <- c(grid,seq(100,1000,10)) 
        points <- length(grid)
        
        
        ARstat=rep(NA,points)
        Kstat=rep(NA,points)
        CLRstat=rep(NA,points)
        CLRpvalue=rep(NA,points) 
        AR_r=rep(NA,points)
        K_r=rep(NA,points)
        CLR_r=rep(NA,points)
        
        AR_rbegin=0
        AR_rbegin_null=0
        AR_rend=0
        AR_rend_null=0
        
        CLR_rbegin=0
        CLR_rbegin_null=0
        CLR_rend=0
        CLR_rend_null=0
        
        K_rbegin=0
        K_rbegin_null=0
        K_rend=0
        K_rend_null=0
        
        for (g in 1:length(grid)){
              
          ## Proposed two-sample AR K CLR test
          results=TStest_hetero(y1,w1hat,w2,z1,z2,x1,x2,grid[g])
          ARstat[g]=results$ARstat
          Kstat[g]=results$Kstat
          CLRstat[g]=results$CLRstat                     
          CLRpvalue[g]=results$CLRpvalue
          
          ## Create rejection indicators
          AR_r[g]<-ARstat[g]>qchisq((1-0.05),k)
          K_r[g]<-Kstat[g]>qchisq((1-0.05),1)
          CLR_r[g]<-CLRpvalue[g]<0.05
          
          ## AR CI
          if (AR_r[g]==0){
            if (AR_rbegin==0){
                AR_rbegin = g
                AR_rbegin_null=grid[g]
            }
              AR_rend=g
              AR_rend_null=grid[g]
            }
              if (AR_r[g]==1 | (AR_r[g]==0 & g==points)){
              if (AR_rbegin>0 & AR_rend>0 & AR_rbegin==AR_rend){
                 rnull = format(AR_rbegin_null,digits=2,nsmall=2)
              if (nchar(AR_cset[s])==0){
                AR_cset[s] <- toString(rnull) 
              }     
              else{
                AR_cset[s] <- paste(AR_cset[s],"U",toString(rnull))}
                AR_rbegin=0
                AR_rend=0
                
              }
              else if (AR_rbegin>0 & AR_rend>0 & AR_rbegin<AR_rend){
                  rnull1 = format(AR_rbegin_null,digits=2,nsmall=2)
                  rnull2 = format(AR_rend_null,digits=2,nsmall=2)
              if (nchar(AR_cset[s])==0){
                AR_cset[s] <- paste("[",toString(rnull1),",",toString(rnull2),"]") 
              }       
              else{                                                           
                AR_cset[s] <- paste(AR_cset[s],"U","[",toString(rnull1),",",toString(rnull2),"]")}
                AR_rbegin=0
                AR_rend=0
              
              
              }
              }
         
          
          ## CLR CI
          if (CLR_r[g]==0){
            if (CLR_rbegin==0){
              CLR_rbegin = g
              CLR_rbegin_null=grid[g]
            }
            CLR_rend=g
            CLR_rend_null=grid[g]
          }
          if (CLR_r[g]==1 | (CLR_r[g]==0 & g==points)){
            if (CLR_rbegin>0 & CLR_rend>0 & CLR_rbegin==CLR_rend){
              rnull = format(CLR_rbegin_null,digits=2,nsmall=2)
              if (nchar(CLR_cset[s])==0){
                CLR_cset[s] <- toString(rnull) 
              }     
              else{
                CLR_cset[s] <- paste(CLR_cset[s],"U",toString(rnull))}
              CLR_rbegin=0
              CLR_rend=0
              
            }
            else if (CLR_rbegin>0 & CLR_rend>0 & CLR_rbegin<CLR_rend){
              rnull1 = format(CLR_rbegin_null,digits=2,nsmall=2)
              rnull2 = format(CLR_rend_null,digits=2,nsmall=2)
              if (nchar(CLR_cset[s])==0){
                CLR_cset[s] <- paste("[",toString(rnull1),",",toString(rnull2),"]") 
              }       
              else{                                                           
                CLR_cset[s] <- paste(CLR_cset[s],"U","[",toString(rnull1),",",toString(rnull2),"]")}
              CLR_rbegin=0
              CLR_rend=0
              
              
            }
          }
          
          ## K CI
          if (K_r[g]==0){
            if (K_rbegin==0){
              K_rbegin = g
              K_rbegin_null=grid[g]
            }
            K_rend=g
            K_rend_null=grid[g]
          }
          if (K_r[g]==1 | (K_r[g]==0 & g==points)){
            if (K_rbegin>0 & K_rend>0 & K_rbegin==K_rend){
              rnull = format(K_rbegin_null,digits=2,nsmall=2)
              if (nchar(K_cset[s])==0){
                K_cset[s] <- toString(rnull) 
              }     
              else{
                K_cset[s] <- paste(K_cset[s],"U",toString(rnull))}
              K_rbegin=0
              K_rend=0
              
           
              
            }
            else if (K_rbegin>0 & K_rend>0 & K_rbegin<K_rend){
              rnull1_K = format(K_rbegin_null,digits=2,nsmall=2)
              rnull2_K = format(K_rend_null,digits=2,nsmall=2)
              if (nchar(K_cset[s])==0){
                K_cset[s] <- paste("[",toString(rnull1_K),",",toString(rnull2_K),"]") 
                K_cset0[s] <- paste("[",toString(rnull1_K),",",toString(rnull2_K),"]") 
                rnull1_K0[s] = format(K_rbegin_null,digits=2,nsmall=2)
                rnull2_K0[s] = format(K_rend_null,digits=2,nsmall=2)
              }       
              else{                                                           
                K_cset[s] <- paste(K_cset[s],"U","[",toString(rnull1_K),",",toString(rnull2_K),"]")}
              K_rbegin=0
              K_rend=0
              
              
            }
          }
          
          
          }
        
        AR_r_midpoint[s] <- AR_r[1]
        CLR_r_midpoint[s] <- CLR_r[1]
        K_r_midpoint[s] <- K_r[1]

        bounded_AR[s] <- AR_r[1]==1 & AR_r[length(grid)]==1
        bounded_CLR[s] <- CLR_r[1]==1 & CLR_r[length(grid)]==1
        bounded_K[s] <- K_r[1]==1 & K_r[length(grid)]==1
        
        
        
        
        
        # Calculate length of CI
        if (bounded_AR[s]==1) {
          cilength_AR[s] <- abs(AR_rbegin_null - AR_rend_null) 
        } else {
          cilength_AR[s] <- 0
        }
        
        if (bounded_CLR[s]==1) {
          cilength_CLR[s] <- abs(CLR_rbegin_null - CLR_rend_null)
        } else {
          cilength_CLR[s] <- 0
        }    
                
        if (bounded_K[s]==1) {
          cilength_K[s] <- abs(as.numeric(rnull1_K0[s]) - as.numeric(rnull2_K0[s])) + abs(as.numeric(rnull1_K) - as.numeric(rnull2_K))
        } else {
          cilength_K[s] <- 0
        }
        
        
        
        
        
        
        if (nchar(AR_cset[s])==0){
          AR_cset[s] <- "null set"
            } 
        if (nchar(CLR_cset[s])==0){
          CLR_cset[s] <- "null set"
        }  
        if (nchar(K_cset[s])==0){
          K_cset[s] <- "null set"
        }  
        
        }     
      
      ci_true_AR_hetero[kind,concind] <- 1 - mean(AR_r_midpoint)
      ci_true_CLR_hetero[kind,concind] <- 1 - mean(CLR_r_midpoint)
      ci_true_K_hetero[kind,concind] <- 1 - mean(K_r_midpoint)
      
      b_AR[kind,concind] <- sum(bounded_AR)
      b_CLR[kind,concind] <- sum(bounded_CLR)
      b_K[kind,concind] <- sum(bounded_K)
      
      cilength_K[is.na(cilength_K)] <- 0
      cilength_AR[is.na(cilength_AR)] <- 0
      cilength_CLR[is.na(cilength_CLR)] <- 0
      
      ci_length_AR[kind,concind] <- sum(cilength_AR) / b_AR[kind,concind]
      ci_length_CLR[kind,concind] <- sum(cilength_CLR) / b_CLR[kind,concind]
      ci_length_K[kind,concind] <- sum(cilength_K) / b_K[kind,concind]
      
      ci_length_AR_median[kind,concind] <-  median(cilength_AR[cilength_AR!=0])
      ci_length_CLR_median[kind,concind] <-  median(cilength_CLR[cilength_CLR!=0])
      ci_length_K_median[kind,concind] <-  median(cilength_K[cilength_K!=0])
      
          }  
    }
  print(ci_true_AR_hetero,digit=3)
  print(ci_true_CLR_hetero,digit=3)
  print(ci_true_K_hetero,digit=3)
    
  print(b_AR,digits=3)
  print(b_CLR,digits=3)
  print(b_K,digits=3)
  
  print(ci_length_AR,digits=3)
  print(ci_length_CLR,digits=3)
  print(ci_length_K,digits=3)
  
  
  }
proc.time() - ptm