import PyIPSDK
import PyIPSDK.IPSDKIPLUtility as util
import PyIPSDK.IPSDKIPLGeometricTransform as gtrans

import PyIPSDK.IPSDKIPLArithmetic as arithm
import PyIPSDK.IPSDKIPLBasicMorphology as morpho
import PyIPSDK.IPSDKIPLFiltering as filtering
import PyIPSDK.IPSDKIPLClassification as classif
import PyIPSDK.IPSDKIPLStats as stats
import PyIPSDK.IPSDKIPLShapeAnalysis as shapeanalysis

import UsefullVariables as vrb
import UsefullWidgets as wgt
import UsefullTexts as txt
import UsefullFunctions as fct
import DatabaseFunction as Dfct

import xml.etree.ElementTree as xmlet

# from sklearn.externals import joblib
import joblib
import numpy as np

import XMLtoInfoSet as xml2IS
from ShapeAnalysisDisplayer import measureIterator,measureIteratorByType

def xmlInfosetToMeasure(xmlElement,greyImage, labelImage, dimension="3D",intensity = False):

    text = ""
    if dimension == "2D":
        text += "calibration = PyIPSDK.createGeometricCalibration2d(1,1,'px')\n"
        if intensity == False:
            text += xml2IS.xmlToInfoset(xmlElement, 'inMeasureInfoSet2d', '2D', 'calibration', addBarycenter=False)
        else:
            text += xml2IS.xmlToInfosetIntensity(xmlElement, 'inMeasureInfoSet2d', '2D', 'calibration')
        text += 'outMeasure = shapeanalysis.labelAnalysis2d(greyImage,labelImage,inMeasureInfoSet2d)\n'
    else:
        text += "calibration = PyIPSDK.createGeometricCalibration3d(1,1,1,'px')\n"
        if intensity == False:
            text += xml2IS.xmlToInfoset(xmlElement, 'inMeasureInfoSet3d', '3D', 'calibration', addBarycenter=False)
        else:
            text += xml2IS.xmlToInfosetIntensity(xmlElement, 'inMeasureInfoSet3d', '3D', 'calibration')
        text += 'outMeasure = shapeanalysis.labelAnalysis3d(greyImage,labelImage,inMeasureInfoSet3d)\n'

    _locals = locals()
    exec(text, globals(), _locals)
    outMeasure = _locals['outMeasure']

    return outMeasure

def computeFeaturesClassification(image,imageGrey,xmlElement):

    dimension = xmlElement.get("Dimension")

    if imageGrey.getSizeC() == 1:
        measures = xmlInfosetToMeasure(xmlElement, imageGrey, image, dimension=dimension)
        features = []
        listNamesFeatures = []
        allMeasures = measureIterator(xmlElement)
        for callName in allMeasures:
            try:
                features.append(measures.getMeasure(callName[0]).getMeasureResult().getColl(0))
                listNamesFeatures.append(callName[1])
            except:
                pass
    else:
        features = []
        listNamesFeatures = []
        colorList = ["Red", "Green", "Blue"]
        for c in range(3):
            plan = PyIPSDK.extractPlan(0, c, 0, imageGrey)
            plan = util.copyImg(plan)
            measures = xmlInfosetToMeasure(xmlElement, plan, image, dimension=dimension, intensity=c != 0)
            if c == 0:
                allMeasuresGeometry = measureIteratorByType(xmlElement, "Geometry")
                for callName in allMeasuresGeometry:
                    try:
                        features.append(measures.getMeasure(callName[0]).getMeasureResult().getColl(0))
                        listNamesFeatures.append(callName[1])
                    except:
                        pass
            allMeasuresIntensity = measureIteratorByType(xmlElement, "Intensity")
            for callName in allMeasuresIntensity:
                try:
                    features.append(measures.getMeasure(callName[0]).getMeasureResult().getColl(0))
                    listNamesFeatures.append(callName[1] + " (" + colorList[c] + ")")
                except:
                    pass

    features = np.asarray(features)

    return features

