clear all

// prog mata to compute the interval each duration belongs to
cap program drop Inter
mata:
	void Inter(string scalar t)
	{
	// import variables and matrices into mata
	st_view(vt,.,t)
	Tcut = st_matrix("Tcut")
	N = rows(vt)
	// interval to which t belongs (consider all intervals strictly included in [0,t] and add 1)
	Tcut_sup = Tcut[.,2..cols(Tcut)]
	nint = rowsum((vt :> (J(N,1,1) * Tcut_sup))) + J(N,1,1)
	// compute t minus the lower bound of the interval t belongs to
	Tcut_inf = Tcut[.,1..cols(Tcut) - 1]
	Ikt = (vt :> (J(N,1,1) * Tcut_inf)) :* ((vt :<= (J(N,1,1) * Tcut_sup))) 
	Ik_length = Tcut_sup-Tcut_inf,0
	lint = rowsum((vt * J(1,cols(Tcut) - 1,1) - J(N,1,1) * Tcut_inf) :* Ikt)
	// send results to stata
	st_addvar(("double", "double"), ("nint", "lint"))
	st_store(.,"nint",nint)
	st_store(.,"lint",lint)
	}
end

// prog that declares model and maximizes likelihood
cap program drop maxlik
program define maxlik
syntax [, NV(real 1) MINIT]
	// declare model - parameters are ordered as follows: bp gp01 gp13 gp3 bz dp01 dp13 dp3 by I_p I_z I_y vp vz vy pzy
	local K_ALL = 0
	foreach vv of varlist P Z Y{
		local K_Tcut = colsof(Tcut_`vv')-2
		foreach kk of numlist 1(1)`K_Tcut'{
			local dkk = 9+`K_ALL'+`kk'
			local deltaIk_`vv' = "`deltaIk_`vv'' (delta`dkk':)"
		}
		local K_ALL = `K_ALL'+`K_Tcut'
	}
	scalar nv = `nv'
	local nv_all = 4*`nv'-1
	foreach hh of numlist 1(1)`nv_all'{
		local dhh = 9+`K_ALL'+`hh'
		local deltahet = "`deltahet' (delta`dhh':)"
	}
	ml model lf lik (delta1: X*, nocons)  ///
	                (delta2:) (delta3:) (delta4:) (delta5: X*, nocons)  ///
					(delta6:) (delta7:) (delta8:) (delta9: X*, nocons)  ///
					`deltaIk_A1' `deltaIk_P' `deltaIk_B1' `deltaIk_Z' `deltaIk_Y' `deltahet' , technique(bhhh 10 bfgs 100)
	// starting values
	if "`minit'"~=""{
		ml init M, copy
	}
	else{
		ml search
	}
	// maximize likelihood
	ml max, difficult
end

