1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray;
39
40 import edu.rit.pj.IntegerForLoop;
41 import edu.rit.pj.ParallelRegion;
42 import edu.rit.pj.ParallelTeam;
43 import edu.rit.pj.reduction.SharedDouble;
44 import edu.rit.pj.reduction.SharedDoubleArray;
45 import edu.rit.pj.reduction.SharedInteger;
46 import ffx.crystal.Crystal;
47 import ffx.crystal.HKL;
48 import ffx.crystal.ReflectionList;
49 import ffx.crystal.ReflectionSpline;
50 import ffx.numerics.OptimizationInterface;
51 import ffx.numerics.math.ComplexNumber;
52 import ffx.xray.solvent.SolventModel;
53
54 import java.util.logging.Logger;
55
56 import static ffx.numerics.math.DoubleMath.dot;
57 import static ffx.numerics.math.MatrixMath.mat3Mat3Multiply;
58 import static ffx.numerics.math.MatrixMath.mat3SymVec6;
59 import static ffx.numerics.math.MatrixMath.mat3Transpose;
60 import static ffx.numerics.math.MatrixMath.vec3Mat3;
61 import static ffx.numerics.special.ModifiedBessel.i1OverI0;
62 import static ffx.numerics.special.ModifiedBessel.lnI0;
63 import static java.lang.Double.isNaN;
64 import static java.lang.String.format;
65 import static java.lang.System.arraycopy;
66 import static java.util.Arrays.fill;
67 import static org.apache.commons.math3.util.FastMath.PI;
68 import static org.apache.commons.math3.util.FastMath.abs;
69 import static org.apache.commons.math3.util.FastMath.atan;
70 import static org.apache.commons.math3.util.FastMath.cos;
71 import static org.apache.commons.math3.util.FastMath.cosh;
72 import static org.apache.commons.math3.util.FastMath.exp;
73 import static org.apache.commons.math3.util.FastMath.log;
74 import static org.apache.commons.math3.util.FastMath.sin;
75 import static org.apache.commons.math3.util.FastMath.sqrt;
76 import static org.apache.commons.math3.util.FastMath.tanh;
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99 public class SigmaAEnergy implements OptimizationInterface {
100
101 private static final Logger logger = Logger.getLogger(SigmaAEnergy.class.getName());
102 private static final double twoPI2 = 2.0 * PI * PI;
103
104 private final ReflectionList reflectionList;
105 private final DiffractionRefinementData refinementData;
106 private final ParallelTeam parallelTeam;
107 private final Crystal crystal;
108 private final double[][] fSigF;
109 private final double[][] fcTot;
110 private final double[][] fomPhi;
111 private final double[][] foFc1;
112 private final double[][] foFc2;
113 private final double[][] dFc;
114 private final double[][] dFs;
115 private final int nBins;
116
117
118
119
120 private final double dfScale;
121
122
123
124 private final double[][] transposeA;
125
126
127
128
129
130
131 private final double[] sa;
132
133
134
135 private final double[] wa;
136
137 private final SigmaARegion sigmaARegion;
138 private final boolean useCernBessel;
139 private double[] optimizationScaling = null;
140 private double totalEnergy;
141
142
143
144
145
146
147
148
149 SigmaAEnergy(ReflectionList reflectionList, DiffractionRefinementData refinementData,
150 ParallelTeam parallelTeam) {
151 this.reflectionList = reflectionList;
152 this.refinementData = refinementData;
153 this.parallelTeam = parallelTeam;
154 this.crystal = reflectionList.crystal;
155 this.fSigF = refinementData.fSigF;
156 this.fcTot = refinementData.fcTot;
157 this.fomPhi = refinementData.fomPhi;
158 this.foFc1 = refinementData.foFc1;
159 this.foFc2 = refinementData.foFc2;
160 this.dFc = refinementData.dFc;
161 this.dFs = refinementData.dFs;
162 this.nBins = refinementData.nBins;
163
164
165 assert (refinementData.crystalReciprocalSpaceFc != null);
166 double nGrid2 = 2.0
167 * refinementData.crystalReciprocalSpaceFc.getXDim()
168 * refinementData.crystalReciprocalSpaceFc.getYDim()
169 * refinementData.crystalReciprocalSpaceFc.getZDim();
170 dfScale = (crystal.volume * crystal.volume) / nGrid2;
171 transposeA = mat3Transpose(crystal.A);
172 sa = new double[nBins];
173 wa = new double[nBins];
174
175 sigmaARegion = new SigmaARegion(this.parallelTeam.getThreadCount());
176 useCernBessel = true;
177 }
178
179
180
181
182 @Override
183 public boolean destroy() {
184
185 return true;
186 }
187
188
189
190
191 @Override
192 public double energy(double[] x) {
193 unscaleCoordinates(x);
194 double sum = target(x, null, false, false);
195 scaleCoordinates(x);
196 return sum;
197 }
198
199
200
201
202 @Override
203 public double energyAndGradient(double[] x, double[] g) {
204 unscaleCoordinates(x);
205 double sum = target(x, g, true, false);
206 scaleCoordinatesAndGradient(x, g);
207 return sum;
208 }
209
210
211
212
213 @Override
214 public double[] getCoordinates(double[] parameters) {
215 throw new UnsupportedOperationException("Not supported yet.");
216 }
217
218
219
220
221 @Override
222 public void setCoordinates(double[] parameters) {
223 throw new UnsupportedOperationException("Not supported yet.");
224 }
225
226
227
228
229 @Override
230 public int getNumberOfVariables() {
231 throw new UnsupportedOperationException("Not supported yet.");
232 }
233
234
235
236
237 @Override
238 public double[] getScaling() {
239 return optimizationScaling;
240 }
241
242
243
244
245 @Override
246 public void setScaling(double[] scaling) {
247 if (scaling != null && scaling.length == nBins * 2) {
248 optimizationScaling = scaling;
249 } else {
250 optimizationScaling = null;
251 }
252 }
253
254
255
256
257 @Override
258 public double getTotalEnergy() {
259 return totalEnergy;
260 }
261
262
263
264
265
266
267
268
269
270
271 public double target(double[] x, double[] g, boolean gradient, boolean print) {
272
273 try {
274 sigmaARegion.init(x, g, gradient);
275 parallelTeam.execute(sigmaARegion);
276 } catch (Exception e) {
277 logger.info(e.toString());
278 }
279
280 double sum = sigmaARegion.sum.get();
281 double sumR = sigmaARegion.sumR.get();
282 refinementData.llkR = sumR;
283 refinementData.llkF = sum;
284
285 if (print) {
286 int nSum = sigmaARegion.nSum.get();
287 int nSumr = sigmaARegion.nSumR.get();
288 StringBuilder sb = new StringBuilder("\n");
289 sb.append(" Sigma A (s and w) fit using only R free reflections\n");
290 sb.append(format(" # HKL: %10d (free set) %10d (working set) %10d (total)\n", nSum, nSumr, nSum + nSumr));
291 sb.append(format(" residual: %10.4f (free set) %10.4f (working set) %10.4f (total)\n", sum, sumR, sum + sumR));
292 sb.append(" X: ");
293 for (double x1 : x) {
294 sb.append(format("%8.5f ", x1));
295 }
296 if (gradient) {
297 sb.append("\n G: ");
298 for (double v : g) {
299 sb.append(format("%8.5f ", v));
300 }
301 }
302 sb.append("\n");
303 logger.info(sb.toString());
304 }
305
306 totalEnergy = sum;
307 return totalEnergy;
308 }
309
310 private class SigmaARegion extends ParallelRegion {
311
312 private final double[] model_b = new double[6];
313 private final double[][] uStar = new double[3][3];
314 boolean gradient = true;
315 double modelK;
316 double solventK;
317 double solventUEq;
318 double[] x;
319 double[] g;
320 SharedInteger nSum;
321 SharedInteger nSumR;
322 SharedDouble sum;
323 SharedDouble sumR;
324 SharedDoubleArray grad;
325 SigmaALoop[] sigmaALoop;
326
327 SigmaARegion(int nThreads) {
328 sigmaALoop = new SigmaALoop[nThreads];
329 nSum = new SharedInteger();
330 nSumR = new SharedInteger();
331 sum = new SharedDouble();
332 sumR = new SharedDouble();
333 }
334
335 @Override
336 public void finish() {
337 if (gradient) {
338 for (int i = 0; i < g.length; i++) {
339 g[i] = grad.get(i);
340 }
341 }
342 }
343
344 public void init(double[] x, double[] g, boolean gradient) {
345 this.x = x;
346 this.g = g;
347 this.gradient = gradient;
348 }
349
350 @Override
351 public void run() {
352 int ti = getThreadIndex();
353
354 if (sigmaALoop[ti] == null) {
355 sigmaALoop[ti] = new SigmaALoop();
356 }
357
358 try {
359 execute(0, reflectionList.hklList.size() - 1, sigmaALoop[ti]);
360 } catch (Exception e) {
361 logger.info(e.toString());
362 }
363 }
364
365 @Override
366 public void start() {
367
368 if (gradient) {
369 if (grad == null) {
370 grad = new SharedDoubleArray(g.length);
371 }
372 for (int i = 0; i < g.length; i++) {
373 grad.set(i, 0.0);
374 }
375 }
376 sum.set(0.0);
377 nSum.set(0);
378 sumR.set(0.0);
379 nSumR.set(0);
380
381 modelK = refinementData.modelScaleK;
382 solventK = refinementData.bulkSolventK;
383 solventUEq = refinementData.bulkSolventUeq;
384 arraycopy(refinementData.modelAnisoB, 0, model_b, 0, 6);
385
386
387 mat3SymVec6(crystal.A, model_b, uStar);
388 mat3Mat3Multiply(uStar, transposeA, uStar);
389
390 for (int i = 0; i < nBins; i++) {
391 sa[i] = x[i];
392 wa[i] = x[nBins + i];
393 }
394
395
396 for (int i = 0; i < nBins; i++) {
397 if (wa[i] <= 0.0) {
398 wa[i] = 1.0e-6;
399 }
400 }
401 }
402
403 private class SigmaALoop extends IntegerForLoop {
404
405 private final double[] lGrad;
406 private final double[] resv = new double[3];
407 private final double[] ihc = new double[3];
408 private final ComplexNumber resc = new ComplexNumber();
409 private final ComplexNumber fcc = new ComplexNumber();
410 private final ComplexNumber fsc = new ComplexNumber();
411 private final ComplexNumber fct = new ComplexNumber();
412 private final ComplexNumber kfct = new ComplexNumber();
413 private final ComplexNumber ecc = new ComplexNumber();
414 private final ComplexNumber esc = new ComplexNumber();
415 private final ComplexNumber ect = new ComplexNumber();
416 private final ComplexNumber kect = new ComplexNumber();
417 private final ComplexNumber mfo = new ComplexNumber();
418 private final ComplexNumber mfo2 = new ComplexNumber();
419 private final ComplexNumber dfcc = new ComplexNumber();
420 private final ReflectionSpline spline = new ReflectionSpline(reflectionList, nBins);
421
422 private double lSum;
423 private double lSumR;
424 private int lSumN;
425 private int lSumRN;
426
427 SigmaALoop() {
428 lGrad = new double[2 * nBins];
429 }
430
431 @Override
432 public void finish() {
433 sum.addAndGet(lSum);
434 sumR.addAndGet(lSumR);
435 nSum.addAndGet(lSumN);
436 nSumR.addAndGet(lSumRN);
437 if (gradient) {
438 for (int i = 0; i < lGrad.length; i++) {
439 grad.getAndAdd(i, lGrad[i]);
440 }
441 }
442 }
443
444 @Override
445 public void run(int lb, int ub) {
446 for (int j = lb; j <= ub; j++) {
447 HKL ih = reflectionList.hklList.get(j);
448 int i = ih.getIndex();
449
450 ihc[0] = ih.getH();
451 ihc[1] = ih.getK();
452 ihc[2] = ih.getL();
453 double s = crystal.invressq(ih);
454
455 double ebs = exp(-twoPI2 * solventUEq * s);
456 double ksebs = solventK * ebs;
457
458 vec3Mat3(ihc, uStar, resv);
459 double u = modelK - dot(resv, ihc);
460 double kmems = exp(0.25 * u);
461 double km2 = exp(0.5 * u);
462
463
464
465 double epsc = ih.epsilonc();
466
467
468 double ecscale = spline.f(s, refinementData.esqFc);
469
470 double eoscale = spline.f(s, refinementData.esqFo);
471 double sqrtECScale = sqrt(ecscale);
472 double sqrtEOScale = sqrt(eoscale);
473 double iSqrtEOScale = 1.0 / sqrtEOScale;
474
475
476 double sai = spline.f(s, sa);
477 double wai = spline.f(s, wa);
478 double sa2 = sai * sai;
479
480
481
482 refinementData.getFcIP(i, fcc);
483 fct.copy(fcc);
484
485 refinementData.getFsIP(i, fsc);
486 if (refinementData.crystalReciprocalSpaceFs.getSolventModel() != SolventModel.NONE) {
487 resc.copy(fsc);
488 resc.timesIP(ksebs);
489
490 fct.plusIP(resc);
491 }
492
493 kfct.copy(fct);
494 kfct.timesIP(kmems);
495
496 ecc.copy(fcc);
497 ecc.timesIP(sqrtECScale);
498
499 esc.copy(fsc);
500 esc.timesIP(sqrtECScale);
501
502 ect.copy(fct);
503 ect.timesIP(sqrtECScale);
504
505 kect.copy(kfct);
506 kect.timesIP(sqrtECScale);
507
508 double eo = fSigF[i][0] * sqrtEOScale;
509 double eo2 = eo * eo;
510
511 double sigeo = fSigF[i][1] * sqrtEOScale;
512
513 double akect = kect.abs();
514 double kect2 = akect * akect;
515
516
517
518 double d = 2.0 * sigeo * sigeo + epsc * wai;
519 double id = 1.0 / d;
520 double id2 = id * id;
521
522 double fomx = 2.0 * eo * sai * akect * id;
523
524 double inot, dinot, cf;
525 if (ih.centric()) {
526
527 inot = (abs(fomx) < 10.0) ? log(cosh(fomx)) : abs(fomx) + log(0.5);
528
529 dinot = tanh(fomx);
530 cf = 0.5;
531 } else {
532 if (useCernBessel) {
533
534 inot = lnI0(fomx);
535
536 dinot = i1OverI0(fomx);
537 } else {
538 inot = lnI0_clipper(fomx);
539 dinot = i1OverI0_clipper(fomx);
540 }
541 cf = 1.0;
542 }
543
544
545 double llk = cf * log(d) + (eo2 + sa2 * kect2) * id - inot;
546
547
548 double f = dinot * eo;
549
550 double phi = kect.phase();
551 double sinPhi = sin(phi);
552 double cosPhi = cos(phi);
553
554 fomPhi[i][0] = dinot;
555 fomPhi[i][1] = phi;
556 mfo.re(f * cosPhi);
557 mfo.im(f * sinPhi);
558 mfo2.re(2.0 * f * cosPhi);
559 mfo2.im(2.0 * f * sinPhi);
560 dfcc.re(sai * akect * cosPhi);
561 dfcc.im(sai * akect * sinPhi);
562
563 foFc1[i][0] = 0.0;
564 foFc1[i][1] = 0.0;
565 foFc2[i][0] = 0.0;
566 foFc2[i][1] = 0.0;
567 dFc[i][0] = 0.0;
568 dFc[i][1] = 0.0;
569 dFs[i][0] = 0.0;
570 dFs[i][1] = 0.0;
571 if (isNaN(fcTot[i][0])) {
572 if (!isNaN(fSigF[i][0])) {
573 foFc2[i][0] = mfo.re() * iSqrtEOScale;
574 foFc2[i][1] = mfo.im() * iSqrtEOScale;
575 }
576 continue;
577 }
578 if (isNaN(fSigF[i][0])) {
579 if (!isNaN(fcTot[i][0])) {
580 foFc2[i][0] = dfcc.re() * iSqrtEOScale;
581 foFc2[i][1] = dfcc.im() * iSqrtEOScale;
582 }
583 continue;
584 }
585
586 fcTot[i][0] = kfct.re();
587 fcTot[i][1] = kfct.im();
588
589 resc.copy(mfo);
590 resc.minusIP(dfcc);
591 foFc1[i][0] = resc.re() * iSqrtEOScale;
592 foFc1[i][1] = resc.im() * iSqrtEOScale;
593
594 resc.copy(mfo2);
595 resc.minusIP(dfcc);
596 foFc2[i][0] = resc.re() * iSqrtEOScale;
597 foFc2[i][1] = resc.im() * iSqrtEOScale;
598
599
600 double dafct = d * fct.abs();
601 double idafct = 1.0 / dafct;
602 double dfp1 = 2.0 * sa2 * km2 * ecscale;
603 double dfp2 = 2.0 * eo * sai * kmems * sqrt(ecscale);
604 double dfp1id = dfp1 * id;
605 double dfp2id = dfp2 * idafct * dinot;
606 double dfp12 = dfp1id - dfp2id;
607 double dfp21 = ksebs * (dfp2id - dfp1id);
608 double dfcr = fct.re() * dfp12;
609 double dfci = fct.im() * dfp12;
610 double dfsr = fct.re() * dfp21;
611 double dfsi = fct.im() * dfp21;
612 double dfsa = 2.0 * (sai * kect2 - eo * akect * dinot) * id;
613 double dfwa = epsc * (cf * id - (eo2 + sa2 * kect2) * id2 + 2.0 * eo * sai * akect * id2 * dinot);
614
615
616 dFc[i][0] = dfcr * dfScale;
617 dFc[i][1] = dfci * dfScale;
618 dFs[i][0] = dfsr * dfScale;
619 dFs[i][1] = dfsi * dfScale;
620
621
622 if (refinementData.isFreeR(i)) {
623 lSum += llk;
624 lSumN++;
625 } else {
626 lSumR += llk;
627 lSumRN++;
628 dfsa = 0.0;
629 dfwa = 0.0;
630 }
631
632 if (gradient) {
633
634 int i0 = spline.i0();
635 int i1 = spline.i1();
636 int i2 = spline.i2();
637
638 double g0 = spline.dfi0();
639 double g1 = spline.dfi1();
640 double g2 = spline.dfi2();
641
642 lGrad[i0] += dfsa * g0;
643 lGrad[i1] += dfsa * g1;
644 lGrad[i2] += dfsa * g2;
645
646 lGrad[nBins + i0] += dfwa * g0;
647 lGrad[nBins + i1] += dfwa * g1;
648 lGrad[nBins + i2] += dfwa * g2;
649 }
650 }
651 }
652
653 @Override
654 public void start() {
655 lSum = 0.0;
656 lSumR = 0.0;
657 lSumN = 0;
658 lSumRN = 0;
659 if (gradient) {
660 fill(lGrad, 0.0);
661 }
662 }
663 }
664 }
665
666
667
668
669 private static final double sim_a = 1.639294;
670 private static final double sim_b = 3.553967;
671 private static final double sim_c = 2.228716;
672 private static final double sim_d = 3.524142;
673 private static final double sim_e = 7.107935;
674 private static final double sim_A = -1.28173889903;
675 private static final double sim_B = 0.69231689903;
676 private static final double sim_g = 2.13643992379;
677 private static final double sim_p = 0.04613803811;
678 private static final double sim_q = 1.82167089029;
679 private static final double sim_r = -0.74817947490;
680
681
682
683
684
685
686
687
688
689
690 private static double i1OverI0_clipper(double x) {
691 if (x >= 0.0) {
692 return (((x + sim_a) * x + sim_b) * x) / (((x + sim_c) * x + sim_d) * x + sim_e);
693 } else {
694 return -(-(-(-x + sim_a) * x + sim_b) * x) / (-(-(-x + sim_c) * x + sim_d) * x + sim_e);
695 }
696 }
697
698
699
700
701
702
703
704
705
706
707 private static double lnI0_clipper(double x0) {
708 double x = abs(x0);
709 double z = (x + sim_p) / sim_q;
710 return sim_A * log(x + sim_g) + 0.5 * sim_B * log(z * z + 1.0) + sim_r * atan(z) + x + 1.0;
711 }
712 }