"""
crGraphs - library for python that lets you create graphs with single function calls

Author: Lenka Pitonakova: contact@lenkaspace.net

"""
import pylab;
import matplotlib.pyplot as plt
import numpy;
import itertools
from pyvttbl import DataFrame
import matplotlib.colors as clrs
from matplotlib import font_manager as fm
import math;
import scipy.stats;
#from mpl_toolkits.mplot3d import Axes3D;
from copy import deepcopy;
import crData;
import helpers;


SHOW_OUTPUT = True;

LABEL_FONT_SIZE = 'xx-large';
TICK_FONT_SIZE = "large";

DEFAULT_COLORS = ['b','r','g','c','k'];
DEFAULT_MARKERS = ['b-','r-','g-','c-','k-'];

DPI = 300;

INVALID_VALUE = -999999;

def createAnovaTable(data_ , legendLabels_, groupLabels_, valueMultiplier_=1.0, fileName_ = ""):

    outputStr = "";
    #print(groupLabels_);
    #print(legendLabels_);
    #print(len(data_));
    #print(len(data_[0]))
    for i in range(len(groupLabels_)):
        envDescription = helpers.getEnvironmentDescriptionFromString(groupLabels_[i]);
        #print("--------- " + envDescription)
        outputStr += "======================================================================================\n"
        outputStr += "==================== Environment: " + envDescription + "\n";
        outputStr += "======================================================================================\n\n"

        #print(data_[1][i])
        if (all(x == data_[1][i][0] for x in data_[1][i])):
            outputStr += "[no data]\n\n"
            print("[no data]");
            continue

        df=DataFrame()
        df['data'] = [x*valueMultiplier_ for x in data_[0][i]] + [x*valueMultiplier_ for x in data_[1][i]] + [x*valueMultiplier_ for x in data_[2][i]]
        df['conditions'] = [helpers.getControllerDescriptionForAnovaTable(legendLabels_[0])]*len(data_[0][i]) + [helpers.getControllerDescriptionForAnovaTable(legendLabels_[1])]*len(data_[1][i]) + [helpers.getControllerDescriptionForAnovaTable(legendLabels_[2])]*len(data_[2][i])
        # visually verify data in DataFrame
        #print df

        # run 1 way analysis of variance
        # returns another dict-like object
        aov=df.anova1way('data','conditions')

        # print anova results
        outputStr += str(aov);

        # this is just to show the data in the aov object
        #print aov.keys()

        # calculate omega-squared
        aov['omega-sq']=(aov['ssbn']-aov['dfbn']*aov['mswn']) / (aov['ssbn'] + aov['sswn'] + aov['mswn'])

        # you can access the results this way
        #print aov['omega-sq']
        #print aov['f']
        #print aov['p']

        outputStr += "\n\n\n\n"

    #-- if file name empty, show the graph
    if (len(fileName_) > 0):
        with open(fileName_+".txt","w") as text_file:
            text_file.write(outputStr)

        if (SHOW_OUTPUT == True):
            print("Saved " + fileName_);
    else:
        print(outputStr);


#
def createPieChart(data_=[], groupLabels_=[], groupColors_=[], showPercentageVals_=False, showActualVals_=True, showShadow_=False,
                   groupsFontSize_=LABEL_FONT_SIZE, valsFontSize_=TICK_FONT_SIZE, size_=(8,6),
                   fileName_ = "", holdFigure_=False, figure_=None, subPlot_=111):
    
    """
    Create a pie chart
    """
    if (type(data_) != list and type(data_) != numpy.ndarray):
        raise Exception("The data_ must be a 1d list"); 
    
    if (figure_ == None):
        fig = pylab.figure(figsize=size_, dpi=DPI);
    else:
        fig = figure_;
    ax = fig.add_subplot(subPlot_);
    box = ax.get_position();
    ax.set_position([box.x0 - box.width *0.15, box.y0 - box.height*0.15, box.width*1.3, box.height * 1.3]);
    
    if (len(groupColors_) == 0):
        groupColors_ = DEFAULT_COLORS;

    def formatPieceNumber(val_):
        if (showActualVals_ and showPercentageVals_):
            val=int(val_*sum(data_)/100.0)
            return '{p:.1f}% ({v:d})'.format(p=val_,v=val);
        if (showActualVals_):
            val=int(val_*sum(data_)/100.0)
            return '{v:d}'.format(v=val);
        elif (showPercentageVals_):
            return '{p:.2f}%'.format(p=val_);
        return '';
        
    #-- create the graph
    patches, texts, autotexts = ax.pie(data_, labels=groupLabels_, autopct=formatPieceNumber, shadow=showShadow_, colors=groupColors_);
    
    #-- setup fonts
    proptease = fm.FontProperties();
    proptease.set_size(groupsFontSize_);
    plt.setp(texts, fontproperties=proptease);
    proptease.set_size(valsFontSize_);
    plt.setp(autotexts, fontproperties=proptease);
    
    #-- display / print:
    if (not holdFigure_):
        if (len(fileName_) > 0):
            pylab.savefig(fileName_, format='png')
            if (SHOW_OUTPUT == True):
                print("Saved " + fileName_); 
        else:
            pylab.show()
    
    
