1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.openmm;
39
40 import edu.rit.mp.CharacterBuf;
41 import edu.rit.pj.Comm;
42 import ffx.crystal.Crystal;
43 import ffx.potential.FiniteDifferenceUtils;
44 import ffx.potential.ForceFieldEnergy;
45 import ffx.potential.MolecularAssembly;
46 import ffx.potential.Platform;
47 import ffx.potential.Utilities;
48 import ffx.potential.bonded.Atom;
49 import ffx.potential.parameters.ForceField;
50 import ffx.potential.utils.EnergyException;
51 import ffx.potential.utils.PotentialsFunctions;
52 import ffx.potential.utils.PotentialsUtils;
53 import org.apache.commons.configuration2.CompositeConfiguration;
54 import org.apache.commons.io.FilenameUtils;
55
56 import javax.annotation.Nullable;
57 import java.io.File;
58 import java.io.IOException;
59 import java.time.LocalDateTime;
60 import java.time.format.DateTimeFormatter;
61 import java.util.ArrayList;
62 import java.util.Arrays;
63 import java.util.List;
64 import java.util.logging.Level;
65 import java.util.logging.Logger;
66 import java.util.stream.Collectors;
67 import java.util.stream.IntStream;
68
69 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_True;
70 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Energy;
71 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Forces;
72 import static java.lang.Double.isFinite;
73 import static java.lang.String.format;
74
75
76
77
78
79
80
81 public class OpenMMEnergy extends ForceFieldEnergy implements OpenMMPotential {
82
83 private static final Logger logger = Logger.getLogger(OpenMMEnergy.class.getName());
84
85
86
87
88 private final Platform platform;
89
90
91
92 private OpenMMContext openMMContext;
93
94
95
96 private OpenMMSystem openMMSystem;
97
98
99
100 private final Atom[] atoms;
101
102
103
104 private final boolean computeDEDL;
105
106
107
108
109
110
111
112
113
114 public OpenMMEnergy(MolecularAssembly molecularAssembly, Platform requestedPlatform, int nThreads) {
115 super(molecularAssembly, nThreads);
116
117 Crystal crystal = getCrystal();
118 int symOps = crystal.spaceGroup.getNumberOfSymOps();
119 if (symOps > 1) {
120 logger.severe(" OpenMM does not support symmetry operators.");
121 }
122
123 logger.info("\n Initializing OpenMM");
124
125 ForceField forceField = molecularAssembly.getForceField();
126 atoms = molecularAssembly.getAtomArray();
127
128
129 this.platform = requestedPlatform;
130 ffx.openmm.Platform openMMPlatform = OpenMMContext.loadPlatform(platform, forceField);
131
132
133 openMMSystem = new OpenMMSystem(this);
134 openMMSystem.addForces();
135
136
137 openMMContext = new OpenMMContext(openMMPlatform, openMMSystem, atoms);
138
139 computeDEDL = forceField.getBoolean("OMM_DUDL", false);
140 }
141
142
143
144
145
146
147
148
149
150 public static int getDefaultDevice(CompositeConfiguration props) {
151 String availDeviceProp = props.getString("availableDevices", props.getString("CUDA_DEVICES"));
152 if (availDeviceProp == null) {
153 int nDevs = props.getInt("numCudaDevices", 1);
154 availDeviceProp = IntStream.range(0, nDevs).mapToObj(Integer::toString)
155 .collect(Collectors.joining(" "));
156 }
157 availDeviceProp = availDeviceProp.trim();
158
159 String[] availDevices = availDeviceProp.split("\\s+");
160 int nDevs = availDevices.length;
161 int[] devs = new int[nDevs];
162 for (int i = 0; i < nDevs; i++) {
163 devs[i] = Integer.parseInt(availDevices[i]);
164 }
165
166 logger.info(format(" Available devices: %d.", nDevs));
167
168
169 if (nDevs == 1) {
170 return devs[0];
171 }
172
173 int index = 0;
174 try {
175 Comm world = Comm.world();
176 if (world != null) {
177 int size = world.size();
178 logger.fine(format(" Number of MPI processes %d exceeds number of available devices %d.", size, nDevs));
179
180
181 int messageLen = 100;
182 String host = world.host();
183
184 host = host.substring(0, Math.min(messageLen, host.length()));
185
186 host = format("%-100s", host);
187
188 logger.fine(format(" Host: %s", host.trim()));
189 char[] messageOut = host.toCharArray();
190 CharacterBuf out = CharacterBuf.buffer(messageOut);
191
192
193 char[][] incoming = new char[size][messageLen];
194 CharacterBuf[] in = new CharacterBuf[size];
195 for (int i = 0; i < size; i++) {
196 in[i] = CharacterBuf.buffer(incoming[i]);
197 }
198
199 try {
200 logger.fine(" AllGather for determining rank.");
201 world.allGather(out, in);
202 logger.fine(" AllGather complete.");
203 } catch (IOException ex) {
204 logger.warning(format(" Failure at the allGather step for determining rank: %s\n%s", ex, Utilities.stackTraceToString(ex)));
205 }
206 int ownIndex = -1;
207 int rank = world.rank();
208 boolean selfFound = false;
209
210 for (int i = 0; i < size; i++) {
211 String hostI = new String(incoming[i]);
212 if (hostI.equalsIgnoreCase(host)) {
213 ++ownIndex;
214 if (i == rank) {
215 selfFound = true;
216 break;
217 }
218 }
219 }
220 if (!selfFound) {
221 logger.warning(format(" Rank %d: Could not find any incoming host messages matching self %s!", rank, host.trim()));
222 } else {
223 index = ownIndex % nDevs;
224 }
225 }
226 } catch (IllegalStateException ise) {
227
228 }
229 return devs[index];
230 }
231
232
233
234
235
236
237
238
239
240
241
242
243 @Override
244 public void updateContext(String integratorName, double timeStep, double temperature, boolean forceCreation) {
245 openMMContext.update(integratorName, timeStep, temperature, forceCreation);
246 }
247
248
249
250
251
252
253
254
255
256 @Override
257 public OpenMMState getOpenMMState(int mask) {
258 return openMMContext.getOpenMMState(mask);
259 }
260
261
262
263
264 @Override
265 public boolean destroy() {
266 boolean ffxFFEDestroy = super.destroy();
267 free();
268 logger.fine(" Destroyed the Context and OpenMMSystem.");
269 return ffxFFEDestroy;
270 }
271
272
273
274
275 @Override
276 public double energy(double[] x) {
277 return energy(x, false);
278 }
279
280
281
282
283 @Override
284 public double energy(double[] x, boolean verbose) {
285
286 if (lambdaBondedTerms) {
287 return 0.0;
288 }
289
290
291 openMMContext.update();
292
293 updateParameters(atoms);
294
295
296 unscaleCoordinates(x);
297
298 setCoordinates(x);
299
300 OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy);
301 double e = openMMState.potentialEnergy;
302 openMMState.destroy();
303
304 if (!isFinite(e)) {
305 String message = String.format(" Energy from OpenMM was a non-finite %8g", e);
306 logger.warning(message);
307 throw new EnergyException(message);
308 }
309
310 if (verbose) {
311 logger.log(Level.INFO, String.format("\n OpenMM Energy: %14.10g", e));
312 }
313
314
315 scaleCoordinates(x);
316
317 return e;
318 }
319
320
321
322
323
324
325
326 public double energyFFX(double[] x) {
327 return super.energy(x, false);
328 }
329
330
331
332
333
334
335
336
337 public double energyFFX(double[] x, boolean verbose) {
338 return super.energy(x, verbose);
339 }
340
341
342
343
344 @Override
345 public double energyAndGradient(double[] x, double[] g) {
346 return energyAndGradient(x, g, false);
347 }
348
349
350
351
352 @Override
353 public double energyAndGradient(double[] x, double[] g, boolean verbose) {
354 if (lambdaBondedTerms) {
355 return 0.0;
356 }
357
358
359 unscaleCoordinates(x);
360
361
362 openMMContext.update();
363
364
365 setCoordinates(x);
366
367
368
369
370 OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy | OpenMM_State_Forces);
371 double e = openMMState.potentialEnergy;
372 g = openMMState.getGradient(g);
373 openMMState.destroy();
374
375
376
377 if (!isFinite(e)) {
378 String message = format(" Energy from OpenMM was a non-finite %8g", e);
379 logger.warning(message);
380 throw new EnergyException(message);
381 }
382
383
384
385
386
387
388
389
390
391
392 if (maxDebugGradient < Double.MAX_VALUE) {
393 boolean extremeGrad = Arrays.stream(g)
394 .anyMatch((double gi) -> (gi > maxDebugGradient || gi < -maxDebugGradient));
395 if (extremeGrad) {
396 File origFile = molecularAssembly.getFile();
397 String timeString = LocalDateTime.now()
398 .format(DateTimeFormatter.ofPattern("yyyy_MM_dd-HH_mm_ss"));
399
400 String filename = format("%s-LARGEGRAD-%s.pdb",
401 FilenameUtils.removeExtension(molecularAssembly.getFile().getName()), timeString);
402 PotentialsFunctions ef = new PotentialsUtils();
403 filename = ef.versionFile(filename);
404
405 logger.warning(
406 format(" Excessively large gradients detected; printing snapshot to file %s", filename));
407 ef.saveAsPDB(molecularAssembly, new File(filename));
408 molecularAssembly.setFile(origFile);
409 }
410 }
411
412 if (verbose) {
413 logger.log(Level.INFO, format("\n OpenMM Energy: %14.10g", e));
414 }
415
416
417 scaleCoordinatesAndGradient(x, g);
418
419 return e;
420 }
421
422
423
424
425
426
427
428
429 public double energyAndGradientFFX(double[] x, double[] g) {
430 return super.energyAndGradient(x, g, false);
431 }
432
433
434
435
436
437
438
439
440
441 public double energyAndGradientFFX(double[] x, double[] g, boolean verbose) {
442 return super.energyAndGradient(x, g, verbose);
443 }
444
445
446
447
448
449
450 @Override
451 public OpenMMContext getContext() {
452 return openMMContext;
453 }
454
455
456
457
458
459
460 public MolecularAssembly getMolecularAssembly() {
461 return molecularAssembly;
462 }
463
464
465
466
467
468
469 @Override
470 public double[] getGradient(double[] g) {
471 OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Forces);
472 g = openMMState.getGradient(g);
473 openMMState.destroy();
474 return g;
475 }
476
477
478
479
480 @Override
481 public Platform getPlatform() {
482 return platform;
483 }
484
485
486
487
488
489
490 @Override
491 public OpenMMSystem getSystem() {
492 return openMMSystem;
493 }
494
495
496
497
498 @Override
499 public double getd2EdL2() {
500 return 0.0;
501 }
502
503
504
505
506 @Override
507 public double getdEdL() {
508
509 if (!lambdaTerm || !computeDEDL) {
510 return 0.0;
511 }
512
513 return FiniteDifferenceUtils.computedEdL(this, this, molecularAssembly.getForceField());
514 }
515
516
517
518
519 @Override
520 public void getdEdXdL(double[] gradients) {
521
522 }
523
524
525
526
527 @Override
528 public boolean setActiveAtoms() {
529 return openMMSystem.updateAtomMass();
530 }
531
532
533
534
535
536
537 @Override
538 public void setCoordinates(double[] x) {
539
540 super.setCoordinates(x);
541
542 int n = atoms.length * 3;
543 double[] xall = new double[n];
544 int i = 0;
545 for (Atom atom : atoms) {
546 xall[i] = atom.getX();
547 xall[i + 1] = atom.getY();
548 xall[i + 2] = atom.getZ();
549 i += 3;
550 }
551
552
553 openMMContext.setPositions(xall);
554 }
555
556
557
558
559
560
561 @Override
562 public void setVelocity(double[] v) {
563
564 super.setVelocity(v);
565
566 int n = atoms.length * 3;
567 double[] vall = new double[n];
568 double[] v3 = new double[3];
569 int i = 0;
570 for (Atom atom : atoms) {
571 atom.getVelocity(v3);
572 vall[i] = v3[0];
573 vall[i + 1] = v3[1];
574 vall[i + 2] = v3[2];
575 i += 3;
576 }
577
578
579 openMMContext.setVelocities(vall);
580 }
581
582
583
584
585 @Override
586 public void setCrystal(Crystal crystal) {
587 super.setCrystal(crystal);
588 openMMContext.setPeriodicBoxVectors(crystal);
589 }
590
591
592
593
594 @Override
595 public void setLambda(double lambda) {
596 if (!lambdaTerm) {
597 logger.fine(" Attempting to set lambda for an OpenMMEnergy with lambdaterm false.");
598 return;
599 }
600
601 super.setLambda(lambda);
602
603 if (atoms != null) {
604 List<Atom> atomList = new ArrayList<>();
605 for (Atom atom : atoms) {
606 if (atom.applyLambda()) {
607 atomList.add(atom);
608 }
609 }
610
611 updateParameters(atomList.toArray(new Atom[0]));
612 } else {
613 updateParameters(null);
614 }
615
616 }
617
618
619
620
621
622
623 @Override
624 public void updateParameters(@Nullable Atom[] atoms) {
625 if (atoms == null) {
626 atoms = this.atoms;
627 }
628 if (openMMSystem != null) {
629 openMMSystem.updateParameters(atoms);
630 }
631 }
632
633
634
635
636 private void free() {
637 if (openMMContext != null) {
638 openMMContext.free();
639 openMMContext = null;
640 }
641 if (openMMSystem != null) {
642 openMMSystem.free();
643 openMMSystem = null;
644 }
645 }
646
647 }