1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray;
39
40 import ffx.crystal.Crystal;
41 import ffx.crystal.CrystalPotential;
42 import ffx.potential.bonded.Atom;
43 import ffx.potential.bonded.LambdaInterface;
44 import ffx.potential.parameters.ForceField;
45 import ffx.xray.refine.RefinementMode;
46 import ffx.xray.refine.RefinementModel;
47
48 import javax.annotation.Nullable;
49 import java.util.List;
50 import java.util.logging.Logger;
51
52 import static ffx.numerics.math.MatrixMath.determinant3;
53 import static ffx.numerics.math.ScalarMath.b2u;
54 import static ffx.numerics.math.ScalarMath.u2b;
55 import static ffx.utilities.Constants.R;
56 import static java.lang.Math.pow;
57 import static java.lang.String.format;
58 import static java.lang.System.arraycopy;
59 import static org.apache.commons.math3.util.FastMath.PI;
60
61
62
63
64
65
66
67
68 public class XRayEnergy implements LambdaInterface, CrystalPotential {
69
70 private static final Logger logger = Logger.getLogger(XRayEnergy.class.getName());
71 private static final double eightPI2 = 8.0 * PI * PI;
72 private static final double eightPI23 = eightPI2 * eightPI2 * eightPI2;
73
74 private final DiffractionData diffractionData;
75 private final RefinementModel refinementModel;
76 private final RefinementMode refinementMode;
77 private final Atom[] activeAtomArray;
78 private double[] optimizationScaling = null;
79 private final double kTbNonzero;
80 private final double kTbSimWeight;
81 private final boolean lambdaTerm;
82 private final double[] g2;
83 private final double[] dUdXdL;
84 protected double lambda = 1.0;
85 private final int nXYZ;
86 private final int nB;
87 private final int nOCC;
88 private boolean xrayTerms = true;
89 private boolean restraintTerms = true;
90 private double totalEnergy;
91 private double dEdL;
92 private STATE state = STATE.BOTH;
93
94
95
96
97
98
99 public XRayEnergy(DiffractionData diffractionData) {
100 this.diffractionData = diffractionData;
101
102 refinementModel = diffractionData.getRefinementModel();
103 refinementMode = refinementModel.getRefinementMode();
104 nXYZ = refinementModel.getNumCoordParameters();
105 nB = refinementModel.getNumBFactorParameters();
106 nOCC = refinementModel.getNumOccupancyParameters();
107
108 double temperature = 50.0;
109 kTbNonzero = R * temperature * diffractionData.getbNonZeroWeight();
110 kTbSimWeight = R * temperature * diffractionData.getbSimWeight();
111
112 ForceField forceField = diffractionData.getAssembly()[0].getForceField();
113 lambdaTerm = forceField.getBoolean("LAMBDATERM", false);
114
115 activeAtomArray = refinementModel.getActiveAtoms();
116 int count = activeAtomArray.length;
117 dUdXdL = new double[count * 3];
118 g2 = new double[count * 3];
119
120 if (refinementMode.includesBFactors()) {
121 logger.info("\n B-Factor Refinement Parameters");
122 logger.info(" Temperature: " + temperature);
123 logger.info(" Non-zero restraint weight: " + diffractionData.getbNonZeroWeight());
124 logger.info(" Similarity restraint weight: " + diffractionData.getbSimWeight());
125 }
126 }
127
128
129
130
131 @Override
132 public boolean destroy() {
133 return diffractionData.destroy();
134 }
135
136
137
138
139 @Override
140 public double energy(double[] x) {
141 double e = 0.0;
142
143
144 unscaleCoordinates(x);
145
146
147 refinementModel.setParameters(x);
148
149
150 if (refinementMode.includesCoordinates()) {
151 diffractionData.updateCoordinates();
152 }
153
154 if (xrayTerms) {
155
156 if (lambdaTerm) {
157 diffractionData.setLambdaTerm(false);
158 }
159
160
161 diffractionData.computeAtomicDensity();
162
163 e = diffractionData.computeLikelihood();
164
165 if (lambdaTerm) {
166
167
168 diffractionData.setLambdaTerm(true);
169
170
171 diffractionData.computeAtomicDensity();
172
173
174 double e2 = diffractionData.computeLikelihood();
175
176 dEdL = e - e2;
177
178 e = lambda * e + (1.0 - lambda) * e2;
179
180 diffractionData.setLambdaTerm(false);
181 }
182 }
183
184 if (restraintTerms) {
185 if (refinementMode.includesBFactors()) {
186
187 e += getBFactorRestraints(false);
188 }
189 }
190
191
192 scaleCoordinates(x);
193
194 totalEnergy = e;
195 return e;
196 }
197
198
199
200
201 @Override
202 public double energyAndGradient(double[] x, double[] g) {
203 double e = 0.0;
204
205
206 unscaleCoordinates(x);
207
208
209 refinementModel.setParameters(x);
210
211
212 refinementModel.zeroGradient();
213
214
215 if (refinementMode.includesCoordinates()) {
216 diffractionData.updateCoordinates();
217 }
218
219 if (xrayTerms) {
220
221 if (lambdaTerm) {
222 diffractionData.setLambdaTerm(false);
223 }
224
225
226 diffractionData.computeAtomicDensity();
227
228
229 e = diffractionData.computeLikelihood();
230
231
232 diffractionData.computeAtomicGradients(refinementMode);
233
234 if (lambdaTerm) {
235 logger.severe(" Lambda Refinement is not supported.");
236 int n = dUdXdL.length;
237 arraycopy(g, 0, dUdXdL, 0, n);
238
239 for (Atom a : activeAtomArray) {
240 a.setXYZGradient(0.0, 0.0, 0.0);
241 a.setLambdaXYZGradient(0.0, 0.0, 0.0);
242 }
243
244
245 diffractionData.setLambdaTerm(true);
246
247
248 diffractionData.computeAtomicDensity();
249
250
251 double e2 = diffractionData.computeLikelihood();
252
253
254 diffractionData.computeAtomicGradients(refinementMode);
255
256 dEdL = e - e2;
257 e = lambda * e + (1.0 - lambda) * e2;
258
259 for (int i = 0; i < g.length; i++) {
260 dUdXdL[i] -= g2[i];
261 g[i] = lambda * g[i] + (1.0 - lambda) * g2[i];
262 }
263
264 diffractionData.setLambdaTerm(false);
265 }
266 }
267
268 if (restraintTerms) {
269 if (refinementMode.includesBFactors()) {
270
271 e += getBFactorRestraints(true);
272 }
273 }
274
275
276 refinementModel.getGradient(g);
277
278
279 scaleCoordinatesAndGradient(x, g);
280
281 totalEnergy = e;
282 return e;
283 }
284
285
286
287
288 @Override
289 public double[] getCoordinates(double[] x) {
290 if (x == null || x.length != refinementModel.getNumParameters()) {
291 x = new double[refinementModel.getNumParameters()];
292 }
293 refinementModel.getParameters(x);
294 return x;
295 }
296
297
298
299
300 @Override
301 public void setCoordinates(double[] x) {
302 refinementModel.setParameters(x);
303 }
304
305
306
307
308 @Override
309 public Crystal getCrystal() {
310 return diffractionData.getCrystal()[0];
311 }
312
313
314
315
316 @Override
317 public void setCrystal(Crystal crystal) {
318 logger.severe(" XRayEnergy does implement setCrystal yet.");
319 }
320
321
322
323
324 @Override
325 public STATE getEnergyTermState() {
326 return state;
327 }
328
329
330
331
332 @Override
333 public void setEnergyTermState(STATE state) {
334 this.state = state;
335 switch (state) {
336 case FAST:
337 xrayTerms = false;
338 restraintTerms = true;
339 break;
340 case SLOW:
341 xrayTerms = true;
342 restraintTerms = false;
343 break;
344 default:
345 xrayTerms = true;
346 restraintTerms = true;
347 }
348 }
349
350
351
352
353 @Override
354 public double getLambda() {
355 return lambda;
356 }
357
358
359
360
361 @Override
362 public void setLambda(double lambda) {
363 if (lambda <= 1.0 && lambda >= 0.0) {
364 this.lambda = lambda;
365 } else {
366 String message = format("Lambda value %8.3f is not in the range [0..1].", lambda);
367 logger.warning(message);
368 }
369 }
370
371
372
373
374 @Override
375 public double[] getMass() {
376 double[] mass = new double[nXYZ + nB + nOCC];
377 refinementModel.getMass(mass);
378 return mass;
379 }
380
381
382
383
384 @Override
385 public int getNumberOfVariables() {
386 return nXYZ + nB + nOCC;
387 }
388
389
390
391
392 @Override
393 public double[] getScaling() {
394 return optimizationScaling;
395 }
396
397
398
399
400 @Override
401 public void setScaling(@Nullable double[] scaling) {
402 optimizationScaling = scaling;
403 }
404
405
406
407
408 @Override
409 public double getTotalEnergy() {
410 return totalEnergy;
411 }
412
413
414
415
416
417
418 @Override
419 public VARIABLE_TYPE[] getVariableTypes() {
420 VARIABLE_TYPE[] vtypes = new VARIABLE_TYPE[nXYZ + nB + nOCC];
421 int i = 0;
422 if (refinementMode.includesCoordinates()) {
423 for (Atom a : activeAtomArray) {
424 vtypes[i++] = VARIABLE_TYPE.X;
425 vtypes[i++] = VARIABLE_TYPE.Y;
426 vtypes[i++] = VARIABLE_TYPE.Z;
427 }
428 }
429 if (refinementMode.includesBFactors()) {
430 for (int j = i; j < nXYZ + nB; i++, j++) {
431 vtypes[j] = VARIABLE_TYPE.OTHER;
432 }
433 }
434 if (refinementMode.includesOccupancies()) {
435 for (int j = i; j < nXYZ + nB + nOCC; i++, j++) {
436 vtypes[j] = VARIABLE_TYPE.OTHER;
437 }
438 }
439 return vtypes;
440 }
441
442
443
444
445 @Override
446 public double[] getVelocity(double[] velocity) {
447 if (velocity == null || velocity.length != refinementModel.getNumParameters()) {
448 velocity = new double[refinementModel.getNumParameters()];
449 }
450 refinementModel.getVelocity(velocity);
451 return velocity;
452 }
453
454
455
456
457 @Override
458 public void setVelocity(double[] velocity) {
459 refinementModel.setVelocity(velocity);
460 }
461
462
463
464
465 @Override
466 public double getd2EdL2() {
467 return 0.0;
468 }
469
470
471
472
473 @Override
474 public double getdEdL() {
475 return dEdL;
476 }
477
478
479
480
481 @Override
482 public void getdEdXdL(double[] gradient) {
483 int n = dUdXdL.length;
484 arraycopy(dUdXdL, 0, gradient, 0, n);
485 }
486
487
488
489
490 @Override
491 public void setAcceleration(double[] acceleration) {
492 refinementModel.setAcceleration(acceleration);
493 }
494
495
496
497
498 @Override
499 public double[] getAcceleration(double[] acceleration) {
500 if (acceleration == null || acceleration.length != refinementModel.getNumParameters()) {
501 acceleration = new double[refinementModel.getNumParameters()];
502 }
503 refinementModel.getAcceleration(acceleration);
504 return acceleration;
505 }
506
507
508
509
510 @Override
511 public void setPreviousAcceleration(double[] previousAcceleration) {
512 refinementModel.setPreviousAcceleration(previousAcceleration);
513 }
514
515
516
517
518 @Override
519 public double[] getPreviousAcceleration(double[] previousAcceleration) {
520 if (previousAcceleration == null || previousAcceleration.length != refinementModel.getNumParameters()) {
521 previousAcceleration = new double[refinementModel.getNumParameters()];
522 }
523 refinementModel.getPreviousAcceleration(previousAcceleration);
524 return previousAcceleration;
525 }
526
527
528
529
530
531
532
533
534 private double getBFactorRestraints(boolean gradient) {
535 double e = 0.0;
536 double[] anisou1 = new double[6];
537 double[] anisou2;
538 double[] gradu = new double[6];
539 double threeHalves = 3.0 / 2.0;
540 double oneHalf = 1.0 / 2.0;
541
542
543 for (Atom a : activeAtomArray) {
544 if (a.getAnisou(null) == null) {
545
546
547
548
549 double biso = a.getTempFactor();
550 e += -kTbNonzero * Math.log(pow(biso, threeHalves));
551 if (gradient) {
552 double gradb = -kTbNonzero * threeHalves / biso;
553 a.addToTempFactorGradient(gradb);
554 }
555 } else {
556
557 anisou1 = a.getAnisou(anisou1);
558
559
560
561
562 double det = determinant3(anisou1);
563 e += u2b(-oneHalf * kTbNonzero * Math.log(det));
564 if (gradient) {
565 gradu[0] = u2b(-oneHalf * kTbNonzero * ((anisou1[1] * anisou1[2] - anisou1[5] * anisou1[5]) / det));
566 gradu[1] = u2b(-oneHalf * kTbNonzero * ((anisou1[0] * anisou1[2] - anisou1[4] * anisou1[4]) / det));
567 gradu[2] = u2b(-oneHalf * kTbNonzero * ((anisou1[0] * anisou1[1] - anisou1[3] * anisou1[3]) / det));
568 gradu[3] = u2b(-oneHalf * kTbNonzero * ((2.0 * (anisou1[4] * anisou1[5] - anisou1[3] * anisou1[2])) / det));
569 gradu[4] = u2b(-oneHalf * kTbNonzero * ((2.0 * (anisou1[3] * anisou1[5] - anisou1[4] * anisou1[1])) / det));
570 gradu[5] = u2b(-oneHalf * kTbNonzero * ((2.0 * (anisou1[3] * anisou1[4] - anisou1[5] * anisou1[0])) / det));
571 a.addToAnisouGradient(gradu);
572 }
573 }
574 }
575
576
577 List<Atom[]> bonds = refinementModel.getBFactorRestraints();
578 for (Atom[] atoms : bonds) {
579 Atom a1 = atoms[0];
580 Atom a2 = atoms[1];
581 boolean isAnisou1 = a1.getAnisou(null) != null;
582 boolean isAnisou2 = a2.getAnisou(null) != null;
583 if (!isAnisou1 && !isAnisou2) {
584
585 double b1 = a1.getTempFactor();
586 double b2 = a2.getTempFactor();
587 double bdiff = b1 - b2;
588 e += kTbSimWeight * bdiff * bdiff;
589 if (gradient) {
590 double gradb = 2.0 * kTbSimWeight * bdiff;
591 a1.addToTempFactorGradient(gradb);
592 a2.addToTempFactorGradient(-gradb);
593 }
594 } else if (isAnisou1 && isAnisou2) {
595
596 anisou1 = a1.getAnisou(anisou1);
597 anisou2 = a2.getAnisou(anisou1);
598 double det1 = determinant3(anisou1);
599 double det2 = determinant3(anisou2);
600 double bdiff = det1 - det2;
601 double bdiff2 = bdiff * bdiff;
602 e += eightPI23 * kTbSimWeight * bdiff2;
603 if (gradient) {
604 double gradb = eightPI23 * 2.0 * kTbSimWeight * bdiff;
605
606 gradu[0] = gradb * (anisou1[1] * anisou1[2] - anisou1[5] * anisou1[5]);
607 gradu[1] = gradb * (anisou1[0] * anisou1[2] - anisou1[4] * anisou1[4]);
608 gradu[2] = gradb * (anisou1[0] * anisou1[1] - anisou1[3] * anisou1[3]);
609 gradu[3] = gradb * (2.0 * (anisou1[4] * anisou1[5] - anisou1[3] * anisou1[2]));
610 gradu[4] = gradb * (2.0 * (anisou1[3] * anisou1[5] - anisou1[4] * anisou1[1]));
611 gradu[5] = gradb * (2.0 * (anisou1[3] * anisou1[4] - anisou1[5] * anisou1[0]));
612 a1.addToAnisouGradient(gradu);
613
614 gradu[0] = gradb * (anisou2[5] * anisou2[5] - anisou2[1] * anisou2[2]);
615 gradu[1] = gradb * (anisou2[4] * anisou2[4] - anisou2[0] * anisou2[2]);
616 gradu[2] = gradb * (anisou2[3] * anisou2[3] - anisou2[0] * anisou2[1]);
617 gradu[3] = gradb * (2.0 * (anisou2[3] * anisou2[2] - anisou2[4] * anisou2[5]));
618 gradu[4] = gradb * (2.0 * (anisou2[4] * anisou2[1] - anisou2[3] * anisou2[5]));
619 gradu[5] = gradb * (2.0 * (anisou2[5] * anisou2[0] - anisou2[3] * anisou2[4]));
620 a2.addToAnisouGradient(gradu);
621 }
622 } else {
623 if (!isAnisou1) {
624
625 a1 = atoms[1];
626 a2 = atoms[0];
627 }
628 anisou1 = a1.getAnisou(anisou1);
629 double u2 = b2u(a2.getTempFactor());
630 double det1 = determinant3(anisou1);
631
632 double det2 = u2 * u2 * u2;
633 double bdiff = det1 - det2;
634 double bdiff2 = bdiff * bdiff;
635 e += eightPI23 * kTbSimWeight * bdiff2;
636 if (gradient) {
637 double gradb = eightPI23 * 2.0 * kTbSimWeight * bdiff;
638
639 gradu[0] = gradb * (anisou1[1] * anisou1[2] - anisou1[5] * anisou1[5]);
640 gradu[1] = gradb * (anisou1[0] * anisou1[2] - anisou1[4] * anisou1[4]);
641 gradu[2] = gradb * (anisou1[0] * anisou1[1] - anisou1[3] * anisou1[3]);
642 gradu[3] = gradb * (2.0 * (anisou1[4] * anisou1[5] - anisou1[3] * anisou1[2]));
643 gradu[4] = gradb * (2.0 * (anisou1[3] * anisou1[5] - anisou1[4] * anisou1[1]));
644 gradu[5] = gradb * (2.0 * (anisou1[3] * anisou1[4] - anisou1[5] * anisou1[0]));
645 a1.addToAnisouGradient(gradu);
646
647 double gradBiso = u2b(-gradb * u2 * u2);
648 a2.addToTempFactorGradient(gradBiso);
649 }
650 }
651 }
652 return e;
653 }
654 }