1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.realspace.commands.test;
39
40 import ffx.algorithms.cli.AlgorithmsCommand;
41 import ffx.numerics.Potential;
42 import ffx.potential.MolecularAssembly;
43 import ffx.potential.bonded.LambdaInterface;
44 import ffx.potential.cli.AlchemicalOptions;
45 import ffx.potential.cli.GradientOptions;
46 import ffx.realspace.RealSpaceData;
47 import ffx.realspace.cli.RealSpaceOptions;
48 import ffx.realspace.parsers.RealSpaceFile;
49 import ffx.utilities.FFXBinding;
50 import ffx.xray.RefinementEnergy;
51 import picocli.CommandLine.Command;
52 import picocli.CommandLine.Mixin;
53 import picocli.CommandLine.Parameters;
54
55 import java.util.ArrayList;
56 import java.util.Collections;
57 import java.util.List;
58 import java.util.stream.IntStream;
59
60 import static ffx.utilities.StringUtils.parseAtomRanges;
61
62
63
64
65
66
67
68
69 @Command(description = " Test Lambda Derivatives on a Real Space target.", name = "realspace.test.LambdaGradient")
70 public class LambdaGradient extends AlgorithmsCommand {
71
72 @Mixin
73 private RealSpaceOptions realSpaceOptions;
74
75 @Mixin
76 private AlchemicalOptions alchemicalOptions;
77
78 @Mixin
79 private GradientOptions gradientOptions;
80
81
82
83
84 @Parameters(arity = "1..*", paramLabel = "files", description = "PDB and Real Space input files.")
85 private List<String> filenames;
86
87 private Potential potential;
88
89
90
91
92 public LambdaGradient() {
93 super();
94 }
95
96
97
98
99
100 public LambdaGradient(String[] args) {
101 super(args);
102 }
103
104
105
106
107
108 public LambdaGradient(FFXBinding binding) {
109 super(binding);
110 }
111
112 @Override
113 public LambdaGradient run() {
114
115 if (!init()) {
116 return this;
117 }
118
119
120 System.setProperty("lambdaterm", "true");
121
122 String modelfilename;
123 MolecularAssembly[] assemblies;
124 if (filenames != null && filenames.size() > 0) {
125 assemblies = algorithmFunctions.openAll(filenames.get(0));
126 activeAssembly = assemblies[0];
127 modelfilename = filenames.get(0);
128 } else if (activeAssembly == null) {
129 logger.info(helpString());
130 return this;
131 } else {
132 modelfilename = activeAssembly.getFile().getAbsolutePath();
133 assemblies = new MolecularAssembly[]{activeAssembly};
134 }
135
136 alchemicalOptions.setFirstSystemAlchemistry(activeAssembly);
137 alchemicalOptions.setFirstSystemUnchargedAtoms(activeAssembly);
138
139 logger.info("\n Testing lambda derivatives for " + modelfilename);
140
141 List<RealSpaceFile> mapfiles = realSpaceOptions.processData(filenames, activeAssembly);
142 RealSpaceFile[] mapFileArray = mapfiles.toArray(new RealSpaceFile[0]);
143
144 RealSpaceData realspacedata = new RealSpaceData(activeAssembly, activeAssembly.getProperties(),
145 activeAssembly.getParallelTeam(), mapFileArray);
146 RefinementEnergy refinementEnergy = new RefinementEnergy(realspacedata);
147 potential = refinementEnergy;
148 LambdaInterface lambdaInterface = refinementEnergy;
149
150
151 int n = potential.getNumberOfVariables();
152 double[] x = new double[n];
153 double[] gradient = new double[n];
154 double[] lambdaGrad = new double[n];
155 double[][] lambdaGradFD = new double[2][n];
156
157
158 assert (n % 3 == 0);
159 int nAtoms = n / 3;
160
161
162 double lambda = 0.0;
163 lambdaInterface.setLambda(lambda);
164 potential.getCoordinates(x);
165 double e0 = potential.energy(x, true);
166
167
168 lambda = 1.0;
169 lambdaInterface.setLambda(lambda);
170 double e1 = potential.energy(x, true);
171
172 logger.info(String.format(" E(0): %20.8f.", e0));
173 logger.info(String.format(" E(1): %20.8f.", e1));
174 logger.info(String.format(" E(1)-E(0): %20.8f.\n", e1 - e0));
175
176
177 double step = gradientOptions.getDx();
178 double width = 2.0 * step;
179
180
181 double errTol = 1.0e-3;
182
183 double expGrad = 1000.0;
184
185 double initialLambda = alchemicalOptions.getInitialLambda();
186
187
188 for (int j = 0; j < 3; j++) {
189 lambda = initialLambda - 0.01 + 0.01 * j;
190
191 if (lambda - step < 0.0) {
192 continue;
193 }
194 if (lambda + step > 1.0) {
195 continue;
196 }
197
198 logger.info(String.format(" Current lambda value %6.4f", lambda));
199 lambdaInterface.setLambda(lambda);
200
201
202 double e = potential.energyAndGradient(x, gradient);
203
204
205 double dEdL = lambdaInterface.getdEdL();
206 double d2EdL2 = lambdaInterface.getd2EdL2();
207 for (int i = 0; i < n; i++) {
208 lambdaGrad[i] = 0.0;
209 }
210
211 lambdaInterface.getdEdXdL(lambdaGrad);
212
213
214 lambdaInterface.setLambda(lambda + step);
215 double lp = potential.energyAndGradient(x, lambdaGradFD[0]);
216 double dedlp = lambdaInterface.getdEdL();
217 lambdaInterface.setLambda(lambda - step);
218 double lm = potential.energyAndGradient(x, lambdaGradFD[1]);
219 double dedlm = lambdaInterface.getdEdL();
220
221 double dEdLFD = (lp - lm) / width;
222 double d2EdL2FD = (dedlp - dedlm) / width;
223
224 double err = Math.abs(dEdLFD - dEdL);
225 if (err < errTol) {
226 logger.info(String.format(" dE/dL passed: %10.6f", err));
227 } else {
228 logger.info(String.format(" dE/dL failed: %10.6f", err));
229 }
230 logger.info(String.format(" Numeric: %15.8f", dEdLFD));
231 logger.info(String.format(" Analytic: %15.8f", dEdL));
232
233 err = Math.abs(d2EdL2FD - d2EdL2);
234 if (err < errTol) {
235 logger.info(String.format(" d2E/dL2 passed: %10.6f", err));
236 } else {
237 logger.info(String.format(" d2E/dL2 failed: %10.6f", err));
238 }
239 logger.info(String.format(" Numeric: %15.8f", d2EdL2FD));
240 logger.info(String.format(" Analytic: %15.8f", d2EdL2));
241
242 boolean passed = true;
243
244 for (int i = 0; i < nAtoms; i++) {
245 int ii = i * 3;
246 double dX = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
247 double dXa = lambdaGrad[ii];
248 double eX = dX - dXa;
249 ii++;
250 double dY = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
251 double dYa = lambdaGrad[ii];
252 double eY = dY - dYa;
253 ii++;
254 double dZ = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
255 double dZa = lambdaGrad[ii];
256 double eZ = dZ - dZa;
257
258 double error = Math.sqrt(eX * eX + eY * eY + eZ * eZ);
259 if (error < errTol) {
260 logger.fine(String.format(" dE/dX/dL for Atom %d passed: %10.6f", i + 1, error));
261 } else {
262 logger.info(String.format(" dE/dX/dL for Atom %d failed: %10.6f", i + 1, error));
263 logger.info(String.format(" Analytic: (%15.8f, %15.8f, %15.8f)", dXa, dYa, dZa));
264 logger.info(String.format(" Numeric: (%15.8f, %15.8f, %15.8f)", dX, dY, dZ));
265 passed = false;
266 }
267 }
268 if (passed) {
269 logger.info(String.format(" dE/dX/dL passed for all atoms"));
270 }
271
272 logger.info("");
273 }
274
275 boolean loopPrint = gradientOptions.getVerbose();
276 lambdaInterface.setLambda(initialLambda);
277 potential.getCoordinates(x);
278 potential.energyAndGradient(x, gradient, loopPrint);
279
280 logger.info(String.format(" Checking Cartesian coordinate gradient"));
281
282 double[] numeric = new double[3];
283 double avLen = 0.0;
284 int nFailures = 0;
285 double avGrad = 0.0;
286
287
288
289 List<Integer> atomsToTest;
290 if (gradientOptions.getGradientAtoms().equalsIgnoreCase("NONE")) {
291 logger.info(" The gradient of no atoms will be evaluated.");
292 return this;
293 } else if (gradientOptions.getGradientAtoms().equalsIgnoreCase("ALL")) {
294 logger.info(" Checking gradient for all active atoms.\n");
295 atomsToTest = new ArrayList<>();
296 IntStream.range(0, nAtoms).forEach(val -> atomsToTest.add(val));
297 } else {
298 atomsToTest = parseAtomRanges(" Gradient atoms", gradientOptions.getGradientAtoms(), nAtoms);
299 logger.info(
300 " Checking gradient for active atoms in the range: " + gradientOptions.getGradientAtoms() + "\n");
301 }
302
303 for (int i : atomsToTest) {
304 int i3 = i * 3;
305 int i0 = i3 + 0;
306 int i1 = i3 + 1;
307 int i2 = i3 + 2;
308
309
310 double orig = x[i0];
311 x[i0] = x[i0] + step;
312 double e = potential.energyAndGradient(x, lambdaGradFD[0], loopPrint);
313 x[i0] = orig - step;
314 e -= potential.energyAndGradient(x, lambdaGradFD[1], loopPrint);
315 x[i0] = orig;
316 numeric[0] = e / width;
317
318
319 orig = x[i1];
320 x[i1] = x[i1] + step;
321 e = potential.energyAndGradient(x, lambdaGradFD[0], loopPrint);
322 x[i1] = orig - step;
323 e -= potential.energyAndGradient(x, lambdaGradFD[1], loopPrint);
324 x[i1] = orig;
325 numeric[1] = e / width;
326
327
328 orig = x[i2];
329 x[i2] = x[i2] + step;
330 e = potential.energyAndGradient(x, lambdaGradFD[0], loopPrint);
331 x[i2] = orig - step;
332 e -= potential.energyAndGradient(x, lambdaGradFD[1], loopPrint);
333 x[i2] = orig;
334 numeric[2] = e / width;
335
336 double dx = gradient[i0] - numeric[0];
337 double dy = gradient[i1] - numeric[1];
338 double dz = gradient[i2] - numeric[2];
339 double len = dx * dx + dy * dy + dz * dz;
340 avLen += len;
341 len = Math.sqrt(len);
342
343 double grad2 =
344 gradient[i0] * gradient[i0] + gradient[i1] * gradient[i1] + gradient[i2] * gradient[i2];
345 avGrad += grad2;
346 grad2 = Math.sqrt(grad2);
347
348 if (len > errTol) {
349 logger.info(String.format(" Atom %d failed: %10.6f.", i + 1, len)
350 + String.format("\n Analytic: (%12.4f, %12.4f, %12.4f)\n", gradient[i0], gradient[i1],
351 gradient[i2])
352 + String.format(" Numeric: (%12.4f, %12.4f, %12.4f)\n", numeric[0], numeric[1],
353 numeric[2]));
354 ++nFailures;
355
356 } else {
357 logger.info(String.format(" Atom %d passed: %10.6f.", i + 1, len)
358 + String.format("\n Analytic: (%12.4f, %12.4f, %12.4f)\n", gradient[i0], gradient[i1],
359 gradient[i2])
360 + String.format(" Numeric: (%12.4f, %12.4f, %12.4f)", numeric[0], numeric[1],
361 numeric[2]));
362 }
363
364 if (grad2 > expGrad) {
365 logger.info(String.format(" Atom %d has an unusually large gradient: %10.6f", i + 1, grad2));
366 }
367 logger.info("\n");
368 }
369
370 avLen = avLen / nAtoms;
371 avLen = Math.sqrt(avLen);
372 if (avLen > errTol) {
373 logger.info(String.
374 format(" Test failure: RMSD from analytic solution is %10.6f > %10.6f", avLen, errTol));
375 } else {
376 logger.info(String.
377 format(" Test success: RMSD from analytic solution is %10.6f < %10.6f", avLen, errTol));
378 }
379 logger.info(String.format(" Number of atoms failing gradient test: %d", nFailures));
380
381 avGrad = avGrad / nAtoms;
382 avGrad = Math.sqrt(avGrad);
383 if (avGrad > expGrad) {
384 logger.info(String.format(" Unusually large RMS gradient: %10.6f > %10.6f", avGrad, expGrad));
385 } else {
386 logger.info(String.format(" RMS gradient: %10.6f", avGrad));
387 }
388
389 return this;
390 }
391
392 @Override
393 public List<Potential> getPotentials() {
394 return potential == null ? Collections.emptyList() : Collections.singletonList(potential);
395 }
396 }