// This version allows for only one of the NL parameters to be industry specific.

new;
startTim = date;

library maxlik, pgraph;
maxset;
#include optim.sdf

output file=out.txt reset;
outwidth 250;



cutoffyear = 1975;
subcatlev = 1;  // can be set to 1 (6 subcats) or 2 (31 subcats)
citeApp = 1; // set to 0 to measure birth and cite by grant year, set to 1 to measure birth and cite by app ye
sepNLPar1 = 1; // set to 1 to estimate separate nonlinear parameters for each subcategory
sepNLPar2 = 1; // set to 1 to estimate separate nonlinear parameters for each subcategory
ageChkYrs = 15; // Number of years used to do summary stats at end;
useCYdums = 1;
useBYdums = 1;
useSCdums = 1;
useNLpars = 1;

#include datamake.g;

xLinearDum = ones(n,1);
xLinear = zeros(n,3);
parNames = "const";

if useCYdums;
    xLinearDum = xLinearDum~cyDumMat;
    xLinear[.,1] = citeYrIdx;
    cyNames = "cite yr";
    parNames = parNames|(cyNames*ones(cols(cyDumMat),1));
endif;
if useBYdums;
    xLinearDum = xLinearDum~byDumMat;
    xLinear[.,2] = birthYrIdx;
    parNames = parnames|("birth yr"*ones(cols(byDumMat),1));
endif;
if useSCdums;
    xLinearDum = xLinearDum~subCatMat;
    xLinear[.,3] = subCatIdx;
    parNames = parnames|("sub cat"*ones(cols(subCatMat),1));
endif;



kParamLinear = cols(xLinearDum);

thetaLin0 = zeros(kParamLinear,1);

