1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import edu.rit.pj.IntegerForLoop;
41 import edu.rit.pj.IntegerSchedule;
42 import edu.rit.pj.ParallelRegion;
43 import edu.rit.pj.ParallelTeam;
44
45 import javax.annotation.Nullable;
46 import java.util.Arrays;
47 import java.util.Random;
48 import java.util.logging.Level;
49 import java.util.logging.Logger;
50
51 import static java.lang.String.format;
52 import static java.util.Objects.requireNonNullElseGet;
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 public class Complex3DParallel {
71
72 private static final Logger logger = Logger.getLogger(Complex3DParallel.class.getName());
73
74
75
76 private final int nX;
77
78
79
80 private final int nY;
81
82
83
84 private final int nZ;
85
86
87
88 private final int im;
89
90
91
92 private final int ii;
93
94
95
96 private final int nextX;
97
98
99
100 private final int nextY;
101
102
103
104 private final int nextZ;
105
106
107
108 private final int trNextX;
109
110
111
112 private final int trNextY;
113
114
115
116 private final int trNextZ;
117
118
119
120 private final double[] recip;
121
122
123
124 private final int threadCount;
125
126
127
128 private final ParallelTeam parallelTeam;
129
130
131
132 private final Complex2D[] fftXY;
133
134
135
136
137 private final Complex[] fftZ;
138
139
140
141 private final int internalImZ;
142
143
144
145 private final IntegerSchedule schedule;
146
147
148
149 private final int nXm1;
150
151
152
153 private final int nYm1;
154
155
156
157 private final int nZm1;
158
159
160
161 private final FFTRegion fftRegion;
162
163
164
165 private final IFFTRegion ifftRegion;
166
167
168
169 private final ConvolutionRegion convRegion;
170
171
172
173
174 public double[] input;
175
176
177
178 private boolean useSIMD;
179
180
181
182 private boolean packFFTs;
183 private final boolean localZTranspose;
184
185
186
187 public final double[] work3D;
188
189
190
191
192
193
194
195
196
197
198 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam) {
199 this(nX, nY, nZ, parallelTeam, DataLayout3D.INTERLEAVED);
200 }
201
202
203
204
205
206
207
208
209
210
211
212 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, DataLayout3D dataLayout) {
213 this(nX, nY, nZ, parallelTeam, null, dataLayout);
214 }
215
216
217
218
219
220
221
222
223
224
225
226 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, @Nullable IntegerSchedule integerSchedule) {
227 this(nX, nY, nZ, parallelTeam, integerSchedule, DataLayout3D.INTERLEAVED);
228 }
229
230
231
232
233
234
235
236
237
238
239
240
241 public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam,
242 @Nullable IntegerSchedule integerSchedule, DataLayout3D dataLayout) {
243 this.nX = nX;
244 this.nY = nY;
245 this.nZ = nZ;
246 this.parallelTeam = parallelTeam;
247 recip = new double[nX * nY * nZ];
248
249 DataLayout1D dataLayout1D;
250 DataLayout2D dataLayout2D;
251 switch (dataLayout) {
252 default:
253 case INTERLEAVED:
254
255 im = 1;
256 ii = 2;
257 nextX = 2;
258 nextY = 2 * nX;
259 nextZ = 2 * nX * nY;
260
261 dataLayout1D = DataLayout1D.INTERLEAVED;
262 dataLayout2D = DataLayout2D.INTERLEAVED;
263
264 internalImZ = 1;
265 trNextY = 2;
266 trNextZ = 2 * nY;
267 trNextX = 2 * nY * nZ;
268 break;
269 case BLOCKED_X:
270
271 im = nX;
272 ii = 1;
273 nextX = 1;
274 nextY = 2 * nX;
275 nextZ = 2 * nX * nY;
276
277 dataLayout1D = DataLayout1D.BLOCKED;
278 dataLayout2D = DataLayout2D.BLOCKED_X;
279
280 internalImZ = nZ * nY;
281 trNextY = 1;
282 trNextZ = nY;
283 trNextX = 2 * nY * nZ;
284 break;
285 case BLOCKED_XY:
286
287 im = nX * nY;
288 ii = 1;
289 nextX = 1;
290 nextY = nX;
291 nextZ = 2 * nY * nX;
292
293 dataLayout1D = DataLayout1D.BLOCKED;
294 dataLayout2D = DataLayout2D.BLOCKED_XY;
295
296 internalImZ = nY * nZ;
297 trNextY = 1;
298 trNextZ = nY;
299 trNextX = 2 * nY * nZ;
300 break;
301 case BLOCKED_XYZ:
302
303 im = nX * nY * nZ;
304 ii = 1;
305 nextX = 1;
306 nextY = nX;
307 nextZ = nY * nX;
308
309 dataLayout1D = DataLayout1D.BLOCKED;
310 dataLayout2D = DataLayout2D.BLOCKED_XY;
311
312 internalImZ = nY * nZ;
313 trNextY = 1;
314 trNextZ = nY;
315 trNextX = 2 * nY * nZ;
316 break;
317 }
318
319 nXm1 = this.nX - 1;
320 nYm1 = this.nY - 1;
321 nZm1 = this.nZ - 1;
322 threadCount = parallelTeam.getThreadCount();
323 schedule = requireNonNullElseGet(integerSchedule, IntegerSchedule::fixed);
324
325
326 useSIMD = false;
327 String simd = System.getProperty("fft.useSIMD", Boolean.toString(useSIMD));
328 try {
329 useSIMD = Boolean.parseBoolean(simd);
330 } catch (Exception e) {
331 useSIMD = false;
332 }
333
334
335 packFFTs = false;
336 String pack = System.getProperty("fft.packFFTs", Boolean.toString(packFFTs));
337 try {
338 packFFTs = Boolean.parseBoolean(pack);
339 } catch (Exception e) {
340 packFFTs = false;
341 }
342
343 String localTranspose = System.getProperty("fft.localZTranspose", "true");
344 boolean local;
345 try {
346 local = Boolean.parseBoolean(localTranspose);
347 } catch (Exception e) {
348 local = true;
349 }
350 localZTranspose = local;
351
352 fftXY = new Complex2D[threadCount];
353 for (int i = 0; i < threadCount; i++) {
354 fftXY[i] = new Complex2D(nX, nY, dataLayout2D, im);
355 fftXY[i].setPackFFTs(packFFTs);
356 fftXY[i].setUseSIMD(useSIMD);
357 }
358
359 fftZ = new Complex[threadCount];
360 for (int i = 0; i < threadCount; i++) {
361 fftZ[i] = new Complex(nZ, dataLayout1D, internalImZ, nY);
362 fftZ[i].setUseSIMD(useSIMD);
363 }
364
365 if (localZTranspose) {
366 work3D = null;
367 } else {
368 work3D = new double[2 * nX * nY * nZ];
369 }
370
371 fftRegion = new FFTRegion();
372 ifftRegion = new IFFTRegion();
373 convRegion = new ConvolutionRegion();
374 }
375
376 public String toString() {
377 return "Complex3DParallel {" +
378 "nX=" + nX +
379 ", nY=" + nY +
380 ", nZ=" + nZ +
381 ", im=" + im +
382 ", ii=" + ii +
383 ", nextX=" + nextX +
384 ", nextY=" + nextY +
385 ", nextZ=" + nextZ +
386 ", threadCount=" + threadCount +
387 ", parallelTeam=" + parallelTeam +
388 ", internalImZ=" + internalImZ +
389 ", schedule=" + schedule +
390 ", nXm1=" + nXm1 +
391 ", nYm1=" + nYm1 +
392 ", nZm1=" + nZm1 +
393 ", fftRegion=" + fftRegion +
394 ", ifftRegion=" + ifftRegion +
395 ", convRegion=" + convRegion +
396 ", useSIMD=" + useSIMD +
397 ", packFFTs=" + packFFTs +
398 '}';
399 }
400
401
402
403
404
405
406 public void setUseSIMD(boolean useSIMD) {
407 this.useSIMD = useSIMD;
408 for (int i = 0; i < threadCount; i++) {
409 fftXY[i].setUseSIMD(useSIMD);
410 fftZ[i].setUseSIMD(useSIMD);
411 }
412 }
413
414
415
416
417
418
419 public void setPackFFTs(boolean packFFTs) {
420 this.packFFTs = packFFTs;
421 for (int i = 0; i < threadCount; i++) {
422 fftXY[i].setPackFFTs(packFFTs);
423 }
424 }
425
426
427
428
429
430
431
432
433 public void convolution(final double[] input) {
434 this.input = input;
435 try {
436 parallelTeam.execute(convRegion);
437 } catch (Exception e) {
438 String message = "Fatal exception evaluating a convolution.\n";
439 logger.log(Level.SEVERE, message, e);
440 }
441 }
442
443
444
445
446
447
448
449 public void fft(final double[] input) {
450 this.input = input;
451 try {
452 parallelTeam.execute(fftRegion);
453 } catch (Exception e) {
454 String message = " Fatal exception evaluating the FFT.\n";
455 logger.log(Level.SEVERE, message, e);
456 }
457 }
458
459
460
461
462
463
464 public long[] getTiming() {
465 return convRegion.getTiming();
466 }
467
468
469
470
471
472
473
474 public void ifft(final double[] input) {
475 this.input = input;
476 try {
477 parallelTeam.execute(ifftRegion);
478 } catch (Exception e) {
479 String message = "Fatal exception evaluating the inverse FFT.\n";
480 logger.log(Level.SEVERE, message, e);
481 System.exit(-1);
482 }
483 }
484
485
486
487
488 public void initTiming() {
489 convRegion.initTiming();
490 }
491
492
493
494
495
496
497 public String timingString() {
498 return convRegion.timingString();
499 }
500
501
502
503
504
505
506 public void setRecip(double[] recip) {
507
508
509
510
511
512
513 int recipNextY = nX;
514 int recipNextZ = nY * nX;
515 int index = 0;
516 for (int x = 0; x < nX; x++) {
517 int dx = x;
518 for (int z = 0; z < nZ; z++) {
519 int dz = dx + z * recipNextZ;
520 for (int y = 0; y < nY; y++) {
521 int conv = y * recipNextY + dz;
522 this.recip[index] = recip[conv];
523 index++;
524 }
525 }
526 }
527 }
528
529
530
531
532
533
534
535
536
537
538
539
540 private class FFTRegion extends ParallelRegion {
541
542 private final FFTXYLoop[] fftXYLoop;
543 private final FFTZLoop[] fftZLoop;
544 private final TransposeLoop[] transposeLoop;
545 private final UnTransposeLoop[] unTransposeLoop;
546
547 private FFTRegion() {
548 fftXYLoop = new FFTXYLoop[threadCount];
549 fftZLoop = new FFTZLoop[threadCount];
550 for (int i = 0; i < threadCount; i++) {
551 fftXYLoop[i] = new FFTXYLoop();
552 fftZLoop[i] = new FFTZLoop();
553 }
554 if (!localZTranspose) {
555 transposeLoop = new TransposeLoop[threadCount];
556 unTransposeLoop = new UnTransposeLoop[threadCount];
557 for (int i = 0; i < threadCount; i++) {
558 transposeLoop[i] = new TransposeLoop();
559 unTransposeLoop[i] = new UnTransposeLoop();
560 }
561 } else {
562 transposeLoop = null;
563 unTransposeLoop = null;
564 }
565 }
566
567 @Override
568 public void run() {
569 int threadIndex = getThreadIndex();
570 try {
571 if (localZTranspose) {
572 execute(0, nZm1, fftXYLoop[threadIndex]);
573 execute(0, nXm1, fftZLoop[threadIndex]);
574 } else {
575 execute(0, nZm1, fftXYLoop[threadIndex]);
576 execute(0, nXm1, transposeLoop[threadIndex]);
577 execute(0, nXm1, fftZLoop[threadIndex]);
578 execute(0, nZm1, unTransposeLoop[threadIndex]);
579 }
580 } catch (Exception e) {
581 logger.severe(e.toString());
582 }
583 }
584 }
585
586
587
588
589
590
591
592
593
594
595
596
597 private class IFFTRegion extends ParallelRegion {
598
599 private final IFFTXYLoop[] ifftXYLoop;
600 private final IFFTZLoop[] ifftZLoop;
601 private final TransposeLoop[] transposeLoop;
602 private final UnTransposeLoop[] unTransposeLoop;
603
604 private IFFTRegion() {
605 ifftXYLoop = new IFFTXYLoop[threadCount];
606 ifftZLoop = new IFFTZLoop[threadCount];
607 for (int i = 0; i < threadCount; i++) {
608 ifftXYLoop[i] = new IFFTXYLoop();
609 ifftZLoop[i] = new IFFTZLoop();
610 }
611 if (!localZTranspose) {
612 transposeLoop = new TransposeLoop[threadCount];
613 unTransposeLoop = new UnTransposeLoop[threadCount];
614 for (int i = 0; i < threadCount; i++) {
615 transposeLoop[i] = new TransposeLoop();
616 unTransposeLoop[i] = new UnTransposeLoop();
617 }
618 } else {
619 transposeLoop = null;
620 unTransposeLoop = null;
621 }
622 }
623
624 @Override
625 public void run() {
626 int threadIndex = getThreadIndex();
627 try {
628 if (localZTranspose) {
629 execute(0, nXm1, ifftZLoop[threadIndex]);
630 execute(0, nZm1, ifftXYLoop[threadIndex]);
631 } else {
632 execute(0, nXm1, transposeLoop[threadIndex]);
633 execute(0, nXm1, ifftZLoop[threadIndex]);
634 execute(0, nZm1, unTransposeLoop[threadIndex]);
635 execute(0, nZm1, ifftXYLoop[threadIndex]);
636 }
637 } catch (Exception e) {
638 logger.severe(e.toString());
639 }
640 }
641 }
642
643
644
645
646
647
648
649
650
651
652
653
654
655 private class ConvolutionRegion extends ParallelRegion {
656
657 private final FFTXYLoop[] fftXYLoop;
658 private final TransposeLoop[] transposeLoop;
659 private final FFTZIZLoop[] fftZIZLoop;
660 private final UnTransposeLoop[] unTransposeLoop;
661 private final IFFTXYLoop[] ifftXYLoop;
662 private final long[] convTime;
663
664 private ConvolutionRegion() {
665 fftXYLoop = new FFTXYLoop[threadCount];
666 fftZIZLoop = new FFTZIZLoop[threadCount];
667 ifftXYLoop = new IFFTXYLoop[threadCount];
668 convTime = new long[threadCount];
669 for (int i = 0; i < threadCount; i++) {
670 fftXYLoop[i] = new FFTXYLoop();
671 fftZIZLoop[i] = new FFTZIZLoop();
672 ifftXYLoop[i] = new IFFTXYLoop();
673 }
674 if (!localZTranspose) {
675 transposeLoop = new TransposeLoop[threadCount];
676 unTransposeLoop = new UnTransposeLoop[threadCount];
677 for (int i = 0; i < threadCount; i++) {
678 transposeLoop[i] = new TransposeLoop();
679 unTransposeLoop[i] = new UnTransposeLoop();
680 }
681 } else {
682 transposeLoop = null;
683 unTransposeLoop = null;
684 }
685 }
686
687 public void initTiming() {
688 for (int i = 0; i < threadCount; i++) {
689 fftXYLoop[i].time = 0;
690 fftZIZLoop[i].time = 0;
691 ifftXYLoop[i].time = 0;
692 }
693 if (!localZTranspose) {
694 for (int i = 0; i < threadCount; i++) {
695 transposeLoop[i].time = 0;
696 unTransposeLoop[i].time = 0;
697 }
698 }
699 }
700
701 public long[] getTiming() {
702 if (localZTranspose) {
703 for (int i = 0; i < threadCount; i++) {
704 convTime[i] = convRegion.fftXYLoop[i].time
705 + convRegion.fftZIZLoop[i].time
706 + convRegion.ifftXYLoop[i].time;
707 }
708 } else {
709 for (int i = 0; i < threadCount; i++) {
710 convTime[i] = convRegion.fftXYLoop[i].time
711 + convRegion.transposeLoop[i].time
712 + convRegion.fftZIZLoop[i].time
713 + convRegion.unTransposeLoop[i].time
714 + convRegion.ifftXYLoop[i].time;
715 }
716 }
717 return convTime;
718 }
719
720 public String timingString() {
721 StringBuilder sb = new StringBuilder();
722 if (localZTranspose) {
723 double xysum = 0.0;
724 double zizsum = 0.0;
725 double ixysum = 0.0;
726 for (int i = 0; i < threadCount; i++) {
727 double fftxy = fftXYLoop[i].getTime() * 1e-9;
728 double ziz = fftZIZLoop[i].getTime() * 1e-9;
729 double ifftxy = ifftXYLoop[i].getTime() * 1e-9;
730 String s = format(" Thread %3d: FFTXY=%8.6f, FFTZIZ=%8.6f, IFFTXY=%8.6f\n",
731 i, fftxy, ziz, ifftxy);
732 sb.append(s);
733 xysum += fftxy;
734 zizsum += ziz;
735 ixysum += ifftxy;
736 }
737 String s = format(" Sum : FFTXY=%8.6f, FFTZIZ=%8.6f, IFFTXY=%8.6f\n",
738 xysum, zizsum, ixysum);
739 sb.append(s);
740 } else {
741 double xysum = 0.0;
742 double transsum = 0.0;
743 double zizsum = 0.0;
744 double untranssum = 0.0;
745 double ixysum = 0.0;
746 for (int i = 0; i < threadCount; i++) {
747 double fftxy = fftXYLoop[i].getTime() * 1e-9;
748 double trans = transposeLoop[i].getTime() * 1e-9;
749 double ziz = fftZIZLoop[i].getTime() * 1e-9;
750 double untrans = unTransposeLoop[i].getTime() * 1e-9;
751 double ifftxy = ifftXYLoop[i].getTime() * 1e-9;
752 String s = format(" Thread %3d: FFTXY=%8.6f, Trans=%8.6f, FFTZIZ=%8.6f, UnTrans=%8.6f, IFFTXY=%8.6f\n",
753 i, fftxy, trans, ziz, untrans, ifftxy);
754 sb.append(s);
755 xysum += fftxy;
756 transsum += trans;
757 zizsum += ziz;
758 untranssum += untrans;
759 ixysum += ifftxy;
760 }
761 String s = format(" Sum : FFTXY=%8.6f, Trans=%8.6f, FFTZIZ=%8.6f, UnTrans=%8.6f, IFFTXY=%8.6f\n",
762 xysum, transsum, zizsum, untranssum, ixysum);
763 sb.append(s);
764 }
765
766 return sb.toString();
767 }
768
769 @Override
770 public void run() {
771 int threadIndex = getThreadIndex();
772 try {
773 if (localZTranspose) {
774 execute(0, nZm1, fftXYLoop[threadIndex]);
775 execute(0, nXm1, fftZIZLoop[threadIndex]);
776 execute(0, nZm1, ifftXYLoop[threadIndex]);
777 } else {
778 execute(0, nZm1, fftXYLoop[threadIndex]);
779 execute(0, nXm1, transposeLoop[threadIndex]);
780 execute(0, nXm1, fftZIZLoop[threadIndex]);
781 execute(0, nZm1, unTransposeLoop[threadIndex]);
782 execute(0, nZm1, ifftXYLoop[threadIndex]);
783 }
784 } catch (Exception e) {
785 logger.severe(e.toString());
786 }
787 }
788 }
789
790 private class FFTXYLoop extends IntegerForLoop {
791
792 private Complex2D localFFTXY;
793 private long time;
794
795 @Override
796 public void run(final int lb, final int ub) {
797 for (int z = lb; z <= ub; z++) {
798 int offset = z * nextZ;
799 localFFTXY.fft(input, offset);
800 }
801 }
802
803 @Override
804 public IntegerSchedule schedule() {
805 return schedule;
806 }
807
808 public long getTime() {
809 return time;
810 }
811
812 public void finish() {
813 time += System.nanoTime();
814 }
815
816 @Override
817 public void start() {
818 time -= System.nanoTime();
819 localFFTXY = fftXY[getThreadIndex()];
820 }
821 }
822
823 private class IFFTXYLoop extends IntegerForLoop {
824
825 private Complex2D localFFTXY;
826 private long time;
827
828 @Override
829 public void run(final int lb, final int ub) {
830 for (int z = lb; z <= ub; z++) {
831 int offset = z * nextZ;
832 localFFTXY.ifft(input, offset);
833 }
834 }
835
836 @Override
837 public IntegerSchedule schedule() {
838 return schedule;
839 }
840
841 public long getTime() {
842 return time;
843 }
844
845 public void finish() {
846 time += System.nanoTime();
847 }
848
849 @Override
850 public void start() {
851 time -= System.nanoTime();
852 localFFTXY = fftXY[getThreadIndex()];
853 }
854 }
855
856 private class FFTZLoop extends IntegerForLoop {
857
858 private Complex localFFTZ;
859 private long time;
860 private final double[] work;
861
862 private FFTZLoop() {
863 if (localZTranspose) {
864 work = new double[2 * nY * nZ];
865 } else {
866 work = null;
867 }
868 }
869
870 @Override
871 public void run(final int lb, final int ub) {
872 if (localZTranspose) {
873 for (int x = lb; x <= ub; x++) {
874 int inputOffset = x * nextX;
875 transpose(input, inputOffset, work, 0);
876 localFFTZ.fft(work, 0, ii);
877 unTranspose(input, inputOffset, work, 0);
878 }
879 } else {
880 for (int x = lb; x <= ub; x++) {
881 int offset = x * nY * nZ * ii;
882 localFFTZ.fft(work3D, offset, ii);
883 }
884 }
885 }
886
887 @Override
888 public IntegerSchedule schedule() {
889 return schedule;
890 }
891
892 public long getTime() {
893 return time;
894 }
895
896 @Override
897 public void finish() {
898 time += System.nanoTime();
899 }
900
901 @Override
902 public void start() {
903 time -= System.nanoTime();
904 int threadID = getThreadIndex();
905 localFFTZ = fftZ[threadID];
906 }
907 }
908
909 private class IFFTZLoop extends IntegerForLoop {
910
911 private Complex localFFTZ;
912 private long time;
913 private final double[] work;
914
915 private IFFTZLoop() {
916 if (localZTranspose) {
917 work = new double[2 * nY * nZ];
918 } else {
919 work = null;
920 }
921 }
922
923 @Override
924 public void run(final int lb, final int ub) {
925 if (localZTranspose) {
926 for (int x = lb; x <= ub; x++) {
927 int inputOffset = x * nextX;
928 transpose(input, inputOffset, work, 0);
929 localFFTZ.ifft(work, 0, ii);
930 unTranspose(input, inputOffset, work, 0);
931 }
932 } else {
933 for (int x = lb; x <= ub; x++) {
934 int offset = x * nY * nZ * ii;
935 localFFTZ.ifft(work3D, offset, ii);
936 }
937 }
938 }
939
940 @Override
941 public IntegerSchedule schedule() {
942 return schedule;
943 }
944
945 public long getTime() {
946 return time;
947 }
948
949 @Override
950 public void finish() {
951 time += System.nanoTime();
952 }
953
954 @Override
955 public void start() {
956 time -= System.nanoTime();
957 int threadID = getThreadIndex();
958 localFFTZ = fftZ[threadID];
959 }
960 }
961
962 private class FFTZIZLoop extends IntegerForLoop {
963
964 private Complex localFFTZ;
965 private long time;
966 private final double[] work;
967
968 private FFTZIZLoop() {
969 if (localZTranspose) {
970 work = new double[2 * nY * nZ];
971 } else {
972 work = null;
973 }
974 }
975
976 @Override
977 public void run(final int lb, final int ub) {
978 if (localZTranspose) {
979 for (int x = lb; x <= ub; x++) {
980 int inputOffset = x * nextX;
981 transpose(input, inputOffset, work, 0);
982 localFFTZ.fft(work, 0, ii);
983 int recipOffset = x * nY * nZ;
984 recipConv(recipOffset, work, 0);
985 localFFTZ.ifft(work, 0, ii);
986 unTranspose(input, inputOffset, work, 0);
987 }
988 } else {
989 for (int x = lb; x <= ub; x++) {
990
991 int offset = x * nY * nZ * ii;
992 localFFTZ.fft(work3D, offset, ii);
993 int recipOffset = x * nY * nZ;
994 recipConv(recipOffset, work3D, offset);
995 localFFTZ.ifft(work3D, offset, ii);
996 }
997 }
998 }
999
1000 @Override
1001 public IntegerSchedule schedule() {
1002 return schedule;
1003 }
1004
1005 public long getTime() {
1006 return time;
1007 }
1008
1009 public void finish() {
1010 time += System.nanoTime();
1011 }
1012
1013 @Override
1014 public void start() {
1015 time -= System.nanoTime();
1016 int threadID = getThreadIndex();
1017 localFFTZ = fftZ[threadID];
1018 }
1019 }
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029 private class TransposeLoop extends IntegerForLoop {
1030
1031 private long time;
1032
1033 @Override
1034 public void run(final int lb, final int ub) {
1035 for (int x = lb; x <= ub; x++) {
1036 for (int z = 0; z < nZ; z++) {
1037 int inputOffset = x * nextX + z * nextZ;
1038 int workOffset = x * trNextX + z * trNextZ;
1039 for (int y = 0; y < nY; y++) {
1040 int inputIndex = inputOffset + y * nextY;
1041 double real = input[inputIndex];
1042 double imag = input[inputIndex + im];
1043 int workIndex = workOffset + y * trNextY;
1044 work3D[workIndex] = real;
1045 work3D[workIndex + internalImZ] = imag;
1046 }
1047 }
1048 }
1049 }
1050
1051
1052
1053
1054
1055
1056 @Override
1057 public IntegerSchedule schedule() {
1058 return IntegerSchedule.fixed();
1059 }
1060
1061 public long getTime() {
1062 return time;
1063 }
1064
1065 @Override
1066 public void finish() {
1067 time += System.nanoTime();
1068 }
1069
1070 @Override
1071 public void start() {
1072 time -= System.nanoTime();
1073 }
1074 }
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085 private class UnTransposeLoop extends IntegerForLoop {
1086
1087 private long time;
1088
1089 @Override
1090 public void run(final int lb, final int ub) {
1091 for (int x = 0; x < nX; x++) {
1092 for (int y = 0; y < nY; y++) {
1093 int workOffset = x * trNextX + y * trNextY;
1094 int inputOffset = x * nextX + y * nextY;
1095 for (int z = lb; z <= ub; z++) {
1096 int workIndex = workOffset + z * trNextZ;
1097 double real = work3D[workIndex];
1098 double imag = work3D[workIndex + internalImZ];
1099 int inputIndex = inputOffset + z * nextZ;
1100 input[inputIndex] = real;
1101 input[inputIndex + im] = imag;
1102 }
1103 }
1104 }
1105 }
1106
1107
1108
1109
1110
1111
1112 @Override
1113 public IntegerSchedule schedule() {
1114 return IntegerSchedule.fixed();
1115 }
1116
1117 public long getTime() {
1118 return time;
1119 }
1120
1121 @Override
1122 public void finish() {
1123 time += System.nanoTime();
1124 }
1125
1126 @Override
1127 public void start() {
1128 time -= System.nanoTime();
1129 }
1130 }
1131
1132
1133
1134
1135
1136
1137
1138
1139 private void recipConv(int recipOffset, double[] work, int workOffset) {
1140 int index = workOffset;
1141 int rindex = recipOffset;
1142 for (int i = 0; i < nY * nZ; i++) {
1143 double r = recip[rindex++];
1144 work[index] *= r;
1145 work[index + internalImZ] *= r;
1146 index += ii;
1147 }
1148 }
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158 private void transpose(double[] input, int inputOffset, double[] output, int outputOffset) {
1159
1160
1161 for (int z = 0; z < nZ; z++) {
1162 for (int y = 0; y < nY; y++) {
1163 double real = input[inputOffset + y * nextY + z * nextZ];
1164 double imag = input[inputOffset + y * nextY + z * nextZ + im];
1165 output[outputOffset + y * trNextY + z * trNextZ] = real;
1166 output[outputOffset + y * trNextY + z * trNextZ + internalImZ] = imag;
1167 }
1168 }
1169 }
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179 private void unTranspose(double[] input, int inputOffset, double[] output, int outputOffset) {
1180
1181
1182 for (int z = 0; z < nZ; z++) {
1183 for (int y = 0; y < nY; y++) {
1184 double real = output[outputOffset + y * trNextY + z * trNextZ];
1185 double imag = output[outputOffset + y * trNextY + z * trNextZ + internalImZ];
1186 input[inputOffset + y * nextY + z * nextZ] = real;
1187 input[inputOffset + y * nextY + z * nextZ + im] = imag;
1188 }
1189 }
1190 }
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200 public static double[] initRandomData(int dim, ParallelTeam parallelTeam) {
1201 int n = dim * dim * dim;
1202 double[] data = new double[2 * n];
1203 try {
1204 parallelTeam.execute(
1205 new ParallelRegion() {
1206 @Override
1207 public void run() {
1208 try {
1209 execute(
1210 0,
1211 dim - 1,
1212 new IntegerForLoop() {
1213 @Override
1214 public void run(final int lb, final int ub) {
1215 Random randomNumberGenerator = new Random(1);
1216 int index = dim * dim * lb * 2;
1217 for (int i = lb; i <= ub; i++) {
1218 for (int j = 0; j < dim; j++) {
1219 for (int k = 0; k < dim; k++) {
1220 double randomNumber = randomNumberGenerator.nextDouble();
1221 data[index] = randomNumber;
1222 index += 2;
1223 }
1224 }
1225 }
1226 }
1227 });
1228 } catch (Exception e) {
1229 System.out.println(e.getMessage());
1230 System.exit(-1);
1231 }
1232 }
1233 });
1234 } catch (Exception e) {
1235 System.out.println(e.getMessage());
1236 System.exit(-1);
1237 }
1238 return data;
1239 }
1240
1241
1242
1243
1244
1245
1246
1247
1248 public static void main(String[] args) throws Exception {
1249 int dimNotFinal = 128;
1250 int nCPU = ParallelTeam.getDefaultThreadCount();
1251 int reps = 5;
1252 boolean blocked = false;
1253 try {
1254 dimNotFinal = Integer.parseInt(args[0]);
1255 if (dimNotFinal < 1) {
1256 dimNotFinal = 100;
1257 }
1258 nCPU = Integer.parseInt(args[1]);
1259 if (nCPU < 1) {
1260 nCPU = ParallelTeam.getDefaultThreadCount();
1261 }
1262 reps = Integer.parseInt(args[2]);
1263 if (reps < 1) {
1264 reps = 5;
1265 }
1266 blocked = Boolean.parseBoolean(args[3]);
1267 } catch (Exception e) {
1268
1269 }
1270 final int dim = dimNotFinal;
1271 System.out.printf("Initializing a %d cubed grid for %d CPUs.\n"
1272 + "The best timing out of %d repetitions will be used.%n",
1273 dim, nCPU, reps);
1274
1275 Complex3D complex3D;
1276 Complex3DParallel complex3DParallel;
1277 ParallelTeam parallelTeam = new ParallelTeam(nCPU);
1278 if (blocked) {
1279 complex3D = new Complex3D(dim, dim, dim, DataLayout3D.BLOCKED_X);
1280 complex3DParallel = new Complex3DParallel(dim, dim, dim, parallelTeam, DataLayout3D.BLOCKED_X);
1281 } else {
1282 complex3D = new Complex3D(dim, dim, dim, DataLayout3D.INTERLEAVED);
1283 complex3DParallel = new Complex3DParallel(dim, dim, dim, parallelTeam, DataLayout3D.INTERLEAVED);
1284 }
1285 final int dimCubed = dim * dim * dim;
1286 final double[] data = initRandomData(dim, parallelTeam);
1287 final double[] work = new double[dimCubed];
1288 Arrays.fill(work, 1.0);
1289
1290 double toSeconds = 0.000000001;
1291 long seqTime = Long.MAX_VALUE;
1292 long parTime = Long.MAX_VALUE;
1293 long seqTimeConv = Long.MAX_VALUE;
1294 long parTimeConv = Long.MAX_VALUE;
1295
1296 complex3D.setRecip(work);
1297 complex3DParallel.setRecip(work);
1298
1299
1300 System.out.println("Warm Up Sequential FFT");
1301 complex3D.fft(data);
1302 System.out.println("Warm Up Sequential IFFT");
1303 complex3D.ifft(data);
1304 System.out.println("Warm Up Sequential Convolution");
1305 complex3D.convolution(data);
1306 for (int i = 0; i < reps; i++) {
1307 System.out.printf(" Iteration %d%n", i + 1);
1308 long time = System.nanoTime();
1309 complex3D.fft(data);
1310 complex3D.ifft(data);
1311 time = (System.nanoTime() - time);
1312 System.out.printf(" Sequential FFT: %9.6f (sec)%n", toSeconds * time);
1313 if (time < seqTime) {
1314 seqTime = time;
1315 }
1316 time = System.nanoTime();
1317 complex3D.convolution(data);
1318 time = (System.nanoTime() - time);
1319 System.out.printf(" Sequential Conv: %9.6f (sec)%n", toSeconds * time);
1320 if (time < seqTimeConv) {
1321 seqTimeConv = time;
1322 }
1323 }
1324
1325
1326 System.out.println("Warm up Parallel FFT");
1327 complex3DParallel.fft(data);
1328 System.out.println("Warm up Parallel IFFT");
1329 complex3DParallel.ifft(data);
1330 System.out.println("Warm up Parallel Convolution");
1331 complex3DParallel.convolution(data);
1332 complex3DParallel.initTiming();
1333
1334 for (int i = 0; i < reps; i++) {
1335 System.out.printf(" Iteration %d%n", i + 1);
1336 long time = System.nanoTime();
1337 complex3DParallel.fft(data);
1338 complex3DParallel.ifft(data);
1339 time = (System.nanoTime() - time);
1340 System.out.printf(" Parallel FFT: %9.6f (sec)%n", toSeconds * time);
1341 if (time < parTime) {
1342 parTime = time;
1343 }
1344
1345 time = System.nanoTime();
1346 complex3DParallel.convolution(data);
1347 time = (System.nanoTime() - time);
1348 System.out.printf(" Parallel Conv: %9.6f (sec)%n", toSeconds * time);
1349 if (time < parTimeConv) {
1350 parTimeConv = time;
1351 }
1352 }
1353
1354 System.out.printf(" Best Sequential FFT Time: %9.6f (sec)%n", toSeconds * seqTime);
1355 System.out.printf(" Best Sequential Conv. Time: %9.6f (sec)%n", toSeconds * seqTimeConv);
1356 System.out.printf(" Best Parallel FFT Time: %9.6f (sec)%n", toSeconds * parTime);
1357 System.out.printf(" Best Parallel Conv. Time: %9.6f (sec)%n", toSeconds * parTimeConv);
1358 System.out.printf(" 3D FFT Speedup: %9.6f X%n", (double) seqTime / parTime);
1359 System.out.printf(" 3D Conv Speedup: %9.6f X%n", (double) seqTimeConv / parTimeConv);
1360
1361 System.out.printf(" Parallel Convolution Timings:\n" + complex3DParallel.timingString());
1362
1363 parallelTeam.shutdown();
1364 }
1365 }