/*
 * Decompiled with CFR 0.152.
 */
package ffx.potential.utils;

import ffx.numerics.clustering.Cluster;
import ffx.numerics.clustering.DefaultClusteringAlgorithm;
import ffx.numerics.clustering.LinkageStrategy;
import ffx.numerics.clustering.SingleLinkageStrategy;
import ffx.potential.utils.PotentialsUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.math3.ml.clustering.CentroidCluster;
import org.apache.commons.math3.ml.clustering.Clusterable;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.clustering.MultiKMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;

public class Clustering {
    private static final Logger log = Logger.getLogger(PotentialsUtils.class.getName());
    private static final int NUM_ITERATIONS = 10000;

    public static List<CentroidCluster<Conformation>> kMeansClustering(List<double[]> distMatrix, int maxClusters, int numTrials, long seed) {
        int dim = distMatrix.size();
        ArrayList<Conformation> conformationList = new ArrayList<Conformation>();
        for (int i = 0; i < dim; ++i) {
            double[] row = distMatrix.get(i);
            if (row.length != dim) {
                log.severe(String.format(" Row %d of the distance matrix (%d x %d) has %d columns.", i, dim, dim, row.length));
            }
            conformationList.add(new Conformation(row, i));
        }
        KMeansPlusPlusClusterer kMeansPlusPlusClusterer = new KMeansPlusPlusClusterer(maxClusters, 10000);
        RandomGenerator randomGenerator = kMeansPlusPlusClusterer.getRandomGenerator();
        randomGenerator.setSeed(seed);
        MultiKMeansPlusPlusClusterer multiKMeansPlusPlusClusterer = new MultiKMeansPlusPlusClusterer(kMeansPlusPlusClusterer, numTrials);
        return multiKMeansPlusPlusClusterer.cluster(conformationList);
    }

    public static List<CentroidCluster<Conformation>> hierarchicalClustering(List<double[]> distanceMatrix, double threshold) {
        int dim = distanceMatrix.size();
        double[][] distMatrixArray = new double[dim][];
        String[] names = new String[dim];
        for (int i = 0; i < dim; ++i) {
            distMatrixArray[i] = distanceMatrix.get(i);
            names[i] = Integer.toString(i);
        }
        DefaultClusteringAlgorithm clusteringAlgorithm = new DefaultClusteringAlgorithm();
        List clusterList = clusteringAlgorithm.performFlatClustering((double[][])distMatrixArray, names, (LinkageStrategy)new SingleLinkageStrategy(), Double.valueOf(threshold));
        return Clustering.clustersToCentroidClusters(distMatrixArray, clusterList);
    }

    public static List<CentroidCluster<Conformation>> iterativeClustering(List<double[]> distMatrix, int trials, double tolerance) {
        int dim = distMatrix.size();
        ArrayList<Object> bestClusters = new ArrayList<CentroidCluster<Conformation>>();
        for (int i = 0; i < trials; ++i) {
            ArrayList<Integer> remaining = new ArrayList<Integer>();
            for (int j = 0; j < dim; ++j) {
                remaining.add(j);
            }
            ArrayList<CentroidCluster> clusters = new ArrayList<CentroidCluster>();
            while (!remaining.isEmpty()) {
                int seed = (int)FastMath.floor((double)(FastMath.random() * (double)(remaining.size() - 1)));
                int index = (Integer)remaining.get(seed);
                CentroidCluster cluster = new CentroidCluster((Clusterable)new Conformation(distMatrix.get(index), index));
                double[] row = distMatrix.get(index);
                if (log.isLoggable(Level.FINER)) {
                    log.finer(String.format(" Remaining clusters: %3d of %3d", remaining.size(), dim));
                    log.finer(String.format("  Row: %3d (seed %3d) has %3d entries.", index + 1, seed, row.length));
                }
                if (row.length != dim) {
                    log.severe(String.format(" Row %d of the distance matrix (%d x %d) has %d columns.", index, dim, dim, row.length));
                }
                for (int j = 0; j < dim; ++j) {
                    if (!(row[j] < tolerance) || !remaining.contains(j)) continue;
                    cluster.addPoint((Clusterable)new Conformation(distMatrix.get(j), j));
                    if (remaining.remove((Object)j)) continue;
                    log.warning(String.format(" Row %3d matched %3d, but could not be removed.", j + 1, index + 1));
                }
                clusters.add(cluster);
            }
            if (!bestClusters.isEmpty() && clusters.size() >= bestClusters.size()) continue;
            if (log.isLoggable(Level.FINE)) {
                int numStructs = 0;
                for (CentroidCluster cluster : clusters) {
                    numStructs += cluster.getPoints().size();
                }
                log.fine(String.format(" New Best: Num clusters: %3d Num Structs: %3d ", clusters.size(), numStructs));
            }
            bestClusters = new ArrayList(clusters);
        }
        return bestClusters;
    }

