1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.xray;
39
40 import ffx.algorithms.AlgorithmListener;
41 import ffx.algorithms.dynamics.thermostats.Thermostat;
42 import ffx.crystal.Crystal;
43 import ffx.crystal.CrystalPotential;
44 import ffx.numerics.Potential;
45 import ffx.potential.ForceFieldEnergy;
46 import ffx.potential.MolecularAssembly;
47 import ffx.potential.bonded.Atom;
48 import ffx.potential.bonded.LambdaInterface;
49 import ffx.realspace.RealSpaceData;
50 import ffx.realspace.RealSpaceEnergy;
51 import ffx.xray.refine.RefinementMode;
52 import ffx.xray.refine.RefinementModel;
53
54 import java.util.Arrays;
55 import java.util.List;
56 import java.util.logging.Logger;
57 import java.util.stream.Collectors;
58 import java.util.stream.Stream;
59
60 import static ffx.utilities.Constants.KCAL_TO_GRAM_ANG2_PER_PS2;
61 import static ffx.utilities.Constants.kB;
62 import static java.lang.String.format;
63 import static java.util.Arrays.fill;
64
65
66
67
68
69
70
71
72
73 public class RefinementEnergy implements LambdaInterface, CrystalPotential, AlgorithmListener {
74
75 private static final Logger logger = Logger.getLogger(RefinementEnergy.class.getName());
76
77
78
79
80 private final DataContainer data;
81
82
83
84 private final RefinementModel refinementModel;
85
86
87
88
89
90 private final RefinementMode refinementMode;
91
92
93
94 private double totalEnergy;
95
96
97
98 private final Atom[] scatteringAtoms;
99
100
101
102 private final int nAtoms;
103
104
105
106 private final MolecularAssembly[] molecularAssemblies;
107
108
109
110 private final double[][] xChemical;
111
112
113
114 private final double[][] gChemical;
115
116
117
118 private CrystalPotential dataEnergy;
119
120
121
122 private double[] gExperiment;
123
124
125
126 private double[] optimizationScaling;
127
128
129
130 private final int n;
131
132
133
134 private final int nXYZ;
135
136
137
138 private final int nBFactor;
139
140
141
142 private final int nOccupancy;
143
144
145
146 private STATE state = STATE.BOTH;
147
148
149
150
151 protected Thermostat thermostat;
152
153
154
155 private double kTScale;
156
157
158
159
160
161
162
163 public RefinementEnergy(DataContainer dataContainer) {
164 this.data = dataContainer;
165 refinementModel = dataContainer.getRefinementModel();
166 refinementMode = refinementModel.getRefinementMode();
167 molecularAssemblies = refinementModel.getMolecularAssemblies();
168 scatteringAtoms = refinementModel.getScatteringAtoms();
169 nAtoms = scatteringAtoms.length;
170
171 thermostat = null;
172 kTScale = 1.0;
173
174
175 nXYZ = refinementModel.getNumCoordParameters();
176 nBFactor = refinementModel.getNumBFactorParameters();
177 nOccupancy = refinementModel.getNumOccupancyParameters();
178 n = nXYZ + nBFactor + nOccupancy;
179
180
181 for (MolecularAssembly molecularAssembly : molecularAssemblies) {
182 ForceFieldEnergy forceFieldEnergy = molecularAssembly.getPotentialEnergy();
183 if (forceFieldEnergy == null) {
184 forceFieldEnergy = ForceFieldEnergy.energyFactory(molecularAssembly);
185 molecularAssembly.setPotential(forceFieldEnergy);
186 }
187 }
188
189 if (dataContainer instanceof DiffractionData diffractionData) {
190 if (!diffractionData.getScaled()[0]) {
191 diffractionData.printStats();
192 }
193 dataEnergy = new XRayEnergy(diffractionData);
194
195 dataEnergy.setScaling(null);
196 } else if (dataContainer instanceof RealSpaceData realSpaceData) {
197 dataEnergy = new RealSpaceEnergy(realSpaceData);
198
199 dataEnergy.setScaling(null);
200 }
201
202 int assemblySize = molecularAssemblies.length;
203 xChemical = new double[assemblySize][];
204 gChemical = new double[assemblySize][];
205 for (int i = 0; i < assemblySize; i++) {
206 int len = molecularAssemblies[i].getActiveAtomArray().length * 3;
207 xChemical[i] = new double[len];
208 gChemical[i] = new double[len];
209 }
210 gExperiment = new double[n];
211 }
212
213
214
215
216 @Override
217 public boolean algorithmUpdate(MolecularAssembly active) {
218 if (thermostat != null) {
219 kTScale = KCAL_TO_GRAM_ANG2_PER_PS2 / (thermostat.getTargetTemperature() * kB);
220 }
221 logger.info(" kTscale: " + kTScale);
222 logger.info(data.printEnergyUpdate());
223 return true;
224 }
225
226
227
228
229 @Override
230 public boolean destroy() {
231 return dataEnergy.destroy();
232 }
233
234
235
236
237 @Override
238 public double energy(double[] x) {
239 return energy(x, false);
240 }
241
242
243
244
245 @Override
246 public double energy(double[] x, boolean print) {
247 double weight = data.getWeight();
248 double e = 0.0;
249
250 if (thermostat != null) {
251 kTScale = KCAL_TO_GRAM_ANG2_PER_PS2 / (thermostat.getTargetTemperature() * kB);
252 }
253
254 unscaleCoordinates(x);
255 refinementModel.setParameters(x);
256 RefinementMode refinementMode = refinementModel.getRefinementMode();
257
258 int numAssemblies = molecularAssemblies.length;
259
260 if (refinementMode.includesCoordinates()) {
261
262 for (int i = 0; i < numAssemblies; i++) {
263 ForceFieldEnergy forceFieldEnergy = molecularAssemblies[i].getPotentialEnergy();
264 forceFieldEnergy.getCoordinates(xChemical[i]);
265 double curE = forceFieldEnergy.energy(xChemical[i], print);
266 e += curE;
267 }
268 e = e * kTScale / numAssemblies;
269
270 e += weight * dataEnergy.energy(x, print);
271 } else {
272
273 e = dataEnergy.energy(x, print);
274 }
275
276 scaleCoordinates(x);
277
278 totalEnergy = e;
279 return e;
280 }
281
282
283
284
285
286
287 @Override
288 public double energyAndGradient(double[] x, double[] g) {
289 return energyAndGradient(x, g, false);
290 }
291
292
293
294
295
296
297 @Override
298 public double energyAndGradient(double[] x, double[] g, boolean print) {
299 double weight = data.getWeight();
300 double e = 0.0;
301 fill(g, 0.0);
302 fill(gExperiment, 0.0);
303
304 if (thermostat != null) {
305 kTScale = KCAL_TO_GRAM_ANG2_PER_PS2 / (thermostat.getTargetTemperature() * kB);
306 }
307
308 unscaleCoordinates(x);
309 refinementModel.setParameters(x);
310
311 if (refinementMode.includesCoordinates()) {
312 int numAssemblies = molecularAssemblies.length;
313
314 for (int i = 0; i < numAssemblies; i++) {
315 ForceFieldEnergy forceFieldEnergy = molecularAssemblies[i].getPotentialEnergy();
316 forceFieldEnergy.getCoordinates(xChemical[i]);
317 double curE = forceFieldEnergy.energyAndGradient(xChemical[i], gChemical[i], print);
318 e += curE;
319
320 refinementModel.addAssemblyGradient(i, g);
321 }
322
323 e = kTScale * e / numAssemblies;
324
325 if (numAssemblies > 1) {
326 for (int i = 0; i < nXYZ; i++) {
327 g[i] /= numAssemblies;
328 }
329 }
330 for (int i = 0; i < nXYZ; i++) {
331 g[i] *= kTScale;
332 }
333
334 double xE = dataEnergy.energyAndGradient(x, gExperiment);
335 e += weight * xE;
336
337
338 for (int i = 0; i < nXYZ; i++) {
339 g[i] += weight * gExperiment[i];
340 }
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357 if (refinementMode.includesBFactors() || refinementMode.includesOccupancies()) {
358 for (int i = nXYZ; i < n; i++) {
359 g[i] = weight * gExperiment[i];
360 }
361 }
362 } else if (refinementMode.includesBFactors() || refinementMode.includesOccupancies()) {
363
364 e = dataEnergy.energyAndGradient(x, g);
365 }
366
367 scaleCoordinatesAndGradient(x, g);
368 totalEnergy = e;
369 return e;
370 }
371
372
373
374
375 @Override
376 public double[] getAcceleration(double[] acceleration) {
377 return dataEnergy.getAcceleration(acceleration);
378 }
379
380
381
382
383 @Override
384 public double[] getCoordinates(double[] parameters) {
385 return dataEnergy.getCoordinates(parameters);
386 }
387
388
389
390
391 @Override
392 public void setCoordinates(double[] parameters) {
393 dataEnergy.setCoordinates(parameters);
394 }
395
396
397
398
399 @Override
400 public Crystal getCrystal() {
401 return dataEnergy.getCrystal();
402 }
403
404
405
406
407 @Override
408 public void setCrystal(Crystal crystal) {
409 logger.severe(" RefinementEnergy does implement setCrystal yet.");
410 }
411
412
413
414
415
416
417 public CrystalPotential getDataEnergy() {
418 return dataEnergy;
419 }
420
421
422
423
424 @Override
425 public STATE getEnergyTermState() {
426 return state;
427 }
428
429
430
431
432 @Override
433 public void setEnergyTermState(STATE state) {
434 this.state = state;
435 for (MolecularAssembly molecularAssembly : molecularAssemblies) {
436 ForceFieldEnergy fe = molecularAssembly.getPotentialEnergy();
437 fe.setEnergyTermState(state);
438 }
439 dataEnergy.setEnergyTermState(state);
440 }
441
442
443
444
445 @Override
446 public double getLambda() {
447 double lambda = 1.0;
448 if (data instanceof DiffractionData) {
449 XRayEnergy xRayEnergy = (XRayEnergy) dataEnergy;
450 lambda = xRayEnergy.getLambda();
451 } else if (data instanceof RealSpaceData) {
452 RealSpaceEnergy realSpaceEnergy = (RealSpaceEnergy) dataEnergy;
453 lambda = realSpaceEnergy.getLambda();
454 }
455 return lambda;
456 }
457
458
459
460
461 @Override
462 public void setLambda(double lambda) {
463 for (MolecularAssembly molecularAssembly : molecularAssemblies) {
464 ForceFieldEnergy forceFieldEnergy = molecularAssembly.getPotentialEnergy();
465 forceFieldEnergy.setLambda(lambda);
466 }
467 if (data instanceof DiffractionData) {
468 XRayEnergy xRayEnergy = (XRayEnergy) dataEnergy;
469 xRayEnergy.setLambda(lambda);
470 } else if (data instanceof RealSpaceData) {
471 RealSpaceEnergy realSpaceEnergy = (RealSpaceEnergy) dataEnergy;
472 realSpaceEnergy.setLambda(lambda);
473 }
474 }
475
476
477
478
479 @Override
480 public double[] getMass() {
481 return dataEnergy.getMass();
482 }
483
484
485
486
487 @Override
488 public int getNumberOfVariables() {
489 return dataEnergy.getNumberOfVariables();
490 }
491
492
493
494
495 @Override
496 public double[] getPreviousAcceleration(double[] previousAcceleration) {
497 return dataEnergy.getPreviousAcceleration(previousAcceleration);
498 }
499
500
501
502
503 @Override
504 public double[] getScaling() {
505 return optimizationScaling;
506 }
507
508
509
510
511 @Override
512 public void setScaling(double[] scaling) {
513 optimizationScaling = scaling;
514 }
515
516
517
518
519
520
521 public Thermostat getThermostat() {
522 return thermostat;
523 }
524
525
526
527
528
529
530 public void setThermostat(Thermostat thermostat) {
531 this.thermostat = thermostat;
532 }
533
534
535
536
537 @Override
538 public double getTotalEnergy() {
539 return totalEnergy;
540 }
541
542 @Override
543 public List<Potential> getUnderlyingPotentials() {
544 Stream<Potential> directPEs =
545 Arrays.stream(molecularAssemblies).map(MolecularAssembly::getPotentialEnergy);
546 Stream<Potential> allPEs =
547 Arrays.stream(molecularAssemblies)
548 .map(MolecularAssembly::getPotentialEnergy)
549 .map(Potential::getUnderlyingPotentials)
550 .flatMap(List::stream);
551 return Stream.concat(directPEs, allPEs).collect(Collectors.toList());
552 }
553
554
555
556
557
558
559 @Override
560 public VARIABLE_TYPE[] getVariableTypes() {
561 return dataEnergy.getVariableTypes();
562 }
563
564
565
566
567 @Override
568 public double[] getVelocity(double[] velocity) {
569 return dataEnergy.getVelocity(velocity);
570 }
571
572
573
574
575 @Override
576 public double getd2EdL2() {
577 double d2EdL2 = 0.0;
578 if (thermostat != null) {
579 kTScale = KCAL_TO_GRAM_ANG2_PER_PS2 / (thermostat.getTargetTemperature() * kB);
580 }
581 int assemblysize = molecularAssemblies.length;
582
583
584 for (int i = 0; i < assemblysize; i++) {
585 ForceFieldEnergy forceFieldEnergy = molecularAssemblies[i].getPotentialEnergy();
586 double curE = forceFieldEnergy.getd2EdL2();
587 d2EdL2 += (curE - d2EdL2) / (i + 1);
588 }
589 d2EdL2 *= kTScale;
590
591
592 return d2EdL2;
593 }
594
595
596
597
598 @Override
599 public double getdEdL() {
600 double dEdL = 0.0;
601 if (thermostat != null) {
602 kTScale = KCAL_TO_GRAM_ANG2_PER_PS2 / (thermostat.getTargetTemperature() * kB);
603 }
604 int assemblysize = molecularAssemblies.length;
605
606
607 for (int i = 0; i < assemblysize; i++) {
608 ForceFieldEnergy forceFieldEnergy = molecularAssemblies[i].getPotentialEnergy();
609 double curdEdL = forceFieldEnergy.getdEdL();
610 dEdL += (curdEdL - dEdL) / (i + 1);
611 }
612 dEdL *= kTScale;
613 double weight = data.getWeight();
614 if (data instanceof DiffractionData) {
615 XRayEnergy xRayEnergy = (XRayEnergy) dataEnergy;
616 dEdL += weight * xRayEnergy.getdEdL();
617 } else if (data instanceof RealSpaceData) {
618 RealSpaceEnergy realSpaceEnergy = (RealSpaceEnergy) dataEnergy;
619 dEdL += weight * realSpaceEnergy.getdEdL();
620 }
621 return dEdL;
622 }
623
624
625
626
627 @Override
628 public void getdEdXdL(double[] gradient) {
629 double weight = data.getWeight();
630 if (thermostat != null) {
631 kTScale = KCAL_TO_GRAM_ANG2_PER_PS2 / (thermostat.getTargetTemperature() * kB);
632 }
633 int assemblysize = molecularAssemblies.length;
634
635
636 for (int i = 0; i < assemblysize; i++) {
637 ForceFieldEnergy forcefieldEnergy = molecularAssemblies[i].getPotentialEnergy();
638 Arrays.fill(gChemical[i], 0.0);
639 forcefieldEnergy.getdEdXdL(gChemical[i]);
640 }
641 for (int i = 0; i < assemblysize; i++) {
642 for (int j = 0; j < nXYZ; j++) {
643 gradient[j] += gChemical[i][j];
644 }
645 }
646
647
648 if (assemblysize > 1) {
649 for (int i = 0; i < nXYZ; i++) {
650 gradient[i] /= assemblysize;
651 }
652 }
653 for (int i = 0; i < nXYZ; i++) {
654 gradient[i] *= kTScale;
655 }
656
657
658 if (gExperiment == null || gExperiment.length != nXYZ) {
659 gExperiment = new double[nXYZ];
660 } else {
661 for (int j = 0; j < nXYZ; j++) {
662 gExperiment[j] = 0.0;
663 }
664 }
665 if (data instanceof DiffractionData) {
666 XRayEnergy xRayEnergy = (XRayEnergy) dataEnergy;
667 xRayEnergy.getdEdXdL(gExperiment);
668 } else if (data instanceof RealSpaceData) {
669 RealSpaceEnergy realSpaceEnergy = (RealSpaceEnergy) dataEnergy;
670 realSpaceEnergy.getdEdXdL(gExperiment);
671 }
672
673
674 for (int i = 0; i < nXYZ; i++) {
675 gradient[i] += weight * gExperiment[i];
676 }
677 }
678
679
680
681
682 @Override
683 public void setAcceleration(double[] acceleration) {
684 dataEnergy.setAcceleration(acceleration);
685 }
686
687
688
689
690 @Override
691 public void setPreviousAcceleration(double[] previousAcceleration) {
692 dataEnergy.setPreviousAcceleration(previousAcceleration);
693 }
694
695
696
697
698 @Override
699 public void setVelocity(double[] velocity) {
700 dataEnergy.setVelocity(velocity);
701 }
702 }