1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray;
39
40 import static java.lang.System.arraycopy;
41 import static java.util.Arrays.fill;
42
43 import edu.rit.pj.IntegerSchedule;
44 import edu.rit.util.Range;
45
46
47
48
49
50
51
52 public class GradientSchedule extends IntegerSchedule {
53
54 private final int[] lowerBounds;
55 private final int nAtoms;
56 private int nThreads;
57 private boolean[] threadDone;
58 private Range[] ranges;
59 private int[] weights;
60
61
62
63
64
65
66
67 protected GradientSchedule(int nThreads, int nAtoms) {
68 this.nThreads = nThreads;
69 threadDone = new boolean[nThreads];
70 ranges = new Range[nThreads];
71 lowerBounds = new int[nThreads + 1];
72 this.nAtoms = nAtoms;
73 }
74
75
76
77
78
79
80 public int[] getLowerBounds() {
81 int[] boundsToReturn = new int[nThreads];
82 arraycopy(lowerBounds, 1, boundsToReturn, 0, nThreads);
83 return boundsToReturn;
84 }
85
86
87
88
89
90
91 public int[] getThreadWeights() {
92 int[] weightsToReturn = new int[nThreads];
93 arraycopy(weights, 0, weightsToReturn, 0, nThreads);
94 return weightsToReturn;
95 }
96
97
98 @Override
99 public boolean isFixedSchedule() {
100 return true;
101 }
102
103
104 @Override
105 public Range next(int threadID) {
106 if (!threadDone[threadID]) {
107 threadDone[threadID] = true;
108 return ranges[threadID];
109 }
110 return null;
111 }
112
113
114 @Override
115 public void start(int nThreads, Range chunkRange) {
116 this.nThreads = nThreads;
117
118 if (nThreads != threadDone.length) {
119 threadDone = new boolean[nThreads];
120 }
121 fill(threadDone, false);
122
123 if (nThreads != ranges.length) {
124 ranges = new Range[nThreads];
125 }
126 fill(lowerBounds, 0);
127 defineRanges();
128 }
129
130
131
132
133
134
135 void updateWeights(int[] weights) {
136 this.weights = weights;
137 }
138
139 private int totalWeight() {
140 int totalWeight = 0;
141 for (int i = 0; i < nAtoms; i++) {
142 totalWeight += weights[i];
143 }
144 return totalWeight;
145 }
146
147 private void defineRanges() {
148 double totalWeight = totalWeight();
149
150
151
152
153
154 if (totalWeight <= nThreads) {
155 Range temp = new Range(0, nAtoms - 1);
156 ranges = temp.subranges(nThreads);
157 return;
158 }
159
160
161
162
163
164 if (nThreads == 1) {
165 ranges[0] = new Range(0, nAtoms - 1);
166 return;
167 }
168
169 double targetWeight = (totalWeight / nThreads);
170 int lastAtom = nAtoms - 1;
171
172 int currentAtom = 0;
173 lowerBounds[0] = 0;
174 int currentThread = 0;
175 while (currentThread < nThreads) {
176 int threadWeight = 0;
177 while (threadWeight < targetWeight && currentAtom < lastAtom) {
178 threadWeight += weights[currentAtom];
179 currentAtom++;
180 }
181 currentThread++;
182 if (currentAtom < lastAtom) {
183 lowerBounds[currentThread] = currentAtom;
184 } else {
185 lowerBounds[currentThread] = lastAtom;
186 break;
187 }
188 }
189
190 int lastThread = currentThread;
191
192
193 for (currentThread = 0; currentThread < lastThread - 1; currentThread++) {
194 ranges[currentThread] =
195 new Range(lowerBounds[currentThread], lowerBounds[currentThread + 1] - 1);
196 }
197
198
199 ranges[lastThread - 1] = new Range(lowerBounds[lastThread - 1], lastAtom);
200
201
202 for (int it = lastThread; it < nThreads; it++) {
203 ranges[it] = null;
204 }
205 }
206 }