% G2SLS estimation of the SEM system with fixed effect

% M1
% y1 = lam11*Wy1 + X1*beta1 + u1
% y2 = lam22*Wy2 + X2*beta2 + u2

% M2
% y1 = phi21*y2 + X1*beta1 + u1
% y2 = phi12*y1 + X2*beta2 + u2

% M3
% y1 = lam11*Wy1 + lam21*Wy2 + X1*beta1 + u1
% y2 = lam22*Wy2 + lam12*Wy1 + X2*beta2 + u2

% M4
% y1 = lam11*Wy1 + phi21*y2 + X1*beta1 + u1
% y2 = lam22*Wy2 + phi12*y1 + X2*beta2 + u2

% M5
% y1 = lam11*Wy1 + lam21*Wy2 + phi21*y2 + X1*beta1 + u1
% y2 = lam22*Wy2 + lam12*Wy1 + phi12*y1 + X2*beta2 + u2

%clear
%model = 'M5';

CE = 1; % contextual effect(CE)
load input\data_input.mat
%%
np = size(Xn,1);
ng = size(gsize,1);

nz = 0;
for i = 1:ng
    mr = gsize(i);
    nz = nz+mr^2;
end

Ji = zeros(nz,1);
Jj = zeros(nz,1);
Jv = zeros(nz,1);

