1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.commands.test;
39
40 import ffx.numerics.Potential;
41 import ffx.potential.ForceFieldEnergy;
42 import ffx.potential.bonded.Atom;
43 import ffx.potential.bonded.Residue;
44 import ffx.potential.cli.AtomSelectionOptions;
45 import ffx.potential.cli.GradientOptions;
46 import ffx.potential.cli.PotentialCommand;
47 import ffx.potential.extended.ExtendedSystem;
48 import ffx.potential.terms.AnglePotentialEnergy;
49 import ffx.potential.terms.BondPotentialEnergy;
50 import ffx.potential.terms.ImproperTorsionPotentialEnergy;
51 import ffx.potential.terms.OutOfPlaneBendPotentialEnergy;
52 import ffx.potential.terms.PiOrbitalTorsionPotentialEnergy;
53 import ffx.potential.terms.StretchBendPotentialEnergy;
54 import ffx.potential.terms.TorsionPotentialEnergy;
55 import ffx.potential.terms.TorsionTorsionPotentialEnergy;
56 import ffx.potential.terms.UreyBradleyPotentialEnergy;
57 import ffx.utilities.FFXBinding;
58 import picocli.CommandLine.Command;
59 import picocli.CommandLine.Mixin;
60 import picocli.CommandLine.Option;
61 import picocli.CommandLine.Parameters;
62
63 import java.util.ArrayList;
64 import java.util.Collections;
65 import java.util.HashMap;
66 import java.util.List;
67 import java.util.Map;
68 import java.util.stream.IntStream;
69
70 import static ffx.utilities.StringUtils.parseAtomRanges;
71 import static java.lang.String.format;
72 import static org.apache.commons.math3.util.FastMath.abs;
73 import static org.apache.commons.math3.util.FastMath.pow;
74 import static org.apache.commons.math3.util.FastMath.sqrt;
75
76
77
78
79
80
81
82
83 @Command(description = " Test the potential energy gradient for CpHMD.", name = "test.PhGradient")
84 public class PhGradient extends PotentialCommand {
85
86 @Mixin
87 GradientOptions gradientOptions;
88
89 @Mixin
90 AtomSelectionOptions atomSelectionOptions;
91
92
93 @Option(names = {"--pH", "--constantPH"}, paramLabel = "7.4",
94 description = "Constant pH value for the test.")
95 double pH = 7.4;
96
97
98 @Option(names = {"--esvLambda"}, paramLabel = "0.5",
99 description = "ESV Lambda at which to test gradient.")
100 double esvLambda = 0.5;
101
102
103 @Option(names = {"--scanLambdas"}, paramLabel = "false",
104 description = "Scan titration and tautomer lambda landscape.")
105 boolean scan = false;
106
107
108 @Option(names = {"--testEndStateEnergies"}, paramLabel = "false",
109 description = "Test both ESV energy end states as if the polarization damping factor is initialized from the respective protonated or deprotonated state")
110 boolean testEndstateEnergies = false;
111
112
113 @Parameters(arity = "1", paramLabel = "file", description = "The atomic coordinate file in PDB format.")
114 String filename = null;
115
116 private ForceFieldEnergy energy;
117
118 public HashMap<String, double[]> endstateEnergyMap = new HashMap<>();
119 public int nFailures = 0;
120 public int nESVFailures = 0;
121 public double minEnergy = 0.0;
122 public String minLambdaList = "";
123
124
125 public PhGradient() {
126 super();
127 }
128
129
130
131
132
133 public PhGradient(FFXBinding binding) {
134 super(binding);
135 }
136
137
138
139
140
141 public PhGradient(String[] args) {
142 super(args);
143 }
144
145
146 @Override
147 public PhGradient run() {
148
149 if (!init()) {
150 return this;
151 }
152 activeAssembly = getActiveAssembly(filename);
153 if (activeAssembly == null) {
154 logger.info(helpString());
155 return this;
156 }
157
158
159 filename = activeAssembly.getFile().getAbsolutePath();
160
161 logger.info("\n Testing the atomic coordinate gradient of " + filename + "\n");
162
163
164 energy = activeAssembly.getPotentialEnergy();
165
166 ExtendedSystem esvSystem = new ExtendedSystem(activeAssembly, pH, null);
167 esvSystem.setConstantPh(pH);
168 List<Residue> extendedResidues = esvSystem.getExtendedResidueList();
169 List<Residue> titratingResidues = esvSystem.getTitratingResidueList();
170 List<Residue> tautomerResidues = esvSystem.getTautomerizingResidueList();
171
172 int numESVs = esvSystem.getExtendedResidueList().size();
173 energy.attachExtendedSystem(esvSystem);
174 logger.info(format(" Attached extended system with %d residues.", numESVs));
175
176
177 for (Residue residue : extendedResidues) {
178 esvSystem.setTitrationLambda(residue, esvLambda);
179 esvSystem.setTautomerLambda(residue, esvLambda);
180 }
181
182 Atom[] atoms = activeAssembly.getAtomArray();
183 int nAtoms = atoms.length;
184
185
186 atomSelectionOptions.setActiveAtoms(activeAssembly);
187
188
189 double step = gradientOptions.getDx();
190 logger.info(" Finite-difference step size:\t" + step);
191
192
193 boolean print = gradientOptions.getVerbose();
194 logger.info(" Verbose printing:\t\t" + print);
195
196
197 List<Integer> atomsToTest;
198 if (gradientOptions.getGradientAtoms().equalsIgnoreCase("NONE")) {
199 logger.info(" The gradient of no atoms will be evaluated.");
200 return this;
201 } else if (gradientOptions.getGradientAtoms().equalsIgnoreCase("ALL")) {
202 logger.info(" Checking gradient for all active atoms.\n");
203 atomsToTest = new ArrayList<>();
204 IntStream.range(0, nAtoms).forEach(val -> atomsToTest.add(val));
205 } else {
206 atomsToTest = parseAtomRanges(" Gradient atoms", gradientOptions.getGradientAtoms(), nAtoms);
207 logger.info(
208 " Checking gradient for active atoms in the range: " + gradientOptions.getGradientAtoms() +
209 "\n");
210 }
211
212
213 Map<Integer, Integer> allToActive = new HashMap<>();
214 int nActive = 0;
215 for (int i = 0; i < nAtoms; i++) {
216 Atom atom = atoms[i];
217 if (atom.isActive()) {
218 allToActive.put(i, nActive);
219 nActive++;
220 }
221 }
222
223
224 int n = energy.getNumberOfVariables();
225 double[] x = new double[n];
226 double[] g = new double[n];
227 energy.getCoordinates(x);
228 energy.energyAndGradient(x, g);
229 int index = 0;
230 double[][] allAnalytic = new double[nAtoms][3];
231 for (Atom a : atoms) {
232 a.getXYZGradient(allAnalytic[index++]);
233 }
234
235
236 double expGrad = 1000.0;
237 double gradientTolerance = gradientOptions.getTolerance();
238 double width = 2.0 * step;
239 double avLen = 0.0;
240 double avGrad = 0.0;
241 double expGrad2 = expGrad * expGrad;
242
243 int nTested = 0;
244 for (int k : atomsToTest) {
245 Atom a0 = atoms[k];
246 if (!a0.isActive()) {
247 continue;
248 }
249
250 nTested++;
251 double[] analytic = allAnalytic[k];
252
253
254 int ia = allToActive.get(k);
255 int i3 = ia * 3;
256 int i0 = i3 + 0;
257 int i1 = i3 + 1;
258 int i2 = i3 + 2;
259 double[] numeric = new double[3];
260
261
262 double orig = x[i0];
263 x[i0] = x[i0] + step;
264 double e = energy.energy(x);
265 x[i0] = orig - step;
266 e -= energy.energy(x);
267 x[i0] = orig;
268 numeric[0] = e / width;
269
270
271 orig = x[i1];
272 x[i1] = x[i1] + step;
273 e = energy.energy(x);
274 x[i1] = orig - step;
275 e -= energy.energy(x);
276 x[i1] = orig;
277 numeric[1] = e / width;
278
279
280 orig = x[i2];
281 x[i2] = x[i2] + step;
282 e = energy.energy(x);
283 x[i2] = orig - step;
284 e -= energy.energy(x);
285 x[i2] = orig;
286 numeric[2] = e / width;
287
288 double dx = analytic[0] - numeric[0];
289 double dy = analytic[1] - numeric[1];
290 double dz = analytic[2] - numeric[2];
291 double len = dx * dx + dy * dy + dz * dz;
292 avLen += len;
293 len = sqrt(len);
294
295 double grad2 =
296 analytic[0] * analytic[0] + analytic[1] * analytic[1] + analytic[2] * analytic[2];
297 avGrad += grad2;
298
299 if (len > gradientTolerance) {
300 logger.info(format(" %s\n Failed: %10.6f\n", a0.toString(), len) +
301 format(" Analytic: (%12.4f, %12.4f, %12.4f)\n", analytic[0], analytic[1], analytic[2]) +
302 format(" Numeric: (%12.4f, %12.4f, %12.4f)\n", numeric[0], numeric[1], numeric[2]));
303 ++nFailures;
304 } else {
305 logger.info(format(" %s\n Passed: %10.6f\n", a0.toString(), len) +
306 format(" Analytic: (%12.4f, %12.4f, %12.4f)\n", analytic[0], analytic[1], analytic[2]) +
307 format(" Numeric: (%12.4f, %12.4f, %12.4f)", numeric[0], numeric[1], numeric[2]));
308 }
309
310 if (grad2 > expGrad2) {
311 logger.info(format(" Atom %d has an unusually large gradient: %10.6f", ia + 1, Math.sqrt(
312 grad2)));
313 }
314 logger.info("\n");
315 }
316
317 avLen = avLen / nTested;
318 avLen = sqrt(avLen);
319 if (avLen > gradientTolerance) {
320 logger.info(format(" Test failure: RMSD from analytic solution is %10.6f > %10.6f", avLen,
321 gradientTolerance));
322 } else {
323 logger.info(format(" Test success: RMSD from analytic solution is %10.6f < %10.6f", avLen,
324 gradientTolerance));
325 }
326 logger.info(format(" Number of atoms failing analytic test: %d", nFailures));
327
328 avGrad = avGrad / nTested;
329 avGrad = sqrt(avGrad);
330 if (avGrad > expGrad) {
331 logger.info(format(" Unusually large RMS gradient: %10.6f > %10.6f", avGrad, expGrad));
332 } else {
333 logger.info(format(" RMS gradient: %10.6f", avGrad));
334 }
335 energy.setCoordinates(x);
336 energy.getCoordinates(x);
337 energy.energyAndGradient(x, g);
338 double[] esvDerivs = esvSystem.getDerivatives();
339
340
341
342 for (int i = 0; i < titratingResidues.size(); i++) {
343 double eMinusTitr = 0.0;
344 double ePlusTitr = 0.0;
345 double eMinusTaut = 0.0;
346 double ePlusTaut = 0.0;
347 Residue residue = titratingResidues.get(i);
348 int tautomerIndex = tautomerResidues.indexOf(residue) + titratingResidues.size();
349
350 if (esvLambda + step > 1.0) {
351 logger.info("Backward finite difference being applied. Consider using a smaller step size than the default in this case.\n");
352 esvSystem.setTitrationLambda(residue, esvLambda - 2 * step);
353 eMinusTitr = energy.energy(x);
354 esvSystem.setTitrationLambda(residue, esvLambda);
355 ePlusTitr = energy.energy(x);
356
357 if (esvSystem.isTautomer(residue)) {
358 esvSystem.setTautomerLambda(residue, esvLambda - 2 * step);
359 eMinusTaut = energy.energy(x);
360 esvSystem.setTautomerLambda(residue, esvLambda);
361 ePlusTaut = energy.energy(x);
362 }
363 }
364
365
366 else if (esvLambda - step < 0.0) {
367 logger.info("Forward finite difference being applied. Consider using a smaller step size than the default in this case.\n");
368 esvSystem.setTitrationLambda(residue, esvLambda + 2 * step);
369 ePlusTitr = energy.energy(x);
370 esvSystem.setTitrationLambda(residue, esvLambda);
371 eMinusTitr = energy.energy(x);
372
373 if (esvSystem.isTautomer(residue)) {
374 esvSystem.setTautomerLambda(residue, esvLambda + 2 * step);
375 ePlusTaut = energy.energy(x);
376 esvSystem.setTautomerLambda(residue, esvLambda);
377 eMinusTaut = energy.energy(x);
378 }
379 }
380
381
382 else {
383 esvSystem.setTitrationLambda(residue, esvLambda + step);
384 ePlusTitr = energy.energy(x);
385 esvSystem.setTitrationLambda(residue, esvLambda - step);
386 eMinusTitr = energy.energy(x);
387 esvSystem.setTitrationLambda(residue, esvLambda);
388
389 if (esvSystem.isTautomer(residue)) {
390 esvSystem.setTautomerLambda(residue, esvLambda + step);
391 ePlusTaut = energy.energy(x);
392 esvSystem.setTautomerLambda(residue, esvLambda - step);
393 eMinusTaut = energy.energy(x);
394 esvSystem.setTautomerLambda(residue, esvLambda);
395 }
396 }
397
398 double fdDerivTitr = (ePlusTitr - eMinusTitr) / width;
399 double errorTitr = abs(fdDerivTitr - esvDerivs[i]);
400
401 if (errorTitr > gradientTolerance) {
402 logger.info(format(" Residue: %s Chain: %s ESV %d\n Failed: %10.6f\n", residue.toString(), residue.getChainID(), i, errorTitr) +
403 format(" Analytic: %12.4f vs. Numeric: %12.4f\n", esvDerivs[i], fdDerivTitr));
404 ++nESVFailures;
405 } else {
406 logger.info(format(" Residue: %s Chain: %s ESV %d\n Passed: %10.6f\n", residue.toString(), residue.getChainID(), i, errorTitr) +
407 format(" Analytic: %12.4f vs. Numeric: %12.4f\n", esvDerivs[i], fdDerivTitr));
408 }
409
410 if (esvSystem.isTautomer(residue)) {
411 double fdDerivTaut = (ePlusTaut - eMinusTaut) / width;
412 int ti = tautomerIndex;
413 double errorTaut = abs(fdDerivTaut - esvDerivs[ti]);
414 if (errorTaut > gradientTolerance) {
415 logger.info(format(" Residue: %s (Tautomer) Chain: %s ESV %d\n Failed: %10.6f\n", residue.toString(), residue.getChainID(), ti, errorTaut) +
416 format(" Analytic: %12.4f vs. Numeric: %12.4f\n", esvDerivs[ti], fdDerivTaut));
417 ++nESVFailures;
418 } else {
419 logger.info(format(" Residue: %s (Tautomer) Chain: %s ESV %d\n Passed: %10.6f\n", residue.toString(), residue.getChainID(), ti, errorTaut) +
420 format(" Analytic: %12.4f vs. Numeric: %12.4f\n", esvDerivs[ti], fdDerivTaut));
421 }
422 }
423 if (nESVFailures > 0) {
424 logger.info(format(" %d ESVs failed the gradient test.\n", nESVFailures));
425 }
426 }
427 if (nESVFailures == 0) {
428 logger.info(" All ESVs passed the gradient test.\n");
429 }
430
431 if (scan) {
432 for (Residue residue : esvSystem.getTitratingResidueList()) {
433 esvSystem.setTitrationLambda(residue, 0.0);
434 esvSystem.setTautomerLambda(residue, 0.0);
435 }
436 scanLambdas(esvSystem, energy, x);
437 printPermutations(esvSystem, titratingResidues.size(), energy, x);
438 }
439
440 if (testEndstateEnergies) {
441 testEndState(x, esvSystem, 0.0, 0.0);
442 testEndState(x, esvSystem, 1.0, 0.0);
443 }
444
445 return this;
446 }
447
448 private void printPermutations(ExtendedSystem esvSystem, int numESVs, ForceFieldEnergy energy, double[] x) {
449 for (Residue residue : esvSystem.getTitratingResidueList()) {
450 esvSystem.setTitrationLambda(residue, 0.0);
451 }
452 energy.getCoordinates(x);
453 printPermutationsR(esvSystem, numESVs - 1, energy, x);
454 logger.info("Minimum Energy:" + minEnergy + " acheived with lambdas: " + minLambdaList);
455 }
456
457 private void printPermutationsR(ExtendedSystem esvSystem, int esvID, ForceFieldEnergy energy, double[] x) {
458 for (int i = 0; i <= 1; i++) {
459 Residue residue = esvSystem.getTitratingResidueList().get(esvID);
460 esvSystem.setTitrationLambda(residue, (double) i);
461 if (esvID != 0) {
462 printPermutationsR(esvSystem, esvID - 1, energy, x);
463 } else {
464 String lambdaList = esvSystem.getLambdaList();
465 logger.info(format("Lambda List: %s", lambdaList));
466 double stateEnergy = energy.energy(x, true);
467 if (stateEnergy < minEnergy) {
468 minEnergy = stateEnergy;
469 minLambdaList = lambdaList;
470 }
471 logger.info(format("\n"));
472 }
473 }
474 }
475
476 private void scanLambdas(ExtendedSystem esvSystem, ForceFieldEnergy energy, double[] x) {
477 int nTitrESVs = esvSystem.getTitratingResidueList().size();
478 double[][][] peLandscape = new double[nTitrESVs][11][11];
479 for (int i = 0; i < nTitrESVs; i++) {
480 Residue residue = esvSystem.getTitratingResidueList().get(i);
481 for (int j = 0; j < 11; j++) {
482 esvSystem.setTitrationLambda(residue, (double) j / 10.0);
483 for (int k = 0; k < 11; k++) {
484 esvSystem.setTautomerLambda(residue, (double) k / 10.0);
485 peLandscape[i][j][k] = energy.energy(x, false);
486 }
487 }
488 esvSystem.setTitrationLambda(residue, 0.0);
489 esvSystem.setTautomerLambda(residue, 0.0);
490 }
491 StringBuilder tautomerHeader = new StringBuilder(" X→ ");
492 for (int k = 0; k < 11; k++) {
493 double lb = (double) k / 10;
494 tautomerHeader.append(String.format("%1$12s", "[" + lb + "]"));
495 }
496 tautomerHeader.append("\nλ↓");
497 for (int i = 0; i < nTitrESVs; i++) {
498 logger.info(format("ESV: %d \n", i));
499 logger.info(tautomerHeader.toString());
500 for (int j = 0; j < 11; j++) {
501 double lb = (double) j / 10;
502 StringBuilder histogram = new StringBuilder();
503 for (int k = 0; k < 11; k++) {
504 StringBuilder hisvalue = new StringBuilder();
505 String value = String.format("%5.4f", peLandscape[i][j][k]);
506 hisvalue.append(String.format("%1$12s", value));
507 histogram.append(hisvalue);
508 }
509 logger.info("[" + lb + "] " + histogram);
510 }
511 logger.info("\n");
512 }
513 }
514
515 private void testEndState(double[] x, ExtendedSystem esvSystem, double titrLambda, double tautLambda) {
516 for (Residue residue : esvSystem.getTitratingResidueList()) {
517 esvSystem.setTitrationLambda(residue, titrLambda);
518 esvSystem.setTautomerLambda(residue, tautLambda);
519 }
520
521
522
523
524 for (Atom atom : esvSystem.getExtendedAtoms()) {
525 int atomIndex = atom.getArrayIndex();
526 if (esvSystem.isTitratingHeavy(atomIndex)) {
527
528 if (atom.getPolarizeType() != null) {
529 double endstatePolar = esvSystem.getTitrationUtils().getPolarizability(atom, titrLambda, tautLambda, atom.getPolarizeType().polarizability);
530 double sixth = 1.0 / 6.0;
531 atom.getPolarizeType().pdamp = pow(endstatePolar, sixth);
532 }
533 }
534 }
535
536
537 String lambdaList = esvSystem.getLambdaList();
538 energy.setCoordinates(x);
539 energy.getCoordinates(x);
540 double stateEnergy = energy.energy(x, true);
541 double[] energyAndInteractionList = new double[26];
542
543 BondPotentialEnergy bondPotentialEnergy = energy.getBondPotentialEnergy();
544 if (bondPotentialEnergy != null) {
545 energyAndInteractionList[0] = bondPotentialEnergy.getEnergy();
546 energyAndInteractionList[1] = (double) bondPotentialEnergy.getNumberOfBonds();
547 }
548
549 AnglePotentialEnergy anglePotentialEnergy = energy.getAnglePotentialEnergy();
550 if (anglePotentialEnergy != null) {
551 energyAndInteractionList[2] = anglePotentialEnergy.getEnergy();
552 energyAndInteractionList[3] = (double) anglePotentialEnergy.getNumberOfAngles();
553 }
554
555 StretchBendPotentialEnergy stretchBendPotentialEnergy = energy.getStretchBendPotentialEnergy();
556 if (stretchBendPotentialEnergy != null) {
557 energyAndInteractionList[4] = stretchBendPotentialEnergy.getEnergy();
558 energyAndInteractionList[5] = (double) stretchBendPotentialEnergy.getNumberOfStretchBends();
559 }
560
561 UreyBradleyPotentialEnergy ureyBradleyPotentialEnergy = energy.getUreyBradleyPotentialEnergy();
562 if (ureyBradleyPotentialEnergy != null) {
563 energyAndInteractionList[6] = ureyBradleyPotentialEnergy.getEnergy();
564 energyAndInteractionList[7] = (double) ureyBradleyPotentialEnergy.getNumberOfUreyBradleys();
565 }
566
567 OutOfPlaneBendPotentialEnergy ofPlaneBendPotentialEnergy = energy.getOutOfPlaneBendPotentialEnergy();
568 if (ofPlaneBendPotentialEnergy != null) {
569 energyAndInteractionList[8] = ofPlaneBendPotentialEnergy.getEnergy();
570 energyAndInteractionList[9] = (double) ofPlaneBendPotentialEnergy.getNumberOfOutOfPlaneBends();
571 }
572
573 TorsionPotentialEnergy torsionPotentialEnergy = energy.getTorsionPotentialEnergy();
574 if (torsionPotentialEnergy != null) {
575 energyAndInteractionList[10] = torsionPotentialEnergy.getEnergy();
576 energyAndInteractionList[11] = (double) torsionPotentialEnergy.getNumberOfTorsions();
577 }
578
579 ImproperTorsionPotentialEnergy improperTorsionPotentialEnergy = energy.getImproperTorsionPotentialEnergy();
580 if (improperTorsionPotentialEnergy != null) {
581 energyAndInteractionList[12] = improperTorsionPotentialEnergy.getEnergy();
582 energyAndInteractionList[13] = (double) improperTorsionPotentialEnergy.getNumberOfImproperTorsions();
583 }
584
585 PiOrbitalTorsionPotentialEnergy piOrbitalTorsionPotentialEnergy = energy.getPiOrbitalTorsionPotentialEnergy();
586 if (piOrbitalTorsionPotentialEnergy != null) {
587 energyAndInteractionList[14] = piOrbitalTorsionPotentialEnergy.getEnergy();
588 energyAndInteractionList[15] = (double) piOrbitalTorsionPotentialEnergy.getNumberOfPiOrbitalTorsions();
589 }
590
591 TorsionTorsionPotentialEnergy torsionTorsionPotentialEnergy = energy.getTorsionTorsionPotentialEnergy();
592 if (torsionTorsionPotentialEnergy != null) {
593 energyAndInteractionList[16] = torsionTorsionPotentialEnergy.getEnergy();
594 energyAndInteractionList[17] = (double) torsionTorsionPotentialEnergy.getNumberOfTorsionTorsions();
595 }
596
597 energyAndInteractionList[18] = energy.getVanDerWaalsEnergy();
598 energyAndInteractionList[19] = (double) energy.getVanDerWaalsInteractions();
599
600 energyAndInteractionList[20] = energy.getPermanentMultipoleEnergy();
601 energyAndInteractionList[21] = (double) energy.getPermanentInteractions();
602
603 energyAndInteractionList[22] = energy.getPolarizationEnergy();
604 energyAndInteractionList[23] = (double) energy.getPermanentInteractions();
605
606 energyAndInteractionList[24] = energy.getEsvBiasEnergy();
607
608 energyAndInteractionList[25] = energy.getTotalEnergy();
609 endstateEnergyMap.put(lambdaList, energyAndInteractionList);
610 }
611
612 @Override
613 public List<Potential> getPotentials() {
614 List<Potential> potentials;
615 if (energy == null) {
616 potentials = Collections.emptyList();
617 } else {
618 potentials = Collections.singletonList(energy);
619 }
620 return potentials;
621 }
622 }