View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2023.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.estimator;
39  
40  import ffx.numerics.OptimizationInterface;
41  import ffx.numerics.optimization.LBFGS;
42  import ffx.numerics.optimization.LineSearch;
43  import ffx.numerics.optimization.OptimizationListener;
44  import ffx.utilities.Constants;
45  import org.apache.commons.math3.linear.MatrixUtils;
46  import org.apache.commons.math3.linear.RealMatrix;
47  import org.apache.commons.math3.linear.SingularValueDecomposition;
48  import org.apache.commons.math3.util.MathArrays;
49  
50  import java.io.BufferedWriter;
51  import java.io.File;
52  import java.io.FileWriter;
53  import java.io.IOException;
54  import java.util.Arrays;
55  import java.util.Random;
56  import java.util.logging.Logger;
57  
58  import static ffx.numerics.estimator.EstimateBootstrapper.getBootstrapIndices;
59  import static ffx.numerics.estimator.Zwanzig.Directionality.BACKWARDS;
60  import static ffx.numerics.estimator.Zwanzig.Directionality.FORWARDS;
61  import static java.lang.System.arraycopy;
62  import static java.util.Arrays.copyOf;
63  import static java.util.Arrays.stream;
64  import static org.apache.commons.lang3.ArrayFill.fill;
65  import static org.apache.commons.math3.util.FastMath.abs;
66  import static org.apache.commons.math3.util.FastMath.exp;
67  import static org.apache.commons.math3.util.FastMath.log;
68  import static org.apache.commons.math3.util.FastMath.sqrt;
69  
70  /**
71   * The MultistateBennettAcceptanceRatio class defines a statistical estimator based on a generalization
72   * to the Bennett Acceptance Ratio (BAR) method for multiple lambda windows. It requires an input of
73   * K X N array of energies (every window at every snap at every lambda value). No support for different
74   * number of snapshots at each window. This will be caught by the filter, but not by the Harmonic Oscillators
75   * testcase.
76   * <p>
77   * This class implements the method discussed in:
78   * Shirts, M. R. and Chodera, J. D. (2008) Statistically optimal analysis of samples from multiple equilibrium
79   * states. J. Chem. Phys. 129, 124105. doi:10.1063/1.2978177
80   * <p>
81   * This class is based heavily on the pymbar code, which is available at:
82   * https://github.com/choderalab/pymbar/tree/master
83   *
84   * @author Matthew J. Speranza
85   * @since 1.0
86   */
87  public class MultistateBennettAcceptanceRatio extends SequentialEstimator implements BootstrappableEstimator, OptimizationInterface {
88    private static final Logger logger = Logger.getLogger(MultistateBennettAcceptanceRatio.class.getName());
89  
90    /**
91     * Default BAR convergence tolerance.
92     */
93    private static final double DEFAULT_TOLERANCE = 1.0E-7;
94    /**
95     * Number of free of differences between simulation windows.
96     */
97    private final int nFreeEnergyDiffs;
98    /**
99     * MBAR free-energy difference estimates.
100    */
101   private final double[] mbarEstimates;
102   /**
103    * MBAR free-energy difference uncertainties.
104    */
105   private double[] mbarUncertainties;
106 
107   /**
108    * Matrix of free-energy uncertainties between all i & j
109    */
110   private double[][] diffMatrix;
111   /**
112    * BAR convergence tolerance.
113    */
114   private final double tolerance;
115   private final Random random;
116   private final int nStates;
117   /**
118    * MBAR free-energy estimates at each lambda value.
119    */
120   double[] mbarFreeEnergies;
121   /**
122    * Total MBAR free-energy difference estimate.
123    */
124   private double totalMBAREstimate;
125   /**
126    * Total MBAR free-energy difference uncertainty.
127    */
128   private double totalMBARUncertainty;
129   /**
130    * MBAR Enthalpy estimates
131    */
132   private final double[] mbarEnthalpy;
133 
134   /**
135    * Potential energy evaluations.
136    */
137   private double[][] u_kn;
138   /**
139    * Number of samples per state (only equal numbers are allowed).
140    */
141   private double[] N_k;
142   /**
143    * Seed MBAR calculation with another free energy estimation (BAR,ZWANZIG) or zeros
144    */
145   private SeedType seedType;
146 
147   /**
148    * Enum of MBAR seed types.
149    */
150   public enum SeedType {BAR, ZWANZIG, ZEROS}
151 
152   /**
153    * Constructor for MBAR estimator.
154    *
155    * @param lambdaValues array of lambda values
156    * @param energiesAll  array of energies at each lambda value
157    * @param temperature  array of temperatures
158    */
159   public MultistateBennettAcceptanceRatio(double[] lambdaValues, double[][][] energiesAll, double[] temperature) {
160     this(lambdaValues, energiesAll, temperature, DEFAULT_TOLERANCE, SeedType.ZWANZIG);
161   }
162 
163   /**
164    * Constructor for MBAR estimator.
165    *
166    * @param lambdaValues array of lambda values
167    * @param energiesAll  array of energies at each lambda value
168    * @param temperature  array of temperatures
169    * @param tolerance    convergence tolerance
170    * @param seedType     seed type for MBAR
171    */
172   public MultistateBennettAcceptanceRatio(double[] lambdaValues, double[][][] energiesAll, double[] temperature,
173                                           double tolerance, SeedType seedType) {
174     super(lambdaValues, energiesAll, temperature);
175     this.tolerance = tolerance;
176     this.seedType = seedType;
177 
178     // MBAR calculates free energy at each lambda value (only the differences between them have physical significance)
179     nStates = lambdaValues.length;
180     mbarFreeEnergies = new double[nStates];
181 
182     nFreeEnergyDiffs = lambdaValues.length - 1;
183     mbarEstimates = new double[nFreeEnergyDiffs];
184     mbarUncertainties = new double[nFreeEnergyDiffs];
185     mbarEnthalpy = new double[nFreeEnergyDiffs];
186     random = new Random();
187     estimateDG();
188   }
189 
190   /**
191    * Set the MBAR seed energies using BAR, Zwanzig or zeros.
192    */
193   private void seedEnergies() {
194     switch (seedType) {
195       case BAR:
196         try {
197           SequentialEstimator barEstimator = new BennettAcceptanceRatio(lamValues, eLow, eAt, eHigh, temperatures);
198           mbarFreeEnergies[0] = 0.0;
199           double[] barEstimates = barEstimator.getBinEnergies();
200           for (int i = 0; i < nFreeEnergyDiffs; i++) {
201             mbarFreeEnergies[i + 1] = mbarFreeEnergies[i] + barEstimates[i];
202           }
203           break;
204         } catch (IllegalArgumentException e) {
205           logger.warning(" BAR failed to converge. Zwanzig will be used for seed energies.");
206           seedType = SeedType.ZWANZIG;
207           seedEnergies();
208           return;
209         }
210       case ZWANZIG:
211         // Forward Zwanzig instance.
212         Zwanzig forwardsFEP = new Zwanzig(lamValues, eLow, eAt, eHigh, temperatures, FORWARDS);
213         // Backward Zwanzig instance.
214         Zwanzig backwardsFEP = new Zwanzig(lamValues, eLow, eAt, eHigh, temperatures, BACKWARDS);
215         // Forward Zwanzig free-energy difference estimates.
216         double[] forwardZwanzig = forwardsFEP.getBinEnergies();
217         // Backward Zwanzig free-energy difference estimates.
218         double[] backwardZwanzig = backwardsFEP.getBinEnergies();
219         mbarFreeEnergies[0] = 0.0;
220         for (int i = 0; i < nFreeEnergyDiffs; i++) {
221           mbarFreeEnergies[i + 1] = mbarFreeEnergies[i] + .5 * (forwardZwanzig[i] + backwardZwanzig[i]);
222         }
223         break;
224       case SeedType.ZEROS:
225         break;
226       default:
227         throw new IllegalArgumentException("Seed type not supported");
228     }
229   }
230 
231   /**
232    * Get the MBAR free-energy estimates at each lambda value.
233    */
234   @Override
235   public void estimateDG() {
236     estimateDG(false);
237   }
238 
239   /**
240    * Implementation of MBAR solved with self-consistent iteration and L-BFGS optimization.
241    */
242   @Override
243   public void estimateDG(boolean randomSamples) {
244     // Bootstrap needs resetting
245     fill(mbarFreeEnergies, 0.0);
246     seedEnergies();
247 
248     // Throw error if MBAR contains NaNs or Infs
249     if (stream(mbarFreeEnergies).anyMatch(Double::isInfinite) || stream(mbarFreeEnergies).anyMatch(Double::isNaN)) {
250       throw new IllegalArgumentException("MBAR contains NaNs or Infs after seeding.");
251     }
252     double[] prevMBAR;
253 
254     // SCI iterations
255     int iter = 0;
256 
257     // Precompute beta for each state.
258     double[] rtValues = new double[nStates];
259     double[] invRTValues = new double[nStates];
260     for (int i = 0; i < nStates; i++) {
261       rtValues[i] = Constants.R * temperatures[i];
262       invRTValues[i] = 1.0 / rtValues[i];
263     }
264     int numSnaps = eAllFlat[0].length;
265 
266     // Sample random snapshots from each window.
267     int[][] indices = new int[nStates][numSnaps];
268     if (randomSamples) {
269       int[] randomIndices = getBootstrapIndices(numSnaps, random);
270       for (int i = 0; i < nStates; i++) {
271         // Use the same random indices across all lambda values
272         indices[i] = randomIndices;
273       }
274     } else {
275       for (int i = 0; i < numSnaps; i++) {
276         for (int j = 0; j < nStates; j++) {
277           indices[j][i] = i;
278         }
279       }
280     }
281 
282     // Precompute u_kn since it doesn't change
283     u_kn = new double[nStates][numSnaps];
284     N_k = new double[nStates];
285     for (int state = 0; state < nStates; state++) { // For each lambda value
286       for (int n = 0; n < numSnaps; n++) {
287         u_kn[state][n] = eAllFlat[state][indices[state][n]] * invRTValues[state];
288       }
289       N_k[state] = (double) numSnaps / nStates;
290     }
291 
292     // Few SCI iterations used to start optimization of MBAR objective function.
293     // Optimizers can struggle when starting too far from the minimum, but SCI doesn't.
294     double omega = 1.5;
295     for (int i = 0; i < 10; i++) {
296       prevMBAR = copyOf(mbarFreeEnergies, nStates);
297       mbarFreeEnergies = selfConsistentUpdate(u_kn, N_k, mbarFreeEnergies);
298       // Apply SOR
299       for (int j = 0; j < nStates; j++) {
300         mbarFreeEnergies[j] = omega * mbarFreeEnergies[j] + (1 - omega) * prevMBAR[j];
301       }
302       // Throw error if MBAR contains NaNs or Infinities.
303       if (stream(mbarFreeEnergies).anyMatch(Double::isInfinite) || stream(mbarFreeEnergies).anyMatch(Double::isNaN)) {
304         throw new IllegalArgumentException("MBAR contains NaNs or Infs after iteration " + iter);
305       }
306     }
307 
308     try {
309       // L-BFGS optimization for high granularity windows where hessian is expensive
310       if (nStates > 100) {
311         int mCorrections = 5;
312         double[] x = new double[nStates];
313         arraycopy(mbarFreeEnergies, 0, x, 0, nStates);
314         double[] grad = mbarGradient(u_kn, N_k, mbarFreeEnergies);
315         double eps = 1.0E-4;
316         OptimizationListener listener = getOptimizationListener();
317         LBFGS.minimize(nStates, mCorrections, x, mbarObjectiveFunction(u_kn, N_k, mbarFreeEnergies),
318             grad, eps, 1000, this, listener);
319         arraycopy(x, 0, mbarFreeEnergies, 0, nStates);
320       } else { // Newton optimization if hessian inversion isn't too expensive
321         mbarFreeEnergies = newton(mbarFreeEnergies, u_kn, N_k, 1.0, 100, 1.0E-7);
322       }
323     } catch (Exception e) {
324       logger.warning(" L-BFGS/Newton failed to converge. Finishing w/ self-consistent iteration.");
325       logger.warning(e.getMessage());
326     }
327 
328     // Self-consistent iteration is used to finish off optimization of MBAR objective function
329     do {
330       prevMBAR = copyOf(mbarFreeEnergies, nStates);
331       mbarFreeEnergies = selfConsistentUpdate(u_kn, N_k, mbarFreeEnergies);
332       // Apply SOR
333       for (int i = 0; i < nStates; i++) {
334         mbarFreeEnergies[i] = omega * mbarFreeEnergies[i] + (1 - omega) * prevMBAR[i];
335       }
336       // Throw error if MBAR contains NaNs or Infs
337       if (stream(mbarFreeEnergies).anyMatch(Double::isInfinite) || stream(mbarFreeEnergies).anyMatch(Double::isNaN)) {
338         throw new IllegalArgumentException("MBAR contains NaNs or Infs after iteration " + iter);
339       }
340       iter++;
341     } while (!converged(prevMBAR));
342 
343     logger.fine(" MBAR converged after " + iter + " iterations with omega " + omega + ".");
344 
345     // Zero out the first term
346     double f0 = mbarFreeEnergies[0];
347     for (int i = 0; i < nStates; i++) {
348       mbarFreeEnergies[i] -= f0;
349     }
350 
351     // Calculate uncertainties
352     mbarUncertainties = mbarUncertaintyCalc(u_kn, N_k, mbarFreeEnergies);
353     totalMBARUncertainty = mbarTotalUncertaintyCalc(u_kn, N_k, mbarFreeEnergies);
354     diffMatrix = diffMatrixCalculation(u_kn, N_k, mbarFreeEnergies);
355 
356     // Convert to kcal/mol & calculate differences/sums
357     for (int i = 0; i < nStates; i++) {
358       mbarFreeEnergies[i] = mbarFreeEnergies[i] * rtValues[i];
359     }
360 
361     for (int i = 0; i < nFreeEnergyDiffs; i++) {
362       mbarEstimates[i] = mbarFreeEnergies[i + 1] - mbarFreeEnergies[i];
363     }
364 
365     totalMBAREstimate = stream(mbarEstimates).sum();
366   }
367 
368   /**
369    * Checks if the MBAR free energy estimates have converged by comparing the difference
370    * between the previous and current free energies. The tolerance is set by the user.
371    * Default is 1.0E-7.
372    *
373    * @param prevMBAR previous MBAR free energy estimates.
374    * @return true if converged, false otherwise
375    */
376   private boolean converged(double[] prevMBAR) {
377     double[] differences = new double[prevMBAR.length];
378     for (int i = 0; i < prevMBAR.length; i++) {
379       differences[i] = abs(prevMBAR[i] - mbarFreeEnergies[i]);
380     }
381     return stream(differences).allMatch(d -> d < tolerance);
382   }
383 
384   //////// Methods for calculating MBAR variables, vectors, and matrices. ////////
385 
386   /**
387    * MBAR objective function. This is used for L-BFGS optimization.
388    *
389    * @param u_kn energies
390    * @param N_k  number of samples per state
391    * @param f_k  free energies
392    * @return The objective function value.
393    */
394   private static double mbarObjectiveFunction(double[][] u_kn, double[] N_k, double[] f_k) {
395     if (stream(f_k).anyMatch(Double::isInfinite) || stream(f_k).anyMatch(Double::isNaN)) {
396       throw new IllegalArgumentException("MBAR contains NaNs or Infs.");
397     }
398     int nStates = f_k.length;
399     double[] log_denom_n = new double[u_kn[0].length];
400     for (int i = 0; i < u_kn[0].length; i++) {
401       double[] temp = new double[nStates];
402       double maxTemp = Double.NEGATIVE_INFINITY;
403       for (int j = 0; j < nStates; j++) {
404         temp[j] = f_k[j] - u_kn[j][i];
405         if (temp[j] > maxTemp) {
406           maxTemp = temp[j];
407         }
408       }
409       log_denom_n[i] = logSumExp(temp, N_k, maxTemp);
410     }
411     double[] dotNkFk = new double[N_k.length];
412     for (int i = 0; i < N_k.length; i++) {
413       dotNkFk[i] = N_k[i] * f_k[i];
414     }
415     return stream(log_denom_n).sum() - stream(dotNkFk).sum();
416   }
417 
418   /**
419    * Gradient of the MBAR objective function. This is used for L-BFGS optimization.
420    *
421    * @param u_kn energies
422    * @param N_k  number of samples per state
423    * @param f_k  free energies
424    * @return Gradient for the mbar objective function.
425    */
426   private static double[] mbarGradient(double[][] u_kn, double[] N_k, double[] f_k) {
427     int nStates = f_k.length;
428     double[] log_num_k = new double[nStates];
429     double[] log_denom_n = new double[u_kn[0].length];
430     double[][] logDiff = new double[u_kn.length][u_kn[0].length];
431     double maxLogDiff = Double.NEGATIVE_INFINITY;
432     for (int i = 0; i < u_kn[0].length; i++) {
433       double[] temp = new double[nStates];
434       double maxTemp = Double.NEGATIVE_INFINITY;
435       for (int j = 0; j < nStates; j++) {
436         temp[j] = f_k[j] - u_kn[j][i];
437         if (temp[j] > maxTemp) {
438           maxTemp = temp[j];
439         }
440       }
441       log_denom_n[i] = logSumExp(temp, N_k, maxTemp);
442       for (int j = 0; j < nStates; j++) {
443         logDiff[j][i] = -log_denom_n[i] - u_kn[j][i];
444         if (logDiff[j][i] > maxLogDiff) {
445           maxLogDiff = logDiff[j][i];
446         }
447       }
448     }
449     for (int i = 0; i < nStates; i++) {
450       log_num_k[i] = logSumExp(logDiff[i], maxLogDiff);
451     }
452     double[] grad = new double[nStates];
453     for (int i = 0; i < nStates; i++) {
454       grad[i] = -1.0 * N_k[i] * (1.0 - exp(f_k[i] + log_num_k[i]));
455     }
456     return grad;
457   }
458 
459   /**
460    * Hessian of the MBAR objective function. This is used for Newton optimization.
461    *
462    * @param u_kn energies
463    * @param N_k  number of samples per state
464    * @param f_k  free energies
465    * @return Hessian for the mbar objective function.
466    */
467   private static double[][] mbarHessian(double[][] u_kn, double[] N_k, double[] f_k) {
468     int nStates = f_k.length;
469     double[][] W = mbarW(u_kn, N_k, f_k);
470     // h = dot(W.T, W) * N_k * N_k[:, newaxis] - diag(W.sum(0) * N_k)
471     double[][] hessian = new double[nStates][nStates];
472     for (int i = 0; i < nStates; i++) {
473       for (int j = 0; j < nStates; j++) {
474         double sum = 0.0;
475         for (int k = 0; k < u_kn[0].length; k++) {
476           sum += W[i][k] * W[j][k];
477         }
478         hessian[i][j] = sum * N_k[i] * N_k[j];
479       }
480       double wSum = 0.0;
481       for (int k = 0; k < W[i].length; k++) {
482         wSum += W[i][k];
483       }
484       hessian[i][i] -= wSum * N_k[i];
485     }
486     // h = -h
487     for (int i = 0; i < nStates; i++) {
488       for (int j = 0; j < nStates; j++) {
489         hessian[i][j] = -hessian[i][j];
490       }
491     }
492     return hessian;
493   }
494 
495   /**
496    * W = exp(f_k - u_kn.T - log_denominator_n[:, newaxis])
497    *
498    * @param u_kn energies
499    * @param N_k  number of samples per state
500    * @param f_k  free energies
501    * @return W matrix.
502    */
503   private static double[][] mbarW(double[][] u_kn, double[] N_k, double[] f_k) {
504     int nStates = f_k.length;
505     double[] log_denom_n = new double[u_kn[0].length];
506     double[][] logDiff = new double[u_kn.length][u_kn[0].length];
507     double maxLogDiff = Double.NEGATIVE_INFINITY;
508     for (int i = 0; i < u_kn[0].length; i++) {
509       double[] temp = new double[nStates];
510       double maxTemp = Double.NEGATIVE_INFINITY;
511       for (int j = 0; j < nStates; j++) {
512         temp[j] = f_k[j] - u_kn[j][i];
513         if (temp[j] > maxTemp) {
514           maxTemp = temp[j];
515         }
516       }
517       log_denom_n[i] = logSumExp(temp, N_k, maxTemp);
518       for (int j = 0; j < nStates; j++) {
519         logDiff[j][i] = -log_denom_n[i] - u_kn[j][i];
520         if (logDiff[j][i] > maxLogDiff) {
521           maxLogDiff = logDiff[j][i];
522         }
523       }
524     }
525     // logW = f_k - u_kn.T - log_denominator_n[:, newaxis]
526     double[][] W = new double[nStates][u_kn[0].length];
527     for (int i = 0; i < nStates; i++) {
528       for (int j = 0; j < u_kn[0].length; j++) {
529         W[i][j] = exp(f_k[i] - u_kn[i][j] - log_denom_n[j]);
530       }
531     }
532     return W;
533   }
534 
535   /**
536    * logW = f_k - u_kn.T - log_denominator_n[:, newaxis]
537    *
538    * @param u_kn energies
539    * @param N_k  number of samples per state
540    * @param f_k  free energies
541    * @return logW matrix.
542    */
543   private static double[][] mbarLogW(double[][] u_kn, double[] N_k, double[] f_k) {
544     int nStates = f_k.length;
545     // double[] log_num_k = new double[nStates];
546     double[] log_denom_n = new double[u_kn[0].length];
547     double[][] logDiff = new double[u_kn.length][u_kn[0].length];
548     double maxLogDiff = Double.NEGATIVE_INFINITY;
549     for (int i = 0; i < u_kn[0].length; i++) {
550       double[] temp = new double[nStates];
551       double maxTemp = Double.NEGATIVE_INFINITY;
552       for (int j = 0; j < nStates; j++) {
553         temp[j] = f_k[j] - u_kn[j][i];
554         if (temp[j] > maxTemp) {
555           maxTemp = temp[j];
556         }
557       }
558       log_denom_n[i] = logSumExp(temp, N_k, maxTemp);
559       for (int j = 0; j < nStates; j++) {
560         logDiff[j][i] = -log_denom_n[i] - u_kn[j][i];
561         if (logDiff[j][i] > maxLogDiff) {
562           maxLogDiff = logDiff[j][i];
563         }
564       }
565     }
566     // logW = f_k - u_kn.T - log_denominator_n[:, newaxis]
567     double[][] logW = new double[nStates][u_kn[0].length];
568     for (int i = 0; i < nStates; i++) {
569       for (int j = 0; j < u_kn[0].length; j++) {
570         logW[i][j] = f_k[i] - u_kn[i][j] - log_denom_n[j];
571       }
572     }
573     return logW;
574   }
575 
576   /**
577    * Theta = W.T @ (I - W @ diag(N_k) @ W.T)^-1 @ W.
578    * <p>
579    * Requires calculation and inversion of W matrix.
580    * D4 from supp info of MBAR paper used instead.
581    *
582    * @param u_kn energies
583    * @param N_k  number of samples per state
584    * @param f_k  free energies
585    * @return Theta matrix.
586    */
587   private static double[][] mbarTheta(double[][] u_kn, double[] N_k, double[] f_k) {
588     // SVD of W
589     double[][] W = mbarW(u_kn, N_k, f_k);
590     RealMatrix WMatrix = MatrixUtils.createRealMatrix(W).transpose();
591     RealMatrix I = MatrixUtils.createRealIdentityMatrix(f_k.length);
592     RealMatrix NkMatrix = MatrixUtils.createRealDiagonalMatrix(N_k);
593     SingularValueDecomposition svd = new SingularValueDecomposition(WMatrix);
594     RealMatrix V = svd.getV();
595     RealMatrix S = MatrixUtils.createRealDiagonalMatrix(svd.getSingularValues());
596 
597     // W.T @ (I - W @ diag(N_k) @ W.T)^-1 @ W
598     // = V @ S @ (I - S @ V.T @ diag(N_k) @ V @ S)^-1 @ S @ V.T
599     RealMatrix theta = S.multiply(V.transpose());
600     theta = theta.multiply(NkMatrix).multiply(V).multiply(S);
601     theta = I.subtract(theta);
602     theta = new SingularValueDecomposition(theta).getSolver().getInverse(); // pinv equivalent
603     theta = V.multiply(S).multiply(theta).multiply(S).multiply(V.transpose());
604 
605     return theta.getData();
606   }
607 
608   /**
609    * MBAR uncertainty calculation.
610    *
611    * @param u_kn energies
612    * @param N_k  number of samples per state
613    * @param f_k  free energies
614    * @return Uncertainties for the MBAR free energy estimates.
615    */
616   private static double[] mbarUncertaintyCalc(double[][] u_kn, double[] N_k, double[] f_k) {
617     double[][] theta = mbarTheta(u_kn, N_k, f_k);
618     double[] uncertainties = new double[f_k.length - 1];
619     // del(dFij) = Theta[i,i] - 2 * Theta[i,j] + Theta[j,j]
620     for (int i = 0; i < f_k.length - 1; i++) {
621       uncertainties[i] = sqrt(theta[i][i] - 2 * theta[i][i + 1] + theta[i + 1][i + 1]);
622     }
623     return uncertainties;
624   }
625 
626   /**
627    * MBAR total uncertainty calculation.
628    *
629    * @param u_kn energies
630    * @param N_k  number of samples per state
631    * @param f_k  free energies
632    * @return Total uncertainty for the MBAR free energy estimates.
633    */
634   private static double mbarTotalUncertaintyCalc(double[][] u_kn, double[] N_k, double[] f_k) {
635     double[][] theta = mbarTheta(u_kn, N_k, f_k);
636     int nStates = f_k.length;
637     return sqrt(theta[0][0] - 2 * theta[0][nStates - 1] + theta[nStates - 1][nStates - 1]);
638   }
639 
640   /**
641    * MBAR diff Matrix calculation.
642    *
643    * @param u_kn energies
644    * @param N_k  number of samples per state
645    * @param f_k  free energies
646    * @return Diff matrix for the MBAR free energy estimates.
647    */
648   private static double[][] diffMatrixCalculation(double[][] u_kn, double[] N_k, double[] f_k) {
649     double[][] theta = mbarTheta(u_kn, N_k, f_k);
650     double[][] diffMatrix = new double[f_k.length][f_k.length];
651     for (int i = 0; i < f_k.length; i++) {
652       for (int j = 0; j < f_k.length; j++) {
653         diffMatrix[i][j] = sqrt(theta[i][i] - 2 * theta[i][j] + theta[j][j]);
654       }
655     }
656     return diffMatrix;
657   }
658 
659   //////// Methods for solving MBAR with self-consistent iteration, L-BFGS optimization, and Newton-Raphson. ////////
660 
661   /**
662    * Self-consistent iteration to update free energies.
663    *
664    * @param u_kn energies
665    * @param N_k  number of samples per state
666    * @param f_k  free energies
667    * @return updated free energies
668    */
669   private static double[] selfConsistentUpdate(double[][] u_kn, double[] N_k, double[] f_k) {
670     int nStates = f_k.length;
671     double[] updatedF_k = new double[nStates];
672     double[] log_denom_n = new double[u_kn[0].length];
673     double[][] logDiff = new double[u_kn.length][u_kn[0].length];
674     double[] maxLogDiff = new double[nStates];
675     fill(maxLogDiff, Double.NEGATIVE_INFINITY);
676     for (int i = 0; i < u_kn[0].length; i++) {
677       double[] temp = new double[nStates];
678       double maxTemp = Double.NEGATIVE_INFINITY;
679       for (int j = 0; j < nStates; j++) {
680         temp[j] = f_k[j] - u_kn[j][i];
681         if (temp[j] > maxTemp) {
682           maxTemp = temp[j];
683         }
684       }
685       log_denom_n[i] = logSumExp(temp, N_k, maxTemp);
686       for (int j = 0; j < nStates; j++) {
687         logDiff[j][i] = -log_denom_n[i] - u_kn[j][i];
688         if (logDiff[j][i] > maxLogDiff[j]) {
689           maxLogDiff[j] = logDiff[j][i];
690         }
691       }
692     }
693 
694     for (int i = 0; i < nStates; i++) {
695       updatedF_k[i] = -1.0 * logSumExp(logDiff[i], maxLogDiff[i]);
696     }
697 
698     // Constrain f1=0 over the course of iterations to prevent uncontrolled growth in magnitude
699     double norm = updatedF_k[0];
700     updatedF_k[0] = 0.0;
701     for (int i = 1; i < nStates; i++) {
702       updatedF_k[i] = updatedF_k[i] - norm;
703     }
704 
705     return updatedF_k;
706   }
707 
708   /**
709    * Newton-Raphson step for MBAR optimization.
710    *
711    * @param n        current free energies.
712    * @param grad     gradient of the objective function.
713    * @param hessian  hessian of the objective function.
714    * @param stepSize step size for the Newton-Raphson step.
715    * @return updated free energies.
716    */
717   private static double[] newtonStep(double[] n, double[] grad, double[][] hessian, double stepSize) {
718     double[] nPlusOne = new double[n.length];
719     RealMatrix hessianInverse = MatrixUtils.inverse(MatrixUtils.createRealMatrix(hessian));
720     double[] step = hessianInverse.preMultiply(grad);
721     // Zero out the first term of the step
722     double temp = step[0];
723     step[0] = 0.0;
724     for (int i = 1; i < step.length; i++) {
725       step[i] -= temp;
726     }
727     for (int i = 0; i < n.length; i++) {
728       nPlusOne[i] = n[i] - step[i] * stepSize;
729     }
730     return nPlusOne;
731   }
732 
733   /**
734    * Newton-Raphson optimization for MBAR.
735    *
736    * @param f_k       free energies.
737    * @param u_kn      energies.
738    * @param N_k       number of samples per state.
739    * @param stepSize  step size for the Newton-Raphson step.
740    * @param maxIter   maximum number of iterations.
741    * @param tolerance convergence tolerance.
742    * @return updated free energies.
743    */
744   private static double[] newton(double[] f_k, double[][] u_kn, double[] N_k, double stepSize, int maxIter, double tolerance) {
745     double[] grad = mbarGradient(u_kn, N_k, f_k);
746     double[][] hessian = mbarHessian(u_kn, N_k, f_k);
747     double[] f_kPlusOne = newtonStep(f_k, grad, hessian, stepSize);
748     int iter = 1;
749     while (iter < maxIter && MathArrays.distance1(f_k, f_kPlusOne) > tolerance) {
750       f_k = f_kPlusOne;
751       grad = mbarGradient(u_kn, N_k, f_k);
752       hessian = mbarHessian(u_kn, N_k, f_k);
753       f_kPlusOne = newtonStep(f_k, grad, hessian, stepSize);
754       iter++;
755     }
756 
757     logger.fine(" Newton converged after " + iter + " iterations.");
758 
759     return f_kPlusOne;
760   }
761 
762   /**
763    * Calculates the log of the sum of the exponentials of the given values.
764    * <p>
765    * The max value is subtracted from each value in the array before exponentiation to prevent overflow.
766    *
767    * @param values The values to exponentiate and sum.
768    * @param max    The max value is subtracted from each value in the array prior to exponentiation.
769    * @return the sum
770    */
771   private static double logSumExp(double[] values, double max) {
772     double[] b = fill(new double[values.length], 1.0);
773     return logSumExp(values, b, max);
774   }
775 
776   /**
777    * Calculates the log of the sum of the exponentials of the given values.
778    * <p>
779    * The max value is subtracted from each value in the array before exponentiation to prevent overflow.
780    *
781    * @param values The values to exponentiate and sum.
782    * @param max    The max value is subtracted from each value in the array prior to exponentiation.
783    * @param b      Weights for each value in the array.
784    * @return the sum
785    */
786   private static double logSumExp(double[] values, double[] b, double max) {
787     // ChatGPT mostly wrote this and I tweaked it to match more closely with scipy's logsumexp implementation
788     // Find the maximum value in the array.
789     assert values.length == b.length : "values and b must be the same length";
790 
791     // Subtract the maximum value from each value in the array, exponentiate the result, and add up these values.
792     double sum = 0.0;
793     for (int i = 0; i < values.length; i++) {
794       sum += b[i] * exp(values[i] - max);
795     }
796 
797     // Take the natural logarithm of the sum and add the maximum value back in.
798     return max + log(sum);
799   }
800 
801   /**
802    * TODO: Log out the MBAR optimization progress.
803    *
804    * @return
805    */
806   private OptimizationListener getOptimizationListener() {
807     return new OptimizationListener() {
808       @Override
809       public boolean optimizationUpdate(int iter, int nBFGS, int nFunctionEvals, double gradientRMS,
810                                         double coordinateRMS, double f, double df, double angle,
811                                         LineSearch.LineSearchResult info) {
812         return true;
813       }
814     };
815   }
816 
817   /**
818    * MBAR objective function evaluation at a given free energy estimate for L-BFGS optimization.
819    *
820    * @param x Input parameters.
821    * @return The objective function value at the given parameters.
822    */
823   @Override
824   public double energy(double[] x) {
825     // Zero out the first term
826     double tempO = x[0];
827     x[0] = 0.0;
828     for (int i = 1; i < x.length; i++) {
829       x[i] -= tempO;
830     }
831     return mbarObjectiveFunction(u_kn, N_k, x);
832   }
833 
834   /**
835    * MBAR objective function evaluation and gradient at a given free energy estimate for L-BFGS optimization.
836    *
837    * @param x Input parameters.
838    * @param g The gradient with respect to each parameter.
839    * @return The objective function value at the given parameters.
840    */
841   @Override
842   public double energyAndGradient(double[] x, double[] g) {
843     double tempO = x[0];
844     x[0] = 0.0;
845     for (int i = 1; i < x.length; i++) {
846       x[i] -= tempO;
847     }
848     double[] tempG = mbarGradient(u_kn, N_k, x);
849     arraycopy(tempG, 0, g, 0, g.length);
850     return mbarObjectiveFunction(u_kn, N_k, x);
851   }
852 
853   @Override
854   public double[] getCoordinates(double[] parameters) {
855     return new double[0];
856   }
857 
858   @Override
859   public int getNumberOfVariables() {
860     return 0;
861   }
862 
863   @Override
864   public double[] getScaling() {
865     return null;
866   }
867 
868   @Override
869   public void setScaling(double[] scaling) {
870   }
871 
872   @Override
873   public double getTotalEnergy() {
874     return 0;
875   }
876 
877   //////// Getters and setters ////////
878   public BennettAcceptanceRatio getBAR() {
879     return new BennettAcceptanceRatio(lamValues, eLow, eAt, eHigh, temperatures);
880   }
881 
882   @Override
883   public MultistateBennettAcceptanceRatio copyEstimator() {
884     return new MultistateBennettAcceptanceRatio(lamValues, eAll, temperatures, tolerance, seedType);
885   }
886 
887   @Override
888   public double[] getBinEnergies() {
889     return mbarEstimates;
890   }
891 
892   public double[] getMBARFreeEnergies() {
893     return mbarFreeEnergies;
894   }
895 
896   @Override
897   public double[] getBinUncertainties() {
898     return mbarUncertainties;
899   }
900 
901   public double[][] getDiffMatrix() {
902     return diffMatrix;
903   }
904 
905   @Override
906   public double getFreeEnergy() {
907     return totalMBAREstimate;
908   }
909 
910   @Override
911   public double getUncertainty() {
912     return totalMBARUncertainty;
913   }
914 
915   @Override
916   public int numberOfBins() {
917     return nFreeEnergyDiffs;
918   }
919 
920   @Override
921   public double[] getBinEnthalpies() {
922     return mbarEnthalpy;
923   }
924 
925   /**
926    * Harmonic oscillators test case generates data for testing the MBAR implementation
927    */
928   public static class HarmonicOscillatorsTestCase {
929 
930     /**
931      * Inverse temperature.
932      */
933     private final double beta;
934     /**
935      * Equilibrium positions.
936      */
937     private final double[] O_k;
938     /**
939      * Number of states.
940      */
941     private final int n_states;
942     /**
943      * Spring constants.
944      */
945     private final double[] K_k;
946 
947     /**
948      * Constructor for HarmonicOscillatorsTestCase
949      *
950      * @param O_k  array of equilibrium positions
951      * @param K_k  array of spring constants
952      * @param beta inverse temperature
953      */
954     public HarmonicOscillatorsTestCase(double[] O_k, double[] K_k, double beta) {
955       this.beta = beta;
956       this.O_k = O_k;
957       this.n_states = O_k.length;
958       this.K_k = K_k;
959 
960       if (this.K_k.length != this.n_states) {
961         throw new IllegalArgumentException("Lengths of K_k and O_k should be equal");
962       }
963     }
964 
965     public double[] analyticalMeans() {
966       return O_k;
967     }
968 
969     public double[] analyticalVariances() {
970       double[] variances = new double[n_states];
971       for (int i = 0; i < n_states; i++) {
972         variances[i] = 1.0 / (beta * K_k[i]);
973       }
974       return variances;
975     }
976 
977     public double[] analyticalStandardDeviations() {
978       double[] deviations = new double[n_states];
979       for (int i = 0; i < n_states; i++) {
980         deviations[i] = Math.sqrt(1.0 / (beta * K_k[i]));
981       }
982       return deviations;
983     }
984 
985     public double[] analyticalObservable(String observable) {
986       double[] result = new double[n_states];
987 
988       switch (observable) {
989         case "position" -> {
990           return analyticalMeans();
991         }
992         case "potential energy" -> {
993           for (int i = 0; i < n_states; i++) {
994             result[i] = 0.5 / beta;
995           }
996         }
997         case "position^2" -> {
998           for (int i = 0; i < n_states; i++) {
999             result[i] = 1.0 / (beta * K_k[i]) + Math.pow(O_k[i], 2);
1000           }
1001         }
1002         case "RMS displacement" -> {
1003           return analyticalStandardDeviations();
1004         }
1005       }
1006 
1007       return result;
1008     }
1009 
1010     public double[] analyticalFreeEnergies() {
1011       int subtractComponentIndex = 0;
1012       double[] fe = new double[n_states];
1013       double subtract = 0.0;
1014       for (int i = 0; i < n_states; i++) {
1015         fe[i] = -0.5 * Math.log(2 * Math.PI / (beta * K_k[i]));
1016         if (i == 0) {
1017           subtract = fe[subtractComponentIndex];
1018         }
1019         fe[i] -= subtract;
1020       }
1021       return fe;
1022     }
1023 
1024     public double[] analyticalEntropies(int subtractComponent) {
1025       double[] entropies = new double[n_states];
1026       double[] potentialEnergy = analyticalObservable("analytical entropy");
1027       double[] freeEnergies = analyticalFreeEnergies();
1028 
1029       for (int i = 0; i < n_states; i++) {
1030         entropies[i] = potentialEnergy[i] - freeEnergies[i];
1031       }
1032 
1033       return entropies;
1034     }
1035 
1036     /**
1037      * Sample from harmonic oscillator w/ gaussian & std
1038      *
1039      * @param N_k  number of samples per state
1040      * @param mode only u_kn -> return K x N_tot matrix where u_kn[k,n] is reduced potential of sample n evaluated at state k
1041      * @return u_kn[k, n] is reduced potential of sample n evaluated at state k
1042      */
1043     public Object[] sample(int[] N_k, String mode, Long seed) {
1044       Random random = new Random(seed);
1045 
1046       int N_max = 0;
1047       for (int N : N_k) {
1048         if (N > N_max) {
1049           N_max = N;
1050         }
1051       }
1052 
1053       int N_tot = 0;
1054       for (int N : N_k) {
1055         N_tot += N;
1056       }
1057 
1058       double[][] x_kn = new double[n_states][N_max];
1059       double[][] u_kn = new double[n_states][N_tot];
1060       double[][][] u_kln = new double[n_states][n_states][N_max];
1061       double[] x_n = new double[N_tot];
1062       int[] s_n = new int[N_tot];
1063 
1064       // Sample harmonic oscillators
1065       int index = 0;
1066       for (int k = 0; k < n_states; k++) {
1067         double x0 = O_k[k];
1068         double sigma = Math.sqrt(1.0 / (beta * K_k[k]));
1069 
1070         // Number of samples
1071         for (int n = 0; n < N_k[k]; n++) {
1072           double x = x0 + random.nextGaussian() * sigma;
1073 
1074           x_kn[k][n] = x;
1075           x_n[index] = x;
1076           s_n[index] = k;
1077 
1078           // Potential energy evaluations
1079           for (int l = 0; l < n_states; l++) {
1080             double u = beta * 0.5 * K_k[l] * Math.pow(x - O_k[l], 2.0);
1081             u_kln[k][l][n] = u;
1082             u_kn[l][index] = u;
1083           }
1084 
1085           index++;
1086         }
1087       }
1088 
1089       // Setting corrections
1090       if ("u_kn".equals(mode)) {
1091         return new Object[]{x_n, u_kn, N_k, s_n};
1092       } else if ("u_kln".equals(mode)) {
1093         return new Object[]{x_n, u_kln, N_k, s_n};
1094       } else {
1095         throw new IllegalArgumentException("Unknown mode: " + mode);
1096       }
1097     }
1098 
1099     public static Object[] evenlySpacedOscillators(
1100         int n_states, int n_samplesPerState, double lower_O_k, double upper_O_k,
1101         double lower_K_k, double upper_K_k, Long seed) {
1102       // Random random = new Random(seed);
1103 
1104       double[] O_k = new double[n_states];
1105       double[] K_k = new double[n_states];
1106       int[] N_k = new int[n_states];
1107 
1108       double stepO_k = (upper_O_k - lower_O_k) / (n_states - 1);
1109       double stepK_k = (upper_K_k - lower_K_k) / (n_states - 1);
1110 
1111       for (int i = 0; i < n_states; i++) {
1112         O_k[i] = lower_O_k + i * stepO_k;
1113         K_k[i] = lower_K_k + i * stepK_k;
1114         N_k[i] = n_samplesPerState;
1115       }
1116 
1117       HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(O_k, K_k, 1.0);
1118       Object[] result = testCase.sample(N_k, "u_kn", System.currentTimeMillis());
1119 
1120       return new Object[]{testCase, result[0], result[1], result[2], result[3]};
1121     }
1122 
1123     public static void main(String[] args) {
1124       // Example parameters
1125       double[] O_k = {0, 1, 2, 3, 4};
1126       double[] K_k = {1, 2, 4, 8, 16};
1127       double beta = 1.0;
1128       System.out.println("Beta: " + beta);
1129 
1130       // Create an instance of HarmonicOscillatorsTestCase
1131       HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(O_k, K_k, beta);
1132 
1133       // Print results of various functions
1134       System.out.println("Analytical Means: " + Arrays.toString(testCase.analyticalMeans()));
1135       System.out.println("Analytical Variances: " + Arrays.toString(testCase.analyticalVariances()));
1136       System.out.println("Analytical Standard Deviations: " + Arrays.toString(testCase.analyticalStandardDeviations()));
1137       System.out.println("Analytical Free Energies: " + Arrays.toString(testCase.analyticalFreeEnergies()));
1138 
1139       // Example usage of sample function with u_kn mode
1140       int[] N_k = {10, 20, 30, 40, 50};
1141       String setting = "u_kln";
1142       Object[] sampleResult = testCase.sample(N_k, setting, System.currentTimeMillis());
1143 
1144       System.out.println("Sample x_n: " + Arrays.toString((double[]) sampleResult[0]));
1145       if ("u_kn".equals(setting)) {
1146         System.out.println("Sample u_kn: " + Arrays.deepToString((double[][]) sampleResult[1]));
1147       } else {
1148         System.out.println("Sample u_kln: " + Arrays.deepToString((double[][][]) sampleResult[1]));
1149       }
1150       System.out.println("Sample N_k: " + Arrays.toString((int[]) sampleResult[2]));
1151       System.out.println("Sample s_n: " + Arrays.toString((int[]) sampleResult[3]));
1152     }
1153   }
1154 
1155   public static void writeFile(double[][] energies, File file, double temperature) {
1156     try (FileWriter fw = new FileWriter(file);
1157          BufferedWriter bw = new BufferedWriter(fw)) {
1158       // Write the number of snapshots and the temperature on the first line
1159       bw.write(energies[0].length + " " + temperature);
1160       bw.newLine();
1161 
1162       // Write the energies
1163       StringBuilder sb = new StringBuilder();
1164       for (int i = 0; i < energies[0].length; i++) {
1165         sb.append("     ").append(i).append(" "); // Write the index of the snapshot
1166         for (int j = 0; j < energies.length; j++) {
1167           sb.append("    ").append(energies[j][i]).append(" ");
1168         }
1169         sb.append("\n");
1170         bw.write(sb.toString());
1171         sb = new StringBuilder(); // Very important
1172       }
1173     } catch (IOException e) {
1174       e.printStackTrace();
1175     }
1176   }
1177 
1178   public static void main(String[] args) {
1179     double[] O_k = {1, 2, 3, 4}; // Equilibrium positions
1180     double[] K_k = {.5, 1.0, 1.5, 2}; // Spring constants
1181     int[] N_k = {10000, 10000, 10000, 10000}; // No support for different number of snapshots
1182     double beta = 1.0;
1183 
1184     // Create an instance of HarmonicOscillatorsTestCase
1185     HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(O_k, K_k, beta);
1186 
1187     // Generate sample data
1188     String setting = "u_kln";
1189     System.out.print("Generating sample data... ");
1190     Object[] sampleResult = testCase.sample(N_k, setting, (long) 0); // Set seed to fixed value for reproducibility
1191     System.out.println("done. \n");
1192     double[][][] u_kln = (double[][][]) sampleResult[1];
1193     double[] temps = {1 / Constants.R};
1194 
1195     // Write file for comparison with pymbar
1196     // Output to forcefieldx/testing/mbar/data/harmonic_oscillators/mbarFiles/energies_{i}.mbar
1197     // Get absolute path to root of project
1198 
1199     /*
1200     String rootPath = new File("").getAbsolutePath();
1201     File outputPath = new File(rootPath + "/testing/mbar/data/harmonic_oscillators/mbarFiles");
1202     if (!outputPath.exists() && !outputPath.mkdirs()) {
1203       throw new RuntimeException("Failed to create directory: " + outputPath);
1204     }
1205 
1206     double[] temperatures = new double[O_k.length];
1207     Arrays.fill(temperatures, temps[0]);
1208     for (int i = 0; i < u_kln.length; i++) {
1209       File file = new File(outputPath, "energies_" + i + ".mbar");
1210       writeFile(u_kln[i], file, temperatures[i]);
1211     } */
1212 
1213     // Create an instance of MultistateBennettAcceptanceRatio
1214     System.out.print("Creating MBAR instance and estimateDG() with standard tol & Zwanzig seeding.");
1215     MultistateBennettAcceptanceRatio mbar = new MultistateBennettAcceptanceRatio(O_k, u_kln, temps, 1.0E-7, SeedType.ZWANZIG);
1216     double[] mbarFEEstimates = Arrays.copyOf(mbar.mbarFreeEnergies, mbar.mbarFreeEnergies.length);
1217     double[] mbarUncertainties = Arrays.copyOf(mbar.mbarUncertainties, mbar.mbarUncertainties.length);
1218     double[][] mbarDiffMatrix = Arrays.copyOf(mbar.diffMatrix, mbar.diffMatrix.length);
1219 
1220     EstimateBootstrapper bootstrapper = new EstimateBootstrapper(mbar);
1221     bootstrapper.bootstrap(50);
1222     System.out.println("done. \n");
1223 
1224     // Get the analytical free energy differences
1225     double[] analyticalFreeEnergies = testCase.analyticalFreeEnergies();
1226     // Calculate the error
1227     double[] error = new double[analyticalFreeEnergies.length];
1228     for (int i = 0; i < error.length; i++) {
1229       error[i] = -mbarFEEstimates[i] + analyticalFreeEnergies[i];
1230     }
1231 
1232     // Compare the calculated free energy differences with the analytical ones
1233     System.out.println("MBAR Free Energies:       " + Arrays.toString(mbarFEEstimates));
1234     System.out.println("Analytical Free Energies: " + Arrays.toString(analyticalFreeEnergies));
1235     System.out.println("MBAR Uncertainties:       " + Arrays.toString(mbarUncertainties));
1236     System.out.println("Free Energy Error:        " + Arrays.toString(error));
1237     System.out.println();
1238     System.out.println("Diff Matrix: ");
1239     for (double[] matrix : mbarDiffMatrix) {
1240       System.out.println(Arrays.toString(matrix));
1241     }
1242     System.out.println("\n\n");
1243 
1244     // Get the calculated free energy differences
1245     double[] mbarBootstrappedEstimates = bootstrapper.getFE();
1246     double[] mbarBootstrappedFE = new double[mbarBootstrappedEstimates.length + 1];
1247     for (int i = 0; i < mbarBootstrappedEstimates.length; i++) {
1248       mbarBootstrappedFE[i + 1] = mbarBootstrappedEstimates[i] + mbarBootstrappedFE[i];
1249     }
1250     mbarUncertainties = bootstrapper.getUncertainty();
1251     // Calculate the error
1252     double[] errors = new double[mbarBootstrappedFE.length];
1253     for (int i = 0; i < errors.length; i++) {
1254       errors[i] = -mbarBootstrappedFE[i] + analyticalFreeEnergies[i];
1255     }
1256 
1257     System.out.println("MBAR Bootstrapped Estimates:  " + Arrays.toString(mbarBootstrappedFE));
1258     System.out.println("Analytical Estimates:         " + Arrays.toString(analyticalFreeEnergies));
1259     System.out.println("MBAR Bootstrap Uncertainties: " + Arrays.toString(mbarUncertainties));
1260     System.out.println("Bootstrap Free Energy Error:  " + Arrays.toString(errors));
1261   }
1262 }