1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.algorithms.commands.test;
39
40 import ffx.algorithms.cli.AlgorithmsCommand;
41 import ffx.numerics.Potential;
42 import ffx.numerics.math.RunningStatistics;
43 import ffx.numerics.math.SummaryStatistics;
44 import ffx.potential.ForceFieldEnergy;
45 import ffx.potential.bonded.Residue;
46 import ffx.potential.extended.ExtendedSystem;
47 import ffx.potential.parsers.XPHFilter;
48 import ffx.utilities.FFXBinding;
49 import org.apache.commons.lang3.ArrayUtils;
50 import picocli.CommandLine.Command;
51 import picocli.CommandLine.Option;
52 import picocli.CommandLine.Parameters;
53
54 import java.io.File;
55 import java.util.ArrayList;
56 import java.util.Random;
57
58 import static ffx.numerics.estimator.EstimateBootstrapper.getBootstrapIndices;
59 import static java.lang.String.format;
60
61
62
63
64
65
66
67
68 @Command(description = " Use the Rao-Blackwell estimator to get a free energy difference for residues in a CpHMD system.", name = "test.RaoBlackwellEstimator")
69 public class RaoBlackwellEstimator extends AlgorithmsCommand {
70
71 @Option(names = {"--aFi", "--arcFile"}, paramLabel = "traj",
72 description = "A file containing the the PDB from which to build the ExtendedSystem. There is currently no default.")
73 private String arcFileName = null;
74
75 @Option(names = {"--numSnaps"}, paramLabel = "-1", defaultValue = "-1",
76 description = "Number of snapshots to use from an archive file. Default is all.")
77 private int numSnaps;
78
79 @Option(names = {"--specifiedResidues", "--sR"}, paramLabel = "<selection>", defaultValue = "",
80 description = "Specified residues to do analysis.")
81 private String specified;
82
83 @Option(names = {"--startSnap"}, paramLabel = "-1", defaultValue = "-1",
84 description = "Start energy evaluations at a snap other than 2.")
85 private int startSnap;
86
87 @Option(names = {"--bootstrapIter"}, paramLabel = "100000", defaultValue = "100000",
88 description = "Number of bootstrap iterations. Set -1 for no bootstrapping.")
89 private int bootstrapIter;
90
91 @Option(names = {"--skip"}, paramLabel = "-1", defaultValue = "-1",
92 description = "Calculate energies on snaps with this interval.")
93 private int skip;
94
95 @Option(names = {"--writeFrequency"}, paramLabel = "100", defaultValue = "100",
96 description = "Calculate the RBE and print at this snapshot read frequency.")
97 private int writeFrequency;
98
99 @Parameters(arity = "1..*", paramLabel = "files",
100 description = "PDB input file in the same directory as the ARC file.")
101 private String filename;
102
103 private Potential forceFieldEnergy;
104 private ArrayList<Double>[][] oneZeroDeltaLists;
105 private ArrayList<Double>[][] tautomerOneZeroDeltaList;
106 private int numESVs;
107 private int numTautomerESVs;
108
109
110
111
112 public RaoBlackwellEstimator() {
113 super();
114 }
115
116
117
118
119
120 public RaoBlackwellEstimator(FFXBinding binding) {
121 super(binding);
122 }
123
124
125
126
127
128 public RaoBlackwellEstimator(String[] args) {
129 super(args);
130 }
131
132 @Override
133 @SuppressWarnings("unchecked")
134 public RaoBlackwellEstimator run() {
135
136 if (!init()) {
137 return this;
138 }
139
140
141 File arcFile = new File(arcFileName);
142 if(!arcFile.exists()){
143 logger.severe(format(" ARC file %s does not exist.", arcFile));
144 }
145 else{
146 logger.info(format("Using ARC file %s.", arcFile));
147 }
148
149 boolean bootstrap = false;
150 if(bootstrapIter >= 50)
151 {
152 bootstrap = true;
153 } else if (bootstrapIter != -1){
154 logger.severe("Too few bootstrap iterations specified. Must be at least 50.");
155 }
156
157 activeAssembly = getActiveAssembly(filename);
158 if (activeAssembly == null) {
159 logger.info(helpString());
160 return this;
161 }
162 forceFieldEnergy = activeAssembly.getPotentialEnergy();
163
164
165 String filename = activeAssembly.getFile().getAbsolutePath();
166
167
168 ExtendedSystem esvSystem = new ExtendedSystem(activeAssembly, 7.0, null);
169
170
171 Residue specialResidue = null;
172 int numberOfStates = 1;
173 int[][] states = null;
174 if(esvSystem.getSpecialResidueList().size() > 1){
175 logger.severe(" Multiple special residues were identified in the key file. " +
176 "Only one can be specified with this algorithm.");
177 } else if (esvSystem.getSpecialResidueList().size() == 1) {
178 int specialResidueNumber = esvSystem.getSpecialResidueList().get(0).intValue();
179 for (Residue residue : esvSystem.getTitratingResidueList()) {
180 if (residue.getResidueNumber() == specialResidueNumber) {
181 specialResidue = residue;
182 }
183 }
184 if(specialResidue != null){
185 numberOfStates = !esvSystem.isTautomer(specialResidue) ? 3 : 4;
186 switch (specialResidue.getName()) {
187
188 case "ASD":
189 case "GLD":
190 states = new int[3][2];
191 states[0][0] = 0;
192 states[0][1] = 0;
193 states[1][0] = 1;
194 states[1][1] = 0;
195 states[2][0] = 1;
196 states[2][1] = 1;
197 break;
198
199 case "HIS":
200 states = new int[3][2];
201 states[0][0] = 0;
202 states[0][1] = 0;
203 states[1][0] = 0;
204 states[1][1] = 1;
205 states[2][0] = 1;
206 states[2][1] = 0;
207 break;
208
209 case "LYS":
210 case "CYS":
211 states = new int[2][2];
212 states[0][0] = 0;
213 states[0][1] = 0;
214 states[1][0] = 1;
215 states[1][1] = 0;
216 break;
217 }
218
219 } else {
220 logger.severe(" The special residue specified in the key file was not found in the titrating residue list.");
221 }
222 }
223
224
225
226 int[] specifiedResidues = null;
227 if(specified != null && !specified.isEmpty()){
228 String[] specifiedResiduesString = specified.split(",");
229 specifiedResidues = new int[specifiedResiduesString.length];
230 for (int i = 0; i < specifiedResiduesString.length; i++) {
231 specifiedResidues[i] = Integer.parseInt(specifiedResiduesString[i].trim());
232 }
233 }
234 ArrayList<Residue> onlyResidues = new ArrayList<>();
235 ArrayList<Integer> onlyResidueIndices = new ArrayList<>();
236 if(specifiedResidues != null){
237 for (int i = 0; i < esvSystem.getTitratingResidueList().size(); i++) {
238 Residue residue = esvSystem.getTitratingResidueList().get(i);
239 if (ArrayUtils.contains(specifiedResidues, residue.getResidueNumber())) {
240 onlyResidues.add(residue);
241 onlyResidueIndices.add(i);
242 }
243 }
244 if(onlyResidues.size() != specifiedResidues.length){
245 logger.severe("Could not find all residues from --specifiedResidues input.");
246 }
247 }
248 else{
249 for (int i = 0; i < esvSystem.getTitratingResidueList().size(); i++) {
250 onlyResidueIndices.add(i);
251 }
252 }
253
254
255
256 numESVs = esvSystem.getTitratingResidueList().size();
257 oneZeroDeltaLists = new ArrayList[numESVs][numberOfStates + 1];
258 for (int i = 0; i < numESVs; i++) {
259 for (int j = 0; j < numberOfStates + 1; j++) {
260 oneZeroDeltaLists[i][j] = new ArrayList<>();
261 }
262 }
263
264 numTautomerESVs = esvSystem.getTautomerizingResidueList().size();
265 tautomerOneZeroDeltaList = new ArrayList[numTautomerESVs][numberOfStates + 1];
266 for (int i = 0; i < numTautomerESVs; i++) {
267 for (int j = 0; j < numberOfStates + 1; j++) {
268 tautomerOneZeroDeltaList[i][j] = new ArrayList<>();
269 }
270 }
271
272
273 activeAssembly.setFile(arcFile);
274 XPHFilter xphFilter = new XPHFilter(
275 arcFile,
276 activeAssembly,
277 activeAssembly.getForceField(),
278 activeAssembly.getProperties(),
279 esvSystem);
280 xphFilter.readFile();
281
282
283 esvSystem.setFixedTitrationState(true);
284 esvSystem.setFixedTautomerState(true);
285 ((ForceFieldEnergy) forceFieldEnergy).attachExtendedSystem(esvSystem);
286 logger.info(format(" Attached extended system with %d residues.", numESVs));
287
288
289 double[] x = new double[forceFieldEnergy.getNumberOfVariables()];
290 forceFieldEnergy.getCoordinates(x);
291 forceFieldEnergy.energy(x, true);
292
293
294 double pH = 0.0;
295 String[] parts = xphFilter.getRemarkLines()[0].split(" ");
296 for(int i = 0; i < parts.length; i++) {
297 if (parts[i].contains("pH")) {
298 pH = Double.parseDouble(parts[i+1]);
299 }
300 }
301 logger.info("\n Setting constant pH to " + pH + ".");
302 esvSystem.setConstantPh(pH);
303
304
305 int evals = 0;
306 if(numSnaps != -1) {
307 logger.info(format(" Using %d snapshots.", numSnaps));
308 }
309 else {
310 logger.info(format(" Using all %d snapshots.", xphFilter.countNumModels()));
311 }
312
313
314 while(xphFilter.readNext()) {
315
316 if(startSnap != -1 && startSnap > 2 && evals == 0){
317 for(int i = 0; i < startSnap - 2; i++){
318 xphFilter.readNext();
319 }
320 }
321
322
323 forceFieldEnergy.getCoordinates(x);
324 for (int i : onlyResidueIndices) {
325 double titrationState = 0;
326 double tautomerState = 0;
327 if(specialResidue != null){
328 titrationState = esvSystem.getTitrationLambda(specialResidue);
329 tautomerState = esvSystem.getTautomerLambda(specialResidue);
330 }
331
332
333 for (int j = 0; j < numberOfStates; j++) {
334 if (j != 0) {
335 esvSystem.setTitrationLambda(specialResidue, states[j-1][0], false);
336 if(numberOfStates != 3) {
337 esvSystem.setTautomerLambda(specialResidue, states[j-1][1], false);
338 }
339 }
340
341
342 ArrayList<Double> results = getZeroOneDeltas(i, esvSystem, (ForceFieldEnergy) forceFieldEnergy, x);
343
344 Residue res = esvSystem.getTitratingResidueList().get(i);
345 if(esvSystem.isTautomer(res)){
346 tautomerOneZeroDeltaList[esvSystem.getTautomerizingResidueList().indexOf(res)][j].add(results.get(0));
347 oneZeroDeltaLists[i][j].add(results.get(1));
348 } else{
349 oneZeroDeltaLists[i][j].add(results.get(0));
350 }
351 }
352
353 if(specialResidue == esvSystem.getExtendedResidueList().get(i)) {
354 break;
355 }
356
357
358 if(specialResidue != null) {
359 esvSystem.setTitrationLambda(specialResidue, titrationState, false);
360 esvSystem.setTautomerLambda(specialResidue, tautomerState, false);
361 }
362 }
363
364 evals++;
365 if(evals % writeFrequency == 0 || evals == numSnaps) {
366
367 int tautomerCount = 0;
368 double[][] energyLists = new double[numESVs][numberOfStates];
369 double[][] energyStdLists = new double[numESVs][numberOfStates];
370 double[][] tautomerEnergyLists = new double[numTautomerESVs][numberOfStates];
371 double[][] tautomerEnergyStdLists = new double[numTautomerESVs][numberOfStates];
372 for (int i : onlyResidueIndices) {
373 Residue res = esvSystem.getExtendedResidueList().get(i);
374 logger.info("\n Performing Rao-Blackwell Estimator on " + res.getAminoAcid3() + ".");
375 if (bootstrap) {
376 logger.info(" Performing bootstrap with " + bootstrapIter + " iterations.");
377 } else {
378 logger.info(" Performing RBE without bootstrap. Ignore standard deviation values.");
379 }
380 for (int j = 0; j < numberOfStates; j++) {
381 double[] bootstrapMeanStd = RBE(oneZeroDeltaLists[i][j], bootstrap, bootstrapIter);
382 energyLists[i][j] = bootstrapMeanStd[0];
383 if (bootstrap) {
384 energyStdLists[i][j] = bootstrapMeanStd[1];
385 }
386
387 if (esvSystem.getTautomerizingResidueList().contains(res)) {
388 bootstrapMeanStd = RBE(tautomerOneZeroDeltaList[esvSystem.getTautomerizingResidueList().indexOf(res)][j],
389 bootstrap, bootstrapIter);
390 tautomerEnergyLists[tautomerCount][j] = bootstrapMeanStd[0];
391 if (bootstrap) {
392 tautomerEnergyStdLists[tautomerCount][j] = bootstrapMeanStd[1];
393 }
394 }
395 }
396 if (esvSystem.isTautomer(res)) {
397 tautomerCount++;
398 }
399 if (specialResidue == res) {
400 break;
401 }
402 }
403
404
405 printResults(specialResidue, esvSystem, energyLists, energyStdLists, tautomerEnergyLists, tautomerEnergyStdLists,
406 states, numberOfStates, numESVs, onlyResidueIndices);
407 }
408 if (numSnaps != -1 && evals >= numSnaps) {
409 break;
410 }
411
412 if(skip != -1){
413 for(int i = 0; i < skip-1; i++){
414 xphFilter.readNext();
415 }
416 }
417 }
418 return this;
419 }
420
421
422
423 private static double[] RBE(ArrayList<Double> deltaUList, boolean bootstrap, int bootstrapIter){
424 ArrayList<Double> deltaU = deltaUList;
425 double temperature = 298.0;
426 double boltzmann = 0.001985875;
427 double beta = 1.0 / (temperature * boltzmann);
428
429 ArrayList<Double> deltaExp = exp(mult(-beta, deltaU));
430 ArrayList<Double> numerator = div(mult(beta, mult(deltaU, deltaExp)), subtract(1.0, deltaExp));
431 ArrayList<Double> denominator = div(mult(beta, deltaU), subtract(1.0, deltaExp));
432
433 double[] deltaGRBE = bootstrap ?
434 bootStrap(numerator, denominator, bootstrapIter) :
435 new double[] {-(1.0 / beta) * Math.log(average(numerator) / average(denominator))};
436 return deltaGRBE;
437 }
438
439
440 private static double[] bootStrap(ArrayList<Double> numerator, ArrayList<Double> denominator, int iter) {
441 RunningStatistics estimates = new RunningStatistics();
442 for (int k = 0; k < iter; k++) {
443 Random rng = new Random();
444 int[] trial = getBootstrapIndices(numerator.size(), rng);
445 double estimate = estimateDg(numerator, denominator, trial);
446 estimates.addValue(estimate);
447 }
448 SummaryStatistics stats = new SummaryStatistics(estimates);
449 return new double[] {stats.mean, stats.getSd()};
450 }
451
452
453 private static double estimateDg(ArrayList<Double> num, ArrayList<Double> denom, int[] index) {
454 double temperature = 298.0;
455 double boltzmann = 0.001985875;
456 double beta = 1.0 / (temperature * boltzmann);
457 ArrayList<Double> numerator = new ArrayList<>();
458 numerator.ensureCapacity(index.length);
459 ArrayList<Double> denominator = new ArrayList<>();
460 denominator.ensureCapacity(index.length);
461
462 for(int i = 0; i < index.length; i++) {
463 numerator.add(num.get(index[i]));
464 denominator.add(denom.get(index[i]));
465 }
466
467 return -(1.0 / beta) * Math.log(average(numerator) / average(denominator));
468 }
469
470
471 private static void printResults(Residue specialResidue, ExtendedSystem esvSystem, double[][] energyLists, double[][] energyStdLists,
472 double[][] tautomerEnergyLists, double[][]tautomerStdLists, int[][] states,
473 int numberOfStates, int numESVs, ArrayList<Integer> onlyResidueIndex)
474 {
475 logger.info("\n Rao-Blackwell Estimator Results: ");
476 ArrayList<String> line = new ArrayList<>();
477 if(specialResidue != null){
478 logger.info(" Special Residue: " + specialResidue.toString());
479 if(esvSystem.isTautomer(specialResidue)){
480 logger.info(format(" %-10s %-10s %-23s %-28s %-28s %-28s", "Residue", "Tautomer", "DeltaGTitr", "DeltaG-SpecialRes=(" + states[0][0] + "," + states[0][1] + ")", "DeltaG-SpecialRes=(" + states[1][0] + "," + states[1][1] + ")", "DeltaG-SpecialRes=(" + states[2][0] + "," + states[2][1] + ")"));
481 } else{
482 logger.info(format(" %-10s %-10s %-23s %-28s %-28s", "Residue", "Tautomer", "DeltaGTitr","DeltaG-SpecialRes=(" + states[0][0] + "," + states[0][1] + ")", "DeltaG-SpecialRes=(" + states[1][0] + "," + states[1][1] + ")"));
483 }
484 } else
485 {
486 logger.info(format(" %-10s %-10s %-23s", "Residue", "Tautomer", "DeltaGTitr"));
487 }
488 int tautomerCount = 0;
489 for(int i : onlyResidueIndex){
490 Residue res = esvSystem.getExtendedResidueList().get(i);
491 line.add(res.toString());
492 line.add("0");
493 line.add(Double.toString(energyLists[i][0]));
494 line.add(Double.toString(energyStdLists[i][0]));
495 for(int j = 1; j < numberOfStates; j++){
496 line.add(Double.toString(energyLists[i][j]));
497 line.add(Double.toString(energyStdLists[i][j]));
498 }
499 if(specialResidue != null && esvSystem.isTautomer(specialResidue)) {
500 logger.info(format(" %-10s %-10s %-10.5f +/- %-5.3f %-10.5f +/- %-5.3f %-10.5f +/- %-5.3f %-10.5f +/- %-5.3f",
501 line.get(0),
502 line.get(1),
503 Double.parseDouble(line.get(2)),
504 Double.parseDouble(line.get(3)),
505 Double.parseDouble(line.get(4)),
506 Double.parseDouble(line.get(5)),
507 Double.parseDouble(line.get(6)),
508 Double.parseDouble(line.get(7)),
509 Double.parseDouble(line.get(8)),
510 Double.parseDouble(line.get(9))));
511 } else if (specialResidue != null) {
512 logger.info(format(" %-10s %-10s %-10.5f +/- %-5.3f %-10.5f +/- %-5.3f %-10.5f +/- %-5.3f",
513 line.get(0),
514 line.get(1),
515 Double.parseDouble(line.get(2)),
516 Double.parseDouble(line.get(3)),
517 Double.parseDouble(line.get(4)),
518 Double.parseDouble(line.get(5)),
519 Double.parseDouble(line.get(6)),
520 Double.parseDouble(line.get(7))));
521 } else {
522 logger.info(format(" %-10s %-10s %-10.5f +/- %-5.3f",
523 line.get(0),
524 line.get(1),
525 Double.parseDouble(line.get(2)),
526 Double.parseDouble(line.get(3))));
527 }
528 line.clear();
529
530 if(esvSystem.isTautomer(res)) {
531 line.add(res.toString());
532 line.add("1");
533 line.add(Double.toString(tautomerEnergyLists[tautomerCount][0]));
534 line.add(Double.toString(tautomerStdLists[tautomerCount][0]));
535 for(int j = 1; j < numberOfStates; j++){
536 line.add(Double.toString(tautomerEnergyLists[tautomerCount][j]));
537 line.add(Double.toString(tautomerStdLists[tautomerCount][j]));
538 }
539 tautomerCount++;
540 if(specialResidue != null && esvSystem.isTautomer(specialResidue)) {
541 logger.info(format(" %-10s %-10s %-10.5f +/- %-5.3f %-10.5f +/- %-5.3f %-10.5f +/- %-5.3f %-10.5f +/- %-5.3f",
542 line.get(0),
543 line.get(1),
544 Double.parseDouble(line.get(2)),
545 Double.parseDouble(line.get(3)),
546 Double.parseDouble(line.get(4)),
547 Double.parseDouble(line.get(5)),
548 Double.parseDouble(line.get(6)),
549 Double.parseDouble(line.get(7)),
550 Double.parseDouble(line.get(8)),
551 Double.parseDouble(line.get(9))));
552 } else if (specialResidue != null) {
553 logger.info(format(" %-10s %-10s %-10.5f +/- %-5.3f %-10.5f +/- %-5.3f %-10.5f +/- %-5.3f",
554 line.get(0),
555 line.get(1),
556 Double.parseDouble(line.get(2)),
557 Double.parseDouble(line.get(3)),
558 Double.parseDouble(line.get(4)),
559 Double.parseDouble(line.get(5)),
560 Double.parseDouble(line.get(6)),
561 Double.parseDouble(line.get(7))));
562 } else {
563 logger.info(format(" %-10s %-10s %-10.5f +/- %-5.3f",
564 line.get(0),
565 line.get(1),
566 Double.parseDouble(line.get(2)),
567 Double.parseDouble(line.get(3))));
568 }
569 line.clear();
570 }
571 }
572 }
573
574 private static ArrayList<Double> getZeroOneDeltas(int i, ExtendedSystem esv,
575 ForceFieldEnergy forceFieldEnergy, double[] x)
576 {
577 ArrayList<Double> deltaU = new ArrayList<Double>();
578 Residue res = esv.getExtendedResidueList().get(i);
579 double titrationState = esv.getTitrationLambda(res);
580 double tautomerState = esv.getTautomerLambda(res);
581
582 if (esv.getTautomerizingResidueList().contains(res)) {
583 esv.setTautomerLambda(res, 1, false);
584
585 esv.setTitrationLambda(res, 0, false);
586 double zeroEnergy = forceFieldEnergy.energy(x, false);
587
588 esv.setTitrationLambda(res, 1, false);
589 double oneEnergy = forceFieldEnergy.energy(x, false);
590
591 esv.setTitrationLambda(res, titrationState, false);
592
593 deltaU.add(oneEnergy - zeroEnergy);
594 esv.setTautomerLambda(res, 0, false);
595 }
596
597 esv.setTitrationLambda(res, 0, false);
598 double zeroEnergy = forceFieldEnergy.energy(x, false);
599
600 esv.setTitrationLambda(res, 1, false);
601 double oneEnergy = forceFieldEnergy.energy(x, false);
602
603 esv.setTitrationLambda(res, titrationState, false);
604
605 deltaU.add(oneEnergy - zeroEnergy);
606 if (esv.getTautomerizingResidueList().contains(res)) {
607 esv.setTautomerLambda(res, tautomerState, false);
608 }
609
610 return deltaU;
611 }
612
613 private static double average(ArrayList<Double> list) {
614 double sum = 0.0;
615 for (Double d : list) {
616 sum += d;
617 }
618 return sum / list.size();
619 }
620
621 private static ArrayList<Double> mult(double a, ArrayList<Double> u) {
622 ArrayList<Double> result = new ArrayList<Double>();
623 for (Double d : u) {
624 result.add(a * d);
625 }
626 return result;
627 }
628
629 private static ArrayList<Double> mult(ArrayList<Double> v, ArrayList<Double> u) {
630 if (v.size() != u.size()) {
631 throw new IllegalArgumentException("Vector sizes must be equal.");
632 }
633 ArrayList<Double> result = new ArrayList<Double>();
634 for (int i = 0; i < v.size(); i++) {
635 result.add(v.get(i) * u.get(i));
636 }
637 return result;
638 }
639
640 private static ArrayList<Double> subtract(double a, ArrayList<Double> u) {
641 ArrayList<Double> result = new ArrayList<Double>();
642 for (Double d : u) {
643 result.add(a - d);
644 }
645 return result;
646 }
647
648 private static ArrayList<Double> exp(ArrayList<Double> u) {
649 ArrayList<Double> result = new ArrayList<Double>();
650 for (Double d : u) {
651 result.add(Math.exp(d));
652 }
653 return result;
654 }
655
656 private static ArrayList<Double> div(ArrayList<Double> a, ArrayList<Double> b){
657 if (a.size() != b.size()) {
658 throw new IllegalArgumentException("Vector sizes must be equal.");
659 }
660 ArrayList<Double> result = new ArrayList<Double>();
661 for (int i = 0; i < a.size(); i++) {
662 result.add(a.get(i) / b.get(i));
663 }
664 return result;
665 }
666 }