// prog to compute hazard rates and add the different bits of the likelihood
//clear mata
mata:
function LL(string bvars, string dvars, string intp, string intz, string inty, string mehaz)
{
	me = st_matrix("effects")
	st_view(MEHAZ = ., ., tokens(mehaz))

	st_view(NINT = ., ., tokens(intp))
	PCUT = st_matrix("Pcut_P")
	SPCUT = st_matrix("SPcut_P")
	LINT = NINT[.,2]
	NINT = NINT[.,1]
	H = PCUT[NINT,1]
	HINT = SPCUT[NINT:-1+(NINT:==1),1]:*(NINT:>1) + PCUT[NINT,1]:*LINT
	
	st_view(NINT = ., ., tokens(intz))
	PCUT = st_matrix("Pcut_Z")
	SPCUT = st_matrix("SPcut_Z")
	m = cols(NINT)/2
	LINT = NINT[.,m+1..2*m]
	NINT = NINT[.,1..m]
	H = H, PCUT[NINT[.,m],1]
	NINT = colshape(NINT',1)
	LINT = colshape(LINT',1)
	T = exp((0, me[1,1..3])) * (   I(m) + ( J(1,m,0)\(-I(m-1),J(m-1,1,0)) )   )	
	HINT = HINT, (  rowshape(SPCUT[NINT:-1+(NINT:==1),1]:*(NINT:>1) + PCUT[NINT,1]:*LINT, m)' * T'  )

	st_view(NINT = ., ., tokens(inty))
	PCUT = st_matrix("Pcut_Y")
	SPCUT = st_matrix("SPcut_Y")
	m = cols(NINT)/2
	LINT = NINT[.,m+1..2*m]
	NINT = NINT[.,1..m]
	H = H, PCUT[NINT[.,m],1]
	NINT = colshape(NINT',1)
	LINT = colshape(LINT',1)
	T = exp((0, me[1,4..6])) * (   I(m) + ( J(1,m,0)\(-I(m-1),J(m-1,1,0)) )   )	
	HINT = HINT, (  rowshape(SPCUT[NINT:-1+(NINT:==1),1]:*(NINT:>1) + PCUT[NINT,1]:*LINT, m)' * T'  )	

	st_view(BETA = ., ., tokens(bvars))
	st_view(D = ., ., tokens(dvars))
	BETA = BETA
	D = D
	mp = st_matrix("probas")
	mv = st_matrix("vhet")
	R = cols(mp)
	N = rows(D)
	
	L = rowshape(	rowsum(   (mv#J(N,1,1)) :* (J(R,1,1)#D) - (J(R,1,1)#HINT) :* exp(J(R,1,1)#BETA + mv#J(N,1,1))   )   ,   R)'
	L = log(rowsum(exp(mp:+L))) + rowsum(D:*(log(H)+BETA)) :- log(rowsum(exp(mp)))
	L = L + D[.,2]:*MEHAZ[.,1]:*D[.,1] + D[.,3]:*MEHAZ[.,2]:*D[.,1]
	st_store(.,"lnvrais",L)	
}
end

// likelihood
cap program drop lik
program define lik
	
	local nv = nv
	foreach vv in "P" "Z" "Y"{
		local K_`vv' = colsof(Tcut_`vv') - 1
		local I_`vv' = "I2_`vv'"
		foreach kk of numlist 3(1)`K_`vv''{
			local I_`vv' = "`I_`vv'' I`kk'_`vv'"
		}
	}
	foreach bb of numlist 1(1)`nv'{
		local vphet  = "`vphet' vp`bb'"
		local vzhet  = "`vzhet' vz`bb'"
		local vyhet  = "`vyhet' vy`bb'"
		if `bb' ~= `nv'{
			local pzyhet = "`pzyhet' pzy`bb'"
		}
	}
	local pzy`nv' = 0
		
	args lnf bp gp01 gp13 gp3 bz dp01 dp13 dp3 by `I_P' `I_Z' `I_Y' `vphet' `vzhet' `vyhet' `pzyhet'
	tempvar gp dp

	foreach vv in "P" "Z" "Y"{
		matrix Pcut_`vv' = .0001
		matrix SPcut_`vv' = (.0001 * LTcut_`vv'[1,1]) \ J(`K_`vv''-1,1,0)
		foreach kk of numlist 2(1)`K_`vv''{
			matrix Pcut_`vv' = Pcut_`vv' \ (exp(-`I`kk'_`vv''[1]) / (1+exp(-`I`kk'_`vv''[1])))
			matrix SPcut_`vv'[`kk',1] = SPcut_`vv'[`kk'-1,1] + Pcut_`vv'[`kk',1]*LTcut_`vv'[1,`kk']
		}
	}
	gen double `gp' = `gp01'*(tP<tZ)*(tZ<=tP1) + `gp13'*(tP1<tZ)*(tZ<=tP3) + `gp3'*(tP3<tZ)
	gen double `dp' = `dp01'*(tP<tY)*(tY<=tP1) + `dp13'*(tP1<tY)*(tY<=tP3) + `dp3'*(tP3<tY)
	
	matrix effects = (`gp01'[1],`gp13'[1],`gp3'[1],`dp01'[1],`dp13'[1],`dp3'[1])
	matrix probas = `pzy`nv''
	matrix vhet = `vp`nv''[1], `vz`nv''[1], `vy`nv''[1]
	if `nv'>1{
		local nv1 = `nv'-1
		foreach rr of numlist `nv1'(1)1{
			matrix probas = `pzy`rr''[1], probas
			matrix vhet = (`vp`rr''[1], `vz`rr''[1], `vy`rr''[1]) \ vhet
		}
	}
	gen double lnvrais = 0
	mata: LL("`bp' `bz' `by'", "P Z Y", "nint_P lint_P",  ///
	                                    "nint_ZtP nint_ZtP1 nint_ZtP3 nint_Z lint_ZtP lint_ZtP1 lint_ZtP3 lint_Z",  ///
										"nint_YtP nint_YtP1 nint_YtP3 nint_Y lint_YtP lint_YtP1 lint_YtP3 lint_Y", "`gp' `dp'")
	qui replace `lnf' = lnvrais
	drop lnvrais

end

use table_75

drop *A1 *A2 *A3 *A4 *A5 *B1 *B2 *B3 *B4 *B5

// create intervals for piecewise constant hazards
foreach vv of varlist P Z Y{
	su t`vv' if `vv'==1, det
	_pctile t`vv' if `vv'==1, nq(11)
	matrix Tcut_`vv' = (0, r(r1), r(r2), r(r3), r(r4), r(r5), r(r6), r(r7), r(r8), r(r9), r(r10), .)
}
// length of each interval (used later to compute integrated hazard rates)
qui foreach vv of varlist P Z Y{
	local K_`vv' = colsof(Tcut_`vv')-2
	matrix LTcut_`vv' = Tcut_`vv'[1,2]
	foreach cc of numlist 2(1)`K_`vv''{
		matrix LTcut_`vv' = LTcut_`vv', Tcut_`vv'[1,`cc'+1]-Tcut_`vv'[1,`cc']
	}
	noi matrix list Tcut_`vv'
	noi matrix list LTcut_`vv'
}

