1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.numerics.atomic;
39
40 import static ffx.numerics.atomic.AtomicDoubleArray.atomicDoubleArrayFactory;
41 import static java.lang.String.format;
42
43 import edu.rit.pj.IntegerForLoop;
44 import edu.rit.pj.ParallelRegion;
45 import edu.rit.pj.ParallelTeam;
46 import ffx.numerics.atomic.AtomicDoubleArray.AtomicDoubleArrayImpl;
47 import ffx.numerics.math.Double3;
48 import ffx.numerics.math.DoubleMath;
49 import java.util.Objects;
50
51
52
53
54
55
56
57 public class AtomicDoubleArray3D {
58
59
60
61
62 private final AtomicDoubleArray[] atomicDoubleArray;
63
64
65
66
67 private final AtomicDoubleArrayImpl atomicDoubleArrayImpl;
68
69
70
71
72 private final ParallelRegion3D parallelRegion3D = new ParallelRegion3D();
73
74
75
76
77
78
79
80 public AtomicDoubleArray3D(AtomicDoubleArrayImpl atomicDoubleArrayImpl, int size) {
81 this(atomicDoubleArrayImpl, size, ParallelTeam.getDefaultThreadCount());
82 }
83
84
85
86
87
88
89
90
91
92 public AtomicDoubleArray3D(AtomicDoubleArrayImpl atomicDoubleArrayImpl, int size, int nThreads) {
93 atomicDoubleArray = new AtomicDoubleArray[3];
94 atomicDoubleArray[0] = atomicDoubleArrayFactory(atomicDoubleArrayImpl, nThreads, size);
95 atomicDoubleArray[1] = atomicDoubleArrayFactory(atomicDoubleArrayImpl, nThreads, size);
96 atomicDoubleArray[2] = atomicDoubleArrayFactory(atomicDoubleArrayImpl, nThreads, size);
97 this.atomicDoubleArrayImpl = atomicDoubleArrayImpl;
98 }
99
100 public AtomicDoubleArray3D(AtomicDoubleArray x, AtomicDoubleArray y, AtomicDoubleArray z) {
101 atomicDoubleArray = new AtomicDoubleArray[3];
102 atomicDoubleArray[0] = x;
103 atomicDoubleArray[1] = y;
104 atomicDoubleArray[2] = z;
105 if (x instanceof MultiDoubleArray) {
106 this.atomicDoubleArrayImpl = AtomicDoubleArrayImpl.MULTI;
107 } else if (x instanceof AdderDoubleArray) {
108 this.atomicDoubleArrayImpl = AtomicDoubleArrayImpl.ADDER;
109 } else {
110 this.atomicDoubleArrayImpl = AtomicDoubleArrayImpl.PJ;
111 }
112 }
113
114
115
116
117
118
119
120
121
122
123 public void add(int threadID, int index, double x, double y, double z) {
124 atomicDoubleArray[0].add(threadID, index, x);
125 atomicDoubleArray[1].add(threadID, index, y);
126 atomicDoubleArray[2].add(threadID, index, z);
127 }
128
129
130
131
132
133
134
135
136 public void add(int threadID, int index, Double3 d3) {
137 atomicDoubleArray[0].add(threadID, index, d3.x());
138 atomicDoubleArray[1].add(threadID, index, d3.y());
139 atomicDoubleArray[2].add(threadID, index, d3.z());
140 }
141
142
143
144
145
146
147 public void alloc(int size) {
148 atomicDoubleArray[0].alloc(size);
149 atomicDoubleArray[1].alloc(size);
150 atomicDoubleArray[2].alloc(size);
151 }
152
153
154
155
156
157
158
159
160 public double get(int dim, int index) {
161 return atomicDoubleArray[dim].get(index);
162 }
163
164
165
166
167
168
169
170
171 public Double3 get(int index) {
172 return new Double3(
173 atomicDoubleArray[0].get(index),
174 atomicDoubleArray[1].get(index),
175 atomicDoubleArray[2].get(index));
176 }
177
178
179
180
181
182
183
184
185 public double getX(int index) {
186 return atomicDoubleArray[0].get(index);
187 }
188
189
190
191
192
193
194
195
196 public double getY(int index) {
197 return atomicDoubleArray[1].get(index);
198 }
199
200
201
202
203
204
205
206
207 public double getZ(int index) {
208 return atomicDoubleArray[2].get(index);
209 }
210
211
212
213
214
215
216
217 public void reduce(int lb, int ub) {
218
219 if (Objects.requireNonNull(atomicDoubleArrayImpl) == AtomicDoubleArrayImpl.MULTI) {
220 atomicDoubleArray[0].reduce(lb, ub);
221 atomicDoubleArray[1].reduce(lb, ub);
222 atomicDoubleArray[2].reduce(lb, ub);
223 }
224 }
225
226
227
228
229
230
231 public void reduce(ParallelTeam parallelTeam) {
232 if (Objects.requireNonNull(atomicDoubleArrayImpl) == AtomicDoubleArrayImpl.MULTI) {
233 parallelRegion3D.setOperation(Operation.REDUCE);
234 try {
235 parallelTeam.execute(parallelRegion3D);
236 } catch (Exception e) {
237 throw new RuntimeException(e);
238 }
239 }
240 }
241
242
243
244
245
246
247
248
249 public void reset(int threadID, int lb, int ub) {
250 atomicDoubleArray[0].reset(threadID, lb, ub);
251 atomicDoubleArray[1].reset(threadID, lb, ub);
252 atomicDoubleArray[2].reset(threadID, lb, ub);
253 }
254
255
256
257
258
259
260 public void reset(ParallelTeam parallelTeam) {
261 parallelRegion3D.setOperation(Operation.RESET);
262 try {
263 parallelTeam.execute(parallelRegion3D);
264 } catch (Exception e) {
265 throw new RuntimeException(e);
266 }
267 }
268
269
270
271
272
273
274
275
276 public void scale(int threadID, int index, double scale) {
277 atomicDoubleArray[0].scale(threadID, index, scale);
278 atomicDoubleArray[1].scale(threadID, index, scale);
279 atomicDoubleArray[2].scale(threadID, index, scale);
280 }
281
282
283
284
285
286
287
288
289
290
291 public void set(int threadID, int index, double x, double y, double z) {
292 atomicDoubleArray[0].set(threadID, index, x);
293 atomicDoubleArray[1].set(threadID, index, y);
294 atomicDoubleArray[2].set(threadID, index, z);
295 }
296
297
298
299
300
301
302
303
304 public void set(int threadID, int index, Double3 d3) {
305 atomicDoubleArray[0].set(threadID, index, d3.x());
306 atomicDoubleArray[1].set(threadID, index, d3.y());
307 atomicDoubleArray[2].set(threadID, index, d3.z());
308 }
309
310
311
312
313
314
315
316
317
318
319 public void sub(int threadID, int index, double x, double y, double z) {
320 atomicDoubleArray[0].sub(threadID, index, x);
321 atomicDoubleArray[1].sub(threadID, index, y);
322 atomicDoubleArray[2].sub(threadID, index, z);
323 }
324
325
326
327
328
329
330
331
332 public void sub(int threadID, int index, Double3 d3) {
333 atomicDoubleArray[0].sub(threadID, index, d3.x());
334 atomicDoubleArray[1].sub(threadID, index, d3.y());
335 atomicDoubleArray[2].sub(threadID, index, d3.z());
336 }
337
338
339
340
341
342
343
344 public String toString(int index) {
345 String defaultLabel = " " + index + ": ";
346 return toString(index, defaultLabel);
347 }
348
349
350
351
352
353
354
355
356 public String toString(int index, String label) {
357 var d = new double[] {getX(index), getY(index), getZ(index)};
358 return DoubleMath.toString(d, label);
359 }
360
361
362
363
364
365
366 public String toString() {
367 StringBuilder sb = new StringBuilder();
368 for (int i = 0; i < atomicDoubleArray.length; i++) {
369 sb.append(toString(i)).append("\n");
370 }
371 return sb.toString();
372 }
373
374
375
376
377
378
379
380
381 public String toString(String label) {
382 StringBuilder sb = new StringBuilder();
383 if (label.contains("%d")) {
384 for (int i = 0; i < atomicDoubleArray[0].size(); i++) {
385 sb.append(toString(i, format(label, i))).append("\n");
386 }
387 } else {
388 for (int i = 0; i < atomicDoubleArray[0].size(); i++) {
389 sb.append(toString(i, label)).append("\n");
390 }
391 }
392 return sb.toString();
393 }
394
395
396
397
398 private enum Operation {
399 RESET,
400 REDUCE
401 }
402
403
404
405
406 private class ParallelRegion3D extends ParallelRegion {
407
408 private Operation operation;
409 private IntegerForLoop3D[] integerForLoop3D;
410
411
412
413
414 public ParallelRegion3D() {
415 operation = Operation.RESET;
416 }
417
418
419
420
421
422
423 public void setOperation(Operation operation) {
424 this.operation = operation;
425 }
426
427
428 @Override
429 public void start() {
430 int nThreads = getThreadCount();
431 if (integerForLoop3D == null) {
432 integerForLoop3D = new IntegerForLoop3D[nThreads];
433 }
434 }
435
436
437 @Override
438 public void run() throws Exception {
439 int threadID = getThreadIndex();
440 if (integerForLoop3D[threadID] == null) {
441 integerForLoop3D[threadID] = new IntegerForLoop3D();
442 }
443 int size = atomicDoubleArray[0].size();
444 execute(0, size - 1, integerForLoop3D[threadID]);
445 }
446
447
448
449
450 private class IntegerForLoop3D extends IntegerForLoop {
451
452
453 @Override
454 public void run(int lb, int ub) {
455 int threadID = getThreadIndex();
456 switch (operation) {
457 case RESET -> reset(threadID, lb, ub);
458 case REDUCE -> reduce(lb, ub);
459 }
460 }
461 }
462 }
463
464
465 }