/*
 * Decompiled with CFR 0.152.
 */
package ffx.potential.nonbonded;

import ffx.numerics.math.ScalarMath;
import ffx.potential.nonbonded.pme.LambdaMode;
import ffx.potential.parameters.ForceField;
import java.util.Arrays;
import java.util.Objects;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.apache.commons.math3.fitting.leastsquares.EvaluationRmsChecker;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresFactory;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresOptimizer;
import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem;
import org.apache.commons.math3.fitting.leastsquares.LevenbergMarquardtOptimizer;
import org.apache.commons.math3.fitting.leastsquares.MultivariateJacobianFunction;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.util.Pair;

public class ScfPredictor {
    private static final Logger logger = Logger.getLogger(ScfPredictor.class.getName());
    private static final double eps = 1.0E-4;
    private final PredictorMode predictorMode;
    private final int predictorOrder;
    protected double[][][] inducedDipole;
    protected double[][][] inducedDipoleCR;
    private int nAtoms;
    private int mode;
    private double[][][][] predictorInducedDipole;
    private double[][][][] predictorInducedDipoleCR;
    private int predictorStartIndex;
    private int predictorCount;
    private LeastSquaresPredictor leastSquaresPredictor;

    public ScfPredictor(PredictorMode mode, int order, ForceField ff) {
        this.predictorMode = mode;
        this.predictorOrder = order;
        this.predictorCount = 0;
        this.predictorStartIndex = 0;
        if (this.predictorMode != PredictorMode.NONE) {
            if (this.predictorMode == PredictorMode.LS) {
                this.leastSquaresPredictor = new LeastSquaresPredictor(this, 1.0E-4);
            }
            if (ff.getBoolean("LAMBDATERM", false)) {
                this.predictorInducedDipole = new double[3][this.predictorOrder][this.nAtoms][3];
                this.predictorInducedDipoleCR = new double[3][this.predictorOrder][this.nAtoms][3];
            } else {
                this.predictorInducedDipole = new double[1][this.predictorOrder][this.nAtoms][3];
                this.predictorInducedDipoleCR = new double[1][this.predictorOrder][this.nAtoms][3];
            }
        }
    }

    public void run(LambdaMode lambdaMode) {
        if (this.predictorMode == PredictorMode.NONE) {
            return;
        }
        switch (lambdaMode) {
            case CONDENSED_NO_LIGAND: {
                this.mode = 1;
                break;
            }
            case VAPOR: {
                this.mode = 2;
                break;
            }
            default: {
                this.mode = 0;
            }
        }
        if (this.predictorMode != PredictorMode.NONE) {
            switch (this.predictorMode.ordinal()) {
                case 3: {
                    this.aspcPredictor();
                    break;
                }
                case 1: {
                    this.leastSquaresPredictor();
                    break;
                }
                case 2: {
                    this.polynomialPredictor();
                    break;
                }
            }
        }
    }

    public void saveMutualInducedDipoles(double[][][] inducedDipole, double[][][] inducedDipoleCR, double[][] directDipole, double[][] directDipoleCR) {
        --this.predictorStartIndex;
        if (this.predictorStartIndex < 0) {
            this.predictorStartIndex = this.predictorOrder - 1;
        }
        if (this.predictorCount < this.predictorOrder) {
            ++this.predictorCount;
        }
        for (int i = 0; i < this.nAtoms; ++i) {
            for (int j = 0; j < 3; ++j) {
                this.predictorInducedDipole[this.mode][this.predictorStartIndex][i][j] = inducedDipole[0][i][j] - directDipole[i][j];
                this.predictorInducedDipoleCR[this.mode][this.predictorStartIndex][i][j] = inducedDipoleCR[0][i][j] - directDipoleCR[i][j];
            }
        }
    }

