1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.nonbonded;
39
40 import ffx.crystal.Crystal;
41 import ffx.numerics.tornado.FFXTornado;
42 import ffx.potential.bonded.Atom;
43 import ffx.potential.bonded.Bond;
44 import ffx.potential.parameters.AtomType;
45 import ffx.potential.parameters.ForceField;
46 import ffx.potential.parameters.VDWType;
47 import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
48 import uk.ac.manchester.tornado.api.TaskGraph;
49 import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
50 import uk.ac.manchester.tornado.api.annotations.Parallel;
51 import uk.ac.manchester.tornado.api.annotations.Reduce;
52 import uk.ac.manchester.tornado.api.common.TornadoDevice;
53 import uk.ac.manchester.tornado.api.runtime.TornadoRuntime;
54
55 import java.util.List;
56 import java.util.logging.Level;
57 import java.util.logging.Logger;
58
59 import static java.lang.String.format;
60 import static java.util.Arrays.fill;
61 import static uk.ac.manchester.tornado.api.math.TornadoMath.abs;
62 import static uk.ac.manchester.tornado.api.math.TornadoMath.floor;
63 import static uk.ac.manchester.tornado.api.math.TornadoMath.sqrt;
64 import static uk.ac.manchester.tornado.api.enums.DataTransferMode.EVERY_EXECUTION;
65
66
67
68
69
70
71
72
73
74
75
76 public class VanDerWaalsTornado extends VanDerWaals {
77
78 private static final Logger logger = Logger.getLogger(VanDerWaalsTornado.class.getName());
79 private static final byte XX = 0;
80 private static final byte YY = 1;
81 private static final byte ZZ = 2;
82
83 private final VanDerWaalsForm vdwForm;
84
85 private final double vdwTaper;
86
87 private final double vdwCutoff;
88
89
90 private int interactions;
91 private double energy;
92 private double[] grad;
93 private Crystal crystal;
94
95 private Atom[] atoms;
96
97 private ForceField forceField;
98
99 private int nAtoms;
100
101 private int[] atomClass;
102
103 private double[] coordinates;
104
105
106
107
108 private double[] reductionValue;
109
110
111
112
113 private int[] reductionIndex;
114
115 private int[] mask;
116
117 private int[] maskPointer;
118
119
120
121
122
123
124
125
126
127
128 public VanDerWaalsTornado(
129 Atom[] atoms, Crystal crystal, ForceField forceField, double vdwCutoff) {
130 this.atoms = atoms;
131 this.crystal = crystal;
132 this.forceField = forceField;
133 nAtoms = atoms.length;
134 vdwForm = new VanDerWaalsForm(forceField);
135
136
137 initAtomArrays();
138
139
140
141
142
143
144 this.vdwCutoff = vdwCutoff;
145 this.vdwTaper = 0.9 * vdwCutoff;
146 logger.info(toString());
147 }
148
149
150 private static void tornadoEnergy(
151 int[] atomClass,
152 double[] eps,
153 double[] rMin,
154 double[] reducedXYZ,
155 int[] reductionIndex,
156 double[] reductionValue,
157 double[] bondedScaleFactors,
158 int[] maskPointers,
159 int[] masks,
160 double[] A,
161 double[] Ai,
162 double[] cutoffs,
163 @Reduce double[] energy,
164 @Reduce int[] interactions,
165 @Reduce double[] grad) {
166
167
168
169 double A00 = A[0];
170 double A01 = A[1];
171 double A02 = A[2];
172 double A10 = A[3];
173 double A11 = A[4];
174 double A12 = A[5];
175 double A20 = A[6];
176 double A21 = A[7];
177 double A22 = A[8];
178
179 double Ai00 = Ai[0];
180 double Ai01 = Ai[1];
181 double Ai02 = Ai[2];
182
183 double Ai10 = Ai[3];
184 double Ai11 = Ai[4];
185 double Ai12 = Ai[5];
186
187 double Ai20 = Ai[6];
188 double Ai21 = Ai[7];
189 double Ai22 = Ai[8];
190
191
192 double scale12 = bondedScaleFactors[0];
193 double scale13 = bondedScaleFactors[1];
194 double scale14 = bondedScaleFactors[2];
195
196
197 boolean aperiodic = false;
198 if (cutoffs[0] > 0) {
199 aperiodic = true;
200 }
201
202 double vdwTaper = cutoffs[1];
203 double vdwCutoff = cutoffs[2];
204 double vdwTaper2 = vdwTaper * vdwTaper;
205 double vdwCutoff2 = vdwCutoff * vdwCutoff;
206 boolean gradient = false;
207 if (cutoffs[3] > 0) {
208 gradient = true;
209 }
210
211
212 double a = vdwTaper;
213 double b = vdwCutoff;
214 double a2 = a * a;
215 double b2 = b * b;
216 double ba = b - a;
217 double ba2 = ba * ba;
218 double denom = ba * ba2 * ba2;
219 double c0 = b * b2 * (b2 - 5.0 * a * b + 10.0 * a2) / denom;
220 double c1 = -30.0 * a2 * b2 / denom;
221 double c2 = 30.0 * b * a * (b + a) / denom;
222 double c3 = -10.0 * (a2 + 4.0 * a * b + b2) / denom;
223 double c4 = 15.0 * (a + b) / denom;
224 double c5 = -6.0 / denom;
225 double twoC2 = 2.0 * c2;
226 double threeC3 = 3.0 * c3;
227 double fourC4 = 4.0 * c4;
228 double fiveC5 = 5.0 * c5;
229
230
231 final int nAtoms = atomClass.length;
232 double[] mask = new double[nAtoms];
233 for (int i = 0; i < nAtoms; i++) {
234 mask[i] = 1.0;
235 }
236
237
238 final double delta = 0.07;
239 final double gamma = 0.12;
240 final double delta1 = 1.0 + delta;
241 final double d2 = delta1 * delta1;
242 final double d4 = d2 * d2;
243 final double t1n = delta1 * d2 * d4;
244 final double gamma1 = 1.0 + gamma;
245
246 final int XX = 0;
247 final int YY = 1;
248 final int ZZ = 2;
249
250 for (@Parallel int i = 0; i < nAtoms - 1; i++) {
251 final int i3 = i * 3;
252 final double xi = reducedXYZ[i3 + XX];
253 final double yi = reducedXYZ[i3 + YY];
254 final double zi = reducedXYZ[i3 + ZZ];
255 final int redi = reductionIndex[i];
256 final double redv = reductionValue[i];
257 final double rediv = 1.0 - redv;
258 final int classI = atomClass[i];
259 final double ei = eps[classI];
260 final double sei = sqrt(ei);
261 final double ri = rMin[classI];
262 if (ri <= 0.0) {
263 continue;
264 }
265 double gxi = 0.0;
266 double gyi = 0.0;
267 double gzi = 0.0;
268 double gxredi = 0.0;
269 double gyredi = 0.0;
270 double gzredi = 0.0;
271
272
273 for (int ii = maskPointers[i3]; ii < maskPointers[i3 + 1]; ii++) {
274 mask[masks[ii]] = scale12;
275 }
276 for (int ii = maskPointers[i3 + 1]; ii < maskPointers[i3 + 2]; ii++) {
277 mask[masks[ii]] = scale13;
278 }
279 for (int ii = maskPointers[i3 + 2]; ii < maskPointers[i3 + 3]; ii++) {
280 mask[masks[ii]] = scale14;
281 }
282
283
284 for (int k = i + 1; k < nAtoms; k++) {
285 final int k3 = k * 3;
286 final double xk = reducedXYZ[k3 + XX];
287 final double yk = reducedXYZ[k3 + YY];
288 final double zk = reducedXYZ[k3 + ZZ];
289 final double[] dx = new double[3];
290 dx[0] = xi - xk;
291 dx[1] = yi - yk;
292 dx[2] = zi - zk;
293 double x = dx[0];
294 double y = dx[1];
295 double z = dx[2];
296 double r2;
297 if (!aperiodic) {
298 double xf = x * A00 + y * A10 + z * A20;
299 double yf = x * A01 + y * A11 + z * A21;
300 double zf = x * A02 + y * A12 + z * A22;
301
302
303
304
305
306 double xfsn = 0.0;
307 if (-xf > 0.0) {
308 xfsn = 1.0;
309 } else if (-xf < 0.0) {
310 xfsn = -1.0;
311 }
312
313 double yfsn = 0.0;
314 if (-yf > 0.0) {
315 yfsn = 1.0;
316 } else if (-yf < 0.0) {
317 yfsn = -1.0;
318 }
319
320 double zfsn = 0.0;
321 if (-zf > 0.0) {
322 zfsn = 1.0;
323 } else if (-zf < 0.0) {
324 zfsn = -1.0;
325 }
326 xf = floor(abs(xf) + 0.5) * xfsn + xf;
327 yf = floor(abs(yf) + 0.5) * yfsn + yf;
328 zf = floor(abs(zf) + 0.5) * zfsn + zf;
329 x = xf * Ai00 + yf * Ai10 + zf * Ai20;
330 y = xf * Ai01 + yf * Ai11 + zf * Ai21;
331 z = xf * Ai02 + yf * Ai12 + zf * Ai22;
332 dx[0] = x;
333 dx[1] = y;
334 dx[2] = z;
335 }
336 r2 = x * x + y * y + z * z;
337 final int classK = atomClass[k];
338 final double rk = rMin[classK];
339 if (r2 <= vdwCutoff2 && mask[k] > 0 && rk > 0) {
340 double ri2 = ri * ri;
341 double ri3 = ri * ri2;
342 double rk2 = rk * rk;
343 double rk3 = rk * rk2;
344 double irv = 1.0 / (2.0 * (ri3 + rk3) / (ri2 + rk2));
345 final double r = sqrt(r2);
346
347
348
349
350
351
352
353 double ek = eps[classK];
354 double sek = sqrt(ek);
355 double ev = mask[k] * 4.0 * (ei * ek) / ((sei + sek) * (sei + sek));
356 final double rho = r * irv;
357 final double rho2 = rho * rho;
358 final double rhoDisp1 = rho2 * rho2 * rho2;
359 final double rhoDisp = rhoDisp1 * rho;
360 final double rhoD = rho + delta;
361 final double rhoD2 = rhoD * rhoD;
362 final double rhoDelta1 = rhoD2 * rhoD2 * rhoD2;
363 final double rhoDelta = rhoDelta1 * (rho + delta);
364 final double rhoDispGamma = rhoDisp + gamma;
365 final double t1d = 1.0 / rhoDelta;
366 final double t2d = 1.0 / rhoDispGamma;
367 final double t1 = t1n * t1d;
368 final double t2a = gamma1 * t2d;
369 final double t2 = t2a - 2.0;
370 double eik = ev * t1 * t2;
371
372
373
374
375 double taper = 1.0;
376 double dtaper = 0.0;
377 if (r2 > vdwTaper2) {
378 final double r3 = r2 * r;
379 final double r4 = r2 * r2;
380 final double r5 = r2 * r3;
381 taper = c5 * r5 + c4 * r4 + c3 * r3 + c2 * r2 + c1 * r + c0;
382 dtaper = fiveC5 * r4 + fourC4 * r3 + threeC3 * r2 + twoC2 * r + c1;
383 }
384 eik *= taper;
385 energy[0] += eik;
386 interactions[0] += 1;
387 if (!gradient) {
388 continue;
389 }
390 final int redk = reductionIndex[k];
391 final double red = reductionValue[k];
392 final double redkv = 1.0 - red;
393 final double dt1d_dr = 7.0 * rhoDelta1 * irv;
394 final double dt2d_dr = 7.0 * rhoDisp1 * irv;
395 final double dt1_dr = t1 * dt1d_dr * t1d;
396 final double dt2_dr = t2a * dt2d_dr * t2d;
397 final double dedr = -ev * (dt1_dr * t2 + t1 * dt2_dr);
398 final double ir = 1.0 / r;
399 final double drdx = dx[0] * ir;
400 final double drdy = dx[1] * ir;
401 final double drdz = dx[2] * ir;
402 final double dswitch = (eik * dtaper + dedr * taper);
403 final double dedx = dswitch * drdx;
404 final double dedy = dswitch * drdy;
405 final double dedz = dswitch * drdz;
406 gxi += dedx * redv;
407 gyi += dedy * redv;
408 gzi += dedz * redv;
409 gxredi += dedx * rediv;
410 gyredi += dedy * rediv;
411 gzredi += dedz * rediv;
412
413 grad[k3 + XX] -= red * dedx;
414 grad[k3 + YY] -= red * dedy;
415 grad[k3 + ZZ] -= red * dedz;
416
417 int r3 = redk * 3;
418 grad[r3 + XX] -= redkv * dedx;
419 grad[r3 + YY] -= redkv * dedy;
420 grad[r3 + ZZ] -= redkv * dedz;
421 }
422 }
423 if (gradient) {
424
425 grad[i3 + XX] += gxi;
426 grad[i3 + YY] += gyi;
427 grad[i3 + ZZ] += gzi;
428
429 int r3 = redi * 3;
430 grad[r3 + XX] += gxredi;
431 grad[r3 + YY] += gyredi;
432 grad[r3 + ZZ] += gzredi;
433 }
434
435
436 for (int ii = maskPointers[i3]; ii < maskPointers[i3 + 1]; ii++) {
437 mask[masks[ii]] = 1.0;
438 }
439 for (int ii = maskPointers[i3 + 1]; ii < maskPointers[i3 + 2]; ii++) {
440 mask[masks[ii]] = 1.0;
441 }
442 for (int ii = maskPointers[i3 + 2]; ii < maskPointers[i3 + 3]; ii++) {
443 mask[masks[ii]] = 1.0;
444 }
445 }
446 }
447
448
449
450
451
452
453
454
455
456 public double energy(boolean gradient, boolean print) {
457
458 if (vdwForm.vdwType != VanDerWaalsForm.VDW_TYPE.BUFFERED_14_7) {
459 logger.severe((" TornadoVM vdW only supports AMOEBA."));
460 }
461
462
463 for (int i = 0; i < nAtoms; i++) {
464 Atom atom = atoms[i];
465
466 double x = atom.getX();
467 double y = atom.getY();
468 double z = atom.getZ();
469 int i3 = i * 3;
470 coordinates[i3 + XX] = x;
471 coordinates[i3 + YY] = y;
472 coordinates[i3 + ZZ] = z;
473 }
474
475 double[] eps = vdwForm.getEps();
476 double[] rmin = vdwForm.getRmin();
477
478
479 double[] reducedXYZ = new double[nAtoms * 3];
480 final byte XX = 0;
481 final byte YY = 1;
482 final byte ZZ = 2;
483 for (int i = 0; i < nAtoms; i++) {
484 int i3 = i * 3;
485 double x = coordinates[i3 + XX];
486 double y = coordinates[i3 + YY];
487 double z = coordinates[i3 + ZZ];
488 int redIndex = reductionIndex[i];
489 if (redIndex >= 0) {
490 int r3 = redIndex * 3;
491 double rx = coordinates[r3 + XX];
492 double ry = coordinates[r3 + YY];
493 double rz = coordinates[r3 + ZZ];
494 double r = reductionValue[i];
495 reducedXYZ[i3 + XX] = r * (x - rx) + rx;
496 reducedXYZ[i3 + YY] = r * (y - ry) + ry;
497 reducedXYZ[i3 + ZZ] = r * (z - rz) + rz;
498 } else {
499 reducedXYZ[i3 + XX] = x;
500 reducedXYZ[i3 + YY] = y;
501 reducedXYZ[i3 + ZZ] = z;
502 }
503 }
504
505
506 double doGradient = 0.0;
507 if (gradient) {
508 doGradient = 1.0;
509 }
510
511
512 Crystal c = crystal;
513 double[] A = {c.A00, c.A01, c.A02, c.A10, c.A11, c.A12, c.A20, c.A21, c.A22};
514 double[] Ai = {c.Ai00, c.Ai01, c.Ai02, c.Ai10, c.Ai11, c.Ai12, c.Ai20, c.Ai21, c.Ai22};
515 double[] bondedScaleFactors = {vdwForm.scale12, vdwForm.scale13, vdwForm.scale14};
516 double aperiodic = 0.0;
517 if (crystal.aperiodic()) {
518 aperiodic = 1.0;
519 }
520
521
522 double[] cutoffs = {aperiodic, vdwTaper, vdwCutoff, doGradient};
523
524
525 double[] energy = new double[1];
526 int[] interactions = new int[1];
527 if (gradient) {
528 fill(grad, 0.0);
529 }
530
531
532 tornadoEnergy(
533 atomClass,
534 eps,
535 rmin,
536 reducedXYZ,
537 reductionIndex,
538 reductionValue,
539 bondedScaleFactors,
540 maskPointer,
541 mask,
542 A,
543 Ai,
544 cutoffs,
545 energy,
546 interactions,
547 grad);
548
549 logger.info(format(" JVM: %16.8f %d", energy[0], interactions[0]));
550
551 energy[0] = 0.0;
552 interactions[0] = 0;
553 if (gradient) {
554 fill(grad, 0.0);
555 }
556
557 TornadoDevice device = TornadoRuntime.getTornadoRuntime().getDefaultDevice();
558 FFXTornado.logDevice(device);
559 TaskGraph graph =
560 new TaskGraph("vdW").transferToDevice(EVERY_EXECUTION,
561 atomClass, eps, rmin, reducedXYZ, reductionIndex, reductionValue,
562 bondedScaleFactors, maskPointer, mask,
563 A, Ai, cutoffs, energy, interactions, grad)
564 .task("energy", VanDerWaalsTornado::tornadoEnergy,
565 atomClass, eps, rmin, reducedXYZ, reductionIndex, reductionValue,
566 bondedScaleFactors, maskPointer, mask,
567 A, Ai, cutoffs, energy, interactions, grad)
568 .transferToHost(EVERY_EXECUTION, energy, interactions, grad);
569
570 ImmutableTaskGraph itg = graph.snapshot();
571 TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(itg);
572 executionPlan.withWarmUp().withDevice(device);
573 executionPlan.execute();
574
575 logger.info(format(" Tornado OpenCL: %16.8f %d", energy[0], interactions[0]));
576
577
578 if (gradient) {
579 for (int i = 0; i < nAtoms - 1; i++) {
580 Atom ai = atoms[i];
581 int i3 = i * 3;
582 ai.addToXYZGradient(grad[i3 + XX], grad[i3 + YY], grad[i3 + ZZ]);
583 }
584 }
585
586 this.energy = energy[0];
587 this.interactions = interactions[0];
588 return this.energy;
589 }
590
591
592
593
594
595
596
597 public double getEnergy() {
598 return energy;
599 }
600
601
602
603
604
605
606
607 public int getInteractions() {
608 return interactions;
609 }
610
611
612
613
614
615
616 public void setAtoms(Atom[] atoms) {
617 this.atoms = atoms;
618 this.nAtoms = atoms.length;
619 initAtomArrays();
620 }
621
622
623
624
625
626
627
628
629 public void setCrystal(Crystal crystal) {
630 this.crystal = crystal;
631 int newNSymm = crystal.getNumSymOps();
632 if (newNSymm != 1) {
633 String message = " SymOps are not supported by VanDerWaalsTornado.\n";
634 logger.log(Level.SEVERE, message);
635 }
636 }
637
638
639 @Override
640 public String toString() {
641 StringBuffer sb = new StringBuffer("\n Van der Waals\n");
642 sb.append(format(" Switch Start: %6.3f (A)\n", vdwTaper));
643 sb.append(format(" Cut-Off: %6.3f (A)\n", vdwCutoff));
644 return sb.toString();
645 }
646
647
648 private void initAtomArrays() {
649 if (atomClass == null || nAtoms > atomClass.length) {
650 atomClass = new int[nAtoms];
651 coordinates = new double[nAtoms * 3];
652 reductionIndex = new int[nAtoms];
653 reductionValue = new double[nAtoms];
654 grad = new double[nAtoms * 3];
655 maskPointer = new int[nAtoms * 3 + 1];
656 }
657
658 int numBonds = 0;
659 int numAngles = 0;
660 int numTorsions = 0;
661 for (int i = 0; i < nAtoms; i++) {
662 Atom ai = atoms[i];
663 numBonds += ai.getNumBonds();
664 numAngles += ai.getNumAngles();
665 numTorsions += ai.getNumDihedrals();
666 }
667 mask = new int[numBonds + numAngles + numTorsions];
668
669 int[][] mask12 = getMask12();
670 int[][] mask13 = getMask13();
671 int[][] mask14 = getMask14();
672
673 int index = 0;
674 for (int i = 0; i < nAtoms; i++) {
675 Atom ai = atoms[i];
676 assert (i == ai.getXyzIndex() - 1);
677 double[] xyz = ai.getXYZ(null);
678 int i3 = i * 3;
679 coordinates[i3 + XX] = xyz[XX];
680 coordinates[i3 + YY] = xyz[YY];
681 coordinates[i3 + ZZ] = xyz[ZZ];
682 AtomType atomType = ai.getAtomType();
683 if (atomType == null) {
684 logger.severe(ai.toString());
685 }
686 String vdwIndex = forceField.getString("VDWINDEX", "Class");
687 if (vdwIndex.equalsIgnoreCase("Type")) {
688 atomClass[i] = atomType.type;
689 } else {
690 atomClass[i] = atomType.atomClass;
691 }
692 VDWType type = forceField.getVDWType(Integer.toString(atomClass[i]));
693 if (type == null) {
694 logger.info(" No VdW type for atom class " + atomClass[i]);
695 logger.severe(" No VdW type for atom " + ai);
696 return;
697 }
698 ai.setVDWType(type);
699 List<Bond> bonds = ai.getBonds();
700 numBonds = bonds.size();
701 if (type.reductionFactor > 0.0 && numBonds == 1) {
702 Bond bond = bonds.get(0);
703 Atom heavyAtom = bond.get1_2(ai);
704
705 reductionIndex[i] = heavyAtom.getIndex() - 1;
706 reductionValue[i] = type.reductionFactor;
707 } else {
708 reductionIndex[i] = i;
709 reductionValue[i] = 0.0;
710 }
711
712 maskPointer[3 * i] = index;
713 for (int value : mask12[i]) {
714 mask[index++] = value;
715 }
716 maskPointer[3 * i + 1] = index;
717 for (int value : mask13[i]) {
718 mask[index++] = value;
719 }
720 maskPointer[3 * i + 2] = index;
721 for (int value : mask14[i]) {
722 mask[index++] = value;
723 }
724 }
725 maskPointer[3 * nAtoms] = index;
726 }
727
728
729
730
731
732
733
734
735
736
737
738 private void log(int i, int k, double minr, double r, double eij) {
739 logger.info(
740 format(
741 "VDW %6d-%s %6d-%s %10.4f %10.4f %10.4f",
742 atoms[i].getIndex(),
743 atoms[i].getAtomType().name,
744 atoms[k].getIndex(),
745 atoms[k].getAtomType().name,
746 1.0 / minr,
747 r,
748 eij));
749 }
750 }