1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential;
39
40 import static java.lang.String.format;
41 import static java.util.Arrays.fill;
42
43 import edu.rit.pj.ParallelRegion;
44 import edu.rit.pj.ParallelSection;
45 import edu.rit.pj.ParallelTeam;
46 import ffx.crystal.Crystal;
47 import ffx.crystal.CrystalPotential;
48 import ffx.numerics.Potential;
49 import ffx.potential.bonded.LambdaInterface;
50 import ffx.potential.utils.EnergyException;
51 import java.util.ArrayList;
52 import java.util.Arrays;
53 import java.util.LinkedHashSet;
54 import java.util.List;
55 import java.util.Set;
56 import java.util.function.DoubleBinaryOperator;
57 import java.util.logging.Logger;
58
59
60
61
62
63
64
65
66
67
68
69
70 public class QuadTopologyEnergy implements CrystalPotential, LambdaInterface {
71 private static final Logger logger = Logger.getLogger(QuadTopologyEnergy.class.getName());
72 private final DualTopologyEnergy dualTopA;
73 private final DualTopologyEnergy dualTopB;
74 private final LambdaInterface linterA;
75 private final LambdaInterface linterB;
76
77 private final int nVarA;
78 private final int nVarB;
79 private final int nShared;
80 private final int uniqueA;
81 private final int uniqueB;
82 private final int nVarTot;
83
84
85
86
87
88 private final int[] indexAToGlobal;
89
90 private final int[] indexBToGlobal;
91 private final int[] indexGlobalToA;
92 private final int[] indexGlobalToB;
93
94 private final double[] mass;
95 private final double[] xA;
96 private final double[] xB;
97 private final double[] gA;
98 private final double[] gB;
99
100
101
102
103
104
105 private final double[] tempA;
106
107 private final double[] tempB;
108 private final EnergyRegion region;
109 private STATE state = STATE.BOTH;
110 private double lambda;
111 private double totalEnergy;
112 private double energyA;
113 private double energyB;
114 private double dEdL, dEdL_A, dEdL_B;
115 private double d2EdL2, d2EdL2_A, d2EdL2_B;
116 private boolean inParallel = false;
117 private ParallelTeam team;
118
119 private double[] scaling;
120
121 private VARIABLE_TYPE[] types = null;
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138 public QuadTopologyEnergy(DualTopologyEnergy dualTopologyA, DualTopologyEnergy dualTopologyB) {
139 this(dualTopologyA, dualTopologyB, null, null);
140 }
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157 public QuadTopologyEnergy(
158 DualTopologyEnergy dualTopologyA,
159 DualTopologyEnergy dualTopologyB,
160 List<Integer> uniqueAList,
161 List<Integer> uniqueBList) {
162 this.dualTopA = dualTopologyA;
163 this.dualTopB = dualTopologyB;
164 dualTopB.setCrystal(dualTopA.getCrystal());
165 dualTopB.reloadCommonMasses(true);
166 linterA = dualTopologyA;
167 linterB = dualTopologyB;
168 nVarA = dualTopA.getNumberOfVariables();
169 nVarB = dualTopB.getNumberOfVariables();
170
171
172
173
174
175
176
177 Set<Integer> uniqueASet = new LinkedHashSet<>();
178 if (uniqueAList != null) {
179 uniqueASet.addAll(uniqueAList);
180 }
181 int nCommon = dualTopA.getNumSharedVariables();
182 for (int i = nCommon; i < nVarA; i++) {
183 uniqueASet.add(i);
184 }
185
186 uniqueAList = new ArrayList<>(uniqueASet);
187 uniqueA = uniqueAList.size();
188 int[] uniquesA = new int[uniqueA];
189 for (int i = 0; i < uniqueA; i++) {
190 uniquesA[i] = uniqueAList.get(i);
191 }
192
193
194
195
196
197
198
199
200
201
202 Set<Integer> uniqueBSet = new LinkedHashSet<>();
203 if (uniqueBList != null) {
204 uniqueBSet.addAll(uniqueBList);
205 }
206 nCommon = dualTopB.getNumSharedVariables();
207 for (int i = nCommon; i < nVarB; i++) {
208 uniqueBSet.add(i);
209 }
210 uniqueBList = new ArrayList<>(uniqueBSet);
211 uniqueB = uniqueBList.size();
212 int[] uniquesB = new int[uniqueB];
213 for (int i = 0; i < uniqueB; i++) {
214 uniquesB[i] = uniqueBList.get(i);
215 }
216
217 nShared = nVarA - uniqueA;
218 assert (nShared == nVarB - uniqueB);
219 nVarTot = nShared + uniqueA + uniqueB;
220
221 indexAToGlobal = new int[nVarA];
222 indexBToGlobal = new int[nVarB];
223 indexGlobalToA = new int[nVarTot];
224 indexGlobalToB = new int[nVarTot];
225
226
227 fill(indexGlobalToA, -1);
228 fill(indexGlobalToB, -1);
229
230 if (uniqueA > 0) {
231 int commonIndex = 0;
232 int uniqueIndex = 0;
233 for (int i = 0; i < nVarA; i++) {
234 if (uniqueIndex < uniqueA && i == uniquesA[uniqueIndex]) {
235 int destIndex = nShared + uniqueIndex;
236 indexAToGlobal[i] = destIndex;
237 indexGlobalToA[destIndex] = i;
238 ++uniqueIndex;
239 } else {
240 indexAToGlobal[i] = commonIndex;
241 indexGlobalToA[commonIndex++] = i;
242 }
243 }
244 } else {
245 for (int i = 0; i < nVarA; i++) {
246 indexAToGlobal[i] = i;
247 indexGlobalToA[i] = i;
248 }
249 }
250
251 if (uniqueB > 0) {
252 int commonIndex = 0;
253 int uniqueIndex = 0;
254 for (int i = 0; i < nVarB; i++) {
255 if (uniqueIndex < uniqueB && i == uniquesB[uniqueIndex]) {
256 int destIndex = nVarA + uniqueIndex;
257 indexBToGlobal[i] = destIndex;
258 indexGlobalToB[destIndex] = i;
259 ++uniqueIndex;
260 } else {
261 indexBToGlobal[i] = commonIndex;
262 indexGlobalToB[commonIndex++] = i;
263 }
264 }
265 } else {
266 for (int i = 0; i < nVarB; i++) {
267 indexBToGlobal[i] = i;
268 indexGlobalToB[i] = i;
269 }
270 }
271
272 xA = new double[nVarA];
273 xB = new double[nVarB];
274 gA = new double[nVarA];
275 gB = new double[nVarB];
276 tempA = new double[nVarA];
277 tempB = new double[nVarB];
278 mass = new double[nVarTot];
279 doublesFromFunction(mass, dualTopA.getMass(), dualTopB.getMass(), Math::max);
280
281 region = new EnergyRegion();
282 team = new ParallelTeam(1);
283 }
284
285
286
287
288
289
290 @Override
291 public boolean dEdLZeroAtEnds() {
292 if (!dualTopA.dEdLZeroAtEnds() || !dualTopB.dEdLZeroAtEnds()) {
293 return false;
294 }
295 return true;
296 }
297
298
299 @Override
300 public boolean destroy() {
301 boolean dtADestroy = dualTopA.destroy();
302 boolean dtBDestroy = dualTopB.destroy();
303 try {
304 if (team != null) {
305 team.shutdown();
306 }
307 return dtADestroy && dtBDestroy;
308 } catch (Exception ex) {
309 logger.warning(format(" Exception in shutting down QuadTopologyEnergy: %s", ex));
310 logger.info(Utilities.stackTraceToString(ex));
311 return false;
312 }
313 }
314
315
316 @Override
317 public double energy(double[] x) {
318 return energy(x, false);
319 }
320
321
322 @Override
323 public double energy(double[] x, boolean verbose) {
324 region.setX(x);
325 region.setVerbose(verbose);
326 try {
327 team.execute(region);
328 } catch (Exception ex) {
329 throw new EnergyException(format(" Exception in calculating quad-topology energy: %s", ex));
330 }
331
332 if (verbose) {
333 logger.info(format(" Total quad-topology energy: %12.4f", totalEnergy));
334 }
335 return totalEnergy;
336 }
337
338
339 @Override
340 public double energyAndGradient(double[] x, double[] g) {
341 return energyAndGradient(x, g, false);
342 }
343
344
345 @Override
346 public double energyAndGradient(double[] x, double[] g, boolean verbose) {
347 assert Arrays.stream(x).allMatch(Double::isFinite);
348 region.setX(x);
349 region.setG(g);
350 region.setVerbose(verbose);
351 try {
352 team.execute(region);
353 } catch (Exception ex) {
354 throw new EnergyException(format(" Exception in calculating quad-topology energy: %s", ex));
355 }
356
357 if (verbose) {
358 logger.info(format(" Total quad-topology energy: %12.4f", totalEnergy));
359 }
360 return totalEnergy;
361 }
362
363
364 @Override
365 public double[] getAcceleration(double[] acceleration) {
366 doublesFrom(acceleration, dualTopA.getAcceleration(tempA), dualTopB.getAcceleration(tempB));
367 return acceleration;
368 }
369
370
371 @Override
372 public double[] getCoordinates(double[] x) {
373 dualTopA.getCoordinates(xA);
374 dualTopB.getCoordinates(xB);
375 doublesFrom(x, xA, xB);
376 return x;
377 }
378
379
380 @Override
381 public Crystal getCrystal() {
382 return dualTopA.getCrystal();
383 }
384
385
386 @Override
387 public void setCrystal(Crystal crystal) {
388 dualTopA.setCrystal(crystal);
389 dualTopB.setCrystal(crystal);
390 }
391
392
393
394
395
396
397 public DualTopologyEnergy getDualTopA() {
398 return dualTopA;
399 }
400
401
402
403
404
405
406 public DualTopologyEnergy getDualTopB() {
407 return dualTopB;
408 }
409
410
411 @Override
412 public STATE getEnergyTermState() {
413 return state;
414 }
415
416
417 @Override
418 public void setEnergyTermState(STATE state) {
419 this.state = state;
420 dualTopA.setEnergyTermState(state);
421 dualTopB.setEnergyTermState(state);
422 }
423
424
425 @Override
426 public double getLambda() {
427 return lambda;
428 }
429
430
431 @Override
432 public void setLambda(double lambda) {
433 if (!Double.isFinite(lambda) || lambda > 1.0 || lambda < 0.0) {
434 throw new ArithmeticException(
435 format(" Attempted to set invalid lambda value of %10.6g", lambda));
436 }
437 this.lambda = lambda;
438 dualTopA.setLambda(lambda);
439 dualTopB.setLambda(lambda);
440 }
441
442
443 @Override
444 public double[] getMass() {
445 return mass;
446 }
447
448
449
450
451
452
453 public int getNumSharedVariables() {
454 return nShared;
455 }
456
457
458 @Override
459 public int getNumberOfVariables() {
460 return nVarTot;
461 }
462
463
464 @Override
465 public double[] getPreviousAcceleration(double[] previousAcceleration) {
466 doublesFrom(
467 previousAcceleration,
468 dualTopA.getPreviousAcceleration(tempA),
469 dualTopB.getPreviousAcceleration(tempB));
470 return previousAcceleration;
471 }
472
473
474 @Override
475 public double[] getScaling() {
476 return scaling;
477 }
478
479
480 @Override
481 public void setScaling(double[] scaling) {
482 this.scaling = scaling;
483 if (scaling != null) {
484 double[] scaleA = new double[nVarA];
485 double[] scaleB = new double[nVarB];
486 doublesTo(scaling, scaleA, scaleB);
487 dualTopA.setScaling(scaleA);
488 dualTopB.setScaling(scaleB);
489 } else {
490 dualTopA.setScaling(null);
491 dualTopB.setScaling(null);
492 }
493 }
494
495
496 @Override
497 public double getTotalEnergy() {
498 return totalEnergy;
499 }
500
501 @Override
502 public List<Potential> getUnderlyingPotentials() {
503 List<Potential> under = new ArrayList<>(6);
504 under.add(dualTopA);
505 under.add(dualTopB);
506 under.addAll(dualTopA.getUnderlyingPotentials());
507 under.addAll(dualTopB.getUnderlyingPotentials());
508 return under;
509 }
510
511
512 @Override
513 public VARIABLE_TYPE[] getVariableTypes() {
514 if (types == null) {
515 VARIABLE_TYPE[] typesA = dualTopA.getVariableTypes();
516 VARIABLE_TYPE[] typesB = dualTopB.getVariableTypes();
517 if (typesA != null && typesB != null) {
518 types = new VARIABLE_TYPE[nVarTot];
519 copyFrom(types, dualTopA.getVariableTypes(), dualTopB.getVariableTypes());
520 } else {
521 logger.fine(
522 " Variable types array remaining null due to null "
523 + "variable types in either A or B dual topology");
524 }
525 }
526 return types;
527 }
528
529
530 @Override
531 public double[] getVelocity(double[] velocity) {
532 doublesFrom(velocity, dualTopA.getVelocity(tempA), dualTopB.getVelocity(tempB));
533 return velocity;
534 }
535
536
537 @Override
538 public double getd2EdL2() {
539 return d2EdL2;
540 }
541
542
543 @Override
544 public double getdEdL() {
545 return dEdL;
546 }
547
548
549 @Override
550 public void getdEdXdL(double[] g) {
551 dualTopA.getdEdXdL(tempA);
552 dualTopB.getdEdXdL(tempB);
553 addDoublesFrom(g, tempA, tempB);
554 }
555
556
557 @Override
558 public void setAcceleration(double[] acceleration) {
559 doublesTo(acceleration, tempA, tempB);
560 dualTopA.setVelocity(tempA);
561 dualTopB.setVelocity(tempB);
562 }
563
564
565
566
567
568
569 public void setParallel(boolean parallel) {
570 this.inParallel = parallel;
571 if (team != null) {
572 try {
573 team.shutdown();
574 } catch (Exception e) {
575 logger.severe(format(" Exception in shutting down old ParallelTeam for DualTopologyEnergy: %s", e));
576 }
577 }
578 team = parallel ? new ParallelTeam(2) : new ParallelTeam(1);
579 }
580
581
582 @Override
583 public void setPreviousAcceleration(double[] previousAcceleration) {
584 doublesTo(previousAcceleration, tempA, tempB);
585 dualTopA.setPreviousAcceleration(tempA);
586 dualTopB.setPreviousAcceleration(tempB);
587 }
588
589
590
591
592
593
594
595
596
597
598
599 public void setPrintOnFailure(boolean onFail, boolean override) {
600 dualTopA.setPrintOnFailure(onFail, override);
601 dualTopB.setPrintOnFailure(onFail, override);
602 }
603
604
605 @Override
606 public void setVelocity(double[] velocity) {
607 doublesTo(velocity, tempA, tempB);
608 dualTopA.setVelocity(tempA);
609 dualTopB.setVelocity(tempB);
610 }
611
612
613
614
615
616
617
618
619
620
621
622 private <T> void copyFrom(T[] to, T[] fromA, T[] fromB) {
623 if (to == null) {
624 to = Arrays.copyOf(fromA, nVarTot);
625 }
626 for (int i = 0; i < nVarA; i++) {
627 int index = indexAToGlobal[i];
628 to[index] = fromA[i];
629 }
630 for (int i = 0; i < nVarB; i++) {
631 int index = indexBToGlobal[i];
632
633 assert (index >= nShared || to[index].equals(fromB[i]));
634 to[index] = fromB[i];
635 }
636 }
637
638
639
640
641
642
643
644
645 private void doublesTo(double[] from, double[] toA, double[] toB) {
646 toA = (toA == null) ? new double[nVarA] : toA;
647 toB = (toB == null) ? new double[nVarB] : toB;
648 for (int i = 0; i < nVarTot; i++) {
649 int index = indexGlobalToA[i];
650 if (index >= 0) {
651 toA[index] = from[i];
652 }
653 index = indexGlobalToB[i];
654 if (index >= 0) {
655 toB[index] = from[i];
656 }
657 }
658 }
659
660
661
662
663
664
665
666
667
668
669 private void doublesFrom(double[] to, double[] fromA, double[] fromB) {
670 to = (to == null) ? new double[nVarTot] : to;
671 for (int i = 0; i < nVarA; i++) {
672 to[indexAToGlobal[i]] = fromA[i];
673 }
674 for (int i = 0; i < nVarB; i++) {
675 int index = indexBToGlobal[i];
676
677
678 assert (index >= nShared || to[index] == fromB[i]);
679 to[index] = fromB[i];
680 }
681 }
682
683
684
685
686
687
688
689
690
691
692 private void doublesFromFunction(
693 double[] to, double[] fromA, double[] fromB, DoubleBinaryOperator funct) {
694 to = (to == null) ? new double[nVarTot] : to;
695 for (int i = 0; i < nVarA; i++) {
696 to[indexAToGlobal[i]] = fromA[i];
697 }
698 for (int i = 0; i < nVarB; i++) {
699 int index = indexBToGlobal[i];
700 if (index < nShared) {
701 double current = to[index];
702 logger.finer(format(" Current %g, i %d, index %d", current, i, index));
703 to[index] = funct.applyAsDouble(current, fromB[i]);
704 logger.finer(format(" New: %g", to[index]));
705 } else {
706 logger.finer(
707 format(
708 " Applying %g to i %d index %d, current %g", fromB[i], i, index, to[index]));
709 to[index] = fromB[i];
710 }
711 }
712 }
713
714
715
716
717
718
719
720
721
722 private void addDoublesFrom(double[] to, double[] fromA, double[] fromB) {
723 to = (to == null) ? new double[nVarTot] : to;
724 fill(to, 0.0);
725 for (int i = 0; i < nVarA; i++) {
726 to[indexAToGlobal[i]] = fromA[i];
727 }
728 for (int i = 0; i < nVarB; i++) {
729 to[indexBToGlobal[i]] += fromB[i];
730 }
731 }
732
733 private class EnergyRegion extends ParallelRegion {
734
735 private final EnergyASection sectA;
736 private final EnergyBSection sectB;
737 private double[] x;
738 private double[] g;
739 private boolean gradient = false;
740
741 EnergyRegion() {
742 sectA = new EnergyASection();
743 sectB = new EnergyBSection();
744 }
745
746 @Override
747 public void finish() {
748 totalEnergy = energyA + energyB;
749 if (gradient) {
750 addDoublesFrom(g, gA, gB);
751 dEdL = dEdL_A + dEdL_B;
752 d2EdL2 = d2EdL2_A + d2EdL2_B;
753 }
754 gradient = false;
755 }
756
757 @Override
758 public void run() throws Exception {
759 execute(sectA, sectB);
760 }
761
762 public void setG(double[] g) {
763 this.g = g;
764 setGradient(true);
765 }
766
767 public void setGradient(boolean gradient) {
768 this.gradient = gradient;
769 sectA.setGradient(gradient);
770 sectB.setGradient(gradient);
771 }
772
773 public void setVerbose(boolean verbose) {
774 sectA.setVerbose(verbose);
775 sectB.setVerbose(verbose);
776 }
777
778 public void setX(double[] x) {
779 this.x = x;
780 }
781
782 @Override
783 public void start() {
784 doublesTo(x, xA, xB);
785 }
786 }
787
788 private class EnergyASection extends ParallelSection {
789
790 private boolean verbose = false;
791 private boolean gradient = false;
792
793 @Override
794 public void run() throws Exception {
795 if (gradient) {
796 energyA = dualTopA.energyAndGradient(xA, gA, verbose);
797 dEdL_A = linterA.getdEdL();
798 d2EdL2_A = linterA.getd2EdL2();
799 } else {
800 energyA = dualTopA.energy(xA, verbose);
801 }
802 this.verbose = false;
803 this.gradient = false;
804 }
805
806 public void setGradient(boolean gradient) {
807 this.gradient = gradient;
808 }
809
810 public void setVerbose(boolean verbose) {
811 this.verbose = verbose;
812 }
813 }
814
815 private class EnergyBSection extends ParallelSection {
816
817 private boolean verbose = false;
818 private boolean gradient = false;
819
820 @Override
821 public void run() throws Exception {
822 if (gradient) {
823 energyB = dualTopB.energyAndGradient(xB, gB, verbose);
824 dEdL_B = linterB.getdEdL();
825 d2EdL2_B = linterB.getd2EdL2();
826 } else {
827 energyB = dualTopB.energy(xB, verbose);
828 }
829 this.verbose = false;
830 this.gradient = false;
831 }
832
833 public void setGradient(boolean gradient) {
834 this.gradient = gradient;
835 }
836
837 public void setVerbose(boolean verbose) {
838 this.verbose = verbose;
839 }
840 }
841 }