/*
 * Decompiled with CFR 0.152.
 */
package ffx.potential.nonbonded;

import ffx.crystal.Crystal;
import ffx.numerics.tornado.FFXTornado;
import ffx.potential.bonded.Atom;
import ffx.potential.bonded.Bond;
import ffx.potential.nonbonded.VanDerWaals;
import ffx.potential.nonbonded.VanDerWaalsForm;
import ffx.potential.parameters.AtomType;
import ffx.potential.parameters.ForceField;
import ffx.potential.parameters.VDWType;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
import uk.ac.manchester.tornado.api.TaskGraph;
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
import uk.ac.manchester.tornado.api.annotations.Reduce;
import uk.ac.manchester.tornado.api.common.TornadoDevice;
import uk.ac.manchester.tornado.api.math.TornadoMath;
import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider;

public class VanDerWaalsTornado
extends VanDerWaals {
    private static final Logger logger = Logger.getLogger(VanDerWaalsTornado.class.getName());
    private static final byte XX = 0;
    private static final byte YY = 1;
    private static final byte ZZ = 2;
    private final VanDerWaalsForm vdwForm;
    private final double vdwTaper;
    private final double vdwCutoff;
    private int interactions;
    private double energy;
    private double[] grad;
    private Crystal crystal;
    private Atom[] atoms;
    private ForceField forceField;
    private int nAtoms;
    private int[] atomClass;
    private double[] coordinates;
    private double[] reductionValue;
    private int[] reductionIndex;
    private int[] mask;
    private int[] maskPointer;

    public VanDerWaalsTornado(Atom[] atoms, Crystal crystal, ForceField forceField, double vdwCutoff) {
        this.atoms = atoms;
        this.crystal = crystal;
        this.forceField = forceField;
        this.nAtoms = atoms.length;
        this.vdwForm = new VanDerWaalsForm(forceField);
        this.initAtomArrays();
        this.vdwCutoff = vdwCutoff;
        this.vdwTaper = 0.9 * vdwCutoff;
        logger.info(this.toString());
    }

    private static void tornadoEnergy(int[] atomClass, double[] eps, double[] rMin, double[] reducedXYZ, int[] reductionIndex, double[] reductionValue, double[] bondedScaleFactors, int[] maskPointers, int[] masks, double[] A, double[] Ai, double[] cutoffs, @Reduce double[] energy, @Reduce int[] interactions, @Reduce double[] grad) {
        double A00 = A[0];
        double A01 = A[1];
        double A02 = A[2];
        double A10 = A[3];
        double A11 = A[4];
        double A12 = A[5];
        double A20 = A[6];
        double A21 = A[7];
        double A22 = A[8];
        double Ai00 = Ai[0];
        double Ai01 = Ai[1];
        double Ai02 = Ai[2];
        double Ai10 = Ai[3];
        double Ai11 = Ai[4];
        double Ai12 = Ai[5];
        double Ai20 = Ai[6];
        double Ai21 = Ai[7];
        double Ai22 = Ai[8];
        double scale12 = bondedScaleFactors[0];
        double scale13 = bondedScaleFactors[1];
        double scale14 = bondedScaleFactors[2];
        boolean aperiodic = false;
        if (cutoffs[0] > 0.0) {
            aperiodic = true;
        }
        double vdwTaper = cutoffs[1];
        double vdwCutoff = cutoffs[2];
        double vdwTaper2 = vdwTaper * vdwTaper;
        double vdwCutoff2 = vdwCutoff * vdwCutoff;
        boolean gradient = false;
        if (cutoffs[3] > 0.0) {
            gradient = true;
        }
        double a = vdwTaper;
        double b = vdwCutoff;
        double a2 = a * a;
        double b2 = b * b;
        double ba = b - a;
        double ba2 = ba * ba;
        double denom = ba * ba2 * ba2;
        double c0 = b * b2 * (b2 - 5.0 * a * b + 10.0 * a2) / denom;
        double c1 = -30.0 * a2 * b2 / denom;
        double c2 = 30.0 * b * a * (b + a) / denom;
        double c3 = -10.0 * (a2 + 4.0 * a * b + b2) / denom;
        double c4 = 15.0 * (a + b) / denom;
        double c5 = -6.0 / denom;
        double twoC2 = 2.0 * c2;
        double threeC3 = 3.0 * c3;
        double fourC4 = 4.0 * c4;
        double fiveC5 = 5.0 * c5;
        int nAtoms = atomClass.length;
        double[] mask = new double[nAtoms];
        for (int i = 0; i < nAtoms; ++i) {
            mask[i] = 1.0;
        }
        double delta = 0.07;
        double gamma = 0.12;
        double delta1 = 1.07;
        double d2 = 1.1449;
        double d4 = 1.31079601;
        double t1n = 1.6057814764784302;
        double gamma1 = 1.12;
        boolean XX = false;
        boolean YY = true;
        int ZZ = 2;
        for (int i = 0; i < nAtoms - 1; ++i) {
            int ii;
            int i3 = i * 3;
            double xi = reducedXYZ[i3 + 0];
            double yi = reducedXYZ[i3 + 1];
            double zi = reducedXYZ[i3 + 2];
            int redi = reductionIndex[i];
            double redv = reductionValue[i];
            double rediv = 1.0 - redv;
            int classI = atomClass[i];
            double ei = eps[classI];
            double sei = TornadoMath.sqrt((double)ei);
            double ri = rMin[classI];
            if (ri <= 0.0) continue;
            double gxi = 0.0;
            double gyi = 0.0;
            double gzi = 0.0;
            double gxredi = 0.0;
            double gyredi = 0.0;
            double gzredi = 0.0;
            for (ii = maskPointers[i3]; ii < maskPointers[i3 + 1]; ++ii) {
                mask[masks[ii]] = scale12;
            }
            for (ii = maskPointers[i3 + 1]; ii < maskPointers[i3 + 2]; ++ii) {
                mask[masks[ii]] = scale13;
            }
            for (ii = maskPointers[i3 + 2]; ii < maskPointers[i3 + 3]; ++ii) {
                mask[masks[ii]] = scale14;
            }
            for (int k = i + 1; k < nAtoms; ++k) {
                int k3 = k * 3;
                double xk = reducedXYZ[k3 + 0];
                double yk = reducedXYZ[k3 + 1];
                double zk = reducedXYZ[k3 + 2];
                double[] dx = new double[]{xi - xk, yi - yk, zi - zk};
                double x = dx[0];
                double y = dx[1];
                double z = dx[2];
                if (!aperiodic) {
                    double xf = x * A00 + y * A10 + z * A20;
                    double yf = x * A01 + y * A11 + z * A21;
                    double zf = x * A02 + y * A12 + z * A22;
                    double xfsn = 0.0;
                    if (-xf > 0.0) {
                        xfsn = 1.0;
                    } else if (-xf < 0.0) {
                        xfsn = -1.0;
                    }
                    double yfsn = 0.0;
                    if (-yf > 0.0) {
                        yfsn = 1.0;
                    } else if (-yf < 0.0) {
                        yfsn = -1.0;
                    }
                    double zfsn = 0.0;
                    if (-zf > 0.0) {
                        zfsn = 1.0;
                    } else if (-zf < 0.0) {
                        zfsn = -1.0;
                    }
                    xf = TornadoMath.floor((double)(TornadoMath.abs((double)xf) + 0.5)) * xfsn + xf;
                    yf = TornadoMath.floor((double)(TornadoMath.abs((double)yf) + 0.5)) * yfsn + yf;
                    zf = TornadoMath.floor((double)(TornadoMath.abs((double)zf) + 0.5)) * zfsn + zf;
                    x = xf * Ai00 + yf * Ai10 + zf * Ai20;
                    y = xf * Ai01 + yf * Ai11 + zf * Ai21;
                    z = xf * Ai02 + yf * Ai12 + zf * Ai22;
                    dx[0] = x;
                    dx[1] = y;
                    dx[2] = z;
                }
                double r2 = x * x + y * y + z * z;
                int classK = atomClass[k];
                double rk = rMin[classK];
                if (!(r2 <= vdwCutoff2) || !(mask[k] > 0.0) || !(rk > 0.0)) continue;
                double ri2 = ri * ri;
                double ri3 = ri * ri2;
                double rk2 = rk * rk;
                double rk3 = rk * rk2;
                double irv = 1.0 / (2.0 * (ri3 + rk3) / (ri2 + rk2));
                double r = TornadoMath.sqrt((double)r2);
                double ek = eps[classK];
                double sek = TornadoMath.sqrt((double)ek);
                double ev = mask[k] * 4.0 * (ei * ek) / ((sei + sek) * (sei + sek));
                double rho = r * irv;
                double rho2 = rho * rho;
                double rhoDisp1 = rho2 * rho2 * rho2;
                double rhoDisp = rhoDisp1 * rho;
                double rhoD = rho + 0.07;
                double rhoD2 = rhoD * rhoD;
                double rhoDelta1 = rhoD2 * rhoD2 * rhoD2;
                double rhoDelta = rhoDelta1 * (rho + 0.07);
                double rhoDispGamma = rhoDisp + 0.12;
                double t1d = 1.0 / rhoDelta;
                double t2d = 1.0 / rhoDispGamma;
                double t1 = 1.6057814764784302 * t1d;
                double t2a = 1.12 * t2d;
                double t2 = t2a - 2.0;
                double eik = ev * t1 * t2;
                double taper = 1.0;
                double dtaper = 0.0;
                if (r2 > vdwTaper2) {
                    double r3 = r2 * r;
                    double r4 = r2 * r2;
                    double r5 = r2 * r3;
                    taper = c5 * r5 + c4 * r4 + c3 * r3 + c2 * r2 + c1 * r + c0;
                    dtaper = fiveC5 * r4 + fourC4 * r3 + threeC3 * r2 + twoC2 * r + c1;
                }
                energy[0] = energy[0] + (eik *= taper);
                interactions[0] = interactions[0] + 1;
                if (!gradient) continue;
                int redk = reductionIndex[k];
                double red = reductionValue[k];
                double redkv = 1.0 - red;
                double dt1d_dr = 7.0 * rhoDelta1 * irv;
                double dt2d_dr = 7.0 * rhoDisp1 * irv;
                double dt1_dr = t1 * dt1d_dr * t1d;
                double dt2_dr = t2a * dt2d_dr * t2d;
                double dedr = -ev * (dt1_dr * t2 + t1 * dt2_dr);
                double ir = 1.0 / r;
                double drdx = dx[0] * ir;
                double drdy = dx[1] * ir;
                double drdz = dx[2] * ir;
                double dswitch = eik * dtaper + dedr * taper;
                double dedx = dswitch * drdx;
                double dedy = dswitch * drdy;
                double dedz = dswitch * drdz;
                gxi += dedx * redv;
                gyi += dedy * redv;
                gzi += dedz * redv;
                gxredi += dedx * rediv;
                gyredi += dedy * rediv;
                gzredi += dedz * rediv;
                int n = k3 + 0;
                grad[n] = grad[n] - red * dedx;
                int n2 = k3 + 1;
                grad[n2] = grad[n2] - red * dedy;
                int n3 = k3 + 2;
                grad[n3] = grad[n3] - red * dedz;
                int r3 = redk * 3;
                int n4 = r3 + 0;
                grad[n4] = grad[n4] - redkv * dedx;
                int n5 = r3 + 1;
                grad[n5] = grad[n5] - redkv * dedy;
                int n6 = r3 + 2;
                grad[n6] = grad[n6] - redkv * dedz;
            }
            if (gradient) {
                int n = i3 + 0;
                grad[n] = grad[n] + gxi;
                int n7 = i3 + 1;
                grad[n7] = grad[n7] + gyi;
                int n8 = i3 + 2;
                grad[n8] = grad[n8] + gzi;
                int r3 = redi * 3;
                int n9 = r3 + 0;
                grad[n9] = grad[n9] + gxredi;
                int n10 = r3 + 1;
                grad[n10] = grad[n10] + gyredi;
                int n11 = r3 + 2;
                grad[n11] = grad[n11] + gzredi;
            }
            for (ii = maskPointers[i3]; ii < maskPointers[i3 + 1]; ++ii) {
                mask[masks[ii]] = 1.0;
            }
            for (ii = maskPointers[i3 + 1]; ii < maskPointers[i3 + 2]; ++ii) {
                mask[masks[ii]] = 1.0;
            }
            for (ii = maskPointers[i3 + 2]; ii < maskPointers[i3 + 3]; ++ii) {
                mask[masks[ii]] = 1.0;
            }
        }
    }

    @Override
    public double energy(boolean gradient, boolean print) {
        if (this.vdwForm.vdwType != VDWType.VDW_TYPE.BUFFERED_14_7) {
            logger.severe(" TornadoVM vdW only supports AMOEBA.");
        }
        for (int i = 0; i < this.nAtoms; ++i) {
            Atom atom = this.atoms[i];
            double x = atom.getX();
            double y = atom.getY();
            double z = atom.getZ();
            int i3 = i * 3;
            this.coordinates[i3 + 0] = x;
            this.coordinates[i3 + 1] = y;
            this.coordinates[i3 + 2] = z;
        }
        double[] eps = this.vdwForm.getEps();
        double[] rmin = this.vdwForm.getRmin();
        double[] reducedXYZ = new double[this.nAtoms * 3];
        boolean XX = false;
        boolean YY = true;
        int ZZ = 2;
        for (int i = 0; i < this.nAtoms; ++i) {
            int i3 = i * 3;
            double x = this.coordinates[i3 + 0];
            double y = this.coordinates[i3 + 1];
            double z = this.coordinates[i3 + 2];
            int redIndex = this.reductionIndex[i];
            if (redIndex >= 0) {
                int r3 = redIndex * 3;
                double rx = this.coordinates[r3 + 0];
                double ry = this.coordinates[r3 + 1];
                double rz = this.coordinates[r3 + 2];
                double r = this.reductionValue[i];
                reducedXYZ[i3 + 0] = r * (x - rx) + rx;
                reducedXYZ[i3 + 1] = r * (y - ry) + ry;
                reducedXYZ[i3 + 2] = r * (z - rz) + rz;
                continue;
            }
            reducedXYZ[i3 + 0] = x;
            reducedXYZ[i3 + 1] = y;
            reducedXYZ[i3 + 2] = z;
        }
        double doGradient = 0.0;
        if (gradient) {
            doGradient = 1.0;
        }
        Crystal c = this.crystal;
        double[] A = new double[]{c.A00, 0.0, 0.0, c.A10, c.A11, 0.0, c.A20, c.A21, c.A22};
        double[] Ai = new double[]{c.Ai00, 0.0, 0.0, c.Ai10, c.Ai11, 0.0, c.Ai20, c.Ai21, c.Ai22};
        double[] bondedScaleFactors = new double[]{this.vdwForm.scale12, this.vdwForm.scale13, this.vdwForm.scale14};
        double aperiodic = 0.0;
        if (this.crystal.aperiodic()) {
            aperiodic = 1.0;
        }
        double[] cutoffs = new double[]{aperiodic, this.vdwTaper, this.vdwCutoff, doGradient};
        double[] energy = new double[1];
        int[] interactions = new int[1];
        if (gradient) {
            Arrays.fill(this.grad, 0.0);
        }
        VanDerWaalsTornado.tornadoEnergy(this.atomClass, eps, rmin, reducedXYZ, this.reductionIndex, this.reductionValue, bondedScaleFactors, this.maskPointer, this.mask, A, Ai, cutoffs, energy, interactions, this.grad);
        logger.info(String.format(" JVM: %16.8f %d", energy[0], interactions[0]));
        energy[0] = 0.0;
        interactions[0] = 0;
        if (gradient) {
            Arrays.fill(this.grad, 0.0);
        }
        TornadoDevice device = TornadoRuntimeProvider.getTornadoRuntime().getDefaultDevice();
        FFXTornado.logDevice((TornadoDevice)device);
        TaskGraph graph = new TaskGraph("vdW").transferToDevice(1, new Object[]{this.atomClass, eps, rmin, reducedXYZ, this.reductionIndex, this.reductionValue, bondedScaleFactors, this.maskPointer, this.mask, A, Ai, cutoffs, energy, interactions, this.grad}).task("energy", VanDerWaalsTornado::tornadoEnergy, (Object)this.atomClass, (Object)eps, (Object)rmin, (Object)reducedXYZ, (Object)this.reductionIndex, (Object)this.reductionValue, (Object)bondedScaleFactors, (Object)this.maskPointer, (Object)this.mask, (Object)A, (Object)Ai, (Object)cutoffs, (Object)energy, (Object)interactions, (Object)this.grad).transferToHost(1, new Object[]{energy, interactions, this.grad});
        ImmutableTaskGraph itg = graph.snapshot();
        TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(new ImmutableTaskGraph[]{itg});
        executionPlan.withWarmUp().withDevice(device);
        executionPlan.execute();
        logger.info(String.format(" Tornado OpenCL: %16.8f %d", energy[0], interactions[0]));
        if (gradient) {
            for (int i = 0; i < this.nAtoms - 1; ++i) {
                Atom ai = this.atoms[i];
                int i3 = i * 3;
                ai.addToXYZGradient(this.grad[i3 + 0], this.grad[i3 + 1], this.grad[i3 + 2]);
            }
        }
        this.energy = energy[0];
        this.interactions = interactions[0];
        return this.energy;
    }

    @Override
    public double getEnergy() {
        return this.energy;
    }

    @Override
    public int getInteractions() {
        return this.interactions;
    }

    public void setAtoms(Atom[] atoms) {
        this.atoms = atoms;
        this.nAtoms = atoms.length;
        this.initAtomArrays();
    }

    @Override
    public void setCrystal(Crystal crystal) {
        this.crystal = crystal;
        int newNSymm = crystal.getNumSymOps();
        if (newNSymm != 1) {
            String message = " SymOps are not supported by VanDerWaalsTornado.\n";
            logger.log(Level.SEVERE, message);
        }
    }

    @Override
    public String toString() {
        StringBuffer sb = new StringBuffer("\n  Van der Waals\n");
        sb.append(String.format("   Switch Start:                         %6.3f (A)\n", this.vdwTaper));
        sb.append(String.format("   Cut-Off:                              %6.3f (A)\n", this.vdwCutoff));
        return sb.toString();
    }

    private void initAtomArrays() {
        if (this.atomClass == null || this.nAtoms > this.atomClass.length) {
            this.atomClass = new int[this.nAtoms];
            this.coordinates = new double[this.nAtoms * 3];
            this.reductionIndex = new int[this.nAtoms];
            this.reductionValue = new double[this.nAtoms];
            this.grad = new double[this.nAtoms * 3];
            this.maskPointer = new int[this.nAtoms * 3 + 1];
        }
        int numBonds = 0;
        int numAngles = 0;
        int numTorsions = 0;
        for (int i = 0; i < this.nAtoms; ++i) {
            Atom ai = this.atoms[i];
            numBonds += ai.getNumBonds();
            numAngles += ai.getNumAngles();
            numTorsions += ai.getNumDihedrals();
        }
        this.mask = new int[numBonds + numAngles + numTorsions];
        int[][] mask12 = this.getMask12();
        int[][] mask13 = this.getMask13();
        int[][] mask14 = this.getMask14();
        int index = 0;
        for (int i = 0; i < this.nAtoms; ++i) {
            String vdwIndex;
            Atom ai = this.atoms[i];
            assert (i == ai.getXyzIndex() - 1);
            double[] xyz = ai.getXYZ(null);
            int i3 = i * 3;
            this.coordinates[i3 + 0] = xyz[0];
            this.coordinates[i3 + 1] = xyz[1];
            this.coordinates[i3 + 2] = xyz[2];
            AtomType atomType = ai.getAtomType();
            if (atomType == null) {
                logger.severe(ai.toString());
            }
            this.atomClass[i] = (vdwIndex = this.forceField.getString("VDWINDEX", "Class")).equalsIgnoreCase("Type") ? atomType.type : atomType.atomClass;
            VDWType type = this.forceField.getVDWType(Integer.toString(this.atomClass[i]));
            if (type == null) {
                logger.info(" No VdW type for atom class " + this.atomClass[i]);
                logger.severe(" No VdW type for atom " + String.valueOf(ai));
                return;
            }
            ai.setVDWType(type);
            List<Bond> bonds = ai.getBonds();
            numBonds = bonds.size();
            if (type.reductionFactor > 0.0 && numBonds == 1) {
                Bond bond = bonds.get(0);
                Atom heavyAtom = bond.get1_2(ai);
                this.reductionIndex[i] = heavyAtom.getIndex() - 1;
                this.reductionValue[i] = type.reductionFactor;
            } else {
                this.reductionIndex[i] = i;
                this.reductionValue[i] = 0.0;
            }
            this.maskPointer[3 * i] = index;
            for (int value : mask12[i]) {
                this.mask[index++] = value;
            }
            this.maskPointer[3 * i + 1] = index;
            for (int value : mask13[i]) {
                this.mask[index++] = value;
            }
            this.maskPointer[3 * i + 2] = index;
            for (int value : mask14[i]) {
                this.mask[index++] = value;
            }
        }
        this.maskPointer[3 * this.nAtoms] = index;
    }

    private void log(int i, int k, double minr, double r, double eij) {
        logger.info(String.format("VDW %6d-%s %6d-%s %10.4f  %10.4f  %10.4f", this.atoms[i].getIndex(), this.atoms[i].getAtomType().name, this.atoms[k].getIndex(), this.atoms[k].getAtomType().name, 1.0 / minr, r, eij));
    }
}

