1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray.commands.test;
39
40 import ffx.algorithms.cli.AlgorithmsCommand;
41 import ffx.numerics.Potential;
42 import ffx.potential.MolecularAssembly;
43 import ffx.potential.bonded.Atom;
44 import ffx.potential.bonded.LambdaInterface;
45 import ffx.potential.cli.AlchemicalOptions;
46 import ffx.potential.cli.GradientOptions;
47 import ffx.utilities.FFXBinding;
48 import ffx.xray.DiffractionData;
49 import ffx.xray.RefinementEnergy;
50 import ffx.xray.refine.RefinementMode;
51 import ffx.xray.refine.RefinementModel;
52 import ffx.xray.cli.XrayOptions;
53 import org.apache.commons.configuration2.CompositeConfiguration;
54 import picocli.CommandLine.Command;
55 import picocli.CommandLine.Mixin;
56 import picocli.CommandLine.Parameters;
57
58 import java.util.ArrayList;
59 import java.util.Collections;
60 import java.util.List;
61 import java.util.stream.IntStream;
62
63 import static ffx.utilities.StringUtils.parseAtomRanges;
64 import static java.lang.String.format;
65
66
67
68
69
70
71
72
73 @Command(description = " Test Lambda Derivatives on an X-ray target.", name = "xray.test.LambdaGradient")
74 public class LambdaGradient extends AlgorithmsCommand {
75
76 @Mixin
77 private XrayOptions xrayOptions;
78
79 @Mixin
80 private AlchemicalOptions alchemicalOptions;
81
82 @Mixin
83 private GradientOptions gradientOptions;
84
85
86
87
88 @Parameters(arity = "1..*", paramLabel = "files", description = "PDB and Real Space input files.")
89 private List<String> filenames;
90
91 private RefinementEnergy refinementEnergy;
92
93
94
95
96 public LambdaGradient() {
97 super();
98 }
99
100
101
102
103
104
105 public LambdaGradient(String[] args) {
106 super(args);
107 }
108
109
110
111
112
113
114 public LambdaGradient(FFXBinding binding) {
115 super(binding);
116 }
117
118
119
120
121 @Override
122 public LambdaGradient run() {
123
124 if (!init()) {
125 return this;
126 }
127
128 xrayOptions.init();
129
130
131 System.setProperty("lambdaterm", "true");
132
133 String filename;
134 MolecularAssembly[] molecularAssemblies;
135 if (filenames != null && !filenames.isEmpty()) {
136 molecularAssemblies = algorithmFunctions.openAll(filenames.get(0));
137 activeAssembly = molecularAssemblies[0];
138 filename = filenames.get(0);
139 } else if (activeAssembly == null) {
140 logger.info(helpString());
141 return this;
142 } else {
143 filename = activeAssembly.getFile().getAbsolutePath();
144 molecularAssemblies = new MolecularAssembly[]{activeAssembly};
145 }
146
147 alchemicalOptions.setFirstSystemAlchemistry(activeAssembly);
148 alchemicalOptions.setFirstSystemUnchargedAtoms(activeAssembly);
149
150 logger.info("\n Testing X-ray lambda derivatives for " + filename);
151
152
153 CompositeConfiguration properties = molecularAssemblies[0].getProperties();
154 xrayOptions.setProperties(parseResult, properties);
155
156
157 DiffractionData diffractionData = xrayOptions.getDiffractionData(filenames, molecularAssemblies, properties);
158 refinementEnergy = xrayOptions.toXrayEnergy(diffractionData);
159
160 Potential potential = refinementEnergy;
161 LambdaInterface lambdaInterface = refinementEnergy;
162
163
164 double step = gradientOptions.getDx();
165
166 RefinementModel refinementModel = diffractionData.getRefinementModel();
167
168 int n = refinementModel.getNumParameters();
169 Atom[] atoms = refinementModel.getActiveAtoms();
170 int nAtoms = atoms.length;
171 double[] x = new double[n];
172 double[] gradient = new double[n];
173
174
175 double width = 2.0 * step;
176
177 double errTol = 1.0e-3;
178
179 double expGrad = 1000.0;
180
181 double[] lambdaGrad = new double[n];
182 double[][] lambdaGradFD = new double[2][n];
183
184 double initialLambda = alchemicalOptions.getInitialLambda();
185
186
187 double lambda = 0.0;
188 lambdaInterface.setLambda(lambda);
189 potential.getCoordinates(x);
190 double e0 = potential.energy(x, true);
191
192
193 lambda = 1.0;
194 lambdaInterface.setLambda(lambda);
195 double e1 = potential.energy(x, true);
196
197 logger.info(format(" E(0): %20.8f.", e0));
198 logger.info(format(" E(1): %20.8f.", e1));
199 logger.info(format(" E(1)-E(0): %20.8f.\n", e1 - e0));
200
201
202 for (int j = 0; j < 3; j++) {
203 lambda = initialLambda - 0.01 + 0.01 * j;
204
205 if (lambda - step < 0.0) {
206 continue;
207 }
208 if (lambda + step > 1.0) {
209 continue;
210 }
211
212 logger.info(format(" Current lambda value %6.4f", lambda));
213 lambdaInterface.setLambda(lambda);
214
215
216 double e = potential.energyAndGradient(x, gradient);
217
218
219 double dEdL = lambdaInterface.getdEdL();
220 double d2EdL2 = lambdaInterface.getd2EdL2();
221 for (int i = 0; i < n; i++) {
222 lambdaGrad[i] = 0.0;
223 }
224 lambdaInterface.getdEdXdL(lambdaGrad);
225
226
227 lambdaInterface.setLambda(lambda + step);
228 double lp = potential.energyAndGradient(x, lambdaGradFD[0]);
229 double dedlp = lambdaInterface.getdEdL();
230 lambdaInterface.setLambda(lambda - step);
231 double lm = potential.energyAndGradient(x, lambdaGradFD[1]);
232 double dedlm = lambdaInterface.getdEdL();
233
234 double dEdLFD = (lp - lm) / width;
235 double d2EdL2FD = (dedlp - dedlm) / width;
236
237 double err = Math.abs(dEdLFD - dEdL);
238 if (err < errTol) {
239 logger.info(format(" dE/dL passed: %10.6f", err));
240 } else {
241 logger.info(format(" dE/dL failed: %10.6f", err));
242 }
243 logger.info(format(" Numeric: %15.8f", dEdLFD));
244 logger.info(format(" Analytic: %15.8f", dEdL));
245
246 err = Math.abs(d2EdL2FD - d2EdL2);
247 if (err < errTol) {
248 logger.info(format(" d2E/dL2 passed: %10.6f", err));
249 } else {
250 logger.info(format(" d2E/dL2 failed: %10.6f", err));
251 }
252 logger.info(format(" Numeric: %15.8f", d2EdL2FD));
253 logger.info(format(" Analytic: %15.8f", d2EdL2));
254
255 boolean passed = true;
256
257 for (int i = 0; i < nAtoms; i++) {
258 int ii = i * 3;
259 double dX = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
260 double dXa = lambdaGrad[ii];
261 double eX = dX - dXa;
262 ii++;
263 double dY = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
264 double dYa = lambdaGrad[ii];
265 double eY = dY - dYa;
266 ii++;
267 double dZ = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
268 double dZa = lambdaGrad[ii];
269 double eZ = dZ - dZa;
270
271 double error = Math.sqrt(eX * eX + eY * eY + eZ * eZ);
272 if (error < errTol) {
273 logger.fine(format(" dE/dX/dL for Atom %d passed: %10.6f", i + 1, error));
274 } else {
275 logger.info(format(" dE/dX/dL for Atom %d failed: %10.6f", i + 1, error));
276 logger.info(format(" Analytic: (%15.8f, %15.8f, %15.8f)", dXa, dYa, dZa));
277 logger.info(format(" Numeric: (%15.8f, %15.8f, %15.8f)", dX, dY, dZ));
278 passed = false;
279 }
280 }
281 if (passed) {
282 logger.info(format(" dE/dX/dL passed for all atoms"));
283 }
284
285 logger.info("");
286 }
287
288 boolean loopPrint = gradientOptions.getVerbose();
289 refinementEnergy.getCoordinates(x);
290 refinementEnergy.energyAndGradient(x, gradient, loopPrint);
291
292 double[] numeric = new double[3];
293 double avLen = 0.0;
294 int nFailures = 0;
295 double avGrad = 0.0;
296
297
298 List<Integer> atomsToTest;
299 if (gradientOptions.getGradientAtoms().equalsIgnoreCase("NONE")) {
300 logger.info(" The gradient of no atoms will be evaluated.");
301 return this;
302 } else if (gradientOptions.getGradientAtoms().equalsIgnoreCase("ALL")) {
303 logger.info(" Checking gradient for all active atoms.\n");
304 atomsToTest = new ArrayList<>();
305 IntStream.range(0, nAtoms).forEach(val -> atomsToTest.add(val));
306 } else {
307 atomsToTest = parseAtomRanges(" Gradient atoms", gradientOptions.getGradientAtoms(), nAtoms);
308 logger.info(
309 " Checking gradient for active atoms in the range: " + gradientOptions.getGradientAtoms() +
310 "\n");
311 }
312
313 for (int i : atomsToTest) {
314 int i3 = i * 3;
315 int i0 = i3 + 0;
316 int i1 = i3 + 1;
317 int i2 = i3 + 2;
318
319
320 double orig = x[i0];
321 x[i0] = x[i0] + step;
322 double e = refinementEnergy.energy(x, loopPrint);
323 x[i0] = orig - step;
324 e -= refinementEnergy.energy(x, loopPrint);
325 x[i0] = orig;
326 numeric[0] = e / width;
327
328
329 orig = x[i1];
330 x[i1] = x[i1] + step;
331 e = refinementEnergy.energy(x, loopPrint);
332 x[i1] = orig - step;
333 e -= refinementEnergy.energy(x, loopPrint);
334 x[i1] = orig;
335 numeric[1] = e / width;
336
337
338 orig = x[i2];
339 x[i2] = x[i2] + step;
340 e = refinementEnergy.energy(x, loopPrint);
341 x[i2] = orig - step;
342 e -= refinementEnergy.energy(x, loopPrint);
343 x[i2] = orig;
344 numeric[2] = e / width;
345
346 double dx = gradient[i0] - numeric[0];
347 double dy = gradient[i1] - numeric[1];
348 double dz = gradient[i2] - numeric[2];
349 double len = dx * dx + dy * dy + dz * dz;
350 avLen += len;
351 len = Math.sqrt(len);
352
353 double grad2 =
354 gradient[i0] * gradient[i0] + gradient[i1] * gradient[i1] + gradient[i2] * gradient[i2];
355 avGrad += grad2;
356 grad2 = Math.sqrt(grad2);
357
358 if (len > errTol) {
359 logger.info(format(" Atom %d failed: %10.6f.", i + 1, len) +
360 format("\n Analytic: (%12.4f, %12.4f, %12.4f)", gradient[i0], gradient[i1],
361 gradient[i2]) +
362 format("\n Numeric: (%12.4f, %12.4f, %12.4f)", numeric[0], numeric[1], numeric[2]));
363 ++nFailures;
364
365 } else {
366 logger.info(format(" Atom %d passed: %10.6f.", i + 1, len) +
367 format("\n Analytic: (%12.4f, %12.4f, %12.4f)", gradient[i0], gradient[i1],
368 gradient[i2]) +
369 format("\n Numeric: (%12.4f, %12.4f, %12.4f)", numeric[0], numeric[1], numeric[2]));
370 }
371
372 if (grad2 > expGrad) {
373 logger.info(format(" Atom %d has an unusually large gradient: %10.6f", i + 1, grad2));
374 }
375 logger.info("\n");
376 }
377
378 avLen = avLen / nAtoms;
379 avLen = Math.sqrt(avLen);
380 if (avLen > errTol) {
381 logger.info(
382 format(" Test failure: RMSD from analytic solution is %10.6f > %10.6f", avLen, errTol));
383 } else {
384 logger.info(
385 format(" Test success: RMSD from analytic solution is %10.6f < %10.6f", avLen, errTol));
386 }
387 logger.info(format(" Number of atoms failing gradient test: %d", nFailures));
388
389 avGrad = avGrad / nAtoms;
390 avGrad = Math.sqrt(avGrad);
391 if (avGrad > expGrad) {
392 logger.info(format(" Unusually large RMS gradient: %10.6f > %10.6f", avGrad, expGrad));
393 } else {
394 logger.info(format(" RMS gradient: %10.6f", avGrad));
395 }
396
397 diffractionData.getRefinementModel().setRefinementMode(RefinementMode.BFACTORS);
398 refinementEnergy = new RefinementEnergy(diffractionData);
399 n = refinementEnergy.getNumberOfVariables();
400 gradient = new double[n];
401 x = new double[n];
402
403 refinementEnergy.getCoordinates(x);
404 refinementEnergy.energyAndGradient(x, gradient);
405
406 avLen = 0.0;
407 nFailures = 0;
408 avGrad = 0.0;
409 width = 2.0 * step;
410 errTol = 1.0e-3;
411 expGrad = 1000.0;
412
413 for (int i = 0; i < n; i++) {
414
415
416 double orig = x[i];
417 x[i] = x[i] + step;
418 double e = refinementEnergy.energy(x);
419 x[i] = orig - step;
420 e -= refinementEnergy.energy(x);
421 x[i] = orig;
422 double fd = e / width;
423
424 double dB = gradient[i] - fd;
425 double len = dB * dB;
426 avLen += len;
427 len = Math.sqrt(len);
428
429 double grad2 = dB * dB;
430 avGrad += grad2;
431 grad2 = Math.sqrt(grad2);
432
433 if (len > errTol) {
434 logger.info(format(" B-Factor %d failed: %10.6f.", i + 1, len) +
435 format("\n Analytic: %12.4f", gradient[i]) +
436 format("\n Numeric: %12.4f", fd));
437 ++nFailures;
438
439 } else {
440 logger.info(format(" B-Factor %d passed: %10.6f.", i + 1, len) +
441 format("\n Analytic: %12.4f", gradient[i]) +
442 format("\n Numeric: %12.4f", fd));
443 }
444
445 if (grad2 > expGrad) {
446 logger.info(format(" B-Factor %d has an unusually large gradient: %10.6f", i + 1, grad2));
447 }
448 logger.info("\n");
449 }
450
451 avLen = avLen / n;
452 avLen = Math.sqrt(avLen);
453 if (avLen > errTol) {
454 logger.info(
455 format(" Test failure: RMSD from analytic solution is %10.6f > %10.6f", avLen, errTol));
456 } else {
457 logger.info(
458 format(" Test success: RMSD from analytic solution is %10.6f < %10.6f", avLen, errTol));
459 }
460 logger.info(format(" Number of B-Factors failing gradient test: %d", nFailures));
461
462 avGrad = avGrad / n;
463 avGrad = Math.sqrt(avGrad);
464 if (avGrad > expGrad) {
465 logger.info(format(" Unusually large RMS gradient: %10.6f > %10.6f", avGrad, expGrad));
466 } else {
467 logger.info(format(" RMS gradient: %10.6f", avGrad));
468 }
469
470 return this;
471 }
472
473
474
475
476 @Override
477 public List<Potential> getPotentials() {
478 return refinementEnergy == null ? Collections.emptyList() :
479 Collections.singletonList(refinementEnergy);
480 }
481 }