1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray;
39
40 import edu.rit.pj.IntegerForLoop;
41 import edu.rit.pj.ParallelRegion;
42 import edu.rit.pj.ParallelTeam;
43 import edu.rit.pj.reduction.SharedDouble;
44 import edu.rit.pj.reduction.SharedDoubleArray;
45 import ffx.crystal.Crystal;
46 import ffx.crystal.HKL;
47 import ffx.crystal.ReflectionList;
48 import ffx.numerics.OptimizationInterface;
49 import ffx.numerics.math.ComplexNumber;
50 import ffx.xray.solvent.SolventModel;
51
52 import javax.annotation.Nullable;
53 import java.util.logging.Logger;
54
55 import static ffx.numerics.math.DoubleMath.dot;
56 import static ffx.numerics.math.MatrixMath.mat3Mat3;
57 import static ffx.numerics.math.MatrixMath.mat3SymVec6;
58 import static ffx.numerics.math.MatrixMath.transpose3;
59 import static ffx.numerics.math.MatrixMath.vec3Mat3;
60 import static java.lang.Double.isNaN;
61 import static java.lang.String.format;
62 import static java.util.Arrays.fill;
63 import static org.apache.commons.math3.util.FastMath.PI;
64 import static org.apache.commons.math3.util.FastMath.abs;
65 import static org.apache.commons.math3.util.FastMath.exp;
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82 public class ScaleBulkEnergy implements OptimizationInterface {
83
84 private static final Logger logger = Logger.getLogger(ScaleBulkEnergy.class.getName());
85 private static final double twopi2 = 2.0 * PI * PI;
86 private static final double[] v000 = {0.0, 0.0, 0.0};
87 private static final double[] v100 = {1.0, 0.0, 0.0};
88 private static final double[] v010 = {0.0, 1.0, 0.0};
89 private static final double[] v001 = {0.0, 0.0, 1.0};
90 private static final double[][] u11 = {v100, v000, v000};
91 private static final double[][] u22 = {v000, v010, v000};
92 private static final double[][] u33 = {v000, v000, v001};
93 private static final double[][] u12 = {v010, v100, v000};
94 private static final double[][] u13 = {v001, v000, v100};
95 private static final double[][] u23 = {v000, v001, v010};
96 private final double[][] recipt;
97 private final double[][] j11;
98 private final double[][] j22;
99 private final double[][] j33;
100 private final double[][] j12;
101 private final double[][] j13;
102 private final double[][] j23;
103
104 private final ReflectionList reflectionList;
105 private final Crystal crystal;
106 private final DiffractionRefinementData refinementData;
107 private final double[][] fc;
108 private final double[][] fcTot;
109 private final double[][] fSigF;
110 private final int n;
111 private final int solventN;
112 private final ParallelTeam parallelTeam;
113 private final ScaleBulkEnergyRegion scaleBulkEnergyRegion;
114 private double[] optimizationScaling = null;
115 private double totalEnergy;
116 private double R;
117 private double Rfree;
118
119
120
121
122
123
124
125
126
127 ScaleBulkEnergy(ReflectionList reflectionList, DiffractionRefinementData refinementData,
128 int n, ParallelTeam parallelTeam) {
129 this.reflectionList = reflectionList;
130 this.crystal = reflectionList.crystal;
131 this.refinementData = refinementData;
132 this.fc = refinementData.fc;
133 this.fcTot = refinementData.fcTot;
134 this.fSigF = refinementData.fSigF;
135 this.n = n;
136 this.solventN = n - refinementData.nScale;
137
138 recipt = transpose3(crystal.A);
139 j11 = mat3Mat3(mat3Mat3(crystal.A, u11), recipt);
140 j22 = mat3Mat3(mat3Mat3(crystal.A, u22), recipt);
141 j33 = mat3Mat3(mat3Mat3(crystal.A, u33), recipt);
142 j12 = mat3Mat3(mat3Mat3(crystal.A, u12), recipt);
143 j13 = mat3Mat3(mat3Mat3(crystal.A, u13), recipt);
144 j23 = mat3Mat3(mat3Mat3(crystal.A, u23), recipt);
145
146 int threadCount = parallelTeam.getThreadCount();
147 this.parallelTeam = parallelTeam;
148 scaleBulkEnergyRegion = new ScaleBulkEnergyRegion(threadCount);
149 }
150
151
152
153
154 @Override
155 public boolean destroy() {
156
157 return true;
158 }
159
160
161
162
163 @Override
164 public double energy(double[] x) {
165 unscaleCoordinates(x);
166 double sum = target(x, null, false, false);
167 scaleCoordinates(x);
168 return sum;
169 }
170
171
172
173
174 @Override
175 public double energyAndGradient(double[] x, double[] g) {
176 unscaleCoordinates(x);
177 double sum = target(x, g, true, false);
178 scaleCoordinatesAndGradient(x, g);
179 return sum;
180 }
181
182 public double getR() {
183 return R;
184 }
185 public double getRfree() {
186 return Rfree;
187 }
188
189
190
191
192 @Override
193 public double[] getCoordinates(double[] parameters) {
194 throw new UnsupportedOperationException("Not supported yet.");
195 }
196
197
198
199
200 @Override
201 public void setCoordinates(double[] parameters) {
202 throw new UnsupportedOperationException("Not supported yet.");
203 }
204
205
206
207
208 @Override
209 public int getNumberOfVariables() {
210 throw new UnsupportedOperationException("Not supported yet.");
211 }
212
213
214
215
216 @Override
217 public double[] getScaling() {
218 return optimizationScaling;
219 }
220
221
222
223
224 @Override
225 public void setScaling(@Nullable double[] scaling) {
226 if (scaling != null && scaling.length == n) {
227 optimizationScaling = scaling;
228 } else {
229 optimizationScaling = null;
230 }
231 }
232
233
234
235
236 @Override
237 public double getTotalEnergy() {
238 return totalEnergy;
239 }
240
241
242
243
244
245
246
247
248
249
250 public double target(double[] x, double[] g, boolean gradient, boolean print) {
251 try {
252 scaleBulkEnergyRegion.init(x, g, gradient);
253 parallelTeam.execute(scaleBulkEnergyRegion);
254 } catch (Exception e) {
255 logger.severe(e.toString());
256 }
257
258 double sum = scaleBulkEnergyRegion.sum.get();
259 double sumfo = scaleBulkEnergyRegion.sumFo.get();
260 double r = scaleBulkEnergyRegion.r.get();
261 double rf = scaleBulkEnergyRegion.rf.get();
262 double rfree = scaleBulkEnergyRegion.rFree.get();
263 double rfreef = scaleBulkEnergyRegion.rFreeF.get();
264
265 R = (r / rf) * 100.0;
266 Rfree = (rfree / rfreef) * 100.0;
267 totalEnergy = sum / sumfo;
268
269 if (gradient) {
270 double isumfo = 1.0 / sumfo;
271 for (int i = 0; i < g.length; i++) {
272 g[i] *= isumfo;
273 }
274 }
275
276 if (print) {
277 StringBuilder sb = new StringBuilder("\n");
278 sb.append(" Bulk solvent and scale fit\n");
279 sb.append(format(" Residual: %10.5f R: %8.3f Rfree: %8.3f\n", totalEnergy, R, Rfree));
280 sb.append(" Params: ");
281 for (double x1 : x) {
282 sb.append(format("%16.8f ", x1));
283 }
284 if (gradient) {
285 sb.append("\n Gradient: ");
286 for (double v : g) {
287 sb.append(format("%16.8f ", v));
288 }
289 }
290 sb.append("\n");
291 logger.info(sb.toString());
292 }
293
294 return totalEnergy;
295 }
296
297 private class ScaleBulkEnergyRegion extends ParallelRegion {
298
299 private final double[] modelB = new double[6];
300 private final double[][] uStar = new double[3][3];
301 private final double[][] resM = new double[3][3];
302 boolean gradient = true;
303 double[] x;
304 double[] g;
305 double solventK;
306 double modelK;
307 double solventUEq;
308 SharedDouble r;
309 SharedDouble rf;
310 SharedDouble rFree;
311 SharedDouble rFreeF;
312 SharedDouble sum;
313 SharedDouble sumFo;
314 SharedDoubleArray grad;
315 ScaleBulkEnergyLoop[] scaleBulkEnergyLoop;
316
317 ScaleBulkEnergyRegion(int nThreads) {
318 scaleBulkEnergyLoop = new ScaleBulkEnergyLoop[nThreads];
319 r = new SharedDouble();
320 rf = new SharedDouble();
321 rFree = new SharedDouble();
322 rFreeF = new SharedDouble();
323 sum = new SharedDouble();
324 sumFo = new SharedDouble();
325 }
326
327 @Override
328 public void finish() {
329 if (gradient) {
330 for (int i = 0; i < g.length; i++) {
331 g[i] = grad.get(i);
332 }
333 }
334 }
335
336 public void init(double[] x, double[] g, boolean gradient) {
337 this.x = x;
338 this.g = g;
339 this.gradient = gradient;
340 }
341
342 @Override
343 public void run() {
344 int ti = getThreadIndex();
345 if (scaleBulkEnergyLoop[ti] == null) {
346 scaleBulkEnergyLoop[ti] = new ScaleBulkEnergyLoop();
347 }
348
349 try {
350 execute(0, reflectionList.hklList.size() - 1, scaleBulkEnergyLoop[ti]);
351 } catch (Exception e) {
352 logger.info(e.toString());
353 }
354 }
355
356 @Override
357 public void start() {
358 r.set(0.0);
359 rf.set(0.0);
360 rFree.set(0.0);
361 rFreeF.set(0.0);
362 sum.set(0.0);
363 sumFo.set(0.0);
364
365 for (int i = 0; i < 6; i++) {
366 if (crystal.scaleB[i] >= 0) {
367 modelB[i] = x[solventN + crystal.scaleB[i]];
368 }
369 }
370
371 modelK = x[0];
372 solventK = refinementData.bulkSolventK;
373 solventUEq = refinementData.bulkSolventUeq;
374 if (solventN > 1) {
375 solventK = x[1];
376 solventUEq = x[2];
377 }
378
379
380 mat3SymVec6(crystal.A, modelB, resM);
381 mat3Mat3(resM, recipt, uStar);
382
383 if (gradient) {
384 if (grad == null) {
385 grad = new SharedDoubleArray(g.length);
386 }
387 for (int i = 0; i < g.length; i++) {
388 grad.set(i, 0.0);
389 }
390 }
391 }
392
393 private class ScaleBulkEnergyLoop extends IntegerForLoop {
394
395 private final double[] resv = new double[3];
396 private final double[] ihc = new double[3];
397 private final ComplexNumber resc = new ComplexNumber();
398 private final ComplexNumber fcc = new ComplexNumber();
399 private final ComplexNumber fsc = new ComplexNumber();
400 private final ComplexNumber fct = new ComplexNumber();
401 private final ComplexNumber kfct = new ComplexNumber();
402 private final double[] lgrad;
403 private double lr;
404 private double lrf;
405 private double lrfree;
406 private double lrfreef;
407 private double lsum;
408 private double lsumfo;
409
410 ScaleBulkEnergyLoop() {
411 lgrad = new double[n];
412 }
413
414 @Override
415 public void finish() {
416 r.addAndGet(lr);
417 rf.addAndGet(lrf);
418 rFree.addAndGet(lrfree);
419 rFreeF.addAndGet(lrfreef);
420 sum.addAndGet(lsum);
421 sumFo.addAndGet(lsumfo);
422 if (gradient) {
423 for (int i = 0; i < lgrad.length; i++) {
424 grad.getAndAdd(i, lgrad[i]);
425 }
426 }
427 }
428
429 @Override
430 public void run(int lb, int ub) {
431
432 for (int j = lb; j <= ub; j++) {
433 HKL ih = reflectionList.hklList.get(j);
434 int i = ih.getIndex();
435 if (isNaN(fc[i][0]) || isNaN(fSigF[i][0]) || fSigF[i][1] <= 0.0) {
436 continue;
437 }
438
439
440 double s = crystal.invressq(ih);
441 ihc[0] = ih.getH();
442 ihc[1] = ih.getK();
443 ihc[2] = ih.getL();
444 vec3Mat3(ihc, uStar, resv);
445 double u = modelK - dot(resv, ihc);
446 double expBS = exp(-twopi2 * solventUEq * s);
447 double ksExpBS = solventK * expBS;
448 double expU = exp(0.25 * u);
449
450
451 refinementData.getFcIP(i, fcc);
452 refinementData.getFsIP(i, fsc);
453 fct.copy(fcc);
454 if (refinementData.crystalReciprocalSpaceFs.getSolventModel() != SolventModel.NONE) {
455 resc.copy(fsc);
456 resc.timesIP(ksExpBS);
457 fct.plusIP(resc);
458 }
459 kfct.copy(fct);
460 kfct.timesIP(expU);
461
462
463 fcTot[i][0] = kfct.re();
464 fcTot[i][1] = kfct.im();
465
466
467 double f1 = refinementData.getF(i);
468 double akfct = kfct.abs();
469 double af1 = abs(f1);
470 double d = f1 - akfct;
471 double d2 = d * d;
472 double dr = -2.0 * d;
473
474 lsum += d2;
475 lsumfo += f1 * f1;
476
477 if (refinementData.isFreeR(i)) {
478 lrfree += abs(af1 - abs(akfct));
479 lrfreef += af1;
480 } else {
481 lr += abs(af1 - abs(akfct));
482 lrf += af1;
483 }
484
485 if (gradient) {
486
487 double dfm = 0.25 * akfct * dr;
488
489 double afsc = fsc.abs();
490 double dfb = expBS * (fcc.re() * fsc.re() + fcc.im() * fsc.im() + ksExpBS * afsc * afsc);
491
492
493 lgrad[0] += dfm;
494 if (solventN > 1) {
495 double iafct = 1.0 / fct.abs();
496
497 lgrad[1] += expU * dfb * dr * iafct;
498
499 lgrad[2] += expU * -twopi2 * s * solventK * dfb * dr * iafct;
500 }
501
502 for (int jj = 0; jj < 6; jj++) {
503 if (crystal.scaleB[jj] >= 0) {
504 switch (jj) {
505 case (0) -> {
506
507 vec3Mat3(ihc, j11, resv);
508 lgrad[solventN + crystal.scaleB[jj]] += -dfm * dot(resv, ihc);
509 }
510 case (1) -> {
511
512 vec3Mat3(ihc, j22, resv);
513 lgrad[solventN + crystal.scaleB[jj]] += -dfm * dot(resv, ihc);
514 }
515 case (2) -> {
516
517 vec3Mat3(ihc, j33, resv);
518 lgrad[solventN + crystal.scaleB[jj]] += -dfm * dot(resv, ihc);
519 }
520 case (3) -> {
521
522 vec3Mat3(ihc, j12, resv);
523 lgrad[solventN + crystal.scaleB[jj]] += -dfm * dot(resv, ihc);
524 }
525 case (4) -> {
526
527 vec3Mat3(ihc, j13, resv);
528 lgrad[solventN + crystal.scaleB[jj]] += -dfm * dot(resv, ihc);
529 }
530 case (5) -> {
531
532 vec3Mat3(ihc, j23, resv);
533 lgrad[solventN + crystal.scaleB[jj]] += -dfm * dot(resv, ihc);
534 }
535 }
536 }
537 }
538 }
539 }
540 }
541
542 @Override
543 public void start() {
544 lr = 0.0;
545 lrf = 0.0;
546 lrfree = 0.0;
547 lrfreef = 0.0;
548 lsum = 0.0;
549 lsumfo = 0.0;
550 fill(lgrad, 0.0);
551 }
552 }
553 }
554 }