1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import java.util.Arrays;
41 import java.util.Random;
42 import java.util.Vector;
43 import java.util.logging.Level;
44 import java.util.logging.Logger;
45
46 import static java.lang.Integer.max;
47 import static java.lang.Math.fma;
48 import static java.lang.System.arraycopy;
49 import static org.apache.commons.math3.util.FastMath.PI;
50 import static org.apache.commons.math3.util.FastMath.cos;
51 import static org.apache.commons.math3.util.FastMath.sin;
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78 public class Complex {
79
80 private static final Logger logger = Logger.getLogger(Complex.class.getName());
81
82
83
84 private static final int[] availableFactors = {7, 6, 5, 4, 3, 2};
85 private static final int firstUnavailablePrime = 11;
86
87
88
89 private final int n;
90
91
92
93 private final int nFFTs;
94
95
96
97
98
99 private final int externalIm;
100
101
102
103
104
105
106 private final int im;
107
108
109
110 private final int ii;
111
112
113
114 private final int[] factors;
115
116
117
118 private final double[][][] twiddle;
119
120
121
122 private final double[] packedData;
123
124
125
126 private final double[] scratch;
127
128
129
130 private final MixedRadixFactor[] mixedRadixFactors;
131
132
133
134
135 private final PassData[] passData;
136
137
138
139 private boolean useSIMD;
140
141
142
143 private int minSIMDLoopLength;
144
145
146
147
148 private static int lastN = -1;
149
150
151
152 private static int lastIm = -1;
153
154
155
156 private static int lastNFFTs = -1;
157
158
159
160 private static int[] factorsCache;
161
162
163
164 private static double[][][] twiddleCache = null;
165
166
167
168 private static MixedRadixFactor[] mixedRadixFactorsCache = null;
169
170
171
172
173
174
175
176
177 public Complex(int n) {
178 this(n, DataLayout1D.INTERLEAVED, 1);
179 }
180
181
182
183
184
185
186
187
188
189
190
191 public Complex(int n, DataLayout1D dataLayout, int imOffset) {
192 this(n, dataLayout, imOffset, 1);
193 }
194
195
196
197
198
199
200
201
202
203
204
205
206 public Complex(int n, DataLayout1D dataLayout, int imOffset, int nFFTs) {
207 assert (n > 1);
208 this.n = n;
209 this.nFFTs = nFFTs;
210 this.externalIm = imOffset;
211
212
213
214
215
216
217
218 if (dataLayout == DataLayout1D.INTERLEAVED) {
219 im = 1;
220 ii = 2;
221 } else {
222 im = n * nFFTs;
223 ii = 1;
224 }
225 packedData = new double[2 * n * nFFTs];
226 scratch = new double[2 * n * nFFTs];
227 passData = new PassData[2];
228 passData[0] = new PassData(1, packedData, 0, scratch, 0);
229 passData[1] = new PassData(1, packedData, 0, scratch, 0);
230
231
232
233
234
235 synchronized (Complex.class) {
236
237 if (this.n == lastN && this.im == lastIm && this.nFFTs == lastNFFTs) {
238 factors = factorsCache;
239 twiddle = twiddleCache;
240 mixedRadixFactors = mixedRadixFactorsCache;
241 } else {
242
243 factors = factor(n);
244 twiddle = wavetable(n, factors);
245 mixedRadixFactors = new MixedRadixFactor[factors.length];
246 lastN = this.n;
247 lastIm = this.im;
248 lastNFFTs = this.nFFTs;
249 factorsCache = factors;
250 twiddleCache = twiddle;
251 mixedRadixFactorsCache = mixedRadixFactors;
252
253 int product = 1;
254 for (int i = 0; i < factors.length; i++) {
255 final int factor = factors[i];
256 product *= factor;
257 PassConstants passConstants = new PassConstants(n, im, nFFTs, factor, product, twiddle[i]);
258 switch (factor) {
259 case 2 -> mixedRadixFactors[i] = new MixedRadixFactor2(passConstants);
260 case 3 -> mixedRadixFactors[i] = new MixedRadixFactor3(passConstants);
261 case 4 -> mixedRadixFactors[i] = new MixedRadixFactor4(passConstants);
262 case 5 -> mixedRadixFactors[i] = new MixedRadixFactor5(passConstants);
263 case 6 -> mixedRadixFactors[i] = new MixedRadixFactor6(passConstants);
264 case 7 -> mixedRadixFactors[i] = new MixedRadixFactor7(passConstants);
265 default -> {
266 if (dataLayout == DataLayout1D.BLOCKED) {
267 throw new IllegalArgumentException(
268 " Prime factors greater than 7 are only supported for interleaved data: " + factor);
269 }
270 mixedRadixFactors[i] = new MixedRadixFactorPrime(passConstants);
271 }
272 }
273 }
274 }
275
276
277 useSIMD = false;
278 String simd = System.getProperty("fft.useSIMD", Boolean.toString(useSIMD));
279 try {
280 useSIMD = Boolean.parseBoolean(simd);
281 } catch (Exception e) {
282 logger.info(" Invalid value for fft.useSIMD: " + simd);
283 useSIMD = false;
284 }
285
286
287
288
289
290 if (im == 1) {
291
292 minSIMDLoopLength = MixedRadixFactor.LENGTH / 2;
293 } else {
294
295 minSIMDLoopLength = MixedRadixFactor.LENGTH;
296 }
297 String loop = System.getProperty("fft.minLoop", Integer.toString(minSIMDLoopLength));
298 try {
299 minSIMDLoopLength = max(minSIMDLoopLength, Integer.parseInt(loop));
300 } catch (Exception e) {
301 logger.info(" Invalid value for fft.minLoop: " + loop);
302 if (im == 1) {
303
304 minSIMDLoopLength = MixedRadixFactor.LENGTH / 2;
305 } else {
306
307 minSIMDLoopLength = MixedRadixFactor.LENGTH;
308 }
309 }
310 }
311 }
312
313
314
315
316
317
318 @Override
319 public String toString() {
320 StringBuilder sb = new StringBuilder(" Complex FFT: n = " + n + ", nFFTs = " + nFFTs + ", im = " + externalIm);
321 sb.append("\n Factors: ").append(Arrays.toString(factors));
322 return sb.toString();
323 }
324
325
326
327
328
329
330 public void setUseSIMD(boolean useSIMD) {
331 this.useSIMD = useSIMD;
332 }
333
334
335
336
337
338
339
340
341
342 public void setMinSIMDLoopLength(int minSIMDLoopLength) {
343 if (im == 1 && minSIMDLoopLength < 1) {
344 throw new IllegalArgumentException(" Minimum SIMD loop length for interleaved data is 1 or greater.");
345 }
346 if (im > 2 && minSIMDLoopLength < 2) {
347 throw new IllegalArgumentException(" Minimum SIMD loop length for blocked data is 2 or greater.");
348 }
349 this.minSIMDLoopLength = minSIMDLoopLength;
350 }
351
352
353
354
355
356
357
358 public static boolean preferredDimension(int dim) {
359 if (dim < 2) {
360 return false;
361 }
362
363
364 for (int factor : availableFactors) {
365 while ((dim % factor) == 0) {
366 dim /= factor;
367 }
368 }
369 return dim <= 1;
370 }
371
372
373
374
375
376
377 public int[] getFactors() {
378 return factors;
379 }
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396 public void fft(double[] data, int offset, int stride) {
397 transformInternal(data, offset, stride, -1, 2 * n);
398 }
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418 public void fft(double[] data, int offset, int stride, int nextFFT) {
419 transformInternal(data, offset, stride, -1, nextFFT);
420 }
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435 public void ifft(double[] data, int offset, int stride) {
436 transformInternal(data, offset, stride, +1, 2 * n);
437 }
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457 public void ifft(double[] data, int offset, int stride, int nextFFT) {
458 transformInternal(data, offset, stride, +1, nextFFT);
459 }
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476 public void inverse(double[] data, int offset, int stride) {
477 inverse(data, offset, stride, 2 * n);
478 }
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498 public void inverse(double[] data, int offset, int stride, int nextFFT) {
499 ifft(data, offset, stride, nextFFT);
500
501
502 double norm = normalization();
503 int index = 0;
504 for (int f = 0; f < nFFTs; f++) {
505 for (int i = 0; i < 2 * n; i++) {
506 data[index++] *= norm;
507 }
508 }
509 }
510
511
512
513
514
515
516
517
518
519
520 private void transformInternal(
521 final double[] data, final int offset, final int stride, final int sign, final int nextFFT) {
522
523
524 passData[0].sign = sign;
525 passData[0].in = data;
526 passData[0].inOffset = offset;
527 passData[0].out = scratch;
528 passData[0].outOffset = 0;
529
530 passData[1].sign = sign;
531 passData[1].in = scratch;
532 passData[1].inOffset = 0;
533 passData[1].out = data;
534 passData[1].outOffset = offset;
535
536
537 boolean packed = false;
538 if (stride > 2 || externalIm > n * nFFTs) {
539
540
541 packed = true;
542 pack(data, offset, stride, nextFFT);
543
544 passData[0].in = packedData;
545 passData[0].inOffset = 0;
546
547 passData[1].out = packedData;
548 passData[1].outOffset = 0;
549 }
550
551
552 final int nfactors = factors.length;
553 for (int i = 0; i < nfactors; i++) {
554 final int pass = i % 2;
555 MixedRadixFactor mixedRadixFactor = mixedRadixFactors[i];
556 if (useSIMD && mixedRadixFactor.innerLoopLimit >= minSIMDLoopLength) {
557 mixedRadixFactor.passSIMD(passData[pass]);
558 } else {
559 mixedRadixFactor.passScalar(passData[pass]);
560 }
561
562 }
563
564
565 if (nfactors % 2 == 1) {
566
567 if (stride <= 2 && (im == externalIm) && nextFFT == 2 * n) {
568 arraycopy(scratch, 0, data, offset, 2 * n * nFFTs);
569 } else {
570 unpack(scratch, data, offset, stride, nextFFT);
571 }
572
573 } else if (packed) {
574 unpack(packedData, data, offset, stride, nextFFT);
575 }
576 }
577
578
579
580
581
582
583
584
585
586 private void pack(double[] data, int offset, int stride, int nextFFT) {
587 int i = 0;
588 for (int f = 0; f < nFFTs; f++) {
589 int inputOffset = offset + f * nextFFT;
590 for (int index = inputOffset, k = 0; k < n; k++, i += ii, index += stride) {
591 packedData[i] = data[index];
592 packedData[i + im] = data[index + externalIm];
593 }
594 }
595 }
596
597
598
599
600
601
602
603
604
605 private void unpack(double[] source, double[] data, int offset, int stride, int nextFFT) {
606 int i = 0;
607 for (int f = 0; f < nFFTs; f++) {
608 int outputOffset = offset + f * nextFFT;
609 for (int index = outputOffset, k = 0; k < n; k++, i += ii, index += stride) {
610 data[index] = source[i];
611 data[index + externalIm] = source[i + im];
612 }
613 }
614 }
615
616
617
618
619
620
621
622 private double normalization() {
623 return 1.0 / n;
624 }
625
626
627
628
629
630
631
632
633 private static int[] factor(int n) {
634 if (n < 2) {
635 return null;
636 }
637 Vector<Integer> v = new Vector<>();
638 int nTest = n;
639
640
641 for (int factor : availableFactors) {
642 while ((nTest % factor) == 0) {
643 nTest /= factor;
644 v.add(factor);
645 }
646 }
647
648
649 int factor = firstUnavailablePrime;
650 while (nTest > 1) {
651 while ((nTest % factor) != 0) {
652 factor += 2;
653 }
654 nTest /= factor;
655 v.add(factor);
656 }
657 int product = 1;
658 int nf = v.size();
659 int[] ret = new int[nf];
660 for (int i = 0; i < nf; i++) {
661 ret[i] = v.get(i);
662 product *= ret[i];
663 }
664
665
666 if (product != n) {
667 StringBuilder sb = new StringBuilder(" FFT factorization failed for " + n + "\n");
668 for (int i = 0; i < nf; i++) {
669 sb.append(" ");
670 sb.append(ret[i]);
671 }
672 sb.append("\n");
673 sb.append(" Factor product = ");
674 sb.append(product);
675 sb.append("\n");
676 logger.severe(sb.toString());
677 System.exit(-1);
678 } else {
679 if (logger.isLoggable(Level.FINEST)) {
680 StringBuilder sb = new StringBuilder(" FFT factorization for " + n + " = ");
681 for (int i = 0; i < nf - 1; i++) {
682 sb.append(ret[i]);
683 sb.append(" * ");
684 }
685 sb.append(ret[nf - 1]);
686 logger.finest(sb.toString());
687 }
688 }
689 return ret;
690 }
691
692
693
694
695
696
697
698
699 private static double[][][] wavetable(int n, int[] factors) {
700 if (n < 2) {
701 return null;
702 }
703
704
705
706
707 final double TwoPI_N = -2.0 * PI / n;
708 final double[][][] ret = new double[factors.length][][];
709 int product = 1;
710 for (int i = 0; i < factors.length; i++) {
711 int factor = factors[i];
712 int product_1 = product;
713 product *= factor;
714
715 int outLoopLimit = n / product;
716
717 if (factor >= firstUnavailablePrime) {
718 outLoopLimit += 1;
719 }
720 final int nTwiddle = factor - 1;
721
722
723 ret[i] = new double[outLoopLimit][2 * nTwiddle];
724
725 final double[][] twid = ret[i];
726 for (int j = 0; j < factor - 1; j++) {
727 twid[0][2 * j] = 1.0;
728 twid[0][2 * j + 1] = 0.0;
729
730 }
731 for (int k = 1; k < outLoopLimit; k++) {
732 int m = 0;
733 for (int j = 0; j < nTwiddle; j++) {
734 m += k * product_1;
735 m %= n;
736 final double theta = TwoPI_N * m;
737 twid[k][2 * j] = cos(theta);
738 twid[k][2 * j + 1] = sin(theta);
739
740 }
741 }
742 }
743 return ret;
744 }
745
746
747
748
749
750
751
752 public static void dft(double[] in, double[] out) {
753 int n = in.length / 2;
754 for (int k = 0; k < n; k++) {
755 double sumReal = 0;
756 double simImag = 0;
757 for (int t = 0; t < n; t++) {
758 double angle = (2 * PI * t * k) / n;
759 int re = 2 * t;
760 int im = 2 * t + 1;
761 sumReal = fma(in[re], cos(angle), sumReal);
762 sumReal = fma(in[im], sin(angle), sumReal);
763 simImag = fma(-in[re], sin(angle), simImag);
764 simImag = fma(in[im], cos(angle), simImag);
765 }
766 int re = 2 * k;
767 int im = 2 * k + 1;
768 out[re] = sumReal;
769 out[im] = simImag;
770 }
771 }
772
773
774
775
776
777
778
779 public static void dftBlocked(double[] in, double[] out) {
780 int n = in.length / 2;
781 for (int k = 0; k < n; k++) {
782 double sumReal = 0;
783 double simImag = 0;
784 for (int t = 0; t < n; t++) {
785 double angle = (2 * PI * t * k) / n;
786 int re = t;
787 int im = t + n;
788 sumReal = fma(in[re], cos(angle), sumReal);
789 sumReal = fma(in[im], sin(angle), sumReal);
790 simImag = fma(-in[re], sin(angle), simImag);
791 simImag = fma(in[im], cos(angle), simImag);
792 }
793 int re = k;
794 int im = k + n;
795 out[re] = sumReal;
796 out[im] = simImag;
797 }
798 }
799
800
801
802
803
804
805
806
807 public static void main(String[] args) throws Exception {
808 int dimNotFinal = 128;
809 int reps = 5;
810 try {
811 dimNotFinal = Integer.parseInt(args[0]);
812 if (dimNotFinal < 1) {
813 dimNotFinal = 100;
814 }
815 reps = Integer.parseInt(args[1]);
816 if (reps < 1) {
817 reps = 5;
818 }
819 } catch (Exception e) {
820
821 }
822 final int dim = dimNotFinal;
823 System.out.printf("Initializing a 1D array of length %d.\n"
824 + "The best timing out of %d repetitions will be used.%n", dim, reps);
825 Complex complex = new Complex(dim);
826 final double[] data = new double[dim * 2];
827 Random random = new Random(1);
828 for (int i = 0; i < dim; i++) {
829 data[2 * i] = random.nextDouble();
830 }
831 double toSeconds = 0.000000001;
832 long seqTime = Long.MAX_VALUE;
833 for (int i = 0; i < reps; i++) {
834 System.out.printf("Iteration %d%n", i + 1);
835 long time = System.nanoTime();
836 complex.fft(data, 0, 2);
837 complex.ifft(data, 0, 2);
838 time = (System.nanoTime() - time);
839 System.out.printf("Sequential: %12.9f%n", toSeconds * time);
840 if (time < seqTime) {
841 seqTime = time;
842 }
843 }
844 System.out.printf("Best Sequential Time: %12.9f%n", toSeconds * seqTime);
845 }
846
847 }