1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import jdk.incubator.vector.DoubleVector;
41
42 import static java.lang.Math.fma;
43 import static jdk.incubator.vector.DoubleVector.broadcast;
44 import static jdk.incubator.vector.DoubleVector.fromArray;
45 import static org.apache.commons.math3.util.FastMath.PI;
46 import static org.apache.commons.math3.util.FastMath.sin;
47 import static org.apache.commons.math3.util.FastMath.sqrt;
48
49
50
51
52 public class MixedRadixFactor5 extends MixedRadixFactor {
53
54 private static final double sqrt5_4 = sqrt(5.0) / 4.0;
55 private static final double sinPI_5 = sin(PI / 5.0);
56 private static final double sin2PI_5 = sin(2.0 * PI / 5.0);
57
58 private final int di2;
59 private final int di3;
60 private final int di4;
61 private final int dj2;
62 private final int dj3;
63 private final int dj4;
64 private final double tau = sqrt5_4;
65
66
67
68
69
70
71 public MixedRadixFactor5(PassConstants passConstants) {
72 super(passConstants);
73 di2 = 2 * di;
74 di3 = 3 * di;
75 di4 = 4 * di;
76 dj2 = 2 * dj;
77 dj3 = 3 * dj;
78 dj4 = 4 * dj;
79 }
80
81
82
83
84
85
86 @Override
87 protected void passScalar(PassData passData) {
88 final double[] data = passData.in;
89 final double[] ret = passData.out;
90 int sign = passData.sign;
91 int i = passData.inOffset;
92 int j = passData.outOffset;
93 final double sin2PI_5s = sign * sin2PI_5;
94 final double sinPI_5s = sign * sinPI_5;
95
96 for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
97 final double z0r = data[i];
98 final double z1r = data[i + di];
99 final double z2r = data[i + di2];
100 final double z3r = data[i + di3];
101 final double z4r = data[i + di4];
102 final double z0i = data[i + im];
103 final double z1i = data[i + di + im];
104 final double z2i = data[i + di2 + im];
105 final double z3i = data[i + di3 + im];
106 final double z4i = data[i + di4 + im];
107 final double t1r = z1r + z4r;
108 final double t1i = z1i + z4i;
109 final double t2r = z2r + z3r;
110 final double t2i = z2i + z3i;
111 final double t3r = z1r - z4r;
112 final double t3i = z1i - z4i;
113 final double t4r = z2r - z3r;
114 final double t4i = z2i - z3i;
115 final double t5r = t1r + t2r;
116 final double t5i = t1i + t2i;
117 final double t6r = tau * (t1r - t2r);
118 final double t6i = tau * (t1i - t2i);
119 final double t7r = fma(-0.25, t5r, z0r);
120 final double t7i = fma(-0.25, t5i, z0i);
121 final double t8r = t7r + t6r;
122 final double t8i = t7i + t6i;
123 final double t9r = t7r - t6r;
124 final double t9i = t7i - t6i;
125 final double t10r = fma(sin2PI_5s, t3r, sinPI_5s * t4r);
126 final double t10i = fma(sin2PI_5s, t3i, sinPI_5s * t4i);
127 final double t11r = fma(-sin2PI_5s, t4r, sinPI_5s * t3r);
128 final double t11i = fma(-sin2PI_5s, t4i, sinPI_5s * t3i);
129 ret[j] = z0r + t5r;
130 ret[j + im] = z0i + t5i;
131 ret[j + dj] = t8r - t10i;
132 ret[j + dj + im] = t8i + t10r;
133 ret[j + dj2] = t9r - t11i;
134 ret[j + dj2 + im] = t9i + t11r;
135 ret[j + dj3] = t9r + t11i;
136 ret[j + dj3 + im] = t9i - t11r;
137 ret[j + dj4] = t8r + t10i;
138 ret[j + dj4 + im] = t8i - t10r;
139 }
140
141 j += jstep;
142 for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
143
144
145
146
147
148
149
150
151
152 final int index = k * 4;
153 final double w1r = wr[index];
154 final double w2r = wr[index + 1];
155 final double w3r = wr[index + 2];
156 final double w4r = wr[index + 3];
157 final double w1i = -sign * wi[index];
158 final double w2i = -sign * wi[index + 1];
159 final double w3i = -sign * wi[index + 2];
160 final double w4i = -sign * wi[index + 3];
161 for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
162 final double z0r = data[i];
163 final double z1r = data[i + di];
164 final double z2r = data[i + di2];
165 final double z3r = data[i + di3];
166 final double z4r = data[i + di4];
167 final double z0i = data[i + im];
168 final double z1i = data[i + di + im];
169 final double z2i = data[i + di2 + im];
170 final double z3i = data[i + di3 + im];
171 final double z4i = data[i + di4 + im];
172 final double t1r = z1r + z4r;
173 final double t1i = z1i + z4i;
174 final double t2r = z2r + z3r;
175 final double t2i = z2i + z3i;
176 final double t3r = z1r - z4r;
177 final double t3i = z1i - z4i;
178 final double t4r = z2r - z3r;
179 final double t4i = z2i - z3i;
180 final double t5r = t1r + t2r;
181 final double t5i = t1i + t2i;
182 final double t6r = tau * (t1r - t2r);
183 final double t6i = tau * (t1i - t2i);
184 final double t7r = fma(-0.25, t5r, z0r);
185 final double t7i = fma(-0.25, t5i, z0i);
186 final double t8r = t7r + t6r;
187 final double t8i = t7i + t6i;
188 final double t9r = t7r - t6r;
189 final double t9i = t7i - t6i;
190 final double t10r = fma(sin2PI_5s, t3r, sinPI_5s * t4r);
191 final double t10i = fma(sin2PI_5s, t3i, sinPI_5s * t4i);
192 final double t11r = fma(-sin2PI_5s, t4r, sinPI_5s * t3r);
193 final double t11i = fma(-sin2PI_5s, t4i, sinPI_5s * t3i);
194 ret[j] = z0r + t5r;
195 ret[j + im] = z0i + t5i;
196
197 multiplyAndStore(t8r - t10i, t8i + t10r, w1r, w1i, ret, j + dj, j + dj + im);
198 multiplyAndStore(t9r - t11i, t9i + t11r, w2r, w2i, ret, j + dj2, j + dj2 + im);
199 multiplyAndStore(t9r + t11i, t9i - t11r, w3r, w3i, ret, j + dj3, j + dj3 + im);
200 multiplyAndStore(t8r + t10i, t8i - t10r, w4r, w4i, ret, j + dj4, j + dj4 + im);
201 }
202 }
203 }
204
205
206
207
208
209
210 @Override
211 protected void passSIMD(PassData passData) {
212 if (im == 1) {
213
214 if (innerLoopLimit % INTERLEAVED_LOOP != 0) {
215 passScalar(passData);
216 } else {
217 interleaved(passData);
218 }
219 } else {
220
221 if (innerLoopLimit % BLOCK_LOOP != 0) {
222 passScalar(passData);
223 } else {
224 blocked(passData);
225 }
226 }
227 }
228
229
230
231
232
233
234 protected void interleaved(PassData passData) {
235 final double[] data = passData.in;
236 final double[] ret = passData.out;
237 int sign = passData.sign;
238 int i = passData.inOffset;
239 int j = passData.outOffset;
240 final double sin2PI_5s = sign * sin2PI_5;
241 final double sinPI_5s = sign * sinPI_5;
242
243 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP, i += LENGTH, j += LENGTH) {
244 DoubleVector
245 z0 = fromArray(DOUBLE_SPECIES, data, i),
246 z1 = fromArray(DOUBLE_SPECIES, data, i + di),
247 z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
248 z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
249 z4 = fromArray(DOUBLE_SPECIES, data, i + di4);
250 DoubleVector
251 t1 = z1.add(z4),
252 t2 = z2.add(z3),
253 t3 = z1.sub(z4),
254 t4 = z2.sub(z3),
255 t5 = t1.add(t2),
256 t6 = t1.sub(t2).mul(tau),
257 t7 = t5.mul(-0.25).add(z0),
258 t8 = t7.add(t6),
259 t9 = t7.sub(t6),
260 t10 = t3.mul(sin2PI_5s).add(t4.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM),
261 t11 = t4.mul(-sin2PI_5s).add(t3.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM);
262 z0.add(t5).intoArray(ret, j);
263 t8.add(t10.mul(NEGATE_RE)).intoArray(ret, j + dj);
264 t9.add(t11.mul(NEGATE_RE)).intoArray(ret, j + dj2);
265 t9.add(t11.mul(NEGATE_IM)).intoArray(ret, j + dj3);
266 t8.add(t10.mul(NEGATE_IM)).intoArray(ret, j + dj4);
267 }
268
269 j += jstep;
270 for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
271
272
273
274
275
276
277
278
279
280
281 final int index = k * 4;
282 final DoubleVector
283 w1r = broadcast(DOUBLE_SPECIES, wr[index]),
284 w2r = broadcast(DOUBLE_SPECIES, wr[index + 1]),
285 w3r = broadcast(DOUBLE_SPECIES, wr[index + 2]),
286 w4r = broadcast(DOUBLE_SPECIES, wr[index + 3]),
287 w1i = broadcast(DOUBLE_SPECIES, -sign * wi[index]).mul(NEGATE_IM),
288 w2i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 1]).mul(NEGATE_IM),
289 w3i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 2]).mul(NEGATE_IM),
290 w4i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 3]).mul(NEGATE_IM);
291 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP, i += LENGTH, j += LENGTH) {
292 DoubleVector
293 z0 = fromArray(DOUBLE_SPECIES, data, i),
294 z1 = fromArray(DOUBLE_SPECIES, data, i + di),
295 z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
296 z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
297 z4 = fromArray(DOUBLE_SPECIES, data, i + di4);
298 DoubleVector
299 t1 = z1.add(z4),
300 t2 = z2.add(z3),
301 t3 = z1.sub(z4),
302 t4 = z2.sub(z3),
303 t5 = t1.add(t2),
304 t6 = t1.sub(t2).mul(tau),
305 t7 = t5.mul(-0.25).add(z0),
306 t8 = t7.add(t6),
307 t9 = t7.sub(t6),
308 t10 = t3.mul(sin2PI_5s).add(t4.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM),
309 t11 = t4.mul(-sin2PI_5s).add(t3.mul(sinPI_5s)).rearrange(SHUFFLE_RE_IM);
310 z0.add(t5).intoArray(ret, j);
311 DoubleVector
312 x1 = t8.add(t10.mul(NEGATE_RE)),
313 x2 = t9.add(t11.mul(NEGATE_RE)),
314 x3 = t9.add(t11.mul(NEGATE_IM)),
315 x4 = t8.add(t10.mul(NEGATE_IM));
316 w1r.mul(x1).add(w1i.mul(x1).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj);
317 w2r.mul(x2).add(w2i.mul(x2).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj2);
318 w3r.mul(x3).add(w3i.mul(x3).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj3);
319 w4r.mul(x4).add(w4i.mul(x4).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj4);
320 }
321 }
322 }
323
324
325
326
327
328
329 protected void blocked(PassData passData) {
330 final double[] data = passData.in;
331 final double[] ret = passData.out;
332 int sign = passData.sign;
333 int i = passData.inOffset;
334 int j = passData.outOffset;
335 final double sin2PI_5s = sign * sin2PI_5;
336 final double sinPI_5s = sign * sinPI_5;
337
338 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
339 final DoubleVector
340 z0r = fromArray(DOUBLE_SPECIES, data, i),
341 z1r = fromArray(DOUBLE_SPECIES, data, i + di),
342 z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
343 z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
344 z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
345 z0i = fromArray(DOUBLE_SPECIES, data, i + im),
346 z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
347 z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
348 z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
349 z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im);
350 final DoubleVector
351 t1r = z1r.add(z4r),
352 t1i = z1i.add(z4i),
353 t2r = z2r.add(z3r),
354 t2i = z2i.add(z3i),
355 t3r = z1r.sub(z4r),
356 t3i = z1i.sub(z4i),
357 t4r = z2r.sub(z3r),
358 t4i = z2i.sub(z3i),
359 t5r = t1r.add(t2r),
360 t5i = t1i.add(t2i),
361 t6r = t1r.sub(t2r).mul(tau),
362 t6i = t1i.sub(t2i).mul(tau),
363 t7r = t5r.mul(-0.25).add(z0r),
364 t7i = t5i.mul(-0.25).add(z0i),
365 t8r = t7r.add(t6r),
366 t8i = t7i.add(t6i),
367 t9r = t7r.sub(t6r),
368 t9i = t7i.sub(t6i),
369 t10r = t3r.mul(sin2PI_5s).add(t4r.mul(sinPI_5s)),
370 t10i = t3i.mul(sin2PI_5s).add(t4i.mul(sinPI_5s)),
371 t11r = t4r.mul(-sin2PI_5s).add(t3r.mul(sinPI_5s)),
372 t11i = t4i.mul(-sin2PI_5s).add(t3i.mul(sinPI_5s));
373 z0r.add(t5r).intoArray(ret, j);
374 z0i.add(t5i).intoArray(ret, j + im);
375 t8r.sub(t10i).intoArray(ret, j + dj);
376 t8i.add(t10r).intoArray(ret, j + dj + im);
377 t9r.sub(t11i).intoArray(ret, j + dj2);
378 t9i.add(t11r).intoArray(ret, j + dj2 + im);
379 t9r.add(t11i).intoArray(ret, j + dj3);
380 t9i.sub(t11r).intoArray(ret, j + dj3 + im);
381 t8r.add(t10i).intoArray(ret, j + dj4);
382 t8i.sub(t10r).intoArray(ret, j + dj4 + im);
383 }
384 j += jstep;
385 for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
386
387
388
389
390
391
392
393
394
395
396 final int index = k * 4;
397 final DoubleVector
398 w1r = broadcast(DOUBLE_SPECIES, wr[index]),
399 w2r = broadcast(DOUBLE_SPECIES, wr[index + 1]),
400 w3r = broadcast(DOUBLE_SPECIES, wr[index + 2]),
401 w4r = broadcast(DOUBLE_SPECIES, wr[index + 3]),
402 w1i = broadcast(DOUBLE_SPECIES, -sign * wi[index]),
403 w2i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 1]),
404 w3i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 2]),
405 w4i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 3]);
406 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
407 final DoubleVector
408 z0r = fromArray(DOUBLE_SPECIES, data, i),
409 z1r = fromArray(DOUBLE_SPECIES, data, i + di),
410 z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
411 z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
412 z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
413 z0i = fromArray(DOUBLE_SPECIES, data, i + im),
414 z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
415 z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
416 z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
417 z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im);
418 final DoubleVector
419 t1r = z1r.add(z4r),
420 t1i = z1i.add(z4i),
421 t2r = z2r.add(z3r),
422 t2i = z2i.add(z3i),
423 t3r = z1r.sub(z4r),
424 t3i = z1i.sub(z4i),
425 t4r = z2r.sub(z3r),
426 t4i = z2i.sub(z3i),
427 t5r = t1r.add(t2r),
428 t5i = t1i.add(t2i),
429 t6r = t1r.sub(t2r).mul(tau),
430 t6i = t1i.sub(t2i).mul(tau),
431 t7r = t5r.mul(-0.25).add(z0r),
432 t7i = t5i.mul(-0.25).add(z0i),
433 t8r = t7r.add(t6r),
434 t8i = t7i.add(t6i),
435 t9r = t7r.sub(t6r),
436 t9i = t7i.sub(t6i),
437 t10r = t3r.mul(sin2PI_5s).add(t4r.mul(sinPI_5s)),
438 t10i = t3i.mul(sin2PI_5s).add(t4i.mul(sinPI_5s)),
439 t11r = t4r.mul(-sin2PI_5s).add(t3r.mul(sinPI_5s)),
440 t11i = t4i.mul(-sin2PI_5s).add(t3i.mul(sinPI_5s));
441 z0r.add(t5r).intoArray(ret, j);
442 z0i.add(t5i).intoArray(ret, j + im);
443 DoubleVector
444 x1r = t8r.sub(t10i), x1i = t8i.add(t10r),
445 x2r = t9r.sub(t11i), x2i = t9i.add(t11r),
446 x3r = t9r.add(t11i), x3i = t9i.sub(t11r),
447 x4r = t8r.add(t10i), x4i = t8i.sub(t10r);
448 w1r.mul(x1r).add(w1i.neg().mul(x1i)).intoArray(ret, j + dj);
449 w2r.mul(x2r).add(w2i.neg().mul(x2i)).intoArray(ret, j + dj2);
450 w3r.mul(x3r).add(w3i.neg().mul(x3i)).intoArray(ret, j + dj3);
451 w4r.mul(x4r).add(w4i.neg().mul(x4i)).intoArray(ret, j + dj4);
452 w1r.mul(x1i).add(w1i.mul(x1r)).intoArray(ret, j + dj + im);
453 w2r.mul(x2i).add(w2i.mul(x2r)).intoArray(ret, j + dj2 + im);
454 w3r.mul(x3i).add(w3i.mul(x3r)).intoArray(ret, j + dj3 + im);
455 w4r.mul(x4i).add(w4i.mul(x4r)).intoArray(ret, j + dj4 + im);
456 }
457 }
458 }
459
460 }