View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.potential.openmm;
39  
40  import edu.rit.mp.CharacterBuf;
41  import edu.rit.pj.Comm;
42  import ffx.crystal.Crystal;
43  import ffx.potential.FiniteDifferenceUtils;
44  import ffx.potential.ForceFieldEnergy;
45  import ffx.potential.MolecularAssembly;
46  import ffx.potential.Platform;
47  import ffx.potential.Utilities;
48  import ffx.potential.bonded.Atom;
49  import ffx.potential.parameters.ForceField;
50  import ffx.potential.utils.EnergyException;
51  import ffx.potential.utils.PotentialsFunctions;
52  import ffx.potential.utils.PotentialsUtils;
53  import org.apache.commons.configuration2.CompositeConfiguration;
54  import org.apache.commons.io.FilenameUtils;
55  
56  import javax.annotation.Nullable;
57  import java.io.File;
58  import java.io.IOException;
59  import java.time.LocalDateTime;
60  import java.time.format.DateTimeFormatter;
61  import java.util.ArrayList;
62  import java.util.Arrays;
63  import java.util.List;
64  import java.util.logging.Level;
65  import java.util.logging.Logger;
66  import java.util.stream.Collectors;
67  import java.util.stream.IntStream;
68  
69  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_True;
70  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Energy;
71  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Forces;
72  import static java.lang.Double.isFinite;
73  import static java.lang.String.format;
74  
75  /**
76   * Compute the potential energy and derivatives using OpenMM.
77   *
78   * @author Michael J. Schnieders
79   * @since 1.0
80   */
81  public class OpenMMEnergy extends ForceFieldEnergy implements OpenMMPotential {
82  
83    private static final Logger logger = Logger.getLogger(OpenMMEnergy.class.getName());
84  
85    /**
86     * FFX Platform.
87     */
88    private final Platform platform;
89    /**
90     * OpenMM Context.
91     */
92    private OpenMMContext openMMContext;
93    /**
94     * OpenMM System.
95     */
96    private OpenMMSystem openMMSystem;
97    /**
98     * The atoms this OpenMMEnergy operates on.
99     */
100   private final Atom[] atoms;
101   /**
102    * If true, compute dUdL.
103    */
104   private final boolean computeDEDL;
105 
106   /**
107    * ForceFieldEnergyOpenMM constructor; offloads heavy-duty computation to an OpenMM Platform while
108    * keeping track of information locally.
109    *
110    * @param molecularAssembly Assembly to construct energy for.
111    * @param requestedPlatform requested OpenMM platform to be used.
112    * @param nThreads          Number of threads to use in the super class ForceFieldEnergy instance.
113    */
114   public OpenMMEnergy(MolecularAssembly molecularAssembly, Platform requestedPlatform, int nThreads) {
115     super(molecularAssembly, nThreads);
116 
117     Crystal crystal = getCrystal();
118     int symOps = crystal.spaceGroup.getNumberOfSymOps();
119     if (symOps > 1) {
120       logger.severe(" OpenMM does not support symmetry operators.");
121     }
122 
123     logger.info("\n Initializing OpenMM");
124 
125     ForceField forceField = molecularAssembly.getForceField();
126     atoms = molecularAssembly.getAtomArray();
127 
128     // Load the OpenMM plugins
129     this.platform = requestedPlatform;
130     ffx.openmm.Platform openMMPlatform = OpenMMContext.loadPlatform(platform, forceField);
131 
132     // Create the OpenMM System.
133     openMMSystem = new OpenMMSystem(this);
134     openMMSystem.addForces();
135 
136     // Create the Context.
137     openMMContext = new OpenMMContext(openMMPlatform, openMMSystem, atoms);
138 
139     computeDEDL = forceField.getBoolean("OMM_DUDL", false);
140   }
141 
142   /**
143    * Gets the default coprocessor device, ignoring any CUDA_DEVICE over-ride. This is either
144    * determined by process rank and the availableDevices/CUDA_DEVICES property, or just 0 if neither
145    * property is sets.
146    *
147    * @param props Properties in use.
148    * @return Pre-override device index.
149    */
150   public static int getDefaultDevice(CompositeConfiguration props) {
151     String availDeviceProp = props.getString("availableDevices", props.getString("CUDA_DEVICES"));
152     if (availDeviceProp == null) {
153       int nDevs = props.getInt("numCudaDevices", 1);
154       availDeviceProp = IntStream.range(0, nDevs).mapToObj(Integer::toString)
155           .collect(Collectors.joining(" "));
156     }
157     availDeviceProp = availDeviceProp.trim();
158 
159     String[] availDevices = availDeviceProp.split("\\s+");
160     int nDevs = availDevices.length;
161     int[] devs = new int[nDevs];
162     for (int i = 0; i < nDevs; i++) {
163       devs[i] = Integer.parseInt(availDevices[i]);
164     }
165 
166     logger.info(format(" Available devices: %d.", nDevs));
167 
168     // If only one device is available, return it.
169     if (nDevs == 1) {
170       return devs[0];
171     }
172 
173     int index = 0;
174     try {
175       Comm world = Comm.world();
176       if (world != null) {
177         int size = world.size();
178         logger.fine(format(" Number of MPI processes %d exceeds number of available devices %d.", size, nDevs));
179 
180         // Format the host as a CharacterBuf of length 100.
181         int messageLen = 100;
182         String host = world.host();
183         // Truncate to max 100 characters.
184         host = host.substring(0, Math.min(messageLen, host.length()));
185         // Pad to 100 characters.
186         host = format("%-100s", host);
187 
188         logger.fine(format(" Host: %s", host.trim()));
189         char[] messageOut = host.toCharArray();
190         CharacterBuf out = CharacterBuf.buffer(messageOut);
191 
192         // Now create CharacterBuf array for all incoming messages.
193         char[][] incoming = new char[size][messageLen];
194         CharacterBuf[] in = new CharacterBuf[size];
195         for (int i = 0; i < size; i++) {
196           in[i] = CharacterBuf.buffer(incoming[i]);
197         }
198 
199         try {
200           logger.fine(" AllGather for determining rank.");
201           world.allGather(out, in);
202           logger.fine(" AllGather complete.");
203         } catch (IOException ex) {
204           logger.warning(format(" Failure at the allGather step for determining rank: %s\n%s", ex, Utilities.stackTraceToString(ex)));
205         }
206         int ownIndex = -1;
207         int rank = world.rank();
208         boolean selfFound = false;
209 
210         for (int i = 0; i < size; i++) {
211           String hostI = new String(incoming[i]);
212           if (hostI.equalsIgnoreCase(host)) {
213             ++ownIndex;
214             if (i == rank) {
215               selfFound = true;
216               break;
217             }
218           }
219         }
220         if (!selfFound) {
221           logger.warning(format(" Rank %d: Could not find any incoming host messages matching self %s!", rank, host.trim()));
222         } else {
223           index = ownIndex % nDevs;
224         }
225       }
226     } catch (IllegalStateException ise) {
227       // Behavior is just to keep index = 0.
228     }
229     return devs[index];
230   }
231 
232   /**
233    * Create an OpenMM Context.
234    *
235    * <p>Context.free() must be called to free OpenMM memory.
236    *
237    * @param integratorName Integrator to use.
238    * @param timeStep       Time step.
239    * @param temperature    Temperature (K).
240    * @param forceCreation  Force a new Context to be created, even if the existing one matches the
241    *                       request.
242    */
243   @Override
244   public void updateContext(String integratorName, double timeStep, double temperature, boolean forceCreation) {
245     openMMContext.update(integratorName, timeStep, temperature, forceCreation);
246   }
247 
248   /**
249    * Create an immutable OpenMM State.
250    *
251    * <p>State.free() must be called to free OpenMM memory.
252    *
253    * @param mask The State mask.
254    * @return Returns the State.
255    */
256   @Override
257   public OpenMMState getOpenMMState(int mask) {
258     return openMMContext.getOpenMMState(mask);
259   }
260 
261   /**
262    * {@inheritDoc}
263    */
264   @Override
265   public boolean destroy() {
266     boolean ffxFFEDestroy = super.destroy();
267     free();
268     logger.fine(" Destroyed the Context and OpenMMSystem.");
269     return ffxFFEDestroy;
270   }
271 
272   /**
273    * {@inheritDoc}
274    */
275   @Override
276   public double energy(double[] x) {
277     return energy(x, false);
278   }
279 
280   /**
281    * {@inheritDoc}
282    */
283   @Override
284   public double energy(double[] x, boolean verbose) {
285 
286     if (lambdaBondedTerms) {
287       return 0.0;
288     }
289 
290     // Make sure the context has been created.
291     openMMContext.update();
292 
293     updateParameters(atoms);
294 
295     // Unscale the coordinates.
296     unscaleCoordinates(x);
297 
298     setCoordinates(x);
299 
300     OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy);
301     double e = openMMState.potentialEnergy;
302     openMMState.destroy();
303 
304     if (!isFinite(e)) {
305       String message = String.format(" Energy from OpenMM was a non-finite %8g", e);
306       logger.warning(message);
307       throw new EnergyException(message);
308     }
309 
310     if (verbose) {
311       logger.log(Level.INFO, String.format("\n OpenMM Energy: %14.10g", e));
312     }
313 
314     // Rescale the coordinates.
315     scaleCoordinates(x);
316 
317     return e;
318   }
319 
320   /**
321    * Compute the energy using the pure Java code path.
322    *
323    * @param x Atomic coordinates.
324    * @return The energy (kcal/mol)
325    */
326   public double energyFFX(double[] x) {
327     return super.energy(x, false);
328   }
329 
330   /**
331    * Compute the energy using the pure Java code path.
332    *
333    * @param x       Input atomic coordinates
334    * @param verbose Use verbose logging.
335    * @return The energy (kcal/mol)
336    */
337   public double energyFFX(double[] x, boolean verbose) {
338     return super.energy(x, verbose);
339   }
340 
341   /**
342    * {@inheritDoc}
343    */
344   @Override
345   public double energyAndGradient(double[] x, double[] g) {
346     return energyAndGradient(x, g, false);
347   }
348 
349   /**
350    * {@inheritDoc}
351    */
352   @Override
353   public double energyAndGradient(double[] x, double[] g, boolean verbose) {
354     if (lambdaBondedTerms) {
355       return 0.0;
356     }
357 
358     // Un-scale the coordinates.
359     unscaleCoordinates(x);
360 
361     // Make sure a context has been created.
362     openMMContext.update();
363 
364     // long time = -System.nanoTime();
365     setCoordinates(x);
366     // time += System.nanoTime();
367     // logger.info(format(" Load coordinates time %10.6f (sec)", time * 1.0e-9));
368 
369     // time = -System.nanoTime();
370     OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy | OpenMM_State_Forces);
371     double e = openMMState.potentialEnergy;
372     g = openMMState.getGradient(g);
373     openMMState.destroy();
374     // time += System.nanoTime();
375     // logger.info(format(" Calculate energy time %10.6f (sec)", time * 1.0e-9));
376 
377     if (!isFinite(e)) {
378       String message = format(" Energy from OpenMM was a non-finite %8g", e);
379       logger.warning(message);
380       throw new EnergyException(message);
381     }
382 
383     // if (vdwLambdaTerm) {
384     //    PointerByReference parameterArray = OpenMM_State_getEnergyParameterDerivatives(state);
385     //    int numDerives = OpenMM_ParameterArray_getSize(parameterArray);
386     //    if (numDerives > 0) {
387     //        double vdwdUdL = OpenMM_ParameterArray_get(parameterArray,
388     // pointerForString("vdw_lambda")) / OpenMM_KJPerKcal;
389     //    }
390     // }
391 
392     if (maxDebugGradient < Double.MAX_VALUE) {
393       boolean extremeGrad = Arrays.stream(g)
394           .anyMatch((double gi) -> (gi > maxDebugGradient || gi < -maxDebugGradient));
395       if (extremeGrad) {
396         File origFile = molecularAssembly.getFile();
397         String timeString = LocalDateTime.now()
398             .format(DateTimeFormatter.ofPattern("yyyy_MM_dd-HH_mm_ss"));
399 
400         String filename = format("%s-LARGEGRAD-%s.pdb",
401             FilenameUtils.removeExtension(molecularAssembly.getFile().getName()), timeString);
402         PotentialsFunctions ef = new PotentialsUtils();
403         filename = ef.versionFile(filename);
404 
405         logger.warning(
406             format(" Excessively large gradients detected; printing snapshot to file %s", filename));
407         ef.saveAsPDB(molecularAssembly, new File(filename));
408         molecularAssembly.setFile(origFile);
409       }
410     }
411 
412     if (verbose) {
413       logger.log(Level.INFO, format("\n OpenMM Energy: %14.10g", e));
414     }
415 
416     // Scale the coordinates and gradients.
417     scaleCoordinatesAndGradient(x, g);
418 
419     return e;
420   }
421 
422   /**
423    * Compute the energy and gradient using the pure Java code path.
424    *
425    * @param x Input atomic coordinates
426    * @param g Storage for the gradient vector.
427    * @return The energy (kcal/mol)
428    */
429   public double energyAndGradientFFX(double[] x, double[] g) {
430     return super.energyAndGradient(x, g, false);
431   }
432 
433   /**
434    * Compute the energy and gradient using the pure Java code path.
435    *
436    * @param x       Input atomic coordinates
437    * @param g       Storage for the gradient vector.
438    * @param verbose Use verbose logging.
439    * @return The energy (kcal/mol)
440    */
441   public double energyAndGradientFFX(double[] x, double[] g, boolean verbose) {
442     return super.energyAndGradient(x, g, verbose);
443   }
444 
445   /**
446    * Returns the Context instance.
447    *
448    * @return context
449    */
450   @Override
451   public OpenMMContext getContext() {
452     return openMMContext;
453   }
454 
455   /**
456    * Returns the MolecularAssembly instance.
457    *
458    * @return molecularAssembly
459    */
460   public MolecularAssembly getMolecularAssembly() {
461     return molecularAssembly;
462   }
463 
464   /**
465    * Re-compute the gradient using OpenMM and return it.
466    *
467    * @param g Gradient array.
468    */
469   @Override
470   public double[] getGradient(double[] g) {
471     OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Forces);
472     g = openMMState.getGradient(g);
473     openMMState.destroy();
474     return g;
475   }
476 
477   /**
478    * {@inheritDoc}
479    */
480   @Override
481   public Platform getPlatform() {
482     return platform;
483   }
484 
485   /**
486    * Get a reference to the System instance.
487    *
488    * @return Java wrapper to an OpenMM system.
489    */
490   @Override
491   public OpenMMSystem getSystem() {
492     return openMMSystem;
493   }
494 
495   /**
496    * {@inheritDoc}
497    */
498   @Override
499   public double getd2EdL2() {
500     return 0.0;
501   }
502 
503   /**
504    * {@inheritDoc}
505    */
506   @Override
507   public double getdEdL() {
508     // No lambda dependence.
509     if (!lambdaTerm || !computeDEDL) {
510       return 0.0;
511     }
512 
513     return FiniteDifferenceUtils.computedEdL(this, this, molecularAssembly.getForceField());
514   }
515 
516   /**
517    * {@inheritDoc}
518    */
519   @Override
520   public void getdEdXdL(double[] gradients) {
521     // Note for ForceFieldEnergyOpenMM this method is not implemented.
522   }
523 
524   /**
525    * Update active atoms.
526    */
527   @Override
528   public boolean setActiveAtoms() {
529     return openMMSystem.updateAtomMass();
530   }
531 
532   /**
533    * Coordinates for active atoms in units of Angstroms.
534    *
535    * @param x Atomic coordinates active atoms.
536    */
537   @Override
538   public void setCoordinates(double[] x) {
539     // Load the coordinates for active atoms.
540     super.setCoordinates(x);
541 
542     int n = atoms.length * 3;
543     double[] xall = new double[n];
544     int i = 0;
545     for (Atom atom : atoms) {
546       xall[i] = atom.getX();
547       xall[i + 1] = atom.getY();
548       xall[i + 2] = atom.getZ();
549       i += 3;
550     }
551 
552     // Load OpenMM coordinates for all atoms.
553     openMMContext.setPositions(xall);
554   }
555 
556   /**
557    * Velocities for active atoms in units of Angstroms.
558    *
559    * @param v Velocities for active atoms.
560    */
561   @Override
562   public void setVelocity(double[] v) {
563     // Load the velocity for active atoms.
564     super.setVelocity(v);
565 
566     int n = atoms.length * 3;
567     double[] vall = new double[n];
568     double[] v3 = new double[3];
569     int i = 0;
570     for (Atom atom : atoms) {
571       atom.getVelocity(v3);
572       vall[i] = v3[0];
573       vall[i + 1] = v3[1];
574       vall[i + 2] = v3[2];
575       i += 3;
576     }
577 
578     // Load OpenMM velocities for all atoms.
579     openMMContext.setVelocities(vall);
580   }
581 
582   /**
583    * {@inheritDoc}
584    */
585   @Override
586   public void setCrystal(Crystal crystal) {
587     super.setCrystal(crystal);
588     openMMContext.setPeriodicBoxVectors(crystal);
589   }
590 
591   /**
592    * {@inheritDoc}
593    */
594   @Override
595   public void setLambda(double lambda) {
596     if (!lambdaTerm) {
597       logger.fine(" Attempting to set lambda for an OpenMMEnergy with lambdaterm false.");
598       return;
599     }
600 
601     super.setLambda(lambda);
602 
603     if (atoms != null) {
604       List<Atom> atomList = new ArrayList<>();
605       for (Atom atom : atoms) {
606         if (atom.applyLambda()) {
607           atomList.add(atom);
608         }
609       }
610       // Update force field parameters based on defined lambda values.
611       updateParameters(atomList.toArray(new Atom[0]));
612     } else {
613       updateParameters(null);
614     }
615 
616   }
617 
618   /**
619    * Update parameters if the Use flags and/or Lambda value has changed.
620    *
621    * @param atoms Atoms in this list are considered.
622    */
623   @Override
624   public void updateParameters(@Nullable Atom[] atoms) {
625     if (atoms == null) {
626       atoms = this.atoms;
627     }
628     if (openMMSystem != null) {
629       openMMSystem.updateParameters(atoms);
630     }
631   }
632 
633   /**
634    * Free OpenMM memory for the Context and System.
635    */
636   private void free() {
637     if (openMMContext != null) {
638       openMMContext.free();
639       openMMContext = null;
640     }
641     if (openMMSystem != null) {
642       openMMSystem.free();
643       openMMSystem = null;
644     }
645   }
646 
647 }