View Javadoc
1   //******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  //******************************************************************************
38  package ffx.xray.commands.test;
39  
40  import ffx.algorithms.cli.AlgorithmsCommand;
41  import ffx.numerics.Potential;
42  import ffx.potential.MolecularAssembly;
43  import ffx.potential.bonded.Atom;
44  import ffx.potential.bonded.LambdaInterface;
45  import ffx.potential.cli.AlchemicalOptions;
46  import ffx.potential.cli.GradientOptions;
47  import ffx.utilities.FFXBinding;
48  import ffx.xray.DiffractionData;
49  import ffx.xray.RefinementEnergy;
50  import ffx.xray.refine.RefinementMode;
51  import ffx.xray.refine.RefinementModel;
52  import ffx.xray.cli.XrayOptions;
53  import org.apache.commons.configuration2.CompositeConfiguration;
54  import picocli.CommandLine.Command;
55  import picocli.CommandLine.Mixin;
56  import picocli.CommandLine.Parameters;
57  
58  import java.util.ArrayList;
59  import java.util.Collections;
60  import java.util.List;
61  import java.util.stream.IntStream;
62  
63  import static ffx.utilities.StringUtils.parseAtomRanges;
64  import static java.lang.String.format;
65  
66  /**
67   * The X-ray test Lambda Gradient script.
68   * <br>
69   * Usage:
70   * <br>
71   * ffxc xray.test.LambdaGradient [options] &lt;filename&gt;
72   */
73  @Command(description = " Test Lambda Derivatives on an X-ray target.", name = "xray.test.LambdaGradient")
74  public class LambdaGradient extends AlgorithmsCommand {
75  
76    @Mixin
77    private XrayOptions xrayOptions;
78  
79    @Mixin
80    private AlchemicalOptions alchemicalOptions;
81  
82    @Mixin
83    private GradientOptions gradientOptions;
84  
85    /**
86     * One or more filenames.
87     */
88    @Parameters(arity = "1..*", paramLabel = "files", description = "PDB and Real Space input files.")
89    private List<String> filenames;
90    
91    private RefinementEnergy refinementEnergy;
92  
93    /**
94     * LambdaGradient constructor.
95     */
96    public LambdaGradient() {
97      super();
98    }
99  
100   /**
101    * LambdaGradient constructor that sets the command line arguments.
102    *
103    * @param args Command line arguments.
104    */
105   public LambdaGradient(String[] args) {
106     super(args);
107   }
108 
109   /**
110    * LambdaGradient constructor.
111    *
112    * @param binding The Binding to use.
113    */
114   public LambdaGradient(FFXBinding binding) {
115     super(binding);
116   }
117 
118   /**
119    * {@inheritDoc}
120    */
121   @Override
122   public LambdaGradient run() {
123 
124     if (!init()) {
125       return this;
126     }
127 
128     xrayOptions.init();
129 
130     // Turn on computation of lambda derivatives
131     System.setProperty("lambdaterm", "true");
132 
133     String filename;
134     MolecularAssembly[] molecularAssemblies;
135     if (filenames != null && !filenames.isEmpty()) {
136       molecularAssemblies = algorithmFunctions.openAll(filenames.get(0));
137       activeAssembly = molecularAssemblies[0];
138       filename = filenames.get(0);
139     } else if (activeAssembly == null) {
140       logger.info(helpString());
141       return this;
142     } else {
143       filename = activeAssembly.getFile().getAbsolutePath();
144       molecularAssemblies = new MolecularAssembly[]{activeAssembly};
145     }
146 
147     alchemicalOptions.setFirstSystemAlchemistry(activeAssembly);
148     alchemicalOptions.setFirstSystemUnchargedAtoms(activeAssembly);
149 
150     logger.info("\n Testing X-ray lambda derivatives for " + filename);
151 
152     // Load parsed X-ray properties.
153     CompositeConfiguration properties = molecularAssemblies[0].getProperties();
154     xrayOptions.setProperties(parseResult, properties);
155 
156     // Set up diffraction data (can be multiple files)
157     DiffractionData diffractionData = xrayOptions.getDiffractionData(filenames, molecularAssemblies, properties);
158     refinementEnergy = xrayOptions.toXrayEnergy(diffractionData);
159 
160     Potential potential = refinementEnergy;
161     LambdaInterface lambdaInterface = refinementEnergy;
162 
163     // Finite-difference step size in Angstroms.
164     double step = gradientOptions.getDx();
165 
166     RefinementModel refinementModel = diffractionData.getRefinementModel();
167 
168     int n = refinementModel.getNumParameters();
169     Atom[] atoms = refinementModel.getActiveAtoms();
170     int nAtoms = atoms.length;
171     double[] x = new double[n];
172     double[] gradient = new double[n];
173 
174     // Finite-difference step size.
175     double width = 2.0 * step;
176     // Error tolerence
177     double errTol = 1.0e-3;
178     // Upper bound for typical gradient sizes (expected gradient)
179     double expGrad = 1000.0;
180 
181     double[] lambdaGrad = new double[n];
182     double[][] lambdaGradFD = new double[2][n];
183 
184     double initialLambda = alchemicalOptions.getInitialLambda();
185 
186     // Compute the Lambda = 0.0 energy.
187     double lambda = 0.0;
188     lambdaInterface.setLambda(lambda);
189     potential.getCoordinates(x);
190     double e0 = potential.energy(x, true);
191 
192     // Compute the Lambda = 1.0 energy.
193     lambda = 1.0;
194     lambdaInterface.setLambda(lambda);
195     double e1 = potential.energy(x, true);
196 
197     logger.info(format(" E(0):      %20.8f.", e0));
198     logger.info(format(" E(1):      %20.8f.", e1));
199     logger.info(format(" E(1)-E(0): %20.8f.\n", e1 - e0));
200 
201     // Test Lambda gradient in the neighborhood of the lambda variable.
202     for (int j = 0; j < 3; j++) {
203       lambda = initialLambda - 0.01 + 0.01 * j;
204 
205       if (lambda - step < 0.0) {
206         continue;
207       }
208       if (lambda + step > 1.0) {
209         continue;
210       }
211 
212       logger.info(format(" Current lambda value %6.4f", lambda));
213       lambdaInterface.setLambda(lambda);
214 
215       // Calculate the energy, dE/dX, dE/dL, d2E/dL2 and dE/dL/dX
216       double e = potential.energyAndGradient(x, gradient);
217 
218       // Analytic dEdL, d2E/dL2 and dE/dL/dX
219       double dEdL = lambdaInterface.getdEdL();
220       double d2EdL2 = lambdaInterface.getd2EdL2();
221       for (int i = 0; i < n; i++) {
222         lambdaGrad[i] = 0.0;
223       }
224       lambdaInterface.getdEdXdL(lambdaGrad);
225 
226       // Calculate the finite-difference dEdLambda, d2EdLambda2 and dEdLambdadX
227       lambdaInterface.setLambda(lambda + step);
228       double lp = potential.energyAndGradient(x, lambdaGradFD[0]);
229       double dedlp = lambdaInterface.getdEdL();
230       lambdaInterface.setLambda(lambda - step);
231       double lm = potential.energyAndGradient(x, lambdaGradFD[1]);
232       double dedlm = lambdaInterface.getdEdL();
233 
234       double dEdLFD = (lp - lm) / width;
235       double d2EdL2FD = (dedlp - dedlm) / width;
236 
237       double err = Math.abs(dEdLFD - dEdL);
238       if (err < errTol) {
239         logger.info(format(" dE/dL passed:   %10.6f", err));
240       } else {
241         logger.info(format(" dE/dL failed: %10.6f", err));
242       }
243       logger.info(format(" Numeric:   %15.8f", dEdLFD));
244       logger.info(format(" Analytic:  %15.8f", dEdL));
245 
246       err = Math.abs(d2EdL2FD - d2EdL2);
247       if (err < errTol) {
248         logger.info(format(" d2E/dL2 passed: %10.6f", err));
249       } else {
250         logger.info(format(" d2E/dL2 failed: %10.6f", err));
251       }
252       logger.info(format(" Numeric:   %15.8f", d2EdL2FD));
253       logger.info(format(" Analytic:  %15.8f", d2EdL2));
254 
255       boolean passed = true;
256 
257       for (int i = 0; i < nAtoms; i++) {
258         int ii = i * 3;
259         double dX = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
260         double dXa = lambdaGrad[ii];
261         double eX = dX - dXa;
262         ii++;
263         double dY = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
264         double dYa = lambdaGrad[ii];
265         double eY = dY - dYa;
266         ii++;
267         double dZ = (lambdaGradFD[0][ii] - lambdaGradFD[1][ii]) / width;
268         double dZa = lambdaGrad[ii];
269         double eZ = dZ - dZa;
270 
271         double error = Math.sqrt(eX * eX + eY * eY + eZ * eZ);
272         if (error < errTol) {
273           logger.fine(format(" dE/dX/dL for Atom %d passed: %10.6f", i + 1, error));
274         } else {
275           logger.info(format(" dE/dX/dL for Atom %d failed: %10.6f", i + 1, error));
276           logger.info(format(" Analytic: (%15.8f, %15.8f, %15.8f)", dXa, dYa, dZa));
277           logger.info(format(" Numeric:  (%15.8f, %15.8f, %15.8f)", dX, dY, dZ));
278           passed = false;
279         }
280       }
281       if (passed) {
282         logger.info(format(" dE/dX/dL passed for all atoms"));
283       }
284 
285       logger.info("");
286     }
287 
288     boolean loopPrint = gradientOptions.getVerbose();
289     refinementEnergy.getCoordinates(x);
290     refinementEnergy.energyAndGradient(x, gradient, loopPrint);
291 
292     double[] numeric = new double[3];
293     double avLen = 0.0;
294     int nFailures = 0;
295     double avGrad = 0.0;
296 
297     // Collect atoms to test.
298     List<Integer> atomsToTest;
299     if (gradientOptions.getGradientAtoms().equalsIgnoreCase("NONE")) {
300       logger.info(" The gradient of no atoms will be evaluated.");
301       return this;
302     } else if (gradientOptions.getGradientAtoms().equalsIgnoreCase("ALL")) {
303       logger.info(" Checking gradient for all active atoms.\n");
304       atomsToTest = new ArrayList<>();
305       IntStream.range(0, nAtoms).forEach(val -> atomsToTest.add(val));
306     } else {
307       atomsToTest = parseAtomRanges(" Gradient atoms", gradientOptions.getGradientAtoms(), nAtoms);
308       logger.info(
309           " Checking gradient for active atoms in the range: " + gradientOptions.getGradientAtoms() +
310               "\n");
311     }
312 
313     for (int i : atomsToTest) {
314       int i3 = i * 3;
315       int i0 = i3 + 0;
316       int i1 = i3 + 1;
317       int i2 = i3 + 2;
318 
319       // Find numeric dX
320       double orig = x[i0];
321       x[i0] = x[i0] + step;
322       double e = refinementEnergy.energy(x, loopPrint);
323       x[i0] = orig - step;
324       e -= refinementEnergy.energy(x, loopPrint);
325       x[i0] = orig;
326       numeric[0] = e / width;
327 
328       // Find numeric dY
329       orig = x[i1];
330       x[i1] = x[i1] + step;
331       e = refinementEnergy.energy(x, loopPrint);
332       x[i1] = orig - step;
333       e -= refinementEnergy.energy(x, loopPrint);
334       x[i1] = orig;
335       numeric[1] = e / width;
336 
337       // Find numeric dZ
338       orig = x[i2];
339       x[i2] = x[i2] + step;
340       e = refinementEnergy.energy(x, loopPrint);
341       x[i2] = orig - step;
342       e -= refinementEnergy.energy(x, loopPrint);
343       x[i2] = orig;
344       numeric[2] = e / width;
345 
346       double dx = gradient[i0] - numeric[0];
347       double dy = gradient[i1] - numeric[1];
348       double dz = gradient[i2] - numeric[2];
349       double len = dx * dx + dy * dy + dz * dz;
350       avLen += len;
351       len = Math.sqrt(len);
352 
353       double grad2 =
354           gradient[i0] * gradient[i0] + gradient[i1] * gradient[i1] + gradient[i2] * gradient[i2];
355       avGrad += grad2;
356       grad2 = Math.sqrt(grad2);
357 
358       if (len > errTol) {
359         logger.info(format(" Atom %d failed: %10.6f.", i + 1, len) +
360             format("\n Analytic: (%12.4f, %12.4f, %12.4f)", gradient[i0], gradient[i1],
361                 gradient[i2]) +
362             format("\n Numeric:  (%12.4f, %12.4f, %12.4f)", numeric[0], numeric[1], numeric[2]));
363         ++nFailures;
364         //return
365       } else {
366         logger.info(format(" Atom %d passed: %10.6f.", i + 1, len) +
367             format("\n Analytic: (%12.4f, %12.4f, %12.4f)", gradient[i0], gradient[i1],
368                 gradient[i2]) +
369             format("\n Numeric:  (%12.4f, %12.4f, %12.4f)", numeric[0], numeric[1], numeric[2]));
370       }
371 
372       if (grad2 > expGrad) {
373         logger.info(format(" Atom %d has an unusually large gradient: %10.6f", i + 1, grad2));
374       }
375       logger.info("\n");
376     }
377 
378     avLen = avLen / nAtoms;
379     avLen = Math.sqrt(avLen);
380     if (avLen > errTol) {
381       logger.info(
382           format(" Test failure: RMSD from analytic solution is %10.6f > %10.6f", avLen, errTol));
383     } else {
384       logger.info(
385           format(" Test success: RMSD from analytic solution is %10.6f < %10.6f", avLen, errTol));
386     }
387     logger.info(format(" Number of atoms failing gradient test: %d", nFailures));
388 
389     avGrad = avGrad / nAtoms;
390     avGrad = Math.sqrt(avGrad);
391     if (avGrad > expGrad) {
392       logger.info(format(" Unusually large RMS gradient: %10.6f > %10.6f", avGrad, expGrad));
393     } else {
394       logger.info(format(" RMS gradient: %10.6f", avGrad));
395     }
396 
397     diffractionData.getRefinementModel().setRefinementMode(RefinementMode.BFACTORS);
398     refinementEnergy = new RefinementEnergy(diffractionData);
399     n = refinementEnergy.getNumberOfVariables();
400     gradient = new double[n];
401     x = new double[n];
402 
403     refinementEnergy.getCoordinates(x);
404     refinementEnergy.energyAndGradient(x, gradient);
405 
406     avLen = 0.0;
407     nFailures = 0;
408     avGrad = 0.0;
409     width = 2.0 * step;
410     errTol = 1.0e-3;
411     expGrad = 1000.0;
412 
413     for (int i = 0; i < n; i++) {
414 
415       // Find numeric dB
416       double orig = x[i];
417       x[i] = x[i] + step;
418       double e = refinementEnergy.energy(x);
419       x[i] = orig - step;
420       e -= refinementEnergy.energy(x);
421       x[i] = orig;
422       double fd = e / width;
423 
424       double dB = gradient[i] - fd;
425       double len = dB * dB;
426       avLen += len;
427       len = Math.sqrt(len);
428 
429       double grad2 = dB * dB;
430       avGrad += grad2;
431       grad2 = Math.sqrt(grad2);
432 
433       if (len > errTol) {
434         logger.info(format(" B-Factor %d failed: %10.6f.", i + 1, len) +
435             format("\n Analytic: %12.4f", gradient[i]) +
436             format("\n Numeric:  %12.4f", fd));
437         ++nFailures;
438         //return
439       } else {
440         logger.info(format(" B-Factor %d passed: %10.6f.", i + 1, len) +
441             format("\n Analytic: %12.4f", gradient[i]) +
442             format("\n Numeric:  %12.4f", fd));
443       }
444 
445       if (grad2 > expGrad) {
446         logger.info(format(" B-Factor %d has an unusually large gradient: %10.6f", i + 1, grad2));
447       }
448       logger.info("\n");
449     }
450 
451     avLen = avLen / n;
452     avLen = Math.sqrt(avLen);
453     if (avLen > errTol) {
454       logger.info(
455           format(" Test failure: RMSD from analytic solution is %10.6f > %10.6f", avLen, errTol));
456     } else {
457       logger.info(
458           format(" Test success: RMSD from analytic solution is %10.6f < %10.6f", avLen, errTol));
459     }
460     logger.info(format(" Number of B-Factors failing gradient test: %d", nFailures));
461 
462     avGrad = avGrad / n;
463     avGrad = Math.sqrt(avGrad);
464     if (avGrad > expGrad) {
465       logger.info(format(" Unusually large RMS gradient: %10.6f > %10.6f", avGrad, expGrad));
466     } else {
467       logger.info(format(" RMS gradient: %10.6f", avGrad));
468     }
469 
470     return this;
471   }
472 
473   /**
474    * {@inheritDoc}
475    */
476   @Override
477   public List<Potential> getPotentials() {
478     return refinementEnergy == null ? Collections.emptyList() :
479         Collections.singletonList(refinementEnergy);
480   }
481 }