View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2024.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.potential.openmm;
39  
40  import edu.rit.pj.Comm;
41  import edu.uiowa.jopenmm.OpenMMLibrary;
42  import edu.uiowa.jopenmm.OpenMMUtils;
43  import edu.uiowa.jopenmm.OpenMM_Vec3;
44  import ffx.crystal.Crystal;
45  import ffx.openmm.Context;
46  import ffx.openmm.Integrator;
47  import ffx.openmm.Platform;
48  import ffx.openmm.State;
49  import ffx.openmm.StringArray;
50  import ffx.potential.MolecularAssembly;
51  import ffx.potential.bonded.Atom;
52  
53  import java.util.logging.Level;
54  import java.util.logging.Logger;
55  
56  import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_KcalPerKJ;
57  import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_NmPerAngstrom;
58  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_LocalEnergyMinimizer_minimize;
59  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Positions;
60  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Velocities;
61  import static ffx.openmm.Platform.getNumPlatforms;
62  import static ffx.openmm.Platform.getOpenMMVersion;
63  import static ffx.openmm.Platform.getPluginLoadFailures;
64  import static ffx.openmm.Platform.loadPluginsFromDirectory;
65  import static ffx.potential.ForceFieldEnergy.DEFAULT_CONSTRAINT_TOLERANCE;
66  import static ffx.potential.openmm.OpenMMEnergy.getDefaultDevice;
67  import static ffx.potential.openmm.OpenMMIntegrator.createIntegrator;
68  import static java.lang.String.format;
69  
70  /**
71   * Creates and manage an OpenMM Context.
72   *
73   * <p>A Context stores the complete state of a simulation. More specifically, it includes: The
74   * current time The position of each particle The velocity of each particle The values of
75   * configurable parameters defined by Force objects in the System
76   *
77   * <p>You can retrieve a snapshot of the current state at any time by calling getState(). This
78   * allows you to record the state of the simulation at various points, either for analysis or for
79   * checkpointing. getState() can also be used to retrieve the current forces on each particle and
80   * the current energy of the System.
81   */
82  public class OpenMMContext extends Context {
83  
84    private static final Logger logger = Logger.getLogger(OpenMMContext.class.getName());
85  
86    /**
87     * Requested Platform (i.e. Java or an OpenMM platform).
88     */
89    private final ffx.potential.Platform platform;
90    /**
91     * OpenMM Platform pointer.
92     */
93    private Platform openMMPlatform = null;
94    /**
95     * ForceFieldEnergyOpenMM instance.
96     */
97    private final OpenMMEnergy openMMEnergy;
98    /**
99     * Instance of the OpenMM Integrator class.
100    */
101   private Integrator openMMIntegrator;
102   /**
103    * Integrator string (default = VERLET).
104    */
105   private String integratorName = "VERLET";
106   /**
107    * Time step (default = 0.001 psec).
108    */
109   private double timeStep = 0.001;
110   /**
111    * Temperature (default = 298.15).
112    */
113   private double temperature = 298.15;
114   /**
115    *
116    */
117   private final int enforcePBC;
118   /**
119    * Array of atoms.
120    */
121   private final Atom[] atoms;
122 
123   /**
124    * Create an OpenMM Context.
125    *
126    * @param requestedPlatform Platform requested.
127    * @param atoms             Array of atoms.
128    * @param enforcePBC        Enforce periodic boundary conditions.
129    * @param openMMEnergy      ForceFieldEnergyOpenMM instance.
130    */
131   public OpenMMContext(ffx.potential.Platform requestedPlatform, Atom[] atoms, int enforcePBC,
132                        OpenMMEnergy openMMEnergy) {
133     platform = requestedPlatform;
134     this.atoms = atoms;
135     this.enforcePBC = enforcePBC;
136     this.openMMEnergy = openMMEnergy;
137     loadPlatform(requestedPlatform);
138   }
139 
140   /**
141    * Update the Context in which to run a simulation.
142    *
143    * @param integratorName Requested integrator.
144    * @param timeStep       Time step (psec).
145    * @param temperature    Temperature (K).
146    * @param forceCreation  Force creation of a new context, even if the current one matches.
147    * @param openMMEnergy   ForceFieldEnergyOpenMM instance.
148    */
149   public void update(String integratorName, double timeStep, double temperature,
150                      boolean forceCreation, OpenMMEnergy openMMEnergy) {
151     // Check if the current context is consistent with the requested context.
152     if (hasContextPointer() && !forceCreation) {
153       if (this.temperature == temperature && this.timeStep == timeStep
154           && this.integratorName.equalsIgnoreCase(integratorName)) {
155         // All requested features agree.
156         return;
157       }
158     }
159 
160     this.integratorName = integratorName;
161     this.timeStep = timeStep;
162     this.temperature = temperature;
163 
164     logger.info("\n Updating OpenMM Context");
165 
166     // Update the integrator.
167     OpenMMSystem openMMSystem = openMMEnergy.getSystem();
168     if (openMMIntegrator != null) {
169       openMMIntegrator.destroy();
170     }
171     openMMIntegrator = createIntegrator(integratorName, timeStep, temperature, openMMSystem);
172 
173     // Set lambda to 1.0 when creating a context to avoid OpenMM compiling out any terms.
174     double currentLambda = openMMEnergy.getLambda();
175 
176     if (openMMEnergy.getLambdaTerm()) {
177       openMMEnergy.setLambda(1.0);
178     }
179 
180     // Update the context.
181     updateContext(openMMSystem, openMMIntegrator, openMMPlatform);
182 
183     // Revert to the current lambda value.
184     if (openMMEnergy.getLambdaTerm()) {
185       openMMEnergy.setLambda(currentLambda);
186     }
187 
188     // Get initial positions and velocities for active atoms.
189     int nVar = openMMEnergy.getNumberOfVariables();
190     double[] x = new double[nVar];
191     double[] v = new double[nVar];
192     double[] vel3 = new double[3];
193     int index = 0;
194     for (Atom a : atoms) {
195       if (a.isActive()) {
196         a.getVelocity(vel3);
197         // X-axis
198         x[index] = a.getX();
199         v[index++] = vel3[0];
200         // Y-axis
201         x[index] = a.getY();
202         v[index++] = vel3[1];
203         // Z-axis
204         x[index] = a.getZ();
205         v[index++] = vel3[2];
206       }
207     }
208 
209     // Load the current periodic box vectors.
210     Crystal crystal = openMMEnergy.getCrystal();
211     setPeriodicBoxVectors(crystal);
212 
213     // Load current atomic positions.
214     setPositions(x);
215 
216     // Load current velocities.
217     setVelocities(v);
218 
219     // Apply constraints starting from current atomic positions.
220     applyConstraints(DEFAULT_CONSTRAINT_TOLERANCE);
221 
222     // Application of constraints can change coordinates and velocities.
223     // Retrieve them for consistency.
224     OpenMMState openMMState = getOpenMMState(OpenMM_State_Positions | OpenMM_State_Velocities);
225     openMMState.getPositions(x);
226     openMMState.getVelocities(v);
227     openMMState.destroy();
228   }
229 
230   /**
231    * Update the Context if necessary.
232    */
233   public void update() {
234     if (!hasContextPointer()) {
235       openMMEnergy.updateContext(integratorName, timeStep, temperature, true);
236     }
237   }
238 
239   /**
240    * Get an OpenMM State from the Context.
241    *
242    * @param mask A mask specifying which information to retrieve.
243    * @return State pointer.
244    */
245   public OpenMMState getOpenMMState(int mask) {
246     State state = getState(mask, enforcePBC);
247     return new OpenMMState(state.getPointer(), atoms, openMMEnergy.getNumberOfVariables());
248   }
249 
250   public ffx.potential.Platform getPlatform() {
251     return platform;
252   }
253 
254   /**
255    * Use the Context / Integrator combination to take the requested number of steps.
256    *
257    * @param numSteps Number of steps to take.
258    */
259   public void integrate(int numSteps) {
260     openMMIntegrator.step(numSteps);
261   }
262 
263   /**
264    * Use the Context to optimize the system to the requested tolerance.
265    *
266    * @param eps           Convergence criteria (kcal/mole/A).
267    * @param maxIterations Maximum number of iterations.
268    */
269   public void optimize(double eps, int maxIterations) {
270     OpenMM_LocalEnergyMinimizer_minimize(getPointer(),
271         eps / (OpenMM_NmPerAngstrom * OpenMM_KcalPerKJ), maxIterations);
272   }
273 
274   /**
275    * The array x contains atomic coordinates only for active atoms.
276    *
277    * @param x Atomic coordinate array for only active atoms.
278    */
279   public void setPositions(double[] x) {
280     double[] allPositions = new double[3 * atoms.length];
281     double[] d = new double[3];
282     int index = 0;
283     for (Atom a : atoms) {
284       if (a.isActive()) {
285         a.moveTo(x[index], x[index + 1], x[index + 2]);
286       }
287       a.getXYZ(d);
288       allPositions[index] = d[0] * OpenMM_NmPerAngstrom;
289       allPositions[index + 1] = d[1] * OpenMM_NmPerAngstrom;
290       allPositions[index + 2] = d[2] * OpenMM_NmPerAngstrom;
291       index += 3;
292     }
293     super.setPositions(allPositions);
294   }
295 
296   /**
297    * The array v contains velocity values for active atomic coordinates.
298    *
299    * @param v Velocity array for active atoms.
300    */
301   public void setVelocities(double[] v) {
302     double[] allVelocities = new double[3 * atoms.length];
303     double[] velocity = new double[3];
304     int index = 0;
305     for (Atom a : atoms) {
306       if (a.isActive()) {
307         a.setVelocity(v[index], v[index + 1], v[index + 2]);
308       } else {
309         // OpenMM requires velocities for even "inactive" atoms with mass of zero.
310         a.setVelocity(0.0, 0.0, 0.0);
311       }
312       a.getVelocity(velocity);
313       allVelocities[index] = velocity[0] * OpenMM_NmPerAngstrom;
314       allVelocities[index + 1] = velocity[1] * OpenMM_NmPerAngstrom;
315       allVelocities[index + 2] = velocity[2] * OpenMM_NmPerAngstrom;
316       index += 3;
317     }
318     super.setVelocities(allVelocities);
319   }
320 
321   /**
322    * Set the periodic box vectors for a context based on the crystal instance.
323    *
324    * @param crystal The crystal instance.
325    */
326   public void setPeriodicBoxVectors(Crystal crystal) {
327     if (!crystal.aperiodic()) {
328       OpenMM_Vec3 a = new OpenMM_Vec3();
329       OpenMM_Vec3 b = new OpenMM_Vec3();
330       OpenMM_Vec3 c = new OpenMM_Vec3();
331       double[][] Ai = crystal.Ai;
332       a.x = Ai[0][0] * OpenMM_NmPerAngstrom;
333       a.y = Ai[0][1] * OpenMM_NmPerAngstrom;
334       a.z = Ai[0][2] * OpenMM_NmPerAngstrom;
335       b.x = Ai[1][0] * OpenMM_NmPerAngstrom;
336       b.y = Ai[1][1] * OpenMM_NmPerAngstrom;
337       b.z = Ai[1][2] * OpenMM_NmPerAngstrom;
338       c.x = Ai[2][0] * OpenMM_NmPerAngstrom;
339       c.y = Ai[2][1] * OpenMM_NmPerAngstrom;
340       c.z = Ai[2][2] * OpenMM_NmPerAngstrom;
341       setPeriodicBoxVectors(a, b, c);
342     }
343   }
344 
345   /**
346    * {@inheritDoc}
347    */
348   @Override
349   public String toString() {
350     return format(
351         " OpenMM context with integrator %s, timestep %9.3g fsec, temperature %9.3g K",
352         integratorName, timeStep, temperature);
353   }
354 
355   /**
356    * Load an OpenMM Platform
357    */
358   private void loadPlatform(ffx.potential.Platform requestedPlatform) {
359 
360     OpenMMUtils.init();
361 
362     logger.log(Level.INFO, " Loaded from:\n {0}", OpenMMLibrary.JNA_NATIVE_LIB.toString());
363 
364     // Print out the OpenMM Version.
365     logger.log(Level.INFO, " Version: {0}", getOpenMMVersion());
366 
367     // Print out the OpenMM lib directory.
368     String libDirectory = OpenMMUtils.getLibDirectory();
369     logger.log(Level.FINE, " Lib Directory:       {0}", libDirectory);
370     // Load platforms and print out their names.
371     StringArray libs = loadPluginsFromDirectory(libDirectory);
372     int numLibs = libs.getSize();
373     logger.log(Level.FINE, " Number of libraries: {0}", numLibs);
374     for (int i = 0; i < numLibs; i++) {
375       logger.log(Level.FINE, "  Library: {0}", libs.get(i));
376     }
377     libs.destroy();
378 
379     // Print out the OpenMM plugin directory.
380     String pluginDirectory = OpenMMUtils.getPluginDirectory();
381     logger.log(Level.INFO, "\n Plugin Directory:  {0}", pluginDirectory);
382     // Load plugins and print out their names.
383     StringArray plugins = loadPluginsFromDirectory(pluginDirectory);
384     int numPlugins = plugins.getSize();
385     logger.log(Level.INFO, " Number of Plugins: {0}", numPlugins);
386     boolean cuda = false;
387     for (int i = 0; i < numPlugins; i++) {
388       String pluginString = plugins.get(i);
389       logger.log(Level.INFO, "  Plugin: {0}", pluginString);
390       if (pluginString != null) {
391         pluginString = pluginString.toUpperCase();
392         boolean amoebaCudaAvailable = pluginString.contains("AMOEBACUDA");
393         if (amoebaCudaAvailable) {
394           cuda = true;
395         }
396       }
397     }
398     plugins.destroy();
399 
400     int numPlatforms = getNumPlatforms();
401     logger.log(Level.INFO, " Number of Platforms: {0}", numPlatforms);
402 
403     if (requestedPlatform == ffx.potential.Platform.OMM_CUDA && !cuda) {
404       logger.severe(" The OMM_CUDA platform was requested, but is not available.");
405     }
406 
407     // Extra logging to print out plugins that failed to load.
408     if (logger.isLoggable(Level.FINE)) {
409       StringArray pluginFailures = getPluginLoadFailures();
410       int numFailures = pluginFailures.getSize();
411       for (int i = 0; i < numFailures; i++) {
412         logger.log(Level.FINE, " Plugin load failure: {0}", pluginFailures.get(i));
413       }
414       pluginFailures.destroy();
415     }
416 
417     String defaultPrecision = "mixed";
418     MolecularAssembly molecularAssembly = openMMEnergy.getMolecularAssembly();
419     String precision = molecularAssembly.getForceField().getString("PRECISION", defaultPrecision).toLowerCase();
420     precision = precision.replace("-precision", "");
421     switch (precision) {
422       case "double", "mixed", "single" -> logger.info(format(" Precision level: %s", precision));
423       default -> {
424         logger.info(format(" Could not interpret precision level %s, defaulting to %s", precision, defaultPrecision));
425         precision = defaultPrecision;
426       }
427     }
428 
429     if (cuda && requestedPlatform != ffx.potential.Platform.OMM_REF) {
430       int defaultDevice = getDefaultDevice(molecularAssembly.getProperties());
431       openMMPlatform = new Platform("CUDA");
432       int deviceID = molecularAssembly.getForceField().getInteger("CUDA_DEVICE", defaultDevice);
433       String deviceIDString = Integer.toString(deviceID);
434 
435       openMMPlatform.setPropertyDefaultValue("CudaDeviceIndex", deviceIDString);
436       openMMPlatform.setPropertyDefaultValue("Precision", precision);
437       logger.info(format(" Platform: AMOEBA CUDA (Device ID %d)", deviceID));
438       try {
439         Comm world = Comm.world();
440         if (world != null) {
441           logger.info(format(" Running on host %s, rank %d", world.host(), world.rank()));
442         }
443       } catch (IllegalStateException illegalStateException) {
444         logger.fine(" Could not find the world communicator!");
445       }
446     } else {
447       openMMPlatform = new Platform("Reference");
448       logger.info(" Platform: AMOEBA CPU Reference");
449     }
450   }
451 
452   /**
453    * Free OpenMM memory for the current Context and Integrator.
454    */
455   public void free() {
456     if (openMMIntegrator != null) {
457       openMMIntegrator.destroy();
458     }
459     destroy();
460   }
461 }