/*
 * Decompiled with CFR 0.152.
 */
package org.tigr.microarray.mev.cluster.gui.impl.usc;

import java.awt.Cursor;
import java.awt.Dimension;
import java.awt.Frame;
import java.awt.Toolkit;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Vector;
import javax.swing.BoxLayout;
import javax.swing.JFrame;
import javax.swing.JLabel;
import javax.swing.JPanel;
import javax.swing.JProgressBar;
import javax.swing.SpringLayout;
import org.tigr.microarray.mev.cluster.gui.impl.usc.SpringUtilities;
import org.tigr.microarray.mev.cluster.gui.impl.usc.USCDeltaRhoResult;
import org.tigr.microarray.mev.cluster.gui.impl.usc.USCHyb;
import org.tigr.microarray.mev.cluster.gui.impl.usc.USCHybSet;
import org.tigr.microarray.mev.cluster.gui.impl.usc.USCOrder;
import org.tigr.microarray.mev.cluster.gui.impl.usc.USCRelevanceComparator;
import org.tigr.microarray.mev.cluster.gui.impl.usc.USCResult;

public class USCCrossValidation {
    private int deltaKount;
    private int deltaMax;
    private int foldKount;
    private int xValKount;
    private double rhoMin;
    private double rhoMax;
    private double rhoStep;
    private double deltaStep;

    public USCCrossValidation(int numDeltasP, int deltaMaxP, double rhoMinP, double rhoMaxP, double rhoStepP, int numFoldP, int xValKountP) {
        this.deltaKount = numDeltasP;
        this.deltaMax = deltaMaxP;
        this.rhoMin = rhoMinP;
        this.rhoMax = rhoMaxP;
        this.rhoStep = rhoStepP;
        this.foldKount = numFoldP;
        this.xValKount = xValKountP;
        double dDelta = this.deltaMax;
        double dNum = this.deltaKount;
        this.deltaStep = dDelta / dNum;
    }

