/*
 * Decompiled with CFR 0.152.
 */
package ffx.algorithms.optimize.manybody;

import ffx.algorithms.optimize.RotamerOptimization;
import ffx.algorithms.optimize.manybody.DistanceMatrix;
import ffx.algorithms.optimize.manybody.EnergyExpansion;
import ffx.potential.bonded.MultiResidue;
import ffx.potential.bonded.Residue;
import ffx.potential.bonded.Rotamer;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;

public class EliminatedRotamers {
    private static final Logger logger = Logger.getLogger(EliminatedRotamers.class.getName());
    private final RotamerOptimization rO;
    private final DistanceMatrix dM;
    private final List<Residue> allResiduesList;
    private final int maxRotCheckDepth;
    private final double clashThreshold;
    private final double pairClashThreshold;
    private final double multiResClashThreshold;
    private final double nucleicPruningFactor;
    private final double nucleicPairsPruningFactor;
    private final double multiResPairClashAddn;
    private final boolean pruneClashes;
    private final boolean prunePairClashes;
    private final boolean print;
    public boolean[][] eliminatedSingles;
    public boolean[][][][] eliminatedPairs;
    public boolean[][] onlyPrunedSingles;
    public boolean[][][][] onlyPrunedPairs;
    private EnergyExpansion eE;

    public EliminatedRotamers(RotamerOptimization rO, DistanceMatrix dM, List<Residue> allResiduesList, int maxRotCheckDepth, double clashThreshold, double pairClashThreshold, double multiResClashThreshold, double nucleicPruningFactor, double nucleicPairsPruningFactor, double multiResPairClashAddn, boolean pruneClashes, boolean prunePairClashes, boolean print, Residue[] residues) {
        this.rO = rO;
        this.dM = dM;
        this.allResiduesList = allResiduesList;
        this.maxRotCheckDepth = maxRotCheckDepth;
        this.clashThreshold = clashThreshold;
        this.pairClashThreshold = pairClashThreshold;
        this.multiResClashThreshold = multiResClashThreshold;
        this.nucleicPruningFactor = nucleicPruningFactor;
        this.nucleicPairsPruningFactor = nucleicPairsPruningFactor;
        this.multiResPairClashAddn = multiResPairClashAddn;
        this.pruneClashes = pruneClashes;
        this.prunePairClashes = prunePairClashes;
        this.print = print;
        this.allocateEliminationMemory(residues);
    }

    public boolean check(int i, int ri) {
        if (this.eliminatedSingles == null) {
            return false;
        }
        return this.eliminatedSingles[i][ri];
    }

    public boolean check(int i, int ri, int j, int rj) {
        if (this.eliminatedPairs == null) {
            return false;
        }
        if (j < i) {
            int ii = i;
            int iri = ri;
            i = j;
            ri = rj;
            j = ii;
            rj = iri;
        }
        return this.eliminatedPairs[i][ri][j][rj];
    }

    public boolean checkPrunedPairs(int i, int ri, int j, int rj) {
        if (this.onlyPrunedPairs == null) {
            return false;
        }
        if (j < i) {
            int ii = i;
            int iri = ri;
            i = j;
            ri = rj;
            j = ii;
            rj = iri;
        }
        return this.onlyPrunedPairs[i][ri][j][rj];
    }

    public boolean checkPrunedSingles(int i, int ri) {
        if (this.onlyPrunedSingles == null) {
            return false;
        }
        return this.onlyPrunedSingles[i][ri];
    }

    public boolean checkToJ(int i, int ri, int j, int rj) {
        return this.check(j, rj) || this.check(i, ri, j, rj);
    }

    public boolean checkToK(int i, int ri, int j, int rj, int k, int rk) {
        return this.check(k, rk) || this.check(i, ri, k, rk) || this.check(j, rj, k, rk);
    }

    public boolean checkToL(int i, int ri, int j, int rj, int k, int rk, int l, int rl) {
        return this.check(l, rl) || this.check(i, ri, l, rl) || this.check(j, rj, l, rl) || this.check(k, rk, l, rl);
    }

    public boolean eliminateRotamer(Residue[] residues, int i, int ri, boolean verbose) {
        int eliminatedPairs;
        if (this.check(i, ri)) {
            return false;
        }
        int[] validRots = this.rotamerCount(residues, i);
        int rotCount = 0;
        for (int rii = 0; rii < validRots.length; ++rii) {
            if (rii == ri) continue;
            ++rotCount;
        }
        if (rotCount == 0) {
            return false;
        }
        this.eliminatedSingles[i][ri] = true;
        if (verbose) {
            Rotamer[] rotamers = residues[i].getRotamers();
            this.rO.logIfRank0(String.format(" Rotamer (%8s,%2d) eliminated (%2d left).", residues[i].toString(rotamers[ri]), ri, rotCount));
        }
        if ((eliminatedPairs = this.eliminateRotamerPairs(residues, i, ri, verbose)) > 0 && verbose) {
            this.rO.logIfRank0(String.format("  Eliminated %2d rotamer pairs.", eliminatedPairs));
        }
        return true;
    }

