#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

General utilities for analysis.
(forcasting utilities elsewhere)

Can return all series (ALL) or the inner join of text metrics (COMB)

"""
import pandas as pd
import json
import configparser
import numpy as np
import os
import statsmodels.tsa.stattools as smts
from pandas.tseries.offsets import MonthEnd, QuarterEnd

config = configparser.ConfigParser()
config.optionxform = str
config.read("config.ini")


# =============================================================================
# Misc
# =============================================================================
def doubleListCheck(XX):
    """
    If a list, return as list of lists
    """
    if not (any(isinstance(el, list) for el in XX)):
        XX = [XX]
    return XX


def singleListCheck(X):
    """
    If type is string, return string in list
    """
    if type(X) == str:
        X = [X]
    return X


# =============================================================================
# IO and naming functions
# =============================================================================
def loadMetricsData(fileName):
    """
    Loads previously created text metrics by paper filename
    """
    df = pd.read_csv(
        fileName,
        parse_dates=True,
        index_col=0,
        infer_datetime_format=True,
        dayfirst=True,
    )
    df = df.sort_index()
    return df


def paperFname(paper):
    """
    Text metric files names to read in
    """
    return os.path.join(config["data"]["intermed"], paper + config["fEnds"]["met_pa"])


def nameConvert():
    paper_list = json.loads(config.get("papers", "paper_list"))
    ppr_real_names = json.loads(config.get("papers", "paper_real_names"))
    txtmets = allMetricsDict()
    proxies = allProxiesDict()
    mrg_dict = dict(zip(paper_list, ppr_real_names))
    mrg_dict.update(txtmets)
    mrg_dict.update(proxies)
    mrg_dict.update(
        {
            "paper": "Paper",
            "model": "Model",
            "metric": "Metric",
            "target": "Target",
            "horizon": "Horizon",
            "OLS": "OLS",
        }
    )
    mrg_dict.update(dict(config["targets"]))
    mrg_dict.update(dict(config["high_dim"]))
    MLmodnames = json.loads(config.get("runSettings", "MLmodels"))
    mrg_dict.update(dict(zip(MLmodnames, MLmodnames)))
    return mrg_dict


def allProxiesDict():
    proxies = dict(config["proxies_u"])
    proxies.update(dict(config["proxies_s"]))
    return proxies


def allMetricsDict():
    metrics = dict(config["txtmetrics_u"])
    metrics.update(dict(config["txtmetrics_s"]))
    return metrics


def allTargetsDict():
    return dict(config["targets"])


def paperConcat():
    """
    Puts metrics from different sources into dataframe in tidy form
    """
    paper_list = json.loads(config.get("papers", "paper_list"))
    fNames = [paperFname(paper) for paper in paper_list]
    df = pd.DataFrame()
    for i, pathNow in enumerate(fNames):
        dfHere = loadMetricsData(pathNow)
        dfHere["paper"] = paper_list[i]
        df = pd.concat([df, dfHere], sort=True)
    del dfHere
    return df


class BenchmarkData:
    """
    Creates object with data from the benchmark data file (dir BenchmarkData).
    Args:
        applyTransforms (bool): Applies local/backwards only transforms
        also found in BD file to time series before returning them.
    Returns:
        Object
    """

    def __init__(self, globalTrafos=False, localTrafos=True):
        self.freqSheetnameDict = {
            "Q": "Quarterly",
            "D": "Daily",
            "WD": "Weekdaily",
            "M": "Monthly",
        }
        bmdataPath = config["benchmark"]["trafo"]
        self.dftransforms = pd.read_csv(bmdataPath, header=0, index_col=0)
        colsToIntType = ["Log", "Diff", "Normalise", "YoYQGrowth"]
        self.dftransforms[colsToIntType] = self.dftransforms[colsToIntType].astype(
            np.int
        )
        self.dfFreqs = pd.read_csv(bmdataPath, header=0, index_col=0)["Frequency"]
        fileNameFreqs = [
            pd.read_csv(
                config["benchmark"]["data"]
                + self.freqSheetnameDict[frequency]
                + ".csv",
                header=1,
            )
            for frequency in self.freqSheetnameDict.keys()
        ]
        self.dfBD_dict = dict(zip(self.freqSheetnameDict.keys(), fileNameFreqs))
        for dfBD_freq in self.dfBD_dict.keys():
            self.dfBD_dict[dfBD_freq] = self.dfBD_dict[dfBD_freq].rename(
                columns={"Unnamed: 0": "date"}
            )
            dateForFreq = self.dfBD_dict[dfBD_freq]["date"]
            self.dfBD_dict[dfBD_freq]["date"] = pd.to_datetime(
                dateForFreq, format="%d/%m/%Y"
            )
            self.dfBD_dict[dfBD_freq] = self.dfBD_dict[dfBD_freq].set_index("date")
        self.dfBD_dict["M"].index = self.dfBD_dict["M"].index + MonthEnd(-1)
        self.dfBD_dict["Q"].index = self.dfBD_dict["Q"].index + QuarterEnd(-1)
        if globalTrafos | localTrafos:
            # print('Applying transforms to BM data:\n')
            # if(globalTrafos):
            # print('Global transforms')
            # else:
            # print('Local transforms only')
            for dictKey in self.freqSheetnameDict.keys():
                dfTrans = self.dftransforms.loc[
                    self.dftransforms["Frequency"] == dictKey, :
                ]
                for col in self.dfBD_dict[dictKey].columns:
                    ans = self.dfBD_dict[dictKey][col]
                    ans = seriesTransformer(
                        dfTrans, ans, col, dictKey, globalTrafos=globalTrafos
                    )
                    self.dfBD_dict[dictKey][col] = ans
        else:
            print("No transforms applied to BM data")

    def returnSeries(self, frequency="Q"):
        """
        Returns data from the benchmark data file (dir BenchmarkData).
        Args:
            frequency (str):  Takes the following values:
                'Q':'Quarterly',
                'D':'Daily',
                'WD':'Weekdaily',
                'M':'Monthly'
        Returns:
            Pandas dataframe of benchmark data sheet
        """
        return self.dfBD_dict[frequency]

    def returnFrequency(self, seriesName):
        return self.dfFreqs[seriesName]

    def returnGlobalTransforms(self):
        return self.dftransforms.loc[:, "Normalise"]

    def returnAllSeries(self, freq):
        """
        Added a really important bugfix here:
        vanillar interpolation fills nans indiscriminately; limit area to
        inside only allows this for nans between valid values
        Also changed interpolation to time based
        NOTE that this only allows filling forwards
        of Q data to avoid data leakage
        """
        dfBM = pd.concat(
            [
                self.returnSeries(frequency="M"),
                (
                    self.returnSeries(frequency="Q")
                    .resample(freq)
                    .interpolate(
                        method="time", limit_direction="forward", limit_area="inside"
                    )
                ),
                self.returnSeries(frequency="WD").groupby(pd.Grouper(freq=freq)).mean(),
            ],
            axis=1,
        )
        dfBM.index.name = "date"
        return dfBM


def seriesTransformer(dfTrans, series, colname, freq, globalTrafos=False):
    """
    Applies only local/past based transforms to data.
    Added an important bug fix - fill_method = None
    ensures that NaNs are not replaced by 0s (creating new spurious
    data points)
    """
    numPeriods = 12
    if freq == "Q":
        numPeriods = 4
    elif freq == "WD":
        numPeriods = 365
    if dfTrans.loc[colname, "YoYQGrowth"] == 1:
        series = series.pct_change(numPeriods, fill_method=None) * 1.0e2
    if dfTrans.loc[colname, "Log"] == 1:
        series = series.apply(lambda x: np.log(x) if x > 0 else np.nan)
    if dfTrans.loc[colname, "Diff"] == 1:
        series = series.diff()
    if (dfTrans.loc[colname, "Log"] == 1) and (dfTrans.loc[colname, "Diff"] == 1):
        series = series * 1.0e2  # Roughly pct
    if dfTrans.loc[colname, "Normalise"] == 1:
        if globalTrafos:
            series = (series - series.mean()).divide(series.std())
    return series


# =============================================================================
# Return time series funcs
# =============================================================================
def seriesNormalise(xf, seriesName):
    for col in seriesName:
        xf[col] = (xf[col] - xf[col].mean()).divide(xf[col].std())
    return xf


def combinePaperTimeSeries(freqIn="M"):
    """
    Creates mean series combining papers
     uses the mean of each paper in time,
     then the mean of those across papers
    Returns:
        df (dataframe of means using pooled data)
    """
    df = paperConcat()
    xfb = df.groupby(["paper", pd.Grouper(freq=freqIn)]).mean()
    xfb = xfb.reset_index().set_index("date")
    xfb = xfb.loc[~(xfb == 0).all(axis=1), :]
    xfbmean = xfb.reset_index().groupby("date").mean()
    return xfbmean


def getTFTimeSeries(
    nonmetrics, paper, norm=False, roll=False, bmLocalTrafo=True, freq="M"
):
    """
    Returns tf matrix dataframe with other time series added in
    """
    fName = os.path.join(config["data"]["intermed"], paper + config["fEnds"]["tf_m"])
    df = loadMetricsData(fName)
    # Kill off any columns that are just zeros
    df = df.loc[:, (df != 0).any(axis=0)]
    xfb = df.groupby(pd.Grouper(freq=freq)).mean()
    xfb = xfb.reset_index().set_index("date")
    xfb = combineWithBMData(
        xfb,
        list(xfb.columns),
        nonmetrics,
        norm=norm,
        roll=roll,
        bmLocalTrafo=bmLocalTrafo,
        freq=freq,
    )
    return xfb


def getTimeSeries(
    metrics, nonmetrics, paper, norm=False, roll=False, bmLocalTrafo=True, freq="M"
):
    """
    Returns series in order to get longest possible time series

    """
    # Ensure nonmetrics is a list
    nonmetrics = singleListCheck(nonmetrics)
    if paper in json.loads(config.get("papers", "paper_list")):
        # Case of a single paper
        pathpaper = paperFname(paper)
        df = loadMetricsData(pathpaper)
    elif paper == "COMB":
        # Case of combination of papers
        df = combinePaperTimeSeries(freqIn=freq)
    else:
        raise ValueError("Anomalous paper requested")
    xfb = df.groupby(pd.Grouper(freq=freq)).mean()[metrics]
    xfb = xfb.reset_index().set_index("date")
    xfb = combineWithBMData(
        xfb,
        metrics,
        nonmetrics,
        norm=norm,
        roll=roll,
        bmLocalTrafo=bmLocalTrafo,
        freq=freq,
    )
    return xfb


def combineWithBMData(
    xfb, metrics, nonmetrics, norm=False, roll=False, bmLocalTrafo=True, freq="M"
):
    bmdata = BenchmarkData(localTrafos=bmLocalTrafo)
    dfBM = bmdata.returnAllSeries(freq)
    xfb = pd.merge(xfb.reset_index(), dfBM.reset_index(), how="inner", on=["date"])
    xfb = xfb.set_index("date")
    if type(metrics) == str:
        xfb = xfb.loc[~(xfb[metrics] == 0), :]
    # Make sure both are lists
    nonmetrics = singleListCheck(nonmetrics)
    metrics = singleListCheck(metrics)
    xfb = xfb[metrics + nonmetrics]
    # This needs to be a punishing 'any'
    # in order to prevent ragged ends of series
    # causing problems.
    xfb = xfb.dropna(how="any")
    if norm:
        xf = seriesNormalise(xfb, xfb.columns)
    if roll:
        rollNum = 3
        xfb = xfb.rolling(rollNum).mean()
    return xfb


# =============================================================================
# Analysis functions
# =============================================================================
def augmentedDickeyFullerTestTable(df, metrics):
    """This returns the results of an ADF test on each seriesToTry
        WARNING: the percentages which are starred are:
            1%: ***, 5%: **, 10%: *
    Args:
        df (pandas dataframe):  must contain seriesToTry as columns
        seriesToTry (list): to test
    Returns:
        Pandas dataframe of test stats, obs, and significance stars
    """
    resultsDictAll = {}
    for x in metrics:
        resultsDict = {}
        adfresults = smts.adfuller(df[x].dropna(), regression="c", regresults=True)
        sigString = ""
        if adfresults[0] < adfresults[2]["1%"]:
            sigString = "***"
        elif adfresults[0] < adfresults[2]["5%"]:
            sigString = "**"
        elif adfresults[0] < adfresults[2]["10%"]:
            sigString = "*"
        else:
            sigString = ""
        resultsDict["ADF statistic"] = "{:06.2f}".format(adfresults[0]) + sigString
        # resultsDict['Sig'] = sigString
        resultsDict["No. obs."] = adfresults[-1].nobs
        resultsDictAll[x] = resultsDict
    return pd.DataFrame.from_dict(resultsDictAll)
