1 package ffx.algorithms.dynamics;
2
3 import edu.rit.mp.DoubleBuf;
4 import edu.rit.pj.Comm;
5 import java.io.File;
6 import java.io.IOException;
7 import java.nio.file.Files;
8 import java.nio.file.Path;
9 import java.util.*;
10 import java.util.logging.Level;
11 import java.util.logging.Logger;
12 import ffx.numerics.Potential;
13 import ffx.potential.MolecularAssembly;
14 import ffx.potential.utils.PotentialsUtils;
15 import ffx.potential.utils.ProgressiveAlignmentOfCrystals;
16 import ffx.potential.utils.StructureMetrics;
17 import ffx.potential.utils.Superpose;
18 import org.apache.commons.configuration2.CompositeConfiguration;
19 import org.apache.commons.io.FilenameUtils;
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44 public class WeightedEnsembleManager {
45 private static final Logger logger = Logger.getLogger(WeightedEnsembleManager.class.getName());
46
47
48 private final int rank, worldSize;
49 private final double[][] weightsBins, weightsBinsCopyRank, globalValues;
50
51 private final Comm world;
52 private final DoubleBuf[] weightsBinsBuf, globalValuesBuf;
53 private final DoubleBuf myWeightBinBuf, myGlobalValueBuf;
54
55 private boolean restart, staticBins, resample;
56 private final boolean[] enteredNewBins;
57 private int numBins, optNumPerBin;
58 private long totalSteps, numStepsPerResample, cycle;
59 private double weight, dt, temp, trajInterval;
60 private double[] refCoords, binBounds, x;
61 private File dynFile, dynTemp, trajFile, trajTemp, currentDir, refStructureFile;
62 private final Random random;
63 private final MolecularDynamics molecularDynamics;
64 private Potential potential;
65 private final OneDimMetric metric;
66
67
68
69 public enum OneDimMetric {RMSD, RESIDUE_DISTANCE, COM_DISTANCE, ATOM_DISTANCE, POTENTIAL, RADIUS_OF_GYRATION}
70
71 public WeightedEnsembleManager(OneDimMetric metric, int optNumPerBin,
72 MolecularDynamics md,
73 File refStructureFile,
74 boolean resample) {
75
76 this.world = Comm.world();
77 this.rank = world.rank();
78 this.worldSize = world.size();
79
80 if (worldSize < 2){
81 logger.severe(" Weighted Ensemble requires at least 2 ranks.");
82 System.exit(1);
83 } else if (worldSize < 10){
84 logger.warning(" Weighted Ensemble is not recommended for small scale parallel simulations.");
85 }
86
87
88 this.refStructureFile = refStructureFile;
89 refCoords = getRefCoords(refStructureFile);
90 if(refCoords == null){
91 logger.severe(" Failed to get reference coordinates.");
92 System.exit(1);
93 }
94 this.dynFile = md.getDynFile();
95 this.molecularDynamics = md;
96 this.potential = null;
97 if(molecularDynamics.molecularAssembly != null){
98 this.potential = molecularDynamics.molecularAssembly[0].getPotentialEnergy();
99 this.x = potential.getCoordinates(new double[potential.getNumberOfVariables()]);
100 } else {
101 logger.severe(" Molecular Assembly not set for Molecular Dynamics.");
102 System.exit(1);
103 }
104
105 logger.info("\n\n ----------------------------- Initializing Weighted Ensemble Run -----------------------------");
106 if (initFilesOrPrepRestart()){
107 if (restart) {
108 logger.info(" Restarting from previous Weighted Ensemble run.");
109 }
110 } else {
111 logger.severe(" Failed to initialize Weighted Ensemble run.");
112 System.exit(1);
113 }
114
115
116 CompositeConfiguration properties = molecularDynamics.molecularAssembly[0].getProperties();
117 this.random = new Random();
118 random.setSeed(44);
119 this.binBounds = getPropertyList(properties, "WE.BinBounds");
120 logger.info(" Bin bounds: " + Arrays.toString(binBounds));
121 this.numBins = binBounds.length + 1;
122 this.staticBins = binBounds.length > 0;
123 if (staticBins){
124 logger.info(" Using static binning with " + numBins + " bins.");
125 } else {
126 logger.info(" Using dynamic binning.");
127 }
128 this.resample = resample;
129 this.numBins = staticBins ? numBins : worldSize / optNumPerBin;
130 this.optNumPerBin = optNumPerBin;
131 this.metric = metric;
132 this.weight = 1.0 / worldSize;
133 this.cycle = 0;
134 this.enteredNewBins = new boolean[worldSize];
135 this.globalValues = new double[worldSize][1];
136 this.weightsBins = new double[worldSize][2];
137 this.weightsBinsCopyRank = new double[worldSize][3];
138 this.weightsBinsBuf = new DoubleBuf[worldSize];
139 this.globalValuesBuf = new DoubleBuf[worldSize];
140 for (int i = 0; i < worldSize; i++){
141 weightsBinsBuf[i] = DoubleBuf.buffer(weightsBins[i]);
142 globalValuesBuf[i] = DoubleBuf.buffer(globalValues[i]);
143 }
144 this.myWeightBinBuf = weightsBinsBuf[rank];
145 this.myGlobalValueBuf = globalValuesBuf[rank];
146 weightsBins[rank][0] = weight;
147
148
149
150 }
151
152 private double[] getPropertyList(CompositeConfiguration properties, String propertyName) {
153 ArrayList<Double> list = new ArrayList<>();
154 String[] split = properties.getString(propertyName, "").trim()
155 .replace("[", "")
156 .replace("]","")
157 .replace(","," ")
158 .split(" ");
159 if (split[0].isEmpty()){
160 return new double[0];
161 }
162 for (String s1 : split) {
163 if (s1.isEmpty()) {
164 continue;
165 }
166 list.add(Double.parseDouble(s1));
167 }
168 return list.stream().sorted().mapToDouble(Double::doubleValue).toArray();
169 }
170
171 private static double[] getRefCoords(File refStructureFile){
172 PotentialsUtils utils = new PotentialsUtils();
173 MolecularAssembly assembly = utils.open(refStructureFile);
174 return assembly.getPotentialEnergy().getCoordinates(new double[0]);
175 }
176
177 private boolean initFilesOrPrepRestart() {
178
179 logger.info("\n Rank " + rank + " structure is based on: " + refStructureFile.getAbsolutePath());
180 File parent = refStructureFile.getParentFile();
181 logger.info(" Rank " + rank + " is using parent directory: " + parent.getAbsolutePath());
182 currentDir = new File(parent + File.separator + rank);
183 logger.info(" Rank " + rank + " is using directory: " + currentDir.getAbsolutePath());
184 trajFile = new File(currentDir + File.separator +
185 FilenameUtils.getBaseName(refStructureFile.getName()) + ".arc");
186 logger.info(" Rank " + rank + " is using trajectory file: " + trajFile.getAbsolutePath());
187 molecularDynamics.setArchiveFiles(new File[]{trajFile});
188 dynFile = new File(currentDir + File.separator +
189 FilenameUtils.getBaseName(refStructureFile.getName()) + ".dyn");
190 logger.info(" Rank " + rank + " is using dyn restart file: " + dynFile.getAbsolutePath());
191 molecularDynamics.setFallbackDynFile(dynFile);
192
193 restart = !currentDir.mkdir() && checkRestartFiles();
194 logger.info(" \n");
195 return currentDir.exists();
196 }
197
198 private boolean checkRestartFiles(){
199
200 Path dynPath = dynFile.toPath();
201 Path trajPath = trajFile.toPath();
202
203 return dynPath.toFile().exists() && trajPath.toFile().exists();
204 }
205
206
207
208
209 public void run(long totalSteps, long numStepsPerResample, double temp, double dt) {
210 this.totalSteps = totalSteps;
211 this.numStepsPerResample = numStepsPerResample;
212 this.temp = temp;
213 this.dt = dt;
214 if (!staticBins) {
215 dynamics(numStepsPerResample * 3L);
216 }
217 if(restart) {
218 long restartFrom = getRestartTime(dynFile);
219 logger.info(" Restarting from " + restartFrom + " femtoseconds.");
220 totalSteps -= restartFrom;
221 logger.info(" Remaining cycles: " + (int) Math.ceil((double) totalSteps / numStepsPerResample));
222 }
223
224
225 logger.info("\n ----------------------------- Start Weighted Ensemble Run -----------------------------");
226 int numCycles = (int) Math.ceil((double) totalSteps / numStepsPerResample);
227
228 for (int i = 0; i < numCycles; i++) {
229 cycle = i;
230 dynamics(numStepsPerResample);
231 calculateMyMetric();
232 comms();
233 binAssignment();
234 if (resample) {
235 resample();
236 }
237 comms();
238 if (resample) {
239 sanityCheckAndFileMigration();
240 logger.info("\n ----------------------------- Resampling cycle #" + (i + 1) +
241 " complete ----------------------------- ");
242 }
243 }
244 logger.info("\n\n\n ----------------------------- End Weighted Ensemble Run -----------------------------");
245 }
246
247 private long getRestartTime(File dynFile) {
248
249
250 return 0;
251 }
252
253 private void dynamics(long numSteps){
254
255 double dynamicTotalTime = (numSteps * dt / 1000.0);
256 molecularDynamics.setCoordinates(x);
257 potential.energy(x, false);
258 molecularDynamics.dynamic(numSteps, dt, dynamicTotalTime/3.0,
259 dynamicTotalTime/3.0, temp, true, dynFile);
260 x = molecularDynamics.getCoordinates();
261 molecularDynamics.writeRestart();
262 }
263
264 private void calculateMyMetric(){
265 switch (metric){
266 case RMSD:
267 globalValues[rank][0] = Superpose.rmsd(refCoords, x, potential.getMass());
268 break;
269 case RESIDUE_DISTANCE:
270 break;
271 case COM_DISTANCE:
272 break;
273 case ATOM_DISTANCE:
274 break;
275 case POTENTIAL:
276 double refEnergy = potential.energy(refCoords, false);
277 double myEnergy = potential.energy(x, false);
278 globalValues[rank][0] = myEnergy - refEnergy;
279 break;
280 case RADIUS_OF_GYRATION:
281 StructureMetrics.radiusOfGyration(x, potential.getMass());
282 break;
283 default:
284 break;
285 }
286 }
287
288 private void binAssignment() {
289 double[] global = new double[worldSize];
290 for (int i = 0; i < worldSize; i++){
291 global[i] = globalValues[i][0];
292 }
293 double[] binBounds = getOneDimBinBounds(global);
294 numBins = binBounds.length + 1;
295
296
297
298 for (int j = 0; j < worldSize; j++) {
299 int oldBin = (int) Math.round(weightsBins[j][1]);
300 for (int i = 0; i < numBins - 2; i++) {
301 if (i == 0 && global[j] < binBounds[i]) {
302 weightsBins[j][1] = i;
303 break;
304 }
305 if (global[j] >= binBounds[i] && global[j] < binBounds[i + 1]) {
306 weightsBins[j][1] = i + 1;
307 break;
308 }
309 if (i == numBins - 3 && global[j] >= binBounds[i + 1]) {
310 weightsBins[j][1] = i + 2;
311 break;
312 }
313 }
314 enteredNewBins[j] = oldBin != weightsBins[j][1];
315 }
316
317 logger.info("\n Bin bounds: " + Arrays.toString(binBounds));
318 logger.info(" Rank global values with metric \"" + metricToString(metric) + "\": " + Arrays.toString(global));
319 logger.info(" Entered new bins: " + Arrays.toString(enteredNewBins));
320 }
321
322 private String metricToString(OneDimMetric metric) {
323 return switch (metric) {
324 case RMSD -> "RMSD";
325 case RESIDUE_DISTANCE -> "Residue Distance";
326 case COM_DISTANCE -> "COM Distance";
327 case ATOM_DISTANCE -> "Atom Distance";
328 case POTENTIAL -> "Potential";
329 default -> "Invalid Metric";
330 };
331 }
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354 private void resample(){
355 logger.info("\n\n ----------------------------- Resampling ----------------------------- ");
356
357 List<List<Integer>> binRank = new ArrayList<>();
358 PriorityQueue<Decision> merges = new PriorityQueue<>();
359 PriorityQueue<Decision> splits = new PriorityQueue<>();
360 for (int i = 0; i < numBins; i++){
361 binRank.add(new ArrayList<>());
362 }
363
364
365 for (int i = 0; i < worldSize; i++){
366 int bin = (int) Math.round(weightsBins[i][1]);
367 binRank.get(bin).add(i);
368 }
369
370
371 int m = 2;
372 for (int i = 0; i < numBins; i++){
373 List<Integer> ranks = binRank.get(i)
374 .stream()
375 .sorted((a, b) -> Double.compare(weightsBins[a][0], weightsBins[b][0]))
376 .toList();
377 if (ranks.isEmpty()){ continue; }
378 double idealWeight = ranks.stream().mapToDouble(a -> weightsBins[a][0]).sum() / optNumPerBin;
379 logger.info("\n Bin #" + i + " has " + ranks.size() + " ranks with ideal weight " + idealWeight + ".");
380
381
382 ArrayList<Integer> splitRanks = new ArrayList<>();
383 for (int rank : ranks) {
384 double weight = weightsBins[rank][0];
385 if (weightsBins[rank][0] > 2.0 * idealWeight || enteredNewBins[rank]) {
386 splits.add(new Decision(new ArrayList<>(List.of(rank)), new ArrayList<>(List.of(weight))));
387 splitRanks.add(rank);
388 }
389 }
390
391
392 Iterator<Integer> ranksIterator = ranks.stream().iterator();
393 int rank = ranksIterator.next();
394 double weight = weightsBins[rank][0];
395 double combinedWeight = 0.0;
396 ArrayList<Decision> binMergeList = new ArrayList<>();
397 binMergeList.add(new Decision(new ArrayList<>(), new ArrayList<>()));
398 while(weight < idealWeight/2){
399 boolean split = false;
400 if (splitRanks.contains(rank) && ranks.size() > optNumPerBin){
401 double choice = random.nextDouble();
402 if (choice < .5) {
403 int finalRank = rank;
404 splits.removeIf(decision -> decision.ranks.contains(finalRank));
405 } else {
406 split = true;
407 }
408 }
409 if (!split){
410 if (combinedWeight + weight < idealWeight) {
411 binMergeList.getLast().ranks.add(rank);
412 binMergeList.getLast().weights.add(weight);
413 combinedWeight += weight;
414 } else if (combinedWeight + weight <= idealWeight * 1.5) {
415 binMergeList.getLast().ranks.add(rank);
416 binMergeList.getLast().weights.add(weight);
417 combinedWeight = 0;
418 binMergeList.add(new Decision(new ArrayList<>(), new ArrayList<>()));
419 } else {
420 binMergeList.add(new Decision(new ArrayList<>(), new ArrayList<>()));
421 combinedWeight = 0;
422 continue;
423 }
424 }
425 if (ranksIterator.hasNext()){
426 rank = ranksIterator.next();
427 weight = weightsBins[rank][0];
428 } else {
429 break;
430 }
431 }
432 for (Decision decision : binMergeList){
433 if (decision.ranks.size() > 1){
434 merges.add(decision);
435 }
436 }
437 }
438
439
440 StringBuilder mergeString = new StringBuilder();
441 for (int i = 0; i < worldSize; i++){
442 weightsBinsCopyRank[i][0] = weightsBins[i][0];
443 weightsBinsCopyRank[i][1] = weightsBins[i][1];
444 weightsBinsCopyRank[i][2] = -1;
445 }
446
447 ArrayList<Integer> freedRanks = new ArrayList<>();
448 int desiredFreeRanks = splits.size()*m - splits.size();
449 while (!merges.isEmpty() && desiredFreeRanks > 0){
450 Decision decision = merges.poll();
451 if(decision.ranks.size()-1 > desiredFreeRanks){
452 decision.ranks.subList(desiredFreeRanks+1, decision.ranks.size()).clear();
453 decision.weights.subList(desiredFreeRanks+1, decision.weights.size()).clear();
454 }
455 desiredFreeRanks -= decision.ranks.size()-1;
456
457
458 ArrayList<Double> weights = new ArrayList<>(decision.weights);
459 double totalWeight = weights.stream().mapToDouble(Double::doubleValue).sum();
460 weights.replaceAll(a -> a /totalWeight);
461 double rand = random.nextDouble();
462 double cumulativeWeight = 0.0;
463 int rankToMergeInto = -1;
464 for (int i = 0; i < weights.size(); i++){
465 cumulativeWeight += weights.get(i);
466 if (rand <= cumulativeWeight){
467 rankToMergeInto = decision.ranks.get(i);
468 break;
469 }
470 }
471 mergeString.append("\t Ranks ").append(decision.ranks).append(" --> ").append(rankToMergeInto).append("\n");
472
473 for (int rank : decision.ranks){
474 if (rank != rankToMergeInto){
475 freedRanks.add(rank);
476 } else{
477 weightsBinsCopyRank[rank][0] = decision.getTotalWeight();
478 }
479 }
480 }
481
482 StringBuilder splitString = new StringBuilder();
483 double[] global = new double[worldSize];
484 for (int i = 0; i < worldSize; i++){
485 global[i] = globalValues[i][0];
486 }
487 while(!splits.isEmpty() && !freedRanks.isEmpty()){
488 if (freedRanks.size() < m-1){
489 m = freedRanks.size() + 1;
490 }
491 Decision decision = splits.poll();
492 int parent = decision.ranks.getFirst();
493 double weightToEach = decision.getTotalWeight() / m;
494 ArrayList<Integer> ranksUsed = new ArrayList<>();
495 ranksUsed.add(parent);
496 weightsBinsCopyRank[parent][0] = weightToEach;
497 for (int i = 0; i < m-1; i++){
498 int childRank = freedRanks.removeFirst();
499 ranksUsed.add(childRank);
500 weightsBinsCopyRank[childRank][0] = weightToEach;
501 weightsBinsCopyRank[childRank][1] = weightsBinsCopyRank[parent][1];
502 weightsBinsCopyRank[childRank][2] = parent;
503 global[childRank] = globalValues[parent][0];
504 }
505 splitString.append("\t Rank ").append(parent).append(" --> ").append(ranksUsed).append("\n");
506 }
507
508
509 weightsBins[rank][0] = weightsBinsCopyRank[rank][0];
510 weightsBins[rank][1] = weightsBinsCopyRank[rank][1];
511 globalValues[rank][0] = global[rank];
512
513
514 logger.info("\n ----------------------------- Resampling Decisions ----------------------------- ");
515 double[] weights = new double[worldSize];
516 int[] bins = new int[worldSize];
517 for (int i = 0; i < worldSize; i++) {
518 weights[i] = weightsBinsCopyRank[i][0];
519 bins[i] = (int) Math.round(weightsBinsCopyRank[i][1]);
520 }
521 logger.info("\n Rank bin numbers: " + Arrays.toString(bins));
522 logger.info(" Bin bounds: " + Arrays.toString(binBounds));
523 logger.info(" Rank weights: " + Arrays.toString(weights));
524 logger.info(" Weight sum: " + Arrays.stream(weights).sum());
525 logger.info(" Merges: \n" + mergeString + "\n");
526 logger.info(" Splits: \n" + splitString + "\n");
527 if ((Arrays.stream(weights).sum()-1) > 1e-6){
528 logger.severe(" Weights do not sum to 1.0.");
529 }
530 }
531
532 private record Decision(ArrayList<Integer> ranks, ArrayList<Double> weights)
533 implements Comparable<Decision> {
534
535 public double getTotalWeight(){
536 return weights.stream().mapToDouble(Double::doubleValue).sum();
537 }
538
539 @Override
540 public int compareTo(Decision decision) {
541 return Double.compare(this.getTotalWeight(), decision.getTotalWeight());
542 }
543 }
544
545 private double[] getOneDimBinBounds(double[] globalValues){
546 if (staticBins){
547 return binBounds;
548 }
549
550 logger.severe(" Automatic binning not implemented yet.");
551 return new double[0];
552 }
553
554 private void comms(){
555 comms(false);
556 }
557
558
559
560
561 private void comms(boolean log){
562 if (log) {
563 double[] weights = new double[worldSize];
564 double[] global = new double[worldSize];
565 for (int i = 0; i < worldSize; i++) {
566 weights[i] = weightsBinsCopyRank[i][0];
567 global[i] = globalValues[i][0];
568 }
569 logger.info(" Rank bin numbers Pre: " + Arrays.toString(weights));
570 logger.info(" Rank global values Pre: " + Arrays.toString(global));
571 }
572 try{
573 world.allGather(myWeightBinBuf, weightsBinsBuf);
574 world.allGather(myGlobalValueBuf, globalValuesBuf);
575 } catch (IOException e) {
576 String message = " WeightedEnsemble allGather for weightsbins failed.";
577 logger.severe(message);
578 }
579 if(log) {
580 double[] weights = new double[worldSize];
581 double[] global = new double[worldSize];
582 for (int i = 0; i < worldSize; i++) {
583 weights[i] = weightsBinsCopyRank[i][0];
584 global[i] = globalValues[i][0];
585 }
586 logger.info(" Rank bin numbers Post: " + Arrays.toString(weights));
587 logger.info(" Rank global values Post: " + Arrays.toString(global));
588 }
589 }
590
591 private void sanityCheckAndFileMigration() {
592
593 for (int i = 0; i < worldSize; i++) {
594 if (weightsBins[i][0] != weightsBinsCopyRank[i][0] ||
595 weightsBins[i][1] != weightsBinsCopyRank[i][1]) {
596 String message = " Rank " + i + " has mismatched weightsBins and weightsBinsCopyRank.";
597 logger.info(" WeightsBins: " + Arrays.deepToString(weightsBins));
598 logger.info(" WeightsBinsCopyRank: " + Arrays.deepToString(weightsBinsCopyRank));
599 logger.severe(message);
600 System.exit(1);
601 }
602 }
603
604
605 if (weightsBinsCopyRank[rank][2] != -1 && weightsBinsCopyRank[rank][2] != rank) {
606 copyOver((int) weightsBinsCopyRank[rank][2]);
607 }
608 comms();
609 if (weightsBinsCopyRank[rank][2] != -1 && weightsBinsCopyRank[rank][2] != rank) {
610 moveOnto();
611 }
612 }
613
614 private void copyOver(int rank){
615
616
617 File dyn = new File(currentDir.getParent() + File.separator + rank +
618 File.separator + dynFile.getName());
619 dynTemp = new File(currentDir.getParent() + File.separator + rank +
620 File.separator + dynFile.getName() + ".temp");
621 File traj = new File(currentDir.getParent() + File.separator + rank +
622 File.separator + trajFile.getName());
623 trajTemp = new File(currentDir.getParent() + File.separator + rank +
624 File.separator + trajFile.getName() + ".temp");
625
626 try {
627 Files.copy(dyn.toPath(), dynTemp.toPath());
628 } catch (IOException e) {
629 String message = " Failed to copy dyn file from rank " + rank + " to rank " + this.rank;
630 logger.log(Level.SEVERE, message, e);
631 }
632 try {
633 Files.copy(traj.toPath(), trajTemp.toPath());
634 } catch (IOException e) {
635 String message = " Failed to copy traj file from rank " + rank + " to rank " + this.rank;
636 logger.log(Level.SEVERE, message, e);
637 }
638 }
639
640 private void moveOnto(){
641 try {
642 Files.move(dynTemp.toPath(), dynFile.toPath(), java.nio.file.StandardCopyOption.REPLACE_EXISTING);
643 } catch (IOException e) {
644 String message = " Failed to move dyn file from rank " + rank + " to rank " + this.rank;
645 logger.log(Level.SEVERE, message, e);
646 }
647 try {
648 Files.move(trajTemp.toPath(), trajFile.toPath(), java.nio.file.StandardCopyOption.REPLACE_EXISTING);
649 } catch (IOException e) {
650 String message = " Failed to move traj file from rank " + rank + " to rank " + this.rank;
651 logger.log(Level.SEVERE, message, e);
652 }
653 }
654 }