1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.nonbonded.implicit;
39
40 import edu.rit.pj.IntegerForLoop;
41 import edu.rit.pj.ParallelRegion;
42 import edu.rit.pj.ParallelTeam;
43 import ffx.crystal.Crystal;
44 import ffx.crystal.SymOp;
45 import ffx.numerics.atomic.AtomicDoubleArray;
46 import ffx.numerics.atomic.AtomicDoubleArray3D;
47 import ffx.potential.bonded.Atom;
48 import ffx.potential.utils.EnergyException;
49
50 import java.util.logging.Level;
51 import java.util.logging.Logger;
52
53 import static ffx.potential.nonbonded.implicit.BornTanhRescaling.tanhRescalingChainRule;
54 import static ffx.potential.nonbonded.implicit.NeckIntegral.getNeckConstants;
55 import static java.lang.Double.isInfinite;
56 import static java.lang.Double.isNaN;
57 import static java.lang.String.format;
58 import static org.apache.commons.math3.util.FastMath.PI;
59 import static org.apache.commons.math3.util.FastMath.max;
60 import static org.apache.commons.math3.util.FastMath.pow;
61 import static org.apache.commons.math3.util.FastMath.sqrt;
62
63
64
65
66
67
68
69 public class BornGradRegion extends ParallelRegion {
70
71 private static final Logger logger = Logger.getLogger(BornRadiiRegion.class.getName());
72 private static final double PI4_3 = 4.0 / 3.0 * PI;
73
74
75
76 private static final double oneThird = 1.0 / 3.0;
77
78 private final BornCRLoop[] bornCRLoop;
79
80
81
82 protected Atom[] atoms;
83
84
85
86 private Crystal crystal;
87
88
89
90 private double[][][] sXYZ;
91
92
93
94 private int[][][] neighborLists;
95
96
97
98 private double[] baseRadius;
99
100
101
102 private double[] descreenRadius;
103
104
105
106
107
108
109
110
111 private double[] overlapScale;
112
113
114
115
116 private double[] neckScale;
117 private double descreenOffset;
118
119
120
121 private final boolean perfectHCTScale;
122
123
124
125 private boolean[] use;
126
127
128
129 private double cut2;
130
131
132
133 private boolean nativeEnvironmentApproximation;
134
135
136
137 private double[] born;
138
139
140
141 private AtomicDoubleArray3D grad;
142
143
144
145 private AtomicDoubleArray sharedBornGrad;
146 private final double factor = -pow(PI, oneThird) * pow(6.0, (2.0 * oneThird)) / 9.0;
147 private double[] term;
148
149
150
151 private final boolean neckCorrection;
152
153
154
155 private final boolean tanhCorrection;
156
157
158
159
160 private double[] unscaledBornIntegral;
161
162
163
164
165
166
167
168
169
170 public BornGradRegion(int nt, boolean neckCorrection,
171 boolean tanhCorrection, boolean perfectHCTScale) {
172 bornCRLoop = new BornCRLoop[nt];
173 for (int i = 0; i < nt; i++) {
174 bornCRLoop[i] = new BornCRLoop();
175 }
176 this.neckCorrection = neckCorrection;
177 this.tanhCorrection = tanhCorrection;
178 this.perfectHCTScale = perfectHCTScale;
179 }
180
181
182
183
184
185
186 public void executeWith(ParallelTeam parallelTeam) {
187 sharedBornGrad.reduce(parallelTeam, 0, atoms.length - 1);
188 try {
189 parallelTeam.execute(this);
190 } catch (Exception e) {
191 String message = " Exception evaluating Born radii chain rule gradient.\n";
192 logger.log(Level.SEVERE, message, e);
193 }
194 }
195
196 public void init(
197 Atom[] atoms,
198 Crystal crystal,
199 double[][][] sXYZ,
200 int[][][] neighborLists,
201 double[] baseRadius,
202 double[] descreenRadius,
203 double[] overlapScale,
204 double[] neckScale,
205 double descreenOffset,
206 double[] unscaledBornIntegral,
207 boolean[] use,
208 double cut2,
209 boolean nativeEnvironmentApproximation,
210 double[] born,
211 AtomicDoubleArray3D grad,
212 AtomicDoubleArray sharedBornGrad) {
213 this.atoms = atoms;
214 this.crystal = crystal;
215 this.sXYZ = sXYZ;
216 this.neighborLists = neighborLists;
217 this.baseRadius = baseRadius;
218 this.descreenRadius = descreenRadius;
219 this.overlapScale = overlapScale;
220 this.neckScale = neckScale;
221 this.descreenOffset = descreenOffset;
222 this.unscaledBornIntegral = unscaledBornIntegral;
223 this.use = use;
224 this.cut2 = cut2;
225 this.nativeEnvironmentApproximation = nativeEnvironmentApproximation;
226 this.born = born;
227 this.grad = grad;
228 this.sharedBornGrad = sharedBornGrad;
229 }
230
231 @Override
232 public void start() {
233 int nAtoms = atoms.length;
234 if (term == null || term.length < nAtoms) {
235 term = new double[nAtoms];
236 }
237
238
239
240 for (int i = 0; i < nAtoms; i++) {
241 double rbi = born[i];
242 term[i] = PI4_3 / (rbi * rbi * rbi);
243 term[i] = factor / pow(term[i], (4.0 * oneThird));
244 if (tanhCorrection) {
245 term[i] =
246 term[i] * tanhRescalingChainRule(unscaledBornIntegral[i], baseRadius[i]);
247 }
248 }
249 }
250
251 @Override
252 public void run() {
253 try {
254 int nAtoms = atoms.length;
255 execute(0, nAtoms - 1, bornCRLoop[getThreadIndex()]);
256 } catch (Exception e) {
257 String message = "Fatal exception computing Born radii chain rule term in thread "
258 + getThreadIndex() + "\n";
259 logger.log(Level.SEVERE, message, e);
260 }
261 }
262
263
264
265
266
267
268 private class BornCRLoop extends IntegerForLoop {
269
270 private final double[] dx_local;
271 private int threadID;
272
273 BornCRLoop() {
274 dx_local = new double[3];
275 }
276
277 @Override
278 public void run(int lb, int ub) {
279
280 double[] x = sXYZ[0][0];
281 double[] y = sXYZ[0][1];
282 double[] z = sXYZ[0][2];
283
284 int nSymm = crystal.spaceGroup.symOps.size();
285 for (int iSymOp = 0; iSymOp < nSymm; iSymOp++) {
286 SymOp symOp = crystal.spaceGroup.symOps.get(iSymOp);
287 double[][] transOp = new double[3][3];
288 double[][] xyz = sXYZ[iSymOp];
289 crystal.getTransformationOperator(symOp, transOp);
290 for (int i = lb; i <= ub; i++) {
291 if (!nativeEnvironmentApproximation && !use[i]) {
292 continue;
293 }
294
295
296 double bornGrad = sharedBornGrad.get(i);
297 if (isInfinite(bornGrad) || isNaN(bornGrad)) {
298 throw new EnergyException(format(" %s\n Born radii CR %d %8.3f", atoms[i], i, bornGrad), true);
299 }
300 final double integralStartI = max(baseRadius[i], descreenRadius[i]) + descreenOffset;
301 final double descreenRi = descreenRadius[i];
302 final double xi = x[i];
303 final double yi = y[i];
304 final double zi = z[i];
305 final double rbi = born[i];
306 int[] list = neighborLists[iSymOp][i];
307 for (int k : list) {
308 if (!nativeEnvironmentApproximation && !use[k]) {
309 continue;
310 }
311 final double integralStartK = max(baseRadius[k], descreenRadius[k]) + descreenOffset;
312 final double descreenRk = descreenRadius[k];
313 double mixedNeckScale = 0.5 * (neckScale[i] + neckScale[k]);
314
315 if (k != i) {
316 dx_local[0] = xyz[0][k] - xi;
317 dx_local[1] = xyz[1][k] - yi;
318 dx_local[2] = xyz[2][k] - zi;
319 double r2 = crystal.image(dx_local);
320 if (r2 > cut2) {
321 continue;
322 }
323 final double xr = dx_local[0];
324 final double yr = dx_local[1];
325 final double zr = dx_local[2];
326 final double r = sqrt(r2);
327
328
329 double sk = overlapScale[k];
330 if (sk > 0.0 && rbi < 50.0 && descreenRk > 0.0) {
331 double de = descreenDerivative(r, r2, integralStartI, descreenRk, sk);
332 if (neckCorrection) {
333 de += neckDescreenDerivative(r, integralStartI, descreenRk, mixedNeckScale);
334 }
335 if (isInfinite(de) || isNaN(de)) {
336 logger.warning(
337 format(" Born radii chain rule term is unstable %d %d %16.8f", i, k, de));
338 }
339 double dbr = term[i] * de / r;
340 de = dbr * sharedBornGrad.get(i);
341 incrementGradient(i, k, de, xr, yr, zr, transOp);
342 }
343
344
345 double rbk = born[k];
346 double si = overlapScale[i];
347 if (si > 0.0 && rbk < 50.0 && descreenRi > 0.0) {
348 double de = descreenDerivative(r, r2, integralStartK, descreenRi, si);
349 if (neckCorrection) {
350 de += neckDescreenDerivative(r, integralStartK, descreenRi, mixedNeckScale);
351 }
352 if (isInfinite(de) || isNaN(de)) {
353 logger.warning(
354 format(" Born radii chain rule term is unstable %d %d %16.8f", k, i, de));
355 }
356 double dbr = term[k] * de / r;
357 de = dbr * sharedBornGrad.get(k);
358 incrementGradient(i, k, de, xr, yr, zr, transOp);
359 }
360 } else if (iSymOp > 0 && rbi < 50.0) {
361 dx_local[0] = xyz[0][k] - xi;
362 dx_local[1] = xyz[1][k] - yi;
363 dx_local[2] = xyz[2][k] - zi;
364 double r2 = crystal.image(dx_local);
365 double sk = overlapScale[k];
366 if (sk > 0.0 && r2 < cut2 && descreenRk > 0.0) {
367 final double xr = dx_local[0];
368 final double yr = dx_local[1];
369 final double zr = dx_local[2];
370 final double r = sqrt(r2);
371
372 double de = descreenDerivative(r, r2, integralStartI, descreenRk, sk);
373 if (neckCorrection) {
374 de += neckDescreenDerivative(r, integralStartI, descreenRk, mixedNeckScale);
375 }
376 if (isInfinite(de) || isNaN(de)) {
377 logger.warning(
378 format(" Born radii chain rule term is unstable %d %d %d %16.8f", iSymOp, i,
379 k, de));
380 }
381 double dbr = term[i] * de / r;
382 de = dbr * sharedBornGrad.get(i);
383 incrementGradient(i, k, de, xr, yr, zr, transOp);
384
385 }
386 }
387 }
388 }
389 }
390 }
391
392 @Override
393 public void start() {
394 threadID = getThreadIndex();
395 }
396
397 private double neckDescreenDerivative(double r, double radius, double radiusK, double sneck) {
398 double radiusWater = 1.4;
399
400 if (r > radius + radiusK + 2 * radiusWater) {
401 return 0.0;
402 }
403
404
405 double[] constants = getNeckConstants(radius, radiusK);
406
407
408 double Aij = constants[0];
409 double Bij = constants[1];
410
411
412 double rMinusBij = r - Bij;
413 double rMinusBij3 = rMinusBij * rMinusBij * rMinusBij;
414 double rMinusBij4 = rMinusBij3 * rMinusBij;
415 double radiiMinusr = radius + radiusK + 2.0 * radiusWater - r;
416 double radiiMinusr3 = radiiMinusr * radiiMinusr * radiiMinusr;
417 double radiiMinusr4 = radiiMinusr3 * radiiMinusr;
418
419 return 4.0 * PI4_3 * (sneck * Aij * rMinusBij3 * radiiMinusr4
420 - sneck * Aij * rMinusBij4 * radiiMinusr3);
421 }
422
423 private double descreenDerivative(double r, double r2, double radius, double radiusK,
424 double hctScale) {
425 if (perfectHCTScale) {
426 return perfectHCTIntegralDerivative(r, r2, radius, radiusK, hctScale);
427 } else {
428 return integralDerivative(r, r2, radius, radiusK * hctScale);
429 }
430 }
431
432
433
434
435
436
437
438
439
440
441 private double integralDerivative(double r, double r2, double radius, double scaledRadius) {
442 double de = 0.0;
443
444
445 if (scaledRadius > 0.0 && (radius < r + scaledRadius)) {
446
447 if (radius + r < scaledRadius) {
448 double uik = scaledRadius - r;
449 double uik2 = uik * uik;
450 double uik4 = uik2 * uik2;
451 de = -4.0 * PI / uik4;
452 }
453
454
455 double sk2 = scaledRadius * scaledRadius;
456 if (radius + r < scaledRadius) {
457
458 double lik = scaledRadius - r;
459 double lik2 = lik * lik;
460 double lik4 = lik2 * lik2;
461 de = de + 0.25 * PI * (sk2 - 4.0 * scaledRadius * r + 17.0 * r2) / (r2 * lik4);
462 } else if (r < radius + scaledRadius) {
463
464 double lik = radius;
465 double lik2 = lik * lik;
466 double lik4 = lik2 * lik2;
467 de = de + 0.25 * PI * (2.0 * radius * radius - sk2 - r2) / (r2 * lik4);
468 } else {
469
470 double lik = r - scaledRadius;
471 double lik2 = lik * lik;
472 double lik4 = lik2 * lik2;
473 de = de + 0.25 * PI * (sk2 - 4.0 * scaledRadius * r + r2) / (r2 * lik4);
474 }
475
476
477 double uik = r + scaledRadius;
478 double uik2 = uik * uik;
479 double uik4 = uik2 * uik2;
480 de = de - 0.25 * PI * (sk2 + 4.0 * scaledRadius * r + r2) / (r2 * uik4);
481 }
482 return de;
483 }
484
485
486
487
488
489
490
491
492
493
494
495 private double perfectHCTIntegralDerivative(double r, double r2, double radius, double radiusK,
496 double perfectHCT) {
497 double de = 0.0;
498
499
500 if (radiusK > 0.0 && (radius < r + radiusK)) {
501
502 if (radius + r < radiusK) {
503 double uik = radiusK - r;
504 double uik2 = uik * uik;
505 double uik4 = uik2 * uik2;
506 de = -4.0 * PI / uik4;
507 }
508
509
510 double sk2 = radiusK * radiusK;
511 if (radius + r < radiusK) {
512
513 double lik = radiusK - r;
514 double lik2 = lik * lik;
515 double lik4 = lik2 * lik2;
516 de = de + 0.25 * PI * (sk2 - 4.0 * radiusK * r + 17.0 * r2) / (r2 * lik4);
517 } else if (r < radius + radiusK) {
518
519 double lik = radius;
520 double lik2 = lik * lik;
521 double lik4 = lik2 * lik2;
522 de = de + 0.25 * PI * (2.0 * radius * radius - sk2 - r2) / (r2 * lik4);
523 } else {
524
525 double lik = r - radiusK;
526 double lik2 = lik * lik;
527 double lik4 = lik2 * lik2;
528 de = de + 0.25 * PI * (sk2 - 4.0 * radiusK * r + r2) / (r2 * lik4);
529 }
530
531
532 double uik = r + radiusK;
533 double uik2 = uik * uik;
534 double uik4 = uik2 * uik2;
535 de = de - 0.25 * PI * (sk2 + 4.0 * radiusK * r + r2) / (r2 * uik4);
536 }
537 return perfectHCT * de;
538 }
539
540
541
542
543
544
545
546
547
548
549
550 private void incrementGradient(
551 int i, int k, double dE, double xr, double yr, double zr, double[][] transOp) {
552 double dedx = dE * xr;
553 double dedy = dE * yr;
554 double dedz = dE * zr;
555 grad.add(threadID, i, dedx, dedy, dedz);
556 final double dedxk = dedx * transOp[0][0] + dedy * transOp[1][0] + dedz * transOp[2][0];
557 final double dedyk = dedx * transOp[0][1] + dedy * transOp[1][1] + dedz * transOp[2][1];
558 final double dedzk = dedx * transOp[0][2] + dedy * transOp[1][2] + dedz * transOp[2][2];
559 grad.sub(threadID, k, dedxk, dedyk, dedzk);
560 }
561 }
562 }