import constants;

constants.MAX_TIME=14400
START_TIME = 0;
TIME_INTERVAL = 14400;

import crGraphs;
import crData;

import helpers;
import numpy;

import matplotlib;
import analysis;
from copy import deepcopy;
import math;


def getExtraParamString(directory_):
    if ('localRec_relEE_persist' in directory_):
        return '_sigD60_TOmin1020_pS0.1';
    elif ('localRec' in directory_):
        return '_sigD60_TOmin0_pS0.001';
    else:
        return '_sigD1000_TOmin0_pS0.001';

def getModelLabel(directory_):
    str = "";
    if ('PEL_basicModel' in directory_):
        str = 'basic model';
    elif ('PEL' in directory_):
        str = 'self-regulating model';

    return str;

def getMaxMinCollectedResurce(scenarioType_,fP_):
    if (fP_ == 100):
        return [300,400];
        #
        # if (scenarioType_ == "group2"):
        #     return [300,400];
        # elif (scenarioType_ == "group4"):
        #     return [300,400]
    elif (fP_ == 200):
        return [150,250];

def getMaxMinEnergy(scenarioType_,fP_):
    return [10000,15000];

def getMaxMinResourceEnergyCost(scenarioType_,fP_):
    if (fP_ == 100):
        return [30,45];
    elif(fP_ == 200):
        return [60,75]



crGraphs.DPI = 100;
constants.NUM_RUNS = 50;

DO_RES_COLLECTED = True; #also makes the matrix graph with congestion vs non-congestion for the basic model only
DO_RES_COLLECTED_PER_ROBOT=False;
DO_RES_COLLECTED_OVER_TIME=False;
DO_RES_COLLECTED_DIFF = False;
DO_RES_PER_LOADING=False;
DO_LOAD_EVENTS_RECRUITMENT = False;
DO_SCOUT_EVENTS_SUCCESS = False;
DO_NEIGHB_SEARCH_SUCESS = False;
DO_ROBOT_STATES = False;
DO_PELLETS = False;


DO_ENERGY_SPENT=False;
DO_RES_ENERGY_COST=False;
DO_RES_COLLECTED_ENERGY_CAPPED=False;

DO_LOAD_EVENTS_QUALITY = False;
DO_LOAD_EVENT_GROUPS=False;
DO_LOAD_EVENTS=False;
DO_WAGGLE_DANCE_GROUPS=False;
DO_LOAD_EVENT_GROUPS_OVER_TIME=False;

DO_TIME_OF_FIRST_LOADS=False;
DO_UNLOAD_TIME_OVER_TIME=False;
DO_TIME_BETWEEN_LOAD_UNLOAD=False;
DO_UNLOAD_TIME=False;

DO_INFO_TRANSFER = False;
DO_INFO_TRANSFER_REGIONS_GRID = False;

DO_BASIC_MODEL = True;
DO_NON_SOCIAL_MODEL = False;
DO_PARAMETER_SWEEP = False;

DO_COMPARISON_WITH_A = False;
DO_COMPARISON_WITH_NON_SOCIAL_MODEL = False;
DO_ENERGY_CAPPED_COMPARISON = False;

DO_COMPARISON_WITH_B = False;

DO_TU = True;
DO_PW = False;