#
def createMatrixPlot(data_=[[],[]], xLabel_ = "", yLabel_ = "", xTickLabels_ = [], yTickLabels_ = [], colorBarLabel_ = "",
                     colorMap_ = None, vmin_ = 0, vmax_ = 1,
                     labelFontSize_ = LABEL_FONT_SIZE, tickFontSize_ = TICK_FONT_SIZE, legendFontSize_ = LABEL_FONT_SIZE, size_=(8,6),
                     fileName_ = "", holdFigure_=False,
                     annotateValues_=False, annotationStringAfter_="", annotationValues_=[[],[]],
                     roundAnnotatedValues_=False):
    """
    Create 2D matrix plot where color gradient represents value on a 3rd dimension.
    """
    
    if (type(data_) != list and type(data_) != numpy.ndarray):
        raise Exception("The data_ must be a 2d list"); 
    if (type(data_[0]) != list and type(data_[0]) != numpy.ndarray):
        raise Exception("The data_ must be a 2d list");
    if (type(xTickLabels_) != list and type(xTickLabels_) != numpy.ndarray):
        raise Exception("The xTickLabels_ must be a 1d list"); 
    if (type(yTickLabels_) != list and type(yTickLabels_) != numpy.ndarray):
        raise Exception("The yTickLabels_ must be a 1d list");

    if (len(annotationValues_[0]) == 0):
        annotationValues_ = deepcopy(data_);
    
    
    #-- create graph
    fig = pylab.figure(figsize=size_, dpi=DPI);
    pylab.xlabel(xLabel_, size=labelFontSize_);
    pylab.ylabel(yLabel_, size=labelFontSize_);
    pylab.xticks(size=tickFontSize_);
    pylab.yticks(size=tickFontSize_);
    
    #-- decide on colors
    origin = 'lower';
    cmap = None;
    if (colorMap_ == None):
        cmap=plt.cm.get_cmap("summer");
    else:
        cmap=colorMap_;
    
    #-- make the plot
    ax = fig.add_subplot(111);
    cax = ax.matshow(data_,cmap=cmap,origin=origin,vmin=vmin_,vmax=vmax_);
    
    
    #-- set tick labels
    if (len(xTickLabels_) > 0):
        ax.set_xticklabels([''] + xTickLabels_);
    if (len(yTickLabels_) > 0):  
        ax.set_yticklabels([''] + yTickLabels_);
    
    
    #-- make a colorbar for the ContourSet returned by the contourf call.
    #cbar = fig.colorbar(cax, ticks=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0])
    cbar = fig.colorbar(cax)
    cbar.ax.set_ylabel(colorBarLabel_, size=labelFontSize_);
    for t in cbar.ax.get_yticklabels():
        t.set_fontsize(tickFontSize_)
        
    #-- annotations
    if (annotateValues_):
        #-- the XY grid position has [0;0] in the bottom left corner
        gridStartX=0.0;
        gridStartY=0.0;
        gridEndX=1.0;
        gridEndY=1.0;
        gridStepX=(gridEndX-gridStartX)/len(data_[0]);
        gridStepY=(gridEndY-gridStartY)/len(data_);
        for y in range(len(data_)):
            for x in range(len(data_[0])):
                if (roundAnnotatedValues_):
                    ax.annotate(str(math.ceil(annotationValues_[y][x] * 100) / 100.0) + annotationStringAfter_, xy=(gridStartX+x*gridStepX+gridStepX/2, gridStartY+y*gridStepY+gridStepY/2),  xycoords='axes fraction',horizontalalignment='center', verticalalignment='center')
                else:
                    ax.annotate(str(annotationValues_[y][x]) + annotationStringAfter_, xy=(gridStartX+x*gridStepX+gridStepX/2, gridStartY+y*gridStepY+gridStepY/2),  xycoords='axes fraction',horizontalalignment='center', verticalalignment='center')

    #-- if file name empty, show the graph
    if (not holdFigure_):
        if (len(fileName_) > 0):
            pylab.savefig(fileName_, format='png')
            if (SHOW_OUTPUT == True):
                print("Saved " + fileName_); 
        else:
            pylab.show()
    
    
#
def createContourPlot(xValues_=[], yValues_=[], zValues_=[[],[]], xLabel_ = "", yLabel_ = "", colorBarLabel_ = "",
                      showContours_ = True, levels_ = [], colors_ = None, colorMap_ = None,
                      labelFontSize_ = LABEL_FONT_SIZE, tickFontSize_ = TICK_FONT_SIZE, legendFontSize_ = LABEL_FONT_SIZE, size_=(12,6),
                      fileName_ = "", holdFigure_=False):
    """
    Create a countour plot with different colour levels for zValues_ and contours around levels
    """
    
    #-- check data
    if (type(xValues_) != list and type(xValues_) != numpy.ndarray):
        raise Exception("The xValues_ must be a 1d list"); 
    if (type(yValues_) != list and type(yValues_) != numpy.ndarray):
        raise Exception("The yValues_ must be a 1d list"); 
    if (type(zValues_) != list and type(zValues_) != numpy.ndarray):
        raise Exception("The zValues_ must be a 2d list"); 
    if (type(zValues_[0]) != list and type(zValues_[0]) != numpy.ndarray):
        raise Exception("The zValues_ must be a 2d list");
    
    #-- create graph
    pylab.figure(figsize=size_, dpi=DPI);
    pylab.xlabel(xLabel_, size=labelFontSize_);
    pylab.ylabel(yLabel_, size=labelFontSize_);
    pylab.xticks(size=tickFontSize_);
    pylab.yticks(size=tickFontSize_);
    
    #-- setup the meshgrid
    X, Y = numpy.meshgrid(xValues_, yValues_);
    
    #-- decide on colors
    origin = 'lower';
    cmap = None;
    if (colors_ == None):
        if (colorMap_ == None):
            cmap=plt.cm.get_cmap("summer");
        else:
            cmap=colorMap_;
            
    #-- create basic contour plot
    if (len(levels_) > 0):
        #-- draw with levels specified
        CS = plt.contourf(X, Y, zValues_, 5, cmap=cmap, origin=origin, levels=levels_, colors=colors_);
    else:
        #-- automatic levels
        CS = plt.contourf(X, Y, zValues_, 10, cmap=cmap, origin=origin, colors=colors_);
    #-- use it to draw boundaries on top
    if (showContours_):
        CS2 = plt.contour(CS, levels=CS.levels[::2], colors = 'r', origin=origin, hold='on');
    
   
    #-- make a colorbar for the ContourSet returned by the contourf call.
    cbar = plt.colorbar(CS)
    cbar.ax.set_ylabel(colorBarLabel_, size=labelFontSize_);
    for t in cbar.ax.get_yticklabels():
        t.set_fontsize(tickFontSize_)
    
    #-- add the contour line levels to the colorbar
    if (showContours_):
        cbar.add_lines(CS2)

    #-- if file name empty, show the graph
    if (not holdFigure_):
        if (len(fileName_) > 0):
            pylab.savefig(fileName_, format='png')
            if (SHOW_OUTPUT == True):
                print("Saved " + fileName_); 
        else:
            pylab.show()