    public boolean eliminateRotamerPair(Residue[] residues, int i, int ri, int j, int rj, boolean verbose) {
        if (i > j) {
            int ii = i;
            int iri = ri;
            i = j;
            ri = rj;
            j = ii;
            rj = iri;
        }
        if (!this.check(i, ri, j, rj)) {
            this.eliminatedPairs[i][ri][j][rj] = true;
            if (verbose) {
                Rotamer[] rotI = residues[i].getRotamers();
                Rotamer[] rotJ = residues[j].getRotamers();
                this.rO.logIfRank0(String.format("  Rotamer pair eliminated: [(%8s,%2d) (%8s,%2d)]", residues[i].toString(rotI[ri]), ri, residues[j].toString(rotJ[rj]), rj));
            }
            return true;
        }
        return false;
    }

    public int eliminateRotamerPairs(Residue[] residues, int i, int ri, boolean verbose) {
        int eliminatedPairs = 0;
        for (int j = 0; j < residues.length; ++j) {
            if (j == i) continue;
            Residue resJ = residues[j];
            int lenRj = resJ.getRotamers().length;
            for (int rj = 0; rj < lenRj; ++rj) {
                if (!this.eliminateRotamerPair(residues, i, ri, j, rj, verbose)) continue;
                ++eliminatedPairs;
            }
        }
        return eliminatedPairs;
    }

    public boolean pairsToSingleElimination(Residue[] residues, int i, int j) {
        boolean pairRemaining;
        assert (i != j);
        assert (i < residues.length);
        assert (j < residues.length);
        Residue residueI = residues[i];
        Residue residueJ = residues[j];
        Rotamer[] rotI = residueI.getRotamers();
        Rotamer[] rotJ = residueJ.getRotamers();
        int lenRi = rotI.length;
        int lenRj = rotJ.length;
        boolean eliminated = false;
        for (int ri = 0; ri < lenRi; ++ri) {
            if (this.check(i, ri)) continue;
            pairRemaining = false;
            for (int rj = 0; rj < lenRj; ++rj) {
                if (this.check(j, rj) || this.check(i, ri, j, rj)) continue;
                pairRemaining = true;
                break;
            }
            if (pairRemaining) continue;
            if (this.eliminateRotamer(residues, i, ri, this.print)) {
                eliminated = true;
                this.rO.logIfRank0(String.format(" Eliminating rotamer %s-%d with no remaining pairs to residue %s.", residueI.toString(rotI[ri]), ri, residueJ));
                continue;
            }
            this.rO.logIfRank0(String.format(" Already eliminated rotamer %s-%d with no remaining pairs to residue %s.", residueI.toString(rotI[ri]), ri, residueJ), Level.WARNING);
        }
        for (int rj = 0; rj < lenRj; ++rj) {
            if (this.check(j, rj)) continue;
            pairRemaining = false;
            for (int ri = 0; ri < lenRi; ++ri) {
                if (this.check(i, ri) || this.check(i, ri, j, rj)) continue;
                pairRemaining = true;
                break;
            }
            if (pairRemaining) continue;
            if (this.eliminateRotamer(residues, j, rj, this.print)) {
                eliminated = true;
                this.rO.logIfRank0(String.format(" Eliminating rotamer %s-%d with no remaining pairs to residue %s.", residueJ.toString(rotJ[rj]), rj, residueI));
                continue;
            }
            this.rO.logIfRank0(String.format(" Already eliminated rotamer J %s-%d with no remaining pairs to residue %s.", residueJ.toString(rotJ[rj]), rj, residueI), Level.WARNING);
        }
        return eliminated;
    }

