/*
 * Decompiled with CFR 0.152.
 */
package ffx.algorithms.commands;

import ffx.algorithms.cli.AlgorithmsCommand;
import ffx.crystal.CrystalPotential;
import ffx.numerics.Potential;
import ffx.numerics.estimator.BennettAcceptanceRatio;
import ffx.numerics.estimator.BootstrappableEstimator;
import ffx.numerics.estimator.EstimateBootstrapper;
import ffx.numerics.estimator.MBARFilter;
import ffx.numerics.estimator.MultistateBennettAcceptanceRatio;
import ffx.potential.MolecularAssembly;
import ffx.potential.bonded.LambdaInterface;
import ffx.potential.cli.AlchemicalOptions;
import ffx.potential.cli.TopologyOptions;
import ffx.potential.parsers.SystemFilter;
import ffx.potential.utils.PotentialsFunctions;
import ffx.utilities.FFXBinding;
import java.io.File;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.io.FilenameUtils;
import picocli.CommandLine;

@CommandLine.Command(description={" Evaluates a free energy change with the Multistate Bennett Acceptance Ratio algorithm."}, name="MBAR")
public class MBAR
extends AlgorithmsCommand {
    @CommandLine.Mixin
    private AlchemicalOptions alchemicalOptions;
    @CommandLine.Mixin
    private TopologyOptions topologyOptions;
    @CommandLine.Option(names={"--bar"}, paramLabel="true", description={"Run BAR calculation as well using a subset of the MBAR data."})
    private boolean bar = true;
    @CommandLine.Option(names={"--convergence"}, paramLabel="false", description={"Run MBAR multiple times across different time periods of the data to examine the change in FE over time."})
    private boolean convergence = false;
    @CommandLine.Option(names={"--numBootstrap", "--nb"}, paramLabel="0", description={"Number of bootstrap samples to use."})
    private int numBootstrap = 0;
    @CommandLine.Option(names={"--numLambda", "--nL", "--nw"}, paramLabel="-1", description={"Required for lambda energy evaluations. Ensure numLambda is consistent with the trajectory lambdas, i.e. gaps between traj can be filled easily. nL >> nTraj is recommended."})
    private int numLambda = -1;
    @CommandLine.Option(names={"--lambdaDerivative", "--lD"}, paramLabel="false", description={"Calculate lambda derivatives for each snapshot."})
    private boolean lambdaDerivative = false;
    @CommandLine.Option(names={"--continuousLambda", "--cL"}, paramLabel="false", description={"Data comes from continuous lambda source and only contains mbar file."})
    private boolean continuousLambda = false;
    @CommandLine.Option(names={"--outputDir", "--oD"}, paramLabel="", description={"Where to place MBAR files. Default is ../mbarFiles/energy_(window#).mbar. Will write out a file called energy_0.mbar."})
    private String outputDirectory = "";
    @CommandLine.Option(names={"--seed"}, paramLabel="BAR", description={"Seed MBAR calculation with this: ZEROS, ZWANZIG, BAR. Fallback to ZEROS if input is does not or is unlikely to converge."})
    private String seedWith = "BAR";
    @CommandLine.Option(names={"--tol", "--tolerance"}, paramLabel="1e-7", description={"Iteration change tolerance."})
    private double tol = 1.0E-7;
    @CommandLine.Option(names={"--ss", "--startSnapshot"}, paramLabel="-1", description={"Start at this snapshot when reading in tinker BAR files."})
    private int startingSnapshot = -1;
    @CommandLine.Option(names={"--es", "--endSnapshot"}, paramLabel="-1", description={"End at this snapshot when reading in tinker BAR files."})
    private int endingSnapshot = -1;
    @CommandLine.Option(names={"--verbose"}, paramLabel="false", description={"Log weight matrices, iterations, and other details."})
    private boolean verbose = false;
    @CommandLine.Parameters(arity="1..*", paramLabel="files", description={"Path to MBAR/BAR files to analyze or an PDB/XYZ in a directory with archive(s)."})
    private List<String> fileList = null;
    public MultistateBennettAcceptanceRatio mbar = null;
    private int numTopologies;

    public MBAR() {
    }

    public MBAR(FFXBinding binding) {
        super(binding);
    }

    public MBAR(String[] args) {
        super(args);
    }

    public MBAR run() {
        int i;
        boolean biasData;
        boolean isArc;
        if (!this.init()) {
            return this;
        }
        if (this.fileList == null) {
            logger.severe("No path to MBAR/BAR or trajectory(s) file names specified.");
            return this;
        }
        int nFiles = this.fileList.size();
        String[] fileNames = new String[nFiles];
        File[] files = new File[nFiles];
        for (int i2 = 0; i2 < nFiles; ++i2) {
            fileNames[i2] = this.fileList.get(i2);
            files[i2] = new File(fileNames[i2]).getAbsoluteFile();
            if (files[i2].exists()) continue;
            logger.severe("File does not exist: " + fileNames[i2]);
            return this;
        }
        boolean bl = isArc = !files[0].isDirectory();
        if (isArc) {
            File outputDir;
            int window;
            if (this.numLambda == -1) {
                logger.severe("numLambda must be specified for lambda energy evaluations.");
                return this;
            }
            File parent = files[0].getParentFile();
            if (this.outputDirectory.isEmpty()) {
                window = Integer.parseInt(parent.getName());
                outputDir = new File(parent.getParentFile(), "mbarFiles");
                if (!outputDir.exists()) {
                    outputDir.mkdir();
                }
            } else {
                outputDir = new File(this.outputDirectory);
                if (!outputDir.exists()) {
                    outputDir.mkdir();
                }
                window = 0;
            }
            File outputFile = new File(outputDir, "energy_" + window + ".mbar");
            double[][][] energiesAndDerivatives = this.getEnergyForLambdas(files, this.numLambda);
            double[][] energies = energiesAndDerivatives[0];
            MultistateBennettAcceptanceRatio.writeFile((double[][])energies, (File)outputFile, (double)298.0);
            if (this.lambdaDerivative) {
                double[][] lambdaDerivatives = energiesAndDerivatives[1];
                File outputDerivFile = new File(outputDir, "derivatives_" + window + ".mbar");
                MultistateBennettAcceptanceRatio.writeFile((double[][])lambdaDerivatives, (File)outputDerivFile, (double)298.0);
            }
            return this;
        }
        File path = new File(this.fileList.get(0));
        if (!path.exists()) {
            logger.severe("Path to MBAR/BAR fileNames does not exist: " + String.valueOf(path));
            return this;
        }
        if (!(path.isDirectory() || path.isFile() && path.canRead())) {
            logger.severe("Path to MBAR/BAR fileNames is not accessible: " + String.valueOf(path));
            return this;
        }
        MBARFilter filter = new MBARFilter(path, this.continuousLambda);
        if (this.startingSnapshot >= 0) {
            filter.setStartSnapshot(this.startingSnapshot);
            logger.info("Starting with snapshot index: " + this.startingSnapshot);
        }
        if (this.endingSnapshot >= 0) {
            filter.setEndSnapshot(this.endingSnapshot);
            logger.info("Ending with snapshot index: " + this.endingSnapshot);
        }
        this.seedWith = this.seedWith.toUpperCase();
        MultistateBennettAcceptanceRatio.SeedType seed = MultistateBennettAcceptanceRatio.SeedType.valueOf((String)this.seedWith);
        if (seed == null) {
            logger.severe("Invalid seed type: " + this.seedWith);
            return this;
        }
        MultistateBennettAcceptanceRatio.VERBOSE = this.verbose;
        this.mbar = this.mbar = filter.getMBAR(seed, this.tol);
        if (this.mbar == null) {
            logger.severe("Could not create MBAR object.");
            return this;
        }
        logger.info("\n MBAR Results:");
        double[] dGs = this.mbar.getFreeEnergyDifferences();
        double[] uncertainties = this.mbar.getFEDifferenceUncertainties();
        logger.info(String.format(" Total dG = %10.4f +/- %10.4f kcal/mol\n", this.mbar.getTotalFreeEnergyDifference(), this.mbar.getTotalFEDifferenceUncertainty()));
        for (int i3 = 0; i3 < dGs.length; ++i3) {
            logger.info(String.format("   dG %3d = %10.4f +/- %10.4f kcal/mol", i3, dGs[i3], uncertainties[i3]));
        }
        logger.info("\n MBAR Enthalpy & Entropy Results:");
        double[] enthalpies = this.mbar.getEnthalpyDifferences();
        double[] entropies = this.mbar.getBinEntropies();
        double totalEnthalpy = MBAR.sum(enthalpies);
        double totalEntropy = MBAR.sum(entropies);
        logger.info(String.format(" Total dG = %10.4f (dH) - %10.4f (TdS) kcal/mol\n", totalEnthalpy, totalEntropy));
        for (int i4 = 0; i4 < enthalpies.length; ++i4) {
            logger.info(String.format("   dG %3d = %10.4f (dH) - %10.4f (TdS) kcal/mol", i4, enthalpies[i4], entropies[i4]));
        }
        logger.info("\n MBAR uncertainty between all i & j: ");
        double[][] uncertaintyMatrix = this.mbar.getUncertaintyMatrix();
        for (int i5 = 0; i5 < uncertaintyMatrix.length; ++i5) {
            StringBuilder sb = new StringBuilder();
            sb.append("    [");
            for (int j = 0; j < uncertaintyMatrix[i5].length; ++j) {
                sb.append(String.format(" %6.5f ", uncertaintyMatrix[i5][j]));
            }
            sb.append("]");
            logger.info(sb.toString());
        }
        boolean observableData = filter.readObservableData(true, false, true);
        if (observableData) {
            logger.info("\n Observable data read in.");
        }
        if (biasData = filter.readObservableData(true, true, false)) {
            logger.info(" Bias data read in.");
        }
        if (observableData) {
            logger.info("\n MBAR Observable Data: ");
            double[] observableValues = this.mbar.getObservationEnsembleAverages();
            for (i = 0; i < observableValues.length; ++i) {
                logger.info(String.format("     %3d = %10.4f ", i, observableValues[i]));
            }
            logger.info(" Integral:    " + this.mbar.getTIIntegral());
        }
        logger.info("\n");
        if (this.bar) {
            try {
                logger.info("\n BAR Results:");
                BennettAcceptanceRatio bar = this.mbar.getBAR();
                logger.info(String.format(" Total dG = %10.4f +/- %10.4f kcal/mol\n", bar.getTotalFreeEnergyDifference(), bar.getTotalFEDifferenceUncertainty()));
                dGs = bar.getFreeEnergyDifferences();
                uncertainties = bar.getFEDifferenceUncertainties();
                for (i = 0; i < dGs.length; ++i) {
                    logger.info(String.format("   dG %3d = %10.4f +/- %10.4f kcal/mol", i, dGs[i], uncertainties[i]));
                }
                enthalpies = bar.getEnthalpyDifferences();
                totalEnthalpy = MBAR.sum(enthalpies);
                logger.info(String.format("\n Total dH = %10.4f kcal/mol\n", totalEnthalpy));
                for (i = 0; i < enthalpies.length; ++i) {
                    logger.info(String.format("   dH %3d = %10.4f kcal/mol", i, enthalpies[i]));
                }
            }
            catch (Exception ignored) {
                logger.warning(" BAR calculation failed to converge.");
            }
        }
        logger.info("\n");
        if (this.numBootstrap != 0) {
            EstimateBootstrapper bootstrapper = new EstimateBootstrapper((BootstrappableEstimator)this.mbar);
            bootstrapper.bootstrap((long)this.numBootstrap);
            logger.info("\n MBAR Bootstrap Results from " + this.numBootstrap + " Samples:");
            logger.info(String.format(" Total dG = %10.4f +/- %10.4f kcal/mol", bootstrapper.getTotalFreeEnergyDifference(), bootstrapper.getTotalFEDifferenceUncertainty()));
            dGs = bootstrapper.getFreeEnergyDifferences();
            uncertainties = bootstrapper.getFEDifferenceStdDevs();
            for (i = 0; i < dGs.length; ++i) {
                logger.info(String.format("    dG %3d = %10.4f +/- %10.4f kcal/mol", i, dGs[i], uncertainties[i]));
            }
            logger.info("\n");
            if (this.bar) {
                try {
                    logger.info("\n BAR Bootstrap Results:");
                    bootstrapper = new EstimateBootstrapper((BootstrappableEstimator)this.mbar.getBAR());
                    bootstrapper.bootstrap((long)this.numBootstrap);
                    logger.info(String.format(" Total dG = %10.4f +/- %10.4f kcal/mol", bootstrapper.getTotalFreeEnergyDifference(), bootstrapper.getTotalFEDifferenceUncertainty()));
                    dGs = bootstrapper.getFreeEnergyDifferences();
                    uncertainties = bootstrapper.getFEDifferenceStdDevs();
                    for (i = 0; i < dGs.length; ++i) {
                        logger.info(String.format("    dG %3d = %10.4f +/- %10.4f kcal/mol", i, dGs[i], uncertainties[i]));
                    }
                }
                catch (Exception ignored) {
                    logger.warning(" BAR calculation failed to converge.");
                }
            }
        }
        if (this.convergence) {
            MultistateBennettAcceptanceRatio.FORCE_ZEROS_SEED = true;
            MultistateBennettAcceptanceRatio[] mbarPeriodComparison = filter.getPeriodComparisonMBAR(seed, 1.0E-7);
            double[][] dGPeriod = new double[mbarPeriodComparison.length][];
            for (int i6 = 0; i6 < mbarPeriodComparison.length; ++i6) {
                dGPeriod[i6] = mbarPeriodComparison[i6].getFreeEnergyDifferences();
            }
            logger.info("\n MBAR Period Comparison Results:");
            logger.info(String.format("     %10d%%%10d%%%10d%%%10d%%%10d%%%10d%%%10d%%%10d%%%10d%%%10d%% ", 10, 20, 30, 40, 50, 60, 70, 80, 90, 100));
            double[] totals = new double[dGPeriod[0].length];
            for (int i7 = 0; i7 < dGPeriod[0].length; ++i7) {
                StringBuilder sb = new StringBuilder();
                sb.append(" dG_").append(i7).append(": ");
                for (int j = 0; j < dGPeriod.length; ++j) {
                    sb.append(String.format("%10.4f ", dGPeriod[j][i7]));
                    int n = j;
                    totals[n] = totals[n] + dGPeriod[j][i7];
                }
                logger.info(sb.toString());
            }
            StringBuilder totalsSB = new StringBuilder();
            for (int i8 = 0; i8 < totals.length; ++i8) {
                totalsSB.append(String.format("%10.4f ", totals[i8]));
            }
            logger.info("");
            logger.info("  Tot: " + totalsSB.toString());
        }
        return this;
    }

    private static double sum(double[] values) {
        double sum = 0.0;
        for (double value : values) {
            sum += value;
        }
        return sum;
    }

    private double[][][] getEnergyForLambdas(File[] files, int nLambda) {
        this.numTopologies = files.length;
        int threadsPerTopology = this.topologyOptions.getThreadsPerTopology(this.numTopologies);
        MolecularAssembly[] molecularAssemblies = new MolecularAssembly[this.numTopologies];
        SystemFilter[] openers = new SystemFilter[this.numTopologies];
        this.alchemicalOptions.setAlchemicalProperties();
        this.topologyOptions.setAlchemicalProperties(this.numTopologies);
        if (this.numTopologies == 2) {
            logger.info(String.format(" Initializing two topologies for each window.", new Object[0]));
        } else {
            logger.info(String.format(" Initializing a single topology for each window.", new Object[0]));
        }
        for (int i = 0; i < this.numTopologies; ++i) {
            molecularAssemblies[i] = this.alchemicalOptions.openFile((PotentialsFunctions)this.algorithmFunctions, this.topologyOptions, threadsPerTopology, files[i].getName(), i);
            openers[i] = this.algorithmFunctions.getFilter();
        }
        StringBuilder sb = new StringBuilder(String.format("\n Performing FEP evaluations for: %s\n ", Arrays.toString(files)));
        CrystalPotential potential = (CrystalPotential)this.topologyOptions.assemblePotential(molecularAssemblies, sb);
        String[] arcFileName = new String[files.length];
        for (int j = 0; j < this.numTopologies; ++j) {
            arcFileName[j] = FilenameUtils.removeExtension((String)files[j].getAbsolutePath()) + ".arc";
            File archiveFile = new File(arcFileName[j]);
            openers[j].setFile(archiveFile);
            molecularAssemblies[j].setFile(archiveFile);
        }
        int nSnapshots = openers[0].countNumModels();
        double[] x = new double[potential.getNumberOfVariables()];
        double[] lambdaValues = new double[nLambda];
        double[][] energy = new double[nLambda][nSnapshots];
        double[][] lambdaDerivatives = new double[nLambda][nSnapshots];
        for (int k = 0; k < lambdaValues.length; ++k) {
            lambdaValues[k] = (double)k / (double)(nLambda - 1);
            energy[k] = new double[nSnapshots];
            lambdaDerivatives[k] = new double[nSnapshots];
        }
        LambdaInterface linter1 = (LambdaInterface)potential;
        logger.info(String.format("\n\n Performing energy evaluations for %d snapshots.", nSnapshots));
        logger.info(String.format(" Using %d lambda values.", nLambda));
        logger.info(String.format(" Using %d topologies.", this.numTopologies));
        logger.info(" Lambda values: " + Arrays.toString(lambdaValues));
        logger.info("");
        for (int i = 0; i < nSnapshots; ++i) {
            boolean resetPosition = i == 0;
            for (int n = 0; n < openers.length; ++n) {
                openers[n].readNext(resetPosition, false);
            }
            x = potential.getCoordinates(x);
            StringBuilder sb2 = new StringBuilder().append("Snapshot ").append(i).append(" Energy Evaluations: ");
            StringBuilder sb3 = new StringBuilder().append("Snapshot ").append(i).append(" Lambda Derivatives: ");
            for (int k = 0; k < lambdaValues.length; ++k) {
                double lambda = lambdaValues[k];
                if (lambda <= 1.0E-6) {
                    lambda += 0.00275;
                }
                if (lambda - 1.0 < 1.0E-6) {
                    lambda -= 0.00275;
                }
                linter1.setLambda(lambda);
                energy[k][i] = potential.energyAndGradient(x, new double[x.length * 3]);
                if (this.lambdaDerivative) {
                    lambdaDerivatives[k][i] = linter1.getdEdL();
                    sb3.append(" ").append(lambdaDerivatives[k][i]);
                }
                sb2.append(" ").append(energy[k][i]);
            }
            logger.info(sb2.append("\n").toString());
            if (!this.lambdaDerivative) continue;
            logger.info(sb3.append("\n").toString());
        }
        return new double[][][]{energy, lambdaDerivatives};
    }

    @Override
    public List<Potential> getPotentials() {
        return Collections.emptyList();
    }
}