    public void setInducedDipoleReferences(double[][][] inducedDipole, double[][][] inducedDipoleCR, boolean lambdaTerm) {
        this.inducedDipole = inducedDipole;
        this.inducedDipoleCR = inducedDipoleCR;
        if (lambdaTerm) {
            this.predictorInducedDipole = new double[3][this.predictorOrder][this.nAtoms][3];
            this.predictorInducedDipoleCR = new double[3][this.predictorOrder][this.nAtoms][3];
        } else {
            this.predictorInducedDipole = new double[1][this.predictorOrder][this.nAtoms][3];
            this.predictorInducedDipoleCR = new double[1][this.predictorOrder][this.nAtoms][3];
        }
    }

    public String toString() {
        return this.predictorMode.toString();
    }

    private void leastSquaresPredictor() {
        if (this.predictorCount < 2) {
            return;
        }
        try {
            this.leastSquaresPredictor.updateJacobianAndTarget();
            int maxIter = 1000;
            int maxEvals = 1000;
            LeastSquaresOptimizer.Optimum optimum = this.leastSquaresPredictor.predict(maxEvals, maxIter);
            double[] optimalValues = optimum.getPoint().toArray();
            if (logger.isLoggable(Level.FINEST)) {
                logger.finest(String.format("\n LS RMS:            %10.6f", optimum.getRMS()));
                logger.finest(String.format(" LS Iterations:     %10d", optimum.getIterations()));
                logger.finest(String.format(" Jacobian Evals:    %10d", optimum.getEvaluations()));
                logger.finest(String.format(" Root Mean Square:  %10.6f", optimum.getRMS()));
                logger.finest(" LS Coefficients");
                for (int i = 0; i < this.predictorOrder - 1; ++i) {
                    logger.finest(String.format(" %2d  %10.6f", i + 1, optimalValues[i]));
                }
            }
            int index = this.predictorStartIndex;
            for (int k = 0; k < this.predictorOrder - 1; ++k) {
                double c = optimalValues[k];
                for (int i = 0; i < this.nAtoms; ++i) {
                    for (int j = 0; j < 3; ++j) {
                        double[] dArray = this.inducedDipole[0][i];
                        int n = j;
                        dArray[n] = dArray[n] + c * this.predictorInducedDipole[this.mode][index][i][j];
                        double[] dArray2 = this.inducedDipoleCR[0][i];
                        int n2 = j;
                        dArray2[n2] = dArray2[n2] + c * this.predictorInducedDipoleCR[this.mode][index][i][j];
                    }
                }
                if (++index < this.predictorOrder) continue;
                index = 0;
            }
        }
        catch (Exception e) {
            logger.log(Level.WARNING, " Exception computing predictor coefficients", e);
        }
    }

    private void aspcPredictor() {
        if (this.predictorCount < 6) {
            return;
        }
        double[] aspc = new double[]{3.142857142857143, -3.9285714285714284, 2.619047619047619, -1.0476190476190477, 0.23809523809523808, -0.023809523809523808};
        int index = this.predictorStartIndex;
        for (int k = 0; k < 6; ++k) {
            double c = aspc[k];
            for (int i = 0; i < this.nAtoms; ++i) {
                for (int j = 0; j < 3; ++j) {
                    double[] dArray = this.inducedDipole[0][i];
                    int n = j;
                    dArray[n] = dArray[n] + c * this.predictorInducedDipole[this.mode][index][i][j];
                    double[] dArray2 = this.inducedDipoleCR[0][i];
                    int n2 = j;
                    dArray2[n2] = dArray2[n2] + c * this.predictorInducedDipoleCR[this.mode][index][i][j];
                }
            }
            if (++index < this.predictorOrder) continue;
            index = 0;
        }
    }

