# -*- coding: utf-8 -*-
"""
Created on Wed Jul  3 15:03:13 2019

@author: Dhirendra Vaidya (dhirendra22121987@gmail.com)
"""

import sys
#from PyQt5.QtWidgets import QApplication, QWidget, QPushButton, QLineEdit, QMainWindow, QFileDialog, QStyle
import PyQt5.QtWidgets as QtWidgets
from PyQt5.QtCore import pyqtSlot, Qt
from PyQt5.QtGui import QIcon, QImage
import io

from matplotlib.backends.backend_qt5agg import (FigureCanvas, NavigationToolbar2QT as NavigationToolbar)
from matplotlib.figure import Figure
import numpy as np
import matplotlib.pyplot as plt

from fitting_models import *
from string import ascii_lowercase

def getVLayout(Widgets):
    VL_ = QtWidgets.QVBoxLayout()
    for w in Widgets:
        VL_.addWidget(w)
    return VL_

def getHLayout(Widgets):
    HL_ = QtWidgets.QHBoxLayout()
    for w in Widgets:
        HL_.addWidget(w)
    return HL_

class AdvancedPlot(QtWidgets.QMainWindow):
    #class SelectAction(QtWidgets.QAction):

    def __init__(self, nRows, nCols, sharex=False, sharey=False, sublabels=True):
        QtWidgets.QMainWindow.__init__(self)
        self.sharex = sharex
        self.sharey = sharey
        self.sublabels = sublabels
        self.initUI(nRows, nCols, self.sharex, self.sharey, self.sublabels)

        self.click_action = self.canvas.mpl_connect('button_press_event', self.onClick)
        self.hover_action = self.canvas.mpl_connect('motion_notify_event', self.onHover)


        self.canvas.setFocusPolicy( Qt.ClickFocus )
        self.canvas.setFocus()

        self.selectedLine = None
        self.selectionFlag = True
        self.text_string = ''
        self.annoted_text = None

        self.toggledActions = {}
        self.toggledActions['pan'] = self.NT.actions()[4]
        self.toggledActions['zoom'] = self.NT.actions()[5]
        self.toggledActions['select'] = self.NT.actions()[11]
        self.toggledActions['annotate'] = self.NT.actions()[12]

        self.toggledActions['zoom'].toggled.connect(lambda: self.toggleResolve(self.toggledActions['zoom']))
        self.toggledActions['pan'].toggled.connect(lambda: self.toggleResolve(self.toggledActions['pan']))
        self.toggledActions['select'].toggled.connect(lambda: self.toggleResolve(self.toggledActions['select']))
        self.toggledActions['annotate'].toggled.connect(lambda: self.toggleResolve(self.toggledActions['annotate']))

    def toggleResolve(self, selectedAction):
        if selectedAction.isChecked():
            for key in self.toggledActions:
                if self.toggledActions[key] != selectedAction:
                    self.toggledActions[key].setChecked(False)

    def copyPlotToClipboard(self):
        buf = io.BytesIO()
        self.canvas.figure.savefig(buf)
        QtWidgets.QApplication.clipboard().setImage(QImage.fromData(buf.getvalue()))

    def initUI(self, nRows, nCols, sharex, sharey, sublabels):
        self._main = QtWidgets.QWidget()
        self.setCentralWidget(self._main)

        self.mainlayout = QtWidgets.QVBoxLayout(self._main)
        self.setLayout(self.mainlayout)

        canvas = FigureCanvas(Figure(figsize=(15, 9)))
        self.ax = []
        for i in range(nRows*nCols):
            if i==0:
                ax0 = canvas.figure.add_subplot(nRows, nCols, 1)
                self.ax.append(ax0)
            if i > 0:
                if sharex and not sharey:
                    self.ax.append(canvas.figure.add_subplot(nRows, nCols, i+1, sharex=ax0))
                elif sharey and not sharex:
                    self.ax.append(canvas.figure.add_subplot(nRows, nCols, i+1, sharey=ax0))
                elif sharex and sharey:
                    self.ax.append(canvas.figure.add_subplot(nRows, nCols, i+1, sharex=ax0, sharey=ax0))
                else:
                    self.ax.append(canvas.figure.add_subplot(nRows, nCols, i+1))

        for Ai, axx in enumerate(self.ax):
            if sublabels:
                axx.text(0.05,1.025,'('+ascii_lowercase[Ai]+')', transform=axx.transAxes, fontsize=14)
            axx.tick_params(labelsize=12)

        canvas.figure.tight_layout()
        self.canvas = canvas
        self.mainlayout.addWidget(canvas)

        clearAction = QtWidgets.QAction(QIcon('/Users/dhirendravaidya/Southampton_Work/Work/paper2/Advanced_Plot/clear.png'), 'clear', self)
        clearAction.triggered.connect(self.clearPlot)

        selectAction = QtWidgets.QAction(QIcon('/Users/dhirendravaidya/Southampton_Work/Work/paper2/Advanced_Plot/select.png'), 'select on', self)
        selectAction.setCheckable(True)
        selectAction.setChecked(True)
        selectAction.toggled.connect(self.selectionOnOff)

        annotateAction = QtWidgets.QAction(QIcon('/Users/dhirendravaidya/Southampton_Work/Work/paper2/Advanced_Plot/annotate.png'), 'annotate', self)
        annotateAction.setCheckable(True)
        annotateAction.toggled.connect(self.annotateOnOff)

        copyAction = QtWidgets.QAction(QIcon('/Users/dhirendravaidya/Southampton_Work/Work/paper2/Advanced_Plot/copy.png'), 'copy', self)
        copyAction.setShortcut('Ctrl+C')
        copyAction.triggered.connect(self.copyPlotToClipboard)

        self.NT = NavigationToolbar(canvas, self)
        self.xy_vals = self.NT.actions()[-1]
        self.NT.removeAction(self.NT.actions()[-1])
        self.NT.addAction(clearAction)
        self.NT.addAction(selectAction)
        self.NT.addAction(annotateAction)
        self.NT.addAction(copyAction)
        self.NT.addAction(self.xy_vals)
        self.addToolBar(self.NT)

        self.cb_fitting_equation = QtWidgets.QComboBox()
        self.cb_fitting_equation.addItem('Linear')
        self.cb_fitting_equation.addItem('Poly2')
        self.cb_fitting_equation.addItem('Exponential')
        self.cb_fitting_equation.activated[str].connect(self.selectEquation)


        self.b_fit = QtWidgets.QPushButton('fit')

        HL_ = getHLayout([self.cb_fitting_equation, self.b_fit])
        self.mainlayout.addLayout(HL_)

    def selectLine(self, l):
        self.selectedLine = l
        self.delete_line_action = self.canvas.mpl_connect('key_release_event', self.onKeyPress)
        l.set_markeredgecolor('k')
        l.set_linewidth(3)
        self.canvas.draw()

    def deselectLine(self, l):
        self.selectedLine = None
        l.set_markeredgecolor('None')
        l.set_linewidth(1.5)
        self.canvas.draw()

    def enableSelection(self):
        self.disableAnnotation()
        self.click_action = self.canvas.mpl_connect('button_press_event', self.onClick)

    def disableSelection(self):
        self.canvas.mpl_disconnect(self.click_action)

    def selectionOnOff(self, checkEvent):
        if checkEvent:
            self.enableSelection()

        if not checkEvent:
            self.disableSelection()

    def enableAnnotation(self):
        self.disableSelection()
        self.annotate_action = self.canvas.mpl_connect('button_press_event', self.annotate)

    def disableAnnotation(self):
        self.canvas.mpl_disconnect(self.annotate_action)

    def annotateOnOff(self, checkEvent):
        if checkEvent:
            self.enableAnnotation()

        if not checkEvent:
            self.disableAnnotation()

    def annotate(self, event):
        self.text_action = self.canvas.mpl_connect('key_release_event', self.typing)
        self.event = event
        self.xy_text = [event.xdata, event.ydata]
        self.active_axis = event.inaxes
        self.annoted_text = None
        self.text_string = ''

    def typing(self, event):
        if event.key not in ['backspace', 'shift', 'ctrl', 'return']:
            self.text_string = self.text_string + event.key
        if event.key == 'backspace':
            self.text_string = self.text_string[:-1]

        if self.annoted_text == None:
            self.annoted_text = self.active_axis.text(self.xy_text[0], self.xy_text[1], self.text_string, fontsize = 12)

        if self.annoted_text != None:
            self.annoted_text.remove()
            self.annoted_text = self.active_axis.text(self.xy_text[0], self.xy_text[1], self.text_string, fontsize = 12)
        self.canvas.draw()

    def deleteLine(self):
        if self.selectedLine:
            self.selectedLine.remove()
            self.canvas.mpl_disconnect(self.delete_line_action)
            self.hover_action = self.canvas.mpl_connect('motion_notify_event', self.onHover)
            self.canvas.draw()

    def onHover(self, event):
        pass
        '''
        ax = event.inaxes
        if ax:
            for l in ax.get_lines():
                if l.contains(event)[0]:
                    self.selectLine(l)
                else:
                    self.deselectLine(l)


            for t in ax.texts:
                if t.contains(event)[0]:
                    print 'here'
                    fp = t.get_font_properties()
                    fp.set_style = 'bold'
                    t.set_font_properties(fp)
                    self.canvas.draw()
                else:
                    fp = t.get_font_properties()
                    fp.set_style = 'normal'
                    t.set_font_properties(fp)
                    self.canvas.draw()
        '''
    def onClick(self, event):
        if not event.dblclick:
            ax = event.inaxes
            for l in ax.get_lines():
                if l.contains(event)[0]:
                    self.selectLine(l)
                    self.canvas.mpl_disconnect(self.hover_action)

        if event.dblclick:
            ax = event.inaxes
            for l in ax.get_lines():
                if l.contains(event)[0]:
                    self.deselectLine(l)
                    self.hover_action = self.canvas.mpl_connect('motion_notify_event', self.onHover)


    def onKeyPress(self, event):
        if event.key == 'delete':
            self.deleteLine()

    def linearFit(self):
        if self.selectedLine != None:
            l = self.selectedLine
            xdata = l.get_xdata()
            ydata = l.get_ydata()
            p = np.polyfit(xdata, ydata, 1)
            x_ = np.linspace(min(xdata), max(xdata), 100)
            y_ = p[0]*x_ + p[1]
            l.axes.plot(x_, y_, c=l.get_color())
            self.canvas.draw()
            self.deselectLine(l)
            self.hover_action = self.canvas.mpl_connect('motion_notify_event', self.onHover)

    def polyFit2(self):
        if self.selectedLine != None:
            l = self.selectedLine
            xdata = l.get_xdata()
            ydata = l.get_ydata()
            p = np.polyfit(xdata, ydata, 2)
            x_ = np.linspace(min(xdata), max(xdata), 100)
            y_ = p[0]*x_**2 + p[1]*x_ + p[2]
            l.axes.plot(x_, y_, c=l.get_color())
            self.canvas.draw()
            self.deselectLine(l)
            self.hover_action = self.canvas.mpl_connect('motion_notify_event', self.onHover)

    def exponentialFit(self):
        if self.selectedLine != None:
            l = self.selectedLine
            xdata = l.get_xdata()
            ydata = l.get_ydata()
            p = ExponentialFit(xdata, ydata).fit()
            x_ = np.linspace(min(xdata), max(xdata), 100)
            y_ = ExponentialFit(xdata, ydata).analytical(x_, p[0], p[1])
            l.axes.plot(x_, y_, c=l.get_color())
            self.canvas.draw()
            self.deselectLine(l)
            self.hover_action = self.canvas.mpl_connect('motion_notify_event', self.onHover)

    def selectEquation(self, eqnName):
        if self.selectedLine != None:
            if eqnName == 'Linear':
                self.b_fit.clicked.connect(self.linearFit)
            if eqnName == 'Poly2':
                self.b_fit.clicked.connect(self.polyFit2)
            if eqnName == 'Exponential':
                self.b_fit.clicked.connect(self.exponentialFit)


    def clearPlot(self):
        for ax in self.ax:
            ax.cla()
            ax.figure.canvas.draw()

    def set_yaxis_color(self, ax, clr, spine='left'):
        ax.spines[spine].set_color(clr)
        ax.tick_params(axis='y', colors=clr)
        ax.yaxis.label.set_color(clr)
