1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.estimator;
39
40 import ffx.numerics.OptimizationInterface;
41 import ffx.numerics.integrate.DataSet;
42 import ffx.numerics.integrate.DoublesDataSet;
43 import ffx.numerics.integrate.Integrate1DNumeric;
44 import ffx.numerics.optimization.LBFGS;
45 import ffx.numerics.optimization.LineSearch;
46 import ffx.numerics.optimization.OptimizationListener;
47 import ffx.utilities.Constants;
48 import org.apache.commons.math3.linear.MatrixUtils;
49 import org.apache.commons.math3.linear.RealMatrix;
50 import org.apache.commons.math3.linear.SingularValueDecomposition;
51
52 import java.io.BufferedWriter;
53 import java.io.File;
54 import java.io.FileWriter;
55 import java.io.IOException;
56 import java.util.ArrayList;
57 import java.util.Arrays;
58 import java.util.Random;
59 import java.util.logging.Logger;
60
61 import static ffx.numerics.estimator.EstimateBootstrapper.getBootstrapIndices;
62 import static ffx.numerics.estimator.Zwanzig.Directionality.BACKWARDS;
63 import static ffx.numerics.estimator.Zwanzig.Directionality.FORWARDS;
64 import static java.lang.System.arraycopy;
65 import static java.util.Arrays.copyOf;
66 import static java.util.Arrays.stream;
67 import static org.apache.commons.lang3.ArrayFill.fill;
68 import static org.apache.commons.math3.util.FastMath.abs;
69 import static org.apache.commons.math3.util.FastMath.exp;
70 import static org.apache.commons.math3.util.FastMath.log;
71 import static org.apache.commons.math3.util.FastMath.sqrt;
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90 public class MultistateBennettAcceptanceRatio extends SequentialEstimator implements BootstrappableEstimator, OptimizationInterface {
91 private static final Logger logger = Logger.getLogger(MultistateBennettAcceptanceRatio.class.getName());
92
93
94
95 private static final double DEFAULT_TOLERANCE = 1.0E-7;
96
97
98
99 private final int nFreeEnergyDiffs;
100
101
102
103 private final double[] mbarFEDifferenceEstimates;
104
105
106
107 private final int nLambdaStates;
108
109
110
111
112
113 private double[] mbarFEEstimates;
114
115
116
117 private double[] mbarObservableEnsembleAverages;
118 private double[] mbarObservableEnsembleAverageUncertainties;
119
120
121
122 private double[] mbarUncertainties;
123
124
125
126 private double[][] uncertaintyMatrix;
127
128
129
130 private final double tolerance;
131
132
133
134 private final Random random;
135
136
137
138 private double totalMBAREstimate;
139
140
141
142 private double totalMBARUncertainty;
143
144
145
146 private double[] mbarEnthalpy;
147
148
149
150
151 private double[] mbarEntropy;
152
153 public double[] rtValues;
154
155
156
157
158
159 private double[][] reducedPotentials;
160
161 private double[][] oAllFlat;
162 private double[][] biasFlat;
163
164
165
166 private SeedType seedType;
167
168
169
170
171
172 public enum SeedType {BAR, ZWANZIG, ZEROS;}
173
174 public static boolean FORCE_ZEROS_SEED = false;
175 public static boolean VERBOSE = false;
176
177
178
179
180
181
182
183
184 public MultistateBennettAcceptanceRatio(double[] lambdaValues, double[][][] energiesAll, double[] temperature) {
185 this(lambdaValues, energiesAll, temperature, DEFAULT_TOLERANCE, SeedType.ZWANZIG);
186 }
187
188
189
190
191
192
193
194
195
196
197 public MultistateBennettAcceptanceRatio(double[] lambdaValues, double[][][] energiesAll, double[] temperature,
198 double tolerance, SeedType seedType) {
199 super(lambdaValues, energiesAll, temperature);
200 this.tolerance = tolerance;
201 this.seedType = seedType;
202
203
204 nLambdaStates = lambdaValues.length;
205 mbarFEEstimates = new double[nLambdaStates];
206
207 nFreeEnergyDiffs = lambdaValues.length - 1;
208 mbarFEDifferenceEstimates = new double[nFreeEnergyDiffs];
209 mbarUncertainties = new double[nFreeEnergyDiffs];
210 mbarEnthalpy = new double[nFreeEnergyDiffs];
211 mbarEntropy = new double[nFreeEnergyDiffs];
212 random = new Random();
213 estimateDG();
214 }
215
216 public MultistateBennettAcceptanceRatio(double[] lambdaValues, int[] snaps, double[][] eAllFlat, double[] temperature,
217 double tolerance, SeedType seedType) {
218 super(lambdaValues, snaps, eAllFlat, temperature);
219 this.tolerance = tolerance;
220 this.seedType = seedType;
221
222
223 nLambdaStates = lambdaValues.length;
224 mbarFEEstimates = new double[nLambdaStates];
225
226 nFreeEnergyDiffs = lambdaValues.length - 1;
227 mbarFEDifferenceEstimates = new double[nFreeEnergyDiffs];
228 mbarUncertainties = new double[nFreeEnergyDiffs];
229 mbarEnthalpy = new double[nFreeEnergyDiffs];
230 mbarEntropy = new double[nFreeEnergyDiffs];
231 random = new Random();
232 estimateDG();
233 }
234
235
236
237
238 private void seedEnergies() {
239 switch (seedType) {
240 case BAR:
241 try {
242 if (eLambdaMinusdL == null || eLambda == null || eLambdaPlusdL == null) {
243 seedType = SeedType.ZEROS;
244 seedEnergies();
245 return;
246 }
247 SequentialEstimator barEstimator = new BennettAcceptanceRatio(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures);
248 mbarFEEstimates[0] = 0.0;
249 double[] barEstimates = barEstimator.getFreeEnergyDifferences();
250 for (int i = 0; i < nFreeEnergyDiffs; i++) {
251 mbarFEEstimates[i + 1] = mbarFEEstimates[i] + barEstimates[i];
252 }
253 break;
254 } catch (IllegalArgumentException e) {
255 logger.warning(" BAR failed to converge. Zwanzig will be used for seed energies.");
256 seedType = SeedType.ZWANZIG;
257 seedEnergies();
258 return;
259 }
260 case ZWANZIG:
261 try {
262 if (eLambdaMinusdL == null || eLambda == null || eLambdaPlusdL == null) {
263 seedType = SeedType.ZEROS;
264 seedEnergies();
265 return;
266 }
267 Zwanzig forwardsFEP = new Zwanzig(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures, FORWARDS);
268 Zwanzig backwardsFEP = new Zwanzig(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures, BACKWARDS);
269 double[] forwardZwanzig = forwardsFEP.getFreeEnergyDifferences();
270 double[] backwardZwanzig = backwardsFEP.getFreeEnergyDifferences();
271 mbarFEEstimates[0] = 0.0;
272 for (int i = 0; i < nFreeEnergyDiffs; i++) {
273 mbarFEEstimates[i + 1] = mbarFEEstimates[i] + .5 * (forwardZwanzig[i] + backwardZwanzig[i]);
274 }
275 if (stream(mbarFEEstimates).anyMatch(Double::isInfinite) || stream(mbarFEEstimates).anyMatch(Double::isNaN)) {
276 throw new IllegalArgumentException("MBAR contains NaNs or Infs after seeding.");
277 }
278 break;
279 } catch (IllegalArgumentException e) {
280 logger.warning(" Zwanzig failed to converge. Zeros will be used for seed energies.");
281 seedType = SeedType.ZEROS;
282 seedEnergies();
283 return;
284 }
285 case ZEROS:
286 fill(mbarFEEstimates, 0.0);
287 break;
288 default:
289 throw new IllegalArgumentException("Seed type not supported");
290 }
291 }
292
293
294
295
296 @Override
297 public void estimateDG() {
298 estimateDG(false);
299 }
300
301
302
303
304 @Override
305 public void estimateDG(boolean randomSamples) {
306 if (MultistateBennettAcceptanceRatio.VERBOSE) {
307 logger.setLevel(java.util.logging.Level.FINE);
308 }
309
310
311 fill(mbarFEEstimates, 0.0);
312 if (FORCE_ZEROS_SEED) {
313 seedType = SeedType.ZEROS;
314 }
315 seedEnergies();
316 if (stream(mbarFEEstimates).anyMatch(Double::isInfinite) || stream(mbarFEEstimates).anyMatch(Double::isNaN)) {
317 seedType = SeedType.ZEROS;
318 seedEnergies();
319 }
320 if (MultistateBennettAcceptanceRatio.VERBOSE) {
321 logger.info(" Seed Type: " + seedType);
322 logger.info(" MBAR FE Estimates after seeding: " + Arrays.toString(mbarFEEstimates));
323 }
324
325
326 rtValues = new double[nLambdaStates];
327 double[] invRTValues = new double[nLambdaStates];
328 for (int i = 0; i < nLambdaStates; i++) {
329 rtValues[i] = Constants.R * temperatures[i];
330 invRTValues[i] = 1.0 / rtValues[i];
331 }
332
333
334 int numEvaluations = eAllFlat[0].length;
335
336
337 int[][] indices = new int[nLambdaStates][numEvaluations];
338 if (randomSamples) {
339
340 int[] randomIndices = new int[numEvaluations];
341 int sum = 0;
342 for (int snap : nSamples) {
343 System.arraycopy(getBootstrapIndices(snap, random), 0, randomIndices, sum, snap);
344 sum += snap;
345 }
346 for (int i = 0; i < nLambdaStates; i++) {
347
348 indices[i] = randomIndices;
349 }
350 } else {
351 for (int i = 0; i < numEvaluations; i++) {
352 for (int j = 0; j < nLambdaStates; j++) {
353 indices[j][i] = i;
354 }
355 }
356 }
357
358
359 reducedPotentials = new double[nLambdaStates][numEvaluations];
360 double minPotential = Double.POSITIVE_INFINITY;
361 for (int state = 0; state < eAllFlat.length; state++) {
362 for (int n = 0; n < eAllFlat[0].length; n++) {
363 reducedPotentials[state][n] = eAllFlat[state][indices[state][n]] * invRTValues[state];
364 if (reducedPotentials[state][n] < minPotential) {
365 minPotential = reducedPotentials[state][n];
366 }
367 }
368 }
369
370
371 for (int state = 0; state < nLambdaStates; state++) {
372 for (int n = 0; n < numEvaluations; n++) {
373 reducedPotentials[state][n] -= minPotential;
374 }
375 }
376
377
378
379 ArrayList<Integer> zeroSnapLambdas = new ArrayList<>();
380 ArrayList<Integer> sampledLambdas = new ArrayList<>();
381 for (int i = 0; i < nLambdaStates; i++) {
382 if (nSamples[i] == 0) {
383 zeroSnapLambdas.add(i);
384 } else {
385 sampledLambdas.add(i);
386 }
387 }
388 int nLambdaStatesTemp = nLambdaStates - zeroSnapLambdas.size();
389 double[][] reducedPotentialsTemp = new double[nLambdaStates - zeroSnapLambdas.size()][numEvaluations];
390 double[] mbarFEEstimatesTemp = new double[nLambdaStates - zeroSnapLambdas.size()];
391 int[] snapsTemp = new int[nLambdaStates - zeroSnapLambdas.size()];
392 if (!zeroSnapLambdas.isEmpty()) {
393 int index = 0;
394 for (int i = 0; i < nLambdaStates; i++) {
395 if (!zeroSnapLambdas.contains(i)) {
396 reducedPotentialsTemp[index] = reducedPotentials[i];
397 mbarFEEstimatesTemp[index] = mbarFEEstimates[i];
398 snapsTemp[index] = nSamples[i];
399 index++;
400 }
401 }
402 logger.info(" Sampled Lambdas: " + sampledLambdas);
403 logger.info(" Zero Snap Lambdas: " + zeroSnapLambdas);
404 } else {
405 reducedPotentialsTemp = reducedPotentials;
406 mbarFEEstimatesTemp = mbarFEEstimates;
407 snapsTemp = nSamples;
408 }
409
410
411
412 double[] prevMBAR = copyOf(mbarFEEstimatesTemp, nLambdaStatesTemp);
413 ;
414 double omega = 1.5;
415 for (int i = 0; i < 10; i++) {
416 prevMBAR = copyOf(mbarFEEstimatesTemp, nLambdaStatesTemp);
417 mbarFEEstimatesTemp = mbarSelfConsistentUpdate(reducedPotentialsTemp, snapsTemp, mbarFEEstimatesTemp);
418 for (int j = 0; j < nLambdaStatesTemp; j++) {
419 mbarFEEstimatesTemp[j] = omega * mbarFEEstimatesTemp[j] + (1 - omega) * prevMBAR[j];
420 }
421 if (stream(mbarFEEstimatesTemp).anyMatch(Double::isInfinite) || stream(mbarFEEstimatesTemp).anyMatch(Double::isNaN)) {
422 throw new IllegalArgumentException("MBAR contains NaNs or Infs during startup SCI ");
423 }
424 if (converged(prevMBAR)) {
425 break;
426 }
427 }
428 if (MultistateBennettAcceptanceRatio.VERBOSE) {
429 logger.info(" Omega for SCI w/ relaxation: " + omega);
430 logger.info(" MBAR FE Estimates after 10 SCI iterations: " + Arrays.toString(mbarFEEstimatesTemp));
431 }
432
433 try {
434 if (nLambdaStatesTemp > 100 && !converged(prevMBAR)) {
435 if (MultistateBennettAcceptanceRatio.VERBOSE) {
436 logger.info(" L-BFGS optimization started.");
437 }
438 int mCorrections = 5;
439 double[] x = new double[nLambdaStatesTemp];
440 arraycopy(mbarFEEstimatesTemp, 0, x, 0, nLambdaStatesTemp);
441 double[] grad = mbarGradient(reducedPotentialsTemp, snapsTemp, mbarFEEstimatesTemp);
442 double eps = 1.0E-4;
443 OptimizationListener listener = getOptimizationListener();
444 LBFGS.minimize(nLambdaStatesTemp, mCorrections, x, mbarObjectiveFunction(reducedPotentialsTemp, snapsTemp, mbarFEEstimatesTemp),
445 grad, eps, 1000, this, listener);
446 arraycopy(x, 0, mbarFEEstimatesTemp, 0, nLambdaStatesTemp);
447 } else if (!converged(prevMBAR)) {
448 if (MultistateBennettAcceptanceRatio.VERBOSE) {
449 logger.info(" Newton optimization started.");
450 }
451 mbarFEEstimatesTemp = newton(mbarFEEstimatesTemp, reducedPotentialsTemp, snapsTemp, tolerance);
452 }
453 } catch (Exception e) {
454 logger.warning(" L-BFGS/Newton failed to converge. Finishing w/ self-consistent iteration. Message: " +
455 e.getMessage());
456 }
457 if (MultistateBennettAcceptanceRatio.VERBOSE) {
458 logger.info(" MBAR FE Estimates after gradient optimization: " + Arrays.toString(mbarFEEstimatesTemp));
459 }
460
461
462 int count = 0;
463 for (Integer i : sampledLambdas) {
464 if (!Double.isNaN(mbarFEEstimatesTemp[count])) {
465 mbarFEEstimates[i] = mbarFEEstimatesTemp[count];
466 }
467 count++;
468 }
469
470
471 int sciIter = 0;
472 while (!converged(prevMBAR) && sciIter < 1000) {
473 prevMBAR = copyOf(mbarFEEstimates, nLambdaStates);
474 mbarFEEstimates = mbarSelfConsistentUpdate(reducedPotentials, nSamples, mbarFEEstimates);
475 for (int i = 0; i < nLambdaStates; i++) {
476 mbarFEEstimates[i] = omega * mbarFEEstimates[i] + (1 - omega) * prevMBAR[i];
477 }
478 if (stream(mbarFEEstimates).anyMatch(Double::isInfinite) || stream(mbarFEEstimates).anyMatch(Double::isNaN)) {
479 throw new IllegalArgumentException("MBAR estimate contains NaNs or Infs after iteration " + sciIter);
480 }
481 sciIter++;
482 }
483 if (MultistateBennettAcceptanceRatio.VERBOSE) {
484 logger.info(" SCI iterations (max 1000): " + sciIter);
485 }
486
487
488 double[][] theta = mbarTheta(reducedPotentials, nSamples, mbarFEEstimates);
489 mbarUncertainties = mbarUncertaintyCalc(theta);
490 totalMBARUncertainty = mbarTotalUncertaintyCalc(theta);
491 uncertaintyMatrix = diffMatrixCalculation(theta);
492 if (!randomSamples && MultistateBennettAcceptanceRatio.VERBOSE) {
493 logWeights();
494 }
495
496
497 for (int i = 0; i < nLambdaStates; i++) {
498 mbarFEEstimates[i] = mbarFEEstimates[i] * rtValues[i];
499 }
500 for (int i = 0; i < nFreeEnergyDiffs; i++) {
501 mbarFEDifferenceEstimates[i] = mbarFEEstimates[i + 1] - mbarFEEstimates[i];
502 }
503
504 mbarEnthalpy = mbarEnthalpyCalc(eAllFlat, mbarFEEstimates);
505 mbarEntropy = mbarEntropyCalc(mbarEnthalpy, mbarFEEstimates);
506
507 totalMBAREstimate = stream(mbarFEDifferenceEstimates).sum();
508 }
509
510
511
512
513
514
515
516
517
518
519
520 private boolean converged(double[] prevMBAR) {
521 double[] differences = new double[prevMBAR.length];
522 for (int i = 0; i < prevMBAR.length; i++) {
523 differences[i] = abs(prevMBAR[i] - mbarFEEstimates[i]);
524 }
525 return stream(differences).allMatch(d -> d < tolerance);
526 }
527
528
529
530
531
532
533
534
535
536
537 private void logWeights() {
538 logger.info(" MBAR Weight Matrix Information Collapsed:");
539 double[][] W = mbarW(reducedPotentials, nSamples, mbarFEEstimates);
540 double[][] collapsedW = new double[W.length][W.length];
541 for (int i = 0; i < nSamples.length; i++) {
542 for (int j = 0; j < W.length; j++) {
543 int start = 0;
544 for (int k = 0; k < i; k++) {
545 start += nSamples[k];
546 }
547 for (int k = 0; k < nSamples[i]; k++) {
548 collapsedW[j][i] += W[j][start + k];
549 }
550 }
551 }
552 for (int i = 0; i < W.length; i++) {
553 logger.info("\n Estimation " + i + ": " + Arrays.toString(collapsedW[i]));
554 }
555 double[] rowSum = new double[W.length];
556 for (int i = 0; i < collapsedW[0].length; i++) {
557 for (double[] trajectory : collapsedW) {
558 rowSum[i] += trajectory[i];
559 }
560 }
561 softMax(rowSum);
562 logger.info("\n Softmax of trajectory weight: " + Arrays.toString(rowSum));
563 }
564
565
566
567
568
569
570
571
572
573
574
575
576 private static double mbarObjectiveFunction(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
577 if (stream(freeEnergyEstimates).anyMatch(Double::isInfinite) || stream(freeEnergyEstimates).anyMatch(Double::isNaN)) {
578 throw new IllegalArgumentException("MBAR contains NaNs or Infs.");
579 }
580 int nStates = freeEnergyEstimates.length;
581 double[] log_denom_n = new double[reducedPotentials[0].length];
582 for (int i = 0; i < reducedPotentials[0].length; i++) {
583 double[] temp = new double[nStates];
584 double maxTemp = Double.NEGATIVE_INFINITY;
585 for (int j = 0; j < nStates; j++) {
586 temp[j] = freeEnergyEstimates[j] - reducedPotentials[j][i];
587 if (temp[j] > maxTemp) {
588 maxTemp = temp[j];
589 }
590 }
591 log_denom_n[i] = logSumExp(temp, snapsPerLambda, maxTemp);
592 }
593 double[] dotNkFk = new double[snapsPerLambda.length];
594 for (int i = 0; i < snapsPerLambda.length; i++) {
595 dotNkFk[i] = snapsPerLambda[i] * freeEnergyEstimates[i];
596 }
597 return stream(log_denom_n).sum() - stream(dotNkFk).sum();
598 }
599
600
601
602
603
604
605
606
607
608 private static double[] mbarGradient(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
609 int nStates = freeEnergyEstimates.length;
610 double[] log_num_k = new double[nStates];
611 double[] log_denom_n = new double[reducedPotentials[0].length];
612 double[][] logDiff = new double[reducedPotentials.length][reducedPotentials[0].length];
613 double[] maxLogDiff = new double[nStates];
614 Arrays.fill(maxLogDiff, Double.NEGATIVE_INFINITY);
615 for (int i = 0; i < reducedPotentials[0].length; i++) {
616 double[] temp = new double[nStates];
617 double maxTemp = Double.NEGATIVE_INFINITY;
618 for (int j = 0; j < nStates; j++) {
619 temp[j] = freeEnergyEstimates[j] - reducedPotentials[j][i];
620 if (temp[j] > maxTemp) {
621 maxTemp = temp[j];
622 }
623 }
624 log_denom_n[i] = logSumExp(temp, snapsPerLambda, maxTemp);
625 for (int j = 0; j < nStates; j++) {
626 logDiff[j][i] = -log_denom_n[i] - reducedPotentials[j][i];
627 if (logDiff[j][i] > maxLogDiff[j]) {
628 maxLogDiff[j] = logDiff[j][i];
629 }
630 }
631 }
632 for (int i = 0; i < nStates; i++) {
633 log_num_k[i] = logSumExp(logDiff[i], maxLogDiff[i]);
634 }
635 double[] grad = new double[nStates];
636 for (int i = 0; i < nStates; i++) {
637 grad[i] = -1.0 * snapsPerLambda[i] * (1.0 - exp(freeEnergyEstimates[i] + log_num_k[i]));
638 }
639 return grad;
640 }
641
642
643
644
645
646
647
648
649
650 private static double[][] mbarHessian(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
651 int nStates = freeEnergyEstimates.length;
652 double[][] W = mbarW(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
653
654 double[][] hessian = new double[nStates][nStates];
655 for (int i = 0; i < nStates; i++) {
656 for (int j = 0; j < nStates; j++) {
657 double sum = 0.0;
658 for (int k = 0; k < reducedPotentials[0].length; k++) {
659 sum += W[i][k] * W[j][k];
660 }
661 hessian[i][j] = sum * snapsPerLambda[i] * snapsPerLambda[j];
662 }
663 double wSum = 0.0;
664 for (int k = 0; k < W[i].length; k++) {
665 wSum += W[i][k];
666 }
667 hessian[i][i] -= wSum * snapsPerLambda[i];
668 }
669
670 for (int i = 0; i < nStates; i++) {
671 for (int j = 0; j < nStates; j++) {
672 hessian[i][j] = -hessian[i][j];
673 }
674 }
675 return hessian;
676 }
677
678
679
680
681
682
683
684
685
686
687 private static double[][] mbarW(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
688 int nStates = freeEnergyEstimates.length;
689 double[] log_denom_n = new double[reducedPotentials[0].length];
690 for (int i = 0; i < reducedPotentials[0].length; i++) {
691 double[] temp = new double[nStates];
692 double maxTemp = Double.NEGATIVE_INFINITY;
693 for (int j = 0; j < nStates; j++) {
694 temp[j] = freeEnergyEstimates[j] - reducedPotentials[j][i];
695 if (temp[j] > maxTemp) {
696 maxTemp = temp[j];
697 }
698 }
699
700 log_denom_n[i] = logSumExp(temp, snapsPerLambda, maxTemp);
701 }
702
703
704 double[][] W = new double[nStates][reducedPotentials[0].length];
705 for (int i = 0; i < nStates; i++) {
706 for (int j = 0; j < reducedPotentials[0].length; j++) {
707 W[i][j] = exp(freeEnergyEstimates[i] - reducedPotentials[i][j] - log_denom_n[j]);
708 }
709 }
710 return W;
711 }
712
713 private double[] mbarEnthalpyCalc(double[][] reducedPotentials, double[] mbarFEEstimates) {
714 double[] enthalpy = new double[mbarFEEstimates.length - 1];
715 double[] averagePotential = new double[mbarFEEstimates.length];
716 for (int i = 0; i < reducedPotentials.length; i++) {
717 averagePotential[i] = computeExpectations(eAllFlat[i])[i];
718 }
719 for (int i = 0; i < enthalpy.length; i++) {
720 enthalpy[i] = averagePotential[i + 1] - averagePotential[i];
721 }
722 return enthalpy;
723 }
724
725 private double[] mbarEntropyCalc(double[] mbarEnthalpy, double[] mbarFEEstimates) {
726 double[] entropy = new double[mbarFEEstimates.length - 1];
727 for (int i = 0; i < entropy.length; i++) {
728 entropy[i] = mbarEnthalpy[i] - mbarFEDifferenceEstimates[i];
729 }
730 return entropy;
731 }
732
733
734
735
736
737
738
739 public void setBiasData(double[][][] biasAll, boolean multiDataObservable) {
740 biasFlat = new double[biasAll.length][biasAll.length * biasAll[0][0].length];
741 if (multiDataObservable) {
742 int[] snapsT = new int[biasAll.length];
743 int[] nanCount = new int[biasAll.length];
744 for (int i = 0; i < biasAll.length; i++) {
745 ArrayList<Double> temp = new ArrayList<>();
746 double maxBias = Double.NEGATIVE_INFINITY;
747 for (int j = 0; j < biasAll.length; j++) {
748 int count = 0;
749 int countNaN = 0;
750 for (int k = 0; k < biasAll[j][i].length; k++) {
751
752 if (!Double.isNaN(biasAll[j][i][k])) {
753 temp.add(biasAll[j][i][k]);
754 if (biasAll[j][i][k] > maxBias) {
755 maxBias = biasAll[j][i][k];
756 }
757 count++;
758 } else {
759 countNaN++;
760 }
761 }
762 snapsT[j] = count;
763 nanCount[j] = countNaN;
764 }
765 biasFlat[i] = temp.stream().mapToDouble(Double::doubleValue).toArray();
766
767 for (int j = 0; j < biasFlat[i].length; j++) {
768 biasFlat[i][j] -= maxBias;
769 }
770 }
771 } else {
772 int count = 0;
773 double maxBias = Double.NEGATIVE_INFINITY;
774 for (int i = 0; i < biasAll.length; i++) {
775 for (int j = 0; j < biasAll[0][0].length; j++) {
776 if (!Double.isNaN(biasAll[i][i][j])) {
777 biasFlat[0][count] = biasAll[i][i][j];
778 if (biasAll[i][i][j] > maxBias) {
779 maxBias = biasAll[i][i][j];
780 }
781 count++;
782 }
783 }
784 }
785
786 for (int i = 0; i < biasFlat[0].length; i++) {
787 biasFlat[0][i] -= maxBias;
788 }
789 }
790 }
791
792 public void setBiasData(double[][] biasData) {
793 this.biasFlat = biasData;
794
795 for (int i = 0; i < biasFlat.length; i++) {
796 double maxBias = Double.NEGATIVE_INFINITY;
797 for (int j = 0; j < biasFlat[i].length; j++) {
798 if (biasFlat[i][j] > maxBias) {
799 maxBias = biasFlat[i][j];
800 }
801 }
802 for (int j = 0; j < biasFlat[i].length; j++) {
803 biasFlat[i][j] -= maxBias;
804 }
805 }
806 }
807
808 public void setObservableData(double[][][] oAll, boolean multiDataObservable, boolean uncertainties) {
809 oAllFlat = new double[oAll.length][oAll.length * oAll[0][0].length];
810 if (multiDataObservable) {
811 int[] snapsT = new int[oAll.length];
812 int[] nanCount = new int[oAll.length];
813 for (int i = 0; i < oAll.length; i++) {
814 ArrayList<Double> temp = new ArrayList<>();
815 for (int j = 0; j < oAll.length; j++) {
816 int count = 0;
817 int countNaN = 0;
818 for (int k = 0; k < oAll[j][i].length; k++) {
819
820 if (!Double.isNaN(oAll[j][i][k])) {
821 temp.add(oAll[j][i][k]);
822 count++;
823 } else {
824 countNaN++;
825 }
826 }
827 snapsT[j] = count;
828 nanCount[j] = countNaN;
829 }
830 oAllFlat[i] = temp.stream().mapToDouble(Double::doubleValue).toArray();
831 }
832 } else {
833 int count = 0;
834 for (int i = 0; i < oAll.length; i++) {
835 for (int j = 0; j < oAll[0][0].length; j++) {
836 if (!Double.isNaN(oAll[i][i][j])) {
837 oAllFlat[0][count] = oAll[i][i][j];
838 count++;
839 }
840 }
841 }
842 }
843
844 if (biasFlat != null) {
845 for (int i = 0; i < oAllFlat.length; i++) {
846 for (int j = 0; j < oAllFlat[i].length; j++) {
847 oAllFlat[i][j] *= exp(biasFlat[i][j] / rtValues[i]);
848 }
849 }
850 }
851 this.fillObservationExpectations(multiDataObservable, uncertainties);
852 }
853
854 public void setObservableData(double[][] oAll, boolean uncertainties) {
855 oAllFlat = oAll;
856
857 if (biasFlat != null) {
858 if (oAllFlat.length != biasFlat.length || oAllFlat[0].length != biasFlat[0].length) {
859 logger.severe("Observable and bias data are not the same size. Exiting.");
860 }
861 for (int i = 0; i < oAllFlat.length; i++) {
862 for (int j = 0; j < oAllFlat[i].length; j++) {
863 oAllFlat[i][j] *= exp(biasFlat[i][j] / rtValues[i]);
864 }
865 }
866 }
867 this.fillObservationExpectations(oAllFlat.length != 1, uncertainties);
868 }
869
870 public double getTIIntegral() {
871 DataSet dSet = new DoublesDataSet(Integrate1DNumeric.generateXPoints(0, 1, mbarObservableEnsembleAverages.length, false),
872 mbarObservableEnsembleAverages, false);
873 return Integrate1DNumeric.integrateData(dSet, Integrate1DNumeric.IntegrationSide.LEFT, Integrate1DNumeric.IntegrationType.TRAPEZOIDAL);
874 }
875
876
877
878
879
880
881
882 private void fillObservationExpectations(boolean multiData, boolean uncertainties) {
883 if (multiData) {
884 mbarObservableEnsembleAverages = new double[oAllFlat.length];
885 mbarObservableEnsembleAverageUncertainties = new double[oAllFlat.length];
886 for (int i = 0; i < oAllFlat.length; i++) {
887 mbarObservableEnsembleAverages[i] = computeExpectations(oAllFlat[i])[i];
888 if (uncertainties) {
889 mbarObservableEnsembleAverageUncertainties[i] = computeExpectationStd(oAllFlat[i])[i];
890 }
891 }
892 } else {
893 mbarObservableEnsembleAverages = computeExpectations(oAllFlat[0]);
894 if (uncertainties) {
895 mbarObservableEnsembleAverageUncertainties = computeExpectationStd(oAllFlat[0]);
896 }
897 }
898 }
899
900
901
902
903
904
905
906
907
908
909
910 private double[] computeExpectations(double[] samples) {
911 double[][] W = mbarW(reducedPotentials, nSamples, mbarFEEstimates);
912 if (W[0].length != samples.length) {
913 logger.severe("Samples and W matrix are not the same length. Exiting.");
914 }
915 double[] expectation = new double[W.length];
916 for (int i = 0; i < W.length; i++) {
917 for (int j = 0; j < W[i].length; j++) {
918 expectation[i] += W[i][j] * samples[j];
919 }
920 }
921 return expectation;
922 }
923
924
925
926
927
928
929
930
931 private double[][] mbarAugmentedW(double[] samples) {
932 int nStates = mbarFEEstimates.length;
933
934 double minSample = stream(samples).min().getAsDouble() - 3 * java.lang.Math.ulp(1.0);
935 if (minSample < 0) {
936 for (int i = 0; i < samples.length; i++) {
937 samples[i] -= minSample;
938 }
939 }
940
941 double[][] logCATerms = new double[nStates][reducedPotentials[0].length];
942 double[] maxLogCATerm = new double[reducedPotentials[0].length];
943 Arrays.fill(maxLogCATerm, Double.NEGATIVE_INFINITY);
944 double[] logCA = new double[nStates];
945 double[] log_denom_n = new double[reducedPotentials[0].length];
946 for (int i = 0; i < reducedPotentials[0].length; i++) {
947 double[] temp = new double[nStates];
948 double maxTemp = Double.NEGATIVE_INFINITY;
949 for (int j = 0; j < nStates; j++) {
950 temp[j] = mbarFEEstimates[j] - reducedPotentials[j][i];
951 if (temp[j] > maxTemp) {
952 maxTemp = temp[j];
953 }
954 }
955 log_denom_n[i] = logSumExp(temp, nSamples, maxTemp);
956 for (int j = 0; j < nStates; j++) {
957 logCATerms[j][i] = log(samples[i]) - reducedPotentials[j][i] - log_denom_n[i];
958 if (logCATerms[j][i] > maxLogCATerm[i]) {
959 maxLogCATerm[j] = logCATerms[j][i];
960 }
961 }
962 }
963 for (int i = 0; i < nStates; i++) {
964 logCA[i] = logSumExp(logCATerms[i], maxLogCATerm[i]);
965 }
966
967 double[][] WnA = new double[nStates][reducedPotentials[0].length];
968 double[][] Wna = new double[nStates][reducedPotentials[0].length];
969 for (int i = 0; i < nStates; i++) {
970 for (int j = 0; j < reducedPotentials[0].length; j++) {
971 WnA[i][j] = samples[j] * exp(-logCA[i] - reducedPotentials[i][j] - log_denom_n[j]);
972 Wna[i][j] = exp(-mbarFEEstimates[i] - reducedPotentials[i][j] - log_denom_n[j]);
973 }
974 }
975 if (minSample < 0) {
976 for (int i = 0; i < samples.length; i++) {
977 samples[i] += minSample;
978 }
979 }
980 double[][] augmentedW = new double[nStates * 2][reducedPotentials[0].length];
981 for (int i = 0; i < augmentedW.length; i++) {
982 augmentedW[i] = i < nStates ? Wna[i] : WnA[(i - nStates)];
983 }
984 return augmentedW;
985 }
986
987
988
989
990
991
992
993
994
995
996 private double[] computeExpectationStd(double[] samples) {
997 int[] extendedSnaps = new int[nSamples.length * 2];
998 System.arraycopy(nSamples, 0, extendedSnaps, 0, nSamples.length);
999 RealMatrix theta = MatrixUtils.createRealMatrix(mbarTheta(extendedSnaps, mbarAugmentedW(samples)));
1000 double[] expectations = computeExpectations(samples);
1001 double[] diag = new double[expectations.length * 2];
1002 for (int i = 0; i < expectations.length; i++) {
1003 diag[i] = expectations[i];
1004 diag[i + expectations.length] = expectations[i];
1005 }
1006 RealMatrix diagMatrix = MatrixUtils.createRealDiagonalMatrix(diag);
1007 theta = diagMatrix.multiply(theta).multiply(diagMatrix);
1008 RealMatrix ul = theta.getSubMatrix(0, expectations.length - 1, 0, expectations.length - 1);
1009 RealMatrix ur = theta.getSubMatrix(0, expectations.length - 1, expectations.length, expectations.length * 2 - 1);
1010 RealMatrix ll = theta.getSubMatrix(expectations.length, expectations.length * 2 - 1, 0, expectations.length - 1);
1011 RealMatrix lr = theta.getSubMatrix(expectations.length, expectations.length * 2 - 1, expectations.length, expectations.length * 2 - 1);
1012 double[][] covA = ul.add(lr).subtract(ur).subtract(ll).getData();
1013 double[] sigma = new double[covA.length];
1014 for (int i = 0; i < covA.length; i++) {
1015 sigma[i] = sqrt(abs(covA[i][i]));
1016 }
1017 return sigma;
1018 }
1019
1020
1021
1022
1023
1024
1025 private static double[] mbarUncertaintyCalc(double[][] theta) {
1026 double[] uncertainties = new double[theta.length - 1];
1027
1028 for (int i = 0; i < theta.length - 1; i++) {
1029
1030 double variance = theta[i][i] - 2 * theta[i][i + 1] + theta[i + 1][i + 1];
1031 if (variance < 0) {
1032 if (MultistateBennettAcceptanceRatio.VERBOSE) {
1033 logger.warning(" Negative variance detected in MBAR uncertainty calculation. " +
1034 "Multiplying by -1 to get real value. Check diff matrix to see which variances were negative. " +
1035 "They should be NaN.");
1036 }
1037 variance *= -1;
1038 }
1039 uncertainties[i] = sqrt(variance);
1040 }
1041 return uncertainties;
1042 }
1043
1044
1045
1046
1047
1048
1049
1050 private static double mbarTotalUncertaintyCalc(double[][] theta) {
1051 int nStates = theta.length;
1052 return sqrt(abs(theta[0][0] - 2 * theta[0][nStates - 1] + theta[nStates - 1][nStates - 1]));
1053 }
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066 private static double[][] mbarTheta(double[][] reducedPotentials, int[] snapsPerState, double[] freeEnergies) {
1067 return mbarTheta(snapsPerState, mbarW(reducedPotentials, snapsPerState, freeEnergies));
1068 }
1069
1070
1071
1072
1073
1074
1075
1076
1077 private static double[][] mbarTheta(int[] snapsPerState, double[][] W) {
1078 RealMatrix WMatrix = MatrixUtils.createRealMatrix(W).transpose();
1079 RealMatrix I = MatrixUtils.createRealIdentityMatrix(snapsPerState.length);
1080 RealMatrix NkMatrix = MatrixUtils.createRealDiagonalMatrix(stream(snapsPerState).mapToDouble(i -> i).toArray());
1081 SingularValueDecomposition svd = new SingularValueDecomposition(WMatrix);
1082 RealMatrix V = svd.getV();
1083 RealMatrix S = MatrixUtils.createRealDiagonalMatrix(svd.getSingularValues());
1084
1085
1086
1087 RealMatrix theta = S.multiply(V.transpose());
1088 theta = theta.multiply(NkMatrix).multiply(V).multiply(S);
1089 theta = I.subtract(theta);
1090 theta = MatrixUtils.inverse(theta);
1091 theta = V.multiply(S).multiply(theta).multiply(S).multiply(V.transpose());
1092
1093 return theta.getData();
1094 }
1095
1096
1097
1098
1099
1100
1101
1102
1103 private static double[][] diffMatrixCalculation(double[][] theta) {
1104 double[][] diffMatrix = new double[theta.length][theta.length];
1105 for (int i = 0; i < diffMatrix.length; i++) {
1106 for (int j = 0; j < diffMatrix.length; j++) {
1107 diffMatrix[i][j] = sqrt(theta[i][i] - 2 * theta[i][j] + theta[j][j]);
1108 }
1109 }
1110 return diffMatrix;
1111 }
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123 private static double[] mbarSelfConsistentUpdate(double[][] reducedPotential, int[] snapsPerLambda,
1124 double[] freeEnergyEstimates) {
1125 int nStates = freeEnergyEstimates.length;
1126 double[] updatedF_k = new double[nStates];
1127 double[] log_denom_n = new double[reducedPotential[0].length];
1128 double[][] logDiff = new double[reducedPotential.length][reducedPotential[0].length];
1129 double[] maxLogDiff = new double[nStates];
1130 fill(maxLogDiff, Double.NEGATIVE_INFINITY);
1131 for (int i = 0; i < reducedPotential[0].length; i++) {
1132 double[] temp = new double[nStates];
1133 double maxTemp = Double.NEGATIVE_INFINITY;
1134 for (int j = 0; j < nStates; j++) {
1135 temp[j] = freeEnergyEstimates[j] - reducedPotential[j][i];
1136 if (temp[j] > maxTemp) {
1137 maxTemp = temp[j];
1138 }
1139 }
1140 log_denom_n[i] = logSumExp(temp, snapsPerLambda, maxTemp);
1141 for (int j = 0; j < nStates; j++) {
1142 logDiff[j][i] = -log_denom_n[i] - reducedPotential[j][i];
1143 if (logDiff[j][i] > maxLogDiff[j]) {
1144 maxLogDiff[j] = logDiff[j][i];
1145 }
1146 }
1147 }
1148
1149 for (int i = 0; i < nStates; i++) {
1150 updatedF_k[i] = -1.0 * logSumExp(logDiff[i], maxLogDiff[i]);
1151 }
1152
1153
1154 double norm = updatedF_k[0];
1155 updatedF_k[0] = 0.0;
1156 for (int i = 1; i < nStates; i++) {
1157 updatedF_k[i] = updatedF_k[i] - norm;
1158 }
1159
1160 return updatedF_k;
1161 }
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174 private static double[] newtonStep(double[] n, double[] grad, double[][] hessian, double stepSize) {
1175 double[] nPlusOne = new double[n.length];
1176 double[] step;
1177 try {
1178 RealMatrix hessianInverse = MatrixUtils.inverse(MatrixUtils.createRealMatrix(hessian));
1179 step = hessianInverse.preMultiply(grad);
1180 } catch (IllegalArgumentException e) {
1181 if (MultistateBennettAcceptanceRatio.VERBOSE) {
1182 logger.info(" Singular matrix detected in MBAR Newton-Raphson step. Performing steepest descent step.");
1183 }
1184 step = grad;
1185 stepSize = 1e-5;
1186 }
1187
1188 double temp = step[0];
1189 step[0] = 0.0;
1190 for (int i = 1; i < step.length; i++) {
1191 step[i] -= temp;
1192 }
1193 for (int i = 0; i < n.length; i++) {
1194 nPlusOne[i] = n[i] - step[i] * stepSize;
1195 }
1196 return nPlusOne;
1197 }
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208 private static double[] newton(double[] freeEnergyEstimates, double[][] reducedPotentials,
1209 int[] snapsPerLambda, double tolerance) {
1210 double[] grad = mbarGradient(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1211 double[][] hessian = mbarHessian(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1212 double[] f_kPlusOne = newtonStep(freeEnergyEstimates, grad, hessian, 1.0);
1213 int iter = 1;
1214 while (iter < 15) {
1215 freeEnergyEstimates = f_kPlusOne;
1216 grad = mbarGradient(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1217 hessian = mbarHessian(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1218
1219 f_kPlusOne = newtonStep(freeEnergyEstimates, grad, hessian, 1.0);
1220 double eps = 0.0;
1221 for (int i = 0; i < freeEnergyEstimates.length; i++) {
1222 eps += abs(grad[i]);
1223 }
1224 if (eps < tolerance) {
1225 break;
1226 }
1227 iter++;
1228 }
1229 if (MultistateBennettAcceptanceRatio.VERBOSE) {
1230 logger.info(" Newton iterations (max 15): " + iter);
1231 }
1232
1233 return f_kPlusOne;
1234 }
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245 private static double logSumExp(double[] values, double max) {
1246 int[] b = fill(new int[values.length], 1);
1247 return logSumExp(values, b, max);
1248 }
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262 private static double logSumExp(double[] values, int[] b, double max) {
1263
1264
1265 assert values.length == b.length : "values and b must be the same length";
1266
1267
1268 double sum = 0.0;
1269 for (int i = 0; i < values.length; i++) {
1270 sum += b[i] * exp(values[i] - max);
1271 }
1272
1273
1274 return max + log(sum);
1275 }
1276
1277
1278
1279
1280
1281
1282 private static void softMax(double[] values) {
1283 double max = stream(values).max().getAsDouble();
1284 double sum = 0.0;
1285 for (int i = 0; i < values.length; i++) {
1286 values[i] = exp(values[i] - max);
1287 sum += values[i];
1288 }
1289 for (int i = 0; i < values.length; i++) {
1290 values[i] /= sum;
1291 }
1292 }
1293
1294
1295
1296
1297
1298
1299 private OptimizationListener getOptimizationListener() {
1300 return new OptimizationListener() {
1301 @Override
1302 public boolean optimizationUpdate(int iter, int nBFGS, int nFunctionEvals, double gradientRMS,
1303 double coordinateRMS, double f, double df, double angle,
1304 LineSearch.LineSearchResult info) {
1305 return true;
1306 }
1307 };
1308 }
1309
1310
1311
1312
1313
1314
1315
1316 @Override
1317 public double energy(double[] x) {
1318
1319 double tempO = x[0];
1320 x[0] = 0.0;
1321 for (int i = 1; i < x.length; i++) {
1322 x[i] -= tempO;
1323 }
1324 return mbarObjectiveFunction(reducedPotentials, nSamples, x);
1325 }
1326
1327
1328
1329
1330
1331
1332
1333
1334 @Override
1335 public double energyAndGradient(double[] x, double[] g) {
1336 double tempO = x[0];
1337 x[0] = 0.0;
1338 for (int i = 1; i < x.length; i++) {
1339 x[i] -= tempO;
1340 }
1341 double[] tempG = mbarGradient(reducedPotentials, nSamples, x);
1342 arraycopy(tempG, 0, g, 0, g.length);
1343 return mbarObjectiveFunction(reducedPotentials, nSamples, x);
1344 }
1345
1346 @Override
1347 public double[] getCoordinates(double[] parameters) {
1348 return new double[0];
1349 }
1350
1351 @Override
1352 public int getNumberOfVariables() {
1353 return 0;
1354 }
1355
1356 @Override
1357 public double[] getScaling() {
1358 return null;
1359 }
1360
1361 @Override
1362 public void setScaling(double[] scaling) {
1363 }
1364
1365 @Override
1366 public double getTotalEnergy() {
1367 return 0;
1368 }
1369
1370
1371
1372 public BennettAcceptanceRatio getBAR() {
1373 return new BennettAcceptanceRatio(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures);
1374 }
1375
1376 @Override
1377 public MultistateBennettAcceptanceRatio copyEstimator() {
1378 return new MultistateBennettAcceptanceRatio(lamValues, eAll, temperatures, tolerance, seedType);
1379 }
1380
1381 @Override
1382 public double[] getFreeEnergyDifferences() {
1383 return mbarFEDifferenceEstimates;
1384 }
1385
1386 public double[] getMBARFreeEnergies() {
1387 return mbarFEEstimates;
1388 }
1389
1390 public double[][] getReducedPotentials() {
1391 return reducedPotentials;
1392 }
1393
1394 public int[] getSnaps() {
1395 return nSamples;
1396 }
1397
1398 @Override
1399 public double[] getFEDifferenceUncertainties() {
1400 return mbarUncertainties;
1401 }
1402
1403 public double[] getObservationEnsembleAverages() {
1404 return mbarObservableEnsembleAverages;
1405 }
1406
1407 public double[] getObservationEnsembleUncertainties() {
1408 return mbarObservableEnsembleAverageUncertainties;
1409 }
1410
1411 public double[][] getUncertaintyMatrix() {
1412 return uncertaintyMatrix;
1413 }
1414
1415 @Override
1416 public double getTotalFreeEnergyDifference() {
1417 return totalMBAREstimate;
1418 }
1419
1420 @Override
1421 public double getTotalFEDifferenceUncertainty() {
1422 return totalMBARUncertainty;
1423 }
1424
1425 @Override
1426 public int getNumberOfBins() {
1427 return nFreeEnergyDiffs;
1428 }
1429
1430 @Override
1431 public double[] getEnthalpyDifferences() {
1432 return mbarEnthalpy;
1433 }
1434
1435
1436
1437
1438 @Override
1439 public double getTotalEnthalpyDifference() {
1440 return getTotalEnthalpyDifference(mbarEnthalpy);
1441 }
1442
1443 public double[] getBinEntropies() {
1444 return mbarEntropy;
1445 }
1446
1447 public static void writeFile(double[][] energies, File file, double temperature) {
1448 try (FileWriter fw = new FileWriter(file);
1449 BufferedWriter bw = new BufferedWriter(fw)) {
1450
1451 bw.write(energies[0].length + " " + temperature);
1452 bw.newLine();
1453
1454
1455 StringBuilder sb = new StringBuilder();
1456 for (int i = 0; i < energies[0].length; i++) {
1457 sb.append(" ").append(i).append(" ");
1458 for (int j = 0; j < energies.length; j++) {
1459 sb.append(" ").append(energies[j][i]).append(" ");
1460 }
1461 sb.append("\n");
1462 bw.write(sb.toString());
1463 sb = new StringBuilder();
1464 }
1465 } catch (IOException e) {
1466 e.printStackTrace();
1467 }
1468 }
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479 public static String[] testMBARMethods() {
1480
1481 double[] O_k = {1, 2, 3, 4};
1482 double[] K_k = {.5, 1.0, 1.5, 2};
1483 int[] N_k = {100000, 100000, 100000, 100000};
1484 double beta = 1.0;
1485 HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(O_k, K_k, beta);
1486 String setting = "u_kln";
1487 Object[] sampleResult = testCase.sample(N_k, setting, (long) 0);
1488 double[][][] u_kln = (double[][][]) sampleResult[1];
1489 double[] temps = {1 / Constants.R};
1490 MultistateBennettAcceptanceRatio mbar = new MultistateBennettAcceptanceRatio(O_k, u_kln, temps, 1.0E-7, MultistateBennettAcceptanceRatio.SeedType.ZEROS);
1491 MultistateBennettAcceptanceRatio mbarHigherTol = new MultistateBennettAcceptanceRatio(O_k, u_kln, temps, 1.0, MultistateBennettAcceptanceRatio.SeedType.ZEROS);
1492 String[] results = new String[7];
1493
1494 double[][] reducedPotentials = mbar.getReducedPotentials();
1495 double[] freeEnergyEstimates = mbar.getMBARFreeEnergies();
1496 double[] highTolFEEstimates = mbarHigherTol.getMBARFreeEnergies();
1497 double[] zeros = new double[freeEnergyEstimates.length];
1498 int[] snapsPerLambda = mbar.getSnaps();
1499
1500
1501 double[] expectedFEEstimates = new double[]{0.0, 0.3474485596619945, 0.5460865684340613, 0.6866650788765148};
1502 boolean pass = normDiff(freeEnergyEstimates, expectedFEEstimates) < 1e-5;
1503 expectedFEEstimates = new double[]{0.0, 0.35798124225733474, 0.44721370511807645, 0.477203739646745};
1504 pass = normDiff(highTolFEEstimates, expectedFEEstimates) < 1e-5 && pass;
1505 results[0] = pass ? "PASS" : "FAIL getMBARFreeEnergies()";
1506
1507
1508 double objectiveFunction = mbarObjectiveFunction(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1509 pass = !(abs(objectiveFunction - 4786294.2692739945) > 1e-5);
1510 objectiveFunction = mbarObjectiveFunction(reducedPotentials, snapsPerLambda, highTolFEEstimates);
1511 pass = !(abs(objectiveFunction - 4787001.700838844) > 1e-5) && pass;
1512 objectiveFunction = mbarObjectiveFunction(reducedPotentials, snapsPerLambda, zeros);
1513 pass = !(abs(objectiveFunction - 4792767.352152844) > 1e-5) && pass;
1514 results[1] = pass ? "PASS" : "FAIL mbarObjectiveFunction()";
1515
1516
1517 double[] gradient = mbarGradient(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1518 double[] expected = new double[]{6.067113034191607E-4, -8.777718552011038E-4, 8.210768953631487E-4, -5.500246369471995E-4};
1519 pass = !(normDiff(gradient, expected) > 4e-5);
1520 gradient = mbarGradient(reducedPotentials, snapsPerLambda, highTolFEEstimates);
1521 expected = new double[]{1969.705314577408, 5108.841258429764, -1072.9526887468976, -6005.593884267446};
1522 pass = !(normDiff(gradient, expected) > 4e-5) && pass;
1523 gradient = mbarGradient(reducedPotentials, snapsPerLambda, zeros);
1524 expected = new double[]{22797.82037585665, -3273.72282675803, -8859.999065013779, -10664.098484078011};
1525 pass = !(normDiff(gradient, expected) > 4e-5) && pass;
1526 results[2] = pass ? "PASS" : "FAIL mbarGradient()";
1527
1528 pass = true;
1529
1530 double[][] hessian = mbarHessian(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1531 double[][] expected2d = new double[][]{{47600.586808418964, -29977.008359691405, -12870.425573135915, -4753.1528755909385},
1532 {-29977.008359691405, 63767.745823769576, -24597.198354108747, -9193.539109971487},
1533 {-12870.425573135915, -24597.198354108747, 64584.87112481013, -27117.247197561417},
1534 {-4753.1528755909385, -9193.539109971487, -27117.247197561417, 41063.93918312612}};
1535 pass = !(normDiff(hessian, expected2d) > 16e-5);
1536 hessian = mbarHessian(reducedPotentials, snapsPerLambda, highTolFEEstimates);
1537 expected2d = new double[][]{{49168.30161780381, -31256.519016487477, -12983.708230229113, -4928.074371082683},
1538 {-31256.519016487477, 66075.94621325849, -25339.462656640117, -9479.964540130917},
1539 {-12983.708230229113, -25339.462656640117, 64308.30940252403, -25985.13851565483},
1540 {-4928.074371082683, -9479.964540130917, -25985.13851565483, 40393.1774268678}};
1541 pass = !(normDiff(hessian, expected2d) > 16e-5) && pass;
1542 hessian = mbarHessian(reducedPotentials, snapsPerLambda, zeros);
1543 expected2d = new double[][]{{56125.271437145464, -33495.87894376072, -15738.011263498352, -6891.381229885624},
1544 {-33495.87894376072, 64613.515110188295, -21970.091845920833, -9147.544320511564},
1545 {-15738.011263498352, -21970.091845920833, 61407.66256511316, -23699.55945569241},
1546 {-6891.381229885624, -9147.544320511564, -23699.55945569241, 39738.48500608951}};
1547 pass = !(normDiff(hessian, expected2d) > 16e-5) && pass;
1548 results[3] = pass ? "PASS" : "FAIL mbarHessian()";
1549
1550 pass = true;
1551
1552 double[][] theta = mbarTheta(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1553 double[][] diff = diffMatrixCalculation(theta);
1554 expected2d = new double[][]{{0.0, 0.001953125, 0.003400485419234404, 0.004858337095247168},
1555 {0.0020716018980074633, 0.0, 0.002042627017905458, 0.004055968683065466},
1556 {0.003435363105339426, 0.002042627017905458, 0.0, 0.002560568476977909},
1557 {0.0048828125, 0.004055968683065466, 0.0025135815773894045, 0.0}};
1558 pass = !(normDiff(diff, expected2d) > 16e-5);
1559 results[4] = pass ? "PASS" : "FAIL mbarTheta() or diffMatrixCalculation()";
1560
1561 pass = true;
1562
1563 double[] updatedF_k = mbarSelfConsistentUpdate(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1564 expected = new double[]{0.0, 0.3474485745068261, 0.5460865662904055, 0.6866650904438742};
1565 pass = !(normDiff(updatedF_k, expected) > 1e-5);
1566 updatedF_k = mbarSelfConsistentUpdate(reducedPotentials, snapsPerLambda, highTolFEEstimates);
1567 expected = new double[]{0.0, 0.327660608017009, 0.4775067849198251, 0.5586442310038073};
1568 pass = !(normDiff(updatedF_k, expected) > 1e-5) && pass;
1569 updatedF_k = mbarSelfConsistentUpdate(reducedPotentials, snapsPerLambda, zeros);
1570 expected = new double[]{0.0, 0.23865416150488983, 0.29814247007871764, 0.31813582643116334};
1571 pass = !(normDiff(updatedF_k, expected) > 1e-5) && pass;
1572 results[5] = pass ? "PASS" : "FAIL mbarSelfConsistentUpdate()";
1573
1574 pass = true;
1575
1576 updatedF_k = newton(highTolFEEstimates, reducedPotentials, snapsPerLambda, 1e-7);
1577 pass = !(normDiff(updatedF_k, freeEnergyEstimates) > 1e-5);
1578 updatedF_k = newton(zeros, reducedPotentials, snapsPerLambda, 1e-7);
1579 pass = !(normDiff(updatedF_k, freeEnergyEstimates) > 1e-5) && pass;
1580 results[6] = pass ? "PASS" : "FAIL newton()";
1581
1582 return results;
1583 }
1584
1585 private static double normDiff(double[] a, double[] b) {
1586 double sum = 0.0;
1587 for (int i = 0; i < a.length; i++) {
1588 sum += abs(a[i] - b[i]);
1589 }
1590 return sum;
1591 }
1592
1593 private static double normDiff(double[][] a, double[][] b) {
1594 double sum = 0.0;
1595 for (int i = 0; i < a.length; i++) {
1596 for (int j = 0; j < a[i].length; j++) {
1597 sum += abs(a[i][j] - b[i][j]);
1598 }
1599 }
1600 return sum;
1601 }
1602
1603
1604
1605
1606
1607
1608 public static void main(String[] args) {
1609
1610 double[] equilPositions = {1, 2, 3, 4};
1611 double[] springConstants = {.5, 1.0, 1.5, 2};
1612 int[] samples = {100000, 100000, 100000, 100000};
1613 double beta = 1.0;
1614 HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(equilPositions, springConstants, beta);
1615 String setting = "u_kln";
1616 System.out.print("Generating sample data... ");
1617 Object[] sampleResult = testCase.sample(samples, setting, (long) 0);
1618 System.out.println("done. \n");
1619 double[] x_n = (double[]) sampleResult[0];
1620 double[][][] u_kln = (double[][][]) sampleResult[1];
1621 double[] temps = {1 / Constants.R};
1622
1623
1624
1625
1626 String rootPath = new File("").getAbsolutePath();
1627 File outputPath = new File(rootPath + "/testing/mbar/data/harmonic_oscillators/mbarFiles");
1628 if (!outputPath.exists() && !outputPath.mkdirs()) {
1629 throw new RuntimeException("Failed to create directory: " + outputPath);
1630 }
1631
1632 double[] temperatures = new double[equilPositions.length];
1633 Arrays.fill(temperatures, temps[0]);
1634 for (int i = 0; i < u_kln.length; i++) {
1635 File file = new File(outputPath, "energies_" + i + ".mbar");
1636 writeFile(u_kln[i], file, temperatures[i]);
1637 }
1638
1639
1640 System.out.print("Creating MBAR instance and .estimateDG(false) with standard tolerance & zeros seeding...");
1641
1642 MultistateBennettAcceptanceRatio mbar = new MultistateBennettAcceptanceRatio(equilPositions, u_kln, temps, 1e-7, SeedType.ZEROS);
1643 System.out.println("done! \n\n");
1644 double[] mbarFEEstimates = Arrays.copyOf(mbar.mbarFEEstimates, mbar.mbarFEEstimates.length);
1645 double[] mbarEnthalpyDiff = Arrays.copyOf(mbar.mbarEnthalpy, mbar.mbarEnthalpy.length);
1646 double[] mbarEntropyDiff = Arrays.copyOf(mbar.mbarEntropy, mbar.mbarEntropy.length);
1647 double[] mbarUncertainties = Arrays.copyOf(mbar.mbarUncertainties, mbar.mbarUncertainties.length);
1648 double[][] mbarDiffMatrix = Arrays.copyOf(mbar.uncertaintyMatrix, mbar.uncertaintyMatrix.length);
1649
1650
1651 double[] analyticalFreeEnergies = testCase.analyticalFreeEnergies();
1652 double[] error = new double[analyticalFreeEnergies.length];
1653 for (int i = 0; i < error.length; i++) {
1654 error[i] = analyticalFreeEnergies[i] - mbarFEEstimates[i];
1655 }
1656 double[] temp = testCase.analyticalEntropies(0);
1657 double[] analyticEntropyDiff = new double[temp.length - 1];
1658 double[] errorEntropy = new double[temp.length - 1];
1659 for (int i = 0; i < analyticEntropyDiff.length; i++) {
1660 analyticEntropyDiff[i] = temp[i + 1] - temp[i];
1661 errorEntropy[i] = analyticEntropyDiff[i] - mbarEntropyDiff[i];
1662 }
1663
1664
1665 System.out.println("STANDARD THERMODYNAMIC CALCULATIONS: \n");
1666 System.out.println("Analytical Free Energies: " + Arrays.toString(analyticalFreeEnergies));
1667 System.out.println("MBAR Free Energies: " + Arrays.toString(mbarFEEstimates));
1668 System.out.println("Free Energy Error: " + Arrays.toString(error));
1669 System.out.println();
1670 System.out.println("MBAR dG: " + Arrays.toString(mbar.mbarFEDifferenceEstimates));
1671 System.out.println("MBAR Uncertainties: " + Arrays.toString(mbarUncertainties));
1672 System.out.println("MBAR Enthalpy Changes: " + Arrays.toString(mbarEnthalpyDiff));
1673 System.out.println();
1674 System.out.println("MBAR Entropy Changes: " + Arrays.toString(mbarEntropyDiff));
1675 System.out.println("Analytic Entropy Changes: " + Arrays.toString(analyticEntropyDiff));
1676 System.out.println("Entropy Error: " + Arrays.toString(errorEntropy));
1677 System.out.println();
1678 System.out.println("Uncertainty Diff Matrix: ");
1679 for (double[] matrix : mbarDiffMatrix) {
1680 System.out.println(Arrays.toString(matrix));
1681 }
1682 System.out.println("\n\n");
1683
1684
1685 System.out.println("MBAR DERIVED OBSERVABLES: \n");
1686 mbar.setObservableData(u_kln, true, true);
1687 double[] mbarObservableEnsembleAverages = Arrays.copyOf(mbar.mbarObservableEnsembleAverages,
1688 mbar.mbarObservableEnsembleAverages.length);
1689 double[] mbarObservableEnsembleAverageUncertainties = Arrays.copyOf(mbar.mbarObservableEnsembleAverageUncertainties,
1690 mbar.mbarObservableEnsembleAverageUncertainties.length);
1691 System.out.println("Multi-Data Observable Example u_kln:");
1692 System.out.println("MBAR Observable Ensemble Averages (Potential): " + Arrays.toString(mbarObservableEnsembleAverages));
1693 System.out.println("Analytical Observable Ensemble Averages (Potential): " + Arrays.toString(testCase.analyticalObservable("potential energy")));
1694 System.out.println("MBAR Observable Ensemble Average Uncertainties (Potential): " + Arrays.toString(mbarObservableEnsembleAverageUncertainties));
1695 System.out.println();
1696
1697
1698 double[][][] xAll = new double[equilPositions.length][equilPositions.length][x_n.length];
1699 for (int i = 0; i < xAll[0].length; i++) {
1700 for (int j = 0; j < xAll[0][0].length; j++) {
1701
1702 xAll[0][i][j] = x_n[j];
1703 }
1704 }
1705 mbar.setObservableData(xAll, false, true);
1706 mbarObservableEnsembleAverages = Arrays.copyOf(mbar.mbarObservableEnsembleAverages,
1707 mbar.mbarObservableEnsembleAverages.length);
1708 mbarObservableEnsembleAverageUncertainties = Arrays.copyOf(mbar.mbarObservableEnsembleAverageUncertainties,
1709 mbar.mbarObservableEnsembleAverageUncertainties.length);
1710 System.out.println("Single-Data Observable Example x_n:");
1711 System.out.println("MBAR Observable Ensemble Averages (Position): " + Arrays.toString(mbarObservableEnsembleAverages));
1712 System.out.println("Analytical Observable Ensemble Averages (Position): " + Arrays.toString(testCase.analyticalMeans()));
1713 System.out.println("MBAR Observable Ensemble Average Uncertainties (Position): " + Arrays.toString(mbarObservableEnsembleAverageUncertainties));
1714 System.out.println();
1715 }
1716
1717
1718
1719
1720 public static class HarmonicOscillatorsTestCase {
1721 private final double beta;
1722 private final double[] equilPositions;
1723 private final int n_states;
1724 private final double[] springConstants;
1725
1726 public HarmonicOscillatorsTestCase(double[] O_k, double[] K_k, double beta) {
1727 this.beta = beta;
1728 this.equilPositions = O_k;
1729 this.n_states = O_k.length;
1730 this.springConstants = K_k;
1731
1732 if (this.springConstants.length != this.n_states) {
1733 throw new IllegalArgumentException("Lengths of K_k and O_k should be equal");
1734 }
1735 }
1736
1737 public double[] analyticalMeans() {
1738 return equilPositions;
1739 }
1740
1741 public double[] analyticalStandardDeviations() {
1742 double[] deviations = new double[n_states];
1743 for (int i = 0; i < n_states; i++) {
1744 deviations[i] = Math.sqrt(1.0 / (beta * springConstants[i]));
1745 }
1746 return deviations;
1747 }
1748
1749 public double[] analyticalObservable(String observable) {
1750 double[] result = new double[n_states];
1751
1752 switch (observable) {
1753 case "position" -> {
1754 return analyticalMeans();
1755 }
1756 case "potential energy" -> {
1757 for (int i = 0; i < n_states; i++) {
1758 result[i] = 0.5 / beta;
1759 }
1760 }
1761 case "position^2" -> {
1762 for (int i = 0; i < n_states; i++) {
1763 result[i] = 1.0 / (beta * springConstants[i]) + Math.pow(equilPositions[i], 2);
1764 }
1765 }
1766 case "RMS displacement" -> {
1767 return analyticalStandardDeviations();
1768 }
1769 }
1770
1771 return result;
1772 }
1773
1774 public double[] analyticalFreeEnergies() {
1775 int subtractComponentIndex = 0;
1776 double[] fe = new double[n_states];
1777 double subtract = 0.0;
1778 for (int i = 0; i < n_states; i++) {
1779 fe[i] = -0.5 * Math.log(2 * Math.PI / (beta * springConstants[i]));
1780 if (i == 0) {
1781 subtract = fe[subtractComponentIndex];
1782 }
1783 fe[i] -= subtract;
1784 }
1785 return fe;
1786 }
1787
1788 public double[] analyticalEntropies(int subtractComponent) {
1789 double[] entropies = new double[n_states];
1790 double[] potentialEnergy = analyticalObservable("analytical entropy");
1791 double[] freeEnergies = analyticalFreeEnergies();
1792
1793 for (int i = 0; i < n_states; i++) {
1794 entropies[i] = potentialEnergy[i] - freeEnergies[i];
1795 }
1796
1797 return entropies;
1798 }
1799
1800
1801
1802
1803
1804
1805
1806
1807 public Object[] sample(int[] N_k, String mode, Long seed) {
1808 Random random = new Random(seed);
1809
1810 int N_max = 0;
1811 for (int N : N_k) {
1812 if (N > N_max) {
1813 N_max = N;
1814 }
1815 }
1816
1817 int N_tot = 0;
1818 for (int N : N_k) {
1819 N_tot += N;
1820 }
1821
1822 double[][] x_kn = new double[n_states][N_max];
1823 double[][] u_kn = new double[n_states][N_tot];
1824 double[][][] u_kln = new double[n_states][n_states][N_max];
1825 double[] x_n = new double[N_tot];
1826 int[] s_n = new int[N_tot];
1827
1828
1829 int index = 0;
1830 for (int k = 0; k < n_states; k++) {
1831 double x0 = equilPositions[k];
1832 double sigma = Math.sqrt(1.0 / (beta * springConstants[k]));
1833
1834
1835 for (int n = 0; n < N_k[k]; n++) {
1836 double x = x0 + random.nextGaussian() * sigma;
1837 x_kn[k][n] = x;
1838 x_n[index] = x;
1839 s_n[index] = k;
1840
1841 for (int l = 0; l < n_states; l++) {
1842 double u = beta * 0.5 * springConstants[l] * Math.pow(x - equilPositions[l], 2.0);
1843 u_kln[k][l][n] = u;
1844 u_kn[l][index] = u;
1845 }
1846 index++;
1847 }
1848
1849 for (int n = N_k[k]; n < N_max; n++) {
1850 for (int l = 0; l < n_states; l++) {
1851 u_kln[k][l][n] = Double.NaN;
1852 }
1853 }
1854 }
1855
1856
1857 if ("u_kn".equals(mode)) {
1858 return new Object[]{x_n, u_kn, N_k, s_n};
1859 } else if ("u_kln".equals(mode)) {
1860 return new Object[]{x_n, u_kln, N_k, s_n, u_kn};
1861 } else {
1862 throw new IllegalArgumentException("Unknown mode: " + mode);
1863 }
1864 }
1865 }
1866 }