/*
 * Decompiled with CFR 0.152.
 */
package ffx.potential.openmm;

import com.sun.jna.ptr.PointerByReference;
import edu.rit.pj.Comm;
import edu.uiowa.jopenmm.OpenMMLibrary;
import edu.uiowa.jopenmm.OpenMMUtils;
import edu.uiowa.jopenmm.OpenMM_Vec3;
import ffx.crystal.Crystal;
import ffx.numerics.Potential;
import ffx.openmm.Context;
import ffx.openmm.Integrator;
import ffx.openmm.MinimizationReporter;
import ffx.openmm.State;
import ffx.openmm.StringArray;
import ffx.openmm.System;
import ffx.potential.Platform;
import ffx.potential.bonded.Atom;
import ffx.potential.openmm.OpenMMEnergy;
import ffx.potential.openmm.OpenMMIntegrator;
import ffx.potential.openmm.OpenMMState;
import ffx.potential.openmm.OpenMMSystem;
import ffx.potential.parameters.ForceField;
import java.util.logging.Level;
import java.util.logging.Logger;

public class OpenMMContext
extends Context {
    private static final Logger logger = Logger.getLogger(OpenMMContext.class.getName());
    private final OpenMMSystem openMMSystem;
    private String integratorName = "VERLET";
    private double timeStep = 0.001;
    private double temperature = 298.15;
    private final int enforcePBC;
    private final Atom[] atoms;

    public OpenMMContext(ffx.openmm.Platform platform, OpenMMSystem openMMSystem, Atom[] atoms) {
        super((System)openMMSystem, OpenMMIntegrator.createIntegrator("VERLET", 0.001, 298.15, openMMSystem), platform);
        this.openMMSystem = openMMSystem;
        this.atoms = atoms;
        ForceField forceField = openMMSystem.getForceField();
        boolean aperiodic = openMMSystem.getCrystal().aperiodic();
        boolean pbcEnforced = forceField.getBoolean("ENFORCE_PBC", !aperiodic);
        this.enforcePBC = pbcEnforced ? 1 : 0;
    }

    public void update(String integratorName, double timeStep, double temperature, boolean forceCreation) {
        if (this.hasContextPointer() && !forceCreation && this.temperature == temperature && this.timeStep == timeStep && this.integratorName.equalsIgnoreCase(integratorName)) {
            return;
        }
        this.integratorName = integratorName;
        this.timeStep = timeStep;
        this.temperature = temperature;
        logger.info("\n Updating OpenMM Context");
        Integrator newIntegrator = OpenMMIntegrator.createIntegrator(integratorName, timeStep, temperature, this.openMMSystem);
        ffx.openmm.Platform newPlatform = new ffx.openmm.Platform(this.platform.getName());
        this.updateContext(this.openMMSystem, newIntegrator, newPlatform);
        int nVar = this.atoms.length * 3;
        double[] x = new double[nVar];
        double[] v = new double[nVar];
        double[] vel3 = new double[3];
        int index = 0;
        for (Atom a : this.atoms) {
            a.getVelocity(vel3);
            x[index] = a.getX();
            v[index++] = vel3[0];
            x[index] = a.getY();
            v[index++] = vel3[1];
            x[index] = a.getZ();
            v[index++] = vel3[2];
        }
        Crystal crystal = this.openMMSystem.getCrystal();
        this.setPeriodicBoxVectors(crystal);
        this.setPositions(x);
        this.setVelocities(v);
        this.applyConstraints(1.0E-4);
        OpenMMState openMMState = this.getOpenMMState(3);
        Potential energy = this.openMMSystem.getPotential();
        energy.setCoordinates(openMMState.getActivePositions(null, this.atoms));
        energy.setVelocity(openMMState.getActiveVelocities(null, this.atoms));
        openMMState.destroy();
    }

    public void update() {
        if (!this.hasContextPointer()) {
            logger.info(" Delayed creation of OpenMM Context.");
            this.update(this.integratorName, this.timeStep, this.temperature, true);
        }
    }

    public OpenMMState getOpenMMState(int mask) {
        State state = this.getState(mask, this.enforcePBC);
        return new OpenMMState(state.getPointer());
    }

    public void integrate(int numSteps) {
        Integrator integrator = this.getIntegrator();
        integrator.step(numSteps);
    }

    public void optimize(double eps, int maxIterations) {
        MinimizationReporter reporter = new MinimizationReporter();
        OpenMMLibrary.OpenMM_LocalEnergyMinimizer_minimize((PointerByReference)this.getPointer(), (double)(eps / 0.02390057361376673), (int)maxIterations, (PointerByReference)reporter.getPointer());
        reporter.destroy();
    }

    public void setPositions(double[] x) {
        long time = -java.lang.System.nanoTime();
        int n = x.length;
        double[] xn = new double[n];
        for (int i = 0; i < n; ++i) {
            xn[i] = x[i] * 0.1;
        }
        super.setPositions(xn);
        time += java.lang.System.nanoTime();
        if (logger.isLoggable(Level.FINEST)) {
            logger.finest(String.format(" Set OpenMM positions  %9.6f (msec)", (double)time * 1.0E-6));
        }
    }

    public void setVelocities(double[] v) {
        long time = -java.lang.System.nanoTime();
        int n = v.length;
        double[] vn = new double[n];
        for (int i = 0; i < n; ++i) {
            vn[i] = v[i] * 0.1;
        }
        super.setVelocities(vn);
        time += java.lang.System.nanoTime();
        if (logger.isLoggable(Level.FINEST)) {
            logger.finest(String.format(" Set OpenMM velocities %9.6f (msec)", (double)time * 1.0E-6));
        }
    }

    public void setPeriodicBoxVectors(Crystal crystal) {
        if (!crystal.aperiodic()) {
            OpenMM_Vec3 a = new OpenMM_Vec3();
            OpenMM_Vec3 b = new OpenMM_Vec3();
            OpenMM_Vec3 c = new OpenMM_Vec3();
            double[][] Ai = crystal.Ai;
            a.x = Ai[0][0] * 0.1;
            a.y = Ai[0][1] * 0.1;
            a.z = Ai[0][2] * 0.1;
            b.x = Ai[1][0] * 0.1;
            b.y = Ai[1][1] * 0.1;
            b.z = Ai[1][2] * 0.1;
            c.x = Ai[2][0] * 0.1;
            c.y = Ai[2][1] * 0.1;
            c.z = Ai[2][2] * 0.1;
            this.setPeriodicBoxVectors(a, b, c);
        }
    }

    public String toString() {
        return String.format(" OpenMM context with integrator %s, timestep %9.3g fsec, temperature %9.3g K", this.integratorName, this.timeStep, this.temperature);
    }

    public static ffx.openmm.Platform loadPlatform(Platform requestedPlatform, ForceField forceField) {
        ffx.openmm.Platform openMMPlatform;
        OpenMMUtils.init();
        logger.log(Level.INFO, " Loaded from:\n {0}", OpenMMLibrary.JNA_NATIVE_LIB.toString());
        logger.log(Level.INFO, " Version: {0}", ffx.openmm.Platform.getOpenMMVersion());
        String libDirectory = OpenMMUtils.getLibDirectory();
        logger.log(Level.FINE, " Lib Directory:       {0}", libDirectory);
        StringArray libs = ffx.openmm.Platform.loadPluginsFromDirectory((String)libDirectory);
        int numLibs = libs.getSize();
        logger.log(Level.FINE, " Number of libraries: {0}", numLibs);
        for (int i = 0; i < numLibs; ++i) {
            logger.log(Level.FINE, "  Library: {0}", libs.get(i));
        }
        libs.destroy();
        String pluginDirectory = OpenMMUtils.getPluginDirectory();
        logger.log(Level.INFO, "\n Plugin Directory:  {0}", pluginDirectory);
        StringArray plugins = ffx.openmm.Platform.loadPluginsFromDirectory((String)pluginDirectory);
        int numPlugins = plugins.getSize();
        logger.log(Level.INFO, " Number of Plugins: {0}", numPlugins);
        boolean cuda = false;
        boolean opencl = false;
        for (int i = 0; i < numPlugins; ++i) {
            boolean amoebaOpenCLAvailable;
            String pluginString = plugins.get(i);
            logger.log(Level.INFO, "  Plugin: {0}", pluginString);
            if (pluginString == null) continue;
            boolean amoebaCudaAvailable = (pluginString = pluginString.toUpperCase()).contains("AMOEBACUDA");
            if (amoebaCudaAvailable) {
                cuda = true;
            }
            if (!(amoebaOpenCLAvailable = pluginString.contains("AMOEBAOPENCL"))) continue;
            opencl = true;
        }
        plugins.destroy();
        int numPlatforms = ffx.openmm.Platform.getNumPlatforms();
        logger.log(Level.INFO, " Number of Platforms: {0}", numPlatforms);
        if (requestedPlatform == Platform.OMM_CUDA && !cuda) {
            logger.severe(" The OMM_CUDA platform was requested, but is not available.");
        }
        if (requestedPlatform == Platform.OMM_OPENCL && !opencl) {
            logger.severe(" The OMM_OPENCL platform was requested, but is not available.");
        }
        if (logger.isLoggable(Level.FINE)) {
            StringArray pluginFailures = ffx.openmm.Platform.getPluginLoadFailures();
            int numFailures = pluginFailures.getSize();
            for (int i = 0; i < numFailures; ++i) {
                logger.log(Level.FINE, " Plugin load failure: {0}", pluginFailures.get(i));
            }
            pluginFailures.destroy();
        }
        String defaultPrecision = "mixed";
        String precision = forceField.getString("PRECISION", defaultPrecision).toLowerCase();
        switch (precision = precision.replace("-precision", "")) {
            case "double": 
            case "mixed": 
            case "single": {
                logger.info(String.format(" Precision level: %s", precision));
                break;
            }
            default: {
                logger.info(String.format(" Could not interpret precision level %s, defaulting to %s", precision, defaultPrecision));
                precision = defaultPrecision;
            }
        }
        if (cuda && (requestedPlatform == Platform.OMM_CUDA || requestedPlatform == Platform.OMM)) {
            defaultDevice = OpenMMEnergy.getDefaultDevice(forceField.getProperties());
            openMMPlatform = new ffx.openmm.Platform("CUDA");
            int deviceID = forceField.getInteger("CUDA_DEVICE", defaultDevice);
            deviceID = forceField.getInteger("DeviceIndex", deviceID);
            String deviceIDString = Integer.toString(deviceID);
            openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
            openMMPlatform.setPropertyDefaultValue("Precision", precision);
            String name = openMMPlatform.getName();
            logger.info(String.format(" Platform: %s (Device Index %d)", name, deviceID));
        } else if (opencl && (requestedPlatform == Platform.OMM_OPENCL || requestedPlatform == Platform.OMM)) {
            defaultDevice = OpenMMEnergy.getDefaultDevice(forceField.getProperties());
            openMMPlatform = new ffx.openmm.Platform("OpenCL");
            int deviceID = forceField.getInteger("DeviceIndex", defaultDevice);
            String deviceIDString = Integer.toString(deviceID);
            openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
            int openCLPlatformIndex = forceField.getInteger("OpenCLPlatformIndex", 0);
            String openCLPlatformIndexString = Integer.toString(openCLPlatformIndex);
            openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
            openMMPlatform.setPropertyDefaultValue("OpenCLPlatformIndex", openCLPlatformIndexString);
            openMMPlatform.setPropertyDefaultValue("Precision", precision);
            String name = openMMPlatform.getName();
            logger.info(String.format(" Platform: %s (Platform Index %d, Device Index %d)", name, openCLPlatformIndex, deviceID));
        } else {
            openMMPlatform = new ffx.openmm.Platform("Reference");
            String name = openMMPlatform.getName();
            logger.info(String.format(" Platform: %s", name));
        }
        try {
            Comm world = Comm.world();
            if (world != null) {
                logger.info(String.format(" Running on host %s, rank %d", world.host(), world.rank()));
            }
        }
        catch (IllegalStateException illegalStateException) {
            logger.fine(" Could not find the world communicator!");
        }
        return openMMPlatform;
    }

    public void free() {
        this.destroy();
    }
}

