/*
 * Decompiled with CFR 0.152.
 */
package ffx.algorithms.dynamics;

import ffx.algorithms.AlgorithmListener;
import ffx.algorithms.dynamics.Barostat;
import ffx.algorithms.dynamics.MolecularDynamics;
import ffx.algorithms.dynamics.integrators.IntegratorEnum;
import ffx.algorithms.dynamics.thermostats.ThermostatEnum;
import ffx.crystal.Crystal;
import ffx.crystal.CrystalPotential;
import ffx.numerics.Potential;
import ffx.potential.MolecularAssembly;
import ffx.potential.UnmodifiableState;
import ffx.potential.bonded.Atom;
import ffx.potential.bonded.LambdaInterface;
import ffx.potential.openmm.OpenMMContext;
import ffx.potential.openmm.OpenMMPotential;
import ffx.potential.openmm.OpenMMState;
import ffx.potential.openmm.OpenMMSystem;
import java.io.File;
import java.util.logging.Logger;
import javax.annotation.Nullable;

public class MolecularDynamicsOpenMM
extends MolecularDynamics {
    private static final Logger logger = Logger.getLogger(MolecularDynamicsOpenMM.class.getName());
    private final IntegratorEnum integratorType;
    private final ThermostatEnum thermostatType;
    private final CrystalPotential crystalPotential;
    private final OpenMMPotential openMMPotential;
    private String integratorString;
    private int intervalSteps;
    private boolean running;
    private long time;
    private boolean getAllVars = true;
    private Runnable obtainVariables = this::getAllOpenMMVariables;

    public MolecularDynamicsOpenMM(MolecularAssembly assembly, Potential potential, AlgorithmListener listener, ThermostatEnum thermostat, IntegratorEnum integrator) {
        super(assembly, potential, listener, thermostat, integrator);
        logger.info("\n Initializing OpenMM molecular dynamics.");
        this.running = false;
        this.crystalPotential = (CrystalPotential)potential;
        this.openMMPotential = (OpenMMPotential)potential;
        this.thermostatType = thermostat;
        this.integratorType = integrator;
        this.integratorToString(this.integratorType);
    }

    public void setBarostat(@Nullable Barostat barostat) {
        if (barostat != null) {
            this.barostat = barostat;
            this.constantPressure = true;
            barostat.setActive(false);
        } else {
            this.barostat = null;
            this.constantPressure = false;
        }
    }

    @Override
    public void dynamic(long numSteps, double timeStep, double printInterval, double saveInterval, double temperature, boolean initVelocities, File dyn) {
        if (!this.done) {
            logger.warning(" Programming error - a thread invoked dynamic when it was already running.");
            return;
        }
        this.init(numSteps, timeStep, printInterval, saveInterval, this.fileType, this.restartInterval, temperature, initVelocities, dyn);
        if (this.intervalSteps == 0 || (long)this.intervalSteps > numSteps) {
            this.intervalSteps = (int)numSteps;
        }
        this.preRunOps();
        this.setOpenMMState();
        this.getOpenMMEnergies();
        this.initialState = new UnmodifiableState(this.state);
        boolean forceCreation = this.openMMPotential.setActiveAtoms();
        this.openMMPotential.updateContext(this.integratorString, this.dt, this.targetTemperature, forceCreation);
        this.postInitEnergies();
        this.mainLoop(numSteps);
        this.postRun();
    }

    @Override
    public int getIntervalSteps() {
        return this.intervalSteps;
    }

    @Override
    public void setIntervalSteps(int intervalSteps) {
        this.intervalSteps = intervalSteps;
    }

    @Override
    public double getTimeStep() {
        return this.dt;
    }

    @Override
    public void init(long numSteps, double timeStep, double loggingInterval, double trajectoryInterval, String fileType, double restartInterval, double temperature, boolean initVelocities, File dyn) {
        CrystalPotential crystalPotential;
        super.init(numSteps, timeStep, loggingInterval, trajectoryInterval, fileType, restartInterval, temperature, initVelocities, dyn);
        boolean isLangevin = IntegratorEnum.isStochastic(this.integratorType);
        OpenMMSystem openMMSystem = this.openMMPotential.getSystem();
        if (!isLangevin && !this.thermostatType.equals((Object)ThermostatEnum.ADIABATIC)) {
            openMMSystem.addAndersenThermostatForce(this.targetTemperature);
        }
        if (this.constantPressure) {
            double pressure = this.barostat.getPressure();
            int frequency = this.barostat.getMeanBarostatInterval();
            openMMSystem.addMonteCarloBarostatForce(pressure, this.targetTemperature, frequency);
        }
        if (!isLangevin) {
            openMMSystem.addCOMMRemoverForce();
        }
        if ((crystalPotential = this.crystalPotential) instanceof LambdaInterface) {
            LambdaInterface lambdaInferface = (LambdaInterface)crystalPotential;
            lambdaInferface.setLambda(lambdaInferface.getLambda());
        }
    }

    @Override
    public void revertState() throws Exception {
        super.revertState();
        this.setOpenMMState();
    }

    @Override
    public void setFileType(String fileType) {
        this.fileType = fileType;
    }

    @Override
    public void setObtainVelAcc(boolean obtainVA) {
        this.getAllVars = obtainVA;
        this.obtainVariables = obtainVA ? this::getAllOpenMMVariables : this::getOpenMMEnergiesAndPositions;
    }

    @Override
    public void writeRestart() {
        if (!this.getAllVars) {
            this.getAllOpenMMVariables();
        }
        super.writeRestart();
    }

    @Override
    protected void appendSnapshot(String[] extraLines) {
        if (!this.getAllVars) {
            this.getOpenMMEnergiesAndPositions();
        }
        super.appendSnapshot(extraLines);
    }

    private void takeOpenMMSteps(int intervalSteps) {
        OpenMMContext openMMContext = this.openMMPotential.getContext();
        openMMContext.integrate(intervalSteps);
    }

    private void setOpenMMState() {
        OpenMMContext openMMContext = this.openMMPotential.getContext();
        openMMContext.setPeriodicBoxVectors(this.crystalPotential.getCrystal());
        this.crystalPotential.setCoordinates(this.state.x());
        this.crystalPotential.setVelocity(this.state.v());
    }

    private void getOpenMMEnergies() {
        OpenMMState openMMState = this.openMMPotential.getOpenMMState(8);
        this.state.setKineticEnergy(openMMState.kineticEnergy);
        this.state.setPotentialEnergy(openMMState.potentialEnergy);
        this.state.setTemperature(this.openMMPotential.getSystem().getTemperature(openMMState.kineticEnergy));
        openMMState.destroy();
    }

    @Override
    void postInitEnergies() {
        super.postInitEnergies();
        this.running = true;
    }

    private void mainLoop(long numSteps) {
        long i = 0L;
        this.time = System.nanoTime();
        while (i < numSteps) {
            long takeStepsTime = -System.nanoTime();
            this.takeOpenMMSteps(this.intervalSteps);
            logger.finest(String.format("\n Took steps in %6.3f", (double)(takeStepsTime += System.nanoTime()) * 1.0E-9));
            this.totalSimTime += (double)this.intervalSteps * this.dt;
            long secondUpdateTime = -System.nanoTime();
            this.updateFromOpenMM(i += (long)this.intervalSteps, this.running);
            logger.finest(String.format("\n Update finished in %6.3f", (double)(secondUpdateTime += System.nanoTime()) * 1.0E-9));
        }
    }

    private void getOpenMMEnergiesAndPositions() {
        int mask = 9;
        OpenMMState openMMState = this.openMMPotential.getOpenMMState(mask);
        this.state.setPotentialEnergy(openMMState.potentialEnergy);
        this.state.setKineticEnergy(openMMState.kineticEnergy);
        this.state.setTemperature(this.openMMPotential.getSystem().getTemperature(openMMState.kineticEnergy));
        Crystal crystal = this.crystalPotential.getCrystal();
        if (!crystal.aperiodic()) {
            double[][] cellVectors = openMMState.getPeriodicBoxVectors();
            crystal.setCellVectors(cellVectors);
            this.crystalPotential.setCrystal(crystal);
        }
        Atom[] atoms = this.openMMPotential.getSystem().getAtoms();
        openMMState.getActivePositions(this.state.x(), atoms);
        openMMState.destroy();
    }

    private void getAllOpenMMVariables() {
        int mask = 15;
        OpenMMState openMMState = this.openMMPotential.getOpenMMState(mask);
        this.state.setPotentialEnergy(openMMState.potentialEnergy);
        this.state.setKineticEnergy(openMMState.kineticEnergy);
        this.state.setTemperature(this.openMMPotential.getSystem().getTemperature(openMMState.kineticEnergy));
        Crystal crystal = this.crystalPotential.getCrystal();
        if (!crystal.aperiodic()) {
            double[][] cellVectors = openMMState.getPeriodicBoxVectors();
            crystal.setCellVectors(cellVectors);
            this.crystalPotential.setCrystal(crystal);
        }
        Atom[] atoms = this.openMMPotential.getSystem().getAtoms();
        openMMState.getActivePositions(this.state.x(), atoms);
        openMMState.getActiveVelocities(this.state.v(), atoms);
        openMMState.getActiveAccelerations(this.state.a(), atoms);
        openMMState.destroy();
    }

    private void updateFromOpenMM(long i, boolean running) {
        double priorPE = this.state.getPotentialEnergy();
        this.obtainVariables.run();
        if (running) {
            if (i == 0L) {
                logger.log(this.basicLogging, String.format("\n  %8s %12s %12s %12s %8s %8s", "Time", "Kinetic", "Potential", "Total", "Temp", "CPU"));
                logger.log(this.basicLogging, String.format("  %8s %12s %12s %12s %8s %8s", "psec", "kcal/mol", "kcal/mol", "kcal/mol", "K", "sec"));
                logger.log(this.basicLogging, String.format("  %8s %12.4f %12.4f %12.4f %8.2f", "", this.state.getKineticEnergy(), this.state.getPotentialEnergy(), this.state.getTotalEnergy(), this.state.getTemperature()));
            }
            this.time = this.logThermoForTime(i, this.time);
            if (this.automaticWriteouts) {
                this.writeFilesForStep(i, true, true);
            }
        }
    }

    private void integratorToString(IntegratorEnum integrator) {
        if (integrator == null) {
            this.integratorString = "VERLET";
            logger.info(" An integrator was not specified. Verlet will be used.");
        } else {
            switch (this.integratorType) {
                default: {
                    this.integratorString = "VERLET";
                    break;
                }
                case STOCHASTIC: 
                case LANGEVIN: {
                    this.integratorString = "LANGEVIN";
                    break;
                }
                case RESPA: 
                case MTS: {
                    this.integratorString = "MTS";
                    break;
                }
                case STOCHASTIC_MTS: 
                case LANGEVIN_MTS: {
                    this.integratorString = "LANGEVIN-MTS";
                }
            }
        }
    }
}

