import sys

from PyQt5.QtCore import pyqtSignal,QPointF,Qt,QCoreApplication
from PyQt5 import QtGui
import PyQt5.QtWidgets as qt
from PyQt5 import QtCore

import PyIPSDK
import PyIPSDK.IPSDKIPLBinarization as bin
import PyIPSDK.IPSDKIPLGlobalMeasure as glbmsr
import PyIPSDK.IPSDKIPLShapeAnalysis as shapeanalysis
import PyIPSDK.IPSDKUI as ui
import PyIPSDK.IPSDKFunctionsMachineLearning as fctML

import time

try:
    import matplotlib
    matplotlib.use('QT5Agg')
    import matplotlib.pylab as plt
    from matplotlib.backends.qt_compat import QtCore, QtWidgets#, is_pyqt5
    from matplotlib.backends.backend_qt5agg import FigureCanvas, NavigationToolbar2QT as NavigationToolbar
    from matplotlib.figure import Figure
except:
    pass

import UsefullWidgets as wgt
from RangeSlider import RangeSlider

import DatabaseFunction as Dfct

import UsefullFunctions as fct

import PyIPSDK
import UsefullVariables as vrb

import UsefullDisplay as display

import numpy as np

class HistogramLabelWidget(qt.QWidget):
    """
    Widget used to display an IPSDK histogram as table
    Adapts itself to the type of histogram
    """

    def __init__(self, ipsdkLabelHistogram, xmlElement = None):
        from WidgetTypes import ScalarConstraint

        qt.QWidget.__init__(self)
        self.ipsdkLabelHistogram = ipsdkLabelHistogram

        self.labelHistogram = []

        self.fig = Figure()

        self.plotWidget = FigureCanvas(self.fig)
        self.axes = self.fig.add_subplot(111)
        self.toolbar = wgt.NavigationToolbarWidget(self.plotWidget, self)
        self.checkBoxNormalize = qt.QCheckBox("Normalize")
        self.checkBoxLog = qt.QCheckBox("Log")
        self.checkBoxAll = qt.QCheckBox("All")

        self.groupBoxToolbar = qt.QGroupBox()

        self.upLayout = qt.QGridLayout()
        self.upLayout.addWidget(self.toolbar, 0, 0)
        self.upLayout.addWidget(self.checkBoxNormalize, 0, 1)
        self.upLayout.addWidget(self.checkBoxLog, 0, 2)
        self.upLayout.addWidget(self.checkBoxAll, 0, 3)

        self.groupBoxToolbar.setLayout(self.upLayout)
        self.groupBoxToolbar.setStyleSheet('QGroupBox {border: 0px transparent; }')
        self.groupBoxToolbar.setFixedHeight(40 * vrb.ratio)

        self.layout = qt.QGridLayout()
        self.layout.addWidget(self.groupBoxToolbar, 0, 0)
        self.layout.addWidget(self.plotWidget, 1, 0)

        self.layout.setContentsMargins(10 * vrb.ratio, 10 * vrb.ratio , 10 * vrb.ratio, 10 * vrb.ratio)
        self.layout.setSizeConstraint(1)

        self.setLayout(self.layout)

        self.checkBoxLog.stateChanged.connect(self.displayLog)
        self.checkBoxAll.stateChanged.connect(self.displayAll)
        self.checkBoxNormalize.stateChanged.connect(self.displayNormalize)

        style = fct.getStyleSheet()
        self.setStyleSheet(style)

        self.resize(1200*vrb.ratio, 600*vrb.ratio)

    def initValues(self):

        self.axes.clear()

        self.valuesHistogram = self.ipsdkLabelHistogram.getMeasure("HistogramMsr").getMeasureResult().getColl(0)

        try:
            imageLabel = self.ipsdkLabelHistogram.imageLabel
            imageGrey = self.ipsdkLabelHistogram.imageGrey
            histogramGrey = self.ipsdkLabelHistogram.histogramGrey
            info = fctML.readSmartSegmentation(imageLabel)
        except:
            info={}


        # get total number of pixels per label
        if imageLabel.getVolumeGeometryType() == PyIPSDK.eVGT_2d:
            inMeasureInfoSet2d = PyIPSDK.createMeasureInfoSet2d()
            PyIPSDK.createMeasureInfo(inMeasureInfoSet2d, "NbPixels2dMsr")
            analysis = shapeanalysis.labelAnalysis2d(imageLabel, imageLabel, inMeasureInfoSet2d)
            measures = analysis.getMeasure("NbPixels2dMsr")
            pixelsLabelValues = measures.getMeasureResult().getColl(0)
        else:
            inMeasureInfoSet3d = PyIPSDK.createMeasureInfoSet3d()
            PyIPSDK.createMeasureInfo(inMeasureInfoSet3d, "NbPixels3dMsr")
            analysis = shapeanalysis.labelAnalysis3d(imageLabel, imageLabel, inMeasureInfoSet3d)
            measures = analysis.getMeasure("NbPixels3dMsr")
            pixelsLabelValues = measures.getMeasureResult().getColl(0)

        # total number of pixels in the image
        totalNumberPixels = imageLabel.getSizeX()*imageLabel.getSizeY()*imageLabel.getSizeZ()

        for numLabel in range(len(self.valuesHistogram)):
            valueHisto = self.valuesHistogram[numLabel]

            fraction = fct.numberCalibration(100*pixelsLabelValues[numLabel]/totalNumberPixels)

            if valueHisto is not None:

                dictHisto = PyIPSDK.toPyDict(valueHisto)

                # y = np.arange(dictHisto["Min"], dictHisto["Max"]+dictHisto["BinWidth"], dictHisto["BinWidth"])
                y = np.linspace(dictHisto["Min"], dictHisto["Max"], len(dictHisto["Frequencies"]))

                if info != {}:
                    color = info[numLabel]["Color"]
                    colorSplit = color.split(",")
                    line = self.axes.plot(y, dictHisto["Frequencies"], color=(int(colorSplit[0])/255, int(colorSplit[1])/255, int(colorSplit[2])/255), label=str(info[numLabel]["Name"]) + "     " + str(fraction) + " %", linewidth=3)
                    # line = self.axes.plot(y, newHistogram, color=(int(colorSplit[0])/255, int(colorSplit[1])/255, int(colorSplit[2])/255), label=str(info[numLabel]["Name"]) + "     " + str(fraction), linewidth=3)
                else:
                    line = self.axes.plot(y, dictHisto["Frequencies"], label="Phase " + str(numLabel) + "     " + str(fraction) + " %", linewidth=3)
                    # line = self.axes.plot(y, newHistogram, label="Phase " + str(numLabel) + "     " + str(fraction) + " %", linewidth=3)

                self.axes.set_xlabel('Intensity', size=25)
                self.axes.set_ylabel('Number of Pixels', size=25)
                self.axes.legend(loc=1, prop={'size': 20})
                self.plotWidget.draw()



    def stopDisplay(self):
        self.close()

    def display(self, xmlElement=None, boolShow=True):
        self.setWindowTitle(Dfct.childText(xmlElement, 'Name'))
        self.initValues()
        if boolShow:
            if self.isMaximized():
                self.showMaximized()
            else:
                self.showNormal()
            self.window().raise_()
            self.window().activateWindow()

    def displayLog(self):
        if self.checkBoxLog.isChecked():
            self.checkBoxNormalize.setChecked(False)
            self.checkBoxNormalize.setDisabled(True)
            self.axes.set_yscale("log")
            self.plotWidget.draw()
        else:
            self.checkBoxNormalize.setDisabled(False)

            self.axes.set_yscale("linear")
            self.plotWidget.draw()


    def displayNormalize(self):
        if self.checkBoxNormalize.isChecked():
            self.checkBoxLog.setChecked(False)
            self.checkBoxLog.setDisabled(True)
            self.axes.clear()

            self.valuesHistogram = self.ipsdkLabelHistogram.getMeasure("HistogramMsr").getMeasureResult().getColl(0)

            try:
                imageLabel = self.ipsdkLabelHistogram.imageLabel
                info = fctML.readSmartSegmentation(imageLabel)
                histogramGrey = self.ipsdkLabelHistogram.histogramGrey

            except:
                pass

            # get total number of pixels per label
            if imageLabel.getVolumeGeometryType() == PyIPSDK.eVGT_2d:
                inMeasureInfoSet2d = PyIPSDK.createMeasureInfoSet2d()
                PyIPSDK.createMeasureInfo(inMeasureInfoSet2d, "NbPixels2dMsr")
                analysis = shapeanalysis.labelAnalysis2d(imageLabel, imageLabel, inMeasureInfoSet2d)
                measures = analysis.getMeasure("NbPixels2dMsr")
                pixelsLabelValues = measures.getMeasureResult().getColl(0)
            else:
                inMeasureInfoSet3d = PyIPSDK.createMeasureInfoSet3d()
                PyIPSDK.createMeasureInfo(inMeasureInfoSet3d, "NbPixels3dMsr")
                analysis = shapeanalysis.labelAnalysis3d(imageLabel, imageLabel, inMeasureInfoSet3d)
                measures = analysis.getMeasure("NbPixels3dMsr")
                pixelsLabelValues = measures.getMeasureResult().getColl(0)

            # total number of pixels in the image
            totalNumberPixels = imageLabel.getSizeX() * imageLabel.getSizeY()

            dictGreyHisto = PyIPSDK.toPyDict(histogramGrey)

            for numLabel in range(len(self.valuesHistogram)):
                valueHisto = self.valuesHistogram[numLabel]

                #fraction = round(pixelsLabelValues[numLabel] / totalNumberPixels, 2)
                fraction = fct.numberCalibration(100*pixelsLabelValues[numLabel]/totalNumberPixels)

                if valueHisto is not None:
                    dictHisto = PyIPSDK.toPyDict(valueHisto)

                    #y = np.arange(dictHisto["Min"], dictHisto["Max"] + dictHisto["BinWidth"], dictHisto["BinWidth"])
                    y = np.linspace(dictHisto["Min"], dictHisto["Max"], len(dictHisto["Frequencies"]))

                    newHistogram = []
                    lastValue = 0
                    self.minVal = sys.maxsize
                    self.maxVal = 0
                    for i in range(len(dictHisto["Frequencies"])):
                        if (dictGreyHisto["Frequencies"][i] != 0):
                            lastValue = dictHisto["Frequencies"][i] / dictGreyHisto["Frequencies"][i]

                            if (lastValue < self.minVal):
                                self.minVal = lastValue
                            if (lastValue > self.maxVal):
                                self.maxVal = lastValue

                            newHistogram.append(dictHisto["Frequencies"][i] / dictGreyHisto["Frequencies"][i])
                        else:
                            newHistogram.append(lastValue)

                    self.labelHistogram.append(newHistogram)

                    if info != {}:
                        color = info[numLabel]["Color"]
                        colorSplit = color.split(",")
                        # line = self.axes.plot(y, dictHisto["Frequencies"], color=(int(colorSplit[0])/255, int(colorSplit[1])/255, int(colorSplit[2])/255), label=str(info[numLabel]["Name"]) + "     " + str(fraction), linewidth=3)
                        line = self.axes.plot(y, newHistogram, color=(
                        int(colorSplit[0]) / 255, int(colorSplit[1]) / 255, int(colorSplit[2]) / 255),
                                              label=str(info[numLabel]["Name"]) + "     " + str(fraction) + " %", linewidth=3)
                    else:
                        # line = self.axes.plot(y, dictHisto["Frequencies"], label="Phase " + str(numLabel) + "     " + str(fraction), linewidth=3)
                        line = self.axes.plot(y, newHistogram, label="Phase " + str(numLabel) + "     " + str(fraction) + " %",
                                              linewidth=3)

                    self.axes.set_xlabel('Intensity', size=25)
                    self.axes.set_ylabel('Ratio', size=25)
                    self.axes.legend(loc=1, prop={'size': 20})
                    self.plotWidget.draw()

            if self.checkBoxAll.isChecked():
                self.displayAll()
            if self.checkBoxLog.isChecked():
                self.displayLog()

        else:
            self.checkBoxLog.setDisabled(False)
            if self.checkBoxLog.isChecked():
                self.displayLog()
            else:
                self.initValues()
            self.plotWidget.draw()

    def displayAll(self):
        if self.checkBoxAll.isChecked():
            tmp=None
            minFrequency=None
            maxFrequency=0
            bin=0

            histogramGrey = self.ipsdkLabelHistogram.histogramGrey
            dictGreyHisto = PyIPSDK.toPyDict(histogramGrey)

            for numLabel in range(len(self.valuesHistogram)):
                valueHisto = self.valuesHistogram[numLabel]
                if valueHisto is not None:
                    dictHisto = PyIPSDK.toPyDict(valueHisto)

                    # initialization of temporary list for the sum with zero values
                    if tmp is None:
                        if self.checkBoxNormalize.isChecked():
                            size = len(self.labelHistogram[numLabel-1])
                        else:
                            size = len(dictHisto["Frequencies"])
                        tmp = np.zeros(size, int)

                    # sum between a list of frequencies and the temporary list
                    if self.checkBoxNormalize.isChecked():
                        arrayFrequencies = np.asarray(self.labelHistogram[numLabel - 1])
                    else:
                        arrayFrequencies = np.asarray(dictHisto["Frequencies"])

                    tmp = np.add(tmp, arrayFrequencies)

                    # get the min and max frequency, and the bin width
                    if minFrequency is None:
                        minFrequency = dictHisto["Min"]

                    else:
                        minFrequency = min(minFrequency, dictHisto["Min"])

                    if maxFrequency is None:
                        maxFrequency = dictHisto["Max"]

                    else:
                        maxFrequency = max(maxFrequency, dictHisto["Max"])

                    bin = dictHisto["BinWidth"]

            #y = np.arange(minFrequency, maxFrequency + bin, bin)
            y = np.linspace(minFrequency, maxFrequency, len(dictHisto["Frequencies"]))
            line = self.axes.plot(y, tmp, '--', color='black', label="All", linewidth=5)

            # redraw the legend and the graph
            self.axes.legend(loc=1, prop={'size': 20})
            self.plotWidget.draw()

        else: # if All not check
            # redisplay log without the sum
            if self.checkBoxLog.isChecked():
                self.axes.clear()
                if self.checkBoxNormalize.isChecked():
                    self.displayNormalize()
                else:
                    self.initValues()
                self.displayLog()
                self.plotWidget.draw()
            # redisplay initial graph without the sum
            else:
                self.axes.clear()
                if self.checkBoxNormalize.isChecked():
                    self.displayNormalize()
                else:
                    self.initValues()
                self.plotWidget.draw()



