/*
 * Decompiled with CFR 0.152.
 */
package ffx.algorithms.commands.test;

import ffx.algorithms.cli.AlgorithmsCommand;
import ffx.numerics.Potential;
import ffx.numerics.estimator.EstimateBootstrapper;
import ffx.numerics.math.RunningStatistics;
import ffx.numerics.math.SummaryStatistics;
import ffx.potential.ForceFieldEnergy;
import ffx.potential.bonded.Residue;
import ffx.potential.extended.ExtendedSystem;
import ffx.potential.parsers.XPHFilter;
import ffx.utilities.FFXBinding;
import java.io.File;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
import org.apache.commons.lang3.ArrayUtils;
import picocli.CommandLine;

@CommandLine.Command(description={" Use the Rao-Blackwell estimator to get a free energy difference for residues in a CpHMD system."}, name="test.RaoBlackwellEstimator")
public class RaoBlackwellEstimator
extends AlgorithmsCommand {
    @CommandLine.Option(names={"--aFi", "--arcFile"}, paramLabel="traj", description={"A file containing the the PDB from which to build the ExtendedSystem. There is currently no default."})
    private String arcFileName = null;
    @CommandLine.Option(names={"--numSnaps"}, paramLabel="-1", defaultValue="-1", description={"Number of snapshots to use from an archive file. Default is all."})
    private int numSnaps;
    @CommandLine.Option(names={"--specifiedResidues", "--sR"}, paramLabel="<selection>", defaultValue="", description={"Specified residues to do analysis."})
    private String specified;
    @CommandLine.Option(names={"--startSnap"}, paramLabel="-1", defaultValue="-1", description={"Start energy evaluations at a snap other than 2."})
    private int startSnap;
    @CommandLine.Option(names={"--bootstrapIter"}, paramLabel="100000", defaultValue="100000", description={"Number of bootstrap iterations. Set -1 for no bootstrapping."})
    private int bootstrapIter;
    @CommandLine.Option(names={"--skip"}, paramLabel="-1", defaultValue="-1", description={"Calculate energies on snaps with this interval."})
    private int skip;
    @CommandLine.Option(names={"--writeFrequency"}, paramLabel="100", defaultValue="100", description={"Calculate the RBE and print at this snapshot read frequency."})
    private int writeFrequency;
    @CommandLine.Parameters(arity="1..*", paramLabel="files", description={"PDB input file in the same directory as the ARC file."})
    private String filename;
    private Potential forceFieldEnergy;
    private ArrayList<Double>[][] oneZeroDeltaLists;
    private ArrayList<Double>[][] tautomerOneZeroDeltaList;
    private int numESVs;
    private int numTautomerESVs;

    public RaoBlackwellEstimator() {
    }

    public RaoBlackwellEstimator(FFXBinding binding) {
        super(binding);
    }

    public RaoBlackwellEstimator(String[] args) {
        super(args);
    }

    public RaoBlackwellEstimator run() {
        int i;
        if (!this.init()) {
            return this;
        }
        File arcFile = new File(this.arcFileName);
        if (!arcFile.exists()) {
            logger.severe(String.format(" ARC file %s does not exist.", arcFile));
        } else {
            logger.info(String.format("Using ARC file %s.", arcFile));
        }
        boolean bootstrap = false;
        if (this.bootstrapIter >= 50) {
            bootstrap = true;
        } else if (this.bootstrapIter != -1) {
            logger.severe("Too few bootstrap iterations specified. Must be at least 50.");
        }
        this.activeAssembly = this.getActiveAssembly(this.filename);
        if (this.activeAssembly == null) {
            logger.info(this.helpString());
            return this;
        }
        this.forceFieldEnergy = this.activeAssembly.getPotentialEnergy();
        String filename = this.activeAssembly.getFile().getAbsolutePath();
        ExtendedSystem esvSystem = new ExtendedSystem(this.activeAssembly, 7.0, null);
        Residue specialResidue = null;
        int numberOfStates = 1;
        int[][] states = null;
        if (esvSystem.getSpecialResidueList().size() > 1) {
            logger.severe(" Multiple special residues were identified in the key file. Only one can be specified with this algorithm.");
        } else if (esvSystem.getSpecialResidueList().size() == 1) {
            int specialResidueNumber = ((Double)esvSystem.getSpecialResidueList().get(0)).intValue();
            for (Residue residue : esvSystem.getTitratingResidueList()) {
                if (residue.getResidueNumber() != specialResidueNumber) continue;
                specialResidue = residue;
            }
            if (specialResidue != null) {
                numberOfStates = !esvSystem.isTautomer(specialResidue) ? 3 : 4;
                switch (specialResidue.getName()) {
                    case "ASD": 
                    case "GLD": {
                        states = new int[3][2];
                        states[0][0] = 0;
                        states[0][1] = 0;
                        states[1][0] = 1;
                        states[1][1] = 0;
                        states[2][0] = 1;
                        states[2][1] = 1;
                        break;
                    }
                    case "HIS": {
                        states = new int[3][2];
                        states[0][0] = 0;
                        states[0][1] = 0;
                        states[1][0] = 0;
                        states[1][1] = 1;
                        states[2][0] = 1;
                        states[2][1] = 0;
                        break;
                    }
                    case "LYS": 
                    case "CYS": {
                        states = new int[2][2];
                        states[0][0] = 0;
                        states[0][1] = 0;
                        states[1][0] = 1;
                        states[1][1] = 0;
                    }
                }
            } else {
                logger.severe(" The special residue specified in the key file was not found in the titrating residue list.");
            }
        }
        int[] specifiedResidues = null;
        if (this.specified != null && !this.specified.isEmpty()) {
            String[] specifiedResiduesString = this.specified.split(",");
            specifiedResidues = new int[specifiedResiduesString.length];
            for (int i2 = 0; i2 < specifiedResiduesString.length; ++i2) {
                specifiedResidues[i2] = Integer.parseInt(specifiedResiduesString[i2].trim());
            }
        }
        ArrayList<Residue> onlyResidues = new ArrayList<Residue>();
        ArrayList<Integer> onlyResidueIndices = new ArrayList<Integer>();
        if (specifiedResidues != null) {
            for (i = 0; i < esvSystem.getTitratingResidueList().size(); ++i) {
                Residue residue = (Residue)esvSystem.getTitratingResidueList().get(i);
                if (!ArrayUtils.contains((int[])specifiedResidues, (int)residue.getResidueNumber())) continue;
                onlyResidues.add(residue);
                onlyResidueIndices.add(i);
            }
            if (onlyResidues.size() != specifiedResidues.length) {
                logger.severe("Could not find all residues from --specifiedResidues input.");
            }
        } else {
            for (i = 0; i < esvSystem.getTitratingResidueList().size(); ++i) {
                onlyResidueIndices.add(i);
            }
        }
        this.numESVs = esvSystem.getTitratingResidueList().size();
        this.oneZeroDeltaLists = new ArrayList[this.numESVs][numberOfStates + 1];
        for (i = 0; i < this.numESVs; ++i) {
            for (int j = 0; j < numberOfStates + 1; ++j) {
                this.oneZeroDeltaLists[i][j] = new ArrayList();
            }
        }
        this.numTautomerESVs = esvSystem.getTautomerizingResidueList().size();
        this.tautomerOneZeroDeltaList = new ArrayList[this.numTautomerESVs][numberOfStates + 1];
        for (i = 0; i < this.numTautomerESVs; ++i) {
            for (int j = 0; j < numberOfStates + 1; ++j) {
                this.tautomerOneZeroDeltaList[i][j] = new ArrayList();
            }
        }
        this.activeAssembly.setFile(arcFile);
        XPHFilter xphFilter = new XPHFilter(arcFile, this.activeAssembly, this.activeAssembly.getForceField(), this.activeAssembly.getProperties(), esvSystem);
        xphFilter.readFile();
        esvSystem.setFixedTitrationState(true);
        esvSystem.setFixedTautomerState(true);
        ((ForceFieldEnergy)this.forceFieldEnergy).attachExtendedSystem(esvSystem);
        logger.info(String.format(" Attached extended system with %d residues.", this.numESVs));
        double[] x = new double[this.forceFieldEnergy.getNumberOfVariables()];
        this.forceFieldEnergy.getCoordinates(x);
        this.forceFieldEnergy.energy(x, true);
        double pH = 0.0;
        String[] parts = xphFilter.getRemarkLines()[0].split(" ");
        for (int i3 = 0; i3 < parts.length; ++i3) {
            if (!parts[i3].contains("pH")) continue;
            pH = Double.parseDouble(parts[i3 + 1]);
        }
        logger.info("\n Setting constant pH to " + pH + ".");
        esvSystem.setConstantPh(pH);
        int evals = 0;
        if (this.numSnaps != -1) {
            logger.info(String.format(" Using %d snapshots.", this.numSnaps));
        } else {
            logger.info(String.format(" Using all %d snapshots.", xphFilter.countNumModels()));
        }
        while (xphFilter.readNext()) {
            if (this.startSnap != -1 && this.startSnap > 2 && evals == 0) {
                for (int i4 = 0; i4 < this.startSnap - 2; ++i4) {
                    xphFilter.readNext();
                }
            }
            this.forceFieldEnergy.getCoordinates(x);
            Iterator i4 = onlyResidueIndices.iterator();
            while (i4.hasNext()) {
                int i5 = (Integer)i4.next();
                double titrationState = 0.0;
                double tautomerState = 0.0;
                if (specialResidue != null) {
                    titrationState = esvSystem.getTitrationLambda(specialResidue);
                    tautomerState = esvSystem.getTautomerLambda(specialResidue);
                }
                for (int j = 0; j < numberOfStates; ++j) {
                    if (j != 0) {
                        esvSystem.setTitrationLambda(specialResidue, (double)states[j - 1][0], false);
                        if (numberOfStates != 3) {
                            esvSystem.setTautomerLambda(specialResidue, (double)states[j - 1][1], false);
                        }
                    }
                    ArrayList<Double> results = RaoBlackwellEstimator.getZeroOneDeltas(i5, esvSystem, (ForceFieldEnergy)this.forceFieldEnergy, x);
                    Residue res = (Residue)esvSystem.getTitratingResidueList().get(i5);
                    if (esvSystem.isTautomer(res)) {
                        this.tautomerOneZeroDeltaList[esvSystem.getTautomerizingResidueList().indexOf(res)][j].add(results.get(0));
                        this.oneZeroDeltaLists[i5][j].add(results.get(1));
                        continue;
                    }
                    this.oneZeroDeltaLists[i5][j].add(results.get(0));
                }
                if (specialResidue == esvSystem.getExtendedResidueList().get(i5)) break;
                if (specialResidue == null) continue;
                esvSystem.setTitrationLambda(specialResidue, titrationState, false);
                esvSystem.setTautomerLambda(specialResidue, tautomerState, false);
            }
            if (++evals % this.writeFrequency == 0 || evals == this.numSnaps) {
                int tautomerCount = 0;
                double[][] energyLists = new double[this.numESVs][numberOfStates];
                double[][] energyStdLists = new double[this.numESVs][numberOfStates];
                double[][] tautomerEnergyLists = new double[this.numTautomerESVs][numberOfStates];
                double[][] tautomerEnergyStdLists = new double[this.numTautomerESVs][numberOfStates];
                Iterator iterator = onlyResidueIndices.iterator();
                while (iterator.hasNext()) {
                    int i6 = (Integer)iterator.next();
                    Residue res = (Residue)esvSystem.getExtendedResidueList().get(i6);
                    logger.info("\n Performing Rao-Blackwell Estimator on " + String.valueOf(res.getAminoAcid3()) + ".");
                    if (bootstrap) {
                        logger.info("  Performing bootstrap with " + this.bootstrapIter + " iterations.");
                    } else {
                        logger.info("  Performing RBE without bootstrap. Ignore standard deviation values.");
                    }
                    for (int j = 0; j < numberOfStates; ++j) {
                        double[] bootstrapMeanStd = RaoBlackwellEstimator.RBE(this.oneZeroDeltaLists[i6][j], bootstrap, this.bootstrapIter);
                        energyLists[i6][j] = bootstrapMeanStd[0];
                        if (bootstrap) {
                            energyStdLists[i6][j] = bootstrapMeanStd[1];
                        }
                        if (!esvSystem.getTautomerizingResidueList().contains(res)) continue;
                        bootstrapMeanStd = RaoBlackwellEstimator.RBE(this.tautomerOneZeroDeltaList[esvSystem.getTautomerizingResidueList().indexOf(res)][j], bootstrap, this.bootstrapIter);
                        tautomerEnergyLists[tautomerCount][j] = bootstrapMeanStd[0];
                        if (!bootstrap) continue;
                        tautomerEnergyStdLists[tautomerCount][j] = bootstrapMeanStd[1];
                    }
                    if (esvSystem.isTautomer(res)) {
                        ++tautomerCount;
                    }
                    if (specialResidue != res) continue;
                    break;
                }
                RaoBlackwellEstimator.printResults(specialResidue, esvSystem, energyLists, energyStdLists, tautomerEnergyLists, tautomerEnergyStdLists, states, numberOfStates, this.numESVs, onlyResidueIndices);
            }
            if (this.numSnaps != -1 && evals >= this.numSnaps) break;
            if (this.skip == -1) continue;
            for (int i7 = 0; i7 < this.skip - 1; ++i7) {
                xphFilter.readNext();
            }
        }
        return this;
    }

    private static double[] RBE(ArrayList<Double> deltaUList, boolean bootstrap, int bootstrapIter) {
        double[] dArray;
        ArrayList<Double> deltaU = deltaUList;
        double temperature = 298.0;
        double boltzmann = 0.001985875;
        double beta = 1.0 / (temperature * boltzmann);
        ArrayList<Double> deltaExp = RaoBlackwellEstimator.exp(RaoBlackwellEstimator.mult(-beta, deltaU));
        ArrayList<Double> numerator = RaoBlackwellEstimator.div(RaoBlackwellEstimator.mult(beta, RaoBlackwellEstimator.mult(deltaU, deltaExp)), RaoBlackwellEstimator.subtract(1.0, deltaExp));
        ArrayList<Double> denominator = RaoBlackwellEstimator.div(RaoBlackwellEstimator.mult(beta, deltaU), RaoBlackwellEstimator.subtract(1.0, deltaExp));
        if (bootstrap) {
            dArray = RaoBlackwellEstimator.bootStrap(numerator, denominator, bootstrapIter);
        } else {
            double[] dArray2 = new double[1];
            dArray = dArray2;
            dArray2[0] = -(1.0 / beta) * Math.log(RaoBlackwellEstimator.average(numerator) / RaoBlackwellEstimator.average(denominator));
        }
        double[] deltaGRBE = dArray;
        return deltaGRBE;
    }

    private static double[] bootStrap(ArrayList<Double> numerator, ArrayList<Double> denominator, int iter) {
        RunningStatistics estimates = new RunningStatistics();
        for (int k = 0; k < iter; ++k) {
            Random rng = new Random();
            int[] trial = EstimateBootstrapper.getBootstrapIndices((int)numerator.size(), (Random)rng);
            double estimate = RaoBlackwellEstimator.estimateDg(numerator, denominator, trial);
            estimates.addValue(estimate);
        }
        SummaryStatistics stats = new SummaryStatistics(estimates);
        return new double[]{stats.mean, stats.getSd()};
    }

    private static double estimateDg(ArrayList<Double> num, ArrayList<Double> denom, int[] index) {
        double temperature = 298.0;
        double boltzmann = 0.001985875;
        double beta = 1.0 / (temperature * boltzmann);
        ArrayList<Double> numerator = new ArrayList<Double>();
        numerator.ensureCapacity(index.length);
        ArrayList<Double> denominator = new ArrayList<Double>();
        denominator.ensureCapacity(index.length);
        for (int i = 0; i < index.length; ++i) {
            numerator.add(num.get(index[i]));
            denominator.add(denom.get(index[i]));
        }
        return -(1.0 / beta) * Math.log(RaoBlackwellEstimator.average(numerator) / RaoBlackwellEstimator.average(denominator));
    }

    private static void printResults(Residue specialResidue, ExtendedSystem esvSystem, double[][] energyLists, double[][] energyStdLists, double[][] tautomerEnergyLists, double[][] tautomerStdLists, int[][] states, int numberOfStates, int numESVs, ArrayList<Integer> onlyResidueIndex) {
        logger.info("\n Rao-Blackwell Estimator Results: ");
        ArrayList<String> line = new ArrayList<String>();
        if (specialResidue != null) {
            logger.info(" Special Residue: " + specialResidue.toString());
            if (esvSystem.isTautomer(specialResidue)) {
                logger.info(String.format("  %-10s %-10s %-23s %-28s %-28s %-28s", "Residue", "Tautomer", "DeltaGTitr", "DeltaG-SpecialRes=(" + states[0][0] + "," + states[0][1] + ")", "DeltaG-SpecialRes=(" + states[1][0] + "," + states[1][1] + ")", "DeltaG-SpecialRes=(" + states[2][0] + "," + states[2][1] + ")"));
            } else {
                logger.info(String.format("  %-10s %-10s %-23s %-28s %-28s", "Residue", "Tautomer", "DeltaGTitr", "DeltaG-SpecialRes=(" + states[0][0] + "," + states[0][1] + ")", "DeltaG-SpecialRes=(" + states[1][0] + "," + states[1][1] + ")"));
            }
        } else {
            logger.info(String.format("  %-10s %-10s %-23s", "Residue", "Tautomer", "DeltaGTitr"));
        }
        int tautomerCount = 0;
        for (int i : onlyResidueIndex) {
            int j;
            Residue res = (Residue)esvSystem.getExtendedResidueList().get(i);
            line.add(res.toString());
            line.add("0");
            line.add(Double.toString(energyLists[i][0]));
            line.add(Double.toString(energyStdLists[i][0]));
            for (j = 1; j < numberOfStates; ++j) {
                line.add(Double.toString(energyLists[i][j]));
                line.add(Double.toString(energyStdLists[i][j]));
            }
            if (specialResidue != null && esvSystem.isTautomer(specialResidue)) {
                logger.info(String.format("  %-10s %-10s %-10.5f +/- %-5.3f    %-10.5f +/- %-5.3f         %-10.5f +/- %-5.3f         %-10.5f +/- %-5.3f", line.get(0), line.get(1), Double.parseDouble((String)line.get(2)), Double.parseDouble((String)line.get(3)), Double.parseDouble((String)line.get(4)), Double.parseDouble((String)line.get(5)), Double.parseDouble((String)line.get(6)), Double.parseDouble((String)line.get(7)), Double.parseDouble((String)line.get(8)), Double.parseDouble((String)line.get(9))));
            } else if (specialResidue != null) {
                logger.info(String.format("  %-10s %-10s %-10.5f +/- %-5.3f    %-10.5f +/- %-5.3f         %-10.5f +/- %-5.3f", line.get(0), line.get(1), Double.parseDouble((String)line.get(2)), Double.parseDouble((String)line.get(3)), Double.parseDouble((String)line.get(4)), Double.parseDouble((String)line.get(5)), Double.parseDouble((String)line.get(6)), Double.parseDouble((String)line.get(7))));
            } else {
                logger.info(String.format("  %-10s %-10s %-10.5f +/- %-5.3f", line.get(0), line.get(1), Double.parseDouble((String)line.get(2)), Double.parseDouble((String)line.get(3))));
            }
            line.clear();
            if (!esvSystem.isTautomer(res)) continue;
            line.add(res.toString());
            line.add("1");
            line.add(Double.toString(tautomerEnergyLists[tautomerCount][0]));
            line.add(Double.toString(tautomerStdLists[tautomerCount][0]));
            for (j = 1; j < numberOfStates; ++j) {
                line.add(Double.toString(tautomerEnergyLists[tautomerCount][j]));
                line.add(Double.toString(tautomerStdLists[tautomerCount][j]));
            }
            ++tautomerCount;
            if (specialResidue != null && esvSystem.isTautomer(specialResidue)) {
                logger.info(String.format("  %-10s %-10s %-10.5f +/- %-5.3f    %-10.5f +/- %-5.3f         %-10.5f +/- %-5.3f         %-10.5f +/- %-5.3f", line.get(0), line.get(1), Double.parseDouble((String)line.get(2)), Double.parseDouble((String)line.get(3)), Double.parseDouble((String)line.get(4)), Double.parseDouble((String)line.get(5)), Double.parseDouble((String)line.get(6)), Double.parseDouble((String)line.get(7)), Double.parseDouble((String)line.get(8)), Double.parseDouble((String)line.get(9))));
            } else if (specialResidue != null) {
                logger.info(String.format("  %-10s %-10s %-10.5f +/- %-5.3f    %-10.5f +/- %-5.3f         %-10.5f +/- %-5.3f", line.get(0), line.get(1), Double.parseDouble((String)line.get(2)), Double.parseDouble((String)line.get(3)), Double.parseDouble((String)line.get(4)), Double.parseDouble((String)line.get(5)), Double.parseDouble((String)line.get(6)), Double.parseDouble((String)line.get(7))));
            } else {
                logger.info(String.format("  %-10s %-10s %-10.5f +/- %-5.3f", line.get(0), line.get(1), Double.parseDouble((String)line.get(2)), Double.parseDouble((String)line.get(3))));
            }
            line.clear();
        }
    }

    private static ArrayList<Double> getZeroOneDeltas(int i, ExtendedSystem esv, ForceFieldEnergy forceFieldEnergy, double[] x) {
        double oneEnergy;
        double zeroEnergy;
        ArrayList<Double> deltaU = new ArrayList<Double>();
        Residue res = (Residue)esv.getExtendedResidueList().get(i);
        double titrationState = esv.getTitrationLambda(res);
        double tautomerState = esv.getTautomerLambda(res);
        if (esv.getTautomerizingResidueList().contains(res)) {
            esv.setTautomerLambda(res, 1.0, false);
            esv.setTitrationLambda(res, 0.0, false);
            zeroEnergy = forceFieldEnergy.energy(x, false);
            esv.setTitrationLambda(res, 1.0, false);
            oneEnergy = forceFieldEnergy.energy(x, false);
            esv.setTitrationLambda(res, titrationState, false);
            deltaU.add(oneEnergy - zeroEnergy);
            esv.setTautomerLambda(res, 0.0, false);
        }
        esv.setTitrationLambda(res, 0.0, false);
        zeroEnergy = forceFieldEnergy.energy(x, false);
        esv.setTitrationLambda(res, 1.0, false);
        oneEnergy = forceFieldEnergy.energy(x, false);
        esv.setTitrationLambda(res, titrationState, false);
        deltaU.add(oneEnergy - zeroEnergy);
        if (esv.getTautomerizingResidueList().contains(res)) {
            esv.setTautomerLambda(res, tautomerState, false);
        }
        return deltaU;
    }

    private static double average(ArrayList<Double> list) {
        double sum = 0.0;
        for (Double d : list) {
            sum += d.doubleValue();
        }
        return sum / (double)list.size();
    }

    private static ArrayList<Double> mult(double a, ArrayList<Double> u) {
        ArrayList<Double> result = new ArrayList<Double>();
        for (Double d : u) {
            result.add(a * d);
        }
        return result;
    }

    private static ArrayList<Double> mult(ArrayList<Double> v, ArrayList<Double> u) {
        if (v.size() != u.size()) {
            throw new IllegalArgumentException("Vector sizes must be equal.");
        }
        ArrayList<Double> result = new ArrayList<Double>();
        for (int i = 0; i < v.size(); ++i) {
            result.add(v.get(i) * u.get(i));
        }
        return result;
    }

    private static ArrayList<Double> subtract(double a, ArrayList<Double> u) {
        ArrayList<Double> result = new ArrayList<Double>();
        for (Double d : u) {
            result.add(a - d);
        }
        return result;
    }

    private static ArrayList<Double> exp(ArrayList<Double> u) {
        ArrayList<Double> result = new ArrayList<Double>();
        for (Double d : u) {
            result.add(Math.exp(d));
        }
        return result;
    }

    private static ArrayList<Double> div(ArrayList<Double> a, ArrayList<Double> b) {
        if (a.size() != b.size()) {
            throw new IllegalArgumentException("Vector sizes must be equal.");
        }
        ArrayList<Double> result = new ArrayList<Double>();
        for (int i = 0; i < a.size(); ++i) {
            result.add(a.get(i) / b.get(i));
        }
        return result;
    }
}

