clear all

// prog mata to compute the interval each duration belongs to
cap program drop Inter
mata:
	void Inter(string scalar t)
	{
	// import variables and matrices into mata
	st_view(vt,.,t)
	Tcut = st_matrix("Tcut")
	N = rows(vt)
	// interval to which t belongs (consider all intervals strictly included in [0,t] and add 1)
	Tcut_sup = Tcut[.,2..cols(Tcut)]
	nint = rowsum((vt :> (J(N,1,1) * Tcut_sup))) + J(N,1,1)
	// compute t minus the lower bound of the interval t belongs to
	Tcut_inf = Tcut[.,1..cols(Tcut) - 1]
	Ikt = (vt :> (J(N,1,1) * Tcut_inf)) :* ((vt :<= (J(N,1,1) * Tcut_sup))) 
	Ik_length = Tcut_sup-Tcut_inf,0
	lint = rowsum((vt * J(1,cols(Tcut) - 1,1) - J(N,1,1) * Tcut_inf) :* Ikt)
	// send results to stata
	st_addvar(("double",  "double"), ("nint", "lint"))
	st_store(.,"nint",nint)
	st_store(.,"lint",lint)
	}
end

// prog that declares model and maximizes likelihood
cap program drop maxlik
program define maxlik
syntax [, NV(real 1) MINIT]
	// declare model - parameters are ordered as follows: bp gp bz dp by I_p I_z I_y vp vz vy pzy
	local K_ALL = 0
	foreach vv of varlist P Z Y{
		local K_Tcut = colsof(Tcut_`vv')-2
		foreach kk of numlist 1(1)`K_Tcut'{
			local dkk = 5+`K_ALL'+`kk'
			local deltaIk_`vv' = "`deltaIk_`vv'' (delta`dkk':)"
		}
		local K_ALL = `K_ALL'+`K_Tcut'
	}
	scalar nv = `nv'
	local nv_all = 4*`nv'-1
	foreach hh of numlist 1(1)`nv_all'{
		local dhh = 5+`K_ALL'+`hh'
		local deltahet = "`deltahet' (delta`dhh':)"
	}

	ml model lf lik (delta1: X*, nocons) (delta2:) (delta3: X*, nocons) (delta4:) (delta5: X*, nocons)    ///
					`deltaIk_A1' `deltaIk_P' `deltaIk_B1' `deltaIk_Z' `deltaIk_Y' `deltahet', technique(bhhh 20 bfgs 50)
	// starting values
	if "`minit'"~=""{
		ml init M, copy
	}
	else{
		ml search
	}
	// maximize likelihood
	ml max
end

