/*
 * Decompiled with CFR 0.152.
 */
package ffx.potential.commands;

import ffx.crystal.Crystal;
import ffx.numerics.Potential;
import ffx.potential.ForceFieldEnergy;
import ffx.potential.bonded.Atom;
import ffx.potential.bonded.Residue;
import ffx.potential.cli.AtomSelectionOptions;
import ffx.potential.cli.PotentialCommand;
import ffx.potential.extended.ExtendedSystem;
import ffx.potential.parsers.PDBFilter;
import ffx.potential.parsers.SystemFilter;
import ffx.potential.parsers.XPHFilter;
import ffx.potential.parsers.XYZFilter;
import ffx.potential.utils.StructureMetrics;
import ffx.utilities.FFXBinding;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.math3.util.FastMath;
import picocli.CommandLine;

@CommandLine.Command(description={" Compute the force field potential energy for a CpHMD system."}, name="PhEnergy")
public class PhEnergy
extends PotentialCommand {
    @CommandLine.Mixin
    private AtomSelectionOptions atomSelectionOptions = new AtomSelectionOptions();
    @CommandLine.Option(names={"-m", "--moments"}, paramLabel="false", defaultValue="false", description={"Print out electrostatic moments."})
    private boolean moments = false;
    @CommandLine.Option(names={"--rg", "--gyrate"}, paramLabel="false", defaultValue="false", description={"Print out the radius of gyration."})
    private boolean gyrate = false;
    @CommandLine.Option(names={"--in", "--inertia"}, paramLabel="false", defaultValue="false", description={"Print out the moments of inertia."})
    private boolean inertia = false;
    @CommandLine.Option(names={"-g", "--gradient"}, paramLabel="false", defaultValue="false", description={"Compute the atomic gradient as well as energy."})
    private boolean gradient = false;
    @CommandLine.Option(names={"-v", "--verbose"}, paramLabel="false", defaultValue="false", description={"Print out all energy components for each snapshot."})
    private boolean verbose = false;
    @CommandLine.Option(names={"--pH", "--constantPH"}, paramLabel="7.4", description={"pH value for the energy evaluation. (Only applies when esvTerm is true)"})
    double pH = 7.4;
    @CommandLine.Option(names={"--aFi", "--arcFile"}, paramLabel="traj", description={"A file containing snapshots to evaluate on when using a PDB as a reference to build from. There is currently no default."})
    private String arcFileName = null;
    @CommandLine.Option(names={"--bar", "--mbar"}, paramLabel="false", description={"Run (restartable) energy evaluations for MBAR. Requires an ARC file to be passed in. Set the tautomer flag to true for tautomer parameterization."})
    boolean mbar = false;
    @CommandLine.Option(names={"--numLambda", "--nL", "--nw"}, paramLabel="-1", description={"Required for lambda energy evaluations. Ensure numLambda is consistent with the trajectory lambdas, i.e. gaps between traj can be filled easily. nL >> nTraj is recommended."})
    int numLambda = -1;
    @CommandLine.Option(names={"--outputDir", "--oD"}, paramLabel="", description={"Where to place MBAR files. Default is ../mbarFiles/energy_(window#).mbar. Will write out a file called energy_0.mbar."})
    String outputDirectory = "";
    @CommandLine.Option(names={"--lambdaDerivative", "--lD"}, paramLabel="false", description={"Perform dU/dL evaluations and save to mbarFiles."})
    boolean derivatives = false;
    @CommandLine.Option(names={"--perturbTautomer"}, paramLabel="false", description={"Change tautomer instead of lambda state for MBAR energy evaluations."})
    boolean tautomer = false;
    @CommandLine.Option(names={"--testEndStateEnergies"}, paramLabel="false", description={"Test both ESV energy end states as if the polarization damping factor is initialized from the respective protonated or deprotonated state"})
    boolean testEndstateEnergies = false;
    @CommandLine.Option(names={"--recomputeAverage"}, paramLabel="false", description={"Recompute average position and spit out said structure from trajectory"})
    boolean recomputeAverage = false;
    @CommandLine.Parameters(arity="1", paramLabel="file", description={"The atomic coordinate file in PDB or XPH format."})
    private String filename = null;
    public double energy = 0.0;
    public ForceFieldEnergy forceFieldEnergy = null;

    public PhEnergy() {
    }

    public PhEnergy(FFXBinding binding) {
        super(binding);
    }

    public PhEnergy(String[] args) {
        super(args);
    }

    public PhEnergy run() {
        SystemFilter systemFilter;
        int i;
        if (!this.init()) {
            return this;
        }
        if (this.mbar) {
            System.setProperty("lock.esv.states", "false");
        }
        this.activeAssembly = this.getActiveAssembly(this.filename);
        if (this.activeAssembly == null) {
            logger.info(this.helpString());
            return this;
        }
        this.filename = this.activeAssembly.getFile().getAbsolutePath();
        logger.info("\n Running Energy on " + this.filename);
        this.forceFieldEnergy = this.activeAssembly.getPotentialEnergy();
        File esv = new File(FilenameUtils.removeExtension((String)this.filename) + ".esv");
        if (!esv.exists()) {
            esv = null;
        }
        ExtendedSystem esvSystem = new ExtendedSystem(this.activeAssembly, this.pH, esv);
        if (this.testEndstateEnergies && BigDecimal.valueOf(esvSystem.getExtendedLambdas()[0]).compareTo(BigDecimal.ZERO) == 0) {
            for (Atom atom : esvSystem.getExtendedAtoms()) {
                atomIndex = atom.getArrayIndex();
                if (!esvSystem.isTitratingHeavy(atomIndex)) continue;
                endstatePolar = esvSystem.getTitrationUtils().getPolarizability(atom, 0.0, 0.0, atom.getPolarizeType().polarizability);
                sixth = 0.16666666666666666;
                atom.getPolarizeType().pdamp = FastMath.pow((double)endstatePolar, (double)sixth);
            }
        } else if (this.testEndstateEnergies && BigDecimal.valueOf(esvSystem.getExtendedLambdas()[0]).compareTo(BigDecimal.ONE) == 0) {
            for (Atom atom : esvSystem.getExtendedAtoms()) {
                atomIndex = atom.getArrayIndex();
                if (!esvSystem.isTitratingHeavy(atomIndex)) continue;
                endstatePolar = esvSystem.getTitrationUtils().getPolarizability(atom, 1.0, 0.0, atom.getPolarizeType().polarizability);
                sixth = 0.16666666666666666;
                atom.getPolarizeType().pdamp = FastMath.pow((double)endstatePolar, (double)sixth);
            }
        }
        esvSystem.setConstantPh(this.pH);
        int numESVs = esvSystem.getExtendedResidueList().size();
        this.forceFieldEnergy.attachExtendedSystem(esvSystem);
        logger.info(String.format(" Attached extended system with %d residues.", numESVs));
        this.atomSelectionOptions.setActiveAtoms(this.activeAssembly);
        int nVars = this.forceFieldEnergy.getNumberOfVariables();
        double[] x = new double[nVars];
        this.forceFieldEnergy.getCoordinates(x);
        double[] averageCoordinates = Arrays.copyOf(x, x.length);
        if (this.gradient) {
            double[] g = new double[nVars];
            int nAts = nVars / 3;
            this.energy = this.forceFieldEnergy.energyAndGradient(x, g, true);
            logger.info("    Atom       X, Y and Z Gradient Components (kcal/mol/A)");
            for (i = 0; i < nAts; ++i) {
                int i3 = 3 * i;
                logger.info(String.format(" %7d %16.8f %16.8f %16.8f", i + 1, g[i3], g[i3 + 1], g[i3 + 2]));
            }
        } else {
            this.energy = this.forceFieldEnergy.energy(x, true);
        }
        if (this.moments) {
            this.forceFieldEnergy.getPmeNode().computeMoments(this.activeAssembly.getActiveAtomArray(), false);
        }
        if (this.gyrate) {
            double rg = StructureMetrics.radiusOfGyration(this.activeAssembly.getActiveAtomArray());
            logger.info(String.format(" Radius of gyration:           %10.5f A", rg));
        }
        if (this.inertia) {
            StructureMetrics.momentsOfInertia(this.activeAssembly.getActiveAtomArray(), false, true, true);
        }
        if (this.arcFileName != null) {
            File arcFile = new File(this.arcFileName);
            systemFilter = new XPHFilter(arcFile, this.activeAssembly, this.activeAssembly.getForceField(), this.activeAssembly.getProperties(), esvSystem);
        } else {
            systemFilter = this.potentialFunctions.getFilter();
            if (systemFilter instanceof XYZFilter) {
                systemFilter = new XPHFilter(this.activeAssembly.getFile(), this.activeAssembly, this.activeAssembly.getForceField(), this.activeAssembly.getProperties(), esvSystem);
                systemFilter.readFile();
                logger.info("Reading ESV lambdas from XPH file");
                this.forceFieldEnergy.getCoordinates(x);
                this.forceFieldEnergy.energy(x, true);
            }
        }
        if (this.mbar) {
            this.computeESVEnergiesAndWriteFile(systemFilter, esvSystem);
            return this;
        }
        if (systemFilter instanceof XPHFilter || systemFilter instanceof PDBFilter) {
            int index = 1;
            while (systemFilter.readNext()) {
                ++index;
                Crystal crystal = this.activeAssembly.getCrystal();
                this.forceFieldEnergy.setCrystal(crystal);
                this.forceFieldEnergy.getCoordinates(x);
                if (this.recomputeAverage) {
                    for (int i2 = 0; i2 < x.length; ++i2) {
                        int n = i2;
                        averageCoordinates[n] = averageCoordinates[n] + x[i2];
                    }
                }
                if (this.verbose) {
                    logger.info(String.format(" Snapshot %4d", index));
                    if (!crystal.aperiodic()) {
                        logger.info(String.format("\n Density:                                %6.3f (g/cc)", crystal.getDensity(this.activeAssembly.getMass())));
                    }
                    this.energy = this.forceFieldEnergy.energy(x, true);
                    continue;
                }
                this.energy = this.forceFieldEnergy.energy(x, false);
                logger.info(String.format(" Snapshot %4d: %16.8f (kcal/mol)", index, this.energy));
            }
            if (this.recomputeAverage) {
                for (i = 0; i < x.length; ++i) {
                    x[i] = averageCoordinates[i] / (double)index;
                }
                this.forceFieldEnergy.setCoordinates(x);
            }
        }
        if (this.recomputeAverage) {
            this.saveByOriginalExtension(this.activeAssembly, this.filename);
        }
        return this;
    }

    @Override
    public List<Potential> getPotentials() {
        List<Object> potentials = this.forceFieldEnergy == null ? Collections.emptyList() : Collections.singletonList(this.forceFieldEnergy);
        return potentials;
    }

    void computeESVEnergiesAndWriteFile(SystemFilter systemFilter, ExtendedSystem esvSystem) {
        double[] lambdas;
        File mbarGradFile;
        File mbarFile;
        if (this.outputDirectory.isEmpty()) {
            File dir = new File(this.filename).getParentFile();
            File parentDir = dir.getParentFile();
            int thisRung = -1;
            Pattern pattern = Pattern.compile("(\\d+)");
            Matcher matcher = pattern.matcher(dir.getName());
            if (matcher.find()) {
                thisRung = Integer.parseInt(matcher.group(1));
            }
            assert (thisRung != -1) : "Could not determine the rung number from the directory name.";
            mbarFile = new File(parentDir.getAbsolutePath() + File.separator + "mbarFiles" + File.separator + "energy_" + thisRung + ".mbar");
            mbarGradFile = new File(parentDir.getAbsolutePath() + File.separator + "mbarFiles" + File.separator + "derivative_" + thisRung + ".mbar");
            mbarFile.getParentFile().mkdir();
            File[] lsFiles = parentDir.listFiles();
            ArrayList<File> rungFiles = new ArrayList<File>();
            for (File file : lsFiles) {
                if (!file.isDirectory() || !file.getName().matches("\\d+")) continue;
                rungFiles.add(file);
            }
            if (this.numLambda == -1) {
                this.numLambda = rungFiles.size();
            }
            lambdas = new double[this.numLambda];
            for (int i = 0; i < this.numLambda; ++i) {
                double dL = 1.0 / (double)(this.numLambda - 1);
                lambdas[i] = (double)i * dL;
            }
            logger.info(" Computing energies for each lambda state for generation of mbar file.");
            logger.info(" MBAR File: " + String.valueOf(mbarFile));
            logger.info(" Lambda States: " + Arrays.toString(lambdas));
        } else {
            mbarFile = new File(this.outputDirectory + File.separator + "energy_0.mbar");
            mbarGradFile = new File(this.outputDirectory + File.separator + "derivative_0.mbar");
            lambdas = new double[this.numLambda];
            if (this.numLambda == -1) {
                logger.severe("numLambda must be set when outputDirectory is set.");
            }
            for (int i = 0; i < this.numLambda; ++i) {
                double dL = 1.0 / (double)(this.numLambda - 1);
                lambdas[i] = (double)i * dL;
            }
        }
        int progress = 1;
        if (mbarFile.exists()) {
            try {
                progress = (int)Files.lines(mbarFile.toPath()).count() - 1;
            }
            catch (IOException e) {
                logger.severe("Error reading MBAR file for restart.");
                progress = 1;
            }
            for (int i = 0; i < progress; ++i) {
                systemFilter.readNext();
            }
            logger.info("\n Restarting MBAR file at snapshot " + ++progress);
        }
        if (systemFilter instanceof XPHFilter || systemFilter instanceof PDBFilter) {
            int index = progress;
            double[] x = new double[this.forceFieldEnergy.getNumberOfVariables()];
            try (FileWriter fw = new FileWriter(mbarFile, mbarFile.exists());
                 BufferedWriter writer = new BufferedWriter(fw);
                 FileWriter fwGrad = new FileWriter(mbarGradFile, mbarGradFile.exists());
                 BufferedWriter writerGrad = new BufferedWriter(fwGrad);){
                StringBuilder sb = new StringBuilder(systemFilter.countNumModels() + "\t298.0\t" + FilenameUtils.getBaseName((String)this.filename));
                StringBuilder sbGrad = new StringBuilder(systemFilter.countNumModels() + "\t298.0\t" + FilenameUtils.getBaseName((String)this.filename));
                logger.info(" MBAR file temp is hardcoded to 298.0 K. Please change if necessary.");
                sb.append("\n");
                sbGrad.append("\n");
                if (progress == 1) {
                    writer.write(sb.toString());
                    writer.flush();
                    logger.info(" Header: " + String.valueOf(sb));
                    if (this.derivatives) {
                        writerGrad.write(sbGrad.toString());
                        writerGrad.flush();
                        logger.info(" Header: " + String.valueOf(sbGrad));
                    }
                }
                while (systemFilter.readNext()) {
                    sb = new StringBuilder("\t" + index + "\t");
                    sbGrad = new StringBuilder("\t" + index + "\t");
                    ++index;
                    Crystal crystal = this.activeAssembly.getCrystal();
                    this.forceFieldEnergy.setCrystal(crystal);
                    for (double lambda : lambdas) {
                        if (this.tautomer) {
                            PhEnergy.setESVTautomer(lambda, esvSystem);
                        } else {
                            PhEnergy.setESVLambda(lambda, esvSystem);
                        }
                        this.forceFieldEnergy.getCoordinates(x);
                        if (this.derivatives) {
                            this.energy = this.forceFieldEnergy.energyAndGradient(x, new double[x.length * 3]);
                            double grad = esvSystem.getDerivatives()[0];
                            sbGrad.append(grad).append(" ");
                        } else {
                            this.energy = this.forceFieldEnergy.energy(x, false);
                        }
                        sb.append(this.energy).append(" ");
                    }
                    sb.append("\n");
                    writer.write(sb.toString());
                    writer.flush();
                    logger.info(sb.toString());
                    if (!this.derivatives) continue;
                    sbGrad.append("\n");
                    writerGrad.write(sbGrad.toString());
                    writerGrad.flush();
                    logger.info(sbGrad.toString());
                }
            }
            catch (IOException e) {
                logger.severe("Error writing to MBAR file.");
            }
        }
    }

    public static void setESVLambda(double lambda, ExtendedSystem extendedSystem) {
        List<Residue> residueList = extendedSystem.getExtendedResidueList();
        if (residueList.size() == 1 || residueList.size() == 2 && extendedSystem.isTautomer(residueList.getFirst())) {
            extendedSystem.setTitrationLambda(residueList.getFirst(), lambda, false);
        } else if (residueList.isEmpty()) {
            logger.severe(" No residues found in the extended system.");
        } else {
            logger.severe(" Only one lambda path is allowed for MBAR energy evaluations.");
        }
    }

    public static void setESVTautomer(double tautomer, ExtendedSystem extendedSystem) {
        List<Residue> residueList = extendedSystem.getExtendedResidueList();
        if (residueList.size() == 1 || residueList.size() == 2 && extendedSystem.isTautomer(residueList.getFirst())) {
            extendedSystem.setTautomerLambda(residueList.getFirst(), tautomer, false);
        } else if (residueList.isEmpty()) {
            logger.severe(" No residues found in the extended system.");
        } else {
            logger.severe(" Only one lambda path is allowed for MBAR energy evaluations.");
        }
    }
}

