1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray;
39
40 import edu.rit.pj.IntegerForLoop;
41 import edu.rit.pj.ParallelRegion;
42 import edu.rit.pj.ParallelTeam;
43 import edu.rit.pj.reduction.SharedDouble;
44 import edu.rit.pj.reduction.SharedDoubleArray;
45 import edu.rit.pj.reduction.SharedInteger;
46 import ffx.crystal.Crystal;
47 import ffx.crystal.HKL;
48 import ffx.crystal.ReflectionList;
49 import ffx.crystal.ReflectionSpline;
50 import ffx.numerics.OptimizationInterface;
51 import ffx.numerics.math.ComplexNumber;
52 import ffx.xray.solvent.SolventModel;
53
54 import java.util.logging.Logger;
55
56 import static ffx.numerics.math.DoubleMath.dot;
57 import static ffx.numerics.math.MatrixMath.mat3Mat3;
58 import static ffx.numerics.math.MatrixMath.mat3SymVec6;
59 import static ffx.numerics.math.MatrixMath.transpose3;
60 import static ffx.numerics.math.MatrixMath.vec3Mat3;
61 import static ffx.numerics.special.ModifiedBessel.i1OverI0;
62 import static ffx.numerics.special.ModifiedBessel.lnI0;
63 import static java.lang.Double.isNaN;
64 import static java.lang.String.format;
65 import static java.lang.System.arraycopy;
66 import static java.util.Arrays.fill;
67 import static org.apache.commons.math3.util.FastMath.PI;
68 import static org.apache.commons.math3.util.FastMath.abs;
69 import static org.apache.commons.math3.util.FastMath.atan;
70 import static org.apache.commons.math3.util.FastMath.cos;
71 import static org.apache.commons.math3.util.FastMath.cosh;
72 import static org.apache.commons.math3.util.FastMath.exp;
73 import static org.apache.commons.math3.util.FastMath.log;
74 import static org.apache.commons.math3.util.FastMath.sin;
75 import static org.apache.commons.math3.util.FastMath.sqrt;
76 import static org.apache.commons.math3.util.FastMath.tanh;
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100 public class SigmaAEnergy implements OptimizationInterface {
101
102 private static final Logger logger = Logger.getLogger(SigmaAEnergy.class.getName());
103 private static final double twoPI2 = 2.0 * PI * PI;
104
105 private final ReflectionList reflectionList;
106 private final DiffractionRefinementData refinementData;
107 private final ParallelTeam parallelTeam;
108 private final Crystal crystal;
109 private final double[][] fSigF;
110 private final double[][] fcTot;
111 private final double[][] fomPhi;
112 private final double[][] foFc1;
113 private final double[][] foFc2;
114 private final double[][] dFc;
115 private final double[][] dFs;
116 private final int nBins;
117
118
119
120
121 private final double dfScale;
122
123
124
125 private final double[][] transposeA;
126
127
128
129
130
131
132 private final double[] sa;
133
134
135
136 private final double[] wa;
137
138 private final SigmaARegion sigmaARegion;
139 private final boolean useCernBessel;
140 private double[] optimizationScaling = null;
141 private double totalEnergy;
142
143
144
145
146
147
148
149
150 SigmaAEnergy(ReflectionList reflectionList, DiffractionRefinementData refinementData,
151 ParallelTeam parallelTeam) {
152 this.reflectionList = reflectionList;
153 this.refinementData = refinementData;
154 this.parallelTeam = parallelTeam;
155 this.crystal = reflectionList.crystal;
156 this.fSigF = refinementData.fSigF;
157 this.fcTot = refinementData.fcTot;
158 this.fomPhi = refinementData.fomPhi;
159 this.foFc1 = refinementData.foFc1;
160 this.foFc2 = refinementData.foFc2;
161 this.dFc = refinementData.dFc;
162 this.dFs = refinementData.dFs;
163 this.nBins = refinementData.nBins;
164
165
166 assert (refinementData.crystalReciprocalSpaceFc != null);
167 double nGrid2 = 2.0
168 * refinementData.crystalReciprocalSpaceFc.getXDim()
169 * refinementData.crystalReciprocalSpaceFc.getYDim()
170 * refinementData.crystalReciprocalSpaceFc.getZDim();
171 dfScale = (crystal.volume * crystal.volume) / nGrid2;
172 transposeA = transpose3(crystal.A);
173 sa = new double[nBins];
174 wa = new double[nBins];
175
176 sigmaARegion = new SigmaARegion(this.parallelTeam.getThreadCount());
177 useCernBessel = true;
178 }
179
180
181
182
183 @Override
184 public boolean destroy() {
185
186 return true;
187 }
188
189
190
191
192 @Override
193 public double energy(double[] x) {
194 unscaleCoordinates(x);
195 double sum = target(x, null, false, false);
196 scaleCoordinates(x);
197 return sum;
198 }
199
200
201
202
203 @Override
204 public double energyAndGradient(double[] x, double[] g) {
205 unscaleCoordinates(x);
206 double sum = target(x, g, true, false);
207 scaleCoordinatesAndGradient(x, g);
208 return sum;
209 }
210
211
212
213
214 @Override
215 public double[] getCoordinates(double[] parameters) {
216 throw new UnsupportedOperationException("Not supported yet.");
217 }
218
219
220
221
222 @Override
223 public void setCoordinates(double[] parameters) {
224 throw new UnsupportedOperationException("Not supported yet.");
225 }
226
227
228
229
230 @Override
231 public int getNumberOfVariables() {
232 throw new UnsupportedOperationException("Not supported yet.");
233 }
234
235
236
237
238 @Override
239 public double[] getScaling() {
240 return optimizationScaling;
241 }
242
243
244
245
246 @Override
247 public void setScaling(double[] scaling) {
248 if (scaling != null && scaling.length == nBins * 2) {
249 optimizationScaling = scaling;
250 } else {
251 optimizationScaling = null;
252 }
253 }
254
255
256
257
258 @Override
259 public double getTotalEnergy() {
260 return totalEnergy;
261 }
262
263
264
265
266
267
268
269
270
271
272 public double target(double[] x, double[] g, boolean gradient, boolean print) {
273
274 try {
275 sigmaARegion.init(x, g, gradient);
276 parallelTeam.execute(sigmaARegion);
277 } catch (Exception e) {
278 logger.info(e.toString());
279 }
280
281 double sum = sigmaARegion.sum.get();
282 double sumR = sigmaARegion.sumR.get();
283 refinementData.llkR = sumR;
284 refinementData.llkF = sum;
285
286 if (print) {
287 int nSum = sigmaARegion.nSum.get();
288 int nSumr = sigmaARegion.nSumR.get();
289 StringBuilder sb = new StringBuilder("\n");
290 sb.append(" Sigma A (s and w) fit using only R free reflections\n");
291 sb.append(format(" # HKL: %10d (free set) %10d (working set) %10d (total)\n", nSum, nSumr, nSum + nSumr));
292 sb.append(format(" residual: %10.4f (free set) %10.4f (working set) %10.4f (total)\n", sum, sumR, sum + sumR));
293 sb.append(" X: ");
294 for (double x1 : x) {
295 sb.append(format("%8.5f ", x1));
296 }
297 if (gradient) {
298 sb.append("\n G: ");
299 for (double v : g) {
300 sb.append(format("%8.5f ", v));
301 }
302 }
303 sb.append("\n");
304 logger.info(sb.toString());
305 }
306
307 totalEnergy = sum;
308 return totalEnergy;
309 }
310
311 private class SigmaARegion extends ParallelRegion {
312
313 private final double[][] resm = new double[3][3];
314 private final double[] model_b = new double[6];
315 private final double[][] ustar = new double[3][3];
316 boolean gradient = true;
317 double modelK;
318 double solventK;
319 double solventUEq;
320 double[] x;
321 double[] g;
322 SharedInteger nSum;
323 SharedInteger nSumR;
324 SharedDouble sum;
325 SharedDouble sumR;
326 SharedDoubleArray grad;
327 SigmaALoop[] sigmaALoop;
328
329 SigmaARegion(int nThreads) {
330 sigmaALoop = new SigmaALoop[nThreads];
331 nSum = new SharedInteger();
332 nSumR = new SharedInteger();
333 sum = new SharedDouble();
334 sumR = new SharedDouble();
335 }
336
337 @Override
338 public void finish() {
339 if (gradient) {
340 for (int i = 0; i < g.length; i++) {
341 g[i] = grad.get(i);
342 }
343 }
344 }
345
346 public void init(double[] x, double[] g, boolean gradient) {
347 this.x = x;
348 this.g = g;
349 this.gradient = gradient;
350 }
351
352 @Override
353 public void run() {
354 int ti = getThreadIndex();
355
356 if (sigmaALoop[ti] == null) {
357 sigmaALoop[ti] = new SigmaALoop();
358 }
359
360 try {
361 execute(0, reflectionList.hklList.size() - 1, sigmaALoop[ti]);
362 } catch (Exception e) {
363 logger.info(e.toString());
364 }
365 }
366
367 @Override
368 public void start() {
369
370 if (gradient) {
371 if (grad == null) {
372 grad = new SharedDoubleArray(g.length);
373 }
374 for (int i = 0; i < g.length; i++) {
375 grad.set(i, 0.0);
376 }
377 }
378 sum.set(0.0);
379 nSum.set(0);
380 sumR.set(0.0);
381 nSumR.set(0);
382
383 modelK = refinementData.modelScaleK;
384 solventK = refinementData.bulkSolventK;
385 solventUEq = refinementData.bulkSolventUeq;
386 arraycopy(refinementData.modelAnisoB, 0, model_b, 0, 6);
387
388
389 mat3SymVec6(crystal.A, model_b, resm);
390 mat3Mat3(resm, transposeA, ustar);
391
392 for (int i = 0; i < nBins; i++) {
393 sa[i] = x[i];
394 wa[i] = x[nBins + i];
395 }
396
397
398 for (int i = 0; i < nBins; i++) {
399 if (wa[i] <= 0.0) {
400
401 wa[i] = 1.0e-6;
402 }
403 }
404 }
405
406 private class SigmaALoop extends IntegerForLoop {
407
408 private final double[] lGrad;
409 private final double[] resv = new double[3];
410 private final double[] ihc = new double[3];
411 private final ComplexNumber resc = new ComplexNumber();
412 private final ComplexNumber fcc = new ComplexNumber();
413 private final ComplexNumber fsc = new ComplexNumber();
414 private final ComplexNumber fct = new ComplexNumber();
415 private final ComplexNumber kfct = new ComplexNumber();
416 private final ComplexNumber ecc = new ComplexNumber();
417 private final ComplexNumber esc = new ComplexNumber();
418 private final ComplexNumber ect = new ComplexNumber();
419 private final ComplexNumber kect = new ComplexNumber();
420 private final ComplexNumber mfo = new ComplexNumber();
421 private final ComplexNumber mfo2 = new ComplexNumber();
422 private final ComplexNumber dfcc = new ComplexNumber();
423 private final ReflectionSpline spline = new ReflectionSpline(reflectionList, nBins);
424
425 private double lSum;
426 private double lSumR;
427 private int lSumN;
428 private int lSumRN;
429
430 SigmaALoop() {
431 lGrad = new double[2 * nBins];
432 }
433
434 @Override
435 public void finish() {
436 sum.addAndGet(lSum);
437 sumR.addAndGet(lSumR);
438 nSum.addAndGet(lSumN);
439 nSumR.addAndGet(lSumRN);
440 if (gradient) {
441 for (int i = 0; i < lGrad.length; i++) {
442 grad.getAndAdd(i, lGrad[i]);
443 }
444 }
445 }
446
447 @Override
448 public void run(int lb, int ub) {
449 for (int j = lb; j <= ub; j++) {
450 HKL ih = reflectionList.hklList.get(j);
451 int i = ih.getIndex();
452
453 ihc[0] = ih.getH();
454 ihc[1] = ih.getK();
455 ihc[2] = ih.getL();
456 double s = crystal.invressq(ih);
457
458 double ebs = exp(-twoPI2 * solventUEq * s);
459 double ksebs = solventK * ebs;
460
461 vec3Mat3(ihc, ustar, resv);
462 double u = modelK - dot(resv, ihc);
463 double kmems = exp(0.25 * u);
464 double km2 = exp(0.5 * u);
465
466
467
468 double epsc = ih.epsilonc();
469
470
471 double ecscale = spline.f(s, refinementData.esqFc);
472
473 double eoscale = spline.f(s, refinementData.esqFo);
474 double sqrtECScale = sqrt(ecscale);
475 double sqrtEOScale = sqrt(eoscale);
476 double iSqrtEOScale = 1.0 / sqrtEOScale;
477
478
479 double sai = spline.f(s, sa);
480 double wai = spline.f(s, wa);
481 double sa2 = sai * sai;
482
483
484
485 refinementData.getFcIP(i, fcc);
486 fct.copy(fcc);
487
488 refinementData.getFsIP(i, fsc);
489 if (refinementData.crystalReciprocalSpaceFs.getSolventModel() != SolventModel.NONE) {
490 resc.copy(fsc);
491 resc.timesIP(ksebs);
492
493 fct.plusIP(resc);
494 }
495
496 kfct.copy(fct);
497 kfct.timesIP(kmems);
498
499 ecc.copy(fcc);
500 ecc.timesIP(sqrtECScale);
501
502 esc.copy(fsc);
503 esc.timesIP(sqrtECScale);
504
505 ect.copy(fct);
506 ect.timesIP(sqrtECScale);
507
508 kect.copy(kfct);
509 kect.timesIP(sqrtECScale);
510
511 double eo = fSigF[i][0] * sqrtEOScale;
512 double eo2 = eo * eo;
513
514 double sigeo = fSigF[i][1] * sqrtEOScale;
515
516 double akect = kect.abs();
517 double kect2 = akect * akect;
518
519
520
521 double d = 2.0 * sigeo * sigeo + epsc * wai;
522 double id = 1.0 / d;
523 double id2 = id * id;
524
525 double fomx = 2.0 * eo * sai * akect * id;
526
527 double inot, dinot, cf;
528 if (ih.centric()) {
529
530 inot = (abs(fomx) < 10.0) ? log(cosh(fomx)) : abs(fomx) + log(0.5);
531
532 dinot = tanh(fomx);
533 cf = 0.5;
534 } else {
535 if (useCernBessel) {
536
537 inot = lnI0(fomx);
538
539 dinot = i1OverI0(fomx);
540 } else {
541 inot = lnI0_clipper(fomx);
542 dinot = i1OverI0_clipper(fomx);
543 }
544 cf = 1.0;
545 }
546
547
548 double llk = cf * log(d) + (eo2 + sa2 * kect2) * id - inot;
549
550
551 double f = dinot * eo;
552
553 double phi = kect.phase();
554 double sinPhi = sin(phi);
555 double cosPhi = cos(phi);
556
557 fomPhi[i][0] = dinot;
558 fomPhi[i][1] = phi;
559 mfo.re(f * cosPhi);
560 mfo.im(f * sinPhi);
561 mfo2.re(2.0 * f * cosPhi);
562 mfo2.im(2.0 * f * sinPhi);
563 dfcc.re(sai * akect * cosPhi);
564 dfcc.im(sai * akect * sinPhi);
565
566 foFc1[i][0] = 0.0;
567 foFc1[i][1] = 0.0;
568 foFc2[i][0] = 0.0;
569 foFc2[i][1] = 0.0;
570 dFc[i][0] = 0.0;
571 dFc[i][1] = 0.0;
572 dFs[i][0] = 0.0;
573 dFs[i][1] = 0.0;
574 if (isNaN(fcTot[i][0])) {
575 if (!isNaN(fSigF[i][0])) {
576 foFc2[i][0] = mfo.re() * iSqrtEOScale;
577 foFc2[i][1] = mfo.im() * iSqrtEOScale;
578 }
579 continue;
580 }
581 if (isNaN(fSigF[i][0])) {
582 if (!isNaN(fcTot[i][0])) {
583 foFc2[i][0] = dfcc.re() * iSqrtEOScale;
584 foFc2[i][1] = dfcc.im() * iSqrtEOScale;
585 }
586 continue;
587 }
588
589 fcTot[i][0] = kfct.re();
590 fcTot[i][1] = kfct.im();
591
592 resc.copy(mfo);
593 resc.minusIP(dfcc);
594 foFc1[i][0] = resc.re() * iSqrtEOScale;
595 foFc1[i][1] = resc.im() * iSqrtEOScale;
596
597 resc.copy(mfo2);
598 resc.minusIP(dfcc);
599 foFc2[i][0] = resc.re() * iSqrtEOScale;
600 foFc2[i][1] = resc.im() * iSqrtEOScale;
601
602
603 double dafct = d * fct.abs();
604 double idafct = 1.0 / dafct;
605 double dfp1 = 2.0 * sa2 * km2 * ecscale;
606 double dfp2 = 2.0 * eo * sai * kmems * sqrt(ecscale);
607 double dfp1id = dfp1 * id;
608 double dfp2id = dfp2 * idafct * dinot;
609 double dfp12 = dfp1id - dfp2id;
610 double dfp21 = ksebs * (dfp2id - dfp1id);
611 double dfcr = fct.re() * dfp12;
612 double dfci = fct.im() * dfp12;
613 double dfsr = fct.re() * dfp21;
614 double dfsi = fct.im() * dfp21;
615 double dfsa = 2.0 * (sai * kect2 - eo * akect * dinot) * id;
616 double dfwa = epsc * (cf * id - (eo2 + sa2 * kect2) * id2 + 2.0 * eo * sai * akect * id2 * dinot);
617
618
619 dFc[i][0] = dfcr * dfScale;
620 dFc[i][1] = dfci * dfScale;
621 dFs[i][0] = dfsr * dfScale;
622 dFs[i][1] = dfsi * dfScale;
623
624
625 if (refinementData.isFreeR(i)) {
626 lSum += llk;
627 lSumN++;
628 } else {
629 lSumR += llk;
630 lSumRN++;
631 dfsa = 0.0;
632 dfwa = 0.0;
633 }
634
635 if (gradient) {
636
637 int i0 = spline.i0();
638 int i1 = spline.i1();
639 int i2 = spline.i2();
640
641 double g0 = spline.dfi0();
642 double g1 = spline.dfi1();
643 double g2 = spline.dfi2();
644
645 lGrad[i0] += dfsa * g0;
646 lGrad[i1] += dfsa * g1;
647 lGrad[i2] += dfsa * g2;
648
649 lGrad[nBins + i0] += dfwa * g0;
650 lGrad[nBins + i1] += dfwa * g1;
651 lGrad[nBins + i2] += dfwa * g2;
652 }
653 }
654 }
655
656 @Override
657 public void start() {
658 lSum = 0.0;
659 lSumR = 0.0;
660 lSumN = 0;
661 lSumRN = 0;
662 if (gradient) {
663 fill(lGrad, 0.0);
664 }
665 }
666 }
667 }
668
669
670
671
672 private static final double sim_a = 1.639294;
673 private static final double sim_b = 3.553967;
674 private static final double sim_c = 2.228716;
675 private static final double sim_d = 3.524142;
676 private static final double sim_e = 7.107935;
677 private static final double sim_A = -1.28173889903;
678 private static final double sim_B = 0.69231689903;
679 private static final double sim_g = 2.13643992379;
680 private static final double sim_p = 0.04613803811;
681 private static final double sim_q = 1.82167089029;
682 private static final double sim_r = -0.74817947490;
683
684
685
686
687
688
689
690
691
692
693 private static double i1OverI0_clipper(double x) {
694 if (x >= 0.0) {
695 return (((x + sim_a) * x + sim_b) * x) / (((x + sim_c) * x + sim_d) * x + sim_e);
696 } else {
697 return -(-(-(-x + sim_a) * x + sim_b) * x) / (-(-(-x + sim_c) * x + sim_d) * x + sim_e);
698 }
699 }
700
701
702
703
704
705
706
707
708
709
710 private static double lnI0_clipper(double x0) {
711 double x = abs(x0);
712 double z = (x + sim_p) / sim_q;
713 return sim_A * log(x + sim_g) + 0.5 * sim_B * log(z * z + 1.0) + sim_r * atan(z) + x + 1.0;
714 }
715 }