1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.estimator;
39
40 import ffx.numerics.OptimizationInterface;
41 import ffx.numerics.optimization.LBFGS;
42 import ffx.numerics.optimization.LineSearch;
43 import ffx.numerics.optimization.OptimizationListener;
44 import ffx.utilities.Constants;
45 import org.apache.commons.math3.linear.MatrixUtils;
46 import org.apache.commons.math3.linear.RealMatrix;
47 import org.apache.commons.math3.linear.SingularValueDecomposition;
48 import org.apache.commons.math3.util.MathArrays;
49
50 import java.io.BufferedWriter;
51 import java.io.File;
52 import java.io.FileWriter;
53 import java.io.IOException;
54 import java.util.Arrays;
55 import java.util.Random;
56 import java.util.logging.Logger;
57
58 import static ffx.numerics.estimator.EstimateBootstrapper.getBootstrapIndices;
59 import static ffx.numerics.estimator.Zwanzig.Directionality.BACKWARDS;
60 import static ffx.numerics.estimator.Zwanzig.Directionality.FORWARDS;
61 import static java.lang.System.arraycopy;
62 import static java.util.Arrays.copyOf;
63 import static java.util.Arrays.stream;
64 import static org.apache.commons.lang3.ArrayFill.fill;
65 import static org.apache.commons.math3.util.FastMath.abs;
66 import static org.apache.commons.math3.util.FastMath.exp;
67 import static org.apache.commons.math3.util.FastMath.log;
68 import static org.apache.commons.math3.util.FastMath.sqrt;
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87 public class MultistateBennettAcceptanceRatio extends SequentialEstimator implements BootstrappableEstimator, OptimizationInterface {
88 private static final Logger logger = Logger.getLogger(MultistateBennettAcceptanceRatio.class.getName());
89
90
91
92
93 private static final double DEFAULT_TOLERANCE = 1.0E-7;
94
95
96
97 private final int nFreeEnergyDiffs;
98
99
100
101 private final double[] mbarEstimates;
102
103
104
105 private double[] mbarUncertainties;
106
107
108
109
110 private double[][] diffMatrix;
111
112
113
114 private final double tolerance;
115 private final Random random;
116 private final int nStates;
117
118
119
120 double[] mbarFreeEnergies;
121
122
123
124 private double totalMBAREstimate;
125
126
127
128 private double totalMBARUncertainty;
129
130
131
132 private final double[] mbarEnthalpy;
133
134
135
136
137 private double[][] u_kn;
138
139
140
141 private double[] N_k;
142
143
144
145 private SeedType seedType;
146
147
148
149
150 public enum SeedType {BAR, ZWANZIG, ZEROS}
151
152
153
154
155
156
157
158
159 public MultistateBennettAcceptanceRatio(double[] lambdaValues, double[][][] energiesAll, double[] temperature) {
160 this(lambdaValues, energiesAll, temperature, DEFAULT_TOLERANCE, SeedType.ZWANZIG);
161 }
162
163
164
165
166
167
168
169
170
171
172 public MultistateBennettAcceptanceRatio(double[] lambdaValues, double[][][] energiesAll, double[] temperature,
173 double tolerance, SeedType seedType) {
174 super(lambdaValues, energiesAll, temperature);
175 this.tolerance = tolerance;
176 this.seedType = seedType;
177
178
179 nStates = lambdaValues.length;
180 mbarFreeEnergies = new double[nStates];
181
182 nFreeEnergyDiffs = lambdaValues.length - 1;
183 mbarEstimates = new double[nFreeEnergyDiffs];
184 mbarUncertainties = new double[nFreeEnergyDiffs];
185 mbarEnthalpy = new double[nFreeEnergyDiffs];
186 random = new Random();
187 estimateDG();
188 }
189
190
191
192
193 private void seedEnergies() {
194 switch (seedType) {
195 case BAR:
196 try {
197 SequentialEstimator barEstimator = new BennettAcceptanceRatio(lamValues, eLow, eAt, eHigh, temperatures);
198 mbarFreeEnergies[0] = 0.0;
199 double[] barEstimates = barEstimator.getBinEnergies();
200 for (int i = 0; i < nFreeEnergyDiffs; i++) {
201 mbarFreeEnergies[i + 1] = mbarFreeEnergies[i] + barEstimates[i];
202 }
203 break;
204 } catch (IllegalArgumentException e) {
205 logger.warning(" BAR failed to converge. Zwanzig will be used for seed energies.");
206 seedType = SeedType.ZWANZIG;
207 seedEnergies();
208 return;
209 }
210 case ZWANZIG:
211
212 Zwanzig forwardsFEP = new Zwanzig(lamValues, eLow, eAt, eHigh, temperatures, FORWARDS);
213
214 Zwanzig backwardsFEP = new Zwanzig(lamValues, eLow, eAt, eHigh, temperatures, BACKWARDS);
215
216 double[] forwardZwanzig = forwardsFEP.getBinEnergies();
217
218 double[] backwardZwanzig = backwardsFEP.getBinEnergies();
219 mbarFreeEnergies[0] = 0.0;
220 for (int i = 0; i < nFreeEnergyDiffs; i++) {
221 mbarFreeEnergies[i + 1] = mbarFreeEnergies[i] + .5 * (forwardZwanzig[i] + backwardZwanzig[i]);
222 }
223 break;
224 case SeedType.ZEROS:
225 break;
226 default:
227 throw new IllegalArgumentException("Seed type not supported");
228 }
229 }
230
231
232
233
234 @Override
235 public void estimateDG() {
236 estimateDG(false);
237 }
238
239
240
241
242 @Override
243 public void estimateDG(boolean randomSamples) {
244
245 fill(mbarFreeEnergies, 0.0);
246 seedEnergies();
247
248
249 if (stream(mbarFreeEnergies).anyMatch(Double::isInfinite) || stream(mbarFreeEnergies).anyMatch(Double::isNaN)) {
250 throw new IllegalArgumentException("MBAR contains NaNs or Infs after seeding.");
251 }
252 double[] prevMBAR;
253
254
255 int iter = 0;
256
257
258 double[] rtValues = new double[nStates];
259 double[] invRTValues = new double[nStates];
260 for (int i = 0; i < nStates; i++) {
261 rtValues[i] = Constants.R * temperatures[i];
262 invRTValues[i] = 1.0 / rtValues[i];
263 }
264 int numSnaps = eAllFlat[0].length;
265
266
267 int[][] indices = new int[nStates][numSnaps];
268 if (randomSamples) {
269 int[] randomIndices = getBootstrapIndices(numSnaps, random);
270 for (int i = 0; i < nStates; i++) {
271
272 indices[i] = randomIndices;
273 }
274 } else {
275 for (int i = 0; i < numSnaps; i++) {
276 for (int j = 0; j < nStates; j++) {
277 indices[j][i] = i;
278 }
279 }
280 }
281
282
283 u_kn = new double[nStates][numSnaps];
284 N_k = new double[nStates];
285 for (int state = 0; state < nStates; state++) {
286 for (int n = 0; n < numSnaps; n++) {
287 u_kn[state][n] = eAllFlat[state][indices[state][n]] * invRTValues[state];
288 }
289 N_k[state] = (double) numSnaps / nStates;
290 }
291
292
293
294 double omega = 1.5;
295 for (int i = 0; i < 10; i++) {
296 prevMBAR = copyOf(mbarFreeEnergies, nStates);
297 mbarFreeEnergies = selfConsistentUpdate(u_kn, N_k, mbarFreeEnergies);
298
299 for (int j = 0; j < nStates; j++) {
300 mbarFreeEnergies[j] = omega * mbarFreeEnergies[j] + (1 - omega) * prevMBAR[j];
301 }
302
303 if (stream(mbarFreeEnergies).anyMatch(Double::isInfinite) || stream(mbarFreeEnergies).anyMatch(Double::isNaN)) {
304 throw new IllegalArgumentException("MBAR contains NaNs or Infs after iteration " + iter);
305 }
306 }
307
308 try {
309
310 if (nStates > 100) {
311 int mCorrections = 5;
312 double[] x = new double[nStates];
313 arraycopy(mbarFreeEnergies, 0, x, 0, nStates);
314 double[] grad = mbarGradient(u_kn, N_k, mbarFreeEnergies);
315 double eps = 1.0E-4;
316 OptimizationListener listener = getOptimizationListener();
317 LBFGS.minimize(nStates, mCorrections, x, mbarObjectiveFunction(u_kn, N_k, mbarFreeEnergies),
318 grad, eps, 1000, this, listener);
319 arraycopy(x, 0, mbarFreeEnergies, 0, nStates);
320 } else {
321 mbarFreeEnergies = newton(mbarFreeEnergies, u_kn, N_k, 1.0, 100, 1.0E-7);
322 }
323 } catch (Exception e) {
324 logger.warning(" L-BFGS/Newton failed to converge. Finishing w/ self-consistent iteration.");
325 logger.warning(e.getMessage());
326 }
327
328
329 do {
330 prevMBAR = copyOf(mbarFreeEnergies, nStates);
331 mbarFreeEnergies = selfConsistentUpdate(u_kn, N_k, mbarFreeEnergies);
332
333 for (int i = 0; i < nStates; i++) {
334 mbarFreeEnergies[i] = omega * mbarFreeEnergies[i] + (1 - omega) * prevMBAR[i];
335 }
336
337 if (stream(mbarFreeEnergies).anyMatch(Double::isInfinite) || stream(mbarFreeEnergies).anyMatch(Double::isNaN)) {
338 throw new IllegalArgumentException("MBAR contains NaNs or Infs after iteration " + iter);
339 }
340 iter++;
341 } while (!converged(prevMBAR));
342
343 logger.fine(" MBAR converged after " + iter + " iterations with omega " + omega + ".");
344
345
346 double f0 = mbarFreeEnergies[0];
347 for (int i = 0; i < nStates; i++) {
348 mbarFreeEnergies[i] -= f0;
349 }
350
351
352 mbarUncertainties = mbarUncertaintyCalc(u_kn, N_k, mbarFreeEnergies);
353 totalMBARUncertainty = mbarTotalUncertaintyCalc(u_kn, N_k, mbarFreeEnergies);
354 diffMatrix = diffMatrixCalculation(u_kn, N_k, mbarFreeEnergies);
355
356
357 for (int i = 0; i < nStates; i++) {
358 mbarFreeEnergies[i] = mbarFreeEnergies[i] * rtValues[i];
359 }
360
361 for (int i = 0; i < nFreeEnergyDiffs; i++) {
362 mbarEstimates[i] = mbarFreeEnergies[i + 1] - mbarFreeEnergies[i];
363 }
364
365 totalMBAREstimate = stream(mbarEstimates).sum();
366 }
367
368
369
370
371
372
373
374
375
376 private boolean converged(double[] prevMBAR) {
377 double[] differences = new double[prevMBAR.length];
378 for (int i = 0; i < prevMBAR.length; i++) {
379 differences[i] = abs(prevMBAR[i] - mbarFreeEnergies[i]);
380 }
381 return stream(differences).allMatch(d -> d < tolerance);
382 }
383
384
385
386
387
388
389
390
391
392
393
394 private static double mbarObjectiveFunction(double[][] u_kn, double[] N_k, double[] f_k) {
395 if (stream(f_k).anyMatch(Double::isInfinite) || stream(f_k).anyMatch(Double::isNaN)) {
396 throw new IllegalArgumentException("MBAR contains NaNs or Infs.");
397 }
398 int nStates = f_k.length;
399 double[] log_denom_n = new double[u_kn[0].length];
400 for (int i = 0; i < u_kn[0].length; i++) {
401 double[] temp = new double[nStates];
402 double maxTemp = Double.NEGATIVE_INFINITY;
403 for (int j = 0; j < nStates; j++) {
404 temp[j] = f_k[j] - u_kn[j][i];
405 if (temp[j] > maxTemp) {
406 maxTemp = temp[j];
407 }
408 }
409 log_denom_n[i] = logSumExp(temp, N_k, maxTemp);
410 }
411 double[] dotNkFk = new double[N_k.length];
412 for (int i = 0; i < N_k.length; i++) {
413 dotNkFk[i] = N_k[i] * f_k[i];
414 }
415 return stream(log_denom_n).sum() - stream(dotNkFk).sum();
416 }
417
418
419
420
421
422
423
424
425
426 private static double[] mbarGradient(double[][] u_kn, double[] N_k, double[] f_k) {
427 int nStates = f_k.length;
428 double[] log_num_k = new double[nStates];
429 double[] log_denom_n = new double[u_kn[0].length];
430 double[][] logDiff = new double[u_kn.length][u_kn[0].length];
431 double maxLogDiff = Double.NEGATIVE_INFINITY;
432 for (int i = 0; i < u_kn[0].length; i++) {
433 double[] temp = new double[nStates];
434 double maxTemp = Double.NEGATIVE_INFINITY;
435 for (int j = 0; j < nStates; j++) {
436 temp[j] = f_k[j] - u_kn[j][i];
437 if (temp[j] > maxTemp) {
438 maxTemp = temp[j];
439 }
440 }
441 log_denom_n[i] = logSumExp(temp, N_k, maxTemp);
442 for (int j = 0; j < nStates; j++) {
443 logDiff[j][i] = -log_denom_n[i] - u_kn[j][i];
444 if (logDiff[j][i] > maxLogDiff) {
445 maxLogDiff = logDiff[j][i];
446 }
447 }
448 }
449 for (int i = 0; i < nStates; i++) {
450 log_num_k[i] = logSumExp(logDiff[i], maxLogDiff);
451 }
452 double[] grad = new double[nStates];
453 for (int i = 0; i < nStates; i++) {
454 grad[i] = -1.0 * N_k[i] * (1.0 - exp(f_k[i] + log_num_k[i]));
455 }
456 return grad;
457 }
458
459
460
461
462
463
464
465
466
467 private static double[][] mbarHessian(double[][] u_kn, double[] N_k, double[] f_k) {
468 int nStates = f_k.length;
469 double[][] W = mbarW(u_kn, N_k, f_k);
470
471 double[][] hessian = new double[nStates][nStates];
472 for (int i = 0; i < nStates; i++) {
473 for (int j = 0; j < nStates; j++) {
474 double sum = 0.0;
475 for (int k = 0; k < u_kn[0].length; k++) {
476 sum += W[i][k] * W[j][k];
477 }
478 hessian[i][j] = sum * N_k[i] * N_k[j];
479 }
480 double wSum = 0.0;
481 for (int k = 0; k < W[i].length; k++) {
482 wSum += W[i][k];
483 }
484 hessian[i][i] -= wSum * N_k[i];
485 }
486
487 for (int i = 0; i < nStates; i++) {
488 for (int j = 0; j < nStates; j++) {
489 hessian[i][j] = -hessian[i][j];
490 }
491 }
492 return hessian;
493 }
494
495
496
497
498
499
500
501
502
503 private static double[][] mbarW(double[][] u_kn, double[] N_k, double[] f_k) {
504 int nStates = f_k.length;
505 double[] log_denom_n = new double[u_kn[0].length];
506 double[][] logDiff = new double[u_kn.length][u_kn[0].length];
507 double maxLogDiff = Double.NEGATIVE_INFINITY;
508 for (int i = 0; i < u_kn[0].length; i++) {
509 double[] temp = new double[nStates];
510 double maxTemp = Double.NEGATIVE_INFINITY;
511 for (int j = 0; j < nStates; j++) {
512 temp[j] = f_k[j] - u_kn[j][i];
513 if (temp[j] > maxTemp) {
514 maxTemp = temp[j];
515 }
516 }
517 log_denom_n[i] = logSumExp(temp, N_k, maxTemp);
518 for (int j = 0; j < nStates; j++) {
519 logDiff[j][i] = -log_denom_n[i] - u_kn[j][i];
520 if (logDiff[j][i] > maxLogDiff) {
521 maxLogDiff = logDiff[j][i];
522 }
523 }
524 }
525
526 double[][] W = new double[nStates][u_kn[0].length];
527 for (int i = 0; i < nStates; i++) {
528 for (int j = 0; j < u_kn[0].length; j++) {
529 W[i][j] = exp(f_k[i] - u_kn[i][j] - log_denom_n[j]);
530 }
531 }
532 return W;
533 }
534
535
536
537
538
539
540
541
542
543 private static double[][] mbarLogW(double[][] u_kn, double[] N_k, double[] f_k) {
544 int nStates = f_k.length;
545
546 double[] log_denom_n = new double[u_kn[0].length];
547 double[][] logDiff = new double[u_kn.length][u_kn[0].length];
548 double maxLogDiff = Double.NEGATIVE_INFINITY;
549 for (int i = 0; i < u_kn[0].length; i++) {
550 double[] temp = new double[nStates];
551 double maxTemp = Double.NEGATIVE_INFINITY;
552 for (int j = 0; j < nStates; j++) {
553 temp[j] = f_k[j] - u_kn[j][i];
554 if (temp[j] > maxTemp) {
555 maxTemp = temp[j];
556 }
557 }
558 log_denom_n[i] = logSumExp(temp, N_k, maxTemp);
559 for (int j = 0; j < nStates; j++) {
560 logDiff[j][i] = -log_denom_n[i] - u_kn[j][i];
561 if (logDiff[j][i] > maxLogDiff) {
562 maxLogDiff = logDiff[j][i];
563 }
564 }
565 }
566
567 double[][] logW = new double[nStates][u_kn[0].length];
568 for (int i = 0; i < nStates; i++) {
569 for (int j = 0; j < u_kn[0].length; j++) {
570 logW[i][j] = f_k[i] - u_kn[i][j] - log_denom_n[j];
571 }
572 }
573 return logW;
574 }
575
576
577
578
579
580
581
582
583
584
585
586
587 private static double[][] mbarTheta(double[][] u_kn, double[] N_k, double[] f_k) {
588
589 double[][] W = mbarW(u_kn, N_k, f_k);
590 RealMatrix WMatrix = MatrixUtils.createRealMatrix(W).transpose();
591 RealMatrix I = MatrixUtils.createRealIdentityMatrix(f_k.length);
592 RealMatrix NkMatrix = MatrixUtils.createRealDiagonalMatrix(N_k);
593 SingularValueDecomposition svd = new SingularValueDecomposition(WMatrix);
594 RealMatrix V = svd.getV();
595 RealMatrix S = MatrixUtils.createRealDiagonalMatrix(svd.getSingularValues());
596
597
598
599 RealMatrix theta = S.multiply(V.transpose());
600 theta = theta.multiply(NkMatrix).multiply(V).multiply(S);
601 theta = I.subtract(theta);
602 theta = new SingularValueDecomposition(theta).getSolver().getInverse();
603 theta = V.multiply(S).multiply(theta).multiply(S).multiply(V.transpose());
604
605 return theta.getData();
606 }
607
608
609
610
611
612
613
614
615
616 private static double[] mbarUncertaintyCalc(double[][] u_kn, double[] N_k, double[] f_k) {
617 double[][] theta = mbarTheta(u_kn, N_k, f_k);
618 double[] uncertainties = new double[f_k.length - 1];
619
620 for (int i = 0; i < f_k.length - 1; i++) {
621 uncertainties[i] = sqrt(theta[i][i] - 2 * theta[i][i + 1] + theta[i + 1][i + 1]);
622 }
623 return uncertainties;
624 }
625
626
627
628
629
630
631
632
633
634 private static double mbarTotalUncertaintyCalc(double[][] u_kn, double[] N_k, double[] f_k) {
635 double[][] theta = mbarTheta(u_kn, N_k, f_k);
636 int nStates = f_k.length;
637 return sqrt(theta[0][0] - 2 * theta[0][nStates - 1] + theta[nStates - 1][nStates - 1]);
638 }
639
640
641
642
643
644
645
646
647
648 private static double[][] diffMatrixCalculation(double[][] u_kn, double[] N_k, double[] f_k) {
649 double[][] theta = mbarTheta(u_kn, N_k, f_k);
650 double[][] diffMatrix = new double[f_k.length][f_k.length];
651 for (int i = 0; i < f_k.length; i++) {
652 for (int j = 0; j < f_k.length; j++) {
653 diffMatrix[i][j] = sqrt(theta[i][i] - 2 * theta[i][j] + theta[j][j]);
654 }
655 }
656 return diffMatrix;
657 }
658
659
660
661
662
663
664
665
666
667
668
669 private static double[] selfConsistentUpdate(double[][] u_kn, double[] N_k, double[] f_k) {
670 int nStates = f_k.length;
671 double[] updatedF_k = new double[nStates];
672 double[] log_denom_n = new double[u_kn[0].length];
673 double[][] logDiff = new double[u_kn.length][u_kn[0].length];
674 double[] maxLogDiff = new double[nStates];
675 fill(maxLogDiff, Double.NEGATIVE_INFINITY);
676 for (int i = 0; i < u_kn[0].length; i++) {
677 double[] temp = new double[nStates];
678 double maxTemp = Double.NEGATIVE_INFINITY;
679 for (int j = 0; j < nStates; j++) {
680 temp[j] = f_k[j] - u_kn[j][i];
681 if (temp[j] > maxTemp) {
682 maxTemp = temp[j];
683 }
684 }
685 log_denom_n[i] = logSumExp(temp, N_k, maxTemp);
686 for (int j = 0; j < nStates; j++) {
687 logDiff[j][i] = -log_denom_n[i] - u_kn[j][i];
688 if (logDiff[j][i] > maxLogDiff[j]) {
689 maxLogDiff[j] = logDiff[j][i];
690 }
691 }
692 }
693
694 for (int i = 0; i < nStates; i++) {
695 updatedF_k[i] = -1.0 * logSumExp(logDiff[i], maxLogDiff[i]);
696 }
697
698
699 double norm = updatedF_k[0];
700 updatedF_k[0] = 0.0;
701 for (int i = 1; i < nStates; i++) {
702 updatedF_k[i] = updatedF_k[i] - norm;
703 }
704
705 return updatedF_k;
706 }
707
708
709
710
711
712
713
714
715
716
717 private static double[] newtonStep(double[] n, double[] grad, double[][] hessian, double stepSize) {
718 double[] nPlusOne = new double[n.length];
719 RealMatrix hessianInverse = MatrixUtils.inverse(MatrixUtils.createRealMatrix(hessian));
720 double[] step = hessianInverse.preMultiply(grad);
721
722 double temp = step[0];
723 step[0] = 0.0;
724 for (int i = 1; i < step.length; i++) {
725 step[i] -= temp;
726 }
727 for (int i = 0; i < n.length; i++) {
728 nPlusOne[i] = n[i] - step[i] * stepSize;
729 }
730 return nPlusOne;
731 }
732
733
734
735
736
737
738
739
740
741
742
743
744 private static double[] newton(double[] f_k, double[][] u_kn, double[] N_k, double stepSize, int maxIter, double tolerance) {
745 double[] grad = mbarGradient(u_kn, N_k, f_k);
746 double[][] hessian = mbarHessian(u_kn, N_k, f_k);
747 double[] f_kPlusOne = newtonStep(f_k, grad, hessian, stepSize);
748 int iter = 1;
749 while (iter < maxIter && MathArrays.distance1(f_k, f_kPlusOne) > tolerance) {
750 f_k = f_kPlusOne;
751 grad = mbarGradient(u_kn, N_k, f_k);
752 hessian = mbarHessian(u_kn, N_k, f_k);
753 f_kPlusOne = newtonStep(f_k, grad, hessian, stepSize);
754 iter++;
755 }
756
757 logger.fine(" Newton converged after " + iter + " iterations.");
758
759 return f_kPlusOne;
760 }
761
762
763
764
765
766
767
768
769
770
771 private static double logSumExp(double[] values, double max) {
772 double[] b = fill(new double[values.length], 1.0);
773 return logSumExp(values, b, max);
774 }
775
776
777
778
779
780
781
782
783
784
785
786 private static double logSumExp(double[] values, double[] b, double max) {
787
788
789 assert values.length == b.length : "values and b must be the same length";
790
791
792 double sum = 0.0;
793 for (int i = 0; i < values.length; i++) {
794 sum += b[i] * exp(values[i] - max);
795 }
796
797
798 return max + log(sum);
799 }
800
801
802
803
804
805
806 private OptimizationListener getOptimizationListener() {
807 return new OptimizationListener() {
808 @Override
809 public boolean optimizationUpdate(int iter, int nBFGS, int nFunctionEvals, double gradientRMS,
810 double coordinateRMS, double f, double df, double angle,
811 LineSearch.LineSearchResult info) {
812 return true;
813 }
814 };
815 }
816
817
818
819
820
821
822
823 @Override
824 public double energy(double[] x) {
825
826 double tempO = x[0];
827 x[0] = 0.0;
828 for (int i = 1; i < x.length; i++) {
829 x[i] -= tempO;
830 }
831 return mbarObjectiveFunction(u_kn, N_k, x);
832 }
833
834
835
836
837
838
839
840
841 @Override
842 public double energyAndGradient(double[] x, double[] g) {
843 double tempO = x[0];
844 x[0] = 0.0;
845 for (int i = 1; i < x.length; i++) {
846 x[i] -= tempO;
847 }
848 double[] tempG = mbarGradient(u_kn, N_k, x);
849 arraycopy(tempG, 0, g, 0, g.length);
850 return mbarObjectiveFunction(u_kn, N_k, x);
851 }
852
853 @Override
854 public double[] getCoordinates(double[] parameters) {
855 return new double[0];
856 }
857
858 @Override
859 public int getNumberOfVariables() {
860 return 0;
861 }
862
863 @Override
864 public double[] getScaling() {
865 return null;
866 }
867
868 @Override
869 public void setScaling(double[] scaling) {
870 }
871
872 @Override
873 public double getTotalEnergy() {
874 return 0;
875 }
876
877
878 public BennettAcceptanceRatio getBAR() {
879 return new BennettAcceptanceRatio(lamValues, eLow, eAt, eHigh, temperatures);
880 }
881
882 @Override
883 public MultistateBennettAcceptanceRatio copyEstimator() {
884 return new MultistateBennettAcceptanceRatio(lamValues, eAll, temperatures, tolerance, seedType);
885 }
886
887 @Override
888 public double[] getBinEnergies() {
889 return mbarEstimates;
890 }
891
892 public double[] getMBARFreeEnergies() {
893 return mbarFreeEnergies;
894 }
895
896 @Override
897 public double[] getBinUncertainties() {
898 return mbarUncertainties;
899 }
900
901 public double[][] getDiffMatrix() {
902 return diffMatrix;
903 }
904
905 @Override
906 public double getFreeEnergy() {
907 return totalMBAREstimate;
908 }
909
910 @Override
911 public double getUncertainty() {
912 return totalMBARUncertainty;
913 }
914
915 @Override
916 public int numberOfBins() {
917 return nFreeEnergyDiffs;
918 }
919
920 @Override
921 public double[] getBinEnthalpies() {
922 return mbarEnthalpy;
923 }
924
925
926
927
928 public static class HarmonicOscillatorsTestCase {
929
930
931
932
933 private final double beta;
934
935
936
937 private final double[] O_k;
938
939
940
941 private final int n_states;
942
943
944
945 private final double[] K_k;
946
947
948
949
950
951
952
953
954 public HarmonicOscillatorsTestCase(double[] O_k, double[] K_k, double beta) {
955 this.beta = beta;
956 this.O_k = O_k;
957 this.n_states = O_k.length;
958 this.K_k = K_k;
959
960 if (this.K_k.length != this.n_states) {
961 throw new IllegalArgumentException("Lengths of K_k and O_k should be equal");
962 }
963 }
964
965 public double[] analyticalMeans() {
966 return O_k;
967 }
968
969 public double[] analyticalVariances() {
970 double[] variances = new double[n_states];
971 for (int i = 0; i < n_states; i++) {
972 variances[i] = 1.0 / (beta * K_k[i]);
973 }
974 return variances;
975 }
976
977 public double[] analyticalStandardDeviations() {
978 double[] deviations = new double[n_states];
979 for (int i = 0; i < n_states; i++) {
980 deviations[i] = Math.sqrt(1.0 / (beta * K_k[i]));
981 }
982 return deviations;
983 }
984
985 public double[] analyticalObservable(String observable) {
986 double[] result = new double[n_states];
987
988 switch (observable) {
989 case "position" -> {
990 return analyticalMeans();
991 }
992 case "potential energy" -> {
993 for (int i = 0; i < n_states; i++) {
994 result[i] = 0.5 / beta;
995 }
996 }
997 case "position^2" -> {
998 for (int i = 0; i < n_states; i++) {
999 result[i] = 1.0 / (beta * K_k[i]) + Math.pow(O_k[i], 2);
1000 }
1001 }
1002 case "RMS displacement" -> {
1003 return analyticalStandardDeviations();
1004 }
1005 }
1006
1007 return result;
1008 }
1009
1010 public double[] analyticalFreeEnergies() {
1011 int subtractComponentIndex = 0;
1012 double[] fe = new double[n_states];
1013 double subtract = 0.0;
1014 for (int i = 0; i < n_states; i++) {
1015 fe[i] = -0.5 * Math.log(2 * Math.PI / (beta * K_k[i]));
1016 if (i == 0) {
1017 subtract = fe[subtractComponentIndex];
1018 }
1019 fe[i] -= subtract;
1020 }
1021 return fe;
1022 }
1023
1024 public double[] analyticalEntropies(int subtractComponent) {
1025 double[] entropies = new double[n_states];
1026 double[] potentialEnergy = analyticalObservable("analytical entropy");
1027 double[] freeEnergies = analyticalFreeEnergies();
1028
1029 for (int i = 0; i < n_states; i++) {
1030 entropies[i] = potentialEnergy[i] - freeEnergies[i];
1031 }
1032
1033 return entropies;
1034 }
1035
1036
1037
1038
1039
1040
1041
1042
1043 public Object[] sample(int[] N_k, String mode, Long seed) {
1044 Random random = new Random(seed);
1045
1046 int N_max = 0;
1047 for (int N : N_k) {
1048 if (N > N_max) {
1049 N_max = N;
1050 }
1051 }
1052
1053 int N_tot = 0;
1054 for (int N : N_k) {
1055 N_tot += N;
1056 }
1057
1058 double[][] x_kn = new double[n_states][N_max];
1059 double[][] u_kn = new double[n_states][N_tot];
1060 double[][][] u_kln = new double[n_states][n_states][N_max];
1061 double[] x_n = new double[N_tot];
1062 int[] s_n = new int[N_tot];
1063
1064
1065 int index = 0;
1066 for (int k = 0; k < n_states; k++) {
1067 double x0 = O_k[k];
1068 double sigma = Math.sqrt(1.0 / (beta * K_k[k]));
1069
1070
1071 for (int n = 0; n < N_k[k]; n++) {
1072 double x = x0 + random.nextGaussian() * sigma;
1073
1074 x_kn[k][n] = x;
1075 x_n[index] = x;
1076 s_n[index] = k;
1077
1078
1079 for (int l = 0; l < n_states; l++) {
1080 double u = beta * 0.5 * K_k[l] * Math.pow(x - O_k[l], 2.0);
1081 u_kln[k][l][n] = u;
1082 u_kn[l][index] = u;
1083 }
1084
1085 index++;
1086 }
1087 }
1088
1089
1090 if ("u_kn".equals(mode)) {
1091 return new Object[]{x_n, u_kn, N_k, s_n};
1092 } else if ("u_kln".equals(mode)) {
1093 return new Object[]{x_n, u_kln, N_k, s_n};
1094 } else {
1095 throw new IllegalArgumentException("Unknown mode: " + mode);
1096 }
1097 }
1098
1099 public static Object[] evenlySpacedOscillators(
1100 int n_states, int n_samplesPerState, double lower_O_k, double upper_O_k,
1101 double lower_K_k, double upper_K_k, Long seed) {
1102
1103
1104 double[] O_k = new double[n_states];
1105 double[] K_k = new double[n_states];
1106 int[] N_k = new int[n_states];
1107
1108 double stepO_k = (upper_O_k - lower_O_k) / (n_states - 1);
1109 double stepK_k = (upper_K_k - lower_K_k) / (n_states - 1);
1110
1111 for (int i = 0; i < n_states; i++) {
1112 O_k[i] = lower_O_k + i * stepO_k;
1113 K_k[i] = lower_K_k + i * stepK_k;
1114 N_k[i] = n_samplesPerState;
1115 }
1116
1117 HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(O_k, K_k, 1.0);
1118 Object[] result = testCase.sample(N_k, "u_kn", System.currentTimeMillis());
1119
1120 return new Object[]{testCase, result[0], result[1], result[2], result[3]};
1121 }
1122
1123 public static void main(String[] args) {
1124
1125 double[] O_k = {0, 1, 2, 3, 4};
1126 double[] K_k = {1, 2, 4, 8, 16};
1127 double beta = 1.0;
1128 System.out.println("Beta: " + beta);
1129
1130
1131 HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(O_k, K_k, beta);
1132
1133
1134 System.out.println("Analytical Means: " + Arrays.toString(testCase.analyticalMeans()));
1135 System.out.println("Analytical Variances: " + Arrays.toString(testCase.analyticalVariances()));
1136 System.out.println("Analytical Standard Deviations: " + Arrays.toString(testCase.analyticalStandardDeviations()));
1137 System.out.println("Analytical Free Energies: " + Arrays.toString(testCase.analyticalFreeEnergies()));
1138
1139
1140 int[] N_k = {10, 20, 30, 40, 50};
1141 String setting = "u_kln";
1142 Object[] sampleResult = testCase.sample(N_k, setting, System.currentTimeMillis());
1143
1144 System.out.println("Sample x_n: " + Arrays.toString((double[]) sampleResult[0]));
1145 if ("u_kn".equals(setting)) {
1146 System.out.println("Sample u_kn: " + Arrays.deepToString((double[][]) sampleResult[1]));
1147 } else {
1148 System.out.println("Sample u_kln: " + Arrays.deepToString((double[][][]) sampleResult[1]));
1149 }
1150 System.out.println("Sample N_k: " + Arrays.toString((int[]) sampleResult[2]));
1151 System.out.println("Sample s_n: " + Arrays.toString((int[]) sampleResult[3]));
1152 }
1153 }
1154
1155 public static void writeFile(double[][] energies, File file, double temperature) {
1156 try (FileWriter fw = new FileWriter(file);
1157 BufferedWriter bw = new BufferedWriter(fw)) {
1158
1159 bw.write(energies[0].length + " " + temperature);
1160 bw.newLine();
1161
1162
1163 StringBuilder sb = new StringBuilder();
1164 for (int i = 0; i < energies[0].length; i++) {
1165 sb.append(" ").append(i).append(" ");
1166 for (int j = 0; j < energies.length; j++) {
1167 sb.append(" ").append(energies[j][i]).append(" ");
1168 }
1169 sb.append("\n");
1170 bw.write(sb.toString());
1171 sb = new StringBuilder();
1172 }
1173 } catch (IOException e) {
1174 e.printStackTrace();
1175 }
1176 }
1177
1178 public static void main(String[] args) {
1179 double[] O_k = {1, 2, 3, 4};
1180 double[] K_k = {.5, 1.0, 1.5, 2};
1181 int[] N_k = {10000, 10000, 10000, 10000};
1182 double beta = 1.0;
1183
1184
1185 HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(O_k, K_k, beta);
1186
1187
1188 String setting = "u_kln";
1189 System.out.print("Generating sample data... ");
1190 Object[] sampleResult = testCase.sample(N_k, setting, (long) 0);
1191 System.out.println("done. \n");
1192 double[][][] u_kln = (double[][][]) sampleResult[1];
1193 double[] temps = {1 / Constants.R};
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214 System.out.print("Creating MBAR instance and estimateDG() with standard tol & Zwanzig seeding.");
1215 MultistateBennettAcceptanceRatio mbar = new MultistateBennettAcceptanceRatio(O_k, u_kln, temps, 1.0E-7, SeedType.ZWANZIG);
1216 double[] mbarFEEstimates = Arrays.copyOf(mbar.mbarFreeEnergies, mbar.mbarFreeEnergies.length);
1217 double[] mbarUncertainties = Arrays.copyOf(mbar.mbarUncertainties, mbar.mbarUncertainties.length);
1218 double[][] mbarDiffMatrix = Arrays.copyOf(mbar.diffMatrix, mbar.diffMatrix.length);
1219
1220 EstimateBootstrapper bootstrapper = new EstimateBootstrapper(mbar);
1221 bootstrapper.bootstrap(50);
1222 System.out.println("done. \n");
1223
1224
1225 double[] analyticalFreeEnergies = testCase.analyticalFreeEnergies();
1226
1227 double[] error = new double[analyticalFreeEnergies.length];
1228 for (int i = 0; i < error.length; i++) {
1229 error[i] = -mbarFEEstimates[i] + analyticalFreeEnergies[i];
1230 }
1231
1232
1233 System.out.println("MBAR Free Energies: " + Arrays.toString(mbarFEEstimates));
1234 System.out.println("Analytical Free Energies: " + Arrays.toString(analyticalFreeEnergies));
1235 System.out.println("MBAR Uncertainties: " + Arrays.toString(mbarUncertainties));
1236 System.out.println("Free Energy Error: " + Arrays.toString(error));
1237 System.out.println();
1238 System.out.println("Diff Matrix: ");
1239 for (double[] matrix : mbarDiffMatrix) {
1240 System.out.println(Arrays.toString(matrix));
1241 }
1242 System.out.println("\n\n");
1243
1244
1245 double[] mbarBootstrappedEstimates = bootstrapper.getFE();
1246 double[] mbarBootstrappedFE = new double[mbarBootstrappedEstimates.length + 1];
1247 for (int i = 0; i < mbarBootstrappedEstimates.length; i++) {
1248 mbarBootstrappedFE[i + 1] = mbarBootstrappedEstimates[i] + mbarBootstrappedFE[i];
1249 }
1250 mbarUncertainties = bootstrapper.getUncertainty();
1251
1252 double[] errors = new double[mbarBootstrappedFE.length];
1253 for (int i = 0; i < errors.length; i++) {
1254 errors[i] = -mbarBootstrappedFE[i] + analyticalFreeEnergies[i];
1255 }
1256
1257 System.out.println("MBAR Bootstrapped Estimates: " + Arrays.toString(mbarBootstrappedFE));
1258 System.out.println("Analytical Estimates: " + Arrays.toString(analyticalFreeEnergies));
1259 System.out.println("MBAR Bootstrap Uncertainties: " + Arrays.toString(mbarUncertainties));
1260 System.out.println("Bootstrap Free Energy Error: " + Arrays.toString(errors));
1261 }
1262 }