/*
 * Decompiled with CFR 0.152.
 */
package ffx.potential.commands.test;

import ffx.numerics.Potential;
import ffx.potential.ForceFieldEnergy;
import ffx.potential.MolecularAssembly;
import ffx.potential.bonded.LambdaInterface;
import ffx.potential.cli.AlchemicalOptions;
import ffx.potential.cli.GradientOptions;
import ffx.potential.cli.PotentialCommand;
import ffx.potential.cli.TopologyOptions;
import ffx.potential.nonbonded.ParticleMeshEwald;
import ffx.potential.nonbonded.pme.AlchemicalParameters;
import ffx.potential.openmm.OpenMMEnergy;
import ffx.utilities.FFXBinding;
import ffx.utilities.StringUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math3.util.FastMath;
import picocli.CommandLine;

@CommandLine.Command(description={" Test potential energy derivatives with respect to Lambda."}, name="test.LambdaGradient")
public class LambdaGradient
extends PotentialCommand {
    @CommandLine.Mixin
    AlchemicalOptions alchemicalOptions;
    @CommandLine.Mixin
    GradientOptions gradientOptions;
    @CommandLine.Mixin
    TopologyOptions topologyOptions;
    @CommandLine.Option(names={"--ls", "--lambdaScan"}, paramLabel="false", description={"Scan lambda values."})
    boolean lambdaScan = false;
    @CommandLine.Option(names={"--lm", "--lambdaMoveSize"}, paramLabel="0.01", description={"Size of the lambda moves during the test."})
    double lambdaMoveSize = 0.01;
    @CommandLine.Option(names={"--sk2", "--skip2"}, paramLabel="false", description={"Skip 2nd derivatives."})
    boolean skipSecondDerivatives = false;
    @CommandLine.Option(names={"--sdX", "--skipdX"}, paramLabel="false", description={"Skip calculating per-atom dUdX values and only test lambda gradients."})
    boolean skipAtomGradients = false;
    @CommandLine.Parameters(arity="1..*", paramLabel="files", description={"The atomic coordinate files in PDB or XYZ format."})
    List<String> filenames = null;
    MolecularAssembly[] topologies;
    private Potential potential;
    public int ndEdLFailures = 0;
    public int ndEdXFailures = 0;
    public int nd2EdL2Failures = 0;
    public int ndEdXdLFailures = 0;
    public double e0 = 0.0;
    public double e1 = 0.0;

    public LambdaGradient() {
    }

    public LambdaGradient(FFXBinding binding) {
        super(binding);
    }

    public LambdaGradient(String[] args) {
        super(args);
    }

    public LambdaGradient run() {
        double dEdL;
        if (!this.init()) {
            return this;
        }
        int numTopologies = this.topologyOptions.getNumberOfTopologies(this.filenames);
        int threadsPerTopology = this.topologyOptions.getThreadsPerTopology(numTopologies);
        this.topologies = new MolecularAssembly[numTopologies];
        this.alchemicalOptions.setAlchemicalProperties();
        this.topologyOptions.setAlchemicalProperties(numTopologies);
        if (this.filenames == null || this.filenames.isEmpty()) {
            this.activeAssembly = this.getActiveAssembly(this.filenames);
            if (this.activeAssembly == null) {
                logger.info(this.helpString());
                return this;
            }
            this.filenames = new ArrayList<String>();
            this.filenames.add(this.activeAssembly.getFile().getName());
            this.topologies[0] = this.alchemicalOptions.processFile(this.topologyOptions, this.activeAssembly, 0);
        } else {
            logger.info(String.format(" Initializing %d topologies...", numTopologies));
            for (int i = 0; i < numTopologies; ++i) {
                this.topologies[i] = this.alchemicalOptions.openFile(this.potentialFunctions, this.topologyOptions, threadsPerTopology, this.filenames.get(i), i);
            }
        }
        StringBuilder sb = new StringBuilder("\n Testing lambda derivatives for ");
        this.potential = this.topologyOptions.assemblePotential(this.topologies, sb);
        logger.info(sb.toString());
        AlchemicalParameters.AlchemicalMode mode = AlchemicalParameters.AlchemicalMode.OST;
        for (MolecularAssembly assembly : this.topologies) {
            ForceFieldEnergy energy = assembly.getPotentialEnergy();
            ParticleMeshEwald pme = energy.getPmeNode();
            if (pme == null) continue;
            AlchemicalParameters alchemicalParameters = pme.getAlchemicalParameters();
            if (alchemicalParameters.mode == AlchemicalParameters.AlchemicalMode.OST) continue;
            mode = AlchemicalParameters.AlchemicalMode.SCALE;
        }
        boolean skipLambdaDerivatives = false;
        if (mode == AlchemicalParameters.AlchemicalMode.SCALE) {
            skipLambdaDerivatives = true;
        }
        if (this.potential instanceof OpenMMEnergy || mode == AlchemicalParameters.AlchemicalMode.SCALE) {
            this.skipSecondDerivatives = true;
        }
        LambdaInterface lambdaInterface = (LambdaInterface)this.potential;
        int n = this.potential.getNumberOfVariables();
        double[] x = new double[n];
        double[] gradient = new double[n];
        double[] lambdaGrad = new double[n];
        double[][] lambdaGradFD = new double[2][n];
        assert (n % 3 == 0);
        int nAtoms = n / 3;
        double lambda = 0.0;
        lambdaInterface.setLambda(lambda);
        this.potential.getCoordinates(x);
        this.e0 = this.potential.energyAndGradient(x, gradient);
        if (!skipLambdaDerivatives) {
            dEdL = lambdaInterface.getdEdL();
            logger.info(String.format(" L=%4.2f E=%12.6f dE/dL=%12.6f", lambda, this.e0, dEdL));
        } else {
            logger.info(String.format(" L=%4.2f E=%12.6f", lambda, this.e0));
        }
        if (this.lambdaScan) {
            for (int i = 1; i <= 9; ++i) {
                lambda = (double)i * 0.1;
                lambdaInterface.setLambda(lambda);
                double e = this.potential.energyAndGradient(x, gradient);
                if (!skipLambdaDerivatives) {
                    double dEdL2 = lambdaInterface.getdEdL();
                    logger.info(String.format(" L=%4.2f E=%12.6f dE/dL=%12.6f", lambda, e, dEdL2));
                    continue;
                }
                logger.info(String.format(" L=%4.2f E=%12.6f", lambda, e));
            }
        }
        lambda = 1.0;
        lambdaInterface.setLambda(lambda);
        this.e1 = this.potential.energyAndGradient(x, gradient);
        if (!skipLambdaDerivatives) {
            dEdL = lambdaInterface.getdEdL();
            logger.info(String.format(" L=%4.2f E=%12.6f dE/dL=%12.6f", lambda, this.e1, dEdL));
        } else {
            logger.info(String.format(" L=%4.2f E=%12.6f", lambda, this.e1));
        }
        logger.info(String.format(" E(1)-E(0): %12.6f.\n", this.e1 - this.e0));
        double step = this.gradientOptions.getDx();
        double width = 2.0 * step;
        logger.info(" Finite-difference step size:\t" + step);
        boolean print = this.gradientOptions.getVerbose();
        logger.info(" Verbose printing:\t\t" + print + "\n");
        double errTol = this.gradientOptions.getTolerance();
        double expGrad = 1000.0;
        if (!skipLambdaDerivatives) {
            for (int j = 0; j < 3; ++j) {
                int jd2EdXdLFailures = 0;
                lambda = this.alchemicalOptions.getInitialLambda() - this.lambdaMoveSize + this.lambdaMoveSize * (double)j;
                if (lambda - this.gradientOptions.getDx() < 0.0 || lambda + this.gradientOptions.getDx() > 1.0) continue;
                logger.info(String.format(" Current lambda value %6.4f", lambda));
                lambdaInterface.setLambda(lambda);
                double e = this.potential.energyAndGradient(x, gradient);
                double dEdL3 = lambdaInterface.getdEdL();
                double d2EdL2 = lambdaInterface.getd2EdL2();
                Arrays.fill(lambdaGrad, 0.0);
                lambdaInterface.getdEdXdL(lambdaGrad);
                lambdaInterface.setLambda(lambda + this.gradientOptions.getDx());
                double lp = this.potential.energyAndGradient(x, lambdaGradFD[0]);
                double dedlp = lambdaInterface.getdEdL();
                lambdaInterface.setLambda(lambda - this.gradientOptions.getDx());
                double lm = this.potential.energyAndGradient(x, lambdaGradFD[1]);
                double dedlm = lambdaInterface.getdEdL();
                double dEdLFD = (lp - lm) / width;
                double d2EdL2FD = (dedlp - dedlm) / width;
                double err = FastMath.abs((double)(dEdLFD - dEdL3));
                if (err < errTol) {
                    logger.info(String.format(" dE/dL passed:   %10.6f", err));
                } else {
                    logger.info(String.format(" dE/dL failed: %10.6f", err));
                    ++this.ndEdLFailures;
                }
                logger.info(String.format(" Numeric:   %15.8f", dEdLFD));
                logger.info(String.format(" Analytic:  %15.8f", dEdL3));
                if (this.skipSecondDerivatives) continue;
                err = FastMath.abs((double)(d2EdL2FD - d2EdL2));
                if (err < errTol) {
                    logger.info(String.format(" d2E/dL2 passed: %10.6f", err));
                } else {
                    logger.info(String.format(" d2E/dL2 failed: %10.6f", err));
                    ++this.nd2EdL2Failures;
                }
                logger.info(String.format(" Numeric:   %15.8f", d2EdL2FD));
                logger.info(String.format(" Analytic:  %15.8f", d2EdL2));
                double rmsError = 0.0;
                for (int i = 0; i < nAtoms; ++i) {
                    int ii = i * 3;
                    double dX = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
                    double dXa = lambdaGrad[ii];
                    double eX = dX - dXa;
                    double dY = (lambdaGradFD[0][++ii] - lambdaGradFD[1][ii]) / width;
                    double dYa = lambdaGrad[ii];
                    double eY = dY - dYa;
                    double dZ = (lambdaGradFD[0][++ii] - lambdaGradFD[1][ii]) / width;
                    double dZa = lambdaGrad[ii];
                    double eZ = dZ - dZa;
                    double error = eX * eX + eY * eY + eZ * eZ;
                    rmsError += error;
                    if ((error = FastMath.sqrt((double)error)) < errTol) {
                        logger.fine(String.format(" dE/dX/dL for degree of freedom %d passed: %10.6f", i + 1, error));
                        continue;
                    }
                    logger.info(String.format(" dE/dX/dL for degree of freedom %d failed: %10.6f", i + 1, error));
                    logger.info(String.format(" Analytic: (%15.8f, %15.8f, %15.8f)", dXa, dYa, dZa));
                    logger.info(String.format(" Numeric:  (%15.8f, %15.8f, %15.8f)", dX, dY, dZ));
                    ++this.ndEdXdLFailures;
                    ++jd2EdXdLFailures;
                }
                rmsError = FastMath.sqrt((double)(rmsError / (double)nAtoms));
                if (this.ndEdXdLFailures == 0) {
                    logger.info(String.format(" dE/dX/dL passed for all degrees of freedom: RMS error %15.8f", rmsError));
                } else {
                    logger.info(String.format(" dE/dX/dL failed for %d of %d atoms: RMS error %15.8f", jd2EdXdLFailures, nAtoms, rmsError));
                }
                logger.info("");
            }
        }
        lambdaInterface.setLambda(this.alchemicalOptions.getInitialLambda());
        this.potential.getCoordinates(x);
        this.potential.energyAndGradient(x, gradient, print);
        if (!this.skipAtomGradients) {
            List<Integer> degreesOfFreedomToTest;
            double[] numeric = new double[3];
            double avLen = 0.0;
            double avGrad = 0.0;
            if (this.gradientOptions.getGradientAtoms().equalsIgnoreCase("NONE")) {
                logger.info(" The gradient of no atoms will be evaluated.");
                return this;
            }
            if (this.gradientOptions.getGradientAtoms().equalsIgnoreCase("ALL")) {
                logger.info(" Checking gradient for all active atoms.\n");
                degreesOfFreedomToTest = new ArrayList();
                for (int i = 0; i < nAtoms; ++i) {
                    degreesOfFreedomToTest.add(i);
                }
            } else {
                degreesOfFreedomToTest = StringUtils.parseAtomRanges((String)" Gradient atoms", (String)this.gradientOptions.getGradientAtoms(), (int)nAtoms);
                logger.info(" Checking gradient for active atoms in the range: " + this.gradientOptions.getGradientAtoms() + "\n");
            }
            Iterator iterator = degreesOfFreedomToTest.iterator();
            while (iterator.hasNext()) {
                int i = (Integer)iterator.next();
                int i3 = i * 3;
                int i0 = i3 + 0;
                int i1 = i3 + 1;
                int i2 = i3 + 2;
                double orig = x[i0];
                x[i0] = x[i0] + step;
                double e = this.potential.energyAndGradient(x, lambdaGradFD[0], print);
                x[i0] = orig - step;
                x[i0] = orig;
                numeric[0] = (e -= this.potential.energyAndGradient(x, lambdaGradFD[1], print)) / width;
                orig = x[i1];
                x[i1] = x[i1] + step;
                e = this.potential.energyAndGradient(x, lambdaGradFD[0], print);
                x[i1] = orig - step;
                x[i1] = orig;
                numeric[1] = (e -= this.potential.energyAndGradient(x, lambdaGradFD[1], print)) / width;
                orig = x[i2];
                x[i2] = x[i2] + step;
                e = this.potential.energyAndGradient(x, lambdaGradFD[0], print);
                x[i2] = orig - step;
                x[i2] = orig;
                numeric[2] = (e -= this.potential.energyAndGradient(x, lambdaGradFD[1], print)) / width;
                double dx = gradient[i0] - numeric[0];
                double dy = gradient[i1] - numeric[1];
                double dz = gradient[i2] - numeric[2];
                double len = dx * dx + dy * dy + dz * dz;
                avLen += len;
                len = Math.sqrt(len);
                double grad2 = gradient[i0] * gradient[i0] + gradient[i1] * gradient[i1] + gradient[i2] * gradient[i2];
                avGrad += grad2;
                if (len > errTol) {
                    logger.info(String.format(" Degree of freedom %d failed: %10.6f.", i + 1, len) + String.format("\n Analytic: (%12.4f, %12.4f, %12.4f)\n", gradient[i0], gradient[i1], gradient[i2]) + String.format(" Numeric:  (%12.4f, %12.4f, %12.4f)\n", numeric[0], numeric[1], numeric[2]));
                    ++this.ndEdXFailures;
                } else {
                    logger.info(String.format(" Degree of freedom %d passed: %10.6f.", i + 1, len) + String.format("\n Analytic: (%12.4f, %12.4f, %12.4f)\n", gradient[i0], gradient[i1], gradient[i2]) + String.format(" Numeric:  (%12.4f, %12.4f, %12.4f)", numeric[0], numeric[1], numeric[2]));
                }
                if (grad2 > expGrad) {
                    logger.info(String.format(" Degree of freedom %d has an unusually large gradient: %10.6f", i + 1, grad2));
                }
                logger.info("\n");
            }
            avLen /= (double)nAtoms;
            if ((avLen = FastMath.sqrt((double)avLen)) > errTol) {
                logger.info(String.format(" Test failure: RMSD from analytic solution is %10.6f > %10.6f", avLen, errTol));
            } else {
                logger.info(String.format(" Test success: RMSD from analytic solution is %10.6f < %10.6f", avLen, errTol));
            }
            logger.info(String.format(" Number of atoms failing gradient test: %d", this.ndEdXFailures));
            avGrad /= (double)nAtoms;
            avGrad = FastMath.sqrt((double)avGrad);
            if (avGrad > expGrad) {
                logger.info(String.format(" Unusually large RMS gradient: %10.6f > %10.6f", avGrad, expGrad));
            } else {
                logger.info(String.format(" RMS gradient: %10.6f", avGrad));
            }
        } else {
            logger.info(" Skipping atomic dU/dX gradients.");
        }
        return this;
    }

    @Override
    public List<Potential> getPotentials() {
        if (this.potential == null) {
            return Collections.emptyList();
        }
        return Collections.singletonList(this.potential);
    }
}

