"""Test cases for key components of the code."""
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
import src.fcastutils as futil
import src.generalutils as gutil
import src.algfcast as algfcast
import pytest


def test_fcasting_env():
    """
    This test checks that the rolling window forecasting env brings back
    the right slices of data for in-sample and out-of-sample, and that the
    predicted out-of-sample and in-sample values computed via the elaborate
    machinery in the main part of the code do give the same as a simpler,
    manual approach.
    Horizon is number steps ahead to fcast, alpha is size of window, mu is
    number (starting from 1) of current window, stepSize is how much to
    move window by in each iteration.
    """
    # Create dataframe with known data to test the forecasting environment on
    df = pd.DataFrame(
        {
            "Time": ["t-1", "t", "t+1", "t+2", "t+3", "t+4", "t+5", "t+6"],
            "x": list(range(-1, 7)),
            "y": [-1, 0, 1, 2, 0, 1, 1, 2],
        }
    )
    df.index.name = "Time"
    df["TimeIndex"] = pd.date_range("31/1/2000", periods=len(df), freq="M")
    df = df.set_index("TimeIndex")
    yvar = "y"
    target = "target"
    metric = "x"
    lagString = yvar + "_lag_1"
    metricList = [metric]
    controls = [lagString]
    horizon = 1
    # One lag of yvar (will become a control)
    df[lagString] = df[yvar].shift(1)
    # Forecast Yvar horizon ahead
    df[target] = df[yvar].shift(-horizon)
    # Full run specification:
    model_name = "OLS"
    CV = False
    specification = "test"
    paper, trafo, alpha, stepSize, expanding = "None", "none", 3, 1, False
    packagedRunSettings = (
        yvar,
        metric,
        horizon,
        paper,
        trafo,
        alpha,
        stepSize,
        expanding,
        CV,
        model_name,
        specification,
    )
    run_type = "test"
    allExog = metricList + controls
    # As if real-time transformation of text metrics:
    df[metricList] = futil.txtMetricTransformer(df[metricList], alpha, stepSize, trafo)
    # Drop yvar from fcast bit to ensure no contamination
    xf = df.drop(yvar, axis=1)
    # Forecast gubbins ----------------------------------
    # Get the datetime index that will be used throughout
    xf = futil.prepForModelling(xf, allExog + [target])
    inputDataIndex = xf.index
    # These are the bits we are really testing: this returns the forecast
    # variables (target^hat, though it's called target in the dataframe)
    # in the time of the features (as opposed to the time of the
    # target variable). It does so for different 'mu' and both in-sample (IS)
    # and OOS (IS is False in the dataframe).
    allResTxtmet = futil.forecastTool(
        xf.loc[inputDataIndex, [target] + allExog],
        target,
        allExog,
        alpha,
        stepSize,
        expanding,
        CV,
        model_name,
    )
    # Now we get the collated results with all info on what happened. These
    # results are in the time of the target variable so, to get back to the
    # time of the feature variables, one would need to use the horizon column
    # to shift the target data backwards
    summarisedTxtMet = futil.collateRunResults(
        allResTxtmet, df[yvar], inputDataIndex, *packagedRunSettings, "metric"
    )
    # We now test this setup with a simple case: the first window of a rolling
    # window OLS model.
    mu_period = 1  # First window period
    start_pos = (mu_period - 1) * stepSize  # Start at 0 timeindex
    end_pos = mu_period * stepSize + alpha  # First window size finishes at window
    # size plus one step.
    # Run the linear reg:
    reg = LinearRegression().fit(
        xf.loc[inputDataIndex[start_pos:end_pos], allExog],
        xf.loc[inputDataIndex[start_pos:end_pos], [target]],
    )
    # Run an in-sample prediction using the time of the features
    # (so predicting y at y_{t+h})
    y_tph_hat_IS = reg.predict(xf.loc[inputDataIndex[start_pos:end_pos], allExog])
    # Put this into a nicer format
    y_IS = pd.DataFrame(
        data=y_tph_hat_IS,
        index=inputDataIndex[start_pos:end_pos],
        columns=["y_tph_hat_IS"],
    )
    # Create a version that shows the prediction in y's own time
    y_IS["y_hat_IS"] = y_IS["y_tph_hat_IS"].shift(1)
    # Perform same operations for out-of-sample, which is everything from end
    # of the first rolling window onwards
    y_tph_hat_OS = reg.predict(
        xf.loc[inputDataIndex[mu_period * stepSize + alpha:], allExog]
    )
    y_OS = pd.DataFrame(
        data=y_tph_hat_OS,
        index=inputDataIndex[mu_period * stepSize + alpha:],
        columns=["y_tph_hat_OS"],
    )
    y_OS["y_hat_OS"] = y_OS["y_tph_hat_OS"].shift(1)
    # Now we can compare the IS predictions in feature time:
    # grab the relevant slice of all results
    all_res_is = allResTxtmet[
        ((allResTxtmet["IS"] == True) & (allResTxtmet["mu"] == 1))
    ]
    # Check these are the same
    assert all(all_res_is["target"] == y_IS["y_tph_hat_IS"])
    # Now do a similar process for out of sample and feature time
    all_res_os = allResTxtmet[
        ((allResTxtmet["IS"] == False) & (allResTxtmet["mu"] == 1))
    ]
    assert all(all_res_os["target"] == y_OS["y_tph_hat_OS"])

    # Now let's check against the same series transformed into the time of y,
    # which is what happens with summarisedTxtMet. summarisedTxtMet contains
    # answers from mu=2 period also, so first job is to drop those as we only
    # manually created the mu=1 period results in this test.
    # First, in-sample
    mu_one_ix_IS = [
        x
        for x in allResTxtmet[
            (allResTxtmet["mu"] == 1) & (allResTxtmet["IS"] == True)
        ].index
        if x in summarisedTxtMet.index
    ]
    sum_res_is = summarisedTxtMet.loc[pd.to_datetime(mu_one_ix_IS), :].dropna(
        subset=["IS_prediction"]
    )
    # Now out of sample
    mu_one_ix_OS = [
        x
        for x in allResTxtmet[
            (allResTxtmet["mu"] == 1) & (allResTxtmet["IS"] == False)
        ].index
        if x in summarisedTxtMet.index
    ]
    sum_res_os = summarisedTxtMet.loc[pd.to_datetime(mu_one_ix_OS), :].dropna(
        subset=["OOS_prediction"]
    )
    # And compare:
    assert all(sum_res_is["IS_prediction"] == y_IS["y_hat_IS"].dropna())
    assert all(sum_res_os["OOS_prediction"] == y_OS["y_hat_OS"].dropna())
    # Final part of testing is most simple possible you can can imagine:
    # whether the yhat_OS values are as we would expect for the input data,
    # absent transforms & with only the first rolling window period
    # Get cut of df to apply regression to (manually, but based on settings
    # above):
    time_start = pd.to_datetime("2000-02-29")
    time_end = pd.to_datetime("2000-05-31")
    reg_simple = LinearRegression().fit(
        df.loc[time_start:time_end, allExog], df.loc[time_start:time_end, [target]],
    )
    # Get rest of data that isn't nan and use it to predict (NB: we are using
    # feature-space dates here)
    y_hat_OS_simple = reg_simple.predict(
        df.dropna(subset=["target"]).loc[time_end + pd.offsets.MonthEnd():, allExog]
    )
    # Check that this matches out of sample prediction obtained by more
    # elaborate means
    assert all(y_hat_OS_simple.flatten() == y_OS["y_tph_hat_OS"].values)


