1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.algorithms.optimize.manybody;
39
40 import edu.rit.mp.DoubleBuf;
41 import edu.rit.pj.Comm;
42 import edu.rit.pj.IntegerSchedule;
43 import edu.rit.pj.WorkerIntegerForLoop;
44 import edu.rit.pj.WorkerRegion;
45 import ffx.algorithms.optimize.RotamerOptimization;
46 import ffx.potential.bonded.Residue;
47 import ffx.potential.bonded.Rotamer;
48 import ffx.potential.utils.EnergyException;
49
50 import java.io.BufferedWriter;
51 import java.io.IOException;
52 import java.util.List;
53 import java.util.Map;
54 import java.util.Set;
55 import java.util.logging.Level;
56 import java.util.logging.Logger;
57
58 import static java.lang.String.format;
59
60
61 public class TwoBodyEnergyRegion extends WorkerRegion {
62
63 private static final Logger logger = Logger.getLogger(TwoBodyEnergyRegion.class.getName());
64 private final Residue[] residues;
65 private final RotamerOptimization rO;
66 private final DistanceMatrix dM;
67 private final EnergyExpansion eE;
68 private final EliminatedRotamers eR;
69
70
71
72
73 private final List<Residue> allResiduesList;
74
75 private final Map<Integer, Integer[]> twoBodyEnergyMap;
76
77 private final BufferedWriter energyWriter;
78
79 private final Comm world;
80
81 private final int numProc;
82
83 private final boolean prunePairClashes;
84
85
86
87
88 private final double superpositionThreshold;
89
90 private final boolean master;
91
92 private final int rank;
93
94 private final boolean verbose;
95
96 private final boolean writeEnergyRestart;
97
98
99
100
101 private final boolean printFiles;
102
103 private Set<Integer> keySet;
104
105 public TwoBodyEnergyRegion(RotamerOptimization rotamerOptimization, DistanceMatrix dM,
106 EnergyExpansion eE, EliminatedRotamers eR, Residue[] residues, List<Residue> allResiduesList,
107 BufferedWriter energyWriter, Comm world, int numProc, boolean prunePairClashes,
108 double superpositionThreshold, boolean master, int rank, boolean verbose,
109 boolean writeEnergyRestart, boolean printFiles) {
110 this.rO = rotamerOptimization;
111 this.dM = dM;
112 this.eE = eE;
113 this.eR = eR;
114 this.residues = residues;
115 this.allResiduesList = allResiduesList;
116 this.energyWriter = energyWriter;
117 this.world = world;
118 this.numProc = numProc;
119 this.prunePairClashes = prunePairClashes;
120 this.superpositionThreshold = superpositionThreshold;
121 this.master = master;
122 this.rank = rank;
123 this.verbose = verbose;
124 this.writeEnergyRestart = writeEnergyRestart;
125 this.printFiles = printFiles;
126
127 this.twoBodyEnergyMap = eE.getTwoBodyEnergyMap();
128 logger.info(format(" Number of 2-body energies to calculate: %d", twoBodyEnergyMap.size()));
129 }
130
131 @Override
132 public void finish() {
133
134 eR.prePrunePairs(residues);
135
136
137 if (prunePairClashes) {
138 eR.prunePairClashes(residues);
139 }
140
141
142 if (master && verbose) {
143 for (int i = 0; i < residues.length; i++) {
144 Residue resI = residues[i];
145 Rotamer[] rotI = resI.getRotamers();
146 for (int ri = 0; ri < rotI.length; ri++) {
147 if (eR.check(i, ri)) {
148 continue;
149 }
150 for (int j = i + 1; j < residues.length; j++) {
151 Residue resJ = residues[j];
152 Rotamer[] rotJ = resJ.getRotamers();
153 for (int rj = 0; rj < rotJ.length; rj++) {
154 if (eR.check(j, rj) || eR.check(i, ri, j, rj)) {
155 continue;
156 }
157 logger.info(
158 format(" Pair energy %8s %-2d, %8s %-2d: %s", residues[i].toString(rotI[ri]), ri,
159 residues[j].toString(rotJ[rj]), rj,
160 rO.formatEnergy(eE.get2Body(i, ri, j, rj))));
161 }
162 }
163 }
164 }
165 }
166 }
167
168 @Override
169 public void run() throws Exception {
170 if (!keySet.isEmpty()) {
171 execute(0, keySet.size() - 1, new TwoBodyEnergyLoop());
172 }
173 }
174
175 @Override
176 public void start() {
177 int numPair = twoBodyEnergyMap.size();
178 int remainder = numPair % numProc;
179
180
181 Integer[] padding = {-1, -1, -1, -1};
182
183 int padKey = numPair;
184 while (remainder != 0) {
185 twoBodyEnergyMap.put(padKey++, padding);
186 remainder = twoBodyEnergyMap.size() % numProc;
187 }
188
189 numPair = twoBodyEnergyMap.size();
190 if (numPair % numProc != 0) {
191 logger.severe(" Logic error padding pair energies.");
192 }
193
194
195 keySet = twoBodyEnergyMap.keySet();
196 }
197
198 private class TwoBodyEnergyLoop extends WorkerIntegerForLoop {
199
200 final DoubleBuf[] resultBuffer;
201 final DoubleBuf myBuffer;
202
203 TwoBodyEnergyLoop() {
204 resultBuffer = new DoubleBuf[numProc];
205 for (int i = 0; i < numProc; i++) {
206 resultBuffer[i] = DoubleBuf.buffer(new double[5]);
207 }
208 myBuffer = resultBuffer[rank];
209 }
210
211 @Override
212 public void run(int lb, int ub) {
213 for (int key = lb; key <= ub; key++) {
214 long time = -System.nanoTime();
215 Integer[] job = twoBodyEnergyMap.get(key);
216 int i = job[0];
217 int ri = job[1];
218 int j = job[2];
219 int rj = job[3];
220
221 myBuffer.put(0, i);
222 myBuffer.put(1, ri);
223 myBuffer.put(2, j);
224 myBuffer.put(3, rj);
225 myBuffer.put(4, 0.0);
226
227
228 if (i >= 0 && ri >= 0 && j >= 0 && rj >= 0) {
229 if (!eR.check(i, ri) || !eR.check(j, rj) || !eR.check(i, ri, j, rj)) {
230 Residue residueI = residues[i];
231 Residue residueJ = residues[j];
232 Rotamer[] rotI = residues[i].getRotamers();
233 Rotamer[] rotJ = residues[j].getRotamers();
234 int indexI = allResiduesList.indexOf(residueI);
235 int indexJ = allResiduesList.indexOf(residueJ);
236 double resDist = dM.getResidueDistance(indexI, ri, indexJ, rj);
237 String resDistString = "large";
238 if (resDist < Double.MAX_VALUE) {
239 resDistString = format("%5.3f", resDist);
240 }
241
242 double dist = dM.checkDistMatrix(indexI, ri, indexJ, rj);
243 String distString = " large";
244 if (dist < Double.MAX_VALUE) {
245 distString = format("%10.3f", dist);
246 }
247
248 double twoBodyEnergy;
249 if (dist < superpositionThreshold) {
250
251 twoBodyEnergy = Double.NaN;
252 logger.info(
253 format(" Pair %8s %-2d, %8s %-2d:\t NaN at %10.3f A (%s A by res) < %5.3f Ang",
254 residueI.toString(rotI[ri]), ri, residueJ.toString(rotJ[rj]), rj, dist,
255 resDist, superpositionThreshold));
256 } else if (dM.checkPairDistThreshold(indexI, ri, indexJ, rj)) {
257
258 twoBodyEnergy = 0.0;
259 time += System.nanoTime();
260 logger.info(
261 format(" Pair %8s %-2d, %8s %-2d: %s at %s A (%s A by res) in %6.4f (sec).",
262 residueI.toString(rotI[ri]), ri, residueJ.toString(rotJ[rj]), rj,
263 rO.formatEnergy(twoBodyEnergy), distString, resDistString, time * 1.0e-9));
264 } else {
265 try {
266 twoBodyEnergy = eE.compute2BodyEnergy(residues, i, ri, j, rj);
267 time += System.nanoTime();
268 logger.info(
269 format(" Pair %8s %-2d, %8s %-2d: %s at %s A (%s A by res) in %6.4f (sec).",
270 residueI.toString(rotI[ri]), ri, residueJ.toString(rotJ[rj]), rj,
271 rO.formatEnergy(twoBodyEnergy), distString, resDistString, time * 1.0e-9));
272 } catch (EnergyException ex) {
273 twoBodyEnergy = ex.getEnergy();
274 time += System.nanoTime();
275 logger.info(
276 format(" Pair %8s %-2d, %8s %-2d: %s at %s A (%s A by res) in %6.4f (sec).",
277 residueI.toString(rotI[ri]), ri, residueJ.toString(rotJ[rj]), rj,
278 rO.formatEnergy(twoBodyEnergy), distString, resDistString, time * 1.0e-9));
279 }
280 }
281 myBuffer.put(4, twoBodyEnergy);
282 }
283 }
284
285
286 if (numProc > 1) {
287 try {
288 world.allGather(myBuffer, resultBuffer);
289 } catch (Exception e) {
290 logger.log(Level.SEVERE, " Exception communicating pair energies.", e);
291 }
292 }
293
294
295 for (DoubleBuf doubleBuf : resultBuffer) {
296 int resI = (int) doubleBuf.get(0);
297 int rotI = (int) doubleBuf.get(1);
298 int resJ = (int) doubleBuf.get(2);
299 int rotJ = (int) doubleBuf.get(3);
300 double energy = doubleBuf.get(4);
301
302 if (resI >= 0 && rotI >= 0 && resJ >= 0 && rotJ >= 0) {
303 if (!Double.isFinite(energy)) {
304 logger.info(
305 " Rotamer pair eliminated: " + resI + ", " + rotI + ", " + resJ + ", " + rotJ);
306 eR.eliminateRotamerPair(residues, resI, rotI, resJ, rotJ, false);
307 }
308 eE.set2Body(resI, rotI, resJ, rotJ, energy);
309 if (rank == 0 && writeEnergyRestart && printFiles) {
310 try {
311 energyWriter.append(
312 format("Pair %d %d, %d %d: %16.8f", resI, rotI, resJ, rotJ, energy));
313 energyWriter.newLine();
314 energyWriter.flush();
315 } catch (IOException ex) {
316 logger.log(Level.SEVERE, " Exception writing energy restart file.", ex);
317 }
318 }
319 }
320 }
321 }
322 }
323
324 @Override
325 public IntegerSchedule schedule() {
326
327 return IntegerSchedule.fixed();
328 }
329 }
330 }