1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import edu.rit.pj.IntegerForLoop;
41 import edu.rit.pj.IntegerSchedule;
42 import edu.rit.pj.ParallelRegion;
43 import edu.rit.pj.ParallelTeam;
44
45 import javax.annotation.Nullable;
46 import java.util.Arrays;
47 import java.util.Random;
48 import java.util.logging.Level;
49 import java.util.logging.Logger;
50
51 import jdk.incubator.vector.DoubleVector;
52 import jdk.incubator.vector.VectorShuffle;
53 import jdk.incubator.vector.VectorSpecies;
54
55 import static java.lang.String.format;
56 import static java.lang.System.arraycopy;
57 import static java.util.Objects.requireNonNullElseGet;
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75 public class Complex3DParallel {
76
77 private static final Logger logger = Logger.getLogger(Complex3DParallel.class.getName());
78
79
80
81 private final int nX;
82
83
84
85 private final int nY;
86
87
88
89 private final int nZ;
90
91
92
93 private final int im;
94
95
96
97 private final int ii;
98
99
100
101 private final int nextX;
102
103
104
105 private final int nextY;
106
107
108
109 private final int nextZ;
110
111
112
113 private final int trNextX;
114
115
116
117 private final int trNextY;
118
119
120
121 private final int trNextZ;
122
123
124
125 private final double[] recip;
126
127
128
129 private final int threadCount;
130
131
132
133 private final ParallelTeam parallelTeam;
134
135
136
137 private final Complex2D[] fftXY;
138
139
140
141
142 private final Complex[] fftZ;
143
144
145
146 private final int internalImZ;
147
148
149
150 private final IntegerSchedule schedule;
151
152
153
154 private final int nXm1;
155
156
157
158 private final int nYm1;
159
160
161
162 private final int nZm1;
163
164
165
166 private final FFTRegion fftRegion;
167
168
169
170 private final IFFTRegion ifftRegion;
171
172
173
174 private final ConvolutionRegion convRegion;
175
176
177
178
179 public double[] input;
180
181
182
183 private boolean useSIMD;
184 private final VectorSpecies<Double> species = DoubleVector.SPECIES_PREFERRED;
185 private final int vectorSize = species.length();
186 private final int[] shuffle = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7};
187 private final VectorShuffle<Double> expandFirstHalf = VectorShuffle.fromArray(species, shuffle, 0);
188 private final VectorShuffle<Double> expandSecondHalf = VectorShuffle.fromArray(species, shuffle, vectorSize);
189
190
191
192
193 private boolean packFFTs;
194 private final boolean localZTranspose;
195
196
197
198 public final double[] work3D;
199
200
201
202
203
204
205
206
207
208
209 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam) {
210 this(nX, nY, nZ, parallelTeam, DataLayout3D.INTERLEAVED);
211 }
212
213
214
215
216
217
218
219
220
221
222
223 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, DataLayout3D dataLayout) {
224 this(nX, nY, nZ, parallelTeam, null, dataLayout);
225 }
226
227
228
229
230
231
232
233
234
235
236
237 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, @Nullable IntegerSchedule integerSchedule) {
238 this(nX, nY, nZ, parallelTeam, integerSchedule, DataLayout3D.INTERLEAVED);
239 }
240
241
242
243
244
245
246
247
248
249
250
251
252 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam,
253 @Nullable IntegerSchedule integerSchedule, DataLayout3D dataLayout) {
254 this.nX = nX;
255 this.nY = nY;
256 this.nZ = nZ;
257 this.parallelTeam = parallelTeam;
258 recip = new double[nX * nY * nZ];
259
260 DataLayout1D dataLayout1D;
261 DataLayout2D dataLayout2D;
262 switch (dataLayout) {
263 default:
264 case INTERLEAVED:
265
266 im = 1;
267 ii = 2;
268 nextX = 2;
269 nextY = 2 * nX;
270 nextZ = 2 * nX * nY;
271
272 dataLayout1D = DataLayout1D.INTERLEAVED;
273 dataLayout2D = DataLayout2D.INTERLEAVED;
274
275 internalImZ = 1;
276 trNextY = 2;
277 trNextZ = 2 * nY;
278 trNextX = 2 * nY * nZ;
279 break;
280 case BLOCKED_X:
281
282 im = nX;
283 ii = 1;
284 nextX = 1;
285 nextY = 2 * nX;
286 nextZ = 2 * nX * nY;
287
288 dataLayout1D = DataLayout1D.BLOCKED;
289 dataLayout2D = DataLayout2D.BLOCKED_X;
290
291 internalImZ = nY * nZ;
292 trNextY = 1;
293 trNextZ = nY;
294 trNextX = 2 * nY * nZ;
295 break;
296 case BLOCKED_XY:
297
298 im = nX * nY;
299 ii = 1;
300 nextX = 1;
301 nextY = nX;
302 nextZ = 2 * nY * nX;
303
304 dataLayout1D = DataLayout1D.BLOCKED;
305 dataLayout2D = DataLayout2D.BLOCKED_XY;
306
307 internalImZ = nY * nZ;
308 trNextY = 1;
309 trNextZ = nY;
310 trNextX = 2 * nY * nZ;
311 break;
312 case BLOCKED_XYZ:
313
314 im = nX * nY * nZ;
315 ii = 1;
316 nextX = 1;
317 nextY = nX;
318 nextZ = nY * nX;
319
320 dataLayout1D = DataLayout1D.BLOCKED;
321 dataLayout2D = DataLayout2D.BLOCKED_XY;
322
323 internalImZ = nY * nZ;
324 trNextY = 1;
325 trNextZ = nY;
326 trNextX = 2 * nY * nZ;
327 break;
328 }
329
330 nXm1 = this.nX - 1;
331 nYm1 = this.nY - 1;
332 nZm1 = this.nZ - 1;
333 threadCount = parallelTeam.getThreadCount();
334 schedule = requireNonNullElseGet(integerSchedule, IntegerSchedule::fixed);
335
336
337 useSIMD = true;
338 String simd = System.getProperty("fft.simd", Boolean.toString(useSIMD));
339 try {
340 useSIMD = Boolean.parseBoolean(simd);
341 } catch (Exception e) {
342 useSIMD = false;
343 }
344
345
346 packFFTs = true;
347 String pack = System.getProperty("fft.pack", Boolean.toString(packFFTs));
348 try {
349 packFFTs = Boolean.parseBoolean(pack);
350 } catch (Exception e) {
351 packFFTs = false;
352 }
353
354
355 String localTranspose = System.getProperty("fft.localZTranspose", "false");
356 boolean local;
357 try {
358 local = Boolean.parseBoolean(localTranspose);
359 } catch (Exception e) {
360 local = true;
361 }
362 localZTranspose = local;
363
364 fftXY = new Complex2D[threadCount];
365 for (int i = 0; i < threadCount; i++) {
366 fftXY[i] = new Complex2D(nX, nY, dataLayout2D, im);
367 fftXY[i].setPackFFTs(packFFTs);
368 fftXY[i].setUseSIMD(useSIMD);
369 }
370
371 fftZ = new Complex[threadCount];
372 for (int i = 0; i < threadCount; i++) {
373 fftZ[i] = new Complex(nZ, dataLayout1D, internalImZ, nY);
374 fftZ[i].setUseSIMD(useSIMD);
375 }
376
377 if (localZTranspose) {
378 work3D = null;
379 } else {
380 work3D = new double[2 * nX * nY * nZ];
381 }
382
383 fftRegion = new FFTRegion();
384 ifftRegion = new IFFTRegion();
385 convRegion = new ConvolutionRegion();
386 }
387
388 public String toString() {
389 return "Complex3DParallel {" +
390 "nX=" + nX +
391 ", nY=" + nY +
392 ", nZ=" + nZ +
393 ", im=" + im +
394 ", ii=" + ii +
395 ", nextX=" + nextX +
396 ", nextY=" + nextY +
397 ", nextZ=" + nextZ +
398 ", threadCount=" + threadCount +
399 ", parallelTeam=" + parallelTeam +
400 ", internalImZ=" + internalImZ +
401 ", schedule=" + schedule +
402 ", nXm1=" + nXm1 +
403 ", nYm1=" + nYm1 +
404 ", nZm1=" + nZm1 +
405 ", fftRegion=" + fftRegion +
406 ", ifftRegion=" + ifftRegion +
407 ", convRegion=" + convRegion +
408 ", useSIMD=" + useSIMD +
409 ", packFFTs=" + packFFTs +
410 '}';
411 }
412
413
414
415
416
417
418 public void setUseSIMD(boolean useSIMD) {
419 this.useSIMD = useSIMD;
420 for (int i = 0; i < threadCount; i++) {
421 fftXY[i].setUseSIMD(useSIMD);
422 fftZ[i].setUseSIMD(useSIMD);
423 }
424 }
425
426
427
428
429
430
431 public void setPackFFTs(boolean packFFTs) {
432 this.packFFTs = packFFTs;
433 for (int i = 0; i < threadCount; i++) {
434 fftXY[i].setPackFFTs(packFFTs);
435 }
436 }
437
438
439
440
441
442
443
444
445 public void convolution(final double[] input) {
446 this.input = input;
447 try {
448 parallelTeam.execute(convRegion);
449 } catch (Exception e) {
450 String message = "Fatal exception evaluating a convolution.\n";
451 logger.log(Level.SEVERE, message, e);
452 }
453 }
454
455
456
457
458
459
460
461 public void fft(final double[] input) {
462 this.input = input;
463 try {
464 parallelTeam.execute(fftRegion);
465 } catch (Exception e) {
466 String message = " Fatal exception evaluating the FFT.\n";
467 logger.log(Level.SEVERE, message, e);
468 }
469 }
470
471
472
473
474
475
476 public long[] getTiming() {
477 return convRegion.getTiming();
478 }
479
480
481
482
483
484
485
486 public void ifft(final double[] input) {
487 this.input = input;
488 try {
489 parallelTeam.execute(ifftRegion);
490 } catch (Exception e) {
491 String message = "Fatal exception evaluating the inverse FFT.\n";
492 logger.log(Level.SEVERE, message, e);
493 System.exit(-1);
494 }
495 }
496
497
498
499
500 public void initTiming() {
501 convRegion.initTiming();
502 }
503
504
505
506
507
508
509 public String timingString() {
510 return convRegion.timingString();
511 }
512
513
514
515
516
517
518 public void setRecip(double[] recip) {
519
520
521
522
523
524
525 int recipNextY = nX;
526 int recipNextZ = nY * nX;
527 int index = 0;
528 for (int x = 0; x < nX; x++) {
529 int dx = x;
530 for (int z = 0; z < nZ; z++) {
531 int dz = dx + z * recipNextZ;
532 for (int y = 0; y < nY; y++) {
533 int conv = y * recipNextY + dz;
534 this.recip[index] = recip[conv];
535 index++;
536 }
537 }
538 }
539 }
540
541
542
543
544
545
546
547
548
549
550
551
552 private class FFTRegion extends ParallelRegion {
553
554 private final FFTXYLoop[] fftXYLoop;
555 private final FFTZLoop[] fftZLoop;
556 private final TransposeLoop[] transposeLoop;
557 private final UnTransposeLoop[] unTransposeLoop;
558
559 private FFTRegion() {
560 fftXYLoop = new FFTXYLoop[threadCount];
561 fftZLoop = new FFTZLoop[threadCount];
562 for (int i = 0; i < threadCount; i++) {
563 fftXYLoop[i] = new FFTXYLoop();
564 fftZLoop[i] = new FFTZLoop();
565 }
566 if (!localZTranspose) {
567 transposeLoop = new TransposeLoop[threadCount];
568 unTransposeLoop = new UnTransposeLoop[threadCount];
569 for (int i = 0; i < threadCount; i++) {
570 transposeLoop[i] = new TransposeLoop();
571 unTransposeLoop[i] = new UnTransposeLoop();
572 }
573 } else {
574 transposeLoop = null;
575 unTransposeLoop = null;
576 }
577 }
578
579 @Override
580 public void run() {
581 int threadIndex = getThreadIndex();
582 try {
583 if (localZTranspose) {
584 execute(0, nZm1, fftXYLoop[threadIndex]);
585 execute(0, nXm1, fftZLoop[threadIndex]);
586 } else {
587 execute(0, nZm1, fftXYLoop[threadIndex]);
588 execute(0, nXm1, transposeLoop[threadIndex]);
589 execute(0, nXm1, fftZLoop[threadIndex]);
590 execute(0, nZm1, unTransposeLoop[threadIndex]);
591 }
592 } catch (Exception e) {
593 logger.severe(e.toString());
594 }
595 }
596 }
597
598
599
600
601
602
603
604
605
606
607
608
609 private class IFFTRegion extends ParallelRegion {
610
611 private final IFFTXYLoop[] ifftXYLoop;
612 private final IFFTZLoop[] ifftZLoop;
613 private final TransposeLoop[] transposeLoop;
614 private final UnTransposeLoop[] unTransposeLoop;
615
616 private IFFTRegion() {
617 ifftXYLoop = new IFFTXYLoop[threadCount];
618 ifftZLoop = new IFFTZLoop[threadCount];
619 for (int i = 0; i < threadCount; i++) {
620 ifftXYLoop[i] = new IFFTXYLoop();
621 ifftZLoop[i] = new IFFTZLoop();
622 }
623 if (!localZTranspose) {
624 transposeLoop = new TransposeLoop[threadCount];
625 unTransposeLoop = new UnTransposeLoop[threadCount];
626 for (int i = 0; i < threadCount; i++) {
627 transposeLoop[i] = new TransposeLoop();
628 unTransposeLoop[i] = new UnTransposeLoop();
629 }
630 } else {
631 transposeLoop = null;
632 unTransposeLoop = null;
633 }
634 }
635
636 @Override
637 public void run() {
638 int threadIndex = getThreadIndex();
639 try {
640 if (localZTranspose) {
641 execute(0, nXm1, ifftZLoop[threadIndex]);
642 execute(0, nZm1, ifftXYLoop[threadIndex]);
643 } else {
644 execute(0, nXm1, transposeLoop[threadIndex]);
645 execute(0, nXm1, ifftZLoop[threadIndex]);
646 execute(0, nZm1, unTransposeLoop[threadIndex]);
647 execute(0, nZm1, ifftXYLoop[threadIndex]);
648 }
649 } catch (Exception e) {
650 logger.severe(e.toString());
651 }
652 }
653 }
654
655
656
657
658
659
660
661
662
663
664
665
666
667 private class ConvolutionRegion extends ParallelRegion {
668
669 private final FFTXYLoop[] fftXYLoop;
670 private final TransposeLoop[] transposeLoop;
671 private final FFTZIZLoop[] fftZIZLoop;
672 private final UnTransposeLoop[] unTransposeLoop;
673 private final IFFTXYLoop[] ifftXYLoop;
674 private final long[] convTime;
675
676 private ConvolutionRegion() {
677 fftXYLoop = new FFTXYLoop[threadCount];
678 fftZIZLoop = new FFTZIZLoop[threadCount];
679 ifftXYLoop = new IFFTXYLoop[threadCount];
680 convTime = new long[threadCount];
681 for (int i = 0; i < threadCount; i++) {
682 fftXYLoop[i] = new FFTXYLoop();
683 fftZIZLoop[i] = new FFTZIZLoop();
684 ifftXYLoop[i] = new IFFTXYLoop();
685 }
686 if (!localZTranspose) {
687 transposeLoop = new TransposeLoop[threadCount];
688 unTransposeLoop = new UnTransposeLoop[threadCount];
689 for (int i = 0; i < threadCount; i++) {
690 transposeLoop[i] = new TransposeLoop();
691 unTransposeLoop[i] = new UnTransposeLoop();
692 }
693 } else {
694 transposeLoop = null;
695 unTransposeLoop = null;
696 }
697 }
698
699 public void initTiming() {
700 for (int i = 0; i < threadCount; i++) {
701 fftXYLoop[i].time = 0;
702 fftZIZLoop[i].time = 0;
703 ifftXYLoop[i].time = 0;
704 }
705 if (!localZTranspose) {
706 for (int i = 0; i < threadCount; i++) {
707 transposeLoop[i].time = 0;
708 unTransposeLoop[i].time = 0;
709 }
710 }
711 }
712
713 public long[] getTiming() {
714 if (localZTranspose) {
715 for (int i = 0; i < threadCount; i++) {
716 convTime[i] = convRegion.fftXYLoop[i].time
717 + convRegion.fftZIZLoop[i].time
718 + convRegion.ifftXYLoop[i].time;
719 }
720 } else {
721 for (int i = 0; i < threadCount; i++) {
722 convTime[i] = convRegion.fftXYLoop[i].time
723 + convRegion.transposeLoop[i].time
724 + convRegion.fftZIZLoop[i].time
725 + convRegion.unTransposeLoop[i].time
726 + convRegion.ifftXYLoop[i].time;
727 }
728 }
729 return convTime;
730 }
731
732 public String timingString() {
733 StringBuilder sb = new StringBuilder();
734 if (localZTranspose) {
735 double xysum = 0.0;
736 double zizsum = 0.0;
737 double ixysum = 0.0;
738 for (int i = 0; i < threadCount; i++) {
739 double fftxy = fftXYLoop[i].getTime() * 1e-9;
740 double ziz = fftZIZLoop[i].getTime() * 1e-9;
741 double ifftxy = ifftXYLoop[i].getTime() * 1e-9;
742 String s = format(" Thread %3d: FFTXY=%8.6f, FFTZIZ=%8.6f, IFFTXY=%8.6f\n", i, fftxy, ziz, ifftxy);
743 sb.append(s);
744 xysum += fftxy;
745 zizsum += ziz;
746 ixysum += ifftxy;
747 }
748 String s = format(" Sum : FFTXY=%8.6f, FFTZIZ=%8.6f, IFFTXY=%8.6f\n", xysum, zizsum, ixysum);
749 sb.append(s);
750 } else {
751 double xysum = 0.0;
752 double transsum = 0.0;
753 double zizsum = 0.0;
754 double untranssum = 0.0;
755 double ixysum = 0.0;
756 for (int i = 0; i < threadCount; i++) {
757 double fftxy = fftXYLoop[i].getTime() * 1e-9;
758 double trans = transposeLoop[i].getTime() * 1e-9;
759 double ziz = fftZIZLoop[i].getTime() * 1e-9;
760 double untrans = unTransposeLoop[i].getTime() * 1e-9;
761 double ifftxy = ifftXYLoop[i].getTime() * 1e-9;
762 String s = format(" Thread %3d: FFTXY=%8.6f, Trans=%8.6f, FFTZIZ=%8.6f, UnTrans=%8.6f, IFFTXY=%8.6f\n",
763 i, fftxy, trans, ziz, untrans, ifftxy);
764 sb.append(s);
765 xysum += fftxy;
766 transsum += trans;
767 zizsum += ziz;
768 untranssum += untrans;
769 ixysum += ifftxy;
770 }
771 String s = format(" Sum : FFTXY=%8.6f, Trans=%8.6f, FFTZIZ=%8.6f, UnTrans=%8.6f, IFFTXY=%8.6f\n",
772 xysum, transsum, zizsum, untranssum, ixysum);
773 sb.append(s);
774 }
775
776 return sb.toString();
777 }
778
779 @Override
780 public void run() {
781 int threadIndex = getThreadIndex();
782 try {
783 if (localZTranspose) {
784 execute(0, nZm1, fftXYLoop[threadIndex]);
785 execute(0, nXm1, fftZIZLoop[threadIndex]);
786 execute(0, nZm1, ifftXYLoop[threadIndex]);
787 } else {
788 execute(0, nZm1, fftXYLoop[threadIndex]);
789 execute(0, nXm1, transposeLoop[threadIndex]);
790 execute(0, nXm1, fftZIZLoop[threadIndex]);
791 execute(0, nZm1, unTransposeLoop[threadIndex]);
792 execute(0, nZm1, ifftXYLoop[threadIndex]);
793 }
794 } catch (Exception e) {
795 logger.severe(e.toString());
796 }
797 }
798 }
799
800 private class FFTXYLoop extends IntegerForLoop {
801
802 private Complex2D localFFTXY;
803 private long time;
804
805 @Override
806 public void run(final int lb, final int ub) {
807 for (int z = lb; z <= ub; z++) {
808 int offset = z * nextZ;
809 localFFTXY.fft(input, offset);
810 }
811 }
812
813 @Override
814 public IntegerSchedule schedule() {
815 return schedule;
816 }
817
818 public long getTime() {
819 return time;
820 }
821
822 public void finish() {
823 time += System.nanoTime();
824 }
825
826 @Override
827 public void start() {
828 time -= System.nanoTime();
829 localFFTXY = fftXY[getThreadIndex()];
830 }
831 }
832
833 private class IFFTXYLoop extends IntegerForLoop {
834
835 private Complex2D localFFTXY;
836 private long time;
837
838 @Override
839 public void run(final int lb, final int ub) {
840 for (int z = lb; z <= ub; z++) {
841 int offset = z * nextZ;
842 localFFTXY.ifft(input, offset);
843 }
844 }
845
846 @Override
847 public IntegerSchedule schedule() {
848 return schedule;
849 }
850
851 public long getTime() {
852 return time;
853 }
854
855 public void finish() {
856 time += System.nanoTime();
857 }
858
859 @Override
860 public void start() {
861 time -= System.nanoTime();
862 localFFTXY = fftXY[getThreadIndex()];
863 }
864 }
865
866 private class FFTZLoop extends IntegerForLoop {
867
868 private Complex localFFTZ;
869 private long time;
870 private final double[] work;
871
872 private FFTZLoop() {
873 if (localZTranspose) {
874 work = new double[2 * nY * nZ];
875 } else {
876 work = null;
877 }
878 }
879
880 @Override
881 public void run(final int lb, final int ub) {
882 if (localZTranspose) {
883 for (int x = lb; x <= ub; x++) {
884 int inputOffset = x * nextX;
885 transpose(input, inputOffset, work);
886 localFFTZ.fft(work, 0, ii);
887 unTranspose(input, inputOffset, work);
888 }
889 } else {
890 for (int x = lb; x <= ub; x++) {
891 int offset = x * nY * nZ * ii;
892 localFFTZ.fft(work3D, offset, ii);
893 }
894 }
895 }
896
897 @Override
898 public IntegerSchedule schedule() {
899 return schedule;
900 }
901
902 public long getTime() {
903 return time;
904 }
905
906 @Override
907 public void finish() {
908 time += System.nanoTime();
909 }
910
911 @Override
912 public void start() {
913 time -= System.nanoTime();
914 int threadID = getThreadIndex();
915 localFFTZ = fftZ[threadID];
916 }
917 }
918
919 private class IFFTZLoop extends IntegerForLoop {
920
921 private Complex localFFTZ;
922 private long time;
923 private final double[] work;
924
925 private IFFTZLoop() {
926 if (localZTranspose) {
927 work = new double[2 * nY * nZ];
928 } else {
929 work = null;
930 }
931 }
932
933 @Override
934 public void run(final int lb, final int ub) {
935 if (localZTranspose) {
936 for (int x = lb; x <= ub; x++) {
937 int inputOffset = x * nextX;
938 transpose(input, inputOffset, work);
939 localFFTZ.ifft(work, 0, ii);
940 unTranspose(input, inputOffset, work);
941 }
942 } else {
943 for (int x = lb; x <= ub; x++) {
944 int offset = x * nY * nZ * ii;
945 localFFTZ.ifft(work3D, offset, ii);
946 }
947 }
948 }
949
950 @Override
951 public IntegerSchedule schedule() {
952 return schedule;
953 }
954
955 public long getTime() {
956 return time;
957 }
958
959 @Override
960 public void finish() {
961 time += System.nanoTime();
962 }
963
964 @Override
965 public void start() {
966 time -= System.nanoTime();
967 int threadID = getThreadIndex();
968 localFFTZ = fftZ[threadID];
969 }
970 }
971
972 private class FFTZIZLoop extends IntegerForLoop {
973
974 private Complex localFFTZ;
975 private long time;
976 private final double[] work;
977
978 private FFTZIZLoop() {
979 if (localZTranspose) {
980 work = new double[2 * nY * nZ];
981 } else {
982 work = null;
983 }
984 }
985
986 @Override
987 public void run(final int lb, final int ub) {
988 if (localZTranspose) {
989 for (int x = lb; x <= ub; x++) {
990 int inputOffset = x * nextX;
991 transpose(input, inputOffset, work);
992 localFFTZ.fft(work, 0, ii);
993 int recipOffset = x * nY * nZ;
994 recipConv(recipOffset, work, 0);
995 localFFTZ.ifft(work, 0, ii);
996 unTranspose(input, inputOffset, work);
997 }
998 } else {
999 for (int x = lb; x <= ub; x++) {
1000
1001 int offset = x * nY * nZ * ii;
1002 localFFTZ.fft(work3D, offset, ii);
1003 int recipOffset = x * nY * nZ;
1004 recipConv(recipOffset, work3D, offset);
1005 localFFTZ.ifft(work3D, offset, ii);
1006 }
1007 }
1008 }
1009
1010 @Override
1011 public IntegerSchedule schedule() {
1012 return schedule;
1013 }
1014
1015 public long getTime() {
1016 return time;
1017 }
1018
1019 public void finish() {
1020 time += System.nanoTime();
1021 }
1022
1023 @Override
1024 public void start() {
1025 time -= System.nanoTime();
1026 int threadID = getThreadIndex();
1027 localFFTZ = fftZ[threadID];
1028 }
1029 }
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039 private class TransposeLoop extends IntegerForLoop {
1040
1041 private long time;
1042
1043 @Override
1044 public void run(final int lb, final int ub) {
1045 if (internalImZ == 1) {
1046 for (int z = 0; z < nZ; z++) {
1047 for (int x = lb; x <= ub; x++) {
1048 int y = 0;
1049 int i = 0;
1050 int iZX = x * nextX + z * nextZ;
1051 int trZX = x * trNextX + z * trNextZ;
1052 for (; y < nY - 3; y += 4, i += 8) {
1053 int i1 = iZX + y * nextY;
1054 int i2 = i1 + nextY;
1055 int i3 = i2 + nextY;
1056 int i4 = i3 + nextY;
1057 int dest = trZX + y * trNextY;
1058 work3D[dest] = input[i1];
1059 work3D[dest + 1] = input[i1 + im];
1060 work3D[dest + 2] = input[i2];
1061 work3D[dest + 3] = input[i2 + im];
1062 work3D[dest + 4] = input[i3];
1063 work3D[dest + 5] = input[i3 + im];
1064 work3D[dest + 6] = input[i4];
1065 work3D[dest + 7] = input[i4 + im];
1066 }
1067 for (; y < nY; y++, i += 2) {
1068 int i1 = iZX + y * nextY;
1069 int dest = trZX + y * trNextY;
1070 work3D[dest] = input[i1];
1071 work3D[dest + 1] = input[i1 + im];
1072 }
1073 }
1074 }
1075 } else {
1076 for (int z = 0; z < nZ; z++) {
1077 for (int x = lb; x <= ub; x++) {
1078 int y = 0;
1079 int i = 0;
1080 int iZX = x * nextX + z * nextZ;
1081 int trZX = x * trNextX + z * trNextZ;
1082 for (; y < nY - 3; y += 4, i += 4) {
1083 int i1 = iZX + y * nextY;
1084 int i2 = i1 + nextY;
1085 int i3 = i2 + nextY;
1086 int i4 = i3 + nextY;
1087
1088 int destPos = trZX + y * trNextY;
1089 work3D[destPos] = input[i1];
1090 work3D[destPos + 1] = input[i2];
1091 work3D[destPos + 2] = input[i3];
1092 work3D[destPos + 3] = input[i4];
1093
1094 work3D[destPos + internalImZ] = input[i1 + im];
1095 work3D[destPos + 1 + internalImZ] = input[i2 + im];
1096 work3D[destPos + 2 + internalImZ] = input[i3 + im];
1097 work3D[destPos + 3 + internalImZ] = input[i4 + im];
1098 }
1099 for (; y < nY; y++, i++) {
1100 int i1 = iZX + y * nextY;
1101 int destPos = trZX + y * trNextY;
1102 work3D[destPos] = input[i1];
1103 work3D[destPos + internalImZ] = input[i1 + im];
1104 }
1105 }
1106 }
1107 }
1108 }
1109
1110
1111
1112
1113
1114
1115 @Override
1116 public IntegerSchedule schedule() {
1117 return IntegerSchedule.fixed();
1118 }
1119
1120 public long getTime() {
1121 return time;
1122 }
1123
1124 @Override
1125 public void finish() {
1126 time += System.nanoTime();
1127 }
1128
1129 @Override
1130 public void start() {
1131 time -= System.nanoTime();
1132 }
1133 }
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144 private class UnTransposeLoop extends IntegerForLoop {
1145
1146 private long time;
1147
1148 @Override
1149 public void run(final int lb, final int ub) {
1150
1151 for (int z = lb; z <= ub; z++) {
1152 int trZ = z * trNextZ;
1153 int iZ = z * nextZ;
1154 for (int x = 0; x < nX; x++) {
1155 int trZX = trZ + x * trNextX;
1156 int iZX = iZ + x * nextX;
1157 int y = 0;
1158 for (; y < nY - 3; y += 4) {
1159 int w1 = y * trNextY + trZX;
1160 int w2 = w1 + trNextY;
1161 int w3 = w2 + trNextY;
1162 int w4 = w3 + trNextY;
1163 int i1 = y * nextY + iZX;
1164 int i2 = i1 + nextY;
1165 int i3 = i2 + nextY;
1166 int i4 = i3 + nextY;
1167 input[i1] = work3D[w1];
1168 input[i1 + im] = work3D[w1 + internalImZ];
1169 input[i2] = work3D[w2];
1170 input[i2 + im] = work3D[w2 + internalImZ];
1171 input[i3] = work3D[w3];
1172 input[i3 + im] = work3D[w3 + internalImZ];
1173 input[i4] = work3D[w4];
1174 input[i4 + im] = work3D[w4 + internalImZ];
1175 }
1176 for (; y < nY; y++) {
1177 int workIndex = y * trNextY + trZX;
1178 int inputIndex = y * nextY + iZX;
1179 input[inputIndex] = work3D[workIndex];
1180 input[inputIndex + im] = work3D[workIndex + internalImZ];
1181 }
1182 }
1183 }
1184 }
1185
1186
1187
1188
1189
1190
1191 @Override
1192 public IntegerSchedule schedule() {
1193 return IntegerSchedule.fixed();
1194 }
1195
1196 public long getTime() {
1197 return time;
1198 }
1199
1200 @Override
1201 public void finish() {
1202 time += System.nanoTime();
1203 }
1204
1205 @Override
1206 public void start() {
1207 time -= System.nanoTime();
1208 }
1209 }
1210
1211
1212
1213
1214
1215
1216
1217
1218 private void recipConv(int recipOffset, double[] work, int workOffset) {
1219 if (useSIMD && internalImZ == 1) {
1220
1221 recipConvSIMD(recipOffset, work, workOffset);
1222
1223 } else {
1224 recipConvScalar(recipOffset, work, workOffset);
1225 }
1226 }
1227
1228
1229
1230
1231
1232
1233
1234
1235 private void recipConvScalar(int recipOffset, double[] work, int workOffset) {
1236 int index = workOffset;
1237 int rindex = recipOffset;
1238 for (int i = 0; i < nY * nZ; i++) {
1239 double r = recip[rindex++];
1240 work[index] *= r;
1241 work[index + internalImZ] *= r;
1242 index += ii;
1243 }
1244 }
1245
1246
1247
1248
1249
1250
1251
1252
1253 private void recipConvSIMD(int recipOffset, double[] work, int workOffset) {
1254
1255
1256 if (internalImZ != 1) {
1257 logger.severe(" Real and imaginary parts must be interleaved.");
1258 }
1259
1260
1261 int length = nY * nZ * 2;
1262
1263 int vectorSize2 = vectorSize * 2;
1264 int vectorizedLength = (length / vectorSize2) * vectorSize2;
1265
1266
1267 int i = 0;
1268 for (; i < vectorizedLength; i += vectorSize2) {
1269
1270 DoubleVector recipVector = DoubleVector.fromArray(species, recip, recipOffset + i / 2);
1271
1272
1273 DoubleVector complexVector = DoubleVector.fromArray(species, work, workOffset + i);
1274 DoubleVector firstHalf = recipVector.rearrange(expandFirstHalf);
1275 complexVector = complexVector.mul(firstHalf);
1276 complexVector.intoArray(work, workOffset + i);
1277
1278
1279 complexVector = DoubleVector.fromArray(species, work, workOffset + vectorSize + i);
1280 DoubleVector secondHalf = recipVector.rearrange(expandSecondHalf);
1281 complexVector = complexVector.mul(secondHalf);
1282 complexVector.intoArray(work, workOffset + vectorSize + i);
1283 }
1284
1285
1286 for (; i < length; i += 2) {
1287 double r = recip[recipOffset + i / 2];
1288 work[workOffset + i] *= r;
1289 work[workOffset + i + internalImZ] *= r;
1290 }
1291 }
1292
1293
1294
1295
1296
1297
1298
1299
1300 private void transpose(double[] input, int inputOffset, double[] output) {
1301
1302
1303 for (int z = 0; z < nZ; z++) {
1304 for (int y = 0; y < nY; y++) {
1305 double real = input[inputOffset + y * nextY + z * nextZ];
1306 double imag = input[inputOffset + y * nextY + z * nextZ + im];
1307 output[y * trNextY + z * trNextZ] = real;
1308 output[y * trNextY + z * trNextZ + internalImZ] = imag;
1309 }
1310 }
1311 }
1312
1313
1314
1315
1316
1317
1318
1319
1320 private void unTranspose(double[] input, int inputOffset, double[] output) {
1321
1322
1323 for (int z = 0; z < nZ; z++) {
1324 for (int y = 0; y < nY; y++) {
1325 double real = output[y * trNextY + z * trNextZ];
1326 double imag = output[y * trNextY + z * trNextZ + internalImZ];
1327 input[inputOffset + y * nextY + z * nextZ] = real;
1328 input[inputOffset + y * nextY + z * nextZ + im] = imag;
1329 }
1330 }
1331 }
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341 public static double[] initRandomData(int dim, ParallelTeam parallelTeam) {
1342 int n = dim * dim * dim;
1343 double[] data = new double[2 * n];
1344 try {
1345 parallelTeam.execute(
1346 new ParallelRegion() {
1347 @Override
1348 public void run() {
1349 try {
1350 execute(
1351 0,
1352 dim - 1,
1353 new IntegerForLoop() {
1354 @Override
1355 public void run(final int lb, final int ub) {
1356 Random randomNumberGenerator = new Random(1);
1357 int index = dim * dim * lb * 2;
1358 for (int i = lb; i <= ub; i++) {
1359 for (int j = 0; j < dim; j++) {
1360 for (int k = 0; k < dim; k++) {
1361 double randomNumber = randomNumberGenerator.nextDouble();
1362 data[index] = randomNumber;
1363 index += 2;
1364 }
1365 }
1366 }
1367 }
1368 });
1369 } catch (Exception e) {
1370 System.out.println(e.getMessage());
1371 System.exit(-1);
1372 }
1373 }
1374 });
1375 } catch (Exception e) {
1376 System.out.println(e.getMessage());
1377 System.exit(-1);
1378 }
1379 return data;
1380 }
1381
1382
1383
1384
1385
1386
1387
1388
1389 public static void main(String[] args) throws Exception {
1390 int dimNotFinal = 128;
1391 int nCPU = ParallelTeam.getDefaultThreadCount();
1392 int reps = 5;
1393 boolean blocked = false;
1394 try {
1395 dimNotFinal = Integer.parseInt(args[0]);
1396 if (dimNotFinal < 1) {
1397 dimNotFinal = 100;
1398 }
1399 nCPU = Integer.parseInt(args[1]);
1400 if (nCPU < 1) {
1401 nCPU = ParallelTeam.getDefaultThreadCount();
1402 }
1403 reps = Integer.parseInt(args[2]);
1404 if (reps < 1) {
1405 reps = 5;
1406 }
1407 blocked = Boolean.parseBoolean(args[3]);
1408 } catch (Exception e) {
1409
1410 }
1411 final int dim = dimNotFinal;
1412 System.out.printf("Initializing a %d cubed grid for %d CPUs.\n"
1413 + "The best timing out of %d repetitions will be used.%n",
1414 dim, nCPU, reps);
1415
1416 Complex3DParallel complex3D;
1417 Complex3DParallel complex3DParallel;
1418 ParallelTeam parallelTeam = new ParallelTeam(nCPU);
1419 ParallelTeam parallelTeam1 = new ParallelTeam(1);
1420 if (blocked) {
1421 complex3D = new Complex3DParallel(dim, dim, dim, parallelTeam1, DataLayout3D.BLOCKED_X);
1422 complex3DParallel = new Complex3DParallel(dim, dim, dim, parallelTeam, DataLayout3D.BLOCKED_X);
1423 } else {
1424 complex3D = new Complex3DParallel(dim, dim, dim, parallelTeam1, DataLayout3D.INTERLEAVED);
1425 complex3DParallel = new Complex3DParallel(dim, dim, dim, parallelTeam, DataLayout3D.INTERLEAVED);
1426 }
1427 final int dimCubed = dim * dim * dim;
1428 final double[] data = initRandomData(dim, parallelTeam);
1429 final double[] work = new double[dimCubed];
1430 Arrays.fill(work, 1.0);
1431
1432 double toSeconds = 0.000000001;
1433 long seqTime = Long.MAX_VALUE;
1434 long parTime = Long.MAX_VALUE;
1435 long seqTimeConv = Long.MAX_VALUE;
1436 long parTimeConv = Long.MAX_VALUE;
1437
1438 complex3D.setRecip(work);
1439 complex3DParallel.setRecip(work);
1440
1441
1442 System.out.println("Warm Up Sequential FFT");
1443 complex3D.fft(data);
1444 System.out.println("Warm Up Sequential IFFT");
1445 complex3D.ifft(data);
1446 System.out.println("Warm Up Sequential Convolution");
1447 complex3D.convolution(data);
1448
1449
1450 for (int i = 0; i < reps; i++) {
1451 System.out.printf(" Iteration %d%n", i + 1);
1452 long time = System.nanoTime();
1453 complex3D.fft(data);
1454 complex3D.ifft(data);
1455 time = (System.nanoTime() - time);
1456 System.out.printf(" Sequential FFT: %9.6f (sec)%n", toSeconds * time);
1457 if (time < seqTime) {
1458 seqTime = time;
1459 }
1460 time = System.nanoTime();
1461 complex3D.convolution(data);
1462 time = (System.nanoTime() - time);
1463 System.out.printf(" Sequential Conv: %9.6f (sec)%n", toSeconds * time);
1464 if (time < seqTimeConv) {
1465 seqTimeConv = time;
1466 }
1467 }
1468
1469
1470 System.out.println("Warm up Parallel FFT");
1471 complex3DParallel.fft(data);
1472 System.out.println("Warm up Parallel IFFT");
1473 complex3DParallel.ifft(data);
1474 System.out.println("Warm up Parallel Convolution");
1475 complex3DParallel.convolution(data);
1476 complex3DParallel.initTiming();
1477
1478 for (int i = 0; i < reps; i++) {
1479
1480 if (i == reps / 2) {
1481 complex3DParallel.initTiming();
1482 }
1483
1484 System.out.printf(" Iteration %d%n", i + 1);
1485 long time = System.nanoTime();
1486 complex3DParallel.fft(data);
1487 complex3DParallel.ifft(data);
1488 time = (System.nanoTime() - time);
1489 System.out.printf(" Parallel FFT: %9.6f (sec)%n", toSeconds * time);
1490 if (time < parTime) {
1491 parTime = time;
1492 }
1493
1494 time = System.nanoTime();
1495 complex3DParallel.convolution(data);
1496 time = (System.nanoTime() - time);
1497 System.out.printf(" Parallel Conv: %9.6f (sec)%n", toSeconds * time);
1498 if (time < parTimeConv) {
1499 parTimeConv = time;
1500 }
1501
1502 }
1503
1504 System.out.printf(" Best Sequential FFT Time: %9.6f (sec)%n", toSeconds * seqTime);
1505 System.out.printf(" Best Sequential Conv. Time: %9.6f (sec)%n", toSeconds * seqTimeConv);
1506 System.out.printf(" Best Parallel FFT Time: %9.6f (sec)%n", toSeconds * parTime);
1507 System.out.printf(" Best Parallel Conv. Time: %9.6f (sec)%n", toSeconds * parTimeConv);
1508 System.out.printf(" 3D FFT Speedup: %9.6f X%n", (double) seqTime / parTime);
1509 System.out.printf(" 3D Conv Speedup: %9.6f X%n", (double) seqTimeConv / parTimeConv);
1510
1511 System.out.printf(" Parallel Convolution Timings:\n" + complex3DParallel.timingString());
1512
1513 parallelTeam.shutdown();
1514 parallelTeam1.shutdown();
1515 }
1516 }