View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.potential.openmm;
39  
40  import edu.rit.pj.Comm;
41  import edu.uiowa.jopenmm.OpenMMLibrary;
42  import edu.uiowa.jopenmm.OpenMMUtils;
43  import edu.uiowa.jopenmm.OpenMM_Vec3;
44  import ffx.crystal.Crystal;
45  import ffx.openmm.Context;
46  import ffx.openmm.Integrator;
47  import ffx.openmm.MinimizationReporter;
48  import ffx.openmm.Platform;
49  import ffx.openmm.State;
50  import ffx.openmm.StringArray;
51  import ffx.potential.bonded.Atom;
52  import ffx.potential.parameters.ForceField;
53  
54  import java.util.logging.Level;
55  import java.util.logging.Logger;
56  
57  import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_KcalPerKJ;
58  import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_NmPerAngstrom;
59  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_False;
60  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_True;
61  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_LocalEnergyMinimizer_minimize;
62  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Positions;
63  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Velocities;
64  import static ffx.openmm.Platform.getNumPlatforms;
65  import static ffx.openmm.Platform.getOpenMMVersion;
66  import static ffx.openmm.Platform.getPluginLoadFailures;
67  import static ffx.openmm.Platform.loadPluginsFromDirectory;
68  import static ffx.potential.ForceFieldEnergy.DEFAULT_CONSTRAINT_TOLERANCE;
69  import static ffx.potential.Platform.OMM;
70  import static ffx.potential.Platform.OMM_CUDA;
71  import static ffx.potential.Platform.OMM_OPENCL;
72  import static ffx.potential.openmm.OpenMMEnergy.getDefaultDevice;
73  import static ffx.potential.openmm.OpenMMIntegrator.createIntegrator;
74  import static java.lang.String.format;
75  
76  /**
77   * Creates and manage an OpenMM Context.
78   *
79   * <p>A Context stores the complete state of a simulation. More specifically, it includes: The
80   * current time The position of each particle The velocity of each particle The values of
81   * configurable parameters defined by Force objects in the System
82   *
83   * <p>You can retrieve a snapshot of the current state at any time by calling getState(). This
84   * allows you to record the state of the simulation at various points, either for analysis or for
85   * checkpointing. getState() can also be used to retrieve the current forces on each particle and
86   * the current energy of the System.
87   */
88  public class OpenMMContext extends Context {
89  
90    private static final Logger logger = Logger.getLogger(OpenMMContext.class.getName());
91  
92    /**
93     * OpenMM Platform.
94     */
95    private final Platform openMMPlatform;
96    /**
97     * OpenMM System.
98     */
99    private final OpenMMSystem openMMSystem;
100   /**
101    * Instance of the OpenMM Integrator class.
102    */
103   private Integrator openMMIntegrator;
104   /**
105    * Integrator string (default = VERLET).
106    */
107   private String integratorName = "VERLET";
108   /**
109    * Time step (default = 0.001 psec).
110    */
111   private double timeStep = 0.001;
112   /**
113    * Temperature (default = 298.15).
114    */
115   private double temperature = 298.15;
116   /**
117    *
118    */
119   private final int enforcePBC;
120   /**
121    * Array of atoms.
122    */
123   private final Atom[] atoms;
124 
125   /**
126    * Create an OpenMM Context.
127    *
128    * @param platform     OpenMM Platform.
129    * @param openMMSystem OpenMM System.
130    * @param atoms        Array of atoms.
131    */
132   public OpenMMContext(Platform platform, OpenMMSystem openMMSystem, Atom[] atoms) {
133     this.openMMPlatform = platform;
134     this.openMMSystem = openMMSystem;
135 
136     ForceField forceField = openMMSystem.getForceField();
137     boolean aperiodic = openMMSystem.getCrystal().aperiodic();
138     boolean pbcEnforced = forceField.getBoolean("ENFORCE_PBC", !aperiodic);
139     enforcePBC = pbcEnforced ? OpenMM_True : OpenMM_False;
140 
141     this.atoms = atoms;
142     update(integratorName, timeStep, temperature, true);
143   }
144 
145   /**
146    * Update the Context in which to run a simulation.
147    *
148    * @param integratorName Requested integrator.
149    * @param timeStep       Time step (psec).
150    * @param temperature    Temperature (K).
151    * @param forceCreation  Force creation of a new context, even if the current one matches.
152    */
153   public void update(String integratorName, double timeStep, double temperature, boolean forceCreation) {
154     // Check if the current context is consistent with the requested context.
155     if (hasContextPointer() && !forceCreation) {
156       if (this.temperature == temperature && this.timeStep == timeStep
157           && this.integratorName.equalsIgnoreCase(integratorName)) {
158         // All requested features agree.
159         return;
160       }
161     }
162 
163     this.integratorName = integratorName;
164     this.timeStep = timeStep;
165     this.temperature = temperature;
166 
167     logger.info("\n Updating OpenMM Context");
168 
169     // Update the integrator.
170     if (openMMIntegrator != null) {
171       openMMIntegrator.destroy();
172     }
173     openMMIntegrator = createIntegrator(integratorName, timeStep, temperature, openMMSystem);
174 
175     // Set lambda to 1.0 when creating a context to avoid OpenMM compiling out any terms.
176     // TODO: Test on a fixed charge system.
177     // double currentLambda = openMMEnergy.getLambda();
178     // if (openMMEnergy.getLambdaTerm()) {
179     //  openMMEnergy.setLambda(1.0);
180     // }
181 
182     // Update the context.
183     updateContext(openMMSystem, openMMIntegrator, openMMPlatform);
184 
185     // Revert to the current lambda value.
186     // if (openMMEnergy.getLambdaTerm()) {
187     //   openMMEnergy.setLambda(currentLambda);
188     // }
189 
190     // Get initial positions and velocities for active atoms.
191     int nVar = openMMSystem.getNumberOfVariables();
192     double[] x = new double[nVar];
193     double[] v = new double[nVar];
194     double[] vel3 = new double[3];
195     int index = 0;
196     for (Atom a : atoms) {
197       if (a.isActive()) {
198         a.getVelocity(vel3);
199         // X-axis
200         x[index] = a.getX();
201         v[index++] = vel3[0];
202         // Y-axis
203         x[index] = a.getY();
204         v[index++] = vel3[1];
205         // Z-axis
206         x[index] = a.getZ();
207         v[index++] = vel3[2];
208       }
209     }
210 
211     // Load the current periodic box vectors.
212     Crystal crystal = openMMSystem.getCrystal();
213     setPeriodicBoxVectors(crystal);
214 
215     // Load current atomic positions.
216     setPositions(x);
217 
218     // Load current velocities.
219     setVelocities(v);
220 
221     // Apply constraints starting from current atomic positions.
222     applyConstraints(DEFAULT_CONSTRAINT_TOLERANCE);
223 
224     // Application of constraints can change coordinates and velocities.
225     // Retrieve them for consistency.
226     OpenMMState openMMState = getOpenMMState(OpenMM_State_Positions | OpenMM_State_Velocities);
227     openMMState.getPositions(x);
228     openMMState.getVelocities(v);
229     openMMState.destroy();
230   }
231 
232   /**
233    * Update the Context if necessary.
234    */
235   public void update() {
236     if (!hasContextPointer()) {
237       logger.info(" Delayed creation of OpenMM Context.");
238       update(integratorName, timeStep, temperature, true);
239     }
240   }
241 
242   /**
243    * Get an OpenMM State from the Context.
244    *
245    * @param mask A mask specifying which information to retrieve.
246    * @return State pointer.
247    */
248   public OpenMMState getOpenMMState(int mask) {
249     State state = getState(mask, enforcePBC);
250     return new OpenMMState(state.getPointer(), atoms, openMMSystem.getNumberOfVariables());
251   }
252 
253   /**
254    * Use the Context / Integrator combination to take the requested number of steps.
255    *
256    * @param numSteps Number of steps to take.
257    */
258   public void integrate(int numSteps) {
259     openMMIntegrator.step(numSteps);
260   }
261 
262   /**
263    * Use the Context to optimize the system to the requested tolerance.
264    *
265    * @param eps           Convergence criteria (kcal/mole/A).
266    * @param maxIterations Maximum number of iterations.
267    */
268   public void optimize(double eps, int maxIterations) {
269     // The "report" method of MinimizationReporter cannot be overridden, so the reporter does nothing.
270     MinimizationReporter reporter = new MinimizationReporter();
271     OpenMM_LocalEnergyMinimizer_minimize(getPointer(), eps / (OpenMM_NmPerAngstrom * OpenMM_KcalPerKJ),
272         maxIterations, reporter.getPointer());
273     reporter.destroy();
274   }
275 
276   /**
277    * The array x contains atomic coordinates only for active atoms.
278    *
279    * @param x Atomic coordinate array for only active atoms.
280    */
281   public void setPositions(double[] x) {
282     double[] allPositions = new double[3 * atoms.length];
283     double[] d = new double[3];
284     int index = 0;
285     for (Atom a : atoms) {
286       if (a.isActive()) {
287         a.moveTo(x[index], x[index + 1], x[index + 2]);
288       }
289       a.getXYZ(d);
290       allPositions[index] = d[0] * OpenMM_NmPerAngstrom;
291       allPositions[index + 1] = d[1] * OpenMM_NmPerAngstrom;
292       allPositions[index + 2] = d[2] * OpenMM_NmPerAngstrom;
293       index += 3;
294     }
295     super.setPositions(allPositions);
296   }
297 
298   /**
299    * The array v contains velocity values for active atomic coordinates.
300    *
301    * @param v Velocity array for active atoms.
302    */
303   public void setVelocities(double[] v) {
304     double[] allVelocities = new double[3 * atoms.length];
305     double[] velocity = new double[3];
306     int index = 0;
307     for (Atom a : atoms) {
308       if (a.isActive()) {
309         a.setVelocity(v[index], v[index + 1], v[index + 2]);
310       } else {
311         // OpenMM requires velocities for even "inactive" atoms with mass of zero.
312         a.setVelocity(0.0, 0.0, 0.0);
313       }
314       a.getVelocity(velocity);
315       allVelocities[index] = velocity[0] * OpenMM_NmPerAngstrom;
316       allVelocities[index + 1] = velocity[1] * OpenMM_NmPerAngstrom;
317       allVelocities[index + 2] = velocity[2] * OpenMM_NmPerAngstrom;
318       index += 3;
319     }
320     super.setVelocities(allVelocities);
321   }
322 
323   /**
324    * Set the periodic box vectors for a context based on the crystal instance.
325    *
326    * @param crystal The crystal instance.
327    */
328   public void setPeriodicBoxVectors(Crystal crystal) {
329     if (!crystal.aperiodic()) {
330       OpenMM_Vec3 a = new OpenMM_Vec3();
331       OpenMM_Vec3 b = new OpenMM_Vec3();
332       OpenMM_Vec3 c = new OpenMM_Vec3();
333       double[][] Ai = crystal.Ai;
334       a.x = Ai[0][0] * OpenMM_NmPerAngstrom;
335       a.y = Ai[0][1] * OpenMM_NmPerAngstrom;
336       a.z = Ai[0][2] * OpenMM_NmPerAngstrom;
337       b.x = Ai[1][0] * OpenMM_NmPerAngstrom;
338       b.y = Ai[1][1] * OpenMM_NmPerAngstrom;
339       b.z = Ai[1][2] * OpenMM_NmPerAngstrom;
340       c.x = Ai[2][0] * OpenMM_NmPerAngstrom;
341       c.y = Ai[2][1] * OpenMM_NmPerAngstrom;
342       c.z = Ai[2][2] * OpenMM_NmPerAngstrom;
343       setPeriodicBoxVectors(a, b, c);
344     }
345   }
346 
347   /**
348    * {@inheritDoc}
349    */
350   @Override
351   public String toString() {
352     return format(
353         " OpenMM context with integrator %s, timestep %9.3g fsec, temperature %9.3g K",
354         integratorName, timeStep, temperature);
355   }
356 
357   /**
358    * Load an OpenMM Platform
359    *
360    * @param requestedPlatform the requested OpenMM platform.
361    * @param forceField        the ForceField to query for platform flags.
362    * @return the loaded Platform.
363    */
364   public static Platform loadPlatform(ffx.potential.Platform requestedPlatform, ForceField forceField) {
365 
366     OpenMMUtils.init();
367 
368     logger.log(Level.INFO, " Loaded from:\n {0}", OpenMMLibrary.JNA_NATIVE_LIB.toString());
369 
370     // Print out the OpenMM Version.
371     logger.log(Level.INFO, " Version: {0}", getOpenMMVersion());
372 
373     // Print out the OpenMM lib directory.
374     String libDirectory = OpenMMUtils.getLibDirectory();
375     logger.log(Level.FINE, " Lib Directory:       {0}", libDirectory);
376     // Load platforms and print out their names.
377     StringArray libs = loadPluginsFromDirectory(libDirectory);
378     int numLibs = libs.getSize();
379     logger.log(Level.FINE, " Number of libraries: {0}", numLibs);
380     for (int i = 0; i < numLibs; i++) {
381       logger.log(Level.FINE, "  Library: {0}", libs.get(i));
382     }
383     libs.destroy();
384 
385     // Print out the OpenMM plugin directory.
386     String pluginDirectory = OpenMMUtils.getPluginDirectory();
387     logger.log(Level.INFO, "\n Plugin Directory:  {0}", pluginDirectory);
388     // Load plugins and print out their names.
389     StringArray plugins = loadPluginsFromDirectory(pluginDirectory);
390     int numPlugins = plugins.getSize();
391     logger.log(Level.INFO, " Number of Plugins: {0}", numPlugins);
392     boolean cuda = false;
393     boolean opencl = false;
394     for (int i = 0; i < numPlugins; i++) {
395       String pluginString = plugins.get(i);
396       logger.log(Level.INFO, "  Plugin: {0}", pluginString);
397       if (pluginString != null) {
398         pluginString = pluginString.toUpperCase();
399         boolean amoebaCudaAvailable = pluginString.contains("AMOEBACUDA");
400         if (amoebaCudaAvailable) {
401           cuda = true;
402         }
403         boolean amoebaOpenCLAvailable = pluginString.contains("AMOEBAOPENCL");
404         if (amoebaOpenCLAvailable) {
405           opencl = true;
406         }
407       }
408     }
409     plugins.destroy();
410 
411     int numPlatforms = getNumPlatforms();
412     logger.log(Level.INFO, " Number of Platforms: {0}", numPlatforms);
413 
414     if (requestedPlatform == OMM_CUDA && !cuda) {
415       logger.severe(" The OMM_CUDA platform was requested, but is not available.");
416     }
417 
418     if (requestedPlatform == ffx.potential.Platform.OMM_OPENCL && !opencl) {
419       logger.severe(" The OMM_OPENCL platform was requested, but is not available.");
420     }
421 
422     // Extra logging to print out plugins that failed to load.
423     if (logger.isLoggable(Level.FINE)) {
424       StringArray pluginFailures = getPluginLoadFailures();
425       int numFailures = pluginFailures.getSize();
426       for (int i = 0; i < numFailures; i++) {
427         logger.log(Level.FINE, " Plugin load failure: {0}", pluginFailures.get(i));
428       }
429       pluginFailures.destroy();
430     }
431 
432     String defaultPrecision = "mixed";
433     String precision = forceField.getString("PRECISION", defaultPrecision).toLowerCase();
434     precision = precision.replace("-precision", "");
435     switch (precision) {
436       case "double", "mixed", "single" -> logger.info(format(" Precision level: %s", precision));
437       default -> {
438         logger.info(format(" Could not interpret precision level %s, defaulting to %s", precision, defaultPrecision));
439         precision = defaultPrecision;
440       }
441     }
442 
443 
444     Platform openMMPlatform;
445     if (cuda && (requestedPlatform == OMM_CUDA || requestedPlatform == OMM)) {
446       // CUDA
447       int defaultDevice = getDefaultDevice(forceField.getProperties());
448       openMMPlatform = new Platform("CUDA");
449       // CUDA_DEVICE is deprecated; use DeviceIndex.
450       int deviceID = forceField.getInteger("CUDA_DEVICE", defaultDevice);
451       deviceID = forceField.getInteger("DeviceIndex", deviceID);
452       String deviceIDString = Integer.toString(deviceID);
453       openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
454       openMMPlatform.setPropertyDefaultValue("Precision", precision);
455       String name = openMMPlatform.getName();
456       logger.info(format(" Platform: %s (Device Index %d)", name, deviceID));
457     } else if (opencl && (requestedPlatform == OMM_OPENCL || requestedPlatform == OMM)) {
458       // OpenCL
459       int defaultDevice = getDefaultDevice(forceField.getProperties());
460       openMMPlatform = new Platform("OpenCL");
461       int deviceID = forceField.getInteger("DeviceIndex", defaultDevice);
462       String deviceIDString = Integer.toString(deviceID);
463       openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
464       int openCLPlatformIndex = forceField.getInteger("OpenCLPlatformIndex", 0);
465       String openCLPlatformIndexString = Integer.toString(openCLPlatformIndex);
466       openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
467       openMMPlatform.setPropertyDefaultValue("OpenCLPlatformIndex", openCLPlatformIndexString);
468       openMMPlatform.setPropertyDefaultValue("Precision", precision);
469       String name = openMMPlatform.getName();
470       logger.info(format(" Platform: %s (Platform Index %d, Device Index %d)",
471           name, openCLPlatformIndex, deviceID));
472     } else {
473       // Reference
474       openMMPlatform = new Platform("Reference");
475       String name = openMMPlatform.getName();
476       logger.info(format(" Platform: %s", name));
477     }
478 
479     try {
480       Comm world = Comm.world();
481       if (world != null) {
482         logger.info(format(" Running on host %s, rank %d", world.host(), world.rank()));
483       }
484     } catch (IllegalStateException illegalStateException) {
485       logger.fine(" Could not find the world communicator!");
486     }
487 
488     return openMMPlatform;
489   }
490 
491   /**
492    * Free OpenMM memory for the current Context and Integrator.
493    */
494   public void free() {
495     if (openMMIntegrator != null) {
496       openMMIntegrator.destroy();
497     }
498     destroy();
499   }
500 }