View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.estimator;
39  
40  import ffx.numerics.math.SummaryStatistics;
41  
42  import java.util.Random;
43  import java.util.logging.Level;
44  import java.util.logging.Logger;
45  
46  import static ffx.numerics.estimator.EstimateBootstrapper.getBootstrapIndices;
47  import static ffx.numerics.estimator.Zwanzig.Directionality.BACKWARDS;
48  import static ffx.numerics.estimator.Zwanzig.Directionality.FORWARDS;
49  import static ffx.numerics.math.ScalarMath.fermiFunction;
50  import static ffx.utilities.Constants.R;
51  import static java.lang.Double.isInfinite;
52  import static java.lang.Double.isNaN;
53  import static java.lang.String.format;
54  import static java.util.Arrays.copyOf;
55  import static java.util.Arrays.fill;
56  import static java.util.Arrays.stream;
57  import static org.apache.commons.math3.util.FastMath.abs;
58  import static org.apache.commons.math3.util.FastMath.log;
59  import static org.apache.commons.math3.util.FastMath.sqrt;
60  
61  /**
62   * The Bennett Acceptance Ratio class implements the Bennett Acceptance Ratio (BAR) statistical
63   * estimator, based on the Tinker implementation.
64   *
65   * <p>Literature References (from Tinker): C. H. Bennett, "Efficient Estimation of Free Energy
66   * Differences from Monte Carlo Data", Journal of Computational Physics, 22, 245-268 (1976)
67   *
68   * <p>M. A. Wyczalkowski, A. Vitalis and R. V. Pappu, "New Estimators for Calculating Solvation
69   * Entropy and Enthalpy and Comparative Assessments of Their Accuracy and Precision, Journal of
70   * Physical Chemistry, 114, 8166-8180 (2010) [modified BAR algorithm, non-implemented
71   * entropy/enthalpy]
72   *
73   * <p>K. B. Daly, J. B. Benziger, P. G. Debenedetti and A. Z. Panagiotopoulos, "Massively Parallel
74   * Chemical Potential Calculation on Graphics Processing Units", Computer Physics Communications,
75   * 183, 2054-2062 (2012) [non-implemented NPT modification]
76   *
77   * @author Michael J. Schnieders
78   * @author Jacob M. Litman
79   * @since 1.0
80   */
81  public class BennettAcceptanceRatio extends SequentialEstimator implements BootstrappableEstimator {
82  
83    private static final Logger logger = Logger.getLogger(BennettAcceptanceRatio.class.getName());
84  
85    /**
86     * Default BAR convergence tolerance.
87     */
88    private static final double DEFAULT_TOLERANCE = 1.0E-4;
89    /**
90     * Default maximum number of BAR iterations.
91     */
92    private static final int DEFAULT_MAX_BAR_ITERATIONS = 1000;
93    /**
94     * Number of state pairs.
95     */
96    private final int nWindows;
97    /**
98     * BAR convergence tolerance.
99     */
100   private final double tolerance;
101   /**
102    * BAR maximum number of iterations.
103    */
104   private final int nIterations;
105   /**
106    * Forward Zwanzig instance.
107    */
108   private final Zwanzig forwardsFEP;
109   /**
110    * Backward Zwanzig instance.
111    */
112   private final Zwanzig backwardsFEP;
113   /**
114    * Random number generator for bootstrapping.
115    */
116   private final Random random;
117   /**
118    * Total BAR free-energy difference estimate.
119    */
120   private double totalFreeEnergyDifference;
121   /**
122    * Total BAR free-energy difference uncertainty.
123    */
124   private double totalFEDifferenceUncertainty;
125   /**
126    * BAR free-energy difference estimates.
127    */
128   private final double[] freeEnergyDifferences;
129   /**
130    * BAR free-energy difference uncertainties.
131    */
132   private final double[] freeEnergyDifferenceUncertainties;
133   /**
134    * BAR Enthalpy estimates
135    */
136   private final double[] enthalpyDifferences;
137   /**
138    * Forward Zwanzig free-energy difference estimates.
139    */
140   private final double[] forwardZwanzigFEDifferences;
141   /**
142    * Backward Zwanzig free-energy difference estimates.
143    */
144   private final double[] backwardZwanzigFEDifferences;
145 
146   /**
147    * Constructs a BAR estimator and obtains an initial free energy estimate.
148    *
149    * @param lambdaValues   Value of lambda for each state.
150    * @param eLambdaMinusdL Energies of state L samples at L+dL.
151    * @param eLambda        Energies of state L samples at L.
152    * @param eLambdaPlusdL  Energies of state L samples at L+dL.
153    * @param temperature    Temperature of each state.
154    */
155   public BennettAcceptanceRatio(double[] lambdaValues, double[][] eLambdaMinusdL, double[][] eLambda,
156                                 double[][] eLambdaPlusdL, double[] temperature) {
157     this(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature, DEFAULT_TOLERANCE);
158   }
159 
160   /**
161    * Constructs a BAR estimator and obtains an initial free energy estimate.
162    *
163    * @param lambdaValues   Value of lambda for each state.
164    * @param eLambdaMinusdL Energies of state L samples at L+dL.
165    * @param eLambda        Energies of state L samples at L.
166    * @param eLambdaPlusdL  Energies of state L samples at L+dL.
167    * @param temperature    Temperature of each state.
168    * @param tolerance      Convergence criterion in kcal/mol for BAR iteration.
169    */
170   public BennettAcceptanceRatio(double[] lambdaValues, double[][] eLambdaMinusdL, double[][] eLambda,
171                                 double[][] eLambdaPlusdL, double[] temperature, double tolerance) {
172     this(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature, tolerance, DEFAULT_MAX_BAR_ITERATIONS);
173   }
174 
175   /**
176    * Constructs a BAR estimator and obtains an initial free energy estimate.
177    *
178    * @param lambdaValues   Value of lambda for each state.
179    * @param eLambdaMinusdL Energies of state L samples at L+dL.
180    * @param eLambda        Energies of state L samples at L.
181    * @param eLambdaPlusdL  Energies of state L samples at L+dL.
182    * @param temperature    Temperature of each state.
183    * @param tolerance      Convergence criterion in kcal/mol for BAR iteration.
184    * @param nIterations    Maximum number of iterations for BAR.
185    */
186   public BennettAcceptanceRatio(double[] lambdaValues, double[][] eLambdaMinusdL, double[][] eLambda,
187                                 double[][] eLambdaPlusdL, double[] temperature, double tolerance, int nIterations) {
188 
189     super(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature);
190 
191     // Used to seed an initial guess.
192     forwardsFEP = new Zwanzig(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature, FORWARDS);
193     backwardsFEP = new Zwanzig(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature, BACKWARDS);
194 
195     nWindows = nStates - 1;
196     forwardZwanzigFEDifferences = forwardsFEP.getFreeEnergyDifferences();
197     backwardZwanzigFEDifferences = backwardsFEP.getFreeEnergyDifferences();
198 
199     freeEnergyDifferences = new double[nWindows];
200     freeEnergyDifferenceUncertainties = new double[nWindows];
201     enthalpyDifferences = new double[nWindows];
202     this.tolerance = tolerance;
203     this.nIterations = nIterations;
204     random = new Random();
205 
206     estimateDG();
207   }
208 
209   /**
210    * Calculates the Fermi function for the differences used in estimating c.
211    *
212    * <p>f(x) = 1 / (1 + exp(x)) x = (e1 - e0 + c) * invRT
213    *
214    * @param e0         Perturbed energy (to be added; evaluated at L +/- dL).
215    * @param e1         Unperturbed energy (to be subtracted; evaluated at L).
216    * @param fermiDiffs Array to be filled with Fermi differences.
217    * @param len        Number of energies.
218    * @param c          Prior best estimate of the BAR offset/free energy.
219    * @param invRT      1.0 / ideal gas constant * temperature.
220    */
221   private static void fermiDiffIterative(double[] e0, double[] e1, double[] fermiDiffs, int len,
222                                          double c, double invRT) {
223     for (int i = 0; i < len; i++) {
224       fermiDiffs[i] = fermiFunction(invRT * (e0[i] - e1[i] + c));
225     }
226     if (stream(fermiDiffs).sum() == 0) {
227       logger.warning(format(" Input Fermi with length %3d should not be permitted: c: %9.4f invRT: %9.4f Fermi output: %9.4f", len, c, invRT, stream(fermiDiffs).sum()));
228     }
229   }
230 
231   /**
232    * Calculates forward alpha and fbsum for BAR Enthalpy estimation.
233    *
234    * @param e0    Perturbed energy (to be added; evaluated at L +/- dL).
235    * @param e1    Unperturbed energy (to be subtracted; evaluated at L).
236    * @param len   Number of energies.
237    * @param c     Prior best estimate of the BAR offset/free energy.
238    * @param invRT 1.0 / ideal gas constant * temperature.
239    * @param ret   Return alpha and fbsum.
240    */
241   private void calcAlphaForward(double[] e0, double[] e1, int len, double c,
242                                 double invRT, double[] ret) {
243     double fsum = 0;
244     double fvsum = 0;
245     double fbvsum = 0;
246     double vsum = 0;
247     double fbsum = 0;
248     for (int i = 0; i < len; i++) {
249       double fore = fermiFunction(invRT * (e1[i] - e0[i] - c));
250       double back = fermiFunction(invRT * (e0[i] - e1[i] + c));
251       fsum += fore;
252       fvsum += fore * e0[i];
253       fbvsum += fore * back * (e1[i] - e0[i]);
254       vsum += e0[i];
255       fbsum += fore * back;
256     }
257     double alpha = fvsum - (fsum * (vsum / len)) + fbvsum;
258     ret[0] = alpha;
259     ret[1] = fbsum;
260   }
261 
262   /**
263    * Calculates backward alpha and fbsum for BAR Enthalpy estimation.
264    *
265    * @param e0    Perturbed energy (to be added; evaluated at L +/- dL).
266    * @param e1    Unperturbed energy (to be subtracted; evaluated at L).
267    * @param len   Number of energies.
268    * @param c     Prior best estimate of the BAR offset/free energy.
269    * @param invRT 1.0 / ideal gas constant * temperature.
270    * @param ret   Return alpha and fbsum.
271    */
272   private void calcAlphaBackward(double[] e0, double[] e1, int len, double c,
273                                  double invRT, double[] ret) {
274     double bsum = 0;
275     double bvsum = 0;
276     double fbvsum = 0;
277     double vsum = 0;
278     double fbsum = 0;
279     for (int i = 0; i < len; i++) {
280       double fore = fermiFunction(invRT * (e1[i] - e0[i] - c));
281       double back = fermiFunction(invRT * (e0[i] - e1[i] + c));
282       bsum += back;
283       bvsum += back * e1[i];
284       fbvsum += fore * back * (e1[i] - e0[i]);
285       vsum += e1[i];
286       fbsum += fore * back;
287     }
288     double alpha = bvsum - (bsum * (vsum / len)) - fbvsum;
289     ret[0] = alpha;
290     ret[1] = fbsum;
291   }
292 
293 
294   /**
295    * Calculates the Fermi function for the differences used in estimating c, using bootstrap sampling
296    * (choosing random indices with replacement rather than scanning through them all).
297    *
298    * <p>f(x) = 1 / (1 + exp(x)) x = (e1 - e0 + c) * invRT
299    *
300    * @param e0         Perturbed energy (to be added; evaluated at L +/- dL).
301    * @param e1         Unperturbed energy (to be subtracted; evaluated at L).
302    * @param fermiDiffs Array to be filled with Fermi differences.
303    * @param len        Number of energies.
304    * @param c          Prior best estimate of the BAR offset/free energy.
305    * @param invRT      1.0 / ideal gas constant * temperature.
306    */
307   private static void fermiDiffBootstrap(double[] e0, double[] e1, double[] fermiDiffs,
308                                          int len, double c, double invRT, int[] bootstrapSamples) {
309     for (int indexI = 0; indexI < len; indexI++) {
310       int i = bootstrapSamples[indexI];
311       fermiDiffs[indexI] = fermiFunction(invRT * (e0[i] - e1[i] + c));
312     }
313   }
314 
315   /**
316    * Computes one half of the BAR variance.
317    *
318    * @param meanFermi   Mean Fermi value for either state 0 or state 1.
319    * @param meanSqFermi Mean squared Fermi value for either state 0 or state 1.
320    * @param len         Number of values.
321    * @return One half of BAR variance.
322    */
323   private static double uncertaintyCalculation(double meanFermi, double meanSqFermi, int len) {
324     double sqMeanFermi = meanFermi * meanFermi;
325     return ((meanSqFermi - sqMeanFermi) / len) / sqMeanFermi;
326   }
327 
328   /**
329    * Returns the backwards Zwanzig estimator used to seed BAR.
330    *
331    * @return A backwards Zwanzig estimator.
332    */
333   public Zwanzig getInitialBackwardsGuess() {
334     return backwardsFEP;
335   }
336 
337   /**
338    * Returns the forwards Zwanzig estimator used to seed BAR.
339    *
340    * @return A forwards Zwanzig estimator.
341    */
342   public Zwanzig getInitialForwardsGuess() {
343     return forwardsFEP;
344   }
345 
346   /**
347    * {@inheritDoc}
348    */
349   @Override
350   public BennettAcceptanceRatio copyEstimator() {
351     return new BennettAcceptanceRatio(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures, tolerance, nIterations);
352   }
353 
354   /**
355    * Main driver for estimation of BAR free energy differences.
356    * <p>
357    * Based on Tinker implementation, which uses the substitution proposed in
358    * Wyczalkowski, Vitalis and Pappu 2010.
359    *
360    * @param randomSamples Whether to use random sampling (for bootstrap analysis).
361    */
362   @Override
363   public final void estimateDG(final boolean randomSamples) {
364     double cumDG = 0;
365     fill(freeEnergyDifferences, 0);
366     fill(freeEnergyDifferenceUncertainties, 0);
367     fill(enthalpyDifferences, 0);
368 
369     // Avoid duplicate warnings when bootstrapping.
370     Level warningLevel = randomSamples ? Level.FINE : Level.WARNING;
371 
372     for (int i = 0; i < nWindows; i++) {
373       // Free energy estimate/shift constant.
374       if (isNaN(forwardZwanzigFEDifferences[i]) || isInfinite(forwardZwanzigFEDifferences[i])
375           || isNaN(backwardZwanzigFEDifferences[i]) || isInfinite(backwardZwanzigFEDifferences[i])) {
376         logger.warning(format(" Window %3d bin energies produced unreasonable value(s) for forward Zwanzig (%8.4f) and/or backward Zwanzig (%8.4f)", i, forwardZwanzigFEDifferences[i], backwardZwanzigFEDifferences[i]));
377       }
378       double c = 0.5 * (forwardZwanzigFEDifferences[i] + backwardZwanzigFEDifferences[i]);
379       // double c = 0.0;
380 
381       if (!randomSamples) {
382         logger.fine(format(" BAR Iteration Seed: %12.4f Kcal/mol", c));
383       }
384 
385       double cold = c;
386       int len0 = eLambda[i].length;
387       int len1 = eLambda[i + 1].length;
388 
389       if (len0 == 0 || len1 == 0) {
390         freeEnergyDifferences[i] = c;
391         logger.log(warningLevel, format(" Window %d has no snapshots at one end (%d, %d)!", i, len0, len1));
392         continue;
393       }
394 
395       // Ratio of the number of snaps: Tinker equivalent: rfrm
396       double sampleRatio = ((double) len0) / ((double) len1);
397 
398       // Fermi differences.
399       double[] fermi0 = new double[len0];
400       double[] fermi1 = new double[len1];
401       double[] ret = new double[2];
402 
403       // Ideal gas constant * temperature, or its inverse.
404       double rta = R * temperatures[i];
405       double rtb = R * temperatures[i + 1];
406       double rtMean = 0.5 * (rta + rtb);
407       double invRTA = 1.0 / rta;
408       double invRTB = 1.0 / rtb;
409 
410       // Summary statistics for Fermi differences for the upper half.
411       SummaryStatistics s1 = null;
412       // Summary statistics for Fermi differences for the lower half.
413       SummaryStatistics s0 = null;
414 
415       // Each BAR convergence cycle needs to operate on the same set of indices.
416       int[] bootstrapSamples0 = null;
417       int[] bootstrapSamples1 = null;
418 
419       if (randomSamples) {
420         bootstrapSamples0 = getBootstrapIndices(len0, random);
421         bootstrapSamples1 = getBootstrapIndices(len1, random);
422       }
423 
424       int cycleCounter = 0;
425       boolean converged = false;
426       while (!converged) {
427         if (randomSamples) {
428           fermiDiffBootstrap(eLambdaPlusdL[i], eLambda[i], fermi0, len0, -c, invRTA, bootstrapSamples0);
429           fermiDiffBootstrap(eLambdaMinusdL[i + 1], eLambda[i + 1], fermi1, len1, c, invRTB, bootstrapSamples1);
430         } else {
431           fermiDiffIterative(eLambdaPlusdL[i], eLambda[i], fermi0, len0, -c, invRTA);
432           fermiDiffIterative(eLambdaMinusdL[i + 1], eLambda[i + 1], fermi1, len1, c, invRTB);
433         }
434 
435         s0 = new SummaryStatistics(fermi0);
436         s1 = new SummaryStatistics(fermi1);
437         double ratio = s1.sum / s0.sum;
438         c += rtMean * log(sampleRatio * ratio);
439 
440         cycleCounter++;
441         converged = (abs(c - cold) < tolerance);
442 
443         if (!randomSamples && !converged && cycleCounter > nIterations) {
444           throw new IllegalArgumentException(
445               format(" BAR required too many iterations (%d) to converge! (%9.8f > %9.8f)", cycleCounter, abs(c - cold), tolerance));
446         }
447 
448         if (!randomSamples) {
449           logger.fine(format(" BAR Iteration   %2d: %12.4f Kcal/mol", cycleCounter, c));
450         }
451         cold = c;
452       }
453 
454       freeEnergyDifferences[i] = c;
455       cumDG += c;
456       double sqFermiMean0 = new SummaryStatistics(stream(fermi0).map((double d) -> d * d).toArray()).mean;
457       double sqFermiMean1 = new SummaryStatistics(stream(fermi1).map((double d) -> d * d).toArray()).mean;
458       freeEnergyDifferenceUncertainties[i] = sqrt(uncertaintyCalculation(s0.mean, sqFermiMean0, len0)
459           + uncertaintyCalculation(s1.mean, sqFermiMean1, len1));
460 
461       calcAlphaForward(eLambda[i], eLambdaPlusdL[i], len0, c, invRTA, ret);
462       double alpha0 = ret[0];
463       double fbsum0 = ret[1];
464 
465       calcAlphaBackward(eLambdaMinusdL[i + 1], eLambda[i + 1], len1, c, invRTB, ret);
466       double alpha1 = ret[0];
467       double fbsum1 = ret[1];
468 
469       double hBar = (alpha0 - alpha1) / (fbsum0 + fbsum1);
470       enthalpyDifferences[i] = hBar;
471     }
472 
473     totalFreeEnergyDifference = cumDG;
474     totalFEDifferenceUncertainty = sqrt(stream(freeEnergyDifferenceUncertainties).map((double d) -> d * d).sum());
475   }
476 
477   /**
478    * {@inheritDoc}
479    */
480   @Override
481   public double[] getFreeEnergyDifferences() {
482     return copyOf(freeEnergyDifferences, nWindows);
483   }
484 
485   /**
486    * {@inheritDoc}
487    */
488   @Override
489   public double[] getFEDifferenceUncertainties() {
490     return copyOf(freeEnergyDifferenceUncertainties, nWindows);
491   }
492 
493   /**
494    * {@inheritDoc}
495    */
496   @Override
497   public double getTotalFreeEnergyDifference() {
498     return totalFreeEnergyDifference;
499   }
500 
501   /**
502    * {@inheritDoc}
503    */
504   @Override
505   public double getTotalFEDifferenceUncertainty() {
506     return totalFEDifferenceUncertainty;
507   }
508 
509   /**
510    * {@inheritDoc}
511    */
512   @Override
513   public int getNumberOfBins() {
514     return nWindows;
515   }
516 
517   /**
518    * {@inheritDoc}
519    */
520   @Override
521   public double getTotalEnthalpyDifference() {
522     return getTotalEnthalpyDifference(enthalpyDifferences);
523   }
524 
525   /**
526    * {@inheritDoc}
527    */
528   @Override
529   public double[] getEnthalpyDifferences() {
530     return enthalpyDifferences;
531   }
532 }