View Javadoc
1   //******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  //******************************************************************************
38  package ffx.xray.commands.test;
39  
40  import ffx.algorithms.cli.AlgorithmsCommand;
41  import ffx.numerics.Potential;
42  import ffx.potential.MolecularAssembly;
43  import ffx.potential.bonded.Atom;
44  import ffx.potential.bonded.LambdaInterface;
45  import ffx.potential.cli.AlchemicalOptions;
46  import ffx.potential.cli.GradientOptions;
47  import ffx.utilities.FFXBinding;
48  import ffx.xray.DiffractionData;
49  import ffx.xray.RefinementEnergy;
50  import ffx.xray.RefinementMinimize.RefinementMode;
51  import ffx.xray.cli.XrayOptions;
52  import org.apache.commons.configuration2.CompositeConfiguration;
53  import picocli.CommandLine.Command;
54  import picocli.CommandLine.Mixin;
55  import picocli.CommandLine.Parameters;
56  
57  import java.util.ArrayList;
58  import java.util.Collections;
59  import java.util.List;
60  import java.util.stream.IntStream;
61  
62  import static ffx.utilities.StringUtils.parseAtomRanges;
63  import static java.lang.String.format;
64  
65  /**
66   * The X-ray test Lambda Gradient script.
67   * <br>
68   * Usage:
69   * <br>
70   * ffxc xray.test.LambdaGradient [options] &lt;filename&gt;
71   */
72  @Command(description = " Test Lambda Derivatives on an X-ray target.", name = "xray.test.LambdaGradient")
73  public class LambdaGradient extends AlgorithmsCommand {
74  
75    @Mixin
76    private XrayOptions xrayOptions;
77  
78    @Mixin
79    private AlchemicalOptions alchemicalOptions;
80  
81    @Mixin
82    private GradientOptions gradientOptions;
83  
84    /**
85     * One or more filenames.
86     */
87    @Parameters(arity = "1..*", paramLabel = "files", description = "PDB and Real Space input files.")
88    private List<String> filenames;
89    
90    private RefinementEnergy refinementEnergy;
91  
92    /**
93     * LambdaGradient constructor.
94     */
95    public LambdaGradient() {
96      super();
97    }
98  
99    /**
100    * LambdaGradient constructor that sets the command line arguments.
101    *
102    * @param args Command line arguments.
103    */
104   public LambdaGradient(String[] args) {
105     super(args);
106   }
107 
108   /**
109    * LambdaGradient constructor.
110    *
111    * @param binding The Binding to use.
112    */
113   public LambdaGradient(FFXBinding binding) {
114     super(binding);
115   }
116 
117   /**
118    * {@inheritDoc}
119    */
120   @Override
121   public LambdaGradient run() {
122 
123     if (!init()) {
124       return this;
125     }
126 
127     xrayOptions.init();
128 
129     // Turn on computation of lambda derivatives
130     System.setProperty("lambdaterm", "true");
131 
132     String filename;
133     MolecularAssembly[] molecularAssemblies;
134     if (filenames != null && !filenames.isEmpty()) {
135       molecularAssemblies = algorithmFunctions.openAll(filenames.get(0));
136       activeAssembly = molecularAssemblies[0];
137       filename = filenames.get(0);
138     } else if (activeAssembly == null) {
139       logger.info(helpString());
140       return this;
141     } else {
142       filename = activeAssembly.getFile().getAbsolutePath();
143       molecularAssemblies = new MolecularAssembly[]{activeAssembly};
144     }
145 
146     alchemicalOptions.setFirstSystemAlchemistry(activeAssembly);
147     alchemicalOptions.setFirstSystemUnchargedAtoms(activeAssembly);
148 
149     logger.info("\n Testing X-ray lambda derivatives for " + filename);
150 
151     // Load parsed X-ray properties.
152     CompositeConfiguration properties = molecularAssemblies[0].getProperties();
153     xrayOptions.setProperties(parseResult, properties);
154 
155     // Set up diffraction data (can be multiple files)
156     DiffractionData diffractionData = xrayOptions.getDiffractionData(filenames, molecularAssemblies, properties);
157     refinementEnergy = xrayOptions.toXrayEnergy(diffractionData);
158 
159     Potential potential = refinementEnergy;
160     LambdaInterface lambdaInterface = refinementEnergy;
161 
162     // Finite-difference step size in Angstroms.
163     double step = gradientOptions.getDx();
164 
165     int n = refinementEnergy.getNumberOfVariables();
166     Atom[] atoms = refinementEnergy.getActiveAtoms();
167     int nAtoms = atoms.length;
168     double[] x = new double[n];
169     double[] gradient = new double[n];
170 
171     // Finite-difference step size.
172     double width = 2.0 * step;
173     // Error tolerence
174     double errTol = 1.0e-3;
175     // Upper bound for typical gradient sizes (expected gradient)
176     double expGrad = 1000.0;
177 
178     double[] lambdaGrad = new double[n];
179     double[][] lambdaGradFD = new double[2][n];
180 
181     double initialLambda = alchemicalOptions.getInitialLambda();
182 
183     // Compute the Lambda = 0.0 energy.
184     double lambda = 0.0;
185     lambdaInterface.setLambda(lambda);
186     potential.getCoordinates(x);
187     double e0 = potential.energy(x, true);
188 
189     // Compute the Lambda = 1.0 energy.
190     lambda = 1.0;
191     lambdaInterface.setLambda(lambda);
192     double e1 = potential.energy(x, true);
193 
194     logger.info(format(" E(0):      %20.8f.", e0));
195     logger.info(format(" E(1):      %20.8f.", e1));
196     logger.info(format(" E(1)-E(0): %20.8f.\n", e1 - e0));
197 
198     // Test Lambda gradient in the neighborhood of the lambda variable.
199     for (int j = 0; j < 3; j++) {
200       lambda = initialLambda - 0.01 + 0.01 * j;
201 
202       if (lambda - step < 0.0) {
203         continue;
204       }
205       if (lambda + step > 1.0) {
206         continue;
207       }
208 
209       logger.info(format(" Current lambda value %6.4f", lambda));
210       lambdaInterface.setLambda(lambda);
211 
212       // Calculate the energy, dE/dX, dE/dL, d2E/dL2 and dE/dL/dX
213       double e = potential.energyAndGradient(x, gradient);
214 
215       // Analytic dEdL, d2E/dL2 and dE/dL/dX
216       double dEdL = lambdaInterface.getdEdL();
217       double d2EdL2 = lambdaInterface.getd2EdL2();
218       for (int i = 0; i < n; i++) {
219         lambdaGrad[i] = 0.0;
220       }
221       lambdaInterface.getdEdXdL(lambdaGrad);
222 
223       // Calculate the finite-difference dEdLambda, d2EdLambda2 and dEdLambdadX
224       lambdaInterface.setLambda(lambda + step);
225       double lp = potential.energyAndGradient(x, lambdaGradFD[0]);
226       double dedlp = lambdaInterface.getdEdL();
227       lambdaInterface.setLambda(lambda - step);
228       double lm = potential.energyAndGradient(x, lambdaGradFD[1]);
229       double dedlm = lambdaInterface.getdEdL();
230 
231       double dEdLFD = (lp - lm) / width;
232       double d2EdL2FD = (dedlp - dedlm) / width;
233 
234       double err = Math.abs(dEdLFD - dEdL);
235       if (err < errTol) {
236         logger.info(format(" dE/dL passed:   %10.6f", err));
237       } else {
238         logger.info(format(" dE/dL failed: %10.6f", err));
239       }
240       logger.info(format(" Numeric:   %15.8f", dEdLFD));
241       logger.info(format(" Analytic:  %15.8f", dEdL));
242 
243       err = Math.abs(d2EdL2FD - d2EdL2);
244       if (err < errTol) {
245         logger.info(format(" d2E/dL2 passed: %10.6f", err));
246       } else {
247         logger.info(format(" d2E/dL2 failed: %10.6f", err));
248       }
249       logger.info(format(" Numeric:   %15.8f", d2EdL2FD));
250       logger.info(format(" Analytic:  %15.8f", d2EdL2));
251 
252       boolean passed = true;
253 
254       for (int i = 0; i < nAtoms; i++) {
255         int ii = i * 3;
256         double dX = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
257         double dXa = lambdaGrad[ii];
258         double eX = dX - dXa;
259         ii++;
260         double dY = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
261         double dYa = lambdaGrad[ii];
262         double eY = dY - dYa;
263         ii++;
264         double dZ = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
265         double dZa = lambdaGrad[ii];
266         double eZ = dZ - dZa;
267 
268         double error = Math.sqrt(eX * eX + eY * eY + eZ * eZ);
269         if (error < errTol) {
270           logger.fine(format(" dE/dX/dL for Atom %d passed: %10.6f", i + 1, error));
271         } else {
272           logger.info(format(" dE/dX/dL for Atom %d failed: %10.6f", i + 1, error));
273           logger.info(format(" Analytic: (%15.8f, %15.8f, %15.8f)", dXa, dYa, dZa));
274           logger.info(format(" Numeric:  (%15.8f, %15.8f, %15.8f)", dX, dY, dZ));
275           passed = false;
276         }
277       }
278       if (passed) {
279         logger.info(format(" dE/dX/dL passed for all atoms"));
280       }
281 
282       logger.info("");
283     }
284 
285     boolean loopPrint = gradientOptions.getVerbose();
286     refinementEnergy.getCoordinates(x);
287     refinementEnergy.energyAndGradient(x, gradient, loopPrint);
288 
289     double[] numeric = new double[3];
290     double avLen = 0.0;
291     int nFailures = 0;
292     double avGrad = 0.0;
293 
294     // Collect atoms to test.
295     List<Integer> atomsToTest;
296     if (gradientOptions.getGradientAtoms().equalsIgnoreCase("NONE")) {
297       logger.info(" The gradient of no atoms will be evaluated.");
298       return this;
299     } else if (gradientOptions.getGradientAtoms().equalsIgnoreCase("ALL")) {
300       logger.info(" Checking gradient for all active atoms.\n");
301       atomsToTest = new ArrayList<>();
302       IntStream.range(0, nAtoms).forEach(val -> atomsToTest.add(val));
303     } else {
304       atomsToTest = parseAtomRanges(" Gradient atoms", gradientOptions.getGradientAtoms(), nAtoms);
305       logger.info(
306           " Checking gradient for active atoms in the range: " + gradientOptions.getGradientAtoms() +
307               "\n");
308     }
309 
310     for (int i : atomsToTest) {
311       int i3 = i * 3;
312       int i0 = i3 + 0;
313       int i1 = i3 + 1;
314       int i2 = i3 + 2;
315 
316       // Find numeric dX
317       double orig = x[i0];
318       x[i0] = x[i0] + step;
319       double e = refinementEnergy.energy(x, loopPrint);
320       x[i0] = orig - step;
321       e -= refinementEnergy.energy(x, loopPrint);
322       x[i0] = orig;
323       numeric[0] = e / width;
324 
325       // Find numeric dY
326       orig = x[i1];
327       x[i1] = x[i1] + step;
328       e = refinementEnergy.energy(x, loopPrint);
329       x[i1] = orig - step;
330       e -= refinementEnergy.energy(x, loopPrint);
331       x[i1] = orig;
332       numeric[1] = e / width;
333 
334       // Find numeric dZ
335       orig = x[i2];
336       x[i2] = x[i2] + step;
337       e = refinementEnergy.energy(x, loopPrint);
338       x[i2] = orig - step;
339       e -= refinementEnergy.energy(x, loopPrint);
340       x[i2] = orig;
341       numeric[2] = e / width;
342 
343       double dx = gradient[i0] - numeric[0];
344       double dy = gradient[i1] - numeric[1];
345       double dz = gradient[i2] - numeric[2];
346       double len = dx * dx + dy * dy + dz * dz;
347       avLen += len;
348       len = Math.sqrt(len);
349 
350       double grad2 =
351           gradient[i0] * gradient[i0] + gradient[i1] * gradient[i1] + gradient[i2] * gradient[i2];
352       avGrad += grad2;
353       grad2 = Math.sqrt(grad2);
354 
355       if (len > errTol) {
356         logger.info(format(" Atom %d failed: %10.6f.", i + 1, len) +
357             format("\n Analytic: (%12.4f, %12.4f, %12.4f)", gradient[i0], gradient[i1],
358                 gradient[i2]) +
359             format("\n Numeric:  (%12.4f, %12.4f, %12.4f)", numeric[0], numeric[1], numeric[2]));
360         ++nFailures;
361         //return
362       } else {
363         logger.info(format(" Atom %d passed: %10.6f.", i + 1, len) +
364             format("\n Analytic: (%12.4f, %12.4f, %12.4f)", gradient[i0], gradient[i1],
365                 gradient[i2]) +
366             format("\n Numeric:  (%12.4f, %12.4f, %12.4f)", numeric[0], numeric[1], numeric[2]));
367       }
368 
369       if (grad2 > expGrad) {
370         logger.info(format(" Atom %d has an unusually large gradient: %10.6f", i + 1, grad2));
371       }
372       logger.info("\n");
373     }
374 
375     avLen = avLen / nAtoms;
376     avLen = Math.sqrt(avLen);
377     if (avLen > errTol) {
378       logger.info(
379           format(" Test failure: RMSD from analytic solution is %10.6f > %10.6f", avLen, errTol));
380     } else {
381       logger.info(
382           format(" Test success: RMSD from analytic solution is %10.6f < %10.6f", avLen, errTol));
383     }
384     logger.info(format(" Number of atoms failing gradient test: %d", nFailures));
385 
386     avGrad = avGrad / nAtoms;
387     avGrad = Math.sqrt(avGrad);
388     if (avGrad > expGrad) {
389       logger.info(format(" Unusually large RMS gradient: %10.6f > %10.6f", avGrad, expGrad));
390     } else {
391       logger.info(format(" RMS gradient: %10.6f", avGrad));
392     }
393 
394     refinementEnergy = new RefinementEnergy(diffractionData, RefinementMode.BFACTORS);
395     n = refinementEnergy.getNumberOfVariables();
396     gradient = new double[n];
397     x = new double[n];
398 
399     refinementEnergy.getCoordinates(x);
400     refinementEnergy.energyAndGradient(x, gradient);
401 
402     avLen = 0.0;
403     nFailures = 0;
404     avGrad = 0.0;
405     width = 2.0 * step;
406     errTol = 1.0e-3;
407     expGrad = 1000.0;
408 
409     for (int i = 0; i < n; i++) {
410 
411       // Find numeric dB
412       double orig = x[i];
413       x[i] = x[i] + step;
414       double e = refinementEnergy.energy(x);
415       x[i] = orig - step;
416       e -= refinementEnergy.energy(x);
417       x[i] = orig;
418       double fd = e / width;
419 
420       double dB = gradient[i] - fd;
421       double len = dB * dB;
422       avLen += len;
423       len = Math.sqrt(len);
424 
425       double grad2 = dB * dB;
426       avGrad += grad2;
427       grad2 = Math.sqrt(grad2);
428 
429       if (len > errTol) {
430         logger.info(format(" B-Factor %d failed: %10.6f.", i + 1, len) +
431             format("\n Analytic: %12.4f", gradient[i]) +
432             format("\n Numeric:  %12.4f", fd));
433         ++nFailures;
434         //return
435       } else {
436         logger.info(format(" B-Factor %d passed: %10.6f.", i + 1, len) +
437             format("\n Analytic: %12.4f", gradient[i]) +
438             format("\n Numeric:  %12.4f", fd));
439       }
440 
441       if (grad2 > expGrad) {
442         logger.info(format(" B-Factor %d has an unusually large gradient: %10.6f", i + 1, grad2));
443       }
444       logger.info("\n");
445     }
446 
447     avLen = avLen / n;
448     avLen = Math.sqrt(avLen);
449     if (avLen > errTol) {
450       logger.info(
451           format(" Test failure: RMSD from analytic solution is %10.6f > %10.6f", avLen, errTol));
452     } else {
453       logger.info(
454           format(" Test success: RMSD from analytic solution is %10.6f < %10.6f", avLen, errTol));
455     }
456     logger.info(format(" Number of B-Factors failing gradient test: %d", nFailures));
457 
458     avGrad = avGrad / n;
459     avGrad = Math.sqrt(avGrad);
460     if (avGrad > expGrad) {
461       logger.info(format(" Unusually large RMS gradient: %10.6f > %10.6f", avGrad, expGrad));
462     } else {
463       logger.info(format(" RMS gradient: %10.6f", avGrad));
464     }
465 
466     return this;
467   }
468 
469   /**
470    * {@inheritDoc}
471    */
472   @Override
473   public List<Potential> getPotentials() {
474     return refinementEnergy == null ? Collections.emptyList() :
475         Collections.singletonList(refinementEnergy);
476   }
477 }