1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.openmm;
39
40 import edu.rit.mp.CharacterBuf;
41 import edu.rit.pj.Comm;
42 import ffx.crystal.Crystal;
43 import ffx.potential.FiniteDifferenceUtils;
44 import ffx.potential.ForceFieldEnergy;
45 import ffx.potential.MolecularAssembly;
46 import ffx.potential.Platform;
47 import ffx.potential.Utilities;
48 import ffx.potential.bonded.Atom;
49 import ffx.potential.parameters.ForceField;
50 import ffx.potential.utils.EnergyException;
51 import ffx.potential.utils.PotentialsFunctions;
52 import ffx.potential.utils.PotentialsUtils;
53 import org.apache.commons.configuration2.CompositeConfiguration;
54 import org.apache.commons.io.FilenameUtils;
55
56 import javax.annotation.Nullable;
57 import java.io.File;
58 import java.io.IOException;
59 import java.time.LocalDateTime;
60 import java.time.format.DateTimeFormatter;
61 import java.util.ArrayList;
62 import java.util.Arrays;
63 import java.util.List;
64 import java.util.logging.Level;
65 import java.util.logging.Logger;
66 import java.util.stream.Collectors;
67 import java.util.stream.IntStream;
68
69 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Energy;
70 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Forces;
71 import static java.lang.Double.isFinite;
72 import static java.lang.String.format;
73
74
75
76
77
78
79
80 public class OpenMMEnergy extends ForceFieldEnergy implements OpenMMPotential {
81
82 private static final Logger logger = Logger.getLogger(OpenMMEnergy.class.getName());
83
84
85
86
87 private final Platform platform;
88
89
90
91 private OpenMMContext openMMContext;
92
93
94
95 private OpenMMSystem openMMSystem;
96
97
98
99 private final Atom[] atoms;
100
101
102
103 private final boolean computeDEDL;
104
105
106
107
108
109
110
111
112
113 public OpenMMEnergy(MolecularAssembly molecularAssembly, Platform requestedPlatform, int nThreads) {
114 super(molecularAssembly, nThreads);
115
116 Crystal crystal = getCrystal();
117 int symOps = crystal.spaceGroup.getNumberOfSymOps();
118 if (symOps > 1) {
119 logger.severe(" OpenMM does not support symmetry operators.");
120 }
121
122 logger.info("\n Initializing OpenMM");
123
124 ForceField forceField = molecularAssembly.getForceField();
125 atoms = molecularAssembly.getAtomArray();
126
127
128 this.platform = requestedPlatform;
129 ffx.openmm.Platform openMMPlatform = OpenMMContext.loadPlatform(platform, forceField);
130
131
132 openMMSystem = new OpenMMSystem(this);
133 openMMSystem.addForces();
134
135
136 openMMContext = new OpenMMContext(openMMPlatform, openMMSystem, atoms);
137
138 computeDEDL = forceField.getBoolean("OMM_DUDL", false);
139 }
140
141
142
143
144
145
146
147
148
149 public static int getDefaultDevice(CompositeConfiguration props) {
150 String availDeviceProp = props.getString("availableDevices", props.getString("CUDA_DEVICES"));
151 if (availDeviceProp == null) {
152 int nDevs = props.getInt("numCudaDevices", 1);
153 availDeviceProp = IntStream.range(0, nDevs).mapToObj(Integer::toString)
154 .collect(Collectors.joining(" "));
155 }
156 availDeviceProp = availDeviceProp.trim();
157
158 String[] availDevices = availDeviceProp.split("\\s+");
159 int nDevs = availDevices.length;
160 int[] devs = new int[nDevs];
161 for (int i = 0; i < nDevs; i++) {
162 devs[i] = Integer.parseInt(availDevices[i]);
163 }
164
165 logger.info(format(" Available devices: %d.", nDevs));
166
167
168 if (nDevs == 1) {
169 return devs[0];
170 }
171
172 int index = 0;
173 try {
174 Comm world = Comm.world();
175 if (world != null) {
176 int size = world.size();
177 logger.fine(format(" Number of MPI processes %d exceeds number of available devices %d.", size, nDevs));
178
179
180 int messageLen = 100;
181 String host = world.host();
182
183 host = host.substring(0, Math.min(messageLen, host.length()));
184
185 host = format("%-100s", host);
186
187 logger.fine(format(" Host: %s", host.trim()));
188 char[] messageOut = host.toCharArray();
189 CharacterBuf out = CharacterBuf.buffer(messageOut);
190
191
192 char[][] incoming = new char[size][messageLen];
193 CharacterBuf[] in = new CharacterBuf[size];
194 for (int i = 0; i < size; i++) {
195 in[i] = CharacterBuf.buffer(incoming[i]);
196 }
197
198 try {
199 logger.fine(" AllGather for determining rank.");
200 world.allGather(out, in);
201 logger.fine(" AllGather complete.");
202 } catch (IOException ex) {
203 logger.warning(format(" Failure at the allGather step for determining rank: %s\n%s", ex, Utilities.stackTraceToString(ex)));
204 }
205 int ownIndex = -1;
206 int rank = world.rank();
207 boolean selfFound = false;
208
209 for (int i = 0; i < size; i++) {
210 String hostI = new String(incoming[i]);
211 if (hostI.equalsIgnoreCase(host)) {
212 ++ownIndex;
213 if (i == rank) {
214 selfFound = true;
215 break;
216 }
217 }
218 }
219 if (!selfFound) {
220 logger.warning(format(" Rank %d: Could not find any incoming host messages matching self %s!", rank, host.trim()));
221 } else {
222 index = ownIndex % nDevs;
223 }
224 }
225 } catch (IllegalStateException ise) {
226
227 }
228 return devs[index];
229 }
230
231
232
233
234
235
236
237
238
239
240
241
242 @Override
243 public void updateContext(String integratorName, double timeStep, double temperature, boolean forceCreation) {
244 openMMContext.update(integratorName, timeStep, temperature, forceCreation);
245 }
246
247
248
249
250
251
252
253
254
255 @Override
256 public OpenMMState getOpenMMState(int mask) {
257 return openMMContext.getOpenMMState(mask);
258 }
259
260
261
262
263 @Override
264 public boolean destroy() {
265 boolean ffxFFEDestroy = super.destroy();
266 free();
267 logger.fine(" Destroyed the Context and OpenMMSystem.");
268 return ffxFFEDestroy;
269 }
270
271
272
273
274 @Override
275 public double energy(double[] x) {
276 return energy(x, false);
277 }
278
279
280
281
282 @Override
283 public double energy(double[] x, boolean verbose) {
284
285 if (lambdaBondedTerms) {
286 return 0.0;
287 }
288
289
290 openMMContext.update();
291
292 updateParameters(atoms);
293
294
295 unscaleCoordinates(x);
296
297 setCoordinates(x);
298
299 OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy);
300 double e = openMMState.potentialEnergy;
301 openMMState.destroy();
302
303 if (!isFinite(e)) {
304 String message = String.format(" Energy from OpenMM was a non-finite %8g", e);
305 logger.warning(message);
306 throw new EnergyException(message);
307 }
308
309 if (verbose) {
310 logger.log(Level.INFO, String.format("\n OpenMM Energy: %14.10g", e));
311 }
312
313
314 scaleCoordinates(x);
315
316 return e;
317 }
318
319
320
321
322
323
324
325 public double energyFFX(double[] x) {
326 return super.energy(x, false);
327 }
328
329
330
331
332
333
334
335
336 public double energyFFX(double[] x, boolean verbose) {
337 return super.energy(x, verbose);
338 }
339
340
341
342
343 @Override
344 public double energyAndGradient(double[] x, double[] g) {
345 return energyAndGradient(x, g, false);
346 }
347
348
349
350
351 @Override
352 public double energyAndGradient(double[] x, double[] g, boolean verbose) {
353 if (lambdaBondedTerms) {
354 return 0.0;
355 }
356
357
358 unscaleCoordinates(x);
359
360
361 openMMContext.update();
362
363
364 setCoordinates(x);
365
366
367
368
369 OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy | OpenMM_State_Forces);
370 double e = openMMState.potentialEnergy;
371 g = openMMState.getGradient(g);
372 openMMState.destroy();
373
374
375
376 if (!isFinite(e)) {
377 String message = format(" Energy from OpenMM was a non-finite %8g", e);
378 logger.warning(message);
379 throw new EnergyException(message);
380 }
381
382
383
384
385
386
387
388
389
390
391 if (maxDebugGradient < Double.MAX_VALUE) {
392 boolean extremeGrad = Arrays.stream(g)
393 .anyMatch((double gi) -> (gi > maxDebugGradient || gi < -maxDebugGradient));
394 if (extremeGrad) {
395 File origFile = molecularAssembly.getFile();
396 String timeString = LocalDateTime.now()
397 .format(DateTimeFormatter.ofPattern("yyyy_MM_dd-HH_mm_ss"));
398
399 String filename = format("%s-LARGEGRAD-%s.pdb",
400 FilenameUtils.removeExtension(molecularAssembly.getFile().getName()), timeString);
401 PotentialsFunctions ef = new PotentialsUtils();
402 filename = ef.versionFile(filename);
403
404 logger.warning(
405 format(" Excessively large gradients detected; printing snapshot to file %s", filename));
406 ef.saveAsPDB(molecularAssembly, new File(filename));
407 molecularAssembly.setFile(origFile);
408 }
409 }
410
411 if (verbose) {
412 logger.log(Level.INFO, format("\n OpenMM Energy: %14.10g", e));
413 }
414
415
416 scaleCoordinatesAndGradient(x, g);
417
418 return e;
419 }
420
421
422
423
424
425
426
427
428 public double energyAndGradientFFX(double[] x, double[] g) {
429 return super.energyAndGradient(x, g, false);
430 }
431
432
433
434
435
436
437
438
439
440 public double energyAndGradientFFX(double[] x, double[] g, boolean verbose) {
441 return super.energyAndGradient(x, g, verbose);
442 }
443
444
445
446
447
448
449 @Override
450 public OpenMMContext getContext() {
451 return openMMContext;
452 }
453
454
455
456
457
458
459 public MolecularAssembly getMolecularAssembly() {
460 return molecularAssembly;
461 }
462
463
464
465
466
467
468 @Override
469 public double[] getGradient(double[] g) {
470 OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Forces);
471 g = openMMState.getGradient(g);
472 openMMState.destroy();
473 return g;
474 }
475
476
477
478
479 @Override
480 public Platform getPlatform() {
481 return platform;
482 }
483
484
485
486
487
488
489 @Override
490 public OpenMMSystem getSystem() {
491 return openMMSystem;
492 }
493
494
495
496
497 @Override
498 public double getd2EdL2() {
499 return 0.0;
500 }
501
502
503
504
505 @Override
506 public double getdEdL() {
507
508 if (!lambdaTerm || !computeDEDL) {
509 return 0.0;
510 }
511
512 return FiniteDifferenceUtils.computedEdL(this, this, molecularAssembly.getForceField());
513 }
514
515
516
517
518 @Override
519 public void getdEdXdL(double[] gradients) {
520
521 }
522
523
524
525
526 @Override
527 public boolean setActiveAtoms() {
528 return openMMSystem.updateAtomMass();
529 }
530
531
532
533
534
535
536 @Override
537 public void setCoordinates(double[] x) {
538
539 super.setCoordinates(x);
540
541 int n = atoms.length * 3;
542 double[] xall = new double[n];
543 int i = 0;
544 for (Atom atom : atoms) {
545 xall[i] = atom.getX();
546 xall[i + 1] = atom.getY();
547 xall[i + 2] = atom.getZ();
548 i += 3;
549 }
550
551
552 openMMContext.setPositions(xall);
553 }
554
555
556
557
558
559
560 @Override
561 public void setVelocity(double[] v) {
562
563 super.setVelocity(v);
564
565 int n = atoms.length * 3;
566 double[] vall = new double[n];
567 double[] v3 = new double[3];
568 int i = 0;
569 for (Atom atom : atoms) {
570 atom.getVelocity(v3);
571 vall[i] = v3[0];
572 vall[i + 1] = v3[1];
573 vall[i + 2] = v3[2];
574 i += 3;
575 }
576
577
578 openMMContext.setVelocities(vall);
579 }
580
581
582
583
584 @Override
585 public void setCrystal(Crystal crystal) {
586 super.setCrystal(crystal);
587 openMMContext.setPeriodicBoxVectors(crystal);
588 }
589
590
591
592
593 @Override
594 public void setLambda(double lambda) {
595 if (!lambdaTerm) {
596 logger.fine(" Attempting to set lambda for an OpenMMEnergy with lambdaterm false.");
597 return;
598 }
599
600 super.setLambda(lambda);
601
602 if (atoms != null) {
603 List<Atom> atomList = new ArrayList<>();
604 for (Atom atom : atoms) {
605 if (atom.applyLambda()) {
606 atomList.add(atom);
607 }
608 }
609
610 updateParameters(atomList.toArray(new Atom[0]));
611 } else {
612 updateParameters(null);
613 }
614
615 }
616
617
618
619
620
621
622 @Override
623 public void updateParameters(@Nullable Atom[] atoms) {
624 if (atoms == null) {
625 atoms = this.atoms;
626 }
627 if (openMMSystem != null) {
628 openMMSystem.updateParameters(atoms);
629 }
630 }
631
632
633
634
635 private void free() {
636 if (openMMContext != null) {
637 openMMContext.free();
638 openMMContext = null;
639 }
640 if (openMMSystem != null) {
641 openMMSystem.free();
642 openMMSystem = null;
643 }
644 }
645
646 }