1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.algorithms.mc;
39
40 import ffx.algorithms.dynamics.thermostats.Thermostat;
41 import ffx.potential.ForceFieldEnergy;
42 import ffx.potential.MolecularAssembly;
43 import ffx.potential.bonded.AminoAcidUtils;
44 import ffx.potential.bonded.AminoAcidUtils.AminoAcid3;
45 import ffx.potential.bonded.Atom;
46 import ffx.potential.bonded.Residue;
47 import ffx.potential.bonded.ResidueState;
48 import ffx.potential.bonded.Torsion;
49 import ffx.potential.parsers.PDBFilter;
50 import org.apache.commons.io.FilenameUtils;
51
52 import java.io.File;
53 import java.util.ArrayList;
54 import java.util.List;
55 import java.util.concurrent.ThreadLocalRandom;
56 import java.util.logging.Logger;
57
58 import static ffx.utilities.Constants.R;
59 import static java.lang.String.format;
60 import static org.apache.commons.math3.util.FastMath.exp;
61 import static org.apache.commons.math3.util.FastMath.min;
62
63
64
65
66
67
68
69
70
71 public class RosenbluthOBMC implements MonteCarloListener {
72
73 private static final Logger logger = Logger.getLogger(RosenbluthOBMC.class.getName());
74
75
76 private final MolecularAssembly molecularAssembly;
77
78 private final ForceFieldEnergy forceFieldEnergy;
79
80 private final Thermostat thermostat;
81
82 private final List<Residue> targets;
83
84 private final int mcFrequency;
85
86 private final int trialSetSize;
87
88 private int steps = 0;
89
90 private double Wn;
91
92 private double Wo;
93
94 private int numMovesProposed = 0;
95
96 private StringBuilder report = new StringBuilder();
97
98 private boolean writeSnapshots = false;
99
100
101
102
103
104
105
106
107
108
109
110 public RosenbluthOBMC(MolecularAssembly molecularAssembly, ForceFieldEnergy forceFieldEnergy,
111 Thermostat thermostat, List<Residue> targets, int mcFrequency, int trialSetSize) {
112 this.targets = targets;
113 this.mcFrequency = mcFrequency;
114 this.trialSetSize = trialSetSize;
115 this.molecularAssembly = molecularAssembly;
116 this.forceFieldEnergy = forceFieldEnergy;
117 this.thermostat = thermostat;
118 }
119
120
121
122
123
124
125
126
127
128
129
130
131 public RosenbluthOBMC(MolecularAssembly molecularAssembly, ForceFieldEnergy forceFieldEnergy,
132 Thermostat thermostat, List<Residue> targets, int mcFrequency, int trialSetSize,
133 boolean writeSnapshots) {
134 this(molecularAssembly, forceFieldEnergy, thermostat, targets, mcFrequency, trialSetSize);
135 this.writeSnapshots = writeSnapshots;
136 }
137
138
139 @Override
140 public boolean mcUpdate(double temperature) {
141 steps++;
142 if (steps % mcFrequency == 0) {
143 return mcStep();
144 }
145 return false;
146 }
147
148
149
150
151
152
153 private boolean mcStep() {
154 numMovesProposed++;
155 boolean accepted;
156
157
158 int index = ThreadLocalRandom.current().nextInt(targets.size());
159 Residue target = targets.get(index);
160 ResidueState origState = target.storeState();
161 Torsion chi0 = getChiZeroTorsion(target);
162 writeSnapshot("orig");
163
164
165
166
167
168
169 List<MCMove> oldTrialSet = createTrialSet(target, origState, trialSetSize - 1);
170 List<MCMove> newTrialSet = createTrialSet(target, origState, trialSetSize);
171 report = new StringBuilder();
172 report.append(format(" Rosenbluth Rotamer MC Move: %4d\n", numMovesProposed));
173 report.append(format(" residue: %s\n", target));
174 report.append(format(" chi0: %s\n", chi0.toString()));
175 MCMove proposal = calculateRosenbluthFactors(target, chi0, origState, oldTrialSet, origState,
176 newTrialSet);
177
178
179
180
181
182 setState(target, origState);
183 writeSnapshot("uIndO");
184 double uIndO = getTotalEnergy() - getTorsionEnergy(chi0);
185 proposal.move();
186 writeSnapshot("uIndN");
187 double uIndN = getTotalEnergy() - getTorsionEnergy(chi0);
188
189
190 double temperature = thermostat.getCurrentTemperature();
191 double beta = 1.0 / (R * temperature);
192 double dInd = uIndN - uIndO;
193 double dIndE = exp(-beta * dInd);
194 double criterion = (Wn / Wo) * exp(-beta * (uIndN - uIndO));
195 double metropolis = min(1, criterion);
196 double rng = ThreadLocalRandom.current().nextDouble();
197
198 report.append(format(" theta: %3.2f\n", ((RosenbluthChi0Move) proposal).theta));
199 report.append(format(" criterion: %1.4f\n", criterion));
200 report.append(format(" Wn/Wo: %.2f\n", Wn / Wo));
201 report.append(format(" uIndN,O: %7.2f\t%7.2f\n", uIndN, uIndO));
202 report.append(format(" dInd(E): %7.2f\t%7.2f\n", dInd, dIndE));
203 report.append(format(" rng: %1.4f\n", rng));
204 if (rng < metropolis) {
205 report.append(" Accepted.\n");
206 accepted = true;
207 } else {
208 proposal.revertMove();
209 report.append(" Denied.\n");
210 accepted = false;
211 }
212 logger.info(report.toString());
213
214
215 Wn = 0.0;
216 Wo = 0.0;
217 return accepted;
218 }
219
220
221
222
223
224 private List<MCMove> createTrialSet(Residue target, ResidueState state, int setSize) {
225 List<MCMove> moves = new ArrayList<>();
226
227 setState(target, state);
228 for (int i = 0; i < setSize; i++) {
229 moves.add(new RosenbluthChi0Move(target));
230 }
231 return moves;
232 }
233
234
235
236
237
238
239
240 private MCMove calculateRosenbluthFactors(Residue target, Torsion chi0, ResidueState oldConf,
241 List<MCMove> oldTrialSet, ResidueState newConf, List<MCMove> newTrialSet) {
242 double temperature = thermostat.getCurrentTemperature();
243 double beta = 1.0 / (R * temperature);
244
245
246 Wo = exp(-beta * getTorsionEnergy(chi0));
247 report.append(format(" TestSet (Old): %5s\t%7s\t\t%7s\n", "uDepO", "uDepOe", "Sum(Wo)"));
248 report.append(format(" Orig %d: %7.4f\t%7.4f\t\t%7.4f\n", 0, getTorsionEnergy(chi0),
249 exp(-beta * getTorsionEnergy(chi0)), Wo));
250 for (int i = 0; i < oldTrialSet.size(); i++) {
251 setState(target, oldConf);
252 MCMove move = oldTrialSet.get(i);
253 move.move();
254 double uDepO = getTorsionEnergy(chi0);
255 double uDepOe = exp(-beta * uDepO);
256 Wo += uDepOe;
257 if (i < 5 || i >= oldTrialSet.size() - 5) {
258 report.append(format(" Prop %d: %7.4f\t%7.4f\t\t%7.4f\n", i + 1, uDepO, uDepOe, Wo));
259 writeSnapshot("ots");
260 } else if (i == 5) {
261 report.append(" ... \n");
262 }
263 }
264
265
266 Wn = 0.0;
267 double[] uDepN = new double[newTrialSet.size()];
268 double[] uDepNe = new double[newTrialSet.size()];
269 report.append(format(" TestSet (New): %5s\t%7s\t\t%7s\n", "uDepN", "uDepNe", "Sum(Wn)"));
270 for (int i = 0; i < newTrialSet.size(); i++) {
271 setState(target, newConf);
272 MCMove move = newTrialSet.get(i);
273 move.move();
274 uDepN[i] = getTorsionEnergy(chi0);
275 uDepNe[i] = exp(-beta * uDepN[i]);
276 Wn += uDepNe[i];
277 if (i < 5 || i >= newTrialSet.size() - 5) {
278 report.append(
279 format(" Prop %d: %7.4f\t%7.4f\t\t%7.4f\n", i, uDepN[i], uDepNe[i], Wn));
280 writeSnapshot("nts");
281 } else if (i == 5) {
282 report.append(" ... \n");
283 }
284 }
285 setState(target, oldConf);
286
287
288 MCMove proposal = null;
289 double rng = ThreadLocalRandom.current().nextDouble(Wn);
290 double running = 0.0;
291 for (int i = 0; i < newTrialSet.size(); i++) {
292 running += uDepNe[i];
293 if (rng < running) {
294 proposal = newTrialSet.get(i);
295 double prob = uDepNe[i] / Wn * 100;
296 report.append(
297 format(" Chose %d %7.4f\t%7.4f\t %4.1f%%\n", i, uDepN[i], uDepNe[i], prob));
298 break;
299 }
300 }
301 if (proposal == null) {
302 logger.severe("Programming error.");
303 }
304 return proposal;
305 }
306
307 private double getTotalEnergy() {
308 double[] x = new double[forceFieldEnergy.getNumberOfVariables() * 3];
309 forceFieldEnergy.getCoordinates(x);
310 return forceFieldEnergy.energy(x);
311 }
312
313 private double getTorsionEnergy(Torsion torsion) {
314 return torsion.energy(false);
315 }
316
317 private Torsion getChiZeroTorsion(Residue residue) {
318 AminoAcid3 name = AminoAcidUtils.AminoAcid3.valueOf(residue.getName());
319 List<Torsion> torsions = residue.getTorsionList();
320 switch (name) {
321 case VAL -> {
322 Atom N = (Atom) residue.getAtomNode("N");
323 Atom CA = (Atom) residue.getAtomNode("CA");
324 Atom CB = (Atom) residue.getAtomNode("CB");
325 Atom CG1 = (Atom) residue.getAtomNode("CG1");
326 for (Torsion torsion : torsions) {
327 if (torsion.compare(N, CA, CB, CG1)) {
328 return torsion;
329 }
330 }
331 }
332 case ILE -> {
333 Atom N = (Atom) residue.getAtomNode("N");
334 Atom CA = (Atom) residue.getAtomNode("CA");
335 Atom CB = (Atom) residue.getAtomNode("CB");
336 Atom CG1 = (Atom) residue.getAtomNode("CG1");
337 for (Torsion torsion : torsions) {
338 if (torsion.compare(N, CA, CB, CG1)) {
339 return torsion;
340 }
341 }
342 }
343 case SER -> {
344 Atom N = (Atom) residue.getAtomNode("N");
345 Atom CA = (Atom) residue.getAtomNode("CA");
346 Atom CB = (Atom) residue.getAtomNode("CB");
347 Atom OG = (Atom) residue.getAtomNode("OG");
348 for (Torsion torsion : torsions) {
349 if (torsion.compare(N, CA, CB, OG)) {
350 return torsion;
351 }
352 }
353 }
354 case THR -> {
355 Atom N = (Atom) residue.getAtomNode("N");
356 Atom CA = (Atom) residue.getAtomNode("CA");
357 Atom CB = (Atom) residue.getAtomNode("CB");
358 Atom OG1 = (Atom) residue.getAtomNode("OG1");
359 for (Torsion torsion : torsions) {
360 if (torsion.compare(N, CA, CB, OG1)) {
361 return torsion;
362 }
363 }
364 }
365 case CYX -> {
366 Atom N = (Atom) residue.getAtomNode("N");
367 Atom CA = (Atom) residue.getAtomNode("CA");
368 Atom CB = (Atom) residue.getAtomNode("CB");
369 Atom SG = (Atom) residue.getAtomNode("SG");
370 for (Torsion torsion : torsions) {
371 if (torsion.compare(N, CA, CB, SG)) {
372 return torsion;
373 }
374 }
375 }
376 case CYD -> {
377 Atom N = (Atom) residue.getAtomNode("N");
378 Atom CA = (Atom) residue.getAtomNode("CA");
379 Atom CB = (Atom) residue.getAtomNode("CB");
380 Atom SG = (Atom) residue.getAtomNode("SG");
381 for (Torsion torsion : torsions) {
382 if (torsion.compare(N, CA, CB, SG)) {
383 return torsion;
384 }
385 }
386 }
387 default -> {
388 Atom N = (Atom) residue.getAtomNode("N");
389 Atom CA = (Atom) residue.getAtomNode("CA");
390 Atom CB = (Atom) residue.getAtomNode("CB");
391 Atom CG = (Atom) residue.getAtomNode("CG");
392 for (Torsion torsion : torsions) {
393 if (torsion.compare(N, CA, CB, CG)) {
394 return torsion;
395 }
396 }
397 logger.info("Couldn't find chi[0] for residue " + residue);
398 return null;
399 }
400 }
401 logger.info("Couldn't find chi[0] for residue " + residue);
402 return null;
403 }
404
405
406
407
408
409 private void setState(Residue target, ResidueState state) {
410 target.revertState(state);
411 for (Torsion torsion : target.getTorsionList()) {
412 torsion.update();
413 }
414 }
415
416 private void writeSnapshot(String suffix) {
417 if (!writeSnapshots) {
418 return;
419 }
420 String filename =
421 FilenameUtils.removeExtension(molecularAssembly.getFile().toString()) + "." + suffix + "-"
422 + numMovesProposed;
423 File file = new File(filename);
424 PDBFilter writer = new PDBFilter(file, molecularAssembly, null, null);
425 writer.writeFile(file, false);
426 }
427 }