View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2026.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.fft;
39  
40  import jdk.incubator.vector.DoubleVector;
41  import jdk.incubator.vector.VectorShuffle;
42  import jdk.incubator.vector.VectorSpecies;
43  
44  import static jdk.incubator.vector.DoubleVector.SPECIES_128;
45  import static jdk.incubator.vector.DoubleVector.SPECIES_256;
46  import static jdk.incubator.vector.DoubleVector.SPECIES_512;
47  import static jdk.incubator.vector.DoubleVector.broadcast;
48  import static jdk.incubator.vector.DoubleVector.fromArray;
49  
50  /**
51   * The MixedRadixFactor2 class handles factors of 2 in the FFT.
52   */
53  public class MixedRadixFactor2 extends MixedRadixFactor {
54  
55    /**
56     * Available SIMD sizes for Pass 2.
57     */
58    private static int[] simdSizes = {8, 4, 2};
59  
60    /**
61     * Create a new MixedRadixFactor2 instance.
62     *
63     * @param passConstants The pass constants.
64     */
65    public MixedRadixFactor2(PassConstants passConstants) {
66      super(passConstants);
67    }
68  
69    /**
70     * Check if the requested SIMD length is valid.
71     *
72     * @param width Requested SIMD species width.
73     * @return True if this width is supported.
74     */
75    @Override
76    public boolean isValidSIMDWidth(int width) {
77      // Must be a supported width.
78      if (width != 2 && width != 4 && width != 8) {
79        return false;
80      }
81      if (im == 1) {
82        // Interleaved
83        return innerLoopLimit % (width / 2) == 0;
84      } else {
85        // Blocked
86        return innerLoopLimit % width == 0;
87      }
88    }
89  
90    /**
91     * Determine the optimal SIMD width. Currently supported widths are 2, 4 and 8.
92     * If no SIMD width is valid, return 0 to indicate use of the scalar path.
93     *
94     * @return The optimal SIMD width.
95     */
96    @Override
97    public int getOptimalSIMDWidth() {
98      // Check the platform specific preferred width.
99      if (isValidSIMDWidth(LENGTH)) {
100       return LENGTH;
101     }
102     // Fall back to a smaller SIMD vector that fits the inner loop limit.
103     for (int size : simdSizes) {
104       if (size >= LENGTH) {
105         // Skip anything greater than or equal to the preferred SIMD vector size (which was too big).
106         continue;
107       }
108       if (isValidSIMDWidth(size)) {
109         return size;
110       }
111     }
112     // No valid SIMD width is found.
113     return 0;
114   }
115 
116   /**
117    * Handle factors of 2.
118    *
119    * @param passData The pass data.
120    */
121   @Override
122   protected void passScalar(PassData passData) {
123     final double[] data = passData.in;
124     final double[] ret = passData.out;
125     int sign = passData.sign;
126     int i = passData.inOffset;
127     int j = passData.outOffset;
128     // First pass of the 2-point FFT has no twiddle factors.
129     for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
130       final double z0_r = data[i];
131       final double z0_i = data[i + im];
132       final int idi = i + di;
133       final double z1_r = data[idi];
134       final double z1_i = data[idi + im];
135       ret[j] = z0_r + z1_r;
136       ret[j + im] = z0_i + z1_i;
137       final double x_r = z0_r - z1_r;
138       final double x_i = z0_i - z1_i;
139       final int jdj = j + dj;
140       ret[jdj] = x_r;
141       ret[jdj + im] = x_i;
142     }
143     j += dj;
144     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
145       final double w_r = wr[k];
146       final double w_i = -sign * wi[k];
147       for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
148         final double z0_r = data[i];
149         final double z0_i = data[i + im];
150         final int idi = i + di;
151         final double z1_r = data[idi];
152         final double z1_i = data[idi + im];
153         ret[j] = z0_r + z1_r;
154         ret[j + im] = z0_i + z1_i;
155         final int jdj = j + dj;
156         multiplyAndStore(z0_r - z1_r, z0_i - z1_i, w_r, w_i, ret, jdj, jdj + im);
157       }
158     }
159   }
160 
161   /**
162    * Handle factors of 2 using SIMD vectors.
163    *
164    * @param passData The pass data.
165    */
166   @Override
167   protected void passSIMD(PassData passData) {
168     if (!isValidSIMDWidth(simdWidth)) {
169       passScalar(passData);
170     } else {
171       if (im == 1) {
172         interleaved(passData, simdWidth);
173       } else {
174         blocked(passData, simdWidth);
175       }
176     }
177   }
178 
179   /**
180    * Handle factors of 2 using the chosen SIMD vector.
181    *
182    * @param passData   The pass data.
183    * @param simdLength The SIMD vector length.
184    */
185   private void interleaved(PassData passData, int simdLength) {
186     // Use the preferred SIMD vector.
187     switch (simdLength) {
188       case 2:
189         // 1 complex number per loop iteration.
190         interleaved128(passData);
191         break;
192       case 4:
193         // 2 complex numbers per loop iteration.
194         interleaved256(passData);
195         break;
196       case 8:
197         // 4 complex numbers per loop iteration.
198         interleaved512(passData);
199         break;
200       default:
201         passScalar(passData);
202     }
203   }
204 
205   /**
206    * Handle factors of 2 using the chosen SIMD vector.
207    *
208    * @param passData   The pass data.
209    * @param simdLength The SIMD vector length.
210    */
211   private void blocked(PassData passData, int simdLength) {
212     // Use the preferred SIMD vector.
213     switch (simdLength) {
214       case 2:
215         // 2 complex numbers per loop iteration.
216         blocked128(passData);
217         break;
218       case 4:
219         // 4 complex numbers per loop iteration.
220         blocked256(passData);
221         break;
222       case 8:
223         // 8 complex numbers per loop iteration.
224         blocked512(passData);
225         break;
226       default:
227         passScalar(passData);
228     }
229   }
230 
231   /**
232    * ButterFly for radix-2 with blocked data.
233    *
234    * @param data The input array to retrieve data from.
235    * @param i    The index to read the first real input.
236    * @param w_r  The real part of the twiddle factor.
237    * @param w_i  The imaginary part of the twiddle factor.
238    * @param ret  The array to store into.
239    * @param j    The index to store the first real output.
240    */
241   private void butterFly2Blocked(
242       VectorSpecies<Double> species,
243       double[] data, int i, double w_r, double w_i, double[] ret, int j) {
244     final DoubleVector
245         z0_r = fromArray(species, data, i),
246         z0_i = fromArray(species, data, i + im),
247         z1_r = fromArray(species, data, i + di),
248         z1_i = fromArray(species, data, i + di + im);
249     z0_r.add(z1_r).intoArray(ret, j);
250     z0_i.add(z1_i).intoArray(ret, j + im);
251     final DoubleVector
252         x_r = z0_r.sub(z1_r),
253         x_i = z0_i.sub(z1_i);
254     x_r.mul(w_r).sub(x_i.mul(w_i)).intoArray(ret, j + dj);
255     x_i.mul(w_r).add(x_r.mul(w_i)).intoArray(ret, j + dj + im);
256   }
257 
258   /**
259    * ButterFly for radix-2 with blocked data.
260    *
261    * @param z0  The first vector of real + imaginary data.
262    * @param z1  The second vector of real + imaginary data.
263    * @param w_r The real part of the twiddle factor.
264    * @param w_i The imaginary part of the twiddle factor.
265    * @param j   The index to store the first real output.
266    * @param ret The array to store into.
267    */
268   private void butterFly2Interleaved(
269       DoubleVector z0, DoubleVector z1,
270       DoubleVector w_r, DoubleVector w_i,
271       VectorShuffle<Double> shuffle_re_im,
272       int j, double[] ret) {
273     z0.add(z1).intoArray(ret, j);
274     final DoubleVector x = z0.sub(z1);
275     x.mul(w_r).add(x.mul(w_i).rearrange(shuffle_re_im)).intoArray(ret, j + dj);
276   }
277 
278   /**
279    * Handle factors of 2 using the 128-bit SIMD vectors.
280    */
281   private void blocked128(PassData passData) {
282     final double[] data = passData.in;
283     final double[] ret = passData.out;
284     final int sign = passData.sign;
285     int i = passData.inOffset;
286     int j = passData.outOffset;
287     // First pass of the 2-point FFT has no twiddle factors.
288     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
289       final DoubleVector
290           z0_r = fromArray(SPECIES_128, data, i),
291           z0_i = fromArray(SPECIES_128, data, i + im),
292           z1_r = fromArray(SPECIES_128, data, i + di),
293           z1_i = fromArray(SPECIES_128, data, i + di + im);
294       z0_r.add(z1_r).intoArray(ret, j);
295       z0_i.add(z1_i).intoArray(ret, j + im);
296       z0_r.sub(z1_r).intoArray(ret, j + dj);
297       z0_i.sub(z1_i).intoArray(ret, j + dj + im);
298     }
299 
300     j += dj;
301     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
302       final double
303           w_r = wr[k],
304           w_i = -sign * wi[k];
305       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
306         final DoubleVector
307             z0_r = fromArray(SPECIES_128, data, i),
308             z0_i = fromArray(SPECIES_128, data, i + im),
309             z1_r = fromArray(SPECIES_128, data, i + di),
310             z1_i = fromArray(SPECIES_128, data, i + di + im);
311         z0_r.add(z1_r).intoArray(ret, j);
312         z0_i.add(z1_i).intoArray(ret, j + im);
313         final DoubleVector
314             x_r = z0_r.sub(z1_r),
315             x_i = z0_i.sub(z1_i);
316         x_r.mul(w_r).sub(x_i.mul(w_i)).intoArray(ret, j + dj);
317         x_i.mul(w_r).add(x_r.mul(w_i)).intoArray(ret, j + dj + im);
318       }
319     }
320   }
321 
322   /**
323    * Handle factors of 2 using the 256-bit SIMD vectors.
324    */
325   private void blocked256(PassData passData) {
326     final double[] data = passData.in;
327     final double[] ret = passData.out;
328     final int sign = passData.sign;
329     int i = passData.inOffset;
330     int j = passData.outOffset;
331     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
332       final DoubleVector
333           z0_r = fromArray(SPECIES_256, data, i),
334           z0_i = fromArray(SPECIES_256, data, i + im),
335           z1_r = fromArray(SPECIES_256, data, i + di),
336           z1_i = fromArray(SPECIES_256, data, i + di + im);
337       z0_r.add(z1_r).intoArray(ret, j);
338       z0_i.add(z1_i).intoArray(ret, j + im);
339       z0_r.sub(z1_r).intoArray(ret, j + dj);
340       z0_i.sub(z1_i).intoArray(ret, j + dj + im);
341     }
342 
343     j += dj;
344     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
345       final double
346           w_r = wr[k],
347           w_i = -sign * wi[k];
348       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
349         final DoubleVector
350             z0_r = fromArray(SPECIES_256, data, i),
351             z0_i = fromArray(SPECIES_256, data, i + im),
352             z1_r = fromArray(SPECIES_256, data, i + di),
353             z1_i = fromArray(SPECIES_256, data, i + di + im);
354         z0_r.add(z1_r).intoArray(ret, j);
355         z0_i.add(z1_i).intoArray(ret, j + im);
356         final DoubleVector
357             x_r = z0_r.sub(z1_r),
358             x_i = z0_i.sub(z1_i);
359         x_r.mul(w_r).sub(x_i.mul(w_i)).intoArray(ret, j + dj);
360         x_i.mul(w_r).add(x_r.mul(w_i)).intoArray(ret, j + dj + im);
361       }
362     }
363   }
364 
365   /**
366    * Handle factors of 2 using the 512-bit SIMD vectors.
367    */
368   private void blocked512(PassData passData) {
369     final double[] data = passData.in;
370     final double[] ret = passData.out;
371     final int sign = passData.sign;
372     int i = passData.inOffset;
373     int j = passData.outOffset;
374     // First pass of the 2-point FFT has no twiddle factors.
375     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
376       final DoubleVector
377           z0_r = fromArray(SPECIES_512, data, i),
378           z0_i = fromArray(SPECIES_512, data, i + im),
379           z1_r = fromArray(SPECIES_512, data, i + di),
380           z1_i = fromArray(SPECIES_512, data, i + di + im);
381       z0_r.add(z1_r).intoArray(ret, j);
382       z0_i.add(z1_i).intoArray(ret, j + im);
383       z0_r.sub(z1_r).intoArray(ret, j + dj);
384       z0_i.sub(z1_i).intoArray(ret, j + dj + im);
385     }
386 
387     j += dj;
388     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
389       final double
390           w_r = wr[k],
391           w_i = -sign * wi[k];
392       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
393         final DoubleVector
394             z0_r = fromArray(SPECIES_512, data, i),
395             z1_r = fromArray(SPECIES_512, data, i + di),
396             z0_i = fromArray(SPECIES_512, data, i + im),
397             z1_i = fromArray(SPECIES_512, data, i + di + im);
398         z0_r.add(z1_r).intoArray(ret, j);
399         z0_i.add(z1_i).intoArray(ret, j + im);
400         final DoubleVector
401             x_r = z0_r.sub(z1_r),
402             x_i = z0_i.sub(z1_i);
403         x_r.mul(w_r).sub(x_i.mul(w_i)).intoArray(ret, j + dj);
404         x_i.mul(w_r).add(x_r.mul(w_i)).intoArray(ret, j + dj + im);
405       }
406     }
407   }
408 
409   /**
410    * Handle factors of 2 using the 128-bit SIMD vectors.
411    */
412   private void interleaved128(PassData passData) {
413     final double[] data = passData.in;
414     final double[] ret = passData.out;
415     final int sign = passData.sign;
416     int i = passData.inOffset;
417     int j = passData.outOffset;
418     // First pass of the 2-point FFT has no twiddle factors.
419     for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
420       final DoubleVector
421           z0 = fromArray(SPECIES_128, data, i),
422           z1 = fromArray(SPECIES_128, data, i + di);
423       z0.add(z1).intoArray(ret, j);
424       z0.sub(z1).intoArray(ret, j + dj);
425     }
426 
427     j += dj;
428     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
429       final DoubleVector
430           w_r = broadcast(SPECIES_128, wr[k]),
431           w_i = broadcast(SPECIES_128, -sign * wi[k]).mul(NEGATE_IM_128);
432       for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
433         final DoubleVector
434             z0 = fromArray(SPECIES_128, data, i),
435             z1 = fromArray(SPECIES_128, data, i + di);
436         z0.add(z1).intoArray(ret, j);
437         final DoubleVector x = z0.sub(z1);
438         x.fma(w_r, x.mul(w_i).rearrange(SHUFFLE_RE_IM_128)).intoArray(ret, j + dj);
439       }
440     }
441   }
442 
443   /**
444    * Handle factors of 2 using the 256-bit SIMD vectors.
445    */
446   private void interleaved256(PassData passData) {
447     final double[] data = passData.in;
448     final double[] ret = passData.out;
449     final int sign = passData.sign;
450     int i = passData.inOffset;
451     int j = passData.outOffset;
452     // First pass of the 2-point FFT has no twiddle factors.
453     for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
454       final DoubleVector
455           z0 = fromArray(SPECIES_256, data, i),
456           z1 = fromArray(SPECIES_256, data, i + di);
457       z0.add(z1).intoArray(ret, j);
458       z0.sub(z1).intoArray(ret, j + dj);
459     }
460 
461     j += dj;
462     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
463       final DoubleVector
464           w_r = broadcast(SPECIES_256, wr[k]),
465           w_i = broadcast(SPECIES_256, -sign * wi[k]).mul(NEGATE_IM_256);
466       for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
467         final DoubleVector
468             z0 = fromArray(SPECIES_256, data, i),
469             z1 = fromArray(SPECIES_256, data, i + di);
470         z0.add(z1).intoArray(ret, j);
471         final DoubleVector x = z0.sub(z1);
472         x.fma(w_r, x.mul(w_i).rearrange(SHUFFLE_RE_IM_256)).intoArray(ret, j + dj);
473       }
474     }
475   }
476 
477   /**
478    * Handle factors of 2 using the 512-bit SIMD vectors.
479    */
480   private void interleaved512(PassData passData) {
481     final double[] data = passData.in;
482     final double[] ret = passData.out;
483     final int sign = passData.sign;
484     int i = passData.inOffset;
485     int j = passData.outOffset;
486     // First pass of the 2-point FFT has no twiddle factors.
487     for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
488       final DoubleVector
489           z0 = fromArray(SPECIES_512, data, i),
490           z1 = fromArray(SPECIES_512, data, i + di);
491       z0.add(z1).intoArray(ret, j);
492       z0.sub(z1).intoArray(ret, j + dj);
493     }
494 
495     j += dj;
496     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
497       final DoubleVector
498           w_r = broadcast(SPECIES_512, wr[k]),
499           w_i = broadcast(SPECIES_512, -sign * wi[k]).mul(NEGATE_IM_512);
500       for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
501         final DoubleVector
502             z0 = fromArray(SPECIES_512, data, i),
503             z1 = fromArray(SPECIES_512, data, i + di);
504         z0.add(z1).intoArray(ret, j);
505         final DoubleVector x = z0.sub(z1);
506         x.fma(w_r, x.mul(w_i).rearrange(SHUFFLE_RE_IM_512)).intoArray(ret, j + dj);
507       }
508     }
509   }
510 
511 }