    public void prePrunePairs(Residue[] residues) {
        int nResidues = residues.length;
        for (int i = 0; i < nResidues - 1; ++i) {
            Residue resi = residues[i];
            Rotamer[] rotI = resi.getRotamers();
            int ni = rotI.length;
            for (int j = i + 1; j < nResidues; ++j) {
                Residue resJ = residues[j];
                Rotamer[] rotJ = resJ.getRotamers();
                int nj = rotJ.length;
                for (int ri = 0; ri < ni; ++ri) {
                    if (!this.validRotamer(residues, i, ri)) continue;
                    for (int rj = 0; rj < nj; ++rj) {
                        if (!this.validRotamer(residues, j, rj) || this.check(i, ri, j, rj) || this.check(i, ri, j, rj) || !Double.isNaN(this.eE.get2Body(i, ri, j, rj))) continue;
                        this.rO.logIfRank0(String.format(" Rotamer Pair (%7s,%2d) (%7s,%2d) 2-body energy %12.4f pre-pruned since energy is NaN.", i, ri, j, rj, this.eE.get2Body(i, ri, j, rj)));
                        this.eliminateRotamerPair(residues, i, ri, j, rj, this.print);
                    }
                }
            }
        }
    }

    public void prePruneSelves(Residue[] residues) {
        for (int i = 0; i < residues.length; ++i) {
            Residue residue = residues[i];
            Rotamer[] rotamers = residue.getRotamers();
            int nRot = rotamers.length;
            for (int ri = 0; ri < nRot; ++ri) {
                if (this.check(i, ri) || !Double.isNaN(this.eE.getSelf(i, ri))) continue;
                this.rO.logIfRank0(String.format(" Rotamer (%7s,%2d) self-energy %12.4f pre-pruned since energy is NaN.", residue, ri, this.eE.getSelf(i, ri)));
                this.eliminateRotamer(residues, i, ri, false);
            }
        }
    }

    public void prunePairClashes(Residue[] residues) {
        if (!this.prunePairClashes) {
            return;
        }
        int nResidues = residues.length;
        for (int i = 0; i < nResidues - 1; ++i) {
            Residue residueI = residues[i];
            Rotamer[] rotI = residueI.getRotamers();
            int lenRi = rotI.length;
            int indI = this.allResiduesList.indexOf(residueI);
            for (int j = i + 1; j < nResidues; ++j) {
                Residue residueJ = residues[j];
                Rotamer[] rotJ = residueJ.getRotamers();
                int lenRj = rotJ.length;
                int indJ = this.allResiduesList.indexOf(residueJ);
                double minPair = Double.MAX_VALUE;
                int minRI = -1;
                int minRJ = -1;
                boolean cutoffPair = true;
                for (int ri = 0; ri < lenRi; ++ri) {
                    if (this.check(i, ri)) continue;
                    for (int rj = 0; rj < lenRj; ++rj) {
                        if (this.check(j, rj) || this.check(i, ri, j, rj) || this.dM.checkPairDistThreshold(indI, ri, indJ, rj)) continue;
                        cutoffPair = false;
                        double pairEnergy = this.eE.get2Body(i, ri, j, rj) + this.eE.getSelf(i, ri) + this.eE.getSelf(j, rj);
                        assert (Double.isFinite(pairEnergy));
                        if (!(pairEnergy < minPair)) continue;
                        minPair = pairEnergy;
                        minRI = ri;
                        minRJ = rj;
                    }
                }
                if (cutoffPair) continue;
                assert (minRI >= 0 && minRJ >= 0);
                double threshold = this.pairClashThreshold;
                if (residueI instanceof MultiResidue) {
                    threshold += this.multiResPairClashAddn;
                }
                if (residueJ instanceof MultiResidue) {
                    threshold += this.multiResPairClashAddn;
                }
                int numNARes = (residueI.getResidueType() == Residue.ResidueType.NA ? 1 : 0) + (residueJ.getResidueType() == Residue.ResidueType.NA ? 1 : 0);
                switch (numNARes) {
                    case 0: {
                        break;
                    }
                    case 1: {
                        threshold *= this.nucleicPairsPruningFactor;
                        break;
                    }
                    case 2: {
                        threshold *= this.nucleicPruningFactor;
                        break;
                    }
                    default: {
                        throw new ArithmeticException(" RotamerOptimization.prunePairClashes() has somehow found less than zero or more than two nucleic acid residues in a pair of residues. This result should be impossible.");
                    }
                }
                double toEliminate = threshold + minPair;
                for (int ri = 0; ri < lenRi; ++ri) {
                    if (this.check(i, ri)) continue;
                    for (int rj = 0; rj < lenRj; ++rj) {
                        if (this.check(j, rj) || this.check(i, ri, j, rj)) continue;
                        double pairEnergy = this.eE.get2Body(i, ri, j, rj) + this.eE.getSelf(i, ri) + this.eE.getSelf(j, rj);
                        assert (Double.isFinite(pairEnergy));
                        if (!(pairEnergy > toEliminate)) continue;
                        this.rO.logIfRank0(String.format(" Pruning pair %s-%d %s-%d by %s-%d %s-%d; energy %s > %s + %s", residueI.toString(rotI[ri]), ri, residueJ.toString(rotJ[rj]), rj, residueI.toString(rotI[minRI]), minRI, residueJ.toString(rotJ[minRJ]), minRJ, this.rO.formatEnergy(pairEnergy), this.rO.formatEnergy(threshold), this.rO.formatEnergy(minPair)));
                    }
                }
                this.pairsToSingleElimination(residues, i, j);
            }
        }
    }

