View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2023.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.estimator;
39  
40  import ffx.numerics.OptimizationInterface;
41  import ffx.numerics.integrate.DataSet;
42  import ffx.numerics.integrate.DoublesDataSet;
43  import ffx.numerics.integrate.Integrate1DNumeric;
44  import ffx.numerics.optimization.LBFGS;
45  import ffx.numerics.optimization.LineSearch;
46  import ffx.numerics.optimization.OptimizationListener;
47  import ffx.utilities.Constants;
48  import org.apache.commons.math3.linear.MatrixUtils;
49  import org.apache.commons.math3.linear.RealMatrix;
50  import org.apache.commons.math3.linear.SingularValueDecomposition;
51  
52  import java.io.BufferedWriter;
53  import java.io.File;
54  import java.io.FileWriter;
55  import java.io.IOException;
56  import java.util.ArrayList;
57  import java.util.Arrays;
58  import java.util.Random;
59  import java.util.logging.Logger;
60  
61  import static ffx.numerics.estimator.EstimateBootstrapper.getBootstrapIndices;
62  import static ffx.numerics.estimator.Zwanzig.Directionality.BACKWARDS;
63  import static ffx.numerics.estimator.Zwanzig.Directionality.FORWARDS;
64  import static java.lang.System.arraycopy;
65  import static java.util.Arrays.copyOf;
66  import static java.util.Arrays.stream;
67  import static org.apache.commons.lang3.ArrayFill.fill;
68  import static org.apache.commons.math3.util.FastMath.abs;
69  import static org.apache.commons.math3.util.FastMath.exp;
70  import static org.apache.commons.math3.util.FastMath.log;
71  import static org.apache.commons.math3.util.FastMath.sqrt;
72  
73  /**
74   * The MultistateBennettAcceptanceRatio class defines a statistical estimator based on a generalization
75   * to the Bennett Acceptance Ratio (BAR) method for multiple lambda windows. It requires an input of
76   * K X N array of energies (every window at every snap at every lambda value). No support for different
77   * number of snapshots at each window. This will be caught by the filter, but not by the Harmonic Oscillators
78   * testcase.
79   * <p>
80   * This class implements the method discussed in:
81   * Shirts, M. R. and Chodera, J. D. (2008) Statistically optimal analysis of snaps from multiple equilibrium
82   * states. J. Chem. Phys. 129, 124105. doi:10.1063/1.2978177
83   * <p>
84   * This class is based heavily on the pymbar code, which is available from
85   * <a href="https://github.com/choderalab/pymbar/tree/master">Github</a>.
86   *
87   * @author Matthew J. Speranza
88   * @since 1.0
89   */
90  public class MultistateBennettAcceptanceRatio extends SequentialEstimator implements BootstrappableEstimator, OptimizationInterface {
91    private static final Logger logger = Logger.getLogger(MultistateBennettAcceptanceRatio.class.getName());
92    /**
93     * Default MBAR convergence tolerance.
94     */
95    private static final double DEFAULT_TOLERANCE = 1.0E-7;
96    /**
97     * Number of free of differences between simulation windows. Calculated at the very end.
98     */
99    private final int nFreeEnergyDiffs;
100   /**
101    * MBAR free-energy difference estimates (nFreeEnergyDiffs values).
102    */
103   private final double[] mbarFEDifferenceEstimates;
104   /**
105    * Number of lamda states (basically nFreeEnergyDiffs + 1).
106    */
107   private final int nLambdaStates;
108   /**
109    * MBAR free-energy estimates at each lambda value (nStates values). The first value is defined as 0 throughout to
110    * promote stability. This estimate is novel to the MBAR method and is not seen in BAR or Zwanzig. Only the differences
111    * between these values have physical significance.
112    */
113   private double[] mbarFEEstimates;
114   /**
115    * MBAR observable ensemble estimates.
116    */
117   private double[] mbarObservableEnsembleAverages;
118   private double[] mbarObservableEnsembleAverageUncertainties;
119   /**
120    * MBAR free-energy difference uncertainties.
121    */
122   private double[] mbarUncertainties;
123   /**
124    * Matrix of free-energy difference uncertainties between all i & j
125    */
126   private double[][] uncertaintyMatrix;
127   /**
128    * MBAR convergence tolerance.
129    */
130   private final double tolerance;
131   /**
132    * Random number generator used for bootstrapping.
133    */
134   private final Random random;
135   /**
136    * Total MBAR free-energy difference estimate.
137    */
138   private double totalMBAREstimate;
139   /**
140    * Total MBAR free-energy difference uncertainty.
141    */
142   private double totalMBARUncertainty;
143   /**
144    * MBAR Enthalpy estimates
145    */
146   private double[] mbarEnthalpy;
147 
148   /**
149    * MBAR Entropy estimates
150    */
151   private double[] mbarEntropy;
152 
153   public double[] rtValues;
154 
155   /**
156    * "Reduced" potential energies. -ln(exp(beta * -U)) or more practically U * (1 / RT).
157    * Has shape (nLambdaStates, numSnaps * nLambdaStates)
158    */
159   private double[][] reducedPotentials;
160 
161   private double[][] oAllFlat;
162   private double[][] biasFlat;
163   /**
164    * Seed MBAR calculation with another free energy estimation (BAR,ZWANZIG) or zeros
165    */
166   private SeedType seedType;
167 
168 
169   /**
170    * Enum of MBAR seed types.
171    */
172   public enum SeedType {BAR, ZWANZIG, ZEROS;}
173 
174   public static boolean FORCE_ZEROS_SEED = false;
175   public static boolean VERBOSE = false;
176 
177   /**
178    * Constructor for MBAR estimator.
179    *
180    * @param lambdaValues array of lambda values
181    * @param energiesAll  array of energies at each lambda value
182    * @param temperature  array of temperatures
183    */
184   public MultistateBennettAcceptanceRatio(double[] lambdaValues, double[][][] energiesAll, double[] temperature) {
185     this(lambdaValues, energiesAll, temperature, DEFAULT_TOLERANCE, SeedType.ZWANZIG);
186   }
187 
188   /**
189    * Constructor for MBAR estimator.
190    *
191    * @param lambdaValues array of lambda values
192    * @param energiesAll  array of energies at each lambda value
193    * @param temperature  array of temperatures
194    * @param tolerance    convergence tolerance
195    * @param seedType     seed type for MBAR
196    */
197   public MultistateBennettAcceptanceRatio(double[] lambdaValues, double[][][] energiesAll, double[] temperature,
198                                           double tolerance, SeedType seedType) {
199     super(lambdaValues, energiesAll, temperature);
200     this.tolerance = tolerance;
201     this.seedType = seedType;
202 
203     // MBAR calculates free energy at each lambda value (only the differences between them have physical significance)
204     nLambdaStates = lambdaValues.length;
205     mbarFEEstimates = new double[nLambdaStates];
206 
207     nFreeEnergyDiffs = lambdaValues.length - 1;
208     mbarFEDifferenceEstimates = new double[nFreeEnergyDiffs];
209     mbarUncertainties = new double[nFreeEnergyDiffs];
210     mbarEnthalpy = new double[nFreeEnergyDiffs];
211     mbarEntropy = new double[nFreeEnergyDiffs];
212     random = new Random();
213     estimateDG();
214   }
215 
216   public MultistateBennettAcceptanceRatio(double[] lambdaValues, int[] snaps, double[][] eAllFlat, double[] temperature,
217                                           double tolerance, SeedType seedType) {
218     super(lambdaValues, snaps, eAllFlat, temperature);
219     this.tolerance = tolerance;
220     this.seedType = seedType;
221 
222     // MBAR calculates free energy at each lambda value (only the differences between them have physical significance)
223     nLambdaStates = lambdaValues.length;
224     mbarFEEstimates = new double[nLambdaStates];
225 
226     nFreeEnergyDiffs = lambdaValues.length - 1;
227     mbarFEDifferenceEstimates = new double[nFreeEnergyDiffs];
228     mbarUncertainties = new double[nFreeEnergyDiffs];
229     mbarEnthalpy = new double[nFreeEnergyDiffs];
230     mbarEntropy = new double[nFreeEnergyDiffs];
231     random = new Random();
232     estimateDG();
233   }
234 
235   /**
236    * Set the MBAR seed energies using BAR, Zwanzig, or zeros.
237    */
238   private void seedEnergies() {
239     switch (seedType) {
240       case BAR:
241         try {
242           if (eLambdaMinusdL == null || eLambda == null || eLambdaPlusdL == null) {
243             seedType = SeedType.ZEROS;
244             seedEnergies();
245             return;
246           }
247           SequentialEstimator barEstimator = new BennettAcceptanceRatio(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures);
248           mbarFEEstimates[0] = 0.0;
249           double[] barEstimates = barEstimator.getFreeEnergyDifferences();
250           for (int i = 0; i < nFreeEnergyDiffs; i++) {
251             mbarFEEstimates[i + 1] = mbarFEEstimates[i] + barEstimates[i];
252           }
253           break;
254         } catch (IllegalArgumentException e) {
255           logger.warning(" BAR failed to converge. Zwanzig will be used for seed energies.");
256           seedType = SeedType.ZWANZIG;
257           seedEnergies();
258           return;
259         }
260       case ZWANZIG:
261         try {
262           if (eLambdaMinusdL == null || eLambda == null || eLambdaPlusdL == null) {
263             seedType = SeedType.ZEROS;
264             seedEnergies();
265             return;
266           }
267           Zwanzig forwardsFEP = new Zwanzig(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures, FORWARDS);
268           Zwanzig backwardsFEP = new Zwanzig(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures, BACKWARDS);
269           double[] forwardZwanzig = forwardsFEP.getFreeEnergyDifferences();
270           double[] backwardZwanzig = backwardsFEP.getFreeEnergyDifferences();
271           mbarFEEstimates[0] = 0.0;
272           for (int i = 0; i < nFreeEnergyDiffs; i++) {
273             mbarFEEstimates[i + 1] = mbarFEEstimates[i] + .5 * (forwardZwanzig[i] + backwardZwanzig[i]);
274           }
275           if (stream(mbarFEEstimates).anyMatch(Double::isInfinite) || stream(mbarFEEstimates).anyMatch(Double::isNaN)) {
276             throw new IllegalArgumentException("MBAR contains NaNs or Infs after seeding.");
277           }
278           break;
279         } catch (IllegalArgumentException e) {
280           logger.warning(" Zwanzig failed to converge. Zeros will be used for seed energies.");
281           seedType = SeedType.ZEROS;
282           seedEnergies();
283           return;
284         }
285       case ZEROS:
286         fill(mbarFEEstimates, 0.0);
287         break;
288       default:
289         throw new IllegalArgumentException("Seed type not supported");
290     }
291   }
292 
293   /**
294    * Get the MBAR free-energy estimates at each lambda value.
295    */
296   @Override
297   public void estimateDG() {
298     estimateDG(false);
299   }
300 
301   /**
302    * MBAR solved with self-consistent iteration and Newton/L-BFGS optimization.
303    */
304   @Override
305   public void estimateDG(boolean randomSamples) {
306     if (MultistateBennettAcceptanceRatio.VERBOSE) {
307       logger.setLevel(java.util.logging.Level.FINE);
308     }
309 
310     // Bootstrap needs resetting to zeros
311     fill(mbarFEEstimates, 0.0);
312     if (FORCE_ZEROS_SEED) {
313       seedType = SeedType.ZEROS;
314     }
315     seedEnergies();
316     if (stream(mbarFEEstimates).anyMatch(Double::isInfinite) || stream(mbarFEEstimates).anyMatch(Double::isNaN)) {
317       seedType = SeedType.ZEROS;
318       seedEnergies();
319     }
320     if (MultistateBennettAcceptanceRatio.VERBOSE) {
321       logger.info(" Seed Type: " + seedType);
322       logger.info(" MBAR FE Estimates after seeding: " + Arrays.toString(mbarFEEstimates));
323     }
324 
325     // Precompute beta for each state.
326     rtValues = new double[nLambdaStates];
327     double[] invRTValues = new double[nLambdaStates];
328     for (int i = 0; i < nLambdaStates; i++) {
329       rtValues[i] = Constants.R * temperatures[i];
330       invRTValues[i] = 1.0 / rtValues[i];
331     }
332 
333     // Find repeated snapshots if from continuous lambda
334     int numEvaluations = eAllFlat[0].length;
335 
336     // Sample random snapshots from each window.
337     int[][] indices = new int[nLambdaStates][numEvaluations];
338     if (randomSamples) {
339       // Build random indices vector maintaining snapshot nums!
340       int[] randomIndices = new int[numEvaluations];
341       int sum = 0;
342       for (int snap : nSamples) {
343         System.arraycopy(getBootstrapIndices(snap, random), 0, randomIndices, sum, snap);
344         sum += snap;
345       }
346       for (int i = 0; i < nLambdaStates; i++) {
347         // Use the same random indices across lambda values
348         indices[i] = randomIndices;
349       }
350     } else {
351       for (int i = 0; i < numEvaluations; i++) {
352         for (int j = 0; j < nLambdaStates; j++) {
353           indices[j][i] = i;
354         }
355       }
356     }
357 
358     // Precompute reducedPotentials since it doesn't change
359     reducedPotentials = new double[nLambdaStates][numEvaluations];
360     double minPotential = Double.POSITIVE_INFINITY;
361     for (int state = 0; state < eAllFlat.length; state++) { // For each lambda value
362       for (int n = 0; n < eAllFlat[0].length; n++) {
363         reducedPotentials[state][n] = eAllFlat[state][indices[state][n]] * invRTValues[state];
364         if (reducedPotentials[state][n] < minPotential) {
365           minPotential = reducedPotentials[state][n];
366         }
367       }
368     }
369 
370     // Subtract the minimum potential from all potentials (we are calculating relative free energies anyway)
371     for (int state = 0; state < nLambdaStates; state++) {
372       for (int n = 0; n < numEvaluations; n++) {
373         reducedPotentials[state][n] -= minPotential;
374       }
375     }
376 
377     // Remove reduced potential arrays where snaps are zero since they cause issues for N.R. and SCI
378     // i.e. no trajectories for that lambda were generated/sampled, but other trajectories had potentials evaluated at that lambda
379     ArrayList<Integer> zeroSnapLambdas = new ArrayList<>();
380     ArrayList<Integer> sampledLambdas = new ArrayList<>();
381     for (int i = 0; i < nLambdaStates; i++) {
382       if (nSamples[i] == 0) {
383         zeroSnapLambdas.add(i);
384       } else {
385         sampledLambdas.add(i);
386       }
387     }
388     int nLambdaStatesTemp = nLambdaStates - zeroSnapLambdas.size();
389     double[][] reducedPotentialsTemp = new double[nLambdaStates - zeroSnapLambdas.size()][numEvaluations];
390     double[] mbarFEEstimatesTemp = new double[nLambdaStates - zeroSnapLambdas.size()];
391     int[] snapsTemp = new int[nLambdaStates - zeroSnapLambdas.size()];
392     if (!zeroSnapLambdas.isEmpty()) {
393       int index = 0;
394       for (int i = 0; i < nLambdaStates; i++) {
395         if (!zeroSnapLambdas.contains(i)) {
396           reducedPotentialsTemp[index] = reducedPotentials[i];
397           mbarFEEstimatesTemp[index] = mbarFEEstimates[i];
398           snapsTemp[index] = nSamples[i];
399           index++;
400         }
401       }
402       logger.info(" Sampled Lambdas: " + sampledLambdas);
403       logger.info(" Zero Snap Lambdas: " + zeroSnapLambdas);
404     } else { // If there aren't any zero snap lambdas, just use the original arrays
405       reducedPotentialsTemp = reducedPotentials;
406       mbarFEEstimatesTemp = mbarFEEstimates;
407       snapsTemp = nSamples;
408     }
409 
410     // SCI iterations used to start optimization of MBAR objective function.
411     // Optimizers can struggle when starting too far from the minimum, but SCI doesn't.
412     double[] prevMBAR = copyOf(mbarFEEstimatesTemp, nLambdaStatesTemp);
413     ;
414     double omega = 1.5; // Parameter chosen empirically to work with most systems (> 2 works but not always).
415     for (int i = 0; i < 10; i++) {
416       prevMBAR = copyOf(mbarFEEstimatesTemp, nLambdaStatesTemp);
417       mbarFEEstimatesTemp = mbarSelfConsistentUpdate(reducedPotentialsTemp, snapsTemp, mbarFEEstimatesTemp);
418       for (int j = 0; j < nLambdaStatesTemp; j++) { // SOR
419         mbarFEEstimatesTemp[j] = omega * mbarFEEstimatesTemp[j] + (1 - omega) * prevMBAR[j];
420       }
421       if (stream(mbarFEEstimatesTemp).anyMatch(Double::isInfinite) || stream(mbarFEEstimatesTemp).anyMatch(Double::isNaN)) {
422         throw new IllegalArgumentException("MBAR contains NaNs or Infs during startup SCI ");
423       }
424       if (converged(prevMBAR)) {
425         break;
426       }
427     }
428     if (MultistateBennettAcceptanceRatio.VERBOSE) {
429       logger.info(" Omega for SCI w/ relaxation: " + omega);
430       logger.info(" MBAR FE Estimates after 10 SCI iterations: " + Arrays.toString(mbarFEEstimatesTemp));
431     }
432 
433     try {
434       if (nLambdaStatesTemp > 100 && !converged(prevMBAR)) { // L-BFGS optimization for high granularity windows where hessian^-1 is expensive
435         if (MultistateBennettAcceptanceRatio.VERBOSE) {
436           logger.info(" L-BFGS optimization started.");
437         }
438         int mCorrections = 5;
439         double[] x = new double[nLambdaStatesTemp];
440         arraycopy(mbarFEEstimatesTemp, 0, x, 0, nLambdaStatesTemp);
441         double[] grad = mbarGradient(reducedPotentialsTemp, snapsTemp, mbarFEEstimatesTemp);
442         double eps = 1.0E-4; // Gradient tolarance -> chosen since L-BFGS seems unstable with tight tolerances
443         OptimizationListener listener = getOptimizationListener();
444         LBFGS.minimize(nLambdaStatesTemp, mCorrections, x, mbarObjectiveFunction(reducedPotentialsTemp, snapsTemp, mbarFEEstimatesTemp),
445             grad, eps, 1000, this, listener);
446         arraycopy(x, 0, mbarFEEstimatesTemp, 0, nLambdaStatesTemp);
447       } else if (!converged(prevMBAR)) { // Newton optimization if hessian inversion isn't too expensive
448         if (MultistateBennettAcceptanceRatio.VERBOSE) {
449           logger.info(" Newton optimization started.");
450         }
451         mbarFEEstimatesTemp = newton(mbarFEEstimatesTemp, reducedPotentialsTemp, snapsTemp, tolerance);
452       }
453     } catch (Exception e) {
454       logger.warning(" L-BFGS/Newton failed to converge. Finishing w/ self-consistent iteration. Message: " +
455           e.getMessage());
456     }
457     if (MultistateBennettAcceptanceRatio.VERBOSE) {
458       logger.info(" MBAR FE Estimates after gradient optimization: " + Arrays.toString(mbarFEEstimatesTemp));
459     }
460 
461     // Update the FE estimates with the optimized values from derivative-based optimization
462     int count = 0;
463     for (Integer i : sampledLambdas) {
464       if (!Double.isNaN(mbarFEEstimatesTemp[count])) { // Should be !NaN
465         mbarFEEstimates[i] = mbarFEEstimatesTemp[count];
466       }
467       count++;
468     }
469 
470     // Self-consistent iteration is used to finish off optimization of MBAR objective function
471     int sciIter = 0;
472     while (!converged(prevMBAR) && sciIter < 1000) {
473       prevMBAR = copyOf(mbarFEEstimates, nLambdaStates);
474       mbarFEEstimates = mbarSelfConsistentUpdate(reducedPotentials, nSamples, mbarFEEstimates);
475       for (int i = 0; i < nLambdaStates; i++) { // SOR for acceleration
476         mbarFEEstimates[i] = omega * mbarFEEstimates[i] + (1 - omega) * prevMBAR[i];
477       }
478       if (stream(mbarFEEstimates).anyMatch(Double::isInfinite) || stream(mbarFEEstimates).anyMatch(Double::isNaN)) {
479         throw new IllegalArgumentException("MBAR estimate contains NaNs or Infs after iteration " + sciIter);
480       }
481       sciIter++;
482     }
483     if (MultistateBennettAcceptanceRatio.VERBOSE) {
484       logger.info(" SCI iterations (max 1000): " + sciIter);
485     }
486 
487     // Calculate uncertainties
488     double[][] theta = mbarTheta(reducedPotentials, nSamples, mbarFEEstimates); // Quite expensive
489     mbarUncertainties = mbarUncertaintyCalc(theta);
490     totalMBARUncertainty = mbarTotalUncertaintyCalc(theta);
491     uncertaintyMatrix = diffMatrixCalculation(theta);
492     if (!randomSamples && MultistateBennettAcceptanceRatio.VERBOSE) { // Never log for bootstrapping
493       logWeights();
494     }
495 
496     // Convert to kcal/mol & calculate differences/sums
497     for (int i = 0; i < nLambdaStates; i++) {
498       mbarFEEstimates[i] = mbarFEEstimates[i] * rtValues[i];
499     }
500     for (int i = 0; i < nFreeEnergyDiffs; i++) {
501       mbarFEDifferenceEstimates[i] = mbarFEEstimates[i + 1] - mbarFEEstimates[i];
502     }
503 
504     mbarEnthalpy = mbarEnthalpyCalc(eAllFlat, mbarFEEstimates);
505     mbarEntropy = mbarEntropyCalc(mbarEnthalpy, mbarFEEstimates);
506 
507     totalMBAREstimate = stream(mbarFEDifferenceEstimates).sum();
508   }
509 
510 
511   //////// Misc. Methods ////////////
512 
513   /**
514    * Checks if the MBAR free energy estimates have converged by comparing the difference
515    * between the previous and current free energies. The tolerance is set by the user.
516    *
517    * @param prevMBAR previous MBAR free energy estimates.
518    * @return true if converged, false otherwise
519    */
520   private boolean converged(double[] prevMBAR) {
521     double[] differences = new double[prevMBAR.length];
522     for (int i = 0; i < prevMBAR.length; i++) {
523       differences[i] = abs(prevMBAR[i] - mbarFEEstimates[i]);
524     }
525     return stream(differences).allMatch(d -> d < tolerance);
526   }
527 
528   /**
529    * Print out, for each FE expectation, the sum of the weights for each trajectory. This
530    * gives an array of length nLambdaStates, where each element is the sum of the weights
531    * coming from the trajectory sampled at the lambda value corresponding to that index.
532    *
533    * <p>i.e. collapsedW[0][0] is the sum of the weights in W[0] from the trajectory sampled at
534    * lambda 0. The diagonal of this matrix should be larger than all other values if that
535    * window had proper sampling.
536    */
537   private void logWeights() {
538     logger.info(" MBAR Weight Matrix Information Collapsed:");
539     double[][] W = mbarW(reducedPotentials, nSamples, mbarFEEstimates);
540     double[][] collapsedW = new double[W.length][W.length]; // Collapse W trajectory-wise (into K x K)
541     for (int i = 0; i < nSamples.length; i++) {
542       for (int j = 0; j < W.length; j++) {
543         int start = 0;
544         for (int k = 0; k < i; k++) {
545           start += nSamples[k];
546         }
547         for (int k = 0; k < nSamples[i]; k++) {
548           collapsedW[j][i] += W[j][start + k];
549         }
550       }
551     }
552     for (int i = 0; i < W.length; i++) {
553       logger.info("\n Estimation " + i + ": " + Arrays.toString(collapsedW[i]));
554     }
555     double[] rowSum = new double[W.length];
556     for (int i = 0; i < collapsedW[0].length; i++) {
557       for (double[] trajectory : collapsedW) {
558         rowSum[i] += trajectory[i];
559       }
560     }
561     softMax(rowSum);
562     logger.info("\n Softmax of trajectory weight: " + Arrays.toString(rowSum));
563   }
564 
565 
566   //////// Methods for calculating MBAR variables, vectors, and matrices. ////////
567 
568   /**
569    * MBAR objective function. This is used for L-BFGS optimization.
570    *
571    * @param reducedPotentials   -ln(boltzmann weights)
572    * @param snapsPerLambda      number of snaps per state
573    * @param freeEnergyEstimates free energies
574    * @return The objective function value.
575    */
576   private static double mbarObjectiveFunction(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
577     if (stream(freeEnergyEstimates).anyMatch(Double::isInfinite) || stream(freeEnergyEstimates).anyMatch(Double::isNaN)) {
578       throw new IllegalArgumentException("MBAR contains NaNs or Infs.");
579     }
580     int nStates = freeEnergyEstimates.length;
581     double[] log_denom_n = new double[reducedPotentials[0].length];
582     for (int i = 0; i < reducedPotentials[0].length; i++) {
583       double[] temp = new double[nStates];
584       double maxTemp = Double.NEGATIVE_INFINITY;
585       for (int j = 0; j < nStates; j++) {
586         temp[j] = freeEnergyEstimates[j] - reducedPotentials[j][i];
587         if (temp[j] > maxTemp) {
588           maxTemp = temp[j];
589         }
590       }
591       log_denom_n[i] = logSumExp(temp, snapsPerLambda, maxTemp);
592     }
593     double[] dotNkFk = new double[snapsPerLambda.length];
594     for (int i = 0; i < snapsPerLambda.length; i++) {
595       dotNkFk[i] = snapsPerLambda[i] * freeEnergyEstimates[i];
596     }
597     return stream(log_denom_n).sum() - stream(dotNkFk).sum();
598   }
599 
600   /**
601    * Gradient of the MBAR objective function. C6 in Shirts and Chodera 2008.
602    *
603    * @param reducedPotentials   energies
604    * @param snapsPerLambda      number of snaps per state
605    * @param freeEnergyEstimates free energies
606    * @return Gradient for the mbar objective function.
607    */
608   private static double[] mbarGradient(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
609     int nStates = freeEnergyEstimates.length;
610     double[] log_num_k = new double[nStates];
611     double[] log_denom_n = new double[reducedPotentials[0].length];
612     double[][] logDiff = new double[reducedPotentials.length][reducedPotentials[0].length];
613     double[] maxLogDiff = new double[nStates];
614     Arrays.fill(maxLogDiff, Double.NEGATIVE_INFINITY);
615     for (int i = 0; i < reducedPotentials[0].length; i++) {
616       double[] temp = new double[nStates];
617       double maxTemp = Double.NEGATIVE_INFINITY;
618       for (int j = 0; j < nStates; j++) {
619         temp[j] = freeEnergyEstimates[j] - reducedPotentials[j][i];
620         if (temp[j] > maxTemp) {
621           maxTemp = temp[j];
622         }
623       }
624       log_denom_n[i] = logSumExp(temp, snapsPerLambda, maxTemp);
625       for (int j = 0; j < nStates; j++) {
626         logDiff[j][i] = -log_denom_n[i] - reducedPotentials[j][i];
627         if (logDiff[j][i] > maxLogDiff[j]) {
628           maxLogDiff[j] = logDiff[j][i];
629         }
630       }
631     }
632     for (int i = 0; i < nStates; i++) {
633       log_num_k[i] = logSumExp(logDiff[i], maxLogDiff[i]);
634     }
635     double[] grad = new double[nStates];
636     for (int i = 0; i < nStates; i++) {
637       grad[i] = -1.0 * snapsPerLambda[i] * (1.0 - exp(freeEnergyEstimates[i] + log_num_k[i]));
638     }
639     return grad;
640   }
641 
642   /**
643    * Hessian of the MBAR objective function. C9 in Shirts and Chodera 2008.
644    *
645    * @param reducedPotentials   energies
646    * @param snapsPerLambda      number of snaps per state
647    * @param freeEnergyEstimates free energies
648    * @return Hessian for the mbar objective function.
649    */
650   private static double[][] mbarHessian(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
651     int nStates = freeEnergyEstimates.length;
652     double[][] W = mbarW(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
653     // h = dot(W.T, W) * snapsPerLambda * snapsPerLambda[:, newaxis] - diag(W.sum(0) * snapsPerLambda)
654     double[][] hessian = new double[nStates][nStates];
655     for (int i = 0; i < nStates; i++) {
656       for (int j = 0; j < nStates; j++) {
657         double sum = 0.0;
658         for (int k = 0; k < reducedPotentials[0].length; k++) {
659           sum += W[i][k] * W[j][k];
660         }
661         hessian[i][j] = sum * snapsPerLambda[i] * snapsPerLambda[j];
662       }
663       double wSum = 0.0;
664       for (int k = 0; k < W[i].length; k++) {
665         wSum += W[i][k];
666       }
667       hessian[i][i] -= wSum * snapsPerLambda[i];
668     }
669     // h = -h
670     for (int i = 0; i < nStates; i++) {
671       for (int j = 0; j < nStates; j++) {
672         hessian[i][j] = -hessian[i][j];
673       }
674     }
675     return hessian;
676   }
677 
678   /**
679    * W = exp(freeEnergyEstimates - reducedPotentials.T - log_denominator_n[:, newaxis])
680    * Eq. 9 in Shirts and Chodera 2008.
681    *
682    * @param reducedPotentials   energies
683    * @param snapsPerLambda      number of snaps per state
684    * @param freeEnergyEstimates free energies
685    * @return W matrix.
686    */
687   private static double[][] mbarW(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
688     int nStates = freeEnergyEstimates.length;
689     double[] log_denom_n = new double[reducedPotentials[0].length];
690     for (int i = 0; i < reducedPotentials[0].length; i++) {
691       double[] temp = new double[nStates];
692       double maxTemp = Double.NEGATIVE_INFINITY;
693       for (int j = 0; j < nStates; j++) {
694         temp[j] = freeEnergyEstimates[j] - reducedPotentials[j][i];
695         if (temp[j] > maxTemp) {
696           maxTemp = temp[j];
697         }
698       }
699       // log_denom_n = calculates log(sumOverStates(N_k * exp(FE[j] - reducedPotentials[j][i])))
700       log_denom_n[i] = logSumExp(temp, snapsPerLambda, maxTemp);
701     }
702     // logW = freeEnergyEstimates - reducedPotentials.T - log_denominator_n[:, newaxis]
703     // freeEnergyEstimates[i] = log(ck / ci) --> ratio of normalization constants
704     double[][] W = new double[nStates][reducedPotentials[0].length];
705     for (int i = 0; i < nStates; i++) {
706       for (int j = 0; j < reducedPotentials[0].length; j++) {
707         W[i][j] = exp(freeEnergyEstimates[i] - reducedPotentials[i][j] - log_denom_n[j]);
708       }
709     }
710     return W;
711   }
712 
713   private double[] mbarEnthalpyCalc(double[][] reducedPotentials, double[] mbarFEEstimates) {
714     double[] enthalpy = new double[mbarFEEstimates.length - 1];
715     double[] averagePotential = new double[mbarFEEstimates.length];
716     for (int i = 0; i < reducedPotentials.length; i++) {
717       averagePotential[i] = computeExpectations(eAllFlat[i])[i]; // average potential of ith lambda
718     }
719     for (int i = 0; i < enthalpy.length; i++) {
720       enthalpy[i] = averagePotential[i + 1] - averagePotential[i];
721     }
722     return enthalpy;
723   }
724 
725   private double[] mbarEntropyCalc(double[] mbarEnthalpy, double[] mbarFEEstimates) {
726     double[] entropy = new double[mbarFEEstimates.length - 1];
727     for (int i = 0; i < entropy.length; i++) {
728       entropy[i] = mbarEnthalpy[i] - mbarFEDifferenceEstimates[i]; // dG = dH - TdS || TdS = dH - dG
729     }
730     return entropy;
731   }
732 
733   /**
734    * Weight observable by exp(bias/RT) prior to computing expectation when set.
735    *
736    * @param biasAll
737    * @param multiDataObservable
738    */
739   public void setBiasData(double[][][] biasAll, boolean multiDataObservable) {
740     biasFlat = new double[biasAll.length][biasAll.length * biasAll[0][0].length];
741     if (multiDataObservable) { // Flatten data
742       int[] snapsT = new int[biasAll.length];
743       int[] nanCount = new int[biasAll.length];
744       for (int i = 0; i < biasAll.length; i++) {
745         ArrayList<Double> temp = new ArrayList<>();
746         double maxBias = Double.NEGATIVE_INFINITY;
747         for (int j = 0; j < biasAll.length; j++) {
748           int count = 0;
749           int countNaN = 0;
750           for (int k = 0; k < biasAll[j][i].length; k++) {
751             // Don't include NaN values
752             if (!Double.isNaN(biasAll[j][i][k])) {
753               temp.add(biasAll[j][i][k]);
754               if (biasAll[j][i][k] > maxBias) {
755                 maxBias = biasAll[j][i][k];
756               }
757               count++;
758             } else {
759               countNaN++;
760             }
761           }
762           snapsT[j] = count;
763           nanCount[j] = countNaN;
764         }
765         biasFlat[i] = temp.stream().mapToDouble(Double::doubleValue).toArray();
766         // Regularize bias for this lambda
767         for (int j = 0; j < biasFlat[i].length; j++) {
768           biasFlat[i][j] -= maxBias;
769         }
770       }
771     } else { // Put relevant data into the 0th index
772       int count = 0;
773       double maxBias = Double.NEGATIVE_INFINITY;
774       for (int i = 0; i < biasAll.length; i++) {
775         for (int j = 0; j < biasAll[0][0].length; j++) {
776           if (!Double.isNaN(biasAll[i][i][j])) {
777             biasFlat[0][count] = biasAll[i][i][j];
778             if (biasAll[i][i][j] > maxBias) {
779               maxBias = biasAll[i][i][j];
780             }
781             count++;
782           }
783         }
784       }
785       // Regularize bias for this lambda
786       for (int i = 0; i < biasFlat[0].length; i++) {
787         biasFlat[0][i] -= maxBias;
788       }
789     }
790   }
791 
792   public void setBiasData(double[][] biasData) {
793     this.biasFlat = biasData;
794     // Regularize bias for this lambda
795     for (int i = 0; i < biasFlat.length; i++) {
796       double maxBias = Double.NEGATIVE_INFINITY;
797       for (int j = 0; j < biasFlat[i].length; j++) {
798         if (biasFlat[i][j] > maxBias) {
799           maxBias = biasFlat[i][j];
800         }
801       }
802       for (int j = 0; j < biasFlat[i].length; j++) {
803         biasFlat[i][j] -= maxBias;
804       }
805     }
806   }
807 
808   public void setObservableData(double[][][] oAll, boolean multiDataObservable, boolean uncertainties) {
809     oAllFlat = new double[oAll.length][oAll.length * oAll[0][0].length];
810     if (multiDataObservable) { // Flatten data
811       int[] snapsT = new int[oAll.length];
812       int[] nanCount = new int[oAll.length];
813       for (int i = 0; i < oAll.length; i++) {
814         ArrayList<Double> temp = new ArrayList<>();
815         for (int j = 0; j < oAll.length; j++) {
816           int count = 0;
817           int countNaN = 0;
818           for (int k = 0; k < oAll[j][i].length; k++) {
819             // Don't include NaN values
820             if (!Double.isNaN(oAll[j][i][k])) {
821               temp.add(oAll[j][i][k]);
822               count++;
823             } else {
824               countNaN++;
825             }
826           }
827           snapsT[j] = count;
828           nanCount[j] = countNaN;
829         }
830         oAllFlat[i] = temp.stream().mapToDouble(Double::doubleValue).toArray();
831       }
832     } else { // Put relevant data into the 0th index
833       int count = 0;
834       for (int i = 0; i < oAll.length; i++) {
835         for (int j = 0; j < oAll[0][0].length; j++) {
836           if (!Double.isNaN(oAll[i][i][j])) {
837             oAllFlat[0][count] = oAll[i][i][j]; // Note [i][i] indexing
838             count++;
839           }
840         }
841       }
842     }
843     // OST Data
844     if (biasFlat != null) {
845       for (int i = 0; i < oAllFlat.length; i++) {
846         for (int j = 0; j < oAllFlat[i].length; j++) {
847           oAllFlat[i][j] *= exp(biasFlat[i][j] / rtValues[i]);
848         }
849       }
850     }
851     this.fillObservationExpectations(multiDataObservable, uncertainties);
852   }
853 
854   public void setObservableData(double[][] oAll, boolean uncertainties) {
855     oAllFlat = oAll;
856     // OST Data
857     if (biasFlat != null) {
858       if (oAllFlat.length != biasFlat.length || oAllFlat[0].length != biasFlat[0].length) {
859         logger.severe("Observable and bias data are not the same size. Exiting.");
860       }
861       for (int i = 0; i < oAllFlat.length; i++) {
862         for (int j = 0; j < oAllFlat[i].length; j++) {
863           oAllFlat[i][j] *= exp(biasFlat[i][j] / rtValues[i]);
864         }
865       }
866     }
867     this.fillObservationExpectations(oAllFlat.length != 1, uncertainties);
868   }
869 
870   public double getTIIntegral() {
871     DataSet dSet = new DoublesDataSet(Integrate1DNumeric.generateXPoints(0, 1, mbarObservableEnsembleAverages.length, false),
872         mbarObservableEnsembleAverages, false);
873     return Integrate1DNumeric.integrateData(dSet, Integrate1DNumeric.IntegrationSide.LEFT, Integrate1DNumeric.IntegrationType.TRAPEZOIDAL);
874   }
875 
876   /**
877    * Calculate expectation of samples from W matrix. Optionally calculate the uncertainty with
878    * augmented W matrix (incurs a significant computational cost ~10-20x MBAR calculation).
879    *
880    * @return Uncertainty of the observable.
881    */
882   private void fillObservationExpectations(boolean multiData, boolean uncertainties) {
883     if (multiData) {
884       mbarObservableEnsembleAverages = new double[oAllFlat.length];
885       mbarObservableEnsembleAverageUncertainties = new double[oAllFlat.length];
886       for (int i = 0; i < oAllFlat.length; i++) {
887         mbarObservableEnsembleAverages[i] = computeExpectations(oAllFlat[i])[i];
888         if (uncertainties) {
889           mbarObservableEnsembleAverageUncertainties[i] = computeExpectationStd(oAllFlat[i])[i];
890         }
891       }
892     } else {
893       mbarObservableEnsembleAverages = computeExpectations(oAllFlat[0]);
894       if (uncertainties) {
895         mbarObservableEnsembleAverageUncertainties = computeExpectationStd(oAllFlat[0]);
896       }
897     }
898   }
899 
900   /**
901    * Compute the MBAR expectation of a given observable (1xN) for each K. This observable
902    * could be something like x, x^2 (where x is equilibrium for a harmonic oscillator),
903    * or some other function of the configuration X like RMSD from a target conformation.
904    * Additionally, it could be evaluations of some potential at a specific lambda value.
905    * Each trajectory snap should have a corresponding observable value (or evaluation).
906    *
907    * @param samples
908    * @return
909    */
910   private double[] computeExpectations(double[] samples) {
911     double[][] W = mbarW(reducedPotentials, nSamples, mbarFEEstimates);
912     if (W[0].length != samples.length) {
913       logger.severe("Samples and W matrix are not the same length. Exiting.");
914     }
915     double[] expectation = new double[W.length];
916     for (int i = 0; i < W.length; i++) {
917       for (int j = 0; j < W[i].length; j++) {
918         expectation[i] += W[i][j] * samples[j];
919       }
920     }
921     return expectation;
922   }
923 
924   /**
925    * Eq. 13-15 in Shirts and Chodera (2008) for the MBAR observable uncertainty calculation.
926    * Originally implemented as seen in paper, but switched to logsumexp version because of
927    * Inf/NaN issues for large values captured in samples (i.e. potential energies).
928    *
929    * @return WnA matrix.
930    */
931   private double[][] mbarAugmentedW(double[] samples) {
932     int nStates = mbarFEEstimates.length;
933     // Enforce positivity of samples --> from pymbar
934     double minSample = stream(samples).min().getAsDouble() - 3 * java.lang.Math.ulp(1.0); // ulp to avoid zeros
935     if (minSample < 0) {
936       for (int i = 0; i < samples.length; i++) {
937         samples[i] -= minSample;
938       }
939     }
940     // Eq. 14 in Shirts and Chodera (2008)
941     double[][] logCATerms = new double[nStates][reducedPotentials[0].length];
942     double[] maxLogCATerm = new double[reducedPotentials[0].length];
943     Arrays.fill(maxLogCATerm, Double.NEGATIVE_INFINITY);
944     double[] logCA = new double[nStates];
945     double[] log_denom_n = new double[reducedPotentials[0].length];
946     for (int i = 0; i < reducedPotentials[0].length; i++) {
947       double[] temp = new double[nStates];
948       double maxTemp = Double.NEGATIVE_INFINITY;
949       for (int j = 0; j < nStates; j++) {
950         temp[j] = mbarFEEstimates[j] - reducedPotentials[j][i];
951         if (temp[j] > maxTemp) {
952           maxTemp = temp[j];
953         }
954       }
955       log_denom_n[i] = logSumExp(temp, nSamples, maxTemp);
956       for (int j = 0; j < nStates; j++) {
957         logCATerms[j][i] = log(samples[i]) - reducedPotentials[j][i] - log_denom_n[i];
958         if (logCATerms[j][i] > maxLogCATerm[i]) {
959           maxLogCATerm[j] = logCATerms[j][i];
960         }
961       }
962     }
963     for (int i = 0; i < nStates; i++) {
964       logCA[i] = logSumExp(logCATerms[i], maxLogCATerm[i]);
965     }
966     // Eq. 13 in Shirts and Chodera (2008)
967     double[][] WnA = new double[nStates][reducedPotentials[0].length];
968     double[][] Wna = new double[nStates][reducedPotentials[0].length]; // normal W matrix
969     for (int i = 0; i < nStates; i++) {
970       for (int j = 0; j < reducedPotentials[0].length; j++) {
971         WnA[i][j] = samples[j] * exp(-logCA[i] - reducedPotentials[i][j] - log_denom_n[j]);
972         Wna[i][j] = exp(-mbarFEEstimates[i] - reducedPotentials[i][j] - log_denom_n[j]);
973       }
974     }
975     if (minSample < 0) { // reset samples
976       for (int i = 0; i < samples.length; i++) {
977         samples[i] += minSample;
978       }
979     }
980     double[][] augmentedW = new double[nStates * 2][reducedPotentials[0].length];
981     for (int i = 0; i < augmentedW.length; i++) {
982       augmentedW[i] = i < nStates ? Wna[i] : WnA[(i - nStates)];
983     }
984     return augmentedW;
985   }
986 
987   /**
988    * Compute the MBAR uncertainty of an observable. The equations for this are not clear,
989    * but we append an augmented weight matrix (calculated by multiplying the observed values
990    * into the W matrix calculation) to the original W matrix. This is then used to calculate
991    * theta.
992    *
993    * @param samples
994    * @return
995    */
996   private double[] computeExpectationStd(double[] samples) {
997     int[] extendedSnaps = new int[nSamples.length * 2];
998     System.arraycopy(nSamples, 0, extendedSnaps, 0, nSamples.length);
999     RealMatrix theta = MatrixUtils.createRealMatrix(mbarTheta(extendedSnaps, mbarAugmentedW(samples)));
1000     double[] expectations = computeExpectations(samples);
1001     double[] diag = new double[expectations.length * 2];
1002     for (int i = 0; i < expectations.length; i++) {
1003       diag[i] = expectations[i];
1004       diag[i + expectations.length] = expectations[i];
1005     }
1006     RealMatrix diagMatrix = MatrixUtils.createRealDiagonalMatrix(diag);
1007     theta = diagMatrix.multiply(theta).multiply(diagMatrix);
1008     RealMatrix ul = theta.getSubMatrix(0, expectations.length - 1, 0, expectations.length - 1);
1009     RealMatrix ur = theta.getSubMatrix(0, expectations.length - 1, expectations.length, expectations.length * 2 - 1);
1010     RealMatrix ll = theta.getSubMatrix(expectations.length, expectations.length * 2 - 1, 0, expectations.length - 1);
1011     RealMatrix lr = theta.getSubMatrix(expectations.length, expectations.length * 2 - 1, expectations.length, expectations.length * 2 - 1);
1012     double[][] covA = ul.add(lr).subtract(ur).subtract(ll).getData(); // Loose precision here
1013     double[] sigma = new double[covA.length];
1014     for (int i = 0; i < covA.length; i++) {
1015       sigma[i] = sqrt(abs(covA[i][i]));
1016     }
1017     return sigma;
1018   }
1019 
1020   /**
1021    * MBAR uncertainty calculation.
1022    *
1023    * @return Uncertainties for the MBAR free energy estimates.
1024    */
1025   private static double[] mbarUncertaintyCalc(double[][] theta) {
1026     double[] uncertainties = new double[theta.length - 1];
1027     // del(dFij) = Theta[i,i] - 2 * Theta[i,j] + Theta[j,j]
1028     for (int i = 0; i < theta.length - 1; i++) {
1029       // TODO: Figure out why negative var is happening (likely due to theta calculation differing from pymbar's)
1030       double variance = theta[i][i] - 2 * theta[i][i + 1] + theta[i + 1][i + 1];
1031       if (variance < 0) {
1032         if (MultistateBennettAcceptanceRatio.VERBOSE) {
1033           logger.warning(" Negative variance detected in MBAR uncertainty calculation. " +
1034               "Multiplying by -1 to get real value. Check diff matrix to see which variances were negative. " +
1035               "They should be NaN.");
1036         }
1037         variance *= -1;
1038       }
1039       uncertainties[i] = sqrt(variance);
1040     }
1041     return uncertainties;
1042   }
1043 
1044   /**
1045    * MBAR total uncertainty calculation. Eq 12 in Shirts and Chodera (2008).
1046    *
1047    * @param theta matrix of covariances
1048    * @return Total uncertainty for the MBAR free energy estimates.
1049    */
1050   private static double mbarTotalUncertaintyCalc(double[][] theta) {
1051     int nStates = theta.length;
1052     return sqrt(abs(theta[0][0] - 2 * theta[0][nStates - 1] + theta[nStates - 1][nStates - 1]));
1053   }
1054 
1055   /**
1056    * Theta = W.T @ (I - W @ diag(snapsPerState) @ W.T)^-1 @ W.
1057    * <p>
1058    * Requires calculation and inversion of W matrix.
1059    * D4 from supp info of MBAR paper used instead to reduce storage and comp. complexity.
1060    *
1061    * @param reducedPotentials energies
1062    * @param snapsPerState     number of snaps per state
1063    * @param freeEnergies      free energies
1064    * @return Theta matrix.
1065    */
1066   private static double[][] mbarTheta(double[][] reducedPotentials, int[] snapsPerState, double[] freeEnergies) {
1067     return mbarTheta(snapsPerState, mbarW(reducedPotentials, snapsPerState, freeEnergies));
1068   }
1069 
1070   /**
1071    * Compute theta with a given W matrix.
1072    *
1073    * @param snapsPerState
1074    * @param W
1075    * @return
1076    */
1077   private static double[][] mbarTheta(int[] snapsPerState, double[][] W) {
1078     RealMatrix WMatrix = MatrixUtils.createRealMatrix(W).transpose();
1079     RealMatrix I = MatrixUtils.createRealIdentityMatrix(snapsPerState.length);
1080     RealMatrix NkMatrix = MatrixUtils.createRealDiagonalMatrix(stream(snapsPerState).mapToDouble(i -> i).toArray());
1081     SingularValueDecomposition svd = new SingularValueDecomposition(WMatrix);
1082     RealMatrix V = svd.getV();
1083     RealMatrix S = MatrixUtils.createRealDiagonalMatrix(svd.getSingularValues());
1084 
1085     // W.T @ (I - W @ diag(snapsPerState) @ W.T)^-1 @ W
1086     // = V @ S @ (I - S @ V.T @ diag(snapsPerState) @ V @ S)^-1 @ S @ V.T
1087     RealMatrix theta = S.multiply(V.transpose());
1088     theta = theta.multiply(NkMatrix).multiply(V).multiply(S);
1089     theta = I.subtract(theta);
1090     theta = MatrixUtils.inverse(theta); // pinv equivalent
1091     theta = V.multiply(S).multiply(theta).multiply(S).multiply(V.transpose());
1092 
1093     return theta.getData();
1094   }
1095 
1096   /**
1097    * MBAR uncertainty matrix calculation. diff[i][j] gives FE uncertainty of moving between
1098    * lambda i-> j.
1099    *
1100    * @param theta matrix of covariances
1101    * @return Diff matrix for the MBAR free energy estimates.
1102    */
1103   private static double[][] diffMatrixCalculation(double[][] theta) {
1104     double[][] diffMatrix = new double[theta.length][theta.length];
1105     for (int i = 0; i < diffMatrix.length; i++) {
1106       for (int j = 0; j < diffMatrix.length; j++) {
1107         diffMatrix[i][j] = sqrt(theta[i][i] - 2 * theta[i][j] + theta[j][j]);
1108       }
1109     }
1110     return diffMatrix;
1111   }
1112 
1113   //////// Methods for solving MBAR with self-consistent iteration, L-BFGS optimization, and Newton-Raphson. ////////
1114 
1115   /**
1116    * Self-consistent iteration to update free energies. Eq. 11 from Shirts and Chodera (2008).
1117    *
1118    * @param reducedPotential    energies
1119    * @param snapsPerLambda      number of snaps per state
1120    * @param freeEnergyEstimates free energies
1121    * @return updated free energies
1122    */
1123   private static double[] mbarSelfConsistentUpdate(double[][] reducedPotential, int[] snapsPerLambda,
1124                                                    double[] freeEnergyEstimates) {
1125     int nStates = freeEnergyEstimates.length;
1126     double[] updatedF_k = new double[nStates];
1127     double[] log_denom_n = new double[reducedPotential[0].length];
1128     double[][] logDiff = new double[reducedPotential.length][reducedPotential[0].length];
1129     double[] maxLogDiff = new double[nStates];
1130     fill(maxLogDiff, Double.NEGATIVE_INFINITY);
1131     for (int i = 0; i < reducedPotential[0].length; i++) {
1132       double[] temp = new double[nStates];
1133       double maxTemp = Double.NEGATIVE_INFINITY;
1134       for (int j = 0; j < nStates; j++) {
1135         temp[j] = freeEnergyEstimates[j] - reducedPotential[j][i];
1136         if (temp[j] > maxTemp) {
1137           maxTemp = temp[j];
1138         }
1139       }
1140       log_denom_n[i] = logSumExp(temp, snapsPerLambda, maxTemp);
1141       for (int j = 0; j < nStates; j++) {
1142         logDiff[j][i] = -log_denom_n[i] - reducedPotential[j][i];
1143         if (logDiff[j][i] > maxLogDiff[j]) {
1144           maxLogDiff[j] = logDiff[j][i];
1145         }
1146       }
1147     }
1148 
1149     for (int i = 0; i < nStates; i++) {
1150       updatedF_k[i] = -1.0 * logSumExp(logDiff[i], maxLogDiff[i]);
1151     }
1152 
1153     // Constrain f1=0 over the course of iterations to prevent uncontrolled growth in magnitude
1154     double norm = updatedF_k[0];
1155     updatedF_k[0] = 0.0;
1156     for (int i = 1; i < nStates; i++) {
1157       updatedF_k[i] = updatedF_k[i] - norm;
1158     }
1159 
1160     return updatedF_k;
1161   }
1162 
1163   /**
1164    * Newton-Raphson step for MBAR optimization. Falls back to the steepest descent if hessian is singular.
1165    * <p>
1166    * The matrix can come back from being singular after several iterations, so it isn't worth moving to L-BFGS.
1167    *
1168    * @param n        current free energies.
1169    * @param grad     gradient of the objective function.
1170    * @param hessian  hessian of the objective function.
1171    * @param stepSize step size for the Newton-Raphson step.
1172    * @return updated free energies.
1173    */
1174   private static double[] newtonStep(double[] n, double[] grad, double[][] hessian, double stepSize) {
1175     double[] nPlusOne = new double[n.length];
1176     double[] step;
1177     try {
1178       RealMatrix hessianInverse = MatrixUtils.inverse(MatrixUtils.createRealMatrix(hessian));
1179       step = hessianInverse.preMultiply(grad);
1180     } catch (IllegalArgumentException e) {
1181       if (MultistateBennettAcceptanceRatio.VERBOSE) {
1182         logger.info(" Singular matrix detected in MBAR Newton-Raphson step. Performing steepest descent step.");
1183       }
1184       step = grad;
1185       stepSize = 1e-5;
1186     }
1187     // Zero out the first term of the step
1188     double temp = step[0];
1189     step[0] = 0.0;
1190     for (int i = 1; i < step.length; i++) {
1191       step[i] -= temp;
1192     }
1193     for (int i = 0; i < n.length; i++) {
1194       nPlusOne[i] = n[i] - step[i] * stepSize;
1195     }
1196     return nPlusOne;
1197   }
1198 
1199   /**
1200    * Newton-Raphson optimization for MBAR.
1201    *
1202    * @param freeEnergyEstimates free energies.
1203    * @param reducedPotentials   energies.
1204    * @param snapsPerLambda      number of snaps per state.
1205    * @param tolerance           convergence tolerance.
1206    * @return updated free energies.
1207    */
1208   private static double[] newton(double[] freeEnergyEstimates, double[][] reducedPotentials,
1209                                  int[] snapsPerLambda, double tolerance) {
1210     double[] grad = mbarGradient(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1211     double[][] hessian = mbarHessian(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1212     double[] f_kPlusOne = newtonStep(freeEnergyEstimates, grad, hessian, 1.0);
1213     int iter = 1;
1214     while (iter < 15) { // Quadratic convergence is expected, SCI will run anyway
1215       freeEnergyEstimates = f_kPlusOne;
1216       grad = mbarGradient(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1217       hessian = mbarHessian(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1218       // Catches singular matrices and performs steepest descent
1219       f_kPlusOne = newtonStep(freeEnergyEstimates, grad, hessian, 1.0);
1220       double eps = 0.0;
1221       for (int i = 0; i < freeEnergyEstimates.length; i++) {
1222         eps += abs(grad[i]);
1223       }
1224       if (eps < tolerance) {
1225         break;
1226       }
1227       iter++;
1228     }
1229     if (MultistateBennettAcceptanceRatio.VERBOSE) {
1230       logger.info(" Newton iterations (max 15): " + iter);
1231     }
1232 
1233     return f_kPlusOne;
1234   }
1235 
1236   /**
1237    * Calculates the log of the sum of the exponential of the given values.
1238    * <p>
1239    * The max value is subtracted from each value in the array before exponentiation to prevent overflow.
1240    *
1241    * @param values The values to exponential and sum.
1242    * @param max    The max value is subtracted from each value in the array prior to exponentiation.
1243    * @return the sum
1244    */
1245   private static double logSumExp(double[] values, double max) {
1246     int[] b = fill(new int[values.length], 1);
1247     return logSumExp(values, b, max);
1248   }
1249 
1250   /**
1251    * Calculates the log of the sum of the exponential of the given values.
1252    * <p>
1253    * The max value is subtracted from each value in the array before exponentiation to prevent overflow.
1254    * MBAR calculation is easiest to do in log terms, only exponentiating when required. Prevents zeros
1255    * in the denominator.
1256    *
1257    * @param values The values to exponential and sum.
1258    * @param max    The max value is subtracted from each value in the array prior to exponentiation.
1259    * @param b      Weights for each value in the array.
1260    * @return the sum
1261    */
1262   private static double logSumExp(double[] values, int[] b, double max) {
1263     // ChatGPT mostly wrote this and I tweaked it to match more closely with scipy's log-sum-exp implementation
1264     // Find the maximum value in the array.
1265     assert values.length == b.length : "values and b must be the same length";
1266 
1267     // Subtract the maximum value from each value in the array, exponential the result, and add up these values.
1268     double sum = 0.0;
1269     for (int i = 0; i < values.length; i++) {
1270       sum += b[i] * exp(values[i] - max);
1271     }
1272 
1273     // Take the natural logarithm of the sum and add the maximum value back in.
1274     return max + log(sum);
1275   }
1276 
1277   /**
1278    * Turns vector into probability distribution.
1279    *
1280    * @param values
1281    */
1282   private static void softMax(double[] values) {
1283     double max = stream(values).max().getAsDouble();
1284     double sum = 0.0;
1285     for (int i = 0; i < values.length; i++) {
1286       values[i] = exp(values[i] - max);
1287       sum += values[i];
1288     }
1289     for (int i = 0; i < values.length; i++) {
1290       values[i] /= sum;
1291     }
1292   }
1293 
1294   /**
1295    * TODO: Log out the MBAR optimization progress.
1296    *
1297    * @return
1298    */
1299   private OptimizationListener getOptimizationListener() {
1300     return new OptimizationListener() {
1301       @Override
1302       public boolean optimizationUpdate(int iter, int nBFGS, int nFunctionEvals, double gradientRMS,
1303                                         double coordinateRMS, double f, double df, double angle,
1304                                         LineSearch.LineSearchResult info) {
1305         return true;
1306       }
1307     };
1308   }
1309 
1310   /**
1311    * MBAR objective function evaluation at a given free energy estimate for L-BFGS optimization.
1312    *
1313    * @param x Input parameters.
1314    * @return The objective function value at the given parameters.
1315    */
1316   @Override
1317   public double energy(double[] x) {
1318     // Zero out the first term
1319     double tempO = x[0];
1320     x[0] = 0.0;
1321     for (int i = 1; i < x.length; i++) {
1322       x[i] -= tempO;
1323     }
1324     return mbarObjectiveFunction(reducedPotentials, nSamples, x);
1325   }
1326 
1327   /**
1328    * MBAR objective function evaluation and gradient at a given free energy estimate for L-BFGS optimization.
1329    *
1330    * @param x Input parameters.
1331    * @param g The gradient with respect to each parameter.
1332    * @return The objective function value at the given parameters.
1333    */
1334   @Override
1335   public double energyAndGradient(double[] x, double[] g) {
1336     double tempO = x[0];
1337     x[0] = 0.0;
1338     for (int i = 1; i < x.length; i++) {
1339       x[i] -= tempO;
1340     }
1341     double[] tempG = mbarGradient(reducedPotentials, nSamples, x);
1342     arraycopy(tempG, 0, g, 0, g.length);
1343     return mbarObjectiveFunction(reducedPotentials, nSamples, x);
1344   }
1345 
1346   @Override
1347   public double[] getCoordinates(double[] parameters) {
1348     return new double[0];
1349   }
1350 
1351   @Override
1352   public int getNumberOfVariables() {
1353     return 0;
1354   }
1355 
1356   @Override
1357   public double[] getScaling() {
1358     return null;
1359   }
1360 
1361   @Override
1362   public void setScaling(double[] scaling) {
1363   }
1364 
1365   @Override
1366   public double getTotalEnergy() {
1367     return 0;
1368   }
1369 
1370   /// ///// Getters and setters ////////
1371 
1372   public BennettAcceptanceRatio getBAR() {
1373     return new BennettAcceptanceRatio(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures);
1374   }
1375 
1376   @Override
1377   public MultistateBennettAcceptanceRatio copyEstimator() {
1378     return new MultistateBennettAcceptanceRatio(lamValues, eAll, temperatures, tolerance, seedType);
1379   }
1380 
1381   @Override
1382   public double[] getFreeEnergyDifferences() {
1383     return mbarFEDifferenceEstimates;
1384   }
1385 
1386   public double[] getMBARFreeEnergies() {
1387     return mbarFEEstimates;
1388   }
1389 
1390   public double[][] getReducedPotentials() {
1391     return reducedPotentials;
1392   }
1393 
1394   public int[] getSnaps() {
1395     return nSamples;
1396   }
1397 
1398   @Override
1399   public double[] getFEDifferenceUncertainties() {
1400     return mbarUncertainties;
1401   }
1402 
1403   public double[] getObservationEnsembleAverages() {
1404     return mbarObservableEnsembleAverages;
1405   }
1406 
1407   public double[] getObservationEnsembleUncertainties() {
1408     return mbarObservableEnsembleAverageUncertainties;
1409   }
1410 
1411   public double[][] getUncertaintyMatrix() {
1412     return uncertaintyMatrix;
1413   }
1414 
1415   @Override
1416   public double getTotalFreeEnergyDifference() {
1417     return totalMBAREstimate;
1418   }
1419 
1420   @Override
1421   public double getTotalFEDifferenceUncertainty() {
1422     return totalMBARUncertainty;
1423   }
1424 
1425   @Override
1426   public int getNumberOfBins() {
1427     return nFreeEnergyDiffs;
1428   }
1429 
1430   @Override
1431   public double[] getEnthalpyDifferences() {
1432     return mbarEnthalpy;
1433   }
1434 
1435   /**
1436    * {@inheritDoc}
1437    */
1438   @Override
1439   public double getTotalEnthalpyDifference() {
1440     return getTotalEnthalpyDifference(mbarEnthalpy);
1441   }
1442 
1443   public double[] getBinEntropies() {
1444     return mbarEntropy;
1445   }
1446 
1447   public static void writeFile(double[][] energies, File file, double temperature) {
1448     try (FileWriter fw = new FileWriter(file);
1449          BufferedWriter bw = new BufferedWriter(fw)) {
1450       // Write the number of snapshots and the temperature on the first line
1451       bw.write(energies[0].length + " " + temperature);
1452       bw.newLine();
1453 
1454       // Write the energies
1455       StringBuilder sb = new StringBuilder();
1456       for (int i = 0; i < energies[0].length; i++) {
1457         sb.append("     ").append(i).append(" "); // Write the index of the snapshot
1458         for (int j = 0; j < energies.length; j++) {
1459           sb.append("    ").append(energies[j][i]).append(" ");
1460         }
1461         sb.append("\n");
1462         bw.write(sb.toString());
1463         sb = new StringBuilder(); // Very important
1464       }
1465     } catch (IOException e) {
1466       e.printStackTrace();
1467     }
1468   }
1469 
1470   /**
1471    * Test all MBAR methods individually with a simple Harmonic Oscillator test case with an
1472    * excess of samples. "PASS" indicates that the test passed, while "FAIL" followed by the
1473    * method name indicates that the test failed.
1474    * <p>
1475    * Last updated - 06/11/2024
1476    *
1477    * @return array of test results
1478    */
1479   public static String[] testMBARMethods() {
1480     // Set up highly converged test case
1481     double[] O_k = {1, 2, 3, 4};
1482     double[] K_k = {.5, 1.0, 1.5, 2};
1483     int[] N_k = {100000, 100000, 100000, 100000};
1484     double beta = 1.0;
1485     HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(O_k, K_k, beta);
1486     String setting = "u_kln";
1487     Object[] sampleResult = testCase.sample(N_k, setting, (long) 0);
1488     double[][][] u_kln = (double[][][]) sampleResult[1];
1489     double[] temps = {1 / Constants.R};
1490     MultistateBennettAcceptanceRatio mbar = new MultistateBennettAcceptanceRatio(O_k, u_kln, temps, 1.0E-7, MultistateBennettAcceptanceRatio.SeedType.ZEROS);
1491     MultistateBennettAcceptanceRatio mbarHigherTol = new MultistateBennettAcceptanceRatio(O_k, u_kln, temps, 1.0, MultistateBennettAcceptanceRatio.SeedType.ZEROS);
1492     String[] results = new String[7];
1493     // Get required information for all methods
1494     double[][] reducedPotentials = mbar.getReducedPotentials();
1495     double[] freeEnergyEstimates = mbar.getMBARFreeEnergies();
1496     double[] highTolFEEstimates = mbarHigherTol.getMBARFreeEnergies();
1497     double[] zeros = new double[freeEnergyEstimates.length];
1498     int[] snapsPerLambda = mbar.getSnaps();
1499 
1500     // getMBARFreeEnergies()
1501     double[] expectedFEEstimates = new double[]{0.0, 0.3474485596619945, 0.5460865684340613, 0.6866650788765148};
1502     boolean pass = normDiff(freeEnergyEstimates, expectedFEEstimates) < 1e-5;
1503     expectedFEEstimates = new double[]{0.0, 0.35798124225733474, 0.44721370511807645, 0.477203739646745};
1504     pass = normDiff(highTolFEEstimates, expectedFEEstimates) < 1e-5 && pass;
1505     results[0] = pass ? "PASS" : "FAIL getMBARFreeEnergies()";
1506 
1507     // mbarObjectiveFunction()
1508     double objectiveFunction = mbarObjectiveFunction(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1509     pass = !(abs(objectiveFunction - 4786294.2692739945) > 1e-5);
1510     objectiveFunction = mbarObjectiveFunction(reducedPotentials, snapsPerLambda, highTolFEEstimates);
1511     pass = !(abs(objectiveFunction - 4787001.700838844) > 1e-5) && pass;
1512     objectiveFunction = mbarObjectiveFunction(reducedPotentials, snapsPerLambda, zeros);
1513     pass = !(abs(objectiveFunction - 4792767.352152844) > 1e-5) && pass;
1514     results[1] = pass ? "PASS" : "FAIL mbarObjectiveFunction()";
1515 
1516     // mbarGradient()
1517     double[] gradient = mbarGradient(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1518     double[] expected = new double[]{6.067113034191607E-4, -8.777718552011038E-4, 8.210768953631487E-4, -5.500246369471995E-4};
1519     pass = !(normDiff(gradient, expected) > 4e-5);
1520     gradient = mbarGradient(reducedPotentials, snapsPerLambda, highTolFEEstimates);
1521     expected = new double[]{1969.705314577408, 5108.841258429764, -1072.9526887468976, -6005.593884267446};
1522     pass = !(normDiff(gradient, expected) > 4e-5) && pass;
1523     gradient = mbarGradient(reducedPotentials, snapsPerLambda, zeros);
1524     expected = new double[]{22797.82037585665, -3273.72282675803, -8859.999065013779, -10664.098484078011};
1525     pass = !(normDiff(gradient, expected) > 4e-5) && pass;
1526     results[2] = pass ? "PASS" : "FAIL mbarGradient()";
1527 
1528     pass = true;
1529     // mbarHessian()
1530     double[][] hessian = mbarHessian(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1531     double[][] expected2d = new double[][]{{47600.586808418964, -29977.008359691405, -12870.425573135915, -4753.1528755909385},
1532         {-29977.008359691405, 63767.745823769576, -24597.198354108747, -9193.539109971487},
1533         {-12870.425573135915, -24597.198354108747, 64584.87112481013, -27117.247197561417},
1534         {-4753.1528755909385, -9193.539109971487, -27117.247197561417, 41063.93918312612}};
1535     pass = !(normDiff(hessian, expected2d) > 16e-5);
1536     hessian = mbarHessian(reducedPotentials, snapsPerLambda, highTolFEEstimates);
1537     expected2d = new double[][]{{49168.30161780381, -31256.519016487477, -12983.708230229113, -4928.074371082683},
1538         {-31256.519016487477, 66075.94621325849, -25339.462656640117, -9479.964540130917},
1539         {-12983.708230229113, -25339.462656640117, 64308.30940252403, -25985.13851565483},
1540         {-4928.074371082683, -9479.964540130917, -25985.13851565483, 40393.1774268678}};
1541     pass = !(normDiff(hessian, expected2d) > 16e-5) && pass;
1542     hessian = mbarHessian(reducedPotentials, snapsPerLambda, zeros);
1543     expected2d = new double[][]{{56125.271437145464, -33495.87894376072, -15738.011263498352, -6891.381229885624},
1544         {-33495.87894376072, 64613.515110188295, -21970.091845920833, -9147.544320511564},
1545         {-15738.011263498352, -21970.091845920833, 61407.66256511316, -23699.55945569241},
1546         {-6891.381229885624, -9147.544320511564, -23699.55945569241, 39738.48500608951}};
1547     pass = !(normDiff(hessian, expected2d) > 16e-5) && pass;
1548     results[3] = pass ? "PASS" : "FAIL mbarHessian()";
1549 
1550     pass = true;
1551     // mbarTheta() --> Checked by diffMatrix
1552     double[][] theta = mbarTheta(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1553     double[][] diff = diffMatrixCalculation(theta);
1554     expected2d = new double[][]{{0.0, 0.001953125, 0.003400485419234404, 0.004858337095247168},
1555         {0.0020716018980074633, 0.0, 0.002042627017905458, 0.004055968683065466},
1556         {0.003435363105339426, 0.002042627017905458, 0.0, 0.002560568476977909},
1557         {0.0048828125, 0.004055968683065466, 0.0025135815773894045, 0.0}};
1558     pass = !(normDiff(diff, expected2d) > 16e-5);
1559     results[4] = pass ? "PASS" : "FAIL mbarTheta() or diffMatrixCalculation()";
1560 
1561     pass = true;
1562     // selfConsistentUpdate()
1563     double[] updatedF_k = mbarSelfConsistentUpdate(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1564     expected = new double[]{0.0, 0.3474485745068261, 0.5460865662904055, 0.6866650904438742};
1565     pass = !(normDiff(updatedF_k, expected) > 1e-5);
1566     updatedF_k = mbarSelfConsistentUpdate(reducedPotentials, snapsPerLambda, highTolFEEstimates);
1567     expected = new double[]{0.0, 0.327660608017009, 0.4775067849198251, 0.5586442310038073};
1568     pass = !(normDiff(updatedF_k, expected) > 1e-5) && pass;
1569     updatedF_k = mbarSelfConsistentUpdate(reducedPotentials, snapsPerLambda, zeros);
1570     expected = new double[]{0.0, 0.23865416150488983, 0.29814247007871764, 0.31813582643116334};
1571     pass = !(normDiff(updatedF_k, expected) > 1e-5) && pass;
1572     results[5] = pass ? "PASS" : "FAIL mbarSelfConsistentUpdate()";
1573 
1574     pass = true;
1575     // newton()
1576     updatedF_k = newton(highTolFEEstimates, reducedPotentials, snapsPerLambda, 1e-7);
1577     pass = !(normDiff(updatedF_k, freeEnergyEstimates) > 1e-5);
1578     updatedF_k = newton(zeros, reducedPotentials, snapsPerLambda, 1e-7);
1579     pass = !(normDiff(updatedF_k, freeEnergyEstimates) > 1e-5) && pass;
1580     results[6] = pass ? "PASS" : "FAIL newton()";
1581 
1582     return results;
1583   }
1584 
1585   private static double normDiff(double[] a, double[] b) {
1586     double sum = 0.0;
1587     for (int i = 0; i < a.length; i++) {
1588       sum += abs(a[i] - b[i]);
1589     }
1590     return sum;
1591   }
1592 
1593   private static double normDiff(double[][] a, double[][] b) {
1594     double sum = 0.0;
1595     for (int i = 0; i < a.length; i++) {
1596       for (int j = 0; j < a[i].length; j++) {
1597         sum += abs(a[i][j] - b[i][j]);
1598       }
1599     }
1600     return sum;
1601   }
1602 
1603   /**
1604    * Example MBAR code usage and comparison with analytic answers for Harmonic Oscillators.
1605    *
1606    * @param args
1607    */
1608   public static void main(String[] args) {
1609     // Generate sample data
1610     double[] equilPositions = {1, 2, 3, 4}; // Equilibrium positions
1611     double[] springConstants = {.5, 1.0, 1.5, 2}; // Spring constants
1612     int[] samples = {100000, 100000, 100000, 100000}; // Samples per state
1613     double beta = 1.0; // 1 / (kB * T) equivalent
1614     HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(equilPositions, springConstants, beta);
1615     String setting = "u_kln";
1616     System.out.print("Generating sample data... ");
1617     Object[] sampleResult = testCase.sample(samples, setting, (long) 0); // Set seed to fixed value for reproducibility
1618     System.out.println("done. \n");
1619     double[] x_n = (double[]) sampleResult[0];
1620     double[][][] u_kln = (double[][][]) sampleResult[1];
1621     double[] temps = {1 / Constants.R}; // To be passed into MBAR to cancel out beta within calculation
1622 
1623     // Write file for comparison with pymbar
1624     // Output to forcefieldx/testing/mbar/data/harmonic_oscillators/mbarFiles/energies_{i}.mbar
1625     // Get absolute path to root of project
1626     String rootPath = new File("").getAbsolutePath();
1627     File outputPath = new File(rootPath + "/testing/mbar/data/harmonic_oscillators/mbarFiles");
1628     if (!outputPath.exists() && !outputPath.mkdirs()) {
1629       throw new RuntimeException("Failed to create directory: " + outputPath);
1630     }
1631 
1632     double[] temperatures = new double[equilPositions.length];
1633     Arrays.fill(temperatures, temps[0]);
1634     for (int i = 0; i < u_kln.length; i++) {
1635       File file = new File(outputPath, "energies_" + i + ".mbar");
1636       writeFile(u_kln[i], file, temperatures[i]);
1637     }
1638 
1639     // Create an instance of MultistateBennettAcceptanceRatio
1640     System.out.print("Creating MBAR instance and .estimateDG(false) with standard tolerance & zeros seeding...");
1641     //MultistateBennettAcceptanceRatio.VERBOSE = true; // Log Newton/SCI iters and other relevant information
1642     MultistateBennettAcceptanceRatio mbar = new MultistateBennettAcceptanceRatio(equilPositions, u_kln, temps, 1e-7, SeedType.ZEROS);
1643     System.out.println("done! \n\n");
1644     double[] mbarFEEstimates = Arrays.copyOf(mbar.mbarFEEstimates, mbar.mbarFEEstimates.length);
1645     double[] mbarEnthalpyDiff = Arrays.copyOf(mbar.mbarEnthalpy, mbar.mbarEnthalpy.length);
1646     double[] mbarEntropyDiff = Arrays.copyOf(mbar.mbarEntropy, mbar.mbarEntropy.length);
1647     double[] mbarUncertainties = Arrays.copyOf(mbar.mbarUncertainties, mbar.mbarUncertainties.length);
1648     double[][] mbarDiffMatrix = Arrays.copyOf(mbar.uncertaintyMatrix, mbar.uncertaintyMatrix.length);
1649 
1650     // Analytical free energies and entropies
1651     double[] analyticalFreeEnergies = testCase.analyticalFreeEnergies();
1652     double[] error = new double[analyticalFreeEnergies.length];
1653     for (int i = 0; i < error.length; i++) {
1654       error[i] = analyticalFreeEnergies[i] - mbarFEEstimates[i];
1655     }
1656     double[] temp = testCase.analyticalEntropies(0);
1657     double[] analyticEntropyDiff = new double[temp.length - 1];
1658     double[] errorEntropy = new double[temp.length - 1];
1659     for (int i = 0; i < analyticEntropyDiff.length; i++) {
1660       analyticEntropyDiff[i] = temp[i + 1] - temp[i];
1661       errorEntropy[i] = analyticEntropyDiff[i] - mbarEntropyDiff[i];
1662     }
1663 
1664     // Compare the calculated free energy differences with the analytical ones
1665     System.out.println("STANDARD THERMODYNAMIC CALCULATIONS: \n");
1666     System.out.println("Analytical Free Energies: " + Arrays.toString(analyticalFreeEnergies));
1667     System.out.println("MBAR Free Energies:       " + Arrays.toString(mbarFEEstimates));
1668     System.out.println("Free Energy Error:        " + Arrays.toString(error));
1669     System.out.println();
1670     System.out.println("MBAR dG:                  " + Arrays.toString(mbar.mbarFEDifferenceEstimates));
1671     System.out.println("MBAR Uncertainties:       " + Arrays.toString(mbarUncertainties));
1672     System.out.println("MBAR Enthalpy Changes:    " + Arrays.toString(mbarEnthalpyDiff));
1673     System.out.println();
1674     System.out.println("MBAR Entropy Changes:     " + Arrays.toString(mbarEntropyDiff));
1675     System.out.println("Analytic Entropy Changes: " + Arrays.toString(analyticEntropyDiff));
1676     System.out.println("Entropy Error:            " + Arrays.toString(errorEntropy));
1677     System.out.println();
1678     System.out.println("Uncertainty Diff Matrix: ");
1679     for (double[] matrix : mbarDiffMatrix) {
1680       System.out.println(Arrays.toString(matrix));
1681     }
1682     System.out.println("\n\n");
1683 
1684     // Observables
1685     System.out.println("MBAR DERIVED OBSERVABLES: \n");
1686     mbar.setObservableData(u_kln, true, true);
1687     double[] mbarObservableEnsembleAverages = Arrays.copyOf(mbar.mbarObservableEnsembleAverages,
1688         mbar.mbarObservableEnsembleAverages.length);
1689     double[] mbarObservableEnsembleAverageUncertainties = Arrays.copyOf(mbar.mbarObservableEnsembleAverageUncertainties,
1690         mbar.mbarObservableEnsembleAverageUncertainties.length);
1691     System.out.println("Multi-Data Observable Example u_kln:");
1692     System.out.println("MBAR Observable Ensemble Averages (Potential):              " + Arrays.toString(mbarObservableEnsembleAverages));
1693     System.out.println("Analytical Observable Ensemble Averages (Potential):        " + Arrays.toString(testCase.analyticalObservable("potential energy")));
1694     System.out.println("MBAR Observable Ensemble Average Uncertainties (Potential): " + Arrays.toString(mbarObservableEnsembleAverageUncertainties));
1695     System.out.println();
1696 
1697     // Reads data from xAll[0]
1698     double[][][] xAll = new double[equilPositions.length][equilPositions.length][x_n.length];
1699     for (int i = 0; i < xAll[0].length; i++) {
1700       for (int j = 0; j < xAll[0][0].length; j++) {
1701         // Copy data multiple times into same window
1702         xAll[0][i][j] = x_n[j];
1703       }
1704     }
1705     mbar.setObservableData(xAll, false, true);
1706     mbarObservableEnsembleAverages = Arrays.copyOf(mbar.mbarObservableEnsembleAverages,
1707         mbar.mbarObservableEnsembleAverages.length);
1708     mbarObservableEnsembleAverageUncertainties = Arrays.copyOf(mbar.mbarObservableEnsembleAverageUncertainties,
1709         mbar.mbarObservableEnsembleAverageUncertainties.length);
1710     System.out.println("Single-Data Observable Example x_n:");
1711     System.out.println("MBAR Observable Ensemble Averages (Position):              " + Arrays.toString(mbarObservableEnsembleAverages));
1712     System.out.println("Analytical Observable Ensemble Averages (Position):        " + Arrays.toString(testCase.analyticalMeans()));
1713     System.out.println("MBAR Observable Ensemble Average Uncertainties (Position): " + Arrays.toString(mbarObservableEnsembleAverageUncertainties));
1714     System.out.println();
1715   }
1716 
1717   /**
1718    * Harmonic oscillators test case generates data for testing the MBAR implementation
1719    */
1720   public static class HarmonicOscillatorsTestCase {
1721     private final double beta;
1722     private final double[] equilPositions;
1723     private final int n_states;
1724     private final double[] springConstants;
1725 
1726     public HarmonicOscillatorsTestCase(double[] O_k, double[] K_k, double beta) {
1727       this.beta = beta;
1728       this.equilPositions = O_k;
1729       this.n_states = O_k.length;
1730       this.springConstants = K_k;
1731 
1732       if (this.springConstants.length != this.n_states) {
1733         throw new IllegalArgumentException("Lengths of K_k and O_k should be equal");
1734       }
1735     }
1736 
1737     public double[] analyticalMeans() {
1738       return equilPositions;
1739     }
1740 
1741     public double[] analyticalStandardDeviations() {
1742       double[] deviations = new double[n_states];
1743       for (int i = 0; i < n_states; i++) {
1744         deviations[i] = Math.sqrt(1.0 / (beta * springConstants[i]));
1745       }
1746       return deviations;
1747     }
1748 
1749     public double[] analyticalObservable(String observable) {
1750       double[] result = new double[n_states];
1751 
1752       switch (observable) {
1753         case "position" -> {
1754           return analyticalMeans();
1755         }
1756         case "potential energy" -> {
1757           for (int i = 0; i < n_states; i++) {
1758             result[i] = 0.5 / beta;
1759           }
1760         }
1761         case "position^2" -> {
1762           for (int i = 0; i < n_states; i++) {
1763             result[i] = 1.0 / (beta * springConstants[i]) + Math.pow(equilPositions[i], 2);
1764           }
1765         }
1766         case "RMS displacement" -> {
1767           return analyticalStandardDeviations();
1768         }
1769       }
1770 
1771       return result;
1772     }
1773 
1774     public double[] analyticalFreeEnergies() {
1775       int subtractComponentIndex = 0;
1776       double[] fe = new double[n_states];
1777       double subtract = 0.0;
1778       for (int i = 0; i < n_states; i++) {
1779         fe[i] = -0.5 * Math.log(2 * Math.PI / (beta * springConstants[i]));
1780         if (i == 0) {
1781           subtract = fe[subtractComponentIndex];
1782         }
1783         fe[i] -= subtract;
1784       }
1785       return fe;
1786     }
1787 
1788     public double[] analyticalEntropies(int subtractComponent) {
1789       double[] entropies = new double[n_states];
1790       double[] potentialEnergy = analyticalObservable("analytical entropy");
1791       double[] freeEnergies = analyticalFreeEnergies();
1792 
1793       for (int i = 0; i < n_states; i++) {
1794         entropies[i] = potentialEnergy[i] - freeEnergies[i];
1795       }
1796 
1797       return entropies;
1798     }
1799 
1800     /**
1801      * Sample from harmonic oscillator with gaussian and standard deviation.
1802      *
1803      * @param N_k  number of snaps per state
1804      * @param mode only u_kn -> return K x N_tot matrix where u_kn[k,n] is reduced potential of sample n evaluated at state k
1805      * @return u_kn[k, n] is reduced potential of sample n evaluated at state k
1806      */
1807     public Object[] sample(int[] N_k, String mode, Long seed) {
1808       Random random = new Random(seed);
1809 
1810       int N_max = 0;
1811       for (int N : N_k) {
1812         if (N > N_max) {
1813           N_max = N;
1814         }
1815       }
1816 
1817       int N_tot = 0;
1818       for (int N : N_k) {
1819         N_tot += N;
1820       }
1821 
1822       double[][] x_kn = new double[n_states][N_max];
1823       double[][] u_kn = new double[n_states][N_tot];
1824       double[][][] u_kln = new double[n_states][n_states][N_max];
1825       double[] x_n = new double[N_tot];
1826       int[] s_n = new int[N_tot];
1827 
1828       // Sample harmonic oscillators
1829       int index = 0;
1830       for (int k = 0; k < n_states; k++) {
1831         double x0 = equilPositions[k];
1832         double sigma = Math.sqrt(1.0 / (beta * springConstants[k]));
1833 
1834         // Number of snaps
1835         for (int n = 0; n < N_k[k]; n++) {
1836           double x = x0 + random.nextGaussian() * sigma;
1837           x_kn[k][n] = x;
1838           x_n[index] = x;
1839           s_n[index] = k;
1840           // Potential energy evaluations
1841           for (int l = 0; l < n_states; l++) {
1842             double u = beta * 0.5 * springConstants[l] * Math.pow(x - equilPositions[l], 2.0);
1843             u_kln[k][l][n] = u;
1844             u_kn[l][index] = u;
1845           }
1846           index++;
1847         }
1848         // Set the rest of the array to NaN
1849         for (int n = N_k[k]; n < N_max; n++) {
1850           for (int l = 0; l < n_states; l++) {
1851             u_kln[k][l][n] = Double.NaN;
1852           }
1853         }
1854       }
1855 
1856       // Setting corrections
1857       if ("u_kn".equals(mode)) {
1858         return new Object[]{x_n, u_kn, N_k, s_n};
1859       } else if ("u_kln".equals(mode)) {
1860         return new Object[]{x_n, u_kln, N_k, s_n, u_kn};
1861       } else {
1862         throw new IllegalArgumentException("Unknown mode: " + mode);
1863       }
1864     }
1865   }
1866 }