/*
 * Decompiled with CFR 0.152.
 */
package ffx.numerics.estimator;

import ffx.numerics.estimator.BootstrappableEstimator;
import ffx.numerics.estimator.EstimateBootstrapper;
import ffx.numerics.estimator.SequentialEstimator;
import ffx.numerics.estimator.Zwanzig;
import ffx.numerics.math.ScalarMath;
import ffx.numerics.math.SummaryStatistics;
import java.lang.invoke.LambdaMetafactory;
import java.util.Arrays;
import java.util.Random;
import java.util.function.DoubleUnaryOperator;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.math3.util.FastMath;

public class BennettAcceptanceRatio
extends SequentialEstimator
implements BootstrappableEstimator {
    private static final Logger logger = Logger.getLogger(BennettAcceptanceRatio.class.getName());
    private static final double DEFAULT_TOLERANCE = 1.0E-4;
    private static final int DEFAULT_MAX_BAR_ITERATIONS = 1000;
    private final int nWindows;
    private final double tolerance;
    private final int nIterations;
    private final Zwanzig forwardsFEP;
    private final Zwanzig backwardsFEP;
    private final Random random;
    private double totalFreeEnergyDifference;
    private double totalFEDifferenceUncertainty;
    private final double[] freeEnergyDifferences;
    private final double[] freeEnergyDifferenceUncertainties;
    private final double[] enthalpyDifferences;
    private final double[] forwardZwanzigFEDifferences;
    private final double[] backwardZwanzigFEDifferences;

    public BennettAcceptanceRatio(double[] lambdaValues, double[][] eLambdaMinusdL, double[][] eLambda, double[][] eLambdaPlusdL, double[] temperature) {
        this(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature, 1.0E-4);
    }

    public BennettAcceptanceRatio(double[] lambdaValues, double[][] eLambdaMinusdL, double[][] eLambda, double[][] eLambdaPlusdL, double[] temperature, double tolerance) {
        this(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature, tolerance, 1000);
    }

    public BennettAcceptanceRatio(double[] lambdaValues, double[][] eLambdaMinusdL, double[][] eLambda, double[][] eLambdaPlusdL, double[] temperature, double tolerance, int nIterations) {
        super(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature);
        this.forwardsFEP = new Zwanzig(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature, Zwanzig.Directionality.FORWARDS);
        this.backwardsFEP = new Zwanzig(lambdaValues, eLambdaMinusdL, eLambda, eLambdaPlusdL, temperature, Zwanzig.Directionality.BACKWARDS);
        this.nWindows = this.nStates - 1;
        this.forwardZwanzigFEDifferences = this.forwardsFEP.getFreeEnergyDifferences();
        this.backwardZwanzigFEDifferences = this.backwardsFEP.getFreeEnergyDifferences();
        this.freeEnergyDifferences = new double[this.nWindows];
        this.freeEnergyDifferenceUncertainties = new double[this.nWindows];
        this.enthalpyDifferences = new double[this.nWindows];
        this.tolerance = tolerance;
        this.nIterations = nIterations;
        this.random = new Random();
        this.estimateDG();
    }

    private static void fermiDiffIterative(double[] e0, double[] e1, double[] fermiDiffs, int len, double c, double invRT) {
        for (int i = 0; i < len; ++i) {
            fermiDiffs[i] = ScalarMath.fermiFunction(invRT * (e0[i] - e1[i] + c));
        }
        if (Arrays.stream(fermiDiffs).sum() == 0.0) {
            logger.warning(String.format(" Input Fermi with length %3d should not be permitted: c: %9.4f invRT: %9.4f Fermi output: %9.4f", len, c, invRT, Arrays.stream(fermiDiffs).sum()));
        }
    }

    private void calcAlphaForward(double[] e0, double[] e1, int len, double c, double invRT, double[] ret) {
        double alpha;
        double fsum = 0.0;
        double fvsum = 0.0;
        double fbvsum = 0.0;
        double vsum = 0.0;
        double fbsum = 0.0;
        for (int i = 0; i < len; ++i) {
            double fore = ScalarMath.fermiFunction(invRT * (e1[i] - e0[i] - c));
            double back = ScalarMath.fermiFunction(invRT * (e0[i] - e1[i] + c));
            fsum += fore;
            fvsum += fore * e0[i];
            fbvsum += fore * back * (e1[i] - e0[i]);
            vsum += e0[i];
            fbsum += fore * back;
        }
        ret[0] = alpha = fvsum - fsum * (vsum / (double)len) + fbvsum;
        ret[1] = fbsum;
    }

    private void calcAlphaBackward(double[] e0, double[] e1, int len, double c, double invRT, double[] ret) {
        double alpha;
        double bsum = 0.0;
        double bvsum = 0.0;
        double fbvsum = 0.0;
        double vsum = 0.0;
        double fbsum = 0.0;
        for (int i = 0; i < len; ++i) {
            double fore = ScalarMath.fermiFunction(invRT * (e1[i] - e0[i] - c));
            double back = ScalarMath.fermiFunction(invRT * (e0[i] - e1[i] + c));
            bsum += back;
            bvsum += back * e1[i];
            fbvsum += fore * back * (e1[i] - e0[i]);
            vsum += e1[i];
            fbsum += fore * back;
        }
        ret[0] = alpha = bvsum - bsum * (vsum / (double)len) - fbvsum;
        ret[1] = fbsum;
    }

    private static void fermiDiffBootstrap(double[] e0, double[] e1, double[] fermiDiffs, int len, double c, double invRT, int[] bootstrapSamples) {
        for (int indexI = 0; indexI < len; ++indexI) {
            int i = bootstrapSamples[indexI];
            fermiDiffs[indexI] = ScalarMath.fermiFunction(invRT * (e0[i] - e1[i] + c));
        }
    }

    private static double uncertaintyCalculation(double meanFermi, double meanSqFermi, int len) {
        double sqMeanFermi = meanFermi * meanFermi;
        return (meanSqFermi - sqMeanFermi) / (double)len / sqMeanFermi;
    }

    public Zwanzig getInitialBackwardsGuess() {
        return this.backwardsFEP;
    }

    public Zwanzig getInitialForwardsGuess() {
        return this.forwardsFEP;
    }

    @Override
    public BennettAcceptanceRatio copyEstimator() {
        return new BennettAcceptanceRatio(this.lamValues, this.eLambdaMinusdL, this.eLambda, this.eLambdaPlusdL, this.temperatures, this.tolerance, this.nIterations);
    }

    @Override
    public final void estimateDG(boolean randomSamples) {
        double cumDG = 0.0;
        Arrays.fill(this.freeEnergyDifferences, 0.0);
        Arrays.fill(this.freeEnergyDifferenceUncertainties, 0.0);
        Arrays.fill(this.enthalpyDifferences, 0.0);
        Level warningLevel = randomSamples ? Level.FINE : Level.WARNING;
        for (int i = 0; i < this.nWindows; ++i) {
            double hBar;
            if (Double.isNaN(this.forwardZwanzigFEDifferences[i]) || Double.isInfinite(this.forwardZwanzigFEDifferences[i]) || Double.isNaN(this.backwardZwanzigFEDifferences[i]) || Double.isInfinite(this.backwardZwanzigFEDifferences[i])) {
                logger.warning(String.format(" Window %3d bin energies produced unreasonable value(s) for forward Zwanzig (%8.4f) and/or backward Zwanzig (%8.4f)", i, this.forwardZwanzigFEDifferences[i], this.backwardZwanzigFEDifferences[i]));
            }
            double c = 0.5 * (this.forwardZwanzigFEDifferences[i] + this.backwardZwanzigFEDifferences[i]);
            if (!randomSamples) {
                logger.fine(String.format(" BAR Iteration Seed: %12.4f Kcal/mol", c));
            }
            double cold = c;
            int len0 = this.eLambda[i].length;
            int len1 = this.eLambda[i + 1].length;
            if (len0 == 0 || len1 == 0) {
                this.freeEnergyDifferences[i] = c;
                logger.log(warningLevel, String.format(" Window %d has no snapshots at one end (%d, %d)!", i, len0, len1));
                continue;
            }
            double sampleRatio = (double)len0 / (double)len1;
            double[] fermi0 = new double[len0];
            double[] fermi1 = new double[len1];
            double[] ret = new double[2];
            double rta = 0.0019872042586408316 * this.temperatures[i];
            double rtb = 0.0019872042586408316 * this.temperatures[i + 1];
            double rtMean = 0.5 * (rta + rtb);
            double invRTA = 1.0 / rta;
            double invRTB = 1.0 / rtb;
            SummaryStatistics s1 = null;
            SummaryStatistics s0 = null;
            int[] bootstrapSamples0 = null;
            int[] bootstrapSamples1 = null;
            if (randomSamples) {
                bootstrapSamples0 = EstimateBootstrapper.getBootstrapIndices(len0, this.random);
                bootstrapSamples1 = EstimateBootstrapper.getBootstrapIndices(len1, this.random);
            }
            int cycleCounter = 0;
            boolean converged = false;
            while (!converged) {
                if (randomSamples) {
                    BennettAcceptanceRatio.fermiDiffBootstrap(this.eLambdaPlusdL[i], this.eLambda[i], fermi0, len0, -c, invRTA, bootstrapSamples0);
                    BennettAcceptanceRatio.fermiDiffBootstrap(this.eLambdaMinusdL[i + 1], this.eLambda[i + 1], fermi1, len1, c, invRTB, bootstrapSamples1);
                } else {
                    BennettAcceptanceRatio.fermiDiffIterative(this.eLambdaPlusdL[i], this.eLambda[i], fermi0, len0, -c, invRTA);
                    BennettAcceptanceRatio.fermiDiffIterative(this.eLambdaMinusdL[i + 1], this.eLambda[i + 1], fermi1, len1, c, invRTB);
                }
                s0 = new SummaryStatistics(fermi0);
                s1 = new SummaryStatistics(fermi1);
                double ratio = s1.sum / s0.sum;
                boolean bl = converged = FastMath.abs((double)((c += rtMean * FastMath.log((double)(sampleRatio * ratio))) - cold)) < this.tolerance;
                if (!randomSamples && !converged && ++cycleCounter > this.nIterations) {
                    throw new IllegalArgumentException(String.format(" BAR required too many iterations (%d) to converge! (%9.8f > %9.8f)", cycleCounter, FastMath.abs((double)(c - cold)), this.tolerance));
                }
                if (!randomSamples) {
                    logger.fine(String.format(" BAR Iteration   %2d: %12.4f Kcal/mol", cycleCounter, c));
                }
                cold = c;
            }
            this.freeEnergyDifferences[i] = c;
            cumDG += c;
            double sqFermiMean0 = new SummaryStatistics((double[])Arrays.stream((double[])fermi0).map((DoubleUnaryOperator)(DoubleUnaryOperator)LambdaMetafactory.metafactory(null, null, null, (D)D, lambda$estimateDG$0(double ), (D)D)()).toArray()).mean;
            double sqFermiMean1 = new SummaryStatistics((double[])Arrays.stream((double[])fermi1).map((DoubleUnaryOperator)(DoubleUnaryOperator)LambdaMetafactory.metafactory(null, null, null, (D)D, lambda$estimateDG$1(double ), (D)D)()).toArray()).mean;
            this.freeEnergyDifferenceUncertainties[i] = FastMath.sqrt((double)(BennettAcceptanceRatio.uncertaintyCalculation(s0.mean, sqFermiMean0, len0) + BennettAcceptanceRatio.uncertaintyCalculation(s1.mean, sqFermiMean1, len1)));
            this.calcAlphaForward(this.eLambda[i], this.eLambdaPlusdL[i], len0, c, invRTA, ret);
            double alpha0 = ret[0];
            double fbsum0 = ret[1];
            this.calcAlphaBackward(this.eLambdaMinusdL[i + 1], this.eLambda[i + 1], len1, c, invRTB, ret);
            double alpha1 = ret[0];
            double fbsum1 = ret[1];
            this.enthalpyDifferences[i] = hBar = (alpha0 - alpha1) / (fbsum0 + fbsum1);
        }
        this.totalFreeEnergyDifference = cumDG;
        this.totalFEDifferenceUncertainty = FastMath.sqrt((double)Arrays.stream(this.freeEnergyDifferenceUncertainties).map(d -> d * d).sum());
    }

    @Override
    public double[] getFreeEnergyDifferences() {
        return Arrays.copyOf(this.freeEnergyDifferences, this.nWindows);
    }

    @Override
    public double[] getFEDifferenceUncertainties() {
        return Arrays.copyOf(this.freeEnergyDifferenceUncertainties, this.nWindows);
    }

    @Override
    public double getTotalFreeEnergyDifference() {
        return this.totalFreeEnergyDifference;
    }

    @Override
    public double getTotalFEDifferenceUncertainty() {
        return this.totalFEDifferenceUncertainty;
    }

    @Override
    public int getNumberOfBins() {
        return this.nWindows;
    }

    @Override
    public double getTotalEnthalpyDifference() {
        return this.getTotalEnthalpyDifference(this.enthalpyDifferences);
    }

    @Override
    public double[] getEnthalpyDifferences() {
        return this.enthalpyDifferences;
    }

    private static /* synthetic */ double lambda$estimateDG$1(double d) {
        return d * d;
    }

    private static /* synthetic */ double lambda$estimateDG$0(double d) {
        return d * d;
    }
}

