1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.nonbonded.implicit;
39
40 import static ffx.potential.nonbonded.implicit.BornTanhRescaling.tanhRescalingChainRule;
41 import static ffx.potential.nonbonded.implicit.NeckIntegral.getNeckConstants;
42 import static java.lang.Double.isInfinite;
43 import static java.lang.Double.isNaN;
44 import static java.lang.String.format;
45 import static org.apache.commons.math3.util.FastMath.PI;
46 import static org.apache.commons.math3.util.FastMath.max;
47 import static org.apache.commons.math3.util.FastMath.pow;
48 import static org.apache.commons.math3.util.FastMath.sqrt;
49
50 import edu.rit.pj.IntegerForLoop;
51 import edu.rit.pj.ParallelRegion;
52 import edu.rit.pj.ParallelTeam;
53 import ffx.crystal.Crystal;
54 import ffx.crystal.SymOp;
55 import ffx.numerics.atomic.AtomicDoubleArray;
56 import ffx.numerics.atomic.AtomicDoubleArray3D;
57 import ffx.potential.bonded.Atom;
58 import ffx.potential.utils.EnergyException;
59 import java.util.logging.Level;
60 import java.util.logging.Logger;
61
62
63
64
65
66
67
68 public class BornGradRegion extends ParallelRegion {
69
70 private static final Logger logger = Logger.getLogger(BornRadiiRegion.class.getName());
71 private static final double PI4_3 = 4.0 / 3.0 * PI;
72
73
74
75 private static final double oneThird = 1.0 / 3.0;
76
77 private final BornCRLoop[] bornCRLoop;
78
79
80
81 protected Atom[] atoms;
82
83
84
85 private Crystal crystal;
86
87
88
89 private double[][][] sXYZ;
90
91
92
93 private int[][][] neighborLists;
94
95
96
97 private double[] baseRadius;
98
99
100
101 private double[] descreenRadius;
102
103
104
105
106
107
108
109
110 private double[] overlapScale;
111
112
113
114
115 private double[] neckScale;
116 private double descreenOffset;
117
118
119
120 private final boolean perfectHCTScale;
121
122
123
124 private boolean[] use;
125
126
127
128 private double cut2;
129
130
131
132 private boolean nativeEnvironmentApproximation;
133
134
135
136 private double[] born;
137
138
139
140 private AtomicDoubleArray3D grad;
141
142
143
144 private AtomicDoubleArray sharedBornGrad;
145 private final double factor = -pow(PI, oneThird) * pow(6.0, (2.0 * oneThird)) / 9.0;
146 private double[] term;
147
148
149
150 private final boolean neckCorrection;
151
152
153
154 private final boolean tanhCorrection;
155
156
157
158
159 private double[] unscaledBornIntegral;
160
161
162
163
164
165
166
167
168
169 public BornGradRegion(int nt, boolean neckCorrection,
170 boolean tanhCorrection, boolean perfectHCTScale) {
171 bornCRLoop = new BornCRLoop[nt];
172 for (int i = 0; i < nt; i++) {
173 bornCRLoop[i] = new BornCRLoop();
174 }
175 this.neckCorrection = neckCorrection;
176 this.tanhCorrection = tanhCorrection;
177 this.perfectHCTScale = perfectHCTScale;
178 }
179
180
181
182
183
184
185 public void executeWith(ParallelTeam parallelTeam) {
186 sharedBornGrad.reduce(parallelTeam, 0, atoms.length - 1);
187 try {
188 parallelTeam.execute(this);
189 } catch (Exception e) {
190 String message = " Exception evaluating Born radii chain rule gradient.\n";
191 logger.log(Level.SEVERE, message, e);
192 }
193 }
194
195 public void init(
196 Atom[] atoms,
197 Crystal crystal,
198 double[][][] sXYZ,
199 int[][][] neighborLists,
200 double[] baseRadius,
201 double[] descreenRadius,
202 double[] overlapScale,
203 double[] neckScale,
204 double descreenOffset,
205 double[] unscaledBornIntegral,
206 boolean[] use,
207 double cut2,
208 boolean nativeEnvironmentApproximation,
209 double[] born,
210 AtomicDoubleArray3D grad,
211 AtomicDoubleArray sharedBornGrad) {
212 this.atoms = atoms;
213 this.crystal = crystal;
214 this.sXYZ = sXYZ;
215 this.neighborLists = neighborLists;
216 this.baseRadius = baseRadius;
217 this.descreenRadius = descreenRadius;
218 this.overlapScale = overlapScale;
219 this.neckScale = neckScale;
220 this.descreenOffset = descreenOffset;
221 this.unscaledBornIntegral = unscaledBornIntegral;
222 this.use = use;
223 this.cut2 = cut2;
224 this.nativeEnvironmentApproximation = nativeEnvironmentApproximation;
225 this.born = born;
226 this.grad = grad;
227 this.sharedBornGrad = sharedBornGrad;
228 }
229
230 @Override
231 public void start() {
232 int nAtoms = atoms.length;
233 if (term == null || term.length < nAtoms) {
234 term = new double[nAtoms];
235 }
236
237
238
239 for (int i = 0; i < nAtoms; i++) {
240 double rbi = born[i];
241 term[i] = PI4_3 / (rbi * rbi * rbi);
242 term[i] = factor / pow(term[i], (4.0 * oneThird));
243 if (tanhCorrection) {
244 term[i] =
245 term[i] * tanhRescalingChainRule(unscaledBornIntegral[i], baseRadius[i]);
246 }
247 }
248 }
249
250 @Override
251 public void run() {
252 try {
253 int nAtoms = atoms.length;
254 execute(0, nAtoms - 1, bornCRLoop[getThreadIndex()]);
255 } catch (Exception e) {
256 String message = "Fatal exception computing Born radii chain rule term in thread "
257 + getThreadIndex() + "\n";
258 logger.log(Level.SEVERE, message, e);
259 }
260 }
261
262
263
264
265
266
267 private class BornCRLoop extends IntegerForLoop {
268
269 private final double[] dx_local;
270 private int threadID;
271
272 BornCRLoop() {
273 dx_local = new double[3];
274 }
275
276 @Override
277 public void run(int lb, int ub) {
278
279 double[] x = sXYZ[0][0];
280 double[] y = sXYZ[0][1];
281 double[] z = sXYZ[0][2];
282
283 int nSymm = crystal.spaceGroup.symOps.size();
284 for (int iSymOp = 0; iSymOp < nSymm; iSymOp++) {
285 SymOp symOp = crystal.spaceGroup.symOps.get(iSymOp);
286 double[][] transOp = new double[3][3];
287 double[][] xyz = sXYZ[iSymOp];
288 crystal.getTransformationOperator(symOp, transOp);
289 for (int i = lb; i <= ub; i++) {
290 if (!nativeEnvironmentApproximation && !use[i]) {
291 continue;
292 }
293
294
295 double bornGrad = sharedBornGrad.get(i);
296 if (isInfinite(bornGrad) || isNaN(bornGrad)) {
297 throw new EnergyException(format(" %s\n Born radii CR %d %8.3f", atoms[i], i, bornGrad), true);
298 }
299 final double integralStartI = max(baseRadius[i], descreenRadius[i]) + descreenOffset;
300 final double descreenRi = descreenRadius[i];
301 final double xi = x[i];
302 final double yi = y[i];
303 final double zi = z[i];
304 final double rbi = born[i];
305 int[] list = neighborLists[iSymOp][i];
306 for (int k : list) {
307 if (!nativeEnvironmentApproximation && !use[k]) {
308 continue;
309 }
310 final double integralStartK = max(baseRadius[k], descreenRadius[k]) + descreenOffset;
311 final double descreenRk = descreenRadius[k];
312 double mixedNeckScale = 0.5 * (neckScale[i] + neckScale[k]);
313
314 if (k != i) {
315 dx_local[0] = xyz[0][k] - xi;
316 dx_local[1] = xyz[1][k] - yi;
317 dx_local[2] = xyz[2][k] - zi;
318 double r2 = crystal.image(dx_local);
319 if (r2 > cut2) {
320 continue;
321 }
322 final double xr = dx_local[0];
323 final double yr = dx_local[1];
324 final double zr = dx_local[2];
325 final double r = sqrt(r2);
326
327
328 double sk = overlapScale[k];
329 if (sk > 0.0 && rbi < 50.0 && descreenRk > 0.0) {
330 double de = descreenDerivative(r, r2, integralStartI, descreenRk, sk);
331 if (neckCorrection) {
332 de += neckDescreenDerivative(r, integralStartI, descreenRk, mixedNeckScale);
333 }
334 if (isInfinite(de) || isNaN(de)) {
335 logger.warning(
336 format(" Born radii chain rule term is unstable %d %d %16.8f", i, k, de));
337 }
338 double dbr = term[i] * de / r;
339 de = dbr * sharedBornGrad.get(i);
340 incrementGradient(i, k, de, xr, yr, zr, transOp);
341 }
342
343
344 double rbk = born[k];
345 double si = overlapScale[i];
346 if (si > 0.0 && rbk < 50.0 && descreenRi > 0.0) {
347 double de = descreenDerivative(r, r2, integralStartK, descreenRi, si);
348 if (neckCorrection) {
349 de += neckDescreenDerivative(r, integralStartK, descreenRi, mixedNeckScale);
350 }
351 if (isInfinite(de) || isNaN(de)) {
352 logger.warning(
353 format(" Born radii chain rule term is unstable %d %d %16.8f", k, i, de));
354 }
355 double dbr = term[k] * de / r;
356 de = dbr * sharedBornGrad.get(k);
357 incrementGradient(i, k, de, xr, yr, zr, transOp);
358 }
359 } else if (iSymOp > 0 && rbi < 50.0) {
360 dx_local[0] = xyz[0][k] - xi;
361 dx_local[1] = xyz[1][k] - yi;
362 dx_local[2] = xyz[2][k] - zi;
363 double r2 = crystal.image(dx_local);
364 double sk = overlapScale[k];
365 if (sk > 0.0 && r2 < cut2 && descreenRk > 0.0) {
366 final double xr = dx_local[0];
367 final double yr = dx_local[1];
368 final double zr = dx_local[2];
369 final double r = sqrt(r2);
370
371 double de = descreenDerivative(r, r2, integralStartI, descreenRk, sk);
372 if (neckCorrection) {
373 de += neckDescreenDerivative(r, integralStartI, descreenRk, mixedNeckScale);
374 }
375 if (isInfinite(de) || isNaN(de)) {
376 logger.warning(
377 format(" Born radii chain rule term is unstable %d %d %d %16.8f", iSymOp, i,
378 k, de));
379 }
380 double dbr = term[i] * de / r;
381 de = dbr * sharedBornGrad.get(i);
382 incrementGradient(i, k, de, xr, yr, zr, transOp);
383
384 }
385 }
386 }
387 }
388 }
389 }
390
391 @Override
392 public void start() {
393 threadID = getThreadIndex();
394 }
395
396 private double neckDescreenDerivative(double r, double radius, double radiusK, double sneck) {
397 double radiusWater = 1.4;
398
399 if (r > radius + radiusK + 2 * radiusWater) {
400 return 0.0;
401 }
402
403
404 double[] constants = getNeckConstants(radius, radiusK);
405
406
407 double Aij = constants[0];
408 double Bij = constants[1];
409
410
411 double rMinusBij = r - Bij;
412 double rMinusBij3 = rMinusBij * rMinusBij * rMinusBij;
413 double rMinusBij4 = rMinusBij3 * rMinusBij;
414 double radiiMinusr = radius + radiusK + 2.0 * radiusWater - r;
415 double radiiMinusr3 = radiiMinusr * radiiMinusr * radiiMinusr;
416 double radiiMinusr4 = radiiMinusr3 * radiiMinusr;
417
418 return 4.0 * PI4_3 * (sneck * Aij * rMinusBij3 * radiiMinusr4
419 - sneck * Aij * rMinusBij4 * radiiMinusr3);
420 }
421
422 private double descreenDerivative(double r, double r2, double radius, double radiusK,
423 double hctScale) {
424 if (perfectHCTScale) {
425 return perfectHCTIntegralDerivative(r, r2, radius, radiusK, hctScale);
426 } else {
427 return integralDerivative(r, r2, radius, radiusK * hctScale);
428 }
429 }
430
431
432
433
434
435
436
437
438
439
440 private double integralDerivative(double r, double r2, double radius, double scaledRadius) {
441 double de = 0.0;
442
443
444 if (scaledRadius > 0.0 && (radius < r + scaledRadius)) {
445
446 if (radius + r < scaledRadius) {
447 double uik = scaledRadius - r;
448 double uik2 = uik * uik;
449 double uik4 = uik2 * uik2;
450 de = -4.0 * PI / uik4;
451 }
452
453
454 double sk2 = scaledRadius * scaledRadius;
455 if (radius + r < scaledRadius) {
456
457 double lik = scaledRadius - r;
458 double lik2 = lik * lik;
459 double lik4 = lik2 * lik2;
460 de = de + 0.25 * PI * (sk2 - 4.0 * scaledRadius * r + 17.0 * r2) / (r2 * lik4);
461 } else if (r < radius + scaledRadius) {
462
463 double lik = radius;
464 double lik2 = lik * lik;
465 double lik4 = lik2 * lik2;
466 de = de + 0.25 * PI * (2.0 * radius * radius - sk2 - r2) / (r2 * lik4);
467 } else {
468
469 double lik = r - scaledRadius;
470 double lik2 = lik * lik;
471 double lik4 = lik2 * lik2;
472 de = de + 0.25 * PI * (sk2 - 4.0 * scaledRadius * r + r2) / (r2 * lik4);
473 }
474
475
476 double uik = r + scaledRadius;
477 double uik2 = uik * uik;
478 double uik4 = uik2 * uik2;
479 de = de - 0.25 * PI * (sk2 + 4.0 * scaledRadius * r + r2) / (r2 * uik4);
480 }
481 return de;
482 }
483
484
485
486
487
488
489
490
491
492
493
494 private double perfectHCTIntegralDerivative(double r, double r2, double radius, double radiusK,
495 double perfectHCT) {
496 double de = 0.0;
497
498
499 if (radiusK > 0.0 && (radius < r + radiusK)) {
500
501 if (radius + r < radiusK) {
502 double uik = radiusK - r;
503 double uik2 = uik * uik;
504 double uik4 = uik2 * uik2;
505 de = -4.0 * PI / uik4;
506 }
507
508
509 double sk2 = radiusK * radiusK;
510 if (radius + r < radiusK) {
511
512 double lik = radiusK - r;
513 double lik2 = lik * lik;
514 double lik4 = lik2 * lik2;
515 de = de + 0.25 * PI * (sk2 - 4.0 * radiusK * r + 17.0 * r2) / (r2 * lik4);
516 } else if (r < radius + radiusK) {
517
518 double lik = radius;
519 double lik2 = lik * lik;
520 double lik4 = lik2 * lik2;
521 de = de + 0.25 * PI * (2.0 * radius * radius - sk2 - r2) / (r2 * lik4);
522 } else {
523
524 double lik = r - radiusK;
525 double lik2 = lik * lik;
526 double lik4 = lik2 * lik2;
527 de = de + 0.25 * PI * (sk2 - 4.0 * radiusK * r + r2) / (r2 * lik4);
528 }
529
530
531 double uik = r + radiusK;
532 double uik2 = uik * uik;
533 double uik4 = uik2 * uik2;
534 de = de - 0.25 * PI * (sk2 + 4.0 * radiusK * r + r2) / (r2 * uik4);
535 }
536 return perfectHCT * de;
537 }
538
539
540
541
542
543
544
545
546
547
548
549 private void incrementGradient(
550 int i, int k, double dE, double xr, double yr, double zr, double[][] transOp) {
551 double dedx = dE * xr;
552 double dedy = dE * yr;
553 double dedz = dE * zr;
554 grad.add(threadID, i, dedx, dedy, dedz);
555 final double dedxk = dedx * transOp[0][0] + dedy * transOp[1][0] + dedz * transOp[2][0];
556 final double dedyk = dedx * transOp[0][1] + dedy * transOp[1][1] + dedz * transOp[2][1];
557 final double dedzk = dedx * transOp[0][2] + dedy * transOp[1][2] + dedz * transOp[2][2];
558 grad.sub(threadID, k, dedxk, dedyk, dedzk);
559 }
560 }
561 }