View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.fft;
39  
40  import jdk.incubator.vector.DoubleVector;
41  
42  import static java.lang.Math.fma;
43  import static jdk.incubator.vector.DoubleVector.broadcast;
44  import static jdk.incubator.vector.DoubleVector.fromArray;
45  import static org.apache.commons.math3.util.FastMath.sqrt;
46  
47  /**
48   * The MixedRadixFactor6 class handles factors of 6 in the FFT.
49   */
50  public class MixedRadixFactor6 extends MixedRadixFactor {
51  
52    private static final double sqrt3_2 = sqrt(3.0) / 2.0;
53  
54    private final int di2;
55    private final int di3;
56    private final int di4;
57    private final int di5;
58    private final int dj2;
59    private final int dj3;
60    private final int dj4;
61    private final int dj5;
62  
63    /**
64     * Construct a MixedRadixFactor6.
65     *
66     * @param passConstants PassConstants.
67     */
68    public MixedRadixFactor6(PassConstants passConstants) {
69      super(passConstants);
70      di2 = 2 * di;
71      di3 = 3 * di;
72      di4 = 4 * di;
73      di5 = 5 * di;
74      dj2 = 2 * dj;
75      dj3 = 3 * dj;
76      dj4 = 4 * dj;
77      dj5 = 5 * dj;
78    }
79  
80    /**
81     * Handle factors of 6.
82     *
83     * @param passData PassData.
84     */
85    @Override
86    protected void passScalar(PassData passData) {
87      final double[] data = passData.in;
88      final double[] ret = passData.out;
89      int sign = passData.sign;
90      int i = passData.inOffset;
91      int j = passData.outOffset;
92      final double tau = sign * sqrt3_2;
93      // First pass of the 6-point FFT has no twiddle factors.
94      for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
95        final double z0r = data[i];
96        final double z1r = data[i + di];
97        final double z2r = data[i + di2];
98        final double z3r = data[i + di3];
99        final double z4r = data[i + di4];
100       final double z5r = data[i + di5];
101       final double z0i = data[i + im];
102       final double z1i = data[i + di + im];
103       final double z2i = data[i + di2 + im];
104       final double z3i = data[i + di3 + im];
105       final double z4i = data[i + di4 + im];
106       final double z5i = data[i + di5 + im];
107       final double ta1r = z2r + z4r;
108       final double ta1i = z2i + z4i;
109       final double ta2r = fma(-0.5, ta1r, z0r);
110       final double ta2i = fma(-0.5, ta1i, z0i);
111       final double ta3r = tau * (z2r - z4r);
112       final double ta3i = tau * (z2i - z4i);
113       final double a0r = z0r + ta1r;
114       final double a0i = z0i + ta1i;
115       final double a1r = ta2r - ta3i;
116       final double a1i = ta2i + ta3r;
117       final double a2r = ta2r + ta3i;
118       final double a2i = ta2i - ta3r;
119       final double tb1r = z5r + z1r;
120       final double tb1i = z5i + z1i;
121       final double tb2r = fma(-0.5, tb1r, z3r);
122       final double tb2i = fma(-0.5, tb1i, z3i);
123       final double tb3r = tau * (z5r - z1r);
124       final double tb3i = tau * (z5i - z1i);
125       final double b0r = z3r + tb1r;
126       final double b0i = z3i + tb1i;
127       final double b1r = tb2r - tb3i;
128       final double b1i = tb2i + tb3r;
129       final double b2r = tb2r + tb3i;
130       final double b2i = tb2i - tb3r;
131       ret[j] = a0r + b0r;
132       ret[j + im] = a0i + b0i;
133       ret[j + dj] = a1r - b1r;
134       ret[j + dj + im] = a1i - b1i;
135       ret[j + dj2] = a2r + b2r;
136       ret[j + dj2 + im] = a2i + b2i;
137       ret[j + dj3] = a0r - b0r;
138       ret[j + dj3 + im] = a0i - b0i;
139       ret[j + dj4] = a1r + b1r;
140       ret[j + dj4 + im] = a1i + b1i;
141       ret[j + dj5] = a2r - b2r;
142       ret[j + dj5 + im] = a2i - b2i;
143     }
144 
145     j += jstep;
146     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
147       final double[] twids = twiddles[k];
148       final double w1r = twids[0];
149       final double w1i = -sign * twids[1];
150       final double w2r = twids[2];
151       final double w2i = -sign * twids[3];
152       final double w3r = twids[4];
153       final double w3i = -sign * twids[5];
154       final double w4r = twids[6];
155       final double w4i = -sign * twids[7];
156       final double w5r = twids[8];
157       final double w5i = -sign * twids[9];
158       for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
159         final double z0r = data[i];
160         final double z1r = data[i + di];
161         final double z2r = data[i + di2];
162         final double z3r = data[i + di3];
163         final double z4r = data[i + di4];
164         final double z5r = data[i + di5];
165         final double z0i = data[i + im];
166         final double z1i = data[i + di + im];
167         final double z2i = data[i + di2 + im];
168         final double z3i = data[i + di3 + im];
169         final double z4i = data[i + di4 + im];
170         final double z5i = data[i + di5 + im];
171         final double ta1r = z2r + z4r;
172         final double ta1i = z2i + z4i;
173         final double ta2r = fma(-0.5, ta1r, z0r);
174         final double ta2i = fma(-0.5, ta1i, z0i);
175         final double ta3r = tau * (z2r - z4r);
176         final double ta3i = tau * (z2i - z4i);
177         final double a0r = z0r + ta1r;
178         final double a0i = z0i + ta1i;
179         final double a1r = ta2r - ta3i;
180         final double a1i = ta2i + ta3r;
181         final double a2r = ta2r + ta3i;
182         final double a2i = ta2i - ta3r;
183         final double tb1r = z5r + z1r;
184         final double tb1i = z5i + z1i;
185         final double tb2r = fma(-0.5, tb1r, z3r);
186         final double tb2i = fma(-0.5, tb1i, z3i);
187         final double tb3r = tau * (z5r - z1r);
188         final double tb3i = tau * (z5i - z1i);
189         final double b0r = z3r + tb1r;
190         final double b0i = z3i + tb1i;
191         final double b1r = tb2r - tb3i;
192         final double b1i = tb2i + tb3r;
193         final double b2r = tb2r + tb3i;
194         final double b2i = tb2i - tb3r;
195         ret[j] = a0r + b0r;
196         ret[j + im] = a0i + b0i;
197         multiplyAndStore(a1r - b1r, a1i - b1i, w1r, w1i, ret, j + dj, j + dj + im);
198         multiplyAndStore(a2r + b2r, a2i + b2i, w2r, w2i, ret, j + dj2, j + dj2 + im);
199         multiplyAndStore(a0r - b0r, a0i - b0i, w3r, w3i, ret, j + dj3, j + dj3 + im);
200         multiplyAndStore(a1r + b1r, a1i + b1i, w4r, w4i, ret, j + dj4, j + dj4 + im);
201         multiplyAndStore(a2r - b2r, a2i - b2i, w5r, w5i, ret, j + dj5, j + dj5 + im);
202       }
203     }
204   }
205 
206   /**
207    * Handle factors of 6 using SIMD vectors.
208    *
209    * @param passData PassData.
210    */
211   @Override
212   protected void passSIMD(PassData passData) {
213     // Interleaved.
214     if (im == 1) {
215       // If the inner loop limit is not divisible by the loop increment, use the scalar method.
216       if (innerLoopLimit % LOOP != 0) {
217         passScalar(passData);
218       } else {
219         interleaved(passData);
220       }
221       // Blocked.
222     } else {
223       // If the inner loop limit is not divisible by the loop increment, use the scalar method.
224       if (innerLoopLimit % BLOCK_LOOP != 0) {
225         passScalar(passData);
226       } else {
227         blocked(passData);
228       }
229     }
230   }
231 
232   /**
233    * Handle factors of 6.
234    *
235    * @param passData PassData.
236    */
237   private void interleaved(PassData passData) {
238     final double[] data = passData.in;
239     final double[] ret = passData.out;
240     int sign = passData.sign;
241     int i = passData.inOffset;
242     int j = passData.outOffset;
243     final double tau = sign * sqrt3_2;
244     // First pass of the 6-point FFT has no twiddle factors.
245     for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP, i += LENGTH, j += LENGTH) {
246       DoubleVector
247           z0 = fromArray(DOUBLE_SPECIES, data, i),
248           z1 = fromArray(DOUBLE_SPECIES, data, i + di),
249           z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
250           z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
251           z4 = fromArray(DOUBLE_SPECIES, data, i + di4),
252           z5 = fromArray(DOUBLE_SPECIES, data, i + di5);
253       DoubleVector
254           ta1 = z2.add(z4),
255           ta2 = ta1.mul(-0.5).add(z0),
256           ta3 = z2.sub(z4).mul(tau).rearrange(SHUFFLE_RE_IM),
257           a0 = z0.add(ta1),
258           a1 = ta2.add(ta3.mul(NEGATE_RE)),
259           a2 = ta2.add(ta3.mul(NEGATE_IM)),
260           tb1 = z5.add(z1),
261           tb2 = tb1.mul(-0.5).add(z3),
262           tb3 = z5.sub(z1).mul(tau).rearrange(SHUFFLE_RE_IM),
263           b0 = z3.add(tb1),
264           b1 = tb2.add(tb3.mul(NEGATE_RE)),
265           b2 = tb2.add(tb3.mul(NEGATE_IM));
266       a0.add(b0).intoArray(ret, j);
267       a1.sub(b1).intoArray(ret, j + dj);
268       a2.add(b2).intoArray(ret, j + dj2);
269       a0.sub(b0).intoArray(ret, j + dj3);
270       a1.add(b1).intoArray(ret, j + dj4);
271       a2.sub(b2).intoArray(ret, j + dj5);
272     }
273 
274     j += jstep;
275     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
276       final double[] twids = twiddles[k];
277       DoubleVector
278           w1r = broadcast(DOUBLE_SPECIES, twids[0]),
279           w1i = broadcast(DOUBLE_SPECIES, -sign * twids[1]).mul(NEGATE_IM),
280           w2r = broadcast(DOUBLE_SPECIES, twids[2]),
281           w2i = broadcast(DOUBLE_SPECIES, -sign * twids[3]).mul(NEGATE_IM),
282           w3r = broadcast(DOUBLE_SPECIES, twids[4]),
283           w3i = broadcast(DOUBLE_SPECIES, -sign * twids[5]).mul(NEGATE_IM),
284           w4r = broadcast(DOUBLE_SPECIES, twids[6]),
285           w4i = broadcast(DOUBLE_SPECIES, -sign * twids[7]).mul(NEGATE_IM),
286           w5r = broadcast(DOUBLE_SPECIES, twids[8]),
287           w5i = broadcast(DOUBLE_SPECIES, -sign * twids[9]).mul(NEGATE_IM);
288       for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP, i += LENGTH, j += LENGTH) {
289         DoubleVector
290             z0 = fromArray(DOUBLE_SPECIES, data, i),
291             z1 = fromArray(DOUBLE_SPECIES, data, i + di),
292             z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
293             z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
294             z4 = fromArray(DOUBLE_SPECIES, data, i + di4),
295             z5 = fromArray(DOUBLE_SPECIES, data, i + di5);
296         DoubleVector
297             ta1 = z2.add(z4),
298             ta2 = ta1.mul(-0.5).add(z0),
299             ta3 = z2.sub(z4).mul(tau).rearrange(SHUFFLE_RE_IM),
300             a0 = z0.add(ta1),
301             a1 = ta2.add(ta3.mul(NEGATE_RE)),
302             a2 = ta2.add(ta3.mul(NEGATE_IM)),
303             tb1 = z5.add(z1),
304             tb2 = tb1.mul(-0.5).add(z3),
305             tb3 = z5.sub(z1).mul(tau).rearrange(SHUFFLE_RE_IM),
306             b0 = z3.add(tb1),
307             b1 = tb2.add(tb3.mul(NEGATE_RE)),
308             b2 = tb2.add(tb3.mul(NEGATE_IM));
309         a0.add(b0).intoArray(ret, j);
310         DoubleVector
311             x1 = a1.sub(b1),
312             x2 = a2.add(b2),
313             x3 = a0.sub(b0),
314             x4 = a1.add(b1),
315             x5 = a2.sub(b2);
316         w1r.mul(x1).add(w1i.mul(x1).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj);
317         w2r.mul(x2).add(w2i.mul(x2).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj2);
318         w3r.mul(x3).add(w3i.mul(x3).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj3);
319         w4r.mul(x4).add(w4i.mul(x4).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj4);
320         w5r.mul(x5).add(w5i.mul(x5).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj5);
321       }
322     }
323   }
324 
325   /**
326    * Handle factors of 6.
327    *
328    * @param passData PassData.
329    */
330   private void blocked(PassData passData) {
331     final double[] data = passData.in;
332     final double[] ret = passData.out;
333     int sign = passData.sign;
334     int i = passData.inOffset;
335     int j = passData.outOffset;
336     final double tau = sign * sqrt3_2;
337     // First pass of the 6-point FFT has no twiddle factors.
338     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
339       final DoubleVector
340           z0r = fromArray(DOUBLE_SPECIES, data, i),
341           z1r = fromArray(DOUBLE_SPECIES, data, i + di),
342           z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
343           z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
344           z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
345           z5r = fromArray(DOUBLE_SPECIES, data, i + di5),
346           z0i = fromArray(DOUBLE_SPECIES, data, i + im),
347           z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
348           z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
349           z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
350           z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im),
351           z5i = fromArray(DOUBLE_SPECIES, data, i + di5 + im);
352       final DoubleVector
353           ta1r = z2r.add(z4r),
354           ta1i = z2i.add(z4i),
355           ta2r = ta1r.mul(-0.5).add(z0r),
356           ta2i = ta1i.mul(-0.5).add(z0i),
357           ta3r = z2r.sub(z4r).mul(tau),
358           ta3i = z2i.sub(z4i).mul(tau),
359           a0r = z0r.add(ta1r),
360           a0i = z0i.add(ta1i),
361           a1r = ta2r.sub(ta3i),
362           a1i = ta2i.add(ta3r),
363           a2r = ta2r.add(ta3i),
364           a2i = ta2i.sub(ta3r),
365           tb1r = z5r.add(z1r),
366           tb1i = z5i.add(z1i),
367           tb2r = tb1r.mul(-0.5).add(z3r),
368           tb2i = tb1i.mul(-0.5).add(z3i),
369           tb3r = z5r.sub(z1r).mul(tau),
370           tb3i = z5i.sub(z1i).mul(tau),
371           b0r = z3r.add(tb1r),
372           b0i = z3i.add(tb1i),
373           b1r = tb2r.sub(tb3i),
374           b1i = tb2i.add(tb3r),
375           b2r = tb2r.add(tb3i),
376           b2i = tb2i.sub(tb3r);
377       a0r.add(b0r).intoArray(ret, j);
378       a0i.add(b0i).intoArray(ret, j + im);
379       a1r.sub(b1r).intoArray(ret, j + dj);
380       a1i.sub(b1i).intoArray(ret, j + dj + im);
381       a2r.add(b2r).intoArray(ret, j + dj2);
382       a2i.add(b2i).intoArray(ret, j + dj2 + im);
383       a0r.sub(b0r).intoArray(ret, j + dj3);
384       a0i.sub(b0i).intoArray(ret, j + dj3 + im);
385       a1r.add(b1r).intoArray(ret, j + dj4);
386       a1i.add(b1i).intoArray(ret, j + dj4 + im);
387       a2r.sub(b2r).intoArray(ret, j + dj5);
388       a2i.sub(b2i).intoArray(ret, j + dj5 + im);
389     }
390 
391     j += jstep;
392     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
393       final double[] twids = twiddles[k];
394       DoubleVector
395           w1r = broadcast(DOUBLE_SPECIES, twids[0]),
396           w1i = broadcast(DOUBLE_SPECIES, -sign * twids[1]),
397           w2r = broadcast(DOUBLE_SPECIES, twids[2]),
398           w2i = broadcast(DOUBLE_SPECIES, -sign * twids[3]),
399           w3r = broadcast(DOUBLE_SPECIES, twids[4]),
400           w3i = broadcast(DOUBLE_SPECIES, -sign * twids[5]),
401           w4r = broadcast(DOUBLE_SPECIES, twids[6]),
402           w4i = broadcast(DOUBLE_SPECIES, -sign * twids[7]),
403           w5r = broadcast(DOUBLE_SPECIES, twids[8]),
404           w5i = broadcast(DOUBLE_SPECIES, -sign * twids[9]);
405       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
406         final DoubleVector
407             z0r = fromArray(DOUBLE_SPECIES, data, i),
408             z1r = fromArray(DOUBLE_SPECIES, data, i + di),
409             z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
410             z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
411             z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
412             z5r = fromArray(DOUBLE_SPECIES, data, i + di5),
413             z0i = fromArray(DOUBLE_SPECIES, data, i + im),
414             z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
415             z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
416             z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
417             z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im),
418             z5i = fromArray(DOUBLE_SPECIES, data, i + di5 + im);
419         final DoubleVector
420             ta1r = z2r.add(z4r),
421             ta1i = z2i.add(z4i),
422             ta2r = ta1r.mul(-0.5).add(z0r),
423             ta2i = ta1i.mul(-0.5).add(z0i),
424             ta3r = z2r.sub(z4r).mul(tau),
425             ta3i = z2i.sub(z4i).mul(tau),
426             a0r = z0r.add(ta1r),
427             a0i = z0i.add(ta1i),
428             a1r = ta2r.sub(ta3i),
429             a1i = ta2i.add(ta3r),
430             a2r = ta2r.add(ta3i),
431             a2i = ta2i.sub(ta3r),
432             tb1r = z5r.add(z1r),
433             tb1i = z5i.add(z1i),
434             tb2r = tb1r.mul(-0.5).add(z3r),
435             tb2i = tb1i.mul(-0.5).add(z3i),
436             tb3r = z5r.sub(z1r).mul(tau),
437             tb3i = z5i.sub(z1i).mul(tau),
438             b0r = z3r.add(tb1r),
439             b0i = z3i.add(tb1i),
440             b1r = tb2r.sub(tb3i),
441             b1i = tb2i.add(tb3r),
442             b2r = tb2r.add(tb3i),
443             b2i = tb2i.sub(tb3r);
444         a0r.add(b0r).intoArray(ret, j);
445         a0i.add(b0i).intoArray(ret, j + im);
446         DoubleVector
447             x1r = a1r.sub(b1r), x1i = a1i.sub(b1i),
448             x2r = a2r.add(b2r), x2i = a2i.add(b2i),
449             x3r = a0r.sub(b0r), x3i = a0i.sub(b0i),
450             x4r = a1r.add(b1r), x4i = a1i.add(b1i),
451             x5r = a2r.sub(b2r), x5i = a2i.sub(b2i);
452         w1r.mul(x1r).add(w1i.neg().mul(x1i)).intoArray(ret, j + dj);
453         w2r.mul(x2r).add(w2i.neg().mul(x2i)).intoArray(ret, j + dj2);
454         w3r.mul(x3r).add(w3i.neg().mul(x3i)).intoArray(ret, j + dj3);
455         w4r.mul(x4r).add(w4i.neg().mul(x4i)).intoArray(ret, j + dj4);
456         w5r.mul(x5r).add(w5i.neg().mul(x5i)).intoArray(ret, j + dj5);
457         w1r.mul(x1i).add(w1i.mul(x1r)).intoArray(ret, j + dj + im);
458         w2r.mul(x2i).add(w2i.mul(x2r)).intoArray(ret, j + dj2 + im);
459         w3r.mul(x3i).add(w3i.mul(x3r)).intoArray(ret, j + dj3 + im);
460         w4r.mul(x4i).add(w4i.mul(x4r)).intoArray(ret, j + dj4 + im);
461         w5r.mul(x5i).add(w5i.mul(x5r)).intoArray(ret, j + dj5 + im);
462       }
463     }
464   }
465 }