class HistogramLabelTableWidget(qt.QWidget):
    """
    Widget used to display an IPSDK histogram as table
    Adapts itself to the type of histogram
    """

    def __init__(self, ipsdkLabelHistogram):
        from WidgetTypes import ScalarConstraint
        qt.QWidget.__init__(self)

        self.ipsdkLabelHistogram = ipsdkLabelHistogram

        self.ddict = {}

        self.valuesHistogram = self.ipsdkLabelHistogram.getMeasure("HistogramMsr").getMeasureResult().getColl(0)
        try:
            imageLabel = self.ipsdkLabelHistogram.imageLabel
            info = fctML.readSmartSegmentation(imageLabel)
        except:
            info = {}


        firstLab = True
        nbLabels=0
        for numLabel in range(len(self.valuesHistogram)):
            valueHisto = self.valuesHistogram[numLabel]

            if valueHisto is not None:
                nbLabels+=1
                dictHisto = PyIPSDK.toPyDict(valueHisto)

                if firstLab:
                    firstLab = False
                    intensityValue=[]
                    for i in range(len(dictHisto["Frequencies"])):
                        intensityValue.append(i*dictHisto["BinWidth"])
                    self.ddict["Intensity"] = intensityValue
                if info != {}:
                    self.ddict[info[numLabel]["Name"]] = dictHisto["Frequencies"]

        self.tableWidget = display.TableWidget(self.ddict)

        self.tableWidget.loadDictionary()


        style = fct.getStyleSheet()
        self.setStyleSheet(style)

        self.barColors = [QtGui.QColor(66, 134, 244)]

        self.buttonSave = wgt.PushButtonImage(vrb.folderImages + "/Save.png", margins=2)
        self.buttonSave.setFixedSize(30 * vrb.ratio, 30 * vrb.ratio)

        self.layout = qt.QGridLayout()
        self.layout.addWidget(self.buttonSave, 0, 0, Qt.AlignLeft)
        self.layout.addWidget(self.tableWidget, 1, 0)

        self.layout.setContentsMargins(15, 15, 15, 15)
        self.layout.setSizeConstraint(1)

        self.setLayout(self.layout)

        self.buttonSave.clicked.connect(self.saveFileResult)

        self.resize((nbLabels+1)*95*vrb.ratio, 400*vrb.ratio)
        self.setWindowTitle("Histogram Label Table")

    def saveFileResult(self):
        filename = qt.QFileDialog.getSaveFileName(self, "Save file as", "","(*.xls);;(*.csv)")
        if filename[1] == "(*.xls)":
            fct.convertToXlsFile(self.tableWidget,filename[0],title="Table",header=True)
        if filename[1] == "(*.csv)":
            fct.convertToCsvFile(self.tableWidget,filename[0])

    def stopDisplay(self):
        self.close()

    def display(self, xmlElement=None, boolShow=True):
        self.setWindowTitle(Dfct.childText(xmlElement, 'Name'))
        if boolShow:
            if self.isMaximized():
                self.showMaximized()
            else:
                self.showNormal()
            self.window().raise_()
            self.window().activateWindow()


