1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.nonbonded;
39
40 import ffx.crystal.Crystal;
41 import ffx.numerics.tornado.FFXTornado;
42 import ffx.potential.bonded.Atom;
43 import ffx.potential.bonded.Bond;
44 import ffx.potential.parameters.AtomType;
45 import ffx.potential.parameters.ForceField;
46 import ffx.potential.parameters.VDWType;
47 import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
48 import uk.ac.manchester.tornado.api.TaskGraph;
49 import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
50 import uk.ac.manchester.tornado.api.annotations.Parallel;
51 import uk.ac.manchester.tornado.api.annotations.Reduce;
52 import uk.ac.manchester.tornado.api.common.TornadoDevice;
53 import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider;
54
55 import java.util.List;
56 import java.util.logging.Level;
57 import java.util.logging.Logger;
58
59 import static java.lang.String.format;
60 import static java.util.Arrays.fill;
61 import static uk.ac.manchester.tornado.api.enums.DataTransferMode.EVERY_EXECUTION;
62 import static uk.ac.manchester.tornado.api.math.TornadoMath.abs;
63 import static uk.ac.manchester.tornado.api.math.TornadoMath.floor;
64 import static uk.ac.manchester.tornado.api.math.TornadoMath.sqrt;
65
66
67
68
69
70
71
72
73
74
75
76 public class VanDerWaalsTornado extends VanDerWaals {
77
78 private static final Logger logger = Logger.getLogger(VanDerWaalsTornado.class.getName());
79 private static final byte XX = 0;
80 private static final byte YY = 1;
81 private static final byte ZZ = 2;
82
83
84
85 private final VanDerWaalsForm vdwForm;
86
87
88
89 private final double vdwTaper;
90
91 private final double vdwCutoff;
92
93
94 private int interactions;
95 private double energy;
96 private double[] grad;
97 private Crystal crystal;
98
99
100
101 private Atom[] atoms;
102
103
104
105 private ForceField forceField;
106
107
108
109 private int nAtoms;
110
111
112
113 private int[] atomClass;
114
115
116
117 private double[] coordinates;
118
119
120
121
122 private double[] reductionValue;
123
124
125
126
127 private int[] reductionIndex;
128
129
130
131 private int[] mask;
132
133
134
135 private int[] maskPointer;
136
137
138
139
140
141
142
143
144
145
146 public VanDerWaalsTornado(
147 Atom[] atoms, Crystal crystal, ForceField forceField, double vdwCutoff) {
148 this.atoms = atoms;
149 this.crystal = crystal;
150 this.forceField = forceField;
151 nAtoms = atoms.length;
152 vdwForm = new VanDerWaalsForm(forceField);
153
154
155 initAtomArrays();
156
157
158
159
160
161
162 this.vdwCutoff = vdwCutoff;
163 this.vdwTaper = 0.9 * vdwCutoff;
164 logger.info(toString());
165 }
166
167
168
169
170 private static void tornadoEnergy(
171 int[] atomClass,
172 double[] eps,
173 double[] rMin,
174 double[] reducedXYZ,
175 int[] reductionIndex,
176 double[] reductionValue,
177 double[] bondedScaleFactors,
178 int[] maskPointers,
179 int[] masks,
180 double[] A,
181 double[] Ai,
182 double[] cutoffs,
183 @Reduce double[] energy,
184 @Reduce int[] interactions,
185 @Reduce double[] grad) {
186
187
188
189 double A00 = A[0];
190 double A01 = A[1];
191 double A02 = A[2];
192 double A10 = A[3];
193 double A11 = A[4];
194 double A12 = A[5];
195 double A20 = A[6];
196 double A21 = A[7];
197 double A22 = A[8];
198
199 double Ai00 = Ai[0];
200 double Ai01 = Ai[1];
201 double Ai02 = Ai[2];
202
203 double Ai10 = Ai[3];
204 double Ai11 = Ai[4];
205 double Ai12 = Ai[5];
206
207 double Ai20 = Ai[6];
208 double Ai21 = Ai[7];
209 double Ai22 = Ai[8];
210
211
212 double scale12 = bondedScaleFactors[0];
213 double scale13 = bondedScaleFactors[1];
214 double scale14 = bondedScaleFactors[2];
215
216
217 boolean aperiodic = false;
218 if (cutoffs[0] > 0) {
219 aperiodic = true;
220 }
221
222 double vdwTaper = cutoffs[1];
223 double vdwCutoff = cutoffs[2];
224 double vdwTaper2 = vdwTaper * vdwTaper;
225 double vdwCutoff2 = vdwCutoff * vdwCutoff;
226 boolean gradient = false;
227 if (cutoffs[3] > 0) {
228 gradient = true;
229 }
230
231
232 double a = vdwTaper;
233 double b = vdwCutoff;
234 double a2 = a * a;
235 double b2 = b * b;
236 double ba = b - a;
237 double ba2 = ba * ba;
238 double denom = ba * ba2 * ba2;
239 double c0 = b * b2 * (b2 - 5.0 * a * b + 10.0 * a2) / denom;
240 double c1 = -30.0 * a2 * b2 / denom;
241 double c2 = 30.0 * b * a * (b + a) / denom;
242 double c3 = -10.0 * (a2 + 4.0 * a * b + b2) / denom;
243 double c4 = 15.0 * (a + b) / denom;
244 double c5 = -6.0 / denom;
245 double twoC2 = 2.0 * c2;
246 double threeC3 = 3.0 * c3;
247 double fourC4 = 4.0 * c4;
248 double fiveC5 = 5.0 * c5;
249
250
251 final int nAtoms = atomClass.length;
252 double[] mask = new double[nAtoms];
253 for (int i = 0; i < nAtoms; i++) {
254 mask[i] = 1.0;
255 }
256
257
258 final double delta = 0.07;
259 final double gamma = 0.12;
260 final double delta1 = 1.0 + delta;
261 final double d2 = delta1 * delta1;
262 final double d4 = d2 * d2;
263 final double t1n = delta1 * d2 * d4;
264 final double gamma1 = 1.0 + gamma;
265
266 final int XX = 0;
267 final int YY = 1;
268 final int ZZ = 2;
269
270 for (@Parallel int i = 0; i < nAtoms - 1; i++) {
271 final int i3 = i * 3;
272 final double xi = reducedXYZ[i3 + XX];
273 final double yi = reducedXYZ[i3 + YY];
274 final double zi = reducedXYZ[i3 + ZZ];
275 final int redi = reductionIndex[i];
276 final double redv = reductionValue[i];
277 final double rediv = 1.0 - redv;
278 final int classI = atomClass[i];
279 final double ei = eps[classI];
280 final double sei = sqrt(ei);
281 final double ri = rMin[classI];
282 if (ri <= 0.0) {
283 continue;
284 }
285 double gxi = 0.0;
286 double gyi = 0.0;
287 double gzi = 0.0;
288 double gxredi = 0.0;
289 double gyredi = 0.0;
290 double gzredi = 0.0;
291
292
293 for (int ii = maskPointers[i3]; ii < maskPointers[i3 + 1]; ii++) {
294 mask[masks[ii]] = scale12;
295 }
296 for (int ii = maskPointers[i3 + 1]; ii < maskPointers[i3 + 2]; ii++) {
297 mask[masks[ii]] = scale13;
298 }
299 for (int ii = maskPointers[i3 + 2]; ii < maskPointers[i3 + 3]; ii++) {
300 mask[masks[ii]] = scale14;
301 }
302
303
304 for (int k = i + 1; k < nAtoms; k++) {
305 final int k3 = k * 3;
306 final double xk = reducedXYZ[k3 + XX];
307 final double yk = reducedXYZ[k3 + YY];
308 final double zk = reducedXYZ[k3 + ZZ];
309 final double[] dx = new double[3];
310 dx[0] = xi - xk;
311 dx[1] = yi - yk;
312 dx[2] = zi - zk;
313 double x = dx[0];
314 double y = dx[1];
315 double z = dx[2];
316 double r2;
317 if (!aperiodic) {
318 double xf = x * A00 + y * A10 + z * A20;
319 double yf = x * A01 + y * A11 + z * A21;
320 double zf = x * A02 + y * A12 + z * A22;
321
322
323
324
325
326 double xfsn = 0.0;
327 if (-xf > 0.0) {
328 xfsn = 1.0;
329 } else if (-xf < 0.0) {
330 xfsn = -1.0;
331 }
332
333 double yfsn = 0.0;
334 if (-yf > 0.0) {
335 yfsn = 1.0;
336 } else if (-yf < 0.0) {
337 yfsn = -1.0;
338 }
339
340 double zfsn = 0.0;
341 if (-zf > 0.0) {
342 zfsn = 1.0;
343 } else if (-zf < 0.0) {
344 zfsn = -1.0;
345 }
346 xf = floor(abs(xf) + 0.5) * xfsn + xf;
347 yf = floor(abs(yf) + 0.5) * yfsn + yf;
348 zf = floor(abs(zf) + 0.5) * zfsn + zf;
349 x = xf * Ai00 + yf * Ai10 + zf * Ai20;
350 y = xf * Ai01 + yf * Ai11 + zf * Ai21;
351 z = xf * Ai02 + yf * Ai12 + zf * Ai22;
352 dx[0] = x;
353 dx[1] = y;
354 dx[2] = z;
355 }
356 r2 = x * x + y * y + z * z;
357 final int classK = atomClass[k];
358 final double rk = rMin[classK];
359 if (r2 <= vdwCutoff2 && mask[k] > 0 && rk > 0) {
360 double ri2 = ri * ri;
361 double ri3 = ri * ri2;
362 double rk2 = rk * rk;
363 double rk3 = rk * rk2;
364 double irv = 1.0 / (2.0 * (ri3 + rk3) / (ri2 + rk2));
365 final double r = sqrt(r2);
366
367
368
369
370
371
372
373 double ek = eps[classK];
374 double sek = sqrt(ek);
375 double ev = mask[k] * 4.0 * (ei * ek) / ((sei + sek) * (sei + sek));
376 final double rho = r * irv;
377 final double rho2 = rho * rho;
378 final double rhoDisp1 = rho2 * rho2 * rho2;
379 final double rhoDisp = rhoDisp1 * rho;
380 final double rhoD = rho + delta;
381 final double rhoD2 = rhoD * rhoD;
382 final double rhoDelta1 = rhoD2 * rhoD2 * rhoD2;
383 final double rhoDelta = rhoDelta1 * (rho + delta);
384 final double rhoDispGamma = rhoDisp + gamma;
385 final double t1d = 1.0 / rhoDelta;
386 final double t2d = 1.0 / rhoDispGamma;
387 final double t1 = t1n * t1d;
388 final double t2a = gamma1 * t2d;
389 final double t2 = t2a - 2.0;
390 double eik = ev * t1 * t2;
391
392
393
394
395 double taper = 1.0;
396 double dtaper = 0.0;
397 if (r2 > vdwTaper2) {
398 final double r3 = r2 * r;
399 final double r4 = r2 * r2;
400 final double r5 = r2 * r3;
401 taper = c5 * r5 + c4 * r4 + c3 * r3 + c2 * r2 + c1 * r + c0;
402 dtaper = fiveC5 * r4 + fourC4 * r3 + threeC3 * r2 + twoC2 * r + c1;
403 }
404 eik *= taper;
405 energy[0] += eik;
406 interactions[0] += 1;
407 if (!gradient) {
408 continue;
409 }
410 final int redk = reductionIndex[k];
411 final double red = reductionValue[k];
412 final double redkv = 1.0 - red;
413 final double dt1d_dr = 7.0 * rhoDelta1 * irv;
414 final double dt2d_dr = 7.0 * rhoDisp1 * irv;
415 final double dt1_dr = t1 * dt1d_dr * t1d;
416 final double dt2_dr = t2a * dt2d_dr * t2d;
417 final double dedr = -ev * (dt1_dr * t2 + t1 * dt2_dr);
418 final double ir = 1.0 / r;
419 final double drdx = dx[0] * ir;
420 final double drdy = dx[1] * ir;
421 final double drdz = dx[2] * ir;
422 final double dswitch = (eik * dtaper + dedr * taper);
423 final double dedx = dswitch * drdx;
424 final double dedy = dswitch * drdy;
425 final double dedz = dswitch * drdz;
426 gxi += dedx * redv;
427 gyi += dedy * redv;
428 gzi += dedz * redv;
429 gxredi += dedx * rediv;
430 gyredi += dedy * rediv;
431 gzredi += dedz * rediv;
432
433 grad[k3 + XX] -= red * dedx;
434 grad[k3 + YY] -= red * dedy;
435 grad[k3 + ZZ] -= red * dedz;
436
437 int r3 = redk * 3;
438 grad[r3 + XX] -= redkv * dedx;
439 grad[r3 + YY] -= redkv * dedy;
440 grad[r3 + ZZ] -= redkv * dedz;
441 }
442 }
443 if (gradient) {
444
445 grad[i3 + XX] += gxi;
446 grad[i3 + YY] += gyi;
447 grad[i3 + ZZ] += gzi;
448
449 int r3 = redi * 3;
450 grad[r3 + XX] += gxredi;
451 grad[r3 + YY] += gyredi;
452 grad[r3 + ZZ] += gzredi;
453 }
454
455
456 for (int ii = maskPointers[i3]; ii < maskPointers[i3 + 1]; ii++) {
457 mask[masks[ii]] = 1.0;
458 }
459 for (int ii = maskPointers[i3 + 1]; ii < maskPointers[i3 + 2]; ii++) {
460 mask[masks[ii]] = 1.0;
461 }
462 for (int ii = maskPointers[i3 + 2]; ii < maskPointers[i3 + 3]; ii++) {
463 mask[masks[ii]] = 1.0;
464 }
465 }
466 }
467
468
469
470
471
472
473
474
475
476 public double energy(boolean gradient, boolean print) {
477
478 if (vdwForm.vdwType != VDWType.VDW_TYPE.BUFFERED_14_7) {
479 logger.severe((" TornadoVM vdW only supports AMOEBA."));
480 }
481
482
483 for (int i = 0; i < nAtoms; i++) {
484 Atom atom = atoms[i];
485
486 double x = atom.getX();
487 double y = atom.getY();
488 double z = atom.getZ();
489 int i3 = i * 3;
490 coordinates[i3 + XX] = x;
491 coordinates[i3 + YY] = y;
492 coordinates[i3 + ZZ] = z;
493 }
494
495 double[] eps = vdwForm.getEps();
496 double[] rmin = vdwForm.getRmin();
497
498
499 double[] reducedXYZ = new double[nAtoms * 3];
500 final byte XX = 0;
501 final byte YY = 1;
502 final byte ZZ = 2;
503 for (int i = 0; i < nAtoms; i++) {
504 int i3 = i * 3;
505 double x = coordinates[i3 + XX];
506 double y = coordinates[i3 + YY];
507 double z = coordinates[i3 + ZZ];
508 int redIndex = reductionIndex[i];
509 if (redIndex >= 0) {
510 int r3 = redIndex * 3;
511 double rx = coordinates[r3 + XX];
512 double ry = coordinates[r3 + YY];
513 double rz = coordinates[r3 + ZZ];
514 double r = reductionValue[i];
515 reducedXYZ[i3 + XX] = r * (x - rx) + rx;
516 reducedXYZ[i3 + YY] = r * (y - ry) + ry;
517 reducedXYZ[i3 + ZZ] = r * (z - rz) + rz;
518 } else {
519 reducedXYZ[i3 + XX] = x;
520 reducedXYZ[i3 + YY] = y;
521 reducedXYZ[i3 + ZZ] = z;
522 }
523 }
524
525
526 double doGradient = 0.0;
527 if (gradient) {
528 doGradient = 1.0;
529 }
530
531
532 Crystal c = crystal;
533 double[] A = {c.A00, c.A01, c.A02, c.A10, c.A11, c.A12, c.A20, c.A21, c.A22};
534 double[] Ai = {c.Ai00, c.Ai01, c.Ai02, c.Ai10, c.Ai11, c.Ai12, c.Ai20, c.Ai21, c.Ai22};
535 double[] bondedScaleFactors = {vdwForm.scale12, vdwForm.scale13, vdwForm.scale14};
536 double aperiodic = 0.0;
537 if (crystal.aperiodic()) {
538 aperiodic = 1.0;
539 }
540
541
542 double[] cutoffs = {aperiodic, vdwTaper, vdwCutoff, doGradient};
543
544
545 double[] energy = new double[1];
546 int[] interactions = new int[1];
547 if (gradient) {
548 fill(grad, 0.0);
549 }
550
551
552 tornadoEnergy(
553 atomClass,
554 eps,
555 rmin,
556 reducedXYZ,
557 reductionIndex,
558 reductionValue,
559 bondedScaleFactors,
560 maskPointer,
561 mask,
562 A,
563 Ai,
564 cutoffs,
565 energy,
566 interactions,
567 grad);
568
569 logger.info(format(" JVM: %16.8f %d", energy[0], interactions[0]));
570
571 energy[0] = 0.0;
572 interactions[0] = 0;
573 if (gradient) {
574 fill(grad, 0.0);
575 }
576
577 TornadoDevice device = TornadoRuntimeProvider.getTornadoRuntime().getDefaultDevice();
578 FFXTornado.logDevice(device);
579 TaskGraph graph =
580 new TaskGraph("vdW").transferToDevice(EVERY_EXECUTION,
581 atomClass, eps, rmin, reducedXYZ, reductionIndex, reductionValue,
582 bondedScaleFactors, maskPointer, mask,
583 A, Ai, cutoffs, energy, interactions, grad)
584 .task("energy", VanDerWaalsTornado::tornadoEnergy,
585 atomClass, eps, rmin, reducedXYZ, reductionIndex, reductionValue,
586 bondedScaleFactors, maskPointer, mask,
587 A, Ai, cutoffs, energy, interactions, grad)
588 .transferToHost(EVERY_EXECUTION, energy, interactions, grad);
589
590 ImmutableTaskGraph itg = graph.snapshot();
591 TornadoExecutionPlan executionPlan = new TornadoExecutionPlan(itg);
592 executionPlan.withWarmUp().withDevice(device);
593 executionPlan.execute();
594
595 logger.info(format(" Tornado OpenCL: %16.8f %d", energy[0], interactions[0]));
596
597
598 if (gradient) {
599 for (int i = 0; i < nAtoms - 1; i++) {
600 Atom ai = atoms[i];
601 int i3 = i * 3;
602 ai.addToXYZGradient(grad[i3 + XX], grad[i3 + YY], grad[i3 + ZZ]);
603 }
604 }
605
606 this.energy = energy[0];
607 this.interactions = interactions[0];
608 return this.energy;
609 }
610
611
612
613
614
615
616
617 public double getEnergy() {
618 return energy;
619 }
620
621
622
623
624
625
626
627 public int getInteractions() {
628 return interactions;
629 }
630
631
632
633
634
635
636 public void setAtoms(Atom[] atoms) {
637 this.atoms = atoms;
638 this.nAtoms = atoms.length;
639 initAtomArrays();
640 }
641
642
643
644
645
646
647
648
649 public void setCrystal(Crystal crystal) {
650 this.crystal = crystal;
651 int newNSymm = crystal.getNumSymOps();
652 if (newNSymm != 1) {
653 String message = " SymOps are not supported by VanDerWaalsTornado.\n";
654 logger.log(Level.SEVERE, message);
655 }
656 }
657
658
659
660
661 @Override
662 public String toString() {
663 StringBuffer sb = new StringBuffer("\n Van der Waals\n");
664 sb.append(format(" Switch Start: %6.3f (A)\n", vdwTaper));
665 sb.append(format(" Cut-Off: %6.3f (A)\n", vdwCutoff));
666 return sb.toString();
667 }
668
669
670
671
672 private void initAtomArrays() {
673 if (atomClass == null || nAtoms > atomClass.length) {
674 atomClass = new int[nAtoms];
675 coordinates = new double[nAtoms * 3];
676 reductionIndex = new int[nAtoms];
677 reductionValue = new double[nAtoms];
678 grad = new double[nAtoms * 3];
679 maskPointer = new int[nAtoms * 3 + 1];
680 }
681
682 int numBonds = 0;
683 int numAngles = 0;
684 int numTorsions = 0;
685 for (int i = 0; i < nAtoms; i++) {
686 Atom ai = atoms[i];
687 numBonds += ai.getNumBonds();
688 numAngles += ai.getNumAngles();
689 numTorsions += ai.getNumDihedrals();
690 }
691 mask = new int[numBonds + numAngles + numTorsions];
692
693 int[][] mask12 = getMask12();
694 int[][] mask13 = getMask13();
695 int[][] mask14 = getMask14();
696
697 int index = 0;
698 for (int i = 0; i < nAtoms; i++) {
699 Atom ai = atoms[i];
700 assert (i == ai.getXyzIndex() - 1);
701 double[] xyz = ai.getXYZ(null);
702 int i3 = i * 3;
703 coordinates[i3 + XX] = xyz[XX];
704 coordinates[i3 + YY] = xyz[YY];
705 coordinates[i3 + ZZ] = xyz[ZZ];
706 AtomType atomType = ai.getAtomType();
707 if (atomType == null) {
708 logger.severe(ai.toString());
709 }
710 String vdwIndex = forceField.getString("VDWINDEX", "Class");
711 if (vdwIndex.equalsIgnoreCase("Type")) {
712 atomClass[i] = atomType.type;
713 } else {
714 atomClass[i] = atomType.atomClass;
715 }
716 VDWType type = forceField.getVDWType(Integer.toString(atomClass[i]));
717 if (type == null) {
718 logger.info(" No VdW type for atom class " + atomClass[i]);
719 logger.severe(" No VdW type for atom " + ai);
720 return;
721 }
722 ai.setVDWType(type);
723 List<Bond> bonds = ai.getBonds();
724 numBonds = bonds.size();
725 if (type.reductionFactor > 0.0 && numBonds == 1) {
726 Bond bond = bonds.get(0);
727 Atom heavyAtom = bond.get1_2(ai);
728
729 reductionIndex[i] = heavyAtom.getIndex() - 1;
730 reductionValue[i] = type.reductionFactor;
731 } else {
732 reductionIndex[i] = i;
733 reductionValue[i] = 0.0;
734 }
735
736 maskPointer[3 * i] = index;
737 for (int value : mask12[i]) {
738 mask[index++] = value;
739 }
740 maskPointer[3 * i + 1] = index;
741 for (int value : mask13[i]) {
742 mask[index++] = value;
743 }
744 maskPointer[3 * i + 2] = index;
745 for (int value : mask14[i]) {
746 mask[index++] = value;
747 }
748 }
749 maskPointer[3 * nAtoms] = index;
750 }
751
752
753
754
755
756
757
758
759
760
761
762 private void log(int i, int k, double minr, double r, double eij) {
763 logger.info(
764 format(
765 "VDW %6d-%s %6d-%s %10.4f %10.4f %10.4f",
766 atoms[i].getIndex(),
767 atoms[i].getAtomType().name,
768 atoms[k].getIndex(),
769 atoms[k].getAtomType().name,
770 1.0 / minr,
771 r,
772 eij));
773 }
774 }