1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import edu.rit.pj.IntegerForLoop;
41 import edu.rit.pj.IntegerSchedule;
42 import edu.rit.pj.ParallelRegion;
43 import edu.rit.pj.ParallelTeam;
44
45 import javax.annotation.Nullable;
46 import java.util.Arrays;
47 import java.util.Random;
48 import java.util.logging.Level;
49 import java.util.logging.Logger;
50
51 import jdk.incubator.vector.DoubleVector;
52 import jdk.incubator.vector.VectorShuffle;
53 import jdk.incubator.vector.VectorSpecies;
54
55 import static java.lang.String.format;
56 import static java.util.Objects.requireNonNullElseGet;
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74 public class Complex3DParallel {
75
76 private static final Logger logger = Logger.getLogger(Complex3DParallel.class.getName());
77
78
79
80 private final int nX;
81
82
83
84 private final int nY;
85
86
87
88 private final int nZ;
89
90
91
92 private final int im;
93
94
95
96 private final int ii;
97
98
99
100 private final int nextX;
101
102
103
104 private final int nextY;
105
106
107
108 private final int nextZ;
109
110
111
112 private final int trNextX;
113
114
115
116 private final int trNextY;
117
118
119
120 private final int trNextZ;
121
122
123
124 private final double[] recip;
125
126
127
128 private final int threadCount;
129
130
131
132 private final ParallelTeam parallelTeam;
133
134
135
136 private final Complex2D[] fftXY;
137
138
139
140
141 private final Complex[] fftZ;
142
143
144
145 private final int internalImZ;
146
147
148
149 private final IntegerSchedule schedule;
150
151
152
153 private final int nXm1;
154
155
156
157 private final int nYm1;
158
159
160
161 private final int nZm1;
162
163
164
165 private final FFTRegion fftRegion;
166
167
168
169 private final IFFTRegion ifftRegion;
170
171
172
173 private final ConvolutionRegion convRegion;
174
175
176
177
178 public double[] input;
179
180
181
182 private boolean useSIMD;
183 private final VectorSpecies<Double> species = DoubleVector.SPECIES_PREFERRED;
184 private final int vectorSize = species.length();
185 private final int[] shuffle = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7};
186 private final VectorShuffle<Double> expandFirstHalf = VectorShuffle.fromArray(species, shuffle, 0);
187 private final VectorShuffle<Double> expandSecondHalf = VectorShuffle.fromArray(species, shuffle, vectorSize);
188
189
190
191
192 private boolean packFFTs;
193 private final boolean localZTranspose;
194
195
196
197 public final double[] work3D;
198
199
200
201
202
203
204
205
206
207
208 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam) {
209 this(nX, nY, nZ, parallelTeam, DataLayout3D.INTERLEAVED);
210 }
211
212
213
214
215
216
217
218
219
220
221
222 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, DataLayout3D dataLayout) {
223 this(nX, nY, nZ, parallelTeam, null, dataLayout);
224 }
225
226
227
228
229
230
231
232
233
234
235
236 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, @Nullable IntegerSchedule integerSchedule) {
237 this(nX, nY, nZ, parallelTeam, integerSchedule, DataLayout3D.INTERLEAVED);
238 }
239
240
241
242
243
244
245
246
247
248
249
250
251 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam,
252 @Nullable IntegerSchedule integerSchedule, DataLayout3D dataLayout) {
253 this.nX = nX;
254 this.nY = nY;
255 this.nZ = nZ;
256 this.parallelTeam = parallelTeam;
257 recip = new double[nX * nY * nZ];
258
259 DataLayout1D dataLayout1D;
260 DataLayout2D dataLayout2D;
261 switch (dataLayout) {
262 default:
263 case INTERLEAVED:
264
265 im = 1;
266 ii = 2;
267 nextX = 2;
268 nextY = 2 * nX;
269 nextZ = 2 * nX * nY;
270
271 dataLayout1D = DataLayout1D.INTERLEAVED;
272 dataLayout2D = DataLayout2D.INTERLEAVED;
273
274 internalImZ = 1;
275 trNextY = 2;
276 trNextZ = 2 * nY;
277 trNextX = 2 * nY * nZ;
278 break;
279 case BLOCKED_X:
280
281 im = nX;
282 ii = 1;
283 nextX = 1;
284 nextY = 2 * nX;
285 nextZ = 2 * nX * nY;
286
287 dataLayout1D = DataLayout1D.BLOCKED;
288 dataLayout2D = DataLayout2D.BLOCKED_X;
289
290 internalImZ = nZ * nY;
291 trNextY = 1;
292 trNextZ = nY;
293 trNextX = 2 * nY * nZ;
294 break;
295 case BLOCKED_XY:
296
297 im = nX * nY;
298 ii = 1;
299 nextX = 1;
300 nextY = nX;
301 nextZ = 2 * nY * nX;
302
303 dataLayout1D = DataLayout1D.BLOCKED;
304 dataLayout2D = DataLayout2D.BLOCKED_XY;
305
306 internalImZ = nY * nZ;
307 trNextY = 1;
308 trNextZ = nY;
309 trNextX = 2 * nY * nZ;
310 break;
311 case BLOCKED_XYZ:
312
313 im = nX * nY * nZ;
314 ii = 1;
315 nextX = 1;
316 nextY = nX;
317 nextZ = nY * nX;
318
319 dataLayout1D = DataLayout1D.BLOCKED;
320 dataLayout2D = DataLayout2D.BLOCKED_XY;
321
322 internalImZ = nY * nZ;
323 trNextY = 1;
324 trNextZ = nY;
325 trNextX = 2 * nY * nZ;
326 break;
327 }
328
329 nXm1 = this.nX - 1;
330 nYm1 = this.nY - 1;
331 nZm1 = this.nZ - 1;
332 threadCount = parallelTeam.getThreadCount();
333 schedule = requireNonNullElseGet(integerSchedule, IntegerSchedule::fixed);
334
335
336 useSIMD = false;
337 String simd = System.getProperty("fft.useSIMD", Boolean.toString(useSIMD));
338 try {
339 useSIMD = Boolean.parseBoolean(simd);
340 } catch (Exception e) {
341 useSIMD = false;
342 }
343
344
345 packFFTs = false;
346 String pack = System.getProperty("fft.packFFTs", Boolean.toString(packFFTs));
347 try {
348 packFFTs = Boolean.parseBoolean(pack);
349 } catch (Exception e) {
350 packFFTs = false;
351 }
352
353 String localTranspose = System.getProperty("fft.localZTranspose", "true");
354 boolean local;
355 try {
356 local = Boolean.parseBoolean(localTranspose);
357 } catch (Exception e) {
358 local = true;
359 }
360 localZTranspose = local;
361
362 fftXY = new Complex2D[threadCount];
363 for (int i = 0; i < threadCount; i++) {
364 fftXY[i] = new Complex2D(nX, nY, dataLayout2D, im);
365 fftXY[i].setPackFFTs(packFFTs);
366 fftXY[i].setUseSIMD(useSIMD);
367 }
368
369 fftZ = new Complex[threadCount];
370 for (int i = 0; i < threadCount; i++) {
371 fftZ[i] = new Complex(nZ, dataLayout1D, internalImZ, nY);
372 fftZ[i].setUseSIMD(useSIMD);
373 }
374
375 if (localZTranspose) {
376 work3D = null;
377 } else {
378 work3D = new double[2 * nX * nY * nZ];
379 }
380
381 fftRegion = new FFTRegion();
382 ifftRegion = new IFFTRegion();
383 convRegion = new ConvolutionRegion();
384 }
385
386 public String toString() {
387 return "Complex3DParallel {" +
388 "nX=" + nX +
389 ", nY=" + nY +
390 ", nZ=" + nZ +
391 ", im=" + im +
392 ", ii=" + ii +
393 ", nextX=" + nextX +
394 ", nextY=" + nextY +
395 ", nextZ=" + nextZ +
396 ", threadCount=" + threadCount +
397 ", parallelTeam=" + parallelTeam +
398 ", internalImZ=" + internalImZ +
399 ", schedule=" + schedule +
400 ", nXm1=" + nXm1 +
401 ", nYm1=" + nYm1 +
402 ", nZm1=" + nZm1 +
403 ", fftRegion=" + fftRegion +
404 ", ifftRegion=" + ifftRegion +
405 ", convRegion=" + convRegion +
406 ", useSIMD=" + useSIMD +
407 ", packFFTs=" + packFFTs +
408 '}';
409 }
410
411
412
413
414
415
416 public void setUseSIMD(boolean useSIMD) {
417 this.useSIMD = useSIMD;
418 for (int i = 0; i < threadCount; i++) {
419 fftXY[i].setUseSIMD(useSIMD);
420 fftZ[i].setUseSIMD(useSIMD);
421 }
422 }
423
424
425
426
427
428
429 public void setPackFFTs(boolean packFFTs) {
430 this.packFFTs = packFFTs;
431 for (int i = 0; i < threadCount; i++) {
432 fftXY[i].setPackFFTs(packFFTs);
433 }
434 }
435
436
437
438
439
440
441
442
443 public void convolution(final double[] input) {
444 this.input = input;
445 try {
446 parallelTeam.execute(convRegion);
447 } catch (Exception e) {
448 String message = "Fatal exception evaluating a convolution.\n";
449 logger.log(Level.SEVERE, message, e);
450 }
451 }
452
453
454
455
456
457
458
459 public void fft(final double[] input) {
460 this.input = input;
461 try {
462 parallelTeam.execute(fftRegion);
463 } catch (Exception e) {
464 String message = " Fatal exception evaluating the FFT.\n";
465 logger.log(Level.SEVERE, message, e);
466 }
467 }
468
469
470
471
472
473
474 public long[] getTiming() {
475 return convRegion.getTiming();
476 }
477
478
479
480
481
482
483
484 public void ifft(final double[] input) {
485 this.input = input;
486 try {
487 parallelTeam.execute(ifftRegion);
488 } catch (Exception e) {
489 String message = "Fatal exception evaluating the inverse FFT.\n";
490 logger.log(Level.SEVERE, message, e);
491 System.exit(-1);
492 }
493 }
494
495
496
497
498 public void initTiming() {
499 convRegion.initTiming();
500 }
501
502
503
504
505
506
507 public String timingString() {
508 return convRegion.timingString();
509 }
510
511
512
513
514
515
516 public void setRecip(double[] recip) {
517
518
519
520
521
522
523 int recipNextY = nX;
524 int recipNextZ = nY * nX;
525 int index = 0;
526 for (int x = 0; x < nX; x++) {
527 int dx = x;
528 for (int z = 0; z < nZ; z++) {
529 int dz = dx + z * recipNextZ;
530 for (int y = 0; y < nY; y++) {
531 int conv = y * recipNextY + dz;
532 this.recip[index] = recip[conv];
533 index++;
534 }
535 }
536 }
537 }
538
539
540
541
542
543
544
545
546
547
548
549
550 private class FFTRegion extends ParallelRegion {
551
552 private final FFTXYLoop[] fftXYLoop;
553 private final FFTZLoop[] fftZLoop;
554 private final TransposeLoop[] transposeLoop;
555 private final UnTransposeLoop[] unTransposeLoop;
556
557 private FFTRegion() {
558 fftXYLoop = new FFTXYLoop[threadCount];
559 fftZLoop = new FFTZLoop[threadCount];
560 for (int i = 0; i < threadCount; i++) {
561 fftXYLoop[i] = new FFTXYLoop();
562 fftZLoop[i] = new FFTZLoop();
563 }
564 if (!localZTranspose) {
565 transposeLoop = new TransposeLoop[threadCount];
566 unTransposeLoop = new UnTransposeLoop[threadCount];
567 for (int i = 0; i < threadCount; i++) {
568 transposeLoop[i] = new TransposeLoop();
569 unTransposeLoop[i] = new UnTransposeLoop();
570 }
571 } else {
572 transposeLoop = null;
573 unTransposeLoop = null;
574 }
575 }
576
577 @Override
578 public void run() {
579 int threadIndex = getThreadIndex();
580 try {
581 if (localZTranspose) {
582 execute(0, nZm1, fftXYLoop[threadIndex]);
583 execute(0, nXm1, fftZLoop[threadIndex]);
584 } else {
585 execute(0, nZm1, fftXYLoop[threadIndex]);
586 execute(0, nXm1, transposeLoop[threadIndex]);
587 execute(0, nXm1, fftZLoop[threadIndex]);
588 execute(0, nZm1, unTransposeLoop[threadIndex]);
589 }
590 } catch (Exception e) {
591 logger.severe(e.toString());
592 }
593 }
594 }
595
596
597
598
599
600
601
602
603
604
605
606
607 private class IFFTRegion extends ParallelRegion {
608
609 private final IFFTXYLoop[] ifftXYLoop;
610 private final IFFTZLoop[] ifftZLoop;
611 private final TransposeLoop[] transposeLoop;
612 private final UnTransposeLoop[] unTransposeLoop;
613
614 private IFFTRegion() {
615 ifftXYLoop = new IFFTXYLoop[threadCount];
616 ifftZLoop = new IFFTZLoop[threadCount];
617 for (int i = 0; i < threadCount; i++) {
618 ifftXYLoop[i] = new IFFTXYLoop();
619 ifftZLoop[i] = new IFFTZLoop();
620 }
621 if (!localZTranspose) {
622 transposeLoop = new TransposeLoop[threadCount];
623 unTransposeLoop = new UnTransposeLoop[threadCount];
624 for (int i = 0; i < threadCount; i++) {
625 transposeLoop[i] = new TransposeLoop();
626 unTransposeLoop[i] = new UnTransposeLoop();
627 }
628 } else {
629 transposeLoop = null;
630 unTransposeLoop = null;
631 }
632 }
633
634 @Override
635 public void run() {
636 int threadIndex = getThreadIndex();
637 try {
638 if (localZTranspose) {
639 execute(0, nXm1, ifftZLoop[threadIndex]);
640 execute(0, nZm1, ifftXYLoop[threadIndex]);
641 } else {
642 execute(0, nXm1, transposeLoop[threadIndex]);
643 execute(0, nXm1, ifftZLoop[threadIndex]);
644 execute(0, nZm1, unTransposeLoop[threadIndex]);
645 execute(0, nZm1, ifftXYLoop[threadIndex]);
646 }
647 } catch (Exception e) {
648 logger.severe(e.toString());
649 }
650 }
651 }
652
653
654
655
656
657
658
659
660
661
662
663
664
665 private class ConvolutionRegion extends ParallelRegion {
666
667 private final FFTXYLoop[] fftXYLoop;
668 private final TransposeLoop[] transposeLoop;
669 private final FFTZIZLoop[] fftZIZLoop;
670 private final UnTransposeLoop[] unTransposeLoop;
671 private final IFFTXYLoop[] ifftXYLoop;
672 private final long[] convTime;
673
674 private ConvolutionRegion() {
675 fftXYLoop = new FFTXYLoop[threadCount];
676 fftZIZLoop = new FFTZIZLoop[threadCount];
677 ifftXYLoop = new IFFTXYLoop[threadCount];
678 convTime = new long[threadCount];
679 for (int i = 0; i < threadCount; i++) {
680 fftXYLoop[i] = new FFTXYLoop();
681 fftZIZLoop[i] = new FFTZIZLoop();
682 ifftXYLoop[i] = new IFFTXYLoop();
683 }
684 if (!localZTranspose) {
685 transposeLoop = new TransposeLoop[threadCount];
686 unTransposeLoop = new UnTransposeLoop[threadCount];
687 for (int i = 0; i < threadCount; i++) {
688 transposeLoop[i] = new TransposeLoop();
689 unTransposeLoop[i] = new UnTransposeLoop();
690 }
691 } else {
692 transposeLoop = null;
693 unTransposeLoop = null;
694 }
695 }
696
697 public void initTiming() {
698 for (int i = 0; i < threadCount; i++) {
699 fftXYLoop[i].time = 0;
700 fftZIZLoop[i].time = 0;
701 ifftXYLoop[i].time = 0;
702 }
703 if (!localZTranspose) {
704 for (int i = 0; i < threadCount; i++) {
705 transposeLoop[i].time = 0;
706 unTransposeLoop[i].time = 0;
707 }
708 }
709 }
710
711 public long[] getTiming() {
712 if (localZTranspose) {
713 for (int i = 0; i < threadCount; i++) {
714 convTime[i] = convRegion.fftXYLoop[i].time
715 + convRegion.fftZIZLoop[i].time
716 + convRegion.ifftXYLoop[i].time;
717 }
718 } else {
719 for (int i = 0; i < threadCount; i++) {
720 convTime[i] = convRegion.fftXYLoop[i].time
721 + convRegion.transposeLoop[i].time
722 + convRegion.fftZIZLoop[i].time
723 + convRegion.unTransposeLoop[i].time
724 + convRegion.ifftXYLoop[i].time;
725 }
726 }
727 return convTime;
728 }
729
730 public String timingString() {
731 StringBuilder sb = new StringBuilder();
732 if (localZTranspose) {
733 double xysum = 0.0;
734 double zizsum = 0.0;
735 double ixysum = 0.0;
736 for (int i = 0; i < threadCount; i++) {
737 double fftxy = fftXYLoop[i].getTime() * 1e-9;
738 double ziz = fftZIZLoop[i].getTime() * 1e-9;
739 double ifftxy = ifftXYLoop[i].getTime() * 1e-9;
740 String s = format(" Thread %3d: FFTXY=%8.6f, FFTZIZ=%8.6f, IFFTXY=%8.6f\n",
741 i, fftxy, ziz, ifftxy);
742 sb.append(s);
743 xysum += fftxy;
744 zizsum += ziz;
745 ixysum += ifftxy;
746 }
747 String s = format(" Sum : FFTXY=%8.6f, FFTZIZ=%8.6f, IFFTXY=%8.6f\n",
748 xysum, zizsum, ixysum);
749 sb.append(s);
750 } else {
751 double xysum = 0.0;
752 double transsum = 0.0;
753 double zizsum = 0.0;
754 double untranssum = 0.0;
755 double ixysum = 0.0;
756 for (int i = 0; i < threadCount; i++) {
757 double fftxy = fftXYLoop[i].getTime() * 1e-9;
758 double trans = transposeLoop[i].getTime() * 1e-9;
759 double ziz = fftZIZLoop[i].getTime() * 1e-9;
760 double untrans = unTransposeLoop[i].getTime() * 1e-9;
761 double ifftxy = ifftXYLoop[i].getTime() * 1e-9;
762 String s = format(" Thread %3d: FFTXY=%8.6f, Trans=%8.6f, FFTZIZ=%8.6f, UnTrans=%8.6f, IFFTXY=%8.6f\n",
763 i, fftxy, trans, ziz, untrans, ifftxy);
764 sb.append(s);
765 xysum += fftxy;
766 transsum += trans;
767 zizsum += ziz;
768 untranssum += untrans;
769 ixysum += ifftxy;
770 }
771 String s = format(" Sum : FFTXY=%8.6f, Trans=%8.6f, FFTZIZ=%8.6f, UnTrans=%8.6f, IFFTXY=%8.6f\n",
772 xysum, transsum, zizsum, untranssum, ixysum);
773 sb.append(s);
774 }
775
776 return sb.toString();
777 }
778
779 @Override
780 public void run() {
781 int threadIndex = getThreadIndex();
782 try {
783 if (localZTranspose) {
784 execute(0, nZm1, fftXYLoop[threadIndex]);
785 execute(0, nXm1, fftZIZLoop[threadIndex]);
786 execute(0, nZm1, ifftXYLoop[threadIndex]);
787 } else {
788 execute(0, nZm1, fftXYLoop[threadIndex]);
789 execute(0, nXm1, transposeLoop[threadIndex]);
790 execute(0, nXm1, fftZIZLoop[threadIndex]);
791 execute(0, nZm1, unTransposeLoop[threadIndex]);
792 execute(0, nZm1, ifftXYLoop[threadIndex]);
793 }
794 } catch (Exception e) {
795 logger.severe(e.toString());
796 }
797 }
798 }
799
800 private class FFTXYLoop extends IntegerForLoop {
801
802 private Complex2D localFFTXY;
803 private long time;
804
805 @Override
806 public void run(final int lb, final int ub) {
807 for (int z = lb; z <= ub; z++) {
808 int offset = z * nextZ;
809 localFFTXY.fft(input, offset);
810 }
811 }
812
813 @Override
814 public IntegerSchedule schedule() {
815 return schedule;
816 }
817
818 public long getTime() {
819 return time;
820 }
821
822 public void finish() {
823 time += System.nanoTime();
824 }
825
826 @Override
827 public void start() {
828 time -= System.nanoTime();
829 localFFTXY = fftXY[getThreadIndex()];
830 }
831 }
832
833 private class IFFTXYLoop extends IntegerForLoop {
834
835 private Complex2D localFFTXY;
836 private long time;
837
838 @Override
839 public void run(final int lb, final int ub) {
840 for (int z = lb; z <= ub; z++) {
841 int offset = z * nextZ;
842 localFFTXY.ifft(input, offset);
843 }
844 }
845
846 @Override
847 public IntegerSchedule schedule() {
848 return schedule;
849 }
850
851 public long getTime() {
852 return time;
853 }
854
855 public void finish() {
856 time += System.nanoTime();
857 }
858
859 @Override
860 public void start() {
861 time -= System.nanoTime();
862 localFFTXY = fftXY[getThreadIndex()];
863 }
864 }
865
866 private class FFTZLoop extends IntegerForLoop {
867
868 private Complex localFFTZ;
869 private long time;
870 private final double[] work;
871
872 private FFTZLoop() {
873 if (localZTranspose) {
874 work = new double[2 * nY * nZ];
875 } else {
876 work = null;
877 }
878 }
879
880 @Override
881 public void run(final int lb, final int ub) {
882 if (localZTranspose) {
883 for (int x = lb; x <= ub; x++) {
884 int inputOffset = x * nextX;
885 transpose(input, inputOffset, work, 0);
886 localFFTZ.fft(work, 0, ii);
887 unTranspose(input, inputOffset, work, 0);
888 }
889 } else {
890 for (int x = lb; x <= ub; x++) {
891 int offset = x * nY * nZ * ii;
892 localFFTZ.fft(work3D, offset, ii);
893 }
894 }
895 }
896
897 @Override
898 public IntegerSchedule schedule() {
899 return schedule;
900 }
901
902 public long getTime() {
903 return time;
904 }
905
906 @Override
907 public void finish() {
908 time += System.nanoTime();
909 }
910
911 @Override
912 public void start() {
913 time -= System.nanoTime();
914 int threadID = getThreadIndex();
915 localFFTZ = fftZ[threadID];
916 }
917 }
918
919 private class IFFTZLoop extends IntegerForLoop {
920
921 private Complex localFFTZ;
922 private long time;
923 private final double[] work;
924
925 private IFFTZLoop() {
926 if (localZTranspose) {
927 work = new double[2 * nY * nZ];
928 } else {
929 work = null;
930 }
931 }
932
933 @Override
934 public void run(final int lb, final int ub) {
935 if (localZTranspose) {
936 for (int x = lb; x <= ub; x++) {
937 int inputOffset = x * nextX;
938 transpose(input, inputOffset, work, 0);
939 localFFTZ.ifft(work, 0, ii);
940 unTranspose(input, inputOffset, work, 0);
941 }
942 } else {
943 for (int x = lb; x <= ub; x++) {
944 int offset = x * nY * nZ * ii;
945 localFFTZ.ifft(work3D, offset, ii);
946 }
947 }
948 }
949
950 @Override
951 public IntegerSchedule schedule() {
952 return schedule;
953 }
954
955 public long getTime() {
956 return time;
957 }
958
959 @Override
960 public void finish() {
961 time += System.nanoTime();
962 }
963
964 @Override
965 public void start() {
966 time -= System.nanoTime();
967 int threadID = getThreadIndex();
968 localFFTZ = fftZ[threadID];
969 }
970 }
971
972 private class FFTZIZLoop extends IntegerForLoop {
973
974 private Complex localFFTZ;
975 private long time;
976 private final double[] work;
977
978 private FFTZIZLoop() {
979 if (localZTranspose) {
980 work = new double[2 * nY * nZ];
981 } else {
982 work = null;
983 }
984 }
985
986 @Override
987 public void run(final int lb, final int ub) {
988 if (localZTranspose) {
989 for (int x = lb; x <= ub; x++) {
990 int inputOffset = x * nextX;
991 transpose(input, inputOffset, work, 0);
992 localFFTZ.fft(work, 0, ii);
993 int recipOffset = x * nY * nZ;
994 recipConv(recipOffset, work, 0);
995 localFFTZ.ifft(work, 0, ii);
996 unTranspose(input, inputOffset, work, 0);
997 }
998 } else {
999 for (int x = lb; x <= ub; x++) {
1000
1001 int offset = x * nY * nZ * ii;
1002 localFFTZ.fft(work3D, offset, ii);
1003 int recipOffset = x * nY * nZ;
1004 recipConv(recipOffset, work3D, offset);
1005 localFFTZ.ifft(work3D, offset, ii);
1006 }
1007 }
1008 }
1009
1010 @Override
1011 public IntegerSchedule schedule() {
1012 return schedule;
1013 }
1014
1015 public long getTime() {
1016 return time;
1017 }
1018
1019 public void finish() {
1020 time += System.nanoTime();
1021 }
1022
1023 @Override
1024 public void start() {
1025 time -= System.nanoTime();
1026 int threadID = getThreadIndex();
1027 localFFTZ = fftZ[threadID];
1028 }
1029 }
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039 private class TransposeLoop extends IntegerForLoop {
1040
1041 private long time;
1042
1043 @Override
1044 public void run(final int lb, final int ub) {
1045 for (int x = lb; x <= ub; x++) {
1046 for (int z = 0; z < nZ; z++) {
1047 int inputOffset = x * nextX + z * nextZ;
1048 int workOffset = x * trNextX + z * trNextZ;
1049 for (int y = 0; y < nY; y++) {
1050 int inputIndex = inputOffset + y * nextY;
1051 double real = input[inputIndex];
1052 double imag = input[inputIndex + im];
1053 int workIndex = workOffset + y * trNextY;
1054 work3D[workIndex] = real;
1055 work3D[workIndex + internalImZ] = imag;
1056 }
1057 }
1058 }
1059 }
1060
1061
1062
1063
1064
1065
1066 @Override
1067 public IntegerSchedule schedule() {
1068 return IntegerSchedule.fixed();
1069 }
1070
1071 public long getTime() {
1072 return time;
1073 }
1074
1075 @Override
1076 public void finish() {
1077 time += System.nanoTime();
1078 }
1079
1080 @Override
1081 public void start() {
1082 time -= System.nanoTime();
1083 }
1084 }
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095 private class UnTransposeLoop extends IntegerForLoop {
1096
1097 private long time;
1098
1099 @Override
1100 public void run(final int lb, final int ub) {
1101 for (int x = 0; x < nX; x++) {
1102 for (int y = 0; y < nY; y++) {
1103 int workOffset = x * trNextX + y * trNextY;
1104 int inputOffset = x * nextX + y * nextY;
1105 for (int z = lb; z <= ub; z++) {
1106 int workIndex = workOffset + z * trNextZ;
1107 double real = work3D[workIndex];
1108 double imag = work3D[workIndex + internalImZ];
1109 int inputIndex = inputOffset + z * nextZ;
1110 input[inputIndex] = real;
1111 input[inputIndex + im] = imag;
1112 }
1113 }
1114 }
1115 }
1116
1117
1118
1119
1120
1121
1122 @Override
1123 public IntegerSchedule schedule() {
1124 return IntegerSchedule.fixed();
1125 }
1126
1127 public long getTime() {
1128 return time;
1129 }
1130
1131 @Override
1132 public void finish() {
1133 time += System.nanoTime();
1134 }
1135
1136 @Override
1137 public void start() {
1138 time -= System.nanoTime();
1139 }
1140 }
1141
1142
1143
1144
1145
1146
1147
1148
1149 private void recipConv(int recipOffset, double[] work, int workOffset) {
1150 if (useSIMD && internalImZ == 1) {
1151
1152 recipConvSIMD(recipOffset, work, workOffset);
1153
1154 } else {
1155 recipConvScalar(recipOffset, work, workOffset);
1156 }
1157 }
1158
1159
1160
1161
1162
1163
1164
1165
1166 private void recipConvScalar(int recipOffset, double[] work, int workOffset) {
1167 int index = workOffset;
1168 int rindex = recipOffset;
1169 for (int i = 0; i < nY * nZ; i++) {
1170 double r = recip[rindex++];
1171 work[index] *= r;
1172 work[index + internalImZ] *= r;
1173 index += ii;
1174 }
1175 }
1176
1177
1178
1179
1180
1181
1182
1183
1184 private void recipConvSIMD(int recipOffset, double[] work, int workOffset) {
1185
1186
1187 if (internalImZ != 1) {
1188 logger.severe(" Real and imaginary parts must be interleaved.");
1189 }
1190
1191
1192 int length = nY * nZ * 2;
1193
1194 int vectorSize2 = vectorSize * 2;
1195 int vectorizedLength = (length / vectorSize2) * vectorSize2;
1196
1197
1198 int i = 0;
1199 for (; i < vectorizedLength; i += vectorSize2) {
1200
1201 DoubleVector recipVector = DoubleVector.fromArray(species, recip, recipOffset + i / 2);
1202
1203
1204 DoubleVector complexVector = DoubleVector.fromArray(species, work, workOffset + i);
1205 DoubleVector firstHalf = recipVector.rearrange(expandFirstHalf);
1206 complexVector = complexVector.mul(firstHalf);
1207 complexVector.intoArray(work, workOffset + i);
1208
1209
1210 complexVector = DoubleVector.fromArray(species, work, workOffset + vectorSize + i);
1211 DoubleVector secondHalf = recipVector.rearrange(expandSecondHalf);
1212 complexVector = complexVector.mul(secondHalf);
1213 complexVector.intoArray(work, workOffset + vectorSize + i);
1214 }
1215
1216
1217 for (; i < length; i+=2) {
1218 double r = recip[recipOffset + i / 2];
1219 work[workOffset + i] *= r;
1220 work[workOffset + i + internalImZ] *= r;
1221 }
1222 }
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232 private void transpose(double[] input, int inputOffset, double[] output, int outputOffset) {
1233
1234
1235 for (int z = 0; z < nZ; z++) {
1236 for (int y = 0; y < nY; y++) {
1237 double real = input[inputOffset + y * nextY + z * nextZ];
1238 double imag = input[inputOffset + y * nextY + z * nextZ + im];
1239 output[outputOffset + y * trNextY + z * trNextZ] = real;
1240 output[outputOffset + y * trNextY + z * trNextZ + internalImZ] = imag;
1241 }
1242 }
1243 }
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253 private void unTranspose(double[] input, int inputOffset, double[] output, int outputOffset) {
1254
1255
1256 for (int z = 0; z < nZ; z++) {
1257 for (int y = 0; y < nY; y++) {
1258 double real = output[outputOffset + y * trNextY + z * trNextZ];
1259 double imag = output[outputOffset + y * trNextY + z * trNextZ + internalImZ];
1260 input[inputOffset + y * nextY + z * nextZ] = real;
1261 input[inputOffset + y * nextY + z * nextZ + im] = imag;
1262 }
1263 }
1264 }
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274 public static double[] initRandomData(int dim, ParallelTeam parallelTeam) {
1275 int n = dim * dim * dim;
1276 double[] data = new double[2 * n];
1277 try {
1278 parallelTeam.execute(
1279 new ParallelRegion() {
1280 @Override
1281 public void run() {
1282 try {
1283 execute(
1284 0,
1285 dim - 1,
1286 new IntegerForLoop() {
1287 @Override
1288 public void run(final int lb, final int ub) {
1289 Random randomNumberGenerator = new Random(1);
1290 int index = dim * dim * lb * 2;
1291 for (int i = lb; i <= ub; i++) {
1292 for (int j = 0; j < dim; j++) {
1293 for (int k = 0; k < dim; k++) {
1294 double randomNumber = randomNumberGenerator.nextDouble();
1295 data[index] = randomNumber;
1296 index += 2;
1297 }
1298 }
1299 }
1300 }
1301 });
1302 } catch (Exception e) {
1303 System.out.println(e.getMessage());
1304 System.exit(-1);
1305 }
1306 }
1307 });
1308 } catch (Exception e) {
1309 System.out.println(e.getMessage());
1310 System.exit(-1);
1311 }
1312 return data;
1313 }
1314
1315
1316
1317
1318
1319
1320
1321
1322 public static void main(String[] args) throws Exception {
1323 int dimNotFinal = 128;
1324 int nCPU = ParallelTeam.getDefaultThreadCount();
1325 int reps = 5;
1326 boolean blocked = false;
1327 try {
1328 dimNotFinal = Integer.parseInt(args[0]);
1329 if (dimNotFinal < 1) {
1330 dimNotFinal = 100;
1331 }
1332 nCPU = Integer.parseInt(args[1]);
1333 if (nCPU < 1) {
1334 nCPU = ParallelTeam.getDefaultThreadCount();
1335 }
1336 reps = Integer.parseInt(args[2]);
1337 if (reps < 1) {
1338 reps = 5;
1339 }
1340 blocked = Boolean.parseBoolean(args[3]);
1341 } catch (Exception e) {
1342
1343 }
1344 final int dim = dimNotFinal;
1345 System.out.printf("Initializing a %d cubed grid for %d CPUs.\n"
1346 + "The best timing out of %d repetitions will be used.%n",
1347 dim, nCPU, reps);
1348
1349 Complex3D complex3D;
1350 Complex3DParallel complex3DParallel;
1351 ParallelTeam parallelTeam = new ParallelTeam(nCPU);
1352 if (blocked) {
1353 complex3D = new Complex3D(dim, dim, dim, DataLayout3D.BLOCKED_X);
1354 complex3DParallel = new Complex3DParallel(dim, dim, dim, parallelTeam, DataLayout3D.BLOCKED_X);
1355 } else {
1356 complex3D = new Complex3D(dim, dim, dim, DataLayout3D.INTERLEAVED);
1357 complex3DParallel = new Complex3DParallel(dim, dim, dim, parallelTeam, DataLayout3D.INTERLEAVED);
1358 }
1359 final int dimCubed = dim * dim * dim;
1360 final double[] data = initRandomData(dim, parallelTeam);
1361 final double[] work = new double[dimCubed];
1362 Arrays.fill(work, 1.0);
1363
1364 double toSeconds = 0.000000001;
1365 long seqTime = Long.MAX_VALUE;
1366 long parTime = Long.MAX_VALUE;
1367 long seqTimeConv = Long.MAX_VALUE;
1368 long parTimeConv = Long.MAX_VALUE;
1369
1370 complex3D.setRecip(work);
1371 complex3DParallel.setRecip(work);
1372
1373
1374 System.out.println("Warm Up Sequential FFT");
1375 complex3D.fft(data);
1376 System.out.println("Warm Up Sequential IFFT");
1377 complex3D.ifft(data);
1378 System.out.println("Warm Up Sequential Convolution");
1379 complex3D.convolution(data);
1380
1381
1382
1383 for (int i = 0; i < reps; i++) {
1384 System.out.printf(" Iteration %d%n", i + 1);
1385 long time = System.nanoTime();
1386 complex3D.fft(data);
1387 complex3D.ifft(data);
1388 time = (System.nanoTime() - time);
1389 System.out.printf(" Sequential FFT: %9.6f (sec)%n", toSeconds * time);
1390 if (time < seqTime) {
1391 seqTime = time;
1392 }
1393 time = System.nanoTime();
1394 complex3D.convolution(data);
1395 time = (System.nanoTime() - time);
1396 System.out.printf(" Sequential Conv: %9.6f (sec)%n", toSeconds * time);
1397 if (time < seqTimeConv) {
1398 seqTimeConv = time;
1399 }
1400 }
1401
1402
1403 System.out.println("Warm up Parallel FFT");
1404 complex3DParallel.fft(data);
1405 System.out.println("Warm up Parallel IFFT");
1406 complex3DParallel.ifft(data);
1407 System.out.println("Warm up Parallel Convolution");
1408 complex3DParallel.convolution(data);
1409 complex3DParallel.initTiming();
1410
1411 for (int i = 0; i < reps; i++) {
1412
1413 if (i == reps / 2) {
1414 complex3DParallel.initTiming();
1415 }
1416
1417 System.out.printf(" Iteration %d%n", i + 1);
1418 long time = System.nanoTime();
1419 complex3DParallel.fft(data);
1420 complex3DParallel.ifft(data);
1421 time = (System.nanoTime() - time);
1422 System.out.printf(" Parallel FFT: %9.6f (sec)%n", toSeconds * time);
1423 if (time < parTime) {
1424 parTime = time;
1425 }
1426
1427 time = System.nanoTime();
1428 complex3DParallel.convolution(data);
1429 time = (System.nanoTime() - time);
1430 System.out.printf(" Parallel Conv: %9.6f (sec)%n", toSeconds * time);
1431 if (time < parTimeConv) {
1432 parTimeConv = time;
1433 }
1434
1435 }
1436
1437 System.out.printf(" Best Sequential FFT Time: %9.6f (sec)%n", toSeconds * seqTime);
1438 System.out.printf(" Best Sequential Conv. Time: %9.6f (sec)%n", toSeconds * seqTimeConv);
1439 System.out.printf(" Best Parallel FFT Time: %9.6f (sec)%n", toSeconds * parTime);
1440 System.out.printf(" Best Parallel Conv. Time: %9.6f (sec)%n", toSeconds * parTimeConv);
1441 System.out.printf(" 3D FFT Speedup: %9.6f X%n", (double) seqTime / parTime);
1442 System.out.printf(" 3D Conv Speedup: %9.6f X%n", (double) seqTimeConv / parTimeConv);
1443
1444 System.out.printf(" Parallel Convolution Timings:\n" + complex3DParallel.timingString());
1445
1446 parallelTeam.shutdown();
1447 }
1448 }