    public void pruneSingleClashes(Residue[] residues) {
        if (!this.pruneClashes) {
            return;
        }
        for (int i = 0; i < residues.length; ++i) {
            Residue residue = residues[i];
            Rotamer[] rotamers = residue.getRotamers();
            int nRot = rotamers.length;
            double minEnergy = Double.MAX_VALUE;
            int minRot = -1;
            for (int ri = 0; ri < nRot; ++ri) {
                if (this.check(i, ri) || !(this.eE.getSelf(i, ri) < minEnergy)) continue;
                minEnergy = this.eE.getSelf(i, ri);
                minRot = ri;
            }
            double energyToPrune = residue instanceof MultiResidue ? this.multiResClashThreshold : this.clashThreshold;
            energyToPrune = residue.getResidueType() == Residue.ResidueType.NA ? energyToPrune * this.nucleicPruningFactor : energyToPrune;
            energyToPrune += minEnergy;
            for (int ri = 0; ri < nRot; ++ri) {
                if (this.check(i, ri) || !(this.eE.getSelf(i, ri) > energyToPrune) || !this.eliminateRotamer(residues, i, ri, this.print)) continue;
                this.rO.logIfRank0(String.format("  Rotamer (%7s,%2d) self-energy %s pruned by (%7s,%2d) %s.", residue.toString(rotamers[ri]), ri, this.rO.formatEnergy(this.eE.getSelf(i, ri)), residue.toString(rotamers[minRot]), minRot, this.rO.formatEnergy(minEnergy)));
            }
        }
    }

    public void setEnergyExpansion(EnergyExpansion eE) {
        this.eE = eE;
    }

    public String toString() {
        int rotamerCount = 0;
        int pairCount = 0;
        int singles = 0;
        int pairs = 0;
        int nRes = this.eliminatedSingles.length;
        for (int i = 0; i < nRes; ++i) {
            int nRotI = this.eliminatedSingles[i].length;
            rotamerCount += nRotI;
            for (int ri = 0; ri < nRotI; ++ri) {
                if (this.eliminatedSingles[i][ri]) {
                    ++singles;
                }
                for (int j = i + 1; j < nRes; ++j) {
                    int nRotJ = this.eliminatedPairs[i][ri][j].length;
                    pairCount += nRotJ;
                    for (int rj = 0; rj < nRotJ; ++rj) {
                        if (!this.eliminatedPairs[i][ri][j][rj]) continue;
                        ++pairs;
                    }
                }
            }
        }
        return String.format(" %d out of %d rotamers eliminated.\n", singles, rotamerCount) + String.format(" %d out of %d rotamer pairs eliminated.", pairs, pairCount);
    }

    public boolean validateDEE(Residue[] residues) {
        Residue residueI;
        int i;
        int nRes = this.eliminatedSingles.length;
        for (i = 0; i < nRes; ++i) {
            residueI = residues[i];
            int ni = this.eliminatedSingles[i].length;
            boolean valid = false;
            for (int ri = 0; ri < ni; ++ri) {
                if (this.check(i, ri)) continue;
                valid = true;
            }
            if (valid) continue;
            logger.severe(String.format(" Coding error: all %d rotamers for residue %s eliminated.", ni, residueI));
        }
        for (i = 0; i < nRes; ++i) {
            residueI = residues[i];
            Rotamer[] rotI = residueI.getRotamers();
            int ni = rotI.length;
            for (int j = i + 1; j < nRes; ++j) {
                Residue residueJ = residues[j];
                Rotamer[] rotJ = residueJ.getRotamers();
                int nj = rotJ.length;
                boolean valid = false;
                for (int ri = 0; ri < ni; ++ri) {
                    for (int rj = 0; rj < nj; ++rj) {
                        if (this.check(i, ri, j, rj)) continue;
                        valid = true;
                    }
                }
                if (valid) continue;
                logger.severe(String.format(" Coding error: all pairs for %s with residue %s eliminated.", residueI.toFormattedString(false, true), residueJ));
            }
        }
        return true;
    }