    private void polynomialPredictor() {
        if (this.predictorCount == 0) {
            return;
        }
        int n = this.predictorOrder;
        if (this.predictorCount < this.predictorOrder) {
            n = this.predictorCount;
        }
        int index = this.predictorStartIndex;
        double sign = -1.0;
        for (int k = 0; k < n; ++k) {
            double c = (sign *= -1.0) * (double)ScalarMath.binomial((long)n, (long)k);
            for (int i = 0; i < this.nAtoms; ++i) {
                for (int j = 0; j < 3; ++j) {
                    double[] dArray = this.inducedDipole[0][i];
                    int n2 = j;
                    dArray[n2] = dArray[n2] + c * this.predictorInducedDipole[this.mode][index][i][j];
                    double[] dArray2 = this.inducedDipoleCR[0][i];
                    int n3 = j;
                    dArray2[n3] = dArray2[n3] + c * this.predictorInducedDipoleCR[this.mode][index][i][j];
                }
            }
            if (++index < this.predictorOrder) continue;
            index = 0;
        }
    }

    public static enum PredictorMode {
        NONE,
        LS,
        POLY,
        ASPC;

    }

    private class LeastSquaresPredictor {
        double[] weights;
        double[] target;
        double[][] jacobian;
        double[] initialSolution;
        double tolerance;
        RealVector valuesVector;
        RealVector targetVector;
        RealMatrix jacobianMatrix;
        LeastSquaresOptimizer optimizer;
        ConvergenceChecker<LeastSquaresProblem.Evaluation> checker;
        Pair<RealVector, RealMatrix> test;
        MultivariateJacobianFunction function;
        final /* synthetic */ ScfPredictor this$0;

        public LeastSquaresPredictor(ScfPredictor scfPredictor, double eps) {
            ScfPredictor scfPredictor2 = scfPredictor;
            Objects.requireNonNull(scfPredictor2);
            this.this$0 = scfPredictor2;
            this.test = new Pair((Object)this.targetVector, (Object)this.jacobianMatrix);
            this.function = new MultivariateJacobianFunction(this){
                final /* synthetic */ LeastSquaresPredictor this$1;
                {
                    LeastSquaresPredictor leastSquaresPredictor = this$1;
                    Objects.requireNonNull(leastSquaresPredictor);
                    this.this$1 = leastSquaresPredictor;
                }

                public Pair<RealVector, RealMatrix> value(RealVector point) {
                    return new Pair((Object)this.this$1.targetVector, (Object)this.this$1.jacobianMatrix);
                }
            };
            this.tolerance = eps;
            this.weights = new double[2 * scfPredictor.nAtoms * 3];
            this.target = new double[2 * scfPredictor.nAtoms * 3];
            this.jacobian = new double[2 * scfPredictor.nAtoms * 3][scfPredictor.predictorOrder - 1];
            this.initialSolution = new double[scfPredictor.predictorOrder - 1];
            Arrays.fill(this.weights, 1.0);
            Arrays.fill(this.initialSolution, 0.0);
            this.initialSolution[0] = 1.0;
            this.optimizer = new LevenbergMarquardtOptimizer().withParameterRelativeTolerance(eps);
            this.checker = new EvaluationRmsChecker(eps);
        }

        public LeastSquaresOptimizer.Optimum predict(int maxEval, int maxIter) {
            ArrayRealVector start = new ArrayRealVector(this.initialSolution);
            LeastSquaresProblem lsp = LeastSquaresFactory.create((MultivariateJacobianFunction)this.function, (RealVector)this.targetVector, (RealVector)start, this.checker, (int)maxEval, (int)maxIter);
            LeastSquaresOptimizer.Optimum optimum = this.optimizer.optimize(lsp);
            logger.info(String.format(" LS Optimization parameters:  %s %s\n  %s %s\n  %d %d", this.function, this.targetVector.toString(), start, this.checker.toString(), maxIter, maxEval));
            return optimum;
        }

