View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.fft;
39  
40  import jdk.incubator.vector.DoubleVector;
41  
42  import static jdk.incubator.vector.DoubleVector.SPECIES_128;
43  import static jdk.incubator.vector.DoubleVector.SPECIES_256;
44  import static jdk.incubator.vector.DoubleVector.SPECIES_512;
45  import static jdk.incubator.vector.DoubleVector.broadcast;
46  import static jdk.incubator.vector.DoubleVector.fromArray;
47  
48  /**
49   * The MixedRadixFactor2 class handles factors of 2 in the FFT.
50   */
51  public class MixedRadixFactor2 extends MixedRadixFactor {
52  
53    /**
54     * Create a new MixedRadixFactor2 instance.
55     *
56     * @param passConstants The pass constants.
57     */
58    public MixedRadixFactor2(PassConstants passConstants) {
59      super(passConstants);
60    }
61  
62    /**
63     * Handle factors of 2.
64     *
65     * @param passData The pass data.
66     */
67    @Override
68    protected void passScalar(PassData passData) {
69      final double[] data = passData.in;
70      final double[] ret = passData.out;
71      int sign = passData.sign;
72      int i = passData.inOffset;
73      int j = passData.outOffset;
74      // First pass of the 2-point FFT has no twiddle factors.
75      for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
76        final double z0_r = data[i];
77        final double z0_i = data[i + im];
78        final int idi = i + di;
79        final double z1_r = data[idi];
80        final double z1_i = data[idi + im];
81        ret[j] = z0_r + z1_r;
82        ret[j + im] = z0_i + z1_i;
83        final double x_r = z0_r - z1_r;
84        final double x_i = z0_i - z1_i;
85        final int jdj = j + dj;
86        ret[jdj] = x_r;
87        ret[jdj + im] = x_i;
88      }
89      j += dj;
90      for (int k = 1; k < outerLoopLimit; k++, j += dj) {
91        final double w_r = wr[k];
92        final double w_i = -sign * wi[k];
93        for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
94          final double z0_r = data[i];
95          final double z0_i = data[i + im];
96          final int idi = i + di;
97          final double z1_r = data[idi];
98          final double z1_i = data[idi + im];
99          ret[j] = z0_r + z1_r;
100         ret[j + im] = z0_i + z1_i;
101         final int jdj = j + dj;
102         multiplyAndStore(z0_r - z1_r, z0_i - z1_i, w_r, w_i, ret, jdj, jdj + im);
103       }
104     }
105   }
106 
107   /**
108    * Handle factors of 2 using SIMD vectors.
109    *
110    * @param passData The pass data.
111    */
112   @Override
113   protected void passSIMD(PassData passData) {
114     if (im == 1) {
115       interleaved(passData);
116     } else {
117       blocked(passData);
118     }
119   }
120 
121   /**
122    * Available SIMD sizes for Pass 2.
123    */
124   private static int[] simdSizes = {8, 4, 2};
125 
126   /**
127    * Handle factors of 2 using the chosen SIMD vector.
128    *
129    * @param passData   The pass data.
130    * @param simdLength The SIMD vector length.
131    */
132   private void interleaved(PassData passData, int simdLength) {
133     // Use the preferred SIMD vector.
134     switch (simdLength) {
135       case 2:
136         // 1 complex number per loop iteration.
137         interleaved128(passData);
138         break;
139       case 4:
140         // 2 complex numbers per loop iteration.
141         interleaved256(passData);
142         break;
143       case 8:
144         // 4 complex numbers per loop iteration.
145         interleaved512(passData);
146         break;
147       default:
148         passScalar(passData);
149     }
150   }
151 
152   /**
153    * Handle factors of 2 using the chosen SIMD vector.
154    *
155    * @param passData   The pass data.
156    * @param simdLength The SIMD vector length.
157    */
158   private void blocked(PassData passData, int simdLength) {
159     // Use the preferred SIMD vector.
160     switch (simdLength) {
161       case 2:
162         // 2 complex numbers per loop iteration.
163         blocked128(passData);
164         break;
165       case 4:
166         // 4 complex numbers per loop iteration.
167         blocked256(passData);
168         break;
169       case 8:
170         // 8 complex numbers per loop iteration.
171         blocked512(passData);
172         break;
173       default:
174         passScalar(passData);
175     }
176   }
177 
178   /**
179    * Handle factors of 2 using the SIMD vectors.
180    *
181    * @param passData The interleaved pass data.
182    */
183   private void interleaved(PassData passData) {
184     if (innerLoopLimit % INTERLEAVED_LOOP == 0) {
185       // Use the preferred SIMD vector.
186       interleaved(passData, LENGTH);
187     } else {
188       // If the inner loop limit is odd, use the scalar method unless the inner loop limit is 1.
189       if (innerLoopLimit % 2 != 0 && innerLoopLimit != 1) {
190         passScalar(passData);
191         return;
192       }
193       // Fall back to a smaller SIMD vector that fits the inner loop limit.
194       for (int size : simdSizes) {
195         if (size >= LENGTH) {
196           // Skip anything greater than or equal to the preferred SIMD vector size (which was too big).
197           continue;
198         }
199         // Divide the SIMD size by two because for interleaved a single SIMD vectors stores both real and imaginary parts.
200         if (innerLoopLimit % (size / 2) == 0) {
201           interleaved(passData, size);
202         }
203       }
204     }
205   }
206 
207   /**
208    * Handle factors of 2 using the SIMD vectors.
209    *
210    * @param passData The pass blocked data.
211    */
212   private void blocked(PassData passData) {
213     if (innerLoopLimit % BLOCK_LOOP == 0) {
214       // Use the preferred SIMD vector.
215       blocked(passData, LENGTH);
216     } else {
217       // If the inner loop limit is odd, use the scalar method unless the inner loop limit is 1.
218       if (innerLoopLimit % 2 != 0) {
219         passScalar(passData);
220         return;
221       }
222       // Fall back to a smaller SIMD vector that fits the inner loop limit.
223       for (int size : simdSizes) {
224         if (size >= LENGTH) {
225           // Skip anything greater than or equal to the preferred SIMD vector size (which was too big).
226           continue;
227         }
228         // Divide the SIMD size by two because for interleaved a single SIMD vectors stores both real and imaginary parts.
229         if (innerLoopLimit % size == 0) {
230           blocked(passData, size);
231         }
232       }
233     }
234   }
235 
236   /**
237    * Handle factors of 2 using the 128-bit SIMD vectors.
238    */
239   private void blocked128(PassData passData) {
240     final double[] data = passData.in;
241     final double[] ret = passData.out;
242     int sign = passData.sign;
243     int i = passData.inOffset;
244     int j = passData.outOffset;
245     // First pass of the 2-point FFT has no twiddle factors.
246     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
247       DoubleVector
248           z0_r = fromArray(SPECIES_128, data, i),
249           z1_r = fromArray(SPECIES_128, data, i + di),
250           z0_i = fromArray(SPECIES_128, data, i + im),
251           z1_i = fromArray(SPECIES_128, data, i + di + im);
252       z0_r.add(z1_r).intoArray(ret, j);
253       z0_i.add(z1_i).intoArray(ret, j + im);
254       z0_r.sub(z1_r).intoArray(ret, j + dj);
255       z0_i.sub(z1_i).intoArray(ret, j + dj + im);
256     }
257 
258     j += dj;
259     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
260       final DoubleVector
261           w_r = broadcast(SPECIES_128, wr[k]),
262           w_i = broadcast(SPECIES_128, -sign * wi[k]);
263       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
264         final DoubleVector
265             z0_r = fromArray(SPECIES_128, data, i),
266             z1_r = fromArray(SPECIES_128, data, i + di),
267             z0_i = fromArray(SPECIES_128, data, i + im),
268             z1_i = fromArray(SPECIES_128, data, i + di + im);
269         z0_r.add(z1_r).intoArray(ret, j);
270         z0_i.add(z1_i).intoArray(ret, j + im);
271         DoubleVector x_r = z0_r.sub(z1_r), x_i = z0_i.sub(z1_i);
272         w_r.mul(x_r).add(w_i.neg().mul(x_i)).intoArray(ret, j + dj);
273         w_r.mul(x_i).add(w_i.mul(x_r)).intoArray(ret, j + dj + im);
274       }
275     }
276   }
277 
278   /**
279    * Handle factors of 2 using the 256-bit SIMD vectors.
280    */
281   private void blocked256(PassData passData) {
282     final double[] data = passData.in;
283     final double[] ret = passData.out;
284     int sign = passData.sign;
285     int i = passData.inOffset;
286     int j = passData.outOffset;
287     // First pass of the 2-point FFT has no twiddle factors.
288     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
289       DoubleVector
290           z0_r = fromArray(SPECIES_256, data, i),
291           z1_r = fromArray(SPECIES_256, data, i + di),
292           z0_i = fromArray(SPECIES_256, data, i + im),
293           z1_i = fromArray(SPECIES_256, data, i + di + im);
294       z0_r.add(z1_r).intoArray(ret, j);
295       z0_i.add(z1_i).intoArray(ret, j + im);
296       z0_r.sub(z1_r).intoArray(ret, j + dj);
297       z0_i.sub(z1_i).intoArray(ret, j + dj + im);
298     }
299 
300     j += dj;
301     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
302       final DoubleVector
303           w_r = broadcast(SPECIES_256, wr[k]),
304           w_i = broadcast(SPECIES_256, -sign * wi[k]);
305       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
306         DoubleVector
307             z0_r = fromArray(SPECIES_256, data, i),
308             z1_r = fromArray(SPECIES_256, data, i + di),
309             z0_i = fromArray(SPECIES_256, data, i + im),
310             z1_i = fromArray(SPECIES_256, data, i + di + im);
311         z0_r.add(z1_r).intoArray(ret, j);
312         z0_i.add(z1_i).intoArray(ret, j + im);
313         DoubleVector x_r = z0_r.sub(z1_r), x_i = z0_i.sub(z1_i);
314         w_r.mul(x_r).add(w_i.neg().mul(x_i)).intoArray(ret, j + dj);
315         w_r.mul(x_i).add(w_i.mul(x_r)).intoArray(ret, j + dj + im);
316       }
317     }
318   }
319 
320   /**
321    * Handle factors of 2 using the 512-bit SIMD vectors.
322    */
323   private void blocked512(PassData passData) {
324     final double[] data = passData.in;
325     final double[] ret = passData.out;
326     int sign = passData.sign;
327     int i = passData.inOffset;
328     int j = passData.outOffset;
329     // First pass of the 2-point FFT has no twiddle factors.
330     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
331       DoubleVector
332           z0_r = fromArray(SPECIES_512, data, i),
333           z1_r = fromArray(SPECIES_512, data, i + di),
334           z0_i = fromArray(SPECIES_512, data, i + im),
335           z1_i = fromArray(SPECIES_512, data, i + di + im);
336       z0_r.add(z1_r).intoArray(ret, j);
337       z0_i.add(z1_i).intoArray(ret, j + im);
338       z0_r.sub(z1_r).intoArray(ret, j + dj);
339       z0_i.sub(z1_i).intoArray(ret, j + dj + im);
340     }
341 
342     j += dj;
343     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
344       final DoubleVector
345           w_r = broadcast(SPECIES_512, wr[k]),
346           w_i = broadcast(SPECIES_512, -sign * wi[k]);
347       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
348         DoubleVector
349             z0_r = fromArray(SPECIES_512, data, i),
350             z1_r = fromArray(SPECIES_512, data, i + di),
351             z0_i = fromArray(SPECIES_512, data, i + im),
352             z1_i = fromArray(SPECIES_512, data, i + di + im);
353         z0_r.add(z1_r).intoArray(ret, j);
354         z0_i.add(z1_i).intoArray(ret, j + im);
355         DoubleVector x_r = z0_r.sub(z1_r), x_i = z0_i.sub(z1_i);
356         w_r.mul(x_r).add(w_i.neg().mul(x_i)).intoArray(ret, j + dj);
357         w_r.mul(x_i).add(w_i.mul(x_r)).intoArray(ret, j + dj + im);
358       }
359     }
360   }
361 
362   /**
363    * Handle factors of 2 using the 128-bit SIMD vectors.
364    */
365   private void interleaved128(PassData passData) {
366     final double[] data = passData.in;
367     final double[] ret = passData.out;
368     int sign = passData.sign;
369     int i = passData.inOffset;
370     int j = passData.outOffset;
371     // First pass of the 2-point FFT has no twiddle factors.
372     for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
373       DoubleVector
374           z0 = fromArray(SPECIES_128, data, i),
375           z1 = fromArray(SPECIES_128, data, i + di);
376       z0.add(z1).intoArray(ret, j);
377       z0.sub(z1).intoArray(ret, j + dj);
378     }
379 
380     j += dj;
381     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
382       final DoubleVector
383           w_r = broadcast(SPECIES_128, wr[k]),
384           w_i = broadcast(SPECIES_128, -sign * wi[k]).mul(NEGATE_IM_128);
385       for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
386         DoubleVector
387             z0 = fromArray(SPECIES_128, data, i),
388             z1 = fromArray(SPECIES_128, data, i + di);
389         z0.add(z1).intoArray(ret, j);
390         DoubleVector x = z0.sub(z1);
391         x.mul(w_r).add(x.mul(w_i).rearrange(SHUFFLE_RE_IM_128)).intoArray(ret, j + dj);
392       }
393     }
394   }
395 
396   /**
397    * Handle factors of 2 using the 256-bit SIMD vectors.
398    */
399   private void interleaved256(PassData passData) {
400     final double[] data = passData.in;
401     final double[] ret = passData.out;
402     int sign = passData.sign;
403     int i = passData.inOffset;
404     int j = passData.outOffset;
405     // First pass of the 2-point FFT has no twiddle factors.
406     for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
407       DoubleVector
408           z0 = fromArray(SPECIES_256, data, i),
409           z1 = fromArray(SPECIES_256, data, i + di);
410       z0.add(z1).intoArray(ret, j);
411       z0.sub(z1).intoArray(ret, j + dj);
412     }
413 
414     j += dj;
415     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
416       final DoubleVector
417           w_r = broadcast(SPECIES_256, wr[k]),
418           w_i = broadcast(SPECIES_256, -sign * wi[k]).mul(NEGATE_IM_256);
419       for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
420         DoubleVector
421             z0 = fromArray(SPECIES_256, data, i),
422             z1 = fromArray(SPECIES_256, data, i + di);
423         z0.add(z1).intoArray(ret, j);
424         DoubleVector x = z0.sub(z1);
425         x.mul(w_r).add(x.mul(w_i).rearrange(SHUFFLE_RE_IM_256)).intoArray(ret, j + dj);
426       }
427     }
428   }
429 
430   /**
431    * Handle factors of 2 using the 512-bit SIMD vectors.
432    */
433   private void interleaved512(PassData passData) {
434     final double[] data = passData.in;
435     final double[] ret = passData.out;
436     int sign = passData.sign;
437     int i = passData.inOffset;
438     int j = passData.outOffset;
439     // First pass of the 2-point FFT has no twiddle factors.
440     for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
441       DoubleVector
442           z0 = fromArray(SPECIES_512, data, i),
443           z1 = fromArray(SPECIES_512, data, i + di);
444       z0.add(z1).intoArray(ret, j);
445       z0.sub(z1).intoArray(ret, j + dj);
446     }
447 
448     j += dj;
449     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
450       final DoubleVector
451           w_r = broadcast(SPECIES_512, wr[k]),
452           w_i = broadcast(SPECIES_512, -sign * wi[k]).mul(NEGATE_IM_512);
453       for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
454         DoubleVector
455             z0 = fromArray(SPECIES_512, data, i),
456             z1 = fromArray(SPECIES_512, data, i + di);
457         z0.add(z1).intoArray(ret, j);
458         DoubleVector x = z0.sub(z1);
459         x.mul(w_r).add(x.mul(w_i).rearrange(SHUFFLE_RE_IM_512)).intoArray(ret, j + dj);
460       }
461     }
462   }
463 
464 }