1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray.commands.test;
39
40 import ffx.algorithms.cli.AlgorithmsCommand;
41 import ffx.numerics.Potential;
42 import ffx.potential.MolecularAssembly;
43 import ffx.potential.bonded.Atom;
44 import ffx.potential.bonded.LambdaInterface;
45 import ffx.potential.cli.AlchemicalOptions;
46 import ffx.potential.cli.GradientOptions;
47 import ffx.utilities.FFXBinding;
48 import ffx.xray.DiffractionData;
49 import ffx.xray.RefinementEnergy;
50 import ffx.xray.RefinementMinimize.RefinementMode;
51 import ffx.xray.cli.XrayOptions;
52 import org.apache.commons.configuration2.CompositeConfiguration;
53 import picocli.CommandLine.Command;
54 import picocli.CommandLine.Mixin;
55 import picocli.CommandLine.Parameters;
56
57 import java.util.ArrayList;
58 import java.util.Collections;
59 import java.util.List;
60 import java.util.stream.IntStream;
61
62 import static ffx.utilities.StringUtils.parseAtomRanges;
63 import static java.lang.String.format;
64
65
66
67
68
69
70
71
72 @Command(description = " Test Lambda Derivatives on an X-ray target.", name = "xray.test.LambdaGradient")
73 public class LambdaGradient extends AlgorithmsCommand {
74
75 @Mixin
76 private XrayOptions xrayOptions;
77
78 @Mixin
79 private AlchemicalOptions alchemicalOptions;
80
81 @Mixin
82 private GradientOptions gradientOptions;
83
84
85
86
87 @Parameters(arity = "1..*", paramLabel = "files", description = "PDB and Real Space input files.")
88 private List<String> filenames;
89
90 private RefinementEnergy refinementEnergy;
91
92
93
94
95 public LambdaGradient() {
96 super();
97 }
98
99
100
101
102
103
104 public LambdaGradient(String[] args) {
105 super(args);
106 }
107
108
109
110
111
112
113 public LambdaGradient(FFXBinding binding) {
114 super(binding);
115 }
116
117
118
119
120 @Override
121 public LambdaGradient run() {
122
123 if (!init()) {
124 return this;
125 }
126
127 xrayOptions.init();
128
129
130 System.setProperty("lambdaterm", "true");
131
132 String filename;
133 MolecularAssembly[] molecularAssemblies;
134 if (filenames != null && !filenames.isEmpty()) {
135 molecularAssemblies = algorithmFunctions.openAll(filenames.get(0));
136 activeAssembly = molecularAssemblies[0];
137 filename = filenames.get(0);
138 } else if (activeAssembly == null) {
139 logger.info(helpString());
140 return this;
141 } else {
142 filename = activeAssembly.getFile().getAbsolutePath();
143 molecularAssemblies = new MolecularAssembly[]{activeAssembly};
144 }
145
146 alchemicalOptions.setFirstSystemAlchemistry(activeAssembly);
147 alchemicalOptions.setFirstSystemUnchargedAtoms(activeAssembly);
148
149 logger.info("\n Testing X-ray lambda derivatives for " + filename);
150
151
152 CompositeConfiguration properties = molecularAssemblies[0].getProperties();
153 xrayOptions.setProperties(parseResult, properties);
154
155
156 DiffractionData diffractionData = xrayOptions.getDiffractionData(filenames, molecularAssemblies, properties);
157 refinementEnergy = xrayOptions.toXrayEnergy(diffractionData);
158
159 Potential potential = refinementEnergy;
160 LambdaInterface lambdaInterface = refinementEnergy;
161
162
163 double step = gradientOptions.getDx();
164
165 int n = refinementEnergy.getNumberOfVariables();
166 Atom[] atoms = refinementEnergy.getActiveAtoms();
167 int nAtoms = atoms.length;
168 double[] x = new double[n];
169 double[] gradient = new double[n];
170
171
172 double width = 2.0 * step;
173
174 double errTol = 1.0e-3;
175
176 double expGrad = 1000.0;
177
178 double[] lambdaGrad = new double[n];
179 double[][] lambdaGradFD = new double[2][n];
180
181 double initialLambda = alchemicalOptions.getInitialLambda();
182
183
184 double lambda = 0.0;
185 lambdaInterface.setLambda(lambda);
186 potential.getCoordinates(x);
187 double e0 = potential.energy(x, true);
188
189
190 lambda = 1.0;
191 lambdaInterface.setLambda(lambda);
192 double e1 = potential.energy(x, true);
193
194 logger.info(format(" E(0): %20.8f.", e0));
195 logger.info(format(" E(1): %20.8f.", e1));
196 logger.info(format(" E(1)-E(0): %20.8f.\n", e1 - e0));
197
198
199 for (int j = 0; j < 3; j++) {
200 lambda = initialLambda - 0.01 + 0.01 * j;
201
202 if (lambda - step < 0.0) {
203 continue;
204 }
205 if (lambda + step > 1.0) {
206 continue;
207 }
208
209 logger.info(format(" Current lambda value %6.4f", lambda));
210 lambdaInterface.setLambda(lambda);
211
212
213 double e = potential.energyAndGradient(x, gradient);
214
215
216 double dEdL = lambdaInterface.getdEdL();
217 double d2EdL2 = lambdaInterface.getd2EdL2();
218 for (int i = 0; i < n; i++) {
219 lambdaGrad[i] = 0.0;
220 }
221 lambdaInterface.getdEdXdL(lambdaGrad);
222
223
224 lambdaInterface.setLambda(lambda + step);
225 double lp = potential.energyAndGradient(x, lambdaGradFD[0]);
226 double dedlp = lambdaInterface.getdEdL();
227 lambdaInterface.setLambda(lambda - step);
228 double lm = potential.energyAndGradient(x, lambdaGradFD[1]);
229 double dedlm = lambdaInterface.getdEdL();
230
231 double dEdLFD = (lp - lm) / width;
232 double d2EdL2FD = (dedlp - dedlm) / width;
233
234 double err = Math.abs(dEdLFD - dEdL);
235 if (err < errTol) {
236 logger.info(format(" dE/dL passed: %10.6f", err));
237 } else {
238 logger.info(format(" dE/dL failed: %10.6f", err));
239 }
240 logger.info(format(" Numeric: %15.8f", dEdLFD));
241 logger.info(format(" Analytic: %15.8f", dEdL));
242
243 err = Math.abs(d2EdL2FD - d2EdL2);
244 if (err < errTol) {
245 logger.info(format(" d2E/dL2 passed: %10.6f", err));
246 } else {
247 logger.info(format(" d2E/dL2 failed: %10.6f", err));
248 }
249 logger.info(format(" Numeric: %15.8f", d2EdL2FD));
250 logger.info(format(" Analytic: %15.8f", d2EdL2));
251
252 boolean passed = true;
253
254 for (int i = 0; i < nAtoms; i++) {
255 int ii = i * 3;
256 double dX = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
257 double dXa = lambdaGrad[ii];
258 double eX = dX - dXa;
259 ii++;
260 double dY = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
261 double dYa = lambdaGrad[ii];
262 double eY = dY - dYa;
263 ii++;
264 double dZ = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
265 double dZa = lambdaGrad[ii];
266 double eZ = dZ - dZa;
267
268 double error = Math.sqrt(eX * eX + eY * eY + eZ * eZ);
269 if (error < errTol) {
270 logger.fine(format(" dE/dX/dL for Atom %d passed: %10.6f", i + 1, error));
271 } else {
272 logger.info(format(" dE/dX/dL for Atom %d failed: %10.6f", i + 1, error));
273 logger.info(format(" Analytic: (%15.8f, %15.8f, %15.8f)", dXa, dYa, dZa));
274 logger.info(format(" Numeric: (%15.8f, %15.8f, %15.8f)", dX, dY, dZ));
275 passed = false;
276 }
277 }
278 if (passed) {
279 logger.info(format(" dE/dX/dL passed for all atoms"));
280 }
281
282 logger.info("");
283 }
284
285 boolean loopPrint = gradientOptions.getVerbose();
286 refinementEnergy.getCoordinates(x);
287 refinementEnergy.energyAndGradient(x, gradient, loopPrint);
288
289 double[] numeric = new double[3];
290 double avLen = 0.0;
291 int nFailures = 0;
292 double avGrad = 0.0;
293
294
295 List<Integer> atomsToTest;
296 if (gradientOptions.getGradientAtoms().equalsIgnoreCase("NONE")) {
297 logger.info(" The gradient of no atoms will be evaluated.");
298 return this;
299 } else if (gradientOptions.getGradientAtoms().equalsIgnoreCase("ALL")) {
300 logger.info(" Checking gradient for all active atoms.\n");
301 atomsToTest = new ArrayList<>();
302 IntStream.range(0, nAtoms).forEach(val -> atomsToTest.add(val));
303 } else {
304 atomsToTest = parseAtomRanges(" Gradient atoms", gradientOptions.getGradientAtoms(), nAtoms);
305 logger.info(
306 " Checking gradient for active atoms in the range: " + gradientOptions.getGradientAtoms() +
307 "\n");
308 }
309
310 for (int i : atomsToTest) {
311 int i3 = i * 3;
312 int i0 = i3 + 0;
313 int i1 = i3 + 1;
314 int i2 = i3 + 2;
315
316
317 double orig = x[i0];
318 x[i0] = x[i0] + step;
319 double e = refinementEnergy.energy(x, loopPrint);
320 x[i0] = orig - step;
321 e -= refinementEnergy.energy(x, loopPrint);
322 x[i0] = orig;
323 numeric[0] = e / width;
324
325
326 orig = x[i1];
327 x[i1] = x[i1] + step;
328 e = refinementEnergy.energy(x, loopPrint);
329 x[i1] = orig - step;
330 e -= refinementEnergy.energy(x, loopPrint);
331 x[i1] = orig;
332 numeric[1] = e / width;
333
334
335 orig = x[i2];
336 x[i2] = x[i2] + step;
337 e = refinementEnergy.energy(x, loopPrint);
338 x[i2] = orig - step;
339 e -= refinementEnergy.energy(x, loopPrint);
340 x[i2] = orig;
341 numeric[2] = e / width;
342
343 double dx = gradient[i0] - numeric[0];
344 double dy = gradient[i1] - numeric[1];
345 double dz = gradient[i2] - numeric[2];
346 double len = dx * dx + dy * dy + dz * dz;
347 avLen += len;
348 len = Math.sqrt(len);
349
350 double grad2 =
351 gradient[i0] * gradient[i0] + gradient[i1] * gradient[i1] + gradient[i2] * gradient[i2];
352 avGrad += grad2;
353 grad2 = Math.sqrt(grad2);
354
355 if (len > errTol) {
356 logger.info(format(" Atom %d failed: %10.6f.", i + 1, len) +
357 format("\n Analytic: (%12.4f, %12.4f, %12.4f)", gradient[i0], gradient[i1],
358 gradient[i2]) +
359 format("\n Numeric: (%12.4f, %12.4f, %12.4f)", numeric[0], numeric[1], numeric[2]));
360 ++nFailures;
361
362 } else {
363 logger.info(format(" Atom %d passed: %10.6f.", i + 1, len) +
364 format("\n Analytic: (%12.4f, %12.4f, %12.4f)", gradient[i0], gradient[i1],
365 gradient[i2]) +
366 format("\n Numeric: (%12.4f, %12.4f, %12.4f)", numeric[0], numeric[1], numeric[2]));
367 }
368
369 if (grad2 > expGrad) {
370 logger.info(format(" Atom %d has an unusually large gradient: %10.6f", i + 1, grad2));
371 }
372 logger.info("\n");
373 }
374
375 avLen = avLen / nAtoms;
376 avLen = Math.sqrt(avLen);
377 if (avLen > errTol) {
378 logger.info(
379 format(" Test failure: RMSD from analytic solution is %10.6f > %10.6f", avLen, errTol));
380 } else {
381 logger.info(
382 format(" Test success: RMSD from analytic solution is %10.6f < %10.6f", avLen, errTol));
383 }
384 logger.info(format(" Number of atoms failing gradient test: %d", nFailures));
385
386 avGrad = avGrad / nAtoms;
387 avGrad = Math.sqrt(avGrad);
388 if (avGrad > expGrad) {
389 logger.info(format(" Unusually large RMS gradient: %10.6f > %10.6f", avGrad, expGrad));
390 } else {
391 logger.info(format(" RMS gradient: %10.6f", avGrad));
392 }
393
394 refinementEnergy = new RefinementEnergy(diffractionData, RefinementMode.BFACTORS);
395 n = refinementEnergy.getNumberOfVariables();
396 gradient = new double[n];
397 x = new double[n];
398
399 refinementEnergy.getCoordinates(x);
400 refinementEnergy.energyAndGradient(x, gradient);
401
402 avLen = 0.0;
403 nFailures = 0;
404 avGrad = 0.0;
405 width = 2.0 * step;
406 errTol = 1.0e-3;
407 expGrad = 1000.0;
408
409 for (int i = 0; i < n; i++) {
410
411
412 double orig = x[i];
413 x[i] = x[i] + step;
414 double e = refinementEnergy.energy(x);
415 x[i] = orig - step;
416 e -= refinementEnergy.energy(x);
417 x[i] = orig;
418 double fd = e / width;
419
420 double dB = gradient[i] - fd;
421 double len = dB * dB;
422 avLen += len;
423 len = Math.sqrt(len);
424
425 double grad2 = dB * dB;
426 avGrad += grad2;
427 grad2 = Math.sqrt(grad2);
428
429 if (len > errTol) {
430 logger.info(format(" B-Factor %d failed: %10.6f.", i + 1, len) +
431 format("\n Analytic: %12.4f", gradient[i]) +
432 format("\n Numeric: %12.4f", fd));
433 ++nFailures;
434
435 } else {
436 logger.info(format(" B-Factor %d passed: %10.6f.", i + 1, len) +
437 format("\n Analytic: %12.4f", gradient[i]) +
438 format("\n Numeric: %12.4f", fd));
439 }
440
441 if (grad2 > expGrad) {
442 logger.info(format(" B-Factor %d has an unusually large gradient: %10.6f", i + 1, grad2));
443 }
444 logger.info("\n");
445 }
446
447 avLen = avLen / n;
448 avLen = Math.sqrt(avLen);
449 if (avLen > errTol) {
450 logger.info(
451 format(" Test failure: RMSD from analytic solution is %10.6f > %10.6f", avLen, errTol));
452 } else {
453 logger.info(
454 format(" Test success: RMSD from analytic solution is %10.6f < %10.6f", avLen, errTol));
455 }
456 logger.info(format(" Number of B-Factors failing gradient test: %d", nFailures));
457
458 avGrad = avGrad / n;
459 avGrad = Math.sqrt(avGrad);
460 if (avGrad > expGrad) {
461 logger.info(format(" Unusually large RMS gradient: %10.6f > %10.6f", avGrad, expGrad));
462 } else {
463 logger.info(format(" RMS gradient: %10.6f", avGrad));
464 }
465
466 return this;
467 }
468
469
470
471
472 @Override
473 public List<Potential> getPotentials() {
474 return refinementEnergy == null ? Collections.emptyList() :
475 Collections.singletonList(refinementEnergy);
476 }
477 }