1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray;
39
40 import static ffx.numerics.math.DoubleMath.dot;
41 import static ffx.numerics.math.MatrixMath.mat3Mat3;
42 import static ffx.numerics.math.MatrixMath.mat3SymVec6;
43 import static ffx.numerics.math.MatrixMath.transpose3;
44 import static ffx.numerics.math.MatrixMath.vec3Mat3;
45 import static ffx.numerics.special.ModifiedBessel.i1OverI0;
46 import static ffx.numerics.special.ModifiedBessel.lnI0;
47 import static java.lang.Double.isNaN;
48 import static java.lang.String.format;
49 import static java.lang.System.arraycopy;
50 import static java.util.Arrays.fill;
51 import static org.apache.commons.math3.util.FastMath.PI;
52 import static org.apache.commons.math3.util.FastMath.abs;
53 import static org.apache.commons.math3.util.FastMath.atan;
54 import static org.apache.commons.math3.util.FastMath.cos;
55 import static org.apache.commons.math3.util.FastMath.cosh;
56 import static org.apache.commons.math3.util.FastMath.exp;
57 import static org.apache.commons.math3.util.FastMath.log;
58 import static org.apache.commons.math3.util.FastMath.sin;
59 import static org.apache.commons.math3.util.FastMath.sqrt;
60 import static org.apache.commons.math3.util.FastMath.tanh;
61
62 import edu.rit.pj.IntegerForLoop;
63 import edu.rit.pj.ParallelRegion;
64 import edu.rit.pj.ParallelTeam;
65 import edu.rit.pj.reduction.SharedDouble;
66 import edu.rit.pj.reduction.SharedDoubleArray;
67 import edu.rit.pj.reduction.SharedInteger;
68 import ffx.crystal.Crystal;
69 import ffx.crystal.HKL;
70 import ffx.crystal.ReflectionList;
71 import ffx.crystal.ReflectionSpline;
72 import ffx.numerics.OptimizationInterface;
73 import ffx.numerics.math.ComplexNumber;
74 import ffx.xray.CrystalReciprocalSpace.SolventModel;
75
76 import java.util.logging.Logger;
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99 public class SigmaAEnergy implements OptimizationInterface {
100
101 private static final Logger logger = Logger.getLogger(SigmaAEnergy.class.getName());
102 private static final double twoPI2 = 2.0 * PI * PI;
103 private static final double sim_a = 1.639294;
104 private static final double sim_b = 3.553967;
105 private static final double sim_c = 2.228716;
106 private static final double sim_d = 3.524142;
107 private static final double sim_e = 7.107935;
108 private static final double sim_A = -1.28173889903;
109 private static final double sim_B = 0.69231689903;
110 private static final double sim_g = 2.13643992379;
111 private static final double sim_p = 0.04613803811;
112 private static final double sim_q = 1.82167089029;
113 private static final double sim_r = -0.74817947490;
114
115 private final ReflectionList reflectionList;
116 private final DiffractionRefinementData refinementData;
117 private final ParallelTeam parallelTeam;
118 private final Crystal crystal;
119 private final double[][] fSigF;
120 private final double[][] fcTot;
121 private final double[][] fomPhi;
122 private final double[][] foFc1;
123 private final double[][] foFc2;
124 private final double[][] dFc;
125 private final double[][] dFs;
126 private final int nBins;
127
128
129
130
131 private final double dfScale;
132
133
134
135 private final double[][] transposeA;
136
137 private final double[] sa;
138 private final double[] wa;
139
140 private final SigmaARegion sigmaARegion;
141 private final boolean useCernBessel;
142 private double[] optimizationScaling = null;
143 private double totalEnergy;
144
145
146
147
148
149
150
151
152 SigmaAEnergy(
153 ReflectionList reflectionList,
154 DiffractionRefinementData refinementData,
155 ParallelTeam parallelTeam) {
156 this.reflectionList = reflectionList;
157 this.refinementData = refinementData;
158 this.parallelTeam = parallelTeam;
159 this.crystal = reflectionList.crystal;
160 this.fSigF = refinementData.fSigF;
161 this.fcTot = refinementData.fcTot;
162 this.fomPhi = refinementData.fomPhi;
163 this.foFc1 = refinementData.foFc1;
164 this.foFc2 = refinementData.foFc2;
165 this.dFc = refinementData.dFc;
166 this.dFs = refinementData.dFs;
167 this.nBins = refinementData.nBins;
168
169
170 assert (refinementData.crystalReciprocalSpaceFc != null);
171 double nGrid2 = 2.0
172 * refinementData.crystalReciprocalSpaceFc.getXDim()
173 * refinementData.crystalReciprocalSpaceFc.getYDim()
174 * refinementData.crystalReciprocalSpaceFc.getZDim();
175 dfScale = (crystal.volume * crystal.volume) / nGrid2;
176 transposeA = transpose3(crystal.A);
177 sa = new double[nBins];
178 wa = new double[nBins];
179
180 sigmaARegion = new SigmaARegion(this.parallelTeam.getThreadCount());
181 useCernBessel = true;
182 }
183
184
185
186
187
188
189
190
191
192
193 private static double sim(double x) {
194 if (x >= 0.0) {
195 return (((x + sim_a) * x + sim_b) * x) / (((x + sim_c) * x + sim_d) * x + sim_e);
196 } else {
197 return -(-(-(-x + sim_a) * x + sim_b) * x) / (-(-(-x + sim_c) * x + sim_d) * x + sim_e);
198 }
199 }
200
201
202
203
204
205
206
207 private static double sim_integ(double x0) {
208 double x = abs(x0);
209 double z = (x + sim_p) / sim_q;
210 return sim_A * log(x + sim_g) + 0.5 * sim_B * log(z * z + 1.0) + sim_r * atan(z) + x + 1.0;
211 }
212
213
214
215
216 @Override
217 public boolean destroy() {
218
219 return true;
220 }
221
222
223
224
225 @Override
226 public double energy(double[] x) {
227 unscaleCoordinates(x);
228 double sum = target(x, null, false, false);
229 scaleCoordinates(x);
230 return sum;
231 }
232
233
234
235
236 @Override
237 public double energyAndGradient(double[] x, double[] g) {
238 unscaleCoordinates(x);
239 double sum = target(x, g, true, false);
240 scaleCoordinatesAndGradient(x, g);
241 return sum;
242 }
243
244
245
246
247 @Override
248 public double[] getCoordinates(double[] parameters) {
249 throw new UnsupportedOperationException("Not supported yet.");
250 }
251
252
253
254
255 @Override
256 public int getNumberOfVariables() {
257 throw new UnsupportedOperationException("Not supported yet.");
258 }
259
260
261
262
263 @Override
264 public double[] getScaling() {
265 return optimizationScaling;
266 }
267
268
269
270
271 @Override
272 public void setScaling(double[] scaling) {
273 if (scaling != null && scaling.length == nBins * 2) {
274 optimizationScaling = scaling;
275 } else {
276 optimizationScaling = null;
277 }
278 }
279
280
281
282
283 @Override
284 public double getTotalEnergy() {
285 return totalEnergy;
286 }
287
288
289
290
291
292
293
294
295
296
297 public double target(double[] x, double[] g, boolean gradient, boolean print) {
298
299 try {
300 sigmaARegion.init(x, g, gradient);
301 parallelTeam.execute(sigmaARegion);
302 } catch (Exception e) {
303 logger.info(e.toString());
304 }
305
306 double sum = sigmaARegion.sum.get();
307 double sumR = sigmaARegion.sumR.get();
308 refinementData.llkR = sumR;
309 refinementData.llkF = sum;
310
311 if (print) {
312 int nSum = sigmaARegion.nSum.get();
313 int nSumr = sigmaARegion.nSumR.get();
314 StringBuilder sb = new StringBuilder("\n");
315 sb.append(" sigmaA (s and w) fit using only R free reflections\n");
316 sb.append(format(" # HKL: %10d (free set) %10d (working set) %10d (total)\n", nSum, nSumr, nSum + nSumr));
317 sb.append(format(" residual: %10g (free set) %10g (working set) %10g (total)\n", sum, sumR, sum + sumR));
318 sb.append(" x: ");
319 for (double x1 : x) {
320 sb.append(format("%g ", x1));
321 }
322 sb.append("\n g: ");
323 for (double v : g) {
324 sb.append(format("%g ", v));
325 }
326 sb.append("\n");
327 logger.info(sb.toString());
328 }
329
330 totalEnergy = sum;
331 return totalEnergy;
332 }
333
334 private class SigmaARegion extends ParallelRegion {
335
336 private final double[][] resm = new double[3][3];
337 private final double[] model_b = new double[6];
338 private final double[][] ustar = new double[3][3];
339 boolean gradient = true;
340 double modelK;
341 double solventK;
342 double solventUEq;
343 double[] x;
344 double[] g;
345 SharedInteger nSum;
346 SharedInteger nSumR;
347 SharedDouble sum;
348 SharedDouble sumR;
349 SharedDoubleArray grad;
350 SigmaALoop[] sigmaALoop;
351
352 SigmaARegion(int nThreads) {
353 sigmaALoop = new SigmaALoop[nThreads];
354 nSum = new SharedInteger();
355 nSumR = new SharedInteger();
356 sum = new SharedDouble();
357 sumR = new SharedDouble();
358 }
359
360 @Override
361 public void finish() {
362 if (gradient) {
363 for (int i = 0; i < g.length; i++) {
364 g[i] = grad.get(i);
365 }
366 }
367 }
368
369 public void init(double[] x, double[] g, boolean gradient) {
370 this.x = x;
371 this.g = g;
372 this.gradient = gradient;
373 }
374
375 @Override
376 public void run() {
377 int ti = getThreadIndex();
378
379 if (sigmaALoop[ti] == null) {
380 sigmaALoop[ti] = new SigmaALoop();
381 }
382
383 try {
384 execute(0, reflectionList.hklList.size() - 1, sigmaALoop[ti]);
385 } catch (Exception e) {
386 logger.info(e.toString());
387 }
388 }
389
390 @Override
391 public void start() {
392
393 if (gradient) {
394 if (grad == null) {
395 grad = new SharedDoubleArray(g.length);
396 }
397 for (int i = 0; i < g.length; i++) {
398 grad.set(i, 0.0);
399 }
400 }
401 sum.set(0.0);
402 nSum.set(0);
403 sumR.set(0.0);
404 nSumR.set(0);
405
406 modelK = refinementData.modelScaleK;
407 solventK = refinementData.bulkSolventK;
408 solventUEq = refinementData.bulkSolventUeq;
409 arraycopy(refinementData.modelAnisoB, 0, model_b, 0, 6);
410
411
412 mat3SymVec6(crystal.A, model_b, resm);
413 mat3Mat3(resm, transposeA, ustar);
414
415 for (int i = 0; i < nBins; i++) {
416 sa[i] = 1.0 + x[i];
417 wa[i] = x[nBins + i];
418 }
419
420
421 for (int i = 0; i < nBins; i++) {
422 if (wa[i] <= 0.0) {
423 wa[i] = 1.0e-6;
424 }
425 }
426 }
427
428 private class SigmaALoop extends IntegerForLoop {
429
430 private final double[] lGrad;
431 private final double[] resv = new double[3];
432 private final double[] ihc = new double[3];
433 private final ComplexNumber resc = new ComplexNumber();
434 private final ComplexNumber fcc = new ComplexNumber();
435 private final ComplexNumber fsc = new ComplexNumber();
436 private final ComplexNumber fct = new ComplexNumber();
437 private final ComplexNumber kfct = new ComplexNumber();
438 private final ComplexNumber ecc = new ComplexNumber();
439 private final ComplexNumber esc = new ComplexNumber();
440 private final ComplexNumber ect = new ComplexNumber();
441 private final ComplexNumber kect = new ComplexNumber();
442 private final ComplexNumber mfo = new ComplexNumber();
443 private final ComplexNumber mfo2 = new ComplexNumber();
444 private final ComplexNumber dfcc = new ComplexNumber();
445 private final ReflectionSpline spline = new ReflectionSpline(reflectionList, nBins);
446
447 private double lSum;
448 private double lSumR;
449 private int lSumN;
450 private int lSumRN;
451
452 SigmaALoop() {
453 lGrad = new double[2 * nBins];
454 }
455
456 @Override
457 public void finish() {
458 sum.addAndGet(lSum);
459 sumR.addAndGet(lSumR);
460 nSum.addAndGet(lSumN);
461 nSumR.addAndGet(lSumRN);
462 for (int i = 0; i < lGrad.length; i++) {
463 grad.getAndAdd(i, lGrad[i]);
464 }
465 }
466
467 @Override
468 public void run(int lb, int ub) {
469 for (int j = lb; j <= ub; j++) {
470 HKL ih = reflectionList.hklList.get(j);
471 int i = ih.getIndex();
472
473 ihc[0] = ih.getH();
474 ihc[1] = ih.getK();
475 ihc[2] = ih.getL();
476 vec3Mat3(ihc, ustar, resv);
477 double u = modelK - dot(resv, ihc);
478 double s = crystal.invressq(ih);
479 double ebs = exp(-twoPI2 * solventUEq * s);
480 double ksebs = solventK * ebs;
481 double kmems = exp(0.25 * u);
482 double km2 = exp(0.5 * u);
483 double epsc = ih.epsilonc();
484
485
486 double ecscale = spline.f(s, refinementData.esqFc);
487 double eoscale = spline.f(s, refinementData.esqFo);
488 double sqrtECScale = sqrt(ecscale);
489 double sqrtEOScale = sqrt(eoscale);
490 double iSqrtEOScale = 1.0 / sqrtEOScale;
491
492 double sai = spline.f(s, sa);
493 double wai = spline.f(s, wa);
494 double sa2 = sai * sai;
495
496
497 refinementData.getFcIP(i, fcc);
498 refinementData.getFsIP(i, fsc);
499 fct.copy(fcc);
500 if (refinementData.crystalReciprocalSpaceFs.solventModel != SolventModel.NONE) {
501 resc.copy(fsc);
502 resc.timesIP(ksebs);
503 fct.plusIP(resc);
504 }
505 kfct.copy(fct);
506 kfct.timesIP(kmems);
507
508 ecc.copy(fcc);
509 ecc.timesIP(sqrtECScale);
510 esc.copy(fsc);
511 esc.timesIP(sqrtECScale);
512 ect.copy(fct);
513 ect.timesIP(sqrtECScale);
514 kect.copy(kfct);
515 kect.timesIP(sqrtECScale);
516 double eo = fSigF[i][0] * sqrtEOScale;
517 double sigeo = fSigF[i][1] * sqrtEOScale;
518 double eo2 = eo * eo;
519 double akect = kect.abs();
520 double kect2 = akect * akect;
521
522
523 double d = 2.0 * sigeo * sigeo + epsc * wai;
524 double id = 1.0 / d;
525 double id2 = id * id;
526 double fomx = 2.0 * eo * sai * kect.abs() * id;
527
528 double inot, dinot, cf;
529 if (ih.centric()) {
530 inot = (abs(fomx) < 10.0) ? log(cosh(fomx)) : abs(fomx) + log(0.5);
531 dinot = tanh(fomx);
532 cf = 0.5;
533 } else {
534 if (useCernBessel) {
535 inot = lnI0(fomx);
536 dinot = i1OverI0(fomx);
537 } else {
538 inot = sim_integ(fomx);
539 dinot = sim(fomx);
540 }
541 cf = 1.0;
542 }
543 double llk = cf * log(d) + (eo2 + sa2 * kect2) * id - inot;
544
545
546 double f = dinot * eo;
547 double phi = kect.phase();
548 double sinPhi = sin(phi);
549 double cosPhi = cos(phi);
550 fomPhi[i][0] = dinot;
551 fomPhi[i][1] = phi;
552 mfo.re(f * cosPhi);
553 mfo.im(f * sinPhi);
554 mfo2.re(2.0 * f * cosPhi);
555 mfo2.im(2.0 * f * sinPhi);
556 akect = kect.abs();
557 dfcc.re(sai * akect * cosPhi);
558 dfcc.im(sai * akect * sinPhi);
559
560 foFc1[i][0] = 0.0;
561 foFc1[i][1] = 0.0;
562 foFc2[i][0] = 0.0;
563 foFc2[i][1] = 0.0;
564 dFc[i][0] = 0.0;
565 dFc[i][1] = 0.0;
566 dFs[i][0] = 0.0;
567 dFs[i][1] = 0.0;
568 if (isNaN(fcTot[i][0])) {
569 if (!isNaN(fSigF[i][0])) {
570 foFc2[i][0] = mfo.re() * iSqrtEOScale;
571 foFc2[i][1] = mfo.im() * iSqrtEOScale;
572 }
573 continue;
574 }
575 if (isNaN(fSigF[i][0])) {
576 if (!isNaN(fcTot[i][0])) {
577 foFc2[i][0] = dfcc.re() * iSqrtEOScale;
578 foFc2[i][1] = dfcc.im() * iSqrtEOScale;
579 }
580 continue;
581 }
582
583 fcTot[i][0] = kfct.re();
584 fcTot[i][1] = kfct.im();
585
586 resc.copy(mfo);
587 resc.minusIP(dfcc);
588 foFc1[i][0] = resc.re() * iSqrtEOScale;
589 foFc1[i][1] = resc.im() * iSqrtEOScale;
590
591 resc.copy(mfo2);
592 resc.minusIP(dfcc);
593 foFc2[i][0] = resc.re() * iSqrtEOScale;
594 foFc2[i][1] = resc.im() * iSqrtEOScale;
595
596
597 double dafct = d * fct.abs();
598 double idafct = 1.0 / dafct;
599 double dfp1 = 2.0 * sa2 * km2 * ecscale;
600 double dfp2 = 2.0 * eo * sai * kmems * sqrt(ecscale);
601 double dfp1id = dfp1 * id;
602 double dfp2id = dfp2 * idafct * dinot;
603 double dfp12 = dfp1id - dfp2id;
604 double dfp21 = ksebs * (dfp2id - dfp1id);
605 double dfcr = fct.re() * dfp12;
606 double dfci = fct.im() * dfp12;
607 double dfsr = fct.re() * dfp21;
608 double dfsi = fct.im() * dfp21;
609 double dfsa = 2.0 * (sai * kect2 - eo * akect * dinot) * id;
610 double dfwa =
611 epsc * (cf * id - (eo2 + sa2 * kect2) * id2 + 2.0 * eo * sai * akect * id2 * dinot);
612
613
614 dFc[i][0] = dfcr * dfScale;
615 dFc[i][1] = dfci * dfScale;
616 dFs[i][0] = dfsr * dfScale;
617 dFs[i][1] = dfsi * dfScale;
618
619
620 if (refinementData.isFreeR(i)) {
621 lSum += llk;
622 lSumN++;
623 } else {
624 lSumR += llk;
625 lSumRN++;
626 dfsa = dfwa = 0.0;
627 }
628
629 if (gradient) {
630 int i0 = spline.i0();
631 int i1 = spline.i1();
632 int i2 = spline.i2();
633 double g0 = spline.dfi0();
634 double g1 = spline.dfi1();
635 double g2 = spline.dfi2();
636
637 lGrad[i0] += dfsa * g0;
638 lGrad[i1] += dfsa * g1;
639 lGrad[i2] += dfsa * g2;
640
641 lGrad[nBins + i0] += dfwa * g0;
642 lGrad[nBins + i1] += dfwa * g1;
643 lGrad[nBins + i2] += dfwa * g2;
644 }
645 }
646 }
647
648 @Override
649 public void start() {
650 lSum = 0.0;
651 lSumR = 0.0;
652 lSumN = 0;
653 lSumRN = 0;
654 fill(lGrad, 0.0);
655 }
656 }
657 }
658 }