1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import jdk.incubator.vector.DoubleVector;
41
42 import static java.lang.Math.fma;
43 import static jdk.incubator.vector.DoubleVector.broadcast;
44 import static jdk.incubator.vector.DoubleVector.fromArray;
45 import static org.apache.commons.math3.util.FastMath.sqrt;
46
47
48
49
50 public class MixedRadixFactor6 extends MixedRadixFactor {
51
52 private static final double sqrt3_2 = sqrt(3.0) / 2.0;
53
54 private final int di2;
55 private final int di3;
56 private final int di4;
57 private final int di5;
58 private final int dj2;
59 private final int dj3;
60 private final int dj4;
61 private final int dj5;
62
63
64
65
66
67
68 public MixedRadixFactor6(PassConstants passConstants) {
69 super(passConstants);
70 di2 = 2 * di;
71 di3 = 3 * di;
72 di4 = 4 * di;
73 di5 = 5 * di;
74 dj2 = 2 * dj;
75 dj3 = 3 * dj;
76 dj4 = 4 * dj;
77 dj5 = 5 * dj;
78 }
79
80
81
82
83
84
85 @Override
86 protected void passScalar(PassData passData) {
87 final double[] data = passData.in;
88 final double[] ret = passData.out;
89 int sign = passData.sign;
90 int i = passData.inOffset;
91 int j = passData.outOffset;
92 final double tau = sign * sqrt3_2;
93
94 for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
95 final double z0r = data[i];
96 final double z1r = data[i + di];
97 final double z2r = data[i + di2];
98 final double z3r = data[i + di3];
99 final double z4r = data[i + di4];
100 final double z5r = data[i + di5];
101 final double z0i = data[i + im];
102 final double z1i = data[i + di + im];
103 final double z2i = data[i + di2 + im];
104 final double z3i = data[i + di3 + im];
105 final double z4i = data[i + di4 + im];
106 final double z5i = data[i + di5 + im];
107 final double ta1r = z2r + z4r;
108 final double ta1i = z2i + z4i;
109 final double ta2r = fma(-0.5, ta1r, z0r);
110 final double ta2i = fma(-0.5, ta1i, z0i);
111 final double ta3r = tau * (z2r - z4r);
112 final double ta3i = tau * (z2i - z4i);
113 final double a0r = z0r + ta1r;
114 final double a0i = z0i + ta1i;
115 final double a1r = ta2r - ta3i;
116 final double a1i = ta2i + ta3r;
117 final double a2r = ta2r + ta3i;
118 final double a2i = ta2i - ta3r;
119 final double tb1r = z5r + z1r;
120 final double tb1i = z5i + z1i;
121 final double tb2r = fma(-0.5, tb1r, z3r);
122 final double tb2i = fma(-0.5, tb1i, z3i);
123 final double tb3r = tau * (z5r - z1r);
124 final double tb3i = tau * (z5i - z1i);
125 final double b0r = z3r + tb1r;
126 final double b0i = z3i + tb1i;
127 final double b1r = tb2r - tb3i;
128 final double b1i = tb2i + tb3r;
129 final double b2r = tb2r + tb3i;
130 final double b2i = tb2i - tb3r;
131 ret[j] = a0r + b0r;
132 ret[j + im] = a0i + b0i;
133 ret[j + dj] = a1r - b1r;
134 ret[j + dj + im] = a1i - b1i;
135 ret[j + dj2] = a2r + b2r;
136 ret[j + dj2 + im] = a2i + b2i;
137 ret[j + dj3] = a0r - b0r;
138 ret[j + dj3 + im] = a0i - b0i;
139 ret[j + dj4] = a1r + b1r;
140 ret[j + dj4 + im] = a1i + b1i;
141 ret[j + dj5] = a2r - b2r;
142 ret[j + dj5 + im] = a2i - b2i;
143 }
144
145 j += jstep;
146 for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
147
148
149
150
151
152
153
154
155
156
157
158 final int index = k * 5;
159 final double w1r = wr[index];
160 final double w2r = wr[index + 1];
161 final double w3r = wr[index + 2];
162 final double w4r = wr[index + 3];
163 final double w5r = wr[index + 4];
164 final double w1i = -sign * wi[index];
165 final double w2i = -sign * wi[index + 1];
166 final double w3i = -sign * wi[index + 2];
167 final double w4i = -sign * wi[index + 3];
168 final double w5i = -sign * wi[index + 4];
169 for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
170 final double z0r = data[i];
171 final double z1r = data[i + di];
172 final double z2r = data[i + di2];
173 final double z3r = data[i + di3];
174 final double z4r = data[i + di4];
175 final double z5r = data[i + di5];
176 final double z0i = data[i + im];
177 final double z1i = data[i + di + im];
178 final double z2i = data[i + di2 + im];
179 final double z3i = data[i + di3 + im];
180 final double z4i = data[i + di4 + im];
181 final double z5i = data[i + di5 + im];
182 final double ta1r = z2r + z4r;
183 final double ta1i = z2i + z4i;
184 final double ta2r = fma(-0.5, ta1r, z0r);
185 final double ta2i = fma(-0.5, ta1i, z0i);
186 final double ta3r = tau * (z2r - z4r);
187 final double ta3i = tau * (z2i - z4i);
188 final double a0r = z0r + ta1r;
189 final double a0i = z0i + ta1i;
190 final double a1r = ta2r - ta3i;
191 final double a1i = ta2i + ta3r;
192 final double a2r = ta2r + ta3i;
193 final double a2i = ta2i - ta3r;
194 final double tb1r = z5r + z1r;
195 final double tb1i = z5i + z1i;
196 final double tb2r = fma(-0.5, tb1r, z3r);
197 final double tb2i = fma(-0.5, tb1i, z3i);
198 final double tb3r = tau * (z5r - z1r);
199 final double tb3i = tau * (z5i - z1i);
200 final double b0r = z3r + tb1r;
201 final double b0i = z3i + tb1i;
202 final double b1r = tb2r - tb3i;
203 final double b1i = tb2i + tb3r;
204 final double b2r = tb2r + tb3i;
205 final double b2i = tb2i - tb3r;
206 ret[j] = a0r + b0r;
207 ret[j + im] = a0i + b0i;
208 multiplyAndStore(a1r - b1r, a1i - b1i, w1r, w1i, ret, j + dj, j + dj + im);
209 multiplyAndStore(a2r + b2r, a2i + b2i, w2r, w2i, ret, j + dj2, j + dj2 + im);
210 multiplyAndStore(a0r - b0r, a0i - b0i, w3r, w3i, ret, j + dj3, j + dj3 + im);
211 multiplyAndStore(a1r + b1r, a1i + b1i, w4r, w4i, ret, j + dj4, j + dj4 + im);
212 multiplyAndStore(a2r - b2r, a2i - b2i, w5r, w5i, ret, j + dj5, j + dj5 + im);
213 }
214 }
215 }
216
217
218
219
220
221
222 @Override
223 protected void passSIMD(PassData passData) {
224
225 if (im == 1) {
226
227 if (innerLoopLimit % INTERLEAVED_LOOP != 0) {
228 passScalar(passData);
229 } else {
230 interleaved(passData);
231 }
232
233 } else {
234
235 if (innerLoopLimit % BLOCK_LOOP != 0) {
236 passScalar(passData);
237 } else {
238 blocked(passData);
239 }
240 }
241 }
242
243
244
245
246
247
248 private void interleaved(PassData passData) {
249 final double[] data = passData.in;
250 final double[] ret = passData.out;
251 int sign = passData.sign;
252 int i = passData.inOffset;
253 int j = passData.outOffset;
254 final double tau = sign * sqrt3_2;
255
256 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP, i += LENGTH, j += LENGTH) {
257 DoubleVector
258 z0 = fromArray(DOUBLE_SPECIES, data, i),
259 z1 = fromArray(DOUBLE_SPECIES, data, i + di),
260 z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
261 z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
262 z4 = fromArray(DOUBLE_SPECIES, data, i + di4),
263 z5 = fromArray(DOUBLE_SPECIES, data, i + di5);
264 DoubleVector
265 ta1 = z2.add(z4),
266 ta2 = ta1.mul(-0.5).add(z0),
267 ta3 = z2.sub(z4).mul(tau).rearrange(SHUFFLE_RE_IM),
268 a0 = z0.add(ta1),
269 a1 = ta2.add(ta3.mul(NEGATE_RE)),
270 a2 = ta2.add(ta3.mul(NEGATE_IM)),
271 tb1 = z5.add(z1),
272 tb2 = tb1.mul(-0.5).add(z3),
273 tb3 = z5.sub(z1).mul(tau).rearrange(SHUFFLE_RE_IM),
274 b0 = z3.add(tb1),
275 b1 = tb2.add(tb3.mul(NEGATE_RE)),
276 b2 = tb2.add(tb3.mul(NEGATE_IM));
277 a0.add(b0).intoArray(ret, j);
278 a1.sub(b1).intoArray(ret, j + dj);
279 a2.add(b2).intoArray(ret, j + dj2);
280 a0.sub(b0).intoArray(ret, j + dj3);
281 a1.add(b1).intoArray(ret, j + dj4);
282 a2.sub(b2).intoArray(ret, j + dj5);
283 }
284
285 j += jstep;
286 for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
287
288
289
290
291
292
293
294
295
296
297
298
299 final int index = k * 5;
300 final DoubleVector
301 w1r = broadcast(DOUBLE_SPECIES, wr[index]),
302 w2r = broadcast(DOUBLE_SPECIES, wr[index + 1]),
303 w3r = broadcast(DOUBLE_SPECIES, wr[index + 2]),
304 w4r = broadcast(DOUBLE_SPECIES, wr[index + 3]),
305 w5r = broadcast(DOUBLE_SPECIES, wr[index + 4]),
306 w1i = broadcast(DOUBLE_SPECIES, -sign * wi[index]).mul(NEGATE_IM),
307 w2i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 1]).mul(NEGATE_IM),
308 w3i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 2]).mul(NEGATE_IM),
309 w4i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 3]).mul(NEGATE_IM),
310 w5i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 4]).mul(NEGATE_IM);
311 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP, i += LENGTH, j += LENGTH) {
312 DoubleVector
313 z0 = fromArray(DOUBLE_SPECIES, data, i),
314 z1 = fromArray(DOUBLE_SPECIES, data, i + di),
315 z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
316 z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
317 z4 = fromArray(DOUBLE_SPECIES, data, i + di4),
318 z5 = fromArray(DOUBLE_SPECIES, data, i + di5);
319 DoubleVector
320 ta1 = z2.add(z4),
321 ta2 = ta1.mul(-0.5).add(z0),
322 ta3 = z2.sub(z4).mul(tau).rearrange(SHUFFLE_RE_IM),
323 a0 = z0.add(ta1),
324 a1 = ta2.add(ta3.mul(NEGATE_RE)),
325 a2 = ta2.add(ta3.mul(NEGATE_IM)),
326 tb1 = z5.add(z1),
327 tb2 = tb1.mul(-0.5).add(z3),
328 tb3 = z5.sub(z1).mul(tau).rearrange(SHUFFLE_RE_IM),
329 b0 = z3.add(tb1),
330 b1 = tb2.add(tb3.mul(NEGATE_RE)),
331 b2 = tb2.add(tb3.mul(NEGATE_IM));
332 a0.add(b0).intoArray(ret, j);
333 DoubleVector
334 x1 = a1.sub(b1),
335 x2 = a2.add(b2),
336 x3 = a0.sub(b0),
337 x4 = a1.add(b1),
338 x5 = a2.sub(b2);
339 w1r.mul(x1).add(w1i.mul(x1).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj);
340 w2r.mul(x2).add(w2i.mul(x2).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj2);
341 w3r.mul(x3).add(w3i.mul(x3).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj3);
342 w4r.mul(x4).add(w4i.mul(x4).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj4);
343 w5r.mul(x5).add(w5i.mul(x5).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj5);
344 }
345 }
346 }
347
348
349
350
351
352
353 private void blocked(PassData passData) {
354 final double[] data = passData.in;
355 final double[] ret = passData.out;
356 int sign = passData.sign;
357 int i = passData.inOffset;
358 int j = passData.outOffset;
359 final double tau = sign * sqrt3_2;
360
361 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
362 final DoubleVector
363 z0r = fromArray(DOUBLE_SPECIES, data, i),
364 z1r = fromArray(DOUBLE_SPECIES, data, i + di),
365 z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
366 z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
367 z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
368 z5r = fromArray(DOUBLE_SPECIES, data, i + di5),
369 z0i = fromArray(DOUBLE_SPECIES, data, i + im),
370 z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
371 z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
372 z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
373 z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im),
374 z5i = fromArray(DOUBLE_SPECIES, data, i + di5 + im);
375 final DoubleVector
376 ta1r = z2r.add(z4r),
377 ta1i = z2i.add(z4i),
378 ta2r = ta1r.mul(-0.5).add(z0r),
379 ta2i = ta1i.mul(-0.5).add(z0i),
380 ta3r = z2r.sub(z4r).mul(tau),
381 ta3i = z2i.sub(z4i).mul(tau),
382 a0r = z0r.add(ta1r),
383 a0i = z0i.add(ta1i),
384 a1r = ta2r.sub(ta3i),
385 a1i = ta2i.add(ta3r),
386 a2r = ta2r.add(ta3i),
387 a2i = ta2i.sub(ta3r),
388 tb1r = z5r.add(z1r),
389 tb1i = z5i.add(z1i),
390 tb2r = tb1r.mul(-0.5).add(z3r),
391 tb2i = tb1i.mul(-0.5).add(z3i),
392 tb3r = z5r.sub(z1r).mul(tau),
393 tb3i = z5i.sub(z1i).mul(tau),
394 b0r = z3r.add(tb1r),
395 b0i = z3i.add(tb1i),
396 b1r = tb2r.sub(tb3i),
397 b1i = tb2i.add(tb3r),
398 b2r = tb2r.add(tb3i),
399 b2i = tb2i.sub(tb3r);
400 a0r.add(b0r).intoArray(ret, j);
401 a0i.add(b0i).intoArray(ret, j + im);
402 a1r.sub(b1r).intoArray(ret, j + dj);
403 a1i.sub(b1i).intoArray(ret, j + dj + im);
404 a2r.add(b2r).intoArray(ret, j + dj2);
405 a2i.add(b2i).intoArray(ret, j + dj2 + im);
406 a0r.sub(b0r).intoArray(ret, j + dj3);
407 a0i.sub(b0i).intoArray(ret, j + dj3 + im);
408 a1r.add(b1r).intoArray(ret, j + dj4);
409 a1i.add(b1i).intoArray(ret, j + dj4 + im);
410 a2r.sub(b2r).intoArray(ret, j + dj5);
411 a2i.sub(b2i).intoArray(ret, j + dj5 + im);
412 }
413
414 j += jstep;
415 for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
416
417
418
419
420
421
422
423
424
425
426
427
428 final int index = k * 5;
429 final DoubleVector
430 w1r = broadcast(DOUBLE_SPECIES, wr[index]),
431 w2r = broadcast(DOUBLE_SPECIES, wr[index + 1]),
432 w3r = broadcast(DOUBLE_SPECIES, wr[index + 2]),
433 w4r = broadcast(DOUBLE_SPECIES, wr[index + 3]),
434 w5r = broadcast(DOUBLE_SPECIES, wr[index + 4]),
435 w1i = broadcast(DOUBLE_SPECIES, -sign * wi[index]),
436 w2i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 1]),
437 w3i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 2]),
438 w4i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 3]),
439 w5i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 4]);
440 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
441 final DoubleVector
442 z0r = fromArray(DOUBLE_SPECIES, data, i),
443 z1r = fromArray(DOUBLE_SPECIES, data, i + di),
444 z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
445 z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
446 z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
447 z5r = fromArray(DOUBLE_SPECIES, data, i + di5),
448 z0i = fromArray(DOUBLE_SPECIES, data, i + im),
449 z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
450 z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
451 z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
452 z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im),
453 z5i = fromArray(DOUBLE_SPECIES, data, i + di5 + im);
454 final DoubleVector
455 ta1r = z2r.add(z4r),
456 ta1i = z2i.add(z4i),
457 ta2r = ta1r.mul(-0.5).add(z0r),
458 ta2i = ta1i.mul(-0.5).add(z0i),
459 ta3r = z2r.sub(z4r).mul(tau),
460 ta3i = z2i.sub(z4i).mul(tau),
461 a0r = z0r.add(ta1r),
462 a0i = z0i.add(ta1i),
463 a1r = ta2r.sub(ta3i),
464 a1i = ta2i.add(ta3r),
465 a2r = ta2r.add(ta3i),
466 a2i = ta2i.sub(ta3r),
467 tb1r = z5r.add(z1r),
468 tb1i = z5i.add(z1i),
469 tb2r = tb1r.mul(-0.5).add(z3r),
470 tb2i = tb1i.mul(-0.5).add(z3i),
471 tb3r = z5r.sub(z1r).mul(tau),
472 tb3i = z5i.sub(z1i).mul(tau),
473 b0r = z3r.add(tb1r),
474 b0i = z3i.add(tb1i),
475 b1r = tb2r.sub(tb3i),
476 b1i = tb2i.add(tb3r),
477 b2r = tb2r.add(tb3i),
478 b2i = tb2i.sub(tb3r);
479 a0r.add(b0r).intoArray(ret, j);
480 a0i.add(b0i).intoArray(ret, j + im);
481 DoubleVector
482 x1r = a1r.sub(b1r), x1i = a1i.sub(b1i),
483 x2r = a2r.add(b2r), x2i = a2i.add(b2i),
484 x3r = a0r.sub(b0r), x3i = a0i.sub(b0i),
485 x4r = a1r.add(b1r), x4i = a1i.add(b1i),
486 x5r = a2r.sub(b2r), x5i = a2i.sub(b2i);
487 w1r.mul(x1r).add(w1i.neg().mul(x1i)).intoArray(ret, j + dj);
488 w2r.mul(x2r).add(w2i.neg().mul(x2i)).intoArray(ret, j + dj2);
489 w3r.mul(x3r).add(w3i.neg().mul(x3i)).intoArray(ret, j + dj3);
490 w4r.mul(x4r).add(w4i.neg().mul(x4i)).intoArray(ret, j + dj4);
491 w5r.mul(x5r).add(w5i.neg().mul(x5i)).intoArray(ret, j + dj5);
492 w1r.mul(x1i).add(w1i.mul(x1r)).intoArray(ret, j + dj + im);
493 w2r.mul(x2i).add(w2i.mul(x2r)).intoArray(ret, j + dj2 + im);
494 w3r.mul(x3i).add(w3i.mul(x3r)).intoArray(ret, j + dj3 + im);
495 w4r.mul(x4i).add(w4i.mul(x4r)).intoArray(ret, j + dj4 + im);
496 w5r.mul(x5i).add(w5i.mul(x5r)).intoArray(ret, j + dj5 + im);
497 }
498 }
499 }
500 }