View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2024.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.potential.openmm;
39  
40  import edu.rit.mp.CharacterBuf;
41  import edu.rit.pj.Comm;
42  import ffx.crystal.Crystal;
43  import ffx.potential.ForceFieldEnergy;
44  import ffx.potential.MolecularAssembly;
45  import ffx.potential.Platform;
46  import ffx.potential.Utilities;
47  import ffx.potential.bonded.Atom;
48  import ffx.potential.parameters.ForceField;
49  import ffx.potential.utils.EnergyException;
50  import ffx.potential.utils.PotentialsFunctions;
51  import ffx.potential.utils.PotentialsUtils;
52  import org.apache.commons.configuration2.CompositeConfiguration;
53  import org.apache.commons.io.FilenameUtils;
54  
55  import javax.annotation.Nullable;
56  import java.io.File;
57  import java.io.IOException;
58  import java.time.LocalDateTime;
59  import java.time.format.DateTimeFormatter;
60  import java.util.ArrayList;
61  import java.util.Arrays;
62  import java.util.List;
63  import java.util.logging.Level;
64  import java.util.logging.Logger;
65  import java.util.stream.Collectors;
66  import java.util.stream.IntStream;
67  
68  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_False;
69  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_True;
70  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Energy;
71  import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Forces;
72  import static java.lang.Double.isFinite;
73  import static java.lang.String.format;
74  
75  /**
76   * Compute the potential energy and derivatives using OpenMM.
77   *
78   * @author Michael J. Schnieders
79   * @since 1.0
80   */
81  @SuppressWarnings("deprecation")
82  public class OpenMMEnergy extends ForceFieldEnergy {
83  
84    private static final Logger logger = Logger.getLogger(OpenMMEnergy.class.getName());
85  
86    /**
87     * OpenMM Context.
88     */
89    private OpenMMContext openMMContext;
90    /**
91     * OpenMM System.
92     */
93    private OpenMMSystem openMMSystem;
94    /**
95     * The atoms this ForceFieldEnergyOpenMM operates on.
96     */
97    private final Atom[] atoms;
98  
99    /**
100    * Truncate the normal OpenMM Lambda Path from 0 ... 1 to Lambda_Start ... 1. This is useful for
101    * conformational optimization if full removal of vdW interactions is not desired (i.e. lambdaStart
102    * = ~0.2).
103    */
104   private double lambdaStart;
105   /**
106    * Use two-sided finite difference dU/dL.
107    */
108   private boolean twoSidedFiniteDifference = true;
109   /**
110    * Lambda step size for finite difference dU/dL.
111    */
112   private final double finiteDifferenceStepSize;
113 
114   /**
115    * ForceFieldEnergyOpenMM constructor; offloads heavy-duty computation to an OpenMM Platform while
116    * keeping track of information locally.
117    *
118    * @param molecularAssembly Assembly to construct energy for.
119    * @param requestedPlatform requested OpenMM platform to be used.
120    * @param nThreads          Number of threads to use in the super class ForceFieldEnergy instance.
121    */
122   public OpenMMEnergy(MolecularAssembly molecularAssembly, Platform requestedPlatform, int nThreads) {
123     super(molecularAssembly, nThreads);
124 
125     Crystal crystal = getCrystal();
126     int symOps = crystal.spaceGroup.getNumberOfSymOps();
127     if (symOps > 1) {
128       logger.info("");
129       logger.severe(" OpenMM does not support symmetry operators.");
130     }
131 
132     logger.info("\n Initializing OpenMM");
133 
134     ForceField forceField = molecularAssembly.getForceField();
135     atoms = molecularAssembly.getAtomArray();
136     boolean aperiodic = super.getCrystal().aperiodic();
137     boolean pbcEnforced = forceField.getBoolean("ENFORCE_PBC", !aperiodic);
138     int enforcePBC = pbcEnforced ? OpenMM_True : OpenMM_False;
139 
140     openMMContext = new OpenMMContext(requestedPlatform, atoms, enforcePBC, this);
141     openMMSystem = new OpenMMSystem(this);
142     openMMSystem.addForces();
143 
144     // Expand the path [lambda-start .. 1.0] to the interval [0.0 .. 1.0].
145     lambdaStart = forceField.getDouble("LAMBDA_START", 0.0);
146     if (lambdaStart > 1.0) {
147       lambdaStart = 1.0;
148     } else if (lambdaStart < 0.0) {
149       lambdaStart = 0.0;
150     }
151 
152     finiteDifferenceStepSize = forceField.getDouble("FD_DLAMBDA", 0.001);
153     twoSidedFiniteDifference = forceField.getBoolean("FD_TWO_SIDED", twoSidedFiniteDifference);
154   }
155 
156   /**
157    * Gets the default coprocessor device, ignoring any CUDA_DEVICE over-ride. This is either
158    * determined by process rank and the availableDevices/CUDA_DEVICES property, or just 0 if neither
159    * property is sets.
160    *
161    * @param props Properties in use.
162    * @return Pre-override device index.
163    */
164   public static int getDefaultDevice(CompositeConfiguration props) {
165     String availDeviceProp = props.getString("availableDevices", props.getString("CUDA_DEVICES"));
166     if (availDeviceProp == null) {
167       int nDevs = props.getInt("numCudaDevices", 1);
168       availDeviceProp = IntStream.range(0, nDevs).mapToObj(Integer::toString)
169           .collect(Collectors.joining(" "));
170     }
171     availDeviceProp = availDeviceProp.trim();
172 
173     String[] availDevices = availDeviceProp.split("\\s+");
174     int nDevs = availDevices.length;
175     int[] devs = new int[nDevs];
176     for (int i = 0; i < nDevs; i++) {
177       devs[i] = Integer.parseInt(availDevices[i]);
178     }
179 
180     logger.info(format(" Number of CUDA devices: %d.", nDevs));
181 
182     int index = 0;
183     try {
184       Comm world = Comm.world();
185       if (world != null) {
186         int size = world.size();
187 
188         // Format the host as a CharacterBuf of length 100.
189         int messageLen = 100;
190         String host = world.host();
191         // Truncate to max 100 characters.
192         host = host.substring(0, Math.min(messageLen, host.length()));
193         // Pad to 100 characters.
194         host = format("%-100s", host);
195         char[] messageOut = host.toCharArray();
196         CharacterBuf out = CharacterBuf.buffer(messageOut);
197 
198         // Now create CharacterBuf array for all incoming messages.
199         char[][] incoming = new char[size][messageLen];
200         CharacterBuf[] in = new CharacterBuf[size];
201         for (int i = 0; i < size; i++) {
202           in[i] = CharacterBuf.buffer(incoming[i]);
203         }
204 
205         try {
206           world.allGather(out, in);
207         } catch (IOException ex) {
208           logger.severe(format(" Failure at the allGather step for determining rank: %s\n%s", ex, Utilities.stackTraceToString(ex)));
209         }
210         int ownIndex = -1;
211         int rank = world.rank();
212         boolean selfFound = false;
213 
214         for (int i = 0; i < size; i++) {
215           String hostI = new String(incoming[i]);
216           if (hostI.equalsIgnoreCase(host)) {
217             ++ownIndex;
218             if (i == rank) {
219               selfFound = true;
220               break;
221             }
222           }
223         }
224         if (!selfFound) {
225           logger.severe(format(" Rank %d: Could not find any incoming host messages matching self %s!", rank, host.trim()));
226         } else {
227           index = ownIndex % nDevs;
228         }
229       }
230     } catch (IllegalStateException ise) {
231       // Behavior is just to keep index = 0.
232     }
233     return devs[index];
234   }
235 
236   /**
237    * Create an OpenMM Context.
238    *
239    * <p>Context.free() must be called to free OpenMM memory.
240    *
241    * @param integratorName Integrator to use.
242    * @param timeStep       Time step.
243    * @param temperature    Temperature (K).
244    * @param forceCreation  Force a new Context to be created, even if the existing one matches the
245    *                       request.
246    */
247   public void updateContext(String integratorName, double timeStep, double temperature, boolean forceCreation) {
248     openMMContext.update(integratorName, timeStep, temperature, forceCreation, this);
249   }
250 
251   /**
252    * Create an immutable OpenMM State.
253    *
254    * <p>State.free() must be called to free OpenMM memory.
255    *
256    * @param mask The State mask.
257    * @return Returns the State.
258    */
259   public OpenMMState getOpenMMState(int mask) {
260     return openMMContext.getOpenMMState(mask);
261   }
262 
263   /**
264    * {@inheritDoc}
265    */
266   @Override
267   public boolean destroy() {
268     boolean ffxFFEDestroy = super.destroy();
269     free();
270     logger.fine(" Destroyed the Context, Integrator, and OpenMMSystem.");
271     return ffxFFEDestroy;
272   }
273 
274   /**
275    * {@inheritDoc}
276    */
277   @Override
278   public double energy(double[] x) {
279     return energy(x, false);
280   }
281 
282   /**
283    * {@inheritDoc}
284    */
285   @Override
286   public double energy(double[] x, boolean verbose) {
287 
288     if (lambdaBondedTerms) {
289       return 0.0;
290     }
291 
292     // Make sure the context has been created.
293     openMMContext.update();
294 
295     updateParameters(atoms);
296 
297     // Unscale the coordinates.
298     unscaleCoordinates(x);
299 
300     setCoordinates(x);
301 
302     OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy);
303     double e = openMMState.potentialEnergy;
304     openMMState.destroy();
305 
306     if (!isFinite(e)) {
307       String message = String.format(" Energy from OpenMM was a non-finite %8g", e);
308       logger.warning(message);
309       if (lambdaTerm) {
310         openMMSystem.printLambdaValues();
311       }
312       throw new EnergyException(message);
313     }
314 
315     if (verbose) {
316       logger.log(Level.INFO, String.format("\n OpenMM Energy: %14.10g", e));
317     }
318 
319     // Rescale the coordinates.
320     scaleCoordinates(x);
321 
322     return e;
323   }
324 
325   /**
326    * {@inheritDoc}
327    */
328   @Override
329   public double energyAndGradient(double[] x, double[] g) {
330     return energyAndGradient(x, g, false);
331   }
332 
333   /**
334    * {@inheritDoc}
335    */
336   @Override
337   public double energyAndGradient(double[] x, double[] g, boolean verbose) {
338     if (lambdaBondedTerms) {
339       return 0.0;
340     }
341 
342     // ZE BUG: updateParameters only gets called for energy(), not energyAndGradient().
343 
344     // Un-scale the coordinates.
345     unscaleCoordinates(x);
346 
347     // Make sure a context has been created.
348     openMMContext.update();
349 
350     setCoordinates(x);
351 
352     OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy | OpenMM_State_Forces);
353     double e = openMMState.potentialEnergy;
354     g = openMMState.getGradient(g);
355     openMMState.destroy();
356 
357     if (!isFinite(e)) {
358       String message = format(" Energy from OpenMM was a non-finite %8g", e);
359       logger.warning(message);
360       if (lambdaTerm) {
361         openMMSystem.printLambdaValues();
362       }
363       throw new EnergyException(message);
364     }
365 
366     // if (vdwLambdaTerm) {
367     //    PointerByReference parameterArray = OpenMM_State_getEnergyParameterDerivatives(state);
368     //    int numDerives = OpenMM_ParameterArray_getSize(parameterArray);
369     //    if (numDerives > 0) {
370     //        double vdwdUdL = OpenMM_ParameterArray_get(parameterArray,
371     // pointerForString("vdw_lambda")) / OpenMM_KJPerKcal;
372     //    }
373     // }
374 
375     if (maxDebugGradient < Double.MAX_VALUE) {
376       boolean extremeGrad = Arrays.stream(g)
377           .anyMatch((double gi) -> (gi > maxDebugGradient || gi < -maxDebugGradient));
378       if (extremeGrad) {
379         File origFile = molecularAssembly.getFile();
380         String timeString = LocalDateTime.now()
381             .format(DateTimeFormatter.ofPattern("yyyy_MM_dd-HH_mm_ss"));
382 
383         String filename = format("%s-LARGEGRAD-%s.pdb",
384             FilenameUtils.removeExtension(molecularAssembly.getFile().getName()), timeString);
385         PotentialsFunctions ef = new PotentialsUtils();
386         filename = ef.versionFile(filename);
387 
388         logger.warning(
389             format(" Excessively large gradients detected; printing snapshot to file %s", filename));
390         ef.saveAsPDB(molecularAssembly, new File(filename));
391         molecularAssembly.setFile(origFile);
392       }
393     }
394 
395     if (verbose) {
396       logger.log(Level.INFO, format("\n OpenMM Energy: %14.10g", e));
397     }
398 
399     // Scale the coordinates and gradients.
400     scaleCoordinatesAndGradient(x, g);
401 
402     return e;
403   }
404 
405   /**
406    * Compute the energy and gradient using the pure Java code path.
407    *
408    * @param x Input atomic coordinates
409    * @param g Storage for the gradient vector.
410    * @return The energy (kcal/mol)
411    */
412   public double energyAndGradientFFX(double[] x, double[] g) {
413     return super.energyAndGradient(x, g, false);
414   }
415 
416   /**
417    * Compute the energy and gradient using the pure Java code path.
418    *
419    * @param x       Input atomic coordinates
420    * @param g       Storage for the gradient vector.
421    * @param verbose Use verbose logging.
422    * @return The energy (kcal/mol)
423    */
424   public double energyAndGradientFFX(double[] x, double[] g, boolean verbose) {
425     return super.energyAndGradient(x, g, verbose);
426   }
427 
428   /**
429    * Compute the energy using the pure Java code path.
430    *
431    * @param x Atomic coordinates.
432    * @return The energy (kcal/mol)
433    */
434   public double energyFFX(double[] x) {
435     return super.energy(x, false);
436   }
437 
438   /**
439    * Compute the energy using the pure Java code path.
440    *
441    * @param x       Input atomic coordinates
442    * @param verbose Use verbose logging.
443    * @return The energy (kcal/mol)
444    */
445   public double energyFFX(double[] x, boolean verbose) {
446     return super.energy(x, verbose);
447   }
448 
449   /**
450    * Returns the Context instance.
451    *
452    * @return context
453    */
454   public OpenMMContext getContext() {
455     return openMMContext;
456   }
457 
458   /**
459    * Returns the MolecularAssembly instance.
460    *
461    * @return molecularAssembly
462    */
463   public MolecularAssembly getMolecularAssembly() {
464     return molecularAssembly;
465   }
466 
467   /**
468    * Set the lambdaTerm flag.
469    *
470    * @param lambdaTerm The value to set.
471    */
472   public void setLambdaTerm(boolean lambdaTerm) {
473     this.lambdaTerm = lambdaTerm;
474   }
475 
476   /**
477    * Get the lambdaTerm flag.
478    *
479    * @return lambdaTerm.
480    */
481   public boolean getLambdaTerm() {
482     return lambdaTerm;
483   }
484 
485   /**
486    * Re-compute the gradient using OpenMM and return it.
487    *
488    * @param g Gradient array.
489    */
490   @Override
491   public double[] getGradient(double[] g) {
492     OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Forces);
493     g = openMMState.getGradient(g);
494     openMMState.destroy();
495     return g;
496   }
497 
498   /**
499    * {@inheritDoc}
500    */
501   @Override
502   public Platform getPlatform() {
503     return openMMContext.getPlatform();
504   }
505 
506   /**
507    * Get a reference to the System instance.
508    *
509    * @return Java wrapper to an OpenMM system.
510    */
511   public OpenMMSystem getSystem() {
512     return openMMSystem;
513   }
514 
515   /**
516    * {@inheritDoc}
517    */
518   @Override
519   public double getd2EdL2() {
520     return 0.0;
521   }
522 
523   /**
524    * {@inheritDoc}
525    */
526   @Override
527   public double getdEdL() {
528     // No lambda dependence.
529     if (!lambdaTerm) {
530       return 0.0;
531     }
532 
533     // Small optimization to only create the x array once.
534     double[] x = new double[getNumberOfVariables()];
535     getCoordinates(x);
536 
537     double currentLambda = getLambda();
538     double width = finiteDifferenceStepSize;
539     double ePlus;
540     double eMinus;
541 
542     if (twoSidedFiniteDifference) {
543       if (currentLambda + finiteDifferenceStepSize > 1.0) {
544         setLambda(currentLambda - finiteDifferenceStepSize);
545         eMinus = energy(x);
546         setLambda(currentLambda);
547         ePlus = energy(x);
548       } else if (currentLambda - finiteDifferenceStepSize < 0.0) {
549         setLambda(currentLambda + finiteDifferenceStepSize);
550         ePlus = energy(x);
551         setLambda(currentLambda);
552         eMinus = energy(x);
553       } else {
554         // Two sided finite difference estimate of dE/dL.
555         setLambda(currentLambda + finiteDifferenceStepSize);
556         ePlus = energy(x);
557         setLambda(currentLambda - finiteDifferenceStepSize);
558         eMinus = energy(x);
559         width *= 2.0;
560         setLambda(currentLambda);
561       }
562     } else {
563       // One-sided finite difference estimates of dE/dL
564       if (currentLambda + finiteDifferenceStepSize > 1.0) {
565         setLambda(currentLambda - finiteDifferenceStepSize);
566         eMinus = energy(x);
567         setLambda(currentLambda);
568         ePlus = energy(x);
569       } else {
570         setLambda(currentLambda + finiteDifferenceStepSize);
571         ePlus = energy(x);
572         setLambda(currentLambda);
573         eMinus = energy(x);
574       }
575     }
576 
577     // Compute the finite difference derivative.
578     double dEdL = (ePlus - eMinus) / width;
579 
580     // logger.info(format(" getdEdL currentLambda: CL=%8.6f L=%8.6f dEdL=%12.6f", currentLambda,
581     // lambda, dEdL));
582     return dEdL;
583   }
584 
585   /**
586    * {@inheritDoc}
587    */
588   @Override
589   public void getdEdXdL(double[] gradients) {
590     // Note for ForceFieldEnergyOpenMM this method is not implemented.
591   }
592 
593   /**
594    * Update active atoms.
595    */
596   public void setActiveAtoms() {
597     openMMSystem.updateAtomMass();
598     // Tests show reinitialization of the OpenMM Context is not necessary to pick up mass changes.
599     // context.reinitContext();
600   }
601 
602   /**
603    * Set FFX and OpenMM coordinates for active atoms.
604    *
605    * @param x Atomic coordinates.
606    */
607   @Override
608   public void setCoordinates(double[] x) {
609     // Set both OpenMM and FFX coordinates to x.
610     openMMContext.setPositions(x);
611   }
612 
613   /**
614    * {@inheritDoc}
615    */
616   @Override
617   public void setCrystal(Crystal crystal) {
618     super.setCrystal(crystal);
619     openMMContext.setPeriodicBoxVectors(crystal);
620   }
621 
622   public void setLambdaStart(double lambdaStart) {
623     this.lambdaStart = lambdaStart;
624   }
625 
626   public double getLambdaStart() {
627     return lambdaStart;
628   }
629 
630   public void setTwoSidedFiniteDifference(boolean twoSidedFiniteDifference) {
631     this.twoSidedFiniteDifference = twoSidedFiniteDifference;
632   }
633 
634   /**
635    * {@inheritDoc}
636    */
637   @Override
638   public void setLambda(double lambda) {
639 
640     if (!lambdaTerm) {
641       logger.fine(" Attempting to set lambda for a ForceFieldEnergyOpenMM with lambdaterm false.");
642       return;
643     }
644 
645     // Check for lambda outside the range [0 .. 1].
646     if (lambda < 0.0 || lambda > 1.0) {
647       String message = format(" Lambda value %8.3f is not in the range [0..1].", lambda);
648       logger.warning(message);
649       return;
650     }
651 
652     super.setLambda(lambda);
653 
654     // Remove the beginning of the normal Lambda path.
655     double mappedLambda = lambda;
656     if (lambdaStart > 0) {
657       double windowSize = 1.0 - lambdaStart;
658       mappedLambda = lambdaStart + lambda * windowSize;
659     }
660 
661     if (openMMSystem != null) {
662       openMMSystem.setLambda(mappedLambda);
663       if (atoms != null) {
664         List<Atom> atomList = new ArrayList<>();
665         for (Atom atom : atoms) {
666           if (atom.applyLambda()) {
667             atomList.add(atom);
668           }
669         }
670         // Update force field parameters based on defined lambda values.
671         updateParameters(atomList.toArray(new Atom[0]));
672       } else {
673         updateParameters(null);
674       }
675     }
676   }
677 
678   /**
679    * Update parameters if the Use flags and/or Lambda value has changed.
680    *
681    * @param atoms Atoms in this list are considered.
682    */
683   public void updateParameters(@Nullable Atom[] atoms) {
684     if (atoms == null) {
685       atoms = this.atoms;
686     }
687     openMMSystem.updateParameters(atoms);
688   }
689 
690   /**
691    * Free OpenMM memory for the Context, Integrator and System.
692    */
693   private void free() {
694     if (openMMContext != null) {
695       openMMContext.free();
696       openMMContext = null;
697     }
698     if (openMMSystem != null) {
699       openMMSystem.free();
700       openMMSystem = null;
701     }
702   }
703 
704 }