/*
 * Decompiled with CFR 0.152.
 */
package ffx.potential.utils;

import edu.rit.mp.Buf;
import edu.rit.mp.DoubleBuf;
import edu.rit.pj.Comm;
import ffx.crystal.SymOp;
import ffx.numerics.math.Double3;
import ffx.numerics.math.RunningStatistics;
import ffx.potential.AssemblyState;
import ffx.potential.ForceFieldEnergy;
import ffx.potential.MolecularAssembly;
import ffx.potential.bonded.Atom;
import ffx.potential.parsers.DistanceMatrixFilter;
import ffx.potential.parsers.SystemFilter;
import ffx.potential.parsers.XYZFilter;
import ffx.utilities.StringUtils;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.FastMath;

public class Superpose {
    private static final Logger logger = Logger.getLogger(Superpose.class.getName());
    private final SystemFilter baseFilter;
    private final SystemFilter targetFilter;
    private final boolean isSymmetric;
    private final int baseSize;
    private final int targetSize;
    private int restartRow;
    private int restartColumn;
    private final double[] distRow;
    private final Comm world;
    private final int numProc;
    private final int rank;
    private final double[][] distances;
    private final DoubleBuf[] buffers;
    private final double[] myDistance;
    private final DoubleBuf myBuffer;

    public Superpose(SystemFilter baseFilter, SystemFilter targetFilter, boolean isSymmetric) {
        int i;
        this.baseFilter = baseFilter;
        this.targetFilter = targetFilter;
        this.isSymmetric = isSymmetric;
        this.baseSize = baseFilter.countNumModels();
        this.targetSize = targetFilter.countNumModels();
        this.distRow = new double[this.targetSize];
        Arrays.fill(this.distRow, -1.0);
        this.world = Comm.world();
        this.numProc = this.world.size();
        this.rank = this.world.rank();
        if (this.numProc > 1) {
            logger.info(String.format(" Number of MPI Processes:  %d", this.numProc));
            logger.info(String.format(" Rank of this MPI Process: %d", this.rank));
        }
        this.distances = new double[this.numProc][1];
        for (i = 0; i < this.numProc; ++i) {
            Arrays.fill(this.distances[i], -1.0);
        }
        this.buffers = new DoubleBuf[this.numProc];
        for (i = 0; i < this.numProc; ++i) {
            this.buffers[i] = DoubleBuf.buffer((double[])this.distances[i]);
        }
        this.myDistance = this.distances[this.rank];
        this.myBuffer = this.buffers[this.rank];
    }

