1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray;
39
40 import static ffx.numerics.math.DoubleMath.dot;
41 import static ffx.numerics.math.MatrixMath.mat3Mat3;
42 import static ffx.numerics.math.MatrixMath.mat3SymVec6;
43 import static ffx.numerics.math.MatrixMath.transpose3;
44 import static ffx.numerics.math.MatrixMath.vec3Mat3;
45 import static java.lang.Double.isNaN;
46 import static java.util.Arrays.fill;
47 import static org.apache.commons.math3.util.FastMath.PI;
48 import static org.apache.commons.math3.util.FastMath.abs;
49 import static org.apache.commons.math3.util.FastMath.exp;
50
51 import edu.rit.pj.IntegerForLoop;
52 import edu.rit.pj.ParallelRegion;
53 import edu.rit.pj.ParallelTeam;
54 import edu.rit.pj.reduction.SharedDouble;
55 import edu.rit.pj.reduction.SharedDoubleArray;
56 import ffx.crystal.Crystal;
57 import ffx.crystal.HKL;
58 import ffx.crystal.ReflectionList;
59 import ffx.numerics.OptimizationInterface;
60 import ffx.numerics.math.ComplexNumber;
61 import ffx.xray.CrystalReciprocalSpace.SolventModel;
62
63 import java.util.logging.Logger;
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79 public class ScaleBulkEnergy implements OptimizationInterface {
80
81 private static final Logger logger = Logger.getLogger(ScaleBulkEnergy.class.getName());
82 private static final double twopi2 = 2.0 * PI * PI;
83 private static final double[][] u11 = {{1.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}};
84 private static final double[][] u22 = {{0.0, 0.0, 0.0}, {0.0, 1.0, 0.0}, {0.0, 0.0, 0.0}};
85 private static final double[][] u33 = {{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 1.0}};
86 private static final double[][] u12 = {{0.0, 1.0, 0.0}, {1.0, 0.0, 0.0}, {0.0, 0.0, 0.0}};
87 private static final double[][] u13 = {{0.0, 0.0, 1.0}, {0.0, 0.0, 0.0}, {1.0, 0.0, 0.0}};
88 private static final double[][] u23 = {{0.0, 0.0, 0.0}, {0.0, 0.0, 1.0}, {0.0, 1.0, 0.0}};
89 private final double[][] recipt;
90 private final double[][] j11;
91 private final double[][] j22;
92 private final double[][] j33;
93 private final double[][] j12;
94 private final double[][] j13;
95 private final double[][] j23;
96
97 private final ReflectionList reflectionList;
98 private final Crystal crystal;
99 private final DiffractionRefinementData refinementData;
100 private final double[][] fc;
101 private final double[][] fcTot;
102 private final double[][] fSigF;
103 private final int n;
104 private final int solventN;
105 private final ParallelTeam parallelTeam;
106 private final ScaleBulkEnergyRegion scaleBulkEnergyRegion;
107 private double[] optimizationScaling = null;
108 private double totalEnergy;
109
110
111
112
113
114
115
116
117
118
119 ScaleBulkEnergy(
120 ReflectionList reflectionList,
121 DiffractionRefinementData refinementData,
122 int n,
123 ParallelTeam parallelTeam) {
124 this.reflectionList = reflectionList;
125 this.crystal = reflectionList.crystal;
126 this.refinementData = refinementData;
127 this.fc = refinementData.fc;
128 this.fcTot = refinementData.fcTot;
129 this.fSigF = refinementData.fSigF;
130 this.n = n;
131 this.solventN = n - refinementData.nScale;
132
133 recipt = transpose3(crystal.A);
134 j11 = mat3Mat3(mat3Mat3(crystal.A, u11), recipt);
135 j22 = mat3Mat3(mat3Mat3(crystal.A, u22), recipt);
136 j33 = mat3Mat3(mat3Mat3(crystal.A, u33), recipt);
137 j12 = mat3Mat3(mat3Mat3(crystal.A, u12), recipt);
138 j13 = mat3Mat3(mat3Mat3(crystal.A, u13), recipt);
139 j23 = mat3Mat3(mat3Mat3(crystal.A, u23), recipt);
140
141 int threadCount = parallelTeam.getThreadCount();
142 this.parallelTeam = parallelTeam;
143 scaleBulkEnergyRegion = new ScaleBulkEnergyRegion(threadCount);
144 }
145
146
147
148
149 @Override
150 public boolean destroy() {
151
152 return true;
153 }
154
155
156
157
158 @Override
159 public double energy(double[] x) {
160 unscaleCoordinates(x);
161 double sum = target(x, null, false, false);
162 scaleCoordinates(x);
163 return sum;
164 }
165
166
167
168
169 @Override
170 public double energyAndGradient(double[] x, double[] g) {
171 unscaleCoordinates(x);
172 double sum = target(x, g, true, false);
173 scaleCoordinatesAndGradient(x, g);
174 return sum;
175 }
176
177
178
179
180 @Override
181 public double[] getCoordinates(double[] parameters) {
182 throw new UnsupportedOperationException("Not supported yet.");
183 }
184
185
186
187
188 @Override
189 public int getNumberOfVariables() {
190 throw new UnsupportedOperationException("Not supported yet.");
191 }
192
193
194
195
196 @Override
197 public double[] getScaling() {
198 return optimizationScaling;
199 }
200
201
202
203
204 @Override
205 public void setScaling(double[] scaling) {
206 if (scaling != null && scaling.length == n) {
207 optimizationScaling = scaling;
208 } else {
209 optimizationScaling = null;
210 }
211 }
212
213
214
215
216 @Override
217 public double getTotalEnergy() {
218 return totalEnergy;
219 }
220
221
222
223
224
225
226
227
228
229
230 public double target(double[] x, double[] g, boolean gradient, boolean print) {
231
232 try {
233 scaleBulkEnergyRegion.init(x, g, gradient);
234 parallelTeam.execute(scaleBulkEnergyRegion);
235 } catch (Exception e) {
236 logger.info(e.toString());
237 }
238
239 double sum = scaleBulkEnergyRegion.sum.get();
240 double sumfo = scaleBulkEnergyRegion.sumFo.get();
241 double r = scaleBulkEnergyRegion.r.get();
242 double rf = scaleBulkEnergyRegion.rf.get();
243 double rfree = scaleBulkEnergyRegion.rFree.get();
244 double rfreef = scaleBulkEnergyRegion.rFreeF.get();
245
246 if (gradient) {
247 double isumfo = 1.0 / sumfo;
248 for (int i = 0; i < g.length; i++) {
249 g[i] *= isumfo;
250 }
251 }
252
253 if (print) {
254 StringBuilder sb = new StringBuilder("\n");
255 sb.append("Bulk solvent and scale fit\n");
256 sb.append(String.format(" residual: %8.3f\n", sum / sumfo));
257 sb.append(
258 String.format(
259 " R: %8.3f Rfree: %8.3f\n", (r / rf) * 100.0, (rfree / rfreef) * 100.0));
260 sb.append("x: ");
261 for (double x1 : x) {
262 sb.append(String.format("%8g ", x1));
263 }
264 sb.append("\ng: ");
265 for (double v : g) {
266 sb.append(String.format("%8g ", v));
267 }
268 sb.append("\n");
269 logger.info(sb.toString());
270 }
271 totalEnergy = sum / sumfo;
272 return sum / sumfo;
273 }
274
275 private class ScaleBulkEnergyRegion extends ParallelRegion {
276
277 private final double[] modelB = new double[6];
278 private final double[][] uStar = new double[3][3];
279 private final double[][] resM = new double[3][3];
280 boolean gradient = true;
281 double[] x;
282 double[] g;
283 double solventK;
284 double modelK;
285 double solventUEq;
286 SharedDouble r;
287 SharedDouble rf;
288 SharedDouble rFree;
289 SharedDouble rFreeF;
290 SharedDouble sum;
291 SharedDouble sumFo;
292 SharedDoubleArray grad;
293 ScaleBulkEnergyLoop[] scaleBulkEnergyLoop;
294
295 ScaleBulkEnergyRegion(int nThreads) {
296 scaleBulkEnergyLoop = new ScaleBulkEnergyLoop[nThreads];
297 r = new SharedDouble();
298 rf = new SharedDouble();
299 rFree = new SharedDouble();
300 rFreeF = new SharedDouble();
301 sum = new SharedDouble();
302 sumFo = new SharedDouble();
303 }
304
305 @Override
306 public void finish() {
307 if (gradient) {
308 for (int i = 0; i < g.length; i++) {
309 g[i] = grad.get(i);
310 }
311 }
312 }
313
314 public void init(double[] x, double[] g, boolean gradient) {
315 this.x = x;
316 this.g = g;
317 this.gradient = gradient;
318 }
319
320 @Override
321 public void run() {
322 int ti = getThreadIndex();
323 if (scaleBulkEnergyLoop[ti] == null) {
324 scaleBulkEnergyLoop[ti] = new ScaleBulkEnergyLoop();
325 }
326
327 try {
328 execute(0, reflectionList.hklList.size() - 1, scaleBulkEnergyLoop[ti]);
329 } catch (Exception e) {
330 logger.info(e.toString());
331 }
332 }
333
334 @Override
335 public void start() {
336 r.set(0.0);
337 rf.set(0.0);
338 rFree.set(0.0);
339 rFreeF.set(0.0);
340 sum.set(0.0);
341 sumFo.set(0.0);
342
343 for (int i = 0; i < 6; i++) {
344 if (crystal.scaleB[i] >= 0) {
345 modelB[i] = x[solventN + crystal.scaleB[i]];
346 }
347 }
348
349 modelK = x[0];
350 solventK = refinementData.bulkSolventK;
351 solventUEq = refinementData.bulkSolventUeq;
352 if (solventN > 1) {
353 solventK = x[1];
354 solventUEq = x[2];
355 }
356
357
358 mat3SymVec6(crystal.A, modelB, resM);
359 mat3Mat3(resM, recipt, uStar);
360
361 if (gradient) {
362 if (grad == null) {
363 grad = new SharedDoubleArray(g.length);
364 }
365 for (int i = 0; i < g.length; i++) {
366 grad.set(i, 0.0);
367 }
368 }
369 }
370
371 private class ScaleBulkEnergyLoop extends IntegerForLoop {
372
373 private final double[] resv = new double[3];
374 private final double[] ihc = new double[3];
375 private final ComplexNumber resc = new ComplexNumber();
376 private final ComplexNumber fcc = new ComplexNumber();
377 private final ComplexNumber fsc = new ComplexNumber();
378 private final ComplexNumber fct = new ComplexNumber();
379 private final ComplexNumber kfct = new ComplexNumber();
380 private final double[] lgrad;
381 private double lr;
382 private double lrf;
383 private double lrfree;
384 private double lrfreef;
385 private double lsum;
386 private double lsumfo;
387
388 ScaleBulkEnergyLoop() {
389 lgrad = new double[g.length];
390 }
391
392 @Override
393 public void finish() {
394 r.addAndGet(lr);
395 rf.addAndGet(lrf);
396 rFree.addAndGet(lrfree);
397 rFreeF.addAndGet(lrfreef);
398 sum.addAndGet(lsum);
399 sumFo.addAndGet(lsumfo);
400 for (int i = 0; i < lgrad.length; i++) {
401 grad.getAndAdd(i, lgrad[i]);
402 }
403 }
404
405 @Override
406 public void run(int lb, int ub) {
407
408 for (int j = lb; j <= ub; j++) {
409 HKL ih = reflectionList.hklList.get(j);
410 int i = ih.getIndex();
411 if (isNaN(fc[i][0]) || isNaN(fSigF[i][0]) || fSigF[i][1] <= 0.0) {
412 continue;
413 }
414
415
416 double s = crystal.invressq(ih);
417 ihc[0] = ih.getH();
418 ihc[1] = ih.getK();
419 ihc[2] = ih.getL();
420 vec3Mat3(ihc, uStar, resv);
421 double u = modelK - dot(resv, ihc);
422 double expBS = exp(-twopi2 * solventUEq * s);
423 double ksExpBS = solventK * expBS;
424 double expU = exp(0.25 * u);
425
426
427 refinementData.getFcIP(i, fcc);
428 refinementData.getFsIP(i, fsc);
429 fct.copy(fcc);
430 if (refinementData.crystalReciprocalSpaceFs.solventModel != SolventModel.NONE) {
431 resc.copy(fsc);
432 resc.timesIP(ksExpBS);
433 fct.plusIP(resc);
434 }
435 kfct.copy(fct);
436 kfct.timesIP(expU);
437
438
439 fcTot[i][0] = kfct.re();
440 fcTot[i][1] = kfct.im();
441
442
443 double f1 = refinementData.getF(i);
444 double akfct = kfct.abs();
445 double af1 = abs(f1);
446 double d = f1 - akfct;
447 double d2 = d * d;
448 double dr = -2.0 * d;
449
450 lsum += d2;
451 lsumfo += f1 * f1;
452
453 if (refinementData.isFreeR(i)) {
454 lrfree += abs(af1 - abs(akfct));
455 lrfreef += af1;
456 } else {
457 lr += abs(af1 - abs(akfct));
458 lrf += af1;
459 }
460
461 if (gradient) {
462
463 double dfm = 0.25 * akfct * dr;
464
465 double afsc = fsc.abs();
466 double dfb =
467 expBS * (fcc.re() * fsc.re() + fcc.im() * fsc.im() + ksExpBS * afsc * afsc);
468
469
470 lgrad[0] += dfm;
471 if (solventN > 1) {
472 double iafct = 1.0 / fct.abs();
473
474 lgrad[1] += expU * dfb * dr * iafct;
475
476 lgrad[2] += expU * -twopi2 * s * solventK * dfb * dr * iafct;
477 }
478
479 for (int jj = 0; jj < 6; jj++) {
480 if (crystal.scaleB[jj] >= 0) {
481 switch (jj) {
482 case (0) -> {
483
484 vec3Mat3(ihc, j11, resv);
485 lgrad[solventN + crystal.scaleB[jj]] += -dfm * dot(resv, ihc);
486 }
487 case (1) -> {
488
489 vec3Mat3(ihc, j22, resv);
490 lgrad[solventN + crystal.scaleB[jj]] += -dfm * dot(resv, ihc);
491 }
492 case (2) -> {
493
494 vec3Mat3(ihc, j33, resv);
495 lgrad[solventN + crystal.scaleB[jj]] += -dfm * dot(resv, ihc);
496 }
497 case (3) -> {
498
499 vec3Mat3(ihc, j12, resv);
500 lgrad[solventN + crystal.scaleB[jj]] += -dfm * dot(resv, ihc);
501 }
502 case (4) -> {
503
504 vec3Mat3(ihc, j13, resv);
505 lgrad[solventN + crystal.scaleB[jj]] += -dfm * dot(resv, ihc);
506 }
507 case (5) -> {
508
509 vec3Mat3(ihc, j23, resv);
510 lgrad[solventN + crystal.scaleB[jj]] += -dfm * dot(resv, ihc);
511 }
512 }
513 }
514 }
515 }
516 }
517 }
518
519 @Override
520 public void start() {
521 lr = 0.0;
522 lrf = 0.0;
523 lrfree = 0.0;
524 lrfreef = 0.0;
525 lsum = 0.0;
526 lsumfo = 0.0;
527 fill(lgrad, 0.0);
528 }
529 }
530 }
531 }