View Javadoc
1   //******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2024.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  //******************************************************************************
38  package ffx.potential.utils;
39  
40  import static java.lang.String.format;
41  import static org.apache.commons.math3.util.FastMath.sqrt;
42  import static org.apache.commons.math3.util.FastMath.floor;
43  import static org.apache.commons.math3.util.FastMath.random;
44  
45  import com.apporiented.algorithm.clustering.Cluster;
46  import com.apporiented.algorithm.clustering.ClusteringAlgorithm;
47  import com.apporiented.algorithm.clustering.DefaultClusteringAlgorithm;
48  import com.apporiented.algorithm.clustering.SingleLinkageStrategy;
49  import java.util.ArrayList;
50  import java.util.List;
51  import java.util.logging.Level;
52  import java.util.logging.Logger;
53  import org.apache.commons.math3.ml.clustering.CentroidCluster;
54  import org.apache.commons.math3.ml.clustering.Clusterable;
55  import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
56  import org.apache.commons.math3.ml.clustering.MultiKMeansPlusPlusClusterer;
57  import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances;
58  import org.apache.commons.math3.ml.distance.EuclideanDistance;
59  import org.apache.commons.math3.random.RandomGenerator;
60  
61  /**
62   * Cluster contains methods utilized in the <code>Cluster.groovy</code> file.
63   *
64   * @author Aaron J. Nessler
65   * @author Michael J. Schnieders
66   */
67  public class Clustering {
68  
69    private static final Logger log = Logger.getLogger(PotentialsUtils.class.getName());
70  
71    /**
72     * Number of iterations for k-means clustering
73     */
74    private static final int NUM_ITERATIONS = 10000;
75  
76    /**
77     * Perform a k-means clustering for a specified number of clusters.
78     *
79     * @param distMatrix Coordinate input serves as the data points.
80     * @param maxClusters Number of clusters to use (k).
81     * @param numTrials Number of trials for the Multi K-Means++ algorithm.
82     * @param seed The seed to use for clustering (-1 uses the current system time).
83     * @return The clusters.
84     */
85    public static List<CentroidCluster<Conformation>> kMeansClustering(List<double[]> distMatrix,
86        int maxClusters, int numTrials, long seed) {
87      // Square distance matrix size (dim x dim).
88      int dim = distMatrix.size();
89      List<Conformation> conformationList = new ArrayList<>();
90      for (int i = 0; i < dim; i++) {
91        double[] row = distMatrix.get(i);
92        // Check that the input data is appropriate.
93        if (row.length != dim) {
94          log.severe(format(" Row %d of the distance matrix (%d x %d) has %d columns.", i, dim, dim,
95              row.length));
96        }
97        conformationList.add(new Conformation(row, i));
98      }
99  
100     // Input the RMSD matrix to the clustering algorithm
101     // Use the org.apache.commons.math3.ml.clustering package.
102     KMeansPlusPlusClusterer<Conformation> kMeansPlusPlusClusterer = new KMeansPlusPlusClusterer<>(
103         maxClusters, NUM_ITERATIONS);
104     // Set the random seed for deterministic clustering.
105     RandomGenerator randomGenerator = kMeansPlusPlusClusterer.getRandomGenerator();
106     randomGenerator.setSeed(seed);
107     // Create a MultiKMeansPlusPlusClusterer
108     MultiKMeansPlusPlusClusterer<Conformation> multiKMeansPlusPlusClusterer = new MultiKMeansPlusPlusClusterer<>(
109         kMeansPlusPlusClusterer, numTrials);
110 
111     // Perform the clustering.
112     return multiKMeansPlusPlusClusterer.cluster(conformationList);
113   }
114 
115   /**
116    * This method performs hierarchical clustering on a distance matrix. If the system isn't headless,
117    * a dendrogram is printed of the clustered results. A PDB file for the centroid of each cluster is
118    * saved.
119    *
120    * @param distanceMatrix A List of double[] entries that holds the distance matrix.
121    * @param threshold the distance used to separate clusters.
122    * @return Return a list of CentroidClusters.
123    */
124   public static List<CentroidCluster<Conformation>> hierarchicalClustering(
125       List<double[]> distanceMatrix, double threshold) {
126 
127     // Convert the distanceMatrix to a double[][] for the clustering algorithm.
128     int dim = distanceMatrix.size();
129 
130     double[][] distMatrixArray = new double[dim][];
131     String[] names = new String[dim];
132     for (int i = 0; i < dim; i++) {
133       distMatrixArray[i] = distanceMatrix.get(i);
134       names[i] = Integer.toString(i);
135     }
136 
137     // Cluster the data.
138     ClusteringAlgorithm clusteringAlgorithm = new DefaultClusteringAlgorithm();
139     List<Cluster> clusterList = clusteringAlgorithm.performFlatClustering(distMatrixArray, names,
140         new SingleLinkageStrategy(), threshold);
141 
142     // Convert the cluster format to CentroidClusters
143     return clustersToCentroidClusters(distMatrixArray, clusterList);
144   }
145 
146   /**
147    * Perform an iterative clustering for a specified number of clusters. Designed and tested by
148    * researchers at Takeda (see method authors).
149    *
150    * @param distMatrix Coordinate input serves as the data points.
151    * @param trials Number of iterations to perform clustering.
152    * @param tolerance RMSD cutoff to divide same values from different.
153    * @return The clusters.
154    * Created by:
155    * @author Yuya, Kinoshita
156    * @author Koki, Nishimura
157    * @author Masatoshi, Karashima
158    * Implemented by:
159    * @author Aaron J. Nessler
160    */
161   public static List<CentroidCluster<Conformation>> iterativeClustering(List<double[]> distMatrix,
162       int trials, double tolerance) {
163     // Square distance matrix size (dim x dim).
164     int dim = distMatrix.size();
165     ArrayList<CentroidCluster<Conformation>> bestClusters = new ArrayList<>();
166     for (int i = 0; i < trials; i++) {
167 
168       ArrayList<Integer> remaining = new ArrayList<>();
169       for (int j = 0; j < dim; j++) {
170         remaining.add(j);
171       }
172       List<CentroidCluster<Conformation>> clusters = new ArrayList<>();
173       while (!remaining.isEmpty()) {
174         int seed = (int) floor(random() * (remaining.size() - 1));
175         int index = remaining.get(seed);
176         CentroidCluster<Conformation> cluster = new CentroidCluster<>(
177             new Conformation(distMatrix.get(index), index));
178         double[] row = distMatrix.get(index);
179         // Check that the input data is complete.
180         if (log.isLoggable(Level.FINER)) {
181           log.finer(format(" Remaining clusters: %3d of %3d", remaining.size(), dim));
182           log.finer(format("  Row: %3d (seed %3d) has %3d entries.", index + 1, seed, row.length));
183         }
184         if (row.length != dim) {
185           log.severe(
186               format(" Row %d of the distance matrix (%d x %d) has %d columns.", index, dim, dim,
187                   row.length));
188         }
189         for (int j = 0; j < dim; j++) {
190           if (row[j] < tolerance && remaining.contains(j)) {
191             cluster.addPoint(new Conformation(distMatrix.get(j), j));
192             if (!remaining.remove((Integer) j)) {
193               log.warning(
194                   format(" Row %3d matched %3d, but could not be removed.", j + 1, index + 1));
195             }
196           }
197         }
198         clusters.add(cluster);
199       }
200       if (bestClusters.isEmpty() || clusters.size() < bestClusters.size()) {
201         if (log.isLoggable(Level.FINE)) {
202           int numStructs = 0;
203           for (CentroidCluster<Conformation> cluster : clusters) {
204             numStructs += cluster.getPoints().size();
205           }
206           log.fine(
207               format(" New Best: Num clusters: %3d Num Structs: %3d ", clusters.size(), numStructs));
208         }
209         bestClusters = new ArrayList<>(clusters);
210       }
211     }
212     return bestClusters;
213   }
214 
215   /**
216    * Analyze a list of CentroidClusters.
217    *
218    * @param clusters The List of CentroidClusters to analyze.
219    * @param repStructs Store a representative conformation for each cluster.
220    * @param verbose If true, use verbose printing.
221    */
222   public static void analyzeClusters(List<CentroidCluster<Conformation>> clusters,
223       List<Integer> repStructs, boolean verbose) {
224     // Number of clusters.
225     int nClusters = clusters.size();
226     double meanClusterRMSD = 0.0;
227 
228     // Loop over clusters
229     for (int i = 0; i < nClusters; i++) {
230       CentroidCluster<Conformation> clusterI = clusters.get(i);
231 
232       List<Conformation> conformations = clusterI.getPoints();
233       int nConformers = conformations.size();
234       StringBuilder sb;
235       if (nConformers == 1) {
236         sb = new StringBuilder(format(" Cluster %d with conformation:", i + 1));
237       } else {
238         sb = new StringBuilder(
239             format(" Cluster %d with %d conformations\n  Conformations:", i + 1, nConformers));
240       }
241 
242       double minRMS = Double.MAX_VALUE;
243       int minID = -1;
244       for (Conformation conformation : conformations) {
245         int row = conformation.index;
246         sb.append(format(" %d", row + 1));
247         if (nConformers > 1) {
248           // Get a row of the RMSD matrix.
249           double[] rmsd = conformation.getPoint();
250           // Calculate the sum of squares.
251           double wrms = 0;
252           for (Conformation conformation2 : conformations) {
253             int col = conformation2.index;
254             if (col == row) {
255               continue;
256             }
257             wrms += rmsd[col] * rmsd[col];
258           }
259           // Calculate the root mean sum of squares distance within the cluster.
260           wrms = sqrt(wrms / (nConformers - 1));
261           if (wrms < minRMS) {
262             minRMS = wrms;
263             minID = conformation.getIndex();
264           }
265         } else {
266           // Only 1 conformer in this cluster.
267           minID = row;
268           minRMS = 0.0;
269         }
270       }
271 
272       // Calculate the RMSD within the cluster.
273       if (nConformers > 1) {
274         double clusterRMSD = clusterRMSD(conformations);
275         meanClusterRMSD += clusterRMSD;
276         sb.append(format("\n  RMSD within the cluster:\t %6.4f A.\n", clusterRMSD));
277 
278         // minID contains the index for the representative conformer.
279         sb.append(format("  Minimum RMSD conformer %d:\t %6.4f A.\n", minID + 1, minRMS));
280       } else {
281         sb.append("\n");
282       }
283 
284       repStructs.add(minID);
285 
286       if (verbose) {
287         log.info(sb.toString());
288       }
289     }
290 
291     if (verbose) {
292       log.info(format(" Mean RMSD within clusters: \t %6.4f A.", meanClusterRMSD / nClusters));
293       double sumOfClusterVariances = sumOfClusterVariances(clusters);
294       log.info(
295           format(" Mean cluster variance:     \t %6.4f A.\n", sumOfClusterVariances / nClusters));
296     }
297 
298   }
299 
300   /**
301    * Compute the RMSD of one cluster.
302    *
303    * @param conformations Conformers for this cluster.
304    * @return The RMSD for the cluster.
305    */
306   private static double clusterRMSD(List<Conformation> conformations) {
307     int nConformers = conformations.size();
308 
309     // If there is only 1 conformer, the RMSD is 0.
310     if (nConformers == 1) {
311       return 0.0;
312     }
313 
314     // Calculate the RMSD within the cluster.
315     double sum = 0.0;
316     int count = 0;
317     for (int j = 0; j < nConformers; j++) {
318       Conformation conformation = conformations.get(j);
319       double[] rmsd = conformation.rmsd;
320       for (int k = j + 1; k < nConformers; k++) {
321         Conformation conformation2 = conformations.get(k);
322         int col = conformation2.index;
323         sum += rmsd[col] * rmsd[col];
324         count++;
325       }
326     }
327 
328     return sqrt(sum / count);
329   }
330 
331   /**
332    * Compute the Sum of Cluster Variances.
333    *
334    * @param clusters The cluster to operate on.
335    * @return The sum of cluster variances.
336    */
337   private static double sumOfClusterVariances(List<CentroidCluster<Conformation>> clusters) {
338     SumOfClusterVariances<Conformation> sumOfClusterVariances = new SumOfClusterVariances<>(
339         new EuclideanDistance());
340     return sumOfClusterVariances.score(clusters);
341   }
342 
343   /**
344    * Convert clusters defined by a List of Strings to Apache Math style CentroidClusters.
345    *
346    * @param distMatrixArray Distance matrix.
347    * @param clusterList Input List of Clusters.
348    * @return Return a List of CentroidClusters.
349    */
350   private static List<CentroidCluster<Conformation>> clustersToCentroidClusters(
351       double[][] distMatrixArray, List<Cluster> clusterList) {
352 
353     List<CentroidCluster<Conformation>> centroidClusters = new ArrayList<>();
354     int dim = distMatrixArray.length;
355 
356     // Loop over clusters defined by lists of Strings.
357     for (Cluster cluster : clusterList) {
358 
359       List<String> names = new ArrayList<>();
360       collectNames(cluster, names);
361 
362       // Collect conformations for this cluster.
363       List<Conformation> conformations = new ArrayList<>();
364       for (String name : names) {
365         int index = Integer.parseInt(name);
366         conformations.add(new Conformation(distMatrixArray[index], index));
367       }
368 
369       // Compute its centroid.
370       Conformation centroid = centroidOf(conformations, dim);
371 
372       // Create a new CentroidCluster and add the conformations.
373       CentroidCluster<Conformation> centroidCluster = new CentroidCluster<>(centroid);
374       centroidClusters.add(centroidCluster);
375       for (Conformation conformation : conformations) {
376         centroidCluster.addPoint(conformation);
377       }
378     }
379     return centroidClusters;
380   }
381 
382   /**
383    * Collect the names of each leaf in Cluster.
384    *
385    * @param cluster The cluster to operate on.
386    * @param names A List of leaf names.
387    */
388   private static void collectNames(Cluster cluster, List<String> names) {
389     if (cluster.isLeaf()) {
390       names.add(cluster.getName());
391     } else {
392       for (Cluster c : cluster.getChildren()) {
393         collectNames(c, names);
394       }
395     }
396   }
397 
398 
399   /**
400    * Computes the centroid for a set of points.
401    *
402    * @param points the set of points
403    * @param dimension the point dimension
404    * @return the computed centroid for the set of points
405    */
406   private static Conformation centroidOf(final List<Conformation> points, final int dimension) {
407     final double[] centroid = new double[dimension];
408     for (final Conformation p : points) {
409       final double[] point = p.getPoint();
410       for (int i = 0; i < centroid.length; i++) {
411         centroid[i] += point[i];
412       }
413     }
414     for (int i = 0; i < centroid.length; i++) {
415       centroid[i] /= points.size();
416     }
417     return new Conformation(centroid, 0);
418   }
419 
420   /**
421    * Class for cluster objects.
422    */
423   public static class Conformation implements Clusterable {
424 
425     private final double[] rmsd;
426     private final int index;
427 
428     Conformation(double[] rmsd, int index) {
429       this.rmsd = rmsd;
430       this.index = index;
431     }
432 
433     public double[] getPoint() {
434       return rmsd;
435     }
436 
437     int getIndex() {
438       return index;
439     }
440   }
441 }