def test_target_series_properly_retrieved_during_fcast():
    """
    When making a forecast with a particular target series, ensure that the
    target series does not get mangled when the forecast happens.
    """
    (target, metric, horizon, paper,
     trafo, alpha, stepSize, expanding) = ('CPIall',
                                           'stability',
                                           3,
                                           'GRDN',
                                           'none', 36, 1, False)
    df = gutil.getTimeSeries(metric, target, paper, freq='M')
    fcastSettings = (target, metric, horizon, paper,
                     trafo, alpha, stepSize, expanding)
    # AR1 test
    res = algfcast.predictvsAR1(*fcastSettings)
    res["target_value"]
    df[target]
    combo = pd.merge(res["target_value"], df[target],
                     left_index=True, right_index=True)
    assert(all(combo["target_value"] == combo[target]))


def test_that_IS_and_OOS_never_overlap_r():
    """
    We never want the in-sample and out-of-sample slices of the time series
    data to overlap for the same rolling window period. (It's fine if ones
    from different rolling window periods overlap). This test checks whether
    there is any overlap.
    """
    xf = pd.DataFrame(np.array([['t-1', 't', 't+1', 't+2', 't+3', 't+4', 't+5'],
                                ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
                                [0, 1, 2, 3, 4, 5, 6]]).T,
                      columns=['Time', 'y', 'x'])
    xf.index.name = 'time'
    lag = 1
    xf['ylag'] = xf['y'].shift(lag)
    X_feat = xf[['x', 'ylag']]
    alpha = 1
    step = 1
    numSteps = np.int((len(xf)-step-alpha)/step)
    allResultsDf = pd.DataFrame()
    for mu in range(1, numSteps+1):
        # print('mu = '+str(mu)+' of '+str(numSteps)+', s = '+str(step))
        # print('\n in-sample: \n')
        xf_IS = futil.slice_insample_r(X_feat, mu, alpha, step)
        xf_IS['mu'] = mu
        xf_IS['IS'] = True
        xf_IS['target'] = 'IS_y_hat ' + str(mu)
        # and out-of-sample:
        xf_OOS = futil.slice_oosample_r(X_feat, mu, alpha, step)
        xf_OOS['mu'] = mu
        xf_OOS['IS'] = False
        xf_OOS['target'] = 'OOS_y_hat ' + str(mu)
        allResultsDf = pd.concat([allResultsDf, xf_IS, xf_OOS],
                                 sort=True, axis=0)
    final_results = futil.script_ISOOS_summary(allResultsDf)
    for col in final_results.columns:
        final_results[col + ' mu'] = final_results[col].dropna().str.split(" ")
    for col in final_results.columns:
        if("mu" in col):
            final_results[col] = final_results[col].apply(
                lambda x: x[1] if type(x) == list else np.nan)
    assert(all(final_results["IS_prediction mu"]
               != final_results["OOS_prediction mu"]))


