1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.openmm;
39
40 import edu.rit.pj.Comm;
41 import edu.uiowa.jopenmm.OpenMMLibrary;
42 import edu.uiowa.jopenmm.OpenMMUtils;
43 import edu.uiowa.jopenmm.OpenMM_Vec3;
44 import ffx.crystal.Crystal;
45 import ffx.numerics.Potential;
46 import ffx.openmm.Context;
47 import ffx.openmm.Integrator;
48 import ffx.openmm.MinimizationReporter;
49 import ffx.openmm.Platform;
50 import ffx.openmm.State;
51 import ffx.openmm.StringArray;
52 import ffx.potential.bonded.Atom;
53 import ffx.potential.parameters.ForceField;
54
55 import java.util.logging.Level;
56 import java.util.logging.Logger;
57
58 import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_KcalPerKJ;
59 import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_NmPerAngstrom;
60 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_False;
61 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_Boolean.OpenMM_True;
62 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_LocalEnergyMinimizer_minimize;
63 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Positions;
64 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Velocities;
65 import static ffx.openmm.Platform.getNumPlatforms;
66 import static ffx.openmm.Platform.getOpenMMVersion;
67 import static ffx.openmm.Platform.getPluginLoadFailures;
68 import static ffx.openmm.Platform.loadPluginsFromDirectory;
69 import static ffx.potential.ForceFieldEnergy.DEFAULT_CONSTRAINT_TOLERANCE;
70 import static ffx.potential.Platform.OMM;
71 import static ffx.potential.Platform.OMM_CUDA;
72 import static ffx.potential.Platform.OMM_OPENCL;
73 import static ffx.potential.openmm.OpenMMEnergy.getDefaultDevice;
74 import static ffx.potential.openmm.OpenMMIntegrator.createIntegrator;
75 import static java.lang.String.format;
76
77
78
79
80
81
82
83
84
85
86
87
88
89 public class OpenMMContext extends Context {
90
91 private static final Logger logger = Logger.getLogger(OpenMMContext.class.getName());
92
93
94
95
96 private final OpenMMSystem openMMSystem;
97
98
99
100 private String integratorName = "VERLET";
101
102
103
104 private double timeStep = 0.001;
105
106
107
108 private double temperature = 298.15;
109
110
111
112 private final int enforcePBC;
113
114
115
116 private final Atom[] atoms;
117
118
119
120
121
122
123
124
125 public OpenMMContext(Platform platform, OpenMMSystem openMMSystem, Atom[] atoms) {
126 super(openMMSystem, createIntegrator("VERLET", 0.001, 298.15, openMMSystem), platform);
127 this.openMMSystem = openMMSystem;
128 this.atoms = atoms;
129
130 ForceField forceField = openMMSystem.getForceField();
131 boolean aperiodic = openMMSystem.getCrystal().aperiodic();
132 boolean pbcEnforced = forceField.getBoolean("ENFORCE_PBC", !aperiodic);
133 enforcePBC = pbcEnforced ? OpenMM_True : OpenMM_False;
134 }
135
136
137
138
139
140
141
142
143
144 public void update(String integratorName, double timeStep, double temperature, boolean forceCreation) {
145
146 if (hasContextPointer() && !forceCreation) {
147 if (this.temperature == temperature && this.timeStep == timeStep
148 && this.integratorName.equalsIgnoreCase(integratorName)) {
149
150 return;
151 }
152 }
153
154 this.integratorName = integratorName;
155 this.timeStep = timeStep;
156 this.temperature = temperature;
157
158 logger.info("\n Updating OpenMM Context");
159
160
161
162
163
164
165
166
167
168 Integrator newIntegrator = createIntegrator(integratorName, timeStep, temperature, openMMSystem);
169 Platform newPlatform = new Platform(platform.getName());
170 updateContext(openMMSystem, newIntegrator, newPlatform);
171
172
173
174
175
176
177
178 int nVar = atoms.length * 3;
179 double[] x = new double[nVar];
180 double[] v = new double[nVar];
181 double[] vel3 = new double[3];
182 int index = 0;
183 for (Atom a : atoms) {
184 a.getVelocity(vel3);
185
186 x[index] = a.getX();
187 v[index++] = vel3[0];
188
189 x[index] = a.getY();
190 v[index++] = vel3[1];
191
192 x[index] = a.getZ();
193 v[index++] = vel3[2];
194 }
195
196
197 Crystal crystal = openMMSystem.getCrystal();
198 setPeriodicBoxVectors(crystal);
199
200
201 setPositions(x);
202
203
204 setVelocities(v);
205
206
207 applyConstraints(DEFAULT_CONSTRAINT_TOLERANCE);
208
209
210
211 OpenMMState openMMState = getOpenMMState(OpenMM_State_Positions | OpenMM_State_Velocities);
212 Potential energy = openMMSystem.getPotential();
213 energy.setCoordinates(openMMState.getActivePositions(null, atoms));
214 energy.setVelocity(openMMState.getActiveVelocities(null, atoms));
215 openMMState.destroy();
216 }
217
218
219
220
221 public void update() {
222 if (!hasContextPointer()) {
223 logger.info(" Delayed creation of OpenMM Context.");
224 update(integratorName, timeStep, temperature, true);
225 }
226 }
227
228
229
230
231
232
233
234 public OpenMMState getOpenMMState(int mask) {
235 State state = getState(mask, enforcePBC);
236 return new OpenMMState(state.getPointer());
237 }
238
239
240
241
242
243
244 public void integrate(int numSteps) {
245 Integrator integrator = getIntegrator();
246 integrator.step(numSteps);
247 }
248
249
250
251
252
253
254
255 public void optimize(double eps, int maxIterations) {
256
257 MinimizationReporter reporter = new MinimizationReporter();
258 OpenMM_LocalEnergyMinimizer_minimize(getPointer(), eps / (OpenMM_NmPerAngstrom * OpenMM_KcalPerKJ),
259 maxIterations, reporter.getPointer());
260 reporter.destroy();
261 }
262
263
264
265
266
267
268 @Override
269 public void setPositions(double[] x) {
270 long time = -System.nanoTime();
271 int n = x.length;
272 double[] xn = new double[n];
273 for (int i = 0; i < n; i++) {
274
275 xn[i] = x[i] * OpenMM_NmPerAngstrom;
276 }
277 super.setPositions(xn);
278 time += System.nanoTime();
279 if (logger.isLoggable(Level.FINEST)) {
280 logger.finest(format(" Set OpenMM positions %9.6f (msec)", time * 1e-6));
281 }
282 }
283
284
285
286
287
288
289 @Override
290 public void setVelocities(double[] v) {
291 long time = -System.nanoTime();
292 int n = v.length;
293 double[] vn = new double[n];
294 for (int i = 0; i < n; i++) {
295
296 vn[i] = v[i] * OpenMM_NmPerAngstrom;
297 }
298 super.setVelocities(vn);
299 time += System.nanoTime();
300 if (logger.isLoggable(Level.FINEST)) {
301 logger.finest(format(" Set OpenMM velocities %9.6f (msec)", time * 1e-6));
302 }
303 }
304
305
306
307
308
309
310 public void setPeriodicBoxVectors(Crystal crystal) {
311 if (!crystal.aperiodic()) {
312 OpenMM_Vec3 a = new OpenMM_Vec3();
313 OpenMM_Vec3 b = new OpenMM_Vec3();
314 OpenMM_Vec3 c = new OpenMM_Vec3();
315 double[][] Ai = crystal.Ai;
316 a.x = Ai[0][0] * OpenMM_NmPerAngstrom;
317 a.y = Ai[0][1] * OpenMM_NmPerAngstrom;
318 a.z = Ai[0][2] * OpenMM_NmPerAngstrom;
319 b.x = Ai[1][0] * OpenMM_NmPerAngstrom;
320 b.y = Ai[1][1] * OpenMM_NmPerAngstrom;
321 b.z = Ai[1][2] * OpenMM_NmPerAngstrom;
322 c.x = Ai[2][0] * OpenMM_NmPerAngstrom;
323 c.y = Ai[2][1] * OpenMM_NmPerAngstrom;
324 c.z = Ai[2][2] * OpenMM_NmPerAngstrom;
325 setPeriodicBoxVectors(a, b, c);
326 }
327 }
328
329
330
331
332 @Override
333 public String toString() {
334 return format(
335 " OpenMM context with integrator %s, timestep %9.3g fsec, temperature %9.3g K",
336 integratorName, timeStep, temperature);
337 }
338
339
340
341
342
343
344
345
346 public static Platform loadPlatform(ffx.potential.Platform requestedPlatform, ForceField forceField) {
347
348 OpenMMUtils.init();
349
350
351 logger.log(Level.INFO, " Loaded from:\n {0}", OpenMMLibrary.JNA_NATIVE_LIB.toString());
352
353
354 logger.log(Level.INFO, " Version: {0}", getOpenMMVersion());
355
356
357 String libDirectory = OpenMMUtils.getLibDirectory();
358 logger.log(Level.FINE, " Lib Directory: {0}", libDirectory);
359
360 StringArray libs = loadPluginsFromDirectory(libDirectory);
361 int numLibs = libs.getSize();
362 logger.log(Level.FINE, " Number of libraries: {0}", numLibs);
363 for (int i = 0; i < numLibs; i++) {
364 logger.log(Level.FINE, " Library: {0}", libs.get(i));
365 }
366 libs.destroy();
367
368
369 String pluginDirectory = OpenMMUtils.getPluginDirectory();
370 logger.log(Level.INFO, "\n Plugin Directory: {0}", pluginDirectory);
371
372 StringArray plugins = loadPluginsFromDirectory(pluginDirectory);
373 int numPlugins = plugins.getSize();
374 logger.log(Level.INFO, " Number of Plugins: {0}", numPlugins);
375 boolean cuda = false;
376 boolean opencl = false;
377 for (int i = 0; i < numPlugins; i++) {
378 String pluginString = plugins.get(i);
379 logger.log(Level.INFO, " Plugin: {0}", pluginString);
380 if (pluginString != null) {
381 pluginString = pluginString.toUpperCase();
382 boolean amoebaCudaAvailable = pluginString.contains("AMOEBACUDA");
383 if (amoebaCudaAvailable) {
384 cuda = true;
385 }
386 boolean amoebaOpenCLAvailable = pluginString.contains("AMOEBAOPENCL");
387 if (amoebaOpenCLAvailable) {
388 opencl = true;
389 }
390 }
391 }
392 plugins.destroy();
393
394 int numPlatforms = getNumPlatforms();
395 logger.log(Level.INFO, " Number of Platforms: {0}", numPlatforms);
396
397 if (requestedPlatform == OMM_CUDA && !cuda) {
398 logger.severe(" The OMM_CUDA platform was requested, but is not available.");
399 }
400
401 if (requestedPlatform == ffx.potential.Platform.OMM_OPENCL && !opencl) {
402 logger.severe(" The OMM_OPENCL platform was requested, but is not available.");
403 }
404
405
406 if (logger.isLoggable(Level.FINE)) {
407 StringArray pluginFailures = getPluginLoadFailures();
408 int numFailures = pluginFailures.getSize();
409 for (int i = 0; i < numFailures; i++) {
410 logger.log(Level.FINE, " Plugin load failure: {0}", pluginFailures.get(i));
411 }
412 pluginFailures.destroy();
413 }
414
415 String defaultPrecision = "mixed";
416 String precision = forceField.getString("PRECISION", defaultPrecision).toLowerCase();
417 precision = precision.replace("-precision", "");
418 switch (precision) {
419 case "double", "mixed", "single" -> logger.info(format(" Precision level: %s", precision));
420 default -> {
421 logger.info(format(" Could not interpret precision level %s, defaulting to %s", precision, defaultPrecision));
422 precision = defaultPrecision;
423 }
424 }
425
426
427 Platform openMMPlatform;
428 if (cuda && (requestedPlatform == OMM_CUDA || requestedPlatform == OMM)) {
429
430 int defaultDevice = getDefaultDevice(forceField.getProperties());
431 openMMPlatform = new Platform("CUDA");
432
433 int deviceID = forceField.getInteger("CUDA_DEVICE", defaultDevice);
434 deviceID = forceField.getInteger("DeviceIndex", deviceID);
435 String deviceIDString = Integer.toString(deviceID);
436 openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
437 openMMPlatform.setPropertyDefaultValue("Precision", precision);
438 String name = openMMPlatform.getName();
439 logger.info(format(" Platform: %s (Device Index %d)", name, deviceID));
440 } else if (opencl && (requestedPlatform == OMM_OPENCL || requestedPlatform == OMM)) {
441
442 int defaultDevice = getDefaultDevice(forceField.getProperties());
443 openMMPlatform = new Platform("OpenCL");
444 int deviceID = forceField.getInteger("DeviceIndex", defaultDevice);
445 String deviceIDString = Integer.toString(deviceID);
446 openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
447 int openCLPlatformIndex = forceField.getInteger("OpenCLPlatformIndex", 0);
448 String openCLPlatformIndexString = Integer.toString(openCLPlatformIndex);
449 openMMPlatform.setPropertyDefaultValue("DeviceIndex", deviceIDString);
450 openMMPlatform.setPropertyDefaultValue("OpenCLPlatformIndex", openCLPlatformIndexString);
451 openMMPlatform.setPropertyDefaultValue("Precision", precision);
452 String name = openMMPlatform.getName();
453 logger.info(format(" Platform: %s (Platform Index %d, Device Index %d)",
454 name, openCLPlatformIndex, deviceID));
455 } else {
456
457 openMMPlatform = new Platform("Reference");
458 String name = openMMPlatform.getName();
459 logger.info(format(" Platform: %s", name));
460 }
461
462 try {
463 Comm world = Comm.world();
464 if (world != null) {
465 logger.info(format(" Running on host %s, rank %d", world.host(), world.rank()));
466 }
467 } catch (IllegalStateException illegalStateException) {
468 logger.fine(" Could not find the world communicator!");
469 }
470
471 return openMMPlatform;
472 }
473
474
475
476
477 public void free() {
478 destroy();
479 }
480 }