# The simple and weighted RRP ATT ESTIMATOR  
#
# DM : the distance matrix of class "dist" or a matrix
# treated : group variable 1/0
# ocome : observed outcome for treated and controls
# pred.come : estimated outcome used to adjust for the difference-in-covariates
# data: a data matrix to calculate the imbalance measure \delta in Gu and Rosenbaum (1993)
# Note: ocome and pred.come must have the same length

rrp.att <- function(DM, treated, ocome, pred.ocome, data){
	treat <- which(treated==1)
	ctrl <- which(treated==0)
	nt <- length(treat)
	nc <- length(ctrl)
	n <- nt+nc

# we work on the proximity matrix
	P <- 1 - as.matrix(DM)
	
# accessories functions
   	lenc <- sapply(treat, function(x) sum(P[x,ctrl]))
	cfactual <- function(i)  ifelse(lenc[i]==0, NA, weighted.mean(ocome[ctrl], w=f.wht[i,], na.rm=T))
	adj.cfactual <- function(i) ifelse(lenc[i]==0, NA, weighted.mean(ocome[ctrl]-pred.ocome[ctrl], w=f.wht[i,], na.rm=T))

# calculates the f_{ij} weights of the simple RRP-ATT estimator
    f.wht <- matrix((P[treat,ctrl])/sapply(1:nt, function(x) ifelse(lenc[x]==0,1,lenc[x])), length(treat), length(ctrl)) 
	
# estimates the counter factual	
	yhat <- sapply(1:nt, cfactual)	

# calculate the weights for the weighted estimator
	p.max <- sapply(treat, function(x) max(P[x,ctrl]))
	wht <- p.max/sum(p.max) 

# simple RRP-ATT estimator
	att <- mean(ocome[treat]-yhat, na.rm=T)

# weighted RRP-ATT estimator
	w.att <- weighted.mean(ocome[treat]-yhat, w=wht, na.rm=T)

	if(missing(pred.ocome)){
		adj.att <- att
		w.adj.att <- w.att
	} else {
# estimates the adjusted counter factual	
		adj.yhat <- sapply(1:nt, adj.cfactual) # adjusted counter factuals
# adjusted RRP-ATT estimator
		adj.att <- mean((ocome[treat]-pred.ocome[treat])-adj.yhat, na.rm=T)
# weighted adjusted RRP-ATT estimator
		w.adj.att <- weighted.mean((ocome[treat]-pred.ocome[treat])-adj.yhat, w=wht, na.rm=T)
	}

	delta <- NA
    if(!missing(data)){
	    dt <- matrix(as.numeric(NA),nt,dim(data)[2])
		idc <- sapply(1:nt, function(x) ifelse(lenc[x]==0, NA, ctrl[which(P[x,ctrl]>0)]))
        for(i in 1:nt){
			dt[i,] <- as.numeric(data[treat[i],] - unlist(apply(data[idc[[i]],], 2, function(x) mean(x, na.rm=T))))		
		}
		delta <- apply(dt, 2, function(x) mean(x,na.rm=T))
	}
	return(list(att,adj.att,w.att,w.adj.att,delta))
}