def createBarChart(valueData_ = [[],[]], stdData_ = [[],[]], xTickLabels_=[], groupLabels_ = [], xLabel_="", yLabel_ = "", groupColors_ = [],
                   legendCols_ = 4, barWidth_ = 0.35,
                   labelFontSize_ = LABEL_FONT_SIZE, tickFontSize_ = TICK_FONT_SIZE, legendFontSize_ = LABEL_FONT_SIZE, size_=(12,6),
                   fileName_ = "", holdFigure_=False):
    """
    Create bars for groups next to each other, grouped by group labels
    """
    
    #-- check data
    if (type(valueData_) != list):
        raise Exception("The valueData_ must be a 2d list where the first dimension is groups and 2nd dimension is values"); 
    if (type(valueData_[0]) != list):
        raise Exception("The valueData_ must be a 2d list where the first dimension is groups and 2nd dimension is values");
    
    if (type(stdData_) != list):
        raise Exception("The stdData_ must be a 2d list where the first dimension is groups and 2nd dimension is values"); 
    if (type(stdData_[0]) != list):
        raise Exception("The stdData_ must be a 2d list where the first dimension is groups and 2nd dimension is values"); 
    
    N = len(valueData_[0]);
    ind = numpy.arange(N)  # the x locations for the groups
    
    fig = pylab.figure(figsize=size_, dpi=DPI);
    ax = fig.add_subplot(111);
    plt.ylabel(yLabel_, size=labelFontSize_);
    plt.xlabel(xLabel_, size=labelFontSize_);
    plt.xticks(ind+barWidth_, size=labelFontSize_);
    ax.set_xticklabels( xTickLabels_, size=tickFontSize_);
    plt.yticks(size=tickFontSize_);
    
    #-- plot bars next to each other    
    plots = [];
    for i in range(len(valueData_)):
        #-- pick a color
        if (len(groupColors_) > i):
            colorCode = groupColors_[i];
        elif (len(DEFAULT_COLORS) > i):
            colorCode = DEFAULT_COLORS[i];
        
        plot = ax.bar(ind+i*barWidth_, valueData_[i], barWidth_, color=colorCode, yerr=stdData_[i])
        plots.append(plot);
        
    
    #-- plot line for y=0
    pylab.plot(numpy.linspace(-0.1,N-1+0.8,3),[0,0,0],'k-');
    
    #-- setup legend
    box = ax.get_position()
    legendItems = [];
    for g in range(len(plots)):
        legendItems.append(plots[g][0]);
        
    
    #-- if file name empty, show the graph
    if (not holdFigure_):
        if (len(fileName_) > 0):
            pylab.savefig(fileName_, format='png')
            if (SHOW_OUTPUT == True):
                print("Saved " + fileName_); 
        else:
            pylab.show()
    
    return fig;
    
    
    
#      
def createStackedBar(valueData_ = [[],[]], stdData_ = [[],[]], groupLabels_ = [], barLabels_ = [],  xLabel_="", yLabel_ = "", groupColors_ = [],
                     groupPatterns_=[], legendCols_ = 4, showLegend_ = True, barWidth_ = 0.35,
                     yLimMin_=-999999,yLimMax_=-999999,
                     labelFontSize_ = LABEL_FONT_SIZE, tickFontSize_ = TICK_FONT_SIZE, legendFontSize_ = LABEL_FONT_SIZE, size_=(12,5),
                     fileName_ = "", holdFigure_=False):
    """
    Create bars for groups stacked on top of each other by bar labels.
    """
    
    #-- check data
    if (type(valueData_) != list):
        raise Exception("The valueData_ must be a 2d list where the first dimension is groups and 2nd dimension is values for each bar");
    if (type(valueData_[0]) != list):
        raise Exception("The valueData_ must be a 2d list where the first dimension is groups and 2nd dimension is values for each bar");
    
    if (type(stdData_) != list):
        raise Exception("The stdData_ must be a 2d list where the first dimension is groups and 2nd dimension is values for each bar");
    if (type(stdData_[0]) != list):
        raise Exception("The stdData_ must be a 2d list where the first dimension is groups and 2nd dimension is values for each bar");
    
    numOfBars = len(valueData_[0]);
    ind = numpy.arange(numOfBars);      # the x locations for the groups

    fig = pylab.figure(figsize=size_, dpi=DPI);
    ax = fig.add_subplot(111);
    plt.ylabel(yLabel_, size=labelFontSize_);
    plt.xlabel(xLabel_, size=labelFontSize_);
    plt.xticks(ind+barWidth_/2., barLabels_, size=tickFontSize_);
    plt.yticks(size=tickFontSize_);
   
    
    #-- go through each group and display it
    plots = [];
    for g in range(len(valueData_)):
        #-- pick a color
        if (len(groupColors_) > g):
            colorCode = groupColors_[g];
        elif (len(DEFAULT_COLORS) > g):
            colorCode = DEFAULT_COLORS[g];
            
        #-- pick a pattern
        if (g < len(groupPatterns_)):
            hatch = groupPatterns_[g];
        else:
            hatch = ' ';
            
        #-- get coorrdinates of bottom
        bottom = 0;
        if (g >= 1):
            bottom = pylab.zeros(numOfBars)
            #-- iteratively add bottoms of previous plots
            for i in range(g-1,-1,-1):  
                #-- do this for all bars
                for b in range(numOfBars):    
                    bottom[b] += valueData_[i][b];
            
        #-- check if error bars available
        if (len(stdData_) > g):
            if (len(stdData_[g]) > numOfBars):   
                plot = ax.bar(ind, valueData_[g], barWidth_, color=colorCode, yerr=stdData_[g], bottom=bottom, hatch=hatch);
            else:
                plot = ax.bar(ind, valueData_[g], barWidth_, color=colorCode, bottom=bottom, hatch=hatch);
        else:
            plot = ax.bar(ind, valueData_[g], barWidth_, color=colorCode, bottom=bottom, hatch=hatch);
        plots.append(plot);
    #-- setup limits
    if (yLimMin_ > -999999 and yLimMax_ > -999999):
        ax.set_ylim(yLimMin_, yLimMax_);
        
    #-- setup graph
    if (showLegend_):
        box = ax.get_position()
        ax.set_position([box.x0 - box.width *0.05, box.y0 + box.height * 0.19, box.width*1.14, box.height * 0.9]);
    
        #-- setup legend
        legendItems = [];
        for g in range(len(plots)):
            legendItems.append(plots[g][0]);
            
        legend = plt.legend(flip(legendItems,legendCols_), flip(groupLabels_,legendCols_), loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=legendCols_);
        for t in legend.get_texts():
            t.set_fontsize(legendFontSize_)
        if (len(groupLabels_) <= 0):
            legend.set_visible(False);
            ax.set_position([box.x0 - box.width *0.05, box.y0, box.width*1.14, box.height * 1.07]);
    else:
        box = ax.get_position()
        ax.set_position([box.x0 - box.width *0.05, box.y0 + box.height * 0.05, box.width*1.14, box.height * 1]);
    
    #-- if file name empty, show the graph
    if (not holdFigure_):
        if (len(fileName_) > 0):
            pylab.savefig(fileName_, format='png')
            if (SHOW_OUTPUT == True):
                print("Saved " + fileName_); 
        else:
            pylab.show()
    
    return fig;
    
