1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.openmm;
39
40 import edu.rit.mp.CharacterBuf;
41 import edu.rit.pj.Comm;
42 import ffx.crystal.Crystal;
43 import ffx.potential.ForceFieldEnergy;
44 import ffx.potential.MolecularAssembly;
45 import ffx.potential.Platform;
46 import ffx.potential.Utilities;
47 import ffx.potential.bonded.Atom;
48 import ffx.potential.parameters.ForceField;
49 import ffx.potential.utils.EnergyException;
50 import ffx.potential.utils.PotentialsFunctions;
51 import ffx.potential.utils.PotentialsUtils;
52 import org.apache.commons.configuration2.CompositeConfiguration;
53 import org.apache.commons.io.FilenameUtils;
54
55 import javax.annotation.Nullable;
56 import java.io.File;
57 import java.io.IOException;
58 import java.time.LocalDateTime;
59 import java.time.format.DateTimeFormatter;
60 import java.util.ArrayList;
61 import java.util.Arrays;
62 import java.util.List;
63 import java.util.logging.Level;
64 import java.util.logging.Logger;
65 import java.util.stream.Collectors;
66 import java.util.stream.IntStream;
67
68 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_False;
69 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_True;
70 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Energy;
71 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Forces;
72 import static java.lang.Double.isFinite;
73 import static java.lang.String.format;
74
75
76
77
78
79
80
81 @SuppressWarnings("deprecation")
82 public class OpenMMEnergy extends ForceFieldEnergy {
83
84 private static final Logger logger = Logger.getLogger(OpenMMEnergy.class.getName());
85
86
87
88
89 private OpenMMContext openMMContext;
90
91
92
93 private OpenMMSystem openMMSystem;
94
95
96
97 private final Atom[] atoms;
98
99
100
101
102
103
104 private double lambdaStart;
105
106
107
108 private boolean twoSidedFiniteDifference = true;
109
110
111
112 private final double finiteDifferenceStepSize;
113
114
115
116
117
118
119
120
121
122 public OpenMMEnergy(MolecularAssembly molecularAssembly, Platform requestedPlatform, int nThreads) {
123 super(molecularAssembly, nThreads);
124
125 Crystal crystal = getCrystal();
126 int symOps = crystal.spaceGroup.getNumberOfSymOps();
127 if (symOps > 1) {
128 logger.info("");
129 logger.severe(" OpenMM does not support symmetry operators.");
130 }
131
132 logger.info("\n Initializing OpenMM");
133
134 ForceField forceField = molecularAssembly.getForceField();
135 atoms = molecularAssembly.getAtomArray();
136 boolean aperiodic = super.getCrystal().aperiodic();
137 boolean pbcEnforced = forceField.getBoolean("ENFORCE_PBC", !aperiodic);
138 int enforcePBC = pbcEnforced ? OpenMM_True : OpenMM_False;
139
140 openMMContext = new OpenMMContext(requestedPlatform, atoms, enforcePBC, this);
141 openMMSystem = new OpenMMSystem(this);
142 openMMSystem.addForces();
143
144
145 lambdaStart = forceField.getDouble("LAMBDA_START", 0.0);
146 if (lambdaStart > 1.0) {
147 lambdaStart = 1.0;
148 } else if (lambdaStart < 0.0) {
149 lambdaStart = 0.0;
150 }
151
152 finiteDifferenceStepSize = forceField.getDouble("FD_DLAMBDA", 0.001);
153 twoSidedFiniteDifference = forceField.getBoolean("FD_TWO_SIDED", twoSidedFiniteDifference);
154 }
155
156
157
158
159
160
161
162
163
164 public static int getDefaultDevice(CompositeConfiguration props) {
165 String availDeviceProp = props.getString("availableDevices", props.getString("CUDA_DEVICES"));
166 if (availDeviceProp == null) {
167 int nDevs = props.getInt("numCudaDevices", 1);
168 availDeviceProp = IntStream.range(0, nDevs).mapToObj(Integer::toString)
169 .collect(Collectors.joining(" "));
170 }
171 availDeviceProp = availDeviceProp.trim();
172
173 String[] availDevices = availDeviceProp.split("\\s+");
174 int nDevs = availDevices.length;
175 int[] devs = new int[nDevs];
176 for (int i = 0; i < nDevs; i++) {
177 devs[i] = Integer.parseInt(availDevices[i]);
178 }
179
180 logger.info(format(" Number of CUDA devices: %d.", nDevs));
181
182 int index = 0;
183 try {
184 Comm world = Comm.world();
185 if (world != null) {
186 int size = world.size();
187
188
189 int messageLen = 100;
190 String host = world.host();
191
192 host = host.substring(0, Math.min(messageLen, host.length()));
193
194 host = format("%-100s", host);
195 char[] messageOut = host.toCharArray();
196 CharacterBuf out = CharacterBuf.buffer(messageOut);
197
198
199 char[][] incoming = new char[size][messageLen];
200 CharacterBuf[] in = new CharacterBuf[size];
201 for (int i = 0; i < size; i++) {
202 in[i] = CharacterBuf.buffer(incoming[i]);
203 }
204
205 try {
206 world.allGather(out, in);
207 } catch (IOException ex) {
208 logger.severe(format(" Failure at the allGather step for determining rank: %s\n%s", ex, Utilities.stackTraceToString(ex)));
209 }
210 int ownIndex = -1;
211 int rank = world.rank();
212 boolean selfFound = false;
213
214 for (int i = 0; i < size; i++) {
215 String hostI = new String(incoming[i]);
216 if (hostI.equalsIgnoreCase(host)) {
217 ++ownIndex;
218 if (i == rank) {
219 selfFound = true;
220 break;
221 }
222 }
223 }
224 if (!selfFound) {
225 logger.severe(format(" Rank %d: Could not find any incoming host messages matching self %s!", rank, host.trim()));
226 } else {
227 index = ownIndex % nDevs;
228 }
229 }
230 } catch (IllegalStateException ise) {
231
232 }
233 return devs[index];
234 }
235
236
237
238
239
240
241
242
243
244
245
246
247 public void updateContext(String integratorName, double timeStep, double temperature, boolean forceCreation) {
248 openMMContext.update(integratorName, timeStep, temperature, forceCreation, this);
249 }
250
251
252
253
254
255
256
257
258
259 public OpenMMState getOpenMMState(int mask) {
260 return openMMContext.getOpenMMState(mask);
261 }
262
263
264
265
266 @Override
267 public boolean destroy() {
268 boolean ffxFFEDestroy = super.destroy();
269 free();
270 logger.fine(" Destroyed the Context, Integrator, and OpenMMSystem.");
271 return ffxFFEDestroy;
272 }
273
274
275
276
277 @Override
278 public double energy(double[] x) {
279 return energy(x, false);
280 }
281
282
283
284
285 @Override
286 public double energy(double[] x, boolean verbose) {
287
288 if (lambdaBondedTerms) {
289 return 0.0;
290 }
291
292
293 openMMContext.update();
294
295 updateParameters(atoms);
296
297
298 unscaleCoordinates(x);
299
300 setCoordinates(x);
301
302 OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy);
303 double e = openMMState.potentialEnergy;
304 openMMState.destroy();
305
306 if (!isFinite(e)) {
307 String message = String.format(" Energy from OpenMM was a non-finite %8g", e);
308 logger.warning(message);
309 if (lambdaTerm) {
310 openMMSystem.printLambdaValues();
311 }
312 throw new EnergyException(message);
313 }
314
315 if (verbose) {
316 logger.log(Level.INFO, String.format("\n OpenMM Energy: %14.10g", e));
317 }
318
319
320 scaleCoordinates(x);
321
322 return e;
323 }
324
325
326
327
328 @Override
329 public double energyAndGradient(double[] x, double[] g) {
330 return energyAndGradient(x, g, false);
331 }
332
333
334
335
336 @Override
337 public double energyAndGradient(double[] x, double[] g, boolean verbose) {
338 if (lambdaBondedTerms) {
339 return 0.0;
340 }
341
342
343
344
345 unscaleCoordinates(x);
346
347
348 openMMContext.update();
349
350 setCoordinates(x);
351
352 OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy | OpenMM_State_Forces);
353 double e = openMMState.potentialEnergy;
354 g = openMMState.getGradient(g);
355 openMMState.destroy();
356
357 if (!isFinite(e)) {
358 String message = format(" Energy from OpenMM was a non-finite %8g", e);
359 logger.warning(message);
360 if (lambdaTerm) {
361 openMMSystem.printLambdaValues();
362 }
363 throw new EnergyException(message);
364 }
365
366
367
368
369
370
371
372
373
374
375 if (maxDebugGradient < Double.MAX_VALUE) {
376 boolean extremeGrad = Arrays.stream(g)
377 .anyMatch((double gi) -> (gi > maxDebugGradient || gi < -maxDebugGradient));
378 if (extremeGrad) {
379 File origFile = molecularAssembly.getFile();
380 String timeString = LocalDateTime.now()
381 .format(DateTimeFormatter.ofPattern("yyyy_MM_dd-HH_mm_ss"));
382
383 String filename = format("%s-LARGEGRAD-%s.pdb",
384 FilenameUtils.removeExtension(molecularAssembly.getFile().getName()), timeString);
385 PotentialsFunctions ef = new PotentialsUtils();
386 filename = ef.versionFile(filename);
387
388 logger.warning(
389 format(" Excessively large gradients detected; printing snapshot to file %s", filename));
390 ef.saveAsPDB(molecularAssembly, new File(filename));
391 molecularAssembly.setFile(origFile);
392 }
393 }
394
395 if (verbose) {
396 logger.log(Level.INFO, format("\n OpenMM Energy: %14.10g", e));
397 }
398
399
400 scaleCoordinatesAndGradient(x, g);
401
402 return e;
403 }
404
405
406
407
408
409
410
411
412 public double energyAndGradientFFX(double[] x, double[] g) {
413 return super.energyAndGradient(x, g, false);
414 }
415
416
417
418
419
420
421
422
423
424 public double energyAndGradientFFX(double[] x, double[] g, boolean verbose) {
425 return super.energyAndGradient(x, g, verbose);
426 }
427
428
429
430
431
432
433
434 public double energyFFX(double[] x) {
435 return super.energy(x, false);
436 }
437
438
439
440
441
442
443
444
445 public double energyFFX(double[] x, boolean verbose) {
446 return super.energy(x, verbose);
447 }
448
449
450
451
452
453
454 public OpenMMContext getContext() {
455 return openMMContext;
456 }
457
458
459
460
461
462
463 public MolecularAssembly getMolecularAssembly() {
464 return molecularAssembly;
465 }
466
467
468
469
470
471
472 public void setLambdaTerm(boolean lambdaTerm) {
473 this.lambdaTerm = lambdaTerm;
474 }
475
476
477
478
479
480
481 public boolean getLambdaTerm() {
482 return lambdaTerm;
483 }
484
485
486
487
488
489
490 @Override
491 public double[] getGradient(double[] g) {
492 OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Forces);
493 g = openMMState.getGradient(g);
494 openMMState.destroy();
495 return g;
496 }
497
498
499
500
501 @Override
502 public Platform getPlatform() {
503 return openMMContext.getPlatform();
504 }
505
506
507
508
509
510
511 public OpenMMSystem getSystem() {
512 return openMMSystem;
513 }
514
515
516
517
518 @Override
519 public double getd2EdL2() {
520 return 0.0;
521 }
522
523
524
525
526 @Override
527 public double getdEdL() {
528
529 if (!lambdaTerm) {
530 return 0.0;
531 }
532
533
534 double[] x = new double[getNumberOfVariables()];
535 getCoordinates(x);
536
537 double currentLambda = getLambda();
538 double width = finiteDifferenceStepSize;
539 double ePlus;
540 double eMinus;
541
542 if (twoSidedFiniteDifference) {
543 if (currentLambda + finiteDifferenceStepSize > 1.0) {
544 setLambda(currentLambda - finiteDifferenceStepSize);
545 eMinus = energy(x);
546 setLambda(currentLambda);
547 ePlus = energy(x);
548 } else if (currentLambda - finiteDifferenceStepSize < 0.0) {
549 setLambda(currentLambda + finiteDifferenceStepSize);
550 ePlus = energy(x);
551 setLambda(currentLambda);
552 eMinus = energy(x);
553 } else {
554
555 setLambda(currentLambda + finiteDifferenceStepSize);
556 ePlus = energy(x);
557 setLambda(currentLambda - finiteDifferenceStepSize);
558 eMinus = energy(x);
559 width *= 2.0;
560 setLambda(currentLambda);
561 }
562 } else {
563
564 if (currentLambda + finiteDifferenceStepSize > 1.0) {
565 setLambda(currentLambda - finiteDifferenceStepSize);
566 eMinus = energy(x);
567 setLambda(currentLambda);
568 ePlus = energy(x);
569 } else {
570 setLambda(currentLambda + finiteDifferenceStepSize);
571 ePlus = energy(x);
572 setLambda(currentLambda);
573 eMinus = energy(x);
574 }
575 }
576
577
578 double dEdL = (ePlus - eMinus) / width;
579
580
581
582 return dEdL;
583 }
584
585
586
587
588 @Override
589 public void getdEdXdL(double[] gradients) {
590
591 }
592
593
594
595
596 public void setActiveAtoms() {
597 openMMSystem.updateAtomMass();
598
599
600 }
601
602
603
604
605
606
607 @Override
608 public void setCoordinates(double[] x) {
609
610 openMMContext.setPositions(x);
611 }
612
613
614
615
616 @Override
617 public void setCrystal(Crystal crystal) {
618 super.setCrystal(crystal);
619 openMMContext.setPeriodicBoxVectors(crystal);
620 }
621
622 public void setLambdaStart(double lambdaStart) {
623 this.lambdaStart = lambdaStart;
624 }
625
626 public double getLambdaStart() {
627 return lambdaStart;
628 }
629
630 public void setTwoSidedFiniteDifference(boolean twoSidedFiniteDifference) {
631 this.twoSidedFiniteDifference = twoSidedFiniteDifference;
632 }
633
634
635
636
637 @Override
638 public void setLambda(double lambda) {
639
640 if (!lambdaTerm) {
641 logger.fine(" Attempting to set lambda for a ForceFieldEnergyOpenMM with lambdaterm false.");
642 return;
643 }
644
645
646 if (lambda < 0.0 || lambda > 1.0) {
647 String message = format(" Lambda value %8.3f is not in the range [0..1].", lambda);
648 logger.warning(message);
649 return;
650 }
651
652 super.setLambda(lambda);
653
654
655 double mappedLambda = lambda;
656 if (lambdaStart > 0) {
657 double windowSize = 1.0 - lambdaStart;
658 mappedLambda = lambdaStart + lambda * windowSize;
659 }
660
661 if (openMMSystem != null) {
662 openMMSystem.setLambda(mappedLambda);
663 if (atoms != null) {
664 List<Atom> atomList = new ArrayList<>();
665 for (Atom atom : atoms) {
666 if (atom.applyLambda()) {
667 atomList.add(atom);
668 }
669 }
670
671 updateParameters(atomList.toArray(new Atom[0]));
672 } else {
673 updateParameters(null);
674 }
675 }
676 }
677
678
679
680
681
682
683 public void updateParameters(@Nullable Atom[] atoms) {
684 if (atoms == null) {
685 atoms = this.atoms;
686 }
687 openMMSystem.updateParameters(atoms);
688 }
689
690
691
692
693 private void free() {
694 if (openMMContext != null) {
695 openMMContext.free();
696 openMMContext = null;
697 }
698 if (openMMSystem != null) {
699 openMMSystem.free();
700 openMMSystem = null;
701 }
702 }
703
704 }