View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.fft;
39  
40  import jdk.incubator.vector.DoubleVector;
41  
42  import static java.lang.Math.fma;
43  import static jdk.incubator.vector.DoubleVector.broadcast;
44  import static jdk.incubator.vector.DoubleVector.fromArray;
45  import static org.apache.commons.math3.util.FastMath.PI;
46  import static org.apache.commons.math3.util.FastMath.sin;
47  import static org.apache.commons.math3.util.FastMath.sqrt;
48  
49  /**
50   * The MixedRadixFactor5 class handles factors of 5 in the FFT.
51   */
52  public class MixedRadixFactor5 extends MixedRadixFactor {
53  
54    private static final double sqrt5_4 = sqrt(5.0) / 4.0;
55    private static final double sinPI_5 = sin(PI / 5.0);
56    private static final double sin2PI_5 = sin(2.0 * PI / 5.0);
57  
58    private final int di2;
59    private final int di3;
60    private final int di4;
61    private final int dj2;
62    private final int dj3;
63    private final int dj4;
64    private final double tau = sqrt5_4;
65  
66    /**
67     * Construct a MixedRadixFactor5.
68     *
69     * @param passConstants PassConstants.
70     */
71    public MixedRadixFactor5(PassConstants passConstants) {
72      super(passConstants);
73      di2 = 2 * di;
74      di3 = 3 * di;
75      di4 = 4 * di;
76      dj2 = 2 * dj;
77      dj3 = 3 * dj;
78      dj4 = 4 * dj;
79    }
80  
81    /**
82     * Handle factors of 5.
83     *
84     * @param passData PassData.
85     */
86    @Override
87    protected void passScalar(PassData passData) {
88      final double[] data = passData.in;
89      final double[] ret = passData.out;
90      int sign = passData.sign;
91      int i = passData.inOffset;
92      int j = passData.outOffset;
93      final double sin2PI_5s = sign * sin2PI_5;
94      final double sinPI_5s = sign * sinPI_5;
95      // First pass of the 5-point FFT has no twiddle factors.
96      for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
97        final double z0r = data[i];
98        final double z1r = data[i + di];
99        final double z2r = data[i + di2];
100       final double z3r = data[i + di3];
101       final double z4r = data[i + di4];
102       final double z0i = data[i + im];
103       final double z1i = data[i + di + im];
104       final double z2i = data[i + di2 + im];
105       final double z3i = data[i + di3 + im];
106       final double z4i = data[i + di4 + im];
107       final double t1r = z1r + z4r;
108       final double t1i = z1i + z4i;
109       final double t2r = z2r + z3r;
110       final double t2i = z2i + z3i;
111       final double t3r = z1r - z4r;
112       final double t3i = z1i - z4i;
113       final double t4r = z2r - z3r;
114       final double t4i = z2i - z3i;
115       final double t5r = t1r + t2r;
116       final double t5i = t1i + t2i;
117       final double t6r = tau * (t1r - t2r);
118       final double t6i = tau * (t1i - t2i);
119       final double t7r = fma(-0.25, t5r, z0r);
120       final double t7i = fma(-0.25, t5i, z0i);
121       final double t8r = t7r + t6r;
122       final double t8i = t7i + t6i;
123       final double t9r = t7r - t6r;
124       final double t9i = t7i - t6i;
125       final double t10r = fma(sin2PI_5s, t3r, sinPI_5s * t4r);
126       final double t10i = fma(sin2PI_5s, t3i, sinPI_5s * t4i);
127       final double t11r = fma(-sin2PI_5s, t4r, sinPI_5s * t3r);
128       final double t11i = fma(-sin2PI_5s, t4i, sinPI_5s * t3i);
129       ret[j] = z0r + t5r;
130       ret[j + im] = z0i + t5i;
131       ret[j + dj] = t8r - t10i;
132       ret[j + dj + im] = t8i + t10r;
133       ret[j + dj2] = t9r - t11i;
134       ret[j + dj2 + im] = t9i + t11r;
135       ret[j + dj3] = t9r + t11i;
136       ret[j + dj3 + im] = t9i - t11r;
137       ret[j + dj4] = t8r + t10i;
138       ret[j + dj4 + im] = t8i - t10r;
139     }
140 
141     j += jstep;
142     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
143 //      final double[] twids = twiddles[k];
144 //      final double w1r = twids[0];
145 //      final double w1i = -sign * twids[1];
146 //      final double w2r = twids[2];
147 //      final double w2i = -sign * twids[3];
148 //      final double w3r = twids[4];
149 //      final double w3i = -sign * twids[5];
150 //      final double w4r = twids[6];
151 //      final double w4i = -sign * twids[7];
152       final int index = k * 4;
153       final double w1r = wr[index];
154       final double w2r = wr[index + 1];
155       final double w3r = wr[index + 2];
156       final double w4r = wr[index + 3];
157       final double w1i = -sign * wi[index];
158       final double w2i = -sign * wi[index + 1];
159       final double w3i = -sign * wi[index + 2];
160       final double w4i = -sign * wi[index + 3];
161       for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
162         final double z0r = data[i];
163         final double z1r = data[i + di];
164         final double z2r = data[i + di2];
165         final double z3r = data[i + di3];
166         final double z4r = data[i + di4];
167         final double z0i = data[i + im];
168         final double z1i = data[i + di + im];
169         final double z2i = data[i + di2 + im];
170         final double z3i = data[i + di3 + im];
171         final double z4i = data[i + di4 + im];
172         final double t1r = z1r + z4r;
173         final double t1i = z1i + z4i;
174         final double t2r = z2r + z3r;
175         final double t2i = z2i + z3i;
176         final double t3r = z1r - z4r;
177         final double t3i = z1i - z4i;
178         final double t4r = z2r - z3r;
179         final double t4i = z2i - z3i;
180         final double t5r = t1r + t2r;
181         final double t5i = t1i + t2i;
182         final double t6r = tau * (t1r - t2r);
183         final double t6i = tau * (t1i - t2i);
184         final double t7r = fma(-0.25, t5r, z0r);
185         final double t7i = fma(-0.25, t5i, z0i);
186         final double t8r = t7r + t6r;
187         final double t8i = t7i + t6i;
188         final double t9r = t7r - t6r;
189         final double t9i = t7i - t6i;
190         final double t10r = fma(sin2PI_5s, t3r, sinPI_5s * t4r);
191         final double t10i = fma(sin2PI_5s, t3i, sinPI_5s * t4i);
192         final double t11r = fma(-sin2PI_5s, t4r, sinPI_5s * t3r);
193         final double t11i = fma(-sin2PI_5s, t4i, sinPI_5s * t3i);
194         ret[j] = z0r + t5r;
195         ret[j + im] = z0i + t5i;
196 
197         multiplyAndStore(t8r - t10i, t8i + t10r, w1r, w1i, ret, j + dj, j + dj + im);
198         multiplyAndStore(t9r - t11i, t9i + t11r, w2r, w2i, ret, j + dj2, j + dj2 + im);
199         multiplyAndStore(t9r + t11i, t9i - t11r, w3r, w3i, ret, j + dj3, j + dj3 + im);
200         multiplyAndStore(t8r + t10i, t8i - t10r, w4r, w4i, ret, j + dj4, j + dj4 + im);
201       }
202     }
203   }
204 
205   /**
206    * Handle factors of 5 using SIMD vectors.
207    *
208    * @param passData PassData.
209    */
210   @Override
211   protected void passSIMD(PassData passData) {
212     if (im == 1) {
213       // If the inner loop limit is not divisible by the loop increment, use the scalar method.
214       if (innerLoopLimit % INTERLEAVED_LOOP != 0) {
215         passScalar(passData);
216       } else {
217         interleaved(passData);
218       }
219     } else {
220       // If the inner loop limit is not divisible by the loop increment, use the scalar method.
221       if (innerLoopLimit % BLOCK_LOOP != 0) {
222         passScalar(passData);
223       } else {
224         blocked(passData);
225       }
226     }
227   }
228 
229   /**
230    * Handle factors of 5.
231    *
232    * @param passData PassData.
233    */
234   protected void interleaved(PassData passData) {
235     final double[] data = passData.in;
236     final double[] ret = passData.out;
237     int sign = passData.sign;
238     int i = passData.inOffset;
239     int j = passData.outOffset;
240     final double sin2PI_5s = sign * sin2PI_5;
241     final double sinPI_5s = sign * sinPI_5;
242     // First pass of the 5-point FFT has no twiddle factors.
243     for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP, i += LENGTH, j += LENGTH) {
244       DoubleVector
245           z0 = fromArray(DOUBLE_SPECIES, data, i),
246           z1 = fromArray(DOUBLE_SPECIES, data, i + di),
247           z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
248           z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
249           z4 = fromArray(DOUBLE_SPECIES, data, i + di4);
250       DoubleVector
251           t1 = z1.add(z4),
252           t2 = z2.add(z3),
253           t3 = z1.sub(z4),
254           t4 = z2.sub(z3),
255           t5 = t1.add(t2),
256           t6 = t1.sub(t2).mul(tau),
257           t7 = t5.mul(-0.25).add(z0),
258           t8 = t7.add(t6),
259           t9 = t7.sub(t6),
260           t10 = t3.mul(sin2PI_5s).add(t4.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM),
261           t11 = t4.mul(-sin2PI_5s).add(t3.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM);
262       z0.add(t5).intoArray(ret, j);
263       t8.add(t10.mul(NEGATE_RE)).intoArray(ret, j + dj);
264       t9.add(t11.mul(NEGATE_RE)).intoArray(ret, j + dj2);
265       t9.add(t11.mul(NEGATE_IM)).intoArray(ret, j + dj3);
266       t8.add(t10.mul(NEGATE_IM)).intoArray(ret, j + dj4);
267     }
268 
269     j += jstep;
270     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
271 //      final double[] twids = twiddles[k];
272 //      DoubleVector
273 //          w1r = broadcast(DOUBLE_SPECIES, twids[0]),
274 //          w1i = broadcast(DOUBLE_SPECIES, -sign * twids[1]).mul(NEGATE_IM),
275 //          w2r = broadcast(DOUBLE_SPECIES, twids[2]),
276 //          w2i = broadcast(DOUBLE_SPECIES, -sign * twids[3]).mul(NEGATE_IM),
277 //          w3r = broadcast(DOUBLE_SPECIES, twids[4]),
278 //          w3i = broadcast(DOUBLE_SPECIES, -sign * twids[5]).mul(NEGATE_IM),
279 //          w4r = broadcast(DOUBLE_SPECIES, twids[6]),
280 //          w4i = broadcast(DOUBLE_SPECIES, -sign * twids[7]).mul(NEGATE_IM);
281       final int index = k * 4;
282       final DoubleVector
283           w1r = broadcast(DOUBLE_SPECIES, wr[index]),
284           w2r = broadcast(DOUBLE_SPECIES, wr[index + 1]),
285           w3r = broadcast(DOUBLE_SPECIES, wr[index + 2]),
286           w4r = broadcast(DOUBLE_SPECIES, wr[index + 3]),
287           w1i = broadcast(DOUBLE_SPECIES, -sign * wi[index]).mul(NEGATE_IM),
288           w2i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 1]).mul(NEGATE_IM),
289           w3i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 2]).mul(NEGATE_IM),
290           w4i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 3]).mul(NEGATE_IM);
291       for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP, i += LENGTH, j += LENGTH) {
292         DoubleVector
293             z0 = fromArray(DOUBLE_SPECIES, data, i),
294             z1 = fromArray(DOUBLE_SPECIES, data, i + di),
295             z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
296             z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
297             z4 = fromArray(DOUBLE_SPECIES, data, i + di4);
298         DoubleVector
299             t1 = z1.add(z4),
300             t2 = z2.add(z3),
301             t3 = z1.sub(z4),
302             t4 = z2.sub(z3),
303             t5 = t1.add(t2),
304             t6 = t1.sub(t2).mul(tau),
305             t7 = t5.mul(-0.25).add(z0),
306             t8 = t7.add(t6),
307             t9 = t7.sub(t6),
308             t10 = t3.mul(sin2PI_5s).add(t4.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM),
309             t11 = t4.mul(-sin2PI_5s).add(t3.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM);
310         z0.add(t5).intoArray(ret, j);
311         DoubleVector
312             x1 = t8.add(t10.mul(NEGATE_RE)),
313             x2 = t9.add(t11.mul(NEGATE_RE)),
314             x3 = t9.add(t11.mul(NEGATE_IM)),
315             x4 = t8.add(t10.mul(NEGATE_IM));
316         w1r.mul(x1).add(w1i.mul(x1).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj);
317         w2r.mul(x2).add(w2i.mul(x2).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj2);
318         w3r.mul(x3).add(w3i.mul(x3).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj3);
319         w4r.mul(x4).add(w4i.mul(x4).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj4);
320       }
321     }
322   }
323 
324   /**
325    * Handle factors of 5.
326    *
327    * @param passData PassData.
328    */
329   protected void blocked(PassData passData) {
330     final double[] data = passData.in;
331     final double[] ret = passData.out;
332     int sign = passData.sign;
333     int i = passData.inOffset;
334     int j = passData.outOffset;
335     final double sin2PI_5s = sign * sin2PI_5;
336     final double sinPI_5s = sign * sinPI_5;
337     // First pass of the 5-point FFT has no twiddle factors.
338     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
339       final DoubleVector
340           z0r = fromArray(DOUBLE_SPECIES, data, i),
341           z1r = fromArray(DOUBLE_SPECIES, data, i + di),
342           z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
343           z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
344           z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
345           z0i = fromArray(DOUBLE_SPECIES, data, i + im),
346           z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
347           z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
348           z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
349           z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im);
350       final DoubleVector
351           t1r = z1r.add(z4r),
352           t1i = z1i.add(z4i),
353           t2r = z2r.add(z3r),
354           t2i = z2i.add(z3i),
355           t3r = z1r.sub(z4r),
356           t3i = z1i.sub(z4i),
357           t4r = z2r.sub(z3r),
358           t4i = z2i.sub(z3i),
359           t5r = t1r.add(t2r),
360           t5i = t1i.add(t2i),
361           t6r = t1r.sub(t2r).mul(tau),
362           t6i = t1i.sub(t2i).mul(tau),
363           t7r = t5r.mul(-0.25).add(z0r),
364           t7i = t5i.mul(-0.25).add(z0i),
365           t8r = t7r.add(t6r),
366           t8i = t7i.add(t6i),
367           t9r = t7r.sub(t6r),
368           t9i = t7i.sub(t6i),
369           t10r = t3r.mul(sin2PI_5s).add(t4r.mul(sinPI_5s)),
370           t10i = t3i.mul(sin2PI_5s).add(t4i.mul(sinPI_5s)),
371           t11r = t4r.mul(-sin2PI_5s).add(t3r.mul(sinPI_5s)),
372           t11i = t4i.mul(-sin2PI_5s).add(t3i.mul(sinPI_5s));
373       z0r.add(t5r).intoArray(ret, j);
374       z0i.add(t5i).intoArray(ret, j + im);
375       t8r.sub(t10i).intoArray(ret, j + dj);
376       t8i.add(t10r).intoArray(ret, j + dj + im);
377       t9r.sub(t11i).intoArray(ret, j + dj2);
378       t9i.add(t11r).intoArray(ret, j + dj2 + im);
379       t9r.add(t11i).intoArray(ret, j + dj3);
380       t9i.sub(t11r).intoArray(ret, j + dj3 + im);
381       t8r.add(t10i).intoArray(ret, j + dj4);
382       t8i.sub(t10r).intoArray(ret, j + dj4 + im);
383     }
384     j += jstep;
385     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
386 //      final double[] twids = twiddles[k];
387 //      DoubleVector
388 //          w1r = broadcast(DOUBLE_SPECIES, twids[0]),
389 //          w1i = broadcast(DOUBLE_SPECIES, -sign * twids[1]),
390 //          w2r = broadcast(DOUBLE_SPECIES, twids[2]),
391 //          w2i = broadcast(DOUBLE_SPECIES, -sign * twids[3]),
392 //          w3r = broadcast(DOUBLE_SPECIES, twids[4]),
393 //          w3i = broadcast(DOUBLE_SPECIES, -sign * twids[5]),
394 //          w4r = broadcast(DOUBLE_SPECIES, twids[6]),
395 //          w4i = broadcast(DOUBLE_SPECIES, -sign * twids[7]);
396       final int index = k * 4;
397       final DoubleVector
398           w1r = broadcast(DOUBLE_SPECIES, wr[index]),
399           w2r = broadcast(DOUBLE_SPECIES, wr[index + 1]),
400           w3r = broadcast(DOUBLE_SPECIES, wr[index + 2]),
401           w4r = broadcast(DOUBLE_SPECIES, wr[index + 3]),
402           w1i = broadcast(DOUBLE_SPECIES, -sign * wi[index]),
403           w2i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 1]),
404           w3i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 2]),
405           w4i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 3]);
406       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
407         final DoubleVector
408             z0r = fromArray(DOUBLE_SPECIES, data, i),
409             z1r = fromArray(DOUBLE_SPECIES, data, i + di),
410             z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
411             z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
412             z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
413             z0i = fromArray(DOUBLE_SPECIES, data, i + im),
414             z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
415             z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
416             z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
417             z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im);
418         final DoubleVector
419             t1r = z1r.add(z4r),
420             t1i = z1i.add(z4i),
421             t2r = z2r.add(z3r),
422             t2i = z2i.add(z3i),
423             t3r = z1r.sub(z4r),
424             t3i = z1i.sub(z4i),
425             t4r = z2r.sub(z3r),
426             t4i = z2i.sub(z3i),
427             t5r = t1r.add(t2r),
428             t5i = t1i.add(t2i),
429             t6r = t1r.sub(t2r).mul(tau),
430             t6i = t1i.sub(t2i).mul(tau),
431             t7r = t5r.mul(-0.25).add(z0r),
432             t7i = t5i.mul(-0.25).add(z0i),
433             t8r = t7r.add(t6r),
434             t8i = t7i.add(t6i),
435             t9r = t7r.sub(t6r),
436             t9i = t7i.sub(t6i),
437             t10r = t3r.mul(sin2PI_5s).add(t4r.mul(sinPI_5s)),
438             t10i = t3i.mul(sin2PI_5s).add(t4i.mul(sinPI_5s)),
439             t11r = t4r.mul(-sin2PI_5s).add(t3r.mul(sinPI_5s)),
440             t11i = t4i.mul(-sin2PI_5s).add(t3i.mul(sinPI_5s));
441         z0r.add(t5r).intoArray(ret, j);
442         z0i.add(t5i).intoArray(ret, j + im);
443         DoubleVector
444             x1r = t8r.sub(t10i), x1i = t8i.add(t10r),
445             x2r = t9r.sub(t11i), x2i = t9i.add(t11r),
446             x3r = t9r.add(t11i), x3i = t9i.sub(t11r),
447             x4r = t8r.add(t10i), x4i = t8i.sub(t10r);
448         w1r.mul(x1r).add(w1i.neg().mul(x1i)).intoArray(ret, j + dj);
449         w2r.mul(x2r).add(w2i.neg().mul(x2i)).intoArray(ret, j + dj2);
450         w3r.mul(x3r).add(w3i.neg().mul(x3i)).intoArray(ret, j + dj3);
451         w4r.mul(x4r).add(w4i.neg().mul(x4i)).intoArray(ret, j + dj4);
452         w1r.mul(x1i).add(w1i.mul(x1r)).intoArray(ret, j + dj + im);
453         w2r.mul(x2i).add(w2i.mul(x2r)).intoArray(ret, j + dj2 + im);
454         w3r.mul(x3i).add(w3i.mul(x3r)).intoArray(ret, j + dj3 + im);
455         w4r.mul(x4i).add(w4i.mul(x4r)).intoArray(ret, j + dj4 + im);
456       }
457     }
458   }
459 
460 }