1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  package ffx.numerics.fft;
39  
40  import edu.rit.pj.IntegerForLoop;
41  import edu.rit.pj.IntegerSchedule;
42  import edu.rit.pj.ParallelRegion;
43  import edu.rit.pj.ParallelTeam;
44  
45  import javax.annotation.Nullable;
46  import java.util.Arrays;
47  import java.util.Random;
48  import java.util.logging.Level;
49  import java.util.logging.Logger;
50  
51  import jdk.incubator.vector.DoubleVector;
52  import jdk.incubator.vector.VectorShuffle;
53  import jdk.incubator.vector.VectorSpecies;
54  
55  import static java.lang.String.format;
56  import static java.util.Objects.requireNonNullElseGet;
57  
58  
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  
72  
73  
74  public class Complex3DParallel {
75  
76    private static final Logger logger = Logger.getLogger(Complex3DParallel.class.getName());
77    
78  
79  
80    private final int nX;
81    
82  
83  
84    private final int nY;
85    
86  
87  
88    private final int nZ;
89    
90  
91  
92    private final int im;
93    
94  
95  
96    private final int ii;
97    
98  
99  
100   private final int nextX;
101   
102 
103 
104   private final int nextY;
105   
106 
107 
108   private final int nextZ;
109   
110 
111 
112   private final int trNextX;
113   
114 
115 
116   private final int trNextY;
117   
118 
119 
120   private final int trNextZ;
121   
122 
123 
124   private final double[] recip;
125   
126 
127 
128   private final int threadCount;
129   
130 
131 
132   private final ParallelTeam parallelTeam;
133   
134 
135 
136   private final Complex2D[] fftXY;
137   
138 
139 
140 
141   private final Complex[] fftZ;
142   
143 
144 
145   private final int internalImZ;
146   
147 
148 
149   private final IntegerSchedule schedule;
150   
151 
152 
153   private final int nXm1;
154   
155 
156 
157   private final int nYm1;
158   
159 
160 
161   private final int nZm1;
162   
163 
164 
165   private final FFTRegion fftRegion;
166   
167 
168 
169   private final IFFTRegion ifftRegion;
170   
171 
172 
173   private final ConvolutionRegion convRegion;
174   
175 
176 
177 
178   public double[] input;
179   
180 
181 
182   private boolean useSIMD;
183   private final VectorSpecies<Double> species = DoubleVector.SPECIES_PREFERRED;
184   private final int vectorSize = species.length();
185   private final int[] shuffle = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7};
186   private final VectorShuffle<Double> expandFirstHalf = VectorShuffle.fromArray(species, shuffle, 0);
187   private final VectorShuffle<Double> expandSecondHalf = VectorShuffle.fromArray(species, shuffle, vectorSize);
188 
189   
190 
191 
192   private boolean packFFTs;
193   private final boolean localZTranspose;
194   
195 
196 
197   public final double[] work3D;
198 
199   
200 
201 
202 
203 
204 
205 
206 
207 
208   public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam) {
209     this(nX, nY, nZ, parallelTeam, DataLayout3D.INTERLEAVED);
210   }
211 
212   
213 
214 
215 
216 
217 
218 
219 
220 
221 
222   public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, DataLayout3D dataLayout) {
223     this(nX, nY, nZ, parallelTeam, null, dataLayout);
224   }
225 
226   
227 
228 
229 
230 
231 
232 
233 
234 
235 
236   public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, @Nullable IntegerSchedule integerSchedule) {
237     this(nX, nY, nZ, parallelTeam, integerSchedule, DataLayout3D.INTERLEAVED);
238   }
239 
240   
241 
242 
243 
244 
245 
246 
247 
248 
249 
250 
251   public Complex3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam,
252                            @Nullable IntegerSchedule integerSchedule, DataLayout3D dataLayout) {
253     this.nX = nX;
254     this.nY = nY;
255     this.nZ = nZ;
256     this.parallelTeam = parallelTeam;
257     recip = new double[nX * nY * nZ];
258 
259     DataLayout1D dataLayout1D;
260     DataLayout2D dataLayout2D;
261     switch (dataLayout) {
262       default:
263       case INTERLEAVED:
264         
265         im = 1;
266         ii = 2;
267         nextX = 2;
268         nextY = 2 * nX;
269         nextZ = 2 * nX * nY;
270         
271         dataLayout1D = DataLayout1D.INTERLEAVED;
272         dataLayout2D = DataLayout2D.INTERLEAVED;
273         
274         internalImZ = 1;
275         trNextY = 2;
276         trNextZ = 2 * nY;
277         trNextX = 2 * nY * nZ;
278         break;
279       case BLOCKED_X:
280         
281         im = nX;
282         ii = 1;
283         nextX = 1;
284         nextY = 2 * nX;
285         nextZ = 2 * nX * nY;
286         
287         dataLayout1D = DataLayout1D.BLOCKED;
288         dataLayout2D = DataLayout2D.BLOCKED_X;
289         
290         internalImZ = nZ * nY;
291         trNextY = 1;
292         trNextZ = nY;
293         trNextX = 2 * nY * nZ;
294         break;
295       case BLOCKED_XY:
296         
297         im = nX * nY;
298         ii = 1;
299         nextX = 1;
300         nextY = nX;
301         nextZ = 2 * nY * nX;
302         
303         dataLayout1D = DataLayout1D.BLOCKED;
304         dataLayout2D = DataLayout2D.BLOCKED_XY;
305         
306         internalImZ = nY * nZ;
307         trNextY = 1;
308         trNextZ = nY;
309         trNextX = 2 * nY * nZ;
310         break;
311       case BLOCKED_XYZ:
312         
313         im = nX * nY * nZ;
314         ii = 1;
315         nextX = 1;
316         nextY = nX;
317         nextZ = nY * nX;
318         
319         dataLayout1D = DataLayout1D.BLOCKED;
320         dataLayout2D = DataLayout2D.BLOCKED_XY;
321         
322         internalImZ = nY * nZ;
323         trNextY = 1;
324         trNextZ = nY;
325         trNextX = 2 * nY * nZ;
326         break;
327     }
328 
329     nXm1 = this.nX - 1;
330     nYm1 = this.nY - 1;
331     nZm1 = this.nZ - 1;
332     threadCount = parallelTeam.getThreadCount();
333     schedule = requireNonNullElseGet(integerSchedule, IntegerSchedule::fixed);
334 
335     
336     useSIMD = false;
337     String simd = System.getProperty("fft.useSIMD", Boolean.toString(useSIMD));
338     try {
339       useSIMD = Boolean.parseBoolean(simd);
340     } catch (Exception e) {
341       useSIMD = false;
342     }
343 
344     
345     packFFTs = false;
346     String pack = System.getProperty("fft.packFFTs", Boolean.toString(packFFTs));
347     try {
348       packFFTs = Boolean.parseBoolean(pack);
349     } catch (Exception e) {
350       packFFTs = false;
351     }
352 
353     String localTranspose = System.getProperty("fft.localZTranspose", "true");
354     boolean local;
355     try {
356       local = Boolean.parseBoolean(localTranspose);
357     } catch (Exception e) {
358       local = true;
359     }
360     localZTranspose = local;
361 
362     fftXY = new Complex2D[threadCount];
363     for (int i = 0; i < threadCount; i++) {
364       fftXY[i] = new Complex2D(nX, nY, dataLayout2D, im);
365       fftXY[i].setPackFFTs(packFFTs);
366       fftXY[i].setUseSIMD(useSIMD);
367     }
368 
369     fftZ = new Complex[threadCount];
370     for (int i = 0; i < threadCount; i++) {
371       fftZ[i] = new Complex(nZ, dataLayout1D, internalImZ, nY);
372       fftZ[i].setUseSIMD(useSIMD);
373     }
374 
375     if (localZTranspose) {
376       work3D = null;
377     } else {
378       work3D = new double[2 * nX * nY * nZ];
379     }
380 
381     fftRegion = new FFTRegion();
382     ifftRegion = new IFFTRegion();
383     convRegion = new ConvolutionRegion();
384   }
385 
386   public String toString() {
387     return "Complex3DParallel {" +
388         "nX=" + nX +
389         ", nY=" + nY +
390         ", nZ=" + nZ +
391         ", im=" + im +
392         ", ii=" + ii +
393         ", nextX=" + nextX +
394         ", nextY=" + nextY +
395         ", nextZ=" + nextZ +
396         ", threadCount=" + threadCount +
397         ", parallelTeam=" + parallelTeam +
398         ", internalImZ=" + internalImZ +
399         ", schedule=" + schedule +
400         ", nXm1=" + nXm1 +
401         ", nYm1=" + nYm1 +
402         ", nZm1=" + nZm1 +
403         ", fftRegion=" + fftRegion +
404         ", ifftRegion=" + ifftRegion +
405         ", convRegion=" + convRegion +
406         ", useSIMD=" + useSIMD +
407         ", packFFTs=" + packFFTs +
408         '}';
409   }
410 
411   
412 
413 
414 
415 
416   public void setUseSIMD(boolean useSIMD) {
417     this.useSIMD = useSIMD;
418     for (int i = 0; i < threadCount; i++) {
419       fftXY[i].setUseSIMD(useSIMD);
420       fftZ[i].setUseSIMD(useSIMD);
421     }
422   }
423 
424   
425 
426 
427 
428 
429   public void setPackFFTs(boolean packFFTs) {
430     this.packFFTs = packFFTs;
431     for (int i = 0; i < threadCount; i++) {
432       fftXY[i].setPackFFTs(packFFTs);
433     }
434   }
435 
436   
437 
438 
439 
440 
441 
442 
443   public void convolution(final double[] input) {
444     this.input = input;
445     try {
446       parallelTeam.execute(convRegion);
447     } catch (Exception e) {
448       String message = "Fatal exception evaluating a convolution.\n";
449       logger.log(Level.SEVERE, message, e);
450     }
451   }
452 
453   
454 
455 
456 
457 
458 
459   public void fft(final double[] input) {
460     this.input = input;
461     try {
462       parallelTeam.execute(fftRegion);
463     } catch (Exception e) {
464       String message = " Fatal exception evaluating the FFT.\n";
465       logger.log(Level.SEVERE, message, e);
466     }
467   }
468 
469   
470 
471 
472 
473 
474   public long[] getTiming() {
475     return convRegion.getTiming();
476   }
477 
478   
479 
480 
481 
482 
483 
484   public void ifft(final double[] input) {
485     this.input = input;
486     try {
487       parallelTeam.execute(ifftRegion);
488     } catch (Exception e) {
489       String message = "Fatal exception evaluating the inverse FFT.\n";
490       logger.log(Level.SEVERE, message, e);
491       System.exit(-1);
492     }
493   }
494 
495   
496 
497 
498   public void initTiming() {
499     convRegion.initTiming();
500   }
501 
502   
503 
504 
505 
506 
507   public String timingString() {
508     return convRegion.timingString();
509   }
510 
511   
512 
513 
514 
515 
516   public void setRecip(double[] recip) {
517     
518     
519     
520     
521     
522     
523     int recipNextY = nX;
524     int recipNextZ = nY * nX;
525     int index = 0;
526     for (int x = 0; x < nX; x++) {
527       int dx = x;
528       for (int z = 0; z < nZ; z++) {
529         int dz = dx + z * recipNextZ;
530         for (int y = 0; y < nY; y++) {
531           int conv = y * recipNextY + dz;
532           this.recip[index] = recip[conv];
533           index++;
534         }
535       }
536     }
537   }
538 
539   
540 
541 
542 
543 
544 
545 
546 
547 
548 
549 
550   private class FFTRegion extends ParallelRegion {
551 
552     private final FFTXYLoop[] fftXYLoop;
553     private final FFTZLoop[] fftZLoop;
554     private final TransposeLoop[] transposeLoop;
555     private final UnTransposeLoop[] unTransposeLoop;
556 
557     private FFTRegion() {
558       fftXYLoop = new FFTXYLoop[threadCount];
559       fftZLoop = new FFTZLoop[threadCount];
560       for (int i = 0; i < threadCount; i++) {
561         fftXYLoop[i] = new FFTXYLoop();
562         fftZLoop[i] = new FFTZLoop();
563       }
564       if (!localZTranspose) {
565         transposeLoop = new TransposeLoop[threadCount];
566         unTransposeLoop = new UnTransposeLoop[threadCount];
567         for (int i = 0; i < threadCount; i++) {
568           transposeLoop[i] = new TransposeLoop();
569           unTransposeLoop[i] = new UnTransposeLoop();
570         }
571       } else {
572         transposeLoop = null;
573         unTransposeLoop = null;
574       }
575     }
576 
577     @Override
578     public void run() {
579       int threadIndex = getThreadIndex();
580       try {
581         if (localZTranspose) {
582           execute(0, nZm1, fftXYLoop[threadIndex]);
583           execute(0, nXm1, fftZLoop[threadIndex]);
584         } else {
585           execute(0, nZm1, fftXYLoop[threadIndex]);
586           execute(0, nXm1, transposeLoop[threadIndex]);
587           execute(0, nXm1, fftZLoop[threadIndex]);
588           execute(0, nZm1, unTransposeLoop[threadIndex]);
589         }
590       } catch (Exception e) {
591         logger.severe(e.toString());
592       }
593     }
594   }
595 
596   
597 
598 
599 
600 
601 
602 
603 
604 
605 
606 
607   private class IFFTRegion extends ParallelRegion {
608 
609     private final IFFTXYLoop[] ifftXYLoop;
610     private final IFFTZLoop[] ifftZLoop;
611     private final TransposeLoop[] transposeLoop;
612     private final UnTransposeLoop[] unTransposeLoop;
613 
614     private IFFTRegion() {
615       ifftXYLoop = new IFFTXYLoop[threadCount];
616       ifftZLoop = new IFFTZLoop[threadCount];
617       for (int i = 0; i < threadCount; i++) {
618         ifftXYLoop[i] = new IFFTXYLoop();
619         ifftZLoop[i] = new IFFTZLoop();
620       }
621       if (!localZTranspose) {
622         transposeLoop = new TransposeLoop[threadCount];
623         unTransposeLoop = new UnTransposeLoop[threadCount];
624         for (int i = 0; i < threadCount; i++) {
625           transposeLoop[i] = new TransposeLoop();
626           unTransposeLoop[i] = new UnTransposeLoop();
627         }
628       } else {
629         transposeLoop = null;
630         unTransposeLoop = null;
631       }
632     }
633 
634     @Override
635     public void run() {
636       int threadIndex = getThreadIndex();
637       try {
638         if (localZTranspose) {
639           execute(0, nXm1, ifftZLoop[threadIndex]);
640           execute(0, nZm1, ifftXYLoop[threadIndex]);
641         } else {
642           execute(0, nXm1, transposeLoop[threadIndex]);
643           execute(0, nXm1, ifftZLoop[threadIndex]);
644           execute(0, nZm1, unTransposeLoop[threadIndex]);
645           execute(0, nZm1, ifftXYLoop[threadIndex]);
646         }
647       } catch (Exception e) {
648         logger.severe(e.toString());
649       }
650     }
651   }
652 
653   
654 
655 
656 
657 
658 
659 
660 
661 
662 
663 
664 
665   private class ConvolutionRegion extends ParallelRegion {
666 
667     private final FFTXYLoop[] fftXYLoop;
668     private final TransposeLoop[] transposeLoop;
669     private final FFTZIZLoop[] fftZIZLoop;
670     private final UnTransposeLoop[] unTransposeLoop;
671     private final IFFTXYLoop[] ifftXYLoop;
672     private final long[] convTime;
673 
674     private ConvolutionRegion() {
675       fftXYLoop = new FFTXYLoop[threadCount];
676       fftZIZLoop = new FFTZIZLoop[threadCount];
677       ifftXYLoop = new IFFTXYLoop[threadCount];
678       convTime = new long[threadCount];
679       for (int i = 0; i < threadCount; i++) {
680         fftXYLoop[i] = new FFTXYLoop();
681         fftZIZLoop[i] = new FFTZIZLoop();
682         ifftXYLoop[i] = new IFFTXYLoop();
683       }
684       if (!localZTranspose) {
685         transposeLoop = new TransposeLoop[threadCount];
686         unTransposeLoop = new UnTransposeLoop[threadCount];
687         for (int i = 0; i < threadCount; i++) {
688           transposeLoop[i] = new TransposeLoop();
689           unTransposeLoop[i] = new UnTransposeLoop();
690         }
691       } else {
692         transposeLoop = null;
693         unTransposeLoop = null;
694       }
695     }
696 
697     public void initTiming() {
698       for (int i = 0; i < threadCount; i++) {
699         fftXYLoop[i].time = 0;
700         fftZIZLoop[i].time = 0;
701         ifftXYLoop[i].time = 0;
702       }
703       if (!localZTranspose) {
704         for (int i = 0; i < threadCount; i++) {
705           transposeLoop[i].time = 0;
706           unTransposeLoop[i].time = 0;
707         }
708       }
709     }
710 
711     public long[] getTiming() {
712       if (localZTranspose) {
713         for (int i = 0; i < threadCount; i++) {
714           convTime[i] = convRegion.fftXYLoop[i].time
715               + convRegion.fftZIZLoop[i].time
716               + convRegion.ifftXYLoop[i].time;
717         }
718       } else {
719         for (int i = 0; i < threadCount; i++) {
720           convTime[i] = convRegion.fftXYLoop[i].time
721               + convRegion.transposeLoop[i].time
722               + convRegion.fftZIZLoop[i].time
723               + convRegion.unTransposeLoop[i].time
724               + convRegion.ifftXYLoop[i].time;
725         }
726       }
727       return convTime;
728     }
729 
730     public String timingString() {
731       StringBuilder sb = new StringBuilder();
732       if (localZTranspose) {
733         double xysum = 0.0;
734         double zizsum = 0.0;
735         double ixysum = 0.0;
736         for (int i = 0; i < threadCount; i++) {
737           double fftxy = fftXYLoop[i].getTime() * 1e-9;
738           double ziz = fftZIZLoop[i].getTime() * 1e-9;
739           double ifftxy = ifftXYLoop[i].getTime() * 1e-9;
740           String s = format("  Thread %3d: FFTXY=%8.6f, FFTZIZ=%8.6f, IFFTXY=%8.6f\n",
741               i, fftxy, ziz, ifftxy);
742           sb.append(s);
743           xysum += fftxy;
744           zizsum += ziz;
745           ixysum += ifftxy;
746         }
747         String s = format("  Sum       : FFTXY=%8.6f, FFTZIZ=%8.6f, IFFTXY=%8.6f\n",
748             xysum, zizsum, ixysum);
749         sb.append(s);
750       } else {
751         double xysum = 0.0;
752         double transsum = 0.0;
753         double zizsum = 0.0;
754         double untranssum = 0.0;
755         double ixysum = 0.0;
756         for (int i = 0; i < threadCount; i++) {
757           double fftxy = fftXYLoop[i].getTime() * 1e-9;
758           double trans = transposeLoop[i].getTime() * 1e-9;
759           double ziz = fftZIZLoop[i].getTime() * 1e-9;
760           double untrans = unTransposeLoop[i].getTime() * 1e-9;
761           double ifftxy = ifftXYLoop[i].getTime() * 1e-9;
762           String s = format("  Thread %3d: FFTXY=%8.6f, Trans=%8.6f, FFTZIZ=%8.6f, UnTrans=%8.6f, IFFTXY=%8.6f\n",
763               i, fftxy, trans, ziz, untrans, ifftxy);
764           sb.append(s);
765           xysum += fftxy;
766           transsum += trans;
767           zizsum += ziz;
768           untranssum += untrans;
769           ixysum += ifftxy;
770         }
771         String s = format("  Sum       : FFTXY=%8.6f, Trans=%8.6f, FFTZIZ=%8.6f, UnTrans=%8.6f, IFFTXY=%8.6f\n",
772             xysum, transsum, zizsum, untranssum, ixysum);
773         sb.append(s);
774       }
775 
776       return sb.toString();
777     }
778 
779     @Override
780     public void run() {
781       int threadIndex = getThreadIndex();
782       try {
783         if (localZTranspose) {
784           execute(0, nZm1, fftXYLoop[threadIndex]);
785           execute(0, nXm1, fftZIZLoop[threadIndex]);
786           execute(0, nZm1, ifftXYLoop[threadIndex]);
787         } else {
788           execute(0, nZm1, fftXYLoop[threadIndex]);
789           execute(0, nXm1, transposeLoop[threadIndex]);
790           execute(0, nXm1, fftZIZLoop[threadIndex]);
791           execute(0, nZm1, unTransposeLoop[threadIndex]);
792           execute(0, nZm1, ifftXYLoop[threadIndex]);
793         }
794       } catch (Exception e) {
795         logger.severe(e.toString());
796       }
797     }
798   }
799 
800   private class FFTXYLoop extends IntegerForLoop {
801 
802     private Complex2D localFFTXY;
803     private long time;
804 
805     @Override
806     public void run(final int lb, final int ub) {
807       for (int z = lb; z <= ub; z++) {
808         int offset = z * nextZ;
809         localFFTXY.fft(input, offset);
810       }
811     }
812 
813     @Override
814     public IntegerSchedule schedule() {
815       return schedule;
816     }
817 
818     public long getTime() {
819       return time;
820     }
821 
822     public void finish() {
823       time += System.nanoTime();
824     }
825 
826     @Override
827     public void start() {
828       time -= System.nanoTime();
829       localFFTXY = fftXY[getThreadIndex()];
830     }
831   }
832 
833   private class IFFTXYLoop extends IntegerForLoop {
834 
835     private Complex2D localFFTXY;
836     private long time;
837 
838     @Override
839     public void run(final int lb, final int ub) {
840       for (int z = lb; z <= ub; z++) {
841         int offset = z * nextZ;
842         localFFTXY.ifft(input, offset);
843       }
844     }
845 
846     @Override
847     public IntegerSchedule schedule() {
848       return schedule;
849     }
850 
851     public long getTime() {
852       return time;
853     }
854 
855     public void finish() {
856       time += System.nanoTime();
857     }
858 
859     @Override
860     public void start() {
861       time -= System.nanoTime();
862       localFFTXY = fftXY[getThreadIndex()];
863     }
864   }
865 
866   private class FFTZLoop extends IntegerForLoop {
867 
868     private Complex localFFTZ;
869     private long time;
870     private final double[] work;
871 
872     private FFTZLoop() {
873       if (localZTranspose) {
874         work = new double[2 * nY * nZ];
875       } else {
876         work = null;
877       }
878     }
879 
880     @Override
881     public void run(final int lb, final int ub) {
882       if (localZTranspose) {
883         for (int x = lb; x <= ub; x++) {
884           int inputOffset = x * nextX;
885           transpose(input, inputOffset, work, 0);
886           localFFTZ.fft(work, 0, ii);
887           unTranspose(input, inputOffset, work, 0);
888         }
889       } else {
890         for (int x = lb; x <= ub; x++) {
891           int offset = x * nY * nZ * ii;
892           localFFTZ.fft(work3D, offset, ii);
893         }
894       }
895     }
896 
897     @Override
898     public IntegerSchedule schedule() {
899       return schedule;
900     }
901 
902     public long getTime() {
903       return time;
904     }
905 
906     @Override
907     public void finish() {
908       time += System.nanoTime();
909     }
910 
911     @Override
912     public void start() {
913       time -= System.nanoTime();
914       int threadID = getThreadIndex();
915       localFFTZ = fftZ[threadID];
916     }
917   }
918 
919   private class IFFTZLoop extends IntegerForLoop {
920 
921     private Complex localFFTZ;
922     private long time;
923     private final double[] work;
924 
925     private IFFTZLoop() {
926       if (localZTranspose) {
927         work = new double[2 * nY * nZ];
928       } else {
929         work = null;
930       }
931     }
932 
933     @Override
934     public void run(final int lb, final int ub) {
935       if (localZTranspose) {
936         for (int x = lb; x <= ub; x++) {
937           int inputOffset = x * nextX;
938           transpose(input, inputOffset, work, 0);
939           localFFTZ.ifft(work, 0, ii);
940           unTranspose(input, inputOffset, work, 0);
941         }
942       } else {
943         for (int x = lb; x <= ub; x++) {
944           int offset = x * nY * nZ * ii;
945           localFFTZ.ifft(work3D, offset, ii);
946         }
947       }
948     }
949 
950     @Override
951     public IntegerSchedule schedule() {
952       return schedule;
953     }
954 
955     public long getTime() {
956       return time;
957     }
958 
959     @Override
960     public void finish() {
961       time += System.nanoTime();
962     }
963 
964     @Override
965     public void start() {
966       time -= System.nanoTime();
967       int threadID = getThreadIndex();
968       localFFTZ = fftZ[threadID];
969     }
970   }
971 
972   private class FFTZIZLoop extends IntegerForLoop {
973 
974     private Complex localFFTZ;
975     private long time;
976     private final double[] work;
977 
978     private FFTZIZLoop() {
979       if (localZTranspose) {
980         work = new double[2 * nY * nZ];
981       } else {
982         work = null;
983       }
984     }
985 
986     @Override
987     public void run(final int lb, final int ub) {
988       if (localZTranspose) {
989         for (int x = lb; x <= ub; x++) {
990           int inputOffset = x * nextX;
991           transpose(input, inputOffset, work, 0);
992           localFFTZ.fft(work, 0, ii);
993           int recipOffset = x * nY * nZ;
994           recipConv(recipOffset, work, 0);
995           localFFTZ.ifft(work, 0, ii);
996           unTranspose(input, inputOffset, work, 0);
997         }
998       } else {
999         for (int x = lb; x <= ub; x++) {
1000           
1001           int offset = x * nY * nZ * ii;
1002           localFFTZ.fft(work3D, offset, ii);
1003           int recipOffset = x * nY * nZ;
1004           recipConv(recipOffset, work3D, offset);
1005           localFFTZ.ifft(work3D, offset, ii);
1006         }
1007       }
1008     }
1009 
1010     @Override
1011     public IntegerSchedule schedule() {
1012       return schedule;
1013     }
1014 
1015     public long getTime() {
1016       return time;
1017     }
1018 
1019     public void finish() {
1020       time += System.nanoTime();
1021     }
1022 
1023     @Override
1024     public void start() {
1025       time -= System.nanoTime();
1026       int threadID = getThreadIndex();
1027       localFFTZ = fftZ[threadID];
1028     }
1029   }
1030 
1031   
1032 
1033 
1034 
1035 
1036 
1037 
1038 
1039   private class TransposeLoop extends IntegerForLoop {
1040 
1041     private long time;
1042 
1043     @Override
1044     public void run(final int lb, final int ub) {
1045       for (int x = lb; x <= ub; x++) {
1046         for (int z = 0; z < nZ; z++) {
1047           int inputOffset = x * nextX + z * nextZ;
1048           int workOffset = x * trNextX + z * trNextZ;
1049           for (int y = 0; y < nY; y++) {
1050             int inputIndex = inputOffset + y * nextY;
1051             double real = input[inputIndex];
1052             double imag = input[inputIndex + im];
1053             int workIndex = workOffset + y * trNextY;
1054             work3D[workIndex] = real;
1055             work3D[workIndex + internalImZ] = imag;
1056           }
1057         }
1058       }
1059     }
1060 
1061     
1062 
1063 
1064 
1065 
1066     @Override
1067     public IntegerSchedule schedule() {
1068       return IntegerSchedule.fixed();
1069     }
1070 
1071     public long getTime() {
1072       return time;
1073     }
1074 
1075     @Override
1076     public void finish() {
1077       time += System.nanoTime();
1078     }
1079 
1080     @Override
1081     public void start() {
1082       time -= System.nanoTime();
1083     }
1084   }
1085 
1086   
1087 
1088 
1089 
1090 
1091 
1092 
1093 
1094 
1095   private class UnTransposeLoop extends IntegerForLoop {
1096 
1097     private long time;
1098 
1099     @Override
1100     public void run(final int lb, final int ub) {
1101       for (int x = 0; x < nX; x++) {
1102         for (int y = 0; y < nY; y++) {
1103           int workOffset = x * trNextX + y * trNextY;
1104           int inputOffset = x * nextX + y * nextY;
1105           for (int z = lb; z <= ub; z++) {
1106             int workIndex = workOffset + z * trNextZ;
1107             double real = work3D[workIndex];
1108             double imag = work3D[workIndex + internalImZ];
1109             int inputIndex = inputOffset + z * nextZ;
1110             input[inputIndex] = real;
1111             input[inputIndex + im] = imag;
1112           }
1113         }
1114       }
1115     }
1116 
1117     
1118 
1119 
1120 
1121 
1122     @Override
1123     public IntegerSchedule schedule() {
1124       return IntegerSchedule.fixed();
1125     }
1126 
1127     public long getTime() {
1128       return time;
1129     }
1130 
1131     @Override
1132     public void finish() {
1133       time += System.nanoTime();
1134     }
1135 
1136     @Override
1137     public void start() {
1138       time -= System.nanoTime();
1139     }
1140   }
1141 
1142   
1143 
1144 
1145 
1146 
1147 
1148 
1149   private void recipConv(int recipOffset, double[] work, int workOffset) {
1150     if (useSIMD && internalImZ == 1) {
1151       
1152       recipConvSIMD(recipOffset, work, workOffset);
1153       
1154     } else {
1155       recipConvScalar(recipOffset, work, workOffset);
1156     }
1157   }
1158 
1159   
1160 
1161 
1162 
1163 
1164 
1165 
1166   private void recipConvScalar(int recipOffset, double[] work, int workOffset) {
1167     int index = workOffset;
1168     int rindex = recipOffset;
1169     for (int i = 0; i < nY * nZ; i++) {
1170       double r = recip[rindex++];
1171       work[index] *= r;
1172       work[index + internalImZ] *= r;
1173       index += ii;
1174     }
1175   }
1176 
1177   
1178 
1179 
1180 
1181 
1182 
1183 
1184   private void recipConvSIMD(int recipOffset, double[] work, int workOffset) {
1185 
1186     
1187     if (internalImZ != 1) {
1188       logger.severe(" Real and imaginary parts must be interleaved.");
1189     }
1190 
1191     
1192     int length = nY * nZ * 2;
1193     
1194     int vectorSize2 = vectorSize * 2;
1195     int vectorizedLength = (length / vectorSize2) * vectorSize2;
1196 
1197     
1198     int i = 0;
1199     for (; i < vectorizedLength; i += vectorSize2) {
1200       
1201       DoubleVector recipVector = DoubleVector.fromArray(species, recip, recipOffset + i / 2);
1202 
1203       
1204       DoubleVector complexVector = DoubleVector.fromArray(species, work, workOffset + i);
1205       DoubleVector firstHalf = recipVector.rearrange(expandFirstHalf);
1206       complexVector = complexVector.mul(firstHalf);
1207       complexVector.intoArray(work, workOffset + i);
1208 
1209       
1210       complexVector = DoubleVector.fromArray(species, work, workOffset + vectorSize + i);
1211       DoubleVector secondHalf = recipVector.rearrange(expandSecondHalf);
1212       complexVector = complexVector.mul(secondHalf);
1213       complexVector.intoArray(work, workOffset + vectorSize + i);
1214     }
1215 
1216     
1217     for (; i < length; i+=2) {
1218       double r = recip[recipOffset + i / 2];
1219       work[workOffset + i] *= r;
1220       work[workOffset + i + internalImZ] *= r;
1221     }
1222   }
1223 
1224   
1225 
1226 
1227 
1228 
1229 
1230 
1231 
1232   private void transpose(double[] input, int inputOffset, double[] output, int outputOffset) {
1233     
1234     
1235     for (int z = 0; z < nZ; z++) {
1236       for (int y = 0; y < nY; y++) {
1237         double real = input[inputOffset + y * nextY + z * nextZ];
1238         double imag = input[inputOffset + y * nextY + z * nextZ + im];
1239         output[outputOffset + y * trNextY + z * trNextZ] = real;
1240         output[outputOffset + y * trNextY + z * trNextZ + internalImZ] = imag;
1241       }
1242     }
1243   }
1244 
1245   
1246 
1247 
1248 
1249 
1250 
1251 
1252 
1253   private void unTranspose(double[] input, int inputOffset, double[] output, int outputOffset) {
1254     
1255     
1256     for (int z = 0; z < nZ; z++) {
1257       for (int y = 0; y < nY; y++) {
1258         double real = output[outputOffset + y * trNextY + z * trNextZ];
1259         double imag = output[outputOffset + y * trNextY + z * trNextZ + internalImZ];
1260         input[inputOffset + y * nextY + z * nextZ] = real;
1261         input[inputOffset + y * nextY + z * nextZ + im] = imag;
1262       }
1263     }
1264   }
1265 
1266   
1267 
1268 
1269 
1270 
1271 
1272 
1273 
1274   public static double[] initRandomData(int dim, ParallelTeam parallelTeam) {
1275     int n = dim * dim * dim;
1276     double[] data = new double[2 * n];
1277     try {
1278       parallelTeam.execute(
1279           new ParallelRegion() {
1280             @Override
1281             public void run() {
1282               try {
1283                 execute(
1284                     0,
1285                     dim - 1,
1286                     new IntegerForLoop() {
1287                       @Override
1288                       public void run(final int lb, final int ub) {
1289                         Random randomNumberGenerator = new Random(1);
1290                         int index = dim * dim * lb * 2;
1291                         for (int i = lb; i <= ub; i++) {
1292                           for (int j = 0; j < dim; j++) {
1293                             for (int k = 0; k < dim; k++) {
1294                               double randomNumber = randomNumberGenerator.nextDouble();
1295                               data[index] = randomNumber;
1296                               index += 2;
1297                             }
1298                           }
1299                         }
1300                       }
1301                     });
1302               } catch (Exception e) {
1303                 System.out.println(e.getMessage());
1304                 System.exit(-1);
1305               }
1306             }
1307           });
1308     } catch (Exception e) {
1309       System.out.println(e.getMessage());
1310       System.exit(-1);
1311     }
1312     return data;
1313   }
1314 
1315   
1316 
1317 
1318 
1319 
1320 
1321 
1322   public static void main(String[] args) throws Exception {
1323     int dimNotFinal = 128;
1324     int nCPU = ParallelTeam.getDefaultThreadCount();
1325     int reps = 5;
1326     boolean blocked = false;
1327     try {
1328       dimNotFinal = Integer.parseInt(args[0]);
1329       if (dimNotFinal < 1) {
1330         dimNotFinal = 100;
1331       }
1332       nCPU = Integer.parseInt(args[1]);
1333       if (nCPU < 1) {
1334         nCPU = ParallelTeam.getDefaultThreadCount();
1335       }
1336       reps = Integer.parseInt(args[2]);
1337       if (reps < 1) {
1338         reps = 5;
1339       }
1340       blocked = Boolean.parseBoolean(args[3]);
1341     } catch (Exception e) {
1342       
1343     }
1344     final int dim = dimNotFinal;
1345     System.out.printf("Initializing a %d cubed grid for %d CPUs.\n"
1346             + "The best timing out of %d repetitions will be used.%n",
1347         dim, nCPU, reps);
1348     
1349     Complex3D complex3D;
1350     Complex3DParallel complex3DParallel;
1351     ParallelTeam parallelTeam = new ParallelTeam(nCPU);
1352     if (blocked) {
1353       complex3D = new Complex3D(dim, dim, dim, DataLayout3D.BLOCKED_X);
1354       complex3DParallel = new Complex3DParallel(dim, dim, dim, parallelTeam, DataLayout3D.BLOCKED_X);
1355     } else {
1356       complex3D = new Complex3D(dim, dim, dim, DataLayout3D.INTERLEAVED);
1357       complex3DParallel = new Complex3DParallel(dim, dim, dim, parallelTeam, DataLayout3D.INTERLEAVED);
1358     }
1359     final int dimCubed = dim * dim * dim;
1360     final double[] data = initRandomData(dim, parallelTeam);
1361     final double[] work = new double[dimCubed];
1362     Arrays.fill(work, 1.0);
1363 
1364     double toSeconds = 0.000000001;
1365     long seqTime = Long.MAX_VALUE;
1366     long parTime = Long.MAX_VALUE;
1367     long seqTimeConv = Long.MAX_VALUE;
1368     long parTimeConv = Long.MAX_VALUE;
1369 
1370     complex3D.setRecip(work);
1371     complex3DParallel.setRecip(work);
1372 
1373     
1374     System.out.println("Warm Up Sequential FFT");
1375     complex3D.fft(data);
1376     System.out.println("Warm Up Sequential IFFT");
1377     complex3D.ifft(data);
1378     System.out.println("Warm Up Sequential Convolution");
1379     complex3D.convolution(data);
1380 
1381 
1382 
1383     for (int i = 0; i < reps; i++) {
1384       System.out.printf(" Iteration %d%n", i + 1);
1385       long time = System.nanoTime();
1386       complex3D.fft(data);
1387       complex3D.ifft(data);
1388       time = (System.nanoTime() - time);
1389       System.out.printf("  Sequential FFT:  %9.6f (sec)%n", toSeconds * time);
1390       if (time < seqTime) {
1391         seqTime = time;
1392       }
1393       time = System.nanoTime();
1394       complex3D.convolution(data);
1395       time = (System.nanoTime() - time);
1396       System.out.printf("  Sequential Conv: %9.6f (sec)%n", toSeconds * time);
1397       if (time < seqTimeConv) {
1398         seqTimeConv = time;
1399       }
1400     }
1401 
1402     
1403     System.out.println("Warm up Parallel FFT");
1404     complex3DParallel.fft(data);
1405     System.out.println("Warm up Parallel IFFT");
1406     complex3DParallel.ifft(data);
1407     System.out.println("Warm up Parallel Convolution");
1408     complex3DParallel.convolution(data);
1409     complex3DParallel.initTiming();
1410 
1411     for (int i = 0; i < reps; i++) {
1412       
1413       if (i == reps / 2) {
1414         complex3DParallel.initTiming();
1415       }
1416 
1417       System.out.printf(" Iteration %d%n", i + 1);
1418       long time = System.nanoTime();
1419       complex3DParallel.fft(data);
1420       complex3DParallel.ifft(data);
1421       time = (System.nanoTime() - time);
1422       System.out.printf("  Parallel FFT:  %9.6f (sec)%n", toSeconds * time);
1423       if (time < parTime) {
1424         parTime = time;
1425       }
1426 
1427       time = System.nanoTime();
1428       complex3DParallel.convolution(data);
1429       time = (System.nanoTime() - time);
1430       System.out.printf("  Parallel Conv: %9.6f (sec)%n", toSeconds * time);
1431       if (time < parTimeConv) {
1432         parTimeConv = time;
1433       }
1434 
1435     }
1436 
1437     System.out.printf(" Best Sequential FFT Time:   %9.6f (sec)%n", toSeconds * seqTime);
1438     System.out.printf(" Best Sequential Conv. Time: %9.6f (sec)%n", toSeconds * seqTimeConv);
1439     System.out.printf(" Best Parallel FFT Time:     %9.6f (sec)%n", toSeconds * parTime);
1440     System.out.printf(" Best Parallel Conv. Time:   %9.6f (sec)%n", toSeconds * parTimeConv);
1441     System.out.printf(" 3D FFT Speedup:             %9.6f X%n", (double) seqTime / parTime);
1442     System.out.printf(" 3D Conv Speedup:            %9.6f X%n", (double) seqTimeConv / parTimeConv);
1443 
1444     System.out.printf(" Parallel Convolution Timings:\n" + complex3DParallel.timingString());
1445 
1446     parallelTeam.shutdown();
1447   }
1448 }