    public USCDeltaRhoResult[][][] crossValidate(USCHybSet fullSet, Frame frame) {
        frame.setCursor(Cursor.getPredefinedCursor(3));
        USCDeltaRhoResult[][][] xResult = null;
        JPanel mainPanel = new JPanel();
        mainPanel.setLayout(new SpringLayout());
        JPanel leftPanel = new JPanel();
        leftPanel.add(new JLabel("     "));
        JPanel rightPanel = new JPanel();
        rightPanel.add(new JLabel("     "));
        JPanel midPanel = new JPanel();
        BoxLayout midBox = new BoxLayout(midPanel, 1);
        midPanel.setLayout(midBox);
        JLabel label = new JLabel("Cross Validating... Please Wait");
        JLabel label2 = new JLabel("This will take a few minutes");
        JLabel foldLabel = new JLabel("Fold/CrossVal runs");
        JLabel deltaLabel = new JLabel("Deltas");
        JLabel rhoLabel = new JLabel("Rhos");
        JLabel corrLabel = new JLabel("Pairwise Genes");
        JLabel blankLabel = new JLabel(" ");
        midPanel.add(label);
        midPanel.add(label2);
        midPanel.add(blankLabel);
        JProgressBar foldBar = new JProgressBar(0, this.foldKount * this.xValKount);
        foldBar.setIndeterminate(false);
        foldBar.setStringPainted(true);
        JProgressBar deltaBar = new JProgressBar(0, this.deltaKount);
        deltaBar.setIndeterminate(false);
        deltaBar.setStringPainted(true);
        JProgressBar rhoBar = new JProgressBar(5, 11);
        rhoBar.setIndeterminate(false);
        rhoBar.setStringPainted(true);
        JProgressBar corrBar = new JProgressBar(0, fullSet.getNumGenes());
        corrBar.setIndeterminate(false);
        corrBar.setStringPainted(true);
        midPanel.add(foldLabel);
        midPanel.add(foldBar);
        midPanel.add(deltaLabel);
        midPanel.add(deltaBar);
        midPanel.add(rhoLabel);
        midPanel.add(rhoBar);
        midPanel.add(corrLabel);
        midPanel.add(corrBar);
        mainPanel.add(leftPanel);
        mainPanel.add(midPanel);
        mainPanel.add(rightPanel);
        SpringUtilities.makeCompactGrid(mainPanel, 1, 3, 0, 0, 0, 0);
        JFrame jf = new JFrame("Cross Validating");
        jf.setDefaultCloseOperation(3);
        jf.getContentPane().add(mainPanel);
        jf.setSize(250, 250);
        jf.show();
        Dimension screenSize = Toolkit.getDefaultToolkit().getScreenSize();
        jf.setLocation((screenSize.width - 200) / 2, (screenSize.height - 100) / 2);
        int iProgress = 0;
        int iRho = 0;
        for (double currentRho = this.rhoMin; currentRho < this.rhoMax; currentRho += this.rhoStep) {
            ++iRho;
        }
        int iTrainStep = this.foldKount * this.xValKount;
        int resultKount = this.deltaKount * iRho;
        xResult = new USCDeltaRhoResult[this.xValKount][this.foldKount][];
        int xResultKount = 0;
        for (int m = 0; m < this.xValKount; ++m) {
            for (int f = 0; f < this.foldKount; ++f) {
                deltaBar.setValue(0);
                xResult[m][f] = new USCDeltaRhoResult[resultKount];
                int iResult = 0;
                USCHyb[] subTestArray = fullSet.getTestArray(f);
                USCHyb[] subTrainArray = fullSet.getTrainArray(f);
                double delta = 0.0;
                for (int d = 0; d < this.deltaKount; ++d) {
                    rhoBar.setValue(5);
                    for (int r = 5; r < 11; ++r) {
                        double rho = (double)r * (double)0.1f;
                        USCDeltaRhoResult drResult = this.doDR(subTrainArray, subTestArray, delta, rho, fullSet.getNumGenes(), fullSet.getNumClasses(), fullSet.getUniqueClasses(), corrBar, r);
                        if (drResult == null) {
                            xResult[m][f][iResult] = new USCDeltaRhoResult();
                        } else {
                            xResult[m][f][iResult] = drResult;
                            ++iResult;
                        }
                        rhoBar.setValue(r);
                    }
                    delta += this.deltaStep;
                    deltaBar.setValue(d + 1);
                }
                ++xResultKount;
                foldBar.setIndeterminate(false);
                foldBar.setValue(++iProgress);
                foldBar.setStringPainted(true);
            }
        }
        jf.dispose();
        frame.setCursor(Cursor.getPredefinedCursor(0));
        return xResult;
    }