    public void calculateRMSDs(int[] usedIndices, boolean dRMSD, boolean verbose, boolean restart, boolean write, boolean saveSnapshots, boolean printSym) {
        RunningStatistics runningStatistics;
        String filename = this.baseFilter.getFile().getAbsolutePath();
        String targetFilename = this.targetFilter.getFile().getAbsolutePath();
        File baseOutputFile = null;
        File targetOutputFile = null;
        XYZFilter baseOutputFilter = null;
        XYZFilter targetOutputFilter = null;
        if (saveSnapshots) {
            String baseOutputName = FilenameUtils.concat((String)FilenameUtils.getFullPath((String)filename), (String)(FilenameUtils.getBaseName((String)filename) + "_superposed.arc"));
            String targetOutputName = FilenameUtils.concat((String)FilenameUtils.getFullPath((String)filename), (String)(FilenameUtils.getBaseName((String)targetFilename) + "_superposed.arc"));
            baseOutputFile = SystemFilter.version(new File(baseOutputName));
            targetOutputFile = SystemFilter.version(new File(targetOutputName));
            MolecularAssembly baseAssembly = this.baseFilter.getActiveMolecularSystem();
            MolecularAssembly targetAssembly = this.targetFilter.getActiveMolecularSystem();
            baseOutputFilter = new XYZFilter(baseOutputFile, baseAssembly, baseAssembly.getForceField(), baseAssembly.getProperties());
            targetOutputFilter = new XYZFilter(targetOutputFile, targetAssembly, targetAssembly.getForceField(), targetAssembly.getProperties());
        }
        String matrixFilename = FilenameUtils.concat((String)FilenameUtils.getFullPath((String)filename), (String)(FilenameUtils.getBaseName((String)filename) + ".dst"));
        if (restart) {
            runningStatistics = this.readMatrix(matrixFilename, this.isSymmetric, this.baseSize, this.targetSize);
            if (runningStatistics == null) {
                runningStatistics = new RunningStatistics();
            }
        } else {
            runningStatistics = new RunningStatistics();
            File file = new File(matrixFilename);
            if (file.exists() && file.delete()) {
                logger.info(String.format(" RMSD file (%s) was deleted.", matrixFilename));
                logger.info(" To restart from a previous run, use the '-r' flag.");
            }
        }
        int nUsed = usedIndices.length;
        int nUsedVars = nUsed * 3;
        MolecularAssembly baseMolecularAssembly = this.baseFilter.getActiveMolecularSystem();
        ForceFieldEnergy baseForceFieldEnergy = baseMolecularAssembly.getPotentialEnergy();
        int nVars = baseForceFieldEnergy.getNumberOfVariables();
        double[] baseCoords = new double[nVars];
        baseForceFieldEnergy.getCoordinates(baseCoords);
        double[] baseUsedCoords = new double[nUsedVars];
        Atom[] atoms = baseMolecularAssembly.getAtomArray();
        double[] mass = new double[nUsed];
        for (int i = 0; i < nUsed; ++i) {
            mass[i] = atoms[usedIndices[i]].getMass();
        }
        MolecularAssembly targetMolecularAssembly = this.targetFilter.getActiveMolecularSystem();
        ForceFieldEnergy targetForceFieldEnergy = targetMolecularAssembly.getPotentialEnergy();
        double[] targetCoords = new double[nVars];
        double[] targetUsedCoords = new double[nUsedVars];
        for (int row = 0; row < this.restartRow; ++row) {
            this.baseFilter.readNext(false, false);
        }
        int paddedTargetSize = this.targetSize;
        int extra = this.targetSize % this.numProc;
        if (extra != 0) {
            paddedTargetSize = this.targetSize - extra + this.numProc;
            logger.fine(String.format(" Target size %d vs. Padded size %d", this.targetSize, paddedTargetSize));
        }
        for (int row = this.restartRow; row < this.baseSize; ++row) {
            this.myDistance[0] = -1.0;
            if (row == this.restartRow) {
                if (dRMSD) {
                    logger.info("\n Coordinate RMSD\n Snapshots       Original   After Translation   After Rotation     dRMSD");
                } else if (verbose) {
                    logger.info("\n Coordinate RMSD\n Snapshots       Original   After Translation   After Rotation");
                }
            }
            for (int column = this.restartColumn; column < paddedTargetSize; ++column) {
                if (column < this.targetSize) {
                    int targetRank = column % this.numProc;
                    if (targetRank == this.rank) {
                        if (this.isSymmetric && row == column) {
                            this.myDistance[0] = 0.0;
                            if (verbose) {
                                logger.info(String.format(" %6d  %6d  %s                             %8.5f", row + 1, column + 1, "Diagonal", this.myDistance[0]));
                            }
                        } else if (this.isSymmetric && row > column) {
                            this.myDistance[0] = -1.0;
                        } else {
                            baseForceFieldEnergy.getCoordinates(baseCoords);
                            AssemblyState origStateB = new AssemblyState(targetMolecularAssembly);
                            targetForceFieldEnergy.getCoordinates(targetCoords);
                            Superpose.extractCoordinates(usedIndices, baseCoords, baseUsedCoords);
                            Superpose.extractCoordinates(usedIndices, targetCoords, targetUsedCoords);
                            double origRMSD = Superpose.rmsd(baseUsedCoords, targetUsedCoords, mass);
                            double[] baseTranslation = Superpose.calculateTranslation(baseUsedCoords, mass);
                            Superpose.applyTranslation(baseCoords, baseTranslation);
                            double[] targetTranslation = Superpose.calculateTranslation(targetUsedCoords, mass);
                            Superpose.applyTranslation(targetCoords, targetTranslation);
                            Superpose.extractCoordinates(usedIndices, baseCoords, baseUsedCoords);
                            Superpose.extractCoordinates(usedIndices, targetCoords, targetUsedCoords);
                            double translatedRMSD = Superpose.rmsd(baseUsedCoords, targetUsedCoords, mass);
                            double[][] rotation = Superpose.calculateRotation(baseUsedCoords, targetUsedCoords, mass);
                            Superpose.applyRotation(targetCoords, rotation);
                            Superpose.extractCoordinates(usedIndices, targetCoords, targetUsedCoords);
                            double rotatedRMSD = Superpose.rmsd(baseUsedCoords, targetUsedCoords, mass);
                            if (dRMSD) {
                                double disRMSD = Superpose.calcDRMSD(baseUsedCoords, targetUsedCoords, nUsed * 3);
                                logger.info(String.format(" %6d  %6d  %8.5f            %8.5f         %8.5f  %8.5f", row + 1, column + 1, origRMSD, translatedRMSD, rotatedRMSD, disRMSD));
                            } else if (verbose) {
                                logger.info(String.format(" %6d  %6d  %8.5f            %8.5f         %8.5f", row + 1, column + 1, origRMSD, translatedRMSD, rotatedRMSD));
                            }
                            this.myDistance[0] = rotatedRMSD;
                            if (printSym) {
                                int i;
                                StringBuilder sbSO = new StringBuilder(String.format("\n Sym Op to move %s onto %s:\nsymop ", targetFilename, filename));
                                StringBuilder sbInv = new StringBuilder(String.format("\n Inverted Sym Op to move %s onto %s:\nsymop ", filename, targetFilename));
                                SymOp bestBaseSymOp = new SymOp(SymOp.ZERO_ROTATION, SymOp.Tr_0_0_0).append(new SymOp(SymOp.ZERO_ROTATION, baseTranslation).append(SymOp.invertSymOp((SymOp)new SymOp(rotation, targetTranslation))));
                                double[] inverseBaseTranslation = new double[]{-baseTranslation[0], -baseTranslation[1], -baseTranslation[2]};
                                SymOp bestTargetSymOp = new SymOp(SymOp.ZERO_ROTATION, SymOp.Tr_0_0_0).append(new SymOp(SymOp.ZERO_ROTATION, targetTranslation).append(new SymOp(rotation, inverseBaseTranslation)));
                                ArrayList<Integer> mol1List = new ArrayList<Integer>();
                                ArrayList<Integer> mol2List = new ArrayList<Integer>();
                                Atom[] atomArr1 = baseMolecularAssembly.getAtomArray();
                                Atom[] atomArr2 = targetMolecularAssembly.getAtomArray();
                                int nAtoms = atomArr1.length;
                                for (i = 0; i < nAtoms; ++i) {
                                    if (!atomArr1[i].isActive()) continue;
                                    mol1List.add(i);
                                }
                                nAtoms = atomArr2.length;
                                for (i = 0; i < nAtoms; ++i) {
                                    if (!atomArr2[i].isActive()) continue;
                                    mol2List.add(i);
                                }
                                int[] mol1arr = mol1List.stream().mapToInt(Integer::intValue).toArray();
                                int[] mol2arr = mol2List.stream().mapToInt(Integer::intValue).toArray();
                                sbSO.append(String.format("    %s     %s", StringUtils.writeAtomRanges((int[])mol2arr), StringUtils.writeAtomRanges((int[])mol1arr))).append(SymOp.asMatrixString((SymOp)bestBaseSymOp));
                                sbInv.append(String.format("    %s     %s", StringUtils.writeAtomRanges((int[])mol2arr), StringUtils.writeAtomRanges((int[])mol1arr))).append(SymOp.asMatrixString((SymOp)bestTargetSymOp));
                                logger.info(String.format(" %s\n %s", sbSO, sbInv));
                            }
                            if (saveSnapshots && this.numProc == 1) {
                                MolecularAssembly molecularAssembly = targetOutputFilter.getActiveMolecularSystem();
                                molecularAssembly.getPotentialEnergy().setCoordinates(targetCoords);
                                targetOutputFilter.writeFile(targetOutputFile, true);
                                molecularAssembly = baseOutputFilter.getActiveMolecularSystem();
                                molecularAssembly.getPotentialEnergy().setCoordinates(baseCoords);
                                baseOutputFilter.writeFile(baseOutputFile, true);
                                origStateB.revertState();
                            }
                        }
                    }
                    this.targetFilter.readNext(false, false);
                }
                if ((column + 1) % this.numProc != 0) continue;
                this.gatherRMSDs(row, column, runningStatistics);
            }
            this.restartColumn = 0;
            this.targetFilter.readNext(true, false);
            this.baseFilter.readNext(false, false);
            if (this.rank != 0 || !write) continue;
            int firstColumn = 0;
            if (this.isSymmetric) {
                firstColumn = row;
            }
            DistanceMatrixFilter.writeDistanceMatrixRow(matrixFilename, this.distRow, firstColumn);
        }
        this.baseFilter.closeReader();
        this.targetFilter.closeReader();
        logger.info(String.format(" RMSD Minimum:  %8.6f", runningStatistics.getMin()));
        logger.info(String.format(" RMSD Maximum:  %8.6f", runningStatistics.getMax()));
        logger.info(String.format(" RMSD Mean:     %8.6f", runningStatistics.getMean()));
        double variance = runningStatistics.getVariance();
        if (!Double.isNaN(variance)) {
            logger.info(String.format(" RMSD Variance: %8.6f", variance));
        }
    }

