#!/usr/bin/env python
""" This script reconstructs the results presented in Tables 2.x.
"""
import warnings
warnings.filterwarnings("ignore")

import pickle as pkl
import numpy as np
import copy
import sys
import os

np.random.seed(123)
import respy

from respy.python.shared.shared_auxiliary import transform_disturbances
from respy.python.shared.shared_auxiliary import dist_class_attributes
from respy.python.shared.shared_auxiliary import dist_model_paras
from respy.python.shared.shared_auxiliary import get_total_value
from respy.python.shared.shared_auxiliary import create_draws

# module wide variables
PROJECT_DIR = os.path.dirname(os.path.realpath(__file__))
PROJECT_DIR = PROJECT_DIR.replace('/recomputation/correct_choices', '')
sys.path.insert(0, PROJECT_DIR + '/_modules')

from auxiliary_shared import process_command_line
from auxiliary_shared import send_notification
from auxiliary_shared import enter_results_dir
from auxiliary_correct import write_correct
from auxiliary_shared import EXACT_DIR
from auxiliary_shared import SPEC_DIR
from auxiliary_shared import cleanup
from auxiliary_shared import mkdir_p


def run(is_debug, task, num_procs):
    """ Run a single request.
    """
    # Distribute task
    which, num_draws_emax, num_points = task

    dir_ = 'data_' + which + '/' + '%03.4d' % num_draws_emax

    if not num_points == 'all':
        dir_ += '_' + '%03.4d' % num_points
    else:
        dir_ += '_all'

    mkdir_p(dir_), os.chdir(dir_)

    # Read the baseline specification.
    respy_obj = respy.RespyCls(SPEC_DIR + '/data_' + which + '.ini')

    # Ensure a speedy execution (if possible).
    respy_obj.unlock()
    respy_obj.set_attr('num_procs', num_procs)
    respy_obj.set_attr('is_parallel', (num_procs > 1))
    respy_obj.lock()

    # Get the solutions to the exact solution from previous results.
    fname = EXACT_DIR + '/data_' + which + '/solution.respy.pkl'
    respy_obj_exact = pkl.load(open(fname, 'rb'))

    selected_periods = [1, 10, 20, 30, 40]
    selected_ranges = [(1, 10), (11, 35), (36, 38), (39, 39), (40, 40)]

    # Solve the dynamic programming model for only a subset of states.
    respy_obj.attr['num_draws_emax'] = num_draws_emax

    if not num_points == 'all':
        respy_obj.attr['num_points_interp'] = num_points
        respy_obj.attr['is_interpolated'] = True
    else:
        respy_obj.attr['is_interpolated'] = False

    # Debugging setup
    if is_debug:
        respy_obj.attr['is_interpolated'] = True
        respy_obj.attr['num_draws_emax'] = 10
        respy_obj.attr['num_points_interp'] = 11

    # If I am running the first specification, I need to make sure that the limited interpolation
    # model is used.
    if which == 'one':
        open('.structRecomputation.tmp', 'a').close()

    respy_obj.write_out()
    respy_obj = respy.simulate(respy_obj)
    respy_obj_inte = copy.deepcopy(respy_obj)

    # Distribute all class attributes that are independent of the solution method.
    periods_payoffs_systematic, mapping_state_idx, model_paras, num_periods, num_agents_sim, \
    states_all, edu_start, seed_sim, edu_max, delta = dist_class_attributes(respy_obj,
        'periods_payoffs_systematic', 'mapping_state_idx', 'model_paras', 'num_periods',
        'num_agents_sim', 'states_all', 'edu_start', 'seed_sim', 'edu_max', 'delta')

    # Auxiliary objects
    shocks_cholesky = dist_model_paras(model_paras, True)[4]

    # Extract the expected future values from the exact and approximated solution.
    periods_emax_exact = respy_obj_exact.get_attr('periods_emax')
    periods_emax_inter = respy_obj_inte.get_attr('periods_emax')

    # Draw draws for the simulation.
    periods_draws_sims = create_draws(num_periods, num_agents_sim, seed_sim, True)

    # Standard deviates transformed to the distributions relevant for the agents actual decision
    # making as traversing the tree.
    dimension = (num_periods, num_agents_sim, 4)
    periods_draws_sims_transformed = np.tile(np.nan, dimension)

    for period in range(num_periods):
        periods_draws_sims_transformed[period, :, :] = transform_disturbances(
            periods_draws_sims[period, :, :], shocks_cholesky)

    # Simulate a synthetic agent population and compare the implied decisions based on the exact
    # and approximate decision at each of the decision nodes.
    success_indicators = np.tile(np.nan, (num_agents_sim, num_periods))

    for i in range(num_agents_sim):
        current_state = states_all[0, 0, :].copy()

        # Iterate over each period for the agent
        for period in range(num_periods):

            # Distribute state space
            exp_a, exp_b, edu, edu_lagged = current_state
            k = mapping_state_idx[period, exp_a, exp_b, edu, edu_lagged]

            # Select relevant subset
            payoffs_systematic = periods_payoffs_systematic[period, k, :]
            draws = periods_draws_sims_transformed[period, i, :]

            # Get total value of admissible states
            total_payoffs_exact = get_total_value(period, num_periods, delta, payoffs_systematic,
                draws, edu_max, edu_start, mapping_state_idx, periods_emax_exact, k, states_all)

            total_payoffs_inter = get_total_value(period, num_periods, delta, payoffs_systematic,
                draws, edu_max, edu_start, mapping_state_idx, periods_emax_inter, k, states_all)

            # Determine optimal choices and record whether the implications agree between the
            # exact and approximate solutions.
            max_idx_exact = np.argmax(total_payoffs_exact)
            max_idx_inter = np.argmax(total_payoffs_inter)
            success_indicators[i, period] = (max_idx_exact == max_idx_inter)

            # Update work experiences, level of education, and lagged education according to
            # exact solution.
            if max_idx_exact == 0:
                current_state[0] += 1
            elif max_idx_exact == 1:
                current_state[1] += 1
            elif max_idx_exact == 2:
                current_state[2] += 1
            if max_idx_exact == 2:
                current_state[3] = 1
            else:
                current_state[3] = 0

    # Return to request directories.
    os.chdir('../../')

    # Write out results for all requests to the a corresponding table.
    args = tuple()
    args += (num_points, num_draws_emax, success_indicators)
    args += (selected_periods, which, num_agents_sim, num_periods)
    args += (selected_ranges,)
    write_correct(*args)

''' Execution of module as script.
'''

if __name__ == '__main__':

    # Process the command line arguments.
    description = 'Assess correct decisions.'
    is_debug, num_procs = process_command_line(description)

    # Switch to RSLT_DIR. This separate the results form the source files and eases the updating
    # from the compute servers.
    source_dir = enter_results_dir('correct_choices')

    # Create task grid for parallel processing.
    data = ['one', 'two', 'three']

    tasks, requests = list(), list()
    requests += [(2000, 'all'), (1000, 'all'), (250, 'all'), (2000, 2000), (2000, 500)]

    # Debugging setup
    if is_debug:
        requests = [(20, 200)]
        data = ['one']

    for which in data:
        for request in requests:
            task = (which, ) + request
            tasks += [task]

    # Execute the tasks in parallel
    cleanup()

    for task in tasks:
        run(is_debug, task, num_procs)

    send_notification('correct')

    os.chdir(source_dir)