    public USCDeltaRhoResult doDR(USCHyb[] trainArray, USCHyb[] testArray, double delta, double rho, int numGenes, int numClasses, String[] uniqueClassLabels, JProgressBar rhoBar, int iRho) {
        double[][] discScores;
        int g;
        USCOrder[] order = new USCOrder[numGenes];
        for (int g2 = 0; g2 < numGenes; ++g2) {
            order[g2] = new USCOrder(g2);
        }
        double[] geneCentroids = this.computeGeneCentroids(trainArray, numGenes);
        double[][] classCentroids = this.computeClassCentroids(trainArray, uniqueClassLabels, numGenes);
        double[] mks = this.computeMks(trainArray, uniqueClassLabels);
        double[] sis = this.computeSis(trainArray, classCentroids, uniqueClassLabels, numGenes);
        double s0 = this.computeMedian(sis);
        double[][] dik = this.computeRelativeDifferences(classCentroids, geneCentroids, mks, sis, s0);
        double[][] dikShrunk = this.shrinkDiks(delta, dik);
        double[][] shrunkenCentroids = this.computeShrunkenClassCentroid(geneCentroids, mks, sis, s0, dikShrunk);
        for (int g3 = 0; g3 < numGenes; ++g3) {
            double maxDik = 0.0;
            for (int c = 0; c < numClasses; ++c) {
                double toTest = Math.abs(dikShrunk[c][g3]);
                if (!(toTest > maxDik)) continue;
                maxDik = toTest;
            }
            order[g3].setBeta(maxDik);
        }
        BitSet bsUse = this.findRelevantGenes(dikShrunk, order);
        int numRelevant = 0;
        for (g = 0; g < numGenes; ++g) {
            if (!bsUse.get(g)) continue;
            ++numRelevant;
        }
        Arrays.sort(order, new USCRelevanceComparator());
        for (g = 0; g < order.length; ++g) {
            order[g].setIRelevant(g);
        }
        if (numRelevant > 0) {
            this.doCorrelationTesting(order, trainArray, rho, rhoBar, geneCentroids, iRho);
        }
        if ((discScores = this.computeDiscriminantScores(trainArray, testArray, shrunkenCentroids, order, sis, s0, uniqueClassLabels)) == null) {
            return null;
        }
        int iGenes = 0;
        for (int g4 = 0; g4 < order.length; ++g4) {
            if (!order[g4].use()) continue;
            ++iGenes;
        }
        int numWrong = 0;
        int numRight = 0;
        for (int h = 0; h < testArray.length; ++h) {
            USCHyb hyb = testArray[h];
            String label = hyb.getHybLabel();
            double dLow = 9999999.0;
            int iMin = 0;
            for (int c = 0; c < discScores[h].length; ++c) {
                if (!(discScores[h][c] < dLow)) continue;
                dLow = discScores[h][c];
                iMin = c;
            }
            if (uniqueClassLabels[iMin].equals(label)) {
                ++numRight;
                continue;
            }
            ++numWrong;
        }
        return new USCDeltaRhoResult(delta, rho, numWrong, numRight, iGenes);
    }

    public USCResult testTest(USCHyb[] trainArray, USCHyb[] testArray, double delta, double rho, int numGenes, int numClasses, String[] uniqueClassLabels, JProgressBar corrBar, int iRho) {
        double[][] discScores;
        int g;
        USCOrder[] order = new USCOrder[numGenes];
        for (int g2 = 0; g2 < numGenes; ++g2) {
            order[g2] = new USCOrder(g2);
        }
        double[] geneCentroids = this.computeGeneCentroids(trainArray, numGenes);
        double[][] classCentroids = this.computeClassCentroids(trainArray, uniqueClassLabels, numGenes);
        double[] mks = this.computeMks(trainArray, uniqueClassLabels);
        double[] sis = this.computeSis(trainArray, classCentroids, uniqueClassLabels, numGenes);
        double s0 = this.computeMedian(sis);
        double[][] dik = this.computeRelativeDifferences(classCentroids, geneCentroids, mks, sis, s0);
        double[][] dikShrunk = this.shrinkDiks(delta, dik);
        double[][] shrunkenCentroids = this.computeShrunkenClassCentroid(geneCentroids, mks, sis, s0, dikShrunk);
        for (int g3 = 0; g3 < numGenes; ++g3) {
            double maxDik = 0.0;
            for (int c = 0; c < numClasses; ++c) {
                double toTest = Math.abs(dikShrunk[c][g3]);
                if (!(toTest > maxDik)) continue;
                maxDik = toTest;
            }
            order[g3].setBeta(maxDik);
        }
        BitSet bsUse = this.findRelevantGenes(dikShrunk, order);
        int numRelevant = 0;
        for (g = 0; g < numGenes; ++g) {
            if (!bsUse.get(g)) continue;
            ++numRelevant;
        }
        Arrays.sort(order, new USCRelevanceComparator());
        for (g = 0; g < order.length; ++g) {
            order[g].setIRelevant(g);
        }
        if (numRelevant > 0) {
            this.doCorrelationTesting(order, trainArray, rho, corrBar, geneCentroids, iRho);
        }
        if ((discScores = this.computeDiscriminantScores(trainArray, testArray, shrunkenCentroids, order, sis, s0, uniqueClassLabels)) == null) {
            return null;
        }
        int iGenes = 0;
        for (int g4 = 0; g4 < order.length; ++g4) {
            if (!order[g4].use()) continue;
            ++iGenes;
        }
        return new USCResult(discScores, iGenes, delta, rho, order);
    }

