View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2026.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.fft;
39  
40  import jdk.incubator.vector.DoubleVector;
41  
42  import static java.lang.Math.fma;
43  import static jdk.incubator.vector.DoubleVector.broadcast;
44  import static jdk.incubator.vector.DoubleVector.fromArray;
45  import static org.apache.commons.math3.util.FastMath.PI;
46  import static org.apache.commons.math3.util.FastMath.sin;
47  import static org.apache.commons.math3.util.FastMath.sqrt;
48  
49  /**
50   * The MixedRadixFactor5 class handles factors of 5 in the FFT.
51   */
52  public class MixedRadixFactor5 extends MixedRadixFactor {
53  
54    private static final double sqrt5_4 = sqrt(5.0) / 4.0;
55    private static final double sinPI_5 = sin(PI / 5.0);
56    private static final double sin2PI_5 = sin(2.0 * PI / 5.0);
57  
58    private final int di2;
59    private final int di3;
60    private final int di4;
61    private final int dj2;
62    private final int dj3;
63    private final int dj4;
64    private final double tau = sqrt5_4;
65  
66    /**
67     * Construct a MixedRadixFactor5.
68     *
69     * @param passConstants PassConstants.
70     */
71    public MixedRadixFactor5(PassConstants passConstants) {
72      super(passConstants);
73      di2 = 2 * di;
74      di3 = 3 * di;
75      di4 = 4 * di;
76      dj2 = 2 * dj;
77      dj3 = 3 * dj;
78      dj4 = 4 * dj;
79    }
80  
81    /**
82     * Handle factors of 5.
83     *
84     * @param passData PassData.
85     */
86    @Override
87    protected void passScalar(PassData passData) {
88      final double[] data = passData.in;
89      final double[] ret = passData.out;
90      int sign = passData.sign;
91      int i = passData.inOffset;
92      int j = passData.outOffset;
93      final double sin2PI_5s = sign * sin2PI_5;
94      final double sinPI_5s = sign * sinPI_5;
95      // First pass of the 5-point FFT has no twiddle factors.
96      for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
97        final double z0r = data[i];
98        final double z1r = data[i + di];
99        final double z2r = data[i + di2];
100       final double z3r = data[i + di3];
101       final double z4r = data[i + di4];
102       final double z0i = data[i + im];
103       final double z1i = data[i + di + im];
104       final double z2i = data[i + di2 + im];
105       final double z3i = data[i + di3 + im];
106       final double z4i = data[i + di4 + im];
107       final double t1r = z1r + z4r;
108       final double t1i = z1i + z4i;
109       final double t2r = z2r + z3r;
110       final double t2i = z2i + z3i;
111       final double t3r = z1r - z4r;
112       final double t3i = z1i - z4i;
113       final double t4r = z2r - z3r;
114       final double t4i = z2i - z3i;
115       final double t5r = t1r + t2r;
116       final double t5i = t1i + t2i;
117       final double t6r = tau * (t1r - t2r);
118       final double t6i = tau * (t1i - t2i);
119       final double t7r = fma(-0.25, t5r, z0r);
120       final double t7i = fma(-0.25, t5i, z0i);
121       final double t8r = t7r + t6r;
122       final double t8i = t7i + t6i;
123       final double t9r = t7r - t6r;
124       final double t9i = t7i - t6i;
125       final double t10r = fma(sin2PI_5s, t3r, sinPI_5s * t4r);
126       final double t10i = fma(sin2PI_5s, t3i, sinPI_5s * t4i);
127       final double t11r = fma(-sin2PI_5s, t4r, sinPI_5s * t3r);
128       final double t11i = fma(-sin2PI_5s, t4i, sinPI_5s * t3i);
129       ret[j] = z0r + t5r;
130       ret[j + im] = z0i + t5i;
131       ret[j + dj] = t8r - t10i;
132       ret[j + dj + im] = t8i + t10r;
133       ret[j + dj2] = t9r - t11i;
134       ret[j + dj2 + im] = t9i + t11r;
135       ret[j + dj3] = t9r + t11i;
136       ret[j + dj3 + im] = t9i - t11r;
137       ret[j + dj4] = t8r + t10i;
138       ret[j + dj4 + im] = t8i - t10r;
139     }
140 
141     j += jstep;
142     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
143 //      final double[] twids = twiddles[k];
144 //      final double w1r = twids[0];
145 //      final double w1i = -sign * twids[1];
146 //      final double w2r = twids[2];
147 //      final double w2i = -sign * twids[3];
148 //      final double w3r = twids[4];
149 //      final double w3i = -sign * twids[5];
150 //      final double w4r = twids[6];
151 //      final double w4i = -sign * twids[7];
152       final int index = k * 4;
153       final double w1r = wr[index];
154       final double w2r = wr[index + 1];
155       final double w3r = wr[index + 2];
156       final double w4r = wr[index + 3];
157       final double w1i = -sign * wi[index];
158       final double w2i = -sign * wi[index + 1];
159       final double w3i = -sign * wi[index + 2];
160       final double w4i = -sign * wi[index + 3];
161       for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
162         final double z0r = data[i];
163         final double z1r = data[i + di];
164         final double z2r = data[i + di2];
165         final double z3r = data[i + di3];
166         final double z4r = data[i + di4];
167         final double z0i = data[i + im];
168         final double z1i = data[i + di + im];
169         final double z2i = data[i + di2 + im];
170         final double z3i = data[i + di3 + im];
171         final double z4i = data[i + di4 + im];
172         final double t1r = z1r + z4r;
173         final double t1i = z1i + z4i;
174         final double t2r = z2r + z3r;
175         final double t2i = z2i + z3i;
176         final double t3r = z1r - z4r;
177         final double t3i = z1i - z4i;
178         final double t4r = z2r - z3r;
179         final double t4i = z2i - z3i;
180         final double t5r = t1r + t2r;
181         final double t5i = t1i + t2i;
182         final double t6r = tau * (t1r - t2r);
183         final double t6i = tau * (t1i - t2i);
184         final double t7r = fma(-0.25, t5r, z0r);
185         final double t7i = fma(-0.25, t5i, z0i);
186         final double t8r = t7r + t6r;
187         final double t8i = t7i + t6i;
188         final double t9r = t7r - t6r;
189         final double t9i = t7i - t6i;
190         final double t10r = fma(sin2PI_5s, t3r, sinPI_5s * t4r);
191         final double t10i = fma(sin2PI_5s, t3i, sinPI_5s * t4i);
192         final double t11r = fma(-sin2PI_5s, t4r, sinPI_5s * t3r);
193         final double t11i = fma(-sin2PI_5s, t4i, sinPI_5s * t3i);
194         ret[j] = z0r + t5r;
195         ret[j + im] = z0i + t5i;
196 
197         multiplyAndStore(t8r - t10i, t8i + t10r, w1r, w1i, ret, j + dj, j + dj + im);
198         multiplyAndStore(t9r - t11i, t9i + t11r, w2r, w2i, ret, j + dj2, j + dj2 + im);
199         multiplyAndStore(t9r + t11i, t9i - t11r, w3r, w3i, ret, j + dj3, j + dj3 + im);
200         multiplyAndStore(t8r + t10i, t8i - t10r, w4r, w4i, ret, j + dj4, j + dj4 + im);
201       }
202     }
203   }
204 
205   /**
206    * Handle factors of 5 using SIMD vectors.
207    *
208    * @param passData PassData.
209    */
210   @Override
211   protected void passSIMD(PassData passData) {
212     if (!isValidSIMDWidth(simdWidth)) {
213       passScalar(passData);
214     } else {
215       if (im == 1) {
216         interleaved(passData);
217       } else {
218         blocked(passData);
219       }
220     }
221   }
222 
223   /**
224    * Handle factors of 5.
225    *
226    * @param passData PassData.
227    */
228   protected void interleaved(PassData passData) {
229     final double[] data = passData.in;
230     final double[] ret = passData.out;
231     int sign = passData.sign;
232     int i = passData.inOffset;
233     int j = passData.outOffset;
234     final double sin2PI_5s = sign * sin2PI_5;
235     final double sinPI_5s = sign * sinPI_5;
236     // First pass of the 5-point FFT has no twiddle factors.
237     for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP, i += LENGTH, j += LENGTH) {
238       final DoubleVector
239           z0 = fromArray(DOUBLE_SPECIES, data, i),
240           z1 = fromArray(DOUBLE_SPECIES, data, i + di),
241           z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
242           z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
243           z4 = fromArray(DOUBLE_SPECIES, data, i + di4);
244       final DoubleVector
245           t1 = z1.add(z4),
246           t2 = z2.add(z3),
247           t3 = z1.sub(z4),
248           t4 = z2.sub(z3),
249           t5 = t1.add(t2),
250           t6 = t1.sub(t2).mul(tau),
251           t7 = t5.mul(-0.25).add(z0),
252           t8 = t7.add(t6),
253           t9 = t7.sub(t6),
254           t10 = t3.mul(sin2PI_5s).add(t4.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM),
255           t11 = t4.mul(-sin2PI_5s).add(t3.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM);
256       z0.add(t5).intoArray(ret, j);
257       t8.add(t10.mul(NEGATE_RE)).intoArray(ret, j + dj);
258       t9.add(t11.mul(NEGATE_RE)).intoArray(ret, j + dj2);
259       t9.add(t11.mul(NEGATE_IM)).intoArray(ret, j + dj3);
260       t8.add(t10.mul(NEGATE_IM)).intoArray(ret, j + dj4);
261     }
262 
263     j += jstep;
264     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
265       final int index = k * 4;
266       final DoubleVector
267           w1r = broadcast(DOUBLE_SPECIES, wr[index]),
268           w2r = broadcast(DOUBLE_SPECIES, wr[index + 1]),
269           w3r = broadcast(DOUBLE_SPECIES, wr[index + 2]),
270           w4r = broadcast(DOUBLE_SPECIES, wr[index + 3]),
271           w1i = broadcast(DOUBLE_SPECIES, -sign * wi[index]).mul(NEGATE_IM),
272           w2i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 1]).mul(NEGATE_IM),
273           w3i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 2]).mul(NEGATE_IM),
274           w4i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 3]).mul(NEGATE_IM);
275       for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP, i += LENGTH, j += LENGTH) {
276         final DoubleVector
277             z0 = fromArray(DOUBLE_SPECIES, data, i),
278             z1 = fromArray(DOUBLE_SPECIES, data, i + di),
279             z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
280             z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
281             z4 = fromArray(DOUBLE_SPECIES, data, i + di4);
282         final DoubleVector
283             t1 = z1.add(z4),
284             t2 = z2.add(z3),
285             t3 = z1.sub(z4),
286             t4 = z2.sub(z3),
287             t5 = t1.add(t2),
288             t6 = t1.sub(t2).mul(tau),
289             t7 = t5.mul(-0.25).add(z0),
290             t8 = t7.add(t6),
291             t9 = t7.sub(t6),
292             t10 = t3.mul(sin2PI_5s).add(t4.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM),
293             t11 = t4.mul(-sin2PI_5s).add(t3.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM);
294         z0.add(t5).intoArray(ret, j);
295         final DoubleVector
296             x1 = t10.fma(NEGATE_RE, t8),
297             x2 = t11.fma(NEGATE_RE, t9),
298             x3 = t11.fma(NEGATE_IM, t9),
299             x4 = t10.fma(NEGATE_IM, t8);
300         w1r.fma(x1, w1i.mul(x1).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj);
301         w2r.fma(x2, w2i.mul(x2).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj2);
302         w3r.fma(x3, w3i.mul(x3).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj3);
303         w4r.fma(x4, w4i.mul(x4).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj4);
304       }
305     }
306   }
307 
308   /**
309    * Handle factors of 5.
310    *
311    * @param passData PassData.
312    */
313   protected void blocked(PassData passData) {
314     final double[] data = passData.in;
315     final double[] ret = passData.out;
316     int sign = passData.sign;
317     int i = passData.inOffset;
318     int j = passData.outOffset;
319     final double sin2PI_5s = sign * sin2PI_5;
320     final double sinPI_5s = sign * sinPI_5;
321     // First pass of the 5-point FFT has no twiddle factors.
322     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
323       final DoubleVector
324           z0r = fromArray(DOUBLE_SPECIES, data, i),
325           z1r = fromArray(DOUBLE_SPECIES, data, i + di),
326           z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
327           z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
328           z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
329           z0i = fromArray(DOUBLE_SPECIES, data, i + im),
330           z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
331           z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
332           z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
333           z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im);
334       final DoubleVector
335           t1r = z1r.add(z4r),
336           t1i = z1i.add(z4i),
337           t2r = z2r.add(z3r),
338           t2i = z2i.add(z3i),
339           t3r = z1r.sub(z4r),
340           t3i = z1i.sub(z4i),
341           t4r = z2r.sub(z3r),
342           t4i = z2i.sub(z3i),
343           t5r = t1r.add(t2r),
344           t5i = t1i.add(t2i),
345           t6r = t1r.sub(t2r).mul(tau),
346           t6i = t1i.sub(t2i).mul(tau),
347           t7r = t5r.mul(-0.25).add(z0r),
348           t7i = t5i.mul(-0.25).add(z0i),
349           t8r = t7r.add(t6r),
350           t8i = t7i.add(t6i),
351           t9r = t7r.sub(t6r),
352           t9i = t7i.sub(t6i),
353           t10r = t3r.mul(sin2PI_5s).add(t4r.mul(sinPI_5s)),
354           t10i = t3i.mul(sin2PI_5s).add(t4i.mul(sinPI_5s)),
355           t11r = t4r.mul(-sin2PI_5s).add(t3r.mul(sinPI_5s)),
356           t11i = t4i.mul(-sin2PI_5s).add(t3i.mul(sinPI_5s));
357       z0r.add(t5r).intoArray(ret, j);
358       z0i.add(t5i).intoArray(ret, j + im);
359       t8r.sub(t10i).intoArray(ret, j + dj);
360       t8i.add(t10r).intoArray(ret, j + dj + im);
361       t9r.sub(t11i).intoArray(ret, j + dj2);
362       t9i.add(t11r).intoArray(ret, j + dj2 + im);
363       t9r.add(t11i).intoArray(ret, j + dj3);
364       t9i.sub(t11r).intoArray(ret, j + dj3 + im);
365       t8r.add(t10i).intoArray(ret, j + dj4);
366       t8i.sub(t10r).intoArray(ret, j + dj4 + im);
367     }
368     j += jstep;
369     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
370       final int index = k * 4;
371       final double
372           w1r = wr[index],
373           w2r = wr[index + 1],
374           w3r = wr[index + 2],
375           w4r = wr[index + 3],
376           w1i = -sign * wi[index],
377           w2i = -sign * wi[index + 1],
378           w3i = -sign * wi[index + 2],
379           w4i = -sign * wi[index + 3];
380       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
381         final DoubleVector
382             z0r = fromArray(DOUBLE_SPECIES, data, i),
383             z1r = fromArray(DOUBLE_SPECIES, data, i + di),
384             z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
385             z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
386             z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
387             z0i = fromArray(DOUBLE_SPECIES, data, i + im),
388             z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
389             z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
390             z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
391             z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im);
392         final DoubleVector
393             t1r = z1r.add(z4r),
394             t1i = z1i.add(z4i),
395             t2r = z2r.add(z3r),
396             t2i = z2i.add(z3i),
397             t3r = z1r.sub(z4r),
398             t3i = z1i.sub(z4i),
399             t4r = z2r.sub(z3r),
400             t4i = z2i.sub(z3i),
401             t5r = t1r.add(t2r),
402             t5i = t1i.add(t2i),
403             t6r = t1r.sub(t2r).mul(tau),
404             t6i = t1i.sub(t2i).mul(tau),
405             t7r = t5r.mul(-0.25).add(z0r),
406             t7i = t5i.mul(-0.25).add(z0i),
407             t8r = t7r.add(t6r),
408             t8i = t7i.add(t6i),
409             t9r = t7r.sub(t6r),
410             t9i = t7i.sub(t6i),
411             t10r = t3r.mul(sin2PI_5s).add(t4r.mul(sinPI_5s)),
412             t10i = t3i.mul(sin2PI_5s).add(t4i.mul(sinPI_5s)),
413             t11r = t4r.mul(-sin2PI_5s).add(t3r.mul(sinPI_5s)),
414             t11i = t4i.mul(-sin2PI_5s).add(t3i.mul(sinPI_5s));
415         z0r.add(t5r).intoArray(ret, j);
416         z0i.add(t5i).intoArray(ret, j + im);
417         final DoubleVector
418             x1r = t8r.sub(t10i), x1i = t8i.add(t10r),
419             x2r = t9r.sub(t11i), x2i = t9i.add(t11r),
420             x3r = t9r.add(t11i), x3i = t9i.sub(t11r),
421             x4r = t8r.add(t10i), x4i = t8i.sub(t10r);
422         x1r.mul(w1r).sub(x1i.mul(w1i)).intoArray(ret, j + dj);
423         x2r.mul(w2r).sub(x2i.mul(w2i)).intoArray(ret, j + dj2);
424         x3r.mul(w3r).sub(x3i.mul(w3i)).intoArray(ret, j + dj3);
425         x4r.mul(w4r).sub(x4i.mul(w4i)).intoArray(ret, j + dj4);
426         x1i.mul(w1r).add(x1r.mul(w1i)).intoArray(ret, j + dj + im);
427         x2i.mul(w2r).add(x2r.mul(w2i)).intoArray(ret, j + dj2 + im);
428         x3i.mul(w3r).add(x3r.mul(w3i)).intoArray(ret, j + dj3 + im);
429         x4i.mul(w4r).add(x4r.mul(w4i)).intoArray(ret, j + dj4 + im);
430       }
431     }
432   }
433 }