#
def createStackedPlot(xData_ = [], yData_ = [], xLabel_ = "", yLabel_ = "", legendLabels_ = [], colors_ = [], xTickLabels_=[],
                      xLimMin_=-999999,xLimMax_=-999999, yLimMin_=-999999,yLimMax_=-999999,
                      labelFontSize_ = LABEL_FONT_SIZE, tickFontSize_ = TICK_FONT_SIZE, legendFontSize_ = LABEL_FONT_SIZE, size_=(12,6), 
                      legendCols_ = 4, fileName_ = "", holdFigure_=False, figure_=None):
    """
    Create a stacked plot from a 2D list of data, where 1st dimension is the stacks and 2nd dimension
    is the values across the X axis.
    """

    xTickLabels_ = list(xTickLabels_);

    #-- create graph
    if (figure_ == None):
        fig = pylab.figure(figsize=size_, dpi=DPI);
    else:
        fig = figure_;
    pylab.xlabel(xLabel_, size=labelFontSize_);
    pylab.ylabel(yLabel_, size=labelFontSize_);
    pylab.xticks(size=tickFontSize_);
    pylab.yticks(size=tickFontSize_);
    
    ax = fig.add_subplot(111);
    box = ax.get_position();
    
    #-- prepare data:
    dataStack = numpy.cumsum(yData_, axis=0)



    #-- plot data, filling in legend data as well:
    plots = [];
    legendItems = [];
    for stack in range(len(dataStack)):
        #-- pick a color
        colorCode = 'b';
        edgeColor = 'k';
        lineWidth = 1.0;
        if (len(colors_) > stack):
            colorCode = colors_[stack];
        elif (len(DEFAULT_COLORS) > colors_):
            colorCode = DEFAULT_COLORS[colors_];
        if (stack == len(dataStack)-1):
            lineWidth = 0.0;
        
        #-- plot
        if (stack == 0):
            plot = ax.fill_between(xData_, 0, dataStack[0,:], color=colorCode, edgecolor=edgeColor, linewidth=lineWidth);
        else:
            plot = ax.fill_between(xData_, dataStack[stack-1,:], dataStack[stack,:], color=colorCode, edgecolor=edgeColor, linewidth=lineWidth);

        legendItems.append(math.Rectangle((0, 0), 1, 1, fc=colorCode));
        plots.append(plot);

    #-- apply custom x tick labels
    if (len(xTickLabels_) > 0):
        if (len(xTickLabels_) == len(xData_)):
            print("!!! xTickLabels_ NOT IMPLEMENTED");
            #pylab.xticks(xData_,xTickLabels_);
            #a=0;
        else:
            raise Exception("xTickLabels_ and x axis data must be of the same length");

    #-- set limits on axes
    if (xLimMin_ > -999999 and xLimMax_ > -999999):
        ax.set_xlim(xLimMin_, xLimMax_);
     
    if (yLimMin_ > -999999 and yLimMax_ > -999999):
        ax.set_ylim(yLimMin_, yLimMax_);
    
    #-- setup graph
    if (len(legendLabels_) > 0):
        box = ax.get_position()
        ax.set_position([box.x0 - box.width *0.05, box.y0 + box.height * 0.19, box.width*1.14, box.height * 0.9]);
    
        #-- setup legend
        legend = plt.legend();
        legend = plt.legend(flip(legendItems,legendCols_), flip(legendLabels_,legendCols_), loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=legendCols_);
        for t in legend.get_texts():
            t.set_fontsize(legendFontSize_)
        #p1 =
        #p2 = Rectangle((0, 0), 1, 1, fc="red")
        #legend([p1, p2], legendLabels_)
    else:
        box = ax.get_position()
        ax.set_position([box.x0 - box.width *0.05, box.y0 + box.height * 0.05, box.width*1.14, box.height * 1]);
    
    #-- if file name empty, show the graph
    if (not holdFigure_):
        if (len(fileName_) > 0):
            pylab.savefig(fileName_, format='png')
            if (SHOW_OUTPUT == True):
                print("Saved " + fileName_ + " with dpi " + str(DPI)); 
        else:
            pylab.show();
    return fig;
     
