1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.estimator;
39
40 import ffx.numerics.math.SummaryStatistics;
41
42 import java.util.Random;
43 import java.util.logging.Level;
44 import java.util.logging.Logger;
45
46 import static ffx.numerics.estimator.EstimateBootstrapper.getBootstrapIndices;
47 import static ffx.numerics.estimator.Zwanzig.Directionality.BACKWARDS;
48 import static ffx.numerics.estimator.Zwanzig.Directionality.FORWARDS;
49 import static ffx.numerics.math.ScalarMath.fermiFunction;
50 import static ffx.utilities.Constants.R;
51 import static java.lang.Double.isInfinite;
52 import static java.lang.Double.isNaN;
53 import static java.lang.String.format;
54 import static java.util.Arrays.copyOf;
55 import static java.util.Arrays.fill;
56 import static java.util.Arrays.stream;
57 import static org.apache.commons.math3.util.FastMath.abs;
58 import static org.apache.commons.math3.util.FastMath.log;
59 import static org.apache.commons.math3.util.FastMath.sqrt;
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81 public class BennettAcceptanceRatio extends SequentialEstimator implements BootstrappableEstimator {
82
83 private static final Logger logger = Logger.getLogger(BennettAcceptanceRatio.class.getName());
84
85
86
87
88 private static final double DEFAULT_TOLERANCE = 1.0E-4;
89
90
91
92 private static final int DEFAULT_MAX_BAR_ITERATIONS = 1000;
93
94
95
96 private final int nWindows;
97
98
99
100 private final double tolerance;
101
102
103
104 private final int nIterations;
105
106
107
108 private final Zwanzig forwardsFEP;
109
110
111
112 private final Zwanzig backwardsFEP;
113
114
115
116 private final Random random;
117
118
119
120 private double totalFreeEnergyDifference;
121
122
123
124 private double totalFEDifferenceUncertainty;
125
126
127
128 private final double[] freeEnergyDifferences;
129
130
131
132 private final double[] freeEnergyDifferenceUncertainties;
133
134
135
136 private final double[] enthalpyDifferences;
137
138
139
140 private final double[] forwardZwanzigFEDifferences;
141
142
143
144 private final double[] backwardZwanzigFEDifferences;
145
146
147
148
149
150
151
152
153
154
155 public BennettAcceptanceRatio(double[] lambdaValues, double[][] eLambdaMinusdL, double[][] eLambda,
156 double[][] eLambdaPlusdL, double[] temperature) {
157 this(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature, DEFAULT_TOLERANCE);
158 }
159
160
161
162
163
164
165
166
167
168
169
170 public BennettAcceptanceRatio(double[] lambdaValues, double[][] eLambdaMinusdL, double[][] eLambda,
171 double[][] eLambdaPlusdL, double[] temperature, double tolerance) {
172 this(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature, tolerance, DEFAULT_MAX_BAR_ITERATIONS);
173 }
174
175
176
177
178
179
180
181
182
183
184
185
186 public BennettAcceptanceRatio(double[] lambdaValues, double[][] eLambdaMinusdL, double[][] eLambda,
187 double[][] eLambdaPlusdL, double[] temperature, double tolerance, int nIterations) {
188
189 super(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature);
190
191
192 forwardsFEP = new Zwanzig(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature, FORWARDS);
193 backwardsFEP = new Zwanzig(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature, BACKWARDS);
194
195 nWindows = nStates - 1;
196 forwardZwanzigFEDifferences = forwardsFEP.getFreeEnergyDifferences();
197 backwardZwanzigFEDifferences = backwardsFEP.getFreeEnergyDifferences();
198
199 freeEnergyDifferences = new double[nWindows];
200 freeEnergyDifferenceUncertainties = new double[nWindows];
201 enthalpyDifferences = new double[nWindows];
202 this.tolerance = tolerance;
203 this.nIterations = nIterations;
204 random = new Random();
205
206 estimateDG();
207 }
208
209
210
211
212
213
214
215
216
217
218
219
220
221 private static void fermiDiffIterative(double[] e0, double[] e1, double[] fermiDiffs, int len,
222 double c, double invRT) {
223 for (int i = 0; i < len; i++) {
224 fermiDiffs[i] = fermiFunction(invRT * (e0[i] - e1[i] + c));
225 }
226 if (stream(fermiDiffs).sum() == 0) {
227 logger.warning(format(" Input Fermi with length %3d should not be permitted: c: %9.4f invRT: %9.4f Fermi output: %9.4f", len, c, invRT, stream(fermiDiffs).sum()));
228 }
229 }
230
231
232
233
234
235
236
237
238
239
240
241 private void calcAlphaForward(double[] e0, double[] e1, int len, double c,
242 double invRT, double[] ret) {
243 double fsum = 0;
244 double fvsum = 0;
245 double fbvsum = 0;
246 double vsum = 0;
247 double fbsum = 0;
248 for (int i = 0; i < len; i++) {
249 double fore = fermiFunction(invRT * (e1[i] - e0[i] - c));
250 double back = fermiFunction(invRT * (e0[i] - e1[i] + c));
251 fsum += fore;
252 fvsum += fore * e0[i];
253 fbvsum += fore * back * (e1[i] - e0[i]);
254 vsum += e0[i];
255 fbsum += fore * back;
256 }
257 double alpha = fvsum - (fsum * (vsum / len)) + fbvsum;
258 ret[0] = alpha;
259 ret[1] = fbsum;
260 }
261
262
263
264
265
266
267
268
269
270
271
272 private void calcAlphaBackward(double[] e0, double[] e1, int len, double c,
273 double invRT, double[] ret) {
274 double bsum = 0;
275 double bvsum = 0;
276 double fbvsum = 0;
277 double vsum = 0;
278 double fbsum = 0;
279 for (int i = 0; i < len; i++) {
280 double fore = fermiFunction(invRT * (e1[i] - e0[i] - c));
281 double back = fermiFunction(invRT * (e0[i] - e1[i] + c));
282 bsum += back;
283 bvsum += back * e1[i];
284 fbvsum += fore * back * (e1[i] - e0[i]);
285 vsum += e1[i];
286 fbsum += fore * back;
287 }
288 double alpha = bvsum - (bsum * (vsum / len)) - fbvsum;
289 ret[0] = alpha;
290 ret[1] = fbsum;
291 }
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307 private static void fermiDiffBootstrap(double[] e0, double[] e1, double[] fermiDiffs,
308 int len, double c, double invRT, int[] bootstrapSamples) {
309 for (int indexI = 0; indexI < len; indexI++) {
310 int i = bootstrapSamples[indexI];
311 fermiDiffs[indexI] = fermiFunction(invRT * (e0[i] - e1[i] + c));
312 }
313 }
314
315
316
317
318
319
320
321
322
323 private static double uncertaintyCalculation(double meanFermi, double meanSqFermi, int len) {
324 double sqMeanFermi = meanFermi * meanFermi;
325 return ((meanSqFermi - sqMeanFermi) / len) / sqMeanFermi;
326 }
327
328
329
330
331
332
333 public Zwanzig getInitialBackwardsGuess() {
334 return backwardsFEP;
335 }
336
337
338
339
340
341
342 public Zwanzig getInitialForwardsGuess() {
343 return forwardsFEP;
344 }
345
346
347
348
349 @Override
350 public BennettAcceptanceRatio copyEstimator() {
351 return new BennettAcceptanceRatio(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures, tolerance, nIterations);
352 }
353
354
355
356
357
358
359
360
361
362 @Override
363 public final void estimateDG(final boolean randomSamples) {
364 double cumDG = 0;
365 fill(freeEnergyDifferences, 0);
366 fill(freeEnergyDifferenceUncertainties, 0);
367 fill(enthalpyDifferences, 0);
368
369
370 Level warningLevel = randomSamples ? Level.FINE : Level.WARNING;
371
372 for (int i = 0; i < nWindows; i++) {
373
374 if (isNaN(forwardZwanzigFEDifferences[i]) || isInfinite(forwardZwanzigFEDifferences[i])
375 || isNaN(backwardZwanzigFEDifferences[i]) || isInfinite(backwardZwanzigFEDifferences[i])) {
376 logger.warning(format(" Window %3d bin energies produced unreasonable value(s) for forward Zwanzig (%8.4f) and/or backward Zwanzig (%8.4f)", i, forwardZwanzigFEDifferences[i], backwardZwanzigFEDifferences[i]));
377 }
378 double c = 0.5 * (forwardZwanzigFEDifferences[i] + backwardZwanzigFEDifferences[i]);
379
380
381 if (!randomSamples) {
382 logger.fine(format(" BAR Iteration Seed: %12.4f Kcal/mol", c));
383 }
384
385 double cold = c;
386 int len0 = eLambda[i].length;
387 int len1 = eLambda[i + 1].length;
388
389 if (len0 == 0 || len1 == 0) {
390 freeEnergyDifferences[i] = c;
391 logger.log(warningLevel, format(" Window %d has no snapshots at one end (%d, %d)!", i, len0, len1));
392 continue;
393 }
394
395
396 double sampleRatio = ((double) len0) / ((double) len1);
397
398
399 double[] fermi0 = new double[len0];
400 double[] fermi1 = new double[len1];
401 double[] ret = new double[2];
402
403
404 double rta = R * temperatures[i];
405 double rtb = R * temperatures[i + 1];
406 double rtMean = 0.5 * (rta + rtb);
407 double invRTA = 1.0 / rta;
408 double invRTB = 1.0 / rtb;
409
410
411 SummaryStatistics s1 = null;
412
413 SummaryStatistics s0 = null;
414
415
416 int[] bootstrapSamples0 = null;
417 int[] bootstrapSamples1 = null;
418
419 if (randomSamples) {
420 bootstrapSamples0 = getBootstrapIndices(len0, random);
421 bootstrapSamples1 = getBootstrapIndices(len1, random);
422 }
423
424 int cycleCounter = 0;
425 boolean converged = false;
426 while (!converged) {
427 if (randomSamples) {
428 fermiDiffBootstrap(eLambdaPlusdL[i], eLambda[i], fermi0, len0, -c, invRTA, bootstrapSamples0);
429 fermiDiffBootstrap(eLambdaMinusdL[i + 1], eLambda[i + 1], fermi1, len1, c, invRTB, bootstrapSamples1);
430 } else {
431 fermiDiffIterative(eLambdaPlusdL[i], eLambda[i], fermi0, len0, -c, invRTA);
432 fermiDiffIterative(eLambdaMinusdL[i + 1], eLambda[i + 1], fermi1, len1, c, invRTB);
433 }
434
435 s0 = new SummaryStatistics(fermi0);
436 s1 = new SummaryStatistics(fermi1);
437 double ratio = s1.sum / s0.sum;
438 c += rtMean * log(sampleRatio * ratio);
439
440 cycleCounter++;
441 converged = (abs(c - cold) < tolerance);
442
443 if (!randomSamples && !converged && cycleCounter > nIterations) {
444 throw new IllegalArgumentException(
445 format(" BAR required too many iterations (%d) to converge! (%9.8f > %9.8f)", cycleCounter, abs(c - cold), tolerance));
446 }
447
448 if (!randomSamples) {
449 logger.fine(format(" BAR Iteration %2d: %12.4f Kcal/mol", cycleCounter, c));
450 }
451 cold = c;
452 }
453
454 freeEnergyDifferences[i] = c;
455 cumDG += c;
456 double sqFermiMean0 = new SummaryStatistics(stream(fermi0).map((double d) -> d * d).toArray()).mean;
457 double sqFermiMean1 = new SummaryStatistics(stream(fermi1).map((double d) -> d * d).toArray()).mean;
458 freeEnergyDifferenceUncertainties[i] = sqrt(uncertaintyCalculation(s0.mean, sqFermiMean0, len0)
459 + uncertaintyCalculation(s1.mean, sqFermiMean1, len1));
460
461 calcAlphaForward(eLambda[i], eLambdaPlusdL[i], len0, c, invRTA, ret);
462 double alpha0 = ret[0];
463 double fbsum0 = ret[1];
464
465 calcAlphaBackward(eLambdaMinusdL[i + 1], eLambda[i + 1], len1, c, invRTB, ret);
466 double alpha1 = ret[0];
467 double fbsum1 = ret[1];
468
469 double hBar = (alpha0 - alpha1) / (fbsum0 + fbsum1);
470 enthalpyDifferences[i] = hBar;
471 }
472
473 totalFreeEnergyDifference = cumDG;
474 totalFEDifferenceUncertainty = sqrt(stream(freeEnergyDifferenceUncertainties).map((double d) -> d * d).sum());
475 }
476
477
478
479
480 @Override
481 public double[] getFreeEnergyDifferences() {
482 return copyOf(freeEnergyDifferences, nWindows);
483 }
484
485
486
487
488 @Override
489 public double[] getFEDifferenceUncertainties() {
490 return copyOf(freeEnergyDifferenceUncertainties, nWindows);
491 }
492
493
494
495
496 @Override
497 public double getTotalFreeEnergyDifference() {
498 return totalFreeEnergyDifference;
499 }
500
501
502
503
504 @Override
505 public double getTotalFEDifferenceUncertainty() {
506 return totalFEDifferenceUncertainty;
507 }
508
509
510
511
512 @Override
513 public int getNumberOfBins() {
514 return nWindows;
515 }
516
517
518
519
520 @Override
521 public double getTotalEnthalpyDifference() {
522 return getTotalEnthalpyDifference(enthalpyDifferences);
523 }
524
525
526
527
528 @Override
529 public double[] getEnthalpyDifferences() {
530 return enthalpyDifferences;
531 }
532 }