/*
 * Decompiled with CFR 0.152.
 */
package ffx.potential.openmm;

import com.sun.jna.ptr.PointerByReference;
import ffx.openmm.State;
import ffx.potential.bonded.Atom;
import ffx.potential.utils.EnergyException;
import java.util.Arrays;
import javax.annotation.Nullable;

public class OpenMMState
extends State {
    public final double potentialEnergy;
    public final double kineticEnergy;
    public final double totalEnergy;
    private final int dataTypes = super.getDataTypes();

    protected OpenMMState(PointerByReference pointer) {
        super(pointer);
        if (this.stateContains(8)) {
            this.potentialEnergy = super.getPotentialEnergy() * 0.2390057361376673;
            this.kineticEnergy = super.getKineticEnergy() * 0.2390057361376673;
            this.totalEnergy = this.potentialEnergy + this.kineticEnergy;
        } else {
            this.potentialEnergy = 0.0;
            this.kineticEnergy = 0.0;
            this.totalEnergy = 0.0;
        }
    }

    public double[] getAccelerations(@Nullable double[] a, Atom[] atoms) {
        if (!this.stateContains(4)) {
            return a;
        }
        double[] forces = this.getForces();
        int n = forces.length;
        if (atoms == null || atoms.length == 0) {
            throw new IllegalArgumentException("Atoms array must not be null or empty.");
        }
        if (atoms.length * 3 != n) {
            String message = String.format(" The number of atoms (%d) does not match the number of degrees of freedom (%d).", atoms.length, n);
            throw new IllegalArgumentException(message);
        }
        if (a == null || a.length != n) {
            a = new double[n];
        }
        int index = 0;
        for (Atom atom : atoms) {
            double mass = atom.getMass();
            double xx = forces[index] * 10.0 / mass;
            double yy = forces[index + 1] * 10.0 / mass;
            double zz = forces[index + 2] * 10.0 / mass;
            a[index] = xx;
            a[index + 1] = yy;
            a[index + 2] = zz;
            index += 3;
        }
        return a;
    }

    public double[] getActiveAccelerations(@Nullable double[] a, Atom[] atoms) {
        if (!this.stateContains(4)) {
            return a;
        }
        return OpenMMState.filterToActive(this.getAccelerations(null, atoms), a, atoms);
    }

    public double[] getGradient(@Nullable double[] g) {
        if (!this.stateContains(4)) {
            return g;
        }
        double[] forces = this.getForces();
        int n = forces.length;
        if (g == null || g.length != n) {
            g = new double[n];
        }
        for (int i = 0; i < n; ++i) {
            double xx = -forces[i] * 0.1 * 0.2390057361376673;
            if (Double.isNaN(xx) || Double.isInfinite(xx)) {
                throw new EnergyException(String.format(" The gradient of degree of freedom %d is %8.3f.", i, xx));
            }
            g[i] = xx;
        }
        return g;
    }

    public double[] getActiveGradient(@Nullable double[] g, Atom[] atoms) {
        if (!this.stateContains(4)) {
            return g;
        }
        return OpenMMState.filterToActive(this.getGradient(null), g, atoms);
    }

    public double[][] getPeriodicBoxVectors() {
        if (!this.stateContains(1)) {
            return null;
        }
        double[][] latticeVectors = super.getPeriodicBoxVectors();
        double[] dArray = latticeVectors[0];
        dArray[0] = dArray[0] * 10.0;
        double[] dArray2 = latticeVectors[0];
        dArray2[1] = dArray2[1] * 10.0;
        double[] dArray3 = latticeVectors[0];
        dArray3[2] = dArray3[2] * 10.0;
        double[] dArray4 = latticeVectors[1];
        dArray4[0] = dArray4[0] * 10.0;
        double[] dArray5 = latticeVectors[1];
        dArray5[1] = dArray5[1] * 10.0;
        double[] dArray6 = latticeVectors[1];
        dArray6[2] = dArray6[2] * 10.0;
        double[] dArray7 = latticeVectors[2];
        dArray7[0] = dArray7[0] * 10.0;
        double[] dArray8 = latticeVectors[2];
        dArray8[1] = dArray8[1] * 10.0;
        double[] dArray9 = latticeVectors[2];
        dArray9[2] = dArray9[2] * 10.0;
        return latticeVectors;
    }

    public double[] getPositions(@Nullable double[] x) {
        if (!this.stateContains(1)) {
            return x;
        }
        double[] pos = this.getPositions();
        int n = pos.length;
        if (x == null || x.length != n) {
            x = new double[n];
        }
        for (int i = 0; i < n; ++i) {
            x[i] = pos[i] * 10.0;
        }
        return x;
    }

    public double[] getActivePositions(@Nullable double[] x, Atom[] atoms) {
        if (!this.stateContains(1)) {
            return x;
        }
        return OpenMMState.filterToActive(this.getPositions(null), x, atoms);
    }

    public double[] getVelocities(@Nullable double[] v) {
        if (!this.stateContains(2)) {
            return v;
        }
        double[] vel = this.getVelocities();
        int n = vel.length;
        if (v == null || v.length != n) {
            v = new double[n];
        }
        for (int i = 0; i < n; ++i) {
            v[i] = vel[i] * 10.0;
        }
        return v;
    }

    public double[] getActiveVelocities(@Nullable double[] v, Atom[] atoms) {
        if (!this.stateContains(2)) {
            return v;
        }
        return OpenMMState.filterToActive(this.getVelocities(null), v, atoms);
    }

    public double getPeriodicBoxVolume() {
        return super.getPeriodicBoxVolume() * 10.0 * 10.0 * 10.0;
    }

    public double getPotentialEnergy() {
        return this.potentialEnergy;
    }

    public double getKineticEnergy() {
        return this.kineticEnergy;
    }

    public double getTotalEnergy() {
        return this.totalEnergy;
    }

    public int getDataTypes() {
        return this.dataTypes;
    }

    private boolean stateContains(int dataType) {
        return (this.dataTypes & dataType) == dataType;
    }

    private static double[] filterToActive(double[] source, @Nullable double[] target, Atom[] atoms) {
        if (source == null || atoms == null) {
            throw new IllegalArgumentException("The arrays must be non-null.");
        }
        if (source.length != atoms.length * 3) {
            throw new IllegalArgumentException("Source array length must be three times the number of atoms.");
        }
        int count = (int)Arrays.stream(atoms).filter(Atom::isActive).count();
        if (target == null || target.length < count * 3) {
            target = new double[count * 3];
        }
        int sourceIndedx = 0;
        int targetIndex = 0;
        for (Atom atom : atoms) {
            if (atom.isActive()) {
                target[targetIndex] = source[sourceIndedx];
                target[targetIndex + 1] = source[sourceIndedx + 1];
                target[targetIndex + 2] = source[sourceIndedx + 2];
                targetIndex += 3;
            }
            sourceIndedx += 3;
        }
        return target;
    }
}