    private void gatherRMSDs(int row, int column, RunningStatistics runningStatistics) {
        try {
            logger.finer(" Receiving results.");
            this.world.allGather((Buf)this.myBuffer, (Buf[])this.buffers);
            for (int i = 0; i < this.numProc; ++i) {
                int c = column + 1 - this.numProc + i;
                if (c >= this.targetSize) continue;
                this.distRow[c] = this.distances[i][0];
                if (!this.isSymmetric) {
                    runningStatistics.addValue(this.distRow[c]);
                } else if (c > row) {
                    runningStatistics.addValue(this.distRow[c]);
                }
                logger.finer(String.format(" %d %d %16.8f", row, c, this.distances[i][0]));
            }
        }
        catch (Exception e) {
            logger.severe(" Exception collecting distance values." + String.valueOf(e));
        }
    }

    private RunningStatistics readMatrix(String filename, boolean isSymmetric, int expectedRows, int expectedColumns) {
        this.restartRow = 0;
        this.restartColumn = 0;
        DistanceMatrixFilter distanceMatrixFilter = new DistanceMatrixFilter();
        RunningStatistics runningStatistics = distanceMatrixFilter.readDistanceMatrix(filename, expectedRows, expectedColumns);
        if (runningStatistics != null && runningStatistics.getCount() > 0L) {
            this.restartRow = distanceMatrixFilter.getRestartRow();
            this.restartColumn = distanceMatrixFilter.getRestartColumn();
            if (isSymmetric) {
                if (this.restartRow == expectedRows && this.restartColumn == 1) {
                    logger.info(String.format(" Complete symmetric distance matrix found (%d x %d).", this.restartRow, this.restartRow));
                } else {
                    this.restartColumn = 0;
                    logger.info(String.format(" Incomplete symmetric distance matrix found.\n Restarting at row %d, column %d.", this.restartRow + 1, this.restartColumn + 1));
                }
            } else if (this.restartRow == expectedRows && this.restartColumn == expectedColumns) {
                logger.info(String.format(" Complete distance matrix found (%d x %d).", this.restartRow, this.restartColumn));
            } else {
                this.restartColumn = 0;
                logger.info(String.format(" Incomplete distance matrix found.\n Restarting at row %d, column %d.", this.restartRow + 1, this.restartColumn + 1));
            }
        }
        return runningStatistics;
    }

