1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.algorithms.commands.test;
39
40 import ffx.algorithms.cli.AlgorithmsCommand;
41 import ffx.algorithms.cli.BarostatOptions;
42 import ffx.algorithms.thermodynamics.HistogramData;
43 import ffx.crystal.CrystalPotential;
44 import ffx.numerics.estimator.BennettAcceptanceRatio;
45 import ffx.numerics.estimator.EstimateBootstrapper;
46 import ffx.numerics.estimator.Zwanzig;
47 import ffx.potential.MolecularAssembly;
48 import ffx.potential.bonded.LambdaInterface;
49 import ffx.potential.cli.AlchemicalOptions;
50 import ffx.potential.cli.TopologyOptions;
51 import ffx.potential.parsers.SystemFilter;
52 import ffx.utilities.Constants;
53 import ffx.utilities.FFXBinding;
54 import org.apache.commons.configuration2.CompositeConfiguration;
55 import org.apache.commons.configuration2.Configuration;
56 import org.apache.commons.io.FilenameUtils;
57 import picocli.CommandLine.Command;
58 import picocli.CommandLine.Mixin;
59 import picocli.CommandLine.Option;
60 import picocli.CommandLine.Parameters;
61
62 import java.io.File;
63 import java.util.ArrayList;
64 import java.util.Arrays;
65 import java.util.List;
66 import java.util.OptionalDouble;
67 import java.util.logging.Level;
68
69 import static java.lang.String.format;
70 import static java.util.Arrays.fill;
71 import static org.apache.commons.math3.util.FastMath.max;
72 import static org.apache.commons.math3.util.FastMath.min;
73 import static org.apache.commons.math3.util.FastMath.round;
74
75
76
77
78
79
80
81
82
83 @Command(description = " Evaluates free energy of an M-OST run using the BAR estimator.", name = "test.MostBar")
84 public class MostBar extends AlgorithmsCommand {
85
86 @Mixin
87 private AlchemicalOptions alchemicalOptions;
88
89 @Mixin
90 private TopologyOptions topologyOptions;
91
92 @Mixin
93 private BarostatOptions barostat;
94
95
96
97
98 @Option(names = {"-t", "--temperature"}, paramLabel = "298.15", defaultValue = "298.15",
99 description = "Temperature in Kelvins")
100 private double temp;
101
102
103
104
105 @Option(names = {"--his", "--histogram"}, paramLabel = "file.his", defaultValue = "",
106 description = "Manually provided path to a histogram file (otherwise, attempts to autodetect from same directory as input files).")
107 private String histogramName;
108
109
110
111
112 @Option(names = {"--lb", "--lambdaBins"}, paramLabel = "autodetected", defaultValue = "-1",
113 description = "Manually specified number of lambda bins (else auto-detected from histogram")
114 private int lamBins;
115
116
117
118
119 @Option(names = {"-s", "--start"}, paramLabel = "1", defaultValue = "1",
120 description = "First snapshot to evaluate (1-indexed, inclusive).")
121 private int startFrame;
122
123
124
125
126 @Option(names = {"--fi", "--final"}, paramLabel = "-1", defaultValue = "-1",
127 description = "Last snapshot to evaluate (1-indexed, inclusive); leave negative to analyze to end of trajectory.")
128 private int finalFrame;
129
130
131
132
133 @Option(names = {"--st", "--stride"}, paramLabel = "1", defaultValue = "1",
134 description = "First snapshot to evaluate (1-indexed).")
135 private int stride;
136
137
138
139
140 @Option(names = {"--bo", "--bootstrap"}, paramLabel = "AUTO", defaultValue = "-1",
141 description = "Use this many bootstrap trials to estimate dG and uncertainty; default is 200-100000 (depending on number of frames).")
142 private long bootstrap;
143
144
145
146
147 @Option(names = {"--lambdaSorted"}, paramLabel = "false", defaultValue = "false",
148 description = "Input is sorted by lambda rather than simulation progress (sets -s to skip N-1 frames at each lambda value rather than N-1 of all frames).")
149 private boolean lambdaSorted;
150
151
152
153
154 @Option(names = {"-v", "--verbose"}, paramLabel = "false", defaultValue = "false",
155 description = "Print out extra information (e.g. collection of potential energies).")
156 private boolean verbose;
157
158
159
160
161 @Parameters(arity = "1..*", paramLabel = "files",
162 description = "Trajectory files for the first end of the window, followed by trajectories for the other end")
163 private List<String> filenames;
164
165 private MolecularAssembly[] topologies;
166 private SystemFilter[] openers;
167 private CrystalPotential potential;
168 private LambdaInterface linter;
169 private Configuration additionalProperties;
170
171 private List<List<Double>> energiesL;
172 private List<List<Double>> energiesUp;
173 private List<List<Double>> energiesDown;
174
175 private double[] lamPoints;
176 private int[] observations;
177 private double lamSep;
178 private double halfLamSep;
179 private double[] x;
180 private final double[] lastEntries = new double[3];
181 private static final String energyFormat = "%11.4f kcal/mol";
182 private static final String nanFormat = format("%20s", "N/A");
183
184 private int start;
185
186 private int end;
187 private Level standardLogging = Level.FINE;
188
189
190 private static final long BOOTSTRAP_PRINT = 50L;
191
192 private static final long MIN_BOOTSTRAP_TRIALS = 200L;
193 private static final long MAX_BOOTSTRAP_TRIALS = 50000L;
194 private static final long AUTO_BOOTSTRAP_NUMERATOR = 10000000L;
195
196 public void setProperties(CompositeConfiguration addedProperties) {
197 additionalProperties = addedProperties;
198 }
199
200
201
202
203 public MostBar() {
204 super();
205 }
206
207
208
209
210
211
212 public MostBar(FFXBinding binding) {
213 super(binding);
214 }
215
216
217
218
219
220
221 public MostBar(String[] args) {
222 super(args);
223 }
224
225
226
227
228 @Override
229 public MostBar run() {
230
231 if (!init()) {
232 return this;
233 }
234
235
236 int numTopologies = topologyOptions.getNumberOfTopologies(filenames);
237 int threadsPerTopology = topologyOptions.getThreadsPerTopology(numTopologies);
238 topologies = new MolecularAssembly[numTopologies];
239 openers = new SystemFilter[numTopologies];
240
241
242 alchemicalOptions.setAlchemicalProperties();
243 topologyOptions.setAlchemicalProperties(numTopologies);
244
245 standardLogging = verbose ? Level.INFO : Level.FINE;
246
247 logger.info(format(" Initializing %d topologies", numTopologies));
248
249
250 if (filenames == null || filenames.isEmpty()) {
251 activeAssembly = getActiveAssembly(null);
252 if (activeAssembly == null) {
253 logger.info(helpString());
254 return this;
255 }
256 filenames = new ArrayList<>();
257 filenames.add(activeAssembly.getFile().getName());
258 topologies[0] = alchemicalOptions.processFile(topologyOptions, activeAssembly, 0);
259 } else {
260 logger.info(format(" Initializing %d topologies...", numTopologies));
261 for (int i = 0; i < numTopologies; i++) {
262 topologies[i] = alchemicalOptions.openFile(algorithmFunctions,
263 topologyOptions, threadsPerTopology, filenames.get(i), i);
264 openers[i] = algorithmFunctions.getFilter();
265 }
266 }
267
268 StringBuilder sb = new StringBuilder("\n Using BAR to analyze an M-OST free energy change for systems ");
269 potential = (CrystalPotential) topologyOptions.assemblePotential(topologies, sb);
270 potential = barostat.checkNPT(topologies[0], potential);
271 linter = (LambdaInterface) potential;
272 logger.info(sb.toString());
273
274 int nSnapshots = openers[0].countNumModels();
275
276 if (histogramName.isEmpty()) {
277 histogramName = FilenameUtils.removeExtension(filenames.get(0)) + ".his";
278 }
279
280 if (lamBins < 1) {
281 File histogramFile = new File(histogramName);
282 HistogramData histogramData = HistogramData.readHistogram(histogramFile);
283 lamBins = histogramData.getLambdaBins();
284 if (histogramData.wasHistogramRead()) {
285 logger.info(format(" Autodetected %d from histogram file.", lamBins));
286 }
287 }
288
289 energiesL = new ArrayList<>(lamBins);
290 energiesUp = new ArrayList<>(lamBins);
291 energiesDown = new ArrayList<>(lamBins);
292 for (int i = 0; i < lamBins; i++) {
293 energiesL.add(new ArrayList<Double>());
294 energiesUp.add(new ArrayList<Double>());
295 energiesDown.add(new ArrayList<Double>());
296 }
297
298 lamSep = 1.0 / (lamBins - 1);
299 halfLamSep = 0.5 * lamSep;
300 lamPoints = new double[lamBins];
301
302 for (int i = 0; i < (lamBins - 1); i++) {
303 lamPoints[i] = i * lamSep;
304 }
305 lamPoints[lamBins - 1] = 1.0;
306
307 OptionalDouble optLam = openers[0].getLastReadLambda();
308
309 if (!optLam.isPresent()) {
310 throw new IllegalArgumentException(
311 format(" No lambda records found in the first header of archive file %s", filenames.get(0)));
312 }
313
314 start = startFrame - 1;
315 if (finalFrame < 1) {
316 end = nSnapshots;
317 } else {
318 end = min(nSnapshots, finalFrame);
319 }
320 end -= startFrame;
321
322 double lambda = optLam.getAsDouble();
323 int nVar = potential.getNumberOfVariables();
324 x = new double[nVar];
325
326
327 observations = new int[lamBins];
328 if (lambdaSorted) {
329 fill(observations, -startFrame);
330 } else {
331 fill(observations, 0);
332 }
333
334 logger.info(" Reading snapshots.");
335
336 addEntries(lambda, 0);
337
338 for (int i = 1; i < end; i++) {
339 for (int j = 0; j < numTopologies; j++) {
340 openers[j].readNext(false, false);
341 }
342 lambda = openers[0].getLastReadLambda().getAsDouble();
343 addEntries(lambda, i);
344 }
345
346 for (SystemFilter opener : openers) {
347 opener.closeReader();
348 }
349
350 double[][] eLow = new double[lamBins][];
351 double[][] eAt = new double[lamBins][];
352 double[][] eHigh = new double[lamBins][];
353 for (int i = 0; i < lamBins; i++) {
354 eLow[i] = energiesDown.get(i).stream().mapToDouble(Double::doubleValue).toArray();
355 eAt[i] = energiesL.get(i).stream().mapToDouble(Double::doubleValue).toArray();
356 eHigh[i] = energiesUp.get(i).stream().mapToDouble(Double::doubleValue).toArray();
357 }
358
359 logger.info("\n Initial estimate via the iteration method.");
360 BennettAcceptanceRatio bar = new BennettAcceptanceRatio(lamPoints, eLow, eAt, eHigh,
361 new double[]{temp});
362 Zwanzig forwards = bar.getInitialForwardsGuess();
363 Zwanzig backwards = bar.getInitialBackwardsGuess();
364
365 logger.info(format(" Free energy via BAR: %15.9f +/- %.9f kcal/mol.",
366 bar.getTotalFreeEnergyDifference(), bar.getTotalFEDifferenceUncertainty()));
367 logger.info(format(" Free energy via forwards FEP: %15.9f +/- %.9f kcal/mol.",
368 forwards.getTotalFreeEnergyDifference(), forwards.getTotalFEDifferenceUncertainty()));
369 logger.info(format(" Free energy via backwards FEP: %15.9f +/- %.9f kcal/mol.",
370 backwards.getTotalFreeEnergyDifference(), backwards.getTotalFEDifferenceUncertainty()));
371 logger.info(" Note - non-bootstrap FEP uncertainties are currently unreliable.");
372
373 double[] barFE = bar.getFreeEnergyDifferences();
374 double[] barVar = bar.getFEDifferenceUncertainties();
375 double[] forwardsFE = forwards.getFreeEnergyDifferences();
376 double[] forwardsVar = forwards.getFEDifferenceUncertainties();
377 double[] backwardsFE = backwards.getFreeEnergyDifferences();
378 double[] backwardsVar = backwards.getFEDifferenceUncertainties();
379
380 sb = new StringBuilder(
381 "\n Free Energy Profile Per Window\n Min_Lambda Counts Max_Lambda Counts BAR_dG BAR_Var FEP_dG FEP_Var FEP_Back_dG FEP_Back_Var\n");
382 for (int i = 0; i < (lamBins - 1); i++) {
383 sb.append(format(" %-10.8f %6d %-10.8f %6d %15.9f %12.9f %15.9f %12.9f %15.9f %12.9f\n",
384 lamPoints[i], eAt[i].length, lamPoints[i + 1], eAt[i + 1].length, barFE[i],
385 barVar[i], forwardsFE[i], forwardsVar[i], backwardsFE[i], backwardsVar[i]));
386 }
387 logger.info(sb.toString());
388
389 if (bootstrap == -1) {
390 int totalRead = Arrays.stream(observations).min().getAsInt();
391 if (totalRead >= MIN_BOOTSTRAP_TRIALS) {
392 bootstrap = AUTO_BOOTSTRAP_NUMERATOR / totalRead;
393
394 bootstrap = max(MIN_BOOTSTRAP_TRIALS, min(MAX_BOOTSTRAP_TRIALS, bootstrap));
395 } else {
396 logger.info(format(
397 " At least one lambda window had only %d snapshots read; defaulting to %d bootstrap cycles!",
398 totalRead, MIN_BOOTSTRAP_TRIALS));
399 bootstrap = MIN_BOOTSTRAP_TRIALS;
400 }
401 }
402
403 long bootPrint = BOOTSTRAP_PRINT;
404 if (!verbose) {
405 bootPrint *= 10L;
406 }
407
408
409 if (bootstrap > 0) {
410 logger.info(format(" Re-estimate free energy and uncertainty from %d bootstrap trials.", bootstrap));
411
412 EstimateBootstrapper barBS = new EstimateBootstrapper(bar);
413 EstimateBootstrapper forBS = new EstimateBootstrapper(forwards);
414 EstimateBootstrapper backBS = new EstimateBootstrapper(backwards);
415
416 long time = -System.nanoTime();
417 barBS.bootstrap(bootstrap, bootPrint);
418 time += System.nanoTime();
419 logger.info(format(" BAR bootstrapping complete in %.4f sec", time * Constants.NS2SEC));
420
421 time = -System.nanoTime();
422 forBS.bootstrap(bootstrap, bootPrint);
423 time += System.nanoTime();
424 logger.info(format(" Forwards FEP bootstrapping complete in %.4f sec", time * Constants.NS2SEC));
425
426 time = -System.nanoTime();
427 backBS.bootstrap(bootstrap, bootPrint);
428 time += System.nanoTime();
429 logger.info(format(" Reverse FEP bootstrapping complete in %.4f sec", time * Constants.NS2SEC));
430
431 barFE = barBS.getFreeEnergyDifferences();
432 barVar = barBS.getFEDifferenceStdDevs();
433 forwardsFE = forBS.getFreeEnergyDifferences();
434 forwardsVar = forBS.getFEDifferenceStdDevs();
435 backwardsFE = backBS.getFreeEnergyDifferences();
436 backwardsVar = backBS.getFEDifferenceStdDevs();
437
438 double sumFE = barBS.getTotalFreeEnergyDifference(barFE);
439 double varFE = barBS.getTotalFEDifferenceUncertainty(barVar);
440 logger.info(format(" Free energy via BAR: %15.9f +/- %.9f kcal/mol.", sumFE, varFE));
441
442 sumFE = forBS.getTotalFreeEnergyDifference(forwardsFE);
443 varFE = forBS.getTotalFEDifferenceUncertainty();
444 logger.info(format(" Free energy via forwards FEP: %15.9f +/- %.9f kcal/mol.", sumFE, varFE));
445
446 sumFE = backBS.getTotalFreeEnergyDifference(backwardsFE);
447 varFE = backBS.getTotalFEDifferenceUncertainty();
448 logger.info(format(" Free energy via backwards FEP: %15.9f +/- %.9f kcal/mol.", sumFE, varFE));
449
450 sb = new StringBuilder(
451 " Free Energy Profile\n Min_Lambda Counts Max_Lambda Counts BAR_dG BAR_Var FEP_dG FEP_Var FEP_Back_dG FEP_Back_Var\n");
452 for (int i = 0; i < (lamBins - 1); i++) {
453 sb.append(format(" %-10.8f %6d %-10.8f %6d %15.9f %12.9f %15.9f %12.9f %15.9f %12.9f\n",
454 lamPoints[i], eAt[i].length, lamPoints[i + 1], eAt[i + 1].length, barFE[i], barVar[i],
455 forwardsFE[i], forwardsVar[i], backwardsFE[i], backwardsVar[i]));
456 }
457 logger.info(sb.toString());
458 } else {
459 logger.info(" Bootstrap resampling disabled.");
460 }
461
462 return this;
463 }
464
465
466
467
468
469
470
471
472 private void addEntries(double lambda, int index) {
473 int bin = binForLambda(lambda);
474 ++observations[bin];
475
476 int offsetIndex = lambdaSorted ? observations[bin] : index - start;
477 assert offsetIndex <= end;
478
479 boolean inRange = offsetIndex >= 0 && offsetIndex <= end;
480 boolean onStride = (offsetIndex % stride == 0);
481 if (inRange && onStride) {
482 x = potential.getCoordinates(x);
483 lastEntries[0] = addLambdaDown(lambda, bin);
484 lastEntries[1] = addAtLambda(lambda, bin);
485 lastEntries[2] = addLambdaUp(lambda, bin);
486
487 String low = Double.isNaN(lastEntries[0]) ? nanFormat : format(energyFormat, lastEntries[0]);
488 String high = Double.isNaN(lastEntries[2]) ? nanFormat : format(energyFormat, lastEntries[2]);
489
490 logger.log(standardLogging, format(" Energies for snapshot %5d at lambda %.4f: " +
491 "%s, %s, %s", (index + 1), lambda, low, format(energyFormat, lastEntries[1]), high));
492 } else {
493 logger.log(standardLogging, " Skipping frame " + index);
494 }
495 }
496
497
498
499
500
501
502
503
504 private double addAtLambda(double lambda, int bin) {
505 assert lambda >= 0.0 && lambda <= 1.0;
506 linter.setLambda(lambda);
507 double e = potential.energy(x, false);
508
509
510 energiesL.get(bin).add(e);
511 return e;
512 }
513
514
515
516
517
518
519
520
521 private double addLambdaUp(double lambda, int bin) {
522 double modLambda = lambda + lamSep;
523
524 modLambda = min(1.0d, modLambda);
525 if (bin == (lamBins - 1)) {
526 energiesUp.get(bin).add(Double.NaN);
527 return Double.NaN;
528 } else {
529 linter.setLambda(modLambda);
530 double e = potential.energy(x, false);
531
532
533 energiesUp.get(bin).add(e);
534 linter.setLambda(lambda);
535 return e;
536 }
537 }
538
539
540
541
542
543
544
545
546 private double addLambdaDown(double lambda, int bin) {
547 double modLambda = lambda - lamSep;
548
549 modLambda = max(0.0d, modLambda);
550
551 if (bin == 0) {
552 energiesDown.get(0).add(Double.NaN);
553 return Double.NaN;
554 } else {
555 linter.setLambda(modLambda);
556 double e = potential.energy(x, false);
557
558
559 energiesDown.get(bin).add(e);
560 linter.setLambda(lambda);
561 return e;
562 }
563 }
564
565
566
567
568
569
570
571 private int binForLambda(double lambda) {
572
573 return (int) round(lambda / lamSep);
574 }
575 }