1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.openmm;
39
40 import com.sun.jna.ptr.PointerByReference;
41 import ffx.openmm.State;
42 import ffx.potential.bonded.Atom;
43 import ffx.potential.utils.EnergyException;
44
45 import javax.annotation.Nullable;
46 import java.util.Arrays;
47
48 import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_AngstromsPerNm;
49 import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_KcalPerKJ;
50 import static edu.uiowa.jopenmm.OpenMMAmoebaLibrary.OpenMM_NmPerAngstrom;
51 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Energy;
52 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Forces;
53 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Positions;
54 import static edu.uiowa.jopenmm.OpenMMLibrary.OpenMM_State_DataType.OpenMM_State_Velocities;
55 import static java.lang.Double.isInfinite;
56 import static java.lang.Double.isNaN;
57 import static java.lang.String.format;
58
59
60
61
62 public class OpenMMState extends State {
63
64
65
66
67 public final double potentialEnergy;
68
69
70
71 public final double kineticEnergy;
72
73
74
75 public final double totalEnergy;
76
77
78
79 private final int dataTypes;
80
81
82
83
84
85
86 protected OpenMMState(PointerByReference pointer) {
87 super(pointer);
88
89
90 this.dataTypes = super.getDataTypes();
91 if (stateContains(OpenMM_State_Energy)) {
92
93 potentialEnergy = super.getPotentialEnergy() * OpenMM_KcalPerKJ;
94 kineticEnergy = super.getKineticEnergy() * OpenMM_KcalPerKJ;
95 totalEnergy = potentialEnergy + kineticEnergy;
96 } else {
97 potentialEnergy = 0.0;
98 kineticEnergy = 0.0;
99 totalEnergy = 0.0;
100 }
101 }
102
103
104
105
106
107
108
109
110
111 public double[] getAccelerations(@Nullable double[] a, Atom[] atoms) {
112
113 if (!stateContains(OpenMM_State_Forces)) {
114 return a;
115 }
116 double[] forces = getForces();
117 int n = forces.length;
118
119
120 if (atoms == null || atoms.length == 0) {
121 throw new IllegalArgumentException("Atoms array must not be null or empty.");
122 }
123
124 if (atoms.length * 3 != n) {
125 String message = format(" The number of atoms (%d) does not match the number of degrees of freedom (%d).", atoms.length, n);
126 throw new IllegalArgumentException(message);
127 }
128 if (a == null || a.length != n) {
129 a = new double[n];
130 }
131
132 int index = 0;
133 for (Atom atom : atoms) {
134 double mass = atom.getMass();
135 double xx = forces[index] * OpenMM_AngstromsPerNm / mass;
136 double yy = forces[index + 1] * OpenMM_AngstromsPerNm / mass;
137 double zz = forces[index + 2] * OpenMM_AngstromsPerNm / mass;
138 a[index] = xx;
139 a[index + 1] = yy;
140 a[index + 2] = zz;
141 index += 3;
142 }
143 return a;
144 }
145
146
147
148
149
150
151
152
153
154 public double[] getActiveAccelerations(@Nullable double[] a, Atom[] atoms) {
155 if (!stateContains(OpenMM_State_Forces)) {
156 return a;
157 }
158 return filterToActive(getAccelerations(null, atoms), a, atoms);
159 }
160
161
162
163
164
165
166
167 public double[] getGradient(@Nullable double[] g) {
168
169 if (!stateContains(OpenMM_State_Forces)) {
170 return g;
171 }
172 double[] forces = getForces();
173 int n = forces.length;
174
175
176 if (g == null || g.length != n) {
177 g = new double[n];
178 }
179
180 for (int i = 0; i < n; i++) {
181 double xx = -forces[i] * OpenMM_NmPerAngstrom * OpenMM_KcalPerKJ;
182 if (isNaN(xx) || isInfinite(xx)) {
183 throw new EnergyException(
184 format(" The gradient of degree of freedom %d is %8.3f.", i, xx));
185 }
186 g[i] = xx;
187 }
188 return g;
189 }
190
191
192
193
194
195
196
197 public double[] getActiveGradient(@Nullable double[] g, Atom[] atoms) {
198 if (!stateContains(OpenMM_State_Forces)) {
199 return g;
200 }
201 return filterToActive(getGradient(null), g, atoms);
202 }
203
204
205
206
207
208
209 public double[][] getPeriodicBoxVectors() {
210 if (!stateContains(OpenMM_State_Positions)) {
211 return null;
212 }
213
214 double[][] latticeVectors = super.getPeriodicBoxVectors();
215 latticeVectors[0][0] *= OpenMM_AngstromsPerNm;
216 latticeVectors[0][1] *= OpenMM_AngstromsPerNm;
217 latticeVectors[0][2] *= OpenMM_AngstromsPerNm;
218 latticeVectors[1][0] *= OpenMM_AngstromsPerNm;
219 latticeVectors[1][1] *= OpenMM_AngstromsPerNm;
220 latticeVectors[1][2] *= OpenMM_AngstromsPerNm;
221 latticeVectors[2][0] *= OpenMM_AngstromsPerNm;
222 latticeVectors[2][1] *= OpenMM_AngstromsPerNm;
223 latticeVectors[2][2] *= OpenMM_AngstromsPerNm;
224 return latticeVectors;
225 }
226
227
228
229
230
231
232
233
234 public double[] getPositions(@Nullable double[] x) {
235
236 if (!stateContains(OpenMM_State_Positions)) {
237 return x;
238 }
239
240 double[] pos = getPositions();
241 int n = pos.length;
242
243
244 if (x == null || x.length != n) {
245 x = new double[n];
246 }
247
248 for (int i = 0; i < n; i++) {
249 x[i] = pos[i] * OpenMM_AngstromsPerNm;
250 }
251
252 return x;
253 }
254
255
256
257
258
259
260
261
262 public double[] getActivePositions(@Nullable double[] x, Atom[] atoms) {
263 if (!stateContains(OpenMM_State_Positions)) {
264 return x;
265 }
266 return filterToActive(getPositions(null), x, atoms);
267 }
268
269
270
271
272
273
274
275
276 public double[] getVelocities(@Nullable double[] v) {
277 if (!stateContains(OpenMM_State_Velocities)) {
278 return v;
279 }
280
281 double[] vel = getVelocities();
282 int n = vel.length;
283
284
285 if (v == null || v.length != n) {
286 v = new double[n];
287 }
288
289 for (int i = 0; i < n; i++) {
290 v[i] = vel[i] * OpenMM_AngstromsPerNm;
291 }
292
293 return v;
294 }
295
296
297
298
299
300
301
302
303 public double[] getActiveVelocities(@Nullable double[] v, Atom[] atoms) {
304 if (!stateContains(OpenMM_State_Velocities)) {
305 return v;
306 }
307 return filterToActive(getVelocities(null), v, atoms);
308 }
309
310
311
312
313
314
315 @Override
316 public double getPeriodicBoxVolume() {
317 return super.getPeriodicBoxVolume()
318 * OpenMM_AngstromsPerNm * OpenMM_AngstromsPerNm * OpenMM_AngstromsPerNm;
319 }
320
321
322
323
324
325
326 @Override
327 public double getPotentialEnergy() {
328 return potentialEnergy;
329 }
330
331
332
333
334
335
336 @Override
337 public double getKineticEnergy() {
338 return kineticEnergy;
339 }
340
341
342
343
344
345
346 public double getTotalEnergy() {
347 return totalEnergy;
348 }
349
350
351
352
353
354
355 @Override
356 public int getDataTypes() {
357 return dataTypes;
358 }
359
360
361
362
363
364
365
366 private boolean stateContains(int dataType) {
367 return (dataTypes & dataType) == dataType;
368 }
369
370
371
372
373
374
375
376
377 private static double[] filterToActive(double[] source, @Nullable double[] target, Atom[] atoms) {
378 if (source == null || atoms == null) {
379 throw new IllegalArgumentException("The arrays must be non-null.");
380 }
381
382
383 if (source.length != atoms.length * 3) {
384 throw new IllegalArgumentException("Source array length must be three times the number of atoms.");
385 }
386
387
388 int count = (int) Arrays.stream(atoms).filter(Atom::isActive).count();
389
390
391 if (target == null || target.length < count * 3) {
392 target = new double[count * 3];
393 }
394
395
396 int sourceIndedx = 0;
397 int targetIndex = 0;
398 for (Atom atom : atoms) {
399 if (atom.isActive()) {
400 target[targetIndex] = source[sourceIndedx];
401 target[targetIndex + 1] = source[sourceIndedx + 1];
402 target[targetIndex + 2] = source[sourceIndedx + 2];
403 targetIndex += 3;
404 }
405
406 sourceIndedx += 3;
407 }
408 return target;
409 }
410 }