/*
 * Decompiled with CFR 0.152.
 */
package ffx.algorithms.optimize;

import com.google.common.collect.Lists;
import com.google.common.collect.MinMaxPriorityQueue;
import ffx.numerics.math.HilbertCurveTransforms;
import ffx.potential.AssemblyState;
import ffx.potential.ForceFieldEnergy;
import ffx.potential.MolecularAssembly;
import ffx.potential.bonded.Atom;
import ffx.potential.bonded.Bond;
import ffx.potential.bonded.Molecule;
import ffx.potential.utils.Superpose;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.logging.Logger;
import org.apache.commons.math3.util.FastMath;

public class TorsionSearch {
    private static final Logger logger = Logger.getLogger(TorsionSearch.class.getName());
    private final MolecularAssembly molecularAssembly;
    private final ForceFieldEnergy forceFieldEnergy;
    private final double[] x;
    private final boolean run;
    private final List<Bond> torsionalBonds;
    private final List<Atom[]> atomGroups;
    private final int nTorsionsPerBond;
    private final int nBits;
    private int nTorsionalBonds;
    private final int returnedStates;
    private long numConfigs;
    private long numIndices;
    private long end;
    private double minEnergy = Double.MAX_VALUE;
    private final List<Double> energies = new ArrayList<Double>();
    private final List<Long> hilbertIndices = new ArrayList<Long>();
    private final List<AssemblyState> states = new ArrayList<AssemblyState>();
    private MinMaxPriorityQueue<StateContainer> queue = null;
    private long[] workerAssignments;
    private long indicesPerAssignment;
    private boolean runWorker = false;

    public TorsionSearch(MolecularAssembly molecularAssembly, Molecule molecule, int nTorsionsPerBond, int returnedStates) {
        this.molecularAssembly = molecularAssembly;
        if (Arrays.stream(molecularAssembly.getMoleculeArray()).anyMatch(mol -> mol == molecule)) {
            this.run = true;
        } else {
            logger.warning("Molecule is not part of the assembly. Torsion scan will not run.");
            this.run = false;
        }
        this.forceFieldEnergy = molecularAssembly.getPotentialEnergy();
        this.x = new double[this.forceFieldEnergy.getNumberOfVariables()];
        this.nTorsionsPerBond = nTorsionsPerBond;
        this.nBits = (int)FastMath.ceil((double)FastMath.log((double)2.0, (double)nTorsionsPerBond));
        this.torsionalBonds = TorsionSearch.getTorsionalBonds(molecule);
        this.atomGroups = TorsionSearch.getRotationGroups(this.torsionalBonds);
        this.nTorsionalBonds = this.torsionalBonds.size();
        this.numConfigs = (long)FastMath.pow((double)this.nTorsionsPerBond, (int)this.torsionalBonds.size());
        this.numIndices = (long)FastMath.pow((double)2.0, (int)(this.nBits * this.torsionalBonds.size()));
        this.returnedStates = returnedStates != -1 ? returnedStates : (int)this.numConfigs;
        this.end = this.numIndices;
    }

    public void spinTorsions() {
        logger.info("\n ----------------- Starting torsion scan -----------------");
        logger.info(String.format(" Number of configurations: %d", this.numConfigs));
        logger.info(String.format(" Number of indices: %d", this.numIndices));
        this.spinTorsions(0L, this.end);
    }

    public void spinTorsions(long start, long end) {
        logger.info("\n ----------------- Starting torsion scan -----------------");
        logger.info(String.format(" Number of configurations: %d", this.numConfigs));
        logger.info(String.format(" Number of indices: %d", this.numIndices));
        this.spinTorsions(start, end, true);
    }

    private void spinTorsions(long start, long end, boolean updateLists) {
        if (!this.run) {
            logger.warning("Torsion spin returning early since molecule not part of molecular assembly.");
            return;
        }
        long[] currentState = new long[this.nTorsionalBonds];
        if (this.queue == null) {
            this.queue = MinMaxPriorityQueue.maximumSize((int)this.returnedStates).create();
        }
        int progress = 0;
        while (start <= end && (long)progress < this.numConfigs) {
            long[] newState = HilbertCurveTransforms.hilbertIndexToCoordinates((int)this.nTorsionalBonds, (int)this.nBits, (long)start);
            boolean exit = false;
            for (long ind : newState) {
                if (ind < (long)this.nTorsionsPerBond) continue;
                exit = true;
                break;
            }
            if (exit) {
                ++start;
                continue;
            }
            TorsionSearch.changeState(currentState, newState, this.nTorsionsPerBond, this.torsionalBonds, this.atomGroups);
            this.forceFieldEnergy.getCoordinates(this.x);
            double energy = this.forceFieldEnergy.energy(this.x);
            if (energy < this.minEnergy) {
                this.minEnergy = energy;
                logger.info(String.format("\n New minimum energy: %12.5f", this.minEnergy));
                logger.info(String.format(" Hilbert Index: %d; Coordinate State: " + Arrays.toString(newState), start));
            }
            this.queue.add((Object)new StateContainer(new AssemblyState(this.molecularAssembly), energy, start));
            currentState = newState;
            ++start;
            ++progress;
        }
        TorsionSearch.changeState(currentState, new long[this.nTorsionalBonds], this.nTorsionsPerBond, this.torsionalBonds, this.atomGroups);
        if ((long)progress == this.numConfigs) {
            logger.info("\n Completed all configurations before end index.");
        }
        if (updateLists) {
            this.updateInfoLists();
        }
    }

