1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  package ffx.algorithms.thermodynamics;
39  
40  import edu.rit.mp.LongBuf;
41  import edu.rit.pj.Comm;
42  import ffx.algorithms.cli.DynamicsOptions;
43  import ffx.algorithms.cli.OSTOptions;
44  import ffx.algorithms.dynamics.MDWriteAction;
45  import ffx.algorithms.dynamics.MolecularDynamics;
46  import ffx.algorithms.mc.BoltzmannMC;
47  import ffx.potential.MolecularAssembly;
48  import ffx.potential.cli.WriteoutOptions;
49  import ffx.utilities.Constants;
50  import org.apache.commons.configuration2.CompositeConfiguration;
51  import org.apache.commons.io.FilenameUtils;
52  
53  import java.io.File;
54  import java.io.IOException;
55  import java.util.Arrays;
56  import java.util.EnumSet;
57  import java.util.Optional;
58  import java.util.Random;
59  import java.util.concurrent.ThreadLocalRandom;
60  import java.util.function.LongConsumer;
61  import java.util.logging.Level;
62  import java.util.logging.Logger;
63  import java.util.stream.IntStream;
64  
65  import static java.lang.String.format;
66  import static java.util.Arrays.fill;
67  
68  
69  
70  
71  
72  
73  
74  
75  public class RepExOST {
76  
77    private static final Logger logger = Logger.getLogger(RepExOST.class.getName());
78    private static final int mainLoopTag = 2020;
79    private final OrthogonalSpaceTempering orthogonalSpaceTempering;
80    private final OrthogonalSpaceTempering.Histogram[] allHistograms;
81    private final SendSynchronous[] sendSynchronous;
82    private final LongConsumer algoRun;
83    private final MolecularDynamics molecularDynamics;
84    private final DynamicsOptions dynamicsOptions;
85    private final String fileType;
86    private final MonteCarloOST monteCarloOST;
87    private final long stepsBetweenExchanges;
88    private final Comm world;
89    private final int rank;
90    private final int numPairs;
91    private final int[] rankToHisto;
92    private final int[] histoToRank;
93    private final boolean isMC;
94    private final Random random;
95    private final double invKT;
96    private final String basePath;
97    private final String[] allFilenames;
98    private final File dynFile;
99    private final String extension;
100   private final long[] totalSwaps;
101   private final long[] acceptedSwaps;
102   private boolean reinitVelocities = true;
103   private int currentHistoIndex;
104   private double currentLambda;
105 
106   
107 
108 
109 
110 
111 
112 
113 
114 
115 
116 
117 
118   private RepExOST(OrthogonalSpaceTempering orthogonalSpaceTempering, MonteCarloOST monteCarloOST,
119       MolecularDynamics molecularDynamics, OstType ostType, DynamicsOptions dynamicsOptions,
120       OSTOptions ostOptions, CompositeConfiguration properties, String fileType,
121       double repexInterval) throws IOException {
122     this.orthogonalSpaceTempering = orthogonalSpaceTempering;
123     switch (ostType) {
124       case MD -> {
125         algoRun = this::runMD;
126         isMC = false;
127       }
128       case MC_ONESTEP -> {
129         algoRun = this::runMCOneStep;
130         isMC = true;
131       }
132       case MC_TWOSTEP -> {
133         algoRun = this::runMCTwoStep;
134         isMC = true;
135       }
136       default -> throw new IllegalArgumentException(
137           " Could not recognize whether this is supposed to be MD, MC 1-step, or MC 2-step!");
138     }
139     this.molecularDynamics = molecularDynamics;
140     this.molecularDynamics.setAutomaticWriteouts(false);
141     this.dynamicsOptions = dynamicsOptions;
142     this.fileType = fileType;
143     this.monteCarloOST = monteCarloOST;
144     if (monteCarloOST != null) {
145       monteCarloOST.setAutomaticWriteouts(false);
146     }
147     this.extension = WriteoutOptions.toArchiveExtension(fileType);
148 
149     this.world = Comm.world();
150     this.rank = world.rank();
151     int size = world.size();
152 
153     MolecularAssembly[] allAssemblies = this.molecularDynamics.getAssemblyArray();
154     allFilenames = Arrays.stream(allAssemblies).map(MolecularAssembly::getFile).map(File::getName)
155         .map(FilenameUtils::getBaseName).toArray(String[]::new);
156 
157     File firstFile = allAssemblies[0].getFile();
158     basePath = FilenameUtils.getFullPath(firstFile.getAbsolutePath()) + File.separator;
159     String baseFileName = FilenameUtils.getBaseName(firstFile.getAbsolutePath());
160     dynFile = new File(format("%s%d%s%s.dyn", basePath, rank, File.separator, baseFileName));
161     this.molecularDynamics.setFallbackDynFile(dynFile);
162 
163     currentHistoIndex = orthogonalSpaceTempering.getHistogram().ld.histogramIndex;
164 
165     allHistograms = orthogonalSpaceTempering.getAllHistograms();
166     this.numPairs = size - 1;
167     this.invKT = -1.0 / (Constants.R * dynamicsOptions.getTemperature());
168 
169     long seed;
170     
171     
172     LongBuf seedBuf = LongBuf.buffer(0L);
173     if (rank == 0) {
174       seed = properties.getLong("randomseed", ThreadLocalRandom.current().nextLong());
175       seedBuf.put(0, seed);
176       world.broadcast(0, seedBuf);
177     } else {
178       world.broadcast(0, seedBuf);
179       seed = seedBuf.get(0);
180     }
181     this.random = new Random(seed);
182 
183     double timestep = dynamicsOptions.getDt() * Constants.FSEC_TO_PSEC;
184     stepsBetweenExchanges = Math.max(1, (int) (repexInterval / timestep));
185 
186     sendSynchronous = Arrays.stream(allHistograms)
187         .map(OrthogonalSpaceTempering.Histogram::getSynchronousSend).map(Optional::get)
188         .toArray(SendSynchronous[]::new);
189     if (sendSynchronous.length < 1) {
190       throw new IllegalArgumentException(" No SynchronousSend objects were found!");
191     }
192 
193     rankToHisto = IntStream.range(0, size).toArray();
194     
195     histoToRank = Arrays.copyOf(rankToHisto, size);
196 
197     Arrays.stream(sendSynchronous)
198         .forEach((SendSynchronous ss) -> ss.setHistograms(allHistograms, rankToHisto));
199 
200     totalSwaps = new long[numPairs];
201     acceptedSwaps = new long[numPairs];
202     fill(totalSwaps, 0);
203     fill(acceptedSwaps, 0);
204 
205     setFiles();
206     setHistogram(rank);
207   }
208 
209   
210 
211 
212 
213 
214 
215 
216 
217 
218 
219 
220 
221 
222 
223   public static RepExOST repexMC(OrthogonalSpaceTempering orthogonalSpaceTempering,
224       MonteCarloOST monteCarloOST, DynamicsOptions dynamicsOptions, OSTOptions ostOptions,
225       CompositeConfiguration compositeConfiguration, String fileType, boolean twoStep,
226       double repexInterval) throws IOException {
227     MolecularDynamics md = monteCarloOST.getMD();
228     OstType type = twoStep ? OstType.MC_TWOSTEP : OstType.MC_ONESTEP;
229     return new RepExOST(orthogonalSpaceTempering, monteCarloOST, md, type, dynamicsOptions,
230         ostOptions, compositeConfiguration, fileType, repexInterval);
231   }
232 
233   
234 
235 
236 
237 
238 
239 
240 
241 
242 
243 
244 
245 
246   public static RepExOST repexMD(OrthogonalSpaceTempering orthogonalSpaceTempering,
247       MolecularDynamics molecularDynamics, DynamicsOptions dynamicsOptions, OSTOptions ostOptions,
248       CompositeConfiguration compositeConfiguration, String fileType, double repexInterval)
249       throws IOException {
250     return new RepExOST(orthogonalSpaceTempering, null, molecularDynamics, OstType.MD,
251         dynamicsOptions, ostOptions, compositeConfiguration, fileType, repexInterval);
252   }
253 
254   public OrthogonalSpaceTempering getOST() {
255     return orthogonalSpaceTempering;
256   }
257 
258   
259 
260 
261 
262 
263 
264 
265   public void mainLoop(long numTimesteps, boolean equilibrate) throws IOException {
266     if (isMC) {
267       monteCarloOST.setEquilibration(equilibrate);
268     }
269     currentLambda = orthogonalSpaceTempering.getLambda();
270 
271     fill(totalSwaps, 0);
272     fill(acceptedSwaps, 0);
273 
274     if (equilibrate) {
275       logger.info(
276           format(" Equilibrating RepEx OST without exchanges on histogram %d.", currentHistoIndex));
277       algoRun.accept(numTimesteps);
278       reinitVelocities = false;
279     } else {
280       long numExchanges = numTimesteps / stepsBetweenExchanges;
281       for (int i = 0; i < numExchanges; i++) {
282         logger.info(format(" Beginning of RepEx loop %d of %d, operating on histogram %d", (i + 1),
283             numExchanges, currentHistoIndex));
284         world.barrier(mainLoopTag);
285         algoRun.accept(stepsBetweenExchanges);
286         orthogonalSpaceTempering.logOutputFiles(currentHistoIndex);
287         world.barrier(mainLoopTag);
288         proposeSwaps((i % 2), 2);
289         setFiles();
290 
291         long mdMoveNum = i * stepsBetweenExchanges;
292         currentLambda = orthogonalSpaceTempering.getLambda();
293         boolean trySnapshot = currentLambda >= orthogonalSpaceTempering.getLambdaWriteOut();
294         
295         EnumSet<MDWriteAction> written = molecularDynamics.writeFilesForStep(mdMoveNum, trySnapshot,
296             true);
297         if (written.contains(MDWriteAction.RESTART)) {
298           orthogonalSpaceTempering.writeAdditionalRestartInfo(false);
299         }
300 
301         reinitVelocities = false;
302       }
303     }
304 
305     logger.info(" Final rank-to-histogram mapping: " + Arrays.toString(rankToHisto));
306   }
307 
308   
309 
310   private void setHistogram(int index) {
311     currentHistoIndex = index;
312     orthogonalSpaceTempering.switchHistogram(index);
313   }
314 
315   
316 
317 
318 
319 
320   private void logIfMaster(String message) {
321     logIfMaster(Level.INFO, message);
322   }
323 
324   
325 
326 
327 
328 
329 
330   private void logIfMaster(Level level, String message) {
331     if (rank == 0) {
332       logger.log(level, message);
333     }
334   }
335 
336   
337 
338 
339 
340 
341 
342   private void logIfSwapping(String message) {
343     logIfMaster(message);
344   }
345 
346   private void setFiles() {
347     File[] trajFiles = Arrays.stream(allFilenames).map(
348         (String fn) -> format("%s%d%s%s.%s", basePath, currentHistoIndex, File.separator, fn,
349             extension)).map(File::new).toArray(File[]::new);
350     molecularDynamics.setArchiveFiles(trajFiles);
351   }
352 
353   
354 
355 
356 
357 
358 
359 
360 
361   private void proposeSwaps(final int offset, final int stride) {
362     for (int i = offset; i < numPairs; i += stride) {
363       int rankLow = histoToRank[i];
364       int rankHigh = histoToRank[i + 1];
365       OrthogonalSpaceTempering.Histogram histoLow = allHistograms[i];
366       OrthogonalSpaceTempering.Histogram histoHigh = allHistograms[i + 1];
367 
368       double lamLow = histoLow.getLastReceivedLambda();
369       double dUdLLow = histoLow.getLastReceivedDUDL();
370       double lamHigh = histoHigh.getLastReceivedLambda();
371       double dUdLHigh = histoHigh.getLastReceivedDUDL();
372 
373       double eii = histoLow.computeBiasEnergy(lamLow, dUdLLow);
374       double eij = histoLow.computeBiasEnergy(lamHigh, dUdLHigh);
375       double eji = histoHigh.computeBiasEnergy(lamLow, dUdLLow);
376       double ejj = histoHigh.computeBiasEnergy(lamHigh, dUdLHigh);
377 
378       logIfSwapping(format(
379           "\n Proposing exchange between histograms %d (rank %d) and %d (rank %d).\n"
380               + " Li: %.6f dU/dLi: %.6f Lj: %.6f dU/dLj: %.6f", i, rankLow, i + 1, rankHigh, lamLow,
381           dUdLLow, lamHigh, dUdLHigh));
382 
383       double e1 = eii + ejj;
384       double e2 = eji + eij;
385       boolean accept = BoltzmannMC.evaluateMove(random, invKT, e1, e2);
386       double acceptChance = BoltzmannMC.acceptChance(invKT, e1, e2);
387 
388       String desc = accept ? "Accepted" : "Rejected";
389       logIfSwapping(format(
390           " %s exchange with probability %.5f based on Eii %.6f, Ejj %.6f, Eij %.6f, Eji %.6f kcal/mol",
391           desc, acceptChance, eii, ejj, eij, eji));
392 
393       ++totalSwaps[i];
394       if (accept) {
395         ++acceptedSwaps[i];
396         switchHistos(rankLow, rankHigh, i);
397       }
398 
399       double acceptRate = ((double) acceptedSwaps[i]) / ((double) totalSwaps[i]);
400       logIfSwapping(format(" Replica exchange acceptance rate for pair %d-%d is %.3f%%", i, (i + 1),
401           acceptRate * 100));
402     }
403   }
404 
405   private void switchHistos(int rankLow, int rankHigh, int histoLow) {
406     int histoHigh = histoLow + 1;
407     rankToHisto[rankLow] = histoHigh;
408     rankToHisto[rankHigh] = histoLow;
409     histoToRank[histoLow] = rankHigh;
410     histoToRank[histoHigh] = rankLow;
411     setHistogram(rankToHisto[rank]);
412 
413     orthogonalSpaceTempering.setLambda(currentLambda);
414     
415 
416 
417     for (SendSynchronous send : sendSynchronous) {
418       send.updateRanks(rankToHisto);
419     }
420   }
421 
422   
423 
424 
425 
426 
427   private void runMCOneStep(long numSteps) {
428     monteCarloOST.setTotalSteps(numSteps);
429     monteCarloOST.sampleOneStep();
430   }
431 
432   
433 
434 
435 
436 
437   private void runMCTwoStep(long numSteps) {
438     monteCarloOST.setTotalSteps(numSteps);
439     monteCarloOST.sampleTwoStep();
440   }
441 
442   
443 
444 
445 
446 
447   private void runMD(long numSteps) {
448     molecularDynamics.dynamic(numSteps, dynamicsOptions.getDt(), dynamicsOptions.getReport(),
449         dynamicsOptions.getSnapshotInterval(), dynamicsOptions.getTemperature(), reinitVelocities,
450         fileType, dynamicsOptions.getCheckpoint(), dynFile);
451   }
452 
453   private enum OstType {
454     MD, MC_ONESTEP, MC_TWOSTEP
455   }
456 }