#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
@author: at

Opens data generated by other files and puts it
into tables and figures before exporting to output/

Note that not all figs are generated here. The revision
figures are generated within the revisions script and the partial_dependence script.

"""
import configparser
import os
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import json
import seaborn as sns
import re
from dateutil.relativedelta import relativedelta
from matplotlib.ticker import AutoMinorLocator
import importlib
import textwrap
from functools import reduce
import src.generalutils as gutil
import src.fcastutils as futil

importlib.reload(futil)
importlib.reload(gutil)
# ---------------------------------------------------------------------------
# Settings
# ---------------------------------------------------------------------------
mpl.rcParams.update(mpl.rcParamsDefault)  # VS Code plots not black
config = configparser.ConfigParser()
config.optionxform = str
config.read("config.ini")
plt.style.use(config["viz"])
paper_list = json.loads(config.get("papers", "paper_list"))
colWheel = json.loads(config.get("vizcolours", "colourWheel"))
dashWheel = json.loads(config.get("vizdashes", "dashWheel"))


def align_yaxis(ax1, v1, ax2, v2):
    """adjust ax2 ylimit so that v2 in ax2 is aligned to v1 in ax1"""
    _, y1 = ax1.transData.transform((0, v1))
    _, y2 = ax2.transData.transform((0, v2))
    inv = ax2.transData.inverted()
    _, dy = inv.transform((0, 0)) - inv.transform((0, y1 - y2))
    miny, maxy = ax2.get_ylim()
    ax2.set_ylim(miny + dy, maxy + dy)


def returnEventDict():
    return {
        "Brexit ref": "23 June 2016",
        "Scottish ref": "18 September 2014",
        "General election 2010": "6 May 2010",
        "General election 2015": "7 May 2015",
        "Lehman brothers": "15 September 2008",
        'Draghi "Whatever it takes"': "26 July 2012",
        # 'Greek bailout ref':'5 July 2015',
        "US Gov shutdown": "1 October 2013",
        "9/11": "11 September 2001",
        "Invasion of Iraq": "20 March 2003",
        "Northern Rock": "14 September 2007",
        "Russian crisis": "17 August 1998",
        "Asian crisis mini-crash": "27 October 1997",
        "Barings bank": "26 February 1995",
        # 'Florence speech': '22 September 2017',
        "1997 Labour victory": "01 May 1997",
        # 'Chequers summit': '06 July 2018',
        # 'General election 2017':'08 June 2017',
        # 'Brexit bill': '26 June 2018'
    }


def starFunc(numin):
    if numin < 0.01:
        ans = "***"
    elif numin < 0.05:
        ans = "**"
    elif numin < 0.1:
        ans = "*"
    else:
        ans = ""
    return ans


def grabRMSEs(specification):
    """
    Retrieves the dataframe of RMSES for the given specification.
    """
    if "tf_vs" in specification:
        end_file_name = "ALL_SUM_tf_ML.pkl"
        df = pd.read_pickle(os.path.join(config["data"]["results"], end_file_name))
        df = df[df["specification"] == specification]
    else:
        df = pd.read_pickle(
            os.path.join(config["data"]["results"], "ALL_SUM_" + specification + ".pkl")
        )
    return df


def tableDescriptiveStats(save=True):
    """
    Table: descriptive statistics. Expects all three newspapers to be present as uses external info on circulation
    for all three.
    """
    circDict = {"DAIM": 1265, "DMIR": 563, "GRDN": 138}
    statsTable = pd.DataFrame(index=circDict.keys())
    statsTable["Circulation/$10^3$"] = statsTable.index.map(circDict)
    for paper in list(circDict.keys()):
        df = gutil.loadMetricsData(gutil.paperFname(paper))
        statsTable.loc[paper, "Unique articles"] = df.iloc[:, 0].count()
        art_mth_str = r"$ \langle \text{articles/month} \rangle $"
        statsTable.loc[paper, art_mth_str] = (
            df.groupby(pd.Grouper(freq="M")).count().mean().iloc[0]
        )
        statsTable.loc[paper, "First article"] = df.index.min()
        statsTable.loc[paper, "Last article"] = df.index.max()
    statsTable = statsTable.sort_values("Circulation/$10^3$")
    statsTable["Circulation/$10^3$"] = statsTable["Circulation/$10^3$"] * 1.0e3
    statsTable = statsTable.rename(columns={"Circulation/$10^3$": "Circulation"})
    statsTable.insert(
        2,
        r"\% of total",
        (
            statsTable["Unique articles"].divide(statsTable["Unique articles"].sum())
            * 100
        ),
    )
    totalLine = statsTable.sum()
    totalLine["First article"] = "-"
    totalLine["Last article"] = "-"
    totalLine.name = "\textbf{Total}"
    for col in ["First article", "Last article"]:
        statsTable[col] = statsTable[col].astype(str)
    numCols = [
        x for x in statsTable.columns if x not in ["First article", "Last article"]
    ]
    for col in numCols:
        statsTable[col] = statsTable[col].astype(np.double)
    statsTable = statsTable.append(totalLine)
    stringFormats = ["{:,.0f}", "{:,.0f}", "{:.1f}", "{:,.0f}", "{}", "{}"]
    stringFormats = dict(zip(statsTable.columns, stringFormats))
    for key, value in stringFormats.items():
        statsTable[key] = statsTable[key].apply(value.format)
    dictHere = gutil.nameConvert()
    dictHere.update({"\textbf{Total}": "\textbf{Total}"})
    statsTable.index = statsTable.index.map(dictHere)
    print(statsTable)
    if save:
        outPath = os.path.join(config["data"]["output"], "descStatsTable.txt")
        statsTable.to_latex(outPath, escape=False)


def swathePlotsCreate(save=True):
    inPath = os.path.join(config["data"]["results"], "ALL_swathes.csv")
    df = pd.read_csv(
        inPath, parse_dates=True, index_col=0, infer_datetime_format=True, dayfirst=True
    )
    namesOfPlots = ["Uncertainty", "Sentiment"]
    for paper in df["paper"].unique():
        for i, plotType in enumerate(["u", "s"]):
            proxyName = "proxies_" + plotType
            txtmetricName = "txtmetrics_" + plotType
            xf = df.loc[df["paper"] == paper, :].drop("paper", axis=1)
            colsToUse = [x for x in xf.columns if x.split("_")[1] == plotType]
            xf = xf[colsToUse]
            xf = xf.dropna(how="any")
            (proxMin, proxMax, proxMean) = (
                xf[proxyName + "_min"],
                xf[proxyName + "_max"],
                xf[proxyName + "_mean"],
            )
            txtmetMean = xf[txtmetricName + "_mean"]
            plt.close("all")
            fig, ax = plt.subplots()
            ax.fill_between(
                proxMax.index,
                proxMax,
                proxMin,
                color="#FFBAD2",
                zorder=1,
                label="Proxy swathe (min to max)",
            )
            ax.plot(
                txtmetMean.index,
                txtmetMean,
                color=colWheel[2],
                label="Mean of " + namesOfPlots[i].lower(),
                linewidth=2.0,
                zorder=3,
            )
            ax.plot(
                proxMean.index,
                proxMean,
                lw=2.0,
                color="#672C1B",
                zorder=2,
                label="Mean of proxies",
                dashes=[5, 3],
            )
            ax.legend(
                frameon=True, loc="lower left", ncol=3, handlelength=2, fontsize=14
            )
            eventVar = 1.0
            date_min = txtmetMean.index.min()
            date_min = pd.to_datetime("01-01-2002")
            date_max = txtmetMean.index.max()
            EventDict = returnEventDict()
            eventVar = 1.6
            for key, value in EventDict.items():
                date_of_event = pd.to_datetime(value)
                if (date_of_event > date_min) and (date_of_event < date_max):
                    ax.axvline(
                        x=date_of_event,
                        linewidth=0.5,
                        linestyle="-.",
                        color=colWheel[2],
                    )
                    ax.annotate(
                        key,
                        xy=(date_of_event - relativedelta(months=3), 0.61),
                        xycoords=("data", "axes fraction"),
                        rotation=90,
                        fontsize=11,
                    )
            ax.yaxis.major.formatter._useMathText = True
            if plotType == "s":
                descr_word = namesOfPlots[i]
                minFac = -6
                maxFac = 6
            else:
                descr_word = namesOfPlots[i]
                minFac = -3
                maxFac = 6
            ymin = minFac
            ymax = maxFac
            ax.set_ylim(ymin, ymax)
            ax.set_xlim(date_min, date_max)
            ax.yaxis.tick_right()
            plt.title(
                descr_word + " metrics (standard deviations from the mean)", loc="right"
            )
            if save:
                plt.tight_layout()
                outPath = os.path.join(
                    config["data"]["output"], paper + "_" + plotType + "_swathe.eps"
                )
                plt.savefig(outPath, dpi=300)
            print(paper)
            plt.show()


def getLocs(indexListNames, index):
    """
    Used by runRegTxtMet
    """
    locs = [index.get_loc(x) for x in indexListNames]
    numsOfIndices = locs.copy()
    [locs.append(x + 1) for x in numsOfIndices]
    locs.sort()
    return locs


def runRegTxtMet(target, horizon, sentiment, save=True):
    """
    Regression table of best metrics (with paper FE)

    """
    # TODO: add just controls column?
    df = gutil.paperConcat()
    if sentiment:
        metricsToTry = list(dict(config["txtmetrics_s"]))
    else:
        metricsToTry = list(dict(config["txtmetrics_u"]))
    metricsToTry.sort()
    catControls = ["paper"]
    ctrlString = "Lag" + gutil.allTargetsDict()[target]
    controls = [ctrlString]
    bmdata = gutil.BenchmarkData()
    dfBM = bmdata.returnAllSeries("M")
    dfBM[ctrlString] = dfBM[target].shift(1)
    dfBM[target] = dfBM[target].shift(-horizon)
    df = df[metricsToTry + catControls].groupby([pd.Grouper(freq="M"), "paper"]).mean()
    df = df.reset_index().set_index("date")
    df = pd.merge(df.reset_index(), dfBM.reset_index(), how="inner", on=["date"])
    df = df.set_index("date")
    df = df[[target] + metricsToTry + controls + catControls]
    df = df.dropna(how="any")
    allDict = gutil.nameConvert()
    allDict.update({"Lag" + allDict[target]: "Lagged " + allDict[target]})
    df["paper"] = df["paper"].map(gutil.nameConvert())
    df.columns = df.columns.map(allDict)
    # Update inputs
    metricsToTry = [allDict[k] for k in metricsToTry]
    controls = [allDict[k] for k in controls]
    catControls = [allDict[k] for k in catControls]
    target = allDict[target]
    regModelsList = [metricsToTry[i : i + 1] for i in range(0, len(metricsToTry), 1)]
    results = futil.MasterRunRegression(
        df,
        regModelsList,
        target,
        controls,
        catControls=catControls,
        horizon=horizon,
        baseline=True,
    )
    rdf = results.tables[0]
    rdf.index = [re.sub(r'Q\("(.*?)"\)', r"\1", x) for x in rdf.index]
    rdf.index = [re.sub(r"C\(\w+\)\[(.*?)\]", r"\1", x) for x in rdf.index]
    # Get every control position and the number after it
    ctrlLocs = getLocs(controls, rdf.index)
    # Now get all the paper locs
    paperLocs = getLocs(df["Paper"].unique(), rdf.index)
    # Other locs
    metLocs = getLocs(metricsToTry, rdf.index)
    if "Intercept" in results.tables[0].index:
        ctrlLocs = ctrlLocs + getLocs(["Intercept"], rdf.index)
    redoIndex = metLocs + ctrlLocs + paperLocs
    # Add missing rows
    [redoIndex.append(np.int(x)) for x in range(np.max(redoIndex), len(rdf) + 1)]
    mapToNew = dict(zip(redoIndex, range(len(rdf) + 1)))
    rdf = rdf.reset_index()
    depVarStr = "Dep. variable: " + target
    rdf = (
        rdf.reindex(mapToNew).rename(columns={"index": depVarStr}).set_index(depVarStr)
    )
    type_reg = "uncert"
    if sentiment:
        type_reg = "sent"
    if save:
        outFile = os.path.join(
            config["data"]["output"],
            "tab_reg_" + type_reg + "_" + target + str(horizon) + ".txt",
        )
        rdf.to_latex(outFile)
    return rdf


def corrHeatMapPlot(paper, horizon, sentiment, save=True):
    fLoc = os.path.join(config["data"]["results"], "ALL_corrs.csv")
    df = pd.read_csv(fLoc)
    metStr = "u"
    if sentiment:
        metStr = "s"
    txtMetsD = dict(config["txtmetrics_" + metStr])
    prxMetsD = dict(config["proxies_" + metStr])
    df = df.loc[df["metric"].isin(txtMetsD.keys())]
    df = df.loc[df["proxy"].isin(prxMetsD.keys())]
    df = df.loc[df["horizon"] == horizon]
    df = df.loc[df["paper"] == paper]
    # Convert names
    convCols = ["metric", "proxy", "paper"]
    for col in convCols:
        df[col] = df[col].map(gutil.nameConvert())
    # Pivot data:
    corrMat = df.pivot(index="metric", columns="proxy", values="correlation")
    # Sort by upper triangular
    colOrder = corrMat.mean(axis=0).sort_values()[::-1]
    corrMat = corrMat[colOrder.index]
    # corrMat = (corrMat
    #           .rename(columns=dict(zip(corrMat.columns,
    #                                    [x + '\n (t+'+str(horizon)+')'
    #                                     for x in corrMat.columns]))))
    corrMat = corrMat.rename(
        columns=dict(
            zip(corrMat.columns, [textwrap.fill(x, width=20) for x in corrMat.columns])
        )
    )
    plt.close("all")
    fig, ax = plt.subplots()
    sns.heatmap(
        corrMat,
        annot=True,
        fmt="1.1f",
        cbar_kws=dict(ticks=np.linspace(-1, 1, 5)),
        cmap="PRGn_r",
        linewidths=0.2,
        ax=ax,
        center=0,
        vmin=-1,
        vmax=1,
        annot_kws={"size": 15},
    )
    ax.tick_params(axis="x", which="major", pad=5)
    plt.xticks(rotation=35, size=11)
    plt.yticks(size=16)
    ax.set_ylabel("")
    ax.set_xlabel("")
    # Create offset transform by 5 points in x direction
    dx = -25 / 72.0
    dy = 0 / 72.0
    offset = mpl.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
    # apply offset transform to all x ticklabels.
    for label in ax.xaxis.get_majorticklabels():
        label.set_transform(label.get_transform() + offset)
    plt.tight_layout()
    if save:
        outFile = os.path.join(
            config["data"]["output"],
            paper + "_corrPlot_" + metStr + "_" + str(horizon) + ".eps",
        )
        plt.savefig(outFile, dpi=300)
    plt.show()


def plottingGunkTS(
    df,
    target,
    metric,
    horizon,
    paper,
    trafo,
    alpha,
    stepSize,
    expanding,
    CV,
    model,
    specification,
    run_type,
    save,
    sentiment,
):
    df = df.loc[df["run_type"] == run_type]
    IS_RMSE = futil.rmse(df["IS_prediction"], df["target_value"])
    OOS_RMSE = futil.rmse(df["OOS_prediction"], df["target_value"])
    colsToPlot = ["IS_prediction", "OOS_prediction", "target_value"]
    label = "Data"
    zorderToUse = 2
    plt.close("all")
    fig, ax = plt.subplots()
    for i, col in enumerate(colsToPlot):
        zorderToUse = 1
        if "predict" in col:
            if "IS" in col:
                label = "In-sample " + "(RMSE = {:03.2f})".format((IS_RMSE))
                zorderToUse = 1
            if "OOS" in col:
                label = "Out-of-sample " + "(RMSE = {:03.2f})".format((OOS_RMSE))
                zorderToUse = 3
        ax.plot(
            df.index,
            df[col],
            color=colWheel[(i + 1) % len(colWheel)],
            label=label,
            linewidth=3.0,
            dashes=dashWheel[i % len(dashWheel)],
            zorder=zorderToUse,
        )
    # ax.xaxis.set_minor_locator(AutoMinorLocator(2))
    ax.tick_params(axis="both", which="both", direction="in", length=6)
    handles, labels = ax.get_legend_handles_labels()
    handles = [handles[2], handles[0], handles[1]]
    labels = [labels[2], labels[0], labels[1]]
    ax.legend(handles, labels, frameon=False, loc="upper left", ncol=2, handlelength=2)
    ax.yaxis.major.formatter._useMathText = True
    ax.yaxis.tick_right()
    signMin = np.sign(df["target_value"].min())
    factor = 1.3
    if signMin < 0:
        factor = 1.0 / factor
    minY = round(df["target_value"].min() * factor, 0)
    maxY = round(df["target_value"].max() * 1.3, 0)
    minY = np.min([0.0, minY])
    ax.set_ylim(minY, maxY)
    if run_type == "benchmark":
        metricName = "benchmark"
    else:
        metricName = gutil.nameConvert()[metric] + " + benchmark"
    titleName = (
        gutil.nameConvert()[target] + ": " + metricName + ", h = " + str(horizon)
    )
    plt.title(titleName, loc="right")
    plt.tight_layout()
    metStr = "u"
    if sentiment:
        metStr = "s"
    if save:
        # TODO: create filenames without multiplicity
        outFile = os.path.join(
            config["data"]["output"],
            paper
            + "_"
            + specification
            + "_"
            + metStr
            + "_"
            + metric
            + "_"
            + str(horizon)
            + run_type
            + ".eps",
        )
        plt.savefig(outFile, dpi=300)
    plt.show()


def createSingleTSPlot(
    target,
    metric,
    horizon,
    paper,
    trafo,
    alpha,
    stepSize,
    expanding,
    CV,
    model,
    specification,
    run_type,
    save,
    sentiment,
):
    """Train and test plots against original.
    Plots data, returns dataframe
    """
    end_file_name = "ALL_" + specification + ".pkl"
    if metric == "tf_matrix":
        end_file_name = "ALL_tf_ML_results.pkl"
    df = pd.read_pickle(os.path.join(config["data"]["results"], end_file_name))
    df.index = pd.to_datetime(df.index)
    namesVars = json.loads(config["runSettings"]["runSettings"])
    packagedSettings = (
        target,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
        run_type,
    )
    for name, setting in zip(namesVars, packagedSettings):
        df = df.loc[df[name] == setting]
    if df.index.duplicated().sum() > 0:
        print("Warning: duplicates in datetime index")
    plottingGunkTS(
        df,
        target,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
        run_type,
        save,
        sentiment,
    )


def plotTSversusBenchmark(
    target,
    metric,
    horizon,
    paper,
    trafo,
    alpha,
    stepSize,
    expanding,
    CV,
    model,
    specification,
    save,
    sentiment,
    **kwargs
):
    """
    NB: reads results from file.
    """
    end_file_name = "ALL_" + specification + ".pkl"
    if metric == "tf_matrix":
        end_file_name = "ALL_tf_ML_results.pkl"
    df = pd.read_pickle(os.path.join(config["data"]["results"], end_file_name))
    df.index = pd.to_datetime(df.index)
    namesVars = json.loads(config["runSettings"]["runSettings"])
    packagedSettings = (
        target,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
        run_type,
    )
    # Select every field except the run type
    for name, setting in zip(namesVars[:-1], packagedSettings[:-1]):
        df = df.loc[df[name] == setting]
    # First plot the benchmark
    bench_df = df.loc[df["run_type"] == "benchmark"]
    metric_df = df.loc[df["run_type"] == "metric"]
    OOS_RMSE_bench = futil.rmse(bench_df["OOS_prediction"], bench_df["target_value"])
    OOS_RMSE_metric = futil.rmse(metric_df["OOS_prediction"], metric_df["target_value"])
    colsToPlot = ["IS_prediction", "OOS_prediction", "target_value"]
    titleName = gutil.nameConvert()[target] + ", h = " + str(horizon)
    plt.close("all")
    fig, ax = plt.subplots()
    ax.plot(
        bench_df.index,
        bench_df["OOS_prediction"],
        color=colWheel[1],
        label="Benchmark " + "(RMSE = {:03.2f})".format((OOS_RMSE_bench)),
        linewidth=3.0,
        dashes=dashWheel[1],
        zorder=1,
    )
    ax.plot(
        metric_df.index,
        metric_df["OOS_prediction"],
        color=colWheel[2],
        label=textwrap.fill(
            gutil.nameConvert()[metric]
            + " (RMSE = {:03.2f})".format((OOS_RMSE_metric)),
            30,
        ),
        linewidth=3.0,
        dashes=dashWheel[2],
        zorder=3,
    )
    ax.plot(
        metric_df.index,
        metric_df["target_value"],
        color=colWheel[3],
        label="Data",
        linewidth=3.0,
        dashes=dashWheel[3],
        zorder=2,
    )
    ax.xaxis.set_minor_locator(AutoMinorLocator(2))
    ax.tick_params(axis="both", which="both", direction="in", length=6)
    handles, labels = ax.get_legend_handles_labels()
    handles = [handles[2], handles[0], handles[1]]
    labels = [labels[2], labels[0], labels[1]]
    ax.legend(
        handles, labels, frameon=False, loc="upper left", ncol=3, handlelength=1.5
    )
    ax.yaxis.major.formatter._useMathText = True
    ax.yaxis.tick_right()
    signMin = np.sign(df["target_value"].min())
    factor = 1.3
    if signMin < 0:
        factor = factor
    minY = round(df["target_value"].min() * factor, 0)
    maxY = round(df["target_value"].max() * 1.3, 0)
    minY = np.min([0.0, minY])
    ax.set_ylim(minY, maxY)
    if "BU_post" in kwargs:
        titleName = "Forecasts of GDP (made 3 months ahead)"
        ax.set_ylabel("GDP growth, %", rotation=270)
        ax.yaxis.set_label_coords(1.1, 0.5)
        ax.yaxis.set_label_position("right")
        labels[0] = "GDP (actual)"
        labels[2] = "Neural network" + "\n(RMSE = {:03.2f})".format((OOS_RMSE_metric))
        ax.legend(
            handles, labels, frameon=False, loc="upper left", ncol=3, handlelength=1.5
        )
        ax.set_ylim(-8.0, 8)
        ax.axhline(alpha=0.2, linestyle="--", color="k", zorder=10)
    plt.title(titleName, loc="right")
    plt.tight_layout()
    metStr = "u"
    if sentiment:
        metStr = "s"
    dpi = 300
    if save:
        # TODO: create filenames without multiplicity
        outFile = os.path.join(
            config["data"]["output"],
            paper
            + "_"
            + specification
            + "_"
            + metStr
            + "_"
            + metric
            + "_"
            + str(horizon)
            + "metandbench"
            + ".eps",
        )
        if "BU_post" in kwargs:
            outFile = os.path.join(
                config["data"]["output"],
                "BU_"
                + paper
                + "_"
                + specification
                + "_"
                + metStr
                + "_"
                + metric
                + "_"
                + str(horizon)
                + "metandbench"
                + ".png",
            )
            dpi = 900
        plt.savefig(outFile, dpi=dpi)
    plt.show()


def barChartFacet(specification, paper, saveAnalysis=True):
    """
    Produces facet grid of target, metric, and RMSE ratio.
    Shows std over not shown variables (e.g. horizons)
    """
    xf = grabRMSEs(specification)
    if paper == "ALL":
        xf = xf[xf["paper"] != "COMB"]
        print("Averaging over papers")
    else:
        xf = xf[xf["paper"] == paper]
    print("Load " + str(len(xf)) + " rows of data.")
    xf = xf.sort_values("RMSE/RMSE_bch")
    metrics_to_keep = list(dict(config["txtmetrics_s"])) + list(
        dict(config["txtmetrics_u"])
    )
    xf = xf[xf["metric"].isin(metrics_to_keep)]
    # Switch to nice names
    colsToNicify = ["target", "metric", "paper"]
    nicifyDict = gutil.nameConvert()
    nicifyDict.update({"COMB": "Mean"})
    for col in colsToNicify:
        xf[col] = xf[col].map(nicifyDict)
    convertColNamesD = dict(zip(colsToNicify, [x.capitalize() for x in colsToNicify]))
    xf = xf.rename(columns=convertColNamesD)
    # D Bholat suggested consistent ordering - changed to alphabetical
    y_order = sorted(pd.Series(xf["Metric"].unique(), dtype="str"))
    wrap_num = 3
    plt.figure()
    g = sns.catplot(
        x="RMSE/RMSE_bch",
        y="Metric",
        col="Target",
        data=xf,
        kind="bar",
        height=4,
        ci="sd",
        order=y_order,
        row_order=sorted(xf["Target"].unique()),
        col_order=sorted(xf["Target"].unique()),
        # sharex=True,
        sharey=False,
        col_wrap=wrap_num,
    )
    plt.xlim(0.6, 1.2)
    for i, axnow in enumerate(g.axes):
        axnow.axvline(x=1.0, color="k", ls="--")
        axnow.set_xlabel("")
        for item in [axnow.title, axnow.xaxis.label] + axnow.get_yticklabels():
            item.set_fontsize(15)
        axnow.set_xticks(np.arange(0.6, 1.4, step=0.2))
        if (i % wrap_num) != 0:
            labels = axnow.get_yticklabels()
            axnow.set_yticklabels([""] * len(labels))
    plt.tight_layout()
    suptitle_text = r"$\frac{\mathrm{RMSE}}{\mathrm{RMSE}_{\mathrm{AR(1)}}}$"
    if specification[-5:] == "fctrs":
        suptitle_text = r"$\frac{\mathrm{RMSE}}{\mathrm{RMSE}_{\mathrm{Bench.}}}$"
    plt.suptitle(suptitle_text, y=0.1)
    plt.subplots_adjust(bottom=0.15, hspace=0.15, wspace=0.2)
    if saveAnalysis:
        outPath = os.path.join(
            config["data"]["output"], paper + "_barFacet_" + specification + ".eps"
        )
        plt.savefig(outPath, dpi=300)
    plt.show()


def ML_bar_chart_facet(specification, saveAnalysis=True):
    """
    Produces facet grid of target, model, and RMSE ratio.
    Shows std over not shown variables (e.g. horizons and papers)
    """
    xf = grabRMSEs(specification)
    xf = xf.sort_values("RMSE/RMSE_bch")
    # Switch to nice names
    colsToNicify = ["target", "metric", "paper"]
    nicifyDict = gutil.nameConvert()
    nicifyDict.update({"COMB": "Mean"})
    for col in colsToNicify:
        xf[col] = xf[col].map(nicifyDict)
    convertColNamesD = dict(
        zip(
            colsToNicify + ["model"], [x.capitalize() for x in colsToNicify + ["model"]]
        )
    )
    xf = xf.rename(columns=convertColNamesD)
    xf["Model"] = xf["Model"].astype("category")
    # D Bholat suggested consistent ordering - changed to alphabetical
    y_order = sorted(xf["Model"].unique())
    wrap_num = 3
    plt.close("all")
    plt.figure()
    g = sns.catplot(
        x="RMSE/RMSE_bch",
        y="Model",
        col="Target",
        data=xf,
        kind="bar",
        # hue='Metric',
        height=4,
        legend=False,
        sharex=True,
        sharey=False,
        order=y_order,
        row_order=sorted(xf["Target"].unique()),
        col_order=sorted(xf["Target"].unique()),
        hue_order=y_order,
        facet_kws={"hue_order": y_order},
        ci="sd",
        col_wrap=wrap_num,
    )
    plt.xlim(0.2, 2.0)
    for i, axnow in enumerate(g.axes):
        axnow.axvline(x=1.0, color="k", ls="--")
        axnow.set_xlabel("")
        for item in [axnow.title, axnow.xaxis.label] + axnow.get_yticklabels():
            item.set_fontsize(15)
        axnow.set_xticks(np.arange(0.2, 2, step=0.4))
        if (i % wrap_num) != 0:
            labels = axnow.get_yticklabels()
            axnow.set_yticklabels([""] * len(labels))
    plt.tight_layout()
    suptitle_text = r"$\frac{\mathrm{RMSE}}{\mathrm{RMSE}_{\mathrm{AR(1)}}}$"
    if specification[-7:] == "fctrOLS":
        suptitle_text = r"$\frac{\mathrm{RMSE}}{\mathrm{RMSE}_{\mathrm{Bench.}}}$"
    plt.suptitle(suptitle_text, y=0.1)
    plt.subplots_adjust(bottom=0.15, hspace=0.15, wspace=0.15)
    if saveAnalysis:
        outPath = os.path.join(
            config["data"]["output"], "MLbarFacet_" + specification + ".eps"
        )
        plt.savefig(outPath, dpi=300)
    plt.show()


def DMTableAllPapers(specification, saveAnalysis=True):
    """
    DM stat results plus AR(1) as page table
    h = 3
    Factor model in online appendix also h=3
               Target -----> MGDP -------------
    Newspaper | Metric |    DM Stat
    GRDN      | OPINION|   16***
    """
    xf = grabRMSEs(specification)
    xf = xf[xf["paper"] != "COMB"]
    metrics_to_keep = list(dict(config["txtmetrics_s"])) + list(
        dict(config["txtmetrics_u"])
    )
    xf = xf[xf["metric"].isin(metrics_to_keep)]
    xf = xf[(xf["RMSE/RMSE_bch"] < 1.0)]
    xf = xf.loc[xf["DMpval"] < 0.1, :]
    xf["DMpval"] = xf["DMpval"].apply(lambda x: starFunc(x))
    colsToNicify = ["target", "metric", "paper"]
    nicifyDict = gutil.nameConvert()
    nicifyDict.update({"COMB": "Mean"})
    for col in colsToNicify:
        xf[col] = xf[col].map(nicifyDict)
    convertColNamesD = dict(
        zip(
            colsToNicify + ["horizon"],
            [x.capitalize() for x in colsToNicify + ["horizon"]],
        )
    )
    xf = xf.rename(columns=convertColNamesD)
    xf["DM Statistic"] = (
        xf["DMstat"].apply(lambda x: "{:,.2f}".format(x)) + xf["DMpval"]
    )
    xf = (
        xf.reset_index()
        .groupby(["Paper", "Metric", "Horizon", "Target"])
        .first()["DM Statistic"]
        .unstack()
        .fillna("")
    )
    if saveAnalysis:
        outPath = os.path.join(
            config["data"]["output"], "ALL_DMStatTable_" + specification + ".txt"
        )
        xf.to_latex(outPath)
    return xf


def combined_DM_table_ML(specification, saveAnalysis=True):
    """
    One specification, all of everything else but then filtered down
    to just sig runs
    """
    xf = grabRMSEs(specification)
    xf = xf[xf["paper"] != "COMB"]
    xf = xf[(xf["RMSE/RMSE_bch"] < 1.0)]
    xf = xf.loc[xf["DMpval"] < 0.1, :]
    xf["DMpval"] = xf["DMpval"].apply(lambda x: starFunc(x))
    colsToNicify = ["target", "metric", "paper"]
    nicifyDict = gutil.nameConvert()
    nicifyDict.update({"COMB": "Mean"})
    for col in colsToNicify:
        xf[col] = xf[col].map(nicifyDict)
    convertColNamesD = dict(
        zip(
            colsToNicify + ["model", "horizon"],
            [x.capitalize() for x in colsToNicify + ["model", "horizon"]],
        )
    )
    xf = xf.rename(columns=convertColNamesD)
    xf["DM Statistic"] = (
        xf["DMstat"].apply(lambda x: "{:,.2f}".format(x)) + xf["DMpval"]
    )
    xf = (
        xf.reset_index()
        .groupby(["Paper", "Model", "Horizon", "Target"])
        .first()["DM Statistic"]
        .unstack()
        .fillna("")
    )
    xf = xf.fillna("")
    if saveAnalysis:
        outPath = os.path.join(
            config["data"]["output"],
            "ALL_DMStatTable_" + specification + "_" + "ML.txt",
        )
        xf.to_latex(outPath)
    return xf


def run_breakdown(
    target,
    paper,
    trafo,
    alpha,
    stepSize,
    expanding,
    CV,
    specification,
    save,
    metrics_list=["stability", "tf_idf_econom"],
    models_list=["OLS"],
):
    """This produces plots of point-by-point squared error differences
    i.e. (error from benchmark)^2_t - (error from model)^2_t
    For clarity, it's not written in a particularly efficient way.
    Above zero indicates periods when model beats benchmark

    This is set up to present std dev over horizons, and to only show
    a couple of user specificied metrics or models
    """
    y_name = r"$ \epsilon^2_{\mathrm{Bench.}} -" + r"\epsilon_{\mathrm{Text}}^2$"
    # Retrieve same run for metric and benchmark
    end_file_name = "ALL_" + specification + ".pkl"
    if metrics_list[0] == "tf_matrix":
        end_file_name = "ALL_tf_ML_results.pkl"
        hue_and_style_var = "Model"
    else:
        hue_and_style_var = "Metric"
    df = pd.read_pickle(os.path.join(config["data"]["results"], end_file_name))
    df.index = pd.to_datetime(df.index)
    namesVars = json.loads(config["runSettings"]["runSettings"])
    # Ensure that specification case is aligned
    if(specification not in df["specification"].unique()):
        specification = str(specification).lower()
    xf = pd.DataFrame()
    # Loop over metric, horizon, paper, etc.
    for horizon in json.loads(config.get("runSettings", "horizons")):
        for metric in metrics_list:
            for model in models_list:
                packagedSettings = (
                    target,
                    metric,
                    horizon,
                    paper,
                    trafo,
                    alpha,
                    stepSize,
                    expanding,
                    CV,
                    model,
                    specification,
                )
                # Select everything but run type
                tmp_df = df.copy()
                for name, setting in zip(namesVars[:-1], packagedSettings):
                    tmp_df = tmp_df.loc[tmp_df[name] == setting]
                tmp_df["Error"] = np.power(
                    tmp_df["OOS_prediction"] - tmp_df["target_value"], 2
                )
                tmp_df = tmp_df.pivot(columns="run_type", values="Error")
                tmp_df[y_name] = -(tmp_df["metric"] - tmp_df["benchmark"])
                for name, setting in zip(namesVars[1:-1], packagedSettings[1:]):
                    tmp_df[name] = setting
                xf = pd.concat([tmp_df, xf], axis=0)
    # Grab target time series
    tmp_df = df.copy()
    for name, setting in zip(namesVars[:-1], packagedSettings):
        tmp_df = tmp_df.loc[tmp_df[name] == setting]
    tmp_df = tmp_df["target_value"]
    tmp_df = tmp_df.loc[~tmp_df.index.duplicated(keep="first")]
    # End of loop
    # Put nice names in
    colsToNicify = ["metric", "paper", "model"]
    nicifyDict = gutil.nameConvert()
    nicifyDict.update({"COMB": "Mean"})
    for col in colsToNicify:
        xf[col] = xf[col].map(nicifyDict)
    convertColNamesD = dict(zip(colsToNicify, [x.capitalize() for x in colsToNicify]))
    xf = xf.rename(columns=convertColNamesD)
    plt.close("all")
    fig, ax = plt.subplots()
    sns.lineplot(
        x="date",
        y=y_name,
        hue=hue_and_style_var,
        style=hue_and_style_var,
        err_style="bars",
        err_kws={"elinewidth": 0.4, "errorevery": 1},
        ci=None,
        data=xf.reset_index(),
        ax=ax,
        zorder=2,
    )
    ax.set_xlim(xf.dropna().index.min(), pd.to_datetime("2018-01-01"))
    axlimits_x = ax.get_xlim()
    ax2 = ax.twinx()
    ax2.plot(
        tmp_df.index,
        tmp_df,
        label=nicifyDict[target] + " (RHS)",
        color="r",
        linestyle="-.",
        lw=1.5,
        zorder=0,
    )
    handles, labels = ax.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    ncols = 2
    if len(metrics_list) == 2:
        ncols = 3
    ax.legend(
        handles=handles + h2,
        labels=labels + l2,
        frameon=False,
        loc="upper right",
        ncol=ncols,
        handlelength=2,
        title="",
    )
    ax.set_xlabel("")
    ax2.set_ylabel("Growth, %")
    ax2.set_xlim(axlimits_x)
    ax.set_ylim(-10, 30)
    ax2.set_ylim(tmp_df.min(), tmp_df.max())
    ax2.set_ylim(-4, 12)
    align_yaxis(ax2, 0, ax, 0)
    ax.axhline(0.0, color="grey", zorder=1)
    plt.title(
        "Forecast " + nicifyDict[target] + " squared error differences comparison",
        loc="right",
    )
    if save:
        outPath = os.path.join(
            config["data"]["output"], "brkdown_" + specification + "_" + target + ".eps"
        )
        plt.savefig(outPath, dpi=300)
    plt.show()


def BU_swathe_plot(save=True):
    inPath = os.path.join(config["data"]["results"], "ALL_swathes.csv")
    df = pd.read_csv(
        inPath, parse_dates=True, index_col=0, infer_datetime_format=True, dayfirst=True
    )
    xf = df.loc[df["paper"] == "DAIM", :].drop("paper", axis=1)
    plotType = "s"
    colsToUse = [x for x in xf.columns if x.split("_")[1] == plotType]
    xf = xf[colsToUse]
    xf = xf.dropna(how="any")
    (proxMin, proxMax, proxMean) = (
        xf["proxies_s_min"],
        xf["proxies_s_max"],
        xf["proxies_s_mean"],
    )
    txtmetMean = xf["txtmetrics_s_mean"]
    plt.close("all")
    fig, ax = plt.subplots()
    ax.fill_between(
        proxMax.index,
        proxMax,
        proxMin,
        color="#FFBAD2",
        zorder=1,
        label="Indicator swathe (min to max)",
    )
    ax.plot(
        txtmetMean.index,
        txtmetMean,
        color=colWheel[2],
        label="Mean of sentiment",
        linewidth=2.0,
        zorder=3,
    )
    ax.plot(
        proxMean.index,
        proxMean,
        lw=2.0,
        color="#672C1B",
        zorder=2,
        label="Mean of indicators",
        dashes=[5, 3],
    )
    ax.legend(frameon=True, loc="lower left", ncol=3, handlelength=2, fontsize=13)
    eventVar = 1.0
    date_min = txtmetMean.index.min()
    date_min = pd.to_datetime("01-01-2002")
    date_max = txtmetMean.index.max()
    EventDict = returnEventDict()
    eventVar = 1.6
    for key, value in EventDict.items():
        date_of_event = pd.to_datetime(value)
        if (date_of_event > date_min) and (date_of_event < date_max):
            ax.axvline(
                x=date_of_event, linewidth=0.5, linestyle="-.", color=colWheel[2]
            )
            ax.annotate(
                key,
                xy=(date_of_event - relativedelta(months=3), 0.65),
                xycoords=("data", "axes fraction"),
                rotation=90,
                fontsize=11,
            )
    ax.yaxis.major.formatter._useMathText = True
    descr_word = "Sentiment"
    minFac = -6
    maxFac = 6
    ymin = minFac
    ymax = maxFac
    ax.set_ylim(ymin, ymax)
    ax.set_xlim(date_min, date_max)
    ax.yaxis.tick_right()
    ax.set_ylabel("Standard deviations from the mean", rotation=270)
    ax.yaxis.set_label_coords(1.1, 0.5)
    ax.yaxis.set_label_position("right")
    plt.title("Confidence indicators vs. newspaper sentiment", loc="right")
    if save:
        plt.tight_layout()
        outPath = os.path.join(config["data"]["output"], "Fig1_BU_swathe.png")
        plt.savefig(outPath, dpi=900)
    plt.show()


def table_of_targets_and_their_transforms(save=False):
    """
    Produces a table of all of the target variables and how they are
    transformed before use.
    NB: this assumes that everything is re-sampled to monthly frequency
    as used in all predict* functions in algfcast and mlfcast
    NB: Transform names are hard-coded here. New transforms will not appear
    in this function.
    """
    bmdata = gutil.BenchmarkData()
    # Retrieve targets that are being used:
    target_dict = dict(config["targets"])
    # Grab trafo data
    trafo_df = bmdata.dftransforms[
        ["VariableName", "Log", "Diff", "YoYQGrowth", "Frequency"]
    ]
    trafo_df = trafo_df.loc[target_dict.keys()]
    trafo_df["Name"] = (
        trafo_df.reset_index()["IDcode"].apply(lambda x: target_dict[x]).values
    )
    yoygrowth = "Y-on-Y growth"
    trafo_df = trafo_df.rename(
        columns={"YoYQGrowth": yoygrowth, "VariableName": "Description"}
    )
    trafo_df.loc[trafo_df["Frequency"] != "M", "Frequency"] = trafo_df.loc[
        trafo_df["Frequency"] != "M", "Frequency"
    ].apply(lambda x: x + ", up-sampled to M")
    trafo_cols = ["Log", "Diff", yoygrowth]
    # Add a no transform column
    trafo_df["None"] = trafo_df[trafo_cols].sum(axis=1).replace({0: 1, 1: 0})
    trafo_df["Transform"] = trafo_df[trafo_cols + ["None"]].idxmax(axis=1)
    # Make MGDP consistent with other entries - hard-coded
    trafo_df.loc["MGDP", "Transform"] = "3M-on-3M growth"
    trafo_df.loc["MGDP", "Description"] = "Gross Value Added: CVA SA"
    trafo_df = (
        trafo_df.reset_index()
        .drop(trafo_cols + ["None", "IDcode"], axis=1)
        .set_index("Name")
    )
    outPath = os.path.join(config["data"]["output"], "target_transforms.txt")
    trafo_df.to_latex(outPath, escape=False)


def produce_all_charts_and_tables():
    # =============================================================================
    # Outputs of analysis
    # =============================================================================
    saveAnalysis = True
    # saveAnalysis = False
    # Descriptive stats - USED IN PAPER
    tableDescriptiveStats(save=saveAnalysis)
    # Swathe plots - USED IN PAPER
    swathePlotsCreate(save=saveAnalysis)
    # Full time period regressions on text mets
    reg = runRegTxtMet("MGDP", 9, True, save=saveAnalysis)
    reg = runRegTxtMet("MGDP", 9, False, save=saveAnalysis)
    # USED IN PAPER
    # Correlation heatmap
    corrHeatMapPlot("COMB", 3, False, save=saveAnalysis)
    corrHeatMapPlot("COMB", 3, True, save=saveAnalysis)
    # -----------------------------------------------------------------
    # ALGORITHM RESULTS
    # -----------------------------------------------------------------
    # USED IN PAPER - plot of unemployment rate fcast for text metrics
    target = "LFSURATE"
    metric = "tf_idf_econom"
    horizon = 6
    paper = "COMB"
    trafo = "none"
    alpha = 36
    stepSize = 1
    expanding = False
    CV = False
    model = "OLS"
    specification = "met_vs_ar1"
    run_type = "metric"
    sentiment = True
    plotTSversusBenchmark(
        target,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
        saveAnalysis,
        sentiment,
    )
    # USED IN PAPER (std dev over horizons)
    specification = "met_vs_AR1"
    paper = "ALL"
    barChartFacet(specification, paper, saveAnalysis=saveAnalysis)
    specification = "met_vs_fctrs"
    paper = "ALL"
    barChartFacet(specification, paper, saveAnalysis=saveAnalysis)
    # USED IN PAPER
    xf = DMTableAllPapers("met_vs_AR1", saveAnalysis=saveAnalysis)
    # USED IN APPENDIX
    specification = "met_vs_fctrs"
    xf = DMTableAllPapers(specification, saveAnalysis=saveAnalysis)
    # ERROR BREAKDOWN
    target = "MGDP"
    specification = "met_vs_AR1"
    paper = "COMB"
    run_breakdown(
        target,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        specification,
        saveAnalysis,
        metrics_list=["stability", "tf_idf_econom"],
        models_list=["OLS"],
    )
    # -----------------------------------------------------------------
    # ML RESULTS
    # -----------------------------------------------------------------
    # USED IN PAPER - plot of MGDP fcast for tf vector
    target = "MGDP"
    metric = "tf_matrix"
    horizon = 3
    paper = "DAIM"
    trafo = "none"
    alpha = 36
    stepSize = 1
    expanding = False
    CV = False
    model = "NN"
    specification = "tf_vs_AR1OLS"
    run_type = "metric"
    sentiment = True
    plotTSversusBenchmark(
        target,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
        saveAnalysis,
        sentiment,
    )
    # USED IN PAPER - DM stats for ML runs
    specification = "tf_vs_AR1"
    xf = combined_DM_table_ML(specification, saveAnalysis=saveAnalysis)
    specification = "tf_vs_fctrs"
    xf = combined_DM_table_ML(specification, saveAnalysis=saveAnalysis)
    specification = "tf_vs_fctrOLS"
    xf = combined_DM_table_ML(specification, saveAnalysis=saveAnalysis)
    specification = "tf_vs_AR1OLS"
    xf = combined_DM_table_ML(specification, saveAnalysis=saveAnalysis)
    # USED IN PAPER - BAR CHARTS FOR ML RUNS
    specification = "tf_vs_AR1"
    ML_bar_chart_facet(specification, saveAnalysis=saveAnalysis)
    specification = "tf_vs_fctrs"
    ML_bar_chart_facet(specification, saveAnalysis=saveAnalysis)
    specification = "tf_vs_AR1OLS"
    ML_bar_chart_facet(specification, saveAnalysis=saveAnalysis)
    specification = "tf_vs_fctrOLS"
    ML_bar_chart_facet(specification, saveAnalysis=saveAnalysis)
    # ERROR BREAKDOWN ML
    target = "MGDP"
    run_breakdown(
        target,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        "tf_vs_AR1OLS",
        saveAnalysis,
        metrics_list=["tf_matrix"],
        models_list=["NN", "Ridge", "Forest"],
    )
    # -----------------------------------------------------------------
    # BU post special plots
    # -----------------------------------------------------------------
    BU_swathe_plot(True)
    # BU post - MGDP fcast for tf vector
    saveAnalysis = True
    target = "MGDP"
    metric = "tf_matrix"
    horizon = 3
    paper = "DAIM"
    trafo = "none"
    alpha = 36
    stepSize = 1
    expanding = False
    CV = False
    model = "NN"
    specification = "tf_vs_AR1OLS"
    run_type = "metric"
    sentiment = True
    plotTSversusBenchmark(
        target,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model,
        specification,
        saveAnalysis,
        sentiment,
        BU_post=True,
    )
    print("All charts and tables created!")