#
def createPlot(xData_ = [], yData_ = [], xLabel_ = "", yLabel_ = "", legendLabels_ = [], markers_ = [], colors_ = [], xTickLabels_=[], yTicksStep_ = 0, yTicksStepMultiplier_ = 1,
               boxPlots_=False, boxPlotWidth_=-1, confidenceIntervals_=False, lineWidth_ = 1, lineStyles_ = [], markerSize_=5, xAxisGroupSize_ = 0,
               extraMarkers_= [] , extraMarkersStep_ = 999999,
               xLimMin_=-999999,xLimMax_=-999999, yLimMin_=-999999,yLimMax_=-999999,
               labelFontSize_ = LABEL_FONT_SIZE, tickFontSize_ = TICK_FONT_SIZE, legendFontSize_ = LABEL_FONT_SIZE, size_=(12,6), 
               legendCols_ = 4, fileName_ = "", holdFigure_=False, figure_=None,
               doWilcoxon_=False, doAverages_=False):
    """
    Create 2D plot based on a number of data sets.
    If boxPlots=true, data in the same group and on the same x position is displayed as box plot.
    
    Specify markers as strings or additionally specify colors as a list of rgb touples.
    
    Specify an array of extraMarkers to show e.g. crosses over lines where density of points is too
    high. Specify extraMarkerStep in x-axis units to be large enough for markers to be apart.

    Returns: the matplotlib figure
    
    Parameters:
    ---------------------------------------
    doWilcoxon_: Set to true to put asterix when the difference of 2 means is significant, based on the Wilcoxon signed-rank test (http://en.wikipedia.org/wiki/Wilcoxon_signed-rank_test)
    yTicksStep_ : If > 0, the yticks will be changed so that each yTicksStep_ value is shown
    yTicksStepMultiplier_: if yTicksStep_>0, additionally this parameter specified how the yTicksStep_ is multiplied when shown on the graph.
    xAxisGroupSize_: if > 0, specifies how many points along the X axis will be linked by a line. E.g if xAxisGroupSize_ = 3, there will be a gap in the line between each 3rd,6th,9th,.. and 4th,7th,10th,.. point.
    
    Example: box plots next to each other:
    ---------------------------------------
    If we have a parameter value we vary and a list of values for each parameter (e.g. from different runs), we
    can plot box plots next to each other and show the parameter value along the X axis:
    groups = [1,2,3,4,5];
    yData = [[dataVals1,dataVals2,dataVals3,dataVals4,dataVals5]];
    createPlot(groups,yData);
    
    Example multiple box plots on top of each other:
    ------------------------------------------------
    For example for box plots, if we have 3 groups (plotted along x axis, where group is just a parameter) 
    and 2 parameter values that we want to compare, we want to plot for each group
    2 box plots one above another. Then the box plots are connected between groups.
    
    groups = [10,30,50];
    yData = [   [data_10_vals1, data_30_vals1,data_50_vals1],
                [data_10_vals2, data_30_vals2,data_50_vals2] ];
    createPlot(groups,yData,boxPlots_=True);
    
    """

    xTickLabels_ = list(xTickLabels_);

    #-- check parameters:
    if (doWilcoxon_):
        if (len(yData_) != 2):
            doWilcoxon_ = False;
            print ("!!! Cannot perform Wilcoxon signed-rank test: Y data must be of size 2");
        else:
            if (type(yData_[0]) != list):
                doWilcoxon_ = False;
                print ("!!! Cannot perform Wilcoxon signed-rank test: each element of Y data must be a list");
        #-- if no xtick labels given, make them the x axis values
        if (len(xTickLabels_) == 0):
            xTickLabels_ = [];
            for i in range(len(xData_)):
                xTickLabels_.append(str(xData_[i]));

    
    #-- create graph
    if (figure_ == None):
        fig = pylab.figure(figsize=size_, dpi=DPI);
    else:
        fig = figure_;
    pylab.xlabel(xLabel_, size=labelFontSize_);
    pylab.ylabel(yLabel_, size=labelFontSize_);
    pylab.xticks(size=tickFontSize_);
    pylab.yticks(size=tickFontSize_);


    #-- prepare x data
    xData = [];
    if (type(yData_[0]) == list):
        #-- more plots on y, look if there are also corresponding multiple arrays for x data
        if (type(xData_[0]) == list):
            #-- both are 2d arrays, leave it
            xData = xData_;
        else:
            #-- more plots on y but only 1 list of x values, copy them over repeatedly:
            for i in range(len(yData_)):
                xData.append(xData_);
    else:
        #-- yData is a simple 1d array, plotting 1 graph only
        if (type(xData_[0]) == list):
            #-- xData is however 2d array, take only the first set
            xData = xData_[0];
        else:
            #-- dimensions match
            xData = xData_
    
    #-- prepare extra marker data
    extraYData = [];
    extraXData = [];
    if (len(extraMarkers_) > 0):
        if (type(yData_[0]) == list):
            for i in range(len(yData_)):
                extraYData.append([]);
                for t in range(len(yData_[i])):
                    if (t%extraMarkersStep_ == 0):
                        extraYData[i].append(yData_[i][t]);
                        if (i==0):
                            extraXData.append(t);
        else:
            for t in range(len(yData_)): 
                if (t%extraMarkersStep_ == 0):
                    extraYData[i].append(yData_[i][t]);
                    extraXData.append(t);
               
    #-- prepare box plot width
    if ((boxPlots_ and boxPlotWidth_ <= 0) or confidenceIntervals_):
        boxPlotWidth_ = abs(xData[0][-1] - xData[0][0]) / 20.0;

    plots = [];
    #-- check Y data and plot
    if (type(yData_[0]) == list):

        #-- plot individual graphs, yData is a 2d array
        for i in range(len(yData_)):
            legendLabel = " ";
            #-- choose a legend label
            if (len(legendLabels_) > i):
                legendLabel = legendLabels_[i];
            #-- choose a marker
            marker = DEFAULT_MARKERS[0];
            if (len(markers_) > i):
                marker = markers_[i];
            elif (len(DEFAULT_MARKERS) > i):
                marker = DEFAULT_MARKERS[i];
            #-- choose a color, default to marker color
            color = marker[0:1];
            if (len(colors_) > i):
                color = colors_[i];
            lineStyle = '-';
            if (len(lineStyles_) > i):
                lineStyle = lineStyles_[i];
            elif (lineWidth_ == 0):
                lineStyle = '';
             
            #-- apply custom x tick labels
            if (len(xTickLabels_) > 0):
                if (len(xTickLabels_) == len(xData[i])):
                    plt.xticks(xData[i],xTickLabels_);
                else:
                    raise Exception("xTickLabels_ and x axis data must be of the same length");



            #-- plot
            if (type(yData_[i]) == list):
                #-- only plot median of data that is a list. Box plots can be added later if set
                #-- get medians one by one, as numpy can't deal with lists of different lengths
                medians = [];
                for q in range(len(yData_[i])):
                   # print(yData_[i]);
                    #print(crData.getMedian(yData_[i][q]));
                    if (doAverages_):
                        medians.append(crData.getAverage(yData_[i][q]));
                    else:
                        medians.append(numpy.median(yData_[i][q]));
                    if (doWilcoxon_ and i == 1):
                        #-- do the Wilcoxon test on individual samples (that together form a median) and compare them to runs of previous data set:
                        #print("Do Wilxocon test for column " + str(q) + " medians " + str(numpy.median(yData_[i][q])) + ", " + str(numpy.median(yData_[0][q])));
                        pVal = scipy.stats.wilcoxon(yData_[i][q],yData_[0][q])[1];
                        #print("  p=" + str(pVal));
                        if (pVal < 0.01):
                            xTickLabels_[q] = str(xTickLabels_[q]) + "**";
                        elif (pVal < 0.05):
                            xTickLabels_[q] = str(xTickLabels_[q]) + "*";

                if (confidenceIntervals_):
                    dataDof = [(len(yData_[i][q])-1) for q in range(len(medians))]; #degrees of freedom is sample size -1
                    dataStd = [numpy.std(yData_[i][q]) for q in range(len(medians))];
                    (_, caps, _) = plt.errorbar(xData[i], medians, yerr=scipy.stats.t.ppf(0.95, dataDof)*dataStd, color=color, linewidth=0, elinewidth=lineWidth_, capsize=markerSize_-2, linestyle=lineStyle);
                    for cap in caps:
                        if (lineWidth_ == 0):
                            cap.set_markeredgewidth(3);
                        else:
                            cap.set_markeredgewidth(lineWidth_);

                #-- draw, in line segments
                numOfSegments = 0;
                lineSegmentLength = xAxisGroupSize_;
                if (xAxisGroupSize_ > 0):
                    numOfSegments = int(math.ceil(len(xData[i]) / xAxisGroupSize_));

                else:
                    numOfSegments = 1;
                    lineSegmentLength = len(xData[i]);

                for seg in range(numOfSegments):
                    segStart = seg * lineSegmentLength;
                    segEnd = segStart + lineSegmentLength;
                    if (segEnd > len(xData[i])):
                        segEnd = len(xData[i]-1);

                    plot = pylab.plot(xData[i][segStart:segEnd], medians[segStart:segEnd], marker, color=color, label = legendLabel, linewidth=lineWidth_, linestyle=lineStyle, markersize=markerSize_);

            else:
                if (confidenceIntervals_):
                    dataDof = [(len(yData_[i])-1) for x in range(len(yData_))]; #degrees of freedom is sample size -1
                    dataStd = [numpy.std(yData_[i]) for x in range(len(yData_))];
                    (_, caps, _) = plt.errorbar(xData[i], medians, yerr=scipy.stats.t.ppf(0.95, dataDof)*dataStd, color=color, linewidth=0, elinewidth=lineWidth_, capsize=markerSize_-2, linestyle=lineStyle, markersize=markerSize_);
                    for cap in caps:
                        if (lineWidth_ == 0):
                            cap.set_markeredgewidth(3);
                        else:
                            cap.set_markeredgewidth(lineWidth_);

                #-- draw, in line segments
                numOfSegments = 0;
                lineSegmentLength = xAxisGroupSize_;
                if (xAxisGroupSize_ > 0):
                    numOfSegments = int(math.ceil(len(xData[i]) / xAxisGroupSize_));

                else:
                    numOfSegments = 1;
                    lineSegmentLength = len(xData[i]);

                for seg in range(numOfSegments):
                    segStart = seg * lineSegmentLength;
                    segEnd = segStart + lineSegmentLength;
                    if (segEnd > len(xData[i])):
                        segEnd = len(xData[i]-1);

                    plot = pylab.plot(xData[i], yData_[i], marker, color=color, label = legendLabel, linewidth=lineWidth_, linestyle=lineStyle, markersize=markerSize_);

            
            #-- do extra markers
            if (len(extraYData) > 0):
                plot = pylab.plot(extraXData,extraYData[i],extraMarkers_[i]);    
                

            plots.append(plot);
            #-- do box plots
            if (boxPlots_):
                boxPlot = pylab.boxplot(yData_[i],positions=xData[i],widths=boxPlotWidth_);
                pylab.setp(boxPlot['boxes'], color=color);
                pylab.setp(boxPlot['whiskers'], color=color);
                pylab.setp(boxPlot['medians'], color=color);
                pylab.setp(boxPlot['fliers'], color=color);
                pylab.setp(boxPlot['caps'], color=color);

                boxPlotLineWidth = min(lineWidth_,2);
                if (boxPlotLineWidth <= 0):
                    boxPlotLineWidth = 1;
                boxPlotLineWidth = 1;
                for box in boxPlot['boxes']:
                    box.set(linewidth=boxPlotLineWidth)
                for median in boxPlot['medians']:
                    median.set(linewidth=boxPlotLineWidth)
                for cap in boxPlot['caps']:
                    cap.set(linewidth=boxPlotLineWidth)
                for cap in boxPlot['whiskers']:
                    cap.set(linewidth=boxPlotLineWidth)


           
    else:
        #-- plot 1 graph, yData is a simple 1d array
        legendLabel = " ";

        marker = DEFAULT_MARKERS[0];
        if (len(legendLabels_) > 0):
            legendLabel = legendLabels_[0];
        if (len(markers_) > 0):
            marker = markers_[0];

        #-- apply custom x tick labels
        if (len(xTickLabels_) > 0):
            if (len(xTickLabels_) == len(xData)):
                plt.xticks(xData,xTickLabels_);
            else:
                raise Exception("xTickLabels_ and x axis data must be of the same length");
        #-- plot
        lineStyle = '-';

        if (len(lineStyles_) > 0):
            lineStyle = lineStyles_[0];

        #-- draw, in line segments
        numOfSegments = 0;
        lineSegmentLength = xAxisGroupSize_;
        if (xAxisGroupSize_ > 0):
            numOfSegments = int(math.ceil(len(xData) / xAxisGroupSize_));

        else:
            numOfSegments = 1;
            lineSegmentLength = len(xData);

        for seg in range(numOfSegments):
            segStart = seg * lineSegmentLength;
            segEnd = segStart + lineSegmentLength;
            if (segEnd > len(xData)):
                segEnd = len(xData-1);

            if (0 < len(colors_)):
                plot = pylab.plot(xData[segStart:segEnd], yData_[segStart:segEnd], marker, color=colors_[0], label = legendLabel, linewidth=lineWidth_, linestyle=lineStyle, markersize=markerSize_);
            else:
                plot = pylab.plot(xData[segStart:segEnd], yData_[segStart:segEnd], marker, label = legendLabel, linewidth=lineWidth_, linestyle=lineStyle, markersize=markerSize_);

            #-- do extra markers
            if (len(extraYData) > 0):
                plot = pylab.plot(extraXData[segStart:segEnd],extraYData[segStart:segEnd],extraMarkers_[0]);
        
        plots.append(plot);

    #-- output graph
    ax = fig.add_subplot(111);
    box = ax.get_position();
    if (figure_ == None):    
        ax.set_position([box.x0 - box.width *0.05, box.y0 + box.height * 0.19, box.width*1.14, box.height * 0.9]);

    #-- adjust x axis zoom because box plots could break it:
    if (boxPlots_ or confidenceIntervals_):
        ax.set_xlim(xData[0][0]-2*boxPlotWidth_/3.0,xData[0][-1]+2*boxPlotWidth_/3.0);

    if (xLimMin_ > -999999 and xLimMax_ > -999999):
        ax.set_xlim(xLimMin_, xLimMax_);

    if (yLimMin_ > -999999 and yLimMax_ > -999999):
        ax.set_ylim(yLimMin_, yLimMax_);

    #-- apply custom y ticks
    ticks = [];
    ticksLabels = [];
    if (yTicksStep_ > 0):
        start, stop = ax.get_ylim();
        if (yLimMax_ != -999999):
            stop = yLimMax_;
        #else:
            #yLimMax_ = stop;
            #print("YES")
        if (yLimMin_ != -999999):
            start = yLimMin_;
        if (yLimMin_ == - yLimMax_*0.05):
            start = 0;

        ticks = numpy.arange(start, stop + yTicksStep_, yTicksStep_);
        ax.set_yticks(ticks);
        if (yTicksStepMultiplier_ != 1):
            for t in range(len(ticks)):
                ticksLabels.append(ticks[t] * yTicksStepMultiplier_);
            #print(ticksLabels);
            ax.set_yticklabels(ticksLabels);
            ticks = ticksLabels;

        ticksLabels = [];
        for t in range(len(ticks)):
            intVal = int(ticks[t]);
            if (intVal >= 1000):
                ticksLabels.append("{}k".format( intVal / 1000));
            else:
                ticksLabels.append(ticks[t]);
        ax.set_yticklabels(ticksLabels);

    #-- modify yTicks so that if numbers are too high, should show 1000s as 'k'
    if (False): #this is unreliable as the yticks are not alwas correct
        for plot in plots:
            yTickLabels = [item.get_text() for item in ax.get_yticklabels()];
            print(yTickLabels)
            shouldConvertHighVals = False;
            maxYValueFound = -999999;

            for y in range(len(yTickLabels)):
                try:
                    if (int(yTickLabels[y]) > maxYValueFound):
                        maxYValueFound = int(yTickLabels[y]);
                except ValueError:
                    pass;

            if (maxYValueFound >= 5000):
                shouldConvertHighVals = True;

            if (shouldConvertHighVals):
                for y in range(len(yTickLabels)):
                    try:
                        intVal = int(yTickLabels[y]);
                        if (intVal >= 1000):
                            yTickLabels[y] = "{}k".format( intVal / 1000);
                    except ValueError:
                        pass;
                print(yTickLabels)
                ax.set_yticklabels(yTickLabels);



    
    legendItems = [];
    for g in range(len(plots)):
        legendItems.append(plots[g][0]);
   
        
    #legend = plt.legend(flip(legendItems,legendCols_), flip(groupLabels_,legendCols_), loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=legendCols_);

    legend = ax.legend(flip(legendItems,legendCols_), flip(legendLabels_,legendCols_),loc='upper center', bbox_to_anchor=(0.5, -0.12), ncol=legendCols_)
    for t in legend.get_texts():
        if (type(legendFontSize_) == str):
            t.set_fontsize(legendFontSize_)
        else:
            font = math.QFont(t.font());
            font.setPointSize(legendFontSize_);
            t.setFont(font);
    if (len(legendLabels_) <= 0):
        legend.set_visible(False);
        if (figure_ == None): 
            ax.set_position([box.x0 - box.width *0.05, box.y0+box.height*0.05, box.width*1.14, box.height * 1.05]);
    
    #-- if file name empty, show the graph

    if (not holdFigure_):
        if (len(fileName_) > 0):
            pylab.savefig(fileName_, format='png')
            if (SHOW_OUTPUT == True):
                print("Saved " + fileName_ + " with dpi " + str(DPI)); 
        else:
            pylab.show();
    return fig;