    private double[][] computeDiscriminantScores(USCHyb[] trainArray, USCHyb[] testArray, double[][] shrkClsCntrds, USCOrder[] order, double[] sis, double s0, String[] uniqueClasses) {
        double[][] toReturn = new double[testArray.length][uniqueClasses.length];
        for (int h = 0; h < testArray.length; ++h) {
            for (int c = 0; c < shrkClsCntrds.length; ++c) {
                double classScore = 0.0;
                double classProb = 0.0;
                double fHybsInClassKount = this.getNumClassHybsInUSCHybArray(trainArray, uniqueClasses[c]);
                double fHybKount = trainArray.length;
                classProb = Math.log(fHybsInClassKount / fHybKount);
                for (int g = 0; g < order.length; ++g) {
                    if (!order[g].use()) continue;
                    int iOrig = order[g].getIOriginal();
                    double ratio = testArray[h].getRatio(iOrig);
                    double diffSquare = (ratio - shrkClsCntrds[c][iOrig]) * (ratio - shrkClsCntrds[c][iOrig]);
                    double denom = (sis[iOrig] + s0) * (sis[iOrig] + s0);
                    double geneScore = diffSquare / denom;
                    classScore += geneScore;
                }
                toReturn[h][c] = classScore - 2.0 * classProb;
            }
        }
        return toReturn;
    }

    private void doCorrelationTesting(USCOrder[] order, USCHyb[] trainArray, double rho, JProgressBar corrBar, double[] geneCentroids, int iRho) {
        for (int r = 9; r >= iRho; --r) {
            double dRho = (double)r / 10.0;
            for (int o = 0; o < order.length; ++o) {
                if (order[o].use()) {
                    int iFirstGene = order[o].getIOriginal();
                    for (int j = o + 1; j < order.length; ++j) {
                        int iSecondGene;
                        if (!order[j].use() || iFirstGene == (iSecondGene = order[j].getIOriginal())) continue;
                        double correlation = Math.abs(this.computeCorrelation(trainArray, iFirstGene, iSecondGene, geneCentroids));
                        if (correlation > dRho || correlation == dRho) {
                            order[j].setCorrelated(true);
                            continue;
                        }
                        order[j].setCorrelated(false);
                    }
                }
                corrBar.setValue(o);
                corrBar.setStringPainted(true);
            }
        }
    }

    private double computeCorrelation(USCHyb[] trainArray, int iGeneX, int iGeneY, double[] geneCentroids) {
        double toReturn = 0.0;
        double[] xRatios = new double[trainArray.length];
        double[] yRatios = new double[trainArray.length];
        for (int i = 0; i < trainArray.length; ++i) {
            USCHyb hyb = trainArray[i];
            xRatios[i] = hyb.getRatio(iGeneX);
            yRatios[i] = hyb.getRatio(iGeneY);
        }
        double xMean = geneCentroids[iGeneX];
        double yMean = geneCentroids[iGeneY];
        double numSum = 0.0;
        double xSum = 0.0;
        double ySum = 0.0;
        for (int i = 0; i < trainArray.length; ++i) {
            USCHyb hyb = trainArray[i];
            numSum += (hyb.getRatio(iGeneX) - xMean) * (hyb.getRatio(iGeneY) - yMean);
            xSum += (xRatios[i] - xMean) * (xRatios[i] - xMean);
            ySum += (yRatios[i] - yMean) * (yRatios[i] - yMean);
        }
        toReturn = numSum / Math.sqrt(xSum * ySum);
        return toReturn;
    }

    private double[][] computeShrunkenClassCentroid(double[] geneCentroids, double[] mks, double[] sis, double s0, double[][] dikShrunk) {
        double[][] toReturn = new double[mks.length][geneCentroids.length];
        for (int c = 0; c < mks.length; ++c) {
            for (int g = 0; g < geneCentroids.length; ++g) {
                toReturn[c][g] = geneCentroids[g] + mks[c] * dikShrunk[c][g] * (sis[g] + s0);
            }
        }
        return toReturn;
    }

