1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray.parallel;
39
40 import edu.rit.pj.IntegerSchedule;
41 import edu.rit.util.Range;
42
43 import static java.lang.System.arraycopy;
44 import static java.util.Arrays.fill;
45 import static org.apache.commons.math3.util.FastMath.min;
46
47
48
49
50
51
52
53 public class SliceSchedule extends IntegerSchedule {
54
55 private final int fftZ;
56 private final int[] lowerBounds;
57 private int nThreads;
58 private boolean[] threadDone;
59 private Range[] ranges;
60 private int[] weights;
61
62
63
64
65
66
67
68 public SliceSchedule(int nThreads, int fftZ) {
69 this.nThreads = nThreads;
70 this.fftZ = fftZ;
71 int length = min(nThreads, fftZ);
72 threadDone = new boolean[length];
73 ranges = new Range[length];
74 lowerBounds = new int[length + 1];
75 }
76
77
78
79
80 @Override
81 public boolean isFixedSchedule() {
82 return true;
83 }
84
85
86
87
88 @Override
89 public Range next(int threadID) {
90 if (threadID >= min(fftZ, nThreads)) {
91 return null;
92 }
93 if (!threadDone[threadID]) {
94 threadDone[threadID] = true;
95 return ranges[threadID];
96 }
97 return null;
98 }
99
100
101
102
103 @Override
104 public void start(int nThreads, Range chunkRange) {
105 this.nThreads = nThreads;
106 int length = min(nThreads, fftZ);
107
108 if (length != threadDone.length) {
109 threadDone = new boolean[length];
110 }
111 fill(threadDone, false);
112
113 if (length != ranges.length) {
114 ranges = new Range[length];
115 }
116 fill(lowerBounds, 0);
117 defineRanges();
118 }
119
120
121
122
123
124
125 public void updateWeights(int[] weights) {
126 this.weights = weights;
127 }
128
129 private int totalWeight() {
130 int totalWeight = 0;
131 for (int i = 0; i < fftZ; i++) {
132 totalWeight += weights[i];
133 }
134 return totalWeight;
135 }
136
137 private void defineRanges() {
138 double totalWeight = totalWeight();
139
140 int length = min(nThreads, fftZ);
141
142
143 if (totalWeight <= length) {
144 Range temp = new Range(0, fftZ - 1);
145 ranges = temp.subranges(length);
146 return;
147 }
148
149
150 if (nThreads == 1) {
151 ranges[0] = new Range(0, fftZ - 1);
152 return;
153 }
154
155 double targetWeight = (totalWeight / nThreads) * .96;
156 int lastSlice = fftZ - 1;
157
158 int currentSlice = 0;
159 lowerBounds[0] = 0;
160 int currentThread = 0;
161 while (currentThread < length) {
162 int threadWeight = 0;
163 while (threadWeight < targetWeight && currentSlice < lastSlice) {
164 threadWeight += weights[currentSlice];
165 currentSlice++;
166 }
167 currentThread++;
168 if (currentSlice < lastSlice) {
169 lowerBounds[currentThread] = currentSlice;
170 } else {
171 lowerBounds[currentThread] = lastSlice;
172 break;
173 }
174 }
175
176 int lastThread = currentThread;
177
178
179 for (currentThread = 0; currentThread < lastThread - 1; currentThread++) {
180 ranges[currentThread] =
181 new Range(lowerBounds[currentThread], lowerBounds[currentThread + 1] - 1);
182 }
183
184
185 ranges[lastThread - 1] = new Range(lowerBounds[lastThread - 1], lastSlice);
186
187
188 for (int it = lastThread; it < length; it++) {
189 ranges[it] = null;
190 }
191 }
192
193
194
195
196
197
198 public int[] getThreadWeights() {
199 int length = min(fftZ, nThreads);
200 int[] weightsToReturn = new int[length];
201 arraycopy(weights, 0, weightsToReturn, 0, length);
202 return weightsToReturn;
203 }
204
205
206
207
208
209
210 public int[] getLowerBounds() {
211 int length = min(fftZ, nThreads);
212 int[] boundsToReturn = new int[length];
213 arraycopy(lowerBounds, 1, boundsToReturn, 0, length);
214 return boundsToReturn;
215 }
216 }