n1 = 0;
n2 = 0;
for i = 1:ng
    mr = gsize(i);
    Ji(n1+1:n1+mr^2) = kron(ones(mr,1),(1:mr)')+n2;
    Jj(n1+1:n1+mr^2) = kron((1:mr)',ones(mr,1))+n2;
    Jv(n1+1:n1+mr^2) = 1/mr;
    n1 = n1+mr^2;
    n2 = n2+mr;
end
Jn = speye(np)-sparse(Ji,Jj,Jv);
%%
switch model
    case 'M1'
        Yn1 = Jn*Wn*Y1;
        Yn2 = Jn*Wn*Y2;
    case 'M2'
        Yn1 = Jn*Y2;
        Yn2 = Jn*Y1;
    case 'M3'
        Yn1 = Jn*[Wn*Y1,Wn*Y2];
        Yn2 = Jn*[Wn*Y2,Wn*Y1];
    case 'M4'
        Yn1 = Jn*[Wn*Y1,Y2];
        Yn2 = Jn*[Wn*Y2,Y1];
    otherwise
        Yn1 = Jn*[Wn*Y1,Wn*Y2,Y2];
        Yn2 = Jn*[Wn*Y2,Wn*Y1,Y1];
end

switch model
    case {'M1','M3','M4'}
        X1 = [Xn,IV1,IV2];
        X2 = [Xn,IV1,IV2];
    otherwise
        X1 = [Xn,IV1];
        X2 = [Xn,IV2];
end
q1 = size(X1,2);
q2 = size(X2,2);

if CE == 0
    Xn1 = Jn*X1;
    Xn2 = Jn*X2;
else
    Xn1 = Jn*[X1,Wn*X1];
    Xn2 = Jn*[X2,Wn*X2];
end

switch model
    case 'M2'
        if CE == 0
            Zn1 = Jn*IV2;
            Zn2 = Jn*IV1;
        else
            Zn1 = Jn*[IV2,Wn*IV2];
            Zn2 = Jn*[IV1,Wn*IV1];
        end
    case {'M1','M3','M4'}
        WX1 = Wn*XX1;
        WX2 = Wn*XX2;
        if CE == 0
            Zn1 = Jn*WX1;
            Zn2 = Jn*WX2;
        else
            Zn1 = Jn*(Wn*WX1);
            Zn2 = Jn*(Wn*WX2);
        end
    otherwise
        WX1 = Wn*XX1;
        WX2 = Wn*XX2;
        if CE == 0
            WnX = Wn*[IV1,IV2];
            Zn1 = Jn*[IV2,WnX];
            Zn2 = Jn*[IV1,WnX];
        else
            Zn1 = Jn*[IV2,Wn*IV2,Wn*WX1];
            Zn2 = Jn*[IV1,Wn*IV1,Wn*WX2];
        end
end

Z1 = [Yn1,Xn1];
Z2 = [Yn2,Xn2];

Y1 = Jn*Y1;
Y2 = Jn*Y2;

Q1 = [Xn1,Zn1];
P1 = Q1*((Q1'*Q1)\Q1');

Q2 = [Xn2,Zn2];
P2 = Q2*((Q2'*Q2)\Q2');

%{
temp = [Y1,Yn1,Xn1,Zn1];
name = cell(1,size(temp,2));
name{1} = 'y';
for j = 1:size(Yn1,2)
    name{j+1} = ['w' num2str(j)];
end
for j = 1:size(Xn1,2)
    name{j+1+size(Yn1,2)} = ['x' num2str(j)];
end
for j = 1:size(Zn1,2)
    name{j+1+size([Yn1,Xn1],2)} = ['z' num2str(j)];
end

data1 = [name;num2cell(temp)];
delete('output\stata_data1.xlsx');
xlswrite('output\stata_data1.xlsx',data1);
%}
%% 2SLS of equation 1
ZZ = (Z1'*P1*Z1)\(Z1'*P1);
b1 = ZZ*Y1;

% robust s.e.
e1 = Y1-Z1*b1;
VV = sparse(1:np,1:np,e1.^2);
V1 = ZZ*VV*ZZ';
s1 = sqrt(spdiags(V1,0));

% p-value of OIR test
Qe = Q1'*e1;
QQ = Q1'*VV*Q1;
p1 = 1-chi2cdf(Qe'*(QQ\Qe),size(Q1,2)-size(Z1,2));

% First-stage F test (Stock and Yogo 2005)
Vn = Yn1'*(speye(np)-P1)*Yn1/(np-ng-size(Q1,2));
Mn = speye(np)-Xn1*((Xn1'*Xn1)\Xn1');
YM = Mn*Yn1;
ZM = Mn*Zn1;

PZ = ZM*((ZM'*ZM)\ZM');
YV = YM/sqrtm(Vn);
SY = full(YV'*PZ*YV);
f1 = min(eig(SY/size(Zn1,2)));
%% 2SLS of equation 2
ZZ = (Z2'*P2*Z2)\(Z2'*P2);
b2 = ZZ*Y2;

% robust s.e.
e2 = Y2-Z2*b2;
VV = sparse(1:np,1:np,e2.^2);
V2 = ZZ*VV*ZZ';
s2 = sqrt(spdiags(V2,0));

% p-value of OIR test
Qe = Q2'*e2;
QQ = Q2'*VV*Q2;
p2 = 1-chi2cdf(Qe'*(QQ\Qe),size(Q2,2)-size(Z2,2));

% First-stage F test (Stock and Yogo 2005)
Vn = Yn2'*(speye(np)-P2)*Yn2/(np-ng-size(Q2,2));
Mn = speye(np)-Xn2*((Xn2'*Xn2)\Xn2');
YM = Mn*Yn2;
ZM = Mn*Zn2;

PZ = ZM*((ZM'*ZM)\ZM');
YV = YM/sqrtm(Vn);
SY = full(YV'*PZ*YV);
f2 = min(eig(SY/size(Zn2,2)));
%%
fid = fopen(['output\' model '_output.txt'],'a');
fprintf(fid,'%s \n',date);

if CE == 0
    fprintf(fid,'without contextual effect\n');
else
    fprintf(fid,'with contextual effect\n');
end    
fprintf(fid,'sample size = %4.0f\n',np);
fprintf(fid,'# of networks = %4.0f\n',ng);
fprintf(fid,'gsize min  = %5.0f, max = %5.0f\n',min(gsize),max(gsize));
fprintf(fid,'symm = %4.0f\n',symm);
fprintf(fid,'lnTV = %4.0f\n',lnTV);
fprintf(fid,'~~~~~~~~~~~~~~~~~~~\n');
fprintf(fid,'2SLS for equation 1 \n');
fprintf(fid,'OIR test p-value is %7.3f\n',p1);
fprintf(fid,'First stage F test statistic is %7.3f\n',f1);
for ip = 1:length(X1name)
    fprintf(fid,'%s, ',X1name{ip});
end
fprintf(fid,'\n~~~~~~~~~~~~~~~~~~~\n');
for ip = 1:size(Yn1,2)
    tstat = b1(ip)/s1(ip);
    tstat = abs(tstat);
    if tstat >= 2.326
        fprintf(fid,' %7.4f***\n (%6.4f)\n',b1(ip),s1(ip));
    elseif tstat >= 1.96
        fprintf(fid,' %7.4f**\n (%6.4f)\n',b1(ip),s1(ip));
    elseif tstat >= 1.645
        fprintf(fid,' %7.4f*\n (%6.4f)\n',b1(ip),s1(ip));
    else
        fprintf(fid,' %7.4f\n (%6.4f)\n',b1(ip),s1(ip));
    end
end
fprintf(fid,'~~~~~~~~~~~~~~~~~~~\n');
for ip = size(Yn1,2)+1:size(Yn1,2)+q1
    tstat = b1(ip)/s1(ip);
    tstat = abs(tstat);
    %{
    switch model
        case {'M1','M3','M4'}
            fprintf(fid,'%s \n',names{ip-size(Yn1,2)});
        otherwise
            fprintf(fid,'%s \n',name1{ip-size(Yn1,2)});
    end
    %}
    if tstat >= 2.326
        fprintf(fid,' %7.4f***\n (%6.4f)\n',b1(ip),s1(ip));
    elseif tstat >= 1.96
        fprintf(fid,' %7.4f**\n (%6.4f)\n',b1(ip),s1(ip));
    elseif tstat >= 1.645
        fprintf(fid,' %7.4f*\n (%6.4f)\n',b1(ip),s1(ip));
    else
        fprintf(fid,' %7.4f\n (%6.4f)\n',b1(ip),s1(ip));
    end
end
fprintf(fid,'~~~~~~~~~~~~~~~~~~~\n');
fprintf(fid,'2SLS for equation 2 \n');
fprintf(fid,'OIR test p-value is %7.3f\n',p2);
fprintf(fid,'First stage F test statistic is %7.3f\n',f2);
for ip = 1:length(X2name)
    fprintf(fid,'%s, ',X2name{ip});
end
fprintf(fid,'\n~~~~~~~~~~~~~~~~~~~\n');
for ip = 1:size(Yn2,2)
    tstat = b2(ip)/s2(ip);
    tstat = abs(tstat);
    if tstat >= 2.326
        fprintf(fid,' %7.4f***\n (%6.4f)\n',b2(ip),s2(ip));
    elseif tstat >= 1.96
        fprintf(fid,' %7.4f**\n (%6.4f)\n',b2(ip),s2(ip));
    elseif tstat >= 1.645
        fprintf(fid,' %7.4f*\n (%6.4f)\n',b2(ip),s2(ip));
    else
        fprintf(fid,' %7.4f\n (%6.4f)\n',b2(ip),s2(ip));
    end
end
fprintf(fid,'~~~~~~~~~~~~~~~~~~~\n');
for ip = size(Yn2,2)+1:size(Yn2,2)+q2
    tstat = b2(ip)/s2(ip);
    tstat = abs(tstat);
    %{
    switch model
        case {'M1','M3','M4'}
            fprintf(fid,'%s \n',names{ip-size(Yn2,2)});
        otherwise
            fprintf(fid,'%s \n',name2{ip-size(Yn2,2)});
    end
    %}
    if tstat >= 2.326
        fprintf(fid,' %7.4f***\n (%6.4f)\n',b2(ip),s2(ip));
    elseif tstat >= 1.96
        fprintf(fid,' %7.4f**\n (%6.4f)\n',b2(ip),s2(ip));
    elseif tstat >= 1.645
        fprintf(fid,' %7.4f*\n (%6.4f)\n',b2(ip),s2(ip));
    else
        fprintf(fid,' %7.4f\n (%6.4f)\n',b2(ip),s2(ip));
    end
end
kq = size([Xn,IV1,IV2],2);
kx = size(Xn,2);
k1 = size(IV1,2);
k2 = size(IV2,2);
if strcmp(model,'M5')
    b1 = [b1(1:kx+k1+3);zeros(k2,1);b1(kx+k1+4:end);zeros(k2,1)];
    b2 = [b2(1:kx+3);zeros(k1,1);b2(kx+4:2*kx+k2+3);zeros(k1,1);b2(2*kx+k2+4:end)];
    
    V1 = blkdiag(V1(1:kx+k1+3,1:kx+k1+3),zeros(k2),V1(kx+k1+4:end,kx+k1+4:end),zeros(k2));
    V2 = blkdiag(V2(1:kx+3,1:kx+3),zeros(k1),V2(kx+4:2*kx+k2+3,kx+4:2*kx+k2+3),zeros(k1),V2(2*kx+k2+4:end,2*kx+k2+4:end));
    
    lam11 = b1(1);
    lam21 = b1(2);
    phi21 = b1(3);
    beta1 = b1(4:3+kq);
    gama1 = b1(kq+4:end);
    
    lam22 = b2(1);
    lam12 = b2(2);
    phi12 = b2(3);
    beta2 = b2(4:3+kq);
    gama2 = b2(kq+4:end);
    W2 = Wn*Wn;
    S0 = (1-phi12*phi21)*speye(np)-(lam11+lam22+phi21*lam12+phi12*lam21)*Wn+(lam11*lam22-lam12*lam21)*W2;
    S0 = S0\speye(np);
    Slam11 = -Wn+lam22*W2;
    Slam22 = -Wn+lam11*W2;
    Slam12 = -phi21*Wn-lam21*W2;
    Slam21 = -phi12*Wn-lam12*W2;
    Sphi12 = -phi21*speye(np)-lam21*Wn;
    Sphi21 = -phi12*speye(np)-lam12*Wn;
    
    Slam11 = -S0*Slam11*S0;
    Slam22 = -S0*Slam22*S0;
    Slam12 = -S0*Slam12*S0;
    Slam21 = -S0*Slam21*S0;
    Sphi12 = -S0*Sphi12*S0;
    Sphi21 = -S0*Sphi21*S0;
    
    m1_d = zeros(kq,1);
    m1_i = zeros(kq,1);
    m2_d = zeros(kq,1);
    m2_i = zeros(kq,1);
    
    s1_d = zeros(kq,1);
    s1_i = zeros(kq,1);
    s2_d = zeros(kq,1);
    s2_i = zeros(kq,1);
    for h = 1:kq
        V1h = [V1(1:3,:);V1(3+h,:);V1(3+kq+h,:)];
        V1h = [V1h(:,1:3),V1h(:,3+h),V1h(:,3+kq+h)];
        
        V2h = [V2(1:3,:);V2(3+h,:);V2(3+kq+h,:)];
        V2h = [V2h(:,1:3),V2h(:,3+h),V2h(:,3+kq+h)];
        
        f1 = (phi21*beta2(h)+beta1(h))*speye(np)+(lam21*beta2(h)-lam22*beta1(h)+phi21*gama2(h)+gama1(h))*Wn+(lam21*gama2(h)-lam22*gama1(h))*W2;
        f2 = (phi12*beta1(h)+beta2(h))*speye(np)+(lam12*beta1(h)-lam11*beta2(h)+phi12*gama1(h)+gama2(h))*Wn+(lam12*gama1(h)-lam11*gama2(h))*W2;
        
        M1 = S0*f1;
        M2 = S0*f2;
        
        m1_d(h) = trace(M1)/np;
        m1_i(h) = sum(sum(M1))/np - trace(M1)/np;
        
        m2_d(h) = trace(M2)/np;
        m2_i(h) = sum(sum(M2))/np - trace(M2)/np;
        
        M1_lam11 = Slam11*f1;
        M1_lam21 = Slam21*f1 + S0*(beta2(h)*Wn+gama2(h)*W2);
        M1_phi21 = Sphi21*f1 + S0*(beta2(h)*speye(np)+gama2(h)*Wn);
        M1_beta1 = S0*(speye(np)-lam22*Wn);
        M1_gama1 = S0*(Wn-lam22*W2);
        
        M1_lam22 = Slam22*f1 - S0*(beta1(h)*Wn+gama1(h)*W2);
        M1_lam12 = Slam12*f1;
        M1_phi12 = Sphi12*f1;
        M1_beta2 = S0*(phi21*speye(np)+lam21*Wn);
        M1_gama2 = S0*(phi21*Wn+lam21*W2);
        
        M1b1t = [trace(M1_lam11);trace(M1_lam21);trace(M1_phi21);trace(M1_beta1);trace(M1_gama1)]/np;
        M1b2t = [trace(M1_lam22);trace(M1_lam12);trace(M1_phi12);trace(M1_beta2);trace(M1_gama2)]/np;
        
        s1_d(h) = sqrt(M1b1t'*V1h*M1b1t + M1b2t'*V2h*M1b2t);
        
        M1b1s = [sum(sum(M1_lam11));sum(sum(M1_lam21));sum(sum(M1_phi21));sum(sum(M1_beta1));sum(sum(M1_gama1))]/np - M1b1t;
        M1b2s = [sum(sum(M1_lam22));sum(sum(M1_lam12));sum(sum(M1_phi12));sum(sum(M1_beta2));sum(sum(M1_gama2))]/np - M1b2t;
        
        s1_i(h) = sqrt(M1b1s'*V1h*M1b1s + M1b2s'*V2h*M1b2s);
        
        M2_lam11 = Slam11*f2 - S0*(beta2(h)*Wn+gama2(h)*W2);
        M2_lam21 = Slam21*f2;
        M2_phi21 = Sphi21*f2;
        M2_beta1 = S0*(phi12*speye(np)+lam12*Wn);
        M2_gama1 = S0*(phi12*Wn+lam12*W2);
        
        M2_lam22 = Slam22*f2;
        M2_lam12 = Slam12*f2 + S0*(beta1(h)*Wn+gama1(h)*W2);
        M2_phi12 = Sphi12*f2 + S0*(beta1(h)*speye(np)+gama1(h)*Wn);
        M2_beta2 = S0*(speye(np)-lam11*Wn);
        M2_gama2 = S0*(Wn-lam11*W2);
        
        M2b1t = [trace(M2_lam11);trace(M2_lam21);trace(M2_phi21);trace(M2_beta1);trace(M2_gama1)]/np;
        M2b2t = [trace(M2_lam22);trace(M2_lam12);trace(M2_phi12);trace(M2_beta2);trace(M2_gama2)]/np;
        
        s2_d(h) = sqrt(M2b1t'*V1h*M2b1t + M2b2t'*V2h*M2b2t);
        
        M2b1s = [sum(sum(M2_lam11));sum(sum(M2_lam21));sum(sum(M2_phi21));sum(sum(M2_beta1));sum(sum(M2_gama1))]/np - M2b1t;
        M2b2s = [sum(sum(M2_lam22));sum(sum(M2_lam12));sum(sum(M2_phi12));sum(sum(M2_beta2));sum(sum(M2_gama2))]/np - M2b2t;
        
        s2_i(h) = sqrt(M2b1s'*V1h*M2b1s + M2b2s'*V2h*M2b2s);
    end
    
    fprintf(fid,'*******************\n');
    fprintf(fid,'marginal effects for equation 1 \n');
    fprintf(fid,'direct effects \n');
    for h = 1:kq
        tstat = m1_d(h)/s1_d(h);
        tstat = abs(tstat);
        if tstat >= 2.326
            fprintf(fid,' %7.4f***\n (%6.4f)\n',m1_d(h),s1_d(h));
        elseif tstat >= 1.96
            fprintf(fid,' %7.4f**\n (%6.4f)\n',m1_d(h),s1_d(h));
        elseif tstat >= 1.645
            fprintf(fid,' %7.4f*\n (%6.4f)\n',m1_d(h),s1_d(h));
        else
            fprintf(fid,' %7.4f\n (%6.4f)\n',m1_d(h),s1_d(h));
        end
    end
    fprintf(fid,'~~~~~~~~~~~~~~~~~~~\n');
    fprintf(fid,'indirect effects \n');
    for h = 1:kq
        tstat = m1_i(h)/s1_i(h);
        tstat = abs(tstat);
        if tstat >= 2.326
            fprintf(fid,' %7.4f***\n (%6.4f)\n',m1_i(h),s1_i(h));
        elseif tstat >= 1.96
            fprintf(fid,' %7.4f**\n (%6.4f)\n',m1_i(h),s1_i(h));
        elseif tstat >= 1.645
            fprintf(fid,' %7.4f*\n (%6.4f)\n',m1_i(h),s1_i(h));
        else
            fprintf(fid,' %7.4f\n (%6.4f)\n',m1_i(h),s1_i(h));
        end
    end
    fprintf(fid,'~~~~~~~~~~~~~~~~~~~\n');
    fprintf(fid,'~~~~~~~~~~~~~~~~~~~\n');
    fprintf(fid,'marginal effects for equation 2 \n');
    fprintf(fid,'direct effects \n');
    for h = 1:kq
        tstat = m2_d(h)/s2_d(h);
        tstat = abs(tstat);
        if tstat >= 2.326
            fprintf(fid,' %7.4f***\n (%6.4f)\n',m2_d(h),s2_d(h));
        elseif tstat >= 1.96
            fprintf(fid,' %7.4f**\n (%6.4f)\n',m2_d(h),s2_d(h));
        elseif tstat >= 1.645
            fprintf(fid,' %7.4f*\n (%6.4f)\n',m2_d(h),s2_d(h));
        else
            fprintf(fid,' %7.4f\n (%6.4f)\n',m2_d(h),s2_d(h));
        end
    end
    fprintf(fid,'~~~~~~~~~~~~~~~~~~~\n');
    fprintf(fid,'indirect effects \n');
    for h = 1:kq
        tstat = m2_i(h)/s2_i(h);
        tstat = abs(tstat);
        if tstat >= 2.326
            fprintf(fid,' %7.4f***\n (%6.4f)\n',m2_i(h),s2_i(h));
        elseif tstat >= 1.96
            fprintf(fid,' %7.4f**\n (%6.4f)\n',m2_i(h),s2_i(h));
        elseif tstat >= 1.645
            fprintf(fid,' %7.4f*\n (%6.4f)\n',m2_i(h),s2_i(h));
        else
            fprintf(fid,' %7.4f\n (%6.4f)\n',m2_i(h),s2_i(h));
        end
    end
end
fprintf(fid,'*******************\n\n');
fclose(fid);