View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.potential.openmm;
39  
40  import edu.rit.mp.CharacterBuf;
41  import edu.rit.pj.Comm;
42  import ffx.crystal.Crystal;
43  import ffx.potential.FiniteDifferenceUtils;
44  import ffx.potential.ForceFieldEnergy;
45  import ffx.potential.MolecularAssembly;
46  import ffx.potential.Platform;
47  import ffx.potential.Utilities;
48  import ffx.potential.bonded.Atom;
49  import ffx.potential.parameters.ForceField;
50  import ffx.potential.utils.EnergyException;
51  import ffx.potential.utils.PotentialsFunctions;
52  import ffx.potential.utils.PotentialsUtils;
53  import org.apache.commons.configuration2.CompositeConfiguration;
54  import org.apache.commons.io.FilenameUtils;
55  
56  import javax.annotation.Nullable;
57  import java.io.File;
58  import java.io.IOException;
59  import java.time.LocalDateTime;
60  import java.time.format.DateTimeFormatter;
61  import java.util.ArrayList;
62  import java.util.Arrays;
63  import java.util.List;
64  import java.util.logging.Level;
65  import java.util.logging.Logger;
66  import java.util.stream.Collectors;
67  import java.util.stream.IntStream;
68  
69  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Energy;
70  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Forces;
71  import static java.lang.Double.isFinite;
72  import static java.lang.String.format;
73  
74  /**
75   * Compute the potential energy and derivatives using OpenMM.
76   *
77   * @author Michael J. Schnieders
78   * @since 1.0
79   */
80  @SuppressWarnings("deprecation")
81  public class OpenMMEnergy extends ForceFieldEnergy {
82  
83    private static final Logger logger = Logger.getLogger(OpenMMEnergy.class.getName());
84  
85    /**
86     * FFX Platform.
87     */
88    private final Platform platform;
89    /**
90     * OpenMM Context.
91     */
92    private OpenMMContext openMMContext;
93    /**
94     * OpenMM System.
95     */
96    private OpenMMSystem openMMSystem;
97    /**
98     * The atoms this ForceFieldEnergyOpenMM operates on.
99     */
100   private final Atom[] atoms;
101   /**
102    * If true, compute dUdL.
103    */
104   private final boolean computeDEDL;
105 
106   /**
107    * ForceFieldEnergyOpenMM constructor; offloads heavy-duty computation to an OpenMM Platform while
108    * keeping track of information locally.
109    *
110    * @param molecularAssembly Assembly to construct energy for.
111    * @param requestedPlatform requested OpenMM platform to be used.
112    * @param nThreads          Number of threads to use in the super class ForceFieldEnergy instance.
113    */
114   public OpenMMEnergy(MolecularAssembly molecularAssembly, Platform requestedPlatform, int nThreads) {
115     super(molecularAssembly, nThreads);
116 
117     Crystal crystal = getCrystal();
118     int symOps = crystal.spaceGroup.getNumberOfSymOps();
119     if (symOps > 1) {
120       logger.info("");
121       logger.severe(" OpenMM does not support symmetry operators.");
122     }
123 
124     logger.info("\n Initializing OpenMM");
125 
126     ForceField forceField = molecularAssembly.getForceField();
127     atoms = molecularAssembly.getAtomArray();
128 
129     // Load the OpenMM plugins
130     this.platform = requestedPlatform;
131     ffx.openmm.Platform openMMPlatform = OpenMMContext.loadPlatform(platform, forceField);
132 
133     // Create the OpenMM System.
134     openMMSystem = new OpenMMSystem(this);
135     openMMSystem.addForces();
136 
137     // Create the Context.
138     openMMContext = new OpenMMContext(openMMPlatform, openMMSystem, atoms);
139 
140     computeDEDL = forceField.getBoolean("OMM_DUDL", false);
141   }
142 
143   /**
144    * Gets the default coprocessor device, ignoring any CUDA_DEVICE over-ride. This is either
145    * determined by process rank and the availableDevices/CUDA_DEVICES property, or just 0 if neither
146    * property is sets.
147    *
148    * @param props Properties in use.
149    * @return Pre-override device index.
150    */
151   public static int getDefaultDevice(CompositeConfiguration props) {
152     String availDeviceProp = props.getString("availableDevices", props.getString("CUDA_DEVICES"));
153     if (availDeviceProp == null) {
154       int nDevs = props.getInt("numCudaDevices", 1);
155       availDeviceProp = IntStream.range(0, nDevs).mapToObj(Integer::toString)
156           .collect(Collectors.joining(" "));
157     }
158     availDeviceProp = availDeviceProp.trim();
159 
160     String[] availDevices = availDeviceProp.split("\\s+");
161     int nDevs = availDevices.length;
162     int[] devs = new int[nDevs];
163     for (int i = 0; i < nDevs; i++) {
164       devs[i] = Integer.parseInt(availDevices[i]);
165     }
166 
167     logger.info(format(" Available devices: %d.", nDevs));
168 
169     // If only one device is available, return it.
170     if (nDevs == 1) {
171       return devs[0];
172     }
173 
174     int index = 0;
175     try {
176       Comm world = Comm.world();
177       if (world != null) {
178         int size = world.size();
179         logger.fine(format(" Number of MPI processes %d exceeds number of available devices %d.", size, nDevs));
180 
181         // Format the host as a CharacterBuf of length 100.
182         int messageLen = 100;
183         String host = world.host();
184         // Truncate to max 100 characters.
185         host = host.substring(0, Math.min(messageLen, host.length()));
186         // Pad to 100 characters.
187         host = format("%-100s", host);
188 
189         logger.fine(format(" Host: %s", host.trim()));
190         char[] messageOut = host.toCharArray();
191         CharacterBuf out = CharacterBuf.buffer(messageOut);
192 
193         // Now create CharacterBuf array for all incoming messages.
194         char[][] incoming = new char[size][messageLen];
195         CharacterBuf[] in = new CharacterBuf[size];
196         for (int i = 0; i < size; i++) {
197           in[i] = CharacterBuf.buffer(incoming[i]);
198         }
199 
200         try {
201           logger.fine(" AllGather for determining rank.");
202           world.allGather(out, in);
203           logger.fine(" AllGather complete.");
204         } catch (IOException ex) {
205           logger.warning(format(" Failure at the allGather step for determining rank: %s\n%s", ex, Utilities.stackTraceToString(ex)));
206         }
207         int ownIndex = -1;
208         int rank = world.rank();
209         boolean selfFound = false;
210 
211         for (int i = 0; i < size; i++) {
212           String hostI = new String(incoming[i]);
213           if (hostI.equalsIgnoreCase(host)) {
214             ++ownIndex;
215             if (i == rank) {
216               selfFound = true;
217               break;
218             }
219           }
220         }
221         if (!selfFound) {
222           logger.warning(format(" Rank %d: Could not find any incoming host messages matching self %s!", rank, host.trim()));
223         } else {
224           index = ownIndex % nDevs;
225         }
226       }
227     } catch (IllegalStateException ise) {
228       // Behavior is just to keep index = 0.
229     }
230     return devs[index];
231   }
232 
233   /**
234    * Create an OpenMM Context.
235    *
236    * <p>Context.free() must be called to free OpenMM memory.
237    *
238    * @param integratorName Integrator to use.
239    * @param timeStep       Time step.
240    * @param temperature    Temperature (K).
241    * @param forceCreation  Force a new Context to be created, even if the existing one matches the
242    *                       request.
243    */
244   public void updateContext(String integratorName, double timeStep, double temperature, boolean forceCreation) {
245     openMMContext.update(integratorName, timeStep, temperature, forceCreation);
246   }
247 
248   /**
249    * Create an immutable OpenMM State.
250    *
251    * <p>State.free() must be called to free OpenMM memory.
252    *
253    * @param mask The State mask.
254    * @return Returns the State.
255    */
256   public OpenMMState getOpenMMState(int mask) {
257     return openMMContext.getOpenMMState(mask);
258   }
259 
260   /**
261    * {@inheritDoc}
262    */
263   @Override
264   public boolean destroy() {
265     boolean ffxFFEDestroy = super.destroy();
266     free();
267     logger.fine(" Destroyed the Context, Integrator, and OpenMMSystem.");
268     return ffxFFEDestroy;
269   }
270 
271   /**
272    * {@inheritDoc}
273    */
274   @Override
275   public double energy(double[] x) {
276     return energy(x, false);
277   }
278 
279   /**
280    * {@inheritDoc}
281    */
282   @Override
283   public double energy(double[] x, boolean verbose) {
284 
285     if (lambdaBondedTerms) {
286       return 0.0;
287     }
288 
289     // Make sure the context has been created.
290     openMMContext.update();
291 
292     updateParameters(atoms);
293 
294     // Unscale the coordinates.
295     unscaleCoordinates(x);
296 
297     setCoordinates(x);
298 
299     OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy);
300     double e = openMMState.potentialEnergy;
301     openMMState.destroy();
302 
303     if (!isFinite(e)) {
304       String message = String.format(" Energy from OpenMM was a non-finite %8g", e);
305       logger.warning(message);
306       throw new EnergyException(message);
307     }
308 
309     if (verbose) {
310       logger.log(Level.INFO, String.format("\n OpenMM Energy: %14.10g", e));
311     }
312 
313     // Rescale the coordinates.
314     scaleCoordinates(x);
315 
316     return e;
317   }
318 
319   /**
320    * {@inheritDoc}
321    */
322   @Override
323   public double energyAndGradient(double[] x, double[] g) {
324     return energyAndGradient(x, g, false);
325   }
326 
327   /**
328    * {@inheritDoc}
329    */
330   @Override
331   public double energyAndGradient(double[] x, double[] g, boolean verbose) {
332     if (lambdaBondedTerms) {
333       return 0.0;
334     }
335 
336     // Un-scale the coordinates.
337     unscaleCoordinates(x);
338 
339     // Make sure a context has been created.
340     openMMContext.update();
341 
342     // long time = -System.nanoTime();
343     setCoordinates(x);
344     // time += System.nanoTime();
345     // logger.info(format(" Load coordinates time %10.6f (sec)", time * 1.0e-9));
346 
347     // time = -System.nanoTime();
348     OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy | OpenMM_State_Forces);
349     double e = openMMState.potentialEnergy;
350     g = openMMState.getGradient(g);
351     openMMState.destroy();
352     // time += System.nanoTime();
353     // logger.info(format(" Calculate energy time %10.6f (sec)", time * 1.0e-9));
354 
355     if (!isFinite(e)) {
356       String message = format(" Energy from OpenMM was a non-finite %8g", e);
357       logger.warning(message);
358       throw new EnergyException(message);
359     }
360 
361     // if (vdwLambdaTerm) {
362     //    PointerByReference parameterArray = OpenMM_State_getEnergyParameterDerivatives(state);
363     //    int numDerives = OpenMM_ParameterArray_getSize(parameterArray);
364     //    if (numDerives > 0) {
365     //        double vdwdUdL = OpenMM_ParameterArray_get(parameterArray,
366     // pointerForString("vdw_lambda")) / OpenMM_KJPerKcal;
367     //    }
368     // }
369 
370     if (maxDebugGradient < Double.MAX_VALUE) {
371       boolean extremeGrad = Arrays.stream(g)
372           .anyMatch((double gi) -> (gi > maxDebugGradient || gi < -maxDebugGradient));
373       if (extremeGrad) {
374         File origFile = molecularAssembly.getFile();
375         String timeString = LocalDateTime.now()
376             .format(DateTimeFormatter.ofPattern("yyyy_MM_dd-HH_mm_ss"));
377 
378         String filename = format("%s-LARGEGRAD-%s.pdb",
379             FilenameUtils.removeExtension(molecularAssembly.getFile().getName()), timeString);
380         PotentialsFunctions ef = new PotentialsUtils();
381         filename = ef.versionFile(filename);
382 
383         logger.warning(
384             format(" Excessively large gradients detected; printing snapshot to file %s", filename));
385         ef.saveAsPDB(molecularAssembly, new File(filename));
386         molecularAssembly.setFile(origFile);
387       }
388     }
389 
390     if (verbose) {
391       logger.log(Level.INFO, format("\n OpenMM Energy: %14.10g", e));
392     }
393 
394     // Scale the coordinates and gradients.
395     scaleCoordinatesAndGradient(x, g);
396 
397     return e;
398   }
399 
400   /**
401    * Compute the energy and gradient using the pure Java code path.
402    *
403    * @param x Input atomic coordinates
404    * @param g Storage for the gradient vector.
405    * @return The energy (kcal/mol)
406    */
407   public double energyAndGradientFFX(double[] x, double[] g) {
408     return super.energyAndGradient(x, g, false);
409   }
410 
411   /**
412    * Compute the energy and gradient using the pure Java code path.
413    *
414    * @param x       Input atomic coordinates
415    * @param g       Storage for the gradient vector.
416    * @param verbose Use verbose logging.
417    * @return The energy (kcal/mol)
418    */
419   public double energyAndGradientFFX(double[] x, double[] g, boolean verbose) {
420     return super.energyAndGradient(x, g, verbose);
421   }
422 
423   /**
424    * Compute the energy using the pure Java code path.
425    *
426    * @param x Atomic coordinates.
427    * @return The energy (kcal/mol)
428    */
429   public double energyFFX(double[] x) {
430     return super.energy(x, false);
431   }
432 
433   /**
434    * Compute the energy using the pure Java code path.
435    *
436    * @param x       Input atomic coordinates
437    * @param verbose Use verbose logging.
438    * @return The energy (kcal/mol)
439    */
440   public double energyFFX(double[] x, boolean verbose) {
441     return super.energy(x, verbose);
442   }
443 
444   /**
445    * Returns the Context instance.
446    *
447    * @return context
448    */
449   public OpenMMContext getContext() {
450     return openMMContext;
451   }
452 
453   /**
454    * Returns the MolecularAssembly instance.
455    *
456    * @return molecularAssembly
457    */
458   public MolecularAssembly getMolecularAssembly() {
459     return molecularAssembly;
460   }
461 
462   /**
463    * Set the lambdaTerm flag.
464    *
465    * @param lambdaTerm The value to set.
466    */
467   public void setLambdaTerm(boolean lambdaTerm) {
468     this.lambdaTerm = lambdaTerm;
469   }
470 
471   /**
472    * Get the lambdaTerm flag.
473    *
474    * @return lambdaTerm.
475    */
476   public boolean getLambdaTerm() {
477     return lambdaTerm;
478   }
479 
480   /**
481    * Re-compute the gradient using OpenMM and return it.
482    *
483    * @param g Gradient array.
484    */
485   @Override
486   public double[] getGradient(double[] g) {
487     OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Forces);
488     g = openMMState.getGradient(g);
489     openMMState.destroy();
490     return g;
491   }
492 
493   /**
494    * {@inheritDoc}
495    */
496   @Override
497   public Platform getPlatform() {
498     return platform;
499   }
500 
501   /**
502    * Get a reference to the System instance.
503    *
504    * @return Java wrapper to an OpenMM system.
505    */
506   public OpenMMSystem getSystem() {
507     return openMMSystem;
508   }
509 
510   /**
511    * {@inheritDoc}
512    */
513   @Override
514   public double getd2EdL2() {
515     return 0.0;
516   }
517 
518   /**
519    * {@inheritDoc}
520    */
521   @Override
522   public double getdEdL() {
523     // No lambda dependence.
524     if (!lambdaTerm || !computeDEDL) {
525       return 0.0;
526     }
527 
528     return FiniteDifferenceUtils.computedEdL(this, this, molecularAssembly.getForceField());
529   }
530 
531   /**
532    * {@inheritDoc}
533    */
534   @Override
535   public void getdEdXdL(double[] gradients) {
536     // Note for ForceFieldEnergyOpenMM this method is not implemented.
537   }
538 
539   /**
540    * Update active atoms.
541    */
542   public void setActiveAtoms() {
543     openMMSystem.updateAtomMass();
544     // Tests show reinitialization of the OpenMM Context is not necessary to pick up mass changes.
545     // context.reinitContext();
546   }
547 
548   /**
549    * Set FFX and OpenMM coordinates for active atoms.
550    *
551    * @param x Atomic coordinates.
552    */
553   @Override
554   public void setCoordinates(double[] x) {
555     // Set both OpenMM and FFX coordinates to x.
556     openMMContext.setPositions(x);
557   }
558 
559   /**
560    * {@inheritDoc}
561    */
562   @Override
563   public void setCrystal(Crystal crystal) {
564     super.setCrystal(crystal);
565     openMMContext.setPeriodicBoxVectors(crystal);
566   }
567 
568   /**
569    * {@inheritDoc}
570    */
571   @Override
572   public void setLambda(double lambda) {
573     if (!lambdaTerm) {
574       logger.fine(" Attempting to set lambda for an OpenMMEnergy with lambdaterm false.");
575       return;
576     }
577 
578     super.setLambda(lambda);
579 
580     if (atoms != null) {
581       List<Atom> atomList = new ArrayList<>();
582       for (Atom atom : atoms) {
583         if (atom.applyLambda()) {
584           atomList.add(atom);
585         }
586       }
587       // Update force field parameters based on defined lambda values.
588       updateParameters(atomList.toArray(new Atom[0]));
589     } else {
590       updateParameters(null);
591     }
592 
593   }
594 
595   /**
596    * Update parameters if the Use flags and/or Lambda value has changed.
597    *
598    * @param atoms Atoms in this list are considered.
599    */
600   public void updateParameters(@Nullable Atom[] atoms) {
601     if (atoms == null) {
602       atoms = this.atoms;
603     }
604     if (openMMSystem != null) {
605       openMMSystem.updateParameters(atoms);
606     }
607   }
608 
609   /**
610    * Free OpenMM memory for the Context, Integrator and System.
611    */
612   private void free() {
613     if (openMMContext != null) {
614       openMMContext.free();
615       openMMContext = null;
616     }
617     if (openMMSystem != null) {
618       openMMSystem.free();
619       openMMSystem = null;
620     }
621   }
622 
623 }