    private double[] findBeta(double[][] dikSig) {
        double[] toReturn = new double[dikSig[0].length];
        for (int i = 0; i < toReturn.length; ++i) {
            double currentHigh = 0.0;
            for (int j = 0; j < dikSig.length; ++j) {
                if (!(dikSig[j][i] > currentHigh)) continue;
                currentHigh = dikSig[j][i];
            }
            toReturn[i] = currentHigh;
        }
        return toReturn;
    }

    private double[][] removeInsignificantGenes(double[][] dikShrunk, Vector vRemove) {
        int numGenes = dikShrunk[0].length;
        int iSignificant = numGenes - vRemove.size();
        int index = 0;
        double[][] toReturn = new double[dikShrunk.length][iSignificant];
        for (int i = 0; i < numGenes; ++i) {
            int j;
            boolean include = true;
            for (j = 0; j < vRemove.size(); ++j) {
                Integer IRemove = (Integer)vRemove.elementAt(j);
                if (i != IRemove) continue;
                include = false;
                break;
            }
            if (!include) continue;
            for (j = 0; j < dikShrunk.length; ++j) {
                toReturn[j][index] = dikShrunk[j][i];
            }
            ++index;
        }
        return toReturn;
    }

    private BitSet findRelevantGenes(double[][] dikShrunk, USCOrder[] order) {
        BitSet toReturn = new BitSet(dikShrunk[0].length);
        block0: for (int i = 0; i < dikShrunk[0].length; ++i) {
            for (int j = 0; j < dikShrunk.length; ++j) {
                if (!(Math.abs(dikShrunk[j][i]) > 0.0)) continue;
                toReturn.flip(i);
                order[i].setRelevant(true);
                continue block0;
            }
        }
        return toReturn;
    }

    private double[] computeGeneCentroids(USCHyb[] trainArray, int numGenes) {
        double[] toReturn = new double[numGenes];
        for (int i = 0; i < numGenes; ++i) {
            double ratioTotal = 0.0;
            for (int j = 0; j < trainArray.length; ++j) {
                ratioTotal += trainArray[j].getRatio(i);
            }
            toReturn[i] = ratioTotal / (double)trainArray.length;
        }
        return toReturn;
    }

    private double[][] computeClassCentroids(USCHyb[] trainArray, String[] classLabels, int numGenes) {
        double[][] toReturn = new double[classLabels.length][numGenes];
        for (int i = 0; i < numGenes; ++i) {
            for (int j = 0; j < classLabels.length; ++j) {
                double total = 0.0;
                int kount = 0;
                for (int k = 0; k < trainArray.length; ++k) {
                    if (!trainArray[k].getHybLabel().equalsIgnoreCase(classLabels[j])) continue;
                    total += trainArray[k].getRatio(i);
                    ++kount;
                }
                toReturn[j][i] = total / (double)kount;
            }
        }
        return toReturn;
    }

    private double[][] computeRelativeDifferences(double[][] classCentroids, double[] geneCentroids, double[] mks, double[] sis, double s0) {
        double[][] toReturn = new double[classCentroids.length][classCentroids[0].length];
        for (int c = 0; c < classCentroids.length; ++c) {
            for (int g = 0; g < classCentroids[0].length; ++g) {
                toReturn[c][g] = (classCentroids[c][g] - geneCentroids[g]) / (mks[c] * (sis[g] + s0));
            }
        }
        return toReturn;
    }

    private double[] computeSis(USCHyb[] trainArray, double[][] classCentroids, String[] classLabels, int numGenes) {
        double firstTerm = 1.0 / ((double)trainArray.length - (double)classLabels.length);
        double[] toReturn = new double[numGenes];
        for (int i = 0; i < numGenes; ++i) {
            double geneSum = 0.0;
            for (int j = 0; j < classLabels.length; ++j) {
                double classDiffSquareSum = 0.0;
                for (int k = 0; k < trainArray.length; ++k) {
                    if (!trainArray[k].getHybLabel().equalsIgnoreCase(classLabels[j])) continue;
                    double difference = trainArray[k].getRatio(i) - classCentroids[j][i];
                    double diffSquare = difference * difference;
                    classDiffSquareSum += diffSquare;
                }
                geneSum += classDiffSquareSum;
            }
            toReturn[i] = Math.sqrt(firstTerm * geneSum);
        }
        return toReturn;
    }