    public static void extractCoordinates(int[] usedIndices, double[] x, double[] xUsed) {
        int nUsed = usedIndices.length;
        for (int u = 0; u < nUsed; ++u) {
            int u3 = 3 * u;
            int i3 = 3 * usedIndices[u];
            System.arraycopy(x, i3, xUsed, u3, 3);
        }
    }

    public static double calcDRMSD(double[] xUsed, double[] x2Used, int nUsed) {
        double disRMSD = 0.0;
        int counter = 0;
        for (int i = 0; i < nUsed; i += 3) {
            Double3 xi = new Double3(xUsed[i], xUsed[i + 1], xUsed[i + 2]);
            Double3 x2i = new Double3(x2Used[i], x2Used[i + 1], x2Used[i + 2]);
            for (int j = i + 3; j < nUsed; j += 3) {
                Double3 xj = new Double3(xUsed[j], xUsed[j + 1], xUsed[j + 2]);
                Double3 x2j = new Double3(x2Used[j], x2Used[j + 1], x2Used[j + 2]);
                double dis1 = xi.sub(xj).length();
                double dis2 = x2i.sub(x2j).length();
                double diff = dis1 - dis2;
                disRMSD += diff * diff;
                ++counter;
            }
        }
        return FastMath.sqrt((double)(disRMSD /= (double)counter));
    }

