View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2023.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.estimator;
39  
40  import ffx.numerics.OptimizationInterface;
41  import ffx.numerics.integrate.DataSet;
42  import ffx.numerics.integrate.DoublesDataSet;
43  import ffx.numerics.integrate.Integrate1DNumeric;
44  import ffx.numerics.optimization.LBFGS;
45  import ffx.numerics.optimization.LineSearch;
46  import ffx.numerics.optimization.OptimizationListener;
47  import ffx.utilities.Constants;
48  import org.apache.commons.math3.linear.MatrixUtils;
49  import org.apache.commons.math3.linear.RealMatrix;
50  import org.apache.commons.math3.linear.SingularValueDecomposition;
51  
52  import java.io.BufferedWriter;
53  import java.io.File;
54  import java.io.FileWriter;
55  import java.io.IOException;
56  import java.util.ArrayList;
57  import java.util.Arrays;
58  import java.util.Random;
59  import java.util.logging.Logger;
60  
61  import static ffx.numerics.estimator.EstimateBootstrapper.getBootstrapIndices;
62  import static ffx.numerics.estimator.Zwanzig.Directionality.BACKWARDS;
63  import static ffx.numerics.estimator.Zwanzig.Directionality.FORWARDS;
64  import static java.lang.System.arraycopy;
65  import static java.util.Arrays.copyOf;
66  import static java.util.Arrays.stream;
67  import static org.apache.commons.lang3.ArrayFill.fill;
68  import static org.apache.commons.math3.util.FastMath.abs;
69  import static org.apache.commons.math3.util.FastMath.exp;
70  import static org.apache.commons.math3.util.FastMath.log;
71  import static org.apache.commons.math3.util.FastMath.sqrt;
72  
73  /**
74   * The MultistateBennettAcceptanceRatio class defines a statistical estimator based on a generalization
75   * to the Bennett Acceptance Ratio (BAR) method for multiple lambda windows. It requires an input of
76   * K X N array of energies (every window at every snap at every lambda value). No support for different
77   * number of snapshots at each window. This will be caught by the filter, but not by the Harmonic Oscillators
78   * testcase.
79   * <p>
80   * This class implements the method discussed in:
81   * Shirts, M. R. and Chodera, J. D. (2008) Statistically optimal analysis of snaps from multiple equilibrium
82   * states. J. Chem. Phys. 129, 124105. doi:10.1063/1.2978177
83   * <p>
84   * This class is based heavily on the pymbar code, which is available from
85   * <a href="https://github.com/choderalab/pymbar/tree/master">Github</a>.
86   *
87   * @author Matthew J. Speranza
88   * @since 1.0
89   */
90  public class MultistateBennettAcceptanceRatio extends SequentialEstimator implements BootstrappableEstimator, OptimizationInterface {
91    private static final Logger logger = Logger.getLogger(MultistateBennettAcceptanceRatio.class.getName());
92    /**
93     * Default MBAR convergence tolerance.
94     */
95    private static final double DEFAULT_TOLERANCE = 1.0E-7;
96    /**
97     * Number of free of differences between simulation windows. Calculated at the very end.
98     */
99    private final int nFreeEnergyDiffs;
100   /**
101    * MBAR free-energy difference estimates (nFreeEnergyDiffs values).
102    */
103   private final double[] mbarFEDifferenceEstimates;
104   /**
105    * Number of lamda states (basically nFreeEnergyDiffs + 1).
106    */
107   private final int nLambdaStates;
108   /**
109    * MBAR free-energy estimates at each lambda value (nStates values). The first value is defined as 0 throughout to
110    * promote stability. This estimate is novel to the MBAR method and is not seen in BAR or Zwanzig. Only the differences
111    * between these values have physical significance.
112    */
113   private double[] mbarFEEstimates;
114   /**
115    * MBAR observable ensemble estimates.
116    */
117   private double[] mbarObservableEnsembleAverages;
118   private double[] mbarObservableEnsembleAverageUncertainties;
119   /**
120    * MBAR free-energy difference uncertainties.
121    */
122   private double[] mbarUncertainties;
123   /**
124    * Matrix of free-energy difference uncertainties between all i & j
125    */
126   private double[][] uncertaintyMatrix;
127   /**
128    * MBAR convergence tolerance.
129    */
130   private final double tolerance;
131   /**
132    * Random number generator used for bootstrapping.
133    */
134   private final Random random;
135   /**
136    * Total MBAR free-energy difference estimate.
137    */
138   private double totalMBAREstimate;
139   /**
140    * Total MBAR free-energy difference uncertainty.
141    */
142   private double totalMBARUncertainty;
143   /**
144    * MBAR Enthalpy estimates
145    */
146   private double[] mbarEnthalpy;
147 
148   /**
149    * MBAR Entropy estimates
150    */
151   private double[] mbarEntropy;
152 
153   public double[] rtValues;
154 
155   /**
156    * "Reduced" potential energies. -ln(exp(beta * -U)) or more practically U * (1 / RT).
157    * Has shape (nLambdaStates, numSnaps * nLambdaStates)
158    */
159   private double[][] reducedPotentials;
160 
161   private double[][] oAllFlat;
162   private double[][] biasFlat;
163   /**
164    * Seed MBAR calculation with another free energy estimation (BAR,ZWANZIG) or zeros
165    */
166   private SeedType seedType;
167 
168 
169   /**
170    * Enum of MBAR seed types.
171    */
172   public enum SeedType {BAR, ZWANZIG, ZEROS;}
173 
174   public static boolean FORCE_ZEROS_SEED = false;
175   public static boolean VERBOSE = false;
176 
177   /**
178    * Constructor for MBAR estimator.
179    *
180    * @param lambdaValues array of lambda values
181    * @param energiesAll  array of energies at each lambda value
182    * @param temperature  array of temperatures
183    */
184   public MultistateBennettAcceptanceRatio(double[] lambdaValues, double[][][] energiesAll, double[] temperature) {
185     this(lambdaValues, energiesAll, temperature, DEFAULT_TOLERANCE, SeedType.ZWANZIG);
186   }
187 
188   /**
189    * Constructor for MBAR estimator.
190    *
191    * @param lambdaValues array of lambda values
192    * @param energiesAll  array of energies at each lambda value
193    * @param temperature  array of temperatures
194    * @param tolerance    convergence tolerance
195    * @param seedType     seed type for MBAR
196    */
197   public MultistateBennettAcceptanceRatio(double[] lambdaValues, double[][][] energiesAll, double[] temperature,
198                                           double tolerance, SeedType seedType) {
199     super(lambdaValues, energiesAll, temperature);
200     this.tolerance = tolerance;
201     this.seedType = seedType;
202 
203     // MBAR calculates free energy at each lambda value (only the differences between them have physical significance)
204     nLambdaStates = lambdaValues.length;
205     mbarFEEstimates = new double[nLambdaStates];
206 
207     nFreeEnergyDiffs = lambdaValues.length - 1;
208     mbarFEDifferenceEstimates = new double[nFreeEnergyDiffs];
209     mbarUncertainties = new double[nFreeEnergyDiffs];
210     mbarEnthalpy = new double[nFreeEnergyDiffs];
211     mbarEntropy = new double[nFreeEnergyDiffs];
212     random = new Random();
213     estimateDG();
214   }
215 
216   public MultistateBennettAcceptanceRatio(double[] lambdaValues, int[] snaps, double[][] eAllFlat, double[] temperature,
217                                           double tolerance, SeedType seedType) {
218     super(lambdaValues, snaps, eAllFlat, temperature);
219     this.tolerance = tolerance;
220     this.seedType = seedType;
221 
222     // MBAR calculates free energy at each lambda value (only the differences between them have physical significance)
223     nLambdaStates = lambdaValues.length;
224     mbarFEEstimates = new double[nLambdaStates];
225 
226     nFreeEnergyDiffs = lambdaValues.length - 1;
227     mbarFEDifferenceEstimates = new double[nFreeEnergyDiffs];
228     mbarUncertainties = new double[nFreeEnergyDiffs];
229     mbarEnthalpy = new double[nFreeEnergyDiffs];
230     mbarEntropy = new double[nFreeEnergyDiffs];
231     random = new Random();
232     estimateDG();
233   }
234 
235   /**
236    * Set the MBAR seed energies using BAR, Zwanzig, or zeros.
237    */
238   private void seedEnergies() {
239     switch (seedType) {
240       case BAR:
241         try {
242           if (eLambdaMinusdL == null || eLambda == null || eLambdaPlusdL == null) {
243             seedType = SeedType.ZEROS;
244             seedEnergies();
245             return;
246           }
247           SequentialEstimator barEstimator = new BennettAcceptanceRatio(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures);
248           mbarFEEstimates[0] = 0.0;
249           double[] barEstimates = barEstimator.getFreeEnergyDifferences();
250           for (int i = 0; i < nFreeEnergyDiffs; i++) {
251             mbarFEEstimates[i + 1] = mbarFEEstimates[i] + barEstimates[i];
252           }
253           break;
254         } catch (IllegalArgumentException e) {
255           logger.warning(" BAR failed to converge. Zwanzig will be used for seed energies.");
256           seedType = SeedType.ZWANZIG;
257           seedEnergies();
258           return;
259         }
260       case ZWANZIG:
261         try {
262           if (eLambdaMinusdL == null || eLambda == null || eLambdaPlusdL == null) {
263             seedType = SeedType.ZEROS;
264             seedEnergies();
265             return;
266           }
267           Zwanzig forwardsFEP = new Zwanzig(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures, FORWARDS);
268           Zwanzig backwardsFEP = new Zwanzig(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures, BACKWARDS);
269           double[] forwardZwanzig = forwardsFEP.getFreeEnergyDifferences();
270           double[] backwardZwanzig = backwardsFEP.getFreeEnergyDifferences();
271           mbarFEEstimates[0] = 0.0;
272           for (int i = 0; i < nFreeEnergyDiffs; i++) {
273             mbarFEEstimates[i + 1] = mbarFEEstimates[i] + .5 * (forwardZwanzig[i] + backwardZwanzig[i]);
274           }
275           if (stream(mbarFEEstimates).anyMatch(Double::isInfinite) || stream(mbarFEEstimates).anyMatch(Double::isNaN)) {
276             throw new IllegalArgumentException("MBAR contains NaNs or Infs after seeding.");
277           }
278           break;
279         } catch (IllegalArgumentException e) {
280           logger.warning(" Zwanzig failed to converge. Zeros will be used for seed energies.");
281           seedType = SeedType.ZEROS;
282           seedEnergies();
283           return;
284         }
285       case ZEROS:
286         fill(mbarFEEstimates, 0.0);
287         break;
288       default:
289         throw new IllegalArgumentException("Seed type not supported");
290     }
291   }
292 
293   /**
294    * Get the MBAR free-energy estimates at each lambda value.
295    */
296   @Override
297   public void estimateDG() {
298     estimateDG(false);
299   }
300 
301   /**
302    * MBAR solved with self-consistent iteration and Newton/L-BFGS optimization.
303    */
304   @Override
305   public void estimateDG(boolean randomSamples) {
306     if (MultistateBennettAcceptanceRatio.VERBOSE) {
307       logger.setLevel(java.util.logging.Level.FINE);
308     }
309 
310     // Bootstrap needs resetting to zeros
311     fill(mbarFEEstimates, 0.0);
312     if (FORCE_ZEROS_SEED) {
313       seedType = SeedType.ZEROS;
314     }
315     seedEnergies();
316     if (stream(mbarFEEstimates).anyMatch(Double::isInfinite) || stream(mbarFEEstimates).anyMatch(Double::isNaN)) {
317       seedType = SeedType.ZEROS;
318       seedEnergies();
319     }
320     if (MultistateBennettAcceptanceRatio.VERBOSE) {
321       logger.info(" Seed Type: " + seedType);
322       logger.info(" MBAR FE Estimates after seeding: " + Arrays.toString(mbarFEEstimates));
323     }
324 
325     // Precompute beta for each state.
326     rtValues = new double[nLambdaStates];
327     double[] invRTValues = new double[nLambdaStates];
328     for (int i = 0; i < nLambdaStates; i++) {
329       rtValues[i] = Constants.R * temperatures[i];
330       invRTValues[i] = 1.0 / rtValues[i];
331     }
332 
333     // Find repeated snapshots if from continuous lambda
334     int numEvaluations = eAllFlat[0].length;
335 
336     // Sample random snapshots from each window.
337     int[][] indices = new int[nLambdaStates][numEvaluations];
338     if (randomSamples) {
339       // Build random indices vector maintaining snapshot nums!
340       int[] randomIndices = new int[numEvaluations];
341       int sum = 0;
342       for (int snap : nSamples) {
343         System.arraycopy(getBootstrapIndices(snap, random), 0, randomIndices, sum, snap);
344         sum += snap;
345       }
346       for (int i = 0; i < nLambdaStates; i++) {
347         // Use the same random indices across lambda values
348         indices[i] = randomIndices;
349       }
350     } else {
351       for (int i = 0; i < numEvaluations; i++) {
352         for (int j = 0; j < nLambdaStates; j++) {
353           indices[j][i] = i;
354         }
355       }
356     }
357 
358     // Precompute reducedPotentials since it doesn't change
359     reducedPotentials = new double[nLambdaStates][numEvaluations];
360     double minPotential = Double.POSITIVE_INFINITY;
361     for (int state = 0; state < eAllFlat.length; state++) { // For each lambda value
362       for (int n = 0; n < eAllFlat[0].length; n++) {
363         reducedPotentials[state][n] = eAllFlat[state][indices[state][n]] * invRTValues[state];
364         if (reducedPotentials[state][n] < minPotential) {
365           minPotential = reducedPotentials[state][n];
366         }
367       }
368     }
369 
370     // Subtract the minimum potential from all potentials (we are calculating relative free energies anyway)
371     for (int state = 0; state < nLambdaStates; state++) {
372       for (int n = 0; n < numEvaluations; n++) {
373         reducedPotentials[state][n] -= minPotential;
374       }
375     }
376 
377     // Remove reduced potential arrays where snaps are zero since they cause issues for N.R. and SCI
378     // i.e. no trajectories for that lambda were generated/sampled, but other trajectories had potentials evaluated at that lambda
379     ArrayList<Integer> zeroSnapLambdas = new ArrayList<>();
380     ArrayList<Integer> sampledLambdas = new ArrayList<>();
381     for (int i = 0; i < nLambdaStates; i++) {
382       if (nSamples[i] == 0) {
383         zeroSnapLambdas.add(i);
384       } else {
385         sampledLambdas.add(i);
386       }
387     }
388     int nLambdaStatesTemp = nLambdaStates - zeroSnapLambdas.size();
389     double[][] reducedPotentialsTemp = new double[nLambdaStates - zeroSnapLambdas.size()][numEvaluations];
390     double[] mbarFEEstimatesTemp = new double[nLambdaStates - zeroSnapLambdas.size()];
391     int[] snapsTemp = new int[nLambdaStates - zeroSnapLambdas.size()];
392     if (!zeroSnapLambdas.isEmpty()) {
393       int index = 0;
394       for (int i = 0; i < nLambdaStates; i++) {
395         if (!zeroSnapLambdas.contains(i)) {
396           reducedPotentialsTemp[index] = reducedPotentials[i];
397           mbarFEEstimatesTemp[index] = mbarFEEstimates[i];
398           snapsTemp[index] = nSamples[i];
399           index++;
400         }
401       }
402       logger.info(" Sampled Lambdas: " + sampledLambdas);
403       logger.info(" Zero Snap Lambdas: " + zeroSnapLambdas);
404     } else { // If there aren't any zero snap lambdas, just use the original arrays
405       reducedPotentialsTemp = reducedPotentials;
406       mbarFEEstimatesTemp = mbarFEEstimates;
407       snapsTemp = nSamples;
408     }
409 
410     // SCI iterations used to start optimization of MBAR objective function.
411     // Optimizers can struggle when starting too far from the minimum, but SCI doesn't.
412     double[] prevMBAR = copyOf(mbarFEEstimatesTemp, nLambdaStatesTemp);
413     ;
414     double omega = 1.5; // Parameter chosen empirically to work with most systems (> 2 works but not always).
415     for (int i = 0; i < 10; i++) {
416       prevMBAR = copyOf(mbarFEEstimatesTemp, nLambdaStatesTemp);
417       mbarFEEstimatesTemp = mbarSelfConsistentUpdate(reducedPotentialsTemp, snapsTemp, mbarFEEstimatesTemp);
418       for (int j = 0; j < nLambdaStatesTemp; j++) { // SOR
419         mbarFEEstimatesTemp[j] = omega * mbarFEEstimatesTemp[j] + (1 - omega) * prevMBAR[j];
420       }
421       if (stream(mbarFEEstimatesTemp).anyMatch(Double::isInfinite) || stream(mbarFEEstimatesTemp).anyMatch(Double::isNaN)) {
422         throw new IllegalArgumentException("MBAR contains NaNs or Infs during startup SCI ");
423       }
424       if (converged(prevMBAR)) {
425         break;
426       }
427     }
428     if (MultistateBennettAcceptanceRatio.VERBOSE) {
429       logger.info(" Omega for SCI w/ relaxation: " + omega);
430       logger.info(" MBAR FE Estimates after 10 SCI iterations: " + Arrays.toString(mbarFEEstimatesTemp));
431     }
432 
433     try {
434       if (nLambdaStatesTemp > 100 && !converged(prevMBAR)) { // L-BFGS optimization for high granularity windows where hessian^-1 is expensive
435         if (MultistateBennettAcceptanceRatio.VERBOSE) {
436           logger.info(" L-BFGS optimization started.");
437         }
438         int mCorrections = 5;
439         double[] x = new double[nLambdaStatesTemp];
440         arraycopy(mbarFEEstimatesTemp, 0, x, 0, nLambdaStatesTemp);
441         double[] grad = mbarGradient(reducedPotentialsTemp, snapsTemp, mbarFEEstimatesTemp);
442         double eps = 1.0E-4; // Gradient tolarance -> chosen since L-BFGS seems unstable with tight tolerances
443         OptimizationListener listener = getOptimizationListener();
444         LBFGS.minimize(nLambdaStatesTemp, mCorrections, x, mbarObjectiveFunction(reducedPotentialsTemp, snapsTemp, mbarFEEstimatesTemp),
445             grad, eps, 1000, this, listener);
446         arraycopy(x, 0, mbarFEEstimatesTemp, 0, nLambdaStatesTemp);
447       } else if (!converged(prevMBAR)) { // Newton optimization if hessian inversion isn't too expensive
448         if (MultistateBennettAcceptanceRatio.VERBOSE) {
449           logger.info(" Newton optimization started.");
450         }
451         mbarFEEstimatesTemp = newton(mbarFEEstimatesTemp, reducedPotentialsTemp, snapsTemp, tolerance);
452       }
453     } catch (Exception e) {
454       logger.warning(" L-BFGS/Newton failed to converge. Finishing w/ self-consistent iteration. Message: " +
455           e.getMessage());
456     }
457     if (MultistateBennettAcceptanceRatio.VERBOSE) {
458       logger.info(" MBAR FE Estimates after gradient optimization: " + Arrays.toString(mbarFEEstimatesTemp));
459     }
460 
461     // Update the FE estimates with the optimized values from derivative-based optimization
462     int count = 0;
463     for (Integer i : sampledLambdas) {
464       if (!Double.isNaN(mbarFEEstimatesTemp[count])) { // Should be !NaN
465         mbarFEEstimates[i] = mbarFEEstimatesTemp[count];
466       }
467       count++;
468     }
469 
470     // Self-consistent iteration is used to finish off optimization of MBAR objective function
471     int sciIter = 0;
472     while (!converged(prevMBAR) && sciIter < 1000) {
473       prevMBAR = copyOf(mbarFEEstimates, nLambdaStates);
474       mbarFEEstimates = mbarSelfConsistentUpdate(reducedPotentials, nSamples, mbarFEEstimates);
475       for (int i = 0; i < nLambdaStates; i++) { // SOR for acceleration
476         mbarFEEstimates[i] = omega * mbarFEEstimates[i] + (1 - omega) * prevMBAR[i];
477       }
478       if (stream(mbarFEEstimates).anyMatch(Double::isInfinite) || stream(mbarFEEstimates).anyMatch(Double::isNaN)) {
479         throw new IllegalArgumentException("MBAR estimate contains NaNs or Infs after iteration " + sciIter);
480       }
481       sciIter++;
482     }
483     if (MultistateBennettAcceptanceRatio.VERBOSE) {
484       logger.info(" SCI iterations (max 1000): " + sciIter);
485     }
486 
487     // Calculate uncertainties
488     double[][] theta = mbarTheta(reducedPotentials, nSamples, mbarFEEstimates); // Quite expensive
489     mbarUncertainties = mbarUncertaintyCalc(theta);
490     totalMBARUncertainty = mbarTotalUncertaintyCalc(theta);
491     uncertaintyMatrix = diffMatrixCalculation(theta);
492     if (!randomSamples && MultistateBennettAcceptanceRatio.VERBOSE) { // Never log for bootstrapping
493       logWeights();
494     }
495 
496     // Convert to kcal/mol & calculate differences/sums
497     for (int i = 0; i < nLambdaStates; i++) {
498       mbarFEEstimates[i] = mbarFEEstimates[i] * rtValues[i];
499     }
500     for (int i = 0; i < nFreeEnergyDiffs; i++) {
501       mbarFEDifferenceEstimates[i] = mbarFEEstimates[i + 1] - mbarFEEstimates[i];
502     }
503 
504     mbarEnthalpy = mbarEnthalpyCalc(eAllFlat, mbarFEEstimates);
505     mbarEntropy = mbarEntropyCalc(mbarEnthalpy, mbarFEEstimates);
506 
507     totalMBAREstimate = stream(mbarFEDifferenceEstimates).sum();
508   }
509 
510   /* /////// Misc. Methods //////////// */
511 
512   /**
513    * Checks if the MBAR free energy estimates have converged by comparing the difference
514    * between the previous and current free energies. The tolerance is set by the user.
515    *
516    * @param prevMBAR previous MBAR free energy estimates.
517    * @return true if converged, false otherwise
518    */
519   private boolean converged(double[] prevMBAR) {
520     double[] differences = new double[prevMBAR.length];
521     for (int i = 0; i < prevMBAR.length; i++) {
522       differences[i] = abs(prevMBAR[i] - mbarFEEstimates[i]);
523     }
524     return stream(differences).allMatch(d -> d < tolerance);
525   }
526 
527   /**
528    * Print out, for each FE expectation, the sum of the weights for each trajectory. This
529    * gives an array of length nLambdaStates, where each element is the sum of the weights
530    * coming from the trajectory sampled at the lambda value corresponding to that index.
531    *
532    * <p>i.e. collapsedW[0][0] is the sum of the weights in W[0] from the trajectory sampled at
533    * lambda 0. The diagonal of this matrix should be larger than all other values if that
534    * window had proper sampling.
535    */
536   private void logWeights() {
537     logger.info(" MBAR Weight Matrix Information Collapsed:");
538     double[][] W = mbarW(reducedPotentials, nSamples, mbarFEEstimates);
539     double[][] collapsedW = new double[W.length][W.length]; // Collapse W trajectory-wise (into K x K)
540     for (int i = 0; i < nSamples.length; i++) {
541       for (int j = 0; j < W.length; j++) {
542         int start = 0;
543         for (int k = 0; k < i; k++) {
544           start += nSamples[k];
545         }
546         for (int k = 0; k < nSamples[i]; k++) {
547           collapsedW[j][i] += W[j][start + k];
548         }
549       }
550     }
551     for (int i = 0; i < W.length; i++) {
552       logger.info("\n Estimation " + i + ": " + Arrays.toString(collapsedW[i]));
553     }
554     double[] rowSum = new double[W.length];
555     for (int i = 0; i < collapsedW[0].length; i++) {
556       for (double[] trajectory : collapsedW) {
557         rowSum[i] += trajectory[i];
558       }
559     }
560     softMax(rowSum);
561     logger.info("\n Softmax of trajectory weight: " + Arrays.toString(rowSum));
562   }
563 
564 
565   /* /////// Methods for calculating MBAR variables, vectors, and matrices. /////// */
566 
567   /**
568    * MBAR objective function. This is used for L-BFGS optimization.
569    *
570    * @param reducedPotentials   -ln(boltzmann weights)
571    * @param snapsPerLambda      number of snaps per state
572    * @param freeEnergyEstimates free energies
573    * @return The objective function value.
574    */
575   private static double mbarObjectiveFunction(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
576     if (stream(freeEnergyEstimates).anyMatch(Double::isInfinite) || stream(freeEnergyEstimates).anyMatch(Double::isNaN)) {
577       throw new IllegalArgumentException("MBAR contains NaNs or Infs.");
578     }
579     int nStates = freeEnergyEstimates.length;
580     double[] log_denom_n = new double[reducedPotentials[0].length];
581     for (int i = 0; i < reducedPotentials[0].length; i++) {
582       double[] temp = new double[nStates];
583       double maxTemp = Double.NEGATIVE_INFINITY;
584       for (int j = 0; j < nStates; j++) {
585         temp[j] = freeEnergyEstimates[j] - reducedPotentials[j][i];
586         if (temp[j] > maxTemp) {
587           maxTemp = temp[j];
588         }
589       }
590       log_denom_n[i] = logSumExp(temp, snapsPerLambda, maxTemp);
591     }
592     double[] dotNkFk = new double[snapsPerLambda.length];
593     for (int i = 0; i < snapsPerLambda.length; i++) {
594       dotNkFk[i] = snapsPerLambda[i] * freeEnergyEstimates[i];
595     }
596     return stream(log_denom_n).sum() - stream(dotNkFk).sum();
597   }
598 
599   /**
600    * Gradient of the MBAR objective function. C6 in Shirts and Chodera 2008.
601    *
602    * @param reducedPotentials   energies
603    * @param snapsPerLambda      number of snaps per state
604    * @param freeEnergyEstimates free energies
605    * @return Gradient for the mbar objective function.
606    */
607   private static double[] mbarGradient(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
608     int nStates = freeEnergyEstimates.length;
609     double[] log_num_k = new double[nStates];
610     double[] log_denom_n = new double[reducedPotentials[0].length];
611     double[][] logDiff = new double[reducedPotentials.length][reducedPotentials[0].length];
612     double[] maxLogDiff = new double[nStates];
613     Arrays.fill(maxLogDiff, Double.NEGATIVE_INFINITY);
614     for (int i = 0; i < reducedPotentials[0].length; i++) {
615       double[] temp = new double[nStates];
616       double maxTemp = Double.NEGATIVE_INFINITY;
617       for (int j = 0; j < nStates; j++) {
618         temp[j] = freeEnergyEstimates[j] - reducedPotentials[j][i];
619         if (temp[j] > maxTemp) {
620           maxTemp = temp[j];
621         }
622       }
623       log_denom_n[i] = logSumExp(temp, snapsPerLambda, maxTemp);
624       for (int j = 0; j < nStates; j++) {
625         logDiff[j][i] = -log_denom_n[i] - reducedPotentials[j][i];
626         if (logDiff[j][i] > maxLogDiff[j]) {
627           maxLogDiff[j] = logDiff[j][i];
628         }
629       }
630     }
631     for (int i = 0; i < nStates; i++) {
632       log_num_k[i] = logSumExp(logDiff[i], maxLogDiff[i]);
633     }
634     double[] grad = new double[nStates];
635     for (int i = 0; i < nStates; i++) {
636       grad[i] = -1.0 * snapsPerLambda[i] * (1.0 - exp(freeEnergyEstimates[i] + log_num_k[i]));
637     }
638     return grad;
639   }
640 
641   /**
642    * Hessian of the MBAR objective function. C9 in Shirts and Chodera 2008.
643    *
644    * @param reducedPotentials   energies
645    * @param snapsPerLambda      number of snaps per state
646    * @param freeEnergyEstimates free energies
647    * @return Hessian for the mbar objective function.
648    */
649   private static double[][] mbarHessian(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
650     int nStates = freeEnergyEstimates.length;
651     double[][] W = mbarW(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
652     // h = dot(W.T, W) * snapsPerLambda * snapsPerLambda[:, newaxis] - diag(W.sum(0) * snapsPerLambda)
653     double[][] hessian = new double[nStates][nStates];
654     for (int i = 0; i < nStates; i++) {
655       for (int j = 0; j < nStates; j++) {
656         double sum = 0.0;
657         for (int k = 0; k < reducedPotentials[0].length; k++) {
658           sum += W[i][k] * W[j][k];
659         }
660         hessian[i][j] = sum * snapsPerLambda[i] * snapsPerLambda[j];
661       }
662       double wSum = 0.0;
663       for (int k = 0; k < W[i].length; k++) {
664         wSum += W[i][k];
665       }
666       hessian[i][i] -= wSum * snapsPerLambda[i];
667     }
668     // h = -h
669     for (int i = 0; i < nStates; i++) {
670       for (int j = 0; j < nStates; j++) {
671         hessian[i][j] = -hessian[i][j];
672       }
673     }
674     return hessian;
675   }
676 
677   /**
678    * W = exp(freeEnergyEstimates - reducedPotentials.T - log_denominator_n[:, newaxis])
679    * Eq. 9 in Shirts and Chodera 2008.
680    *
681    * @param reducedPotentials   energies
682    * @param snapsPerLambda      number of snaps per state
683    * @param freeEnergyEstimates free energies
684    * @return W matrix.
685    */
686   private static double[][] mbarW(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
687     int nStates = freeEnergyEstimates.length;
688     double[] log_denom_n = new double[reducedPotentials[0].length];
689     for (int i = 0; i < reducedPotentials[0].length; i++) {
690       double[] temp = new double[nStates];
691       double maxTemp = Double.NEGATIVE_INFINITY;
692       for (int j = 0; j < nStates; j++) {
693         temp[j] = freeEnergyEstimates[j] - reducedPotentials[j][i];
694         if (temp[j] > maxTemp) {
695           maxTemp = temp[j];
696         }
697       }
698       // log_denom_n = calculates log(sumOverStates(N_k * exp(FE[j] - reducedPotentials[j][i])))
699       log_denom_n[i] = logSumExp(temp, snapsPerLambda, maxTemp);
700     }
701     // logW = freeEnergyEstimates - reducedPotentials.T - log_denominator_n[:, newaxis]
702     // freeEnergyEstimates[i] = log(ck / ci) --> ratio of normalization constants
703     double[][] W = new double[nStates][reducedPotentials[0].length];
704     for (int i = 0; i < nStates; i++) {
705       for (int j = 0; j < reducedPotentials[0].length; j++) {
706         W[i][j] = exp(freeEnergyEstimates[i] - reducedPotentials[i][j] - log_denom_n[j]);
707       }
708     }
709     return W;
710   }
711 
712   private double[] mbarEnthalpyCalc(double[][] reducedPotentials, double[] mbarFEEstimates) {
713     double[] enthalpy = new double[mbarFEEstimates.length - 1];
714     double[] averagePotential = new double[mbarFEEstimates.length];
715     for (int i = 0; i < reducedPotentials.length; i++) {
716       averagePotential[i] = computeExpectations(eAllFlat[i])[i]; // average potential of ith lambda
717     }
718     for (int i = 0; i < enthalpy.length; i++) {
719       enthalpy[i] = averagePotential[i + 1] - averagePotential[i];
720     }
721     return enthalpy;
722   }
723 
724   private double[] mbarEntropyCalc(double[] mbarEnthalpy, double[] mbarFEEstimates) {
725     double[] entropy = new double[mbarFEEstimates.length - 1];
726     for (int i = 0; i < entropy.length; i++) {
727       entropy[i] = mbarEnthalpy[i] - mbarFEDifferenceEstimates[i]; // dG = dH - TdS || TdS = dH - dG
728     }
729     return entropy;
730   }
731 
732   /**
733    * Weight observable by exp(bias/RT) prior to computing expectation when set.
734    *
735    * @param biasAll
736    * @param multiDataObservable
737    */
738   public void setBiasData(double[][][] biasAll, boolean multiDataObservable) {
739     biasFlat = new double[biasAll.length][biasAll.length * biasAll[0][0].length];
740     if (multiDataObservable) { // Flatten data
741       int[] snapsT = new int[biasAll.length];
742       int[] nanCount = new int[biasAll.length];
743       for (int i = 0; i < biasAll.length; i++) {
744         ArrayList<Double> temp = new ArrayList<>();
745         double maxBias = Double.NEGATIVE_INFINITY;
746         for (int j = 0; j < biasAll.length; j++) {
747           int count = 0;
748           int countNaN = 0;
749           for (int k = 0; k < biasAll[j][i].length; k++) {
750             // Don't include NaN values
751             if (!Double.isNaN(biasAll[j][i][k])) {
752               temp.add(biasAll[j][i][k]);
753               if (biasAll[j][i][k] > maxBias) {
754                 maxBias = biasAll[j][i][k];
755               }
756               count++;
757             } else {
758               countNaN++;
759             }
760           }
761           snapsT[j] = count;
762           nanCount[j] = countNaN;
763         }
764         biasFlat[i] = temp.stream().mapToDouble(Double::doubleValue).toArray();
765         // Regularize bias for this lambda
766         for (int j = 0; j < biasFlat[i].length; j++) {
767           biasFlat[i][j] -= maxBias;
768         }
769       }
770     } else { // Put relevant data into the 0th index
771       int count = 0;
772       double maxBias = Double.NEGATIVE_INFINITY;
773       for (int i = 0; i < biasAll.length; i++) {
774         for (int j = 0; j < biasAll[0][0].length; j++) {
775           if (!Double.isNaN(biasAll[i][i][j])) {
776             biasFlat[0][count] = biasAll[i][i][j];
777             if (biasAll[i][i][j] > maxBias) {
778               maxBias = biasAll[i][i][j];
779             }
780             count++;
781           }
782         }
783       }
784       // Regularize bias for this lambda
785       for (int i = 0; i < biasFlat[0].length; i++) {
786         biasFlat[0][i] -= maxBias;
787       }
788     }
789   }
790 
791   public void setBiasData(double[][] biasData) {
792     this.biasFlat = biasData;
793     // Regularize bias for this lambda
794     for (int i = 0; i < biasFlat.length; i++) {
795       double maxBias = Double.NEGATIVE_INFINITY;
796       for (int j = 0; j < biasFlat[i].length; j++) {
797         if (biasFlat[i][j] > maxBias) {
798           maxBias = biasFlat[i][j];
799         }
800       }
801       for (int j = 0; j < biasFlat[i].length; j++) {
802         biasFlat[i][j] -= maxBias;
803       }
804     }
805   }
806 
807   public void setObservableData(double[][][] oAll, boolean multiDataObservable, boolean uncertainties) {
808     oAllFlat = new double[oAll.length][oAll.length * oAll[0][0].length];
809     if (multiDataObservable) { // Flatten data
810       int[] snapsT = new int[oAll.length];
811       int[] nanCount = new int[oAll.length];
812       for (int i = 0; i < oAll.length; i++) {
813         ArrayList<Double> temp = new ArrayList<>();
814         for (int j = 0; j < oAll.length; j++) {
815           int count = 0;
816           int countNaN = 0;
817           for (int k = 0; k < oAll[j][i].length; k++) {
818             // Don't include NaN values
819             if (!Double.isNaN(oAll[j][i][k])) {
820               temp.add(oAll[j][i][k]);
821               count++;
822             } else {
823               countNaN++;
824             }
825           }
826           snapsT[j] = count;
827           nanCount[j] = countNaN;
828         }
829         oAllFlat[i] = temp.stream().mapToDouble(Double::doubleValue).toArray();
830       }
831     } else { // Put relevant data into the 0th index
832       int count = 0;
833       for (int i = 0; i < oAll.length; i++) {
834         for (int j = 0; j < oAll[0][0].length; j++) {
835           if (!Double.isNaN(oAll[i][i][j])) {
836             oAllFlat[0][count] = oAll[i][i][j]; // Note [i][i] indexing
837             count++;
838           }
839         }
840       }
841     }
842     // OST Data
843     if (biasFlat != null) {
844       for (int i = 0; i < oAllFlat.length; i++) {
845         for (int j = 0; j < oAllFlat[i].length; j++) {
846           oAllFlat[i][j] *= exp(biasFlat[i][j] / rtValues[i]);
847         }
848       }
849     }
850     this.fillObservationExpectations(multiDataObservable, uncertainties);
851   }
852 
853   public void setObservableData(double[][] oAll, boolean uncertainties) {
854     oAllFlat = oAll;
855     // OST Data
856     if (biasFlat != null) {
857       if (oAllFlat.length != biasFlat.length || oAllFlat[0].length != biasFlat[0].length) {
858         logger.severe("Observable and bias data are not the same size. Exiting.");
859       }
860       for (int i = 0; i < oAllFlat.length; i++) {
861         for (int j = 0; j < oAllFlat[i].length; j++) {
862           oAllFlat[i][j] *= exp(biasFlat[i][j] / rtValues[i]);
863         }
864       }
865     }
866     this.fillObservationExpectations(oAllFlat.length != 1, uncertainties);
867   }
868 
869   public double getTIIntegral() {
870     DataSet dSet = new DoublesDataSet(Integrate1DNumeric.generateXPoints(0, 1, mbarObservableEnsembleAverages.length, false),
871         mbarObservableEnsembleAverages, false);
872     return Integrate1DNumeric.integrateData(dSet, Integrate1DNumeric.IntegrationSide.LEFT, Integrate1DNumeric.IntegrationType.TRAPEZOIDAL);
873   }
874 
875   /**
876    * Calculate expectation of samples from W matrix. Optionally calculate the uncertainty with
877    * augmented W matrix (incurs a significant computational cost ~10-20x MBAR calculation).
878    *
879    * @return Uncertainty of the observable.
880    */
881   private void fillObservationExpectations(boolean multiData, boolean uncertainties) {
882     if (multiData) {
883       mbarObservableEnsembleAverages = new double[oAllFlat.length];
884       mbarObservableEnsembleAverageUncertainties = new double[oAllFlat.length];
885       for (int i = 0; i < oAllFlat.length; i++) {
886         mbarObservableEnsembleAverages[i] = computeExpectations(oAllFlat[i])[i];
887         if (uncertainties) {
888           mbarObservableEnsembleAverageUncertainties[i] = computeExpectationStd(oAllFlat[i])[i];
889         }
890       }
891     } else {
892       mbarObservableEnsembleAverages = computeExpectations(oAllFlat[0]);
893       if (uncertainties) {
894         mbarObservableEnsembleAverageUncertainties = computeExpectationStd(oAllFlat[0]);
895       }
896     }
897   }
898 
899   /**
900    * Compute the MBAR expectation of a given observable (1xN) for each K. This observable
901    * could be something like x, x^2 (where x is equilibrium for a harmonic oscillator),
902    * or some other function of the configuration X like RMSD from a target conformation.
903    * Additionally, it could be evaluations of some potential at a specific lambda value.
904    * Each trajectory snap should have a corresponding observable value (or evaluation).
905    *
906    * @param samples
907    * @return
908    */
909   private double[] computeExpectations(double[] samples) {
910     double[][] W = mbarW(reducedPotentials, nSamples, mbarFEEstimates);
911     if (W[0].length != samples.length) {
912       logger.severe("Samples and W matrix are not the same length. Exiting.");
913     }
914     double[] expectation = new double[W.length];
915     for (int i = 0; i < W.length; i++) {
916       for (int j = 0; j < W[i].length; j++) {
917         expectation[i] += W[i][j] * samples[j];
918       }
919     }
920     return expectation;
921   }
922 
923   /**
924    * Eq. 13-15 in Shirts and Chodera (2008) for the MBAR observable uncertainty calculation.
925    * Originally implemented as seen in paper, but switched to logsumexp version because of
926    * Inf/NaN issues for large values captured in samples (i.e. potential energies).
927    *
928    * @return WnA matrix.
929    */
930   private double[][] mbarAugmentedW(double[] samples) {
931     int nStates = mbarFEEstimates.length;
932     // Enforce positivity of samples --> from pymbar
933     double minSample = stream(samples).min().getAsDouble() - 3 * java.lang.Math.ulp(1.0); // ulp to avoid zeros
934     if (minSample < 0) {
935       for (int i = 0; i < samples.length; i++) {
936         samples[i] -= minSample;
937       }
938     }
939     // Eq. 14 in Shirts and Chodera (2008)
940     double[][] logCATerms = new double[nStates][reducedPotentials[0].length];
941     double[] maxLogCATerm = new double[reducedPotentials[0].length];
942     Arrays.fill(maxLogCATerm, Double.NEGATIVE_INFINITY);
943     double[] logCA = new double[nStates];
944     double[] log_denom_n = new double[reducedPotentials[0].length];
945     for (int i = 0; i < reducedPotentials[0].length; i++) {
946       double[] temp = new double[nStates];
947       double maxTemp = Double.NEGATIVE_INFINITY;
948       for (int j = 0; j < nStates; j++) {
949         temp[j] = mbarFEEstimates[j] - reducedPotentials[j][i];
950         if (temp[j] > maxTemp) {
951           maxTemp = temp[j];
952         }
953       }
954       log_denom_n[i] = logSumExp(temp, nSamples, maxTemp);
955       for (int j = 0; j < nStates; j++) {
956         logCATerms[j][i] = log(samples[i]) - reducedPotentials[j][i] - log_denom_n[i];
957         if (logCATerms[j][i] > maxLogCATerm[i]) {
958           maxLogCATerm[j] = logCATerms[j][i];
959         }
960       }
961     }
962     for (int i = 0; i < nStates; i++) {
963       logCA[i] = logSumExp(logCATerms[i], maxLogCATerm[i]);
964     }
965     // Eq. 13 in Shirts and Chodera (2008)
966     double[][] WnA = new double[nStates][reducedPotentials[0].length];
967     double[][] Wna = new double[nStates][reducedPotentials[0].length]; // normal W matrix
968     for (int i = 0; i < nStates; i++) {
969       for (int j = 0; j < reducedPotentials[0].length; j++) {
970         WnA[i][j] = samples[j] * exp(-logCA[i] - reducedPotentials[i][j] - log_denom_n[j]);
971         Wna[i][j] = exp(-mbarFEEstimates[i] - reducedPotentials[i][j] - log_denom_n[j]);
972       }
973     }
974     if (minSample < 0) { // reset samples
975       for (int i = 0; i < samples.length; i++) {
976         samples[i] += minSample;
977       }
978     }
979     double[][] augmentedW = new double[nStates * 2][reducedPotentials[0].length];
980     for (int i = 0; i < augmentedW.length; i++) {
981       augmentedW[i] = i < nStates ? Wna[i] : WnA[(i - nStates)];
982     }
983     return augmentedW;
984   }
985 
986   /**
987    * Compute the MBAR uncertainty of an observable. The equations for this are not clear,
988    * but we append an augmented weight matrix (calculated by multiplying the observed values
989    * into the W matrix calculation) to the original W matrix. This is then used to calculate
990    * theta.
991    *
992    * @param samples
993    * @return
994    */
995   private double[] computeExpectationStd(double[] samples) {
996     int[] extendedSnaps = new int[nSamples.length * 2];
997     System.arraycopy(nSamples, 0, extendedSnaps, 0, nSamples.length);
998     RealMatrix theta = MatrixUtils.createRealMatrix(mbarTheta(extendedSnaps, mbarAugmentedW(samples)));
999     double[] expectations = computeExpectations(samples);
1000     double[] diag = new double[expectations.length * 2];
1001     for (int i = 0; i < expectations.length; i++) {
1002       diag[i] = expectations[i];
1003       diag[i + expectations.length] = expectations[i];
1004     }
1005     RealMatrix diagMatrix = MatrixUtils.createRealDiagonalMatrix(diag);
1006     theta = diagMatrix.multiply(theta).multiply(diagMatrix);
1007     RealMatrix ul = theta.getSubMatrix(0, expectations.length - 1, 0, expectations.length - 1);
1008     RealMatrix ur = theta.getSubMatrix(0, expectations.length - 1, expectations.length, expectations.length * 2 - 1);
1009     RealMatrix ll = theta.getSubMatrix(expectations.length, expectations.length * 2 - 1, 0, expectations.length - 1);
1010     RealMatrix lr = theta.getSubMatrix(expectations.length, expectations.length * 2 - 1, expectations.length, expectations.length * 2 - 1);
1011     double[][] covA = ul.add(lr).subtract(ur).subtract(ll).getData(); // Loose precision here
1012     double[] sigma = new double[covA.length];
1013     for (int i = 0; i < covA.length; i++) {
1014       sigma[i] = sqrt(abs(covA[i][i]));
1015     }
1016     return sigma;
1017   }
1018 
1019   /**
1020    * MBAR uncertainty calculation.
1021    *
1022    * @return Uncertainties for the MBAR free energy estimates.
1023    */
1024   private static double[] mbarUncertaintyCalc(double[][] theta) {
1025     double[] uncertainties = new double[theta.length - 1];
1026     // del(dFij) = Theta[i,i] - 2 * Theta[i,j] + Theta[j,j]
1027     for (int i = 0; i < theta.length - 1; i++) {
1028       // TODO: Figure out why negative var is happening (likely due to theta calculation differing from pymbar's)
1029       double variance = theta[i][i] - 2 * theta[i][i + 1] + theta[i + 1][i + 1];
1030       if (variance < 0) {
1031         if (MultistateBennettAcceptanceRatio.VERBOSE) {
1032           logger.warning(" Negative variance detected in MBAR uncertainty calculation. " +
1033               "Multiplying by -1 to get real value. Check diff matrix to see which variances were negative. " +
1034               "They should be NaN.");
1035         }
1036         variance *= -1;
1037       }
1038       uncertainties[i] = sqrt(variance);
1039     }
1040     return uncertainties;
1041   }
1042 
1043   /**
1044    * MBAR total uncertainty calculation. Eq 12 in Shirts and Chodera (2008).
1045    *
1046    * @param theta matrix of covariances
1047    * @return Total uncertainty for the MBAR free energy estimates.
1048    */
1049   private static double mbarTotalUncertaintyCalc(double[][] theta) {
1050     int nStates = theta.length;
1051     return sqrt(abs(theta[0][0] - 2 * theta[0][nStates - 1] + theta[nStates - 1][nStates - 1]));
1052   }
1053 
1054   /**
1055    * Theta = W.T @ (I - W @ diag(snapsPerState) @ W.T)^-1 @ W.
1056    * <p>
1057    * Requires calculation and inversion of W matrix.
1058    * D4 from supp info of MBAR paper used instead to reduce storage and comp. complexity.
1059    *
1060    * @param reducedPotentials energies
1061    * @param snapsPerState     number of snaps per state
1062    * @param freeEnergies      free energies
1063    * @return Theta matrix.
1064    */
1065   private static double[][] mbarTheta(double[][] reducedPotentials, int[] snapsPerState, double[] freeEnergies) {
1066     return mbarTheta(snapsPerState, mbarW(reducedPotentials, snapsPerState, freeEnergies));
1067   }
1068 
1069   /**
1070    * Compute theta with a given W matrix.
1071    *
1072    * @param snapsPerState
1073    * @param W
1074    * @return
1075    */
1076   private static double[][] mbarTheta(int[] snapsPerState, double[][] W) {
1077     RealMatrix WMatrix = MatrixUtils.createRealMatrix(W).transpose();
1078     RealMatrix I = MatrixUtils.createRealIdentityMatrix(snapsPerState.length);
1079     RealMatrix NkMatrix = MatrixUtils.createRealDiagonalMatrix(stream(snapsPerState).mapToDouble(i -> i).toArray());
1080     SingularValueDecomposition svd = new SingularValueDecomposition(WMatrix);
1081     RealMatrix V = svd.getV();
1082     RealMatrix S = MatrixUtils.createRealDiagonalMatrix(svd.getSingularValues());
1083 
1084     // W.T @ (I - W @ diag(snapsPerState) @ W.T)^-1 @ W
1085     // = V @ S @ (I - S @ V.T @ diag(snapsPerState) @ V @ S)^-1 @ S @ V.T
1086     RealMatrix theta = S.multiply(V.transpose());
1087     theta = theta.multiply(NkMatrix).multiply(V).multiply(S);
1088     theta = I.subtract(theta);
1089     theta = MatrixUtils.inverse(theta); // pinv equivalent
1090     theta = V.multiply(S).multiply(theta).multiply(S).multiply(V.transpose());
1091 
1092     return theta.getData();
1093   }
1094 
1095   /**
1096    * MBAR uncertainty matrix calculation. diff[i][j] gives FE uncertainty of moving between
1097    * lambda i-> j.
1098    *
1099    * @param theta matrix of covariances
1100    * @return Diff matrix for the MBAR free energy estimates.
1101    */
1102   private static double[][] diffMatrixCalculation(double[][] theta) {
1103     double[][] diffMatrix = new double[theta.length][theta.length];
1104     for (int i = 0; i < diffMatrix.length; i++) {
1105       for (int j = 0; j < diffMatrix.length; j++) {
1106         diffMatrix[i][j] = sqrt(theta[i][i] - 2 * theta[i][j] + theta[j][j]);
1107       }
1108     }
1109     return diffMatrix;
1110   }
1111 
1112   /* /////// Methods for solving MBAR with self-consistent iteration, L-BFGS optimization, and Newton-Raphson. ////// */
1113 
1114   /**
1115    * Self-consistent iteration to update free energies. Eq. 11 from Shirts and Chodera (2008).
1116    *
1117    * @param reducedPotential    energies
1118    * @param snapsPerLambda      number of snaps per state
1119    * @param freeEnergyEstimates free energies
1120    * @return updated free energies
1121    */
1122   private static double[] mbarSelfConsistentUpdate(double[][] reducedPotential, int[] snapsPerLambda,
1123                                                    double[] freeEnergyEstimates) {
1124     int nStates = freeEnergyEstimates.length;
1125     double[] updatedF_k = new double[nStates];
1126     double[] log_denom_n = new double[reducedPotential[0].length];
1127     double[][] logDiff = new double[reducedPotential.length][reducedPotential[0].length];
1128     double[] maxLogDiff = new double[nStates];
1129     fill(maxLogDiff, Double.NEGATIVE_INFINITY);
1130     for (int i = 0; i < reducedPotential[0].length; i++) {
1131       double[] temp = new double[nStates];
1132       double maxTemp = Double.NEGATIVE_INFINITY;
1133       for (int j = 0; j < nStates; j++) {
1134         temp[j] = freeEnergyEstimates[j] - reducedPotential[j][i];
1135         if (temp[j] > maxTemp) {
1136           maxTemp = temp[j];
1137         }
1138       }
1139       log_denom_n[i] = logSumExp(temp, snapsPerLambda, maxTemp);
1140       for (int j = 0; j < nStates; j++) {
1141         logDiff[j][i] = -log_denom_n[i] - reducedPotential[j][i];
1142         if (logDiff[j][i] > maxLogDiff[j]) {
1143           maxLogDiff[j] = logDiff[j][i];
1144         }
1145       }
1146     }
1147 
1148     for (int i = 0; i < nStates; i++) {
1149       updatedF_k[i] = -1.0 * logSumExp(logDiff[i], maxLogDiff[i]);
1150     }
1151 
1152     // Constrain f1=0 over the course of iterations to prevent uncontrolled growth in magnitude
1153     double norm = updatedF_k[0];
1154     updatedF_k[0] = 0.0;
1155     for (int i = 1; i < nStates; i++) {
1156       updatedF_k[i] = updatedF_k[i] - norm;
1157     }
1158 
1159     return updatedF_k;
1160   }
1161 
1162   /**
1163    * Newton-Raphson step for MBAR optimization. Falls back to the steepest descent if hessian is singular.
1164    * <p>
1165    * The matrix can come back from being singular after several iterations, so it isn't worth moving to L-BFGS.
1166    *
1167    * @param n        current free energies.
1168    * @param grad     gradient of the objective function.
1169    * @param hessian  hessian of the objective function.
1170    * @param stepSize step size for the Newton-Raphson step.
1171    * @return updated free energies.
1172    */
1173   private static double[] newtonStep(double[] n, double[] grad, double[][] hessian, double stepSize) {
1174     double[] nPlusOne = new double[n.length];
1175     double[] step;
1176     try {
1177       RealMatrix hessianInverse = MatrixUtils.inverse(MatrixUtils.createRealMatrix(hessian));
1178       step = hessianInverse.preMultiply(grad);
1179     } catch (IllegalArgumentException e) {
1180       if (MultistateBennettAcceptanceRatio.VERBOSE) {
1181         logger.info(" Singular matrix detected in MBAR Newton-Raphson step. Performing steepest descent step.");
1182       }
1183       step = grad;
1184       stepSize = 1e-5;
1185     }
1186     // Zero out the first term of the step
1187     double temp = step[0];
1188     step[0] = 0.0;
1189     for (int i = 1; i < step.length; i++) {
1190       step[i] -= temp;
1191     }
1192     for (int i = 0; i < n.length; i++) {
1193       nPlusOne[i] = n[i] - step[i] * stepSize;
1194     }
1195     return nPlusOne;
1196   }
1197 
1198   /**
1199    * Newton-Raphson optimization for MBAR.
1200    *
1201    * @param freeEnergyEstimates free energies.
1202    * @param reducedPotentials   energies.
1203    * @param snapsPerLambda      number of snaps per state.
1204    * @param tolerance           convergence tolerance.
1205    * @return updated free energies.
1206    */
1207   private static double[] newton(double[] freeEnergyEstimates, double[][] reducedPotentials,
1208                                  int[] snapsPerLambda, double tolerance) {
1209     double[] grad = mbarGradient(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1210     double[][] hessian = mbarHessian(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1211     double[] f_kPlusOne = newtonStep(freeEnergyEstimates, grad, hessian, 1.0);
1212     int iter = 1;
1213     while (iter < 15) { // Quadratic convergence is expected, SCI will run anyway
1214       freeEnergyEstimates = f_kPlusOne;
1215       grad = mbarGradient(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1216       hessian = mbarHessian(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1217       // Catches singular matrices and performs steepest descent
1218       f_kPlusOne = newtonStep(freeEnergyEstimates, grad, hessian, 1.0);
1219       double eps = 0.0;
1220       for (int i = 0; i < freeEnergyEstimates.length; i++) {
1221         eps += abs(grad[i]);
1222       }
1223       if (eps < tolerance) {
1224         break;
1225       }
1226       iter++;
1227     }
1228     if (MultistateBennettAcceptanceRatio.VERBOSE) {
1229       logger.info(" Newton iterations (max 15): " + iter);
1230     }
1231 
1232     return f_kPlusOne;
1233   }
1234 
1235   /**
1236    * Calculates the log of the sum of the exponential of the given values.
1237    * <p>
1238    * The max value is subtracted from each value in the array before exponentiation to prevent overflow.
1239    *
1240    * @param values The values to exponential and sum.
1241    * @param max    The max value is subtracted from each value in the array prior to exponentiation.
1242    * @return the sum
1243    */
1244   private static double logSumExp(double[] values, double max) {
1245     int[] b = fill(new int[values.length], 1);
1246     return logSumExp(values, b, max);
1247   }
1248 
1249   /**
1250    * Calculates the log of the sum of the exponential of the given values.
1251    * <p>
1252    * The max value is subtracted from each value in the array before exponentiation to prevent overflow.
1253    * MBAR calculation is easiest to do in log terms, only exponentiating when required. Prevents zeros
1254    * in the denominator.
1255    *
1256    * @param values The values to exponential and sum.
1257    * @param max    The max value is subtracted from each value in the array prior to exponentiation.
1258    * @param b      Weights for each value in the array.
1259    * @return the sum
1260    */
1261   private static double logSumExp(double[] values, int[] b, double max) {
1262     // ChatGPT mostly wrote this and I tweaked it to match more closely with scipy's log-sum-exp implementation
1263     // Find the maximum value in the array.
1264     assert values.length == b.length : "values and b must be the same length";
1265 
1266     // Subtract the maximum value from each value in the array, exponential the result, and add up these values.
1267     double sum = 0.0;
1268     for (int i = 0; i < values.length; i++) {
1269       sum += b[i] * exp(values[i] - max);
1270     }
1271 
1272     // Take the natural logarithm of the sum and add the maximum value back in.
1273     return max + log(sum);
1274   }
1275 
1276   /**
1277    * Turns vector into probability distribution.
1278    *
1279    * @param values
1280    */
1281   private static void softMax(double[] values) {
1282     double max = stream(values).max().getAsDouble();
1283     double sum = 0.0;
1284     for (int i = 0; i < values.length; i++) {
1285       values[i] = exp(values[i] - max);
1286       sum += values[i];
1287     }
1288     for (int i = 0; i < values.length; i++) {
1289       values[i] /= sum;
1290     }
1291   }
1292 
1293   /**
1294    * TODO: Log out the MBAR optimization progress.
1295    *
1296    * @return
1297    */
1298   private OptimizationListener getOptimizationListener() {
1299     return new OptimizationListener() {
1300       @Override
1301       public boolean optimizationUpdate(int iter, int nBFGS, int nFunctionEvals, double gradientRMS,
1302                                         double coordinateRMS, double f, double df, double angle,
1303                                         LineSearch.LineSearchResult info) {
1304         return true;
1305       }
1306     };
1307   }
1308 
1309   /**
1310    * MBAR objective function evaluation at a given free energy estimate for L-BFGS optimization.
1311    *
1312    * @param x Input parameters.
1313    * @return The objective function value at the given parameters.
1314    */
1315   @Override
1316   public double energy(double[] x) {
1317     // Zero out the first term
1318     double tempO = x[0];
1319     x[0] = 0.0;
1320     for (int i = 1; i < x.length; i++) {
1321       x[i] -= tempO;
1322     }
1323     return mbarObjectiveFunction(reducedPotentials, nSamples, x);
1324   }
1325 
1326   /**
1327    * MBAR objective function evaluation and gradient at a given free energy estimate for L-BFGS optimization.
1328    *
1329    * @param x Input parameters.
1330    * @param g The gradient with respect to each parameter.
1331    * @return The objective function value at the given parameters.
1332    */
1333   @Override
1334   public double energyAndGradient(double[] x, double[] g) {
1335     double tempO = x[0];
1336     x[0] = 0.0;
1337     for (int i = 1; i < x.length; i++) {
1338       x[i] -= tempO;
1339     }
1340     double[] tempG = mbarGradient(reducedPotentials, nSamples, x);
1341     arraycopy(tempG, 0, g, 0, g.length);
1342     return mbarObjectiveFunction(reducedPotentials, nSamples, x);
1343   }
1344 
1345   @Override
1346   public double[] getCoordinates(double[] parameters) {
1347     return new double[0];
1348   }
1349 
1350   @Override
1351   public void setCoordinates(double[] parameters) {
1352     // Do nothing.
1353   }
1354 
1355   @Override
1356   public int getNumberOfVariables() {
1357     return 0;
1358   }
1359 
1360   @Override
1361   public double[] getScaling() {
1362     return null;
1363   }
1364 
1365   @Override
1366   public void setScaling(double[] scaling) {
1367   }
1368 
1369   @Override
1370   public double getTotalEnergy() {
1371     return 0;
1372   }
1373 
1374   /// ///// Getters and setters ////////
1375 
1376   public BennettAcceptanceRatio getBAR() {
1377     return new BennettAcceptanceRatio(lamValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperatures);
1378   }
1379 
1380   @Override
1381   public MultistateBennettAcceptanceRatio copyEstimator() {
1382     return new MultistateBennettAcceptanceRatio(lamValues, eAll, temperatures, tolerance, seedType);
1383   }
1384 
1385   @Override
1386   public double[] getFreeEnergyDifferences() {
1387     return mbarFEDifferenceEstimates;
1388   }
1389 
1390   public double[] getMBARFreeEnergies() {
1391     return mbarFEEstimates;
1392   }
1393 
1394   public double[][] getReducedPotentials() {
1395     return reducedPotentials;
1396   }
1397 
1398   public int[] getSnaps() {
1399     return nSamples;
1400   }
1401 
1402   @Override
1403   public double[] getFEDifferenceUncertainties() {
1404     return mbarUncertainties;
1405   }
1406 
1407   public double[] getObservationEnsembleAverages() {
1408     return mbarObservableEnsembleAverages;
1409   }
1410 
1411   public double[] getObservationEnsembleUncertainties() {
1412     return mbarObservableEnsembleAverageUncertainties;
1413   }
1414 
1415   public double[][] getUncertaintyMatrix() {
1416     return uncertaintyMatrix;
1417   }
1418 
1419   @Override
1420   public double getTotalFreeEnergyDifference() {
1421     return totalMBAREstimate;
1422   }
1423 
1424   @Override
1425   public double getTotalFEDifferenceUncertainty() {
1426     return totalMBARUncertainty;
1427   }
1428 
1429   @Override
1430   public int getNumberOfBins() {
1431     return nFreeEnergyDiffs;
1432   }
1433 
1434   @Override
1435   public double[] getEnthalpyDifferences() {
1436     return mbarEnthalpy;
1437   }
1438 
1439   /**
1440    * {@inheritDoc}
1441    */
1442   @Override
1443   public double getTotalEnthalpyDifference() {
1444     return getTotalEnthalpyDifference(mbarEnthalpy);
1445   }
1446 
1447   public double[] getBinEntropies() {
1448     return mbarEntropy;
1449   }
1450 
1451   public static void writeFile(double[][] energies, File file, double temperature) {
1452     try (FileWriter fw = new FileWriter(file);
1453          BufferedWriter bw = new BufferedWriter(fw)) {
1454       // Write the number of snapshots and the temperature on the first line
1455       bw.write(energies[0].length + " " + temperature);
1456       bw.newLine();
1457 
1458       // Write the energies
1459       StringBuilder sb = new StringBuilder();
1460       for (int i = 0; i < energies[0].length; i++) {
1461         sb.append("     ").append(i).append(" "); // Write the index of the snapshot
1462         for (int j = 0; j < energies.length; j++) {
1463           sb.append("    ").append(energies[j][i]).append(" ");
1464         }
1465         sb.append("\n");
1466         bw.write(sb.toString());
1467         sb = new StringBuilder(); // Very important
1468       }
1469     } catch (IOException e) {
1470       e.printStackTrace();
1471     }
1472   }
1473 
1474   /**
1475    * Test all MBAR methods individually with a simple Harmonic Oscillator test case with an
1476    * excess of samples. "PASS" indicates that the test passed, while "FAIL" followed by the
1477    * method name indicates that the test failed.
1478    * <p>
1479    * Last updated - 06/11/2024
1480    *
1481    * @return array of test results
1482    */
1483   public static String[] testMBARMethods() {
1484     // Set up highly converged test case
1485     double[] O_k = {1, 2, 3, 4};
1486     double[] K_k = {.5, 1.0, 1.5, 2};
1487     int[] N_k = {100000, 100000, 100000, 100000};
1488     double beta = 1.0;
1489     HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(O_k, K_k, beta);
1490     String setting = "u_kln";
1491     Object[] sampleResult = testCase.sample(N_k, setting, (long) 0);
1492     double[][][] u_kln = (double[][][]) sampleResult[1];
1493     double[] temps = {1 / Constants.R};
1494     MultistateBennettAcceptanceRatio mbar = new MultistateBennettAcceptanceRatio(O_k, u_kln, temps, 1.0E-7, MultistateBennettAcceptanceRatio.SeedType.ZEROS);
1495     MultistateBennettAcceptanceRatio mbarHigherTol = new MultistateBennettAcceptanceRatio(O_k, u_kln, temps, 1.0, MultistateBennettAcceptanceRatio.SeedType.ZEROS);
1496     String[] results = new String[7];
1497     // Get required information for all methods
1498     double[][] reducedPotentials = mbar.getReducedPotentials();
1499     double[] freeEnergyEstimates = mbar.getMBARFreeEnergies();
1500     double[] highTolFEEstimates = mbarHigherTol.getMBARFreeEnergies();
1501     double[] zeros = new double[freeEnergyEstimates.length];
1502     int[] snapsPerLambda = mbar.getSnaps();
1503 
1504     // getMBARFreeEnergies()
1505     double[] expectedFEEstimates = new double[]{0.0, 0.3474485596619945, 0.5460865684340613, 0.6866650788765148};
1506     boolean pass = normDiff(freeEnergyEstimates, expectedFEEstimates) < 1e-5;
1507     expectedFEEstimates = new double[]{0.0, 0.35798124225733474, 0.44721370511807645, 0.477203739646745};
1508     pass = normDiff(highTolFEEstimates, expectedFEEstimates) < 1e-5 && pass;
1509     results[0] = pass ? "PASS" : "FAIL getMBARFreeEnergies()";
1510 
1511     // mbarObjectiveFunction()
1512     double objectiveFunction = mbarObjectiveFunction(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1513     pass = !(abs(objectiveFunction - 4786294.2692739945) > 1e-5);
1514     objectiveFunction = mbarObjectiveFunction(reducedPotentials, snapsPerLambda, highTolFEEstimates);
1515     pass = !(abs(objectiveFunction - 4787001.700838844) > 1e-5) && pass;
1516     objectiveFunction = mbarObjectiveFunction(reducedPotentials, snapsPerLambda, zeros);
1517     pass = !(abs(objectiveFunction - 4792767.352152844) > 1e-5) && pass;
1518     results[1] = pass ? "PASS" : "FAIL mbarObjectiveFunction()";
1519 
1520     // mbarGradient()
1521     double[] gradient = mbarGradient(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1522     double[] expected = new double[]{6.067113034191607E-4, -8.777718552011038E-4, 8.210768953631487E-4, -5.500246369471995E-4};
1523     pass = !(normDiff(gradient, expected) > 4e-5);
1524     gradient = mbarGradient(reducedPotentials, snapsPerLambda, highTolFEEstimates);
1525     expected = new double[]{1969.705314577408, 5108.841258429764, -1072.9526887468976, -6005.593884267446};
1526     pass = !(normDiff(gradient, expected) > 4e-5) && pass;
1527     gradient = mbarGradient(reducedPotentials, snapsPerLambda, zeros);
1528     expected = new double[]{22797.82037585665, -3273.72282675803, -8859.999065013779, -10664.098484078011};
1529     pass = !(normDiff(gradient, expected) > 4e-5) && pass;
1530     results[2] = pass ? "PASS" : "FAIL mbarGradient()";
1531 
1532     pass = true;
1533     // mbarHessian()
1534     double[][] hessian = mbarHessian(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1535     double[][] expected2d = new double[][]{{47600.586808418964, -29977.008359691405, -12870.425573135915, -4753.1528755909385},
1536         {-29977.008359691405, 63767.745823769576, -24597.198354108747, -9193.539109971487},
1537         {-12870.425573135915, -24597.198354108747, 64584.87112481013, -27117.247197561417},
1538         {-4753.1528755909385, -9193.539109971487, -27117.247197561417, 41063.93918312612}};
1539     pass = !(normDiff(hessian, expected2d) > 16e-5);
1540     hessian = mbarHessian(reducedPotentials, snapsPerLambda, highTolFEEstimates);
1541     expected2d = new double[][]{{49168.30161780381, -31256.519016487477, -12983.708230229113, -4928.074371082683},
1542         {-31256.519016487477, 66075.94621325849, -25339.462656640117, -9479.964540130917},
1543         {-12983.708230229113, -25339.462656640117, 64308.30940252403, -25985.13851565483},
1544         {-4928.074371082683, -9479.964540130917, -25985.13851565483, 40393.1774268678}};
1545     pass = !(normDiff(hessian, expected2d) > 16e-5) && pass;
1546     hessian = mbarHessian(reducedPotentials, snapsPerLambda, zeros);
1547     expected2d = new double[][]{{56125.271437145464, -33495.87894376072, -15738.011263498352, -6891.381229885624},
1548         {-33495.87894376072, 64613.515110188295, -21970.091845920833, -9147.544320511564},
1549         {-15738.011263498352, -21970.091845920833, 61407.66256511316, -23699.55945569241},
1550         {-6891.381229885624, -9147.544320511564, -23699.55945569241, 39738.48500608951}};
1551     pass = !(normDiff(hessian, expected2d) > 16e-5) && pass;
1552     results[3] = pass ? "PASS" : "FAIL mbarHessian()";
1553 
1554     pass = true;
1555     // mbarTheta() --> Checked by diffMatrix
1556     double[][] theta = mbarTheta(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1557     double[][] diff = diffMatrixCalculation(theta);
1558     expected2d = new double[][]{{0.0, 0.001953125, 0.003400485419234404, 0.004858337095247168},
1559         {0.0020716018980074633, 0.0, 0.002042627017905458, 0.004055968683065466},
1560         {0.003435363105339426, 0.002042627017905458, 0.0, 0.002560568476977909},
1561         {0.0048828125, 0.004055968683065466, 0.0025135815773894045, 0.0}};
1562     pass = !(normDiff(diff, expected2d) > 16e-5);
1563     results[4] = pass ? "PASS" : "FAIL mbarTheta() or diffMatrixCalculation()";
1564 
1565     pass = true;
1566     // selfConsistentUpdate()
1567     double[] updatedF_k = mbarSelfConsistentUpdate(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
1568     expected = new double[]{0.0, 0.3474485745068261, 0.5460865662904055, 0.6866650904438742};
1569     pass = !(normDiff(updatedF_k, expected) > 1e-5);
1570     updatedF_k = mbarSelfConsistentUpdate(reducedPotentials, snapsPerLambda, highTolFEEstimates);
1571     expected = new double[]{0.0, 0.327660608017009, 0.4775067849198251, 0.5586442310038073};
1572     pass = !(normDiff(updatedF_k, expected) > 1e-5) && pass;
1573     updatedF_k = mbarSelfConsistentUpdate(reducedPotentials, snapsPerLambda, zeros);
1574     expected = new double[]{0.0, 0.23865416150488983, 0.29814247007871764, 0.31813582643116334};
1575     pass = !(normDiff(updatedF_k, expected) > 1e-5) && pass;
1576     results[5] = pass ? "PASS" : "FAIL mbarSelfConsistentUpdate()";
1577 
1578     pass = true;
1579     // newton()
1580     updatedF_k = newton(highTolFEEstimates, reducedPotentials, snapsPerLambda, 1e-7);
1581     pass = !(normDiff(updatedF_k, freeEnergyEstimates) > 1e-5);
1582     updatedF_k = newton(zeros, reducedPotentials, snapsPerLambda, 1e-7);
1583     pass = !(normDiff(updatedF_k, freeEnergyEstimates) > 1e-5) && pass;
1584     results[6] = pass ? "PASS" : "FAIL newton()";
1585 
1586     return results;
1587   }
1588 
1589   private static double normDiff(double[] a, double[] b) {
1590     double sum = 0.0;
1591     for (int i = 0; i < a.length; i++) {
1592       sum += abs(a[i] - b[i]);
1593     }
1594     return sum;
1595   }
1596 
1597   private static double normDiff(double[][] a, double[][] b) {
1598     double sum = 0.0;
1599     for (int i = 0; i < a.length; i++) {
1600       for (int j = 0; j < a[i].length; j++) {
1601         sum += abs(a[i][j] - b[i][j]);
1602       }
1603     }
1604     return sum;
1605   }
1606 
1607   /**
1608    * Example MBAR code usage and comparison with analytic answers for Harmonic Oscillators.
1609    *
1610    * @param args
1611    */
1612   public static void main(String[] args) {
1613     // Generate sample data
1614     double[] equilPositions = {1, 2, 3, 4}; // Equilibrium positions
1615     double[] springConstants = {.5, 1.0, 1.5, 2}; // Spring constants
1616     int[] samples = {100000, 100000, 100000, 100000}; // Samples per state
1617     double beta = 1.0; // 1 / (kB * T) equivalent
1618     HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(equilPositions, springConstants, beta);
1619     String setting = "u_kln";
1620     System.out.print("Generating sample data... ");
1621     Object[] sampleResult = testCase.sample(samples, setting, (long) 0); // Set seed to fixed value for reproducibility
1622     System.out.println("done. \n");
1623     double[] x_n = (double[]) sampleResult[0];
1624     double[][][] u_kln = (double[][][]) sampleResult[1];
1625     double[] temps = {1 / Constants.R}; // To be passed into MBAR to cancel out beta within calculation
1626 
1627     // Write file for comparison with pymbar
1628     // Output to forcefieldx/testing/mbar/data/harmonic_oscillators/mbarFiles/energies_{i}.mbar
1629     // Get absolute path to root of project
1630     String rootPath = new File("").getAbsolutePath();
1631     File outputPath = new File(rootPath + "/testing/mbar/data/harmonic_oscillators/mbarFiles");
1632     if (!outputPath.exists() && !outputPath.mkdirs()) {
1633       throw new RuntimeException("Failed to create directory: " + outputPath);
1634     }
1635 
1636     double[] temperatures = new double[equilPositions.length];
1637     Arrays.fill(temperatures, temps[0]);
1638     for (int i = 0; i < u_kln.length; i++) {
1639       File file = new File(outputPath, "energies_" + i + ".mbar");
1640       writeFile(u_kln[i], file, temperatures[i]);
1641     }
1642 
1643     // Create an instance of MultistateBennettAcceptanceRatio
1644     System.out.print("Creating MBAR instance and .estimateDG(false) with standard tolerance & zeros seeding...");
1645     //MultistateBennettAcceptanceRatio.VERBOSE = true; // Log Newton/SCI iters and other relevant information
1646     MultistateBennettAcceptanceRatio mbar = new MultistateBennettAcceptanceRatio(equilPositions, u_kln, temps, 1e-7, SeedType.ZEROS);
1647     System.out.println("done! \n\n");
1648     double[] mbarFEEstimates = Arrays.copyOf(mbar.mbarFEEstimates, mbar.mbarFEEstimates.length);
1649     double[] mbarEnthalpyDiff = Arrays.copyOf(mbar.mbarEnthalpy, mbar.mbarEnthalpy.length);
1650     double[] mbarEntropyDiff = Arrays.copyOf(mbar.mbarEntropy, mbar.mbarEntropy.length);
1651     double[] mbarUncertainties = Arrays.copyOf(mbar.mbarUncertainties, mbar.mbarUncertainties.length);
1652     double[][] mbarDiffMatrix = Arrays.copyOf(mbar.uncertaintyMatrix, mbar.uncertaintyMatrix.length);
1653 
1654     // Analytical free energies and entropies
1655     double[] analyticalFreeEnergies = testCase.analyticalFreeEnergies();
1656     double[] error = new double[analyticalFreeEnergies.length];
1657     for (int i = 0; i < error.length; i++) {
1658       error[i] = analyticalFreeEnergies[i] - mbarFEEstimates[i];
1659     }
1660     double[] temp = testCase.analyticalEntropies(0);
1661     double[] analyticEntropyDiff = new double[temp.length - 1];
1662     double[] errorEntropy = new double[temp.length - 1];
1663     for (int i = 0; i < analyticEntropyDiff.length; i++) {
1664       analyticEntropyDiff[i] = temp[i + 1] - temp[i];
1665       errorEntropy[i] = analyticEntropyDiff[i] - mbarEntropyDiff[i];
1666     }
1667 
1668     // Compare the calculated free energy differences with the analytical ones
1669     System.out.println("STANDARD THERMODYNAMIC CALCULATIONS: \n");
1670     System.out.println("Analytical Free Energies: " + Arrays.toString(analyticalFreeEnergies));
1671     System.out.println("MBAR Free Energies:       " + Arrays.toString(mbarFEEstimates));
1672     System.out.println("Free Energy Error:        " + Arrays.toString(error));
1673     System.out.println();
1674     System.out.println("MBAR dG:                  " + Arrays.toString(mbar.mbarFEDifferenceEstimates));
1675     System.out.println("MBAR Uncertainties:       " + Arrays.toString(mbarUncertainties));
1676     System.out.println("MBAR Enthalpy Changes:    " + Arrays.toString(mbarEnthalpyDiff));
1677     System.out.println();
1678     System.out.println("MBAR Entropy Changes:     " + Arrays.toString(mbarEntropyDiff));
1679     System.out.println("Analytic Entropy Changes: " + Arrays.toString(analyticEntropyDiff));
1680     System.out.println("Entropy Error:            " + Arrays.toString(errorEntropy));
1681     System.out.println();
1682     System.out.println("Uncertainty Diff Matrix: ");
1683     for (double[] matrix : mbarDiffMatrix) {
1684       System.out.println(Arrays.toString(matrix));
1685     }
1686     System.out.println("\n\n");
1687 
1688     // Observables
1689     System.out.println("MBAR DERIVED OBSERVABLES: \n");
1690     mbar.setObservableData(u_kln, true, true);
1691     double[] mbarObservableEnsembleAverages = Arrays.copyOf(mbar.mbarObservableEnsembleAverages,
1692         mbar.mbarObservableEnsembleAverages.length);
1693     double[] mbarObservableEnsembleAverageUncertainties = Arrays.copyOf(mbar.mbarObservableEnsembleAverageUncertainties,
1694         mbar.mbarObservableEnsembleAverageUncertainties.length);
1695     System.out.println("Multi-Data Observable Example u_kln:");
1696     System.out.println("MBAR Observable Ensemble Averages (Potential):              " + Arrays.toString(mbarObservableEnsembleAverages));
1697     System.out.println("Analytical Observable Ensemble Averages (Potential):        " + Arrays.toString(testCase.analyticalObservable("potential energy")));
1698     System.out.println("MBAR Observable Ensemble Average Uncertainties (Potential): " + Arrays.toString(mbarObservableEnsembleAverageUncertainties));
1699     System.out.println();
1700 
1701     // Reads data from xAll[0]
1702     double[][][] xAll = new double[equilPositions.length][equilPositions.length][x_n.length];
1703     for (int i = 0; i < xAll[0].length; i++) {
1704       for (int j = 0; j < xAll[0][0].length; j++) {
1705         // Copy data multiple times into same window
1706         xAll[0][i][j] = x_n[j];
1707       }
1708     }
1709     mbar.setObservableData(xAll, false, true);
1710     mbarObservableEnsembleAverages = Arrays.copyOf(mbar.mbarObservableEnsembleAverages,
1711         mbar.mbarObservableEnsembleAverages.length);
1712     mbarObservableEnsembleAverageUncertainties = Arrays.copyOf(mbar.mbarObservableEnsembleAverageUncertainties,
1713         mbar.mbarObservableEnsembleAverageUncertainties.length);
1714     System.out.println("Single-Data Observable Example x_n:");
1715     System.out.println("MBAR Observable Ensemble Averages (Position):              " + Arrays.toString(mbarObservableEnsembleAverages));
1716     System.out.println("Analytical Observable Ensemble Averages (Position):        " + Arrays.toString(testCase.analyticalMeans()));
1717     System.out.println("MBAR Observable Ensemble Average Uncertainties (Position): " + Arrays.toString(mbarObservableEnsembleAverageUncertainties));
1718     System.out.println();
1719   }
1720 
1721   /**
1722    * Harmonic oscillators test case generates data for testing the MBAR implementation
1723    */
1724   public static class HarmonicOscillatorsTestCase {
1725     private final double beta;
1726     private final double[] equilPositions;
1727     private final int n_states;
1728     private final double[] springConstants;
1729 
1730     public HarmonicOscillatorsTestCase(double[] O_k, double[] K_k, double beta) {
1731       this.beta = beta;
1732       this.equilPositions = O_k;
1733       this.n_states = O_k.length;
1734       this.springConstants = K_k;
1735 
1736       if (this.springConstants.length != this.n_states) {
1737         throw new IllegalArgumentException("Lengths of K_k and O_k should be equal");
1738       }
1739     }
1740 
1741     public double[] analyticalMeans() {
1742       return equilPositions;
1743     }
1744 
1745     public double[] analyticalStandardDeviations() {
1746       double[] deviations = new double[n_states];
1747       for (int i = 0; i < n_states; i++) {
1748         deviations[i] = Math.sqrt(1.0 / (beta * springConstants[i]));
1749       }
1750       return deviations;
1751     }
1752 
1753     public double[] analyticalObservable(String observable) {
1754       double[] result = new double[n_states];
1755 
1756       switch (observable) {
1757         case "position" -> {
1758           return analyticalMeans();
1759         }
1760         case "potential energy" -> {
1761           for (int i = 0; i < n_states; i++) {
1762             result[i] = 0.5 / beta;
1763           }
1764         }
1765         case "position^2" -> {
1766           for (int i = 0; i < n_states; i++) {
1767             result[i] = 1.0 / (beta * springConstants[i]) + Math.pow(equilPositions[i], 2);
1768           }
1769         }
1770         case "RMS displacement" -> {
1771           return analyticalStandardDeviations();
1772         }
1773       }
1774 
1775       return result;
1776     }
1777 
1778     public double[] analyticalFreeEnergies() {
1779       int subtractComponentIndex = 0;
1780       double[] fe = new double[n_states];
1781       double subtract = 0.0;
1782       for (int i = 0; i < n_states; i++) {
1783         fe[i] = -0.5 * Math.log(2 * Math.PI / (beta * springConstants[i]));
1784         if (i == 0) {
1785           subtract = fe[subtractComponentIndex];
1786         }
1787         fe[i] -= subtract;
1788       }
1789       return fe;
1790     }
1791 
1792     public double[] analyticalEntropies(int subtractComponent) {
1793       double[] entropies = new double[n_states];
1794       double[] potentialEnergy = analyticalObservable("analytical entropy");
1795       double[] freeEnergies = analyticalFreeEnergies();
1796 
1797       for (int i = 0; i < n_states; i++) {
1798         entropies[i] = potentialEnergy[i] - freeEnergies[i];
1799       }
1800 
1801       return entropies;
1802     }
1803 
1804     /**
1805      * Sample from harmonic oscillator with gaussian and standard deviation.
1806      *
1807      * @param N_k  number of snaps per state
1808      * @param mode only u_kn -> return K x N_tot matrix where u_kn[k,n] is reduced potential of sample n evaluated at state k
1809      * @return u_kn[k, n] is reduced potential of sample n evaluated at state k
1810      */
1811     public Object[] sample(int[] N_k, String mode, Long seed) {
1812       Random random = new Random(seed);
1813 
1814       int N_max = 0;
1815       for (int N : N_k) {
1816         if (N > N_max) {
1817           N_max = N;
1818         }
1819       }
1820 
1821       int N_tot = 0;
1822       for (int N : N_k) {
1823         N_tot += N;
1824       }
1825 
1826       double[][] x_kn = new double[n_states][N_max];
1827       double[][] u_kn = new double[n_states][N_tot];
1828       double[][][] u_kln = new double[n_states][n_states][N_max];
1829       double[] x_n = new double[N_tot];
1830       int[] s_n = new int[N_tot];
1831 
1832       // Sample harmonic oscillators
1833       int index = 0;
1834       for (int k = 0; k < n_states; k++) {
1835         double x0 = equilPositions[k];
1836         double sigma = Math.sqrt(1.0 / (beta * springConstants[k]));
1837 
1838         // Number of snaps
1839         for (int n = 0; n < N_k[k]; n++) {
1840           double x = x0 + random.nextGaussian() * sigma;
1841           x_kn[k][n] = x;
1842           x_n[index] = x;
1843           s_n[index] = k;
1844           // Potential energy evaluations
1845           for (int l = 0; l < n_states; l++) {
1846             double u = beta * 0.5 * springConstants[l] * Math.pow(x - equilPositions[l], 2.0);
1847             u_kln[k][l][n] = u;
1848             u_kn[l][index] = u;
1849           }
1850           index++;
1851         }
1852         // Set the rest of the array to NaN
1853         for (int n = N_k[k]; n < N_max; n++) {
1854           for (int l = 0; l < n_states; l++) {
1855             u_kln[k][l][n] = Double.NaN;
1856           }
1857         }
1858       }
1859 
1860       // Setting corrections
1861       if ("u_kn".equals(mode)) {
1862         return new Object[]{x_n, u_kn, N_k, s_n};
1863       } else if ("u_kln".equals(mode)) {
1864         return new Object[]{x_n, u_kln, N_k, s_n, u_kn};
1865       } else {
1866         throw new IllegalArgumentException("Unknown mode: " + mode);
1867       }
1868     }
1869   }
1870 }