View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.fft;
39  
40  import edu.rit.pj.IntegerForLoop;
41  import edu.rit.pj.IntegerSchedule;
42  import edu.rit.pj.ParallelRegion;
43  import edu.rit.pj.ParallelTeam;
44  
45  import javax.annotation.Nullable;
46  import java.util.Random;
47  import java.util.logging.Level;
48  import java.util.logging.Logger;
49  
50  import static java.util.Objects.requireNonNullElseGet;
51  
52  /**
53   * Compute the 3D FFT of real, double precision input of arbitrary dimensions in parallel.
54   *
55   * @author Michal J. Schnieders
56   * @see Real
57   * @since 1.0
58   */
59  public class Real3DParallel {
60  
61    private static final Logger logger = Logger.getLogger(Real3DParallel.class.getName());
62    private final int nX, nY, nZ, nZ2, nX1;
63    private final int n, nextX, nextY, nextZ;
64    private final ParallelTeam parallelTeam;
65    private final int threadCount;
66    private final ParallelIFFT parallelIFFT;
67    private final ParallelFFT parallelFFT;
68    private final ParallelConvolution parallelConvolution;
69    private final double[] recip;
70    private final IntegerSchedule schedule;
71  
72    /**
73     * Initialize the FFT for real input.
74     *
75     * @param nX           X-dimension.
76     * @param nY           Y-dimension.
77     * @param nZ           Z-dimension.
78     * @param parallelTeam a {@link edu.rit.pj.ParallelTeam} object.
79     * @since 1.0
80     */
81    public Real3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam) {
82      this.nX = nX / 2;
83      this.nY = nY;
84      this.nZ = nZ;
85      this.parallelTeam = parallelTeam;
86      n = nX;
87      nX1 = this.nX + 1;
88      nZ2 = this.nZ * 2;
89      nextX = 2;
90      nextY = n + 2;
91      nextZ = nextY * nY;
92      recip = new double[nX1 * nY * nZ];
93      threadCount = parallelTeam.getThreadCount();
94      parallelFFT = new ParallelFFT();
95      parallelIFFT = new ParallelIFFT();
96      parallelConvolution = new ParallelConvolution();
97      schedule = IntegerSchedule.fixed();
98    }
99  
100   /**
101    * Initialize the FFT for real input.
102    *
103    * @param nX              X-dimension.
104    * @param nY              Y-dimension.
105    * @param nZ              Z-dimension.
106    * @param parallelTeam    The ParallelTeam that will execute the transforms.
107    * @param integerSchedule The IntegerSchedule to use.
108    * @since 1.0
109    */
110   public Real3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, @Nullable IntegerSchedule integerSchedule) {
111     this.nX = nX / 2;
112     this.nY = nY;
113     this.nZ = nZ;
114     this.parallelTeam = parallelTeam;
115     n = nX;
116     nX1 = this.nX + 1;
117     nZ2 = this.nZ * 2;
118     nextX = 2;
119     nextY = n + 2;
120     nextZ = nextY * nY;
121     recip = new double[nX1 * nY * nZ];
122     threadCount = parallelTeam.getThreadCount();
123     schedule = requireNonNullElseGet(integerSchedule, IntegerSchedule::fixed);
124     parallelFFT = new ParallelFFT();
125     parallelIFFT = new ParallelIFFT();
126     parallelConvolution = new ParallelConvolution();
127   }
128 
129   /**
130    * Test the real 3D FFT.
131    *
132    * @param args an array of {@link java.lang.String} objects.
133    * @since 1.0
134    */
135   public static void main(String[] args) {
136     int dimNotFinal = 128;
137     int nCPU = ParallelTeam.getDefaultThreadCount();
138     int reps = 5;
139     try {
140       dimNotFinal = Integer.parseInt(args[0]);
141       if (dimNotFinal < 1) {
142         dimNotFinal = 100;
143       }
144       nCPU = Integer.parseInt(args[1]);
145       if (nCPU < 1) {
146         nCPU = ParallelTeam.getDefaultThreadCount();
147       }
148       reps = Integer.parseInt(args[2]);
149       if (reps < 1) {
150         reps = 5;
151       }
152     } catch (Exception e) {
153       //
154     }
155     if (dimNotFinal % 2 != 0) {
156       dimNotFinal++;
157     }
158     final int dim = dimNotFinal;
159     System.out.printf("Initializing a %d cubed grid for %d CPUs.\n"
160         + "The best timing out of %d repetitions will be used.%n", dim, nCPU, reps);
161 
162     Real3D real3D = new Real3D(dim, dim, dim);
163     ParallelTeam parallelTeam = new ParallelTeam(nCPU);
164     Real3DParallel real3DParallel = new Real3DParallel(dim, dim, dim, parallelTeam);
165 
166     final int dimCubed = (dim + 2) * dim * dim;
167     final double[] data = new double[dimCubed];
168     final double[] work = new double[dimCubed];
169 
170     // Parallel Array Initialization.
171     try {
172       parallelTeam.execute(
173           new ParallelRegion() {
174             @Override
175             public void run() {
176               try {
177                 execute(
178                     0,
179                     dim - 1,
180                     new IntegerForLoop() {
181                       @Override
182                       public void run(int lb, int ub) {
183                         Random randomNumberGenerator = new Random(1);
184                         int index = dim * dim * lb;
185                         for (int z = lb; z <= ub; z++) {
186                           for (int y = 0; y < dim; y++) {
187                             for (int x = 0; x < dim; x++) {
188                               double randomNumber = randomNumberGenerator.nextDouble();
189                               data[index] = randomNumber;
190                               index++;
191                             }
192                           }
193                         }
194                       }
195                     });
196               } catch (Exception e) {
197                 System.out.println(e.getMessage());
198                 System.exit(-1);
199               }
200             }
201           });
202     } catch (Exception e) {
203       System.out.println(e.getMessage());
204       System.exit(-1);
205     }
206 
207     double toSeconds = 0.000000001;
208     long parTime = Long.MAX_VALUE;
209     long seqTime = Long.MAX_VALUE;
210     real3D.setRecip(work);
211     real3DParallel.setRecip(work);
212     for (int i = 0; i < reps; i++) {
213       System.out.printf("Iteration %d%n", i + 1);
214       long time = System.nanoTime();
215       real3D.fft(data);
216       real3D.ifft(data);
217       time = (System.nanoTime() - time);
218       System.out.printf("Sequential: %8.3f%n", toSeconds * time);
219       if (time < seqTime) {
220         seqTime = time;
221       }
222       time = System.nanoTime();
223       real3D.convolution(data);
224       time = (System.nanoTime() - time);
225       System.out.printf("Sequential: %8.3f (Convolution)%n", toSeconds * time);
226       if (time < seqTime) {
227         seqTime = time;
228       }
229       time = System.nanoTime();
230       real3DParallel.fft(data);
231       real3DParallel.ifft(data);
232       time = (System.nanoTime() - time);
233       System.out.printf("Parallel:   %8.3f%n", toSeconds * time);
234       if (time < parTime) {
235         parTime = time;
236       }
237       time = System.nanoTime();
238       real3DParallel.convolution(data);
239       time = (System.nanoTime() - time);
240       System.out.printf("Parallel:   %8.3f (Convolution)\n%n", toSeconds * time);
241       if (time < parTime) {
242         parTime = time;
243       }
244     }
245     System.out.printf("Best Sequential Time:  %8.3f%n", toSeconds * seqTime);
246     System.out.printf("Best Parallel Time:    %8.3f%n", toSeconds * parTime);
247     System.out.printf("Speedup: %15.5f%n", (double) seqTime / parTime);
248   }
249 
250   /**
251    * Compute a convolution in parallel.
252    *
253    * @param input The input array must be of size (nX + 2) * nY * nZ.
254    * @since 1.0
255    */
256   public void convolution(final double[] input) {
257     parallelConvolution.input = input;
258     try {
259       parallelTeam.execute(parallelConvolution);
260     } catch (Exception e) {
261       String message = "Fatal exception evaluating a 3D convolution in parallel.\n";
262       logger.log(Level.SEVERE, message, e);
263       System.exit(-1);
264     }
265   }
266 
267   /**
268    * Compute the 3D FFT.
269    *
270    * @param input The input array must be of size (nX + 2) * nY * nZ.
271    * @since 1.0
272    */
273   public void fft(final double[] input) {
274     parallelFFT.input = input;
275     try {
276       parallelTeam.execute(parallelFFT);
277     } catch (Exception e) {
278       String message = "Fatal exception evaluating real 3D FFT in parallel.\n";
279       logger.log(Level.SEVERE, message, e);
280       System.exit(-1);
281     }
282   }
283 
284   /**
285    * Compute the inverse 3D FFT.
286    *
287    * @param input The input array must be of size (nX + 2) * nY * nZ.
288    * @since 1.0
289    */
290   public void ifft(final double[] input) {
291     parallelIFFT.input = input;
292     try {
293       parallelTeam.execute(parallelIFFT);
294     } catch (Exception e) {
295       String message = "Fatal exception evaluating real 3D inverse FFT in parallel.\n";
296       logger.log(Level.SEVERE, message, e);
297       System.exit(-1);
298     }
299   }
300 
301   /**
302    * Setter for the field <code>recip</code>.
303    *
304    * @param recip The recip array must be of size [(nX/2 + 1) * nY * nZ].
305    */
306   public void setRecip(double[] recip) {
307     // Reorder the reciprocal space data into the order it is needed by the convolution routine.
308     for (int index = 0, offset = 0, y = 0; y < nY; y++) {
309       for (int x = 0; x < nX1; x++, offset += 1) {
310         for (int i = 0, z = offset; i < nZ; i++, z += nX1 * nY) {
311           this.recip[index++] = recip[z];
312         }
313       }
314     }
315   }
316 
317   /**
318    * Implement the 3D parallel FFT.
319    *
320    * @author Michael J. Schnieders
321    * @since 1.0
322    */
323   private class ParallelFFT extends ParallelRegion {
324 
325     private final int nZm1;
326     private final FFTXYLoop[] fftXYLoop;
327     private final FFTZLoop[] fftZLoop;
328     public double[] input;
329 
330     private ParallelFFT() {
331       nZm1 = nZ - 1;
332       fftXYLoop = new FFTXYLoop[threadCount];
333       fftZLoop = new FFTZLoop[threadCount];
334       for (int i = 0; i < threadCount; i++) {
335         fftXYLoop[i] = new FFTXYLoop();
336         fftZLoop[i] = new FFTZLoop();
337       }
338     }
339 
340     @Override
341     public void run() {
342       int threadIndex = getThreadIndex();
343       fftXYLoop[threadIndex].input = input;
344       fftZLoop[threadIndex].input = input;
345       try {
346         execute(0, nZm1, fftXYLoop[threadIndex]);
347         // There are nX + 1 frequencies.
348         execute(0, nX, fftZLoop[threadIndex]);
349       } catch (Exception e) {
350         logger.severe(e.toString());
351       }
352     }
353 
354     private class FFTXYLoop extends IntegerForLoop {
355 
356       private final Real fftX;
357       private final Complex fftY;
358       public double[] input;
359 
360       private FFTXYLoop() {
361         fftY = new Complex(nY);
362         fftX = new Real(n);
363       }
364 
365       @Override
366       public void run(final int lb, final int ub) {
367         for (int z = lb; z <= ub; z++) {
368           for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
369             fftX.fft(input, offset);
370           }
371           for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
372             fftY.fft(input, offset, nextY);
373           }
374         }
375       }
376 
377       @Override
378       public IntegerSchedule schedule() {
379         return schedule;
380       }
381     }
382 
383     private class FFTZLoop extends IntegerForLoop {
384 
385       private final double[] work;
386       private final Complex fft;
387       public double[] input;
388 
389       private FFTZLoop() {
390         work = new double[nZ2];
391         fft = new Complex(nZ);
392       }
393 
394       @Override
395       public void run(final int lb, final int ub) {
396         for (int x = lb; x <= ub; x++) {
397           for (int offset = x * 2, y = 0; y < nY; y++, offset += nextY) {
398             for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
399               work[i] = input[z];
400               work[i + 1] = input[z + 1];
401             }
402             fft.fft(work, 0, 2);
403             for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
404               input[z] = work[i];
405               input[z + 1] = work[i + 1];
406             }
407           }
408         }
409       }
410 
411       @Override
412       public IntegerSchedule schedule() {
413         return schedule;
414       }
415     }
416   }
417 
418   /**
419    * Implement the 3D parallel inverse FFT.
420    *
421    * @author Michael J. Schnieders
422    * @since 1.0
423    */
424   private class ParallelIFFT extends ParallelRegion {
425 
426     private final int nZm1;
427     private final IFFTXYLoop[] ifftXYLoop;
428     private final IFFTZLoop[] ifftZLoop;
429     public double[] input;
430 
431     private ParallelIFFT() {
432       nZm1 = nZ - 1;
433       ifftXYLoop = new IFFTXYLoop[threadCount];
434       ifftZLoop = new IFFTZLoop[threadCount];
435       for (int i = 0; i < threadCount; i++) {
436         ifftXYLoop[i] = new IFFTXYLoop();
437         ifftZLoop[i] = new IFFTZLoop();
438       }
439     }
440 
441     @Override
442     public void run() {
443       int threadIndex = getThreadIndex();
444       ifftXYLoop[threadIndex].input = input;
445       ifftZLoop[threadIndex].input = input;
446       try {
447         // There are xDim + 1 frequencies.
448         execute(0, nX, ifftZLoop[threadIndex]);
449         execute(0, nZm1, ifftXYLoop[threadIndex]);
450       } catch (Exception e) {
451         logger.severe(e.toString());
452       }
453     }
454 
455     private class IFFTZLoop extends IntegerForLoop {
456 
457       private final double[] work;
458       private final Complex fft;
459       public double[] input;
460 
461       private IFFTZLoop() {
462         fft = new Complex(nZ);
463         work = new double[nZ2];
464       }
465 
466       @Override
467       public void run(final int lb, final int ub) {
468         for (int x = lb; x <= ub; x++) {
469           for (int offset = x * 2, y = 0; y < nY; y++, offset += nextY) {
470             for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
471               work[i] = input[z];
472               work[i + 1] = input[z + 1];
473             }
474             fft.ifft(work, 0, 2);
475             for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
476               input[z] = work[i];
477               input[z + 1] = work[i + 1];
478             }
479           }
480         }
481       }
482 
483       @Override
484       public IntegerSchedule schedule() {
485         return schedule;
486       }
487     }
488 
489     private class IFFTXYLoop extends IntegerForLoop {
490 
491       private final Real fftX;
492       private final Complex fftY;
493       public double[] input;
494 
495       private IFFTXYLoop() {
496         fftX = new Real(n);
497         fftY = new Complex(nY);
498       }
499 
500       @Override
501       public void run(final int lb, final int ub) {
502         for (int z = lb; z <= ub; z++) {
503           for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
504             fftY.ifft(input, offset, nextY);
505           }
506           for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
507             fftX.ifft(input, offset);
508           }
509         }
510       }
511 
512       @Override
513       public IntegerSchedule schedule() {
514         return schedule;
515       }
516     }
517   }
518 
519   /**
520    * Implement the 3D parallel convolution.
521    *
522    * @author Michael J. Schnieders
523    * @since 1.0
524    */
525   private class ParallelConvolution extends ParallelRegion {
526 
527     private final int nZm1, nYm1, nX1nZ;
528     private final FFTXYLoop[] fftXYLoop;
529     private final FFTZ_Multiply_IFFTZLoop[] fftZ_Multiply_ifftZLoop;
530     private final IFFTXYLoop[] ifftXYLoop;
531     public double[] input;
532 
533     private ParallelConvolution() {
534       nZm1 = nZ - 1;
535       nYm1 = nY - 1;
536       nX1nZ = nX1 * nZ;
537       fftXYLoop = new FFTXYLoop[threadCount];
538       fftZ_Multiply_ifftZLoop = new FFTZ_Multiply_IFFTZLoop[threadCount];
539       ifftXYLoop = new IFFTXYLoop[threadCount];
540       for (int i = 0; i < threadCount; i++) {
541         fftXYLoop[i] = new FFTXYLoop();
542         fftZ_Multiply_ifftZLoop[i] = new FFTZ_Multiply_IFFTZLoop();
543         ifftXYLoop[i] = new IFFTXYLoop();
544       }
545     }
546 
547     @Override
548     public void run() {
549       int threadIndex = getThreadIndex();
550       fftXYLoop[threadIndex].input = input;
551       fftZ_Multiply_ifftZLoop[threadIndex].input = input;
552       ifftXYLoop[threadIndex].input = input;
553       try {
554         execute(0, nZm1, fftXYLoop[threadIndex]);
555         execute(0, nYm1, fftZ_Multiply_ifftZLoop[threadIndex]);
556         execute(0, nZm1, ifftXYLoop[threadIndex]);
557       } catch (Exception e) {
558         logger.severe(e.toString());
559       }
560     }
561 
562     private class FFTXYLoop extends IntegerForLoop {
563 
564       private final Real fftX;
565       private final Complex fftY;
566       public double[] input;
567 
568       private FFTXYLoop() {
569         fftY = new Complex(nY);
570         fftX = new Real(n);
571       }
572 
573       @Override
574       public void run(final int lb, final int ub) {
575         for (int z = lb; z <= ub; z++) {
576           for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
577             fftX.fft(input, offset);
578           }
579           for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
580             fftY.fft(input, offset, nextY);
581           }
582         }
583       }
584 
585       @Override
586       public IntegerSchedule schedule() {
587         return schedule;
588       }
589     }
590 
591     private class FFTZ_Multiply_IFFTZLoop extends IntegerForLoop {
592 
593       private final double[] work;
594       private final Complex fft;
595       public double[] input;
596 
597       private FFTZ_Multiply_IFFTZLoop() {
598         work = new double[nZ2];
599         fft = new Complex(nZ);
600       }
601 
602       @Override
603       public void run(final int lb, final int ub) {
604         int index = lb * nX1nZ;
605         for (int offset = lb * nextY, y = lb; y <= ub; y++) {
606           for (int x = 0; x < nX1; x++, offset += nextX) {
607             for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
608               work[i] = input[z];
609               work[i + 1] = input[z + 1];
610             }
611             fft.fft(work, 0, 2);
612             for (int i = 0; i < nZ2; i += 2) {
613               double r = recip[index++];
614               work[i] *= r;
615               work[i + 1] *= r;
616             }
617             fft.ifft(work, 0, 2);
618             for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
619               input[z] = work[i];
620               input[z + 1] = work[i + 1];
621             }
622           }
623         }
624       }
625 
626       @Override
627       public IntegerSchedule schedule() {
628         return schedule;
629       }
630     }
631 
632     private class IFFTXYLoop extends IntegerForLoop {
633 
634       private final Real fftX;
635       private final Complex fftY;
636       public double[] input;
637 
638       private IFFTXYLoop() {
639         fftY = new Complex(nY);
640         fftX = new Real(n);
641       }
642 
643       @Override
644       public void run(final int lb, final int ub) {
645         for (int z = lb; z <= ub; z++) {
646           for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
647             fftY.ifft(input, offset, nextY);
648           }
649           for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
650             fftX.ifft(input, offset);
651           }
652         }
653       }
654 
655       @Override
656       public IntegerSchedule schedule() {
657         return schedule;
658       }
659     }
660   }
661 }