1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.utils;
39
40 import static java.lang.String.format;
41 import static org.apache.commons.math3.util.FastMath.sqrt;
42 import static org.apache.commons.math3.util.FastMath.floor;
43 import static org.apache.commons.math3.util.FastMath.random;
44
45 import com.apporiented.algorithm.clustering.Cluster;
46 import com.apporiented.algorithm.clustering.ClusteringAlgorithm;
47 import com.apporiented.algorithm.clustering.DefaultClusteringAlgorithm;
48 import com.apporiented.algorithm.clustering.SingleLinkageStrategy;
49 import java.util.ArrayList;
50 import java.util.List;
51 import java.util.logging.Level;
52 import java.util.logging.Logger;
53 import org.apache.commons.math3.ml.clustering.CentroidCluster;
54 import org.apache.commons.math3.ml.clustering.Clusterable;
55 import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
56 import org.apache.commons.math3.ml.clustering.MultiKMeansPlusPlusClusterer;
57 import org.apache.commons.math3.ml.clustering.evaluation.SumOfClusterVariances;
58 import org.apache.commons.math3.ml.distance.EuclideanDistance;
59 import org.apache.commons.math3.random.RandomGenerator;
60
61
62
63
64
65
66
67 public class Clustering {
68
69 private static final Logger log = Logger.getLogger(PotentialsUtils.class.getName());
70
71
72
73
74 private static final int NUM_ITERATIONS = 10000;
75
76
77
78
79
80
81
82
83
84
85 public static List<CentroidCluster<Conformation>> kMeansClustering(List<double[]> distMatrix,
86 int maxClusters, int numTrials, long seed) {
87
88 int dim = distMatrix.size();
89 List<Conformation> conformationList = new ArrayList<>();
90 for (int i = 0; i < dim; i++) {
91 double[] row = distMatrix.get(i);
92
93 if (row.length != dim) {
94 log.severe(format(" Row %d of the distance matrix (%d x %d) has %d columns.", i, dim, dim,
95 row.length));
96 }
97 conformationList.add(new Conformation(row, i));
98 }
99
100
101
102 KMeansPlusPlusClusterer<Conformation> kMeansPlusPlusClusterer = new KMeansPlusPlusClusterer<>(
103 maxClusters, NUM_ITERATIONS);
104
105 RandomGenerator randomGenerator = kMeansPlusPlusClusterer.getRandomGenerator();
106 randomGenerator.setSeed(seed);
107
108 MultiKMeansPlusPlusClusterer<Conformation> multiKMeansPlusPlusClusterer = new MultiKMeansPlusPlusClusterer<>(
109 kMeansPlusPlusClusterer, numTrials);
110
111
112 return multiKMeansPlusPlusClusterer.cluster(conformationList);
113 }
114
115
116
117
118
119
120
121
122
123
124 public static List<CentroidCluster<Conformation>> hierarchicalClustering(
125 List<double[]> distanceMatrix, double threshold) {
126
127
128 int dim = distanceMatrix.size();
129
130 double[][] distMatrixArray = new double[dim][];
131 String[] names = new String[dim];
132 for (int i = 0; i < dim; i++) {
133 distMatrixArray[i] = distanceMatrix.get(i);
134 names[i] = Integer.toString(i);
135 }
136
137
138 ClusteringAlgorithm clusteringAlgorithm = new DefaultClusteringAlgorithm();
139 List<Cluster> clusterList = clusteringAlgorithm.performFlatClustering(distMatrixArray, names,
140 new SingleLinkageStrategy(), threshold);
141
142
143 return clustersToCentroidClusters(distMatrixArray, clusterList);
144 }
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161 public static List<CentroidCluster<Conformation>> iterativeClustering(List<double[]> distMatrix,
162 int trials, double tolerance) {
163
164 int dim = distMatrix.size();
165 ArrayList<CentroidCluster<Conformation>> bestClusters = new ArrayList<>();
166 for (int i = 0; i < trials; i++) {
167
168 ArrayList<Integer> remaining = new ArrayList<>();
169 for (int j = 0; j < dim; j++) {
170 remaining.add(j);
171 }
172 List<CentroidCluster<Conformation>> clusters = new ArrayList<>();
173 while (!remaining.isEmpty()) {
174 int seed = (int) floor(random() * (remaining.size() - 1));
175 int index = remaining.get(seed);
176 CentroidCluster<Conformation> cluster = new CentroidCluster<>(
177 new Conformation(distMatrix.get(index), index));
178 double[] row = distMatrix.get(index);
179
180 if (log.isLoggable(Level.FINER)) {
181 log.finer(format(" Remaining clusters: %3d of %3d", remaining.size(), dim));
182 log.finer(format(" Row: %3d (seed %3d) has %3d entries.", index + 1, seed, row.length));
183 }
184 if (row.length != dim) {
185 log.severe(
186 format(" Row %d of the distance matrix (%d x %d) has %d columns.", index, dim, dim,
187 row.length));
188 }
189 for (int j = 0; j < dim; j++) {
190 if (row[j] < tolerance && remaining.contains(j)) {
191 cluster.addPoint(new Conformation(distMatrix.get(j), j));
192 if (!remaining.remove((Integer) j)) {
193 log.warning(
194 format(" Row %3d matched %3d, but could not be removed.", j + 1, index + 1));
195 }
196 }
197 }
198 clusters.add(cluster);
199 }
200 if (bestClusters.isEmpty() || clusters.size() < bestClusters.size()) {
201 if (log.isLoggable(Level.FINE)) {
202 int numStructs = 0;
203 for (CentroidCluster<Conformation> cluster : clusters) {
204 numStructs += cluster.getPoints().size();
205 }
206 log.fine(
207 format(" New Best: Num clusters: %3d Num Structs: %3d ", clusters.size(), numStructs));
208 }
209 bestClusters = new ArrayList<>(clusters);
210 }
211 }
212 return bestClusters;
213 }
214
215
216
217
218
219
220
221
222 public static void analyzeClusters(List<CentroidCluster<Conformation>> clusters,
223 List<Integer> repStructs, boolean verbose) {
224
225 int nClusters = clusters.size();
226 double meanClusterRMSD = 0.0;
227
228
229 for (int i = 0; i < nClusters; i++) {
230 CentroidCluster<Conformation> clusterI = clusters.get(i);
231
232 List<Conformation> conformations = clusterI.getPoints();
233 int nConformers = conformations.size();
234 StringBuilder sb;
235 if (nConformers == 1) {
236 sb = new StringBuilder(format(" Cluster %d with conformation:", i + 1));
237 } else {
238 sb = new StringBuilder(
239 format(" Cluster %d with %d conformations\n Conformations:", i + 1, nConformers));
240 }
241
242 double minRMS = Double.MAX_VALUE;
243 int minID = -1;
244 for (Conformation conformation : conformations) {
245 int row = conformation.index;
246 sb.append(format(" %d", row + 1));
247 if (nConformers > 1) {
248
249 double[] rmsd = conformation.getPoint();
250
251 double wrms = 0;
252 for (Conformation conformation2 : conformations) {
253 int col = conformation2.index;
254 if (col == row) {
255 continue;
256 }
257 wrms += rmsd[col] * rmsd[col];
258 }
259
260 wrms = sqrt(wrms / (nConformers - 1));
261 if (wrms < minRMS) {
262 minRMS = wrms;
263 minID = conformation.getIndex();
264 }
265 } else {
266
267 minID = row;
268 minRMS = 0.0;
269 }
270 }
271
272
273 if (nConformers > 1) {
274 double clusterRMSD = clusterRMSD(conformations);
275 meanClusterRMSD += clusterRMSD;
276 sb.append(format("\n RMSD within the cluster:\t %6.4f A.\n", clusterRMSD));
277
278
279 sb.append(format(" Minimum RMSD conformer %d:\t %6.4f A.\n", minID + 1, minRMS));
280 } else {
281 sb.append("\n");
282 }
283
284 repStructs.add(minID);
285
286 if (verbose) {
287 log.info(sb.toString());
288 }
289 }
290
291 if (verbose) {
292 log.info(format(" Mean RMSD within clusters: \t %6.4f A.", meanClusterRMSD / nClusters));
293 double sumOfClusterVariances = sumOfClusterVariances(clusters);
294 log.info(
295 format(" Mean cluster variance: \t %6.4f A.\n", sumOfClusterVariances / nClusters));
296 }
297
298 }
299
300
301
302
303
304
305
306 private static double clusterRMSD(List<Conformation> conformations) {
307 int nConformers = conformations.size();
308
309
310 if (nConformers == 1) {
311 return 0.0;
312 }
313
314
315 double sum = 0.0;
316 int count = 0;
317 for (int j = 0; j < nConformers; j++) {
318 Conformation conformation = conformations.get(j);
319 double[] rmsd = conformation.rmsd;
320 for (int k = j + 1; k < nConformers; k++) {
321 Conformation conformation2 = conformations.get(k);
322 int col = conformation2.index;
323 sum += rmsd[col] * rmsd[col];
324 count++;
325 }
326 }
327
328 return sqrt(sum / count);
329 }
330
331
332
333
334
335
336
337 private static double sumOfClusterVariances(List<CentroidCluster<Conformation>> clusters) {
338 SumOfClusterVariances<Conformation> sumOfClusterVariances = new SumOfClusterVariances<>(
339 new EuclideanDistance());
340 return sumOfClusterVariances.score(clusters);
341 }
342
343
344
345
346
347
348
349
350 private static List<CentroidCluster<Conformation>> clustersToCentroidClusters(
351 double[][] distMatrixArray, List<Cluster> clusterList) {
352
353 List<CentroidCluster<Conformation>> centroidClusters = new ArrayList<>();
354 int dim = distMatrixArray.length;
355
356
357 for (Cluster cluster : clusterList) {
358
359 List<String> names = new ArrayList<>();
360 collectNames(cluster, names);
361
362
363 List<Conformation> conformations = new ArrayList<>();
364 for (String name : names) {
365 int index = Integer.parseInt(name);
366 conformations.add(new Conformation(distMatrixArray[index], index));
367 }
368
369
370 Conformation centroid = centroidOf(conformations, dim);
371
372
373 CentroidCluster<Conformation> centroidCluster = new CentroidCluster<>(centroid);
374 centroidClusters.add(centroidCluster);
375 for (Conformation conformation : conformations) {
376 centroidCluster.addPoint(conformation);
377 }
378 }
379 return centroidClusters;
380 }
381
382
383
384
385
386
387
388 private static void collectNames(Cluster cluster, List<String> names) {
389 if (cluster.isLeaf()) {
390 names.add(cluster.getName());
391 } else {
392 for (Cluster c : cluster.getChildren()) {
393 collectNames(c, names);
394 }
395 }
396 }
397
398
399
400
401
402
403
404
405
406 private static Conformation centroidOf(final List<Conformation> points, final int dimension) {
407 final double[] centroid = new double[dimension];
408 for (final Conformation p : points) {
409 final double[] point = p.getPoint();
410 for (int i = 0; i < centroid.length; i++) {
411 centroid[i] += point[i];
412 }
413 }
414 for (int i = 0; i < centroid.length; i++) {
415 centroid[i] /= points.size();
416 }
417 return new Conformation(centroid, 0);
418 }
419
420
421
422
423 public static class Conformation implements Clusterable {
424
425 private final double[] rmsd;
426 private final int index;
427
428 Conformation(double[] rmsd, int index) {
429 this.rmsd = rmsd;
430 this.index = index;
431 }
432
433 public double[] getPoint() {
434 return rmsd;
435 }
436
437 int getIndex() {
438 return index;
439 }
440 }
441 }