clearvars
close all;

date_string = '20210618';
country_list = {'Austria','France','Germany','Italy','Spain','UK'};
% country_list = {'US'};
vacc_pct = 10; % stopping criterion

N = 50000;
nsimul = 500;
roll_window = 14;
MF_window = roll_window;
ini_days = 7;
MA_window = 7;
gamma = 1/14;
network_string = 'ER';
Euro_country_list = {'Austria','France','Germany','Italy','Spain','UK'};
cstar = 0.01;
MF_ini = 5;

path_read = ['.\data_raw\',date_string,'\'];
path_read_recent = '.\data_raw\20210805\'; 
path_write = ['.\results_empirical\',date_string,'_MF',num2str(MF_window/7),...
    'W_guess',num2str(MF_ini),'_N',num2str(N),'_par\'];
figpath = '.\paper\figs\';


linS = repmat({'-','--','-.',':'},1,2);
figs2keep = [5,6]; % figures to keep
days_downto1 = zeros(length(country_list),1);
for ci = 1:length(country_list)
    all_figs = findobj(0, 'type', 'figure');
    close(setdiff(all_figs, figs2keep));
    
    country = char(country_list(ci));
    %% load realized data
    fname_read = [path_read, country,'.csv'];
    data = readtable(fname_read);
    CT = data.C;
    pop = data.Pop(1);
    data.Date = datetime(data.Date,'InputFormat','ddMMMyyyy');
    IT = fn_est_IT(CT,gamma);

    out_MA = fn_MA([CT,IT],MA_window);
    CT_MA = out_MA(:,1);
    IT_MA = out_MA(:,2);
    cT_MA = CT_MA./pop;
    iT_MA = IT_MA./pop;

    dCT_MA = diff(CT_MA);
    dcT_MAp = dCT_MA./pop.*100000;
    idx_cross1 = find(dcT_MAp>=1,1,'first'); 
    idx_beg = idx_cross1 + (MA_window-1) + 1;
    date_beg_betahat = data.Date(idx_beg+1);

    idx_beg_MA = idx_beg - MA_window + 1;
    cT_MA_use = cT_MA(max(idx_beg_MA-roll_window,0)+1:end);
    iT_MA_use = iT_MA(max(idx_beg_MA-roll_window,0)+1:end);
    T = length(cT_MA_use);
    date_beg_MA_use = data.Date(end - length(cT_MA_use) + 1);

    idx_beg_joint = find(cT_MA_use>cstar,1,'first');
    date_beg_joint = data.Date(end - (length(cT_MA_use) - idx_beg_joint + 1) + 1);

    %% dates of joint estimation
    t_est_set = idx_beg_joint-1: MF_window: T;
    date_est_set =  date_beg_joint-1: MF_window: data.Date(end);
    data_vacc = readtable('.\data_raw\vacc_dates.csv');
    switch vacc_pct
        case 10
            data_vacc.date10pct = datetime(data_vacc.date10pct,'InputFormat','ddMMMyyyy');  
            date_end_plot = data_vacc(strcmp(data_vacc.country,country),:).date10pct;
        case 15
            data_vacc.date15pct = datetime(data_vacc.date15pct,'InputFormat','ddMMMyyyy');  
            date_end_plot = data_vacc(strcmp(data_vacc.country,country),:).date15pct;
    end    
   
    date_est_set = date_est_set(date_est_set <= date_end_plot);
    length_cT_use = days(date_est_set(end) - date_beg_MA_use);
    cT_MA_use = cT_MA_use(1:length_cT_use);
    iT_MA_use = iT_MA_use(1:length_cT_use);
    % ===================================================
    %% read stored MFs and interpolate
    MF_pt = ones(1,length(date_est_set)); % store point estimates
    for j=1:length(date_est_set)
        fname_MF = [path_write, country, '_MF_', num2str(roll_window/7),'W_nsimul',num2str(nsimul),'_part',num2str(j)];
        data_MF_past = load(fname_MF);
        MF_pt(j) = data_MF_past.MF;
    end
    MF = interp1(date_est_set, MF_pt, date_est_set(1):date_est_set(end))';
    h_MF = fn_plot_MF(date_est_set, MF_pt, MF_window);
    if ismember(country, Euro_country_list) || strcmp(country,'US')
        ylim(gca, [1, 10])
    end

    %% compute TR
    beta_full = [];
    for j=1:length(date_est_set)
        fname_beta = [path_write, country, '_beta_', num2str(roll_window/7),'W_nsimul',num2str(nsimul),'_part',num2str(j)];
        data_beta = load(fname_beta,'betaT');
        if j==1
            betaj =  data_beta.betaT(ini_days+1:end);
        else
            betaj =  data_beta.betaT(2:end); % one point overlapping
        end
        beta_full = [beta_full; betaj];
    end
    TR  = beta_full./gamma;
    tmax = date_est_set(end);

    %% compute Re
    cT_real_ref = cT_MA_use(end-length(beta_full)+1:end);
    iT_real_ref = iT_MA_use(end-length(beta_full)+1:end);
    MF_full = [MF_ini*ones(length(beta_full)- length(MF),1);MF]; % include the initial guess estimate
    Re = (1 - MF_full.*cT_real_ref).* TR;

    %% plot TR together with Re
    h_TR_Re = fn_plot_empirical_TR_Re(tmax, TR, Re);
    if ismember(country, Euro_country_list)
        [h_Re, results_Re] = fn_plot_empirical_Re_lockdown(tmax, Re, country);
        days_downto1(ci) = results_Re.days_downto1;
    else
        h_Re = fn_plot_empirical_Re(tmax, Re); 
    end

    %% save to csv for R plot
    cT_cal = [];
    for j=1:length(date_est_set)
        fname_sim = [path_write, country, '_sim_',num2str(roll_window/7),'W_nsimul',num2str(nsimul),...
            '_ER_inidays',num2str(ini_days),'_part',num2str(j)];    
        data_sim = load(fname_sim,'results_sim');
        if j==1
            cTj = data_sim.results_sim.cT;
        else
            cTj =  data_sim.results_sim.cT(:,2:end); % one point overlapping
        end
        cT_cal = [cT_cal, cTj];
    end
    dcT_cal = diff(cT_cal,1,2);
    dcT_real = diff(cT_MA_use(end-size(cT_cal,2)+1:end));
    dcT_real_MF = dcT_real.*[MF_ini*ones(length(dcT_real)- length(MF),1);MF];
    date_beg_dc = date_est_set(end) - size(dcT_cal,2) +1;

    dlmwrite([path_write,country,'_ER_dcT_cal.csv'],dcT_cal.*100,'precision','%4.16f');
    dlmwrite([path_write,country,'_ER_dcT_real.csv'],dcT_real_MF.*100,'precision','%4.16f');
    writematrix(date_beg_dc,[path_write,country,'_ER_date.csv']);

    %% compare total cases (using more recent data)
    fname_read = [path_read_recent, country,'.csv'];  
    data = readtable(fname_read);
    data.Date = datetime(data.Date,'InputFormat','ddMMMyyyy');
    CT_MA = fn_MA(data.C,MA_window);
    cT_MA = CT_MA./data.Pop(1);
    
    date_beg_cT_real = data.Date(1)+MA_window-1;
    date_end_cT_real = data.Date(end);
    cT_real = cT_MA(1: days(date_end_cT_real-date_beg_cT_real)+1);
    dcT_real = diff(cT_real);
    dcT_real_MF = dcT_real.*[MF_ini*ones(days(date_est_set(1)-date_beg_cT_real-1),1);...
        MF; MF(end).*ones(days(date_end_cT_real-date_est_set(end)),1)];
    cT_real_MF = cumsum([cT_real(1).*MF_ini; dcT_real_MF]);
    cT_real_plot = cT_real.*100;
    cT_real_MF_plot = cT_real_MF.*100;
    if ci==1
        date_beg_plot = date_beg_cT_real;
    elseif date_beg_cT_real<date_beg_plot
        date_beg_plot = date_beg_cT_real;
    end
    if ci==1
        date_end_plot = date_end_cT_real;
    elseif date_end_cT_real>date_beg_plot
        date_end_plot = date_end_cT_real;
    end
    
    h_cT = figure(5);
    plot(date_beg_cT_real:date_end_cT_real, cT_real_plot,'Linewidth',1.5, 'LineStyle',linS{ci});
    hold on
    if ci==length(country_list)
        ylabel('Proportion total cases (c_t, per cent)');
        legend(country_list, 'NumColumns',2);
        hl = legend('Location','NorthWest');
        set(hl, 'Fontsize',12)
        legend boxoff 
        box('off')
        ax = gca;
        ax.LineWidth = 1.2;
        ax.FontSize = 12;
        ax.TickDir = 'out';
        xlim([date_beg_plot,date_end_plot])
        xticks(date_beg_plot:21:date_end_plot)
        xtickformat('MMM dd')
        xtickangle(90)
        hold off
    end
    
    h_cT_MF = figure(6);
    plot(date_beg_cT_real:date_end_cT_real, cT_real_MF_plot, 'Linewidth',1.5, 'LineStyle',linS{ci});
    hold on;
    if ci==length(country_list)
        ylabel('Proportion total cases (c_t, per cent)');
        legend(country_list, 'NumColumns',2);
        hl = legend('Location','NorthWest');
        set(hl, 'Fontsize',12)
        legend boxoff 
        box('off')
        ax = gca;
        ax.LineWidth = 1.2;
        ax.FontSize = 12;
        ax.TickDir = 'out';
        xlim([date_beg_plot,date_end_plot])
        xticks(date_beg_plot:21:date_end_plot)
        xtickformat('MMM dd')
        xtickangle(90)
        ax.YAxis.Exponent = 0;
    end
    
    fprintf('   %s: %s %s %3.2f %3.2f %3.2f\n',country,[date_beg_dc, date_est_set(end)],...
        MF(1), MF(end), cT_real_MF_plot(end)/cT_real_plot(end))
    
    %% save estimates and figures
    figname = [country,'_',network_string,'_N',num2str(N),'_guess',num2str(MF_ini),...
        '_',num2str(MF_window/7),'W'];
    figname_MF = [figname,'_MF'];
    figname_Re = [figname,'_Re'];
    figname_TR_Re = [figname,'_TR_Re'];
    if ci==length(country_list)
        figname_cT = ['cmp_',network_string,'_N',num2str(N),'_guess',num2str(MF_ini),...
            '_',num2str(MF_window/7),'W_cT'];
        figname_cT_MF = ['cmp_',network_string,'_N',num2str(N),'_guess',num2str(MF_ini),...
            '_',num2str(MF_window/7),'W_cT_MF'];
        saveas(h_cT,[figpath, figname_cT,'.png'])    
        saveas(h_cT_MF,[figpath, figname_cT_MF,'.png'])
    end
    
    saveas(h_MF,[figpath, figname_MF,'.png'])
    saveas(h_Re,[figpath, figname_Re,'.png'])
    saveas(h_TR_Re,[figpath, figname_TR_Re,'.png'])

    tmin_TR = tmax - length(TR) +1;
    save([path_write,country,'_ER_results'],'TR', 'Re', 'MF', 'MF_pt', 'tmax', 'tmin_TR', 'date_est_set')
    
end
% days_downto1     % # days to bring down Re to 1 from the start of the first lockdown
% mean(days_downto1)