def createSeqImage(listImages):

    geometry = PyIPSDK.geometrySeq2d(PyIPSDK.eIBT_Real32, listImages[0].getSizeX(), listImages[0].getSizeY(),len(listImages))
    imageSeq = PyIPSDK.createImage(geometry)

    for i in range(len(listImages)):
        currentSeq = PyIPSDK.extractPlan(0,0,i,imageSeq)
        if listImages[i].getBufferType() != PyIPSDK.eIBT_Real32:
            imageReal = util.convertImg(listImages[i],PyIPSDK.eIBT_Real32)
            util.copyImg(imageReal,currentSeq)
        else:
            util.copyImg(listImages[i], currentSeq)

    return imageSeq

def createModelIPSDK(xmlElement,model,image,nbLabels = None, dictConversion = {}):

    modelElement = PyIPSDK.createRandomForestModel()

    featureList = modelElement.featureList

    featureList.setNbChannels(image.getSizeC())

    maxShift = 0
    featuresElement = Dfct.SubElement(xmlElement,"Features")

    dimension = Dfct.childText(featuresElement,"Dimension")
    if dimension == "2D":
        featureList.setDimension(PyIPSDK.eRandomForestFeaturesDimension.eRFFD_2d)
    else:
        featureList.setDimension(PyIPSDK.eRandomForestFeaturesDimension.eRFFD_3d)

    allSizes = Dfct.childText(featuresElement,"All_Sizes").split(',')
    multiRes = Dfct.childText(featuresElement, "MultiRes").split(',')
    featuresName = ["Gaussian", "Mean", "Laplacian_Difference", "High_Pass", "Variance","Erosion", "Dilation", "Opening", "Closing"]
    for res in multiRes:
        for name in featuresName:
            for i in range(len(allSizes)):
                if Dfct.containsTag(featuresElement, name):
                    if Dfct.childText(featuresElement,name).split(",")[i] == "True":
                        featureName = converterFeatureName(name)
                        feature = PyIPSDK.createRandomForestFeature(featureName, int(allSizes[i]), int(allSizes[i]), int(allSizes[i]),int(res))
                        featureList.addFeature(feature)
                        maxShift = max(maxShift,int(allSizes[i]))
                else:
                    print("[INFO] The model does not contains the size " + str(i) + " for the feature " + name)

        # Obsolete features management
        obsoleteFeatureList = [("Dilatation", "Dilation")] # List of obsolete features and their supported equivalent
        for nameObsolete, nameSupported in obsoleteFeatureList:
            for i in range(len(allSizes)):
                if Dfct.containsTag(featuresElement, nameObsolete):
                    if Dfct.childText(featuresElement, nameObsolete).split(",")[i] == "True":
                        featureName = converterFeatureName(nameSupported)
                        feature = PyIPSDK.createRandomForestFeature(featureName, int(allSizes[i]),
                                                                    int(allSizes[i]), int(allSizes[i]),
                                                                    int(res))
                        featureList.addFeature(feature)
                        maxShift = max(maxShift, int(allSizes[i]))

    featureList.setMaxShift(maxShift)

    forest = modelElement.randomForest

    # if nbLabels is None:
    #     labelClassesElement = Dfct.SubElement(xmlElement, "LabelClasses")
    #     nbLabels = int(Dfct.childText(labelClassesElement,"NumberLabels"))

    forest.setNbLabels(nbLabels)

    for tree_idx, est in enumerate(model.estimators_):
        tree = est.tree_
        treeElement = PyIPSDK.createRandomForestTree()
        forest.addTree(treeElement)

        for i in range(len(tree.children_left)):
            if tree.children_left[i] != -1:
                nodeType = PyIPSDK.eRFTNT_Node
                value=-1
            else:
                nodeType = PyIPSDK.eRFTNT_Leaf
                value = int(np.argmax(tree.value[i][0]))
                if value in dictConversion:
                    value = dictConversion[value]

            node = PyIPSDK.createRandomForestTreeNode(i,nodeType,int(tree.feature[i]),float(tree.threshold[i]),int(tree.children_left[i]),int(tree.children_right[i]),value)
            treeElement.addTreeNode(node)

    #PyIPSDK.writeToXmlFile("F:/modelRFPython.xml",modelElement)

    return modelElement