        public void updateJacobianAndTarget() {
            int index = 0;
            for (int i = 0; i < this.this$0.nAtoms; ++i) {
                this.target[index++] = this.this$0.predictorInducedDipole[this.this$0.mode][this.this$0.predictorStartIndex][i][0];
                this.target[index++] = this.this$0.predictorInducedDipole[this.this$0.mode][this.this$0.predictorStartIndex][i][1];
                this.target[index++] = this.this$0.predictorInducedDipole[this.this$0.mode][this.this$0.predictorStartIndex][i][2];
                this.target[index++] = this.this$0.predictorInducedDipoleCR[this.this$0.mode][this.this$0.predictorStartIndex][i][0];
                this.target[index++] = this.this$0.predictorInducedDipoleCR[this.this$0.mode][this.this$0.predictorStartIndex][i][1];
                this.target[index++] = this.this$0.predictorInducedDipoleCR[this.this$0.mode][this.this$0.predictorStartIndex][i][2];
            }
            this.targetVector = new ArrayRealVector(this.target);
            index = this.this$0.predictorStartIndex + 1;
            if (index >= this.this$0.predictorOrder) {
                index = 0;
            }
            for (int j = 0; j < this.this$0.predictorOrder - 1; ++j) {
                int ji = 0;
                for (int i = 0; i < this.this$0.nAtoms; ++i) {
                    this.jacobian[ji++][j] = this.this$0.predictorInducedDipole[this.this$0.mode][index][i][0];
                    this.jacobian[ji++][j] = this.this$0.predictorInducedDipole[this.this$0.mode][index][i][1];
                    this.jacobian[ji++][j] = this.this$0.predictorInducedDipole[this.this$0.mode][index][i][2];
                    this.jacobian[ji++][j] = this.this$0.predictorInducedDipoleCR[this.this$0.mode][index][i][0];
                    this.jacobian[ji++][j] = this.this$0.predictorInducedDipoleCR[this.this$0.mode][index][i][1];
                    this.jacobian[ji++][j] = this.this$0.predictorInducedDipoleCR[this.this$0.mode][index][i][2];
                }
                if (++index < this.this$0.predictorOrder) continue;
                index = 0;
            }
            this.jacobianMatrix = new Array2DRowRealMatrix(this.jacobian);
        }

        private RealVector value(double[] variables) {
            double[] values = new double[2 * this.this$0.nAtoms * 3];
            for (int i = 0; i < this.this$0.nAtoms; ++i) {
                int index = 6 * i;
                values[index] = 0.0;
                values[index + 1] = 0.0;
                values[index + 2] = 0.0;
                values[index + 3] = 0.0;
                values[index + 4] = 0.0;
                values[index + 5] = 0.0;
                int pi = this.this$0.predictorStartIndex + 1;
                if (pi >= this.this$0.predictorOrder) {
                    pi = 0;
                }
                for (int j = 0; j < this.this$0.predictorOrder - 1; ++j) {
                    int n = index;
                    values[n] = values[n] + variables[j] * this.this$0.predictorInducedDipole[this.this$0.mode][pi][i][0];
                    int n2 = index + 1;
                    values[n2] = values[n2] + variables[j] * this.this$0.predictorInducedDipole[this.this$0.mode][pi][i][1];
                    int n3 = index + 2;
                    values[n3] = values[n3] + variables[j] * this.this$0.predictorInducedDipole[this.this$0.mode][pi][i][2];
                    int n4 = index + 3;
                    values[n4] = values[n4] + variables[j] * this.this$0.predictorInducedDipoleCR[this.this$0.mode][pi][i][0];
                    int n5 = index + 4;
                    values[n5] = values[n5] + variables[j] * this.this$0.predictorInducedDipoleCR[this.this$0.mode][pi][i][1];
                    int n6 = index + 5;
                    values[n6] = values[n6] + variables[j] * this.this$0.predictorInducedDipoleCR[this.this$0.mode][pi][i][2];
                    if (++pi < this.this$0.predictorOrder) continue;
                    pi = 0;
                }
            }
            return new ArrayRealVector(values);
        }
    }
}