if __name__ == '__main__':

    import PyIPSDK.IPSDKIPLGlobalMeasure as glbmsr

    app = QCoreApplication.instance()
    if app is None:
        app = qt.QApplication([])

    sys._excepthook = sys.excepthook

    def exception_hook(exctype, value, traceback):
        print(exctype, value, traceback)
        sys._excepthook(exctype, value, traceback)
        sys.exit(1)

    sys.excepthook = exception_hook




    #statCategory = PyIPSDK.loadTiffImageFile("C:/Users/ae/Documents/statCategory.tif")
    statCategory = PyIPSDK.loadTiffImageFile("C:/Users/ae/Documents/labImgHistoLab.tif")
    #statCategory = PyIPSDK.loadTiffImageFile("C:/Users/ae/Documents/histogramLabelLabel2.tif")
    #AnisotropicDiffusion = PyIPSDK.loadTiffImageFile("C:/Users/ae/Documents/histogramLabel.tif")
    AnisotropicDiffusion = PyIPSDK.loadTiffImageFile("C:/Users/ae/Documents/greyImgHistoLab.tif")
    #AnisotropicDiffusion = PyIPSDK.loadTiffImageFile("C:/dev/Explorer/IPSDK_3_1_0_2/images/AnisotropicDiffusion.tif")

    info = fctML.readSmartSegmentation(statCategory)

    color = info[1]["Color"]
    colorSplit = color.split(",")

    minValue = 0
    maxValue = 255
    inMeasureInfoSet2d = PyIPSDK.createMeasureInfoSet2d()
    PyIPSDK.createMeasureInfo(inMeasureInfoSet2d, 'HistogramMsr',
                              shapeanalysis.createHistogramMsrParamsBinWidth(10, minValue, maxValue))
    HistogramLabel_0 = shapeanalysis.labelAnalysis2d(AnisotropicDiffusion, statCategory, inMeasureInfoSet2d)

    HistogramLabel_0.imageLabel = statCategory

   # foo = HistogramLabelWidget(HistogramLabel_0)
    foo = HistogramLabelTableWidget(HistogramLabel_0)

    foo.display()

    app.exec_()