def converterFeatureName(nameFeature):

    ipsdkFeature = None
    if nameFeature == "Gaussian":
        ipsdkFeature = PyIPSDK.eRandomForestFeaturesType.eRFFT_GaussianSmoothing
    elif nameFeature == "Mean":
        ipsdkFeature = PyIPSDK.eRandomForestFeaturesType.eRFFT_MeanSmoothing
    elif nameFeature == "Laplacian_Difference":
        ipsdkFeature = PyIPSDK.eRandomForestFeaturesType.eRFFT_Laplacian_Difference
    elif nameFeature == "High_Pass":
        ipsdkFeature = PyIPSDK.eRandomForestFeaturesType.eRFFT_High_Pass
    elif nameFeature == "Variance":
        ipsdkFeature = PyIPSDK.eRandomForestFeaturesType.eRFFT_Variance
    elif nameFeature == "Erosion":
        ipsdkFeature = PyIPSDK.eRandomForestFeaturesType.eRFFT_Erosion
    elif nameFeature == "Dilation" or nameFeature == "Dilatation":
        ipsdkFeature = PyIPSDK.eRandomForestFeaturesType.eRFFT_Dilatation
    elif nameFeature == "Opening":
        ipsdkFeature = PyIPSDK.eRandomForestFeaturesType.eRFFT_Opening
    elif nameFeature == "Closing":
        ipsdkFeature = PyIPSDK.eRandomForestFeaturesType.eRFFT_Closing
    else:
        print("Unsupported feature " + nameFeature)

    return ipsdkFeature

def addClassicFeatures(nameFunction, size, label, dictFeatures, res, dimension):

    colors = ["Red", "Green", "Blue"]

    if size not in dictFeatures:
        dictFeatures[size] = {}
    if nameFunction not in dictFeatures[size]:
        dictFeatures[size][nameFunction] = {}
    if res not in dictFeatures[size][nameFunction]:
        dictFeatures[size][nameFunction][res] = {}
    if label.image.getSizeC() == 1:
        try:
            newFeature = label.dictFeatures[size][nameFunction][res][0]
        except:
            newFeature = featureFromName(label.image, nameFunction, size, res, dimension)
        if newFeature is None:
            newFeature = featureFromName(label.image, nameFunction, size, res, dimension)
        dictFeatures[size][nameFunction][res][0] = newFeature
        label.listFeatures.append(newFeature)
        textName = nameFunction + " " +str(2*size+1) + "x" + str(2*size+1)
        if label.image.getSizeZ() > 1:
            textName += " " + dimension
        if res != 1:
            textName += " (x 1/"+str(res)+")"
        label.listNamesFeatures.append(textName)
    else:
        sizeC = label.image.getSizeC()
        for c in range(sizeC):
            try:
                newFeature = label.dictFeatures[size][nameFunction][res][c]
            except:
                if label.image.getSizeZ() ==1:
                    plan = PyIPSDK.extractPlan(0, c, 0, label.image)
                else:
                    plan = PyIPSDK.extractVolume(c,0,label.image)
                newFeature = featureFromName(plan, nameFunction, size, res, dimension)
            if newFeature is None:
                if label.image.getSizeZ() == 1:
                    plan = PyIPSDK.extractPlan(0, c, 0, label.image)
                else:
                    plan = PyIPSDK.extractVolume(c,0,label.image)
                newFeature = featureFromName(plan, nameFunction, size, res, dimension)
            dictFeatures[size][nameFunction][res][c] = newFeature
            label.listFeatures.append(newFeature)
            textName = nameFunction + " " + str(2*size+1) + "x" + str(2*size+1)
            if label.image.getSizeZ() > 1:
                textName += " " + dimension
            if res != 1:
                textName += " (x 1/" + str(res) + ")"
            if sizeC == 3:
                textName+="( "+colors[c]+")"
            else:
                textName+=" (channel "+str(c)+")"
            label.listNamesFeatures.append(textName)

