View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.fft;
39  
40  import jdk.incubator.vector.DoubleVector;
41  
42  import static java.lang.Math.fma;
43  import static jdk.incubator.vector.DoubleVector.broadcast;
44  import static jdk.incubator.vector.DoubleVector.fromArray;
45  import static org.apache.commons.math3.util.FastMath.PI;
46  import static org.apache.commons.math3.util.FastMath.sin;
47  import static org.apache.commons.math3.util.FastMath.sqrt;
48  
49  /**
50   * The MixedRadixFactor5 class handles factors of 5 in the FFT.
51   */
52  public class MixedRadixFactor5 extends MixedRadixFactor {
53  
54    private static final double sqrt5_4 = sqrt(5.0) / 4.0;
55    private static final double sinPI_5 = sin(PI / 5.0);
56    private static final double sin2PI_5 = sin(2.0 * PI / 5.0);
57  
58    private final int di2;
59    private final int di3;
60    private final int di4;
61    private final int dj2;
62    private final int dj3;
63    private final int dj4;
64    private final double tau = sqrt5_4;
65  
66    /**
67     * Construct a MixedRadixFactor5.
68     *
69     * @param passConstants PassConstants.
70     */
71    public MixedRadixFactor5(PassConstants passConstants) {
72      super(passConstants);
73      di2 = 2 * di;
74      di3 = 3 * di;
75      di4 = 4 * di;
76      dj2 = 2 * dj;
77      dj3 = 3 * dj;
78      dj4 = 4 * dj;
79    }
80  
81    /**
82     * Handle factors of 5.
83     *
84     * @param passData PassData.
85     */
86    @Override
87    protected void passScalar(PassData passData) {
88      final double[] data = passData.in;
89      final double[] ret = passData.out;
90      int sign = passData.sign;
91      int i = passData.inOffset;
92      int j = passData.outOffset;
93      final double sin2PI_5s = sign * sin2PI_5;
94      final double sinPI_5s = sign * sinPI_5;
95      // First pass of the 5-point FFT has no twiddle factors.
96      for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
97        final double z0r = data[i];
98        final double z1r = data[i + di];
99        final double z2r = data[i + di2];
100       final double z3r = data[i + di3];
101       final double z4r = data[i + di4];
102       final double z0i = data[i + im];
103       final double z1i = data[i + di + im];
104       final double z2i = data[i + di2 + im];
105       final double z3i = data[i + di3 + im];
106       final double z4i = data[i + di4 + im];
107       final double t1r = z1r + z4r;
108       final double t1i = z1i + z4i;
109       final double t2r = z2r + z3r;
110       final double t2i = z2i + z3i;
111       final double t3r = z1r - z4r;
112       final double t3i = z1i - z4i;
113       final double t4r = z2r - z3r;
114       final double t4i = z2i - z3i;
115       final double t5r = t1r + t2r;
116       final double t5i = t1i + t2i;
117       final double t6r = tau * (t1r - t2r);
118       final double t6i = tau * (t1i - t2i);
119       final double t7r = fma(-0.25, t5r, z0r);
120       final double t7i = fma(-0.25, t5i, z0i);
121       final double t8r = t7r + t6r;
122       final double t8i = t7i + t6i;
123       final double t9r = t7r - t6r;
124       final double t9i = t7i - t6i;
125       final double t10r = fma(sin2PI_5s, t3r, sinPI_5s * t4r);
126       final double t10i = fma(sin2PI_5s, t3i, sinPI_5s * t4i);
127       final double t11r = fma(-sin2PI_5s, t4r, sinPI_5s * t3r);
128       final double t11i = fma(-sin2PI_5s, t4i, sinPI_5s * t3i);
129       ret[j] = z0r + t5r;
130       ret[j + im] = z0i + t5i;
131       ret[j + dj] = t8r - t10i;
132       ret[j + dj + im] = t8i + t10r;
133       ret[j + dj2] = t9r - t11i;
134       ret[j + dj2 + im] = t9i + t11r;
135       ret[j + dj3] = t9r + t11i;
136       ret[j + dj3 + im] = t9i - t11r;
137       ret[j + dj4] = t8r + t10i;
138       ret[j + dj4 + im] = t8i - t10r;
139     }
140 
141     j += jstep;
142     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
143       final double[] twids = twiddles[k];
144       final double w1r = twids[0];
145       final double w1i = -sign * twids[1];
146       final double w2r = twids[2];
147       final double w2i = -sign * twids[3];
148       final double w3r = twids[4];
149       final double w3i = -sign * twids[5];
150       final double w4r = twids[6];
151       final double w4i = -sign * twids[7];
152       for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
153         final double z0r = data[i];
154         final double z1r = data[i + di];
155         final double z2r = data[i + di2];
156         final double z3r = data[i + di3];
157         final double z4r = data[i + di4];
158         final double z0i = data[i + im];
159         final double z1i = data[i + di + im];
160         final double z2i = data[i + di2 + im];
161         final double z3i = data[i + di3 + im];
162         final double z4i = data[i + di4 + im];
163         final double t1r = z1r + z4r;
164         final double t1i = z1i + z4i;
165         final double t2r = z2r + z3r;
166         final double t2i = z2i + z3i;
167         final double t3r = z1r - z4r;
168         final double t3i = z1i - z4i;
169         final double t4r = z2r - z3r;
170         final double t4i = z2i - z3i;
171         final double t5r = t1r + t2r;
172         final double t5i = t1i + t2i;
173         final double t6r = tau * (t1r - t2r);
174         final double t6i = tau * (t1i - t2i);
175         final double t7r = fma(-0.25, t5r, z0r);
176         final double t7i = fma(-0.25, t5i, z0i);
177         final double t8r = t7r + t6r;
178         final double t8i = t7i + t6i;
179         final double t9r = t7r - t6r;
180         final double t9i = t7i - t6i;
181         final double t10r = fma(sin2PI_5s, t3r, sinPI_5s * t4r);
182         final double t10i = fma(sin2PI_5s, t3i, sinPI_5s * t4i);
183         final double t11r = fma(-sin2PI_5s, t4r, sinPI_5s * t3r);
184         final double t11i = fma(-sin2PI_5s, t4i, sinPI_5s * t3i);
185         ret[j] = z0r + t5r;
186         ret[j + im] = z0i + t5i;
187 
188 
189         multiplyAndStore(t8r - t10i, t8i + t10r, w1r, w1i, ret, j + dj, j + dj + im);
190         multiplyAndStore(t9r - t11i, t9i + t11r, w2r, w2i, ret, j + dj2, j + dj2 + im);
191         multiplyAndStore(t9r + t11i, t9i - t11r, w3r, w3i, ret, j + dj3, j + dj3 + im);
192         multiplyAndStore(t8r + t10i, t8i - t10r, w4r, w4i, ret, j + dj4, j + dj4 + im);
193       }
194     }
195   }
196 
197   /**
198    * Handle factors of 5 using SIMD vectors.
199    *
200    * @param passData PassData.
201    */
202   @Override
203   protected void passSIMD(PassData passData) {
204     if (im == 1) {
205       // If the inner loop limit is not divisible by the loop increment, use the scalar method.
206       if (innerLoopLimit % LOOP != 0) {
207         passScalar(passData);
208       } else {
209         interleaved(passData);
210       }
211     } else {
212       // If the inner loop limit is not divisible by the loop increment, use the scalar method.
213       if (innerLoopLimit % BLOCK_LOOP != 0) {
214         passScalar(passData);
215       } else {
216         blocked(passData);
217       }
218     }
219   }
220 
221   /**
222    * Handle factors of 5.
223    *
224    * @param passData PassData.
225    */
226   protected void interleaved(PassData passData) {
227     final double[] data = passData.in;
228     final double[] ret = passData.out;
229     int sign = passData.sign;
230     int i = passData.inOffset;
231     int j = passData.outOffset;
232     final double sin2PI_5s = sign * sin2PI_5;
233     final double sinPI_5s = sign * sinPI_5;
234     // First pass of the 5-point FFT has no twiddle factors.
235     for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP, i += LENGTH, j += LENGTH) {
236       DoubleVector
237           z0 = fromArray(DOUBLE_SPECIES, data, i),
238           z1 = fromArray(DOUBLE_SPECIES, data, i + di),
239           z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
240           z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
241           z4 = fromArray(DOUBLE_SPECIES, data, i + di4);
242       DoubleVector
243           t1 = z1.add(z4),
244           t2 = z2.add(z3),
245           t3 = z1.sub(z4),
246           t4 = z2.sub(z3),
247           t5 = t1.add(t2),
248           t6 = t1.sub(t2).mul(tau),
249           t7 = t5.mul(-0.25).add(z0),
250           t8 = t7.add(t6),
251           t9 = t7.sub(t6),
252           t10 = t3.mul(sin2PI_5s).add(t4.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM),
253           t11 = t4.mul(-sin2PI_5s).add(t3.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM);
254       z0.add(t5).intoArray(ret, j);
255       t8.add(t10.mul(NEGATE_RE)).intoArray(ret, j + dj);
256       t9.add(t11.mul(NEGATE_RE)).intoArray(ret, j + dj2);
257       t9.add(t11.mul(NEGATE_IM)).intoArray(ret, j + dj3);
258       t8.add(t10.mul(NEGATE_IM)).intoArray(ret, j + dj4);
259     }
260 
261     j += jstep;
262     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
263       final double[] twids = twiddles[k];
264       DoubleVector
265           w1r = broadcast(DOUBLE_SPECIES, twids[0]),
266           w1i = broadcast(DOUBLE_SPECIES, -sign * twids[1]).mul(NEGATE_IM),
267           w2r = broadcast(DOUBLE_SPECIES, twids[2]),
268           w2i = broadcast(DOUBLE_SPECIES, -sign * twids[3]).mul(NEGATE_IM),
269           w3r = broadcast(DOUBLE_SPECIES, twids[4]),
270           w3i = broadcast(DOUBLE_SPECIES, -sign * twids[5]).mul(NEGATE_IM),
271           w4r = broadcast(DOUBLE_SPECIES, twids[6]),
272           w4i = broadcast(DOUBLE_SPECIES, -sign * twids[7]).mul(NEGATE_IM);
273       for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP, i += LENGTH, j += LENGTH) {
274         DoubleVector
275             z0 = fromArray(DOUBLE_SPECIES, data, i),
276             z1 = fromArray(DOUBLE_SPECIES, data, i + di),
277             z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
278             z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
279             z4 = fromArray(DOUBLE_SPECIES, data, i + di4);
280         DoubleVector
281             t1 = z1.add(z4),
282             t2 = z2.add(z3),
283             t3 = z1.sub(z4),
284             t4 = z2.sub(z3),
285             t5 = t1.add(t2),
286             t6 = t1.sub(t2).mul(tau),
287             t7 = t5.mul(-0.25).add(z0),
288             t8 = t7.add(t6),
289             t9 = t7.sub(t6),
290             t10 = t3.mul(sin2PI_5s).add(t4.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM),
291             t11 = t4.mul(-sin2PI_5s).add(t3.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM);
292         z0.add(t5).intoArray(ret, j);
293         DoubleVector
294             x1 = t8.add(t10.mul(NEGATE_RE)),
295             x2 = t9.add(t11.mul(NEGATE_RE)),
296             x3 = t9.add(t11.mul(NEGATE_IM)),
297             x4 = t8.add(t10.mul(NEGATE_IM));
298         w1r.mul(x1).add(w1i.mul(x1).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj);
299         w2r.mul(x2).add(w2i.mul(x2).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj2);
300         w3r.mul(x3).add(w3i.mul(x3).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj3);
301         w4r.mul(x4).add(w4i.mul(x4).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj4);
302       }
303     }
304   }
305 
306   /**
307    * Handle factors of 5.
308    *
309    * @param passData PassData.
310    */
311   protected void blocked(PassData passData) {
312     final double[] data = passData.in;
313     final double[] ret = passData.out;
314     int sign = passData.sign;
315     int i = passData.inOffset;
316     int j = passData.outOffset;
317     final double sin2PI_5s = sign * sin2PI_5;
318     final double sinPI_5s = sign * sinPI_5;
319     // First pass of the 5-point FFT has no twiddle factors.
320     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
321       final DoubleVector
322           z0r = fromArray(DOUBLE_SPECIES, data, i),
323           z1r = fromArray(DOUBLE_SPECIES, data, i + di),
324           z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
325           z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
326           z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
327           z0i = fromArray(DOUBLE_SPECIES, data, i + im),
328           z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
329           z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
330           z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
331           z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im);
332       final DoubleVector
333           t1r = z1r.add(z4r),
334           t1i = z1i.add(z4i),
335           t2r = z2r.add(z3r),
336           t2i = z2i.add(z3i),
337           t3r = z1r.sub(z4r),
338           t3i = z1i.sub(z4i),
339           t4r = z2r.sub(z3r),
340           t4i = z2i.sub(z3i),
341           t5r = t1r.add(t2r),
342           t5i = t1i.add(t2i),
343           t6r = t1r.sub(t2r).mul(tau),
344           t6i = t1i.sub(t2i).mul(tau),
345           t7r = t5r.mul(-0.25).add(z0r),
346           t7i = t5i.mul(-0.25).add(z0i),
347           t8r = t7r.add(t6r),
348           t8i = t7i.add(t6i),
349           t9r = t7r.sub(t6r),
350           t9i = t7i.sub(t6i),
351           t10r = t3r.mul(sin2PI_5s).add(t4r.mul(sinPI_5s)),
352           t10i = t3i.mul(sin2PI_5s).add(t4i.mul(sinPI_5s)),
353           t11r = t4r.mul(-sin2PI_5s).add(t3r.mul(sinPI_5s)),
354           t11i = t4i.mul(-sin2PI_5s).add(t3i.mul(sinPI_5s));
355       z0r.add(t5r).intoArray(ret, j);
356       z0i.add(t5i).intoArray(ret, j + im);
357       t8r.sub(t10i).intoArray(ret, j + dj);
358       t8i.add(t10r).intoArray(ret, j + dj + im);
359       t9r.sub(t11i).intoArray(ret, j + dj2);
360       t9i.add(t11r).intoArray(ret, j + dj2 + im);
361       t9r.add(t11i).intoArray(ret, j + dj3);
362       t9i.sub(t11r).intoArray(ret, j + dj3 + im);
363       t8r.add(t10i).intoArray(ret, j + dj4);
364       t8i.sub(t10r).intoArray(ret, j + dj4 + im);
365     }
366     j += jstep;
367     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
368       final double[] twids = twiddles[k];
369       DoubleVector
370           w1r = broadcast(DOUBLE_SPECIES, twids[0]),
371           w1i = broadcast(DOUBLE_SPECIES, -sign * twids[1]),
372           w2r = broadcast(DOUBLE_SPECIES, twids[2]),
373           w2i = broadcast(DOUBLE_SPECIES, -sign * twids[3]),
374           w3r = broadcast(DOUBLE_SPECIES, twids[4]),
375           w3i = broadcast(DOUBLE_SPECIES, -sign * twids[5]),
376           w4r = broadcast(DOUBLE_SPECIES, twids[6]),
377           w4i = broadcast(DOUBLE_SPECIES, -sign * twids[7]);
378       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
379         final DoubleVector
380             z0r = fromArray(DOUBLE_SPECIES, data, i),
381             z1r = fromArray(DOUBLE_SPECIES, data, i + di),
382             z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
383             z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
384             z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
385             z0i = fromArray(DOUBLE_SPECIES, data, i + im),
386             z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
387             z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
388             z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
389             z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im);
390         final DoubleVector
391             t1r = z1r.add(z4r),
392             t1i = z1i.add(z4i),
393             t2r = z2r.add(z3r),
394             t2i = z2i.add(z3i),
395             t3r = z1r.sub(z4r),
396             t3i = z1i.sub(z4i),
397             t4r = z2r.sub(z3r),
398             t4i = z2i.sub(z3i),
399             t5r = t1r.add(t2r),
400             t5i = t1i.add(t2i),
401             t6r = t1r.sub(t2r).mul(tau),
402             t6i = t1i.sub(t2i).mul(tau),
403             t7r = t5r.mul(-0.25).add(z0r),
404             t7i = t5i.mul(-0.25).add(z0i),
405             t8r = t7r.add(t6r),
406             t8i = t7i.add(t6i),
407             t9r = t7r.sub(t6r),
408             t9i = t7i.sub(t6i),
409             t10r = t3r.mul(sin2PI_5s).add(t4r.mul(sinPI_5s)),
410             t10i = t3i.mul(sin2PI_5s).add(t4i.mul(sinPI_5s)),
411             t11r = t4r.mul(-sin2PI_5s).add(t3r.mul(sinPI_5s)),
412             t11i = t4i.mul(-sin2PI_5s).add(t3i.mul(sinPI_5s));
413         z0r.add(t5r).intoArray(ret, j);
414         z0i.add(t5i).intoArray(ret, j + im);
415         DoubleVector
416             x1r = t8r.sub(t10i), x1i = t8i.add(t10r),
417             x2r = t9r.sub(t11i), x2i = t9i.add(t11r),
418             x3r = t9r.add(t11i), x3i = t9i.sub(t11r),
419             x4r = t8r.add(t10i), x4i = t8i.sub(t10r);
420         w1r.mul(x1r).add(w1i.neg().mul(x1i)).intoArray(ret, j + dj);
421         w2r.mul(x2r).add(w2i.neg().mul(x2i)).intoArray(ret, j + dj2);
422         w3r.mul(x3r).add(w3i.neg().mul(x3i)).intoArray(ret, j + dj3);
423         w4r.mul(x4r).add(w4i.neg().mul(x4i)).intoArray(ret, j + dj4);
424         w1r.mul(x1i).add(w1i.mul(x1r)).intoArray(ret, j + dj + im);
425         w2r.mul(x2i).add(w2i.mul(x2r)).intoArray(ret, j + dj2 + im);
426         w3r.mul(x3i).add(w3i.mul(x3r)).intoArray(ret, j + dj3 + im);
427         w4r.mul(x4i).add(w4i.mul(x4r)).intoArray(ret, j + dj4 + im);
428       }
429     }
430   }
431 
432 }