/*
 * Decompiled with CFR 0.152.
 */
package ffx.potential.openmm;

import edu.rit.mp.Buf;
import edu.rit.mp.CharacterBuf;
import edu.rit.pj.Comm;
import ffx.crystal.Crystal;
import ffx.numerics.Potential;
import ffx.potential.FiniteDifferenceUtils;
import ffx.potential.ForceFieldEnergy;
import ffx.potential.MolecularAssembly;
import ffx.potential.Platform;
import ffx.potential.Utilities;
import ffx.potential.bonded.Atom;
import ffx.potential.openmm.OpenMMContext;
import ffx.potential.openmm.OpenMMPotential;
import ffx.potential.openmm.OpenMMState;
import ffx.potential.openmm.OpenMMSystem;
import ffx.potential.parameters.ForceField;
import ffx.potential.utils.EnergyException;
import ffx.potential.utils.PotentialsUtils;
import java.io.File;
import java.io.IOException;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import org.apache.commons.configuration2.CompositeConfiguration;
import org.apache.commons.io.FilenameUtils;

public class OpenMMEnergy
extends ForceFieldEnergy
implements OpenMMPotential {
    private static final Logger logger = Logger.getLogger(OpenMMEnergy.class.getName());
    private final Platform platform;
    private OpenMMContext openMMContext;
    private OpenMMSystem openMMSystem;
    private final Atom[] atoms;
    private final boolean computeDEDL;

    public OpenMMEnergy(MolecularAssembly molecularAssembly, Platform requestedPlatform, int nThreads) {
        super(molecularAssembly, nThreads);
        Crystal crystal = this.getCrystal();
        int symOps = crystal.spaceGroup.getNumberOfSymOps();
        if (symOps > 1) {
            logger.severe(" OpenMM does not support symmetry operators.");
        }
        logger.info("\n Initializing OpenMM");
        ForceField forceField = molecularAssembly.getForceField();
        this.atoms = molecularAssembly.getAtomArray();
        this.platform = requestedPlatform;
        ffx.openmm.Platform openMMPlatform = OpenMMContext.loadPlatform(this.platform, forceField);
        this.openMMSystem = new OpenMMSystem(this);
        this.openMMSystem.addForces();
        this.openMMContext = new OpenMMContext(openMMPlatform, this.openMMSystem, this.atoms);
        this.computeDEDL = forceField.getBoolean("OMM_DUDL", false);
    }

    public static int getDefaultDevice(CompositeConfiguration props) {
        int index;
        int[] devs;
        block11: {
            String availDeviceProp = props.getString("availableDevices", props.getString("CUDA_DEVICES"));
            if (availDeviceProp == null) {
                int nDevs = props.getInt("numCudaDevices", 1);
                availDeviceProp = IntStream.range(0, nDevs).mapToObj(Integer::toString).collect(Collectors.joining(" "));
            }
            availDeviceProp = availDeviceProp.trim();
            String[] availDevices = availDeviceProp.split("\\s+");
            int nDevs = availDevices.length;
            devs = new int[nDevs];
            for (int i = 0; i < nDevs; ++i) {
                devs[i] = Integer.parseInt(availDevices[i]);
            }
            logger.info(String.format(" Available devices: %d.", nDevs));
            if (nDevs == 1) {
                return devs[0];
            }
            index = 0;
            try {
                Comm world = Comm.world();
                if (world == null) break block11;
                int size = world.size();
                logger.fine(String.format(" Number of MPI processes %d exceeds number of available devices %d.", size, nDevs));
                int messageLen = 100;
                String host = world.host();
                host = host.substring(0, Math.min(messageLen, host.length()));
                host = String.format("%-100s", host);
                logger.fine(String.format(" Host: %s", host.trim()));
                char[] messageOut = host.toCharArray();
                CharacterBuf out = CharacterBuf.buffer((char[])messageOut);
                char[][] incoming = new char[size][messageLen];
                CharacterBuf[] in = new CharacterBuf[size];
                for (int i = 0; i < size; ++i) {
                    in[i] = CharacterBuf.buffer((char[])incoming[i]);
                }
                try {
                    logger.fine(" AllGather for determining rank.");
                    world.allGather((Buf)out, (Buf[])in);
                    logger.fine(" AllGather complete.");
                }
                catch (IOException ex) {
                    logger.warning(String.format(" Failure at the allGather step for determining rank: %s\n%s", ex, Utilities.stackTraceToString(ex)));
                }
                int ownIndex = -1;
                int rank = world.rank();
                boolean selfFound = false;
                for (int i = 0; i < size; ++i) {
                    String hostI = new String(incoming[i]);
                    if (!hostI.equalsIgnoreCase(host)) continue;
                    ++ownIndex;
                    if (i != rank) continue;
                    selfFound = true;
                    break;
                }
                if (!selfFound) {
                    logger.warning(String.format(" Rank %d: Could not find any incoming host messages matching self %s!", rank, host.trim()));
                } else {
                    index = ownIndex % nDevs;
                }
            }
            catch (IllegalStateException illegalStateException) {
                // empty catch block
            }
        }
        return devs[index];
    }

    @Override
    public void updateContext(String integratorName, double timeStep, double temperature, boolean forceCreation) {
        this.openMMContext.update(integratorName, timeStep, temperature, forceCreation);
    }

    @Override
    public OpenMMState getOpenMMState(int mask) {
        return this.openMMContext.getOpenMMState(mask);
    }

    @Override
    public boolean destroy() {
        boolean ffxFFEDestroy = super.destroy();
        this.free();
        logger.fine(" Destroyed the Context and OpenMMSystem.");
        return ffxFFEDestroy;
    }

    @Override
    public double energy(double[] x) {
        return this.energy(x, false);
    }

    @Override
    public double energy(double[] x, boolean verbose) {
        if (this.lambdaBondedTerms) {
            return 0.0;
        }
        this.openMMContext.update();
        this.updateParameters(this.atoms);
        this.unscaleCoordinates(x);
        this.setCoordinates(x);
        OpenMMState openMMState = this.openMMContext.getOpenMMState(8);
        double e = openMMState.potentialEnergy;
        openMMState.destroy();
        if (!Double.isFinite(e)) {
            String message = String.format(" Energy from OpenMM was a non-finite %8g", e);
            logger.warning(message);
            throw new EnergyException(message);
        }
        if (verbose) {
            logger.log(Level.INFO, String.format("\n OpenMM Energy: %14.10g", e));
        }
        this.scaleCoordinates(x);
        return e;
    }

    public double energyFFX(double[] x) {
        return super.energy(x, false);
    }

    public double energyFFX(double[] x, boolean verbose) {
        return super.energy(x, verbose);
    }

    @Override
    public double energyAndGradient(double[] x, double[] g) {
        return this.energyAndGradient(x, g, false);
    }

    @Override
    public double energyAndGradient(double[] x, double[] g, boolean verbose) {
        boolean extremeGrad;
        if (this.lambdaBondedTerms) {
            return 0.0;
        }
        this.unscaleCoordinates(x);
        this.openMMContext.update();
        this.setCoordinates(x);
        OpenMMState openMMState = this.openMMContext.getOpenMMState(12);
        double e = openMMState.potentialEnergy;
        g = openMMState.getGradient(g);
        openMMState.destroy();
        if (!Double.isFinite(e)) {
            String message = String.format(" Energy from OpenMM was a non-finite %8g", e);
            logger.warning(message);
            throw new EnergyException(message);
        }
        if (this.maxDebugGradient < Double.MAX_VALUE && (extremeGrad = Arrays.stream(g).anyMatch(gi -> gi > this.maxDebugGradient || gi < -this.maxDebugGradient))) {
            File origFile = this.molecularAssembly.getFile();
            String timeString = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyy_MM_dd-HH_mm_ss"));
            String filename = String.format("%s-LARGEGRAD-%s.pdb", FilenameUtils.removeExtension((String)this.molecularAssembly.getFile().getName()), timeString);
            PotentialsUtils ef = new PotentialsUtils();
            filename = ef.versionFile(filename);
            logger.warning(String.format(" Excessively large gradients detected; printing snapshot to file %s", filename));
            ef.saveAsPDB(this.molecularAssembly, new File(filename));
            this.molecularAssembly.setFile(origFile);
        }
        if (verbose) {
            logger.log(Level.INFO, String.format("\n OpenMM Energy: %14.10g", e));
        }
        this.scaleCoordinatesAndGradient(x, g);
        return e;
    }

    public double energyAndGradientFFX(double[] x, double[] g) {
        return super.energyAndGradient(x, g, false);
    }

    public double energyAndGradientFFX(double[] x, double[] g, boolean verbose) {
        return super.energyAndGradient(x, g, verbose);
    }

    @Override
    public OpenMMContext getContext() {
        return this.openMMContext;
    }

    @Override
    public MolecularAssembly getMolecularAssembly() {
        return this.molecularAssembly;
    }

    @Override
    public double[] getGradient(double[] g) {
        OpenMMState openMMState = this.openMMContext.getOpenMMState(4);
        g = openMMState.getGradient(g);
        openMMState.destroy();
        return g;
    }

    @Override
    public Platform getPlatform() {
        return this.platform;
    }

    @Override
    public OpenMMSystem getSystem() {
        return this.openMMSystem;
    }

    @Override
    public double getd2EdL2() {
        return 0.0;
    }

    @Override
    public double getdEdL() {
        if (!this.lambdaTerm || !this.computeDEDL) {
            return 0.0;
        }
        return FiniteDifferenceUtils.computedEdL((Potential)this, this, this.molecularAssembly.getForceField());
    }

    @Override
    public void getdEdXdL(double[] gradients) {
    }

    @Override
    public boolean setActiveAtoms() {
        return this.openMMSystem.updateAtomMass();
    }

    @Override
    public void setCoordinates(double[] x) {
        super.setCoordinates(x);
        int n = this.atoms.length * 3;
        double[] xall = new double[n];
        int i = 0;
        for (Atom atom : this.atoms) {
            xall[i] = atom.getX();
            xall[i + 1] = atom.getY();
            xall[i + 2] = atom.getZ();
            i += 3;
        }
        this.openMMContext.setPositions(xall);
    }

    @Override
    public void setVelocity(double[] v) {
        super.setVelocity(v);
        int n = this.atoms.length * 3;
        double[] vall = new double[n];
        double[] v3 = new double[3];
        int i = 0;
        for (Atom atom : this.atoms) {
            atom.getVelocity(v3);
            vall[i] = v3[0];
            vall[i + 1] = v3[1];
            vall[i + 2] = v3[2];
            i += 3;
        }
        this.openMMContext.setVelocities(vall);
    }

    @Override
    public void setCrystal(Crystal crystal) {
        super.setCrystal(crystal);
        this.openMMContext.setPeriodicBoxVectors(crystal);
    }

    @Override
    public void setLambda(double lambda) {
        if (!this.lambdaTerm) {
            logger.fine(" Attempting to set lambda for an OpenMMEnergy with lambdaterm false.");
            return;
        }
        super.setLambda(lambda);
        if (this.atoms != null) {
            ArrayList<Atom> atomList = new ArrayList<Atom>();
            for (Atom atom : this.atoms) {
                if (!atom.applyLambda()) continue;
                atomList.add(atom);
            }
            this.updateParameters(atomList.toArray(new Atom[0]));
        } else {
            this.updateParameters(null);
        }
    }

    @Override
    public void updateParameters(@Nullable Atom[] atoms) {
        if (atoms == null) {
            atoms = this.atoms;
        }
        if (this.openMMSystem != null) {
            this.openMMSystem.updateParameters(atoms);
        }
    }

    private void free() {
        if (this.openMMContext != null) {
            this.openMMContext.free();
            this.openMMContext = null;
        }
        if (this.openMMSystem != null) {
            this.openMMSystem.free();
            this.openMMSystem = null;
        }
    }
}