def create4Dplot(plotData_, xLabel_ = "", yLabel_ = "", zLabel_ = "", xTickLabels_=[], yTickLabels_=[], zTickLabels_=[],
                valueRangeMin_=0,valueRangeMax_=1,colorMap_=plt.cm.get_cmap("OrRd"),
                markerSize_=500,
                labelFontSize_ = LABEL_FONT_SIZE, tickFontSize_ = TICK_FONT_SIZE, legendFontSize_ = LABEL_FONT_SIZE, size_=(14,12),
                fileName_ = "", holdFigure_=False, figure_=None):
    """
    Create a 3D scatter plot where 4th dimension encoded into color of points.
    plotData_ is a 4D array where 1st 3 dimensions are x,y,z coordinates and 4th dimesion is the value.

    The displayed scatter plot has z axis as depth, y axis as height.
    """

    #-- create graph
    if (figure_ == None):
        fig = pylab.figure(figsize=size_, dpi=DPI);
    else:
        fig = figure_;

    #-- create dummy countour plot so that color bar can be shown
    Z = [[0,0],[0,0.8]]
    levels = numpy.linspace(valueRangeMin_,valueRangeMax_,10);
    CS3 = plt.contourf(Z, levels, cmap=colorMap_);
    #-- hide the countour plot
    plt.clf();

    ax = Axes3D(fig);

    maxZVal = len(zTickLabels_)-1;
    for point in plotData_:
        convertedColorVal = (point[3]-valueRangeMin_)*(1.0/(valueRangeMax_-valueRangeMin_));
        color=colorMap_(convertedColorVal);
        #-- point in space
        ax.scatter(point[0], point[1], point[2], s=markerSize_, zdir='y', c=color)
        #print(str(point) + "   " + str(convertedColorVal));
        #ax.plot([point[0],point[0]], [point[1],point[1]],'o', markersize=15, linewidth=50, markeredgecolor='k', color=color, zs=point[2], zdir='y'); #use plot instead of scatter so that joining lines to walls can appear above the points

    #-- join the points to walls of the graph if they are not on a wall
    for point in plotData_:
        if (point[0] > 0 and point[1] > 0 and point[2] < maxZVal ):
            ax.plot([0,point[0]], [point[1],point[1]],'k--', zs=point[2], zdir='y')
            ax.plot([point[0],point[0]], [0,point[1]],'k--', zs=point[2], zdir='y')
            ax.plot([point[0],point[0]], [maxZVal,point[2]],'k--', zs=point[1], zdir='z');


    #-- define labels
    ax.set_xlabel(xLabel_,size=labelFontSize_);
    ax.set_ylabel(zLabel_,size=labelFontSize_);
    ax.set_zlabel(yLabel_,size=labelFontSize_);

    if (len(xTickLabels_) > 0):
        ax.set_xlim3d(0, len(xTickLabels_)-1);
        plt.xticks(range(len(xTickLabels_)),xTickLabels_,size=tickFontSize_);
    if (len(yTickLabels_) > 0):
        ax.set_zlim3d(0, len(yTickLabels_)-1);
        ax.set_zticklabels(yTickLabels_,size=tickFontSize_);
        ax.set_zticks(range(len(yTickLabels_)));
    if (len(zTickLabels_) > 0):
        ax.set_ylim3d(0, len(zTickLabels_)-1)
        plt.yticks(range(len(zTickLabels_)),zTickLabels_,size=tickFontSize_);

    #-- create color bar
    levelLabels = [];
    for level in levels:
        levelLabels.append('{0:.2}'.format(level));
    colorBar = plt.colorbar(CS3);
    colorBar.ax.set_yticklabels(levelLabels);
    box = colorBar.ax.get_position();
    colorBar.ax.set_position([box.x0, box.y0 + box.height*0.1, box.x0+box.width, box.height * 0.8]);

    #-- if file name empty, show the graph
    if (not holdFigure_):
        if (len(fileName_) > 0):
            plt.savefig(fileName_, format='png')
            if (SHOW_OUTPUT == True):
                print("Saved " + fileName_ + " with dpi " + str(DPI));
        else:
            plt.show();
    return fig;


