View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.fft;
39  
40  import jdk.incubator.vector.DoubleVector;
41  
42  import static jdk.incubator.vector.DoubleVector.SPECIES_128;
43  import static jdk.incubator.vector.DoubleVector.SPECIES_256;
44  import static jdk.incubator.vector.DoubleVector.SPECIES_512;
45  import static jdk.incubator.vector.DoubleVector.broadcast;
46  import static jdk.incubator.vector.DoubleVector.fromArray;
47  
48  /**
49   * The MixedRadixFactor2 class handles factors of 2 in the FFT.
50   */
51  public class MixedRadixFactor2 extends MixedRadixFactor {
52  
53    /**
54     * Create a new MixedRadixFactor2 instance.
55     *
56     * @param passConstants The pass constants.
57     */
58    public MixedRadixFactor2(PassConstants passConstants) {
59      super(passConstants);
60    }
61  
62    /**
63     * Handle factors of 2.
64     *
65     * @param passData The pass data.
66     */
67    @Override
68    protected void passScalar(PassData passData) {
69      final double[] data = passData.in;
70      final double[] ret = passData.out;
71      int sign = passData.sign;
72      int i = passData.inOffset;
73      int j = passData.outOffset;
74      // First pass of the 2-point FFT has no twiddle factors.
75      for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
76        final double z0_r = data[i];
77        final double z0_i = data[i + im];
78        final int idi = i + di;
79        final double z1_r = data[idi];
80        final double z1_i = data[idi + im];
81        ret[j] = z0_r + z1_r;
82        ret[j + im] = z0_i + z1_i;
83        final double x_r = z0_r - z1_r;
84        final double x_i = z0_i - z1_i;
85        final int jdj = j + dj;
86        ret[jdj] = x_r;
87        ret[jdj + im] = x_i;
88      }
89      j += dj;
90      for (int k = 1; k < outerLoopLimit; k++, j += dj) {
91        final double[] twids = twiddles[k];
92        final double w_r = twids[0];
93        final double w_i = -sign * twids[1];
94        for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
95          final double z0_r = data[i];
96          final double z0_i = data[i + im];
97          final int idi = i + di;
98          final double z1_r = data[idi];
99          final double z1_i = data[idi + im];
100         ret[j] = z0_r + z1_r;
101         ret[j + im] = z0_i + z1_i;
102         final int jdj = j + dj;
103         multiplyAndStore(z0_r - z1_r, z0_i - z1_i, w_r, w_i, ret, jdj, jdj + im);
104       }
105     }
106   }
107 
108   /**
109    * Handle factors of 2 using SIMD vectors.
110    *
111    * @param passData The pass data.
112    */
113   @Override
114   protected void passSIMD(PassData passData) {
115     if (im == 1) {
116       interleaved(passData);
117     } else {
118       blocked(passData);
119     }
120   }
121 
122   /**
123    * Available SIMD sizes for Pass 2.
124    */
125   private static int[] simdSizes = {8, 4, 2};
126 
127   /**
128    * Handle factors of 2 using the chosen SIMD vector.
129    *
130    * @param passData   The pass data.
131    * @param simdLength The SIMD vector length.
132    */
133   private void interleaved(PassData passData, int simdLength) {
134     // Use the preferred SIMD vector.
135     switch (simdLength) {
136       case 2:
137         // 1 complex number per loop iteration.
138         interleaved128(passData);
139         break;
140       case 4:
141         // 2 complex numbers per loop iteration.
142         interleaved256(passData);
143         break;
144       case 8:
145         // 4 complex numbers per loop iteration.
146         interleaved512(passData);
147         break;
148       default:
149         passScalar(passData);
150     }
151   }
152 
153   /**
154    * Handle factors of 2 using the chosen SIMD vector.
155    *
156    * @param passData   The pass data.
157    * @param simdLength The SIMD vector length.
158    */
159   private void blocked(PassData passData, int simdLength) {
160     // Use the preferred SIMD vector.
161     switch (simdLength) {
162       case 2:
163         // 2 complex numbers per loop iteration.
164         blocked128(passData);
165         break;
166       case 4:
167         // 4 complex numbers per loop iteration.
168         blocked256(passData);
169         break;
170       case 8:
171         // 8 complex numbers per loop iteration.
172         blocked512(passData);
173         break;
174       default:
175         passScalar(passData);
176     }
177   }
178 
179   /**
180    * Handle factors of 2 using the 128-bit SIMD vectors.
181    *
182    * @param passData The interleaved pass data.
183    */
184   private void interleaved(PassData passData) {
185     if (innerLoopLimit % LOOP == 0) {
186       // Use the preferred SIMD vector.
187       interleaved(passData, LENGTH);
188     } else {
189       // If the inner loop limit is odd, use the scalar method unless the inner loop limit is 1.
190       if (innerLoopLimit % 2 != 0 && innerLoopLimit != 1) {
191         passScalar(passData);
192         return;
193       }
194       // Fall back to a smaller SIMD vector that fits the inner loop limit.
195       for (int size : simdSizes) {
196         if (size >= LENGTH) {
197           // Skip anything greater than or equal to the preferred SIMD vector size (which was too big).
198           continue;
199         }
200         // Divide the SIMD size by two because for interleaved a single SIMD vectors stores both real and imaginary parts.
201         if (innerLoopLimit % (size / 2) == 0) {
202           interleaved(passData, size);
203         }
204       }
205     }
206   }
207 
208   /**
209    * Handle factors of 2 using the 128-bit SIMD vectors.
210    *
211    * @param passData The pass blocked data.
212    */
213   private void blocked(PassData passData) {
214     if (innerLoopLimit % BLOCK_LOOP == 0) {
215       // Use the preferred SIMD vector.
216       blocked(passData, LENGTH);
217     } else {
218       // If the inner loop limit is odd, use the scalar method unless the inner loop limit is 1.
219       if (innerLoopLimit % 2 != 0) {
220         passScalar(passData);
221         return;
222       }
223       // Fall back to a smaller SIMD vector that fits the inner loop limit.
224       for (int size : simdSizes) {
225         if (size >= LENGTH) {
226           // Skip anything greater than or equal to the preferred SIMD vector size (which was too big).
227           continue;
228         }
229         // Divide the SIMD size by two because for interleaved a single SIMD vectors stores both real and imaginary parts.
230         if (innerLoopLimit % size == 0) {
231           blocked(passData, size);
232         }
233       }
234     }
235   }
236 
237   /**
238    * Handle factors of 2 using the 128-bit SIMD vectors.
239    */
240   private void blocked128(PassData passData) {
241     final double[] data = passData.in;
242     final double[] ret = passData.out;
243     int sign = passData.sign;
244     int i = passData.inOffset;
245     int j = passData.outOffset;
246     // First pass of the 2-point FFT has no twiddle factors.
247     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
248       DoubleVector
249           z0_r = fromArray(SPECIES_128, data, i),
250           z1_r = fromArray(SPECIES_128, data, i + di),
251           z0_i = fromArray(SPECIES_128, data, i + im),
252           z1_i = fromArray(SPECIES_128, data, i + di + im);
253       z0_r.add(z1_r).intoArray(ret, j);
254       z0_i.add(z1_i).intoArray(ret, j + im);
255       z0_r.sub(z1_r).intoArray(ret, j + dj);
256       z0_i.sub(z1_i).intoArray(ret, j + dj + im);
257     }
258 
259     j += dj;
260     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
261       final double[] twids = twiddles[k];
262       DoubleVector
263           w_r = broadcast(SPECIES_128, twids[0]),
264           w_i = broadcast(SPECIES_128, -sign * twids[1]);
265       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
266         DoubleVector
267             z0_r = fromArray(SPECIES_128, data, i),
268             z1_r = fromArray(SPECIES_128, data, i + di),
269             z0_i = fromArray(SPECIES_128, data, i + im),
270             z1_i = fromArray(SPECIES_128, data, i + di + im);
271         z0_r.add(z1_r).intoArray(ret, j);
272         z0_i.add(z1_i).intoArray(ret, j + im);
273         DoubleVector x_r = z0_r.sub(z1_r), x_i = z0_i.sub(z1_i);
274         w_r.mul(x_r).add(w_i.neg().mul(x_i)).intoArray(ret, j + dj);
275         w_r.mul(x_i).add(w_i.mul(x_r)).intoArray(ret, j + dj + im);
276       }
277     }
278   }
279 
280   /**
281    * Handle factors of 2 using the 256-bit SIMD vectors.
282    */
283   private void blocked256(PassData passData) {
284     final double[] data = passData.in;
285     final double[] ret = passData.out;
286     int sign = passData.sign;
287     int i = passData.inOffset;
288     int j = passData.outOffset;
289     // First pass of the 2-point FFT has no twiddle factors.
290     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
291       DoubleVector
292           z0_r = fromArray(SPECIES_256, data, i),
293           z1_r = fromArray(SPECIES_256, data, i + di),
294           z0_i = fromArray(SPECIES_256, data, i + im),
295           z1_i = fromArray(SPECIES_256, data, i + di + im);
296       z0_r.add(z1_r).intoArray(ret, j);
297       z0_i.add(z1_i).intoArray(ret, j + im);
298       z0_r.sub(z1_r).intoArray(ret, j + dj);
299       z0_i.sub(z1_i).intoArray(ret, j + dj + im);
300     }
301 
302     j += dj;
303     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
304       final double[] twids = twiddles[k];
305       DoubleVector
306           w_r = broadcast(SPECIES_256, twids[0]),
307           w_i = broadcast(SPECIES_256, -sign * twids[1]);
308       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
309         DoubleVector
310             z0_r = fromArray(SPECIES_256, data, i),
311             z1_r = fromArray(SPECIES_256, data, i + di),
312             z0_i = fromArray(SPECIES_256, data, i + im),
313             z1_i = fromArray(SPECIES_256, data, i + di + im);
314         z0_r.add(z1_r).intoArray(ret, j);
315         z0_i.add(z1_i).intoArray(ret, j + im);
316         DoubleVector x_r = z0_r.sub(z1_r), x_i = z0_i.sub(z1_i);
317         w_r.mul(x_r).add(w_i.neg().mul(x_i)).intoArray(ret, j + dj);
318         w_r.mul(x_i).add(w_i.mul(x_r)).intoArray(ret, j + dj + im);
319       }
320     }
321   }
322 
323   /**
324    * Handle factors of 2 using the 512-bit SIMD vectors.
325    */
326   private void blocked512(PassData passData) {
327     final double[] data = passData.in;
328     final double[] ret = passData.out;
329     int sign = passData.sign;
330     int i = passData.inOffset;
331     int j = passData.outOffset;
332     // First pass of the 2-point FFT has no twiddle factors.
333     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
334       DoubleVector
335           z0_r = fromArray(SPECIES_512, data, i),
336           z1_r = fromArray(SPECIES_512, data, i + di),
337           z0_i = fromArray(SPECIES_512, data, i + im),
338           z1_i = fromArray(SPECIES_512, data, i + di + im);
339       z0_r.add(z1_r).intoArray(ret, j);
340       z0_i.add(z1_i).intoArray(ret, j + im);
341       z0_r.sub(z1_r).intoArray(ret, j + dj);
342       z0_i.sub(z1_i).intoArray(ret, j + dj + im);
343     }
344 
345     j += dj;
346     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
347       final double[] twids = twiddles[k];
348       DoubleVector
349           w_r = broadcast(SPECIES_512, twids[0]),
350           w_i = broadcast(SPECIES_512, -sign * twids[1]);
351       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
352         DoubleVector
353             z0_r = fromArray(SPECIES_512, data, i),
354             z1_r = fromArray(SPECIES_512, data, i + di),
355             z0_i = fromArray(SPECIES_512, data, i + im),
356             z1_i = fromArray(SPECIES_512, data, i + di + im);
357         z0_r.add(z1_r).intoArray(ret, j);
358         z0_i.add(z1_i).intoArray(ret, j + im);
359         DoubleVector x_r = z0_r.sub(z1_r), x_i = z0_i.sub(z1_i);
360         w_r.mul(x_r).add(w_i.neg().mul(x_i)).intoArray(ret, j + dj);
361         w_r.mul(x_i).add(w_i.mul(x_r)).intoArray(ret, j + dj + im);
362       }
363     }
364   }
365 
366   /**
367    * Handle factors of 2 using the 128-bit SIMD vectors.
368    */
369   private void interleaved128(PassData passData) {
370     final double[] data = passData.in;
371     final double[] ret = passData.out;
372     int sign = passData.sign;
373     int i = passData.inOffset;
374     int j = passData.outOffset;
375     // First pass of the 2-point FFT has no twiddle factors.
376     for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP_128, i += LENGTH_128, j += LENGTH_128) {
377       DoubleVector
378           z0 = fromArray(SPECIES_128, data, i),
379           z1 = fromArray(SPECIES_128, data, i + di);
380       z0.add(z1).intoArray(ret, j);
381       z0.sub(z1).intoArray(ret, j + dj);
382     }
383 
384     j += dj;
385     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
386       final double[] twids = twiddles[k];
387       DoubleVector
388           wr = broadcast(SPECIES_128, twids[0]),
389           wi = broadcast(SPECIES_128, -sign * twids[1]).mul(NEGATE_IM_128);
390       for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP_128, i += LENGTH_128, j += LENGTH_128) {
391         DoubleVector
392             z0 = fromArray(SPECIES_128, data, i),
393             z1 = fromArray(SPECIES_128, data, i + di);
394         z0.add(z1).intoArray(ret, j);
395         DoubleVector x = z0.sub(z1);
396         x.mul(wr).add(x.mul(wi).rearrange(SHUFFLE_RE_IM_128)).intoArray(ret, j + dj);
397       }
398     }
399   }
400 
401   /**
402    * Handle factors of 2 using the 256-bit SIMD vectors.
403    */
404   private void interleaved256(PassData passData) {
405     final double[] data = passData.in;
406     final double[] ret = passData.out;
407     int sign = passData.sign;
408     int i = passData.inOffset;
409     int j = passData.outOffset;
410     // First pass of the 2-point FFT has no twiddle factors.
411     for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP_256, i += LENGTH_256, j += LENGTH_256) {
412       DoubleVector
413           z0 = fromArray(SPECIES_256, data, i),
414           z1 = fromArray(SPECIES_256, data, i + di);
415       z0.add(z1).intoArray(ret, j);
416       z0.sub(z1).intoArray(ret, j + dj);
417     }
418 
419     j += dj;
420     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
421       final double[] twids = twiddles[k];
422       DoubleVector
423           wr = broadcast(SPECIES_256, twids[0]),
424           wi = broadcast(SPECIES_256, -sign * twids[1]).mul(NEGATE_IM_256);
425       for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP_256, i += LENGTH_256, j += LENGTH_256) {
426         DoubleVector
427             z0 = fromArray(SPECIES_256, data, i),
428             z1 = fromArray(SPECIES_256, data, i + di);
429         z0.add(z1).intoArray(ret, j);
430         DoubleVector x = z0.sub(z1);
431         x.mul(wr).add(x.mul(wi).rearrange(SHUFFLE_RE_IM_256)).intoArray(ret, j + dj);
432       }
433     }
434   }
435 
436   /**
437    * Handle factors of 2 using the 512-bit SIMD vectors.
438    */
439   private void interleaved512(PassData passData) {
440     final double[] data = passData.in;
441     final double[] ret = passData.out;
442     int sign = passData.sign;
443     int i = passData.inOffset;
444     int j = passData.outOffset;
445     // First pass of the 2-point FFT has no twiddle factors.
446     for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP_512, i += LENGTH_512, j += LENGTH_512) {
447       DoubleVector
448           z0 = fromArray(SPECIES_512, data, i),
449           z1 = fromArray(SPECIES_512, data, i + di);
450       z0.add(z1).intoArray(ret, j);
451       z0.sub(z1).intoArray(ret, j + dj);
452     }
453 
454     j += dj;
455     for (int k = 1; k < outerLoopLimit; k++, j += dj) {
456       final double[] twids = twiddles[k];
457       DoubleVector
458           wr = broadcast(SPECIES_512, twids[0]),
459           wi = broadcast(SPECIES_512, -sign * twids[1]).mul(NEGATE_IM_512);
460       for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP_512, i += LENGTH_512, j += LENGTH_512) {
461         DoubleVector
462             z0 = fromArray(SPECIES_512, data, i),
463             z1 = fromArray(SPECIES_512, data, i + di);
464         z0.add(z1).intoArray(ret, j);
465         DoubleVector x = z0.sub(z1);
466         x.mul(wr).add(x.mul(wi).rearrange(SHUFFLE_RE_IM_512)).intoArray(ret, j + dj);
467       }
468     }
469   }
470 
471 }