/*
 * Decompiled with CFR 0.152.
 */
package org.tigr.microarray.mev.cluster.algorithm.impl.gsea;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Hashtable;
import java.util.LinkedHashMap;
import java.util.Random;
import java.util.Vector;
import javax.swing.JOptionPane;
import org.tigr.microarray.mev.cluster.algorithm.AbstractAlgorithm;
import org.tigr.microarray.mev.cluster.algorithm.AlgorithmData;
import org.tigr.microarray.mev.cluster.algorithm.AlgorithmException;
import org.tigr.microarray.mev.cluster.algorithm.impl.gsea.ProcessGroupAssignments;
import org.tigr.microarray.mev.cluster.gui.impl.util.MatrixFunctions;
import org.tigr.util.FloatMatrix;

public class GSEA
extends AbstractAlgorithm {
    private Random aRandom = new Random();
    private LinkedHashMap overEnrichedPVals = new LinkedHashMap();
    private LinkedHashMap underEnrichedPVals = new LinkedHashMap();
    private boolean stop = false;

    public void abort() {
        this.stop = true;
    }

    public AlgorithmData execute(AlgorithmData data) throws AlgorithmException {
        try {
            int num_perms = Integer.parseInt(data.getParams().getString("permutations"));
            if (this.stop) {
                return null;
            }
            this.gsealmPerm(num_perms, data, true);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return data;
    }

    private Hashtable<String, FloatMatrix> lmPerGene(AlgorithmData aData, FloatMatrix factor_matrix, boolean removeNA) {
        FloatMatrix gene_expression = aData.getGeneMatrix("gene-data-matrix");
        Hashtable<String, FloatMatrix> returnHash = new Hashtable<String, FloatMatrix>();
        FloatMatrix eSet = removeNA ? this.removeUnassignedSamples(aData.getVector("unassigned-samples"), gene_expression) : gene_expression;
        FloatMatrix x = factor_matrix;
        int nSamp = eSet.getColumnDimension();
        FloatMatrix xTranspose = x.transpose();
        FloatMatrix xx = xTranspose.times(x);
        FloatMatrix identity = new FloatMatrix(xx.getRowDimension(), xx.getColumnDimension());
        identity = FloatMatrix.identity((int)identity.getRowDimension(), (int)identity.getColumnDimension());
        try {
            FloatMatrix xxInv;
            try {
                xxInv = xx.solve(identity);
            }
            catch (Exception e) {
                return null;
            }
            FloatMatrix hMat = x.times(xxInv).times(xTranspose);
            FloatMatrix diagonalMatrix = this.createDiagonalMatrix(nSamp, nSamp, 1);
            FloatMatrix dMat = diagonalMatrix.minus(hMat);
            int col = x.getColumnDimension();
            FloatMatrix xy = eSet.times(x);
            FloatMatrix res = eSet.times(dMat);
            FloatMatrix beta = xx.solve(xy.transpose());
            FloatMatrix result = eSet.arrayTimes(res);
            float[] varr = this.getRowSumandSome(result, nSamp - col);
            Vector diagonal = this.returnDiagonal(xxInv);
            FloatMatrix varbeta = this.createMatrix(diagonal, eSet.getRowDimension(), col);
            int[] margin = new int[]{2};
            FloatMatrix temp = this.apply(varbeta, margin, varr, "apply-default");
            FloatMatrix coefvar = temp.transpose();
            returnHash.put("lmPerGene-coefficients", beta);
            returnHash.put("lmPerGene-coefvar", coefvar);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return returnHash;
    }

    public FloatMatrix createMatrix(Vector elements, int row, int col) {
        FloatMatrix customMatrix = new FloatMatrix(row, col);
        int index = 0;
        for (int j = 0; j < col; ++j) {
            for (int i = 0; i < row; ++i) {
                customMatrix.set(i, j, ((Float)elements.get(index)).floatValue());
            }
            if (index < elements.size()) {
                ++index;
                continue;
            }
            index = 0;
        }
        return customMatrix;
    }

    public Vector returnDiagonal(FloatMatrix fm) {
        Vector<Float> diagonalElements = new Vector<Float>();
        block0: for (int row = 0; row < fm.getRowDimension(); ++row) {
            for (int col = 0; col < fm.getColumnDimension(); ++col) {
                if (row != col) continue;
                diagonalElements.add(row, Float.valueOf(fm.get(row, col)));
                continue block0;
            }
        }
        return diagonalElements;
    }

    public FloatMatrix createDiagonalMatrix(int row, int col, int val) {
        int value = val;
        value = value == 0 ? 1 : val;
        FloatMatrix diagonalMatrix = row == 0 ? new FloatMatrix(col, col, (float)value) : (col == 0 ? new FloatMatrix(row, row, (float)value) : new FloatMatrix(row, col, (float)val));
        for (int rows = 0; rows < row; ++rows) {
            for (int cols = 0; cols < col; ++cols) {
                if (rows == cols) {
                    diagonalMatrix.set(rows, cols, (float)value);
                    continue;
                }
                diagonalMatrix.set(rows, cols, 0.0f);
            }
        }
        return diagonalMatrix;
    }

    public FloatMatrix removeUnassignedSamples(Vector unassigned, FloatMatrix geneExpression) {
        boolean colindex = false;
        int matrixIndex = 0;
        boolean column = false;
        FloatMatrix withoutNA = new FloatMatrix(geneExpression.getRowDimension(), geneExpression.getColumnDimension() - unassigned.size());
        int size = unassigned.size();
        for (int i = 0; i < geneExpression.getColumnDimension(); ++i) {
            if (unassigned.contains(i)) continue;
            withoutNA.setMatrix(0, withoutNA.getRowDimension() - 1, matrixIndex, matrixIndex, geneExpression.getMatrix(0, geneExpression.getRowDimension() - 1, i, i));
            if (matrixIndex >= withoutNA.getColumnDimension()) continue;
            ++matrixIndex;
        }
        return withoutNA;
    }

    public Vector getRowSums(FloatMatrix matrix) {
        Vector<Float> rowSums = new Vector<Float>();
        for (int index = 0; index < matrix.getRowDimension(); ++index) {
            float _tempVal = 0.0f;
            for (int col = 0; col < matrix.getColumnDimension(); ++col) {
                if (Float.isNaN(matrix.get(index, col))) continue;
                _tempVal += matrix.get(index, col);
            }
            rowSums.add(index, Float.valueOf(_tempVal));
        }
        return rowSums;
    }

    public float[] getRowSumandSome(FloatMatrix matrix, int val) {
        float[] _tempList = new float[matrix.getRowDimension()];
        Vector sum = this.getRowSums(matrix);
        for (int i = 0; i < sum.size(); ++i) {
            float value = ((Float)sum.get(i)).floatValue();
            _tempList[i] = value / (float)val;
        }
        return _tempList;
    }

    public FloatMatrix apply(FloatMatrix matrix, int[] margin, float[] varr, String function) {
        FloatMatrix tempMatrix = new FloatMatrix(matrix.getRowDimension(), matrix.getColumnDimension());
        if (margin.length == 1 && margin[0] == 2 && function.equalsIgnoreCase("apply-default")) {
            for (int col = 0; col < matrix.getColumnDimension(); ++col) {
                for (int row = 0; row < matrix.getRowDimension(); ++row) {
                    float temp = matrix.get(row, col) * varr[row];
                    tempMatrix.set(row, col, temp);
                }
            }
        }
        return tempMatrix;
    }

    public FloatMatrix Sweep(FloatMatrix matrix, int[] margin, Vector stats, String function) {
        FloatMatrix swept = new FloatMatrix(matrix.getRowDimension(), matrix.getColumnDimension());
        int statLength = stats.size();
        int prodMargin = margin.length == 2 ? margin[0] * margin[1] : margin[0];
        for (int col = 0; col < matrix.getColumnDimension(); ++col) {
            for (int row = 0; row < matrix.getRowDimension(); ++row) {
                if (!function.equalsIgnoreCase("divide_function")) continue;
                float val1 = matrix.get(row, col);
                float val2 = ((Float)stats.get(row)).floatValue();
                swept.set(row, col, val1 /= val2);
            }
        }
        return swept;
    }

    public FloatMatrix GSNormalize(FloatMatrix dataSet, FloatMatrix incidence, String GSEAfun, String func1, String func2, boolean removeShift, String removeStat) {
        Vector<Float> rowSums;
        int nRows;
        int nCol;
        Vector<Float> normBy = new Vector<Float>();
        FloatMatrix outM = null;
        if (removeShift) {
            // empty if block
        }
        if ((nCol = incidence.getColumnDimension()) != (nRows = dataSet.getRowDimension())) {
            String eMsg = "<html>GSNormalize: Nonconforming matrices</html>";
            JOptionPane.showMessageDialog(null, eMsg, "Error", 0);
        }
        if (GSEAfun.equalsIgnoreCase("cross_prod") && func1.equalsIgnoreCase("divide_function") && func2.equalsIgnoreCase("sqrt")) {
            rowSums = new Vector<Float>();
            Vector<Float> sqrtVector = new Vector<Float>();
            outM = incidence.times(dataSet);
            MatrixFunctions MatrixFunc = new MatrixFunctions();
            rowSums = MatrixFunc.getRowSums(incidence);
            for (int index = 0; index < rowSums.size(); ++index) {
                float temp = ((Float)rowSums.get(index)).floatValue();
                float sqrt = (float)Math.sqrt(temp);
                sqrtVector.add(index, Float.valueOf(sqrt));
            }
            normBy = sqrtVector;
            int[] margin = new int[]{1};
            outM = this.Sweep(outM, margin, normBy, func1);
        }
        if (GSEAfun.equalsIgnoreCase("cross_prod") && func1.equalsIgnoreCase("divide_function") && func2.equalsIgnoreCase("identity")) {
            outM = dataSet.transpose().times(incidence);
            rowSums = new Vector();
            normBy = rowSums = this.getRowSums(outM);
            int[] margin = new int[]{1};
            outM = this.Sweep(outM, margin, normBy, func1);
        }
        return outM;
    }

    private void gsealmPerm(int num_perms, AlgorithmData adata, boolean removeNA) throws Exception {
        FloatMatrix perms = null;
        String[] factorNames = adata.getStringArray("factor-names");
        int[] factorlevels = adata.getIntArray("factor-levels");
        int[][] factorAssignments = adata.getIntMatrix("factor-assignments");
        int nSamp = adata.getGeneMatrix("gene-data-matrix").getColumnDimension();
        ProcessGroupAssignments pg = new ProcessGroupAssignments(factorNames, factorlevels, factorAssignments, true, nSamp);
        pg.findUnassignedSamples(factorNames, factorlevels, factorAssignments);
        adata.addVector("unassigned-samples", pg.getUnassignedColumns());
        FloatMatrix factor_matrix = pg.generateFactorMatrix(factorNames, factorlevels, factorAssignments);
        try {
            Hashtable<String, FloatMatrix> tempHash = this.lmPerGene(adata, factor_matrix, true);
            FloatMatrix coefficients = tempHash.get("lmPerGene-coefficients");
            FloatMatrix coefVar = tempHash.get("lmPerGene-coefvar");
            FloatMatrix amat = adata.getGeneMatrix("association-matrix");
            adata.addGeneMatrix("lmPerGene-coefficients", coefficients);
            adata.addGeneMatrix("lmPerGene-coefvar", coefVar);
            FloatMatrix coef_intermediate = coefficients.getMatrix(1, 1, 0, coefficients.getColumnDimension() - 1);
            FloatMatrix coefVar_intermediate = coefVar.getMatrix(1, 1, 0, coefVar.getColumnDimension() - 1);
            FloatMatrix sqrtCoefVar = new FloatMatrix(coefVar_intermediate.getRowDimension(), coefVar_intermediate.getColumnDimension());
            for (int col = 0; col < coefVar_intermediate.getColumnDimension(); ++col) {
                sqrtCoefVar.set(0, col, (float)Math.sqrt(coefVar_intermediate.get(0, col)));
            }
            FloatMatrix result1 = coef_intermediate.arrayRightDivide(sqrtCoefVar);
            FloatMatrix observedStats = this.GSNormalize(result1.transpose(), amat, "cross_prod", "divide_function", "sqrt", false, null);
            FloatMatrix permMat = new FloatMatrix(adata.getGeneMatrix("gene-data-matrix").getRowDimension(), num_perms);
            for (int index = 0; index < num_perms; ++index) {
                Hashtable resultHash = this.findOriginalClassOrder(factorNames, factorlevels, factorAssignments, nSamp);
                Hashtable permutedFactorHash = this.generatePermutedFactorHash((Hashtable)resultHash.get("original-order"), (int[])resultHash.get("permuted-order"), factorNames, factorlevels, factorAssignments, nSamp);
                int[][] permutedFactorAssignments = (int[][])permutedFactorHash.get("permuted-factor-assignments");
                ProcessGroupAssignments pga = new ProcessGroupAssignments(factorNames, factorlevels, permutedFactorAssignments, true, nSamp);
                pga.findUnassignedSamples(factorNames, factorlevels, factorAssignments);
                FloatMatrix factor_matrix_new = pga.generateFactorMatrix(factorNames, factorlevels, permutedFactorAssignments);
                Hashtable<String, FloatMatrix> lmPerGeneresultHash = this.lmPerGene(adata, factor_matrix_new, true);
                if (lmPerGeneresultHash == null) {
                    --index;
                    continue;
                }
                FloatMatrix pCoefficients = lmPerGeneresultHash.get("lmPerGene-coefficients");
                FloatMatrix pCoefVar = lmPerGeneresultHash.get("lmPerGene-coefvar");
                FloatMatrix pcoef_intermediate = pCoefficients.getMatrix(1, 1, 0, pCoefficients.getColumnDimension() - 1);
                FloatMatrix pcoefVar_intermediate = pCoefVar.getMatrix(1, 1, 0, pCoefVar.getColumnDimension() - 1);
                FloatMatrix psqrtCoefVar = new FloatMatrix(pcoefVar_intermediate.getRowDimension(), pcoefVar_intermediate.getColumnDimension());
                for (int col = 0; col < pcoefVar_intermediate.getColumnDimension(); ++col) {
                    psqrtCoefVar.set(0, col, (float)Math.sqrt(pcoefVar_intermediate.get(0, col)));
                }
                FloatMatrix result2 = pcoef_intermediate.arrayRightDivide(psqrtCoefVar);
                permMat.setMatrix(0, permMat.getRowDimension() - 1, index, index, result2.transpose());
            }
            perms = this.GSNormalize(permMat, amat, "cross_prod", "divide_function", "sqrt", false, null);
            Vector geneSetNames = adata.getVector("gene-set-names");
            this.pValFromPermMat(observedStats, perms, geneSetNames, adata);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void pValFromPermMat(FloatMatrix obsStats, FloatMatrix permMat, Vector<String> geneSetNames, AlgorithmData data) {
        int row;
        float[] lower_rowSums = new float[permMat.getRowDimension()];
        float[] upper_rowSums = new float[permMat.getRowDimension()];
        HashMap<String, Float> tempHash = new HashMap<String, Float>();
        int nCols = permMat.getColumnDimension();
        FloatMatrix tempObs = new FloatMatrix(permMat.getRowDimension(), permMat.getColumnDimension());
        for (int i = 0; i < tempObs.getColumnDimension(); ++i) {
            tempObs.setMatrix(0, tempObs.getRowDimension() - 1, i, i, obsStats);
        }
        for (row = 0; row < permMat.getRowDimension(); ++row) {
            int num_lower_pVals = 0;
            int num_upper_pVals = 0;
            for (int column = 0; column < permMat.getColumnDimension(); ++column) {
                if (permMat.get(row, column) >= tempObs.get(row, column)) {
                    ++num_upper_pVals;
                    continue;
                }
                if (!(permMat.get(row, column) <= tempObs.get(row, column))) continue;
                ++num_lower_pVals;
            }
            lower_rowSums[row] = num_lower_pVals;
            upper_rowSums[row] = num_upper_pVals;
        }
        for (row = 0; row < permMat.getRowDimension(); ++row) {
            tempHash.put(geneSetNames.get(row), Float.valueOf(upper_rowSums[row] / (float)nCols));
        }
        this.overEnrichedPVals = this.sortHashMapByValues(tempHash);
        tempHash.clear();
        tempHash = new HashMap();
        for (row = 0; row < permMat.getRowDimension(); ++row) {
            tempHash.put(geneSetNames.get(row), Float.valueOf(lower_rowSums[row] / (float)nCols));
        }
        this.underEnrichedPVals = this.sortHashMapByValues(tempHash);
        tempHash.clear();
        data.addMappings("over-enriched", this.overEnrichedPVals);
        data.addMappings("under-enriched", this.underEnrichedPVals);
    }

    public LinkedHashMap sortHashMapByValues(HashMap passedMap) {
        ArrayList mapKeys = new ArrayList(passedMap.keySet());
        ArrayList mapValues = new ArrayList(passedMap.values());
        Collections.sort(mapValues);
        Collections.sort(mapKeys);
        LinkedHashMap<String, Float> sortedMap = new LinkedHashMap<String, Float>();
        block0: for (Object val : mapValues) {
            for (Object key : mapKeys) {
                String comp2;
                String comp1 = passedMap.get(key).toString();
                if (!comp1.equals(comp2 = val.toString())) continue;
                passedMap.remove(key);
                mapKeys.remove(key);
                sortedMap.put((String)key, (Float)val);
                continue block0;
            }
        }
        return sortedMap;
    }

    public Hashtable findOriginalClassOrder(String[] factorNames, int[] factorLevels, int[][] factorAssignments, int num_Samples) {
        int i;
        Hashtable<Integer, String> origOrder = new Hashtable<Integer, String>();
        Hashtable permOrder = new Hashtable();
        Hashtable<String, Object> resultHash = new Hashtable<String, Object>();
        if (factorNames.length == 1) {
            int[] factorAssignment = factorAssignments[0];
            Vector<Integer> assign = new Vector<Integer>();
            for (i = 0; i < num_Samples; ++i) {
                assign.add(i);
                origOrder.put(i, "one-factor-class");
            }
            permOrder.put("one-factor-class", assign);
        }
        for (int index = 0; index < num_Samples; ++index) {
            if (factorNames.length < 2) continue;
            String classAssignment = new String("CLASS-");
            for (i = 1; i < factorNames.length; ++i) {
                classAssignment = factorAssignments[i][index] != 0 ? classAssignment + Integer.toString(factorAssignments[i][index]) : new String("CLASS-unknown");
            }
            if (!permOrder.containsKey(classAssignment = classAssignment.trim())) {
                Vector<Integer> temp = new Vector<Integer>();
                temp.add(index);
                permOrder.put(classAssignment, temp);
            } else {
                ((Vector)permOrder.get(classAssignment)).add(index);
            }
            if (origOrder.containsKey(index)) continue;
            origOrder.put(index, classAssignment);
        }
        Enumeration keys = permOrder.keys();
        int[] permutedSampleAssignment = new int[num_Samples];
        while (keys.hasMoreElements()) {
            String key = (String)keys.nextElement();
            Vector value = (Vector)permOrder.get(key);
            ArrayList permutedArray = new ArrayList(value.size());
            permutedArray = this.getPermutedValues(value);
            for (int i2 = 0; i2 < value.size(); ++i2) {
                int temp = (Integer)value.get(i2);
                permutedSampleAssignment[temp] = (Integer)permutedArray.get(i2);
            }
            permOrder.remove(key);
        }
        permOrder.clear();
        resultHash.put("original-order", origOrder);
        resultHash.put("permuted-order", permutedSampleAssignment);
        return resultHash;
    }

    public ArrayList getPermutedValues(Vector values) {
        ArrayList<Integer> permutedValidArray = new ArrayList<Integer>(values.size());
        for (int i = 0; i < values.size(); ++i) {
            permutedValidArray.add(i, (Integer)values.get(i));
        }
        boolean aStart = false;
        int aEnd = values.size() - 1;
        for (int i = permutedValidArray.size(); i > 1; --i) {
            int randomNumber = this.aRandom.nextInt(i - 1);
            int temp = (Integer)permutedValidArray.get(randomNumber);
            permutedValidArray.set(randomNumber, (Integer)permutedValidArray.get(i - 1));
            permutedValidArray.set(i - 1, temp);
        }
        return permutedValidArray;
    }

    public Hashtable generatePermutedFactorHash(Hashtable origOrder, int[] permOrder, String[] factorNames, int[] factorLevels, int[][] factorAssignments, int num_samples) {
        Hashtable<String, int[][]> permutedFactorHash = new Hashtable<String, int[][]>();
        int[] rowVector = new int[1];
        int[][] permutedFactorAssignments = new int[factorNames.length][num_samples];
        if (factorNames.length >= 1) {
            rowVector = factorAssignments[0];
            for (int index = 1; index < factorNames.length; ++index) {
                permutedFactorAssignments[index] = factorAssignments[index];
            }
            int[] permutedSampleOrder = permOrder;
            for (int index = 0; index < num_samples; ++index) {
                permutedFactorAssignments[0][index] = rowVector[permutedSampleOrder[index]];
            }
        }
        permutedFactorHash.put("permuted-factor-assignments", permutedFactorAssignments);
        return permutedFactorHash;
    }

    public static void main(String[] args) {
        FloatMatrix m = new FloatMatrix(4, 3);
        FloatMatrix geneset = new FloatMatrix(2, 4);
        float[][] _temp = new float[][]{{1.0f, 1.0f, 1.0f, 0.0f}, {1.0f, 0.0f, 1.0f, 0.0f}};
        for (int i = 0; i < 2; ++i) {
            for (int j = 0; j < 4; ++j) {
                geneset.set(i, j, _temp[i][j]);
            }
        }
        int index = 1;
        for (int col = 0; col < 3; ++col) {
            for (int row = 0; row < 4; ++row) {
                m.set(row, col, (float)index);
                ++index;
            }
        }
        System.out.println("matrix :");
        for (int row = 0; row < 4; ++row) {
            for (int col = 0; col < 3; ++col) {
                System.out.print(m.get(row, col));
                System.out.print('\t');
            }
            System.out.println();
        }
        System.out.println("matrix Ends:");
        System.out.println("printing geneset");
        for (int i = 0; i < 2; ++i) {
            for (int j = 0; j < 4; ++j) {
                System.out.print(geneset.get(i, j));
                System.out.print('\t');
            }
            System.out.println();
        }
        System.out.println("printing gene set ends");
        GSEA gsea = new GSEA();
        FloatMatrix _gSNormRes = gsea.GSNormalize(m, geneset, "cross_prod", "divide_function", "sqrt", false, null);
        System.out.println("printing GSNormalize results");
        for (int i = 0; i < _gSNormRes.getRowDimension(); ++i) {
            for (int j = 0; j < _gSNormRes.getColumnDimension(); ++j) {
                System.out.print(_gSNormRes.get(i, j));
                System.out.print('\t');
            }
            System.out.println();
        }
    }
}

