1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import edu.rit.pj.IntegerForLoop;
41 import edu.rit.pj.IntegerSchedule;
42 import edu.rit.pj.ParallelRegion;
43 import edu.rit.pj.ParallelTeam;
44
45 import javax.annotation.Nullable;
46 import java.util.Arrays;
47 import java.util.Random;
48 import java.util.logging.Level;
49 import java.util.logging.Logger;
50
51 import static java.lang.String.format;
52 import static java.util.Objects.requireNonNullElseGet;
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 public class Complex3DParallel {
71
72 private static final Logger logger = Logger.getLogger(Complex3DParallel.class.getName());
73
74
75
76 private final int nX;
77
78
79
80 private final int nY;
81
82
83
84 private final int nZ;
85
86
87
88 private final int im;
89
90
91
92 private final int ii;
93
94
95
96 private final int nextX;
97
98
99
100 private final int nextY;
101
102
103
104 private final int nextZ;
105
106
107
108 private final int trNextX;
109
110
111
112 private final int trNextY;
113
114
115
116 private final int trNextZ;
117
118
119
120 private final double[] recip;
121
122
123
124 private final int threadCount;
125
126
127
128 private final ParallelTeam parallelTeam;
129
130
131
132 private final Complex2D[] fftXY;
133
134
135
136
137 private final Complex[] fftZ;
138
139
140
141 private final int internalImZ;
142
143
144
145 private final IntegerSchedule schedule;
146
147
148
149 private final int nXm1;
150
151
152
153 private final int nYm1;
154
155
156
157 private final int nZm1;
158
159
160
161 private final FFTRegion fftRegion;
162
163
164
165 private final IFFTRegion ifftRegion;
166
167
168
169 private final ConvolutionRegion convRegion;
170
171
172
173
174 public double[] input;
175
176
177
178 private boolean useSIMD;
179
180
181
182 private boolean packFFTs;
183 private final boolean localZTranspose;
184
185
186
187 public final double[] work3D;
188
189
190
191
192
193
194
195
196
197
198 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam) {
199 this(nX, nY, nZ, parallelTeam, DataLayout3D.INTERLEAVED);
200 }
201
202
203
204
205
206
207
208
209
210
211
212 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, DataLayout3D dataLayout) {
213 this(nX, nY, nZ, parallelTeam, null, dataLayout);
214 }
215
216
217
218
219
220
221
222
223
224
225
226 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, @Nullable IntegerSchedule integerSchedule) {
227 this(nX, nY, nZ, parallelTeam, integerSchedule, DataLayout3D.INTERLEAVED);
228 }
229
230
231
232
233
234
235
236
237
238
239
240
241 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam,
242 @Nullable IntegerSchedule integerSchedule, DataLayout3D dataLayout) {
243 this.nX = nX;
244 this.nY = nY;
245 this.nZ = nZ;
246 this.parallelTeam = parallelTeam;
247 recip = new double[nX * nY * nZ];
248
249 DataLayout1D dataLayout1D;
250 DataLayout2D dataLayout2D;
251 switch (dataLayout) {
252 default:
253 case INTERLEAVED:
254
255 im = 1;
256 ii = 2;
257 nextX = 2;
258 nextY = 2 * nX;
259 nextZ = 2 * nX * nY;
260
261 dataLayout1D = DataLayout1D.INTERLEAVED;
262 dataLayout2D = DataLayout2D.INTERLEAVED;
263
264 internalImZ = 1;
265 trNextY = 2;
266 trNextZ = 2 * nY;
267 trNextX = 2 * nY * nZ;
268 break;
269 case BLOCKED_X:
270
271 im = nX;
272 ii = 1;
273 nextX = 1;
274 nextY = 2 * nX;
275 nextZ = 2 * nX * nY;
276
277 dataLayout1D = DataLayout1D.BLOCKED;
278 dataLayout2D = DataLayout2D.BLOCKED_X;
279
280 internalImZ = nZ * nY;
281 trNextY = 1;
282 trNextZ = nY;
283 trNextX = 2 * nY * nZ;
284 break;
285 case BLOCKED_XY:
286
287 im = nX * nY;
288 ii = 1;
289 nextX = 1;
290 nextY = nX;
291 nextZ = 2 * nY * nX;
292
293 dataLayout1D = DataLayout1D.BLOCKED;
294 dataLayout2D = DataLayout2D.BLOCKED_XY;
295
296 internalImZ = nY * nZ;
297 trNextY = 1;
298 trNextZ = nY;
299 trNextX = 2 * nY * nZ;
300 break;
301 case BLOCKED_XYZ:
302
303 im = nX * nY * nZ;
304 ii = 1;
305 nextX = 1;
306 nextY = nX;
307 nextZ = nY * nX;
308
309 dataLayout1D = DataLayout1D.BLOCKED;
310 dataLayout2D = DataLayout2D.BLOCKED_XY;
311
312 internalImZ = nY * nZ;
313 trNextY = 1;
314 trNextZ = nY;
315 trNextX = 2 * nY * nZ;
316 break;
317 }
318
319 nXm1 = this.nX - 1;
320 nYm1 = this.nY - 1;
321 nZm1 = this.nZ - 1;
322 threadCount = parallelTeam.getThreadCount();
323 schedule = requireNonNullElseGet(integerSchedule, IntegerSchedule::fixed);
324
325
326 useSIMD = false;
327 String simd = System.getProperty("fft.useSIMD", Boolean.toString(useSIMD));
328 try {
329 useSIMD = Boolean.parseBoolean(simd);
330 } catch (Exception e) {
331 useSIMD = false;
332 }
333
334 packFFTs = false;
335 String pack = System.getProperty("fft.packFFTs", Boolean.toString(packFFTs));
336 try {
337 packFFTs = Boolean.parseBoolean(pack);
338 } catch (Exception e) {
339 packFFTs = false;
340 }
341
342 String localTranspose = System.getProperty("fft.localZTranspose", "true");
343 boolean local;
344 try {
345 local = Boolean.parseBoolean(localTranspose);
346 } catch (Exception e) {
347 local = true;
348 }
349 localZTranspose = local;
350
351 fftXY = new Complex2D[threadCount];
352 for (int i = 0; i < threadCount; i++) {
353 fftXY[i] = new Complex2D(nX, nY, dataLayout2D, im);
354 fftXY[i].setPackFFTs(packFFTs);
355 fftXY[i].setUseSIMD(useSIMD);
356 }
357
358 fftZ = new Complex[threadCount];
359 for (int i = 0; i < threadCount; i++) {
360 fftZ[i] = new Complex(nZ, dataLayout1D, internalImZ, nY);
361 fftZ[i].setUseSIMD(useSIMD);
362 }
363 if (!localZTranspose) {
364 work3D = new double[2 * nX * nY * nZ];
365 } else {
366 work3D = null;
367 }
368 fftRegion = new FFTRegion();
369 ifftRegion = new IFFTRegion();
370 convRegion = new ConvolutionRegion();
371 }
372
373 public String toString() {
374 return "Complex3DParallel {" +
375 "nX=" + nX +
376 ", nY=" + nY +
377 ", nZ=" + nZ +
378 ", im=" + im +
379 ", ii=" + ii +
380 ", nextX=" + nextX +
381 ", nextY=" + nextY +
382 ", nextZ=" + nextZ +
383 ", threadCount=" + threadCount +
384 ", parallelTeam=" + parallelTeam +
385 ", internalImZ=" + internalImZ +
386 ", schedule=" + schedule +
387 ", nXm1=" + nXm1 +
388 ", nYm1=" + nYm1 +
389 ", nZm1=" + nZm1 +
390 ", fftRegion=" + fftRegion +
391 ", ifftRegion=" + ifftRegion +
392 ", convRegion=" + convRegion +
393 ", useSIMD=" + useSIMD +
394 ", packFFTs=" + packFFTs +
395 '}';
396 }
397
398
399
400
401
402
403 public void setUseSIMD(boolean useSIMD) {
404 this.useSIMD = useSIMD;
405 for (int i = 0; i < threadCount; i++) {
406 fftXY[i].setUseSIMD(useSIMD);
407 fftZ[i].setUseSIMD(useSIMD);
408 }
409 }
410
411
412
413
414
415
416 public void setPackFFTs(boolean packFFTs) {
417 this.packFFTs = packFFTs;
418 for (int i = 0; i < threadCount; i++) {
419 fftXY[i].setPackFFTs(packFFTs);
420 }
421 }
422
423
424
425
426
427
428
429
430 public void convolution(final double[] input) {
431 this.input = input;
432 try {
433 parallelTeam.execute(convRegion);
434 } catch (Exception e) {
435 String message = "Fatal exception evaluating a convolution.\n";
436 logger.log(Level.SEVERE, message, e);
437 }
438 }
439
440
441
442
443
444
445
446 public void fft(final double[] input) {
447 this.input = input;
448 try {
449 parallelTeam.execute(fftRegion);
450 } catch (Exception e) {
451 String message = " Fatal exception evaluating the FFT.\n";
452 logger.log(Level.SEVERE, message, e);
453 }
454 }
455
456
457
458
459
460
461 public long[] getTiming() {
462 return convRegion.getTiming();
463 }
464
465
466
467
468
469
470
471 public void ifft(final double[] input) {
472 this.input = input;
473 try {
474 parallelTeam.execute(ifftRegion);
475 } catch (Exception e) {
476 String message = "Fatal exception evaluating the inverse FFT.\n";
477 logger.log(Level.SEVERE, message, e);
478 System.exit(-1);
479 }
480 }
481
482
483
484
485 public void initTiming() {
486 convRegion.initTiming();
487 }
488
489
490
491
492
493 public String timingString() {
494 return convRegion.timingString();
495 }
496
497
498
499
500
501
502 public void setRecip(double[] recip) {
503
504
505
506
507
508
509 int recipNextY = nX;
510 int recipNextZ = nY * nX;
511 int index = 0;
512 for (int x = 0; x < nX; x++) {
513 int dx = x;
514 for (int z = 0; z < nZ; z++) {
515 int dz = dx + z * recipNextZ;
516 for (int y = 0; y < nY; y++) {
517 int conv = y * recipNextY + dz;
518 this.recip[index] = recip[conv];
519 index++;
520 }
521 }
522 }
523 }
524
525
526
527
528
529
530
531
532
533
534
535
536 private class FFTRegion extends ParallelRegion {
537
538 private final FFTXYLoop[] fftXYLoop;
539 private final FFTZLoop[] fftZLoop;
540 private final TransposeLoop[] transposeLoop;
541 private final UnTransposeLoop[] unTransposeLoop;
542
543 private FFTRegion() {
544 fftXYLoop = new FFTXYLoop[threadCount];
545 fftZLoop = new FFTZLoop[threadCount];
546 for (int i = 0; i < threadCount; i++) {
547 fftXYLoop[i] = new FFTXYLoop();
548 fftZLoop[i] = new FFTZLoop();
549 }
550 if (!localZTranspose) {
551 transposeLoop = new TransposeLoop[threadCount];
552 unTransposeLoop = new UnTransposeLoop[threadCount];
553 for (int i = 0; i < threadCount; i++) {
554 transposeLoop[i] = new TransposeLoop();
555 unTransposeLoop[i] = new UnTransposeLoop();
556 }
557 } else {
558 transposeLoop = null;
559 unTransposeLoop = null;
560 }
561 }
562
563 @Override
564 public void run() {
565 int threadIndex = getThreadIndex();
566 try {
567 if (localZTranspose) {
568 execute(0, nZm1, fftXYLoop[threadIndex]);
569 execute(0, nXm1, fftZLoop[threadIndex]);
570 } else {
571 execute(0, nZm1, fftXYLoop[threadIndex]);
572 execute(0, nXm1, transposeLoop[threadIndex]);
573 execute(0, nXm1, fftZLoop[threadIndex]);
574 execute(0, nZm1, unTransposeLoop[threadIndex]);
575 }
576 } catch (Exception e) {
577 logger.severe(e.toString());
578 }
579 }
580 }
581
582
583
584
585
586
587
588
589
590
591
592
593 private class IFFTRegion extends ParallelRegion {
594
595 private final IFFTXYLoop[] ifftXYLoop;
596 private final IFFTZLoop[] ifftZLoop;
597 private final TransposeLoop[] transposeLoop;
598 private final UnTransposeLoop[] unTransposeLoop;
599
600 private IFFTRegion() {
601 ifftXYLoop = new IFFTXYLoop[threadCount];
602 ifftZLoop = new IFFTZLoop[threadCount];
603 for (int i = 0; i < threadCount; i++) {
604 ifftXYLoop[i] = new IFFTXYLoop();
605 ifftZLoop[i] = new IFFTZLoop();
606 }
607 if (!localZTranspose) {
608 transposeLoop = new TransposeLoop[threadCount];
609 unTransposeLoop = new UnTransposeLoop[threadCount];
610 for (int i = 0; i < threadCount; i++) {
611 transposeLoop[i] = new TransposeLoop();
612 unTransposeLoop[i] = new UnTransposeLoop();
613 }
614 } else {
615 transposeLoop = null;
616 unTransposeLoop = null;
617 }
618 }
619
620 @Override
621 public void run() {
622 int threadIndex = getThreadIndex();
623 try {
624 if (localZTranspose) {
625 execute(0, nXm1, ifftZLoop[threadIndex]);
626 execute(0, nZm1, ifftXYLoop[threadIndex]);
627 } else {
628 execute(0, nXm1, transposeLoop[threadIndex]);
629 execute(0, nXm1, ifftZLoop[threadIndex]);
630 execute(0, nZm1, unTransposeLoop[threadIndex]);
631 execute(0, nZm1, ifftXYLoop[threadIndex]);
632 }
633 } catch (Exception e) {
634 logger.severe(e.toString());
635 }
636 }
637 }
638
639
640
641
642
643
644
645
646
647
648
649
650
651 private class ConvolutionRegion extends ParallelRegion {
652
653 private final FFTXYLoop[] fftXYLoop;
654 private final TransposeLoop[] transposeLoop;
655 private final FFTZIZLoop[] fftZIZLoop;
656 private final UnTransposeLoop[] unTransposeLoop;
657 private final IFFTXYLoop[] ifftXYLoop;
658 private final long[] convTime;
659
660 private ConvolutionRegion() {
661 fftXYLoop = new FFTXYLoop[threadCount];
662 fftZIZLoop = new FFTZIZLoop[threadCount];
663 ifftXYLoop = new IFFTXYLoop[threadCount];
664 convTime = new long[threadCount];
665 for (int i = 0; i < threadCount; i++) {
666 fftXYLoop[i] = new FFTXYLoop();
667 fftZIZLoop[i] = new FFTZIZLoop();
668 ifftXYLoop[i] = new IFFTXYLoop();
669 }
670 if (!localZTranspose) {
671 transposeLoop = new TransposeLoop[threadCount];
672 unTransposeLoop = new UnTransposeLoop[threadCount];
673 for (int i = 0; i < threadCount; i++) {
674 transposeLoop[i] = new TransposeLoop();
675 unTransposeLoop[i] = new UnTransposeLoop();
676 }
677 } else {
678 transposeLoop = null;
679 unTransposeLoop = null;
680 }
681 }
682
683 public void initTiming() {
684 for (int i = 0; i < threadCount; i++) {
685 fftXYLoop[i].time = 0;
686 fftZIZLoop[i].time = 0;
687 ifftXYLoop[i].time = 0;
688 }
689 if (!localZTranspose) {
690 for (int i = 0; i < threadCount; i++) {
691 transposeLoop[i].time = 0;
692 unTransposeLoop[i].time = 0;
693 }
694 }
695 }
696
697 public long[] getTiming() {
698 if (localZTranspose) {
699 for (int i = 0; i < threadCount; i++) {
700 convTime[i] = convRegion.fftXYLoop[i].time
701 + convRegion.fftZIZLoop[i].time
702 + convRegion.ifftXYLoop[i].time;
703 }
704 } else {
705 for (int i = 0; i < threadCount; i++) {
706 convTime[i] = convRegion.fftXYLoop[i].time
707 + convRegion.transposeLoop[i].time
708 + convRegion.fftZIZLoop[i].time
709 + convRegion.unTransposeLoop[i].time
710 + convRegion.ifftXYLoop[i].time;
711 }
712 }
713 return convTime;
714 }
715
716 public String timingString() {
717 StringBuilder sb = new StringBuilder();
718 if (localZTranspose) {
719 double xysum = 0.0;
720 double zizsum = 0.0;
721 double ixysum = 0.0;
722 for (int i = 0; i < threadCount; i++) {
723 double fftxy = fftXYLoop[i].getTime() * 1e-9;
724 double ziz = fftZIZLoop[i].getTime() * 1e-9;
725 double ifftxy = ifftXYLoop[i].getTime() * 1e-9;
726 String s = format(" Thread %3d: FFTXY=%8.6f, FFTZIZ=%8.6f, IFFTXY=%8.6f\n",
727 i, fftxy, ziz, ifftxy);
728 sb.append(s);
729 xysum += fftxy;
730 zizsum += ziz;
731 ixysum += ifftxy;
732 }
733 String s = format(" Sum : FFTXY=%8.6f, FFTZIZ=%8.6f, IFFTXY=%8.6f\n",
734 xysum, zizsum, ixysum);
735 sb.append(s);
736 } else {
737 double xysum = 0.0;
738 double transsum = 0.0;
739 double zizsum = 0.0;
740 double untranssum = 0.0;
741 double ixysum = 0.0;
742 for (int i = 0; i < threadCount; i++) {
743 double fftxy = fftXYLoop[i].getTime() * 1e-9;
744 double trans = transposeLoop[i].getTime() * 1e-9;
745 double ziz = fftZIZLoop[i].getTime() * 1e-9;
746 double untrans = unTransposeLoop[i].getTime() * 1e-9;
747 double ifftxy = ifftXYLoop[i].getTime() * 1e-9;
748 String s = format(" Thread %3d: FFTXY=%8.6f, Trans=%8.6f, FFTZIZ=%8.6f, UnTrans=%8.6f, IFFTXY=%8.6f\n",
749 i, fftxy, trans, ziz, untrans, ifftxy);
750 sb.append(s);
751 xysum += fftxy;
752 transsum += trans;
753 zizsum += ziz;
754 untranssum += untrans;
755 ixysum += ifftxy;
756 }
757 String s = format(" Sum : FFTXY=%8.6f, Trans=%8.6f, FFTZIZ=%8.6f, UnTrans=%8.6f, IFFTXY=%8.6f\n",
758 xysum, transsum, zizsum, untranssum, ixysum);
759 sb.append(s);
760 }
761
762 return sb.toString();
763 }
764
765 @Override
766 public void run() {
767 int threadIndex = getThreadIndex();
768 try {
769 if (localZTranspose) {
770 execute(0, nZm1, fftXYLoop[threadIndex]);
771 execute(0, nXm1, fftZIZLoop[threadIndex]);
772 execute(0, nZm1, ifftXYLoop[threadIndex]);
773 } else {
774 execute(0, nZm1, fftXYLoop[threadIndex]);
775 execute(0, nXm1, transposeLoop[threadIndex]);
776 execute(0, nXm1, fftZIZLoop[threadIndex]);
777 execute(0, nZm1, unTransposeLoop[threadIndex]);
778 execute(0, nZm1, ifftXYLoop[threadIndex]);
779 }
780 } catch (Exception e) {
781 logger.severe(e.toString());
782 }
783 }
784 }
785
786 private class FFTXYLoop extends IntegerForLoop {
787
788 private Complex2D localFFTXY;
789 private long time;
790
791 @Override
792 public void run(final int lb, final int ub) {
793 for (int z = lb; z <= ub; z++) {
794 int offset = z * nextZ;
795 localFFTXY.fft(input, offset);
796 }
797 }
798
799 @Override
800 public IntegerSchedule schedule() {
801 return schedule;
802 }
803
804 public long getTime() {
805 return time;
806 }
807
808 public void finish() {
809 time += System.nanoTime();
810 }
811
812 @Override
813 public void start() {
814 time -= System.nanoTime();
815 localFFTXY = fftXY[getThreadIndex()];
816 }
817 }
818
819 private class IFFTXYLoop extends IntegerForLoop {
820
821 private Complex2D localFFTXY;
822 private long time;
823
824 @Override
825 public void run(final int lb, final int ub) {
826 for (int z = lb; z <= ub; z++) {
827 int offset = z * nextZ;
828 localFFTXY.ifft(input, offset);
829 }
830 }
831
832 @Override
833 public IntegerSchedule schedule() {
834 return schedule;
835 }
836
837 public long getTime() {
838 return time;
839 }
840
841 public void finish() {
842 time += System.nanoTime();
843 }
844
845 @Override
846 public void start() {
847 time -= System.nanoTime();
848 localFFTXY = fftXY[getThreadIndex()];
849 }
850 }
851
852 private class FFTZLoop extends IntegerForLoop {
853
854 private Complex localFFTZ;
855 private long time;
856 private final double[] work;
857
858 private FFTZLoop() {
859 if (localZTranspose) {
860 work = new double[2 * nY * nZ];
861 } else {
862 work = null;
863 }
864 }
865
866 @Override
867 public void run(final int lb, final int ub) {
868 if (localZTranspose) {
869 for (int x = lb; x <= ub; x++) {
870 int inputOffset = x * nextX;
871 transpose(input, inputOffset, work, 0);
872 localFFTZ.fft(work, 0, ii);
873 unTranspose(input, inputOffset, work, 0);
874 }
875 } else {
876 for (int x = lb; x <= ub; x++) {
877 int offset = x * nY * nZ * ii;
878 localFFTZ.fft(work3D, offset, ii);
879 }
880 }
881 }
882
883 @Override
884 public IntegerSchedule schedule() {
885 return schedule;
886 }
887
888 public long getTime() {
889 return time;
890 }
891
892 @Override
893 public void finish() {
894 time += System.nanoTime();
895 }
896
897 @Override
898 public void start() {
899 time -= System.nanoTime();
900 int threadID = getThreadIndex();
901 localFFTZ = fftZ[threadID];
902 }
903 }
904
905 private class IFFTZLoop extends IntegerForLoop {
906
907 private Complex localFFTZ;
908 private long time;
909 private final double[] work;
910
911 private IFFTZLoop() {
912 if (localZTranspose) {
913 work = new double[2 * nY * nZ];
914 } else {
915 work = null;
916 }
917 }
918
919 @Override
920 public void run(final int lb, final int ub) {
921 if (localZTranspose) {
922 for (int x = lb; x <= ub; x++) {
923 int inputOffset = x * nextX;
924 transpose(input, inputOffset, work, 0);
925 localFFTZ.ifft(work, 0, ii);
926 unTranspose(input, inputOffset, work, 0);
927 }
928 } else {
929 for (int x = lb; x <= ub; x++) {
930 int offset = x * nY * nZ * ii;
931 localFFTZ.ifft(work3D, offset, ii);
932 }
933 }
934 }
935
936 @Override
937 public IntegerSchedule schedule() {
938 return schedule;
939 }
940
941 public long getTime() {
942 return time;
943 }
944
945 @Override
946 public void finish() {
947 time += System.nanoTime();
948 }
949
950 @Override
951 public void start() {
952 time -= System.nanoTime();
953 int threadID = getThreadIndex();
954 localFFTZ = fftZ[threadID];
955 }
956 }
957
958 private class FFTZIZLoop extends IntegerForLoop {
959
960 private Complex localFFTZ;
961 private long time;
962 private final double[] work;
963
964 private FFTZIZLoop() {
965 if (localZTranspose) {
966 work = new double[2 * nY * nZ];
967 } else {
968 work = null;
969 }
970 }
971
972 @Override
973 public void run(final int lb, final int ub) {
974 if (localZTranspose) {
975 for (int x = lb; x <= ub; x++) {
976 int inputOffset = x * nextX;
977 transpose(input, inputOffset, work, 0);
978 localFFTZ.fft(work, 0, ii);
979 int recipOffset = x * nY * nZ;
980 recipConv(recipOffset, work, 0);
981 localFFTZ.ifft(work, 0, ii);
982 unTranspose(input, inputOffset, work, 0);
983 }
984 } else {
985 for (int x = lb; x <= ub; x++) {
986
987 int offset = x * nY * nZ * ii;
988 localFFTZ.fft(work3D, offset, ii);
989 int recipOffset = x * nY * nZ;
990 recipConv(recipOffset, work3D, offset);
991 localFFTZ.ifft(work3D, offset, ii);
992 }
993 }
994 }
995
996 @Override
997 public IntegerSchedule schedule() {
998 return schedule;
999 }
1000
1001 public long getTime() {
1002 return time;
1003 }
1004
1005 public void finish() {
1006 time += System.nanoTime();
1007 }
1008
1009 @Override
1010 public void start() {
1011 time -= System.nanoTime();
1012 int threadID = getThreadIndex();
1013 localFFTZ = fftZ[threadID];
1014 }
1015 }
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025 private class TransposeLoop extends IntegerForLoop {
1026
1027 private long time;
1028
1029 @Override
1030 public void run(final int lb, final int ub) {
1031 for (int x = lb; x <= ub; x++) {
1032 for (int z = 0; z < nZ; z++) {
1033 for (int y = 0; y < nY; y++) {
1034 double real = input[x * nextX + y * nextY + z * nextZ];
1035 double imag = input[x * nextX + y * nextY + z * nextZ + im];
1036 work3D[y * trNextY + z * trNextZ + x * trNextX] = real;
1037 work3D[y * trNextY + z * trNextZ + x * trNextX + internalImZ] = imag;
1038 }
1039 }
1040 }
1041 }
1042
1043
1044
1045
1046
1047
1048 @Override
1049 public IntegerSchedule schedule() {
1050 return IntegerSchedule.fixed();
1051 }
1052
1053 public long getTime() {
1054 return time;
1055 }
1056
1057 @Override
1058 public void finish() {
1059 time += System.nanoTime();
1060 }
1061
1062 @Override
1063 public void start() {
1064 time -= System.nanoTime();
1065 }
1066 }
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077 private class UnTransposeLoop extends IntegerForLoop {
1078
1079 private long time;
1080
1081 @Override
1082 public void run(final int lb, final int ub) {
1083 for (int x = 0; x < nX; x++) {
1084 for (int y = 0; y < nY; y++) {
1085 for (int z = lb; z <= ub; z++) {
1086 double real = work3D[y * trNextY + z * trNextZ + x * trNextX];
1087 double imag = work3D[y * trNextY + z * trNextZ + x * trNextX + internalImZ];
1088 input[x * nextX + y * nextY + z * nextZ] = real;
1089 input[x * nextX + y * nextY + z * nextZ + im] = imag;
1090 }
1091 }
1092 }
1093 }
1094
1095
1096
1097
1098
1099
1100 @Override
1101 public IntegerSchedule schedule() {
1102 return IntegerSchedule.fixed();
1103 }
1104
1105 public long getTime() {
1106 return time;
1107 }
1108
1109 @Override
1110 public void finish() {
1111 time += System.nanoTime();
1112 }
1113
1114 @Override
1115 public void start() {
1116 time -= System.nanoTime();
1117 }
1118 }
1119
1120
1121
1122
1123
1124
1125
1126
1127 private void recipConv(int recipOffset, double[] work, int workOffset) {
1128 int index = workOffset;
1129 int rindex = recipOffset;
1130 for (int i = 0; i < nY * nZ; i++) {
1131 double r = recip[rindex++];
1132 work[index] *= r;
1133 work[index + internalImZ] *= r;
1134 index += ii;
1135 }
1136 }
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146 private void transpose(double[] input, int inputOffset, double[] output, int outputOffset) {
1147
1148
1149 for (int z = 0; z < nZ; z++) {
1150 for (int y = 0; y < nY; y++) {
1151 double real = input[inputOffset + y * nextY + z * nextZ];
1152 double imag = input[inputOffset + y * nextY + z * nextZ + im];
1153 output[outputOffset + y * trNextY + z * trNextZ] = real;
1154 output[outputOffset + y * trNextY + z * trNextZ + internalImZ] = imag;
1155 }
1156 }
1157 }
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167 private void unTranspose(double[] input, int inputOffset, double[] output, int outputOffset) {
1168
1169
1170 for (int z = 0; z < nZ; z++) {
1171 for (int y = 0; y < nY; y++) {
1172 double real = output[outputOffset + y * trNextY + z * trNextZ];
1173 double imag = output[outputOffset + y * trNextY + z * trNextZ + internalImZ];
1174 input[inputOffset + y * nextY + z * nextZ] = real;
1175 input[inputOffset + y * nextY + z * nextZ + im] = imag;
1176 }
1177 }
1178 }
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188 public static double[] initRandomData(int dim, ParallelTeam parallelTeam) {
1189 int n = dim * dim * dim;
1190 double[] data = new double[2 * n];
1191 try {
1192 parallelTeam.execute(
1193 new ParallelRegion() {
1194 @Override
1195 public void run() {
1196 try {
1197 execute(
1198 0,
1199 dim - 1,
1200 new IntegerForLoop() {
1201 @Override
1202 public void run(final int lb, final int ub) {
1203 Random randomNumberGenerator = new Random(1);
1204 int index = dim * dim * lb * 2;
1205 for (int i = lb; i <= ub; i++) {
1206 for (int j = 0; j < dim; j++) {
1207 for (int k = 0; k < dim; k++) {
1208 double randomNumber = randomNumberGenerator.nextDouble();
1209 data[index] = randomNumber;
1210 index += 2;
1211 }
1212 }
1213 }
1214 }
1215 });
1216 } catch (Exception e) {
1217 System.out.println(e.getMessage());
1218 System.exit(-1);
1219 }
1220 }
1221 });
1222 } catch (Exception e) {
1223 System.out.println(e.getMessage());
1224 System.exit(-1);
1225 }
1226 return data;
1227 }
1228
1229
1230
1231
1232
1233
1234
1235
1236 public static void main(String[] args) throws Exception {
1237 int dimNotFinal = 128;
1238 int nCPU = ParallelTeam.getDefaultThreadCount();
1239 int reps = 5;
1240 boolean blocked = false;
1241 try {
1242 dimNotFinal = Integer.parseInt(args[0]);
1243 if (dimNotFinal < 1) {
1244 dimNotFinal = 100;
1245 }
1246 nCPU = Integer.parseInt(args[1]);
1247 if (nCPU < 1) {
1248 nCPU = ParallelTeam.getDefaultThreadCount();
1249 }
1250 reps = Integer.parseInt(args[2]);
1251 if (reps < 1) {
1252 reps = 5;
1253 }
1254 blocked = Boolean.parseBoolean(args[3]);
1255 } catch (Exception e) {
1256
1257 }
1258 final int dim = dimNotFinal;
1259 System.out.printf("Initializing a %d cubed grid for %d CPUs.\n"
1260 + "The best timing out of %d repetitions will be used.%n",
1261 dim, nCPU, reps);
1262
1263 Complex3D complex3D;
1264 Complex3DParallel complex3DParallel;
1265 ParallelTeam parallelTeam = new ParallelTeam(nCPU);
1266 if (blocked) {
1267 complex3D = new Complex3D(dim, dim, dim, DataLayout3D.BLOCKED_X);
1268 complex3DParallel = new Complex3DParallel(dim, dim, dim, parallelTeam, DataLayout3D.BLOCKED_X);
1269 } else {
1270 complex3D = new Complex3D(dim, dim, dim, DataLayout3D.INTERLEAVED);
1271 complex3DParallel = new Complex3DParallel(dim, dim, dim, parallelTeam, DataLayout3D.INTERLEAVED);
1272 }
1273 final int dimCubed = dim * dim * dim;
1274 final double[] data = initRandomData(dim, parallelTeam);
1275 final double[] work = new double[dimCubed];
1276 Arrays.fill(work, 1.0);
1277
1278 double toSeconds = 0.000000001;
1279 long seqTime = Long.MAX_VALUE;
1280 long parTime = Long.MAX_VALUE;
1281 long seqTimeConv = Long.MAX_VALUE;
1282 long parTimeConv = Long.MAX_VALUE;
1283
1284 complex3D.setRecip(work);
1285 complex3DParallel.setRecip(work);
1286
1287
1288 System.out.println("Warm Up Sequential FFT");
1289 complex3D.fft(data);
1290 System.out.println("Warm Up Sequential IFFT");
1291 complex3D.ifft(data);
1292 System.out.println("Warm Up Sequential Convolution");
1293 complex3D.convolution(data);
1294 for (int i = 0; i < reps; i++) {
1295 System.out.printf(" Iteration %d%n", i + 1);
1296 long time = System.nanoTime();
1297 complex3D.fft(data);
1298 complex3D.ifft(data);
1299 time = (System.nanoTime() - time);
1300 System.out.printf(" Sequential FFT: %9.6f (sec)%n", toSeconds * time);
1301 if (time < seqTime) {
1302 seqTime = time;
1303 }
1304 time = System.nanoTime();
1305 complex3D.convolution(data);
1306 time = (System.nanoTime() - time);
1307 System.out.printf(" Sequential Conv: %9.6f (sec)%n", toSeconds * time);
1308 if (time < seqTimeConv) {
1309 seqTimeConv = time;
1310 }
1311 }
1312
1313
1314 System.out.println("Warm up Parallel FFT");
1315 complex3DParallel.fft(data);
1316 System.out.println("Warm up Parallel IFFT");
1317 complex3DParallel.ifft(data);
1318 System.out.println("Warm up Parallel Convolution");
1319 complex3DParallel.convolution(data);
1320 complex3DParallel.initTiming();
1321
1322 for (int i = 0; i < reps; i++) {
1323 System.out.printf(" Iteration %d%n", i + 1);
1324 long time = System.nanoTime();
1325 complex3DParallel.fft(data);
1326 complex3DParallel.ifft(data);
1327 time = (System.nanoTime() - time);
1328 System.out.printf(" Parallel FFT: %9.6f (sec)%n", toSeconds * time);
1329 if (time < parTime) {
1330 parTime = time;
1331 }
1332
1333 time = System.nanoTime();
1334 complex3DParallel.convolution(data);
1335 time = (System.nanoTime() - time);
1336 System.out.printf(" Parallel Conv: %9.6f (sec)%n", toSeconds * time);
1337 if (time < parTimeConv) {
1338 parTimeConv = time;
1339 }
1340 }
1341
1342 System.out.printf(" Best Sequential FFT Time: %9.6f (sec)%n", toSeconds * seqTime);
1343 System.out.printf(" Best Sequential Conv. Time: %9.6f (sec)%n", toSeconds * seqTimeConv);
1344 System.out.printf(" Best Parallel FFT Time: %9.6f (sec)%n", toSeconds * parTime);
1345 System.out.printf(" Best Parallel Conv. Time: %9.6f (sec)%n", toSeconds * parTimeConv);
1346 System.out.printf(" 3D FFT Speedup: %9.6f X%n", (double) seqTime / parTime);
1347 System.out.printf(" 3D Conv Speedup: %9.6f X%n", (double) seqTimeConv / parTimeConv);
1348
1349 System.out.printf(" Parallel Convolution Timings:\n" + complex3DParallel.timingString());
1350
1351 parallelTeam.shutdown();
1352 }
1353 }