1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.openmm;
39
40 import edu.rit.pj.Comm;
41 import edu.uiowa.jopenmm.OpenMMLibrary;
42 import edu.uiowa.jopenmm.OpenMMUtils;
43 import edu.uiowa.jopenmm.OpenMM_Vec3;
44 import ffx.crystal.Crystal;
45 import ffx.openmm.Context;
46 import ffx.openmm.Integrator;
47 import ffx.openmm.MinimizationReporter;
48 import ffx.openmm.Platform;
49 import ffx.openmm.State;
50 import ffx.openmm.StringArray;
51 import ffx.potential.bonded.Atom;
52 import ffx.potential.parameters.ForceField;
53
54 import java.util.logging.Level;
55 import java.util.logging.Logger;
56
57 import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_KcalPerKJ;
58 import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_NmPerAngstrom;
59 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_False;
60 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_True;
61 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_LocalEnergyMinimizer_minimize;
62 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Positions;
63 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Velocities;
64 import static ffx.openmm.Platform.getNumPlatforms;
65 import static ffx.openmm.Platform.getOpenMMVersion;
66 import static ffx.openmm.Platform.getPluginLoadFailures;
67 import static ffx.openmm.Platform.loadPluginsFromDirectory;
68 import static ffx.potential.ForceFieldEnergy.DEFAULT_CONSTRAINT_TOLERANCE;
69 import static ffx.potential.Platform.OMM;
70 import static ffx.potential.Platform.OMM_CUDA;
71 import static ffx.potential.Platform.OMM_OPENCL;
72 import static ffx.potential.openmm.OpenMMEnergy.getDefaultDevice;
73 import static ffx.potential.openmm.OpenMMIntegrator.createIntegrator;
74 import static java.lang.String.format;
75
76
77
78
79
80
81
82
83
84
85
86
87
88 public class OpenMMContext extends Context {
89
90 private static final Logger logger = Logger.getLogger(OpenMMContext.class.getName());
91
92
93
94
95 private final Platform openMMPlatform;
96
97
98
99 private final OpenMMSystem openMMSystem;
100
101
102
103 private Integrator openMMIntegrator;
104
105
106
107 private String integratorName = "VERLET";
108
109
110
111 private double timeStep = 0.001;
112
113
114
115 private double temperature = 298.15;
116
117
118
119 private final int enforcePBC;
120
121
122
123 private final Atom[] atoms;
124
125
126
127
128
129
130
131
132 public OpenMMContext(Platform platform, OpenMMSystem openMMSystem, Atom[] atoms) {
133 this.openMMPlatform = platform;
134 this.openMMSystem = openMMSystem;
135
136 ForceField forceField = openMMSystem.getForceField();
137 boolean aperiodic = openMMSystem.getCrystal().aperiodic();
138 boolean pbcEnforced = forceField.getBoolean("ENFORCE_PBC", !aperiodic);
139 enforcePBC = pbcEnforced ? OpenMM_True : OpenMM_False;
140
141 this.atoms = atoms;
142 update(integratorName, timeStep, temperature, true);
143 }
144
145
146
147
148
149
150
151
152
153 public void update(String integratorName, double timeStep, double temperature, boolean forceCreation) {
154
155 if (hasContextPointer() && !forceCreation) {
156 if (this.temperature == temperature && this.timeStep == timeStep
157 && this.integratorName.equalsIgnoreCase(integratorName)) {
158
159 return;
160 }
161 }
162
163 this.integratorName = integratorName;
164 this.timeStep = timeStep;
165 this.temperature = temperature;
166
167 logger.info("\n Updating OpenMM Context");
168
169
170 if (openMMIntegrator != null) {
171 openMMIntegrator.destroy();
172 }
173 openMMIntegrator = createIntegrator(integratorName, timeStep, temperature, openMMSystem);
174
175
176
177
178
179
180
181
182
183 updateContext(openMMSystem, openMMIntegrator, openMMPlatform);
184
185
186
187
188
189
190
191 int nVar = openMMSystem.getNumberOfVariables();
192 double[] x = new double[nVar];
193 double[] v = new double[nVar];
194 double[] vel3 = new double[3];
195 int index = 0;
196 for (Atom a : atoms) {
197 if (a.isActive()) {
198 a.getVelocity(vel3);
199
200 x[index] = a.getX();
201 v[index++] = vel3[0];
202
203 x[index] = a.getY();
204 v[index++] = vel3[1];
205
206 x[index] = a.getZ();
207 v[index++] = vel3[2];
208 }
209 }
210
211
212 Crystal crystal = openMMSystem.getCrystal();
213 setPeriodicBoxVectors(crystal);
214
215
216 setPositions(x);
217
218
219 setVelocities(v);
220
221
222 applyConstraints(DEFAULT_CONSTRAINT_TOLERANCE);
223
224
225
226 OpenMMState openMMState = getOpenMMState(OpenMM_State_Positions | OpenMM_State_Velocities);
227 openMMState.getPositions(x);
228 openMMState.getVelocities(v);
229 openMMState.destroy();
230 }
231
232
233
234
235 public void update() {
236 if (!hasContextPointer()) {
237 logger.info(" Delayed creation of OpenMM Context.");
238 update(integratorName, timeStep, temperature, true);
239 }
240 }
241
242
243
244
245
246
247
248 public OpenMMState getOpenMMState(int mask) {
249 State state = getState(mask, enforcePBC);
250 return new OpenMMState(state.getPointer(), atoms, openMMSystem.getNumberOfVariables());
251 }
252
253
254
255
256
257
258 public void integrate(int numSteps) {
259 openMMIntegrator.step(numSteps);
260 }
261
262
263
264
265
266
267
268 public void optimize(double eps, int maxIterations) {
269
270 MinimizationReporter reporter = new MinimizationReporter();
271 OpenMM_LocalEnergyMinimizer_minimize(getPointer(), eps / (OpenMM_NmPerAngstrom * OpenMM_KcalPerKJ),
272 maxIterations, reporter.getPointer());
273 reporter.destroy();
274 }
275
276
277
278
279
280
281 public void setPositions(double[] x) {
282 double[] allPositions = new double[3 * atoms.length];
283 double[] d = new double[3];
284 int index = 0;
285 for (Atom a : atoms) {
286 if (a.isActive()) {
287 a.moveTo(x[index], x[index + 1], x[index + 2]);
288 }
289 a.getXYZ(d);
290 allPositions[index] = d[0] * OpenMM_NmPerAngstrom;
291 allPositions[index + 1] = d[1] * OpenMM_NmPerAngstrom;
292 allPositions[index + 2] = d[2] * OpenMM_NmPerAngstrom;
293 index += 3;
294 }
295 super.setPositions(allPositions);
296 }
297
298
299
300
301
302
303 public void setVelocities(double[] v) {
304 double[] allVelocities = new double[3 * atoms.length];
305 double[] velocity = new double[3];
306 int index = 0;
307 for (Atom a : atoms) {
308 if (a.isActive()) {
309 a.setVelocity(v[index], v[index + 1], v[index + 2]);
310 } else {
311
312 a.setVelocity(0.0, 0.0, 0.0);
313 }
314 a.getVelocity(velocity);
315 allVelocities[index] = velocity[0] * OpenMM_NmPerAngstrom;
316 allVelocities[index + 1] = velocity[1] * OpenMM_NmPerAngstrom;
317 allVelocities[index + 2] = velocity[2] * OpenMM_NmPerAngstrom;
318 index += 3;
319 }
320 super.setVelocities(allVelocities);
321 }
322
323
324
325
326
327
328 public void setPeriodicBoxVectors(Crystal crystal) {
329 if (!crystal.aperiodic()) {
330 OpenMM_Vec3 a = new OpenMM_Vec3();
331 OpenMM_Vec3 b = new OpenMM_Vec3();
332 OpenMM_Vec3 c = new OpenMM_Vec3();
333 double[][] Ai = crystal.Ai;
334 a.x = Ai[0][0] * OpenMM_NmPerAngstrom;
335 a.y = Ai[0][1] * OpenMM_NmPerAngstrom;
336 a.z = Ai[0][2] * OpenMM_NmPerAngstrom;
337 b.x = Ai[1][0] * OpenMM_NmPerAngstrom;
338 b.y = Ai[1][1] * OpenMM_NmPerAngstrom;
339 b.z = Ai[1][2] * OpenMM_NmPerAngstrom;
340 c.x = Ai[2][0] * OpenMM_NmPerAngstrom;
341 c.y = Ai[2][1] * OpenMM_NmPerAngstrom;
342 c.z = Ai[2][2] * OpenMM_NmPerAngstrom;
343 setPeriodicBoxVectors(a, b, c);
344 }
345 }
346
347
348
349
350 @Override
351 public String toString() {
352 return format(
353 " OpenMM context with integrator %s, timestep %9.3g fsec, temperature %9.3g K",
354 integratorName, timeStep, temperature);
355 }
356
357
358
359
360
361
362
363
364 public static Platform loadPlatform(ffx.potential.Platform requestedPlatform, ForceField forceField) {
365
366 OpenMMUtils.init();
367
368 logger.log(Level.INFO, " Loaded from:\n {0}", OpenMMLibrary.JNA_NATIVE_LIB.toString());
369
370
371 logger.log(Level.INFO, " Version: {0}", getOpenMMVersion());
372
373
374 String libDirectory = OpenMMUtils.getLibDirectory();
375 logger.log(Level.FINE, " Lib Directory: {0}", libDirectory);
376
377 StringArray libs = loadPluginsFromDirectory(libDirectory);
378 int numLibs = libs.getSize();
379 logger.log(Level.FINE, " Number of libraries: {0}", numLibs);
380 for (int i = 0; i < numLibs; i++) {
381 logger.log(Level.FINE, " Library: {0}", libs.get(i));
382 }
383 libs.destroy();
384
385
386 String pluginDirectory = OpenMMUtils.getPluginDirectory();
387 logger.log(Level.INFO, "\n Plugin Directory: {0}", pluginDirectory);
388
389 StringArray plugins = loadPluginsFromDirectory(pluginDirectory);
390 int numPlugins = plugins.getSize();
391 logger.log(Level.INFO, " Number of Plugins: {0}", numPlugins);
392 boolean cuda = false;
393 boolean opencl = false;
394 for (int i = 0; i < numPlugins; i++) {
395 String pluginString = plugins.get(i);
396 logger.log(Level.INFO, " Plugin: {0}", pluginString);
397 if (pluginString != null) {
398 pluginString = pluginString.toUpperCase();
399 boolean amoebaCudaAvailable = pluginString.contains("AMOEBACUDA");
400 if (amoebaCudaAvailable) {
401 cuda = true;
402 }
403 boolean amoebaOpenCLAvailable = pluginString.contains("AMOEBAOPENCL");
404 if (amoebaOpenCLAvailable) {
405 opencl = true;
406 }
407 }
408 }
409 plugins.destroy();
410
411 int numPlatforms = getNumPlatforms();
412 logger.log(Level.INFO, " Number of Platforms: {0}", numPlatforms);
413
414 if (requestedPlatform == OMM_CUDA && !cuda) {
415 logger.severe(" The OMM_CUDA platform was requested, but is not available.");
416 }
417
418 if (requestedPlatform == ffx.potential.Platform.OMM_OPENCL && !opencl) {
419 logger.severe(" The OMM_OPENCL platform was requested, but is not available.");
420 }
421
422
423 if (logger.isLoggable(Level.FINE)) {
424 StringArray pluginFailures = getPluginLoadFailures();
425 int numFailures = pluginFailures.getSize();
426 for (int i = 0; i < numFailures; i++) {
427 logger.log(Level.FINE, " Plugin load failure: {0}", pluginFailures.get(i));
428 }
429 pluginFailures.destroy();
430 }
431
432 String defaultPrecision = "mixed";
433 String precision = forceField.getString("PRECISION", defaultPrecision).toLowerCase();
434 precision = precision.replace("-precision", "");
435 switch (precision) {
436 case "double", "mixed", "single" -> logger.info(format(" Precision level: %s", precision));
437 default -> {
438 logger.info(format(" Could not interpret precision level %s, defaulting to %s", precision, defaultPrecision));
439 precision = defaultPrecision;
440 }
441 }
442
443
444 Platform openMMPlatform;
445 if (cuda && (requestedPlatform == OMM_CUDA || requestedPlatform == OMM)) {
446
447 int defaultDevice = getDefaultDevice(forceField.getProperties());
448 openMMPlatform = new Platform("CUDA");
449
450 int deviceID = forceField.getInteger("CUDA_DEVICE", defaultDevice);
451 deviceID = forceField.getInteger("DeviceIndex", deviceID);
452 String deviceIDString = Integer.toString(deviceID);
453 openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
454 openMMPlatform.setPropertyDefaultValue("Precision", precision);
455 String name = openMMPlatform.getName();
456 logger.info(format(" Platform: %s (Device Index %d)", name, deviceID));
457 } else if (opencl && (requestedPlatform == OMM_OPENCL || requestedPlatform == OMM)) {
458
459 int defaultDevice = getDefaultDevice(forceField.getProperties());
460 openMMPlatform = new Platform("OpenCL");
461 int deviceID = forceField.getInteger("DeviceIndex", defaultDevice);
462 String deviceIDString = Integer.toString(deviceID);
463 openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
464 int openCLPlatformIndex = forceField.getInteger("OpenCLPlatformIndex", 0);
465 String openCLPlatformIndexString = Integer.toString(openCLPlatformIndex);
466 openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
467 openMMPlatform.setPropertyDefaultValue("OpenCLPlatformIndex", openCLPlatformIndexString);
468 openMMPlatform.setPropertyDefaultValue("Precision", precision);
469 String name = openMMPlatform.getName();
470 logger.info(format(" Platform: %s (Platform Index %d, Device Index %d)",
471 name, openCLPlatformIndex, deviceID));
472 } else {
473
474 openMMPlatform = new Platform("Reference");
475 String name = openMMPlatform.getName();
476 logger.info(format(" Platform: %s", name));
477 }
478
479 try {
480 Comm world = Comm.world();
481 if (world != null) {
482 logger.info(format(" Running on host %s, rank %d", world.host(), world.rank()));
483 }
484 } catch (IllegalStateException illegalStateException) {
485 logger.fine(" Could not find the world communicator!");
486 }
487
488 return openMMPlatform;
489 }
490
491
492
493
494 public void free() {
495 if (openMMIntegrator != null) {
496 openMMIntegrator.destroy();
497 }
498 destroy();
499 }
500 }