def test_that_IS_and_OOS_never_overlap_e():
    """
    We never want the in-sample and out-of-sample slices of the time series
    data to overlap for the same expanding window period. (It's fine if ones
    from different rolling window periods overlap). This test checks whether
    there is any overlap.
    """
    xf = pd.DataFrame(np.array([['t-1', 't', 't+1', 't+2', 't+3', 't+4', 't+5'],
                                ['a', 'b', 'c', 'd', 'e', 'f', 'g'],
                                [0, 1, 2, 3, 4, 5, 6]]).T,
                      columns=['Time', 'y', 'x'])
    xf.index.name = 'time'
    lag = 1
    xf['ylag'] = xf['y'].shift(lag)
    X_feat = xf[['x', 'ylag']]
    alpha = 1
    step = 1
    numSteps = np.int((len(xf)-step-alpha)/step)
    allResultsDf = pd.DataFrame()
    for mu in range(1, numSteps+1):
        # print('mu = '+str(mu)+' of '+str(numSteps)+', s = '+str(step))
        # print('\n in-sample: \n')
        xf_IS = futil.slice_insample_e(X_feat, mu, alpha, step)
        xf_IS['mu'] = mu
        xf_IS['IS'] = True
        xf_IS['target'] = 'IS_y_hat ' + str(mu)
        # and out-of-sample:
        xf_OOS = futil.slice_oosample_e(X_feat, mu, alpha, step)
        xf_OOS['mu'] = mu
        xf_OOS['IS'] = False
        xf_OOS['target'] = 'OOS_y_hat ' + str(mu)
        allResultsDf = pd.concat([allResultsDf, xf_IS, xf_OOS],
                                 sort=True, axis=0)
    final_results = futil.script_ISOOS_summary(allResultsDf)
    for col in final_results.columns:
        final_results[col + ' mu'] = final_results[col].dropna().str.split(" ")
    for col in final_results.columns:
        if("mu" in col):
            final_results[col] = final_results[col].apply(
                lambda x: x[1] if type(x) == list else np.nan)
    assert(all(final_results["IS_prediction mu"]
               != final_results["OOS_prediction mu"]))