    private double[] computeMks(USCHyb[] trainArray, String[] classLabels) {
        double[] toReturn = new double[classLabels.length];
        for (int i = 0; i < classLabels.length; ++i) {
            int kount = 0;
            for (int j = 0; j < trainArray.length; ++j) {
                if (!trainArray[j].getHybLabel().equalsIgnoreCase(classLabels[i])) continue;
                ++kount;
            }
            double firstTerm = 1.0 / (double)kount;
            double secondTerm = 1.0 / (double)trainArray.length;
            toReturn[i] = Math.sqrt(firstTerm + secondTerm);
        }
        return toReturn;
    }

    private double computeMean(double[] array) {
        double toReturn = 0.0;
        for (int i = 0; i < array.length; ++i) {
            toReturn += array[i];
        }
        return toReturn / (double)array.length;
    }

    private double computeMedian(double[] array) {
        double[] copy = new double[array.length];
        for (int i = 0; i < array.length; ++i) {
            copy[i] = array[i];
        }
        Arrays.sort(copy);
        int half = copy.length / 2;
        int remainder = copy.length % 2;
        if (remainder == 0) {
            return copy[half];
        }
        return copy[half];
    }

    private double[][] shrinkDiks(double delta, double[][] diks) {
        double[][] toReturn = new double[diks.length][diks[0].length];
        for (int i = 0; i < diks.length; ++i) {
            for (int j = 0; j < diks[0].length; ++j) {
                toReturn[i][j] = this.shrinkDik(delta, diks[i][j]);
            }
        }
        return toReturn;
    }

    private double shrinkDik(double delta, double dik) {
        double toReturn = 0.0;
        toReturn = dik < 0.0 ? -dik - delta : dik - delta;
        if (toReturn < 0.0) {
            toReturn = 0.0;
        } else if (dik < 0.0) {
            toReturn = -toReturn;
        }
        return toReturn;
    }

    public int[] findHybIndicesForClass(USCHyb[] trainArray, int classIndex, USCHybSet hybSet) {
        Vector<Integer> vHybInClass = new Vector<Integer>();
        for (int i = 0; i < trainArray.length; ++i) {
            USCHyb hyb = trainArray[i];
            if (!hyb.getHybLabel().equals(hybSet.getUniqueClass(classIndex))) continue;
            vHybInClass.add(new Integer(i));
        }
        int[] toReturn = new int[vHybInClass.size()];
        for (int i = 0; i < toReturn.length; ++i) {
            Integer I = (Integer)vHybInClass.elementAt(i);
            toReturn[i] = I;
        }
        return toReturn;
    }

    private int getNumClassHybsInUSCHybArray(USCHyb[] hybs, String label) {
        int kount = 0;
        for (int h = 0; h < hybs.length; ++h) {
            USCHyb hyb = hybs[h];
            if (!hyb.getHybLabel().equals(label)) continue;
            ++kount;
        }
        return kount;
    }

    private USCHyb[] getClassHybsInUSCHybArray(USCHyb[] hybs, String label) {
        USCHyb[] toReturn = new USCHyb[this.getNumClassHybsInUSCHybArray(hybs, label)];
        int kount = 0;
        for (int h = 0; h < toReturn.length; ++h) {
            USCHyb hyb = hybs[h];
            if (!hyb.getHybLabel().equals(label)) continue;
            toReturn[kount] = hyb;
            ++kount;
        }
        return toReturn;
    }

    private double computeCommonLog(double x) {
        double toReturn = 0.0;
        toReturn = Math.log(x) / Math.log(10.0);
        return toReturn;
    }
}

