View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2024.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.algorithms.dynamics;
39  
40  import ffx.algorithms.AlgorithmListener;
41  import ffx.algorithms.dynamics.integrators.IntegratorEnum;
42  import ffx.algorithms.dynamics.thermostats.ThermostatEnum;
43  import ffx.crystal.Crystal;
44  import ffx.numerics.Potential;
45  import ffx.potential.openmm.OpenMMEnergy;
46  import ffx.potential.MolecularAssembly;
47  import ffx.potential.openmm.OpenMMContext;
48  import ffx.potential.openmm.OpenMMState;
49  import ffx.potential.openmm.OpenMMSystem;
50  import ffx.potential.UnmodifiableState;
51  
52  import java.io.File;
53  import java.util.ArrayList;
54  import java.util.List;
55  import java.util.logging.Logger;
56  
57  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Energy;
58  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Forces;
59  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Positions;
60  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Velocities;
61  import static ffx.utilities.Constants.NS2SEC;
62  import static java.lang.String.format;
63  
64  /**
65   * Runs Molecular Dynamics using OpenMM implementation
66   *
67   * @author Michael J. Schnieders
68   */
69  public class MolecularDynamicsOpenMM extends MolecularDynamics {
70  
71    private static final Logger logger = Logger.getLogger(MolecularDynamicsOpenMM.class.getName());
72    /**
73     * Integrator Type.
74     */
75    private final IntegratorEnum integratorType;
76    /**
77     * Thermostat Type.
78     */
79    private final ThermostatEnum thermostatType;
80    /**
81     * OpenMM ForceFieldEnergy.
82     */
83    private final OpenMMEnergy openMMEnergy;
84    /**
85     * Integrator String.
86     */
87    private String integratorString;
88    /**
89     * Number of OpenMM MD steps per iteration.
90     */
91    private int intervalSteps;
92    /**
93     * Flag to indicate OpenMM MD interactions are running.
94     */
95    private boolean running;
96    /**
97     * Run time.
98     */
99    private long time;
100   /**
101    * Obtain all variables with each update (i.e., include velocities, gradients).
102    */
103   private boolean getAllVars = true;
104   /**
105    * Method to run on update for obtaining variables. Will either grab everything (default) or
106    * energies + positions (MC-OST).
107    */
108   private Runnable obtainVariables = this::getAllOpenMMVariables;
109 
110   /**
111    * Constructs an MolecularDynamicsOpenMM object, to perform molecular dynamics using native OpenMM
112    * routines, avoiding the cost of communicating coordinates, gradients, and energies back and forth
113    * across the PCI bus.
114    *
115    * @param assembly   MolecularAssembly to operate on
116    * @param potential  Either a ForceFieldEnergyOpenMM, or a Barostat.
117    * @param listener   a {@link ffx.algorithms.AlgorithmListener} object.
118    * @param thermostat May have to be slightly modified for native OpenMM routines
119    * @param integrator May have to be slightly modified for native OpenMM routines
120    */
121   public MolecularDynamicsOpenMM(MolecularAssembly assembly, Potential potential,
122                                  AlgorithmListener listener, ThermostatEnum thermostat, IntegratorEnum integrator) {
123     super(assembly, potential, listener, thermostat, integrator);
124 
125     logger.info("\n Initializing OpenMM molecular dynamics.");
126 
127     // Initialization specific to MolecularDynamicsOpenMM
128     running = false;
129     List<Potential> potentialStack = new ArrayList<>(potential.getUnderlyingPotentials());
130     potentialStack.add(potential);
131 
132     List<OpenMMEnergy> energyList = potentialStack.stream()
133         .filter((Potential p) -> p instanceof OpenMMEnergy)
134         .map((Potential p) -> (OpenMMEnergy) p).toList();
135     if (energyList.size() != 1) {
136       logger.severe(format(
137           " Attempted to create a MolecularDynamicsOpenMM with %d ForceFieldEnergyOpenMM instances.",
138           energyList.size()));
139     }
140     openMMEnergy = energyList.get(0);
141 
142     List<Barostat> barostatList = potentialStack.stream()
143         .filter((Potential p) -> p instanceof Barostat).map((Potential p) -> (Barostat) p).toList();
144     if (barostatList.isEmpty()) {
145       constantPressure = false;
146     } else if (barostatList.size() > 1) {
147       logger.severe(
148           format(" Attempting to create a MolecularDynamicsOpenMM with more than 1 barostat (%d).",
149               barostatList.size()));
150     } else {
151       barostat = barostatList.get(0);
152       barostat.setActive(false);
153     }
154 
155     // Update the set of active and inactive atoms.
156     openMMEnergy.setActiveAtoms();
157 
158     thermostatType = thermostat;
159     integratorType = integrator;
160     integratorToString(integratorType);
161   }
162 
163   /**
164    * {@inheritDoc}
165    *
166    * <p>Execute <code>numSteps</code> of dynamics using the provided <code>timeStep</code> and
167    * <code>temperature</code>. The <code>printInterval</code> and <code>saveInterval</code>
168    * control logging the state of the system to the console and writing a restart file, respectively.
169    * If the <code>dyn</code> File is not null, the simulation will be initialized from the contents.
170    * If the <code>iniVelocities</code> is true, the velocities will be initialized from a Maxwell
171    * Boltzmann distribution.
172    */
173   @Override
174   public void dynamic(long numSteps, double timeStep, double printInterval, double saveInterval,
175                       double temperature, boolean initVelocities, File dyn) {
176     // Return if already running and a second thread calls the dynamic method.
177     if (!done) {
178       logger.warning(" Programming error - a thread invoked dynamic when it was already running.");
179       return;
180     }
181 
182     // Call the init method.
183     init(numSteps, timeStep, printInterval, saveInterval, fileType, restartInterval, temperature,
184         initVelocities, dyn);
185 
186     if (intervalSteps == 0 || intervalSteps > numSteps) {
187       // Safe cast: if intervalSteps > numSteps, then numSteps must be less than Integer.MAX_VALUE.
188       intervalSteps = (int) numSteps;
189     }
190 
191     // Initialization, including reading a dyn file or initialization of velocity.
192     preRunOps();
193 
194     // Send coordinates, velocities, etc. to OpenMM.
195     setOpenMMState();
196 
197     // Retrieve starting energy values.
198     getOpenMMEnergies();
199 
200     // Store the initial state.
201     initialState = new UnmodifiableState(state);
202 
203     // Check that our context is using correct Integrator, time step, and target temperature.
204     openMMEnergy.updateContext(integratorString, dt, targetTemperature, false);
205 
206     // Pre-run operations (mostly logging) that require knowledge of system energy.
207     postInitEnergies();
208 
209     // Run the MD steps.
210     mainLoop(numSteps);
211 
212     // Post-run cleanup operations.
213     postRun();
214   }
215 
216   /**
217    * {@inheritDoc}
218    */
219   @Override
220   public int getIntervalSteps() {
221     return intervalSteps;
222   }
223 
224   /**
225    * Setter for the field <code>intervalSteps</code>.
226    *
227    * @param intervalSteps The number of interval steps.
228    */
229   @Override
230   public void setIntervalSteps(int intervalSteps) {
231     this.intervalSteps = intervalSteps;
232   }
233 
234   /**
235    * {@inheritDoc}
236    */
237   @Override
238   public double getTimeStep() {
239     return dt;
240   }
241 
242   /**
243    * {@inheritDoc}
244    */
245   @Override
246   public void init(long numSteps, double timeStep, double loggingInterval, double trajectoryInterval,
247                    String fileType, double restartInterval, double temperature, boolean initVelocities,
248                    File dyn) {
249 
250     super.init(numSteps, timeStep, loggingInterval, trajectoryInterval, fileType, restartInterval,
251         temperature, initVelocities, dyn);
252 
253     boolean isLangevin = IntegratorEnum.isStochastic(integratorType);
254 
255     OpenMMSystem openMMSystem = openMMEnergy.getSystem();
256     if (!isLangevin && !thermostatType.equals(ThermostatEnum.ADIABATIC)) {
257       // Add Andersen thermostat, or if already present update its target temperature.
258       openMMSystem.addAndersenThermostatForce(targetTemperature);
259     }
260 
261     if (constantPressure) {
262       // Add an isotropic Monte Carlo barostat.
263       // If it is already present, update its target temperature, pressure and frequency.
264       double pressure = barostat.getPressure();
265       int frequency = barostat.getMeanBarostatInterval();
266       openMMSystem.addMonteCarloBarostatForce(pressure, targetTemperature, frequency);
267     }
268 
269     // For Langevin/Stochastic dynamics, center of mass motion will not be removed.
270     if (!isLangevin) {
271       // No action is taken if a COMMRemover is already present.
272       openMMSystem.addCOMMRemoverForce();
273     }
274 
275     // Set the current value of lambda.
276     openMMEnergy.setLambda(openMMEnergy.getLambda());
277   }
278 
279   @Override
280   public void revertState() throws Exception {
281     super.revertState();
282     setOpenMMState();
283   }
284 
285   /**
286    * {@inheritDoc}
287    */
288   @Override
289   public void setFileType(String fileType) {
290     this.fileType = fileType;
291   }
292 
293   /**
294    * Sets whether to obtain all variables (velocities, gradients) from OpenMM, or just positions and
295    * energies.
296    *
297    * @param obtainVA If true, obtain all variables from OpenMM each update.
298    */
299   @Override
300   public void setObtainVelAcc(boolean obtainVA) {
301     // TODO: Make this more generic by letting it obtain any weird combination of variables.
302     getAllVars = obtainVA;
303     obtainVariables = obtainVA ? this::getAllOpenMMVariables : this::getOpenMMEnergiesAndPositions;
304   }
305 
306   @Override
307   public void writeRestart() {
308     if (!getAllVars) {
309       // If !getAllVars, need to ensure all variables are synced before writing the restart.
310       getAllOpenMMVariables();
311     }
312     super.writeRestart();
313   }
314 
315   @Override
316   protected void appendSnapshot(String[] extraLines) {
317     if (!getAllVars) {
318       // If !getAllVars, need to ensure coordinates are synced before writing a snapshot.
319       getOpenMMEnergiesAndPositions();
320     }
321     super.appendSnapshot(extraLines);
322   }
323 
324   /**
325    * Integrate the simulation using the defined Context and Integrator.
326    *
327    * @param intervalSteps Number of MD steps to take.
328    */
329   private void takeOpenMMSteps(int intervalSteps) {
330     OpenMMContext openMMContext = openMMEnergy.getContext();
331     openMMContext.integrate(intervalSteps);
332   }
333 
334   /**
335    * Load coordinates, box vectors and velocities.
336    */
337   private void setOpenMMState() {
338     OpenMMContext openMMContext = openMMEnergy.getContext();
339     openMMContext.setPositions(state.x());
340     openMMContext.setPeriodicBoxVectors(openMMEnergy.getCrystal());
341     openMMContext.setVelocities(state.v());
342   }
343 
344   /**
345    * Get OpenMM Energies.
346    */
347   private void getOpenMMEnergies() {
348     OpenMMState openMMState = openMMEnergy.getOpenMMState(OpenMM_State_Energy);
349     state.setKineticEnergy(openMMState.kineticEnergy);
350     state.setPotentialEnergy(openMMState.potentialEnergy);
351     state.setTemperature(openMMEnergy.getSystem().getTemperature(openMMState.kineticEnergy));
352     openMMState.destroy();
353   }
354 
355   /**
356    * Do some logging of the beginning energy values.
357    */
358   void postInitEnergies() {
359     super.postInitEnergies();
360     running = true;
361   }
362 
363   private void mainLoop(long numSteps) {
364     long i = 0;
365     time = System.nanoTime();
366 
367     while (i < numSteps) {
368 
369       // Take MD steps in OpenMM.
370       long takeStepsTime = -System.nanoTime();
371       takeOpenMMSteps(intervalSteps);
372       takeStepsTime += System.nanoTime();
373       logger.fine(String.format("\n Took steps in %6.3f", takeStepsTime * NS2SEC));
374       totalSimTime += intervalSteps * dt;
375 
376       // Update the total step count.
377       i += intervalSteps;
378 
379       long secondUpdateTime = -System.nanoTime();
380       updateFromOpenMM(i, running);
381       secondUpdateTime += System.nanoTime();
382 
383       logger.fine(String.format("\n Update finished in %6.3f", secondUpdateTime * NS2SEC));
384     }
385   }
386 
387   /**
388    * Get OpenMM Energies and Positions.
389    */
390   private void getOpenMMEnergiesAndPositions() {
391     int mask = OpenMM_State_Energy | OpenMM_State_Positions;
392     OpenMMState openMMState = openMMEnergy.getOpenMMState(mask);
393     state.setPotentialEnergy(openMMState.potentialEnergy);
394     state.setKineticEnergy(openMMState.kineticEnergy);
395     state.setTemperature(openMMEnergy.getSystem().getTemperature(openMMState.kineticEnergy));
396     openMMState.getPositions(state.x());
397     Crystal crystal = openMMEnergy.getCrystal();
398     if (!crystal.aperiodic()) {
399       double[][] cellVectors = openMMState.getPeriodicBoxVectors();
400       crystal.setCellVectors(cellVectors);
401       openMMEnergy.setCrystal(crystal);
402     }
403     openMMState.destroy();
404   }
405 
406   /**
407    * Get OpenMM energies, positions, velocities, and accelerations.
408    */
409   private void getAllOpenMMVariables() {
410     int mask = OpenMM_State_Energy | OpenMM_State_Positions | OpenMM_State_Velocities | OpenMM_State_Forces;
411     OpenMMState openMMState = openMMEnergy.getOpenMMState(mask);
412     state.setPotentialEnergy(openMMState.potentialEnergy);
413     state.setKineticEnergy(openMMState.kineticEnergy);
414     state.setTemperature(openMMEnergy.getSystem().getTemperature(openMMState.kineticEnergy));
415     openMMState.getPositions(state.x());
416     Crystal crystal = openMMEnergy.getCrystal();
417     if (!crystal.aperiodic()) {
418       double[][] cellVectors = openMMState.getPeriodicBoxVectors();
419       crystal.setCellVectors(cellVectors);
420       openMMEnergy.setCrystal(crystal);
421     }
422     openMMState.getVelocities(state.v());
423     openMMState.getAccelerations(state.a());
424     openMMState.destroy();
425   }
426 
427   /**
428    * updateFromOpenMM obtains the state of the simulation from OpenMM, completes some logging, and
429    * saves restart files.
430    *
431    * @param i       Number of OpenMM MD rounds.
432    * @param running True if OpenMM MD rounds have begun running.
433    */
434   private void updateFromOpenMM(long i, boolean running) {
435 
436     double priorPE = state.getPotentialEnergy();
437 
438     obtainVariables.run();
439 
440     if (running) {
441       if (i == 0) {
442         logger.log(basicLogging,
443             format("\n  %8s %12s %12s %12s %8s %8s", "Time", "Kinetic", "Potential", "Total", "Temp",
444                 "CPU"));
445         logger.log(basicLogging,
446             format("  %8s %12s %12s %12s %8s %8s", "psec", "kcal/mol", "kcal/mol", "kcal/mol", "K",
447                 "sec"));
448         logger.log(basicLogging,
449             format("  %8s %12.4f %12.4f %12.4f %8.2f", "", state.getKineticEnergy(),
450                 state.getPotentialEnergy(), state.getTotalEnergy(), state.getTemperature()));
451       }
452       time = logThermoForTime(i, time);
453 
454       if (automaticWriteouts) {
455         writeFilesForStep(i, true, true);
456       }
457     }
458   }
459 
460   /**
461    * integratorToString.
462    *
463    * @param integrator a {@link ffx.algorithms.dynamics.integrators.IntegratorEnum} object.
464    */
465   private void integratorToString(IntegratorEnum integrator) {
466     if (integrator == null) {
467       integratorString = "VERLET";
468       logger.info(" An integrator was not specified. Verlet will be used.");
469     } else {
470       switch (integratorType) {
471         default -> integratorString = "VERLET";
472         case STOCHASTIC, LANGEVIN -> integratorString = "LANGEVIN";
473         case RESPA, MTS -> integratorString = "MTS";
474         case STOCHASTIC_MTS, LANGEVIN_MTS -> integratorString = "LANGEVIN-MTS";
475       }
476     }
477   }
478 }