    public void staticAnalysis(int numRemove, double eliminationThreshold) {
        int i;
        if (!this.run) {
            logger.warning("Static analysis returning early since molecule not part of molecular assembly.");
            return;
        }
        if (this.queue == null) {
            this.queue = MinMaxPriorityQueue.maximumSize((int)this.returnedStates).create();
        }
        eliminationThreshold = FastMath.abs((double)eliminationThreshold);
        this.forceFieldEnergy.getCoordinates(this.x);
        double initialE = this.forceFieldEnergy.energy(this.x);
        ArrayList remove = new ArrayList();
        long[] state = new long[this.nTorsionalBonds];
        long[] oldState = new long[this.nTorsionalBonds];
        for (i = 0; i < this.nTorsionalBonds; ++i) {
            int j;
            int n = j = i == 0 ? 0 : 1;
            while (j < this.nTorsionsPerBond) {
                state[i] = j;
                TorsionSearch.changeState(oldState, state, this.nTorsionsPerBond, this.torsionalBonds, this.atomGroups);
                this.forceFieldEnergy.getCoordinates(this.x);
                double newEnergy = this.forceFieldEnergy.energy(this.x);
                if (newEnergy - initialE > eliminationThreshold && !remove.contains(i)) {
                    remove.add(i);
                } else {
                    this.queue.add((Object)new StateContainer(new AssemblyState(this.molecularAssembly), newEnergy, -1L));
                }
                TorsionSearch.changeState(state, oldState, this.nTorsionsPerBond, this.torsionalBonds, this.atomGroups);
                ++j;
            }
            state[i] = 0L;
        }
        remove.sort(Collections.reverseOrder());
        logger.info("\n " + remove.size() + " bonds that cause large energy increase: " + String.valueOf(remove));
        if (remove.size() > numRemove) {
            remove = Lists.newArrayList(remove.subList(0, numRemove));
        }
        for (i = remove.size() - 1; i >= 0; --i) {
            logger.info(" Removing bond: " + String.valueOf(this.torsionalBonds.get((Integer)remove.get(i))));
            logger.info(" Bond index: " + String.valueOf(remove.get(i)));
            this.torsionalBonds.set((Integer)remove.get(i), null);
            this.atomGroups.set((Integer)remove.get(i), null);
        }
        this.torsionalBonds.removeAll(Collections.singleton(null));
        this.atomGroups.removeAll(Collections.singleton(null));
        this.nTorsionalBonds = this.torsionalBonds.size();
        this.end = (long)FastMath.pow((double)2.0, (int)(this.nBits * this.nTorsionalBonds));
        this.numConfigs = (long)FastMath.pow((double)this.nTorsionsPerBond, (int)this.nTorsionalBonds);
        this.numIndices = this.end;
        logger.info(" Finished static analysis.");
        logger.info(String.format(" Number of configurations after elimination: %d", this.numConfigs));
        logger.info(String.format(" Number of indices after elimination: %d", this.numIndices));
        this.updateInfoLists();
    }

    public boolean buildWorker(int rank, int worldSize) {
        if (!this.run) {
            logger.warning("Build worker returning early since molecule not part of molecular assembly.");
            return false;
        }
        if (rank >= worldSize) {
            logger.warning(" Rank is greater than world size.");
            return false;
        }
        this.runWorker = true;
        this.workerAssignments = new long[worldSize];
        long jobsPerWorker = this.numIndices / (long)worldSize;
        this.indicesPerAssignment = jobsPerWorker / (long)worldSize;
        logger.info("\n Jobs per worker: " + jobsPerWorker);
        logger.info(" Jobs per worker split: " + this.indicesPerAssignment);
        for (int i = 0; i < worldSize; ++i) {
            this.workerAssignments[i] = (long)i * jobsPerWorker + (long)rank * this.indicesPerAssignment;
        }
        logger.info(" Worker " + rank + " assigned indices: " + Arrays.toString(this.workerAssignments));
        return true;
    }

