View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2026.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.fft;
39  
40  import jdk.incubator.vector.DoubleVector;
41  import jdk.incubator.vector.VectorShuffle;
42  import jdk.incubator.vector.VectorSpecies;
43  
44  import static java.lang.Math.fma;
45  
46  /**
47   * Mixed radix factor is extended by the pass classes to apply the mixed radix factor.
48   */
49  public abstract class MixedRadixFactor {
50  
51    private static final double[] negateReal = {-1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0};
52    private static final int[] shuffleMask = {1, 0, 3, 2, 5, 4, 7, 6};
53  
54    /**
55     * The preferred vector species for double precision.
56     */
57    protected static final VectorSpecies<Double> DOUBLE_SPECIES = DoubleVector.SPECIES_PREFERRED;
58    /**
59     * Vector used to change the sign of the imaginary members of the vector via multiplication.
60     */
61    protected static final DoubleVector NEGATE_IM;
62    /**
63     * Vector used to change the sign of the real members of the vector via multiplication.
64     */
65    protected static final DoubleVector NEGATE_RE;
66    /**
67     * Shuffle used to swap real and imaginary members of the vector.
68     */
69    protected static final VectorShuffle<Double> SHUFFLE_RE_IM;
70    /**
71     * The number of contiguous elements that will be read from the input data array.
72     */
73    protected static final int LENGTH = DOUBLE_SPECIES.length();
74    /**
75     * The number of complex elements that will be processed in each inner loop iteration for
76     * interleaved data and the preferred SIMD species. The number of elements to process in the inner loop
77     * must be evenly divisible by this loop increment.
78     */
79    protected static final int INTERLEAVED_LOOP = LENGTH / 2;
80    /**
81     * The number of complex elements that will be processed in each inner loop iteration for block data
82     * and the preferred SIMD species. The number of elements to process in the inner loop must be
83     * evenly divisible by this loop increment.
84     */
85    protected static final int BLOCK_LOOP = LENGTH;
86  
87    /**
88     * Vector used to change the sign of the imaginary members of the vector via multiplication.
89     */
90    protected static final DoubleVector NEGATE_IM_128;
91    /**
92     * Vector used to change the sign of the real members of the vector via multiplication.
93     */
94    protected static final DoubleVector NEGATE_RE_128;
95    /**
96     * Shuffle used to swap real and imaginary members of the vector.
97     */
98    protected static final VectorShuffle<Double> SHUFFLE_RE_IM_128;
99    /**
100    * The number of contiguous elements that will be read from the input data array.
101    */
102   protected static final int LENGTH_128 = DoubleVector.SPECIES_128.length();
103   /**
104    * 1 complex element will be processed in each inner loop iteration for
105    * interleaved data and a 128-bit SIMD width.
106    */
107   protected static final int INTERLEAVED_LOOP_128 = LENGTH_128 / 2;
108   /**
109    * 2 complex elements will be processed in each inner loop iteration for block data
110    * and a 128-bit SIMD width. The number of elements to process in the inner loop must be
111    * evenly divisible by 2.
112    */
113   protected static final int BLOCK_LOOP_128 = LENGTH_128;
114 
115   /**
116    * Vector used to change the sign of the imaginary members of the vector via multiplication.
117    */
118   protected static final DoubleVector NEGATE_IM_256;
119   /**
120    * Vector used to change the sign of the real members of the vector via multiplication.
121    */
122   protected static final DoubleVector NEGATE_RE_256;
123   /**
124    * Shuffle used to swap real and imaginary members of the vector.
125    */
126   protected static final VectorShuffle<Double> SHUFFLE_RE_IM_256;
127   /**
128    * The number of contiguous elements that will be read from the input data array.
129    */
130   protected static final int LENGTH_256 = DoubleVector.SPECIES_256.length();
131   /**
132    * 2 complex elements will be processed in each inner loop iteration for
133    * interleaved data and the 256-bit SIMD width. The number of elements to process in the inner loop
134    * must be evenly divisible by 2.
135    */
136   protected static final int INTERLEAVED_LOOP_256 = LENGTH_256 / 2;
137   /**
138    * 4 complex elements will be processed in each inner loop iteration for block data
139    * and a 256-bit SIMD width. The number of elements to process in the inner loop must be
140    * evenly divisible by 4.
141    */
142   protected static final int BLOCK_LOOP_256 = LENGTH_256;
143 
144   /**
145    * Vector used to change the sign of the imaginary members of the vector via multiplication.
146    */
147   protected static final DoubleVector NEGATE_IM_512;
148   /**
149    * Vector used to change the sign of the real members of the vector via multiplication.
150    */
151   protected static final DoubleVector NEGATE_RE_512;
152   /**
153    * Shuffle used to swap real and imaginary members of the vector.
154    */
155   protected static final VectorShuffle<Double> SHUFFLE_RE_IM_512;
156   /**
157    * The number of contiguous elements that will be read from the input data array.
158    */
159   protected static final int LENGTH_512 = DoubleVector.SPECIES_512.length();
160   /**
161    * 4 complex elements will be processed in each inner loop iteration for
162    * interleaved data and the 512-bit SIMD width. The number of elements to process in the inner loop
163    * must be evenly divisible by 4.
164    */
165   protected static final int INTERLEAVED_LOOP_512 = LENGTH_512 / 2;
166   /**
167    * 8 complex elements will be processed in each inner loop iteration for block data
168    * and a 512-bit SIMD width. The number of elements to process in the inner loop must be
169    * evenly divisible by 8.
170    */
171   protected static final int BLOCK_LOOP_512 = LENGTH_512;
172 
173   static {
174     // Assume that 512 is the largest vector size.
175     if (LENGTH > 8) {
176       throw new IllegalStateException("Unsupported SIMD vector size: " + LENGTH);
177     }
178 
179     NEGATE_RE_128 = DoubleVector.fromArray(DoubleVector.SPECIES_128, negateReal, 0);
180     NEGATE_IM_128 = NEGATE_RE_128.mul(-1.0);
181     SHUFFLE_RE_IM_128 = VectorShuffle.fromArray(DoubleVector.SPECIES_128, shuffleMask, 0);
182 
183     NEGATE_RE_256 = DoubleVector.fromArray(DoubleVector.SPECIES_256, negateReal, 0);
184     NEGATE_IM_256 = NEGATE_RE_256.mul(-1.0);
185     SHUFFLE_RE_IM_256 = VectorShuffle.fromArray(DoubleVector.SPECIES_256, shuffleMask, 0);
186 
187     NEGATE_RE_512 = DoubleVector.fromArray(DoubleVector.SPECIES_512, negateReal, 0);
188     NEGATE_IM_512 = NEGATE_RE_512.mul(-1.0);
189     SHUFFLE_RE_IM_512 = VectorShuffle.fromArray(DoubleVector.SPECIES_512, shuffleMask, 0);
190 
191     switch (LENGTH) {
192       case 2:
193         NEGATE_RE = NEGATE_RE_128;
194         NEGATE_IM = NEGATE_IM_128;
195         SHUFFLE_RE_IM = SHUFFLE_RE_IM_128;
196         break;
197       case 4:
198         NEGATE_RE = NEGATE_RE_256;
199         NEGATE_IM = NEGATE_IM_256;
200         SHUFFLE_RE_IM = SHUFFLE_RE_IM_256;
201         break;
202       case 8:
203         NEGATE_RE = NEGATE_RE_512;
204         NEGATE_IM = NEGATE_IM_512;
205         SHUFFLE_RE_IM = SHUFFLE_RE_IM_512;
206         break;
207       default:
208         throw new IllegalStateException("Unsupported SIMD DoubleVector size: " + LENGTH);
209     }
210   }
211 
212   /**
213    * The size of the input.
214    */
215   protected final int n;
216   /**
217    * The number of FFTs to process (default = 1).
218    */
219   protected final int nFFTs;
220   /**
221    * The imaginary offset.
222    */
223   protected final int im;
224   /**
225    * The mixed radix factor.
226    */
227   protected final int factor;
228   /**
229    * The product of all factors applied so far.
230    */
231   protected final int product;
232   /**
233    * The outer loop limit (n / product).
234    */
235   protected final int outerLoopLimit;
236   /**
237    * The inner loop limit (product / factor).
238    */
239   protected final int innerLoopLimit;
240   /**
241    * The next input (n / factor).
242    * This is the separation between the input data for each pass.
243    */
244   protected final int nextInput;
245   /**
246    * Equal to 2 * nextInput for interleaved complex data.
247    * Equal to nextInput for blocked real and imaginary arrays.
248    */
249   protected final int di;
250   /**
251    * Equal to 2 * innerLoopLimit for interleaved complex data.
252    * Equal to innerLoopLimit for blocked real and imaginary arrays.
253    */
254   protected final int dj;
255   /**
256    * The twiddle factors for this pass.
257    */
258   protected final double[][] twiddles;
259   /**
260    * The real twiddle factors for this pass.
261    */
262   protected final double[] wr;
263   /**
264    * The imaginary twiddle factors for this pass.
265    */
266   protected final double[] wi;
267   /**
268    * The increment for input data within the inner loop.
269    * This is equal to 2 for interleaved complex data.
270    * This is equal to 1 for separate real and imaginary arrays.
271    */
272   protected final int ii;
273   /**
274    * Increment for the inner loop.
275    */
276   protected final int jstep;
277   /**
278    * The SIMD width to use.
279    */
280   protected int simdWidth;
281 
282   /**
283    * Constructor for the mixed radix factor.
284    *
285    * @param passConstants the pass constants.
286    */
287   public MixedRadixFactor(PassConstants passConstants) {
288     n = passConstants.n();
289     nFFTs = passConstants.nFFTs();
290     im = passConstants.im();
291     factor = passConstants.factor();
292     product = passConstants.product();
293     twiddles = passConstants.twiddles();
294     outerLoopLimit = n / product;
295     innerLoopLimit = (product / factor) * nFFTs;
296     nextInput = (n / factor) * nFFTs;
297     if (im == 1) {
298       ii = 2;
299       // For interleaved complex data, the di and dj offsets are doubled.
300       di = 2 * nextInput;
301       dj = 2 * innerLoopLimit;
302     } else {
303       ii = 1;
304       // For separate real and imaginary arrays, the di and dj offsets
305       // are the same as the next input and inner loop limit.
306       di = nextInput;
307       dj = innerLoopLimit;
308     }
309     jstep = (factor - 1) * dj;
310 
311     int f1 = factor - 1;
312     wr = new double[outerLoopLimit * f1];
313     wi = new double[outerLoopLimit * f1];
314     for (int k = 0; k < outerLoopLimit; k++) {
315       final double[] twids = twiddles[k];
316       final int index = k * f1;
317       for (int j = 0; j < f1; j++) {
318         wr[index + j] = twids[2 * j];
319         wi[index + j] = twids[2 * j + 1];
320       }
321     }
322 
323     simdWidth = getOptimalSIMDWidth();
324   }
325 
326   /**
327    * Return a string representation of the mixed radix factor.
328    *
329    * @return a string representation of the mixed radix factor.
330    */
331   public String toString() {
332     return " MixedRadixFactor {" +
333         "\n N: " + n +
334         "\n Factor: " + factor +
335         "\n Number of FFTs: " + nFFTs +
336         "\n Next real value: " + ii +
337         "\n Imaginary offset: " + im +
338         "\n Product: " + product +
339         "\n Outer Loop Limit: " + outerLoopLimit +
340         "\n Inner Loop Limit: " + innerLoopLimit +
341         "\n SIMD width: " + simdWidth +
342         "\n Next input: " + nextInput +
343         "\n Step between input values: " + di +
344         "\n Step between output values: " + dj +
345         "\n jstep =" + jstep +
346         "}\n";
347   }
348 
349   /**
350    * Apply the mixed radix factor using scalar operations.
351    *
352    * @param passData the pass data.
353    */
354   protected abstract void passScalar(PassData passData);
355 
356   /**
357    * Check if the requested SIMD length is valid.
358    * @param width Requested SIMD species width.
359    * @return True if this width is supported.
360    */
361   protected boolean isValidSIMDWidth(int width) {
362     if (width != LENGTH) {
363       return false;
364     }
365     if (im == 1) {
366       // Interleaved
367       return innerLoopLimit % INTERLEAVED_LOOP == 0;
368     } else {
369       // Blocked
370       return innerLoopLimit % BLOCK_LOOP == 0;
371     }
372   }
373 
374   /**
375    * Determine the optimal SIMD width. Currently supported widths are 2, 4 and 8.
376    * If no SIMD width is valid, return 0 to indicate use of the scalar path.
377    * @return The optimal SIMD width.
378    */
379   protected int getOptimalSIMDWidth() {
380     // Check the platform specific preferred width.
381     if (isValidSIMDWidth(LENGTH)) {
382       return LENGTH;
383     }
384     // No valid SIMD width.
385     return 0;
386   }
387 
388   /**
389    * Set the SIMD width to use. If the supplied width is not supported, then the
390    * width will be set by the "getOptimalSIMDWidth" method.
391    *
392    * @param width The SIMD width to use.
393    * @return the SIMD width selected.
394    */
395   public int setSIMDWidth(int width) {
396     if (isValidSIMDWidth(width)) {
397       simdWidth = width;
398     } else {
399       simdWidth = getOptimalSIMDWidth();
400     }
401     return simdWidth;
402   }
403 
404   /**
405    * Apply the mixed radix factor using SIMD operations.
406    *
407    * @param passData the pass data.
408    */
409   protected abstract void passSIMD(PassData passData);
410 
411   /**
412    * Multiply two complex numbers [x_r, x_i] and [w_r, w_i] and store the result.
413    *
414    * @param x_r the real part of the complex number.
415    * @param x_i the imaginary part of the complex number.
416    * @param w_r the real part of the twiddle factor.
417    * @param w_i the imaginary part of the twiddle factor.
418    * @param ret the array to store the result.
419    * @param re  the real part index in the result array.
420    * @param im  the imaginary part index in the result array.
421    */
422   protected static void multiplyAndStore(double x_r, double x_i, double w_r, double w_i, double[] ret, int re, int im) {
423     ret[re] = fma(w_r, x_r, -w_i * x_i);
424     ret[im] = fma(w_r, x_i, w_i * x_r);
425   }
426 }