def featureFromName(image, nameFunction, size, res, dimension):

    image = util.convertImg(image,PyIPSDK.eIBT_Real32)

    if res != 1:
        if dimension == "2D":
            # imageResized = gtrans.zoom2dImg(image, 1 / res, 1 / res, PyIPSDK.eZoomInterpolationMethod.eZIM_Linear)
            if image.getSizeZ() == 1:
                imageResized = PyIPSDK.createImage(PyIPSDK.eIBT_Real32, int((image.getSizeX() / res) + 0.5), int((image.getSizeY() / res) + 0.5))
            else:
                imageResized = PyIPSDK.createImage(PyIPSDK.eIBT_Real32, int((image.getSizeX() / res) + 0.5), int((image.getSizeY() / res) + 0.5),image.getSizeZ())
            gtrans.zoom2dImg(image, PyIPSDK.eZoomInterpolationMethod.eZIM_Linear,imageResized)
        else:
            # imageResized = gtrans.zoom3dImg(image, 1 / res, 1 / res, 1 / res, PyIPSDK.eZoomInterpolationMethod.eZIM_Linear)
            imageResized = PyIPSDK.createImage(PyIPSDK.eIBT_Real32, int((image.getSizeX() / res) + 0.5), int((image.getSizeY() / res) + 0.5),int((image.getSizeZ() / res) + 0.5))
            gtrans.zoom3dImg(image, PyIPSDK.eZoomInterpolationMethod.eZIM_Linear, imageResized)
    else:
        imageResized = image

    if dimension == "2D":
        if nameFunction == "Gaussian":
            imageFeature = filtering.gaussianSmoothing2dImg(imageResized, size)
        if nameFunction == "Mean":
            imageFeature = filtering.meanSmoothing2dImg(imageResized, size, size)
        if nameFunction == "Laplacian Difference":
            imageFeature = filtering.laplacianDoG2dImg(imageResized, size)
        if nameFunction == "High Pass":
            imageFeature = filtering.highPass2dImg(imageResized, size)
        if nameFunction == "Variance":
            imageFeature = stats.variance2dImg(imageResized, size, size)
        if nameFunction == "Erosion":
            structuringElement = PyIPSDK.circularSEXYInfo(size)
            imageFeature = morpho.erode2dImg(imageResized, structuringElement)
        if nameFunction == "Dilation":
            structuringElement = PyIPSDK.circularSEXYInfo(size)
            imageFeature = morpho.dilate2dImg(imageResized, structuringElement)
        if nameFunction == "Opening":
            structuringElement = PyIPSDK.circularSEXYInfo(size)
            imageFeature = morpho.opening2dImg(imageResized, structuringElement, PyIPSDK.eBEP_Disable)
        if nameFunction == "Closing":
            structuringElement = PyIPSDK.circularSEXYInfo(size)
            imageFeature = morpho.closing2dImg(imageResized, structuringElement, PyIPSDK.eBEP_Disable)
    else:
        if nameFunction == "Gaussian":
            imageFeature = filtering.gaussianSmoothing3dImg(imageResized, size)
        if nameFunction == "Mean":
            imageFeature = filtering.meanSmoothing3dImg(imageResized, size, size, size)
        if nameFunction == "Laplacian Difference":
            imageFeature = filtering.laplacianDoG3dImg(imageResized, size)
        if nameFunction == "High Pass":
            imageFeature = filtering.highPass3dImg(imageResized, size)
        if nameFunction == "Variance":
            imageFeature = stats.variance3dImg(imageResized, size, size, size)
        if nameFunction == "Erosion":
            structuringElement = PyIPSDK.sphericalSEXYZInfo(size)
            imageFeature = morpho.erode3dImg(imageResized, structuringElement)
        if nameFunction == "Dilation":
            structuringElement = PyIPSDK.sphericalSEXYZInfo(size)
            imageFeature = morpho.dilate3dImg(imageResized, structuringElement)
        if nameFunction == "Opening":
            structuringElement = PyIPSDK.sphericalSEXYZInfo(size)
            imageFeature = morpho.opening3dImg(imageResized, structuringElement, PyIPSDK.eBEP_Disable)
        if nameFunction == "Closing":
            structuringElement = PyIPSDK.sphericalSEXYZInfo(size)
            imageFeature = morpho.closing3dImg(imageResized, structuringElement, PyIPSDK.eBEP_Disable)

    if res != 1:
        outImage = PyIPSDK.createImage(image)
        if dimension == "2D":
            gtrans.zoom2dImg(imageFeature, PyIPSDK.eZoomInterpolationMethod.eZIM_Linear, outImage)
        else:
            gtrans.zoom3dImg(imageFeature, PyIPSDK.eZoomInterpolationMethod.eZIM_Linear, outImage)
    else:
        outImage = imageFeature

    return outImage