1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import jdk.incubator.vector.DoubleVector;
41
42 import static jdk.incubator.vector.DoubleVector.SPECIES_128;
43 import static jdk.incubator.vector.DoubleVector.SPECIES_256;
44 import static jdk.incubator.vector.DoubleVector.SPECIES_512;
45 import static jdk.incubator.vector.DoubleVector.broadcast;
46 import static jdk.incubator.vector.DoubleVector.fromArray;
47
48
49
50
51 public class MixedRadixFactor2 extends MixedRadixFactor {
52
53
54
55
56
57
58 public MixedRadixFactor2(PassConstants passConstants) {
59 super(passConstants);
60 }
61
62
63
64
65
66
67 @Override
68 protected void passScalar(PassData passData) {
69 final double[] data = passData.in;
70 final double[] ret = passData.out;
71 int sign = passData.sign;
72 int i = passData.inOffset;
73 int j = passData.outOffset;
74
75 for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
76 final double z0_r = data[i];
77 final double z0_i = data[i + im];
78 final int idi = i + di;
79 final double z1_r = data[idi];
80 final double z1_i = data[idi + im];
81 ret[j] = z0_r + z1_r;
82 ret[j + im] = z0_i + z1_i;
83 final double x_r = z0_r - z1_r;
84 final double x_i = z0_i - z1_i;
85 final int jdj = j + dj;
86 ret[jdj] = x_r;
87 ret[jdj + im] = x_i;
88 }
89 j += dj;
90 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
91 final double[] twids = twiddles[k];
92 final double w_r = twids[0];
93 final double w_i = -sign * twids[1];
94 for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
95 final double z0_r = data[i];
96 final double z0_i = data[i + im];
97 final int idi = i + di;
98 final double z1_r = data[idi];
99 final double z1_i = data[idi + im];
100 ret[j] = z0_r + z1_r;
101 ret[j + im] = z0_i + z1_i;
102 final int jdj = j + dj;
103 multiplyAndStore(z0_r - z1_r, z0_i - z1_i, w_r, w_i, ret, jdj, jdj + im);
104 }
105 }
106 }
107
108
109
110
111
112
113 @Override
114 protected void passSIMD(PassData passData) {
115 if (im == 1) {
116 interleaved(passData);
117 } else {
118 blocked(passData);
119 }
120 }
121
122
123
124
125 private static int[] simdSizes = {8, 4, 2};
126
127
128
129
130
131
132
133 private void interleaved(PassData passData, int simdLength) {
134
135 switch (simdLength) {
136 case 2:
137
138 interleaved128(passData);
139 break;
140 case 4:
141
142 interleaved256(passData);
143 break;
144 case 8:
145
146 interleaved512(passData);
147 break;
148 default:
149 passScalar(passData);
150 }
151 }
152
153
154
155
156
157
158
159 private void blocked(PassData passData, int simdLength) {
160
161 switch (simdLength) {
162 case 2:
163
164 blocked128(passData);
165 break;
166 case 4:
167
168 blocked256(passData);
169 break;
170 case 8:
171
172 blocked512(passData);
173 break;
174 default:
175 passScalar(passData);
176 }
177 }
178
179
180
181
182
183
184 private void interleaved(PassData passData) {
185 if (innerLoopLimit % LOOP == 0) {
186
187 interleaved(passData, LENGTH);
188 } else {
189
190 if (innerLoopLimit % 2 != 0 && innerLoopLimit != 1) {
191 passScalar(passData);
192 return;
193 }
194
195 for (int size : simdSizes) {
196 if (size >= LENGTH) {
197
198 continue;
199 }
200
201 if (innerLoopLimit % (size / 2) == 0) {
202 interleaved(passData, size);
203 }
204 }
205 }
206 }
207
208
209
210
211
212
213 private void blocked(PassData passData) {
214 if (innerLoopLimit % BLOCK_LOOP == 0) {
215
216 blocked(passData, LENGTH);
217 } else {
218
219 if (innerLoopLimit % 2 != 0) {
220 passScalar(passData);
221 return;
222 }
223
224 for (int size : simdSizes) {
225 if (size >= LENGTH) {
226
227 continue;
228 }
229
230 if (innerLoopLimit % size == 0) {
231 blocked(passData, size);
232 }
233 }
234 }
235 }
236
237
238
239
240 private void blocked128(PassData passData) {
241 final double[] data = passData.in;
242 final double[] ret = passData.out;
243 int sign = passData.sign;
244 int i = passData.inOffset;
245 int j = passData.outOffset;
246
247 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
248 DoubleVector
249 z0_r = fromArray(SPECIES_128, data, i),
250 z1_r = fromArray(SPECIES_128, data, i + di),
251 z0_i = fromArray(SPECIES_128, data, i + im),
252 z1_i = fromArray(SPECIES_128, data, i + di + im);
253 z0_r.add(z1_r).intoArray(ret, j);
254 z0_i.add(z1_i).intoArray(ret, j + im);
255 z0_r.sub(z1_r).intoArray(ret, j + dj);
256 z0_i.sub(z1_i).intoArray(ret, j + dj + im);
257 }
258
259 j += dj;
260 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
261 final double[] twids = twiddles[k];
262 DoubleVector
263 w_r = broadcast(SPECIES_128, twids[0]),
264 w_i = broadcast(SPECIES_128, -sign * twids[1]);
265 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
266 DoubleVector
267 z0_r = fromArray(SPECIES_128, data, i),
268 z1_r = fromArray(SPECIES_128, data, i + di),
269 z0_i = fromArray(SPECIES_128, data, i + im),
270 z1_i = fromArray(SPECIES_128, data, i + di + im);
271 z0_r.add(z1_r).intoArray(ret, j);
272 z0_i.add(z1_i).intoArray(ret, j + im);
273 DoubleVector x_r = z0_r.sub(z1_r), x_i = z0_i.sub(z1_i);
274 w_r.mul(x_r).add(w_i.neg().mul(x_i)).intoArray(ret, j + dj);
275 w_r.mul(x_i).add(w_i.mul(x_r)).intoArray(ret, j + dj + im);
276 }
277 }
278 }
279
280
281
282
283 private void blocked256(PassData passData) {
284 final double[] data = passData.in;
285 final double[] ret = passData.out;
286 int sign = passData.sign;
287 int i = passData.inOffset;
288 int j = passData.outOffset;
289
290 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
291 DoubleVector
292 z0_r = fromArray(SPECIES_256, data, i),
293 z1_r = fromArray(SPECIES_256, data, i + di),
294 z0_i = fromArray(SPECIES_256, data, i + im),
295 z1_i = fromArray(SPECIES_256, data, i + di + im);
296 z0_r.add(z1_r).intoArray(ret, j);
297 z0_i.add(z1_i).intoArray(ret, j + im);
298 z0_r.sub(z1_r).intoArray(ret, j + dj);
299 z0_i.sub(z1_i).intoArray(ret, j + dj + im);
300 }
301
302 j += dj;
303 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
304 final double[] twids = twiddles[k];
305 DoubleVector
306 w_r = broadcast(SPECIES_256, twids[0]),
307 w_i = broadcast(SPECIES_256, -sign * twids[1]);
308 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
309 DoubleVector
310 z0_r = fromArray(SPECIES_256, data, i),
311 z1_r = fromArray(SPECIES_256, data, i + di),
312 z0_i = fromArray(SPECIES_256, data, i + im),
313 z1_i = fromArray(SPECIES_256, data, i + di + im);
314 z0_r.add(z1_r).intoArray(ret, j);
315 z0_i.add(z1_i).intoArray(ret, j + im);
316 DoubleVector x_r = z0_r.sub(z1_r), x_i = z0_i.sub(z1_i);
317 w_r.mul(x_r).add(w_i.neg().mul(x_i)).intoArray(ret, j + dj);
318 w_r.mul(x_i).add(w_i.mul(x_r)).intoArray(ret, j + dj + im);
319 }
320 }
321 }
322
323
324
325
326 private void blocked512(PassData passData) {
327 final double[] data = passData.in;
328 final double[] ret = passData.out;
329 int sign = passData.sign;
330 int i = passData.inOffset;
331 int j = passData.outOffset;
332
333 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
334 DoubleVector
335 z0_r = fromArray(SPECIES_512, data, i),
336 z1_r = fromArray(SPECIES_512, data, i + di),
337 z0_i = fromArray(SPECIES_512, data, i + im),
338 z1_i = fromArray(SPECIES_512, data, i + di + im);
339 z0_r.add(z1_r).intoArray(ret, j);
340 z0_i.add(z1_i).intoArray(ret, j + im);
341 z0_r.sub(z1_r).intoArray(ret, j + dj);
342 z0_i.sub(z1_i).intoArray(ret, j + dj + im);
343 }
344
345 j += dj;
346 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
347 final double[] twids = twiddles[k];
348 DoubleVector
349 w_r = broadcast(SPECIES_512, twids[0]),
350 w_i = broadcast(SPECIES_512, -sign * twids[1]);
351 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
352 DoubleVector
353 z0_r = fromArray(SPECIES_512, data, i),
354 z1_r = fromArray(SPECIES_512, data, i + di),
355 z0_i = fromArray(SPECIES_512, data, i + im),
356 z1_i = fromArray(SPECIES_512, data, i + di + im);
357 z0_r.add(z1_r).intoArray(ret, j);
358 z0_i.add(z1_i).intoArray(ret, j + im);
359 DoubleVector x_r = z0_r.sub(z1_r), x_i = z0_i.sub(z1_i);
360 w_r.mul(x_r).add(w_i.neg().mul(x_i)).intoArray(ret, j + dj);
361 w_r.mul(x_i).add(w_i.mul(x_r)).intoArray(ret, j + dj + im);
362 }
363 }
364 }
365
366
367
368
369 private void interleaved128(PassData passData) {
370 final double[] data = passData.in;
371 final double[] ret = passData.out;
372 int sign = passData.sign;
373 int i = passData.inOffset;
374 int j = passData.outOffset;
375
376 for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP_128, i += LENGTH_128, j += LENGTH_128) {
377 DoubleVector
378 z0 = fromArray(SPECIES_128, data, i),
379 z1 = fromArray(SPECIES_128, data, i + di);
380 z0.add(z1).intoArray(ret, j);
381 z0.sub(z1).intoArray(ret, j + dj);
382 }
383
384 j += dj;
385 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
386 final double[] twids = twiddles[k];
387 DoubleVector
388 wr = broadcast(SPECIES_128, twids[0]),
389 wi = broadcast(SPECIES_128, -sign * twids[1]).mul(NEGATE_IM_128);
390 for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP_128, i += LENGTH_128, j += LENGTH_128) {
391 DoubleVector
392 z0 = fromArray(SPECIES_128, data, i),
393 z1 = fromArray(SPECIES_128, data, i + di);
394 z0.add(z1).intoArray(ret, j);
395 DoubleVector x = z0.sub(z1);
396 x.mul(wr).add(x.mul(wi).rearrange(SHUFFLE_RE_IM_128)).intoArray(ret, j + dj);
397 }
398 }
399 }
400
401
402
403
404 private void interleaved256(PassData passData) {
405 final double[] data = passData.in;
406 final double[] ret = passData.out;
407 int sign = passData.sign;
408 int i = passData.inOffset;
409 int j = passData.outOffset;
410
411 for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP_256, i += LENGTH_256, j += LENGTH_256) {
412 DoubleVector
413 z0 = fromArray(SPECIES_256, data, i),
414 z1 = fromArray(SPECIES_256, data, i + di);
415 z0.add(z1).intoArray(ret, j);
416 z0.sub(z1).intoArray(ret, j + dj);
417 }
418
419 j += dj;
420 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
421 final double[] twids = twiddles[k];
422 DoubleVector
423 wr = broadcast(SPECIES_256, twids[0]),
424 wi = broadcast(SPECIES_256, -sign * twids[1]).mul(NEGATE_IM_256);
425 for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP_256, i += LENGTH_256, j += LENGTH_256) {
426 DoubleVector
427 z0 = fromArray(SPECIES_256, data, i),
428 z1 = fromArray(SPECIES_256, data, i + di);
429 z0.add(z1).intoArray(ret, j);
430 DoubleVector x = z0.sub(z1);
431 x.mul(wr).add(x.mul(wi).rearrange(SHUFFLE_RE_IM_256)).intoArray(ret, j + dj);
432 }
433 }
434 }
435
436
437
438
439 private void interleaved512(PassData passData) {
440 final double[] data = passData.in;
441 final double[] ret = passData.out;
442 int sign = passData.sign;
443 int i = passData.inOffset;
444 int j = passData.outOffset;
445
446 for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP_512, i += LENGTH_512, j += LENGTH_512) {
447 DoubleVector
448 z0 = fromArray(SPECIES_512, data, i),
449 z1 = fromArray(SPECIES_512, data, i + di);
450 z0.add(z1).intoArray(ret, j);
451 z0.sub(z1).intoArray(ret, j + dj);
452 }
453
454 j += dj;
455 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
456 final double[] twids = twiddles[k];
457 DoubleVector
458 wr = broadcast(SPECIES_512, twids[0]),
459 wi = broadcast(SPECIES_512, -sign * twids[1]).mul(NEGATE_IM_512);
460 for (int k1 = 0; k1 < innerLoopLimit; k1 += LOOP_512, i += LENGTH_512, j += LENGTH_512) {
461 DoubleVector
462 z0 = fromArray(SPECIES_512, data, i),
463 z1 = fromArray(SPECIES_512, data, i + di);
464 z0.add(z1).intoArray(ret, j);
465 DoubleVector x = z0.sub(z1);
466 x.mul(wr).add(x.mul(wi).rearrange(SHUFFLE_RE_IM_512)).intoArray(ret, j + dj);
467 }
468 }
469 }
470
471 }