    public static void applyRotation(double[] x2, double[][] rot) {
        int n = x2.length / 3;
        for (int i = 0; i < n; ++i) {
            int k = i * 3;
            double xrot = x2[k] * rot[0][0] + x2[k + 1] * rot[0][1] + x2[k + 2] * rot[0][2];
            double yrot = x2[k] * rot[1][0] + x2[k + 1] * rot[1][1] + x2[k + 2] * rot[1][2];
            double zrot = x2[k] * rot[2][0] + x2[k + 1] * rot[2][1] + x2[k + 2] * rot[2][2];
            x2[k] = xrot;
            x2[k + 1] = yrot;
            x2[k + 2] = zrot;
        }
    }

    public static void applyTranslation(double[] x, double[] translation) {
        int n = x.length / 3;
        for (int i = 0; i < n; ++i) {
            int k = i * 3;
            for (int j = 0; j < 3; ++j) {
                int n2 = k + j;
                x[n2] = x[n2] + translation[j];
            }
        }
    }

    public static double[][] calculateRotation(double[] x1, double[] x2, double[] mass) {
        double xxyx = 0.0;
        double xxyy = 0.0;
        double xxyz = 0.0;
        double xyyx = 0.0;
        double xyyy = 0.0;
        double xyyz = 0.0;
        double xzyx = 0.0;
        double xzyy = 0.0;
        double xzyz = 0.0;
        int n = x1.length / 3;
        for (int i = 0; i < n; ++i) {
            int k = i * 3;
            double weigh = mass[i];
            xxyx += weigh * x1[k] * x2[k];
            xxyy += weigh * x1[k + 1] * x2[k];
            xxyz += weigh * x1[k + 2] * x2[k];
            xyyx += weigh * x1[k] * x2[k + 1];
            xyyy += weigh * x1[k + 1] * x2[k + 1];
            xyyz += weigh * x1[k + 2] * x2[k + 1];
            xzyx += weigh * x1[k] * x2[k + 2];
            xzyy += weigh * x1[k + 1] * x2[k + 2];
            xzyz += weigh * x1[k + 2] * x2[k + 2];
        }
        double[][] c = new double[4][4];
        c[0][0] = xxyx + xyyy + xzyz;
        c[0][1] = xzyy - xyyz;
        c[1][0] = c[0][1];
        c[1][1] = xxyx - xyyy - xzyz;
        c[0][2] = xxyz - xzyx;
        c[2][0] = c[0][2];
        c[1][2] = xxyy + xyyx;
        c[2][1] = c[1][2];
        c[2][2] = xyyy - xzyz - xxyx;
        c[0][3] = xyyx - xxyy;
        c[3][0] = c[0][3];
        c[1][3] = xzyx + xxyz;
        c[3][1] = c[1][3];
        c[2][3] = xyyz + xzyy;
        c[3][2] = c[2][3];
        c[3][3] = xzyz - xxyx - xyyy;
        Array2DRowRealMatrix cMatrix = new Array2DRowRealMatrix(c, false);
        EigenDecomposition eigenDecomposition = new EigenDecomposition((RealMatrix)cMatrix);
        double[] q = eigenDecomposition.getEigenvector(0).toArray();
        double q02 = q[0] * q[0];
        double q12 = q[1] * q[1];
        double q22 = q[2] * q[2];
        double q32 = q[3] * q[3];
        double[][] rot = new double[3][3];
        rot[0][0] = q02 + q12 - q22 - q32;
        rot[1][0] = 2.0 * (q[1] * q[2] - q[0] * q[3]);
        rot[2][0] = 2.0 * (q[1] * q[3] + q[0] * q[2]);
        rot[0][1] = 2.0 * (q[1] * q[2] + q[0] * q[3]);
        rot[1][1] = q02 - q12 + q22 - q32;
        rot[2][1] = 2.0 * (q[2] * q[3] - q[0] * q[1]);
        rot[0][2] = 2.0 * (q[1] * q[3] - q[0] * q[2]);
        rot[1][2] = 2.0 * (q[2] * q[3] + q[0] * q[1]);
        rot[2][2] = q02 - q12 - q22 + q32;
        return rot;
    }

