View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.potential.openmm;
39  
40  import edu.rit.mp.CharacterBuf;
41  import edu.rit.pj.Comm;
42  import ffx.crystal.Crystal;
43  import ffx.potential.FiniteDifferenceUtils;
44  import ffx.potential.ForceFieldEnergy;
45  import ffx.potential.MolecularAssembly;
46  import ffx.potential.Platform;
47  import ffx.potential.Utilities;
48  import ffx.potential.bonded.Atom;
49  import ffx.potential.parameters.ForceField;
50  import ffx.potential.utils.EnergyException;
51  import ffx.potential.utils.PotentialsFunctions;
52  import ffx.potential.utils.PotentialsUtils;
53  import org.apache.commons.configuration2.CompositeConfiguration;
54  import org.apache.commons.io.FilenameUtils;
55  
56  import javax.annotation.Nullable;
57  import java.io.File;
58  import java.io.IOException;
59  import java.time.LocalDateTime;
60  import java.time.format.DateTimeFormatter;
61  import java.util.ArrayList;
62  import java.util.Arrays;
63  import java.util.List;
64  import java.util.logging.Level;
65  import java.util.logging.Logger;
66  import java.util.stream.Collectors;
67  import java.util.stream.IntStream;
68  
69  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Energy;
70  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Forces;
71  import static java.lang.Double.isFinite;
72  import static java.lang.String.format;
73  
74  /**
75   * Compute the potential energy and derivatives using OpenMM.
76   *
77   * @author Michael J. Schnieders
78   * @since 1.0
79   */
80  public class OpenMMEnergy extends ForceFieldEnergy implements OpenMMPotential {
81  
82    private static final Logger logger = Logger.getLogger(OpenMMEnergy.class.getName());
83  
84    /**
85     * FFX Platform.
86     */
87    private final Platform platform;
88    /**
89     * OpenMM Context.
90     */
91    private OpenMMContext openMMContext;
92    /**
93     * OpenMM System.
94     */
95    private OpenMMSystem openMMSystem;
96    /**
97     * The atoms this OpenMMEnergy operates on.
98     */
99    private final Atom[] atoms;
100   /**
101    * If true, compute dUdL.
102    */
103   private final boolean computeDEDL;
104 
105   /**
106    * ForceFieldEnergyOpenMM constructor; offloads heavy-duty computation to an OpenMM Platform while
107    * keeping track of information locally.
108    *
109    * @param molecularAssembly Assembly to construct energy for.
110    * @param requestedPlatform requested OpenMM platform to be used.
111    * @param nThreads          Number of threads to use in the super class ForceFieldEnergy instance.
112    */
113   public OpenMMEnergy(MolecularAssembly molecularAssembly, Platform requestedPlatform, int nThreads) {
114     super(molecularAssembly, nThreads);
115 
116     Crystal crystal = getCrystal();
117     int symOps = crystal.spaceGroup.getNumberOfSymOps();
118     if (symOps > 1) {
119       logger.severe(" OpenMM does not support symmetry operators.");
120     }
121 
122     logger.info("\n Initializing OpenMM");
123 
124     ForceField forceField = molecularAssembly.getForceField();
125     atoms = molecularAssembly.getAtomArray();
126 
127     // Load the OpenMM plugins
128     this.platform = requestedPlatform;
129     ffx.openmm.Platform openMMPlatform = OpenMMContext.loadPlatform(platform, forceField);
130 
131     // Create the OpenMM System.
132     openMMSystem = new OpenMMSystem(this);
133     openMMSystem.addForces();
134 
135     // Create the Context.
136     openMMContext = new OpenMMContext(openMMPlatform, openMMSystem, atoms);
137 
138     computeDEDL = forceField.getBoolean("OMM_DUDL", false);
139   }
140 
141   /**
142    * Gets the default coprocessor device, ignoring any CUDA_DEVICE over-ride. This is either
143    * determined by process rank and the availableDevices/CUDA_DEVICES property, or just 0 if neither
144    * property is sets.
145    *
146    * @param props Properties in use.
147    * @return Pre-override device index.
148    */
149   public static int getDefaultDevice(CompositeConfiguration props) {
150     String availDeviceProp = props.getString("availableDevices", props.getString("CUDA_DEVICES"));
151     if (availDeviceProp == null) {
152       int nDevs = props.getInt("numCudaDevices", 1);
153       availDeviceProp = IntStream.range(0, nDevs).mapToObj(Integer::toString)
154           .collect(Collectors.joining(" "));
155     }
156     availDeviceProp = availDeviceProp.trim();
157 
158     String[] availDevices = availDeviceProp.split("\\s+");
159     int nDevs = availDevices.length;
160     int[] devs = new int[nDevs];
161     for (int i = 0; i < nDevs; i++) {
162       devs[i] = Integer.parseInt(availDevices[i]);
163     }
164 
165     logger.info(format(" Available devices: %d.", nDevs));
166 
167     // If only one device is available, return it.
168     if (nDevs == 1) {
169       return devs[0];
170     }
171 
172     int index = 0;
173     try {
174       Comm world = Comm.world();
175       if (world != null) {
176         int size = world.size();
177         logger.fine(format(" Number of MPI processes %d exceeds number of available devices %d.", size, nDevs));
178 
179         // Format the host as a CharacterBuf of length 100.
180         int messageLen = 100;
181         String host = world.host();
182         // Truncate to max 100 characters.
183         host = host.substring(0, Math.min(messageLen, host.length()));
184         // Pad to 100 characters.
185         host = format("%-100s", host);
186 
187         logger.fine(format(" Host: %s", host.trim()));
188         char[] messageOut = host.toCharArray();
189         CharacterBuf out = CharacterBuf.buffer(messageOut);
190 
191         // Now create CharacterBuf array for all incoming messages.
192         char[][] incoming = new char[size][messageLen];
193         CharacterBuf[] in = new CharacterBuf[size];
194         for (int i = 0; i < size; i++) {
195           in[i] = CharacterBuf.buffer(incoming[i]);
196         }
197 
198         try {
199           logger.fine(" AllGather for determining rank.");
200           world.allGather(out, in);
201           logger.fine(" AllGather complete.");
202         } catch (IOException ex) {
203           logger.warning(format(" Failure at the allGather step for determining rank: %s\n%s", ex, Utilities.stackTraceToString(ex)));
204         }
205         int ownIndex = -1;
206         int rank = world.rank();
207         boolean selfFound = false;
208 
209         for (int i = 0; i < size; i++) {
210           String hostI = new String(incoming[i]);
211           if (hostI.equalsIgnoreCase(host)) {
212             ++ownIndex;
213             if (i == rank) {
214               selfFound = true;
215               break;
216             }
217           }
218         }
219         if (!selfFound) {
220           logger.warning(format(" Rank %d: Could not find any incoming host messages matching self %s!", rank, host.trim()));
221         } else {
222           index = ownIndex % nDevs;
223         }
224       }
225     } catch (IllegalStateException ise) {
226       // Behavior is just to keep index = 0.
227     }
228     return devs[index];
229   }
230 
231   /**
232    * Create an OpenMM Context.
233    *
234    * <p>Context.free() must be called to free OpenMM memory.
235    *
236    * @param integratorName Integrator to use.
237    * @param timeStep       Time step.
238    * @param temperature    Temperature (K).
239    * @param forceCreation  Force a new Context to be created, even if the existing one matches the
240    *                       request.
241    */
242   @Override
243   public void updateContext(String integratorName, double timeStep, double temperature, boolean forceCreation) {
244     openMMContext.update(integratorName, timeStep, temperature, forceCreation);
245   }
246 
247   /**
248    * Create an immutable OpenMM State.
249    *
250    * <p>State.free() must be called to free OpenMM memory.
251    *
252    * @param mask The State mask.
253    * @return Returns the State.
254    */
255   @Override
256   public OpenMMState getOpenMMState(int mask) {
257     return openMMContext.getOpenMMState(mask);
258   }
259 
260   /**
261    * {@inheritDoc}
262    */
263   @Override
264   public boolean destroy() {
265     boolean ffxFFEDestroy = super.destroy();
266     free();
267     logger.fine(" Destroyed the Context and OpenMMSystem.");
268     return ffxFFEDestroy;
269   }
270 
271   /**
272    * {@inheritDoc}
273    */
274   @Override
275   public double energy(double[] x) {
276     return energy(x, false);
277   }
278 
279   /**
280    * {@inheritDoc}
281    */
282   @Override
283   public double energy(double[] x, boolean verbose) {
284 
285     if (lambdaBondedTerms) {
286       return 0.0;
287     }
288 
289     // Make sure the context has been created.
290     openMMContext.update();
291 
292     updateParameters(atoms);
293 
294     // Unscale the coordinates.
295     unscaleCoordinates(x);
296 
297     setCoordinates(x);
298 
299     OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy);
300     double e = openMMState.potentialEnergy;
301     openMMState.destroy();
302 
303     if (!isFinite(e)) {
304       String message = String.format(" Energy from OpenMM was a non-finite %8g", e);
305       logger.warning(message);
306       throw new EnergyException(message);
307     }
308 
309     if (verbose) {
310       logger.log(Level.INFO, String.format("\n OpenMM Energy: %14.10g", e));
311     }
312 
313     // Rescale the coordinates.
314     scaleCoordinates(x);
315 
316     return e;
317   }
318 
319   /**
320    * Compute the energy using the pure Java code path.
321    *
322    * @param x Atomic coordinates.
323    * @return The energy (kcal/mol)
324    */
325   public double energyFFX(double[] x) {
326     return super.energy(x, false);
327   }
328 
329   /**
330    * Compute the energy using the pure Java code path.
331    *
332    * @param x       Input atomic coordinates
333    * @param verbose Use verbose logging.
334    * @return The energy (kcal/mol)
335    */
336   public double energyFFX(double[] x, boolean verbose) {
337     return super.energy(x, verbose);
338   }
339 
340   /**
341    * {@inheritDoc}
342    */
343   @Override
344   public double energyAndGradient(double[] x, double[] g) {
345     return energyAndGradient(x, g, false);
346   }
347 
348   /**
349    * {@inheritDoc}
350    */
351   @Override
352   public double energyAndGradient(double[] x, double[] g, boolean verbose) {
353     if (lambdaBondedTerms) {
354       return 0.0;
355     }
356 
357     // Un-scale the coordinates.
358     unscaleCoordinates(x);
359 
360     // Make sure a context has been created.
361     openMMContext.update();
362 
363     // long time = -System.nanoTime();
364     setCoordinates(x);
365     // time += System.nanoTime();
366     // logger.info(format(" Load coordinates time %10.6f (sec)", time * 1.0e-9));
367 
368     // time = -System.nanoTime();
369     OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy | OpenMM_State_Forces);
370     double e = openMMState.potentialEnergy;
371     g = openMMState.getGradient(g);
372     openMMState.destroy();
373     // time += System.nanoTime();
374     // logger.info(format(" Calculate energy time %10.6f (sec)", time * 1.0e-9));
375 
376     if (!isFinite(e)) {
377       String message = format(" Energy from OpenMM was a non-finite %8g", e);
378       logger.warning(message);
379       throw new EnergyException(message);
380     }
381 
382     // if (vdwLambdaTerm) {
383     //    PointerByReference parameterArray = OpenMM_State_getEnergyParameterDerivatives(state);
384     //    int numDerives = OpenMM_ParameterArray_getSize(parameterArray);
385     //    if (numDerives > 0) {
386     //        double vdwdUdL = OpenMM_ParameterArray_get(parameterArray,
387     // pointerForString("vdw_lambda")) / OpenMM_KJPerKcal;
388     //    }
389     // }
390 
391     if (maxDebugGradient < Double.MAX_VALUE) {
392       boolean extremeGrad = Arrays.stream(g)
393           .anyMatch((double gi) -> (gi > maxDebugGradient || gi < -maxDebugGradient));
394       if (extremeGrad) {
395         File origFile = molecularAssembly.getFile();
396         String timeString = LocalDateTime.now()
397             .format(DateTimeFormatter.ofPattern("yyyy_MM_dd-HH_mm_ss"));
398 
399         String filename = format("%s-LARGEGRAD-%s.pdb",
400             FilenameUtils.removeExtension(molecularAssembly.getFile().getName()), timeString);
401         PotentialsFunctions ef = new PotentialsUtils();
402         filename = ef.versionFile(filename);
403 
404         logger.warning(
405             format(" Excessively large gradients detected; printing snapshot to file %s", filename));
406         ef.saveAsPDB(molecularAssembly, new File(filename));
407         molecularAssembly.setFile(origFile);
408       }
409     }
410 
411     if (verbose) {
412       logger.log(Level.INFO, format("\n OpenMM Energy: %14.10g", e));
413     }
414 
415     // Scale the coordinates and gradients.
416     scaleCoordinatesAndGradient(x, g);
417 
418     return e;
419   }
420 
421   /**
422    * Compute the energy and gradient using the pure Java code path.
423    *
424    * @param x Input atomic coordinates
425    * @param g Storage for the gradient vector.
426    * @return The energy (kcal/mol)
427    */
428   public double energyAndGradientFFX(double[] x, double[] g) {
429     return super.energyAndGradient(x, g, false);
430   }
431 
432   /**
433    * Compute the energy and gradient using the pure Java code path.
434    *
435    * @param x       Input atomic coordinates
436    * @param g       Storage for the gradient vector.
437    * @param verbose Use verbose logging.
438    * @return The energy (kcal/mol)
439    */
440   public double energyAndGradientFFX(double[] x, double[] g, boolean verbose) {
441     return super.energyAndGradient(x, g, verbose);
442   }
443 
444   /**
445    * Returns the Context instance.
446    *
447    * @return context
448    */
449   @Override
450   public OpenMMContext getContext() {
451     return openMMContext;
452   }
453 
454   /**
455    * Returns the MolecularAssembly instance.
456    *
457    * @return molecularAssembly
458    */
459   public MolecularAssembly getMolecularAssembly() {
460     return molecularAssembly;
461   }
462 
463   /**
464    * Re-compute the gradient using OpenMM and return it.
465    *
466    * @param g Gradient array.
467    */
468   @Override
469   public double[] getGradient(double[] g) {
470     OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Forces);
471     g = openMMState.getGradient(g);
472     openMMState.destroy();
473     return g;
474   }
475 
476   /**
477    * {@inheritDoc}
478    */
479   @Override
480   public Platform getPlatform() {
481     return platform;
482   }
483 
484   /**
485    * Get a reference to the System instance.
486    *
487    * @return Java wrapper to an OpenMM system.
488    */
489   @Override
490   public OpenMMSystem getSystem() {
491     return openMMSystem;
492   }
493 
494   /**
495    * {@inheritDoc}
496    */
497   @Override
498   public double getd2EdL2() {
499     return 0.0;
500   }
501 
502   /**
503    * {@inheritDoc}
504    */
505   @Override
506   public double getdEdL() {
507     // No lambda dependence.
508     if (!lambdaTerm || !computeDEDL) {
509       return 0.0;
510     }
511 
512     return FiniteDifferenceUtils.computedEdL(this, this, molecularAssembly.getForceField());
513   }
514 
515   /**
516    * {@inheritDoc}
517    */
518   @Override
519   public void getdEdXdL(double[] gradients) {
520     // Note for ForceFieldEnergyOpenMM this method is not implemented.
521   }
522 
523   /**
524    * Update active atoms.
525    */
526   @Override
527   public boolean setActiveAtoms() {
528     return openMMSystem.updateAtomMass();
529   }
530 
531   /**
532    * Coordinates for active atoms in units of Angstroms.
533    *
534    * @param x Atomic coordinates active atoms.
535    */
536   @Override
537   public void setCoordinates(double[] x) {
538     // Load the coordinates for active atoms.
539     super.setCoordinates(x);
540 
541     int n = atoms.length * 3;
542     double[] xall = new double[n];
543     int i = 0;
544     for (Atom atom : atoms) {
545       xall[i] = atom.getX();
546       xall[i + 1] = atom.getY();
547       xall[i + 2] = atom.getZ();
548       i += 3;
549     }
550 
551     // Load OpenMM coordinates for all atoms.
552     openMMContext.setPositions(xall);
553   }
554 
555   /**
556    * Velocities for active atoms in units of Angstroms.
557    *
558    * @param v Velocities for active atoms.
559    */
560   @Override
561   public void setVelocity(double[] v) {
562     // Load the velocity for active atoms.
563     super.setVelocity(v);
564 
565     int n = atoms.length * 3;
566     double[] vall = new double[n];
567     double[] v3 = new double[3];
568     int i = 0;
569     for (Atom atom : atoms) {
570       atom.getVelocity(v3);
571       vall[i] = v3[0];
572       vall[i + 1] = v3[1];
573       vall[i + 2] = v3[2];
574       i += 3;
575     }
576 
577     // Load OpenMM velocities for all atoms.
578     openMMContext.setVelocities(vall);
579   }
580 
581   /**
582    * {@inheritDoc}
583    */
584   @Override
585   public void setCrystal(Crystal crystal) {
586     super.setCrystal(crystal);
587     openMMContext.setPeriodicBoxVectors(crystal);
588   }
589 
590   /**
591    * {@inheritDoc}
592    */
593   @Override
594   public void setLambda(double lambda) {
595     if (!lambdaTerm) {
596       logger.fine(" Attempting to set lambda for an OpenMMEnergy with lambdaterm false.");
597       return;
598     }
599 
600     super.setLambda(lambda);
601 
602     if (atoms != null) {
603       List<Atom> atomList = new ArrayList<>();
604       for (Atom atom : atoms) {
605         if (atom.applyLambda()) {
606           atomList.add(atom);
607         }
608       }
609       // Update force field parameters based on defined lambda values.
610       updateParameters(atomList.toArray(new Atom[0]));
611     } else {
612       updateParameters(null);
613     }
614 
615   }
616 
617   /**
618    * Update parameters if the Use flags and/or Lambda value has changed.
619    *
620    * @param atoms Atoms in this list are considered.
621    */
622   @Override
623   public void updateParameters(@Nullable Atom[] atoms) {
624     if (atoms == null) {
625       atoms = this.atoms;
626     }
627     if (openMMSystem != null) {
628       openMMSystem.updateParameters(atoms);
629     }
630   }
631 
632   /**
633    * Free OpenMM memory for the Context and System.
634    */
635   private void free() {
636     if (openMMContext != null) {
637       openMMContext.free();
638       openMMContext = null;
639     }
640     if (openMMSystem != null) {
641       openMMSystem.free();
642       openMMSystem = null;
643     }
644   }
645 
646 }