1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential;
39
40 import edu.rit.pj.ParallelRegion;
41 import edu.rit.pj.ParallelSection;
42 import edu.rit.pj.ParallelTeam;
43 import ffx.crystal.Crystal;
44 import ffx.crystal.CrystalPotential;
45 import ffx.numerics.Potential;
46 import ffx.potential.bonded.LambdaInterface;
47 import ffx.potential.utils.EnergyException;
48
49 import java.util.ArrayList;
50 import java.util.Arrays;
51 import java.util.LinkedHashSet;
52 import java.util.List;
53 import java.util.Set;
54 import java.util.function.DoubleBinaryOperator;
55 import java.util.logging.Logger;
56
57 import static java.lang.String.format;
58 import static java.util.Arrays.fill;
59
60
61
62
63
64
65
66
67
68
69
70
71 public class QuadTopologyEnergy implements CrystalPotential, LambdaInterface {
72 private static final Logger logger = Logger.getLogger(QuadTopologyEnergy.class.getName());
73 private final DualTopologyEnergy dualTopA;
74 private final DualTopologyEnergy dualTopB;
75 private final LambdaInterface linterA;
76 private final LambdaInterface linterB;
77
78 private final int nVarA;
79 private final int nVarB;
80 private final int nShared;
81 private final int uniqueA;
82 private final int uniqueB;
83 private final int nVarTot;
84
85
86
87
88
89 private final int[] indexAToGlobal;
90
91 private final int[] indexBToGlobal;
92 private final int[] indexGlobalToA;
93 private final int[] indexGlobalToB;
94
95 private final double[] mass;
96 private final double[] xA;
97 private final double[] xB;
98 private final double[] gA;
99 private final double[] gB;
100
101
102
103
104
105
106 private final double[] tempA;
107
108 private final double[] tempB;
109 private final EnergyRegion region;
110 private STATE state = STATE.BOTH;
111 private double lambda;
112 private double totalEnergy;
113 private double energyA;
114 private double energyB;
115 private double dEdL, dEdL_A, dEdL_B;
116 private double d2EdL2, d2EdL2_A, d2EdL2_B;
117 private boolean inParallel = false;
118 private ParallelTeam team;
119
120 private double[] scaling;
121
122 private VARIABLE_TYPE[] types = null;
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139 public QuadTopologyEnergy(DualTopologyEnergy dualTopologyA, DualTopologyEnergy dualTopologyB) {
140 this(dualTopologyA, dualTopologyB, null, null);
141 }
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158 public QuadTopologyEnergy(
159 DualTopologyEnergy dualTopologyA,
160 DualTopologyEnergy dualTopologyB,
161 List<Integer> uniqueAList,
162 List<Integer> uniqueBList) {
163 this.dualTopA = dualTopologyA;
164 this.dualTopB = dualTopologyB;
165 dualTopB.setCrystal(dualTopA.getCrystal());
166 dualTopB.reloadCommonMasses(true);
167 linterA = dualTopologyA;
168 linterB = dualTopologyB;
169 nVarA = dualTopA.getNumberOfVariables();
170 nVarB = dualTopB.getNumberOfVariables();
171
172
173
174
175
176
177
178 Set<Integer> uniqueASet = new LinkedHashSet<>();
179 if (uniqueAList != null) {
180 uniqueASet.addAll(uniqueAList);
181 }
182 int nCommon = dualTopA.getNumSharedVariables();
183 for (int i = nCommon; i < nVarA; i++) {
184 uniqueASet.add(i);
185 }
186
187 uniqueAList = new ArrayList<>(uniqueASet);
188 uniqueA = uniqueAList.size();
189 int[] uniquesA = new int[uniqueA];
190 for (int i = 0; i < uniqueA; i++) {
191 uniquesA[i] = uniqueAList.get(i);
192 }
193
194
195
196
197
198
199
200
201
202
203 Set<Integer> uniqueBSet = new LinkedHashSet<>();
204 if (uniqueBList != null) {
205 uniqueBSet.addAll(uniqueBList);
206 }
207 nCommon = dualTopB.getNumSharedVariables();
208 for (int i = nCommon; i < nVarB; i++) {
209 uniqueBSet.add(i);
210 }
211 uniqueBList = new ArrayList<>(uniqueBSet);
212 uniqueB = uniqueBList.size();
213 int[] uniquesB = new int[uniqueB];
214 for (int i = 0; i < uniqueB; i++) {
215 uniquesB[i] = uniqueBList.get(i);
216 }
217
218 nShared = nVarA - uniqueA;
219 assert (nShared == nVarB - uniqueB);
220 nVarTot = nShared + uniqueA + uniqueB;
221
222 indexAToGlobal = new int[nVarA];
223 indexBToGlobal = new int[nVarB];
224 indexGlobalToA = new int[nVarTot];
225 indexGlobalToB = new int[nVarTot];
226
227
228 fill(indexGlobalToA, -1);
229 fill(indexGlobalToB, -1);
230
231 if (uniqueA > 0) {
232 int commonIndex = 0;
233 int uniqueIndex = 0;
234 for (int i = 0; i < nVarA; i++) {
235 if (uniqueIndex < uniqueA && i == uniquesA[uniqueIndex]) {
236 int destIndex = nShared + uniqueIndex;
237 indexAToGlobal[i] = destIndex;
238 indexGlobalToA[destIndex] = i;
239 ++uniqueIndex;
240 } else {
241 indexAToGlobal[i] = commonIndex;
242 indexGlobalToA[commonIndex++] = i;
243 }
244 }
245 } else {
246 for (int i = 0; i < nVarA; i++) {
247 indexAToGlobal[i] = i;
248 indexGlobalToA[i] = i;
249 }
250 }
251
252 if (uniqueB > 0) {
253 int commonIndex = 0;
254 int uniqueIndex = 0;
255 for (int i = 0; i < nVarB; i++) {
256 if (uniqueIndex < uniqueB && i == uniquesB[uniqueIndex]) {
257 int destIndex = nVarA + uniqueIndex;
258 indexBToGlobal[i] = destIndex;
259 indexGlobalToB[destIndex] = i;
260 ++uniqueIndex;
261 } else {
262 indexBToGlobal[i] = commonIndex;
263 indexGlobalToB[commonIndex++] = i;
264 }
265 }
266 } else {
267 for (int i = 0; i < nVarB; i++) {
268 indexBToGlobal[i] = i;
269 indexGlobalToB[i] = i;
270 }
271 }
272
273 xA = new double[nVarA];
274 xB = new double[nVarB];
275 gA = new double[nVarA];
276 gB = new double[nVarB];
277 tempA = new double[nVarA];
278 tempB = new double[nVarB];
279 mass = new double[nVarTot];
280 doublesFromFunction(mass, dualTopA.getMass(), dualTopB.getMass(), Math::max);
281
282 region = new EnergyRegion();
283 team = new ParallelTeam(1);
284 }
285
286
287
288
289
290
291 @Override
292 public boolean dEdLZeroAtEnds() {
293 if (!dualTopA.dEdLZeroAtEnds() || !dualTopB.dEdLZeroAtEnds()) {
294 return false;
295 }
296 return true;
297 }
298
299
300 @Override
301 public boolean destroy() {
302 boolean dtADestroy = dualTopA.destroy();
303 boolean dtBDestroy = dualTopB.destroy();
304 try {
305 if (team != null) {
306 team.shutdown();
307 }
308 return dtADestroy && dtBDestroy;
309 } catch (Exception ex) {
310 logger.warning(format(" Exception in shutting down QuadTopologyEnergy: %s", ex));
311 logger.info(Utilities.stackTraceToString(ex));
312 return false;
313 }
314 }
315
316
317 @Override
318 public double energy(double[] x) {
319 return energy(x, false);
320 }
321
322
323 @Override
324 public double energy(double[] x, boolean verbose) {
325 region.setX(x);
326 region.setVerbose(verbose);
327 try {
328 team.execute(region);
329 } catch (Exception ex) {
330 throw new EnergyException(format(" Exception in calculating quad-topology energy: %s", ex));
331 }
332
333 if (verbose) {
334 logger.info(format(" Total quad-topology energy: %12.4f", totalEnergy));
335 }
336 return totalEnergy;
337 }
338
339
340 @Override
341 public double energyAndGradient(double[] x, double[] g) {
342 return energyAndGradient(x, g, false);
343 }
344
345
346 @Override
347 public double energyAndGradient(double[] x, double[] g, boolean verbose) {
348 assert Arrays.stream(x).allMatch(Double::isFinite);
349 region.setX(x);
350 region.setG(g);
351 region.setVerbose(verbose);
352 try {
353 team.execute(region);
354 } catch (Exception ex) {
355 throw new EnergyException(format(" Exception in calculating quad-topology energy: %s", ex));
356 }
357
358 if (verbose) {
359 logger.info(format(" Total quad-topology energy: %12.4f", totalEnergy));
360 }
361 return totalEnergy;
362 }
363
364
365 @Override
366 public double[] getAcceleration(double[] acceleration) {
367 doublesFrom(acceleration, dualTopA.getAcceleration(tempA), dualTopB.getAcceleration(tempB));
368 return acceleration;
369 }
370
371
372 @Override
373 public double[] getCoordinates(double[] x) {
374 dualTopA.getCoordinates(xA);
375 dualTopB.getCoordinates(xB);
376 doublesFrom(x, xA, xB);
377 return x;
378 }
379
380
381 @Override
382 public Crystal getCrystal() {
383 return dualTopA.getCrystal();
384 }
385
386
387 @Override
388 public void setCrystal(Crystal crystal) {
389 dualTopA.setCrystal(crystal);
390 dualTopB.setCrystal(crystal);
391 }
392
393
394
395
396
397
398 public DualTopologyEnergy getDualTopA() {
399 return dualTopA;
400 }
401
402
403
404
405
406
407 public DualTopologyEnergy getDualTopB() {
408 return dualTopB;
409 }
410
411
412 @Override
413 public STATE getEnergyTermState() {
414 return state;
415 }
416
417
418 @Override
419 public void setEnergyTermState(STATE state) {
420 this.state = state;
421 dualTopA.setEnergyTermState(state);
422 dualTopB.setEnergyTermState(state);
423 }
424
425
426 @Override
427 public double getLambda() {
428 return lambda;
429 }
430
431
432 @Override
433 public void setLambda(double lambda) {
434 if (!Double.isFinite(lambda) || lambda > 1.0 || lambda < 0.0) {
435 throw new ArithmeticException(
436 format(" Attempted to set invalid lambda value of %10.6g", lambda));
437 }
438 this.lambda = lambda;
439 dualTopA.setLambda(lambda);
440 dualTopB.setLambda(lambda);
441 }
442
443
444 @Override
445 public double[] getMass() {
446 return mass;
447 }
448
449
450
451
452
453
454 public int getNumSharedVariables() {
455 return nShared;
456 }
457
458
459 @Override
460 public int getNumberOfVariables() {
461 return nVarTot;
462 }
463
464
465 @Override
466 public double[] getPreviousAcceleration(double[] previousAcceleration) {
467 doublesFrom(
468 previousAcceleration,
469 dualTopA.getPreviousAcceleration(tempA),
470 dualTopB.getPreviousAcceleration(tempB));
471 return previousAcceleration;
472 }
473
474
475 @Override
476 public double[] getScaling() {
477 return scaling;
478 }
479
480
481 @Override
482 public void setScaling(double[] scaling) {
483 this.scaling = scaling;
484 if (scaling != null) {
485 double[] scaleA = new double[nVarA];
486 double[] scaleB = new double[nVarB];
487 doublesTo(scaling, scaleA, scaleB);
488 dualTopA.setScaling(scaleA);
489 dualTopB.setScaling(scaleB);
490 } else {
491 dualTopA.setScaling(null);
492 dualTopB.setScaling(null);
493 }
494 }
495
496
497 @Override
498 public double getTotalEnergy() {
499 return totalEnergy;
500 }
501
502 @Override
503 public List<Potential> getUnderlyingPotentials() {
504 List<Potential> under = new ArrayList<>(6);
505 under.add(dualTopA);
506 under.add(dualTopB);
507 under.addAll(dualTopA.getUnderlyingPotentials());
508 under.addAll(dualTopB.getUnderlyingPotentials());
509 return under;
510 }
511
512
513 @Override
514 public VARIABLE_TYPE[] getVariableTypes() {
515 if (types == null) {
516 VARIABLE_TYPE[] typesA = dualTopA.getVariableTypes();
517 VARIABLE_TYPE[] typesB = dualTopB.getVariableTypes();
518 if (typesA != null && typesB != null) {
519 types = new VARIABLE_TYPE[nVarTot];
520 copyFrom(types, dualTopA.getVariableTypes(), dualTopB.getVariableTypes());
521 } else {
522 logger.fine(
523 " Variable types array remaining null due to null "
524 + "variable types in either A or B dual topology");
525 }
526 }
527 return types;
528 }
529
530
531 @Override
532 public double[] getVelocity(double[] velocity) {
533 doublesFrom(velocity, dualTopA.getVelocity(tempA), dualTopB.getVelocity(tempB));
534 return velocity;
535 }
536
537
538 @Override
539 public double getd2EdL2() {
540 return d2EdL2;
541 }
542
543
544 @Override
545 public double getdEdL() {
546 return dEdL;
547 }
548
549
550 @Override
551 public void getdEdXdL(double[] g) {
552 dualTopA.getdEdXdL(tempA);
553 dualTopB.getdEdXdL(tempB);
554 addDoublesFrom(g, tempA, tempB);
555 }
556
557
558 @Override
559 public void setAcceleration(double[] acceleration) {
560 doublesTo(acceleration, tempA, tempB);
561 dualTopA.setVelocity(tempA);
562 dualTopB.setVelocity(tempB);
563 }
564
565
566
567
568
569
570 public void setParallel(boolean parallel) {
571 this.inParallel = parallel;
572 if (team != null) {
573 try {
574 team.shutdown();
575 } catch (Exception e) {
576 logger.severe(format(" Exception in shutting down old ParallelTeam for DualTopologyEnergy: %s", e));
577 }
578 }
579 team = parallel ? new ParallelTeam(2) : new ParallelTeam(1);
580 }
581
582
583 @Override
584 public void setPreviousAcceleration(double[] previousAcceleration) {
585 doublesTo(previousAcceleration, tempA, tempB);
586 dualTopA.setPreviousAcceleration(tempA);
587 dualTopB.setPreviousAcceleration(tempB);
588 }
589
590
591
592
593
594
595
596
597
598
599
600 public void setPrintOnFailure(boolean onFail, boolean override) {
601 dualTopA.setPrintOnFailure(onFail, override);
602 dualTopB.setPrintOnFailure(onFail, override);
603 }
604
605
606 @Override
607 public void setCoordinates(double[] coordinates) {
608 doublesTo(coordinates, tempA, tempB);
609 dualTopA.setCoordinates(tempA);
610 dualTopB.setCoordinates(tempB);
611 }
612
613
614 @Override
615 public void setVelocity(double[] velocity) {
616 doublesTo(velocity, tempA, tempB);
617 dualTopA.setVelocity(tempA);
618 dualTopB.setVelocity(tempB);
619 }
620
621
622
623
624
625
626
627
628
629
630
631 private <T> void copyFrom(T[] to, T[] fromA, T[] fromB) {
632 if (to == null) {
633 to = Arrays.copyOf(fromA, nVarTot);
634 }
635 for (int i = 0; i < nVarA; i++) {
636 int index = indexAToGlobal[i];
637 to[index] = fromA[i];
638 }
639 for (int i = 0; i < nVarB; i++) {
640 int index = indexBToGlobal[i];
641
642 assert (index >= nShared || to[index].equals(fromB[i]));
643 to[index] = fromB[i];
644 }
645 }
646
647
648
649
650
651
652
653
654 private void doublesTo(double[] from, double[] toA, double[] toB) {
655 toA = (toA == null) ? new double[nVarA] : toA;
656 toB = (toB == null) ? new double[nVarB] : toB;
657 for (int i = 0; i < nVarTot; i++) {
658 int index = indexGlobalToA[i];
659 if (index >= 0) {
660 toA[index] = from[i];
661 }
662 index = indexGlobalToB[i];
663 if (index >= 0) {
664 toB[index] = from[i];
665 }
666 }
667 }
668
669
670
671
672
673
674
675
676
677
678 private void doublesFrom(double[] to, double[] fromA, double[] fromB) {
679 to = (to == null) ? new double[nVarTot] : to;
680 for (int i = 0; i < nVarA; i++) {
681 to[indexAToGlobal[i]] = fromA[i];
682 }
683 for (int i = 0; i < nVarB; i++) {
684 int index = indexBToGlobal[i];
685
686
687 assert (index >= nShared || to[index] == fromB[i]);
688 to[index] = fromB[i];
689 }
690 }
691
692
693
694
695
696
697
698
699
700
701 private void doublesFromFunction(
702 double[] to, double[] fromA, double[] fromB, DoubleBinaryOperator funct) {
703 to = (to == null) ? new double[nVarTot] : to;
704 for (int i = 0; i < nVarA; i++) {
705 to[indexAToGlobal[i]] = fromA[i];
706 }
707 for (int i = 0; i < nVarB; i++) {
708 int index = indexBToGlobal[i];
709 if (index < nShared) {
710 double current = to[index];
711 logger.finer(format(" Current %g, i %d, index %d", current, i, index));
712 to[index] = funct.applyAsDouble(current, fromB[i]);
713 logger.finer(format(" New: %g", to[index]));
714 } else {
715 logger.finer(
716 format(
717 " Applying %g to i %d index %d, current %g", fromB[i], i, index, to[index]));
718 to[index] = fromB[i];
719 }
720 }
721 }
722
723
724
725
726
727
728
729
730
731 private void addDoublesFrom(double[] to, double[] fromA, double[] fromB) {
732 to = (to == null) ? new double[nVarTot] : to;
733 fill(to, 0.0);
734 for (int i = 0; i < nVarA; i++) {
735 to[indexAToGlobal[i]] = fromA[i];
736 }
737 for (int i = 0; i < nVarB; i++) {
738 to[indexBToGlobal[i]] += fromB[i];
739 }
740 }
741
742 private class EnergyRegion extends ParallelRegion {
743
744 private final EnergyASection sectA;
745 private final EnergyBSection sectB;
746 private double[] x;
747 private double[] g;
748 private boolean gradient = false;
749
750 EnergyRegion() {
751 sectA = new EnergyASection();
752 sectB = new EnergyBSection();
753 }
754
755 @Override
756 public void finish() {
757 totalEnergy = energyA + energyB;
758 if (gradient) {
759 addDoublesFrom(g, gA, gB);
760 dEdL = dEdL_A + dEdL_B;
761 d2EdL2 = d2EdL2_A + d2EdL2_B;
762 }
763 gradient = false;
764 }
765
766 @Override
767 public void run() throws Exception {
768 execute(sectA, sectB);
769 }
770
771 public void setG(double[] g) {
772 this.g = g;
773 setGradient(true);
774 }
775
776 public void setGradient(boolean gradient) {
777 this.gradient = gradient;
778 sectA.setGradient(gradient);
779 sectB.setGradient(gradient);
780 }
781
782 public void setVerbose(boolean verbose) {
783 sectA.setVerbose(verbose);
784 sectB.setVerbose(verbose);
785 }
786
787 public void setX(double[] x) {
788 this.x = x;
789 }
790
791 @Override
792 public void start() {
793 doublesTo(x, xA, xB);
794 }
795 }
796
797 private class EnergyASection extends ParallelSection {
798
799 private boolean verbose = false;
800 private boolean gradient = false;
801
802 @Override
803 public void run() throws Exception {
804 if (gradient) {
805 energyA = dualTopA.energyAndGradient(xA, gA, verbose);
806 dEdL_A = linterA.getdEdL();
807 d2EdL2_A = linterA.getd2EdL2();
808 } else {
809 energyA = dualTopA.energy(xA, verbose);
810 }
811 this.verbose = false;
812 this.gradient = false;
813 }
814
815 public void setGradient(boolean gradient) {
816 this.gradient = gradient;
817 }
818
819 public void setVerbose(boolean verbose) {
820 this.verbose = verbose;
821 }
822 }
823
824 private class EnergyBSection extends ParallelSection {
825
826 private boolean verbose = false;
827 private boolean gradient = false;
828
829 @Override
830 public void run() throws Exception {
831 if (gradient) {
832 energyB = dualTopB.energyAndGradient(xB, gB, verbose);
833 dEdL_B = linterB.getdEdL();
834 d2EdL2_B = linterB.getd2EdL2();
835 } else {
836 energyB = dualTopB.energy(xB, verbose);
837 }
838 this.verbose = false;
839 this.gradient = false;
840 }
841
842 public void setGradient(boolean gradient) {
843 this.gradient = gradient;
844 }
845
846 public void setVerbose(boolean verbose) {
847 this.verbose = verbose;
848 }
849 }
850 }