1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import jdk.incubator.vector.DoubleVector;
41 import jdk.incubator.vector.VectorShuffle;
42 import jdk.incubator.vector.VectorSpecies;
43
44 import static jdk.incubator.vector.DoubleVector.SPECIES_128;
45 import static jdk.incubator.vector.DoubleVector.SPECIES_256;
46 import static jdk.incubator.vector.DoubleVector.SPECIES_512;
47 import static jdk.incubator.vector.DoubleVector.broadcast;
48 import static jdk.incubator.vector.DoubleVector.fromArray;
49
50
51
52
53 public class MixedRadixFactor2 extends MixedRadixFactor {
54
55
56
57
58 private static int[] simdSizes = {8, 4, 2};
59
60
61
62
63
64
65 public MixedRadixFactor2(PassConstants passConstants) {
66 super(passConstants);
67 }
68
69
70
71
72
73
74
75 @Override
76 public boolean isValidSIMDWidth(int width) {
77
78 if (width != 2 && width != 4 && width != 8) {
79 return false;
80 }
81 if (im == 1) {
82
83 return innerLoopLimit % (width / 2) == 0;
84 } else {
85
86 return innerLoopLimit % width == 0;
87 }
88 }
89
90
91
92
93
94
95
96 @Override
97 public int getOptimalSIMDWidth() {
98
99 if (isValidSIMDWidth(LENGTH)) {
100 return LENGTH;
101 }
102
103 for (int size : simdSizes) {
104 if (size >= LENGTH) {
105
106 continue;
107 }
108 if (isValidSIMDWidth(size)) {
109 return size;
110 }
111 }
112
113 return 0;
114 }
115
116
117
118
119
120
121 @Override
122 protected void passScalar(PassData passData) {
123 final double[] data = passData.in;
124 final double[] ret = passData.out;
125 int sign = passData.sign;
126 int i = passData.inOffset;
127 int j = passData.outOffset;
128
129 for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
130 final double z0_r = data[i];
131 final double z0_i = data[i + im];
132 final int idi = i + di;
133 final double z1_r = data[idi];
134 final double z1_i = data[idi + im];
135 ret[j] = z0_r + z1_r;
136 ret[j + im] = z0_i + z1_i;
137 final double x_r = z0_r - z1_r;
138 final double x_i = z0_i - z1_i;
139 final int jdj = j + dj;
140 ret[jdj] = x_r;
141 ret[jdj + im] = x_i;
142 }
143 j += dj;
144 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
145 final double w_r = wr[k];
146 final double w_i = -sign * wi[k];
147 for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
148 final double z0_r = data[i];
149 final double z0_i = data[i + im];
150 final int idi = i + di;
151 final double z1_r = data[idi];
152 final double z1_i = data[idi + im];
153 ret[j] = z0_r + z1_r;
154 ret[j + im] = z0_i + z1_i;
155 final int jdj = j + dj;
156 multiplyAndStore(z0_r - z1_r, z0_i - z1_i, w_r, w_i, ret, jdj, jdj + im);
157 }
158 }
159 }
160
161
162
163
164
165
166 @Override
167 protected void passSIMD(PassData passData) {
168 if (!isValidSIMDWidth(simdWidth)) {
169 passScalar(passData);
170 } else {
171 if (im == 1) {
172 interleaved(passData, simdWidth);
173 } else {
174 blocked(passData, simdWidth);
175 }
176 }
177 }
178
179
180
181
182
183
184
185 private void interleaved(PassData passData, int simdLength) {
186
187 switch (simdLength) {
188 case 2:
189
190 interleaved128(passData);
191 break;
192 case 4:
193
194 interleaved256(passData);
195 break;
196 case 8:
197
198 interleaved512(passData);
199 break;
200 default:
201 passScalar(passData);
202 }
203 }
204
205
206
207
208
209
210
211 private void blocked(PassData passData, int simdLength) {
212
213 switch (simdLength) {
214 case 2:
215
216 blocked128(passData);
217 break;
218 case 4:
219
220 blocked256(passData);
221 break;
222 case 8:
223
224 blocked512(passData);
225 break;
226 default:
227 passScalar(passData);
228 }
229 }
230
231
232
233
234
235
236
237
238
239
240
241 private void butterFly2Blocked(
242 VectorSpecies<Double> species,
243 double[] data, int i, double w_r, double w_i, double[] ret, int j) {
244 final DoubleVector
245 z0_r = fromArray(species, data, i),
246 z0_i = fromArray(species, data, i + im),
247 z1_r = fromArray(species, data, i + di),
248 z1_i = fromArray(species, data, i + di + im);
249 z0_r.add(z1_r).intoArray(ret, j);
250 z0_i.add(z1_i).intoArray(ret, j + im);
251 final DoubleVector
252 x_r = z0_r.sub(z1_r),
253 x_i = z0_i.sub(z1_i);
254 x_r.mul(w_r).sub(x_i.mul(w_i)).intoArray(ret, j + dj);
255 x_i.mul(w_r).add(x_r.mul(w_i)).intoArray(ret, j + dj + im);
256 }
257
258
259
260
261
262
263
264
265
266
267
268 private void butterFly2Interleaved(
269 DoubleVector z0, DoubleVector z1,
270 DoubleVector w_r, DoubleVector w_i,
271 VectorShuffle<Double> shuffle_re_im,
272 int j, double[] ret) {
273 z0.add(z1).intoArray(ret, j);
274 final DoubleVector x = z0.sub(z1);
275 x.mul(w_r).add(x.mul(w_i).rearrange(shuffle_re_im)).intoArray(ret, j + dj);
276 }
277
278
279
280
281 private void blocked128(PassData passData) {
282 final double[] data = passData.in;
283 final double[] ret = passData.out;
284 final int sign = passData.sign;
285 int i = passData.inOffset;
286 int j = passData.outOffset;
287
288 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
289 final DoubleVector
290 z0_r = fromArray(SPECIES_128, data, i),
291 z0_i = fromArray(SPECIES_128, data, i + im),
292 z1_r = fromArray(SPECIES_128, data, i + di),
293 z1_i = fromArray(SPECIES_128, data, i + di + im);
294 z0_r.add(z1_r).intoArray(ret, j);
295 z0_i.add(z1_i).intoArray(ret, j + im);
296 z0_r.sub(z1_r).intoArray(ret, j + dj);
297 z0_i.sub(z1_i).intoArray(ret, j + dj + im);
298 }
299
300 j += dj;
301 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
302 final double
303 w_r = wr[k],
304 w_i = -sign * wi[k];
305 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
306 final DoubleVector
307 z0_r = fromArray(SPECIES_128, data, i),
308 z0_i = fromArray(SPECIES_128, data, i + im),
309 z1_r = fromArray(SPECIES_128, data, i + di),
310 z1_i = fromArray(SPECIES_128, data, i + di + im);
311 z0_r.add(z1_r).intoArray(ret, j);
312 z0_i.add(z1_i).intoArray(ret, j + im);
313 final DoubleVector
314 x_r = z0_r.sub(z1_r),
315 x_i = z0_i.sub(z1_i);
316 x_r.mul(w_r).sub(x_i.mul(w_i)).intoArray(ret, j + dj);
317 x_i.mul(w_r).add(x_r.mul(w_i)).intoArray(ret, j + dj + im);
318 }
319 }
320 }
321
322
323
324
325 private void blocked256(PassData passData) {
326 final double[] data = passData.in;
327 final double[] ret = passData.out;
328 final int sign = passData.sign;
329 int i = passData.inOffset;
330 int j = passData.outOffset;
331 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
332 final DoubleVector
333 z0_r = fromArray(SPECIES_256, data, i),
334 z0_i = fromArray(SPECIES_256, data, i + im),
335 z1_r = fromArray(SPECIES_256, data, i + di),
336 z1_i = fromArray(SPECIES_256, data, i + di + im);
337 z0_r.add(z1_r).intoArray(ret, j);
338 z0_i.add(z1_i).intoArray(ret, j + im);
339 z0_r.sub(z1_r).intoArray(ret, j + dj);
340 z0_i.sub(z1_i).intoArray(ret, j + dj + im);
341 }
342
343 j += dj;
344 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
345 final double
346 w_r = wr[k],
347 w_i = -sign * wi[k];
348 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
349 final DoubleVector
350 z0_r = fromArray(SPECIES_256, data, i),
351 z0_i = fromArray(SPECIES_256, data, i + im),
352 z1_r = fromArray(SPECIES_256, data, i + di),
353 z1_i = fromArray(SPECIES_256, data, i + di + im);
354 z0_r.add(z1_r).intoArray(ret, j);
355 z0_i.add(z1_i).intoArray(ret, j + im);
356 final DoubleVector
357 x_r = z0_r.sub(z1_r),
358 x_i = z0_i.sub(z1_i);
359 x_r.mul(w_r).sub(x_i.mul(w_i)).intoArray(ret, j + dj);
360 x_i.mul(w_r).add(x_r.mul(w_i)).intoArray(ret, j + dj + im);
361 }
362 }
363 }
364
365
366
367
368 private void blocked512(PassData passData) {
369 final double[] data = passData.in;
370 final double[] ret = passData.out;
371 final int sign = passData.sign;
372 int i = passData.inOffset;
373 int j = passData.outOffset;
374
375 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
376 final DoubleVector
377 z0_r = fromArray(SPECIES_512, data, i),
378 z0_i = fromArray(SPECIES_512, data, i + im),
379 z1_r = fromArray(SPECIES_512, data, i + di),
380 z1_i = fromArray(SPECIES_512, data, i + di + im);
381 z0_r.add(z1_r).intoArray(ret, j);
382 z0_i.add(z1_i).intoArray(ret, j + im);
383 z0_r.sub(z1_r).intoArray(ret, j + dj);
384 z0_i.sub(z1_i).intoArray(ret, j + dj + im);
385 }
386
387 j += dj;
388 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
389 final double
390 w_r = wr[k],
391 w_i = -sign * wi[k];
392 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
393 final DoubleVector
394 z0_r = fromArray(SPECIES_512, data, i),
395 z1_r = fromArray(SPECIES_512, data, i + di),
396 z0_i = fromArray(SPECIES_512, data, i + im),
397 z1_i = fromArray(SPECIES_512, data, i + di + im);
398 z0_r.add(z1_r).intoArray(ret, j);
399 z0_i.add(z1_i).intoArray(ret, j + im);
400 final DoubleVector
401 x_r = z0_r.sub(z1_r),
402 x_i = z0_i.sub(z1_i);
403 x_r.mul(w_r).sub(x_i.mul(w_i)).intoArray(ret, j + dj);
404 x_i.mul(w_r).add(x_r.mul(w_i)).intoArray(ret, j + dj + im);
405 }
406 }
407 }
408
409
410
411
412 private void interleaved128(PassData passData) {
413 final double[] data = passData.in;
414 final double[] ret = passData.out;
415 final int sign = passData.sign;
416 int i = passData.inOffset;
417 int j = passData.outOffset;
418
419 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
420 final DoubleVector
421 z0 = fromArray(SPECIES_128, data, i),
422 z1 = fromArray(SPECIES_128, data, i + di);
423 z0.add(z1).intoArray(ret, j);
424 z0.sub(z1).intoArray(ret, j + dj);
425 }
426
427 j += dj;
428 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
429 final DoubleVector
430 w_r = broadcast(SPECIES_128, wr[k]),
431 w_i = broadcast(SPECIES_128, -sign * wi[k]).mul(NEGATE_IM_128);
432 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
433 final DoubleVector
434 z0 = fromArray(SPECIES_128, data, i),
435 z1 = fromArray(SPECIES_128, data, i + di);
436 z0.add(z1).intoArray(ret, j);
437 final DoubleVector x = z0.sub(z1);
438 x.fma(w_r, x.mul(w_i).rearrange(SHUFFLE_RE_IM_128)).intoArray(ret, j + dj);
439 }
440 }
441 }
442
443
444
445
446 private void interleaved256(PassData passData) {
447 final double[] data = passData.in;
448 final double[] ret = passData.out;
449 final int sign = passData.sign;
450 int i = passData.inOffset;
451 int j = passData.outOffset;
452
453 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
454 final DoubleVector
455 z0 = fromArray(SPECIES_256, data, i),
456 z1 = fromArray(SPECIES_256, data, i + di);
457 z0.add(z1).intoArray(ret, j);
458 z0.sub(z1).intoArray(ret, j + dj);
459 }
460
461 j += dj;
462 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
463 final DoubleVector
464 w_r = broadcast(SPECIES_256, wr[k]),
465 w_i = broadcast(SPECIES_256, -sign * wi[k]).mul(NEGATE_IM_256);
466 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
467 final DoubleVector
468 z0 = fromArray(SPECIES_256, data, i),
469 z1 = fromArray(SPECIES_256, data, i + di);
470 z0.add(z1).intoArray(ret, j);
471 final DoubleVector x = z0.sub(z1);
472 x.fma(w_r, x.mul(w_i).rearrange(SHUFFLE_RE_IM_256)).intoArray(ret, j + dj);
473 }
474 }
475 }
476
477
478
479
480 private void interleaved512(PassData passData) {
481 final double[] data = passData.in;
482 final double[] ret = passData.out;
483 final int sign = passData.sign;
484 int i = passData.inOffset;
485 int j = passData.outOffset;
486
487 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
488 final DoubleVector
489 z0 = fromArray(SPECIES_512, data, i),
490 z1 = fromArray(SPECIES_512, data, i + di);
491 z0.add(z1).intoArray(ret, j);
492 z0.sub(z1).intoArray(ret, j + dj);
493 }
494
495 j += dj;
496 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
497 final DoubleVector
498 w_r = broadcast(SPECIES_512, wr[k]),
499 w_i = broadcast(SPECIES_512, -sign * wi[k]).mul(NEGATE_IM_512);
500 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
501 final DoubleVector
502 z0 = fromArray(SPECIES_512, data, i),
503 z1 = fromArray(SPECIES_512, data, i + di);
504 z0.add(z1).intoArray(ret, j);
505 final DoubleVector x = z0.sub(z1);
506 x.fma(w_r, x.mul(w_i).rearrange(SHUFFLE_RE_IM_512)).intoArray(ret, j + dj);
507 }
508 }
509 }
510
511 }