clearvars
close all
date_string = '20210529';
t0 = tic();
fpath = ['.\results_theory\',date_string,'\'];
if ~exist(fpath, 'dir')
   mkdir(fpath)
end
profiler = 0;
if profiler==1 
    profile on;
end

% ============ INITIALIZATION ==============
N = 10000;
T = 800;     
m = 0.001*N;   % number of initially affected individuals
gamma = 1/14;
k0 = 10;  % average degree (number of connections)
R0 = 3;
kmin = 5;
kmax = 50;
flag.exact = 0; % 1=use exact formula to compute tau for the estimation exercise and
                % comparison of N and network topologies; 0=use approximate formula
flag.vacc = 0; % 0=no vaccination; 1=random vaccination; 2=older people first
flag.dist = 0; % 0=no distancing; 1=distancing
choice_network = 3; % 1=ER; 2=power law; 3=multigroup
if choice_network==3
    L = 5; % number of groups
end
nsimul = 1000;

switch choice_network
    case 1
        network_string = 'ER';
    case 2
        network_string = 'PL';
    case 3
        network_string = ['multigroup_L', num2str(L)];
end
if choice_network==1 && flag.vacc==2
    error('Network and vaccination scheme does not match!')
end

if flag.exact==1
    tau_onegroup = - log(1- R0*gamma/k0*(N-1)/N );
else
    tau_onegroup = R0*gamma/k0;
end
%============================================
switch choice_network
    case 1
        p = k0/(N-1);
        tau = tau_onegroup;
    case 2
        options = optimset('Display','off');
        fun = @(a)   (1/sum([kmin:kmax].^(-a))) * sum([kmin:kmax].^(1-a)) - k0;
        exponent = fsolve(fun,1.5,options);  % exponent of PDF 
        tau = tau_onegroup;
    case 3
        fname_contact = ['.\data_temp\Germany_contact_',num2str(L),'_groups'];
        load(fname_contact)      
        % =================================================================
        Q1 = tau_ratio.*Pmtx;
        tau_basis = R0*gamma/N/(w'*Q1*w);
        tauL = tau_basis.*tau_ratio;
        csum_sizes = cumsum(pop_sim);
        idx_blk_beg = csum_sizes - pop_sim + 1;
        idx_blk_end = idx_blk_beg + pop_sim - 1;
        tauN = zeros(N,1);
        for Li=1:L
            tauN(idx_blk_beg(Li):idx_blk_end(Li)) = tauL(Li);
        end
        tau = tauN;

        ClT = zeros(L,T,nsimul);
        RlT = zeros(L,T,nsimul);
        IlT = zeros(L,T,nsimul);
end

%% social distancing
if flag.dist~=0 
    %-----set parameters---------            
    TR_target = 0.9;
    TR_relax = 1.5;
    weeks_R0 = 2;
    weeks_dist_fall = 3;
    weeks_dist_keep = 8;
    weeks_dist_rise = 3;
    %----------------------------   
    k_target = TR_target/R0 * k0; 
    k_fall = linspace(k0, k_target, weeks_dist_fall*7)';  
    k_relax = TR_relax/R0 * k0;
    k_rise = linspace(k_target, k_relax, weeks_dist_rise*7)';
    k_vec = [k0.*ones(weeks_R0*7,1); k_fall; k_target.*ones(weeks_dist_keep*7,1); k_rise];
    kT = [k_vec; k_relax.*ones(T-length(k_vec), 1)];
    TR = kT./k0.*R0;
end
%% vaccination
if flag.vacc~=0
    %-----set parameters---------
    eff_vacc = 0.95;                            
    frac_vacc = 0.75;
    weeks_vacc = 12;
    if flag.dist==0 % without distancing
        wbeg_vacc = 4;
    else % with distancing
        wbeg_vacc = weeks_R0 + weeks_dist_fall + weeks_dist_keep/2 + 0*weeks_dist_rise + 1;
    end
    tbeg_vacc = (wbeg_vacc-1)*7 + 1;
    days_vacc = weeks_vacc*7;
    N_vacc = frac_vacc*N;
    N_vacc_per_day = floor(N_vacc/days_vacc);
    tend_vacc = tbeg_vacc + days_vacc -1;
    mu_vacc = 1/(1-eff_vacc);
    N_vacc_record = zeros(nsimul,N_vacc);
    %----------------------------
    if flag.vacc==2  % vaccinate older people first
        reverse_csum_sizes = cumsum(pop_sim, 'reverse');
        num_groups_vacc = sum(reverse_csum_sizes<=N_vacc)+1; % number of groups to be vaccinated
        num_last_group_vacc = N_vacc - reverse_csum_sizes(L - (num_groups_vacc-1)+1); % number of people to be vaccinated in the last group
    end  
end

%% ============================================
CT = zeros(nsimul,T);
RT = zeros(nsimul,T);
IT = zeros(nsimul,T);
CT(:,1) = m;
IT(:,1) = CT(:,1); 

days_recovery = zeros(N, nsimul);
Tend = zeros(nsimul,1); 

for iter = 1:nsimul
    if mod(iter,50)==0
        disp(iter)
    end
    % ========== initialization for each iter ============
    xNT = zeros(N,T);
    rng(20*iter-1);
    x1_idx = randperm(N,m);
    xNT(x1_idx,1) = 1;                

    yNT = zeros(N,T);
    xstarNT = zeros(N,T);
    if choice_network==3
        for Li = 1:L
            x1_idx_Li = intersect(x1_idx,  idx_blk_beg(Li):idx_blk_end(Li));
            ClT(Li,1,iter) = length(x1_idx_Li);
            IlT(Li,1,iter) = ClT(Li,1,iter);
        end
    end
    
     % for each iter, before start of the dynamic process
    switch flag.vacc
        case 1 % random vaccination
            idx_vacc = randperm(N, N_vacc)';  
            idx_vacc_t = [];
        case 2 % vaccinate older people first
            idx_vacc = [];
            for groupj_vacc=1:num_groups_vacc-1
                idx_vacc_temp = csum_sizes(L-groupj_vacc) + randperm(pop_sim(L-groupj_vacc+1))';
                idx_vacc = [idx_vacc; idx_vacc_temp];
            end
            idx_vacc_temp = csum_sizes(L-num_groups_vacc) + randperm(pop_sim(L-num_groups_vacc+1), num_last_group_vacc)';
            idx_vacc = [idx_vacc; idx_vacc_temp];
            idx_vacc_t = [];
    end    
    
    history = zeros(N,T);
    history(xNT(:,1)==1,1) = 1;
    days_recovery_mtx = zeros(N,T);
    muN = ones(N,1);

    % =========== random draws ================
    rng(20*iter);
    kxiNT = -log(1-rand(N,T));
    kxiNT = kxiNT./sum(kxiNT).*N;
            
    rng(20*iter+1);
    recover_draw = rand(N,T)<gamma;            

    % ========= start the dynamic process ============
    for t=2:T
        idx_ill = history(:,t-1)>0;
        xNT(idx_ill, t) = 1;
        idx_healthy = ones(N,1) - idx_ill;                       
        yNT(:,t) = yNT(:,t-1) + (1-yNT(:,t-1)).*xNT(:,t-1).*recover_draw(:,t);
  
        switch choice_network
            case 1  % ER random graph
                if flag.dist==0
                    p_t = p;
                else
                    p_t = kT(t)/k0*p;
                end
                seed_D = (21000+20*t)*iter;
                [D,nEdge] = fn_gen_ER_graph(N, p_t, seed_D); 
            case 2  % power law
                seed_PowLaw_kseq = (21000+20*t)*iter;
                seed_PowLaw_graph = (41000+20*t)*iter;
                [kseq,kmean_target] = fn_gen_PowLaw(N,exponent,kmin,kmax,seed_PowLaw_kseq);
                [G,D,count_droplink] = fn_gen_Config_graph(kseq,seed_PowLaw_graph);
            case 3 % multigroup                
                if flag.dist==0
                    Pmtx_t = Pmtx;
                else
                    Pmtx_t = kT(t)/k0.*Pmtx;
                end
                seed_SBM = (21000+20*t)*iter;
                D = fn_gen_SBM(pop_sim, Pmtx_t, seed_SBM); 
        end                

        if flag.vacc~=0
            if t>=tbeg_vacc && t<tend_vacc
                idx_vacc_t = idx_vacc(1:N_vacc_per_day*(t-tbeg_vacc+1)); 
            elseif t==tend_vacc
                idx_vacc_t = idx_vacc;
            end
            muN(idx_vacc_t) = mu_vacc;            
        end
        xstarNT(:,t) = tau.* (D * (xNT(:,t-1).*(1-yNT(:,t-1)))) - muN.*kxiNT(:,t);

        idx_infected = double(xstarNT(:,t)>0);
        idx_newill = idx_healthy.*idx_infected;
        xNT(:,t) = idx_newill + idx_ill;
        history(:,t) = history(:,t-1) + idx_ill + idx_newill;
        days_recovery_mtx(:,t) = history(:,t-1).*yNT(:,t).*(1-yNT(:,t-1));      

        CT(iter,t) = sum(xNT(:,t));      
        IT(iter,t) = sum(xNT(:,t).*(ones(N,1)-yNT(:,t)));
        RT(iter,t) = CT(iter,t) - IT(iter,t);                  

        if choice_network==3  % multigroup
            for Li = 1:L
                ClT(Li,t,iter) = sum( xNT(idx_blk_beg(Li):idx_blk_end(Li) ,t) ); 
                IlT(Li,t,iter) = sum(xNT(idx_blk_beg(Li):idx_blk_end(Li),t) .* ...
                    (ones(pop_sim(Li),1)-yNT(idx_blk_beg(Li):idx_blk_end(Li),t)));
                RlT(Li,t,iter) =  ClT(Li,t,iter) - IlT(Li,t,iter);
            end
        end            

        if IT(iter,t)==0  % skip to next iteration if epidemic ends
            Tend(iter) = t;
            CT(iter,t+1:T) = CT(iter,t);
            RT(iter,t+1:T) = RT(iter,t);
            if choice_network==3
                ClT(:,t+1:T,iter) = repmat(ClT(:,t,iter),1,T-t,1);
                RlT(:,t+1:T,iter) = repmat(RlT(:,t,iter),1,T-t,1);                
            end
            break
        end
    end    % end t loop
    days_recovery(:,iter) = sum(days_recovery_mtx,2);
    
    if flag.vacc~=0
        N_vacc_record(iter,:) = idx_vacc';
    end

end  % end iter loop

cT = CT./N; rT = RT./N; iT = IT./N;
cT_avg = mean(cT,1); iT_avg = mean(iT,1); rT_avg = mean(rT,1);       
cT_med = median(cT,1); iT_med = median(iT,1); rT_med = median(rT,1);    
dcT = diff(cT,1,2);
dcT_avg = mean(dcT); dcT_med = median(dcT); 

results = fn_summary_onegroup(cT, iT, Tend)

if flag.dist~=0
    Tplot = 300;
else
    Tplot = 150;
end

if choice_network==3
    clT = ClT./repmat(pop_sim,1,T,nsimul);
    rlT = RlT./repmat(pop_sim,1,T,nsimul);
    ilT = IlT./repmat(pop_sim,1,T,nsimul);  
    
    clT_avg = mean(clT,3); rlT_avg = mean(rlT,3); ilT_avg = mean(ilT,3);
    clT_med = median(clT,3); rlT_med = median(rlT,3); ilT_med = median(ilT,3);
    
    dclT = diff(clT,1,2); 
    dclT_avg = mean(dclT,3); dclT_med = median(dclT,3); 

    resultsL = fn_summary_multigroup(clT, ilT)
end

elapsedMin = toc(t0)/60
if profiler==1
    profile off;
    profile viewer;
end

%% ----------save mat results--------------
if nsimul==1000
    fname = [fpath,network_string];
    if flag.dist~=0
        fname = [fname,'_dist_TR_target',num2str(TR_target*10),'_relax',num2str(TR_relax*10),...
            '_fall',num2str(weeks_dist_fall),'W_keep',num2str(weeks_dist_keep),'W_rise',num2str(weeks_dist_rise),'W'];
    end
    if flag.vacc~=0
        fname = [fname,'_vacc_pct',num2str(frac_vacc*100),'_eff',num2str(eff_vacc*100),...
            '_Wbeg', num2str(wbeg_vacc),'_', num2str(weeks_vacc),'W'];
        if flag.vacc==2
            fname = [fname,'_oldfirst'];
        end
    end       
    if N~=10000
        fname = [fname,'_N',num2str(N)];
    end
    if flag.exact==1
        fname = [fname,'_exact'];
    end
    
    save_varlist = {'cT','rT','iT',...
        'cT_avg','rT_avg','iT_avg',...
        'cT_med','rT_med','iT_med',...
        'dcT','dcT_avg','dcT_med',...
        'results',...
        'days_recovery','Tend'};
    if choice_network==3
        save_varlist = [save_varlist,'clT','rlT','ilT',...
            'clT_avg','rlT_avg','ilT_avg',...
            'clT_med','rlT_med','ilT_med',...
            'dclT','dclT_avg','dclT_med',...
            'resultsL'];
    end
    if flag.dist~=0
        save_varlist = [save_varlist,'TR'];        
    end 
    if flag.vacc~=0
        save_varlist = [save_varlist,'N_vacc_record'];
    end 
    save(fname,save_varlist{:});
end

