View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.potential.openmm;
39  
40  import edu.rit.pj.Comm;
41  import edu.uiowa.jopenmm.OpenMMLibrary;
42  import edu.uiowa.jopenmm.OpenMMUtils;
43  import edu.uiowa.jopenmm.OpenMM_Vec3;
44  import ffx.crystal.Crystal;
45  import ffx.numerics.Potential;
46  import ffx.openmm.Context;
47  import ffx.openmm.Integrator;
48  import ffx.openmm.MinimizationReporter;
49  import ffx.openmm.Platform;
50  import ffx.openmm.State;
51  import ffx.openmm.StringArray;
52  import ffx.potential.bonded.Atom;
53  import ffx.potential.parameters.ForceField;
54  
55  import java.util.logging.Level;
56  import java.util.logging.Logger;
57  
58  import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_KcalPerKJ;
59  import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_NmPerAngstrom;
60  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_False;
61  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_True;
62  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_LocalEnergyMinimizer_minimize;
63  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Positions;
64  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Velocities;
65  import static ffx.openmm.Platform.getNumPlatforms;
66  import static ffx.openmm.Platform.getOpenMMVersion;
67  import static ffx.openmm.Platform.getPluginLoadFailures;
68  import static ffx.openmm.Platform.loadPluginsFromDirectory;
69  import static ffx.potential.ForceFieldEnergy.DEFAULT_CONSTRAINT_TOLERANCE;
70  import static ffx.potential.Platform.OMM;
71  import static ffx.potential.Platform.OMM_CUDA;
72  import static ffx.potential.Platform.OMM_OPENCL;
73  import static ffx.potential.openmm.OpenMMEnergy.getDefaultDevice;
74  import static ffx.potential.openmm.OpenMMIntegrator.createIntegrator;
75  import static java.lang.String.format;
76  
77  /**
78   * Creates and manage an OpenMM Context.
79   *
80   * <p>A Context stores the complete state of a simulation. More specifically, it includes: The
81   * current time The position of each particle The velocity of each particle The values of
82   * configurable parameters defined by Force objects in the System
83   *
84   * <p>You can retrieve a snapshot of the current state at any time by calling getState(). This
85   * allows you to record the state of the simulation at various points, either for analysis or for
86   * checkpointing. getState() can also be used to retrieve the current forces on each particle and
87   * the current energy of the System.
88   */
89  public class OpenMMContext extends Context {
90  
91    private static final Logger logger = Logger.getLogger(OpenMMContext.class.getName());
92  
93    /**
94     * OpenMM System.
95     */
96    private final OpenMMSystem openMMSystem;
97    /**
98     * Integrator string (default = VERLET).
99     */
100   private String integratorName = "VERLET";
101   /**
102    * Time step (default = 0.001 psec).
103    */
104   private double timeStep = 0.001;
105   /**
106    * Temperature (default = 298.15).
107    */
108   private double temperature = 298.15;
109   /**
110    *
111    */
112   private final int enforcePBC;
113   /**
114    * Array of atoms.
115    */
116   private final Atom[] atoms;
117 
118   /**
119    * Create an OpenMM Context for a single topology OpenMM system.
120    *
121    * @param platform     OpenMM Platform.
122    * @param openMMSystem OpenMM System.
123    * @param atoms        Array of atoms.
124    */
125   public OpenMMContext(Platform platform, OpenMMSystem openMMSystem, Atom[] atoms) {
126     super(openMMSystem, createIntegrator("VERLET", 0.001, 298.15, openMMSystem), platform);
127     this.openMMSystem = openMMSystem;
128     this.atoms = atoms;
129 
130     ForceField forceField = openMMSystem.getForceField();
131     boolean aperiodic = openMMSystem.getCrystal().aperiodic();
132     boolean pbcEnforced = forceField.getBoolean("ENFORCE_PBC", !aperiodic);
133     enforcePBC = pbcEnforced ? OpenMM_True : OpenMM_False;
134   }
135 
136   /**
137    * Update the Context in which to run a simulation.
138    *
139    * @param integratorName Requested integrator.
140    * @param timeStep       Time step (psec).
141    * @param temperature    Temperature (K).
142    * @param forceCreation  Force creation of a new context, even if the current one matches.
143    */
144   public void update(String integratorName, double timeStep, double temperature, boolean forceCreation) {
145     // Check if the current context is consistent with the requested context.
146     if (hasContextPointer() && !forceCreation) {
147       if (this.temperature == temperature && this.timeStep == timeStep
148           && this.integratorName.equalsIgnoreCase(integratorName)) {
149         // All requested features agree.
150         return;
151       }
152     }
153 
154     this.integratorName = integratorName;
155     this.timeStep = timeStep;
156     this.temperature = temperature;
157 
158     logger.info("\n Updating OpenMM Context");
159 
160     // Set lambda to 1.0 when creating a context to avoid OpenMM compiling out any terms.
161     // TODO: Test on a fixed charge system.
162     // double currentLambda = openMMEnergy.getLambda();
163     // if (openMMEnergy.getLambdaTerm()) {
164     //  openMMEnergy.setLambda(1.0);
165     // }
166 
167     // Update the context.
168     Integrator newIntegrator = createIntegrator(integratorName, timeStep, temperature, openMMSystem);
169     Platform newPlatform = new Platform(platform.getName());
170     updateContext(openMMSystem, newIntegrator, newPlatform);
171 
172     // Revert to the current lambda value.
173     // if (openMMEnergy.getLambdaTerm()) {
174     //   openMMEnergy.setLambda(currentLambda);
175     // }
176 
177     // Get initial positions and velocities for all atoms.
178     int nVar = atoms.length * 3;
179     double[] x = new double[nVar];
180     double[] v = new double[nVar];
181     double[] vel3 = new double[3];
182     int index = 0;
183     for (Atom a : atoms) {
184       a.getVelocity(vel3);
185       // X-axis
186       x[index] = a.getX();
187       v[index++] = vel3[0];
188       // Y-axis
189       x[index] = a.getY();
190       v[index++] = vel3[1];
191       // Z-axis
192       x[index] = a.getZ();
193       v[index++] = vel3[2];
194     }
195 
196     // Load the current periodic box vectors.
197     Crystal crystal = openMMSystem.getCrystal();
198     setPeriodicBoxVectors(crystal);
199 
200     // Load current atomic positions.
201     setPositions(x);
202 
203     // Load current velocities.
204     setVelocities(v);
205 
206     // Apply constraints starting from current atomic positions.
207     applyConstraints(DEFAULT_CONSTRAINT_TOLERANCE);
208 
209     // Application of constraints can change coordinates and velocities.
210     // Retrieve them for consistency.
211     OpenMMState openMMState = getOpenMMState(OpenMM_State_Positions | OpenMM_State_Velocities);
212     Potential energy = openMMSystem.getPotential();
213     energy.setCoordinates(openMMState.getActivePositions(null, atoms));
214     energy.setVelocity(openMMState.getActiveVelocities(null, atoms));
215     openMMState.destroy();
216   }
217 
218   /**
219    * Update the Context if necessary.
220    */
221   public void update() {
222     if (!hasContextPointer()) {
223       logger.info(" Delayed creation of OpenMM Context.");
224       update(integratorName, timeStep, temperature, true);
225     }
226   }
227 
228   /**
229    * Get an OpenMM State from the Context.
230    *
231    * @param mask A mask specifying which information to retrieve.
232    * @return State pointer.
233    */
234   public OpenMMState getOpenMMState(int mask) {
235     State state = getState(mask, enforcePBC);
236     return new OpenMMState(state.getPointer());
237   }
238 
239   /**
240    * Use the Context / Integrator combination to take the requested number of steps.
241    *
242    * @param numSteps Number of steps to take.
243    */
244   public void integrate(int numSteps) {
245     Integrator integrator = getIntegrator();
246     integrator.step(numSteps);
247   }
248 
249   /**
250    * Use the Context to optimize the system to the requested tolerance.
251    *
252    * @param eps           Convergence criteria (kcal/mole/A).
253    * @param maxIterations Maximum number of iterations.
254    */
255   public void optimize(double eps, int maxIterations) {
256     // The "report" method of MinimizationReporter cannot be overridden, so the reporter does nothing.
257     MinimizationReporter reporter = new MinimizationReporter();
258     OpenMM_LocalEnergyMinimizer_minimize(getPointer(), eps / (OpenMM_NmPerAngstrom * OpenMM_KcalPerKJ),
259         maxIterations, reporter.getPointer());
260     reporter.destroy();
261   }
262 
263   /**
264    * The array x should contain atomic coordinates for all atoms in units of Angstroms.
265    *
266    * @param x Atomic coordinate array for all atoms in units of Angstroms.
267    */
268   @Override
269   public void setPositions(double[] x) {
270     long time = -System.nanoTime();
271     int n = x.length;
272     double[] xn = new double[n];
273     for (int i = 0; i < n; i++) {
274       // Convert Angstroms to nanometers.
275       xn[i] = x[i] * OpenMM_NmPerAngstrom;
276     }
277     super.setPositions(xn);
278     time += System.nanoTime();
279     if (logger.isLoggable(Level.FINEST)) {
280       logger.finest(format(" Set OpenMM positions  %9.6f (msec)", time * 1e-6));
281     }
282   }
283 
284   /**
285    * The array v contains velocity values for all atomic coordinates in units of Angstroms/psec.
286    *
287    * @param v Velocity array for all atoms.
288    */
289   @Override
290   public void setVelocities(double[] v) {
291     long time = -System.nanoTime();
292     int n = v.length;
293     double[] vn = new double[n];
294     for (int i = 0; i < n; i++) {
295       // Convert Angstroms to nanometers.
296       vn[i] = v[i] * OpenMM_NmPerAngstrom;
297     }
298     super.setVelocities(vn);
299     time += System.nanoTime();
300     if (logger.isLoggable(Level.FINEST)) {
301       logger.finest(format(" Set OpenMM velocities %9.6f (msec)", time * 1e-6));
302     }
303   }
304 
305   /**
306    * Set the periodic box vectors for a context based on the crystal instance.
307    *
308    * @param crystal The crystal instance.
309    */
310   public void setPeriodicBoxVectors(Crystal crystal) {
311     if (!crystal.aperiodic()) {
312       OpenMM_Vec3 a = new OpenMM_Vec3();
313       OpenMM_Vec3 b = new OpenMM_Vec3();
314       OpenMM_Vec3 c = new OpenMM_Vec3();
315       double[][] Ai = crystal.Ai;
316       a.x = Ai[0][0] * OpenMM_NmPerAngstrom;
317       a.y = Ai[0][1] * OpenMM_NmPerAngstrom;
318       a.z = Ai[0][2] * OpenMM_NmPerAngstrom;
319       b.x = Ai[1][0] * OpenMM_NmPerAngstrom;
320       b.y = Ai[1][1] * OpenMM_NmPerAngstrom;
321       b.z = Ai[1][2] * OpenMM_NmPerAngstrom;
322       c.x = Ai[2][0] * OpenMM_NmPerAngstrom;
323       c.y = Ai[2][1] * OpenMM_NmPerAngstrom;
324       c.z = Ai[2][2] * OpenMM_NmPerAngstrom;
325       setPeriodicBoxVectors(a, b, c);
326     }
327   }
328 
329   /**
330    * {@inheritDoc}
331    */
332   @Override
333   public String toString() {
334     return format(
335         " OpenMM context with integrator %s, timestep %9.3g fsec, temperature %9.3g K",
336         integratorName, timeStep, temperature);
337   }
338 
339   /**
340    * Load an OpenMM Platform
341    *
342    * @param requestedPlatform the requested OpenMM platform.
343    * @param forceField        the ForceField to query for platform flags.
344    * @return the loaded Platform.
345    */
346   public static Platform loadPlatform(ffx.potential.Platform requestedPlatform, ForceField forceField) {
347 
348     OpenMMUtils.init();
349 
350     // Print out the OpenMM library path.
351     logger.log(Level.INFO, " Loaded from:\n {0}", OpenMMLibrary.JNA_NATIVE_LIB.toString());
352 
353     // Print out the OpenMM Version.
354     logger.log(Level.INFO, " Version: {0}", getOpenMMVersion());
355 
356     // Print out the OpenMM lib directory.
357     String libDirectory = OpenMMUtils.getLibDirectory();
358     logger.log(Level.FINE, " Lib Directory:       {0}", libDirectory);
359     // Load platforms and print out their names.
360     StringArray libs = loadPluginsFromDirectory(libDirectory);
361     int numLibs = libs.getSize();
362     logger.log(Level.FINE, " Number of libraries: {0}", numLibs);
363     for (int i = 0; i < numLibs; i++) {
364       logger.log(Level.FINE, "  Library: {0}", libs.get(i));
365     }
366     libs.destroy();
367 
368     // Print out the OpenMM plugin directory.
369     String pluginDirectory = OpenMMUtils.getPluginDirectory();
370     logger.log(Level.INFO, "\n Plugin Directory:  {0}", pluginDirectory);
371     // Load plugins and print out their names.
372     StringArray plugins = loadPluginsFromDirectory(pluginDirectory);
373     int numPlugins = plugins.getSize();
374     logger.log(Level.INFO, " Number of Plugins: {0}", numPlugins);
375     boolean cuda = false;
376     boolean opencl = false;
377     for (int i = 0; i < numPlugins; i++) {
378       String pluginString = plugins.get(i);
379       logger.log(Level.INFO, "  Plugin: {0}", pluginString);
380       if (pluginString != null) {
381         pluginString = pluginString.toUpperCase();
382         boolean amoebaCudaAvailable = pluginString.contains("AMOEBACUDA");
383         if (amoebaCudaAvailable) {
384           cuda = true;
385         }
386         boolean amoebaOpenCLAvailable = pluginString.contains("AMOEBAOPENCL");
387         if (amoebaOpenCLAvailable) {
388           opencl = true;
389         }
390       }
391     }
392     plugins.destroy();
393 
394     int numPlatforms = getNumPlatforms();
395     logger.log(Level.INFO, " Number of Platforms: {0}", numPlatforms);
396 
397     if (requestedPlatform == OMM_CUDA && !cuda) {
398       logger.severe(" The OMM_CUDA platform was requested, but is not available.");
399     }
400 
401     if (requestedPlatform == ffx.potential.Platform.OMM_OPENCL && !opencl) {
402       logger.severe(" The OMM_OPENCL platform was requested, but is not available.");
403     }
404 
405     // Extra logging to print out plugins that failed to load.
406     if (logger.isLoggable(Level.FINE)) {
407       StringArray pluginFailures = getPluginLoadFailures();
408       int numFailures = pluginFailures.getSize();
409       for (int i = 0; i < numFailures; i++) {
410         logger.log(Level.FINE, " Plugin load failure: {0}", pluginFailures.get(i));
411       }
412       pluginFailures.destroy();
413     }
414 
415     String defaultPrecision = "mixed";
416     String precision = forceField.getString("PRECISION", defaultPrecision).toLowerCase();
417     precision = precision.replace("-precision", "");
418     switch (precision) {
419       case "double", "mixed", "single" -> logger.info(format(" Precision level: %s", precision));
420       default -> {
421         logger.info(format(" Could not interpret precision level %s, defaulting to %s", precision, defaultPrecision));
422         precision = defaultPrecision;
423       }
424     }
425 
426 
427     Platform openMMPlatform;
428     if (cuda && (requestedPlatform == OMM_CUDA || requestedPlatform == OMM)) {
429       // CUDA
430       int defaultDevice = getDefaultDevice(forceField.getProperties());
431       openMMPlatform = new Platform("CUDA");
432       // CUDA_DEVICE is deprecated; use DeviceIndex.
433       int deviceID = forceField.getInteger("CUDA_DEVICE", defaultDevice);
434       deviceID = forceField.getInteger("DeviceIndex", deviceID);
435       String deviceIDString = Integer.toString(deviceID);
436       openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
437       openMMPlatform.setPropertyDefaultValue("Precision", precision);
438       String name = openMMPlatform.getName();
439       logger.info(format(" Platform: %s (Device Index %d)", name, deviceID));
440     } else if (opencl && (requestedPlatform == OMM_OPENCL || requestedPlatform == OMM)) {
441       // OpenCL
442       int defaultDevice = getDefaultDevice(forceField.getProperties());
443       openMMPlatform = new Platform("OpenCL");
444       int deviceID = forceField.getInteger("DeviceIndex", defaultDevice);
445       String deviceIDString = Integer.toString(deviceID);
446       openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
447       int openCLPlatformIndex = forceField.getInteger("OpenCLPlatformIndex", 0);
448       String openCLPlatformIndexString = Integer.toString(openCLPlatformIndex);
449       openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
450       openMMPlatform.setPropertyDefaultValue("OpenCLPlatformIndex", openCLPlatformIndexString);
451       openMMPlatform.setPropertyDefaultValue("Precision", precision);
452       String name = openMMPlatform.getName();
453       logger.info(format(" Platform: %s (Platform Index %d, Device Index %d)",
454           name, openCLPlatformIndex, deviceID));
455     } else {
456       // Reference
457       openMMPlatform = new Platform("Reference");
458       String name = openMMPlatform.getName();
459       logger.info(format(" Platform: %s", name));
460     }
461 
462     try {
463       Comm world = Comm.world();
464       if (world != null) {
465         logger.info(format(" Running on host %s, rank %d", world.host(), world.rank()));
466       }
467     } catch (IllegalStateException illegalStateException) {
468       logger.fine(" Could not find the world communicator!");
469     }
470 
471     return openMMPlatform;
472   }
473 
474   /**
475    * Free OpenMM memory for the current Context and Integrator.
476    */
477   public void free() {
478     destroy();
479   }
480 }