1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray;
39
40 import static java.lang.System.arraycopy;
41 import static java.util.Arrays.fill;
42
43 import edu.rit.pj.IntegerSchedule;
44 import edu.rit.util.Range;
45
46
47
48
49
50
51
52 public class RowSchedule extends IntegerSchedule {
53
54 private final int[] lowerBounds;
55 private final int fftZ;
56 private final int fftY;
57 private int nThreads;
58 private boolean[] threadDone;
59 private Range[] ranges;
60 private int[] weights;
61
62
63
64
65
66
67
68
69 protected RowSchedule(int nThreads, int fftZ, int fftY) {
70 this.nThreads = nThreads;
71 threadDone = new boolean[nThreads];
72 ranges = new Range[nThreads];
73 lowerBounds = new int[nThreads + 1];
74 this.fftY = fftY;
75 this.fftZ = fftZ;
76 }
77
78
79 @Override
80 public boolean isFixedSchedule() {
81 return true;
82 }
83
84
85 @Override
86 public Range next(int threadID) {
87 if (!threadDone[threadID]) {
88 threadDone[threadID] = true;
89 return ranges[threadID];
90 }
91 return null;
92 }
93
94
95 @Override
96 public void start(int nThreads, Range chunkRange) {
97 this.nThreads = nThreads;
98
99 if (nThreads != threadDone.length) {
100 threadDone = new boolean[nThreads];
101 }
102 fill(threadDone, false);
103
104 if (nThreads != ranges.length) {
105 ranges = new Range[nThreads];
106 }
107 fill(lowerBounds, 0);
108 defineRanges();
109 }
110
111
112
113
114
115
116 void updateWeights(int[] weights) {
117 this.weights = weights;
118 }
119
120 private int totalWeight() {
121 int totalWeight = 0;
122 for (int i = 0; i < fftZ * fftY; i++) {
123 totalWeight += weights[i];
124 }
125 return totalWeight;
126 }
127
128 private void defineRanges() {
129 double totalWeight = totalWeight();
130
131
132 if (totalWeight <= nThreads) {
133 Range temp = new Range(0, fftZ * fftY - 1);
134 ranges = temp.subranges(nThreads);
135 return;
136 }
137
138
139 if (nThreads == 1) {
140 ranges[0] = new Range(0, fftZ * fftY - 1);
141 return;
142 }
143
144 double targetWeight = (totalWeight / nThreads);
145 int lastRow = fftZ * fftY - 1;
146
147 int currentRow = 0;
148 lowerBounds[0] = 0;
149 int currentThread = 0;
150 while (currentThread < nThreads) {
151 int threadWeight = 0;
152 while (threadWeight < targetWeight && currentRow < lastRow) {
153 threadWeight += weights[currentRow];
154 currentRow++;
155 }
156 currentThread++;
157 if (currentRow < lastRow) {
158 lowerBounds[currentThread] = currentRow;
159 } else {
160 lowerBounds[currentThread] = lastRow;
161 break;
162 }
163 }
164
165 int lastThread = currentThread;
166
167
168 for (currentThread = 0; currentThread < lastThread - 1; currentThread++) {
169 ranges[currentThread] =
170 new Range(lowerBounds[currentThread], lowerBounds[currentThread + 1] - 1);
171
172
173
174 }
175
176
177 ranges[lastThread - 1] = new Range(lowerBounds[lastThread - 1], lastRow);
178
179
180
181
182 for (int it = lastThread; it < nThreads; it++) {
183 ranges[it] = null;
184 }
185 }
186
187
188
189
190
191
192 int[] getThreadWeights() {
193 int[] weightsToReturn = new int[nThreads];
194 arraycopy(weights, 0, weightsToReturn, 0, nThreads);
195 return weightsToReturn;
196 }
197
198
199
200
201
202
203 int[] getLowerBounds() {
204 int[] boundsToReturn = new int[nThreads];
205 arraycopy(lowerBounds, 1, boundsToReturn, 0, nThreads);
206 return boundsToReturn;
207 }
208 }