1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.fft;
39
40 import edu.rit.pj.IntegerForLoop;
41 import edu.rit.pj.IntegerSchedule;
42 import edu.rit.pj.ParallelRegion;
43 import edu.rit.pj.ParallelTeam;
44
45 import javax.annotation.Nullable;
46 import java.util.Random;
47 import java.util.logging.Level;
48 import java.util.logging.Logger;
49
50 import static java.util.Objects.requireNonNullElseGet;
51
52
53
54
55
56
57
58
59 public class Real3DParallel {
60
61 private static final Logger logger = Logger.getLogger(Real3DParallel.class.getName());
62 private final int nX, nY, nZ, nZ2, nX1;
63 private final int n, nextX, nextY, nextZ;
64 private final ParallelTeam parallelTeam;
65 private final int threadCount;
66 private final ParallelIFFT parallelIFFT;
67 private final ParallelFFT parallelFFT;
68 private final ParallelConvolution parallelConvolution;
69 private final double[] recip;
70 private final IntegerSchedule schedule;
71
72
73
74
75
76
77
78
79
80
81 public Real3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam) {
82 this.nX = nX / 2;
83 this.nY = nY;
84 this.nZ = nZ;
85 this.parallelTeam = parallelTeam;
86 n = nX;
87 nX1 = this.nX + 1;
88 nZ2 = this.nZ * 2;
89 nextX = 2;
90 nextY = n + 2;
91 nextZ = nextY * nY;
92 recip = new double[nX1 * nY * nZ];
93 threadCount = parallelTeam.getThreadCount();
94 parallelFFT = new ParallelFFT();
95 parallelIFFT = new ParallelIFFT();
96 parallelConvolution = new ParallelConvolution();
97 schedule = IntegerSchedule.fixed();
98 }
99
100
101
102
103
104
105
106
107
108
109
110 public Real3DParallel(int nX, int nY, int nZ, ParallelTeam parallelTeam, @Nullable IntegerSchedule integerSchedule) {
111 this.nX = nX / 2;
112 this.nY = nY;
113 this.nZ = nZ;
114 this.parallelTeam = parallelTeam;
115 n = nX;
116 nX1 = this.nX + 1;
117 nZ2 = this.nZ * 2;
118 nextX = 2;
119 nextY = n + 2;
120 nextZ = nextY * nY;
121 recip = new double[nX1 * nY * nZ];
122 threadCount = parallelTeam.getThreadCount();
123 schedule = requireNonNullElseGet(integerSchedule, IntegerSchedule::fixed);
124 parallelFFT = new ParallelFFT();
125 parallelIFFT = new ParallelIFFT();
126 parallelConvolution = new ParallelConvolution();
127 }
128
129
130
131
132
133
134
135 public static void main(String[] args) {
136 int dimNotFinal = 128;
137 int nCPU = ParallelTeam.getDefaultThreadCount();
138 int reps = 5;
139 try {
140 dimNotFinal = Integer.parseInt(args[0]);
141 if (dimNotFinal < 1) {
142 dimNotFinal = 100;
143 }
144 nCPU = Integer.parseInt(args[1]);
145 if (nCPU < 1) {
146 nCPU = ParallelTeam.getDefaultThreadCount();
147 }
148 reps = Integer.parseInt(args[2]);
149 if (reps < 1) {
150 reps = 5;
151 }
152 } catch (Exception e) {
153
154 }
155 if (dimNotFinal % 2 != 0) {
156 dimNotFinal++;
157 }
158 final int dim = dimNotFinal;
159 System.out.printf("Initializing a %d cubed grid for %d CPUs.\n"
160 + "The best timing out of %d repetitions will be used.%n", dim, nCPU, reps);
161
162 Real3D real3D = new Real3D(dim, dim, dim);
163 ParallelTeam parallelTeam = new ParallelTeam(nCPU);
164 Real3DParallel real3DParallel = new Real3DParallel(dim, dim, dim, parallelTeam);
165
166 final int dimCubed = (dim + 2) * dim * dim;
167 final double[] data = new double[dimCubed];
168 final double[] work = new double[dimCubed];
169
170
171 try {
172 parallelTeam.execute(
173 new ParallelRegion() {
174 @Override
175 public void run() {
176 try {
177 execute(
178 0,
179 dim - 1,
180 new IntegerForLoop() {
181 @Override
182 public void run(int lb, int ub) {
183 Random randomNumberGenerator = new Random(1);
184 int index = dim * dim * lb;
185 for (int z = lb; z <= ub; z++) {
186 for (int y = 0; y < dim; y++) {
187 for (int x = 0; x < dim; x++) {
188 double randomNumber = randomNumberGenerator.nextDouble();
189 data[index] = randomNumber;
190 index++;
191 }
192 }
193 }
194 }
195 });
196 } catch (Exception e) {
197 System.out.println(e.getMessage());
198 System.exit(-1);
199 }
200 }
201 });
202 } catch (Exception e) {
203 System.out.println(e.getMessage());
204 System.exit(-1);
205 }
206
207 double toSeconds = 0.000000001;
208 long parTime = Long.MAX_VALUE;
209 long seqTime = Long.MAX_VALUE;
210 real3D.setRecip(work);
211 real3DParallel.setRecip(work);
212 for (int i = 0; i < reps; i++) {
213 System.out.printf("Iteration %d%n", i + 1);
214 long time = System.nanoTime();
215 real3D.fft(data);
216 real3D.ifft(data);
217 time = (System.nanoTime() - time);
218 System.out.printf("Sequential: %8.3f%n", toSeconds * time);
219 if (time < seqTime) {
220 seqTime = time;
221 }
222 time = System.nanoTime();
223 real3D.convolution(data);
224 time = (System.nanoTime() - time);
225 System.out.printf("Sequential: %8.3f (Convolution)%n", toSeconds * time);
226 if (time < seqTime) {
227 seqTime = time;
228 }
229 time = System.nanoTime();
230 real3DParallel.fft(data);
231 real3DParallel.ifft(data);
232 time = (System.nanoTime() - time);
233 System.out.printf("Parallel: %8.3f%n", toSeconds * time);
234 if (time < parTime) {
235 parTime = time;
236 }
237 time = System.nanoTime();
238 real3DParallel.convolution(data);
239 time = (System.nanoTime() - time);
240 System.out.printf("Parallel: %8.3f (Convolution)\n%n", toSeconds * time);
241 if (time < parTime) {
242 parTime = time;
243 }
244 }
245 System.out.printf("Best Sequential Time: %8.3f%n", toSeconds * seqTime);
246 System.out.printf("Best Parallel Time: %8.3f%n", toSeconds * parTime);
247 System.out.printf("Speedup: %15.5f%n", (double) seqTime / parTime);
248 }
249
250
251
252
253
254
255
256 public void convolution(final double[] input) {
257 parallelConvolution.input = input;
258 try {
259 parallelTeam.execute(parallelConvolution);
260 } catch (Exception e) {
261 String message = "Fatal exception evaluating a 3D convolution in parallel.\n";
262 logger.log(Level.SEVERE, message, e);
263 System.exit(-1);
264 }
265 }
266
267
268
269
270
271
272
273 public void fft(final double[] input) {
274 parallelFFT.input = input;
275 try {
276 parallelTeam.execute(parallelFFT);
277 } catch (Exception e) {
278 String message = "Fatal exception evaluating real 3D FFT in parallel.\n";
279 logger.log(Level.SEVERE, message, e);
280 System.exit(-1);
281 }
282 }
283
284
285
286
287
288
289
290 public void ifft(final double[] input) {
291 parallelIFFT.input = input;
292 try {
293 parallelTeam.execute(parallelIFFT);
294 } catch (Exception e) {
295 String message = "Fatal exception evaluating real 3D inverse FFT in parallel.\n";
296 logger.log(Level.SEVERE, message, e);
297 System.exit(-1);
298 }
299 }
300
301
302
303
304
305
306 public void setRecip(double[] recip) {
307
308 for (int index = 0, offset = 0, y = 0; y < nY; y++) {
309 for (int x = 0; x < nX1; x++, offset += 1) {
310 for (int i = 0, z = offset; i < nZ; i++, z += nX1 * nY) {
311 this.recip[index++] = recip[z];
312 }
313 }
314 }
315 }
316
317
318
319
320
321
322
323 private class ParallelFFT extends ParallelRegion {
324
325 private final int nZm1;
326 private final FFTXYLoop[] fftXYLoop;
327 private final FFTZLoop[] fftZLoop;
328 public double[] input;
329
330 private ParallelFFT() {
331 nZm1 = nZ - 1;
332 fftXYLoop = new FFTXYLoop[threadCount];
333 fftZLoop = new FFTZLoop[threadCount];
334 for (int i = 0; i < threadCount; i++) {
335 fftXYLoop[i] = new FFTXYLoop();
336 fftZLoop[i] = new FFTZLoop();
337 }
338 }
339
340 @Override
341 public void run() {
342 int threadIndex = getThreadIndex();
343 fftXYLoop[threadIndex].input = input;
344 fftZLoop[threadIndex].input = input;
345 try {
346 execute(0, nZm1, fftXYLoop[threadIndex]);
347
348 execute(0, nX, fftZLoop[threadIndex]);
349 } catch (Exception e) {
350 logger.severe(e.toString());
351 }
352 }
353
354 private class FFTXYLoop extends IntegerForLoop {
355
356 private final Real fftX;
357 private final Complex fftY;
358 public double[] input;
359
360 private FFTXYLoop() {
361 fftY = new Complex(nY);
362 fftX = new Real(n);
363 }
364
365 @Override
366 public void run(final int lb, final int ub) {
367 for (int z = lb; z <= ub; z++) {
368 for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
369 fftX.fft(input, offset);
370 }
371 for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
372 fftY.fft(input, offset, nextY);
373 }
374 }
375 }
376
377 @Override
378 public IntegerSchedule schedule() {
379 return schedule;
380 }
381 }
382
383 private class FFTZLoop extends IntegerForLoop {
384
385 private final double[] work;
386 private final Complex fft;
387 public double[] input;
388
389 private FFTZLoop() {
390 work = new double[nZ2];
391 fft = new Complex(nZ);
392 }
393
394 @Override
395 public void run(final int lb, final int ub) {
396 for (int x = lb; x <= ub; x++) {
397 for (int offset = x * 2, y = 0; y < nY; y++, offset += nextY) {
398 for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
399 work[i] = input[z];
400 work[i + 1] = input[z + 1];
401 }
402 fft.fft(work, 0, 2);
403 for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
404 input[z] = work[i];
405 input[z + 1] = work[i + 1];
406 }
407 }
408 }
409 }
410
411 @Override
412 public IntegerSchedule schedule() {
413 return schedule;
414 }
415 }
416 }
417
418
419
420
421
422
423
424 private class ParallelIFFT extends ParallelRegion {
425
426 private final int nZm1;
427 private final IFFTXYLoop[] ifftXYLoop;
428 private final IFFTZLoop[] ifftZLoop;
429 public double[] input;
430
431 private ParallelIFFT() {
432 nZm1 = nZ - 1;
433 ifftXYLoop = new IFFTXYLoop[threadCount];
434 ifftZLoop = new IFFTZLoop[threadCount];
435 for (int i = 0; i < threadCount; i++) {
436 ifftXYLoop[i] = new IFFTXYLoop();
437 ifftZLoop[i] = new IFFTZLoop();
438 }
439 }
440
441 @Override
442 public void run() {
443 int threadIndex = getThreadIndex();
444 ifftXYLoop[threadIndex].input = input;
445 ifftZLoop[threadIndex].input = input;
446 try {
447
448 execute(0, nX, ifftZLoop[threadIndex]);
449 execute(0, nZm1, ifftXYLoop[threadIndex]);
450 } catch (Exception e) {
451 logger.severe(e.toString());
452 }
453 }
454
455 private class IFFTZLoop extends IntegerForLoop {
456
457 private final double[] work;
458 private final Complex fft;
459 public double[] input;
460
461 private IFFTZLoop() {
462 fft = new Complex(nZ);
463 work = new double[nZ2];
464 }
465
466 @Override
467 public void run(final int lb, final int ub) {
468 for (int x = lb; x <= ub; x++) {
469 for (int offset = x * 2, y = 0; y < nY; y++, offset += nextY) {
470 for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
471 work[i] = input[z];
472 work[i + 1] = input[z + 1];
473 }
474 fft.ifft(work, 0, 2);
475 for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
476 input[z] = work[i];
477 input[z + 1] = work[i + 1];
478 }
479 }
480 }
481 }
482
483 @Override
484 public IntegerSchedule schedule() {
485 return schedule;
486 }
487 }
488
489 private class IFFTXYLoop extends IntegerForLoop {
490
491 private final Real fftX;
492 private final Complex fftY;
493 public double[] input;
494
495 private IFFTXYLoop() {
496 fftX = new Real(n);
497 fftY = new Complex(nY);
498 }
499
500 @Override
501 public void run(final int lb, final int ub) {
502 for (int z = lb; z <= ub; z++) {
503 for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
504 fftY.ifft(input, offset, nextY);
505 }
506 for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
507 fftX.ifft(input, offset);
508 }
509 }
510 }
511
512 @Override
513 public IntegerSchedule schedule() {
514 return schedule;
515 }
516 }
517 }
518
519
520
521
522
523
524
525 private class ParallelConvolution extends ParallelRegion {
526
527 private final int nZm1, nYm1, nX1nZ;
528 private final FFTXYLoop[] fftXYLoop;
529 private final FFTZ_Multiply_IFFTZLoop[] fftZ_Multiply_ifftZLoop;
530 private final IFFTXYLoop[] ifftXYLoop;
531 public double[] input;
532
533 private ParallelConvolution() {
534 nZm1 = nZ - 1;
535 nYm1 = nY - 1;
536 nX1nZ = nX1 * nZ;
537 fftXYLoop = new FFTXYLoop[threadCount];
538 fftZ_Multiply_ifftZLoop = new FFTZ_Multiply_IFFTZLoop[threadCount];
539 ifftXYLoop = new IFFTXYLoop[threadCount];
540 for (int i = 0; i < threadCount; i++) {
541 fftXYLoop[i] = new FFTXYLoop();
542 fftZ_Multiply_ifftZLoop[i] = new FFTZ_Multiply_IFFTZLoop();
543 ifftXYLoop[i] = new IFFTXYLoop();
544 }
545 }
546
547 @Override
548 public void run() {
549 int threadIndex = getThreadIndex();
550 fftXYLoop[threadIndex].input = input;
551 fftZ_Multiply_ifftZLoop[threadIndex].input = input;
552 ifftXYLoop[threadIndex].input = input;
553 try {
554 execute(0, nZm1, fftXYLoop[threadIndex]);
555 execute(0, nYm1, fftZ_Multiply_ifftZLoop[threadIndex]);
556 execute(0, nZm1, ifftXYLoop[threadIndex]);
557 } catch (Exception e) {
558 logger.severe(e.toString());
559 }
560 }
561
562 private class FFTXYLoop extends IntegerForLoop {
563
564 private final Real fftX;
565 private final Complex fftY;
566 public double[] input;
567
568 private FFTXYLoop() {
569 fftY = new Complex(nY);
570 fftX = new Real(n);
571 }
572
573 @Override
574 public void run(final int lb, final int ub) {
575 for (int z = lb; z <= ub; z++) {
576 for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
577 fftX.fft(input, offset);
578 }
579 for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
580 fftY.fft(input, offset, nextY);
581 }
582 }
583 }
584
585 @Override
586 public IntegerSchedule schedule() {
587 return schedule;
588 }
589 }
590
591 private class FFTZ_Multiply_IFFTZLoop extends IntegerForLoop {
592
593 private final double[] work;
594 private final Complex fft;
595 public double[] input;
596
597 private FFTZ_Multiply_IFFTZLoop() {
598 work = new double[nZ2];
599 fft = new Complex(nZ);
600 }
601
602 @Override
603 public void run(final int lb, final int ub) {
604 int index = lb * nX1nZ;
605 for (int offset = lb * nextY, y = lb; y <= ub; y++) {
606 for (int x = 0; x < nX1; x++, offset += nextX) {
607 for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
608 work[i] = input[z];
609 work[i + 1] = input[z + 1];
610 }
611 fft.fft(work, 0, 2);
612 for (int i = 0; i < nZ2; i += 2) {
613 double r = recip[index++];
614 work[i] *= r;
615 work[i + 1] *= r;
616 }
617 fft.ifft(work, 0, 2);
618 for (int z = offset, i = 0; i < nZ2; i += 2, z += nextZ) {
619 input[z] = work[i];
620 input[z + 1] = work[i + 1];
621 }
622 }
623 }
624 }
625
626 @Override
627 public IntegerSchedule schedule() {
628 return schedule;
629 }
630 }
631
632 private class IFFTXYLoop extends IntegerForLoop {
633
634 private final Real fftX;
635 private final Complex fftY;
636 public double[] input;
637
638 private IFFTXYLoop() {
639 fftY = new Complex(nY);
640 fftX = new Real(n);
641 }
642
643 @Override
644 public void run(final int lb, final int ub) {
645 for (int z = lb; z <= ub; z++) {
646 for (int offset = z * nextZ, x = 0; x < nX1; x++, offset += nextX) {
647 fftY.ifft(input, offset, nextY);
648 }
649 for (int offset = z * nextZ, y = 0; y < nY; y++, offset += nextY) {
650 fftX.ifft(input, offset);
651 }
652 }
653 }
654
655 @Override
656 public IntegerSchedule schedule() {
657 return schedule;
658 }
659 }
660 }
661 }