1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import edu.rit.pj.IntegerForLoop;
41 import edu.rit.pj.IntegerSchedule;
42 import edu.rit.pj.ParallelRegion;
43 import edu.rit.pj.ParallelTeam;
44
45 import javax.annotation.Nullable;
46 import java.util.Random;
47 import java.util.logging.Level;
48 import java.util.logging.Logger;
49
50 import static java.util.Objects.requireNonNullElseGet;
51
52
53
54
55
56
57
58
59
60
61 public class Real3DParallel {
62
63 private static final Logger logger = Logger.getLogger(Real3DParallel.class.getName());
64 private final int nX, nY, nZ, nZ2, nX1;
65 private final int n, nextX, nextY, nextZ;
66 private final ParallelTeam parallelTeam;
67 private final int threadCount;
68 private final ParallelIFFT parallelIFFT;
69 private final ParallelFFT parallelFFT;
70 private final ParallelConvolution parallelConvolution;
71 private final double[] recip;
72 private final IntegerSchedule schedule;
73
74
75
76
77
78
79
80
81
82
83 public Real3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam) {
84 this.nX = nX / 2;
85 this.nY = nY;
86 this.nZ = nZ;
87 this.parallelTeam = parallelTeam;
88 n = nX;
89 nX1 = this.nX + 1;
90 nZ2 = this.nZ * 2;
91 nextX = 2;
92 nextY = n + 2;
93 nextZ = nextY * nY;
94 recip = new double[nX1 * nY * nZ];
95 threadCount = parallelTeam.getThreadCount();
96 parallelFFT = new ParallelFFT();
97 parallelIFFT = new ParallelIFFT();
98 parallelConvolution = new ParallelConvolution();
99 schedule = IntegerSchedule.fixed();
100 }
101
102
103
104
105
106
107
108
109
110
111
112 public Real3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, @Nullable IntegerSchedule integerSchedule) {
113 this.nX = nX / 2;
114 this.nY = nY;
115 this.nZ = nZ;
116 this.parallelTeam = parallelTeam;
117 n = nX;
118 nX1 = this.nX + 1;
119 nZ2 = this.nZ * 2;
120 nextX = 2;
121 nextY = n + 2;
122 nextZ = nextY * nY;
123 recip = new double[nX1 * nY * nZ];
124 threadCount = parallelTeam.getThreadCount();
125 schedule = requireNonNullElseGet(integerSchedule, IntegerSchedule::fixed);
126 parallelFFT = new ParallelFFT();
127 parallelIFFT = new ParallelIFFT();
128 parallelConvolution = new ParallelConvolution();
129 }
130
131
132
133
134
135
136
137 public static void main(String[] args) {
138 int dimNotFinal = 128;
139 int nCPU = ParallelTeam.getDefaultThreadCount();
140 int reps = 5;
141 try {
142 dimNotFinal = Integer.parseInt(args[0]);
143 if (dimNotFinal < 1) {
144 dimNotFinal = 100;
145 }
146 nCPU = Integer.parseInt(args[1]);
147 if (nCPU < 1) {
148 nCPU = ParallelTeam.getDefaultThreadCount();
149 }
150 reps = Integer.parseInt(args[2]);
151 if (reps < 1) {
152 reps = 5;
153 }
154 } catch (Exception e) {
155
156 }
157 if (dimNotFinal % 2 != 0) {
158 dimNotFinal++;
159 }
160 final int dim = dimNotFinal;
161 System.out.printf("Initializing a %d cubed grid for %d CPUs.\n"
162 + "The best timing out of %d repetitions will be used.%n", dim, nCPU, reps);
163
164 Real3D real3D = new Real3D(dim, dim, dim);
165 ParallelTeam parallelTeam = new ParallelTeam(nCPU);
166 Real3DParallel real3DParallel = new Real3DParallel(dim, dim, dim, parallelTeam);
167
168 final int dimCubed = (dim + 2) * dim * dim;
169 final double[] data = new double[dimCubed];
170 final double[] work = new double[dimCubed];
171
172
173 try {
174 parallelTeam.execute(
175 new ParallelRegion() {
176 @Override
177 public void run() {
178 try {
179 execute(
180 0,
181 dim - 1,
182 new IntegerForLoop() {
183 @Override
184 public void run(int lb, int ub) {
185 Random randomNumberGenerator = new Random(1);
186 int index = dim * dim * lb;
187 for (int z = lb; z <= ub; z++) {
188 for (int y = 0; y < dim; y++) {
189 for (int x = 0; x < dim; x++) {
190 double randomNumber = randomNumberGenerator.nextDouble();
191 data[index] = randomNumber;
192 index++;
193 }
194 }
195 }
196 }
197 });
198 } catch (Exception e) {
199 System.out.println(e.getMessage());
200 System.exit(-1);
201 }
202 }
203 });
204 } catch (Exception e) {
205 System.out.println(e.getMessage());
206 System.exit(-1);
207 }
208
209 double toSeconds = 0.000000001;
210 long parTime = Long.MAX_VALUE;
211 long seqTime = Long.MAX_VALUE;
212 real3D.setRecip(work);
213 real3DParallel.setRecip(work);
214 for (int i = 0; i < reps; i++) {
215 System.out.printf("Iteration %d%n", i + 1);
216 long time = System.nanoTime();
217 real3D.fft(data);
218 real3D.ifft(data);
219 time = (System.nanoTime() - time);
220 System.out.printf("Sequential: %8.3f%n", toSeconds * time);
221 if (time < seqTime) {
222 seqTime = time;
223 }
224 time = System.nanoTime();
225 real3D.convolution(data);
226 time = (System.nanoTime() - time);
227 System.out.printf("Sequential: %8.3f (Convolution)%n", toSeconds * time);
228 if (time < seqTime) {
229 seqTime = time;
230 }
231 time = System.nanoTime();
232 real3DParallel.fft(data);
233 real3DParallel.ifft(data);
234 time = (System.nanoTime() - time);
235 System.out.printf("Parallel: %8.3f%n", toSeconds * time);
236 if (time < parTime) {
237 parTime = time;
238 }
239 time = System.nanoTime();
240 real3DParallel.convolution(data);
241 time = (System.nanoTime() - time);
242 System.out.printf("Parallel: %8.3f (Convolution)\n%n", toSeconds * time);
243 if (time < parTime) {
244 parTime = time;
245 }
246 }
247 System.out.printf("Best Sequential Time: %8.3f%n", toSeconds * seqTime);
248 System.out.printf("Best Parallel Time: %8.3f%n", toSeconds * parTime);
249 System.out.printf("Speedup: %15.5f%n", (double) seqTime / parTime);
250 }
251
252
253
254
255
256
257
258 public void convolution(final double[] input) {
259 parallelConvolution.input = input;
260 try {
261 parallelTeam.execute(parallelConvolution);
262 } catch (Exception e) {
263 String message = "Fatal exception evaluating a 3D convolution in parallel.\n";
264 logger.log(Level.SEVERE, message, e);
265 System.exit(-1);
266 }
267 }
268
269
270
271
272
273
274
275 public void fft(final double[] input) {
276 parallelFFT.input = input;
277 try {
278 parallelTeam.execute(parallelFFT);
279 } catch (Exception e) {
280 String message = "Fatal exception evaluating real 3D FFT in parallel.\n";
281 logger.log(Level.SEVERE, message, e);
282 System.exit(-1);
283 }
284 }
285
286
287
288
289
290
291
292 public void ifft(final double[] input) {
293 parallelIFFT.input = input;
294 try {
295 parallelTeam.execute(parallelIFFT);
296 } catch (Exception e) {
297 String message = "Fatal exception evaluating real 3D inverse FFT in parallel.\n";
298 logger.log(Level.SEVERE, message, e);
299 System.exit(-1);
300 }
301 }
302
303
304
305
306
307
308 public void setRecip(double[] recip) {
309
310 for (int index = 0, offset = 0, y = 0; y < nY; y++) {
311 for (int x = 0; x < nX1; x++, offset += 1) {
312 for (int i = 0, z = offset; i < nZ; i++, z += nX1 * nY) {
313 this.recip[index++] = recip[z];
314 }
315 }
316 }
317 }
318
319
320
321
322
323
324
325 private class ParallelFFT extends ParallelRegion {
326
327 private final int nZm1;
328 private final FFTXYLoop[] fftXYLoop;
329 private final FFTZLoop[] fftZLoop;
330 public double[] input;
331
332 private ParallelFFT() {
333 nZm1 = nZ - 1;
334 fftXYLoop = new FFTXYLoop[threadCount];
335 fftZLoop = new FFTZLoop[threadCount];
336 for (int i = 0; i < threadCount; i++) {
337 fftXYLoop[i] = new FFTXYLoop();
338 fftZLoop[i] = new FFTZLoop();
339 }
340 }
341
342 @Override
343 public void run() {
344 int threadIndex = getThreadIndex();
345 fftXYLoop[threadIndex].input = input;
346 fftZLoop[threadIndex].input = input;
347 try {
348 execute(0, nZm1, fftXYLoop[threadIndex]);
349
350 execute(0, nX, fftZLoop[threadIndex]);
351 } catch (Exception e) {
352 logger.severe(e.toString());
353 }
354 }
355
356 private class FFTXYLoop extends IntegerForLoop {
357
358 private final Real fftX;
359 private final Complex fftY;
360 public double[] input;
361
362 private FFTXYLoop() {
363 fftY = new Complex(nY);
364 fftX = new Real(n);
365 }
366
367 @Override
368 public void run(final int lb, final int ub) {
369 for (int z = lb; z <= ub; z++) {
370 for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
371 fftX.fft(input, offset);
372 }
373 for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
374 fftY.fft(input, offset, nextY);
375 }
376 }
377 }
378
379 @Override
380 public IntegerSchedule schedule() {
381 return schedule;
382 }
383 }
384
385 private class FFTZLoop extends IntegerForLoop {
386
387 private final double[] work;
388 private final Complex fft;
389 public double[] input;
390
391 private FFTZLoop() {
392 work = new double[nZ2];
393 fft = new Complex(nZ);
394 }
395
396 @Override
397 public void run(final int lb, final int ub) {
398 for (int x = lb; x <= ub; x++) {
399 for (int offset = x * 2, y = 0; y < nY; y++, offset += nextY) {
400 for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
401 work[i] = input[z];
402 work[i + 1] = input[z + 1];
403 }
404 fft.fft(work, 0, 2);
405 for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
406 input[z] = work[i];
407 input[z + 1] = work[i + 1];
408 }
409 }
410 }
411 }
412
413 @Override
414 public IntegerSchedule schedule() {
415 return schedule;
416 }
417 }
418 }
419
420
421
422
423
424
425
426 private class ParallelIFFT extends ParallelRegion {
427
428 private final int nZm1;
429 private final IFFTXYLoop[] ifftXYLoop;
430 private final IFFTZLoop[] ifftZLoop;
431 public double[] input;
432
433 private ParallelIFFT() {
434 nZm1 = nZ - 1;
435 ifftXYLoop = new IFFTXYLoop[threadCount];
436 ifftZLoop = new IFFTZLoop[threadCount];
437 for (int i = 0; i < threadCount; i++) {
438 ifftXYLoop[i] = new IFFTXYLoop();
439 ifftZLoop[i] = new IFFTZLoop();
440 }
441 }
442
443 @Override
444 public void run() {
445 int threadIndex = getThreadIndex();
446 ifftXYLoop[threadIndex].input = input;
447 ifftZLoop[threadIndex].input = input;
448 try {
449
450 execute(0, nX, ifftZLoop[threadIndex]);
451 execute(0, nZm1, ifftXYLoop[threadIndex]);
452 } catch (Exception e) {
453 logger.severe(e.toString());
454 }
455 }
456
457 private class IFFTZLoop extends IntegerForLoop {
458
459 private final double[] work;
460 private final Complex fft;
461 public double[] input;
462
463 private IFFTZLoop() {
464 fft = new Complex(nZ);
465 work = new double[nZ2];
466 }
467
468 @Override
469 public void run(final int lb, final int ub) {
470 for (int x = lb; x <= ub; x++) {
471 for (int offset = x * 2, y = 0; y < nY; y++, offset += nextY) {
472 for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
473 work[i] = input[z];
474 work[i + 1] = input[z + 1];
475 }
476 fft.ifft(work, 0, 2);
477 for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
478 input[z] = work[i];
479 input[z + 1] = work[i + 1];
480 }
481 }
482 }
483 }
484
485 @Override
486 public IntegerSchedule schedule() {
487 return schedule;
488 }
489 }
490
491 private class IFFTXYLoop extends IntegerForLoop {
492
493 private final Real fftX;
494 private final Complex fftY;
495 public double[] input;
496
497 private IFFTXYLoop() {
498 fftX = new Real(n);
499 fftY = new Complex(nY);
500 }
501
502 @Override
503 public void run(final int lb, final int ub) {
504 for (int z = lb; z <= ub; z++) {
505 for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
506 fftY.ifft(input, offset, nextY);
507 }
508 for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
509 fftX.ifft(input, offset);
510 }
511 }
512 }
513
514 @Override
515 public IntegerSchedule schedule() {
516 return schedule;
517 }
518 }
519 }
520
521
522
523
524
525
526
527 private class ParallelConvolution extends ParallelRegion {
528
529 private final int nZm1, nYm1, nX1nZ;
530 private final FFTXYLoop[] fftXYLoop;
531 private final FFTZ_Multiply_IFFTZLoop[] fftZ_Multiply_ifftZLoop;
532 private final IFFTXYLoop[] ifftXYLoop;
533 public double[] input;
534
535 private ParallelConvolution() {
536 nZm1 = nZ - 1;
537 nYm1 = nY - 1;
538 nX1nZ = nX1 * nZ;
539 fftXYLoop = new FFTXYLoop[threadCount];
540 fftZ_Multiply_ifftZLoop = new FFTZ_Multiply_IFFTZLoop[threadCount];
541 ifftXYLoop = new IFFTXYLoop[threadCount];
542 for (int i = 0; i < threadCount; i++) {
543 fftXYLoop[i] = new FFTXYLoop();
544 fftZ_Multiply_ifftZLoop[i] = new FFTZ_Multiply_IFFTZLoop();
545 ifftXYLoop[i] = new IFFTXYLoop();
546 }
547 }
548
549 @Override
550 public void run() {
551 int threadIndex = getThreadIndex();
552 fftXYLoop[threadIndex].input = input;
553 fftZ_Multiply_ifftZLoop[threadIndex].input = input;
554 ifftXYLoop[threadIndex].input = input;
555 try {
556 execute(0, nZm1, fftXYLoop[threadIndex]);
557 execute(0, nYm1, fftZ_Multiply_ifftZLoop[threadIndex]);
558 execute(0, nZm1, ifftXYLoop[threadIndex]);
559 } catch (Exception e) {
560 logger.severe(e.toString());
561 }
562 }
563
564 private class FFTXYLoop extends IntegerForLoop {
565
566 private final Real fftX;
567 private final Complex fftY;
568 public double[] input;
569
570 private FFTXYLoop() {
571 fftY = new Complex(nY);
572 fftX = new Real(n);
573 }
574
575 @Override
576 public void run(final int lb, final int ub) {
577 for (int z = lb; z <= ub; z++) {
578 for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
579 fftX.fft(input, offset);
580 }
581 for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
582 fftY.fft(input, offset, nextY);
583 }
584 }
585 }
586
587 @Override
588 public IntegerSchedule schedule() {
589 return schedule;
590 }
591 }
592
593 private class FFTZ_Multiply_IFFTZLoop extends IntegerForLoop {
594
595 private final double[] work;
596 private final Complex fft;
597 public double[] input;
598
599 private FFTZ_Multiply_IFFTZLoop() {
600 work = new double[nZ2];
601 fft = new Complex(nZ);
602 }
603
604 @Override
605 public void run(final int lb, final int ub) {
606 int index = lb * nX1nZ;
607 for (int offset = lb * nextY, y = lb; y <= ub; y++) {
608 for (int x = 0; x < nX1; x++, offset += nextX) {
609 for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
610 work[i] = input[z];
611 work[i + 1] = input[z + 1];
612 }
613 fft.fft(work, 0, 2);
614 for (int i = 0; i < nZ2; i += 2) {
615 double r = recip[index++];
616 work[i] *= r;
617 work[i + 1] *= r;
618 }
619 fft.ifft(work, 0, 2);
620 for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
621 input[z] = work[i];
622 input[z + 1] = work[i + 1];
623 }
624 }
625 }
626 }
627
628 @Override
629 public IntegerSchedule schedule() {
630 return schedule;
631 }
632 }
633
634 private class IFFTXYLoop extends IntegerForLoop {
635
636 private final Real fftX;
637 private final Complex fftY;
638 public double[] input;
639
640 private IFFTXYLoop() {
641 fftY = new Complex(nY);
642 fftX = new Real(n);
643 }
644
645 @Override
646 public void run(final int lb, final int ub) {
647 for (int z = lb; z <= ub; z++) {
648 for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
649 fftY.ifft(input, offset, nextY);
650 }
651 for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
652 fftX.ifft(input, offset);
653 }
654 }
655 }
656
657 @Override
658 public IntegerSchedule schedule() {
659 return schedule;
660 }
661 }
662 }
663 }