1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray;
39
40 import static java.lang.System.arraycopy;
41 import static java.util.Arrays.fill;
42 import static org.apache.commons.math3.util.FastMath.min;
43
44 import edu.rit.pj.IntegerSchedule;
45 import edu.rit.util.Range;
46
47
48
49
50
51
52
53 public class SliceSchedule extends IntegerSchedule {
54
55 private final int fftZ;
56 private final int[] lowerBounds;
57 private int nThreads;
58 private boolean[] threadDone;
59 private Range[] ranges;
60 private int[] weights;
61
62
63
64
65
66
67
68 protected SliceSchedule(int nThreads, int fftZ) {
69 this.nThreads = nThreads;
70 this.fftZ = fftZ;
71 int length = min(nThreads, fftZ);
72 threadDone = new boolean[length];
73 ranges = new Range[length];
74 lowerBounds = new int[length + 1];
75 }
76
77
78 @Override
79 public boolean isFixedSchedule() {
80 return true;
81 }
82
83
84 @Override
85 public Range next(int threadID) {
86 if (threadID >= min(fftZ, nThreads)) {
87 return null;
88 }
89 if (!threadDone[threadID]) {
90 threadDone[threadID] = true;
91 return ranges[threadID];
92 }
93 return null;
94 }
95
96
97 @Override
98 public void start(int nThreads, Range chunkRange) {
99 this.nThreads = nThreads;
100 int length = min(nThreads, fftZ);
101
102 if (length != threadDone.length) {
103 threadDone = new boolean[length];
104 }
105 fill(threadDone, false);
106
107 if (length != ranges.length) {
108 ranges = new Range[length];
109 }
110 fill(lowerBounds, 0);
111 defineRanges();
112 }
113
114
115
116
117
118
119 void updateWeights(int[] weights) {
120 this.weights = weights;
121 }
122
123 private int totalWeight() {
124 int totalWeight = 0;
125 for (int i = 0; i < fftZ; i++) {
126 totalWeight += weights[i];
127 }
128 return totalWeight;
129 }
130
131 private void defineRanges() {
132 double totalWeight = totalWeight();
133
134 int length = min(nThreads, fftZ);
135
136
137 if (totalWeight <= length) {
138 Range temp = new Range(0, fftZ - 1);
139 ranges = temp.subranges(length);
140 return;
141 }
142
143
144 if (nThreads == 1) {
145 ranges[0] = new Range(0, fftZ - 1);
146 return;
147 }
148
149 double targetWeight = (totalWeight / nThreads) * .96;
150 int lastSlice = fftZ - 1;
151
152 int currentSlice = 0;
153 lowerBounds[0] = 0;
154 int currentThread = 0;
155 while (currentThread < length) {
156 int threadWeight = 0;
157 while (threadWeight < targetWeight && currentSlice < lastSlice) {
158 threadWeight += weights[currentSlice];
159 currentSlice++;
160 }
161 currentThread++;
162 if (currentSlice < lastSlice) {
163 lowerBounds[currentThread] = currentSlice;
164 } else {
165 lowerBounds[currentThread] = lastSlice;
166 break;
167 }
168 }
169
170 int lastThread = currentThread;
171
172
173 for (currentThread = 0; currentThread < lastThread - 1; currentThread++) {
174 ranges[currentThread] =
175 new Range(lowerBounds[currentThread], lowerBounds[currentThread + 1] - 1);
176 }
177
178
179 ranges[lastThread - 1] = new Range(lowerBounds[lastThread - 1], lastSlice);
180
181
182 for (int it = lastThread; it < length; it++) {
183 ranges[it] = null;
184 }
185 }
186
187
188
189
190
191
192 int[] getThreadWeights() {
193 int length = min(fftZ, nThreads);
194 int[] weightsToReturn = new int[length];
195 arraycopy(weights, 0, weightsToReturn, 0, length);
196 return weightsToReturn;
197 }
198
199
200
201
202
203
204 int[] getLowerBounds() {
205 int length = min(fftZ, nThreads);
206 int[] boundsToReturn = new int[length];
207 arraycopy(lowerBounds, 1, boundsToReturn, 0, length);
208 return boundsToReturn;
209 }
210 }