    public void runWorker() {
        if (!this.run) {
            logger.warning("Worker returning early since molecule not part of molecular assembly.");
            return;
        }
        if (!this.runWorker) {
            logger.warning("Worker returning early since worker not built or is invalid.");
            return;
        }
        logger.info("\n ----------------- Starting torsion scan -----------------");
        logger.info(String.format("\n Number of configurations before worker starts: %d", this.numConfigs));
        logger.info(String.format(" Number of indices before worker starts: %d", this.numIndices));
        for (int i = 0; i < this.workerAssignments.length; ++i) {
            logger.info(String.format(" Worker torsion assignment %3d of %3d: %12d to %12d.", i + 1, this.workerAssignments.length, this.workerAssignments[i], this.workerAssignments[i] + this.indicesPerAssignment - 1L));
            this.spinTorsions(this.workerAssignments[i], this.workerAssignments[i] + this.indicesPerAssignment - 1L, i == this.workerAssignments.length - 1);
        }
    }

    public List<Double> getEnergies() {
        return this.energies;
    }

    public List<Long> getHilbertIndices() {
        return this.hilbertIndices;
    }

    public List<AssemblyState> getStates() {
        return this.states;
    }

    public long getEnd() {
        return this.end;
    }

    private void updateInfoLists() {
        if (!this.states.isEmpty()) {
            for (int i = 0; i < this.states.size(); ++i) {
                this.queue.add((Object)new StateContainer(this.states.get(i), this.energies.get(i), this.hilbertIndices.get(i)));
            }
            this.states.clear();
            this.energies.clear();
            this.hilbertIndices.clear();
        }
        while (!this.queue.isEmpty()) {
            StateContainer toBeSaved = (StateContainer)this.queue.removeFirst();
            this.states.add(toBeSaved.getState());
            this.energies.add(toBeSaved.getEnergy());
            this.hilbertIndices.add(toBeSaved.getIndex());
        }
    }

    private static List<Bond> getTorsionalBonds(Molecule molecule) {
        ArrayList<Bond> torsionalBonds = new ArrayList<Bond>();
        for (Bond bond : molecule.getBondList()) {
            Atom a1 = bond.getAtom(0);
            Atom a2 = bond.getAtom(1);
            List bond1 = a1.getBonds();
            int b1 = bond1.size();
            List bond2 = a2.getBonds();
            int b2 = bond2.size();
            if (b1 <= 1 || b2 <= 1 || a1.getAtomicNumber() == 6 && a1.getNumberOfBondedHydrogen() == 3 || a2.getAtomicNumber() == 6 && a2.getNumberOfBondedHydrogen() == 3 || a1.isRing(a2)) continue;
            torsionalBonds.add(bond);
            logger.info(" Bond " + String.valueOf(bond) + " is a torsional bond.");
        }
        return torsionalBonds;
    }

    private static List<Atom[]> getRotationGroups(List<Bond> bonds) {
        ArrayList<Atom[]> rotationGroups = new ArrayList<Atom[]>();
        for (Bond bond : bonds) {
            Atom a1 = bond.getAtom(0);
            Atom a2 = bond.getAtom(1);
            ArrayList<Atom> a1List = new ArrayList<Atom>();
            ArrayList<Atom> a2List = new ArrayList<Atom>();
            TorsionSearch.searchTorsions(a1, a1List, a2);
            TorsionSearch.searchTorsions(a2, a2List, a1);
            Atom[] a1Array = new Atom[a1List.size()];
            Atom[] a2Array = new Atom[a2List.size()];
            a1List.toArray(a1Array);
            a2List.toArray(a2Array);
            if (a1List.size() > a2List.size()) {
                rotationGroups.add(a2Array);
                continue;
            }
            rotationGroups.add(a1Array);
        }
        return rotationGroups;
    }

