1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import jdk.incubator.vector.DoubleVector;
41
42 import static jdk.incubator.vector.DoubleVector.SPECIES_128;
43 import static jdk.incubator.vector.DoubleVector.SPECIES_256;
44 import static jdk.incubator.vector.DoubleVector.SPECIES_512;
45 import static jdk.incubator.vector.DoubleVector.broadcast;
46 import static jdk.incubator.vector.DoubleVector.fromArray;
47
48
49
50
51 public class MixedRadixFactor2 extends MixedRadixFactor {
52
53
54
55
56
57
58 public MixedRadixFactor2(PassConstants passConstants) {
59 super(passConstants);
60 }
61
62
63
64
65
66
67 @Override
68 protected void passScalar(PassData passData) {
69 final double[] data = passData.in;
70 final double[] ret = passData.out;
71 int sign = passData.sign;
72 int i = passData.inOffset;
73 int j = passData.outOffset;
74
75 for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
76 final double z0_r = data[i];
77 final double z0_i = data[i + im];
78 final int idi = i + di;
79 final double z1_r = data[idi];
80 final double z1_i = data[idi + im];
81 ret[j] = z0_r + z1_r;
82 ret[j + im] = z0_i + z1_i;
83 final double x_r = z0_r - z1_r;
84 final double x_i = z0_i - z1_i;
85 final int jdj = j + dj;
86 ret[jdj] = x_r;
87 ret[jdj + im] = x_i;
88 }
89 j += dj;
90 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
91 final double w_r = wr[k];
92 final double w_i = -sign * wi[k];
93 for (int k1 = 0; k1 < innerLoopLimit; k1++, i += ii, j += ii) {
94 final double z0_r = data[i];
95 final double z0_i = data[i + im];
96 final int idi = i + di;
97 final double z1_r = data[idi];
98 final double z1_i = data[idi + im];
99 ret[j] = z0_r + z1_r;
100 ret[j + im] = z0_i + z1_i;
101 final int jdj = j + dj;
102 multiplyAndStore(z0_r - z1_r, z0_i - z1_i, w_r, w_i, ret, jdj, jdj + im);
103 }
104 }
105 }
106
107
108
109
110
111
112 @Override
113 protected void passSIMD(PassData passData) {
114 if (im == 1) {
115 interleaved(passData);
116 } else {
117 blocked(passData);
118 }
119 }
120
121
122
123
124 private static int[] simdSizes = {8, 4, 2};
125
126
127
128
129
130
131
132 private void interleaved(PassData passData, int simdLength) {
133
134 switch (simdLength) {
135 case 2:
136
137 interleaved128(passData);
138 break;
139 case 4:
140
141 interleaved256(passData);
142 break;
143 case 8:
144
145 interleaved512(passData);
146 break;
147 default:
148 passScalar(passData);
149 }
150 }
151
152
153
154
155
156
157
158 private void blocked(PassData passData, int simdLength) {
159
160 switch (simdLength) {
161 case 2:
162
163 blocked128(passData);
164 break;
165 case 4:
166
167 blocked256(passData);
168 break;
169 case 8:
170
171 blocked512(passData);
172 break;
173 default:
174 passScalar(passData);
175 }
176 }
177
178
179
180
181
182
183 private void interleaved(PassData passData) {
184 if (innerLoopLimit % INTERLEAVED_LOOP == 0) {
185
186 interleaved(passData, LENGTH);
187 } else {
188
189 if (innerLoopLimit % 2 != 0 && innerLoopLimit != 1) {
190 passScalar(passData);
191 return;
192 }
193
194 for (int size : simdSizes) {
195 if (size >= LENGTH) {
196
197 continue;
198 }
199
200 if (innerLoopLimit % (size / 2) == 0) {
201 interleaved(passData, size);
202 }
203 }
204 }
205 }
206
207
208
209
210
211
212 private void blocked(PassData passData) {
213 if (innerLoopLimit % BLOCK_LOOP == 0) {
214
215 blocked(passData, LENGTH);
216 } else {
217
218 if (innerLoopLimit % 2 != 0) {
219 passScalar(passData);
220 return;
221 }
222
223 for (int size : simdSizes) {
224 if (size >= LENGTH) {
225
226 continue;
227 }
228
229 if (innerLoopLimit % size == 0) {
230 blocked(passData, size);
231 }
232 }
233 }
234 }
235
236
237
238
239 private void blocked128(PassData passData) {
240 final double[] data = passData.in;
241 final double[] ret = passData.out;
242 int sign = passData.sign;
243 int i = passData.inOffset;
244 int j = passData.outOffset;
245
246 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
247 DoubleVector
248 z0_r = fromArray(SPECIES_128, data, i),
249 z1_r = fromArray(SPECIES_128, data, i + di),
250 z0_i = fromArray(SPECIES_128, data, i + im),
251 z1_i = fromArray(SPECIES_128, data, i + di + im);
252 z0_r.add(z1_r).intoArray(ret, j);
253 z0_i.add(z1_i).intoArray(ret, j + im);
254 z0_r.sub(z1_r).intoArray(ret, j + dj);
255 z0_i.sub(z1_i).intoArray(ret, j + dj + im);
256 }
257
258 j += dj;
259 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
260 final DoubleVector
261 w_r = broadcast(SPECIES_128, wr[k]),
262 w_i = broadcast(SPECIES_128, -sign * wi[k]);
263 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
264 final DoubleVector
265 z0_r = fromArray(SPECIES_128, data, i),
266 z1_r = fromArray(SPECIES_128, data, i + di),
267 z0_i = fromArray(SPECIES_128, data, i + im),
268 z1_i = fromArray(SPECIES_128, data, i + di + im);
269 z0_r.add(z1_r).intoArray(ret, j);
270 z0_i.add(z1_i).intoArray(ret, j + im);
271 DoubleVector x_r = z0_r.sub(z1_r), x_i = z0_i.sub(z1_i);
272 w_r.mul(x_r).add(w_i.neg().mul(x_i)).intoArray(ret, j + dj);
273 w_r.mul(x_i).add(w_i.mul(x_r)).intoArray(ret, j + dj + im);
274 }
275 }
276 }
277
278
279
280
281 private void blocked256(PassData passData) {
282 final double[] data = passData.in;
283 final double[] ret = passData.out;
284 int sign = passData.sign;
285 int i = passData.inOffset;
286 int j = passData.outOffset;
287
288 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
289 DoubleVector
290 z0_r = fromArray(SPECIES_256, data, i),
291 z1_r = fromArray(SPECIES_256, data, i + di),
292 z0_i = fromArray(SPECIES_256, data, i + im),
293 z1_i = fromArray(SPECIES_256, data, i + di + im);
294 z0_r.add(z1_r).intoArray(ret, j);
295 z0_i.add(z1_i).intoArray(ret, j + im);
296 z0_r.sub(z1_r).intoArray(ret, j + dj);
297 z0_i.sub(z1_i).intoArray(ret, j + dj + im);
298 }
299
300 j += dj;
301 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
302 final DoubleVector
303 w_r = broadcast(SPECIES_256, wr[k]),
304 w_i = broadcast(SPECIES_256, -sign * wi[k]);
305 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
306 DoubleVector
307 z0_r = fromArray(SPECIES_256, data, i),
308 z1_r = fromArray(SPECIES_256, data, i + di),
309 z0_i = fromArray(SPECIES_256, data, i + im),
310 z1_i = fromArray(SPECIES_256, data, i + di + im);
311 z0_r.add(z1_r).intoArray(ret, j);
312 z0_i.add(z1_i).intoArray(ret, j + im);
313 DoubleVector x_r = z0_r.sub(z1_r), x_i = z0_i.sub(z1_i);
314 w_r.mul(x_r).add(w_i.neg().mul(x_i)).intoArray(ret, j + dj);
315 w_r.mul(x_i).add(w_i.mul(x_r)).intoArray(ret, j + dj + im);
316 }
317 }
318 }
319
320
321
322
323 private void blocked512(PassData passData) {
324 final double[] data = passData.in;
325 final double[] ret = passData.out;
326 int sign = passData.sign;
327 int i = passData.inOffset;
328 int j = passData.outOffset;
329
330 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
331 DoubleVector
332 z0_r = fromArray(SPECIES_512, data, i),
333 z1_r = fromArray(SPECIES_512, data, i + di),
334 z0_i = fromArray(SPECIES_512, data, i + im),
335 z1_i = fromArray(SPECIES_512, data, i + di + im);
336 z0_r.add(z1_r).intoArray(ret, j);
337 z0_i.add(z1_i).intoArray(ret, j + im);
338 z0_r.sub(z1_r).intoArray(ret, j + dj);
339 z0_i.sub(z1_i).intoArray(ret, j + dj + im);
340 }
341
342 j += dj;
343 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
344 final DoubleVector
345 w_r = broadcast(SPECIES_512, wr[k]),
346 w_i = broadcast(SPECIES_512, -sign * wi[k]);
347 for (int k1 = 0; k1 < innerLoopLimit; k1 += BLOCK_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
348 DoubleVector
349 z0_r = fromArray(SPECIES_512, data, i),
350 z1_r = fromArray(SPECIES_512, data, i + di),
351 z0_i = fromArray(SPECIES_512, data, i + im),
352 z1_i = fromArray(SPECIES_512, data, i + di + im);
353 z0_r.add(z1_r).intoArray(ret, j);
354 z0_i.add(z1_i).intoArray(ret, j + im);
355 DoubleVector x_r = z0_r.sub(z1_r), x_i = z0_i.sub(z1_i);
356 w_r.mul(x_r).add(w_i.neg().mul(x_i)).intoArray(ret, j + dj);
357 w_r.mul(x_i).add(w_i.mul(x_r)).intoArray(ret, j + dj + im);
358 }
359 }
360 }
361
362
363
364
365 private void interleaved128(PassData passData) {
366 final double[] data = passData.in;
367 final double[] ret = passData.out;
368 int sign = passData.sign;
369 int i = passData.inOffset;
370 int j = passData.outOffset;
371
372 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
373 DoubleVector
374 z0 = fromArray(SPECIES_128, data, i),
375 z1 = fromArray(SPECIES_128, data, i + di);
376 z0.add(z1).intoArray(ret, j);
377 z0.sub(z1).intoArray(ret, j + dj);
378 }
379
380 j += dj;
381 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
382 final DoubleVector
383 w_r = broadcast(SPECIES_128, wr[k]),
384 w_i = broadcast(SPECIES_128, -sign * wi[k]).mul(NEGATE_IM_128);
385 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_128, i += LENGTH_128, j += LENGTH_128) {
386 DoubleVector
387 z0 = fromArray(SPECIES_128, data, i),
388 z1 = fromArray(SPECIES_128, data, i + di);
389 z0.add(z1).intoArray(ret, j);
390 DoubleVector x = z0.sub(z1);
391 x.mul(w_r).add(x.mul(w_i).rearrange(SHUFFLE_RE_IM_128)).intoArray(ret, j + dj);
392 }
393 }
394 }
395
396
397
398
399 private void interleaved256(PassData passData) {
400 final double[] data = passData.in;
401 final double[] ret = passData.out;
402 int sign = passData.sign;
403 int i = passData.inOffset;
404 int j = passData.outOffset;
405
406 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
407 DoubleVector
408 z0 = fromArray(SPECIES_256, data, i),
409 z1 = fromArray(SPECIES_256, data, i + di);
410 z0.add(z1).intoArray(ret, j);
411 z0.sub(z1).intoArray(ret, j + dj);
412 }
413
414 j += dj;
415 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
416 final DoubleVector
417 w_r = broadcast(SPECIES_256, wr[k]),
418 w_i = broadcast(SPECIES_256, -sign * wi[k]).mul(NEGATE_IM_256);
419 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_256, i += LENGTH_256, j += LENGTH_256) {
420 DoubleVector
421 z0 = fromArray(SPECIES_256, data, i),
422 z1 = fromArray(SPECIES_256, data, i + di);
423 z0.add(z1).intoArray(ret, j);
424 DoubleVector x = z0.sub(z1);
425 x.mul(w_r).add(x.mul(w_i).rearrange(SHUFFLE_RE_IM_256)).intoArray(ret, j + dj);
426 }
427 }
428 }
429
430
431
432
433 private void interleaved512(PassData passData) {
434 final double[] data = passData.in;
435 final double[] ret = passData.out;
436 int sign = passData.sign;
437 int i = passData.inOffset;
438 int j = passData.outOffset;
439
440 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
441 DoubleVector
442 z0 = fromArray(SPECIES_512, data, i),
443 z1 = fromArray(SPECIES_512, data, i + di);
444 z0.add(z1).intoArray(ret, j);
445 z0.sub(z1).intoArray(ret, j + dj);
446 }
447
448 j += dj;
449 for (int k = 1; k < outerLoopLimit; k++, j += dj) {
450 final DoubleVector
451 w_r = broadcast(SPECIES_512, wr[k]),
452 w_i = broadcast(SPECIES_512, -sign * wi[k]).mul(NEGATE_IM_512);
453 for (int k1 = 0; k1 < innerLoopLimit; k1 += INTERLEAVED_LOOP_512, i += LENGTH_512, j += LENGTH_512) {
454 DoubleVector
455 z0 = fromArray(SPECIES_512, data, i),
456 z1 = fromArray(SPECIES_512, data, i + di);
457 z0.add(z1).intoArray(ret, j);
458 DoubleVector x = z0.sub(z1);
459 x.mul(w_r).add(x.mul(w_i).rearrange(SHUFFLE_RE_IM_512)).intoArray(ret, j + dj);
460 }
461 }
462 }
463
464 }