View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2024.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.fft;
39  
40  import edu.rit.pj.IntegerForLoop;
41  import edu.rit.pj.IntegerSchedule;
42  import edu.rit.pj.ParallelRegion;
43  import edu.rit.pj.ParallelTeam;
44  
45  import javax.annotation.Nullable;
46  import java.util.Random;
47  import java.util.logging.Level;
48  import java.util.logging.Logger;
49  
50  import static java.util.Objects.requireNonNullElseGet;
51  
52  /**
53   * Compute the 3D FFT of real, double precision input of arbitrary dimensions in parallel.
54   *
55   * <p>
56   *
57   * @author Michal J. Schnieders
58   * @see Real
59   * @since 1.0
60   */
61  public class Real3DParallel {
62  
63    private static final Logger logger = Logger.getLogger(Real3DParallel.class.getName());
64    private final int nX, nY, nZ, nZ2, nX1;
65    private final int n, nextX, nextY, nextZ;
66    private final ParallelTeam parallelTeam;
67    private final int threadCount;
68    private final ParallelIFFT parallelIFFT;
69    private final ParallelFFT parallelFFT;
70    private final ParallelConvolution parallelConvolution;
71    private final double[] recip;
72    private final IntegerSchedule schedule;
73  
74    /**
75     * Initialize the FFT for real input.
76     *
77     * @param nX           X-dimension.
78     * @param nY           Y-dimension.
79     * @param nZ           Z-dimension.
80     * @param parallelTeam a {@link edu.rit.pj.ParallelTeam} object.
81     * @since 1.0
82     */
83    public Real3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam) {
84      this.nX = nX / 2;
85      this.nY = nY;
86      this.nZ = nZ;
87      this.parallelTeam = parallelTeam;
88      n = nX;
89      nX1 = this.nX + 1;
90      nZ2 = this.nZ * 2;
91      nextX = 2;
92      nextY = n + 2;
93      nextZ = nextY * nY;
94      recip = new double[nX1 * nY * nZ];
95      threadCount = parallelTeam.getThreadCount();
96      parallelFFT = new ParallelFFT();
97      parallelIFFT = new ParallelIFFT();
98      parallelConvolution = new ParallelConvolution();
99      schedule = IntegerSchedule.fixed();
100   }
101 
102   /**
103    * Initialize the FFT for real input.
104    *
105    * @param nX              X-dimension.
106    * @param nY              Y-dimension.
107    * @param nZ              Z-dimension.
108    * @param parallelTeam    The ParallelTeam that will execute the transforms.
109    * @param integerSchedule The IntegerSchedule to use.
110    * @since 1.0
111    */
112   public Real3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, @Nullable IntegerSchedule integerSchedule) {
113     this.nX = nX / 2;
114     this.nY = nY;
115     this.nZ = nZ;
116     this.parallelTeam = parallelTeam;
117     n = nX;
118     nX1 = this.nX + 1;
119     nZ2 = this.nZ * 2;
120     nextX = 2;
121     nextY = n + 2;
122     nextZ = nextY * nY;
123     recip = new double[nX1 * nY * nZ];
124     threadCount = parallelTeam.getThreadCount();
125     schedule = requireNonNullElseGet(integerSchedule, IntegerSchedule::fixed);
126     parallelFFT = new ParallelFFT();
127     parallelIFFT = new ParallelIFFT();
128     parallelConvolution = new ParallelConvolution();
129   }
130 
131   /**
132    * Test the real 3D FFT.
133    *
134    * @param args an array of {@link java.lang.String} objects.
135    * @since 1.0
136    */
137   public static void main(String[] args) {
138     int dimNotFinal = 128;
139     int nCPU = ParallelTeam.getDefaultThreadCount();
140     int reps = 5;
141     try {
142       dimNotFinal = Integer.parseInt(args[0]);
143       if (dimNotFinal < 1) {
144         dimNotFinal = 100;
145       }
146       nCPU = Integer.parseInt(args[1]);
147       if (nCPU < 1) {
148         nCPU = ParallelTeam.getDefaultThreadCount();
149       }
150       reps = Integer.parseInt(args[2]);
151       if (reps < 1) {
152         reps = 5;
153       }
154     } catch (Exception e) {
155       //
156     }
157     if (dimNotFinal % 2 != 0) {
158       dimNotFinal++;
159     }
160     final int dim = dimNotFinal;
161     System.out.printf("Initializing a %d cubed grid for %d CPUs.\n"
162         + "The best timing out of %d repetitions will be used.%n", dim, nCPU, reps);
163 
164     Real3D real3D = new Real3D(dim, dim, dim);
165     ParallelTeam parallelTeam = new ParallelTeam(nCPU);
166     Real3DParallel real3DParallel = new Real3DParallel(dim, dim, dim, parallelTeam);
167 
168     final int dimCubed = (dim + 2) * dim * dim;
169     final double[] data = new double[dimCubed];
170     final double[] work = new double[dimCubed];
171 
172     // Parallel Array Initialization.
173     try {
174       parallelTeam.execute(
175           new ParallelRegion() {
176             @Override
177             public void run() {
178               try {
179                 execute(
180                     0,
181                     dim - 1,
182                     new IntegerForLoop() {
183                       @Override
184                       public void run(int lb, int ub) {
185                         Random randomNumberGenerator = new Random(1);
186                         int index = dim * dim * lb;
187                         for (int z = lb; z <= ub; z++) {
188                           for (int y = 0; y < dim; y++) {
189                             for (int x = 0; x < dim; x++) {
190                               double randomNumber = randomNumberGenerator.nextDouble();
191                               data[index] = randomNumber;
192                               index++;
193                             }
194                           }
195                         }
196                       }
197                     });
198               } catch (Exception e) {
199                 System.out.println(e.getMessage());
200                 System.exit(-1);
201               }
202             }
203           });
204     } catch (Exception e) {
205       System.out.println(e.getMessage());
206       System.exit(-1);
207     }
208 
209     double toSeconds = 0.000000001;
210     long parTime = Long.MAX_VALUE;
211     long seqTime = Long.MAX_VALUE;
212     real3D.setRecip(work);
213     real3DParallel.setRecip(work);
214     for (int i = 0; i < reps; i++) {
215       System.out.printf("Iteration %d%n", i + 1);
216       long time = System.nanoTime();
217       real3D.fft(data);
218       real3D.ifft(data);
219       time = (System.nanoTime() - time);
220       System.out.printf("Sequential: %8.3f%n", toSeconds * time);
221       if (time < seqTime) {
222         seqTime = time;
223       }
224       time = System.nanoTime();
225       real3D.convolution(data);
226       time = (System.nanoTime() - time);
227       System.out.printf("Sequential: %8.3f (Convolution)%n", toSeconds * time);
228       if (time < seqTime) {
229         seqTime = time;
230       }
231       time = System.nanoTime();
232       real3DParallel.fft(data);
233       real3DParallel.ifft(data);
234       time = (System.nanoTime() - time);
235       System.out.printf("Parallel:   %8.3f%n", toSeconds * time);
236       if (time < parTime) {
237         parTime = time;
238       }
239       time = System.nanoTime();
240       real3DParallel.convolution(data);
241       time = (System.nanoTime() - time);
242       System.out.printf("Parallel:   %8.3f (Convolution)\n%n", toSeconds * time);
243       if (time < parTime) {
244         parTime = time;
245       }
246     }
247     System.out.printf("Best Sequential Time:  %8.3f%n", toSeconds * seqTime);
248     System.out.printf("Best Parallel Time:    %8.3f%n", toSeconds * parTime);
249     System.out.printf("Speedup: %15.5f%n", (double) seqTime / parTime);
250   }
251 
252   /**
253    * Compute a convolution in parallel.
254    *
255    * @param input The input array must be of size (nX + 2) * nY * nZ.
256    * @since 1.0
257    */
258   public void convolution(final double[] input) {
259     parallelConvolution.input = input;
260     try {
261       parallelTeam.execute(parallelConvolution);
262     } catch (Exception e) {
263       String message = "Fatal exception evaluating a 3D convolution in parallel.\n";
264       logger.log(Level.SEVERE, message, e);
265       System.exit(-1);
266     }
267   }
268 
269   /**
270    * Compute the 3D FFT.
271    *
272    * @param input The input array must be of size (nX + 2) * nY * nZ.
273    * @since 1.0
274    */
275   public void fft(final double[] input) {
276     parallelFFT.input = input;
277     try {
278       parallelTeam.execute(parallelFFT);
279     } catch (Exception e) {
280       String message = "Fatal exception evaluating real 3D FFT in parallel.\n";
281       logger.log(Level.SEVERE, message, e);
282       System.exit(-1);
283     }
284   }
285 
286   /**
287    * Compute the inverse 3D FFT.
288    *
289    * @param input The input array must be of size (nX + 2) * nY * nZ.
290    * @since 1.0
291    */
292   public void ifft(final double[] input) {
293     parallelIFFT.input = input;
294     try {
295       parallelTeam.execute(parallelIFFT);
296     } catch (Exception e) {
297       String message = "Fatal exception evaluating real 3D inverse FFT in parallel.\n";
298       logger.log(Level.SEVERE, message, e);
299       System.exit(-1);
300     }
301   }
302 
303   /**
304    * Setter for the field <code>recip</code>.
305    *
306    * @param recip The recip array must be of size [(nX/2 + 1) * nY * nZ].
307    */
308   public void setRecip(double[] recip) {
309     // Reorder the reciprocal space data into the order it is needed by the convolution routine.
310     for (int index = 0, offset = 0, y = 0; y < nY; y++) {
311       for (int x = 0; x < nX1; x++, offset += 1) {
312         for (int i = 0, z = offset; i < nZ; i++, z += nX1 * nY) {
313           this.recip[index++] = recip[z];
314         }
315       }
316     }
317   }
318 
319   /**
320    * Implement the 3D parallel FFT.
321    *
322    * @author Michael J. Schnieders
323    * @since 1.0
324    */
325   private class ParallelFFT extends ParallelRegion {
326 
327     private final int nZm1;
328     private final FFTXYLoop[] fftXYLoop;
329     private final FFTZLoop[] fftZLoop;
330     public double[] input;
331 
332     private ParallelFFT() {
333       nZm1 = nZ - 1;
334       fftXYLoop = new FFTXYLoop[threadCount];
335       fftZLoop = new FFTZLoop[threadCount];
336       for (int i = 0; i < threadCount; i++) {
337         fftXYLoop[i] = new FFTXYLoop();
338         fftZLoop[i] = new FFTZLoop();
339       }
340     }
341 
342     @Override
343     public void run() {
344       int threadIndex = getThreadIndex();
345       fftXYLoop[threadIndex].input = input;
346       fftZLoop[threadIndex].input = input;
347       try {
348         execute(0, nZm1, fftXYLoop[threadIndex]);
349         // There are nX + 1 frequencies.
350         execute(0, nX, fftZLoop[threadIndex]);
351       } catch (Exception e) {
352         logger.severe(e.toString());
353       }
354     }
355 
356     private class FFTXYLoop extends IntegerForLoop {
357 
358       private final Real fftX;
359       private final Complex fftY;
360       public double[] input;
361 
362       private FFTXYLoop() {
363         fftY = new Complex(nY);
364         fftX = new Real(n);
365       }
366 
367       @Override
368       public void run(final int lb, final int ub) {
369         for (int z = lb; z <= ub; z++) {
370           for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
371             fftX.fft(input, offset);
372           }
373           for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
374             fftY.fft(input, offset, nextY);
375           }
376         }
377       }
378 
379       @Override
380       public IntegerSchedule schedule() {
381         return schedule;
382       }
383     }
384 
385     private class FFTZLoop extends IntegerForLoop {
386 
387       private final double[] work;
388       private final Complex fft;
389       public double[] input;
390 
391       private FFTZLoop() {
392         work = new double[nZ2];
393         fft = new Complex(nZ);
394       }
395 
396       @Override
397       public void run(final int lb, final int ub) {
398         for (int x = lb; x <= ub; x++) {
399           for (int offset = x * 2, y = 0; y < nY; y++, offset += nextY) {
400             for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
401               work[i] = input[z];
402               work[i + 1] = input[z + 1];
403             }
404             fft.fft(work, 0, 2);
405             for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
406               input[z] = work[i];
407               input[z + 1] = work[i + 1];
408             }
409           }
410         }
411       }
412 
413       @Override
414       public IntegerSchedule schedule() {
415         return schedule;
416       }
417     }
418   }
419 
420   /**
421    * Implement the 3D parallel inverse FFT.
422    *
423    * @author Michael J. Schnieders
424    * @since 1.0
425    */
426   private class ParallelIFFT extends ParallelRegion {
427 
428     private final int nZm1;
429     private final IFFTXYLoop[] ifftXYLoop;
430     private final IFFTZLoop[] ifftZLoop;
431     public double[] input;
432 
433     private ParallelIFFT() {
434       nZm1 = nZ - 1;
435       ifftXYLoop = new IFFTXYLoop[threadCount];
436       ifftZLoop = new IFFTZLoop[threadCount];
437       for (int i = 0; i < threadCount; i++) {
438         ifftXYLoop[i] = new IFFTXYLoop();
439         ifftZLoop[i] = new IFFTZLoop();
440       }
441     }
442 
443     @Override
444     public void run() {
445       int threadIndex = getThreadIndex();
446       ifftXYLoop[threadIndex].input = input;
447       ifftZLoop[threadIndex].input = input;
448       try {
449         // There are xDim + 1 frequencies.
450         execute(0, nX, ifftZLoop[threadIndex]);
451         execute(0, nZm1, ifftXYLoop[threadIndex]);
452       } catch (Exception e) {
453         logger.severe(e.toString());
454       }
455     }
456 
457     private class IFFTZLoop extends IntegerForLoop {
458 
459       private final double[] work;
460       private final Complex fft;
461       public double[] input;
462 
463       private IFFTZLoop() {
464         fft = new Complex(nZ);
465         work = new double[nZ2];
466       }
467 
468       @Override
469       public void run(final int lb, final int ub) {
470         for (int x = lb; x <= ub; x++) {
471           for (int offset = x * 2, y = 0; y < nY; y++, offset += nextY) {
472             for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
473               work[i] = input[z];
474               work[i + 1] = input[z + 1];
475             }
476             fft.ifft(work, 0, 2);
477             for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
478               input[z] = work[i];
479               input[z + 1] = work[i + 1];
480             }
481           }
482         }
483       }
484 
485       @Override
486       public IntegerSchedule schedule() {
487         return schedule;
488       }
489     }
490 
491     private class IFFTXYLoop extends IntegerForLoop {
492 
493       private final Real fftX;
494       private final Complex fftY;
495       public double[] input;
496 
497       private IFFTXYLoop() {
498         fftX = new Real(n);
499         fftY = new Complex(nY);
500       }
501 
502       @Override
503       public void run(final int lb, final int ub) {
504         for (int z = lb; z <= ub; z++) {
505           for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
506             fftY.ifft(input, offset, nextY);
507           }
508           for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
509             fftX.ifft(input, offset);
510           }
511         }
512       }
513 
514       @Override
515       public IntegerSchedule schedule() {
516         return schedule;
517       }
518     }
519   }
520 
521   /**
522    * Implement the 3D parallel convolution.
523    *
524    * @author Michael J. Schnieders
525    * @since 1.0
526    */
527   private class ParallelConvolution extends ParallelRegion {
528 
529     private final int nZm1, nYm1, nX1nZ;
530     private final FFTXYLoop[] fftXYLoop;
531     private final FFTZ_Multiply_IFFTZLoop[] fftZ_Multiply_ifftZLoop;
532     private final IFFTXYLoop[] ifftXYLoop;
533     public double[] input;
534 
535     private ParallelConvolution() {
536       nZm1 = nZ - 1;
537       nYm1 = nY - 1;
538       nX1nZ = nX1 * nZ;
539       fftXYLoop = new FFTXYLoop[threadCount];
540       fftZ_Multiply_ifftZLoop = new FFTZ_Multiply_IFFTZLoop[threadCount];
541       ifftXYLoop = new IFFTXYLoop[threadCount];
542       for (int i = 0; i < threadCount; i++) {
543         fftXYLoop[i] = new FFTXYLoop();
544         fftZ_Multiply_ifftZLoop[i] = new FFTZ_Multiply_IFFTZLoop();
545         ifftXYLoop[i] = new IFFTXYLoop();
546       }
547     }
548 
549     @Override
550     public void run() {
551       int threadIndex = getThreadIndex();
552       fftXYLoop[threadIndex].input = input;
553       fftZ_Multiply_ifftZLoop[threadIndex].input = input;
554       ifftXYLoop[threadIndex].input = input;
555       try {
556         execute(0, nZm1, fftXYLoop[threadIndex]);
557         execute(0, nYm1, fftZ_Multiply_ifftZLoop[threadIndex]);
558         execute(0, nZm1, ifftXYLoop[threadIndex]);
559       } catch (Exception e) {
560         logger.severe(e.toString());
561       }
562     }
563 
564     private class FFTXYLoop extends IntegerForLoop {
565 
566       private final Real fftX;
567       private final Complex fftY;
568       public double[] input;
569 
570       private FFTXYLoop() {
571         fftY = new Complex(nY);
572         fftX = new Real(n);
573       }
574 
575       @Override
576       public void run(final int lb, final int ub) {
577         for (int z = lb; z <= ub; z++) {
578           for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
579             fftX.fft(input, offset);
580           }
581           for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
582             fftY.fft(input, offset, nextY);
583           }
584         }
585       }
586 
587       @Override
588       public IntegerSchedule schedule() {
589         return schedule;
590       }
591     }
592 
593     private class FFTZ_Multiply_IFFTZLoop extends IntegerForLoop {
594 
595       private final double[] work;
596       private final Complex fft;
597       public double[] input;
598 
599       private FFTZ_Multiply_IFFTZLoop() {
600         work = new double[nZ2];
601         fft = new Complex(nZ);
602       }
603 
604       @Override
605       public void run(final int lb, final int ub) {
606         int index = lb * nX1nZ;
607         for (int offset = lb * nextY, y = lb; y <= ub; y++) {
608           for (int x = 0; x < nX1; x++, offset += nextX) {
609             for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
610               work[i] = input[z];
611               work[i + 1] = input[z + 1];
612             }
613             fft.fft(work, 0, 2);
614             for (int i = 0; i < nZ2; i += 2) {
615               double r = recip[index++];
616               work[i] *= r;
617               work[i + 1] *= r;
618             }
619             fft.ifft(work, 0, 2);
620             for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
621               input[z] = work[i];
622               input[z + 1] = work[i + 1];
623             }
624           }
625         }
626       }
627 
628       @Override
629       public IntegerSchedule schedule() {
630         return schedule;
631       }
632     }
633 
634     private class IFFTXYLoop extends IntegerForLoop {
635 
636       private final Real fftX;
637       private final Complex fftY;
638       public double[] input;
639 
640       private IFFTXYLoop() {
641         fftY = new Complex(nY);
642         fftX = new Real(n);
643       }
644 
645       @Override
646       public void run(final int lb, final int ub) {
647         for (int z = lb; z <= ub; z++) {
648           for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
649             fftY.ifft(input, offset, nextY);
650           }
651           for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
652             fftX.ifft(input, offset);
653           }
654         }
655       }
656 
657       @Override
658       public IntegerSchedule schedule() {
659         return schedule;
660       }
661     }
662   }
663 }