    private static void rotateAbout(double[] u, Atom a2, double theta) {
        theta = FastMath.toRadians((double)theta);
        double[] quaternion = new double[]{FastMath.cos((double)(theta / 2.0)), u[0] * FastMath.sin((double)(theta / 2.0)), u[1] * FastMath.sin((double)(theta / 2.0)), u[2] * FastMath.sin((double)(theta / 2.0))};
        double quaternionNorm = 1.0 / Math.sqrt(quaternion[0] * quaternion[0] + quaternion[1] * quaternion[1] + quaternion[2] * quaternion[2] + quaternion[3] * quaternion[3]);
        int i = 0;
        while (i < 4) {
            int n = i++;
            quaternion[n] = quaternion[n] * quaternionNorm;
        }
        double q1q1 = quaternion[1] * quaternion[1];
        double q2q2 = quaternion[2] * quaternion[2];
        double q3q3 = quaternion[3] * quaternion[3];
        double q0q1 = quaternion[0] * quaternion[1];
        double q0q2 = quaternion[0] * quaternion[2];
        double q0q3 = quaternion[0] * quaternion[3];
        double q1q2 = quaternion[1] * quaternion[2];
        double q1q3 = quaternion[1] * quaternion[3];
        double q2q3 = quaternion[2] * quaternion[3];
        double[][] rotation2 = new double[3][3];
        rotation2[0][0] = 1.0 - 2.0 * (q2q2 + q3q3);
        rotation2[0][1] = 2.0 * (q1q2 - q0q3);
        rotation2[0][2] = 2.0 * (q1q3 + q0q2);
        rotation2[1][0] = 2.0 * (q1q2 + q0q3);
        rotation2[1][1] = 1.0 - 2.0 * (q1q1 + q3q3);
        rotation2[1][2] = 2.0 * (q2q3 - q0q1);
        rotation2[2][0] = 2.0 * (q1q3 - q0q2);
        rotation2[2][1] = 2.0 * (q2q3 + q0q1);
        rotation2[2][2] = 1.0 - 2.0 * (q1q1 + q2q2);
        double[] a2XYZ = new double[3];
        a2.getXYZ(a2XYZ);
        Superpose.applyRotation((double[])a2XYZ, (double[][])rotation2);
        a2.setXYZ(a2XYZ);
    }

    private static void searchTorsions(Atom seed, List<Atom> atoms, Atom notAtom) {
        if (seed == null) {
            return;
        }
        atoms.add(seed);
        for (Bond b : seed.getBonds()) {
            Atom nextAtom = b.get1_2(seed);
            if (nextAtom == notAtom || atoms.contains(nextAtom)) continue;
            TorsionSearch.searchTorsions(nextAtom, atoms, notAtom);
        }
    }

    private static void changeState(long[] oldState, long[] newState, int nTorsions, List<Bond> bonds, List<Atom[]> atoms) {
        for (int i = 0; i < oldState.length; ++i) {
            int j;
            double unit;
            double sqrt;
            if (oldState[i] == newState[i]) continue;
            int change = (int)(newState[i] - oldState[i]);
            int turnDegrees = change * (360 / nTorsions);
            double[] u = new double[3];
            double[] translation = new double[3];
            double[] a1 = bonds.get(i).getAtom(0).getXYZ(new double[3]);
            double[] a2 = bonds.get(i).getAtom(1).getXYZ(new double[3]);
            if (atoms.get(i)[0] == bonds.get(i).getAtom(0)) {
                for (int j2 = 0; j2 < 3; ++j2) {
                    u[j2] = a1[j2] - a2[j2];
                    translation[j2] = -a1[j2];
                }
                sqrt = Math.sqrt(u[0] * u[0] + u[1] * u[1] + u[2] * u[2]);
                unit = 1.0 / sqrt;
                j = 0;
                while (j < 3) {
                    int n = j++;
                    u[n] = u[n] * unit;
                }
                for (j = 0; j < atoms.get(i).length; ++j) {
                    atoms.get(i)[j].move(translation);
                    TorsionSearch.rotateAbout(u, atoms.get(i)[j], turnDegrees);
                    atoms.get(i)[j].move(a1);
                }
                continue;
            }
            for (int j3 = 0; j3 < 3; ++j3) {
                u[j3] = a2[j3] - a1[j3];
                translation[j3] = -a2[j3];
            }
            sqrt = Math.sqrt(u[0] * u[0] + u[1] * u[1] + u[2] * u[2]);
            unit = 1.0 / sqrt;
            j = 0;
            while (j < 3) {
                int n = j++;
                u[n] = u[n] * unit;
            }
            for (j = 0; j < atoms.get(i).length; ++j) {
                atoms.get(i)[j].move(translation);
                TorsionSearch.rotateAbout(u, atoms.get(i)[j], turnDegrees);
                atoms.get(i)[j].move(a2);
            }
        }
    }

    private record StateContainer(AssemblyState state, double e, long hilbertIndex) implements Comparable<StateContainer>
    {
        AssemblyState getState() {
            return this.state;
        }

        double getEnergy() {
            return this.e;
        }

        long getIndex() {
            return this.hilbertIndex;
        }

        @Override
        public int compareTo(StateContainer o) {
            return Double.compare(this.e, o.getEnergy());
        }
    }
}