thetalin0 = inv(xLinearDum'xLinearDum)*xLinearDum'ncites;

thetaNonLin0 = ln(0.0749|0.1);
thetaNonLin0 = 0|0;
kNL1 = 1 + cols(subCatmat)*sepNLPar1;
kNL2 = 1 + cols(subCatmat)*sepNLPar2;

if useNLpars;
    parnames = parnames|("nlPar1"*ones(kNL1,1));
    parnames = parnames|("nlPar2"*ones(kNL2,1));
endif;

age1 = age;
age2 = age;
thetaNonLin01 = thetaNonLin0[1];
thetaNonLin02 = thetaNonLin0[2];
if sepNLPar1;
    thetaNonLin01 = thetaNonLin0[1].*.ones(kNL1,1);
    age1 = age1.*((ones(n,1)-sumc(subCatMat'))~subCatMat);
endif;
if sepNLPar2;
    thetaNonLin02 = thetaNonLin0[2].*.ones(kNL2,1);
    age2 = age2.*((ones(n,1)-sumc(subCatMat'))~subCatMat);
endif;
thetanonlin0 = thetanonlin01|thetaNonLin02;


datasetmat = ncites~xLinearDum~age1~age2~_freq_;



theta0 = thetaLin0;
if useNLpars;
    theta0 = thetalin0|thetanonlin0;
endif;

//load theta0 = theta2;


__weight = _freq_;
_max_covpar = 3;
__output = 1;
numObs = sumc(_freq_);
_max_parnames = parnames;
//_max_active = active;

output off;
{theta,f,g,cov,rc} =maxlik(datasetmat, 0, &likefunc,theta0);
output on;


cov2 = cov*n/sumc(_freq_);
thetaTrue = getTrueParam(theta);
dgtp = gradp(&getTrueParam, theta);
covTrue = dgtp*cov2*dgtp';

{thetaTrue,f,g,covTrue,rc} = maxprt(thetaTrue,f,g,covTrue,rc);


// Compute error by age:

thetaLin = theta[1:kParamLinear];
f= 0;
if useNLpars;
    thetaNonLin = theta[kParamLinear+1:kParamLinear+kNL1+kNL2];   
    f = getFfunc(thetaNonLin, age1, age2);
endif;
lambda = xlinearDum*thetaLin + f;
expLambda = exp(lambda);
eps = ncites - expLambda;

ageMax = maxc(age); ageMin = minc(age);
offset = 1-ageMin;
ageeps = zeros(ageMax-ageMin+1,1);
for i(ageMin,ageMax,1);
    epsX = selif(eps,age.==i);
    freqX = selif(_freq_,age.==i);
    ageeps[i+offset] = (epsX'freqX)/sumc(freqX);
endfor;

print "Average prediction error by age";
seqa(ageMin,1,ageMax-ageMin+1)~ageeps;


if useNLPars;
    thetaNonLin = theta[kParamLinear+1:kParamLinear+kNL1+kNL2];   
    
    ageCheck = seqa(1,1,ageChkYrs);
    ageCheck1 = ageCheck; ageCheck2 = ageCheck;
    if sepNLpar1;
        ageCheck1 = ageCheck.*.eye(6);
        ageCheck2 = ageCheck.*.ones(6,1);
    endif;
    if sepNLpar2;
        ageCheck1 = ageCheck.*.ones(6,1);
        ageCheck2 = ageCheck.*.eye(6);
    endif;
    if sepNLpar1*sepNLpar2;
        ageCheck1 = ageCheck.*.eye(6);
        ageCheck2 = ageCheck.*.eye(6);
    endif;

    fFunc = getFfunc(thetaNonLin, ageCheck1, ageCheck2);
    print "ffunc: " rows(ffunc)~cols(ffunc);
    fFuncSt = fFunc;
    if sepNLPar1 + sepNLPar2;
        fFunc = reshape(fFunc(ageChkYrs,6));
    endif;
    
    "Peak is at:"
    maxindc(fFunc);
    
    
    struct PV ageVal;
    struct DS nlPars;
    
    agePeak = zeros(1 + maxc(sepNLpar1|sepNLpar2)*5,1);
    if sepNLPar1+sepNLPar2 == 0;
        agePeak = getAgePeak(thetaNonLin[1], thetaNonLin[2]);
    else;
        for i(1,6,1);
            i1 = maxc(1|sepNLPar1*i);
            i2 = maxc(1|sepNLPar2*i);
            agePeak[i] = getAgePeak(thetaNonLin[i1],thetaNonLin[1+(sepNLPar1*5)+i2]);
        endfor;
    endif;
    
    indNames = "Chemical"|"Computer"|"Drugs"|"Electrical"|"Manufact"|"Other";
    
    print;
    
    

    // compute mean age of citation
    
    proportions = exp(ffunc).*ones(ageChkYrs,1);
    proportions = proportions./sumc(proportions)';
    
    meanAge = agecheck'proportions;
    
    sorter = sortc((indNames~agePeak),2);
    
    print "Industry        modal age        mean age    Sorted";
    d = printfmt(indNames~agepeak~meanAge'~sorter[.,1], 0~1~1~0);
        

    xy(seqa(1,1,ageChkYrs),fFunc);
endif;


finTime = date;
elapTime = etstr(ethsec(startTim, finTime));
print; print "Elapsed time during program run: " elapTime;

output off;


proc likefunc(theta, datasetmat);

local ncites, xlineardum, age1, age2, weights, thetalin, f, thetanonlin, lambda, explambda, lnlikefunc;
ncites = datasetmat[.,1];
xLinearDum = datasetmat[.,2:kParamLinear+1];
age1 = datasetmat[.,kParamLinear+2:kParamLinear+kNL1+1];
age2 = datasetmat[.,kParamLinear+kNL1+2:kParamLinear+kNL1+kNL2+1];

thetaLin = theta[1:kParamLinear];

f= 0;
if useNLpars;
    thetaNonLin = theta[kParamLinear+1:kParamLinear+kNL1+kNL2];   
    f = getFfunc(thetaNonLin, age1, age2);
endif;

lambda = xlinearDum*thetaLin + f;
expLambda = exp(lambda);

lnLikeFunc = ncites.*lambda - expLambda;
retp(lnLikeFunc);
endp;


proc getFfunc(theta,age1,age2);
local thetaNL1, thetaNL2, f;
thetaNL1 = exp(theta[1:kNL1]);
thetaNL2 = exp(theta[kNL1+1:kNL1+kNL2]);
f = age1*-thetaNL1 + ln(1-exp(age2*-thetaNL2));
//f = exp(age1*-thetaNL1).*ln(1-exp(age2*-thetaNL2));
retp(f);
endp;

proc getTrueParam(theta);
if useNLpars;
    theta[kParamLinear+1:kParamLinear+kNL1] = exp(theta[kParamLinear+1:kParamLinear+kNL1]);
    theta[kParamLinear+1+kNL1:kParamLinear+kNL1+kNL2] = exp(theta[kParamLinear+1+kNL1:kParamLinear+kNL1+kNL2]);
endif;
retp(theta);
endp;



// ******** Use bisection to compute where a delta falls in a set of discretized states
proc  getAgePeak(thetaNL1, thetaNL2);
local lo, hi, loVal, hiVal, mid, midVal;

nlPars.dataMatrix = thetaNL1|thetaNL2; 

lo = 1;
hi = 15;

ageVal = pvCreate;
ageVal = pvPack(ageVal, lo, "age");
loVal = gradMT(&getfderiv, ageVal, nlPars); 
ageVal = pvCreate;
ageVal = pvPack(ageVal, hi, "age");
hiVal = gradMT(&getfderiv, ageVal, nlPars); 

for i(1,500,1);
    if hi - lo<1e-3; break; endif;
    mid = (lo+hi)/2;
    ageVal = pvCreate;
    ageVal = pvPack(ageVal, mid, "age");
    midVal = gradMT(&getfderiv, ageVal, nlPars); 
    if midVal>0;
        lo = mid;
    elseif midVal<0; 
        hi = mid;
    else;
        hi = mid; lo = mid;
    endif;
endfor;

retp((hi+lo)/2);
endp;    

// ****** computes f func for program that finds max of ffunc *******
proc getfderiv(struct PV ageVal, struct DS nlPars);
local age, thetaNL1, thetaNL2, f;
age = pvUnpack(ageVal, "age"); 
thetaNL1 = exp(nlPars.dataMatrix[1]);
thetaNL2 = exp(nlPars.dataMatrix[2]);
//f = exp(age*-thetaNL1).*ln(1-exp(age*-thetaNL2));
f = age*-thetaNL1 + ln(1-exp(age*-thetaNL2));
retp(f);
endp;