1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.openmm;
39
40 import edu.rit.pj.Comm;
41 import edu.uiowa.jopenmm.OpenMMLibrary;
42 import edu.uiowa.jopenmm.OpenMMUtils;
43 import edu.uiowa.jopenmm.OpenMM_Vec3;
44 import ffx.crystal.Crystal;
45 import ffx.openmm.Context;
46 import ffx.openmm.Integrator;
47 import ffx.openmm.Platform;
48 import ffx.openmm.State;
49 import ffx.openmm.StringArray;
50 import ffx.potential.MolecularAssembly;
51 import ffx.potential.bonded.Atom;
52
53 import java.util.logging.Level;
54 import java.util.logging.Logger;
55
56 import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_KcalPerKJ;
57 import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_NmPerAngstrom;
58 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_LocalEnergyMinimizer_minimize;
59 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Positions;
60 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Velocities;
61 import static ffx.openmm.Platform.getNumPlatforms;
62 import static ffx.openmm.Platform.getOpenMMVersion;
63 import static ffx.openmm.Platform.getPluginLoadFailures;
64 import static ffx.openmm.Platform.loadPluginsFromDirectory;
65 import static ffx.potential.ForceFieldEnergy.DEFAULT_CONSTRAINT_TOLERANCE;
66 import static ffx.potential.openmm.OpenMMEnergy.getDefaultDevice;
67 import static ffx.potential.openmm.OpenMMIntegrator.createIntegrator;
68 import static java.lang.String.format;
69
70
71
72
73
74
75
76
77
78
79
80
81
82 public class OpenMMContext extends Context {
83
84 private static final Logger logger = Logger.getLogger(OpenMMContext.class.getName());
85
86
87
88
89 private final ffx.potential.Platform platform;
90
91
92
93 private Platform openMMPlatform = null;
94
95
96
97 private final OpenMMEnergy openMMEnergy;
98
99
100
101 private Integrator openMMIntegrator;
102
103
104
105 private String integratorName = "VERLET";
106
107
108
109 private double timeStep = 0.001;
110
111
112
113 private double temperature = 298.15;
114
115
116
117 private final int enforcePBC;
118
119
120
121 private final Atom[] atoms;
122
123
124
125
126
127
128
129
130
131 public OpenMMContext(ffx.potential.Platform requestedPlatform, Atom[] atoms, int enforcePBC,
132 OpenMMEnergy openMMEnergy) {
133 platform = requestedPlatform;
134 this.atoms = atoms;
135 this.enforcePBC = enforcePBC;
136 this.openMMEnergy = openMMEnergy;
137 loadPlatform(requestedPlatform);
138 }
139
140
141
142
143
144
145
146
147
148
149 public void update(String integratorName, double timeStep, double temperature,
150 boolean forceCreation, OpenMMEnergy openMMEnergy) {
151
152 if (hasContextPointer() && !forceCreation) {
153 if (this.temperature == temperature && this.timeStep == timeStep
154 && this.integratorName.equalsIgnoreCase(integratorName)) {
155
156 return;
157 }
158 }
159
160 this.integratorName = integratorName;
161 this.timeStep = timeStep;
162 this.temperature = temperature;
163
164 logger.info("\n Updating OpenMM Context");
165
166
167 OpenMMSystem openMMSystem = openMMEnergy.getSystem();
168 if (openMMIntegrator != null) {
169 openMMIntegrator.destroy();
170 }
171 openMMIntegrator = createIntegrator(integratorName, timeStep, temperature, openMMSystem);
172
173
174 double currentLambda = openMMEnergy.getLambda();
175
176 if (openMMEnergy.getLambdaTerm()) {
177 openMMEnergy.setLambda(1.0);
178 }
179
180
181 updateContext(openMMSystem, openMMIntegrator, openMMPlatform);
182
183
184 if (openMMEnergy.getLambdaTerm()) {
185 openMMEnergy.setLambda(currentLambda);
186 }
187
188
189 int nVar = openMMEnergy.getNumberOfVariables();
190 double[] x = new double[nVar];
191 double[] v = new double[nVar];
192 double[] vel3 = new double[3];
193 int index = 0;
194 for (Atom a : atoms) {
195 if (a.isActive()) {
196 a.getVelocity(vel3);
197
198 x[index] = a.getX();
199 v[index++] = vel3[0];
200
201 x[index] = a.getY();
202 v[index++] = vel3[1];
203
204 x[index] = a.getZ();
205 v[index++] = vel3[2];
206 }
207 }
208
209
210 Crystal crystal = openMMEnergy.getCrystal();
211 setPeriodicBoxVectors(crystal);
212
213
214 setPositions(x);
215
216
217 setVelocities(v);
218
219
220 applyConstraints(DEFAULT_CONSTRAINT_TOLERANCE);
221
222
223
224 OpenMMState openMMState = getOpenMMState(OpenMM_State_Positions | OpenMM_State_Velocities);
225 openMMState.getPositions(x);
226 openMMState.getVelocities(v);
227 openMMState.destroy();
228 }
229
230
231
232
233 public void update() {
234 if (!hasContextPointer()) {
235 openMMEnergy.updateContext(integratorName, timeStep, temperature, true);
236 }
237 }
238
239
240
241
242
243
244
245 public OpenMMState getOpenMMState(int mask) {
246 State state = getState(mask, enforcePBC);
247 return new OpenMMState(state.getPointer(), atoms, openMMEnergy.getNumberOfVariables());
248 }
249
250 public ffx.potential.Platform getPlatform() {
251 return platform;
252 }
253
254
255
256
257
258
259 public void integrate(int numSteps) {
260 openMMIntegrator.step(numSteps);
261 }
262
263
264
265
266
267
268
269 public void optimize(double eps, int maxIterations) {
270 OpenMM_LocalEnergyMinimizer_minimize(getPointer(),
271 eps / (OpenMM_NmPerAngstrom * OpenMM_KcalPerKJ), maxIterations);
272 }
273
274
275
276
277
278
279 public void setPositions(double[] x) {
280 double[] allPositions = new double[3 * atoms.length];
281 double[] d = new double[3];
282 int index = 0;
283 for (Atom a : atoms) {
284 if (a.isActive()) {
285 a.moveTo(x[index], x[index + 1], x[index + 2]);
286 }
287 a.getXYZ(d);
288 allPositions[index] = d[0] * OpenMM_NmPerAngstrom;
289 allPositions[index + 1] = d[1] * OpenMM_NmPerAngstrom;
290 allPositions[index + 2] = d[2] * OpenMM_NmPerAngstrom;
291 index += 3;
292 }
293 super.setPositions(allPositions);
294 }
295
296
297
298
299
300
301 public void setVelocities(double[] v) {
302 double[] allVelocities = new double[3 * atoms.length];
303 double[] velocity = new double[3];
304 int index = 0;
305 for (Atom a : atoms) {
306 if (a.isActive()) {
307 a.setVelocity(v[index], v[index + 1], v[index + 2]);
308 } else {
309
310 a.setVelocity(0.0, 0.0, 0.0);
311 }
312 a.getVelocity(velocity);
313 allVelocities[index] = velocity[0] * OpenMM_NmPerAngstrom;
314 allVelocities[index + 1] = velocity[1] * OpenMM_NmPerAngstrom;
315 allVelocities[index + 2] = velocity[2] * OpenMM_NmPerAngstrom;
316 index += 3;
317 }
318 super.setVelocities(allVelocities);
319 }
320
321
322
323
324
325
326 public void setPeriodicBoxVectors(Crystal crystal) {
327 if (!crystal.aperiodic()) {
328 OpenMM_Vec3 a = new OpenMM_Vec3();
329 OpenMM_Vec3 b = new OpenMM_Vec3();
330 OpenMM_Vec3 c = new OpenMM_Vec3();
331 double[][] Ai = crystal.Ai;
332 a.x = Ai[0][0] * OpenMM_NmPerAngstrom;
333 a.y = Ai[0][1] * OpenMM_NmPerAngstrom;
334 a.z = Ai[0][2] * OpenMM_NmPerAngstrom;
335 b.x = Ai[1][0] * OpenMM_NmPerAngstrom;
336 b.y = Ai[1][1] * OpenMM_NmPerAngstrom;
337 b.z = Ai[1][2] * OpenMM_NmPerAngstrom;
338 c.x = Ai[2][0] * OpenMM_NmPerAngstrom;
339 c.y = Ai[2][1] * OpenMM_NmPerAngstrom;
340 c.z = Ai[2][2] * OpenMM_NmPerAngstrom;
341 setPeriodicBoxVectors(a, b, c);
342 }
343 }
344
345
346
347
348 @Override
349 public String toString() {
350 return format(
351 " OpenMM context with integrator %s, timestep %9.3g fsec, temperature %9.3g K",
352 integratorName, timeStep, temperature);
353 }
354
355
356
357
358 private void loadPlatform(ffx.potential.Platform requestedPlatform) {
359
360 OpenMMUtils.init();
361
362 logger.log(Level.INFO, " Loaded from:\n {0}", OpenMMLibrary.JNA_NATIVE_LIB.toString());
363
364
365 logger.log(Level.INFO, " Version: {0}", getOpenMMVersion());
366
367
368 String libDirectory = OpenMMUtils.getLibDirectory();
369 logger.log(Level.FINE, " Lib Directory: {0}", libDirectory);
370
371 StringArray libs = loadPluginsFromDirectory(libDirectory);
372 int numLibs = libs.getSize();
373 logger.log(Level.FINE, " Number of libraries: {0}", numLibs);
374 for (int i = 0; i < numLibs; i++) {
375 logger.log(Level.FINE, " Library: {0}", libs.get(i));
376 }
377 libs.destroy();
378
379
380 String pluginDirectory = OpenMMUtils.getPluginDirectory();
381 logger.log(Level.INFO, "\n Plugin Directory: {0}", pluginDirectory);
382
383 StringArray plugins = loadPluginsFromDirectory(pluginDirectory);
384 int numPlugins = plugins.getSize();
385 logger.log(Level.INFO, " Number of Plugins: {0}", numPlugins);
386 boolean cuda = false;
387 for (int i = 0; i < numPlugins; i++) {
388 String pluginString = plugins.get(i);
389 logger.log(Level.INFO, " Plugin: {0}", pluginString);
390 if (pluginString != null) {
391 pluginString = pluginString.toUpperCase();
392 boolean amoebaCudaAvailable = pluginString.contains("AMOEBACUDA");
393 if (amoebaCudaAvailable) {
394 cuda = true;
395 }
396 }
397 }
398 plugins.destroy();
399
400 int numPlatforms = getNumPlatforms();
401 logger.log(Level.INFO, " Number of Platforms: {0}", numPlatforms);
402
403 if (requestedPlatform == ffx.potential.Platform.OMM_CUDA && !cuda) {
404 logger.severe(" The OMM_CUDA platform was requested, but is not available.");
405 }
406
407
408 if (logger.isLoggable(Level.FINE)) {
409 StringArray pluginFailures = getPluginLoadFailures();
410 int numFailures = pluginFailures.getSize();
411 for (int i = 0; i < numFailures; i++) {
412 logger.log(Level.FINE, " Plugin load failure: {0}", pluginFailures.get(i));
413 }
414 pluginFailures.destroy();
415 }
416
417 String defaultPrecision = "mixed";
418 MolecularAssembly molecularAssembly = openMMEnergy.getMolecularAssembly();
419 String precision = molecularAssembly.getForceField().getString("PRECISION", defaultPrecision).toLowerCase();
420 precision = precision.replace("-precision", "");
421 switch (precision) {
422 case "double", "mixed", "single" -> logger.info(format(" Precision level: %s", precision));
423 default -> {
424 logger.info(format(" Could not interpret precision level %s, defaulting to %s", precision, defaultPrecision));
425 precision = defaultPrecision;
426 }
427 }
428
429 if (cuda && requestedPlatform != ffx.potential.Platform.OMM_REF) {
430 int defaultDevice = getDefaultDevice(molecularAssembly.getProperties());
431 openMMPlatform = new Platform("CUDA");
432 int deviceID = molecularAssembly.getForceField().getInteger("CUDA_DEVICE", defaultDevice);
433 String deviceIDString = Integer.toString(deviceID);
434
435 openMMPlatform.setPropertyDefaultValue("CudaDeviceIndex", deviceIDString);
436 openMMPlatform.setPropertyDefaultValue("Precision", precision);
437 logger.info(format(" Platform: AMOEBA CUDA (Device ID %d)", deviceID));
438 try {
439 Comm world = Comm.world();
440 if (world != null) {
441 logger.info(format(" Running on host %s, rank %d", world.host(), world.rank()));
442 }
443 } catch (IllegalStateException illegalStateException) {
444 logger.fine(" Could not find the world communicator!");
445 }
446 } else {
447 openMMPlatform = new Platform("Reference");
448 logger.info(" Platform: AMOEBA CPU Reference");
449 }
450 }
451
452
453
454
455 public void free() {
456 if (openMMIntegrator != null) {
457 openMMIntegrator.destroy();
458 }
459 destroy();
460 }
461 }