if __name__ == "__main__":
    
    NR=50;
    fPVals = [1,50,100,200];
    fPValLegend = ["fP=1","fP=50","fP=100","fP=200"];
    fPValLegend = ["fP=100","fP=200"];
    fPVals = [50,100,200];
    fPVals = [50,200];
    fPVals = [1,50,200];
    fPValLegend = ["fP=1","fP=50","fP=200"];

    if (DO_COMPARISON_WITH_NON_SOCIAL_MODEL):
        experimentDirModelA = "PELNONSOC";
        modelALabel = getModelLabel(experimentDirModelA);
        extraParamNameStrModelA = '_pW0.0001_tDsigD90_TU600';
        DO_COMPARISON_WITH_A = True;
    else:
        experimentDirModelA = "PEL_basicModel";
        modelALabel = getModelLabel(experimentDirModelA);
        extraParamNameStrModelA = '';



    if (DO_TU):
        tDSigDVals = [90];
        pWVals = [0.0001];
        TUVals = [200,400,600,800,1000,1200];
        #TUVals = [800];
        TULegendVals = ["TU=200","TU=400","TU=600","TU=800","TU=1000","TU=1200"];

        if (DO_NON_SOCIAL_MODEL):
            TUVals = [400,600,800,1000,1200];
            TULegendVals = ["TU=400","TU=600","TU=800","TU=1000","TU=1200"];
            pWVals = [0.01,0.001,0.0001];
    elif (DO_PW):
        tDSigDVals = [90];
        pWVals = [0.01,0.001,0.0001,0.0];
        pWLegendVals = ["p(W) = 0.01","p(W) = 0.001","p(W) = 0.0001","p(W) = 0.0"]
        TUVals = [800];


    energyCapVals = [2000,4000,6000,8000,10000,12000,15000];
    #energyCapVals = [8000,15000];
    #energyCapLegend = ["E=8000","E=12000","E=unlimited"];
    #energyCapLegend = ["8000","12000","Unlimited"];
    #energyCapLegend = energyCapVals;
    #energyCapVals = [15000];
    energyCapLegend = [];
    for ec in range(len(energyCapVals)):
        if (energyCapVals[ec] < 14400):
            energyCapLegend.append(str(energyCapVals[ec]));
        else:
            energyCapLegend.append("unlimited");


    selectedParameterSetStr = "pW0.0001_tDsigD90_TU800";
    if (DO_NON_SOCIAL_MODEL):
        selectedParameterSetStr = "pW0.0001_tDsigD90_TU600";

    distances = [5,7,9];

    if (DO_BASIC_MODEL):
        EXPERIMENT_DIR = 'PEL_basicModel';
       # scenarioTypes = ['Heap1','Heap2','Heap4','Scatter10','Scatter25'];
        scenarioTypes = ['Heap1','Heap4','Scatter10','Scatter25'];
    elif (DO_NON_SOCIAL_MODEL):
        EXPERIMENT_DIR = 'PELNONSOC';
        scenarioTypes = ['Heap2','Heap4','Scatter25'];
    else:
        EXPERIMENT_DIR = 'PEL';
        scenarioTypes = ['Heap2','Heap4','Scatter25'];

    if (DO_COMPARISON_WITH_A):
        #scenarioTypes = ['Heap1','Heap2','Heap4','Scatter10','Scatter25'];
        scenarioTypes = ['Heap1','Heap4','Scatter10','Scatter25'];
        if (DO_ENERGY_CAPPED_COMPARISON):
            scenarioTypes = ['Heap1','Scatter25'];
            distances = [5,9];


    if DO_LOAD_EVENT_GROUPS_OVER_TIME:
        scenarioTypes = ['Heap2'];
        distances = [9];

    if (DO_ROBOT_STATES or DO_PELLETS):
        scenarioTypes = ['Heap1','Heap2','Scatter25'];
        distances = [7];
        pWVals = [0.0001];
        tDSigDVals = [90];
        if (DO_NON_SOCIAL_MODEL):
            TUVals = [600];
        else:
            TUVals = [800];
    examinedModelLabel = getModelLabel(EXPERIMENT_DIR);
    extraParamNameStr = '';




    # experimentDirModelA = "PEL";
    # modelALabel = getModelLabel(experimentDirModelA);
    # extraParamNameStrModelA = '_pW0.001_tDsigD60_TU800';


    #--
    IMAGE_OUTPUT_DIR = EXPERIMENT_DIR + "/" + EXPERIMENT_DIR;


    scenarioInfo = helpers.getScenariosBasedOnTypes(scenarioTypes, distances);
    scenarioLegends = scenarioInfo[0];
    scenarios = scenarioInfo[1];
    scenarioLegendsWithDistances = [];
    for d in range(len(distances)):
        for sc in range(len(scenarioTypes)):
            scenarioLegendsWithDistances.append(scenarioTypes[sc] + ", D=" + str(distances[d]));
    print(scenarioLegendsWithDistances);

                       
    xLabel = "|-------------------  D=5  -------------------|       |-------------------  D=7  -------------------|       |-------------------  D=9  -------------------|";
    xLabel = "|-----------------  D=5  -----------------|           |-----------------  D=9  -----------------|";


              
    matplotlib.pyplot.close("all");  

    if (DO_BASIC_MODEL):
        analysis.analyse("fP",fPVals,fPValLegend,scenarios,scenarioLegends,[""],EXPERIMENT_DIR, IMAGE_OUTPUT_DIR,
                DO_RES_COLLECTED,DO_RES_COLLECTED_PER_ROBOT,DO_LOAD_EVENTS_RECRUITMENT,DO_SCOUT_EVENTS_SUCCESS,
                DO_NEIGHB_SEARCH_SUCESS,DO_ROBOT_STATES,DO_PELLETS,DO_LOAD_EVENT_GROUPS,DO_UNLOAD_TIME,DO_ENERGY_SPENT,DO_RES_ENERGY_COST,
                DO_WAGGLE_DANCE_GROUPS,DO_LOAD_EVENT_GROUPS_OVER_TIME,
                DO_INFO_TRANSFER,DO_INFO_TRANSFER_REGIONS_GRID, False,
                2000,xLabel_=xLabel,timeInterval_=TIME_INTERVAL,startTime_=START_TIME,
                afterParamNameStr_="_NR" + str(NR),
                scenarioTypes_=scenarioTypes,distances_=distances,
                size_=(20,6));
    else:
        if (DO_PARAMETER_SWEEP):

            xTickVals = [400,600,800,1000,1200]; xTickLabels = [400,600,800]; xLabel = "unl. time threshold TU";
            yTickVals = [0.1,0.01,0.001,0.0001,0.0]; yTickLabels = [0.1,0.01,0.001,0.0001]; yLabel = "waking probability p(W)"
            zTickVals = [60,90,120,1000]; zTickLabels = [60,90,120,1000]; zLabel = "TD signal range"

            pointData = [
                #-- pW=0.01
                [0,1,0],
                [1,1,0],
                [2,1,0],

                [0,1,1],
                [1,1,1],
                [2,1,1],

                [0,1,2],
                [1,1,2],
                [2,1,2],

                #-- pW=0.001
                [0,2,0],
                [1,2,0],
                [2,2,0],

                [0,2,1],
                [1,2,1],
                [2,2,1],

                [0,2,2],
                [1,2,2],
                [2,2,2],

                #-- TDsigD=1000
                [0,0,3],
                [2,0,3],

                [0,1,3],
                [2,1,3],

                [0,2,3],
                [2,2,3],

                #-- pW=0.0001
                [0,3,1],
                [1,3,1],
                [2,3,1],

                [0,3,2],
                [1,3,2],
                [2,3,2],

                #-- pW=0.0
                # [2,4,1],
                #
                # #-- high TU for sigD=90, pW=0.0001
                # [3,3,1],
                # [4,3,1],



            ];

            for fP in range(len(fPVals)):
                for scType in range(len(scenarioTypes)):
                    #-- include only 1 scenario type but all distances in the measure
                    scenariosToInclude = helpers.getScenariosBasedOnTypes([scenarioTypes[scType]], distances)[1];
                    print("-------- scenarios: " + str(scenariosToInclude));

                    graphDataResCollected = deepcopy(pointData);
                    graphDataEnergy = deepcopy(pointData);
                    graphDataResEnergyCost = deepcopy(pointData);

                    for pointIndex in range(len(pointData)):
                        #-- get res collected for a specific parameter point, for all scenarios
                        resCollectedData = helpers.getResourceCollectedForParameter("fP" + str(fPVals[fP]) + "_NR" + str(NR)  + "_pW" + str(yTickVals[pointData[pointIndex][1]]) + "_tDsigD" + str(zTickVals[pointData[pointIndex][2]]) + "_TU" + str(xTickVals[pointData[pointIndex][0]]) ,[""], scenariosToInclude, EXPERIMENT_DIR);
                        energyData = helpers.getEnergySpent("fP" + str(fPVals[fP]) + "_NR" + str(NR)  + "_pW" + str(yTickVals[pointData[pointIndex][1]]) + "_tDsigD" + str(zTickVals[pointData[pointIndex][2]]) + "_TU" + str(xTickVals[pointData[pointIndex][0]]), [''], scenariosToInclude, EXPERIMENT_DIR,NR);

                        allResCollectedValues = [];
                        allEnergyValues = [];
                        allResEnergyCostValues = [];
                        #-- put together values found in all scenarios
                        for sc in range(len(scenariosToInclude)):
                            allResCollectedValues.extend(resCollectedData[0][sc][0]);
                            allEnergyValues.extend(energyData[sc][0]);
                            costsInRuns = [];
                            for r in range(len(energyData[sc][0])):
                                costsInRuns.append(energyData[sc][0][r]/resCollectedData[0][sc][0][r]);
                            allResEnergyCostValues.extend(costsInRuns);

                        #-- generate 1 number that represents how this parameter set did in all the scenarios
                        graphDataResCollected[pointIndex].append(crData.getAverage(allResCollectedValues));
                        graphDataEnergy[pointIndex].append(crData.getAverage(allEnergyValues));
                        graphDataResEnergyCost[pointIndex].append(crData.getAverage(allResEnergyCostValues));
                        print(str(pointIndex) + "/" + str(len(pointData)) + ": res: " + str(graphDataResCollected[pointIndex][3]) + "   energy " + str(graphDataEnergy[pointIndex][3]) + "  cost " + str(graphDataResEnergyCost[pointIndex][3]));

                    fileName = constants.BASE_FILE_PATH+"/" +IMAGE_OUTPUT_DIR + "_resAll_fP" + str(fPVals[fP]) + "_" + scenarioTypes[scType] + "_NR" + str(NR) + ".png";
                    crGraphs.create4Dplot(graphDataResCollected,xLabel,yLabel,zLabel,xTickLabels,yTickLabels,zTickLabels,valueRangeMin_=getMaxMinCollectedResurce(scenarioTypes[scType],fPVals[fP])[0], valueRangeMax_=getMaxMinCollectedResurce(scenarioTypes[scType],fPVals[fP])[1],fileName_=fileName,size_=(16,12));

                    fileName = constants.BASE_FILE_PATH+"/" +IMAGE_OUTPUT_DIR + "_energyAll_fP" + str(fPVals[fP]) + "_" + scenarioTypes[scType] + "_NR" + str(NR) + ".png";
                    crGraphs.create4Dplot(graphDataEnergy,xLabel,yLabel,zLabel,xTickLabels,yTickLabels,zTickLabels,valueRangeMin_=getMaxMinEnergy(scenarioTypes[scType],fPVals[fP])[0], valueRangeMax_=getMaxMinEnergy(scenarioTypes[scType],fPVals[fP])[1],fileName_=fileName,size_=(16,12));

                    fileName = constants.BASE_FILE_PATH+"/" +IMAGE_OUTPUT_DIR + "_resEnergyCostAll_fP" + str(fPVals[fP]) + "_" + scenarioTypes[scType] + "_NR" + str(NR) + ".png";
                    crGraphs.create4Dplot(graphDataResEnergyCost,xLabel,yLabel,zLabel,xTickLabels,yTickLabels,zTickLabels,valueRangeMin_=getMaxMinResourceEnergyCost(scenarioTypes[scType],fPVals[fP])[0], valueRangeMax_=getMaxMinResourceEnergyCost(scenarioTypes[scType],fPVals[fP])[1],fileName_=fileName,size_=(16,12));


        elif (DO_COMPARISON_WITH_A):
            if (DO_ENERGY_CAPPED_COMPARISON):
                #          NR25, H1:5   NR25,H25:5    NR25,H1:9  NR25,S25:9
                markers = ['go-',       'go-',      'rs-',      'rs-'];
                lineStyles = ['-', '--','-','--'];

                #NRVals = [50];
                params = [(25,50), (50,200)]
                for i in range(len(params)):
                    NR = params[i][0]
                    fP = params[i][1]
                    #scenarioLegendsWithDistancesAndNRs = [];
                    #for nr in range(len(NRVals)):
                    #    for sc in range(len(scenarios)):
                    #        scenarioLegendsWithDistancesAndNRs.append(scenaroLegendsWithDistances[sc] + ", NR=" + str(NRVals[nr]));
                    #graphDataByECap = [[[] for c in range(len(energyCapVals))] for sc in range(len(scenarioLegendsWithDistancesAndNRs))];

                    #print(scenarioLegendsWithDistancesAndNRs);

                    graphDataByECap = [[[] for c in range(len(energyCapVals))] for sc in range(len(scenarioLegendsWithDistances))];
                    #graphDataByECap[0] = [ [[0 for r in range(constants.NUM_RUNS)] for sc in range(len(scenarios))] for c in range(len(energyCapVals))] ;
                    #for nr in range(len(NRVals)):
                    graphDataByScenario = [];
                    for c in range(len(energyCapVals)):

                        data = helpers.getResourceCollectedForEnergy("fP" + str(fP) + "_NR" + str(NR) + "_" + selectedParameterSetStr,[""], scenarios, EXPERIMENT_DIR,energyCapVals[c]);
                        dataNoTD = helpers.getResourceCollectedForEnergy("fP" + str(fP) + "_NR" + str(NR) + extraParamNameStrModelA, [""], scenarios, experimentDirModelA,energyCapVals[c]);

                        #-- for each data point of examined model, compare
                        graphDataPartial = [[0 for r in range(constants.NUM_RUNS)] for sc in range(len(scenarios))];
                        for sc in range(len(data[0])):
                            #for paramVal in range(len(data[0][sc])):
                            for run in range(len(data[0][sc][0])):
                                #graphDataPartial[sc][paramVal][run] = data[0][sc][paramVal][run] - dataNoTD[0][sc][paramVal][run];   #delta improvement of resource collected
                                graphDataPartial[sc][run] = ((data[0][sc][0][run]/dataNoTD[0][sc][0][run]) - 1.0)*100; # percentage improvement of resource collected
                                #print(dataNoTD[0][sc][paramVal][run]);

                        #-- append to the overal structure that has all results for all E cap values
                        graphDataByScenario.append(graphDataPartial);

                        #print(graphDataByScenarioAndNR);
                            #print(len(graphDataByScenarioAndNR));
                            #print(len(graphDataByScenarioAndNR[0]));

                        #-- the graphDataByScenario is ordered so that X axis is different scenarios. Make it so that x axis is Ecap values.
                        #-- append the results to the overal structure that collates results from more NR vals.

                    for sc in range(len(scenarios)):
                        print(scenarios[sc]);
                        for c in range(len(energyCapVals)):
                            #print(graphDataByScenario[c][sc]);
                            #overalIndex = int(nr*len(scenarios) + sc);
                            #print(str(sc) + " APPENDING TO " + str(sc));
                            #print(graphDataByScenario[c][sc]);
                            graphDataByECap[sc][c] = graphDataByScenario[c][sc];

                    graphWidth = 12;


                    #print(len(graphDataByECap));

                    fileName = constants.BASE_FILE_PATH+"/" + IMAGE_OUTPUT_DIR + "fP" + str(fP) + "_NR" + str(NR) + selectedParameterSetStr + "_resECapVS_" + experimentDirModelA + "_" + extraParamNameStrModelA + ".png"; # yLimMin_=50,yLimMax_=150
                    crGraphs.createPlot(range(len(energyCapLegend)), graphDataByECap, "Total swarm energy limitation", "% improvement of res. collected", scenarioLegendsWithDistances, xTickLabels_=energyCapLegend, confidenceIntervals_=True, markers_=markers, markerSize_=15, lineWidth_=2, lineStyles_=lineStyles, fileName_=fileName,size_=(graphWidth,6),yLimMin_=-100,yLimMax_=100);

            else:
                for fP in range(len(fPVals)):
                    analysis.compareWithA(selectedParameterSetStr, scenarios, scenarioLegends, NR, fPVals[fP], "-",
                                      "fP" + str(fPVals[fP]) + "_NR" + str(NR), experimentDirModelA,
                                    "", "",
                                      EXPERIMENT_DIR, IMAGE_OUTPUT_DIR + "fP" + str(fPVals[fP]) + "_NR" + str(NR) + selectedParameterSetStr,
                    DO_RES_COLLECTED,DO_RES_COLLECTED_OVER_TIME,DO_RES_COLLECTED_PER_ROBOT,DO_RES_COLLECTED_DIFF,DO_RES_PER_LOADING,
                    DO_LOAD_EVENTS_RECRUITMENT,DO_SCOUT_EVENTS_SUCCESS,
                    DO_NEIGHB_SEARCH_SUCESS,DO_ROBOT_STATES,DO_PELLETS, DO_LOAD_EVENT_GROUPS, DO_UNLOAD_TIME, DO_ENERGY_SPENT,DO_RES_ENERGY_COST,
                    DO_LOAD_EVENTS,DO_LOAD_EVENT_GROUPS_OVER_TIME,DO_TIME_OF_FIRST_LOADS,DO_UNLOAD_TIME_OVER_TIME,DO_TIME_BETWEEN_LOAD_UNLOAD,DO_RES_COLLECTED_ENERGY_CAPPED,
                    DO_INFO_TRANSFER,
                    DO_COMPARISON_WITH_B,
                    800,100,energyCapPerRobot_=12000,xLabel_=xLabel,size_=(20,6),
                    scenarioTypes_=scenarioTypes,distances_=distances, timeInterval_=TIME_INTERVAL,startTime_=START_TIME,
                    afterParamNameStr_=extraParamNameStr, afterParamNameStrModelA_=extraParamNameStrModelA,
                    examinedModelLabel_=examinedModelLabel, modelALabel_=modelALabel);
        elif (DO_TU):
            for fP in range(len(fPVals)):
                for pW in range(len(pWVals)):
                    for tDSigD in range(len(tDSigDVals)):
                        analysis.analyse("fP" + str(fPVals[fP]) + "_NR" + str(NR) + "_pW" + str(pWVals[pW]) + "_tDsigD" + str(tDSigDVals[tDSigD]) + "_TU",TUVals,TULegendVals,scenarios,scenarioLegends,[""],EXPERIMENT_DIR,
                            IMAGE_OUTPUT_DIR + "fP" + str(fPVals[fP]) + "_NR" + str(NR) + "_pW" + str(pWVals[pW]) + "_tDsigD" + str(tDSigDVals[tDSigD]),
                            DO_RES_COLLECTED,DO_RES_COLLECTED_PER_ROBOT,DO_LOAD_EVENTS_RECRUITMENT,DO_SCOUT_EVENTS_SUCCESS,
                            DO_NEIGHB_SEARCH_SUCESS,DO_ROBOT_STATES,DO_PELLETS,DO_LOAD_EVENT_GROUPS,DO_UNLOAD_TIME,DO_ENERGY_SPENT,DO_RES_ENERGY_COST,
                            DO_WAGGLE_DANCE_GROUPS,DO_LOAD_EVENT_GROUPS_OVER_TIME,
                            DO_INFO_TRANSFER,DO_INFO_TRANSFER_REGIONS_GRID,
                            400,xLabel_=xLabel,timeInterval_=TIME_INTERVAL,startTime_=START_TIME,
                            scenarioTypes_=scenarioTypes,distances_=distances,
                            size_=(20,6));
        elif (DO_PW):
            for fP in range(len(fPVals)):
                for tDSigD in range(len(tDSigDVals)):
                    for TU in range(len(TUVals)):
                        analysis.analyse("fP" + str(fPVals[fP]) + "_NR" + str(NR) + "_pW",pWVals,pWLegendVals,scenarios,scenarioLegends,[""],EXPERIMENT_DIR,
                            IMAGE_OUTPUT_DIR + "fP" + str(fPVals[fP]) + "_NR" + str(NR) + "_tDsigD" + str(tDSigDVals[tDSigD]) + "_TU" + str(TUVals[TU]) ,
                            DO_RES_COLLECTED,DO_RES_COLLECTED_PER_ROBOT,DO_LOAD_EVENTS_RECRUITMENT,DO_SCOUT_EVENTS_SUCCESS,
                            DO_NEIGHB_SEARCH_SUCESS,DO_ROBOT_STATES,DO_PELLETS,DO_LOAD_EVENT_GROUPS,DO_UNLOAD_TIME,DO_ENERGY_SPENT,DO_RES_ENERGY_COST,
                            DO_WAGGLE_DANCE_GROUPS,DO_LOAD_EVENT_GROUPS_OVER_TIME,
                            DO_INFO_TRANSFER,DO_INFO_TRANSFER_REGIONS_GRID,
                            400,xLabel_=xLabel,timeInterval_=TIME_INTERVAL,startTime_=START_TIME,
                            afterParamNameStr_= "_tDsigD" + str(tDSigDVals[tDSigD]) + "_TU" + str(TUVals[TU]),
                            scenarioTypes_=scenarioTypes,distances_=distances,
                            size_=(20,6));