// tZ (resp. tY) is censored by tY (resp. tZ)
replace Y = 0 if tZ < tY
replace tY = tZ if tZ < tY

foreach tt of numlist 1 3{
	gen tP`tt' = tP + 30*`tt'
	replace tP`tt' = tZ if tZ<=tP`tt'
}

// find which interval each duration belongs to
foreach vv of varlist P Z Y{
	matrix Tcut = Tcut_`vv'
	mata: Inter("t`vv'")
	rename nint nint_`vv'
	rename lint lint_`vv'
}
foreach vv of varlist Z Y{
	matrix Tcut = Tcut_`vv'
	foreach ww of varlist tP tP1 tP3{
		mata: Inter("`ww'")
		rename nint nint_`vv'`ww'
		rename lint lint_`vv'`ww'
	}
}

matrix M = (  -.2503427,  1.095747, -1.111194, -.3466542,   .228139, -.1779796, -.0357615,  .0979616,  .2797555,  .4380528,  .2091984,  .1826208,  .2344999,  ///
               .0536777,  .2325145,  .2208252,  .2079253,  .2800609, -5.631953, -.2577436,  3.477755,  3.198145,  2.476576, -.1565191,  1.782457,  -1.87697,  ///
              -.0309763,  .0336657,  .3694039,  .0349636,  .0097894, -.3737974, -.3741546, -.2537781, -.2419749,  -.034299, -.6489743,  .0762391, -.1636345,  ///
               .0062899,  .1709136,  .4970911,  .1287766, -.6154377, -.4768891, -.2334397,  .0458767, -1.304751,  1.165482, -.2954503,  .2222098, -.2208908,  ///
              -.0205214,  .0176795,  .0795176,  .0794567,  -.032139,  -.201985, -.0506264,  .2122047,  .1196189,  .1405139,  .0175493, -.0066169,  2.772966,  ///
               -.733962,  9.537329,  9.637847,  9.843151,  10.38039,  10.62291,  10.65529,  10.28282,  10.64147,  10.69868,  11.21815,  9.186808,  9.012974,  ///
               8.958628,  8.911931,  8.923253,  9.087379,  9.181146,  9.399104,  9.536073,  10.63704,  8.435178,  9.039169,  9.181936,  9.303206,  9.210694,  ///
               9.150933,  9.291051,  9.394037,  9.326082,  9.133421,  1.902381,  2.876954,  1.567064, -.6170112,  2.714978,  3.621703, -1.050106  )
				   
maxlik, nv(2) minit
matrix list e(b)
matrix eb2 = e(b)
matrix eV2 = e(V)

matrix bb0 = eb2[1,21], eb2[1,44]
matrix bb1 = eb2[1,22], eb2[1,45]
matrix bb3 = eb2[1,23], eb2[1,46]
matrix VV0 = (eV2[21,21], eV2[21,44]) \ (eV2[44,21], eV2[44,44])
matrix VV1 = (eV2[22,22], eV2[22,45]) \ (eV2[45,22], eV2[45,45])
matrix VV3 = (eV2[23,23], eV2[23,46]) \ (eV2[46,23], eV2[46,46])

matrix sd0 = (eb2[1,44], eb2[1,21]) * VV0 * (eb2[1,44], eb2[1,21])'
matrix sd1 = (eb2[1,45], eb2[1,22]) * VV1 * (eb2[1,45], eb2[1,22])'
matrix sd3 = (eb2[1,46], eb2[1,23]) * VV3 * (eb2[1,46], eb2[1,23])'
matrix sd = sqrt(sd0[1,1]), sqrt(sd1[1,1]), sqrt(sd3[1,1])
matrix list sd





