View Javadoc
1   // ******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2024.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  // ******************************************************************************
38  package ffx.xray;
39  
40  import static ffx.numerics.math.DoubleMath.dot;
41  import static ffx.numerics.math.MatrixMath.mat3Mat3;
42  import static ffx.numerics.math.MatrixMath.mat3SymVec6;
43  import static ffx.numerics.math.MatrixMath.transpose3;
44  import static ffx.numerics.math.MatrixMath.vec3Mat3;
45  import static ffx.numerics.special.ModifiedBessel.i1OverI0;
46  import static ffx.numerics.special.ModifiedBessel.lnI0;
47  import static java.lang.Double.isNaN;
48  import static java.lang.String.format;
49  import static java.lang.System.arraycopy;
50  import static java.util.Arrays.fill;
51  import static org.apache.commons.math3.util.FastMath.PI;
52  import static org.apache.commons.math3.util.FastMath.abs;
53  import static org.apache.commons.math3.util.FastMath.atan;
54  import static org.apache.commons.math3.util.FastMath.cos;
55  import static org.apache.commons.math3.util.FastMath.cosh;
56  import static org.apache.commons.math3.util.FastMath.exp;
57  import static org.apache.commons.math3.util.FastMath.log;
58  import static org.apache.commons.math3.util.FastMath.sin;
59  import static org.apache.commons.math3.util.FastMath.sqrt;
60  import static org.apache.commons.math3.util.FastMath.tanh;
61  
62  import edu.rit.pj.IntegerForLoop;
63  import edu.rit.pj.ParallelRegion;
64  import edu.rit.pj.ParallelTeam;
65  import edu.rit.pj.reduction.SharedDouble;
66  import edu.rit.pj.reduction.SharedDoubleArray;
67  import edu.rit.pj.reduction.SharedInteger;
68  import ffx.crystal.Crystal;
69  import ffx.crystal.HKL;
70  import ffx.crystal.ReflectionList;
71  import ffx.crystal.ReflectionSpline;
72  import ffx.numerics.OptimizationInterface;
73  import ffx.numerics.math.ComplexNumber;
74  import ffx.xray.CrystalReciprocalSpace.SolventModel;
75  
76  import java.util.logging.Logger;
77  
78  /**
79   * Optimize SigmaA coefficients (using spline coefficients) and structure factor derivatives using a
80   * likelihood target function.
81   *
82   * <p>This target can also be used for structure refinement.
83   *
84   * @author Timothy D. Fenn<br>
85   * @see <a href="http://dx.doi.org/10.1107/S0021889804031474" target="_blank"> K. Cowtan, J. Appl.
86   * Cryst. (2005). 38, 193-198</a>
87   * @see <a href="http://dx.doi.org/10.1107/S0907444992007352" target="_blank"> A. T. Brunger, Acta
88   * Cryst. (1993). D49, 24-36</a>
89   * @see <a href="http://dx.doi.org/10.1107/S0907444996012255" target="_blank"> G. N. Murshudov, A.
90   * A. Vagin and E. J. Dodson, Acta Cryst. (1997). D53, 240-255</a>
91   * @see <a href="http://dx.doi.org/10.1107/S0108767388009183" target="_blank"> A. T. Brunger, Acta
92   * Cryst. (1989). A45, 42-50.</a>
93   * @see <a href="http://dx.doi.org/10.1107/S0108767386099622" target="_blank"> R. J. Read, Acta
94   * Cryst. (1986). A42, 140-149.</a>
95   * @see <a href="http://dx.doi.org/10.1107/S0108767396004370" target="_blank"> N. S. Pannu and R. J.
96   * Read, Acta Cryst. (1996). A52, 659-668.</a>
97   * @since 1.0
98   */
99  public class SigmaAEnergy implements OptimizationInterface {
100 
101   private static final Logger logger = Logger.getLogger(SigmaAEnergy.class.getName());
102   private static final double twoPI2 = 2.0 * PI * PI;
103   private static final double sim_a = 1.639294;
104   private static final double sim_b = 3.553967;
105   private static final double sim_c = 2.228716;
106   private static final double sim_d = 3.524142;
107   private static final double sim_e = 7.107935;
108   private static final double sim_A = -1.28173889903;
109   private static final double sim_B = 0.69231689903;
110   private static final double sim_g = 2.13643992379;
111   private static final double sim_p = 0.04613803811;
112   private static final double sim_q = 1.82167089029;
113   private static final double sim_r = -0.74817947490;
114 
115   private final ReflectionList reflectionList;
116   private final DiffractionRefinementData refinementData;
117   private final ParallelTeam parallelTeam;
118   private final Crystal crystal;
119   private final double[][] fSigF;
120   private final double[][] fcTot;
121   private final double[][] fomPhi;
122   private final double[][] foFc1;
123   private final double[][] foFc2;
124   private final double[][] dFc;
125   private final double[][] dFs;
126   private final int nBins;
127 
128   /**
129    * Crystal Volume^2 / (2 * number of grid points)
130    */
131   private final double dfScale;
132   /**
133    * Transpose of the matrix 'A' that converts from Cartesian to fractional coordinates.
134    */
135   private final double[][] transposeA;
136 
137   private final double[] sa;
138   private final double[] wa;
139 
140   private final SigmaARegion sigmaARegion;
141   private final boolean useCernBessel;
142   private double[] optimizationScaling = null;
143   private double totalEnergy;
144 
145   /**
146    * Constructor for SigmaAEnergy.
147    *
148    * @param reflectionList a {@link ffx.crystal.ReflectionList} object.
149    * @param refinementData a {@link ffx.xray.DiffractionRefinementData} object.
150    * @param parallelTeam   the ParallelTeam to execute the SigmaAEnergy.
151    */
152   SigmaAEnergy(
153       ReflectionList reflectionList,
154       DiffractionRefinementData refinementData,
155       ParallelTeam parallelTeam) {
156     this.reflectionList = reflectionList;
157     this.refinementData = refinementData;
158     this.parallelTeam = parallelTeam;
159     this.crystal = reflectionList.crystal;
160     this.fSigF = refinementData.fSigF;
161     this.fcTot = refinementData.fcTot;
162     this.fomPhi = refinementData.fomPhi;
163     this.foFc1 = refinementData.foFc1;
164     this.foFc2 = refinementData.foFc2;
165     this.dFc = refinementData.dFc;
166     this.dFs = refinementData.dFs;
167     this.nBins = refinementData.nBins;
168 
169     // Initialize parameters.
170     assert (refinementData.crystalReciprocalSpaceFc != null);
171     double nGrid2 = 2.0
172         * refinementData.crystalReciprocalSpaceFc.getXDim()
173         * refinementData.crystalReciprocalSpaceFc.getYDim()
174         * refinementData.crystalReciprocalSpaceFc.getZDim();
175     dfScale = (crystal.volume * crystal.volume) / nGrid2;
176     transposeA = transpose3(crystal.A);
177     sa = new double[nBins];
178     wa = new double[nBins];
179 
180     sigmaARegion = new SigmaARegion(this.parallelTeam.getThreadCount());
181     useCernBessel = true;
182   }
183 
184   /**
185    * From sim and sim_integ functions in clipper utils:
186    * http://www.ysbl.york.ac.uk/~cowtan/clipper/clipper.html
187    * and from lnI0 and i1OverI0 functions in bessel.h in scitbx module of cctbx:
188    * http://cci.lbl.gov/cctbx_sources/scitbx/math/bessel.h
189    *
190    * @param x a double.
191    * @return a double.
192    */
193   private static double sim(double x) {
194     if (x >= 0.0) {
195       return (((x + sim_a) * x + sim_b) * x) / (((x + sim_c) * x + sim_d) * x + sim_e);
196     } else {
197       return -(-(-(-x + sim_a) * x + sim_b) * x) / (-(-(-x + sim_c) * x + sim_d) * x + sim_e);
198     }
199   }
200 
201   /**
202    * sim_integ
203    *
204    * @param x0 a double.
205    * @return a double.
206    */
207   private static double sim_integ(double x0) {
208     double x = abs(x0);
209     double z = (x + sim_p) / sim_q;
210     return sim_A * log(x + sim_g) + 0.5 * sim_B * log(z * z + 1.0) + sim_r * atan(z) + x + 1.0;
211   }
212 
213   /**
214    * {@inheritDoc}
215    */
216   @Override
217   public boolean destroy() {
218     // Should be destroyed upstream in DiffractionData.
219     return true;
220   }
221 
222   /**
223    * {@inheritDoc}
224    */
225   @Override
226   public double energy(double[] x) {
227     unscaleCoordinates(x);
228     double sum = target(x, null, false, false);
229     scaleCoordinates(x);
230     return sum;
231   }
232 
233   /**
234    * {@inheritDoc}
235    */
236   @Override
237   public double energyAndGradient(double[] x, double[] g) {
238     unscaleCoordinates(x);
239     double sum = target(x, g, true, false);
240     scaleCoordinatesAndGradient(x, g);
241     return sum;
242   }
243 
244   /**
245    * {@inheritDoc}
246    */
247   @Override
248   public double[] getCoordinates(double[] parameters) {
249     throw new UnsupportedOperationException("Not supported yet.");
250   }
251 
252   /**
253    * {@inheritDoc}
254    */
255   @Override
256   public int getNumberOfVariables() {
257     throw new UnsupportedOperationException("Not supported yet.");
258   }
259 
260   /**
261    * {@inheritDoc}
262    */
263   @Override
264   public double[] getScaling() {
265     return optimizationScaling;
266   }
267 
268   /**
269    * {@inheritDoc}
270    */
271   @Override
272   public void setScaling(double[] scaling) {
273     if (scaling != null && scaling.length == nBins * 2) {
274       optimizationScaling = scaling;
275     } else {
276       optimizationScaling = null;
277     }
278   }
279 
280   /**
281    * {@inheritDoc}
282    */
283   @Override
284   public double getTotalEnergy() {
285     return totalEnergy;
286   }
287 
288   /**
289    * target
290    *
291    * @param x        an array of double.
292    * @param g        an array of double.
293    * @param gradient a boolean.
294    * @param print    a boolean.
295    * @return a double.
296    */
297   public double target(double[] x, double[] g, boolean gradient, boolean print) {
298 
299     try {
300       sigmaARegion.init(x, g, gradient);
301       parallelTeam.execute(sigmaARegion);
302     } catch (Exception e) {
303       logger.info(e.toString());
304     }
305 
306     double sum = sigmaARegion.sum.get();
307     double sumR = sigmaARegion.sumR.get();
308     refinementData.llkR = sumR;
309     refinementData.llkF = sum;
310 
311     if (print) {
312       int nSum = sigmaARegion.nSum.get();
313       int nSumr = sigmaARegion.nSumR.get();
314       StringBuilder sb = new StringBuilder("\n");
315       sb.append(" sigmaA (s and w) fit using only R free reflections\n");
316       sb.append(format("      # HKL: %10d (free set) %10d (working set) %10d (total)\n", nSum, nSumr, nSum + nSumr));
317       sb.append(format("   residual: %10g (free set) %10g (working set) %10g (total)\n", sum, sumR, sum + sumR));
318       sb.append("    x: ");
319       for (double x1 : x) {
320         sb.append(format("%g ", x1));
321       }
322       sb.append("\n    g: ");
323       for (double v : g) {
324         sb.append(format("%g ", v));
325       }
326       sb.append("\n");
327       logger.info(sb.toString());
328     }
329 
330     totalEnergy = sum;
331     return totalEnergy;
332   }
333 
334   private class SigmaARegion extends ParallelRegion {
335 
336     private final double[][] resm = new double[3][3];
337     private final double[] model_b = new double[6];
338     private final double[][] ustar = new double[3][3];
339     boolean gradient = true;
340     double modelK;
341     double solventK;
342     double solventUEq;
343     double[] x;
344     double[] g;
345     SharedInteger nSum;
346     SharedInteger nSumR;
347     SharedDouble sum;
348     SharedDouble sumR;
349     SharedDoubleArray grad;
350     SigmaALoop[] sigmaALoop;
351 
352     SigmaARegion(int nThreads) {
353       sigmaALoop = new SigmaALoop[nThreads];
354       nSum = new SharedInteger();
355       nSumR = new SharedInteger();
356       sum = new SharedDouble();
357       sumR = new SharedDouble();
358     }
359 
360     @Override
361     public void finish() {
362       if (gradient) {
363         for (int i = 0; i < g.length; i++) {
364           g[i] = grad.get(i);
365         }
366       }
367     }
368 
369     public void init(double[] x, double[] g, boolean gradient) {
370       this.x = x;
371       this.g = g;
372       this.gradient = gradient;
373     }
374 
375     @Override
376     public void run() {
377       int ti = getThreadIndex();
378 
379       if (sigmaALoop[ti] == null) {
380         sigmaALoop[ti] = new SigmaALoop();
381       }
382 
383       try {
384         execute(0, reflectionList.hklList.size() - 1, sigmaALoop[ti]);
385       } catch (Exception e) {
386         logger.info(e.toString());
387       }
388     }
389 
390     @Override
391     public void start() {
392       // Zero out the gradient
393       if (gradient) {
394         if (grad == null) {
395           grad = new SharedDoubleArray(g.length);
396         }
397         for (int i = 0; i < g.length; i++) {
398           grad.set(i, 0.0);
399         }
400       }
401       sum.set(0.0);
402       nSum.set(0);
403       sumR.set(0.0);
404       nSumR.set(0);
405 
406       modelK = refinementData.modelScaleK;
407       solventK = refinementData.bulkSolventK;
408       solventUEq = refinementData.bulkSolventUeq;
409       arraycopy(refinementData.modelAnisoB, 0, model_b, 0, 6);
410 
411       // Generate Ustar
412       mat3SymVec6(crystal.A, model_b, resm);
413       mat3Mat3(resm, transposeA, ustar);
414 
415       for (int i = 0; i < nBins; i++) {
416         sa[i] = 1.0 + x[i];
417         wa[i] = x[nBins + i];
418       }
419 
420       // Cheap method of preventing negative w values.
421       for (int i = 0; i < nBins; i++) {
422         if (wa[i] <= 0.0) {
423           wa[i] = 1.0e-6;
424         }
425       }
426     }
427 
428     private class SigmaALoop extends IntegerForLoop {
429 
430       private final double[] lGrad;
431       private final double[] resv = new double[3];
432       private final double[] ihc = new double[3];
433       private final ComplexNumber resc = new ComplexNumber();
434       private final ComplexNumber fcc = new ComplexNumber();
435       private final ComplexNumber fsc = new ComplexNumber();
436       private final ComplexNumber fct = new ComplexNumber();
437       private final ComplexNumber kfct = new ComplexNumber();
438       private final ComplexNumber ecc = new ComplexNumber();
439       private final ComplexNumber esc = new ComplexNumber();
440       private final ComplexNumber ect = new ComplexNumber();
441       private final ComplexNumber kect = new ComplexNumber();
442       private final ComplexNumber mfo = new ComplexNumber();
443       private final ComplexNumber mfo2 = new ComplexNumber();
444       private final ComplexNumber dfcc = new ComplexNumber();
445       private final ReflectionSpline spline = new ReflectionSpline(reflectionList, nBins);
446       // Thread local work variables.
447       private double lSum;
448       private double lSumR;
449       private int lSumN;
450       private int lSumRN;
451 
452       SigmaALoop() {
453         lGrad = new double[2 * nBins];
454       }
455 
456       @Override
457       public void finish() {
458         sum.addAndGet(lSum);
459         sumR.addAndGet(lSumR);
460         nSum.addAndGet(lSumN);
461         nSumR.addAndGet(lSumRN);
462         for (int i = 0; i < lGrad.length; i++) {
463           grad.getAndAdd(i, lGrad[i]);
464         }
465       }
466 
467       @Override
468       public void run(int lb, int ub) {
469         for (int j = lb; j <= ub; j++) {
470           HKL ih = reflectionList.hklList.get(j);
471           int i = ih.getIndex();
472           // Constants
473           ihc[0] = ih.getH();
474           ihc[1] = ih.getK();
475           ihc[2] = ih.getL();
476           vec3Mat3(ihc, ustar, resv);
477           double u = modelK - dot(resv, ihc);
478           double s = crystal.invressq(ih);
479           double ebs = exp(-twoPI2 * solventUEq * s);
480           double ksebs = solventK * ebs;
481           double kmems = exp(0.25 * u);
482           double km2 = exp(0.5 * u);
483           double epsc = ih.epsilonc();
484 
485           // Spline setup
486           double ecscale = spline.f(s, refinementData.esqFc);
487           double eoscale = spline.f(s, refinementData.esqFo);
488           double sqrtECScale = sqrt(ecscale);
489           double sqrtEOScale = sqrt(eoscale);
490           double iSqrtEOScale = 1.0 / sqrtEOScale;
491 
492           double sai = spline.f(s, sa);
493           double wai = spline.f(s, wa);
494           double sa2 = sai * sai;
495 
496           // Structure factors
497           refinementData.getFcIP(i, fcc);
498           refinementData.getFsIP(i, fsc);
499           fct.copy(fcc);
500           if (refinementData.crystalReciprocalSpaceFs.solventModel != SolventModel.NONE) {
501             resc.copy(fsc);
502             resc.timesIP(ksebs);
503             fct.plusIP(resc);
504           }
505           kfct.copy(fct);
506           kfct.timesIP(kmems);
507 
508           ecc.copy(fcc);
509           ecc.timesIP(sqrtECScale);
510           esc.copy(fsc);
511           esc.timesIP(sqrtECScale);
512           ect.copy(fct);
513           ect.timesIP(sqrtECScale);
514           kect.copy(kfct);
515           kect.timesIP(sqrtECScale);
516           double eo = fSigF[i][0] * sqrtEOScale;
517           double sigeo = fSigF[i][1] * sqrtEOScale;
518           double eo2 = eo * eo;
519           double akect = kect.abs();
520           double kect2 = akect * akect;
521 
522           // FOM
523           double d = 2.0 * sigeo * sigeo + epsc * wai;
524           double id = 1.0 / d;
525           double id2 = id * id;
526           double fomx = 2.0 * eo * sai * kect.abs() * id;
527 
528           double inot, dinot, cf;
529           if (ih.centric()) {
530             inot = (abs(fomx) < 10.0) ? log(cosh(fomx)) : abs(fomx) + log(0.5);
531             dinot = tanh(fomx);
532             cf = 0.5;
533           } else {
534             if (useCernBessel) {
535               inot = lnI0(fomx);
536               dinot = i1OverI0(fomx);
537             } else {
538               inot = sim_integ(fomx);
539               dinot = sim(fomx);
540             }
541             cf = 1.0;
542           }
543           double llk = cf * log(d) + (eo2 + sa2 * kect2) * id - inot;
544 
545           // Map coefficients
546           double f = dinot * eo;
547           double phi = kect.phase();
548           double sinPhi = sin(phi);
549           double cosPhi = cos(phi);
550           fomPhi[i][0] = dinot;
551           fomPhi[i][1] = phi;
552           mfo.re(f * cosPhi);
553           mfo.im(f * sinPhi);
554           mfo2.re(2.0 * f * cosPhi);
555           mfo2.im(2.0 * f * sinPhi);
556           akect = kect.abs();
557           dfcc.re(sai * akect * cosPhi);
558           dfcc.im(sai * akect * sinPhi);
559           // Set up map coefficients
560           foFc1[i][0] = 0.0;
561           foFc1[i][1] = 0.0;
562           foFc2[i][0] = 0.0;
563           foFc2[i][1] = 0.0;
564           dFc[i][0] = 0.0;
565           dFc[i][1] = 0.0;
566           dFs[i][0] = 0.0;
567           dFs[i][1] = 0.0;
568           if (isNaN(fcTot[i][0])) {
569             if (!isNaN(fSigF[i][0])) {
570               foFc2[i][0] = mfo.re() * iSqrtEOScale;
571               foFc2[i][1] = mfo.im() * iSqrtEOScale;
572             }
573             continue;
574           }
575           if (isNaN(fSigF[i][0])) {
576             if (!isNaN(fcTot[i][0])) {
577               foFc2[i][0] = dfcc.re() * iSqrtEOScale;
578               foFc2[i][1] = dfcc.im() * iSqrtEOScale;
579             }
580             continue;
581           }
582           // Update Fctot
583           fcTot[i][0] = kfct.re();
584           fcTot[i][1] = kfct.im();
585           // mFo - DFc
586           resc.copy(mfo);
587           resc.minusIP(dfcc);
588           foFc1[i][0] = resc.re() * iSqrtEOScale;
589           foFc1[i][1] = resc.im() * iSqrtEOScale;
590           // 2mFo - DFc
591           resc.copy(mfo2);
592           resc.minusIP(dfcc);
593           foFc2[i][0] = resc.re() * iSqrtEOScale;
594           foFc2[i][1] = resc.im() * iSqrtEOScale;
595 
596           // Derivatives
597           double dafct = d * fct.abs();
598           double idafct = 1.0 / dafct;
599           double dfp1 = 2.0 * sa2 * km2 * ecscale;
600           double dfp2 = 2.0 * eo * sai * kmems * sqrt(ecscale);
601           double dfp1id = dfp1 * id;
602           double dfp2id = dfp2 * idafct * dinot;
603           double dfp12 = dfp1id - dfp2id;
604           double dfp21 = ksebs * (dfp2id - dfp1id);
605           double dfcr = fct.re() * dfp12;
606           double dfci = fct.im() * dfp12;
607           double dfsr = fct.re() * dfp21;
608           double dfsi = fct.im() * dfp21;
609           double dfsa = 2.0 * (sai * kect2 - eo * akect * dinot) * id;
610           double dfwa =
611               epsc * (cf * id - (eo2 + sa2 * kect2) * id2 + 2.0 * eo * sai * akect * id2 * dinot);
612 
613           // Partial LLK wrt Fc or Fs
614           dFc[i][0] = dfcr * dfScale;
615           dFc[i][1] = dfci * dfScale;
616           dFs[i][0] = dfsr * dfScale;
617           dFs[i][1] = dfsi * dfScale;
618 
619           // Only use free R flagged reflections in overall sum
620           if (refinementData.isFreeR(i)) {
621             lSum += llk;
622             lSumN++;
623           } else {
624             lSumR += llk;
625             lSumRN++;
626             dfsa = dfwa = 0.0;
627           }
628 
629           if (gradient) {
630             int i0 = spline.i0();
631             int i1 = spline.i1();
632             int i2 = spline.i2();
633             double g0 = spline.dfi0();
634             double g1 = spline.dfi1();
635             double g2 = spline.dfi2();
636             // s derivative
637             lGrad[i0] += dfsa * g0;
638             lGrad[i1] += dfsa * g1;
639             lGrad[i2] += dfsa * g2;
640             // w derivative
641             lGrad[nBins + i0] += dfwa * g0;
642             lGrad[nBins + i1] += dfwa * g1;
643             lGrad[nBins + i2] += dfwa * g2;
644           }
645         }
646       }
647 
648       @Override
649       public void start() {
650         lSum = 0.0;
651         lSumR = 0.0;
652         lSumN = 0;
653         lSumRN = 0;
654         fill(lGrad, 0.0);
655       }
656     }
657   }
658 }