def createScatterPlot(xData_, yData_, xLabel_ = "", yLabel_ = "", xTickLabels_=[], yTickLabels_=[], dataLabels_=[],
                xLimMin_=-999999,xLimMax_=-999999, yLimMin_=-999999, yLimMax_=-999999,
                colors_ = [], sizes_ = 200, marker_ = 'o', showValueLabels_ = True, valueLabelsDecimalPlaces_ = 2, graphId_ = 0,
                labelFontSize_ = LABEL_FONT_SIZE, tickFontSize_ = TICK_FONT_SIZE, legendFontSize_ = LABEL_FONT_SIZE, size_=(12,12),
                fileName_ = "", holdFigure_=False, figure_=None):
    """
    Create a 2D scatter plot.

    sizes_ can be either a scalar or a shape list
    """
    #-- check inputs
    if (xLimMin_ > -999999 and xLimMax_ > -999999):
        if (max(xData_) > xLimMax_):
            raise ValueError("Data does not fit X axis: xMax={} xDataMax={}".format(xLimMax_,max(xData_)))
        if (min(xData_) < xLimMin_):
            raise ValueError("Data does not fit X axis: xMin={} xDataMin={}".format(xLimMin_,min(xData_)))
    if (yLimMin_ > -999999 and yLimMax_ > -999999):
        if (max(yData_) > yLimMax_):
            raise ValueError("Data does not fit Y axis: yMax={} yDataMax={}".format(yLimMax_,max(yData_)))
        if (min(yData_) < yLimMin_):
            raise ValueError("Data does not fit Y axis: yMin={} yDataMin={}".format(yLimMin_,min(yData_)))

    if (len(colors_) == 0):
        colors_ = ['b' for i in range(len(xData_))]

    #-- create graph
    if (figure_ == None):
        fig = pylab.figure(figsize=size_, dpi=DPI);
    else:
        fig = figure_;


    ax = fig.add_subplot(111);
    plt.scatter(xData_, yData_, c=colors_, s=sizes_, marker=marker_)

    #-- define labels and ticks
    pylab.xlabel(xLabel_, size=labelFontSize_);
    pylab.ylabel(yLabel_, size=labelFontSize_);
    pylab.xticks(size=tickFontSize_);
    pylab.yticks(size=tickFontSize_);

    #-- setup limits
    if (xLimMin_ > -999999 and xLimMax_ > -999999):
        ax.set_xlim(xLimMin_, xLimMax_);
    if (yLimMin_ > -999999 and yLimMax_ > -999999):
        ax.set_ylim(yLimMin_, yLimMax_);

    counter = graphId_;
    for label, x, y in zip(dataLabels_, xData_, yData_):
        plt.annotate( label,
        xy = (x, y), xytext = (20*(counter+1), 20*(counter+1)),
        textcoords = 'offset points', ha = 'right', va = 'bottom',
        bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),
        arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'));
        counter += 1;

    if (showValueLabels_):
        valueLabels = [];
        for i in range(len(xData_)):
            formatString = "[{0:." + str(valueLabelsDecimalPlaces_) + "f},{1:." + str(valueLabelsDecimalPlaces_) + "f}]";
            valueLabels.append(formatString.format(xData_[i], yData_[i]))

        counter = graphId_;
        for label, x, y in zip(valueLabels, xData_, yData_):
            plt.annotate( label,
            xy = (x, y), xytext = (-20, -20*(counter+1)),
            textcoords = 'offset points', ha = 'left', va = 'bottom',
            bbox = dict(boxstyle = 'round,pad=0.5', fc = '#DDDDDD', alpha = 0.5),
            arrowprops = dict(arrowstyle = '-', connectionstyle = 'arc3,rad=0'));
            #print(label + str(counter))
            counter += 1;

    #-- if file name empty, show the graph
    if (not holdFigure_):
        if (len(fileName_) > 0):
            plt.savefig(fileName_, format='png')
            if (SHOW_OUTPUT == True):
                print("Saved " + fileName_ + " with dpi " + str(DPI));
        else:
            plt.show();
    return fig;

        
def stringToFile(string_, fileName_):
    """
    Write a string into a text file
    """
    myFile = open(fileName_, 'w');
    myFile.write(string_);
    myFile.close();
    if (SHOW_OUTPUT == True):
            print("Saved " + fileName_); 
        
def flip(items, ncol):
    return itertools.chain(*[items[i::ncol] for i in range(ncol)])