/*
 * Decompiled with CFR 0.152.
 */
package ffx.numerics.estimator;

import ffx.numerics.OptimizationInterface;
import ffx.numerics.estimator.BennettAcceptanceRatio;
import ffx.numerics.estimator.BootstrappableEstimator;
import ffx.numerics.estimator.EstimateBootstrapper;
import ffx.numerics.estimator.SequentialEstimator;
import ffx.numerics.estimator.Zwanzig;
import ffx.numerics.integrate.DoublesDataSet;
import ffx.numerics.integrate.Integrate1DNumeric;
import ffx.numerics.optimization.LBFGS;
import ffx.numerics.optimization.LineSearch;
import ffx.numerics.optimization.OptimizationListener;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Objects;
import java.util.Random;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.lang3.ArrayFill;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.util.FastMath;

public class MultistateBennettAcceptanceRatio
extends SequentialEstimator
implements BootstrappableEstimator,
OptimizationInterface {
    private static final Logger logger = Logger.getLogger(MultistateBennettAcceptanceRatio.class.getName());
    private static final double DEFAULT_TOLERANCE = 1.0E-7;
    private final int nFreeEnergyDiffs;
    private final double[] mbarFEDifferenceEstimates;
    private final int nLambdaStates;
    private double[] mbarFEEstimates;
    private double[] mbarObservableEnsembleAverages;
    private double[] mbarObservableEnsembleAverageUncertainties;
    private double[] mbarUncertainties;
    private double[][] uncertaintyMatrix;
    private final double tolerance;
    private final Random random;
    private double totalMBAREstimate;
    private double totalMBARUncertainty;
    private double[] mbarEnthalpy;
    private double[] mbarEntropy;
    public double[] rtValues;
    private double[][] reducedPotentials;
    private double[][] oAllFlat;
    private double[][] biasFlat;
    private SeedType seedType;
    public static boolean FORCE_ZEROS_SEED = false;
    public static boolean VERBOSE = false;

    public MultistateBennettAcceptanceRatio(double[] lambdaValues, double[][][] energiesAll, double[] temperature) {
        this(lambdaValues, energiesAll, temperature, 1.0E-7, SeedType.ZWANZIG);
    }

    public MultistateBennettAcceptanceRatio(double[] lambdaValues, double[][][] energiesAll, double[] temperature, double tolerance, SeedType seedType) {
        super(lambdaValues, energiesAll, temperature);
        this.tolerance = tolerance;
        this.seedType = seedType;
        this.nLambdaStates = lambdaValues.length;
        this.mbarFEEstimates = new double[this.nLambdaStates];
        this.nFreeEnergyDiffs = lambdaValues.length - 1;
        this.mbarFEDifferenceEstimates = new double[this.nFreeEnergyDiffs];
        this.mbarUncertainties = new double[this.nFreeEnergyDiffs];
        this.mbarEnthalpy = new double[this.nFreeEnergyDiffs];
        this.mbarEntropy = new double[this.nFreeEnergyDiffs];
        this.random = new Random();
        this.estimateDG();
    }

    public MultistateBennettAcceptanceRatio(double[] lambdaValues, int[] snaps, double[][] eAllFlat, double[] temperature, double tolerance, SeedType seedType) {
        super(lambdaValues, snaps, eAllFlat, temperature);
        this.tolerance = tolerance;
        this.seedType = seedType;
        this.nLambdaStates = lambdaValues.length;
        this.mbarFEEstimates = new double[this.nLambdaStates];
        this.nFreeEnergyDiffs = lambdaValues.length - 1;
        this.mbarFEDifferenceEstimates = new double[this.nFreeEnergyDiffs];
        this.mbarUncertainties = new double[this.nFreeEnergyDiffs];
        this.mbarEnthalpy = new double[this.nFreeEnergyDiffs];
        this.mbarEntropy = new double[this.nFreeEnergyDiffs];
        this.random = new Random();
        this.estimateDG();
    }

    private void seedEnergies() {
        switch (this.seedType.ordinal()) {
            case 0: {
                try {
                    if (this.eLambdaMinusdL == null || this.eLambda == null || this.eLambdaPlusdL == null) {
                        this.seedType = SeedType.ZEROS;
                        this.seedEnergies();
                        return;
                    }
                    BennettAcceptanceRatio barEstimator = new BennettAcceptanceRatio(this.lamValues, this.eLambdaMinusdL, this.eLambda, this.eLambdaPlusdL, this.temperatures);
                    this.mbarFEEstimates[0] = 0.0;
                    double[] barEstimates = barEstimator.getFreeEnergyDifferences();
                    for (int i = 0; i < this.nFreeEnergyDiffs; ++i) {
                        this.mbarFEEstimates[i + 1] = this.mbarFEEstimates[i] + barEstimates[i];
                    }
                    break;
                }
                catch (IllegalArgumentException e) {
                    logger.warning(" BAR failed to converge. Zwanzig will be used for seed energies.");
                    this.seedType = SeedType.ZWANZIG;
                    this.seedEnergies();
                    return;
                }
            }
            case 1: {
                try {
                    if (this.eLambdaMinusdL == null || this.eLambda == null || this.eLambdaPlusdL == null) {
                        this.seedType = SeedType.ZEROS;
                        this.seedEnergies();
                        return;
                    }
                    Zwanzig forwardsFEP = new Zwanzig(this.lamValues, this.eLambdaMinusdL, this.eLambda, this.eLambdaPlusdL, this.temperatures, Zwanzig.Directionality.FORWARDS);
                    Zwanzig backwardsFEP = new Zwanzig(this.lamValues, this.eLambdaMinusdL, this.eLambda, this.eLambdaPlusdL, this.temperatures, Zwanzig.Directionality.BACKWARDS);
                    double[] forwardZwanzig = forwardsFEP.getFreeEnergyDifferences();
                    double[] backwardZwanzig = backwardsFEP.getFreeEnergyDifferences();
                    this.mbarFEEstimates[0] = 0.0;
                    for (int i = 0; i < this.nFreeEnergyDiffs; ++i) {
                        this.mbarFEEstimates[i + 1] = this.mbarFEEstimates[i] + 0.5 * (forwardZwanzig[i] + backwardZwanzig[i]);
                    }
                    if (Arrays.stream(this.mbarFEEstimates).anyMatch(Double::isInfinite) || Arrays.stream(this.mbarFEEstimates).anyMatch(Double::isNaN)) {
                        throw new IllegalArgumentException("MBAR contains NaNs or Infs after seeding.");
                    }
                    break;
                }
                catch (IllegalArgumentException e) {
                    logger.warning(" Zwanzig failed to converge. Zeros will be used for seed energies.");
                    this.seedType = SeedType.ZEROS;
                    this.seedEnergies();
                    return;
                }
            }
            case 2: {
                ArrayFill.fill((double[])this.mbarFEEstimates, (double)0.0);
                break;
            }
            default: {
                throw new IllegalArgumentException("Seed type not supported");
            }
        }
    }

    @Override
    public void estimateDG() {
        this.estimateDG(false);
    }

    @Override
    public void estimateDG(boolean randomSamples) {
        int i;
        int sciIter;
        int n;
        int state;
        if (VERBOSE) {
            logger.setLevel(Level.FINE);
        }
        ArrayFill.fill((double[])this.mbarFEEstimates, (double)0.0);
        if (FORCE_ZEROS_SEED) {
            this.seedType = SeedType.ZEROS;
        }
        this.seedEnergies();
        if (Arrays.stream(this.mbarFEEstimates).anyMatch(Double::isInfinite) || Arrays.stream(this.mbarFEEstimates).anyMatch(Double::isNaN)) {
            this.seedType = SeedType.ZEROS;
            this.seedEnergies();
        }
        if (VERBOSE) {
            logger.info(" Seed Type: " + String.valueOf((Object)this.seedType));
            logger.info(" MBAR FE Estimates after seeding: " + Arrays.toString(this.mbarFEEstimates));
        }
        this.rtValues = new double[this.nLambdaStates];
        double[] invRTValues = new double[this.nLambdaStates];
        for (int i2 = 0; i2 < this.nLambdaStates; ++i2) {
            this.rtValues[i2] = 0.0019872042586408316 * this.temperatures[i2];
            invRTValues[i2] = 1.0 / this.rtValues[i2];
        }
        int numEvaluations = this.eAllFlat[0].length;
        int[][] indices = new int[this.nLambdaStates][numEvaluations];
        if (randomSamples) {
            int[] randomIndices = new int[numEvaluations];
            int sum = 0;
            for (int snap : this.nSamples) {
                System.arraycopy(EstimateBootstrapper.getBootstrapIndices(snap, this.random), 0, randomIndices, sum, snap);
                sum += snap;
            }
            for (int i3 = 0; i3 < this.nLambdaStates; ++i3) {
                indices[i3] = randomIndices;
            }
        } else {
            for (int i4 = 0; i4 < numEvaluations; ++i4) {
                for (int j = 0; j < this.nLambdaStates; ++j) {
                    indices[j][i4] = i4;
                }
            }
        }
        this.reducedPotentials = new double[this.nLambdaStates][numEvaluations];
        double minPotential = Double.POSITIVE_INFINITY;
        for (state = 0; state < this.eAllFlat.length; ++state) {
            for (n = 0; n < this.eAllFlat[0].length; ++n) {
                this.reducedPotentials[state][n] = this.eAllFlat[state][indices[state][n]] * invRTValues[state];
                if (!(this.reducedPotentials[state][n] < minPotential)) continue;
                minPotential = this.reducedPotentials[state][n];
            }
        }
        for (state = 0; state < this.nLambdaStates; ++state) {
            n = 0;
            while (n < numEvaluations) {
                double[] dArray = this.reducedPotentials[state];
                int n2 = n++;
                dArray[n2] = dArray[n2] - minPotential;
            }
        }
        ArrayList<Integer> zeroSnapLambdas = new ArrayList<Integer>();
        ArrayList<Integer> sampledLambdas = new ArrayList<Integer>();
        for (int i5 = 0; i5 < this.nLambdaStates; ++i5) {
            if (this.nSamples[i5] == 0) {
                zeroSnapLambdas.add(i5);
                continue;
            }
            sampledLambdas.add(i5);
        }
        int nLambdaStatesTemp = this.nLambdaStates - zeroSnapLambdas.size();
        double[][] reducedPotentialsTemp = new double[this.nLambdaStates - zeroSnapLambdas.size()][numEvaluations];
        double[] mbarFEEstimatesTemp = new double[this.nLambdaStates - zeroSnapLambdas.size()];
        int[] snapsTemp = new int[this.nLambdaStates - zeroSnapLambdas.size()];
        if (!zeroSnapLambdas.isEmpty()) {
            int index = 0;
            for (int i6 = 0; i6 < this.nLambdaStates; ++i6) {
                if (zeroSnapLambdas.contains(i6)) continue;
                reducedPotentialsTemp[index] = this.reducedPotentials[i6];
                mbarFEEstimatesTemp[index] = this.mbarFEEstimates[i6];
                snapsTemp[index] = this.nSamples[i6];
                ++index;
            }
            logger.info(" Sampled Lambdas: " + String.valueOf(sampledLambdas));
            logger.info(" Zero Snap Lambdas: " + String.valueOf(zeroSnapLambdas));
        } else {
            reducedPotentialsTemp = this.reducedPotentials;
            mbarFEEstimatesTemp = this.mbarFEEstimates;
            snapsTemp = this.nSamples;
        }
        double[] prevMBAR = Arrays.copyOf(mbarFEEstimatesTemp, nLambdaStatesTemp);
        double omega = 1.5;
        for (int i7 = 0; i7 < 10; ++i7) {
            prevMBAR = Arrays.copyOf(mbarFEEstimatesTemp, nLambdaStatesTemp);
            mbarFEEstimatesTemp = MultistateBennettAcceptanceRatio.mbarSelfConsistentUpdate(reducedPotentialsTemp, snapsTemp, mbarFEEstimatesTemp);
            for (int j = 0; j < nLambdaStatesTemp; ++j) {
                mbarFEEstimatesTemp[j] = omega * mbarFEEstimatesTemp[j] + (1.0 - omega) * prevMBAR[j];
            }
            if (Arrays.stream(mbarFEEstimatesTemp).anyMatch(Double::isInfinite) || Arrays.stream(mbarFEEstimatesTemp).anyMatch(Double::isNaN)) {
                throw new IllegalArgumentException("MBAR contains NaNs or Infs during startup SCI ");
            }
            if (this.converged(prevMBAR)) break;
        }
        if (VERBOSE) {
            logger.info(" Omega for SCI w/ relaxation: " + omega);
            logger.info(" MBAR FE Estimates after 10 SCI iterations: " + Arrays.toString(mbarFEEstimatesTemp));
        }
        try {
            if (nLambdaStatesTemp > 100 && !this.converged(prevMBAR)) {
                if (VERBOSE) {
                    logger.info(" L-BFGS optimization started.");
                }
                int mCorrections = 5;
                double[] x = new double[nLambdaStatesTemp];
                System.arraycopy(mbarFEEstimatesTemp, 0, x, 0, nLambdaStatesTemp);
                double[] grad = MultistateBennettAcceptanceRatio.mbarGradient(reducedPotentialsTemp, snapsTemp, mbarFEEstimatesTemp);
                double eps = 1.0E-4;
                OptimizationListener listener = this.getOptimizationListener();
                LBFGS.minimize(nLambdaStatesTemp, mCorrections, x, MultistateBennettAcceptanceRatio.mbarObjectiveFunction(reducedPotentialsTemp, snapsTemp, mbarFEEstimatesTemp), grad, eps, 1000, this, listener);
                System.arraycopy(x, 0, mbarFEEstimatesTemp, 0, nLambdaStatesTemp);
            } else if (!this.converged(prevMBAR)) {
                if (VERBOSE) {
                    logger.info(" Newton optimization started.");
                }
                mbarFEEstimatesTemp = MultistateBennettAcceptanceRatio.newton(mbarFEEstimatesTemp, reducedPotentialsTemp, snapsTemp, this.tolerance);
            }
        }
        catch (Exception e) {
            logger.warning(" L-BFGS/Newton failed to converge. Finishing w/ self-consistent iteration. Message: " + e.getMessage());
        }
        if (VERBOSE) {
            logger.info(" MBAR FE Estimates after gradient optimization: " + Arrays.toString(mbarFEEstimatesTemp));
        }
        int count = 0;
        for (Integer i8 : sampledLambdas) {
            if (!Double.isNaN(mbarFEEstimatesTemp[count])) {
                this.mbarFEEstimates[i8.intValue()] = mbarFEEstimatesTemp[count];
            }
            ++count;
        }
        for (sciIter = 0; !this.converged(prevMBAR) && sciIter < 1000; ++sciIter) {
            prevMBAR = Arrays.copyOf(this.mbarFEEstimates, this.nLambdaStates);
            this.mbarFEEstimates = MultistateBennettAcceptanceRatio.mbarSelfConsistentUpdate(this.reducedPotentials, this.nSamples, this.mbarFEEstimates);
            for (int i9 = 0; i9 < this.nLambdaStates; ++i9) {
                this.mbarFEEstimates[i9] = omega * this.mbarFEEstimates[i9] + (1.0 - omega) * prevMBAR[i9];
            }
            if (!Arrays.stream(this.mbarFEEstimates).anyMatch(Double::isInfinite) && !Arrays.stream(this.mbarFEEstimates).anyMatch(Double::isNaN)) continue;
            throw new IllegalArgumentException("MBAR estimate contains NaNs or Infs after iteration " + sciIter);
        }
        if (VERBOSE) {
            logger.info(" SCI iterations (max 1000): " + sciIter);
        }
        double[][] theta = MultistateBennettAcceptanceRatio.mbarTheta(this.reducedPotentials, this.nSamples, this.mbarFEEstimates);
        this.mbarUncertainties = MultistateBennettAcceptanceRatio.mbarUncertaintyCalc(theta);
        this.totalMBARUncertainty = MultistateBennettAcceptanceRatio.mbarTotalUncertaintyCalc(theta);
        this.uncertaintyMatrix = MultistateBennettAcceptanceRatio.diffMatrixCalculation(theta);
        if (!randomSamples && VERBOSE) {
            this.logWeights();
        }
        for (i = 0; i < this.nLambdaStates; ++i) {
            this.mbarFEEstimates[i] = this.mbarFEEstimates[i] * this.rtValues[i];
        }
        for (i = 0; i < this.nFreeEnergyDiffs; ++i) {
            this.mbarFEDifferenceEstimates[i] = this.mbarFEEstimates[i + 1] - this.mbarFEEstimates[i];
        }
        this.mbarEnthalpy = this.mbarEnthalpyCalc(this.eAllFlat, this.mbarFEEstimates);
        this.mbarEntropy = this.mbarEntropyCalc(this.mbarEnthalpy, this.mbarFEEstimates);
        this.totalMBAREstimate = Arrays.stream(this.mbarFEDifferenceEstimates).sum();
    }

    private boolean converged(double[] prevMBAR) {
        double[] differences = new double[prevMBAR.length];
        for (int i = 0; i < prevMBAR.length; ++i) {
            differences[i] = FastMath.abs((double)(prevMBAR[i] - this.mbarFEEstimates[i]));
        }
        return Arrays.stream(differences).allMatch(d -> d < this.tolerance);
    }

    private void logWeights() {
        int i;
        logger.info(" MBAR Weight Matrix Information Collapsed:");
        double[][] W = MultistateBennettAcceptanceRatio.mbarW(this.reducedPotentials, this.nSamples, this.mbarFEEstimates);
        double[][] collapsedW = new double[W.length][W.length];
        for (i = 0; i < this.nSamples.length; ++i) {
            for (int j = 0; j < W.length; ++j) {
                int k;
                int start = 0;
                for (k = 0; k < i; ++k) {
                    start += this.nSamples[k];
                }
                for (k = 0; k < this.nSamples[i]; ++k) {
                    double[] dArray = collapsedW[j];
                    int n = i;
                    dArray[n] = dArray[n] + W[j][start + k];
                }
            }
        }
        for (i = 0; i < W.length; ++i) {
            logger.info("\n Estimation " + i + ": " + Arrays.toString(collapsedW[i]));
        }
        double[] rowSum = new double[W.length];
        for (int i2 = 0; i2 < collapsedW[0].length; ++i2) {
            for (double[] trajectory : collapsedW) {
                int n = i2;
                rowSum[n] = rowSum[n] + trajectory[i2];
            }
        }
        MultistateBennettAcceptanceRatio.softMax(rowSum);
        logger.info("\n Softmax of trajectory weight: " + Arrays.toString(rowSum));
    }

    private static double mbarObjectiveFunction(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
        if (Arrays.stream(freeEnergyEstimates).anyMatch(Double::isInfinite) || Arrays.stream(freeEnergyEstimates).anyMatch(Double::isNaN)) {
            throw new IllegalArgumentException("MBAR contains NaNs or Infs.");
        }
        int nStates = freeEnergyEstimates.length;
        double[] log_denom_n = new double[reducedPotentials[0].length];
        for (int i = 0; i < reducedPotentials[0].length; ++i) {
            double[] temp = new double[nStates];
            double maxTemp = Double.NEGATIVE_INFINITY;
            for (int j = 0; j < nStates; ++j) {
                temp[j] = freeEnergyEstimates[j] - reducedPotentials[j][i];
                if (!(temp[j] > maxTemp)) continue;
                maxTemp = temp[j];
            }
            log_denom_n[i] = MultistateBennettAcceptanceRatio.logSumExp(temp, snapsPerLambda, maxTemp);
        }
        double[] dotNkFk = new double[snapsPerLambda.length];
        for (int i = 0; i < snapsPerLambda.length; ++i) {
            dotNkFk[i] = (double)snapsPerLambda[i] * freeEnergyEstimates[i];
        }
        return Arrays.stream(log_denom_n).sum() - Arrays.stream(dotNkFk).sum();
    }

    private static double[] mbarGradient(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
        int i;
        int nStates = freeEnergyEstimates.length;
        double[] log_num_k = new double[nStates];
        double[] log_denom_n = new double[reducedPotentials[0].length];
        double[][] logDiff = new double[reducedPotentials.length][reducedPotentials[0].length];
        double[] maxLogDiff = new double[nStates];
        Arrays.fill(maxLogDiff, Double.NEGATIVE_INFINITY);
        for (i = 0; i < reducedPotentials[0].length; ++i) {
            int j;
            double[] temp = new double[nStates];
            double maxTemp = Double.NEGATIVE_INFINITY;
            for (j = 0; j < nStates; ++j) {
                temp[j] = freeEnergyEstimates[j] - reducedPotentials[j][i];
                if (!(temp[j] > maxTemp)) continue;
                maxTemp = temp[j];
            }
            log_denom_n[i] = MultistateBennettAcceptanceRatio.logSumExp(temp, snapsPerLambda, maxTemp);
            for (j = 0; j < nStates; ++j) {
                logDiff[j][i] = -log_denom_n[i] - reducedPotentials[j][i];
                if (!(logDiff[j][i] > maxLogDiff[j])) continue;
                maxLogDiff[j] = logDiff[j][i];
            }
        }
        for (i = 0; i < nStates; ++i) {
            log_num_k[i] = MultistateBennettAcceptanceRatio.logSumExp(logDiff[i], maxLogDiff[i]);
        }
        double[] grad = new double[nStates];
        for (int i2 = 0; i2 < nStates; ++i2) {
            grad[i2] = -1.0 * (double)snapsPerLambda[i2] * (1.0 - FastMath.exp((double)(freeEnergyEstimates[i2] + log_num_k[i2])));
        }
        return grad;
    }

    private static double[][] mbarHessian(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
        int i;
        int nStates = freeEnergyEstimates.length;
        double[][] W = MultistateBennettAcceptanceRatio.mbarW(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
        double[][] hessian = new double[nStates][nStates];
        for (i = 0; i < nStates; ++i) {
            for (int j = 0; j < nStates; ++j) {
                double sum = 0.0;
                for (int k = 0; k < reducedPotentials[0].length; ++k) {
                    sum += W[i][k] * W[j][k];
                }
                hessian[i][j] = sum * (double)snapsPerLambda[i] * (double)snapsPerLambda[j];
            }
            double wSum = 0.0;
            for (int k = 0; k < W[i].length; ++k) {
                wSum += W[i][k];
            }
            double[] dArray = hessian[i];
            int n = i;
            dArray[n] = dArray[n] - wSum * (double)snapsPerLambda[i];
        }
        for (i = 0; i < nStates; ++i) {
            for (int j = 0; j < nStates; ++j) {
                hessian[i][j] = -hessian[i][j];
            }
        }
        return hessian;
    }

    private static double[][] mbarW(double[][] reducedPotentials, int[] snapsPerLambda, double[] freeEnergyEstimates) {
        int nStates = freeEnergyEstimates.length;
        double[] log_denom_n = new double[reducedPotentials[0].length];
        for (int i = 0; i < reducedPotentials[0].length; ++i) {
            double[] temp = new double[nStates];
            double maxTemp = Double.NEGATIVE_INFINITY;
            for (int j = 0; j < nStates; ++j) {
                temp[j] = freeEnergyEstimates[j] - reducedPotentials[j][i];
                if (!(temp[j] > maxTemp)) continue;
                maxTemp = temp[j];
            }
            log_denom_n[i] = MultistateBennettAcceptanceRatio.logSumExp(temp, snapsPerLambda, maxTemp);
        }
        double[][] W = new double[nStates][reducedPotentials[0].length];
        for (int i = 0; i < nStates; ++i) {
            for (int j = 0; j < reducedPotentials[0].length; ++j) {
                W[i][j] = FastMath.exp((double)(freeEnergyEstimates[i] - reducedPotentials[i][j] - log_denom_n[j]));
            }
        }
        return W;
    }

    private double[] mbarEnthalpyCalc(double[][] reducedPotentials, double[] mbarFEEstimates) {
        int i;
        double[] enthalpy = new double[mbarFEEstimates.length - 1];
        double[] averagePotential = new double[mbarFEEstimates.length];
        for (i = 0; i < reducedPotentials.length; ++i) {
            averagePotential[i] = this.computeExpectations(this.eAllFlat[i])[i];
        }
        for (i = 0; i < enthalpy.length; ++i) {
            enthalpy[i] = averagePotential[i + 1] - averagePotential[i];
        }
        return enthalpy;
    }

    private double[] mbarEntropyCalc(double[] mbarEnthalpy, double[] mbarFEEstimates) {
        double[] entropy = new double[mbarFEEstimates.length - 1];
        for (int i = 0; i < entropy.length; ++i) {
            entropy[i] = mbarEnthalpy[i] - this.mbarFEDifferenceEstimates[i];
        }
        return entropy;
    }

    public void setBiasData(double[][][] biasAll, boolean multiDataObservable) {
        this.biasFlat = new double[biasAll.length][biasAll.length * biasAll[0][0].length];
        if (multiDataObservable) {
            int[] snapsT = new int[biasAll.length];
            int[] nanCount = new int[biasAll.length];
            for (int i = 0; i < biasAll.length; ++i) {
                int j;
                ArrayList<Double> temp = new ArrayList<Double>();
                double maxBias = Double.NEGATIVE_INFINITY;
                for (j = 0; j < biasAll.length; ++j) {
                    int count = 0;
                    int countNaN = 0;
                    for (int k = 0; k < biasAll[j][i].length; ++k) {
                        if (!Double.isNaN(biasAll[j][i][k])) {
                            temp.add(biasAll[j][i][k]);
                            if (biasAll[j][i][k] > maxBias) {
                                maxBias = biasAll[j][i][k];
                            }
                            ++count;
                            continue;
                        }
                        ++countNaN;
                    }
                    snapsT[j] = count;
                    nanCount[j] = countNaN;
                }
                this.biasFlat[i] = temp.stream().mapToDouble(Double::doubleValue).toArray();
                j = 0;
                while (j < this.biasFlat[i].length) {
                    double[] dArray = this.biasFlat[i];
                    int n = j++;
                    dArray[n] = dArray[n] - maxBias;
                }
            }
        } else {
            int i;
            int count = 0;
            double maxBias = Double.NEGATIVE_INFINITY;
            for (i = 0; i < biasAll.length; ++i) {
                for (int j = 0; j < biasAll[0][0].length; ++j) {
                    if (Double.isNaN(biasAll[i][i][j])) continue;
                    this.biasFlat[0][count] = biasAll[i][i][j];
                    if (biasAll[i][i][j] > maxBias) {
                        maxBias = biasAll[i][i][j];
                    }
                    ++count;
                }
            }
            i = 0;
            while (i < this.biasFlat[0].length) {
                double[] dArray = this.biasFlat[0];
                int n = i++;
                dArray[n] = dArray[n] - maxBias;
            }
        }
    }

    public void setBiasData(double[][] biasData) {
        this.biasFlat = biasData;
        for (int i = 0; i < this.biasFlat.length; ++i) {
            int j;
            double maxBias = Double.NEGATIVE_INFINITY;
            for (j = 0; j < this.biasFlat[i].length; ++j) {
                if (!(this.biasFlat[i][j] > maxBias)) continue;
                maxBias = this.biasFlat[i][j];
            }
            j = 0;
            while (j < this.biasFlat[i].length) {
                double[] dArray = this.biasFlat[i];
                int n = j++;
                dArray[n] = dArray[n] - maxBias;
            }
        }
    }

    public void setObservableData(double[][][] oAll, boolean multiDataObservable, boolean uncertainties) {
        this.oAllFlat = new double[oAll.length][oAll.length * oAll[0][0].length];
        if (multiDataObservable) {
            int[] snapsT = new int[oAll.length];
            int[] nanCount = new int[oAll.length];
            for (int i = 0; i < oAll.length; ++i) {
                ArrayList<Double> temp = new ArrayList<Double>();
                for (int j = 0; j < oAll.length; ++j) {
                    int count = 0;
                    int countNaN = 0;
                    for (int k = 0; k < oAll[j][i].length; ++k) {
                        if (!Double.isNaN(oAll[j][i][k])) {
                            temp.add(oAll[j][i][k]);
                            ++count;
                            continue;
                        }
                        ++countNaN;
                    }
                    snapsT[j] = count;
                    nanCount[j] = countNaN;
                }
                this.oAllFlat[i] = temp.stream().mapToDouble(Double::doubleValue).toArray();
            }
        } else {
            int count = 0;
            for (int i = 0; i < oAll.length; ++i) {
                for (int j = 0; j < oAll[0][0].length; ++j) {
                    if (Double.isNaN(oAll[i][i][j])) continue;
                    this.oAllFlat[0][count] = oAll[i][i][j];
                    ++count;
                }
            }
        }
        if (this.biasFlat != null) {
            for (int i = 0; i < this.oAllFlat.length; ++i) {
                for (int j = 0; j < this.oAllFlat[i].length; ++j) {
                    double[] dArray = this.oAllFlat[i];
                    int n = j;
                    dArray[n] = dArray[n] * FastMath.exp((double)(this.biasFlat[i][j] / this.rtValues[i]));
                }
            }
        }
        this.fillObservationExpectations(multiDataObservable, uncertainties);
    }

    public void setObservableData(double[][] oAll, boolean uncertainties) {
        this.oAllFlat = oAll;
        if (this.biasFlat != null) {
            if (this.oAllFlat.length != this.biasFlat.length || this.oAllFlat[0].length != this.biasFlat[0].length) {
                logger.severe("Observable and bias data are not the same size. Exiting.");
            }
            for (int i = 0; i < this.oAllFlat.length; ++i) {
                for (int j = 0; j < this.oAllFlat[i].length; ++j) {
                    double[] dArray = this.oAllFlat[i];
                    int n = j;
                    dArray[n] = dArray[n] * FastMath.exp((double)(this.biasFlat[i][j] / this.rtValues[i]));
                }
            }
        }
        this.fillObservationExpectations(this.oAllFlat.length != 1, uncertainties);
    }

    public double getTIIntegral() {
        DoublesDataSet dSet = new DoublesDataSet(Integrate1DNumeric.generateXPoints(0.0, 1.0, this.mbarObservableEnsembleAverages.length, false), this.mbarObservableEnsembleAverages, false);
        return Integrate1DNumeric.integrateData(dSet, Integrate1DNumeric.IntegrationSide.LEFT, Integrate1DNumeric.IntegrationType.TRAPEZOIDAL);
    }

    private void fillObservationExpectations(boolean multiData, boolean uncertainties) {
        if (multiData) {
            this.mbarObservableEnsembleAverages = new double[this.oAllFlat.length];
            this.mbarObservableEnsembleAverageUncertainties = new double[this.oAllFlat.length];
            for (int i = 0; i < this.oAllFlat.length; ++i) {
                this.mbarObservableEnsembleAverages[i] = this.computeExpectations(this.oAllFlat[i])[i];
                if (!uncertainties) continue;
                this.mbarObservableEnsembleAverageUncertainties[i] = this.computeExpectationStd(this.oAllFlat[i])[i];
            }
        } else {
            this.mbarObservableEnsembleAverages = this.computeExpectations(this.oAllFlat[0]);
            if (uncertainties) {
                this.mbarObservableEnsembleAverageUncertainties = this.computeExpectationStd(this.oAllFlat[0]);
            }
        }
    }

    private double[] computeExpectations(double[] samples) {
        double[][] W = MultistateBennettAcceptanceRatio.mbarW(this.reducedPotentials, this.nSamples, this.mbarFEEstimates);
        if (W[0].length != samples.length) {
            logger.severe("Samples and W matrix are not the same length. Exiting.");
        }
        double[] expectation = new double[W.length];
        for (int i = 0; i < W.length; ++i) {
            for (int j = 0; j < W[i].length; ++j) {
                int n = i;
                expectation[n] = expectation[n] + W[i][j] * samples[j];
            }
        }
        return expectation;
    }

    private double[][] mbarAugmentedW(double[] samples) {
        int i;
        int i2;
        int nStates = this.mbarFEEstimates.length;
        double minSample = Arrays.stream(samples).min().getAsDouble() - 3.0 * Math.ulp(1.0);
        if (minSample < 0.0) {
            int i3 = 0;
            while (i3 < samples.length) {
                int n = i3++;
                samples[n] = samples[n] - minSample;
            }
        }
        double[][] logCATerms = new double[nStates][this.reducedPotentials[0].length];
        double[] maxLogCATerm = new double[this.reducedPotentials[0].length];
        Arrays.fill(maxLogCATerm, Double.NEGATIVE_INFINITY);
        double[] logCA = new double[nStates];
        double[] log_denom_n = new double[this.reducedPotentials[0].length];
        for (i2 = 0; i2 < this.reducedPotentials[0].length; ++i2) {
            int j;
            double[] temp = new double[nStates];
            double maxTemp = Double.NEGATIVE_INFINITY;
            for (j = 0; j < nStates; ++j) {
                temp[j] = this.mbarFEEstimates[j] - this.reducedPotentials[j][i2];
                if (!(temp[j] > maxTemp)) continue;
                maxTemp = temp[j];
            }
            log_denom_n[i2] = MultistateBennettAcceptanceRatio.logSumExp(temp, this.nSamples, maxTemp);
            for (j = 0; j < nStates; ++j) {
                logCATerms[j][i2] = FastMath.log((double)samples[i2]) - this.reducedPotentials[j][i2] - log_denom_n[i2];
                if (!(logCATerms[j][i2] > maxLogCATerm[i2])) continue;
                maxLogCATerm[j] = logCATerms[j][i2];
            }
        }
        for (i2 = 0; i2 < nStates; ++i2) {
            logCA[i2] = MultistateBennettAcceptanceRatio.logSumExp(logCATerms[i2], maxLogCATerm[i2]);
        }
        double[][] WnA = new double[nStates][this.reducedPotentials[0].length];
        double[][] Wna = new double[nStates][this.reducedPotentials[0].length];
        for (i = 0; i < nStates; ++i) {
            for (int j = 0; j < this.reducedPotentials[0].length; ++j) {
                WnA[i][j] = samples[j] * FastMath.exp((double)(-logCA[i] - this.reducedPotentials[i][j] - log_denom_n[j]));
                Wna[i][j] = FastMath.exp((double)(-this.mbarFEEstimates[i] - this.reducedPotentials[i][j] - log_denom_n[j]));
            }
        }
        if (minSample < 0.0) {
            i = 0;
            while (i < samples.length) {
                int n = i++;
                samples[n] = samples[n] + minSample;
            }
        }
        double[][] augmentedW = new double[nStates * 2][this.reducedPotentials[0].length];
        for (int i4 = 0; i4 < augmentedW.length; ++i4) {
            augmentedW[i4] = i4 < nStates ? Wna[i4] : WnA[i4 - nStates];
        }
        return augmentedW;
    }

    private double[] computeExpectationStd(double[] samples) {
        int[] extendedSnaps = new int[this.nSamples.length * 2];
        System.arraycopy(this.nSamples, 0, extendedSnaps, 0, this.nSamples.length);
        RealMatrix theta = MatrixUtils.createRealMatrix((double[][])MultistateBennettAcceptanceRatio.mbarTheta(extendedSnaps, this.mbarAugmentedW(samples)));
        double[] expectations = this.computeExpectations(samples);
        double[] diag = new double[expectations.length * 2];
        for (int i = 0; i < expectations.length; ++i) {
            diag[i] = expectations[i];
            diag[i + expectations.length] = expectations[i];
        }
        RealMatrix diagMatrix = MatrixUtils.createRealDiagonalMatrix((double[])diag);
        theta = diagMatrix.multiply(theta).multiply(diagMatrix);
        RealMatrix ul = theta.getSubMatrix(0, expectations.length - 1, 0, expectations.length - 1);
        RealMatrix ur = theta.getSubMatrix(0, expectations.length - 1, expectations.length, expectations.length * 2 - 1);
        RealMatrix ll = theta.getSubMatrix(expectations.length, expectations.length * 2 - 1, 0, expectations.length - 1);
        RealMatrix lr = theta.getSubMatrix(expectations.length, expectations.length * 2 - 1, expectations.length, expectations.length * 2 - 1);
        double[][] covA = ul.add(lr).subtract(ur).subtract(ll).getData();
        double[] sigma = new double[covA.length];
        for (int i = 0; i < covA.length; ++i) {
            sigma[i] = FastMath.sqrt((double)FastMath.abs((double)covA[i][i]));
        }
        return sigma;
    }

    private static double[] mbarUncertaintyCalc(double[][] theta) {
        double[] uncertainties = new double[theta.length - 1];
        for (int i = 0; i < theta.length - 1; ++i) {
            double variance = theta[i][i] - 2.0 * theta[i][i + 1] + theta[i + 1][i + 1];
            if (variance < 0.0) {
                if (VERBOSE) {
                    logger.warning(" Negative variance detected in MBAR uncertainty calculation. Multiplying by -1 to get real value. Check diff matrix to see which variances were negative. They should be NaN.");
                }
                variance *= -1.0;
            }
            uncertainties[i] = FastMath.sqrt((double)variance);
        }
        return uncertainties;
    }

    private static double mbarTotalUncertaintyCalc(double[][] theta) {
        int nStates = theta.length;
        return FastMath.sqrt((double)FastMath.abs((double)(theta[0][0] - 2.0 * theta[0][nStates - 1] + theta[nStates - 1][nStates - 1])));
    }

    private static double[][] mbarTheta(double[][] reducedPotentials, int[] snapsPerState, double[] freeEnergies) {
        return MultistateBennettAcceptanceRatio.mbarTheta(snapsPerState, MultistateBennettAcceptanceRatio.mbarW(reducedPotentials, snapsPerState, freeEnergies));
    }

    private static double[][] mbarTheta(int[] snapsPerState, double[][] W) {
        RealMatrix WMatrix = MatrixUtils.createRealMatrix((double[][])W).transpose();
        RealMatrix I = MatrixUtils.createRealIdentityMatrix((int)snapsPerState.length);
        RealMatrix NkMatrix = MatrixUtils.createRealDiagonalMatrix((double[])Arrays.stream(snapsPerState).mapToDouble(i -> i).toArray());
        SingularValueDecomposition svd = new SingularValueDecomposition(WMatrix);
        RealMatrix V = svd.getV();
        RealMatrix S = MatrixUtils.createRealDiagonalMatrix((double[])svd.getSingularValues());
        RealMatrix theta = S.multiply(V.transpose());
        theta = theta.multiply(NkMatrix).multiply(V).multiply(S);
        theta = I.subtract(theta);
        theta = MatrixUtils.inverse((RealMatrix)theta);
        theta = V.multiply(S).multiply(theta).multiply(S).multiply(V.transpose());
        return theta.getData();
    }

    private static double[][] diffMatrixCalculation(double[][] theta) {
        double[][] diffMatrix = new double[theta.length][theta.length];
        for (int i = 0; i < diffMatrix.length; ++i) {
            for (int j = 0; j < diffMatrix.length; ++j) {
                diffMatrix[i][j] = FastMath.sqrt((double)(theta[i][i] - 2.0 * theta[i][j] + theta[j][j]));
            }
        }
        return diffMatrix;
    }

    private static double[] mbarSelfConsistentUpdate(double[][] reducedPotential, int[] snapsPerLambda, double[] freeEnergyEstimates) {
        int i;
        int nStates = freeEnergyEstimates.length;
        double[] updatedF_k = new double[nStates];
        double[] log_denom_n = new double[reducedPotential[0].length];
        double[][] logDiff = new double[reducedPotential.length][reducedPotential[0].length];
        double[] maxLogDiff = new double[nStates];
        ArrayFill.fill((double[])maxLogDiff, (double)Double.NEGATIVE_INFINITY);
        for (i = 0; i < reducedPotential[0].length; ++i) {
            int j;
            double[] temp = new double[nStates];
            double maxTemp = Double.NEGATIVE_INFINITY;
            for (j = 0; j < nStates; ++j) {
                temp[j] = freeEnergyEstimates[j] - reducedPotential[j][i];
                if (!(temp[j] > maxTemp)) continue;
                maxTemp = temp[j];
            }
            log_denom_n[i] = MultistateBennettAcceptanceRatio.logSumExp(temp, snapsPerLambda, maxTemp);
            for (j = 0; j < nStates; ++j) {
                logDiff[j][i] = -log_denom_n[i] - reducedPotential[j][i];
                if (!(logDiff[j][i] > maxLogDiff[j])) continue;
                maxLogDiff[j] = logDiff[j][i];
            }
        }
        for (i = 0; i < nStates; ++i) {
            updatedF_k[i] = -1.0 * MultistateBennettAcceptanceRatio.logSumExp(logDiff[i], maxLogDiff[i]);
        }
        double norm = updatedF_k[0];
        updatedF_k[0] = 0.0;
        for (int i2 = 1; i2 < nStates; ++i2) {
            updatedF_k[i2] = updatedF_k[i2] - norm;
        }
        return updatedF_k;
    }

    private static double[] newtonStep(double[] n, double[] grad, double[][] hessian, double stepSize) {
        double[] step;
        double[] nPlusOne = new double[n.length];
        try {
            RealMatrix hessianInverse = MatrixUtils.inverse((RealMatrix)MatrixUtils.createRealMatrix((double[][])hessian));
            step = hessianInverse.preMultiply(grad);
        }
        catch (IllegalArgumentException e) {
            if (VERBOSE) {
                logger.info(" Singular matrix detected in MBAR Newton-Raphson step. Performing steepest descent step.");
            }
            step = grad;
            stepSize = 1.0E-5;
        }
        double temp = step[0];
        step[0] = 0.0;
        int i = 1;
        while (i < step.length) {
            int n2 = i++;
            step[n2] = step[n2] - temp;
        }
        for (i = 0; i < n.length; ++i) {
            nPlusOne[i] = n[i] - step[i] * stepSize;
        }
        return nPlusOne;
    }

    private static double[] newton(double[] freeEnergyEstimates, double[][] reducedPotentials, int[] snapsPerLambda, double tolerance) {
        int iter;
        double[] grad = MultistateBennettAcceptanceRatio.mbarGradient(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
        double[][] hessian = MultistateBennettAcceptanceRatio.mbarHessian(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
        double[] f_kPlusOne = MultistateBennettAcceptanceRatio.newtonStep(freeEnergyEstimates, grad, hessian, 1.0);
        for (iter = 1; iter < 15; ++iter) {
            freeEnergyEstimates = f_kPlusOne;
            grad = MultistateBennettAcceptanceRatio.mbarGradient(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
            hessian = MultistateBennettAcceptanceRatio.mbarHessian(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
            f_kPlusOne = MultistateBennettAcceptanceRatio.newtonStep(freeEnergyEstimates, grad, hessian, 1.0);
            double eps = 0.0;
            for (int i = 0; i < freeEnergyEstimates.length; ++i) {
                eps += FastMath.abs((double)grad[i]);
            }
            if (eps < tolerance) break;
        }
        if (VERBOSE) {
            logger.info(" Newton iterations (max 15): " + iter);
        }
        return f_kPlusOne;
    }

    private static double logSumExp(double[] values, double max) {
        int[] b = ArrayFill.fill((int[])new int[values.length], (int)1);
        return MultistateBennettAcceptanceRatio.logSumExp(values, b, max);
    }

    private static double logSumExp(double[] values, int[] b, double max) {
        assert (values.length == b.length) : "values and b must be the same length";
        double sum = 0.0;
        for (int i = 0; i < values.length; ++i) {
            sum += (double)b[i] * FastMath.exp((double)(values[i] - max));
        }
        return max + FastMath.log((double)sum);
    }

    private static void softMax(double[] values) {
        int i;
        double max = Arrays.stream(values).max().getAsDouble();
        double sum = 0.0;
        for (i = 0; i < values.length; ++i) {
            values[i] = FastMath.exp((double)(values[i] - max));
            sum += values[i];
        }
        i = 0;
        while (i < values.length) {
            int n = i++;
            values[n] = values[n] / sum;
        }
    }

    private OptimizationListener getOptimizationListener() {
        return new OptimizationListener(this){
            {
                Objects.requireNonNull(this$0);
            }

            @Override
            public boolean optimizationUpdate(int iter, int nBFGS, int nFunctionEvals, double gradientRMS, double coordinateRMS, double f, double df, double angle, LineSearch.LineSearchResult info) {
                return true;
            }
        };
    }

    @Override
    public double energy(double[] x) {
        double tempO = x[0];
        x[0] = 0.0;
        int i = 1;
        while (i < x.length) {
            int n = i++;
            x[n] = x[n] - tempO;
        }
        return MultistateBennettAcceptanceRatio.mbarObjectiveFunction(this.reducedPotentials, this.nSamples, x);
    }

    @Override
    public double energyAndGradient(double[] x, double[] g) {
        double tempO = x[0];
        x[0] = 0.0;
        int i = 1;
        while (i < x.length) {
            int n = i++;
            x[n] = x[n] - tempO;
        }
        double[] tempG = MultistateBennettAcceptanceRatio.mbarGradient(this.reducedPotentials, this.nSamples, x);
        System.arraycopy(tempG, 0, g, 0, g.length);
        return MultistateBennettAcceptanceRatio.mbarObjectiveFunction(this.reducedPotentials, this.nSamples, x);
    }

    @Override
    public double[] getCoordinates(double[] parameters) {
        return new double[0];
    }

    @Override
    public void setCoordinates(double[] parameters) {
    }

    @Override
    public int getNumberOfVariables() {
        return 0;
    }

    @Override
    public double[] getScaling() {
        return null;
    }

    @Override
    public void setScaling(double[] scaling) {
    }

    @Override
    public double getTotalEnergy() {
        return 0.0;
    }

    public BennettAcceptanceRatio getBAR() {
        return new BennettAcceptanceRatio(this.lamValues, this.eLambdaMinusdL, this.eLambda, this.eLambdaPlusdL, this.temperatures);
    }

    @Override
    public MultistateBennettAcceptanceRatio copyEstimator() {
        return new MultistateBennettAcceptanceRatio(this.lamValues, this.eAll, this.temperatures, this.tolerance, this.seedType);
    }

    @Override
    public double[] getFreeEnergyDifferences() {
        return this.mbarFEDifferenceEstimates;
    }

    public double[] getMBARFreeEnergies() {
        return this.mbarFEEstimates;
    }

    public double[][] getReducedPotentials() {
        return this.reducedPotentials;
    }

    public int[] getSnaps() {
        return this.nSamples;
    }

    @Override
    public double[] getFEDifferenceUncertainties() {
        return this.mbarUncertainties;
    }

    public double[] getObservationEnsembleAverages() {
        return this.mbarObservableEnsembleAverages;
    }

    public double[] getObservationEnsembleUncertainties() {
        return this.mbarObservableEnsembleAverageUncertainties;
    }

    public double[][] getUncertaintyMatrix() {
        return this.uncertaintyMatrix;
    }

    @Override
    public double getTotalFreeEnergyDifference() {
        return this.totalMBAREstimate;
    }

    @Override
    public double getTotalFEDifferenceUncertainty() {
        return this.totalMBARUncertainty;
    }

    @Override
    public int getNumberOfBins() {
        return this.nFreeEnergyDiffs;
    }

    @Override
    public double[] getEnthalpyDifferences() {
        return this.mbarEnthalpy;
    }

    @Override
    public double getTotalEnthalpyDifference() {
        return this.getTotalEnthalpyDifference(this.mbarEnthalpy);
    }

    public double[] getBinEntropies() {
        return this.mbarEntropy;
    }

    public static void writeFile(double[][] energies, File file, double temperature) {
        try (FileWriter fw = new FileWriter(file);
             BufferedWriter bw = new BufferedWriter(fw);){
            bw.write(energies[0].length + " " + temperature);
            bw.newLine();
            StringBuilder sb = new StringBuilder();
            for (int i = 0; i < energies[0].length; ++i) {
                sb.append("     ").append(i).append(" ");
                for (int j = 0; j < energies.length; ++j) {
                    sb.append("    ").append(energies[j][i]).append(" ");
                }
                sb.append("\n");
                bw.write(sb.toString());
                sb = new StringBuilder();
            }
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static String[] testMBARMethods() {
        double[] O_k = new double[]{1.0, 2.0, 3.0, 4.0};
        double[] K_k = new double[]{0.5, 1.0, 1.5, 2.0};
        int[] N_k = new int[]{100000, 100000, 100000, 100000};
        double beta = 1.0;
        HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(O_k, K_k, beta);
        String setting = "u_kln";
        Object[] sampleResult = testCase.sample(N_k, setting, 0L);
        double[][][] u_kln = (double[][][])sampleResult[1];
        double[] temps = new double[]{503.21953349876577};
        MultistateBennettAcceptanceRatio mbar = new MultistateBennettAcceptanceRatio(O_k, u_kln, temps, 1.0E-7, SeedType.ZEROS);
        MultistateBennettAcceptanceRatio mbarHigherTol = new MultistateBennettAcceptanceRatio(O_k, u_kln, temps, 1.0, SeedType.ZEROS);
        String[] results = new String[7];
        double[][] reducedPotentials = mbar.getReducedPotentials();
        double[] freeEnergyEstimates = mbar.getMBARFreeEnergies();
        double[] highTolFEEstimates = mbarHigherTol.getMBARFreeEnergies();
        double[] zeros = new double[freeEnergyEstimates.length];
        int[] snapsPerLambda = mbar.getSnaps();
        double[] expectedFEEstimates = new double[]{0.0, 0.3474485596619945, 0.5460865684340613, 0.6866650788765148};
        boolean pass = MultistateBennettAcceptanceRatio.normDiff(freeEnergyEstimates, expectedFEEstimates) < 1.0E-5;
        expectedFEEstimates = new double[]{0.0, 0.35798124225733474, 0.44721370511807645, 0.477203739646745};
        pass = MultistateBennettAcceptanceRatio.normDiff(highTolFEEstimates, expectedFEEstimates) < 1.0E-5 && pass;
        results[0] = pass ? "PASS" : "FAIL getMBARFreeEnergies()";
        double objectiveFunction = MultistateBennettAcceptanceRatio.mbarObjectiveFunction(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
        pass = !(FastMath.abs((double)(objectiveFunction - 4786294.2692739945)) > 1.0E-5);
        objectiveFunction = MultistateBennettAcceptanceRatio.mbarObjectiveFunction(reducedPotentials, snapsPerLambda, highTolFEEstimates);
        pass = !(FastMath.abs((double)(objectiveFunction - 4787001.700838844)) > 1.0E-5) && pass;
        objectiveFunction = MultistateBennettAcceptanceRatio.mbarObjectiveFunction(reducedPotentials, snapsPerLambda, zeros);
        pass = !(FastMath.abs((double)(objectiveFunction - 4792767.352152844)) > 1.0E-5) && pass;
        results[1] = pass ? "PASS" : "FAIL mbarObjectiveFunction()";
        double[] gradient = MultistateBennettAcceptanceRatio.mbarGradient(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
        double[] expected = new double[]{6.067113034191607E-4, -8.777718552011038E-4, 8.210768953631487E-4, -5.500246369471995E-4};
        pass = !(MultistateBennettAcceptanceRatio.normDiff(gradient, expected) > 4.0E-5);
        gradient = MultistateBennettAcceptanceRatio.mbarGradient(reducedPotentials, snapsPerLambda, highTolFEEstimates);
        expected = new double[]{1969.705314577408, 5108.841258429764, -1072.9526887468976, -6005.593884267446};
        pass = !(MultistateBennettAcceptanceRatio.normDiff(gradient, expected) > 4.0E-5) && pass;
        gradient = MultistateBennettAcceptanceRatio.mbarGradient(reducedPotentials, snapsPerLambda, zeros);
        expected = new double[]{22797.82037585665, -3273.72282675803, -8859.999065013779, -10664.098484078011};
        pass = !(MultistateBennettAcceptanceRatio.normDiff(gradient, expected) > 4.0E-5) && pass;
        results[2] = pass ? "PASS" : "FAIL mbarGradient()";
        pass = true;
        double[][] hessian = MultistateBennettAcceptanceRatio.mbarHessian(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
        double[][] expected2d = new double[][]{{47600.586808418964, -29977.008359691405, -12870.425573135915, -4753.1528755909385}, {-29977.008359691405, 63767.745823769576, -24597.198354108747, -9193.539109971487}, {-12870.425573135915, -24597.198354108747, 64584.87112481013, -27117.247197561417}, {-4753.1528755909385, -9193.539109971487, -27117.247197561417, 41063.93918312612}};
        pass = !(MultistateBennettAcceptanceRatio.normDiff(hessian, expected2d) > 1.6E-4);
        hessian = MultistateBennettAcceptanceRatio.mbarHessian(reducedPotentials, snapsPerLambda, highTolFEEstimates);
        expected2d = new double[][]{{49168.30161780381, -31256.519016487477, -12983.708230229113, -4928.074371082683}, {-31256.519016487477, 66075.94621325849, -25339.462656640117, -9479.964540130917}, {-12983.708230229113, -25339.462656640117, 64308.30940252403, -25985.13851565483}, {-4928.074371082683, -9479.964540130917, -25985.13851565483, 40393.1774268678}};
        pass = !(MultistateBennettAcceptanceRatio.normDiff(hessian, expected2d) > 1.6E-4) && pass;
        hessian = MultistateBennettAcceptanceRatio.mbarHessian(reducedPotentials, snapsPerLambda, zeros);
        expected2d = new double[][]{{56125.271437145464, -33495.87894376072, -15738.011263498352, -6891.381229885624}, {-33495.87894376072, 64613.515110188295, -21970.091845920833, -9147.544320511564}, {-15738.011263498352, -21970.091845920833, 61407.66256511316, -23699.55945569241}, {-6891.381229885624, -9147.544320511564, -23699.55945569241, 39738.48500608951}};
        pass = !(MultistateBennettAcceptanceRatio.normDiff(hessian, expected2d) > 1.6E-4) && pass;
        results[3] = pass ? "PASS" : "FAIL mbarHessian()";
        pass = true;
        double[][] theta = MultistateBennettAcceptanceRatio.mbarTheta(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
        double[][] diff = MultistateBennettAcceptanceRatio.diffMatrixCalculation(theta);
        expected2d = new double[][]{{0.0, 0.001953125, 0.003400485419234404, 0.004858337095247168}, {0.0020716018980074633, 0.0, 0.002042627017905458, 0.004055968683065466}, {0.003435363105339426, 0.002042627017905458, 0.0, 0.002560568476977909}, {0.0048828125, 0.004055968683065466, 0.0025135815773894045, 0.0}};
        pass = !(MultistateBennettAcceptanceRatio.normDiff(diff, expected2d) > 1.6E-4);
        results[4] = pass ? "PASS" : "FAIL mbarTheta() or diffMatrixCalculation()";
        pass = true;
        double[] updatedF_k = MultistateBennettAcceptanceRatio.mbarSelfConsistentUpdate(reducedPotentials, snapsPerLambda, freeEnergyEstimates);
        expected = new double[]{0.0, 0.3474485745068261, 0.5460865662904055, 0.6866650904438742};
        pass = !(MultistateBennettAcceptanceRatio.normDiff(updatedF_k, expected) > 1.0E-5);
        updatedF_k = MultistateBennettAcceptanceRatio.mbarSelfConsistentUpdate(reducedPotentials, snapsPerLambda, highTolFEEstimates);
        expected = new double[]{0.0, 0.327660608017009, 0.4775067849198251, 0.5586442310038073};
        pass = !(MultistateBennettAcceptanceRatio.normDiff(updatedF_k, expected) > 1.0E-5) && pass;
        updatedF_k = MultistateBennettAcceptanceRatio.mbarSelfConsistentUpdate(reducedPotentials, snapsPerLambda, zeros);
        expected = new double[]{0.0, 0.23865416150488983, 0.29814247007871764, 0.31813582643116334};
        pass = !(MultistateBennettAcceptanceRatio.normDiff(updatedF_k, expected) > 1.0E-5) && pass;
        results[5] = pass ? "PASS" : "FAIL mbarSelfConsistentUpdate()";
        pass = true;
        updatedF_k = MultistateBennettAcceptanceRatio.newton(highTolFEEstimates, reducedPotentials, snapsPerLambda, 1.0E-7);
        pass = !(MultistateBennettAcceptanceRatio.normDiff(updatedF_k, freeEnergyEstimates) > 1.0E-5);
        updatedF_k = MultistateBennettAcceptanceRatio.newton(zeros, reducedPotentials, snapsPerLambda, 1.0E-7);
        pass = !(MultistateBennettAcceptanceRatio.normDiff(updatedF_k, freeEnergyEstimates) > 1.0E-5) && pass;
        results[6] = pass ? "PASS" : "FAIL newton()";
        return results;
    }

    private static double normDiff(double[] a, double[] b) {
        double sum = 0.0;
        for (int i = 0; i < a.length; ++i) {
            sum += FastMath.abs((double)(a[i] - b[i]));
        }
        return sum;
    }

    private static double normDiff(double[][] a, double[][] b) {
        double sum = 0.0;
        for (int i = 0; i < a.length; ++i) {
            for (int j = 0; j < a[i].length; ++j) {
                sum += FastMath.abs((double)(a[i][j] - b[i][j]));
            }
        }
        return sum;
    }

    public static void main(String[] args) {
        double[] equilPositions = new double[]{1.0, 2.0, 3.0, 4.0};
        double[] springConstants = new double[]{0.5, 1.0, 1.5, 2.0};
        int[] samples = new int[]{100000, 100000, 100000, 100000};
        double beta = 1.0;
        HarmonicOscillatorsTestCase testCase = new HarmonicOscillatorsTestCase(equilPositions, springConstants, beta);
        String setting = "u_kln";
        System.out.print("Generating sample data... ");
        Object[] sampleResult = testCase.sample(samples, setting, 0L);
        System.out.println("done. \n");
        double[] x_n = (double[])sampleResult[0];
        double[][][] u_kln = (double[][][])sampleResult[1];
        double[] temps = new double[]{503.21953349876577};
        String rootPath = new File("").getAbsolutePath();
        File outputPath = new File(rootPath + "/testing/mbar/data/harmonic_oscillators/mbarFiles");
        if (!outputPath.exists() && !outputPath.mkdirs()) {
            throw new RuntimeException("Failed to create directory: " + String.valueOf(outputPath));
        }
        double[] temperatures = new double[equilPositions.length];
        Arrays.fill(temperatures, temps[0]);
        for (int i = 0; i < u_kln.length; ++i) {
            File file = new File(outputPath, "energies_" + i + ".mbar");
            MultistateBennettAcceptanceRatio.writeFile(u_kln[i], file, temperatures[i]);
        }
        System.out.print("Creating MBAR instance and .estimateDG(false) with standard tolerance & zeros seeding...");
        MultistateBennettAcceptanceRatio mbar = new MultistateBennettAcceptanceRatio(equilPositions, u_kln, temps, 1.0E-7, SeedType.ZEROS);
        System.out.println("done! \n\n");
        double[] mbarFEEstimates = Arrays.copyOf(mbar.mbarFEEstimates, mbar.mbarFEEstimates.length);
        double[] mbarEnthalpyDiff = Arrays.copyOf(mbar.mbarEnthalpy, mbar.mbarEnthalpy.length);
        double[] mbarEntropyDiff = Arrays.copyOf(mbar.mbarEntropy, mbar.mbarEntropy.length);
        double[] mbarUncertainties = Arrays.copyOf(mbar.mbarUncertainties, mbar.mbarUncertainties.length);
        double[][] mbarDiffMatrix = (double[][])Arrays.copyOf(mbar.uncertaintyMatrix, mbar.uncertaintyMatrix.length);
        double[] analyticalFreeEnergies = testCase.analyticalFreeEnergies();
        double[] error = new double[analyticalFreeEnergies.length];
        for (int i = 0; i < error.length; ++i) {
            error[i] = analyticalFreeEnergies[i] - mbarFEEstimates[i];
        }
        double[] temp = testCase.analyticalEntropies(0);
        double[] analyticEntropyDiff = new double[temp.length - 1];
        double[] errorEntropy = new double[temp.length - 1];
        for (int i = 0; i < analyticEntropyDiff.length; ++i) {
            analyticEntropyDiff[i] = temp[i + 1] - temp[i];
            errorEntropy[i] = analyticEntropyDiff[i] - mbarEntropyDiff[i];
        }
        System.out.println("STANDARD THERMODYNAMIC CALCULATIONS: \n");
        System.out.println("Analytical Free Energies: " + Arrays.toString(analyticalFreeEnergies));
        System.out.println("MBAR Free Energies:       " + Arrays.toString(mbarFEEstimates));
        System.out.println("Free Energy Error:        " + Arrays.toString(error));
        System.out.println();
        System.out.println("MBAR dG:                  " + Arrays.toString(mbar.mbarFEDifferenceEstimates));
        System.out.println("MBAR Uncertainties:       " + Arrays.toString(mbarUncertainties));
        System.out.println("MBAR Enthalpy Changes:    " + Arrays.toString(mbarEnthalpyDiff));
        System.out.println();
        System.out.println("MBAR Entropy Changes:     " + Arrays.toString(mbarEntropyDiff));
        System.out.println("Analytic Entropy Changes: " + Arrays.toString(analyticEntropyDiff));
        System.out.println("Entropy Error:            " + Arrays.toString(errorEntropy));
        System.out.println();
        System.out.println("Uncertainty Diff Matrix: ");
        for (double[] matrix : mbarDiffMatrix) {
            System.out.println(Arrays.toString(matrix));
        }
        System.out.println("\n\n");
        System.out.println("MBAR DERIVED OBSERVABLES: \n");
        mbar.setObservableData(u_kln, true, true);
        double[] mbarObservableEnsembleAverages = Arrays.copyOf(mbar.mbarObservableEnsembleAverages, mbar.mbarObservableEnsembleAverages.length);
        double[] mbarObservableEnsembleAverageUncertainties = Arrays.copyOf(mbar.mbarObservableEnsembleAverageUncertainties, mbar.mbarObservableEnsembleAverageUncertainties.length);
        System.out.println("Multi-Data Observable Example u_kln:");
        System.out.println("MBAR Observable Ensemble Averages (Potential):              " + Arrays.toString(mbarObservableEnsembleAverages));
        System.out.println("Analytical Observable Ensemble Averages (Potential):        " + Arrays.toString(testCase.analyticalObservable("potential energy")));
        System.out.println("MBAR Observable Ensemble Average Uncertainties (Potential): " + Arrays.toString(mbarObservableEnsembleAverageUncertainties));
        System.out.println();
        double[][][] xAll = new double[equilPositions.length][equilPositions.length][x_n.length];
        for (int i = 0; i < xAll[0].length; ++i) {
            for (int j = 0; j < xAll[0][0].length; ++j) {
                xAll[0][i][j] = x_n[j];
            }
        }
        mbar.setObservableData(xAll, false, true);
        mbarObservableEnsembleAverages = Arrays.copyOf(mbar.mbarObservableEnsembleAverages, mbar.mbarObservableEnsembleAverages.length);
        mbarObservableEnsembleAverageUncertainties = Arrays.copyOf(mbar.mbarObservableEnsembleAverageUncertainties, mbar.mbarObservableEnsembleAverageUncertainties.length);
        System.out.println("Single-Data Observable Example x_n:");
        System.out.println("MBAR Observable Ensemble Averages (Position):              " + Arrays.toString(mbarObservableEnsembleAverages));
        System.out.println("Analytical Observable Ensemble Averages (Position):        " + Arrays.toString(testCase.analyticalMeans()));
        System.out.println("MBAR Observable Ensemble Average Uncertainties (Position): " + Arrays.toString(mbarObservableEnsembleAverageUncertainties));
        System.out.println();
    }

    public static enum SeedType {
        BAR,
        ZWANZIG,
        ZEROS;

    }

    public static class HarmonicOscillatorsTestCase {
        private final double beta;
        private final double[] equilPositions;
        private final int n_states;
        private final double[] springConstants;

        public HarmonicOscillatorsTestCase(double[] O_k, double[] K_k, double beta) {
            this.beta = beta;
            this.equilPositions = O_k;
            this.n_states = O_k.length;
            this.springConstants = K_k;
            if (this.springConstants.length != this.n_states) {
                throw new IllegalArgumentException("Lengths of K_k and O_k should be equal");
            }
        }

        public double[] analyticalMeans() {
            return this.equilPositions;
        }

        public double[] analyticalStandardDeviations() {
            double[] deviations = new double[this.n_states];
            for (int i = 0; i < this.n_states; ++i) {
                deviations[i] = Math.sqrt(1.0 / (this.beta * this.springConstants[i]));
            }
            return deviations;
        }

        public double[] analyticalObservable(String observable) {
            double[] result = new double[this.n_states];
            switch (observable) {
                case "position": {
                    return this.analyticalMeans();
                }
                case "potential energy": {
                    for (int i = 0; i < this.n_states; ++i) {
                        result[i] = 0.5 / this.beta;
                    }
                    break;
                }
                case "position^2": {
                    for (int i = 0; i < this.n_states; ++i) {
                        result[i] = 1.0 / (this.beta * this.springConstants[i]) + Math.pow(this.equilPositions[i], 2.0);
                    }
                    break;
                }
                case "RMS displacement": {
                    return this.analyticalStandardDeviations();
                }
            }
            return result;
        }

        public double[] analyticalFreeEnergies() {
            int subtractComponentIndex = 0;
            double[] fe = new double[this.n_states];
            double subtract = 0.0;
            int i = 0;
            while (i < this.n_states) {
                fe[i] = -0.5 * Math.log(Math.PI * 2 / (this.beta * this.springConstants[i]));
                if (i == 0) {
                    subtract = fe[subtractComponentIndex];
                }
                int n = i++;
                fe[n] = fe[n] - subtract;
            }
            return fe;
        }

        public double[] analyticalEntropies(int subtractComponent) {
            double[] entropies = new double[this.n_states];
            double[] potentialEnergy = this.analyticalObservable("analytical entropy");
            double[] freeEnergies = this.analyticalFreeEnergies();
            for (int i = 0; i < this.n_states; ++i) {
                entropies[i] = potentialEnergy[i] - freeEnergies[i];
            }
            return entropies;
        }

        public Object[] sample(int[] N_k, String mode, Long seed) {
            Random random = new Random(seed);
            int N_max = 0;
            for (int N : N_k) {
                if (N <= N_max) continue;
                N_max = N;
            }
            int N_tot = 0;
            for (int N : N_k) {
                N_tot += N;
            }
            double[][] x_kn = new double[this.n_states][N_max];
            double[][] u_kn = new double[this.n_states][N_tot];
            double[][][] u_kln = new double[this.n_states][this.n_states][N_max];
            double[] x_n = new double[N_tot];
            int[] s_n = new int[N_tot];
            int index = 0;
            for (int k = 0; k < this.n_states; ++k) {
                int n;
                double x0 = this.equilPositions[k];
                double sigma = Math.sqrt(1.0 / (this.beta * this.springConstants[k]));
                for (n = 0; n < N_k[k]; ++n) {
                    double x;
                    x_kn[k][n] = x = x0 + random.nextGaussian() * sigma;
                    x_n[index] = x;
                    s_n[index] = k;
                    for (int l = 0; l < this.n_states; ++l) {
                        double u;
                        u_kln[k][l][n] = u = this.beta * 0.5 * this.springConstants[l] * Math.pow(x - this.equilPositions[l], 2.0);
                        u_kn[l][index] = u;
                    }
                    ++index;
                }
                for (n = N_k[k]; n < N_max; ++n) {
                    for (int l = 0; l < this.n_states; ++l) {
                        u_kln[k][l][n] = Double.NaN;
                    }
                }
            }
            if ("u_kn".equals(mode)) {
                return new Object[]{x_n, u_kn, N_k, s_n};
            }
            if ("u_kln".equals(mode)) {
                return new Object[]{x_n, u_kln, N_k, s_n, u_kn};
            }
            throw new IllegalArgumentException("Unknown mode: " + mode);
        }
    }
}

