1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.realspace.commands.test;
39
40 import ffx.algorithms.cli.AlgorithmsCommand;
41 import ffx.numerics.Potential;
42 import ffx.potential.MolecularAssembly;
43 import ffx.potential.bonded.LambdaInterface;
44 import ffx.potential.cli.AlchemicalOptions;
45 import ffx.potential.cli.GradientOptions;
46 import ffx.realspace.RealSpaceData;
47 import ffx.realspace.cli.RealSpaceOptions;
48 import ffx.realspace.parsers.RealSpaceFile;
49 import ffx.utilities.FFXBinding;
50 import ffx.xray.RefinementEnergy;
51 import picocli.CommandLine.Command;
52 import picocli.CommandLine.Mixin;
53 import picocli.CommandLine.Parameters;
54
55 import java.util.ArrayList;
56 import java.util.Collections;
57 import java.util.List;
58 import java.util.stream.IntStream;
59
60 import static ffx.utilities.StringUtils.parseAtomRanges;
61
62
63
64
65
66
67
68
69 @Command(description = " Test Lambda Derivatives on a Real Space target.", name = "realspace.test.LambdaGradient")
70 public class LambdaGradient extends AlgorithmsCommand {
71
72 @Mixin
73 private RealSpaceOptions realSpaceOptions;
74
75 @Mixin
76 private AlchemicalOptions alchemicalOptions;
77
78 @Mixin
79 private GradientOptions gradientOptions;
80
81
82
83
84 @Parameters(arity = "1..*", paramLabel = "files", description = "PDB and Real Space input files.")
85 private List<String> filenames;
86
87 private Potential potential;
88
89
90
91
92 public LambdaGradient() {
93 super();
94 }
95
96
97
98
99
100 public LambdaGradient(String[] args) {
101 super(args);
102 }
103
104
105
106
107
108 public LambdaGradient(FFXBinding binding) {
109 super(binding);
110 }
111
112 @Override
113 public LambdaGradient run() {
114
115 if (!init()) {
116 return this;
117 }
118
119
120 System.setProperty("lambdaterm", "true");
121
122 String modelfilename;
123 MolecularAssembly[] assemblies;
124 if (filenames != null && filenames.size() > 0) {
125 assemblies = algorithmFunctions.openAll(filenames.get(0));
126 activeAssembly = assemblies[0];
127 modelfilename = filenames.get(0);
128 } else if (activeAssembly == null) {
129 logger.info(helpString());
130 return this;
131 } else {
132 modelfilename = activeAssembly.getFile().getAbsolutePath();
133 assemblies = new MolecularAssembly[]{activeAssembly};
134 }
135
136 alchemicalOptions.setFirstSystemAlchemistry(activeAssembly);
137 alchemicalOptions.setFirstSystemUnchargedAtoms(activeAssembly);
138
139 logger.info("\n Testing lambda derivatives for " + modelfilename);
140
141 List<RealSpaceFile> mapfiles = realSpaceOptions.processData(filenames, activeAssembly);
142 RealSpaceFile[] mapFileArray = mapfiles.toArray(new RealSpaceFile[0]);
143
144 RealSpaceData realspacedata = new RealSpaceData(activeAssembly, activeAssembly.getProperties(),
145 activeAssembly.getParallelTeam(), mapFileArray);
146
147 RefinementEnergy refinementEnergy = new RefinementEnergy(realspacedata,
148 realSpaceOptions.refinementMode);
149 potential = refinementEnergy;
150 LambdaInterface lambdaInterface = refinementEnergy;
151
152
153 int n = potential.getNumberOfVariables();
154 double[] x = new double[n];
155 double[] gradient = new double[n];
156 double[] lambdaGrad = new double[n];
157 double[][] lambdaGradFD = new double[2][n];
158
159
160 assert (n % 3 == 0);
161 int nAtoms = n / 3;
162
163
164 double lambda = 0.0;
165 lambdaInterface.setLambda(lambda);
166 potential.getCoordinates(x);
167 double e0 = potential.energy(x, true);
168
169
170 lambda = 1.0;
171 lambdaInterface.setLambda(lambda);
172 double e1 = potential.energy(x, true);
173
174 logger.info(String.format(" E(0): %20.8f.", e0));
175 logger.info(String.format(" E(1): %20.8f.", e1));
176 logger.info(String.format(" E(1)-E(0): %20.8f.\n", e1 - e0));
177
178
179 double step = gradientOptions.getDx();
180 double width = 2.0 * step;
181
182
183 double errTol = 1.0e-3;
184
185 double expGrad = 1000.0;
186
187 double initialLambda = alchemicalOptions.getInitialLambda();
188
189
190 for (int j = 0; j < 3; j++) {
191 lambda = initialLambda - 0.01 + 0.01 * j;
192
193 if (lambda - step < 0.0) {
194 continue;
195 }
196 if (lambda + step > 1.0) {
197 continue;
198 }
199
200 logger.info(String.format(" Current lambda value %6.4f", lambda));
201 lambdaInterface.setLambda(lambda);
202
203
204 double e = potential.energyAndGradient(x, gradient);
205
206
207 double dEdL = lambdaInterface.getdEdL();
208 double d2EdL2 = lambdaInterface.getd2EdL2();
209 for (int i = 0; i < n; i++) {
210 lambdaGrad[i] = 0.0;
211 }
212
213 lambdaInterface.getdEdXdL(lambdaGrad);
214
215
216 lambdaInterface.setLambda(lambda + step);
217 double lp = potential.energyAndGradient(x, lambdaGradFD[0]);
218 double dedlp = lambdaInterface.getdEdL();
219 lambdaInterface.setLambda(lambda - step);
220 double lm = potential.energyAndGradient(x, lambdaGradFD[1]);
221 double dedlm = lambdaInterface.getdEdL();
222
223 double dEdLFD = (lp - lm) / width;
224 double d2EdL2FD = (dedlp - dedlm) / width;
225
226 double err = Math.abs(dEdLFD - dEdL);
227 if (err < errTol) {
228 logger.info(String.format(" dE/dL passed: %10.6f", err));
229 } else {
230 logger.info(String.format(" dE/dL failed: %10.6f", err));
231 }
232 logger.info(String.format(" Numeric: %15.8f", dEdLFD));
233 logger.info(String.format(" Analytic: %15.8f", dEdL));
234
235 err = Math.abs(d2EdL2FD - d2EdL2);
236 if (err < errTol) {
237 logger.info(String.format(" d2E/dL2 passed: %10.6f", err));
238 } else {
239 logger.info(String.format(" d2E/dL2 failed: %10.6f", err));
240 }
241 logger.info(String.format(" Numeric: %15.8f", d2EdL2FD));
242 logger.info(String.format(" Analytic: %15.8f", d2EdL2));
243
244 boolean passed = true;
245
246 for (int i = 0; i < nAtoms; i++) {
247 int ii = i * 3;
248 double dX = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
249 double dXa = lambdaGrad[ii];
250 double eX = dX - dXa;
251 ii++;
252 double dY = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
253 double dYa = lambdaGrad[ii];
254 double eY = dY - dYa;
255 ii++;
256 double dZ = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
257 double dZa = lambdaGrad[ii];
258 double eZ = dZ - dZa;
259
260 double error = Math.sqrt(eX * eX + eY * eY + eZ * eZ);
261 if (error < errTol) {
262 logger.fine(String.format(" dE/dX/dL for Atom %d passed: %10.6f", i + 1, error));
263 } else {
264 logger.info(String.format(" dE/dX/dL for Atom %d failed: %10.6f", i + 1, error));
265 logger.info(String.format(" Analytic: (%15.8f, %15.8f, %15.8f)", dXa, dYa, dZa));
266 logger.info(String.format(" Numeric: (%15.8f, %15.8f, %15.8f)", dX, dY, dZ));
267 passed = false;
268 }
269 }
270 if (passed) {
271 logger.info(String.format(" dE/dX/dL passed for all atoms"));
272 }
273
274 logger.info("");
275 }
276
277 boolean loopPrint = gradientOptions.getVerbose();
278 lambdaInterface.setLambda(initialLambda);
279 potential.getCoordinates(x);
280 potential.energyAndGradient(x, gradient, loopPrint);
281
282 logger.info(String.format(" Checking Cartesian coordinate gradient"));
283
284 double[] numeric = new double[3];
285 double avLen = 0.0;
286 int nFailures = 0;
287 double avGrad = 0.0;
288
289
290
291 List<Integer> atomsToTest;
292 if (gradientOptions.getGradientAtoms().equalsIgnoreCase("NONE")) {
293 logger.info(" The gradient of no atoms will be evaluated.");
294 return this;
295 } else if (gradientOptions.getGradientAtoms().equalsIgnoreCase("ALL")) {
296 logger.info(" Checking gradient for all active atoms.\n");
297 atomsToTest = new ArrayList<>();
298 IntStream.range(0, nAtoms).forEach(val -> atomsToTest.add(val));
299 } else {
300 atomsToTest = parseAtomRanges(" Gradient atoms", gradientOptions.getGradientAtoms(), nAtoms);
301 logger.info(
302 " Checking gradient for active atoms in the range: " + gradientOptions.getGradientAtoms() + "\n");
303 }
304
305 for (int i : atomsToTest) {
306 int i3 = i * 3;
307 int i0 = i3 + 0;
308 int i1 = i3 + 1;
309 int i2 = i3 + 2;
310
311
312 double orig = x[i0];
313 x[i0] = x[i0] + step;
314 double e = potential.energyAndGradient(x, lambdaGradFD[0], loopPrint);
315 x[i0] = orig - step;
316 e -= potential.energyAndGradient(x, lambdaGradFD[1], loopPrint);
317 x[i0] = orig;
318 numeric[0] = e / width;
319
320
321 orig = x[i1];
322 x[i1] = x[i1] + step;
323 e = potential.energyAndGradient(x, lambdaGradFD[0], loopPrint);
324 x[i1] = orig - step;
325 e -= potential.energyAndGradient(x, lambdaGradFD[1], loopPrint);
326 x[i1] = orig;
327 numeric[1] = e / width;
328
329
330 orig = x[i2];
331 x[i2] = x[i2] + step;
332 e = potential.energyAndGradient(x, lambdaGradFD[0], loopPrint);
333 x[i2] = orig - step;
334 e -= potential.energyAndGradient(x, lambdaGradFD[1], loopPrint);
335 x[i2] = orig;
336 numeric[2] = e / width;
337
338 double dx = gradient[i0] - numeric[0];
339 double dy = gradient[i1] - numeric[1];
340 double dz = gradient[i2] - numeric[2];
341 double len = dx * dx + dy * dy + dz * dz;
342 avLen += len;
343 len = Math.sqrt(len);
344
345 double grad2 =
346 gradient[i0] * gradient[i0] + gradient[i1] * gradient[i1] + gradient[i2] * gradient[i2];
347 avGrad += grad2;
348 grad2 = Math.sqrt(grad2);
349
350 if (len > errTol) {
351 logger.info(String.format(" Atom %d failed: %10.6f.", i + 1, len)
352 + String.format("\n Analytic: (%12.4f, %12.4f, %12.4f)\n", gradient[i0], gradient[i1],
353 gradient[i2])
354 + String.format(" Numeric: (%12.4f, %12.4f, %12.4f)\n", numeric[0], numeric[1],
355 numeric[2]));
356 ++nFailures;
357
358 } else {
359 logger.info(String.format(" Atom %d passed: %10.6f.", i + 1, len)
360 + String.format("\n Analytic: (%12.4f, %12.4f, %12.4f)\n", gradient[i0], gradient[i1],
361 gradient[i2])
362 + String.format(" Numeric: (%12.4f, %12.4f, %12.4f)", numeric[0], numeric[1],
363 numeric[2]));
364 }
365
366 if (grad2 > expGrad) {
367 logger.info(String.format(" Atom %d has an unusually large gradient: %10.6f", i + 1, grad2));
368 }
369 logger.info("\n");
370 }
371
372 avLen = avLen / nAtoms;
373 avLen = Math.sqrt(avLen);
374 if (avLen > errTol) {
375 logger.info(String.
376 format(" Test failure: RMSD from analytic solution is %10.6f > %10.6f", avLen, errTol));
377 } else {
378 logger.info(String.
379 format(" Test success: RMSD from analytic solution is %10.6f < %10.6f", avLen, errTol));
380 }
381 logger.info(String.format(" Number of atoms failing gradient test: %d", nFailures));
382
383 avGrad = avGrad / nAtoms;
384 avGrad = Math.sqrt(avGrad);
385 if (avGrad > expGrad) {
386 logger.info(String.format(" Unusually large RMS gradient: %10.6f > %10.6f", avGrad, expGrad));
387 } else {
388 logger.info(String.format(" RMS gradient: %10.6f", avGrad));
389 }
390
391 return this;
392 }
393
394 @Override
395 public List<Potential> getPotentials() {
396 return potential == null ? Collections.emptyList() : Collections.singletonList(potential);
397 }
398 }