    public static void analyzeClusters(List<CentroidCluster<Conformation>> clusters, List<Integer> repStructs, boolean verbose) {
        int nClusters = clusters.size();
        double meanClusterRMSD = 0.0;
        for (int i = 0; i < nClusters; ++i) {
            CentroidCluster<Conformation> clusterI = clusters.get(i);
            List conformations = clusterI.getPoints();
            int nConformers = conformations.size();
            StringBuilder sb = nConformers == 1 ? new StringBuilder(String.format(" Cluster %d with conformation:", i + 1)) : new StringBuilder(String.format(" Cluster %d with %d conformations\n  Conformations:", i + 1, nConformers));
            double minRMS = Double.MAX_VALUE;
            int minID = -1;
            for (Conformation conformation : conformations) {
                int row = conformation.index;
                sb.append(String.format(" %d", row + 1));
                if (nConformers > 1) {
                    double[] rmsd = conformation.getPoint();
                    double wrms = 0.0;
                    for (Conformation conformation2 : conformations) {
                        int col = conformation2.index;
                        if (col == row) continue;
                        wrms += rmsd[col] * rmsd[col];
                    }
                    if (!((wrms = FastMath.sqrt((double)(wrms / (double)(nConformers - 1)))) < minRMS)) continue;
                    minRMS = wrms;
                    minID = conformation.getIndex();
                    continue;
                }
                minID = row;
                minRMS = 0.0;
            }
            if (nConformers > 1) {
                double clusterRMSD = Clustering.clusterRMSD(conformations);
                meanClusterRMSD += clusterRMSD;
                sb.append(String.format("\n  RMSD within the cluster:\t %6.4f A.\n", clusterRMSD));
                sb.append(String.format("  Minimum RMSD conformer %d:\t %6.4f A.\n", minID + 1, minRMS));
            } else {
                sb.append("\n");
            }
            repStructs.add(minID);
            if (!verbose) continue;
            log.info(sb.toString());
        }
        if (verbose) {
            log.info(String.format(" Mean RMSD within clusters: \t %6.4f A.", meanClusterRMSD / (double)nClusters));
            double sumOfClusterVariances = Clustering.sumOfClusterVariances(clusters);
            log.info(String.format(" Mean cluster variance:     \t %6.4f A.\n", sumOfClusterVariances / (double)nClusters));
        }
    }

    private static double clusterRMSD(List<Conformation> conformations) {
        int nConformers = conformations.size();
        if (nConformers == 1) {
            return 0.0;
        }
        double sum = 0.0;
        int count = 0;
        for (int j = 0; j < nConformers; ++j) {
            Conformation conformation = conformations.get(j);
            double[] rmsd = conformation.rmsd;
            for (int k = j + 1; k < nConformers; ++k) {
                Conformation conformation2 = conformations.get(k);
                int col = conformation2.index;
                sum += rmsd[col] * rmsd[col];
                ++count;
            }
        }
        return FastMath.sqrt((double)(sum / (double)count));
    }

    private static double sumOfClusterVariances(List<CentroidCluster<Conformation>> clusters) {
        SumOfClusterVariances sumOfClusterVariances = new SumOfClusterVariances((DistanceMeasure)new EuclideanDistance());
        return sumOfClusterVariances.score(clusters);
    }

    private static List<CentroidCluster<Conformation>> clustersToCentroidClusters(double[][] distMatrixArray, List<Cluster> clusterList) {
        ArrayList<CentroidCluster<Conformation>> centroidClusters = new ArrayList<CentroidCluster<Conformation>>();
        int dim = distMatrixArray.length;
        for (Cluster cluster : clusterList) {
            ArrayList<String> names = new ArrayList<String>();
            Clustering.collectNames(cluster, names);
            ArrayList<Conformation> conformations = new ArrayList<Conformation>();
            for (String name : names) {
                int index = Integer.parseInt(name);
                conformations.add(new Conformation(distMatrixArray[index], index));
            }
            Conformation centroid = Clustering.centroidOf(conformations, dim);
            CentroidCluster centroidCluster = new CentroidCluster((Clusterable)centroid);
            centroidClusters.add((CentroidCluster<Conformation>)centroidCluster);
            for (Conformation conformation : conformations) {
                centroidCluster.addPoint((Clusterable)conformation);
            }
        }
        return centroidClusters;
    }

    private static void collectNames(Cluster cluster, List<String> names) {
        if (cluster.isLeaf()) {
            names.add(cluster.getName());
        } else {
            for (Cluster c : cluster.getChildren()) {
                Clustering.collectNames(c, names);
            }
        }
    }

    private static Conformation centroidOf(List<Conformation> points, int dimension) {
        double[] centroid = new double[dimension];
        for (Conformation p : points) {
            double[] point = p.getPoint();
            for (int i = 0; i < centroid.length; ++i) {
                int n = i;
                centroid[n] = centroid[n] + point[i];
            }
        }
        int i = 0;
        while (i < centroid.length) {
            int n = i++;
            centroid[n] = centroid[n] / (double)points.size();
        }
        return new Conformation(centroid, 0);
    }

    public static class Conformation
    implements Clusterable {
        private final double[] rmsd;
        private final int index;

        Conformation(double[] rmsd, int index) {
            this.rmsd = rmsd;
            this.index = index;
        }

        public double[] getPoint() {
            return this.rmsd;
        }

        int getIndex() {
            return this.index;
        }
    }
}