    private void allocateEliminationMemory(Residue[] residues) {
        int nRes = residues.length;
        this.eliminatedSingles = new boolean[nRes][];
        this.eliminatedPairs = new boolean[nRes][][][];
        this.rO.logIfRank0("\n     Residue  Nrot");
        for (int i = 0; i < nRes; ++i) {
            Residue residueI = residues[i];
            Rotamer[] rotamersI = residueI.getRotamers();
            int lenRi = rotamersI.length;
            this.rO.logIfRank0(String.format(" %3d %8s %4d", i + 1, residueI.toFormattedString(false, true), lenRi));
            this.eliminatedSingles[i] = new boolean[lenRi];
            this.eliminatedPairs[i] = new boolean[lenRi][][];
            for (int ri = 0; ri < lenRi; ++ri) {
                this.eliminatedSingles[i][ri] = false;
                this.eliminatedPairs[i][ri] = new boolean[nRes][];
                for (int j = i + 1; j < nRes; ++j) {
                    Residue residueJ = residues[j];
                    Rotamer[] rotamersJ = residueJ.getRotamers();
                    int lenRj = rotamersJ.length;
                    this.eliminatedPairs[i][ri][j] = new boolean[lenRj];
                    for (int rj = 0; rj < lenRj; ++rj) {
                        this.eliminatedPairs[i][ri][j][rj] = false;
                    }
                }
            }
        }
    }

    private boolean validRotamer(Residue[] residues, int i, int ri) {
        if (this.check(i, ri)) {
            return false;
        }
        if (this.maxRotCheckDepth > 1) {
            int n = residues.length;
            for (int j = 0; j < n; ++j) {
                if (j == i || this.rotamerPairCount(residues, i, ri, j) != 0) continue;
                return false;
            }
        }
        return true;
    }

    private int[] rotamerCount(Residue[] residues, int i) {
        int nRes = residues.length;
        Rotamer[] rotI = residues[i].getRotamers();
        int ni = rotI.length;
        if (this.maxRotCheckDepth == 0) {
            return IntStream.range(0, ni).toArray();
        }
        return IntStream.range(0, ni).filter(ri -> {
            if (this.check(i, ri)) {
                return false;
            }
            if (this.maxRotCheckDepth > 1) {
                for (int j = 0; j < nRes; ++j) {
                    if (i == j || this.rotamerPairCount(residues, i, ri, j) != 0) continue;
                    return false;
                }
            }
            return true;
        }).toArray();
    }

    private boolean validRotamerPair(Residue[] residues, int i, int ri, int j, int rj) {
        if (i == j) {
            return false;
        }
        if (!this.validRotamer(residues, i, ri) || !this.validRotamer(residues, j, rj)) {
            return false;
        }
        if (this.check(i, ri, j, rj)) {
            return false;
        }
        if (this.maxRotCheckDepth > 1) {
            int n = residues.length;
            for (int k = 0; k < n; ++k) {
                if (k == i || k == j || this.rotamerTripleCount(residues, i, ri, j, rj, k) != 0) continue;
                return false;
            }
        }
        return true;
    }

    private int rotamerPairCount(Residue[] residues, int i, int ri, int j) {
        if (i == j || this.check(i, ri)) {
            return 0;
        }
        int pairCount = 0;
        Rotamer[] rotJ = residues[j].getRotamers();
        int nj = rotJ.length;
        for (int rj = 0; rj < nj; ++rj) {
            if (this.check(j, rj) || this.check(i, ri, j, rj)) continue;
            int nRes = residues.length;
            boolean valid = true;
            if (this.maxRotCheckDepth > 2) {
                for (int k = 0; k < nRes; ++k) {
                    if (k == i || k == j || this.rotamerTripleCount(residues, i, ri, j, rj, k) != 0) continue;
                    valid = false;
                }
            }
            if (!valid) continue;
            ++pairCount;
        }
        return pairCount;
    }

    private int rotamerTripleCount(Residue[] residues, int i, int ri, int j, int rj, int k) {
        if (i == j || i == k || j == k) {
            return 0;
        }
        int tripleCount = 0;
        Rotamer[] rotK = residues[k].getRotamers();
        int nk = rotK.length;
        if (!(this.check(i, ri) || this.check(j, rj) || this.check(i, ri, j, rj))) {
            for (int rk = 0; rk < nk; ++rk) {
                if (this.check(k, rk) || this.check(i, ri, k, rk) || this.check(j, rj, k, rk)) continue;
                ++tripleCount;
            }
        }
        return tripleCount;
    }
}

