View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2026.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.numerics.fft;
39  
40  import jdk.incubator.vector.DoubleVector;
41  
42  import static java.lang.Math.fma;
43  import static jdk.incubator.vector.DoubleVector.broadcast;
44  import static jdk.incubator.vector.DoubleVector.fromArray;
45  import static org.apache.commons.math3.util.FastMath.sqrt;
46  
47  /**
48   * The MixedRadixFactor6 class handles factors of 6 in the FFT.
49   */
50  public class MixedRadixFactor6 extends MixedRadixFactor {
51  
52    private static final double sqrt3_2 = sqrt(3.0) / 2.0;
53  
54    private final int di2;
55    private final int di3;
56    private final int di4;
57    private final int di5;
58    private final int dj2;
59    private final int dj3;
60    private final int dj4;
61    private final int dj5;
62  
63    /**
64     * Construct a MixedRadixFactor6.
65     *
66     * @param passConstants PassConstants.
67     */
68    public MixedRadixFactor6(PassConstants passConstants) {
69      super(passConstants);
70      di2 = 2 * di;
71      di3 = 3 * di;
72      di4 = 4 * di;
73      di5 = 5 * di;
74      dj2 = 2 * dj;
75      dj3 = 3 * dj;
76      dj4 = 4 * dj;
77      dj5 = 5 * dj;
78    }
79  
80    /**
81     * Handle factors of 6.
82     *
83     * @param passData PassData.
84     */
85    @Override
86    protected void passScalar(PassData passData) {
87      final double[] data = passData.in;
88      final double[] ret = passData.out;
89      int sign = passData.sign;
90      int i = passData.inOffset;
91      int j = passData.outOffset;
92      final double tau = sign * sqrt3_2;
93      // First pass of the 6-point FFT has no twiddle factors.
94      for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
95        final double z0r = data[i];
96        final double z1r = data[i + di];
97        final double z2r = data[i + di2];
98        final double z3r = data[i + di3];
99        final double z4r = data[i + di4];
100       final double z5r = data[i + di5];
101       final double z0i = data[i + im];
102       final double z1i = data[i + di + im];
103       final double z2i = data[i + di2 + im];
104       final double z3i = data[i + di3 + im];
105       final double z4i = data[i + di4 + im];
106       final double z5i = data[i + di5 + im];
107       final double ta1r = z2r + z4r;
108       final double ta1i = z2i + z4i;
109       final double ta2r = fma(-0.5, ta1r, z0r);
110       final double ta2i = fma(-0.5, ta1i, z0i);
111       final double ta3r = tau * (z2r - z4r);
112       final double ta3i = tau * (z2i - z4i);
113       final double a0r = z0r + ta1r;
114       final double a0i = z0i + ta1i;
115       final double a1r = ta2r - ta3i;
116       final double a1i = ta2i + ta3r;
117       final double a2r = ta2r + ta3i;
118       final double a2i = ta2i - ta3r;
119       final double tb1r = z5r + z1r;
120       final double tb1i = z5i + z1i;
121       final double tb2r = fma(-0.5, tb1r, z3r);
122       final double tb2i = fma(-0.5, tb1i, z3i);
123       final double tb3r = tau * (z5r - z1r);
124       final double tb3i = tau * (z5i - z1i);
125       final double b0r = z3r + tb1r;
126       final double b0i = z3i + tb1i;
127       final double b1r = tb2r - tb3i;
128       final double b1i = tb2i + tb3r;
129       final double b2r = tb2r + tb3i;
130       final double b2i = tb2i - tb3r;
131       ret[j] = a0r + b0r;
132       ret[j + im] = a0i + b0i;
133       ret[j + dj] = a1r - b1r;
134       ret[j + dj + im] = a1i - b1i;
135       ret[j + dj2] = a2r + b2r;
136       ret[j + dj2 + im] = a2i + b2i;
137       ret[j + dj3] = a0r - b0r;
138       ret[j + dj3 + im] = a0i - b0i;
139       ret[j + dj4] = a1r + b1r;
140       ret[j + dj4 + im] = a1i + b1i;
141       ret[j + dj5] = a2r - b2r;
142       ret[j + dj5 + im] = a2i - b2i;
143     }
144 
145     j += jstep;
146     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
147       final int index = k * 5;
148       final double w1r = wr[index];
149       final double w2r = wr[index + 1];
150       final double w3r = wr[index + 2];
151       final double w4r = wr[index + 3];
152       final double w5r = wr[index + 4];
153       final double w1i = -sign * wi[index];
154       final double w2i = -sign * wi[index + 1];
155       final double w3i = -sign * wi[index + 2];
156       final double w4i = -sign * wi[index + 3];
157       final double w5i = -sign * wi[index + 4];
158       for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
159         final double z0r = data[i];
160         final double z1r = data[i + di];
161         final double z2r = data[i + di2];
162         final double z3r = data[i + di3];
163         final double z4r = data[i + di4];
164         final double z5r = data[i + di5];
165         final double z0i = data[i + im];
166         final double z1i = data[i + di + im];
167         final double z2i = data[i + di2 + im];
168         final double z3i = data[i + di3 + im];
169         final double z4i = data[i + di4 + im];
170         final double z5i = data[i + di5 + im];
171         final double ta1r = z2r + z4r;
172         final double ta1i = z2i + z4i;
173         final double ta2r = fma(-0.5, ta1r, z0r);
174         final double ta2i = fma(-0.5, ta1i, z0i);
175         final double ta3r = tau * (z2r - z4r);
176         final double ta3i = tau * (z2i - z4i);
177         final double a0r = z0r + ta1r;
178         final double a0i = z0i + ta1i;
179         final double a1r = ta2r - ta3i;
180         final double a1i = ta2i + ta3r;
181         final double a2r = ta2r + ta3i;
182         final double a2i = ta2i - ta3r;
183         final double tb1r = z5r + z1r;
184         final double tb1i = z5i + z1i;
185         final double tb2r = fma(-0.5, tb1r, z3r);
186         final double tb2i = fma(-0.5, tb1i, z3i);
187         final double tb3r = tau * (z5r - z1r);
188         final double tb3i = tau * (z5i - z1i);
189         final double b0r = z3r + tb1r;
190         final double b0i = z3i + tb1i;
191         final double b1r = tb2r - tb3i;
192         final double b1i = tb2i + tb3r;
193         final double b2r = tb2r + tb3i;
194         final double b2i = tb2i - tb3r;
195         ret[j] = a0r + b0r;
196         ret[j + im] = a0i + b0i;
197         multiplyAndStore(a1r - b1r, a1i - b1i, w1r, w1i, ret, j + dj, j + dj + im);
198         multiplyAndStore(a2r + b2r, a2i + b2i, w2r, w2i, ret, j + dj2, j + dj2 + im);
199         multiplyAndStore(a0r - b0r, a0i - b0i, w3r, w3i, ret, j + dj3, j + dj3 + im);
200         multiplyAndStore(a1r + b1r, a1i + b1i, w4r, w4i, ret, j + dj4, j + dj4 + im);
201         multiplyAndStore(a2r - b2r, a2i - b2i, w5r, w5i, ret, j + dj5, j + dj5 + im);
202       }
203     }
204   }
205 
206   /**
207    * Handle factors of 6 using SIMD vectors.
208    *
209    * @param passData PassData.
210    */
211   @Override
212   protected void passSIMD(PassData passData) {
213     if (!isValidSIMDWidth(simdWidth)) {
214       passScalar(passData);
215     } else {
216       if (im == 1) {
217         interleaved(passData);
218       } else {
219         blocked(passData);
220       }
221     }
222   }
223 
224   /**
225    * Handle factors of 6.
226    *
227    * @param passData PassData.
228    */
229   private void interleaved(PassData passData) {
230     final double[] data = passData.in;
231     final double[] ret = passData.out;
232     int sign = passData.sign;
233     int i = passData.inOffset;
234     int j = passData.outOffset;
235     final double tau = sign * sqrt3_2;
236     // First pass of the 6-point FFT has no twiddle factors.
237     for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP, i += LENGTH, j += LENGTH) {
238       final DoubleVector
239           z0 = fromArray(DOUBLE_SPECIES, data, i),
240           z1 = fromArray(DOUBLE_SPECIES, data, i + di),
241           z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
242           z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
243           z4 = fromArray(DOUBLE_SPECIES, data, i + di4),
244           z5 = fromArray(DOUBLE_SPECIES, data, i + di5);
245       final DoubleVector
246           ta1 = z2.add(z4),
247           ta2 = ta1.mul(-0.5).add(z0),
248           ta3 = z2.sub(z4).mul(tau).rearrange(SHUFFLE_RE_IM),
249           a0 = z0.add(ta1),
250           a1 = ta3.fma(NEGATE_RE, ta2),
251           a2 = ta3.fma(NEGATE_IM, ta2),
252           tb1 = z5.add(z1),
253           tb2 = tb1.mul(-0.5).add(z3),
254           tb3 = z5.sub(z1).mul(tau).rearrange(SHUFFLE_RE_IM),
255           b0 = z3.add(tb1),
256           b1 = tb3.fma(NEGATE_RE, tb2),
257           b2 = tb3.fma(NEGATE_IM, tb2);
258       a0.add(b0).intoArray(ret, j);
259       a1.sub(b1).intoArray(ret, j + dj);
260       a2.add(b2).intoArray(ret, j + dj2);
261       a0.sub(b0).intoArray(ret, j + dj3);
262       a1.add(b1).intoArray(ret, j + dj4);
263       a2.sub(b2).intoArray(ret, j + dj5);
264     }
265 
266     j += jstep;
267     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
268       final int index = k * 5;
269       final DoubleVector
270           w1r = broadcast(DOUBLE_SPECIES, wr[index]),
271           w2r = broadcast(DOUBLE_SPECIES, wr[index + 1]),
272           w3r = broadcast(DOUBLE_SPECIES, wr[index + 2]),
273           w4r = broadcast(DOUBLE_SPECIES, wr[index + 3]),
274           w5r = broadcast(DOUBLE_SPECIES, wr[index + 4]),
275           w1i = broadcast(DOUBLE_SPECIES, -sign * wi[index]).mul(NEGATE_IM),
276           w2i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 1]).mul(NEGATE_IM),
277           w3i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 2]).mul(NEGATE_IM),
278           w4i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 3]).mul(NEGATE_IM),
279           w5i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 4]).mul(NEGATE_IM);
280       for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP, i += LENGTH, j += LENGTH) {
281         final DoubleVector
282             z0 = fromArray(DOUBLE_SPECIES, data, i),
283             z1 = fromArray(DOUBLE_SPECIES, data, i + di),
284             z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
285             z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
286             z4 = fromArray(DOUBLE_SPECIES, data, i + di4),
287             z5 = fromArray(DOUBLE_SPECIES, data, i + di5);
288         final DoubleVector
289             ta1 = z2.add(z4),
290             ta2 = ta1.mul(-0.5).add(z0),
291             ta3 = z2.sub(z4).mul(tau).rearrange(SHUFFLE_RE_IM),
292             a0 = z0.add(ta1),
293             a1 = ta3.fma(NEGATE_RE, ta2),
294             a2 = ta3.fma(NEGATE_IM, ta2),
295             tb1 = z5.add(z1),
296             tb2 = tb1.mul(-0.5).add(z3),
297             tb3 = z5.sub(z1).mul(tau).rearrange(SHUFFLE_RE_IM),
298             b0 = z3.add(tb1),
299             b1 = tb3.fma(NEGATE_RE, tb2),
300             b2 = tb3.fma(NEGATE_IM, tb2);
301         a0.add(b0).intoArray(ret, j);
302         final DoubleVector
303             x1 = a1.sub(b1),
304             x2 = a2.add(b2),
305             x3 = a0.sub(b0),
306             x4 = a1.add(b1),
307             x5 = a2.sub(b2);
308         w1r.fma(x1, w1i.mul(x1).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj);
309         w2r.fma(x2, w2i.mul(x2).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj2);
310         w3r.fma(x3, w3i.mul(x3).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj3);
311         w4r.fma(x4, w4i.mul(x4).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj4);
312         w5r.fma(x5, w5i.mul(x5).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj5);
313       }
314     }
315   }
316 
317   /**
318    * Handle factors of 6.
319    *
320    * @param passData PassData.
321    */
322   private void blocked(PassData passData) {
323     final double[] data = passData.in;
324     final double[] ret = passData.out;
325     int sign = passData.sign;
326     int i = passData.inOffset;
327     int j = passData.outOffset;
328     final double tau = sign * sqrt3_2;
329     // First pass of the 6-point FFT has no twiddle factors.
330     for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
331       final DoubleVector
332           z0r = fromArray(DOUBLE_SPECIES, data, i),
333           z1r = fromArray(DOUBLE_SPECIES, data, i + di),
334           z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
335           z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
336           z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
337           z5r = fromArray(DOUBLE_SPECIES, data, i + di5),
338           z0i = fromArray(DOUBLE_SPECIES, data, i + im),
339           z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
340           z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
341           z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
342           z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im),
343           z5i = fromArray(DOUBLE_SPECIES, data, i + di5 + im);
344       final DoubleVector
345           ta1r = z2r.add(z4r),
346           ta1i = z2i.add(z4i),
347           ta2r = ta1r.mul(-0.5).add(z0r),
348           ta2i = ta1i.mul(-0.5).add(z0i),
349           ta3r = z2r.sub(z4r).mul(tau),
350           ta3i = z2i.sub(z4i).mul(tau),
351           a0r = z0r.add(ta1r),
352           a0i = z0i.add(ta1i),
353           a1r = ta2r.sub(ta3i),
354           a1i = ta2i.add(ta3r),
355           a2r = ta2r.add(ta3i),
356           a2i = ta2i.sub(ta3r),
357           tb1r = z5r.add(z1r),
358           tb1i = z5i.add(z1i),
359           tb2r = tb1r.mul(-0.5).add(z3r),
360           tb2i = tb1i.mul(-0.5).add(z3i),
361           tb3r = z5r.sub(z1r).mul(tau),
362           tb3i = z5i.sub(z1i).mul(tau),
363           b0r = z3r.add(tb1r),
364           b0i = z3i.add(tb1i),
365           b1r = tb2r.sub(tb3i),
366           b1i = tb2i.add(tb3r),
367           b2r = tb2r.add(tb3i),
368           b2i = tb2i.sub(tb3r);
369       a0r.add(b0r).intoArray(ret, j);
370       a0i.add(b0i).intoArray(ret, j + im);
371       a1r.sub(b1r).intoArray(ret, j + dj);
372       a1i.sub(b1i).intoArray(ret, j + dj + im);
373       a2r.add(b2r).intoArray(ret, j + dj2);
374       a2i.add(b2i).intoArray(ret, j + dj2 + im);
375       a0r.sub(b0r).intoArray(ret, j + dj3);
376       a0i.sub(b0i).intoArray(ret, j + dj3 + im);
377       a1r.add(b1r).intoArray(ret, j + dj4);
378       a1i.add(b1i).intoArray(ret, j + dj4 + im);
379       a2r.sub(b2r).intoArray(ret, j + dj5);
380       a2i.sub(b2i).intoArray(ret, j + dj5 + im);
381     }
382 
383     j += jstep;
384     for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
385       final int index = k * 5;
386       final double
387           w1r = wr[index],
388           w2r = wr[index + 1],
389           w3r = wr[index + 2],
390           w4r = wr[index + 3],
391           w5r = wr[index + 4],
392           w1i = -sign * wi[index],
393           w2i = -sign * wi[index + 1],
394           w3i = -sign * wi[index + 2],
395           w4i = -sign * wi[index + 3],
396           w5i = -sign * wi[index + 4];
397       for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
398         final DoubleVector
399             z0r = fromArray(DOUBLE_SPECIES, data, i),
400             z1r = fromArray(DOUBLE_SPECIES, data, i + di),
401             z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
402             z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
403             z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
404             z5r = fromArray(DOUBLE_SPECIES, data, i + di5),
405             z0i = fromArray(DOUBLE_SPECIES, data, i + im),
406             z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
407             z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
408             z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
409             z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im),
410             z5i = fromArray(DOUBLE_SPECIES, data, i + di5 + im);
411         final DoubleVector
412             ta1r = z2r.add(z4r),
413             ta1i = z2i.add(z4i),
414             ta2r = ta1r.mul(-0.5).add(z0r),
415             ta2i = ta1i.mul(-0.5).add(z0i),
416             ta3r = z2r.sub(z4r).mul(tau),
417             ta3i = z2i.sub(z4i).mul(tau),
418             a0r = z0r.add(ta1r),
419             a0i = z0i.add(ta1i),
420             a1r = ta2r.sub(ta3i),
421             a1i = ta2i.add(ta3r),
422             a2r = ta2r.add(ta3i),
423             a2i = ta2i.sub(ta3r),
424             tb1r = z5r.add(z1r),
425             tb1i = z5i.add(z1i),
426             tb2r = tb1r.mul(-0.5).add(z3r),
427             tb2i = tb1i.mul(-0.5).add(z3i),
428             tb3r = z5r.sub(z1r).mul(tau),
429             tb3i = z5i.sub(z1i).mul(tau),
430             b0r = z3r.add(tb1r),
431             b0i = z3i.add(tb1i),
432             b1r = tb2r.sub(tb3i),
433             b1i = tb2i.add(tb3r),
434             b2r = tb2r.add(tb3i),
435             b2i = tb2i.sub(tb3r);
436         a0r.add(b0r).intoArray(ret, j);
437         a0i.add(b0i).intoArray(ret, j + im);
438         final DoubleVector
439             x1r = a1r.sub(b1r), x1i = a1i.sub(b1i),
440             x2r = a2r.add(b2r), x2i = a2i.add(b2i),
441             x3r = a0r.sub(b0r), x3i = a0i.sub(b0i),
442             x4r = a1r.add(b1r), x4i = a1i.add(b1i),
443             x5r = a2r.sub(b2r), x5i = a2i.sub(b2i);
444         x1r.mul(w1r).sub(x1i.mul(w1i)).intoArray(ret, j + dj);
445         x2r.mul(w2r).sub(x2i.mul(w2i)).intoArray(ret, j + dj2);
446         x3r.mul(w3r).sub(x3i.mul(w3i)).intoArray(ret, j + dj3);
447         x4r.mul(w4r).sub(x4i.mul(w4i)).intoArray(ret, j + dj4);
448         x5r.mul(w5r).sub(x5i.mul(w5i)).intoArray(ret, j + dj5);
449         x1i.mul(w1r).add(x1r.mul(w1i)).intoArray(ret, j + dj + im);
450         x2i.mul(w2r).add(x2r.mul(w2i)).intoArray(ret, j + dj2 + im);
451         x3i.mul(w3r).add(x3r.mul(w3i)).intoArray(ret, j + dj3 + im);
452         x4i.mul(w4r).add(x4r.mul(w4i)).intoArray(ret, j + dj4 + im);
453         x5i.mul(w5r).add(x5r.mul(w5i)).intoArray(ret, j + dj5 + im);
454       }
455     }
456   }
457 }