    public static double[] calculateTranslation(double[] x, double[] mass) {
        double xmid = 0.0;
        double ymid = 0.0;
        double zmid = 0.0;
        double norm = 0.0;
        int n = x.length / 3;
        for (int i = 0; i < n; ++i) {
            int k = 3 * i;
            double weigh = mass[i];
            xmid += x[k] * weigh;
            ymid += x[k + 1] * weigh;
            zmid += x[k + 2] * weigh;
            norm += weigh;
        }
        return new double[]{-(xmid /= norm), -(ymid /= norm), -(zmid /= norm)};
    }

    public static double rmsd(double[] x1, double[] x2, double[] mass) {
        double rmsfit = 0.0;
        double norm = 0.0;
        int n = x1.length / 3;
        for (int i = 0; i < n; ++i) {
            int k = 3 * i;
            double weigh = mass[i];
            double xr = x1[k] - x2[k];
            double yr = x1[k + 1] - x2[k + 1];
            double zr = x1[k + 2] - x2[k + 2];
            double dist2 = xr * xr + yr * yr + zr * zr;
            norm += weigh;
            double rmsterm = dist2 * weigh;
            rmsfit += rmsterm;
        }
        return FastMath.sqrt((double)(rmsfit / norm));
    }

    public static void rotate(double[] x1, double[] x2, double[] mass) {
        double[][] rotation = Superpose.calculateRotation(x1, x2, mass);
        Superpose.applyRotation(x2, rotation);
    }

    public static void translate(double[] x1, double[] mass1, double[] x2, double[] mass2) {
        Superpose.translate(x1, mass1);
        Superpose.translate(x2, mass2);
    }

    public static void translate(double[] x, double[] mass) {
        double[] translation = Superpose.calculateTranslation(x, mass);
        Superpose.applyTranslation(x, translation);
    }

    public static double superpose(double[] x1, double[] x2, double[] mass) {
        Superpose.translate(x1, mass);
        Superpose.translate(x2, mass);
        Superpose.rotate(x1, x2, mass);
        return Superpose.rmsd(x1, x2, mass);
    }
}

