1 package ffx.algorithms.dynamics;
2
3 import edu.rit.mp.DoubleBuf;
4 import edu.rit.pj.Comm;
5 import ffx.numerics.Potential;
6 import ffx.potential.MolecularAssembly;
7 import ffx.potential.utils.PotentialsUtils;
8 import ffx.potential.utils.StructureMetrics;
9 import ffx.potential.utils.Superpose;
10 import org.apache.commons.configuration2.CompositeConfiguration;
11 import org.apache.commons.io.FilenameUtils;
12
13 import java.io.File;
14 import java.io.IOException;
15 import java.nio.file.Files;
16 import java.nio.file.Path;
17 import java.util.ArrayList;
18 import java.util.Arrays;
19 import java.util.Iterator;
20 import java.util.List;
21 import java.util.PriorityQueue;
22 import java.util.Random;
23 import java.util.logging.Level;
24 import java.util.logging.Logger;
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48 public class WeightedEnsembleManager {
49 private static final Logger logger = Logger.getLogger(WeightedEnsembleManager.class.getName());
50
51
52 private final int rank, worldSize;
53 private final double[][] weightsBins, weightsBinsCopyRank, globalValues;
54
55 private final Comm world;
56 private final DoubleBuf[] weightsBinsBuf, globalValuesBuf;
57 private final DoubleBuf myWeightBinBuf, myGlobalValueBuf;
58
59 private boolean restart, staticBins, resample;
60 private final boolean[] enteredNewBins;
61 private int numBins, optNumPerBin;
62 private long totalSteps, numStepsPerResample, cycle;
63 private double weight, dt, temp, trajInterval;
64 private double[] refCoords, binBounds, x;
65 private File dynFile, dynTemp, trajFile, trajTemp, currentDir, refStructureFile;
66 private final Random random;
67 private final MolecularDynamics molecularDynamics;
68 private Potential potential;
69 private final OneDimMetric metric;
70
71
72
73 public enum OneDimMetric {RMSD, RESIDUE_DISTANCE, COM_DISTANCE, ATOM_DISTANCE, POTENTIAL, RADIUS_OF_GYRATION}
74
75 public WeightedEnsembleManager(OneDimMetric metric, int optNumPerBin,
76 MolecularDynamics md,
77 File refStructureFile,
78 boolean resample) {
79
80 this.world = Comm.world();
81 this.rank = world.rank();
82 this.worldSize = world.size();
83
84 if (worldSize < 2){
85 logger.severe(" Weighted Ensemble requires at least 2 ranks.");
86 System.exit(1);
87 } else if (worldSize < 10){
88 logger.warning(" Weighted Ensemble is not recommended for small scale parallel simulations.");
89 }
90
91
92 this.refStructureFile = refStructureFile;
93 refCoords = getRefCoords(refStructureFile);
94 if(refCoords == null){
95 logger.severe(" Failed to get reference coordinates.");
96 System.exit(1);
97 }
98 this.dynFile = md.getDynFile();
99 this.molecularDynamics = md;
100 this.potential = null;
101 if(molecularDynamics.molecularAssembly != null){
102 this.potential = molecularDynamics.molecularAssembly[0].getPotentialEnergy();
103 this.x = potential.getCoordinates(new double[potential.getNumberOfVariables()]);
104 } else {
105 logger.severe(" Molecular Assembly not set for Molecular Dynamics.");
106 System.exit(1);
107 }
108
109 logger.info("\n\n ----------------------------- Initializing Weighted Ensemble Run -----------------------------");
110 if (initFilesOrPrepRestart()){
111 if (restart) {
112 logger.info(" Restarting from previous Weighted Ensemble run.");
113 }
114 } else {
115 logger.severe(" Failed to initialize Weighted Ensemble run.");
116 System.exit(1);
117 }
118
119
120 CompositeConfiguration properties = molecularDynamics.molecularAssembly[0].getProperties();
121 this.random = new Random();
122 random.setSeed(44);
123 this.binBounds = getPropertyList(properties, "WE.BinBounds");
124 logger.info(" Bin bounds: " + Arrays.toString(binBounds));
125 this.numBins = binBounds.length + 1;
126 this.staticBins = binBounds.length > 0;
127 if (staticBins){
128 logger.info(" Using static binning with " + numBins + " bins.");
129 } else {
130 logger.info(" Using dynamic binning.");
131 }
132 this.resample = resample;
133 this.numBins = staticBins ? numBins : worldSize / optNumPerBin;
134 this.optNumPerBin = optNumPerBin;
135 this.metric = metric;
136 this.weight = 1.0 / worldSize;
137 this.cycle = 0;
138 this.enteredNewBins = new boolean[worldSize];
139 this.globalValues = new double[worldSize][1];
140 this.weightsBins = new double[worldSize][2];
141 this.weightsBinsCopyRank = new double[worldSize][3];
142 this.weightsBinsBuf = new DoubleBuf[worldSize];
143 this.globalValuesBuf = new DoubleBuf[worldSize];
144 for (int i = 0; i < worldSize; i++){
145 weightsBinsBuf[i] = DoubleBuf.buffer(weightsBins[i]);
146 globalValuesBuf[i] = DoubleBuf.buffer(globalValues[i]);
147 }
148 this.myWeightBinBuf = weightsBinsBuf[rank];
149 this.myGlobalValueBuf = globalValuesBuf[rank];
150 weightsBins[rank][0] = weight;
151
152
153
154 }
155
156 private double[] getPropertyList(CompositeConfiguration properties, String propertyName) {
157 ArrayList<Double> list = new ArrayList<>();
158 String[] split = properties.getString(propertyName, "").trim()
159 .replace("[", "")
160 .replace("]","")
161 .replace(","," ")
162 .split(" ");
163 if (split[0].isEmpty()){
164 return new double[0];
165 }
166 for (String s1 : split) {
167 if (s1.isEmpty()) {
168 continue;
169 }
170 list.add(Double.parseDouble(s1));
171 }
172 return list.stream().sorted().mapToDouble(Double::doubleValue).toArray();
173 }
174
175 private static double[] getRefCoords(File refStructureFile){
176 PotentialsUtils utils = new PotentialsUtils();
177 MolecularAssembly assembly = utils.open(refStructureFile);
178 return assembly.getPotentialEnergy().getCoordinates(new double[0]);
179 }
180
181 private boolean initFilesOrPrepRestart() {
182
183 logger.info("\n Rank " + rank + " structure is based on: " + refStructureFile.getAbsolutePath());
184 File parent = refStructureFile.getParentFile();
185 logger.info(" Rank " + rank + " is using parent directory: " + parent.getAbsolutePath());
186 currentDir = new File(parent + File.separator + rank);
187 logger.info(" Rank " + rank + " is using directory: " + currentDir.getAbsolutePath());
188 trajFile = new File(currentDir + File.separator +
189 FilenameUtils.getBaseName(refStructureFile.getName()) + ".arc");
190 logger.info(" Rank " + rank + " is using trajectory file: " + trajFile.getAbsolutePath());
191 molecularDynamics.setArchiveFiles(new File[]{trajFile});
192 dynFile = new File(currentDir + File.separator +
193 FilenameUtils.getBaseName(refStructureFile.getName()) + ".dyn");
194 logger.info(" Rank " + rank + " is using dyn restart file: " + dynFile.getAbsolutePath());
195 molecularDynamics.setFallbackDynFile(dynFile);
196
197 restart = !currentDir.mkdir() && checkRestartFiles();
198 logger.info(" \n");
199 return currentDir.exists();
200 }
201
202 private boolean checkRestartFiles(){
203
204 Path dynPath = dynFile.toPath();
205 Path trajPath = trajFile.toPath();
206
207 return dynPath.toFile().exists() && trajPath.toFile().exists();
208 }
209
210
211
212
213 public void run(long totalSteps, long numStepsPerResample, double temp, double dt) {
214 this.totalSteps = totalSteps;
215 this.numStepsPerResample = numStepsPerResample;
216 this.temp = temp;
217 this.dt = dt;
218 if (!staticBins) {
219 dynamics(numStepsPerResample * 3L);
220 }
221 if(restart) {
222 long restartFrom = getRestartTime(dynFile);
223 logger.info(" Restarting from " + restartFrom + " femtoseconds.");
224 totalSteps -= restartFrom;
225 logger.info(" Remaining cycles: " + (int) Math.ceil((double) totalSteps / numStepsPerResample));
226 }
227
228
229 logger.info("\n ----------------------------- Start Weighted Ensemble Run -----------------------------");
230 int numCycles = (int) Math.ceil((double) totalSteps / numStepsPerResample);
231
232 for (int i = 0; i < numCycles; i++) {
233 cycle = i;
234 dynamics(numStepsPerResample);
235 calculateMyMetric();
236 comms();
237 binAssignment();
238 if (resample) {
239 resample();
240 }
241 comms();
242 if (resample) {
243 sanityCheckAndFileMigration();
244 logger.info("\n ----------------------------- Resampling cycle #" + (i + 1) +
245 " complete ----------------------------- ");
246 }
247 }
248 logger.info("\n\n\n ----------------------------- End Weighted Ensemble Run -----------------------------");
249 }
250
251 private long getRestartTime(File dynFile) {
252
253
254 return 0;
255 }
256
257 private void dynamics(long numSteps){
258
259 double dynamicTotalTime = (numSteps * dt / 1000.0);
260 molecularDynamics.setCoordinates(x);
261 potential.energy(x, false);
262 molecularDynamics.dynamic(numSteps, dt, dynamicTotalTime/3.0,
263 dynamicTotalTime/3.0, temp, true, dynFile);
264 x = molecularDynamics.getCoordinates();
265 molecularDynamics.writeRestart();
266 }
267
268 private void calculateMyMetric(){
269 switch (metric){
270 case RMSD:
271 globalValues[rank][0] = Superpose.rmsd(refCoords, x, potential.getMass());
272 break;
273 case RESIDUE_DISTANCE:
274 break;
275 case COM_DISTANCE:
276 break;
277 case ATOM_DISTANCE:
278 break;
279 case POTENTIAL:
280 double refEnergy = potential.energy(refCoords, false);
281 double myEnergy = potential.energy(x, false);
282 globalValues[rank][0] = myEnergy - refEnergy;
283 break;
284 case RADIUS_OF_GYRATION:
285 StructureMetrics.radiusOfGyration(x, potential.getMass());
286 break;
287 default:
288 break;
289 }
290 }
291
292 private void binAssignment() {
293 double[] global = new double[worldSize];
294 for (int i = 0; i < worldSize; i++){
295 global[i] = globalValues[i][0];
296 }
297 double[] binBounds = getOneDimBinBounds(global);
298 numBins = binBounds.length + 1;
299
300
301
302 for (int j = 0; j < worldSize; j++) {
303 int oldBin = (int) Math.round(weightsBins[j][1]);
304 for (int i = 0; i < numBins - 2; i++) {
305 if (i == 0 && global[j] < binBounds[i]) {
306 weightsBins[j][1] = i;
307 break;
308 }
309 if (global[j] >= binBounds[i] && global[j] < binBounds[i + 1]) {
310 weightsBins[j][1] = i + 1;
311 break;
312 }
313 if (i == numBins - 3 && global[j] >= binBounds[i + 1]) {
314 weightsBins[j][1] = i + 2;
315 break;
316 }
317 }
318 enteredNewBins[j] = oldBin != weightsBins[j][1];
319 }
320
321 logger.info("\n Bin bounds: " + Arrays.toString(binBounds));
322 logger.info(" Rank global values with metric \"" + metricToString(metric) + "\": " + Arrays.toString(global));
323 logger.info(" Entered new bins: " + Arrays.toString(enteredNewBins));
324 }
325
326 private String metricToString(OneDimMetric metric) {
327 return switch (metric) {
328 case RMSD -> "RMSD";
329 case RESIDUE_DISTANCE -> "Residue Distance";
330 case COM_DISTANCE -> "COM Distance";
331 case ATOM_DISTANCE -> "Atom Distance";
332 case POTENTIAL -> "Potential";
333 default -> "Invalid Metric";
334 };
335 }
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358 private void resample(){
359 logger.info("\n\n ----------------------------- Resampling ----------------------------- ");
360
361 List<List<Integer>> binRank = new ArrayList<>();
362 PriorityQueue<Decision> merges = new PriorityQueue<>();
363 PriorityQueue<Decision> splits = new PriorityQueue<>();
364 for (int i = 0; i < numBins; i++){
365 binRank.add(new ArrayList<>());
366 }
367
368
369 for (int i = 0; i < worldSize; i++){
370 int bin = (int) Math.round(weightsBins[i][1]);
371 binRank.get(bin).add(i);
372 }
373
374
375 int m = 2;
376 for (int i = 0; i < numBins; i++){
377 List<Integer> ranks = binRank.get(i)
378 .stream()
379 .sorted((a, b) -> Double.compare(weightsBins[a][0], weightsBins[b][0]))
380 .toList();
381 if (ranks.isEmpty()){ continue; }
382 double idealWeight = ranks.stream().mapToDouble(a -> weightsBins[a][0]).sum() / optNumPerBin;
383 logger.info("\n Bin #" + i + " has " + ranks.size() + " ranks with ideal weight " + idealWeight + ".");
384
385
386 ArrayList<Integer> splitRanks = new ArrayList<>();
387 for (int rank : ranks) {
388 double weight = weightsBins[rank][0];
389 if (weightsBins[rank][0] > 2.0 * idealWeight || enteredNewBins[rank]) {
390 splits.add(new Decision(new ArrayList<>(List.of(rank)), new ArrayList<>(List.of(weight))));
391 splitRanks.add(rank);
392 }
393 }
394
395
396 Iterator<Integer> ranksIterator = ranks.stream().iterator();
397 int rank = ranksIterator.next();
398 double weight = weightsBins[rank][0];
399 double combinedWeight = 0.0;
400 ArrayList<Decision> binMergeList = new ArrayList<>();
401 binMergeList.add(new Decision(new ArrayList<>(), new ArrayList<>()));
402 while(weight < idealWeight/2){
403 boolean split = false;
404 if (splitRanks.contains(rank) && ranks.size() > optNumPerBin){
405 double choice = random.nextDouble();
406 if (choice < .5) {
407 int finalRank = rank;
408 splits.removeIf(decision -> decision.ranks.contains(finalRank));
409 } else {
410 split = true;
411 }
412 }
413 if (!split){
414 if (combinedWeight + weight < idealWeight) {
415 binMergeList.getLast().ranks.add(rank);
416 binMergeList.getLast().weights.add(weight);
417 combinedWeight += weight;
418 } else if (combinedWeight + weight <= idealWeight * 1.5) {
419 binMergeList.getLast().ranks.add(rank);
420 binMergeList.getLast().weights.add(weight);
421 combinedWeight = 0;
422 binMergeList.add(new Decision(new ArrayList<>(), new ArrayList<>()));
423 } else {
424 binMergeList.add(new Decision(new ArrayList<>(), new ArrayList<>()));
425 combinedWeight = 0;
426 continue;
427 }
428 }
429 if (ranksIterator.hasNext()){
430 rank = ranksIterator.next();
431 weight = weightsBins[rank][0];
432 } else {
433 break;
434 }
435 }
436 for (Decision decision : binMergeList){
437 if (decision.ranks.size() > 1){
438 merges.add(decision);
439 }
440 }
441 }
442
443
444 StringBuilder mergeString = new StringBuilder();
445 for (int i = 0; i < worldSize; i++){
446 weightsBinsCopyRank[i][0] = weightsBins[i][0];
447 weightsBinsCopyRank[i][1] = weightsBins[i][1];
448 weightsBinsCopyRank[i][2] = -1;
449 }
450
451 ArrayList<Integer> freedRanks = new ArrayList<>();
452 int desiredFreeRanks = splits.size()*m - splits.size();
453 while (!merges.isEmpty() && desiredFreeRanks > 0){
454 Decision decision = merges.poll();
455 if(decision.ranks.size()-1 > desiredFreeRanks){
456 decision.ranks.subList(desiredFreeRanks+1, decision.ranks.size()).clear();
457 decision.weights.subList(desiredFreeRanks+1, decision.weights.size()).clear();
458 }
459 desiredFreeRanks -= decision.ranks.size()-1;
460
461
462 ArrayList<Double> weights = new ArrayList<>(decision.weights);
463 double totalWeight = weights.stream().mapToDouble(Double::doubleValue).sum();
464 weights.replaceAll(a -> a /totalWeight);
465 double rand = random.nextDouble();
466 double cumulativeWeight = 0.0;
467 int rankToMergeInto = -1;
468 for (int i = 0; i < weights.size(); i++){
469 cumulativeWeight += weights.get(i);
470 if (rand <= cumulativeWeight){
471 rankToMergeInto = decision.ranks.get(i);
472 break;
473 }
474 }
475 mergeString.append("\t Ranks ").append(decision.ranks).append(" --> ").append(rankToMergeInto).append("\n");
476
477 for (int rank : decision.ranks){
478 if (rank != rankToMergeInto){
479 freedRanks.add(rank);
480 } else{
481 weightsBinsCopyRank[rank][0] = decision.getTotalWeight();
482 }
483 }
484 }
485
486 StringBuilder splitString = new StringBuilder();
487 double[] global = new double[worldSize];
488 for (int i = 0; i < worldSize; i++){
489 global[i] = globalValues[i][0];
490 }
491 while(!splits.isEmpty() && !freedRanks.isEmpty()){
492 if (freedRanks.size() < m-1){
493 m = freedRanks.size() + 1;
494 }
495 Decision decision = splits.poll();
496 int parent = decision.ranks.getFirst();
497 double weightToEach = decision.getTotalWeight() / m;
498 ArrayList<Integer> ranksUsed = new ArrayList<>();
499 ranksUsed.add(parent);
500 weightsBinsCopyRank[parent][0] = weightToEach;
501 for (int i = 0; i < m-1; i++){
502 int childRank = freedRanks.removeFirst();
503 ranksUsed.add(childRank);
504 weightsBinsCopyRank[childRank][0] = weightToEach;
505 weightsBinsCopyRank[childRank][1] = weightsBinsCopyRank[parent][1];
506 weightsBinsCopyRank[childRank][2] = parent;
507 global[childRank] = globalValues[parent][0];
508 }
509 splitString.append("\t Rank ").append(parent).append(" --> ").append(ranksUsed).append("\n");
510 }
511
512
513 weightsBins[rank][0] = weightsBinsCopyRank[rank][0];
514 weightsBins[rank][1] = weightsBinsCopyRank[rank][1];
515 globalValues[rank][0] = global[rank];
516
517
518 logger.info("\n ----------------------------- Resampling Decisions ----------------------------- ");
519 double[] weights = new double[worldSize];
520 int[] bins = new int[worldSize];
521 for (int i = 0; i < worldSize; i++) {
522 weights[i] = weightsBinsCopyRank[i][0];
523 bins[i] = (int) Math.round(weightsBinsCopyRank[i][1]);
524 }
525 logger.info("\n Rank bin numbers: " + Arrays.toString(bins));
526 logger.info(" Bin bounds: " + Arrays.toString(binBounds));
527 logger.info(" Rank weights: " + Arrays.toString(weights));
528 logger.info(" Weight sum: " + Arrays.stream(weights).sum());
529 logger.info(" Merges: \n" + mergeString + "\n");
530 logger.info(" Splits: \n" + splitString + "\n");
531 if ((Arrays.stream(weights).sum()-1) > 1e-6){
532 logger.severe(" Weights do not sum to 1.0.");
533 }
534 }
535
536 private record Decision(ArrayList<Integer> ranks, ArrayList<Double> weights)
537 implements Comparable<Decision> {
538
539 public double getTotalWeight(){
540 return weights.stream().mapToDouble(Double::doubleValue).sum();
541 }
542
543 @Override
544 public int compareTo(Decision decision) {
545 return Double.compare(this.getTotalWeight(), decision.getTotalWeight());
546 }
547 }
548
549 private double[] getOneDimBinBounds(double[] globalValues){
550 if (staticBins){
551 return binBounds;
552 }
553
554 logger.severe(" Automatic binning not implemented yet.");
555 return new double[0];
556 }
557
558 private void comms(){
559 comms(false);
560 }
561
562
563
564
565 private void comms(boolean log){
566 if (log) {
567 double[] weights = new double[worldSize];
568 double[] global = new double[worldSize];
569 for (int i = 0; i < worldSize; i++) {
570 weights[i] = weightsBinsCopyRank[i][0];
571 global[i] = globalValues[i][0];
572 }
573 logger.info(" Rank bin numbers Pre: " + Arrays.toString(weights));
574 logger.info(" Rank global values Pre: " + Arrays.toString(global));
575 }
576 try{
577 world.allGather(myWeightBinBuf, weightsBinsBuf);
578 world.allGather(myGlobalValueBuf, globalValuesBuf);
579 } catch (IOException e) {
580 String message = " WeightedEnsemble allGather for weightsbins failed.";
581 logger.severe(message);
582 }
583 if(log) {
584 double[] weights = new double[worldSize];
585 double[] global = new double[worldSize];
586 for (int i = 0; i < worldSize; i++) {
587 weights[i] = weightsBinsCopyRank[i][0];
588 global[i] = globalValues[i][0];
589 }
590 logger.info(" Rank bin numbers Post: " + Arrays.toString(weights));
591 logger.info(" Rank global values Post: " + Arrays.toString(global));
592 }
593 }
594
595 private void sanityCheckAndFileMigration() {
596
597 for (int i = 0; i < worldSize; i++) {
598 if (weightsBins[i][0] != weightsBinsCopyRank[i][0] ||
599 weightsBins[i][1] != weightsBinsCopyRank[i][1]) {
600 String message = " Rank " + i + " has mismatched weightsBins and weightsBinsCopyRank.";
601 logger.info(" WeightsBins: " + Arrays.deepToString(weightsBins));
602 logger.info(" WeightsBinsCopyRank: " + Arrays.deepToString(weightsBinsCopyRank));
603 logger.severe(message);
604 System.exit(1);
605 }
606 }
607
608
609 if (weightsBinsCopyRank[rank][2] != -1 && weightsBinsCopyRank[rank][2] != rank) {
610 copyOver((int) weightsBinsCopyRank[rank][2]);
611 }
612 comms();
613 if (weightsBinsCopyRank[rank][2] != -1 && weightsBinsCopyRank[rank][2] != rank) {
614 moveOnto();
615 }
616 }
617
618 private void copyOver(int rank){
619
620
621 File dyn = new File(currentDir.getParent() + File.separator + rank +
622 File.separator + dynFile.getName());
623 dynTemp = new File(currentDir.getParent() + File.separator + rank +
624 File.separator + dynFile.getName() + ".temp");
625 File traj = new File(currentDir.getParent() + File.separator + rank +
626 File.separator + trajFile.getName());
627 trajTemp = new File(currentDir.getParent() + File.separator + rank +
628 File.separator + trajFile.getName() + ".temp");
629
630 try {
631 Files.copy(dyn.toPath(), dynTemp.toPath());
632 } catch (IOException e) {
633 String message = " Failed to copy dyn file from rank " + rank + " to rank " + this.rank;
634 logger.log(Level.SEVERE, message, e);
635 }
636 try {
637 Files.copy(traj.toPath(), trajTemp.toPath());
638 } catch (IOException e) {
639 String message = " Failed to copy traj file from rank " + rank + " to rank " + this.rank;
640 logger.log(Level.SEVERE, message, e);
641 }
642 }
643
644 private void moveOnto(){
645 try {
646 Files.move(dynTemp.toPath(), dynFile.toPath(), java.nio.file.StandardCopyOption.REPLACE_EXISTING);
647 } catch (IOException e) {
648 String message = " Failed to move dyn file from rank " + rank + " to rank " + this.rank;
649 logger.log(Level.SEVERE, message, e);
650 }
651 try {
652 Files.move(trajTemp.toPath(), trajFile.toPath(), java.nio.file.StandardCopyOption.REPLACE_EXISTING);
653 } catch (IOException e) {
654 String message = " Failed to move traj file from rank " + rank + " to rank " + this.rank;
655 logger.log(Level.SEVERE, message, e);
656 }
657 }
658 }