1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import jdk.incubator.vector.DoubleVector;
41
42 import static java.lang.Math.fma;
43 import static jdk.incubator.vector.DoubleVector.broadcast;
44 import static jdk.incubator.vector.DoubleVector.fromArray;
45 import static org.apache.commons.math3.util.FastMath.sqrt;
46
47
48
49
50 public class MixedRadixFactor6 extends MixedRadixFactor {
51
52 private static final double sqrt3_2 = sqrt(3.0) / 2.0;
53
54 private final int di2;
55 private final int di3;
56 private final int di4;
57 private final int di5;
58 private final int dj2;
59 private final int dj3;
60 private final int dj4;
61 private final int dj5;
62
63
64
65
66
67
68 public MixedRadixFactor6(PassConstants passConstants) {
69 super(passConstants);
70 di2 = 2 * di;
71 di3 = 3 * di;
72 di4 = 4 * di;
73 di5 = 5 * di;
74 dj2 = 2 * dj;
75 dj3 = 3 * dj;
76 dj4 = 4 * dj;
77 dj5 = 5 * dj;
78 }
79
80
81
82
83
84
85 @Override
86 protected void passScalar(PassData passData) {
87 final double[] data = passData.in;
88 final double[] ret = passData.out;
89 int sign = passData.sign;
90 int i = passData.inOffset;
91 int j = passData.outOffset;
92 final double tau = sign * sqrt3_2;
93
94 for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
95 final double z0r = data[i];
96 final double z1r = data[i + di];
97 final double z2r = data[i + di2];
98 final double z3r = data[i + di3];
99 final double z4r = data[i + di4];
100 final double z5r = data[i + di5];
101 final double z0i = data[i + im];
102 final double z1i = data[i + di + im];
103 final double z2i = data[i + di2 + im];
104 final double z3i = data[i + di3 + im];
105 final double z4i = data[i + di4 + im];
106 final double z5i = data[i + di5 + im];
107 final double ta1r = z2r + z4r;
108 final double ta1i = z2i + z4i;
109 final double ta2r = fma(-0.5, ta1r, z0r);
110 final double ta2i = fma(-0.5, ta1i, z0i);
111 final double ta3r = tau * (z2r - z4r);
112 final double ta3i = tau * (z2i - z4i);
113 final double a0r = z0r + ta1r;
114 final double a0i = z0i + ta1i;
115 final double a1r = ta2r - ta3i;
116 final double a1i = ta2i + ta3r;
117 final double a2r = ta2r + ta3i;
118 final double a2i = ta2i - ta3r;
119 final double tb1r = z5r + z1r;
120 final double tb1i = z5i + z1i;
121 final double tb2r = fma(-0.5, tb1r, z3r);
122 final double tb2i = fma(-0.5, tb1i, z3i);
123 final double tb3r = tau * (z5r - z1r);
124 final double tb3i = tau * (z5i - z1i);
125 final double b0r = z3r + tb1r;
126 final double b0i = z3i + tb1i;
127 final double b1r = tb2r - tb3i;
128 final double b1i = tb2i + tb3r;
129 final double b2r = tb2r + tb3i;
130 final double b2i = tb2i - tb3r;
131 ret[j] = a0r + b0r;
132 ret[j + im] = a0i + b0i;
133 ret[j + dj] = a1r - b1r;
134 ret[j + dj + im] = a1i - b1i;
135 ret[j + dj2] = a2r + b2r;
136 ret[j + dj2 + im] = a2i + b2i;
137 ret[j + dj3] = a0r - b0r;
138 ret[j + dj3 + im] = a0i - b0i;
139 ret[j + dj4] = a1r + b1r;
140 ret[j + dj4 + im] = a1i + b1i;
141 ret[j + dj5] = a2r - b2r;
142 ret[j + dj5 + im] = a2i - b2i;
143 }
144
145 j += jstep;
146 for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
147 final int index = k * 5;
148 final double w1r = wr[index];
149 final double w2r = wr[index + 1];
150 final double w3r = wr[index + 2];
151 final double w4r = wr[index + 3];
152 final double w5r = wr[index + 4];
153 final double w1i = -sign * wi[index];
154 final double w2i = -sign * wi[index + 1];
155 final double w3i = -sign * wi[index + 2];
156 final double w4i = -sign * wi[index + 3];
157 final double w5i = -sign * wi[index + 4];
158 for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
159 final double z0r = data[i];
160 final double z1r = data[i + di];
161 final double z2r = data[i + di2];
162 final double z3r = data[i + di3];
163 final double z4r = data[i + di4];
164 final double z5r = data[i + di5];
165 final double z0i = data[i + im];
166 final double z1i = data[i + di + im];
167 final double z2i = data[i + di2 + im];
168 final double z3i = data[i + di3 + im];
169 final double z4i = data[i + di4 + im];
170 final double z5i = data[i + di5 + im];
171 final double ta1r = z2r + z4r;
172 final double ta1i = z2i + z4i;
173 final double ta2r = fma(-0.5, ta1r, z0r);
174 final double ta2i = fma(-0.5, ta1i, z0i);
175 final double ta3r = tau * (z2r - z4r);
176 final double ta3i = tau * (z2i - z4i);
177 final double a0r = z0r + ta1r;
178 final double a0i = z0i + ta1i;
179 final double a1r = ta2r - ta3i;
180 final double a1i = ta2i + ta3r;
181 final double a2r = ta2r + ta3i;
182 final double a2i = ta2i - ta3r;
183 final double tb1r = z5r + z1r;
184 final double tb1i = z5i + z1i;
185 final double tb2r = fma(-0.5, tb1r, z3r);
186 final double tb2i = fma(-0.5, tb1i, z3i);
187 final double tb3r = tau * (z5r - z1r);
188 final double tb3i = tau * (z5i - z1i);
189 final double b0r = z3r + tb1r;
190 final double b0i = z3i + tb1i;
191 final double b1r = tb2r - tb3i;
192 final double b1i = tb2i + tb3r;
193 final double b2r = tb2r + tb3i;
194 final double b2i = tb2i - tb3r;
195 ret[j] = a0r + b0r;
196 ret[j + im] = a0i + b0i;
197 multiplyAndStore(a1r - b1r, a1i - b1i, w1r, w1i, ret, j + dj, j + dj + im);
198 multiplyAndStore(a2r + b2r, a2i + b2i, w2r, w2i, ret, j + dj2, j + dj2 + im);
199 multiplyAndStore(a0r - b0r, a0i - b0i, w3r, w3i, ret, j + dj3, j + dj3 + im);
200 multiplyAndStore(a1r + b1r, a1i + b1i, w4r, w4i, ret, j + dj4, j + dj4 + im);
201 multiplyAndStore(a2r - b2r, a2i - b2i, w5r, w5i, ret, j + dj5, j + dj5 + im);
202 }
203 }
204 }
205
206
207
208
209
210
211 @Override
212 protected void passSIMD(PassData passData) {
213 if (!isValidSIMDWidth(simdWidth)) {
214 passScalar(passData);
215 } else {
216 if (im == 1) {
217 interleaved(passData);
218 } else {
219 blocked(passData);
220 }
221 }
222 }
223
224
225
226
227
228
229 private void interleaved(PassData passData) {
230 final double[] data = passData.in;
231 final double[] ret = passData.out;
232 int sign = passData.sign;
233 int i = passData.inOffset;
234 int j = passData.outOffset;
235 final double tau = sign * sqrt3_2;
236
237 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP, i += LENGTH, j += LENGTH) {
238 final DoubleVector
239 z0 = fromArray(DOUBLE_SPECIES, data, i),
240 z1 = fromArray(DOUBLE_SPECIES, data, i + di),
241 z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
242 z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
243 z4 = fromArray(DOUBLE_SPECIES, data, i + di4),
244 z5 = fromArray(DOUBLE_SPECIES, data, i + di5);
245 final DoubleVector
246 ta1 = z2.add(z4),
247 ta2 = ta1.mul(-0.5).add(z0),
248 ta3 = z2.sub(z4).mul(tau).rearrange(SHUFFLE_RE_IM),
249 a0 = z0.add(ta1),
250 a1 = ta3.fma(NEGATE_RE, ta2),
251 a2 = ta3.fma(NEGATE_IM, ta2),
252 tb1 = z5.add(z1),
253 tb2 = tb1.mul(-0.5).add(z3),
254 tb3 = z5.sub(z1).mul(tau).rearrange(SHUFFLE_RE_IM),
255 b0 = z3.add(tb1),
256 b1 = tb3.fma(NEGATE_RE, tb2),
257 b2 = tb3.fma(NEGATE_IM, tb2);
258 a0.add(b0).intoArray(ret, j);
259 a1.sub(b1).intoArray(ret, j + dj);
260 a2.add(b2).intoArray(ret, j + dj2);
261 a0.sub(b0).intoArray(ret, j + dj3);
262 a1.add(b1).intoArray(ret, j + dj4);
263 a2.sub(b2).intoArray(ret, j + dj5);
264 }
265
266 j += jstep;
267 for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
268 final int index = k * 5;
269 final DoubleVector
270 w1r = broadcast(DOUBLE_SPECIES, wr[index]),
271 w2r = broadcast(DOUBLE_SPECIES, wr[index + 1]),
272 w3r = broadcast(DOUBLE_SPECIES, wr[index + 2]),
273 w4r = broadcast(DOUBLE_SPECIES, wr[index + 3]),
274 w5r = broadcast(DOUBLE_SPECIES, wr[index + 4]),
275 w1i = broadcast(DOUBLE_SPECIES, -sign * wi[index]).mul(NEGATE_IM),
276 w2i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 1]).mul(NEGATE_IM),
277 w3i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 2]).mul(NEGATE_IM),
278 w4i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 3]).mul(NEGATE_IM),
279 w5i = broadcast(DOUBLE_SPECIES, -sign * wi[index + 4]).mul(NEGATE_IM);
280 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP, i += LENGTH, j += LENGTH) {
281 final DoubleVector
282 z0 = fromArray(DOUBLE_SPECIES, data, i),
283 z1 = fromArray(DOUBLE_SPECIES, data, i + di),
284 z2 = fromArray(DOUBLE_SPECIES, data, i + di2),
285 z3 = fromArray(DOUBLE_SPECIES, data, i + di3),
286 z4 = fromArray(DOUBLE_SPECIES, data, i + di4),
287 z5 = fromArray(DOUBLE_SPECIES, data, i + di5);
288 final DoubleVector
289 ta1 = z2.add(z4),
290 ta2 = ta1.mul(-0.5).add(z0),
291 ta3 = z2.sub(z4).mul(tau).rearrange(SHUFFLE_RE_IM),
292 a0 = z0.add(ta1),
293 a1 = ta3.fma(NEGATE_RE, ta2),
294 a2 = ta3.fma(NEGATE_IM, ta2),
295 tb1 = z5.add(z1),
296 tb2 = tb1.mul(-0.5).add(z3),
297 tb3 = z5.sub(z1).mul(tau).rearrange(SHUFFLE_RE_IM),
298 b0 = z3.add(tb1),
299 b1 = tb3.fma(NEGATE_RE, tb2),
300 b2 = tb3.fma(NEGATE_IM, tb2);
301 a0.add(b0).intoArray(ret, j);
302 final DoubleVector
303 x1 = a1.sub(b1),
304 x2 = a2.add(b2),
305 x3 = a0.sub(b0),
306 x4 = a1.add(b1),
307 x5 = a2.sub(b2);
308 w1r.fma(x1, w1i.mul(x1).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj);
309 w2r.fma(x2, w2i.mul(x2).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj2);
310 w3r.fma(x3, w3i.mul(x3).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj3);
311 w4r.fma(x4, w4i.mul(x4).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj4);
312 w5r.fma(x5, w5i.mul(x5).rearrange(SHUFFLE_RE_IM)).intoArray(ret, j + dj5);
313 }
314 }
315 }
316
317
318
319
320
321
322 private void blocked(PassData passData) {
323 final double[] data = passData.in;
324 final double[] ret = passData.out;
325 int sign = passData.sign;
326 int i = passData.inOffset;
327 int j = passData.outOffset;
328 final double tau = sign * sqrt3_2;
329
330 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
331 final DoubleVector
332 z0r = fromArray(DOUBLE_SPECIES, data, i),
333 z1r = fromArray(DOUBLE_SPECIES, data, i + di),
334 z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
335 z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
336 z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
337 z5r = fromArray(DOUBLE_SPECIES, data, i + di5),
338 z0i = fromArray(DOUBLE_SPECIES, data, i + im),
339 z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
340 z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
341 z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
342 z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im),
343 z5i = fromArray(DOUBLE_SPECIES, data, i + di5 + im);
344 final DoubleVector
345 ta1r = z2r.add(z4r),
346 ta1i = z2i.add(z4i),
347 ta2r = ta1r.mul(-0.5).add(z0r),
348 ta2i = ta1i.mul(-0.5).add(z0i),
349 ta3r = z2r.sub(z4r).mul(tau),
350 ta3i = z2i.sub(z4i).mul(tau),
351 a0r = z0r.add(ta1r),
352 a0i = z0i.add(ta1i),
353 a1r = ta2r.sub(ta3i),
354 a1i = ta2i.add(ta3r),
355 a2r = ta2r.add(ta3i),
356 a2i = ta2i.sub(ta3r),
357 tb1r = z5r.add(z1r),
358 tb1i = z5i.add(z1i),
359 tb2r = tb1r.mul(-0.5).add(z3r),
360 tb2i = tb1i.mul(-0.5).add(z3i),
361 tb3r = z5r.sub(z1r).mul(tau),
362 tb3i = z5i.sub(z1i).mul(tau),
363 b0r = z3r.add(tb1r),
364 b0i = z3i.add(tb1i),
365 b1r = tb2r.sub(tb3i),
366 b1i = tb2i.add(tb3r),
367 b2r = tb2r.add(tb3i),
368 b2i = tb2i.sub(tb3r);
369 a0r.add(b0r).intoArray(ret, j);
370 a0i.add(b0i).intoArray(ret, j + im);
371 a1r.sub(b1r).intoArray(ret, j + dj);
372 a1i.sub(b1i).intoArray(ret, j + dj + im);
373 a2r.add(b2r).intoArray(ret, j + dj2);
374 a2i.add(b2i).intoArray(ret, j + dj2 + im);
375 a0r.sub(b0r).intoArray(ret, j + dj3);
376 a0i.sub(b0i).intoArray(ret, j + dj3 + im);
377 a1r.add(b1r).intoArray(ret, j + dj4);
378 a1i.add(b1i).intoArray(ret, j + dj4 + im);
379 a2r.sub(b2r).intoArray(ret, j + dj5);
380 a2i.sub(b2i).intoArray(ret, j + dj5 + im);
381 }
382
383 j += jstep;
384 for (int k = 1; k < outerLoopLimit; k++, j += jstep) {
385 final int index = k * 5;
386 final double
387 w1r = wr[index],
388 w2r = wr[index + 1],
389 w3r = wr[index + 2],
390 w4r = wr[index + 3],
391 w5r = wr[index + 4],
392 w1i = -sign * wi[index],
393 w2i = -sign * wi[index + 1],
394 w3i = -sign * wi[index + 2],
395 w4i = -sign * wi[index + 3],
396 w5i = -sign * wi[index + 4];
397 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP, i += LENGTH, j += LENGTH) {
398 final DoubleVector
399 z0r = fromArray(DOUBLE_SPECIES, data, i),
400 z1r = fromArray(DOUBLE_SPECIES, data, i + di),
401 z2r = fromArray(DOUBLE_SPECIES, data, i + di2),
402 z3r = fromArray(DOUBLE_SPECIES, data, i + di3),
403 z4r = fromArray(DOUBLE_SPECIES, data, i + di4),
404 z5r = fromArray(DOUBLE_SPECIES, data, i + di5),
405 z0i = fromArray(DOUBLE_SPECIES, data, i + im),
406 z1i = fromArray(DOUBLE_SPECIES, data, i + di + im),
407 z2i = fromArray(DOUBLE_SPECIES, data, i + di2 + im),
408 z3i = fromArray(DOUBLE_SPECIES, data, i + di3 + im),
409 z4i = fromArray(DOUBLE_SPECIES, data, i + di4 + im),
410 z5i = fromArray(DOUBLE_SPECIES, data, i + di5 + im);
411 final DoubleVector
412 ta1r = z2r.add(z4r),
413 ta1i = z2i.add(z4i),
414 ta2r = ta1r.mul(-0.5).add(z0r),
415 ta2i = ta1i.mul(-0.5).add(z0i),
416 ta3r = z2r.sub(z4r).mul(tau),
417 ta3i = z2i.sub(z4i).mul(tau),
418 a0r = z0r.add(ta1r),
419 a0i = z0i.add(ta1i),
420 a1r = ta2r.sub(ta3i),
421 a1i = ta2i.add(ta3r),
422 a2r = ta2r.add(ta3i),
423 a2i = ta2i.sub(ta3r),
424 tb1r = z5r.add(z1r),
425 tb1i = z5i.add(z1i),
426 tb2r = tb1r.mul(-0.5).add(z3r),
427 tb2i = tb1i.mul(-0.5).add(z3i),
428 tb3r = z5r.sub(z1r).mul(tau),
429 tb3i = z5i.sub(z1i).mul(tau),
430 b0r = z3r.add(tb1r),
431 b0i = z3i.add(tb1i),
432 b1r = tb2r.sub(tb3i),
433 b1i = tb2i.add(tb3r),
434 b2r = tb2r.add(tb3i),
435 b2i = tb2i.sub(tb3r);
436 a0r.add(b0r).intoArray(ret, j);
437 a0i.add(b0i).intoArray(ret, j + im);
438 final DoubleVector
439 x1r = a1r.sub(b1r), x1i = a1i.sub(b1i),
440 x2r = a2r.add(b2r), x2i = a2i.add(b2i),
441 x3r = a0r.sub(b0r), x3i = a0i.sub(b0i),
442 x4r = a1r.add(b1r), x4i = a1i.add(b1i),
443 x5r = a2r.sub(b2r), x5i = a2i.sub(b2i);
444 x1r.mul(w1r).sub(x1i.mul(w1i)).intoArray(ret, j + dj);
445 x2r.mul(w2r).sub(x2i.mul(w2i)).intoArray(ret, j + dj2);
446 x3r.mul(w3r).sub(x3i.mul(w3i)).intoArray(ret, j + dj3);
447 x4r.mul(w4r).sub(x4i.mul(w4i)).intoArray(ret, j + dj4);
448 x5r.mul(w5r).sub(x5i.mul(w5i)).intoArray(ret, j + dj5);
449 x1i.mul(w1r).add(x1r.mul(w1i)).intoArray(ret, j + dj + im);
450 x2i.mul(w2r).add(x2r.mul(w2i)).intoArray(ret, j + dj2 + im);
451 x3i.mul(w3r).add(x3r.mul(w3i)).intoArray(ret, j + dj3 + im);
452 x4i.mul(w4r).add(x4r.mul(w4i)).intoArray(ret, j + dj4 + im);
453 x5i.mul(w5r).add(x5r.mul(w5i)).intoArray(ret, j + dj5 + im);
454 }
455 }
456 }
457 }