// prog to compute hazard rates and add the different bits of the likelihood
//clear mata
mata:
function LL(string bvars, string dvars, string intp, string intz, string inty)
{
	me = st_matrix("effects")

	st_view(NINT = ., ., tokens(intp))
	PCUT = st_matrix("Pcut_P")
	SPCUT = st_matrix("SPcut_P")
	LINT = NINT[.,2]
	NINT = NINT[.,1]
	H = PCUT[NINT,1]
	HINT = SPCUT[NINT:-1+(NINT:==1),1]:*(NINT:>1) + PCUT[NINT,1]:*LINT
	
	st_view(NINT = ., ., tokens(intz))
	PCUT = st_matrix("Pcut_Z")
	SPCUT = st_matrix("SPcut_Z")
	m = cols(NINT)/2
	LINT = NINT[.,m+1..2*m]
	NINT = NINT[.,1..m]
	H = H, PCUT[NINT[.,m],1]
	NINT = colshape(NINT',1)
	LINT = colshape(LINT',1)
	T = exp((0, me[1,1])) * (   I(m) + ( J(1,m,0)\(-I(m-1),J(m-1,1,0)) )   )	
	HINT = HINT, (  rowshape(SPCUT[NINT:-1+(NINT:==1),1]:*(NINT:>1) + PCUT[NINT,1]:*LINT, m)' * T'  )

	st_view(NINT = ., ., tokens(inty))
	PCUT = st_matrix("Pcut_Y")
	SPCUT = st_matrix("SPcut_Y")
	m = cols(NINT)/2
	LINT = NINT[.,m+1..2*m]
	NINT = NINT[.,1..m]
	H = H, PCUT[NINT[.,m],1]
	NINT = colshape(NINT',1)
	LINT = colshape(LINT',1)
	T = exp((0, me[1,2])) * (   I(m) + ( J(1,m,0)\(-I(m-1),J(m-1,1,0)) )   )	
	HINT = HINT, (  rowshape(SPCUT[NINT:-1+(NINT:==1),1]:*(NINT:>1) + PCUT[NINT,1]:*LINT, m)' * T'  )	

	st_view(BETA = ., ., tokens(bvars))
	st_view(D = ., ., tokens(dvars))
	BETA = BETA
	D = D
	mp = st_matrix("probas")
	mv = st_matrix("vhet")
	R = cols(mp)
	N = rows(D)
	
	L = rowshape(	rowsum(   (mv#J(N,1,1)) :* (J(R,1,1)#D) - (J(R,1,1)#HINT) :* exp(J(R,1,1)#BETA + mv#J(N,1,1))   )   ,   R)'
	L = log(rowsum(exp(mp:+L))) + rowsum(D:*(log(H)+BETA)) :- log(rowsum(exp(mp)))
	L = L + D[.,2]:*me[1,1]:*D[.,1] + D[.,3]:*me[1,2]:*D[.,1]
	st_store(.,"lnvrais",L)	
}
end

// likelihood
cap program drop lik
program define lik
	
	local nv = nv
	foreach vv in "P" "Z" "Y"{
		local K_`vv' = colsof(Tcut_`vv') - 1
		local I_`vv' = "I2_`vv'"
		foreach kk of numlist 3(1)`K_`vv''{
			local I_`vv' = "`I_`vv'' I`kk'_`vv'"
		}
	}
	foreach bb of numlist 1(1)`nv'{
		local vphet  = "`vphet' vp`bb'"
		local vzhet  = "`vzhet' vz`bb'"
		local vyhet  = "`vyhet' vy`bb'"
		if `bb' ~= `nv'{
			local pzyhet = "`pzyhet' pzy`bb'"
		}
	}
	local pzy`nv' = 0
		
	args lnf bp gp bz dp by `I_P' `I_Z' `I_Y' `vphet' `vzhet' `vyhet' `pzyhet'

	foreach vv in "P" "Z" "Y"{
		matrix Pcut_`vv' = .0001
		matrix SPcut_`vv' = (.0001 * LTcut_`vv'[1,1]) \ J(`K_`vv''-1,1,0)
		foreach kk of numlist 2(1)`K_`vv''{
			matrix Pcut_`vv' = Pcut_`vv' \ (exp(-`I`kk'_`vv''[1]) / (1+exp(-`I`kk'_`vv''[1])))
			matrix SPcut_`vv'[`kk',1] = SPcut_`vv'[`kk'-1,1] + Pcut_`vv'[`kk',1]*LTcut_`vv'[1,`kk']
		}
	}
	matrix effects = (`gp'[1],`dp'[1])
	matrix probas = `pzy`nv''
	matrix vhet = `vp`nv''[1], `vz`nv''[1], `vy`nv''[1]
	if `nv'>1{
		local nv1 = `nv'-1
		foreach rr of numlist `nv1'(1)1{
			matrix probas = `pzy`rr''[1], probas
			matrix vhet = (`vp`rr''[1], `vz`rr''[1], `vy`rr''[1]) \ vhet
		}
	}
	gen double lnvrais = 0
	mata: LL("`bp' `bz' `by'", "P Z Y", "nint_P lint_P", "nint_ZtP nint_Z lint_ZtP lint_Z", "nint_YtP nint_Y lint_YtP lint_Y")
	qui replace `lnf' = lnvrais
	drop lnvrais

end

use table_75

drop *A1 *A2 *A3 *A4 *A5 *B1 *B2 *B3 *B4 *B5

// create intervals for piecewise constant hazards
foreach vv of varlist P Z Y{
	su t`vv' if `vv'==1, det
	_pctile t`vv' if `vv'==1, nq(11)
	matrix Tcut_`vv' = (0, r(r1), r(r2), r(r3), r(r4), r(r5), r(r6), r(r7), r(r8), r(r9), r(r10), .)
}
// length of each interval (used later to compute integrated hazard rates)
qui foreach vv of varlist P Z Y{
	local K_`vv' = colsof(Tcut_`vv')-2
	matrix LTcut_`vv' = Tcut_`vv'[1,2]
	foreach cc of numlist 2(1)`K_`vv''{
		matrix LTcut_`vv' = LTcut_`vv', Tcut_`vv'[1,`cc'+1]-Tcut_`vv'[1,`cc']
	}
	noi matrix list Tcut_`vv'
	noi matrix list LTcut_`vv'
}

// tZ (resp. tY) is censored by tY (resp. tZ)
replace Y = 0 if tZ < tY
replace tY = tZ if tZ < tY

// tZ and tY are censored by tP+cut (if P==1)
local cut = 31
replace Z = 0 if P==1 & tZ>tP+`cut'
replace tZ = tP+`cut' if P==1 & tZ>tP+`cut'
replace Y = 0 if P==1 & tY>tP+`cut'
replace tY = tP+`cut' if P==1 & tY>tP+`cut'

// find which interval each duration belongs to
foreach vv of varlist P Z Y{
	matrix Tcut = Tcut_`vv'
	mata: Inter("t`vv'")
	rename nint nint_`vv'
	rename lint lint_`vv'
}
foreach vv of varlist Z Y{
	matrix Tcut = Tcut_`vv'
	mata: Inter("tP")
	rename nint nint_`vv'tP
	rename lint lint_`vv'tP
}

matrix M = (  -.26414375,  1.0943123, -1.1104551, -.33229034,  .21793678, -.18244728, -.03602972,  .05865271,  .30127646,   .4762183,  .21625415,  ///
               .18237786,  .22315086,   .0546282,  .18559054,  .21778545,  .17055954,  .29467053, -5.7501487, -.27505398,  3.9644872, -.20807506,  ///
               1.8689557, -1.9673002, -.07334174,  .05372056,   .3481237,  .00786719,  .04840993, -.33668841, -.34468082, -.26608427, -.24996483,  ///
                -.072606, -.62111135,  .10881638, -.13901036, -.01902242,  .15577402,  1.2327585,  .17030191, -.65256666,  .04923025, -1.3239817,  ///
               1.1912436, -.31365885,  .23700974, -.21997201, -.01976574,  .00713186,  .07686699,  .08732716, -.02289074, -.19615896, -.04954885,  ///
                .2193809,  .12858792,  .14826893,   .0113071, -.01600422,  2.9474561, -.74085051,  9.5629788,  9.6785842,  9.8577576,  10.394821,  ///
               10.638467,  10.697222,  10.312558,  10.663783,  10.689606,  11.166506,  9.1705146,  9.0220925,  8.9360722,  8.9297182,   8.973197,  ///
               9.1836837,  9.2373702,  9.4382793,  9.6642841,  10.865181,  8.4385872,  9.0436091,  9.1827572,  9.2935145,   9.207675,  9.1382738,  ///
               9.2753384,  9.3733534,  9.2919141,  9.0754739,   1.808617,  2.9,  1.72, -1.6,  2.6,  3.6, -1  )

maxlik, nv(2) minit
matrix list e(b)
matrix eb2 = e(b)
matrix eV2 = e(V)
matrix bb = eb2[1,21], eb2[1,42]
matrix VV = (eV2[21,21], eV2[21,42]) \ (eV2[42,21], eV2[42,42])
matrix sd = (eb2[1,42], eb2[1,21]) * VV * (eb2[1,42], eb2[1,21])'
matrix sd = sqrt(sd[1,1])
matrix list bb
matrix list sd


