1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.openmm;
39
40 import edu.rit.mp.CharacterBuf;
41 import edu.rit.pj.Comm;
42 import ffx.crystal.Crystal;
43 import ffx.potential.FiniteDifferenceUtils;
44 import ffx.potential.ForceFieldEnergy;
45 import ffx.potential.MolecularAssembly;
46 import ffx.potential.Platform;
47 import ffx.potential.Utilities;
48 import ffx.potential.bonded.Atom;
49 import ffx.potential.parameters.ForceField;
50 import ffx.potential.utils.EnergyException;
51 import ffx.potential.utils.PotentialsFunctions;
52 import ffx.potential.utils.PotentialsUtils;
53 import org.apache.commons.configuration2.CompositeConfiguration;
54 import org.apache.commons.io.FilenameUtils;
55
56 import javax.annotation.Nullable;
57 import java.io.File;
58 import java.io.IOException;
59 import java.time.LocalDateTime;
60 import java.time.format.DateTimeFormatter;
61 import java.util.ArrayList;
62 import java.util.Arrays;
63 import java.util.List;
64 import java.util.logging.Level;
65 import java.util.logging.Logger;
66 import java.util.stream.Collectors;
67 import java.util.stream.IntStream;
68
69 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Energy;
70 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Forces;
71 import static java.lang.Double.isFinite;
72 import static java.lang.String.format;
73
74
75
76
77
78
79
80 @SuppressWarnings("deprecation")
81 public class OpenMMEnergy extends ForceFieldEnergy {
82
83 private static final Logger logger = Logger.getLogger(OpenMMEnergy.class.getName());
84
85
86
87
88 private final Platform platform;
89
90
91
92 private OpenMMContext openMMContext;
93
94
95
96 private OpenMMSystem openMMSystem;
97
98
99
100 private final Atom[] atoms;
101
102
103
104 private final boolean computeDEDL;
105
106
107
108
109
110
111
112
113
114 public OpenMMEnergy(MolecularAssembly molecularAssembly, Platform requestedPlatform, int nThreads) {
115 super(molecularAssembly, nThreads);
116
117 Crystal crystal = getCrystal();
118 int symOps = crystal.spaceGroup.getNumberOfSymOps();
119 if (symOps > 1) {
120 logger.info("");
121 logger.severe(" OpenMM does not support symmetry operators.");
122 }
123
124 logger.info("\n Initializing OpenMM");
125
126 ForceField forceField = molecularAssembly.getForceField();
127 atoms = molecularAssembly.getAtomArray();
128
129
130 this.platform = requestedPlatform;
131 ffx.openmm.Platform openMMPlatform = OpenMMContext.loadPlatform(platform, forceField);
132
133
134 openMMSystem = new OpenMMSystem(this);
135 openMMSystem.addForces();
136
137
138 openMMContext = new OpenMMContext(openMMPlatform, openMMSystem, atoms);
139
140 computeDEDL = forceField.getBoolean("OMM_DUDL", false);
141 }
142
143
144
145
146
147
148
149
150
151 public static int getDefaultDevice(CompositeConfiguration props) {
152 String availDeviceProp = props.getString("availableDevices", props.getString("CUDA_DEVICES"));
153 if (availDeviceProp == null) {
154 int nDevs = props.getInt("numCudaDevices", 1);
155 availDeviceProp = IntStream.range(0, nDevs).mapToObj(Integer::toString)
156 .collect(Collectors.joining(" "));
157 }
158 availDeviceProp = availDeviceProp.trim();
159
160 String[] availDevices = availDeviceProp.split("\\s+");
161 int nDevs = availDevices.length;
162 int[] devs = new int[nDevs];
163 for (int i = 0; i < nDevs; i++) {
164 devs[i] = Integer.parseInt(availDevices[i]);
165 }
166
167 logger.info(format(" Available devices: %d.", nDevs));
168
169
170 if (nDevs == 1) {
171 return devs[0];
172 }
173
174 int index = 0;
175 try {
176 Comm world = Comm.world();
177 if (world != null) {
178 int size = world.size();
179 logger.fine(format(" Number of MPI processes %d exceeds number of available devices %d.", size, nDevs));
180
181
182 int messageLen = 100;
183 String host = world.host();
184
185 host = host.substring(0, Math.min(messageLen, host.length()));
186
187 host = format("%-100s", host);
188
189 logger.fine(format(" Host: %s", host.trim()));
190 char[] messageOut = host.toCharArray();
191 CharacterBuf out = CharacterBuf.buffer(messageOut);
192
193
194 char[][] incoming = new char[size][messageLen];
195 CharacterBuf[] in = new CharacterBuf[size];
196 for (int i = 0; i < size; i++) {
197 in[i] = CharacterBuf.buffer(incoming[i]);
198 }
199
200 try {
201 logger.fine(" AllGather for determining rank.");
202 world.allGather(out, in);
203 logger.fine(" AllGather complete.");
204 } catch (IOException ex) {
205 logger.warning(format(" Failure at the allGather step for determining rank: %s\n%s", ex, Utilities.stackTraceToString(ex)));
206 }
207 int ownIndex = -1;
208 int rank = world.rank();
209 boolean selfFound = false;
210
211 for (int i = 0; i < size; i++) {
212 String hostI = new String(incoming[i]);
213 if (hostI.equalsIgnoreCase(host)) {
214 ++ownIndex;
215 if (i == rank) {
216 selfFound = true;
217 break;
218 }
219 }
220 }
221 if (!selfFound) {
222 logger.warning(format(" Rank %d: Could not find any incoming host messages matching self %s!", rank, host.trim()));
223 } else {
224 index = ownIndex % nDevs;
225 }
226 }
227 } catch (IllegalStateException ise) {
228
229 }
230 return devs[index];
231 }
232
233
234
235
236
237
238
239
240
241
242
243
244 public void updateContext(String integratorName, double timeStep, double temperature, boolean forceCreation) {
245 openMMContext.update(integratorName, timeStep, temperature, forceCreation);
246 }
247
248
249
250
251
252
253
254
255
256 public OpenMMState getOpenMMState(int mask) {
257 return openMMContext.getOpenMMState(mask);
258 }
259
260
261
262
263 @Override
264 public boolean destroy() {
265 boolean ffxFFEDestroy = super.destroy();
266 free();
267 logger.fine(" Destroyed the Context, Integrator, and OpenMMSystem.");
268 return ffxFFEDestroy;
269 }
270
271
272
273
274 @Override
275 public double energy(double[] x) {
276 return energy(x, false);
277 }
278
279
280
281
282 @Override
283 public double energy(double[] x, boolean verbose) {
284
285 if (lambdaBondedTerms) {
286 return 0.0;
287 }
288
289
290 openMMContext.update();
291
292 updateParameters(atoms);
293
294
295 unscaleCoordinates(x);
296
297 setCoordinates(x);
298
299 OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy);
300 double e = openMMState.potentialEnergy;
301 openMMState.destroy();
302
303 if (!isFinite(e)) {
304 String message = String.format(" Energy from OpenMM was a non-finite %8g", e);
305 logger.warning(message);
306 throw new EnergyException(message);
307 }
308
309 if (verbose) {
310 logger.log(Level.INFO, String.format("\n OpenMM Energy: %14.10g", e));
311 }
312
313
314 scaleCoordinates(x);
315
316 return e;
317 }
318
319
320
321
322 @Override
323 public double energyAndGradient(double[] x, double[] g) {
324 return energyAndGradient(x, g, false);
325 }
326
327
328
329
330 @Override
331 public double energyAndGradient(double[] x, double[] g, boolean verbose) {
332 if (lambdaBondedTerms) {
333 return 0.0;
334 }
335
336
337 unscaleCoordinates(x);
338
339
340 openMMContext.update();
341
342
343 setCoordinates(x);
344
345
346
347
348 OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Energy | OpenMM_State_Forces);
349 double e = openMMState.potentialEnergy;
350 g = openMMState.getGradient(g);
351 openMMState.destroy();
352
353
354
355 if (!isFinite(e)) {
356 String message = format(" Energy from OpenMM was a non-finite %8g", e);
357 logger.warning(message);
358 throw new EnergyException(message);
359 }
360
361
362
363
364
365
366
367
368
369
370 if (maxDebugGradient < Double.MAX_VALUE) {
371 boolean extremeGrad = Arrays.stream(g)
372 .anyMatch((double gi) -> (gi > maxDebugGradient || gi < -maxDebugGradient));
373 if (extremeGrad) {
374 File origFile = molecularAssembly.getFile();
375 String timeString = LocalDateTime.now()
376 .format(DateTimeFormatter.ofPattern("yyyy_MM_dd-HH_mm_ss"));
377
378 String filename = format("%s-LARGEGRAD-%s.pdb",
379 FilenameUtils.removeExtension(molecularAssembly.getFile().getName()), timeString);
380 PotentialsFunctions ef = new PotentialsUtils();
381 filename = ef.versionFile(filename);
382
383 logger.warning(
384 format(" Excessively large gradients detected; printing snapshot to file %s", filename));
385 ef.saveAsPDB(molecularAssembly, new File(filename));
386 molecularAssembly.setFile(origFile);
387 }
388 }
389
390 if (verbose) {
391 logger.log(Level.INFO, format("\n OpenMM Energy: %14.10g", e));
392 }
393
394
395 scaleCoordinatesAndGradient(x, g);
396
397 return e;
398 }
399
400
401
402
403
404
405
406
407 public double energyAndGradientFFX(double[] x, double[] g) {
408 return super.energyAndGradient(x, g, false);
409 }
410
411
412
413
414
415
416
417
418
419 public double energyAndGradientFFX(double[] x, double[] g, boolean verbose) {
420 return super.energyAndGradient(x, g, verbose);
421 }
422
423
424
425
426
427
428
429 public double energyFFX(double[] x) {
430 return super.energy(x, false);
431 }
432
433
434
435
436
437
438
439
440 public double energyFFX(double[] x, boolean verbose) {
441 return super.energy(x, verbose);
442 }
443
444
445
446
447
448
449 public OpenMMContext getContext() {
450 return openMMContext;
451 }
452
453
454
455
456
457
458 public MolecularAssembly getMolecularAssembly() {
459 return molecularAssembly;
460 }
461
462
463
464
465
466
467 public void setLambdaTerm(boolean lambdaTerm) {
468 this.lambdaTerm = lambdaTerm;
469 }
470
471
472
473
474
475
476 public boolean getLambdaTerm() {
477 return lambdaTerm;
478 }
479
480
481
482
483
484
485 @Override
486 public double[] getGradient(double[] g) {
487 OpenMMState openMMState = openMMContext.getOpenMMState(OpenMM_State_Forces);
488 g = openMMState.getGradient(g);
489 openMMState.destroy();
490 return g;
491 }
492
493
494
495
496 @Override
497 public Platform getPlatform() {
498 return platform;
499 }
500
501
502
503
504
505
506 public OpenMMSystem getSystem() {
507 return openMMSystem;
508 }
509
510
511
512
513 @Override
514 public double getd2EdL2() {
515 return 0.0;
516 }
517
518
519
520
521 @Override
522 public double getdEdL() {
523
524 if (!lambdaTerm || !computeDEDL) {
525 return 0.0;
526 }
527
528 return FiniteDifferenceUtils.computedEdL(this, this, molecularAssembly.getForceField());
529 }
530
531
532
533
534 @Override
535 public void getdEdXdL(double[] gradients) {
536
537 }
538
539
540
541
542 public void setActiveAtoms() {
543 openMMSystem.updateAtomMass();
544
545
546 }
547
548
549
550
551
552
553 @Override
554 public void setCoordinates(double[] x) {
555
556 openMMContext.setPositions(x);
557 }
558
559
560
561
562 @Override
563 public void setCrystal(Crystal crystal) {
564 super.setCrystal(crystal);
565 openMMContext.setPeriodicBoxVectors(crystal);
566 }
567
568
569
570
571 @Override
572 public void setLambda(double lambda) {
573 if (!lambdaTerm) {
574 logger.fine(" Attempting to set lambda for an OpenMMEnergy with lambdaterm false.");
575 return;
576 }
577
578 super.setLambda(lambda);
579
580 if (atoms != null) {
581 List<Atom> atomList = new ArrayList<>();
582 for (Atom atom : atoms) {
583 if (atom.applyLambda()) {
584 atomList.add(atom);
585 }
586 }
587
588 updateParameters(atomList.toArray(new Atom[0]));
589 } else {
590 updateParameters(null);
591 }
592
593 }
594
595
596
597
598
599
600 public void updateParameters(@Nullable Atom[] atoms) {
601 if (atoms == null) {
602 atoms = this.atoms;
603 }
604 if (openMMSystem != null) {
605 openMMSystem.updateParameters(atoms);
606 }
607 }
608
609
610
611
612 private void free() {
613 if (openMMContext != null) {
614 openMMContext.free();
615 openMMContext = null;
616 }
617 if (openMMSystem != null) {
618 openMMSystem.free();
619 openMMSystem = null;
620 }
621 }
622
623 }