1   
2   
3   
4   
5   
6   
7   
8   
9   
10  
11  
12  
13  
14  
15  
16  
17  
18  
19  
20  
21  
22  
23  
24  
25  
26  
27  
28  
29  
30  
31  
32  
33  
34  
35  
36  
37  
38  package ffx.potential;
39  
40  import edu.rit.pj.ParallelRegion;
41  import edu.rit.pj.ParallelSection;
42  import edu.rit.pj.ParallelTeam;
43  import ffx.crystal.Crystal;
44  import ffx.crystal.CrystalPotential;
45  import ffx.numerics.Potential;
46  import ffx.potential.bonded.LambdaInterface;
47  import ffx.potential.utils.EnergyException;
48  
49  import java.util.ArrayList;
50  import java.util.Arrays;
51  import java.util.LinkedHashSet;
52  import java.util.List;
53  import java.util.Set;
54  import java.util.function.DoubleBinaryOperator;
55  import java.util.logging.Logger;
56  
57  import static java.lang.String.format;
58  import static java.util.Arrays.fill;
59  
60  
61  
62  
63  
64  
65  
66  
67  
68  
69  
70  
71  public class QuadTopologyEnergy implements CrystalPotential, LambdaInterface {
72    private static final Logger logger = Logger.getLogger(QuadTopologyEnergy.class.getName());
73    private final DualTopologyEnergy dualTopA;
74    private final DualTopologyEnergy dualTopB;
75    private final LambdaInterface linterA;
76    private final LambdaInterface linterB;
77  
78    private final int nVarA;
79    private final int nVarB;
80    private final int nShared;
81    private final int uniqueA;
82    private final int uniqueB;
83    private final int nVarTot;
84  
85    
86  
87  
88  
89    private final int[] indexAToGlobal;
90  
91    private final int[] indexBToGlobal;
92    private final int[] indexGlobalToA;
93    private final int[] indexGlobalToB;
94  
95    private final double[] mass;
96    private final double[] xA;
97    private final double[] xB;
98    private final double[] gA;
99    private final double[] gB;
100   
101 
102 
103 
104 
105 
106   private final double[] tempA;
107 
108   private final double[] tempB;
109   private final EnergyRegion region;
110   private STATE state = STATE.BOTH;
111   private double lambda;
112   private double totalEnergy;
113   private double energyA;
114   private double energyB;
115   private double dEdL, dEdL_A, dEdL_B;
116   private double d2EdL2, d2EdL2_A, d2EdL2_B;
117   private boolean inParallel = false;
118   private ParallelTeam team;
119   
120   private double[] scaling;
121 
122   private VARIABLE_TYPE[] types = null;
123 
124   
125 
126 
127 
128 
129 
130 
131 
132 
133 
134 
135 
136 
137 
138 
139   public QuadTopologyEnergy(DualTopologyEnergy dualTopologyA, DualTopologyEnergy dualTopologyB) {
140     this(dualTopologyA, dualTopologyB, null, null);
141   }
142 
143   
144 
145 
146 
147 
148 
149 
150 
151 
152 
153 
154 
155 
156 
157 
158   public QuadTopologyEnergy(
159       DualTopologyEnergy dualTopologyA,
160       DualTopologyEnergy dualTopologyB,
161       List<Integer> uniqueAList,
162       List<Integer> uniqueBList) {
163     this.dualTopA = dualTopologyA;
164     this.dualTopB = dualTopologyB;
165     dualTopB.setCrystal(dualTopA.getCrystal());
166     dualTopB.reloadCommonMasses(true);
167     linterA = dualTopologyA;
168     linterB = dualTopologyB;
169     nVarA = dualTopA.getNumberOfVariables();
170     nVarB = dualTopB.getNumberOfVariables();
171 
172     
173 
174 
175 
176 
177 
178     Set<Integer> uniqueASet = new LinkedHashSet<>();
179     if (uniqueAList != null) {
180       uniqueASet.addAll(uniqueAList);
181     }
182     int nCommon = dualTopA.getNumSharedVariables();
183     for (int i = nCommon; i < nVarA; i++) {
184       uniqueASet.add(i);
185     }
186 
187     uniqueAList = new ArrayList<>(uniqueASet);
188     uniqueA = uniqueAList.size();
189     int[] uniquesA = new int[uniqueA];
190     for (int i = 0; i < uniqueA; i++) {
191       uniquesA[i] = uniqueAList.get(i);
192     }
193     
194     
195     
196 
197     
198 
199 
200 
201 
202 
203     Set<Integer> uniqueBSet = new LinkedHashSet<>();
204     if (uniqueBList != null) {
205       uniqueBSet.addAll(uniqueBList);
206     }
207     nCommon = dualTopB.getNumSharedVariables();
208     for (int i = nCommon; i < nVarB; i++) {
209       uniqueBSet.add(i);
210     }
211     uniqueBList = new ArrayList<>(uniqueBSet);
212     uniqueB = uniqueBList.size();
213     int[] uniquesB = new int[uniqueB];
214     for (int i = 0; i < uniqueB; i++) {
215       uniquesB[i] = uniqueBList.get(i);
216     }
217 
218     nShared = nVarA - uniqueA;
219     assert (nShared == nVarB - uniqueB);
220     nVarTot = nShared + uniqueA + uniqueB;
221 
222     indexAToGlobal = new int[nVarA];
223     indexBToGlobal = new int[nVarB];
224     indexGlobalToA = new int[nVarTot];
225     indexGlobalToB = new int[nVarTot];
226     
227     
228     fill(indexGlobalToA, -1);
229     fill(indexGlobalToB, -1);
230 
231     if (uniqueA > 0) {
232       int commonIndex = 0;
233       int uniqueIndex = 0;
234       for (int i = 0; i < nVarA; i++) {
235         if (uniqueIndex < uniqueA && i == uniquesA[uniqueIndex]) {
236           int destIndex = nShared + uniqueIndex;
237           indexAToGlobal[i] = destIndex;
238           indexGlobalToA[destIndex] = i;
239           ++uniqueIndex;
240         } else {
241           indexAToGlobal[i] = commonIndex;
242           indexGlobalToA[commonIndex++] = i;
243         }
244       }
245     } else {
246       for (int i = 0; i < nVarA; i++) {
247         indexAToGlobal[i] = i;
248         indexGlobalToA[i] = i;
249       }
250     }
251 
252     if (uniqueB > 0) {
253       int commonIndex = 0;
254       int uniqueIndex = 0;
255       for (int i = 0; i < nVarB; i++) {
256         if (uniqueIndex < uniqueB && i == uniquesB[uniqueIndex]) {
257           int destIndex = nVarA + uniqueIndex;
258           indexBToGlobal[i] = destIndex;
259           indexGlobalToB[destIndex] = i;
260           ++uniqueIndex;
261         } else {
262           indexBToGlobal[i] = commonIndex;
263           indexGlobalToB[commonIndex++] = i;
264         }
265       }
266     } else {
267       for (int i = 0; i < nVarB; i++) {
268         indexBToGlobal[i] = i;
269         indexGlobalToB[i] = i;
270       }
271     }
272 
273     xA = new double[nVarA];
274     xB = new double[nVarB];
275     gA = new double[nVarA];
276     gB = new double[nVarB];
277     tempA = new double[nVarA];
278     tempB = new double[nVarB];
279     mass = new double[nVarTot];
280     doublesFromFunction(mass, dualTopA.getMass(), dualTopB.getMass(), Math::max);
281 
282     region = new EnergyRegion();
283     team = new ParallelTeam(1);
284   }
285 
286   
287 
288 
289 
290 
291   @Override
292   public boolean dEdLZeroAtEnds() {
293     if (!dualTopA.dEdLZeroAtEnds() || !dualTopB.dEdLZeroAtEnds()) {
294       return false;
295     }
296     return true;
297   }
298 
299   
300   @Override
301   public boolean destroy() {
302     boolean dtADestroy = dualTopA.destroy();
303     boolean dtBDestroy = dualTopB.destroy();
304     try {
305       if (team != null) {
306         team.shutdown();
307       }
308       return dtADestroy && dtBDestroy;
309     } catch (Exception ex) {
310       logger.warning(format(" Exception in shutting down QuadTopologyEnergy: %s", ex));
311       logger.info(Utilities.stackTraceToString(ex));
312       return false;
313     }
314   }
315 
316   
317   @Override
318   public double energy(double[] x) {
319     return energy(x, false);
320   }
321 
322   
323   @Override
324   public double energy(double[] x, boolean verbose) {
325     region.setX(x);
326     region.setVerbose(verbose);
327     try {
328       team.execute(region);
329     } catch (Exception ex) {
330       throw new EnergyException(format(" Exception in calculating quad-topology energy: %s", ex));
331     }
332 
333     if (verbose) {
334       logger.info(format(" Total quad-topology energy: %12.4f", totalEnergy));
335     }
336     return totalEnergy;
337   }
338 
339   
340   @Override
341   public double energyAndGradient(double[] x, double[] g) {
342     return energyAndGradient(x, g, false);
343   }
344 
345   
346   @Override
347   public double energyAndGradient(double[] x, double[] g, boolean verbose) {
348     assert Arrays.stream(x).allMatch(Double::isFinite);
349     region.setX(x);
350     region.setG(g);
351     region.setVerbose(verbose);
352     try {
353       team.execute(region);
354     } catch (Exception ex) {
355       throw new EnergyException(format(" Exception in calculating quad-topology energy: %s", ex));
356     }
357 
358     if (verbose) {
359       logger.info(format(" Total quad-topology energy: %12.4f", totalEnergy));
360     }
361     return totalEnergy;
362   }
363 
364   
365   @Override
366   public double[] getAcceleration(double[] acceleration) {
367     doublesFrom(acceleration, dualTopA.getAcceleration(tempA), dualTopB.getAcceleration(tempB));
368     return acceleration;
369   }
370 
371   
372   @Override
373   public double[] getCoordinates(double[] x) {
374     dualTopA.getCoordinates(xA);
375     dualTopB.getCoordinates(xB);
376     doublesFrom(x, xA, xB);
377     return x;
378   }
379 
380   
381   @Override
382   public Crystal getCrystal() {
383     return dualTopA.getCrystal();
384   }
385 
386   
387   @Override
388   public void setCrystal(Crystal crystal) {
389     dualTopA.setCrystal(crystal);
390     dualTopB.setCrystal(crystal);
391   }
392 
393   
394 
395 
396 
397 
398   public DualTopologyEnergy getDualTopA() {
399     return dualTopA;
400   }
401 
402   
403 
404 
405 
406 
407   public DualTopologyEnergy getDualTopB() {
408     return dualTopB;
409   }
410 
411   
412   @Override
413   public STATE getEnergyTermState() {
414     return state;
415   }
416 
417   
418   @Override
419   public void setEnergyTermState(STATE state) {
420     this.state = state;
421     dualTopA.setEnergyTermState(state);
422     dualTopB.setEnergyTermState(state);
423   }
424 
425   
426   @Override
427   public double getLambda() {
428     return lambda;
429   }
430 
431   
432   @Override
433   public void setLambda(double lambda) {
434     if (!Double.isFinite(lambda) || lambda > 1.0 || lambda < 0.0) {
435       throw new ArithmeticException(
436           format(" Attempted to set invalid lambda value of %10.6g", lambda));
437     }
438     this.lambda = lambda;
439     dualTopA.setLambda(lambda);
440     dualTopB.setLambda(lambda);
441   }
442 
443   
444   @Override
445   public double[] getMass() {
446     return mass;
447   }
448 
449   
450 
451 
452 
453 
454   public int getNumSharedVariables() {
455     return nShared;
456   }
457 
458   
459   @Override
460   public int getNumberOfVariables() {
461     return nVarTot;
462   }
463 
464   
465   @Override
466   public double[] getPreviousAcceleration(double[] previousAcceleration) {
467     doublesFrom(
468         previousAcceleration,
469         dualTopA.getPreviousAcceleration(tempA),
470         dualTopB.getPreviousAcceleration(tempB));
471     return previousAcceleration;
472   }
473 
474   
475   @Override
476   public double[] getScaling() {
477     return scaling;
478   }
479 
480   
481   @Override
482   public void setScaling(double[] scaling) {
483     this.scaling = scaling;
484     if (scaling != null) {
485       double[] scaleA = new double[nVarA];
486       double[] scaleB = new double[nVarB];
487       doublesTo(scaling, scaleA, scaleB);
488       dualTopA.setScaling(scaleA);
489       dualTopB.setScaling(scaleB);
490     } else {
491       dualTopA.setScaling(null);
492       dualTopB.setScaling(null);
493     }
494   }
495 
496   
497   @Override
498   public double getTotalEnergy() {
499     return totalEnergy;
500   }
501 
502   @Override
503   public List<Potential> getUnderlyingPotentials() {
504     List<Potential> under = new ArrayList<>(6);
505     under.add(dualTopA);
506     under.add(dualTopB);
507     under.addAll(dualTopA.getUnderlyingPotentials());
508     under.addAll(dualTopB.getUnderlyingPotentials());
509     return under;
510   }
511 
512   
513   @Override
514   public VARIABLE_TYPE[] getVariableTypes() {
515     if (types == null) {
516       VARIABLE_TYPE[] typesA = dualTopA.getVariableTypes();
517       VARIABLE_TYPE[] typesB = dualTopB.getVariableTypes();
518       if (typesA != null && typesB != null) {
519         types = new VARIABLE_TYPE[nVarTot];
520         copyFrom(types, dualTopA.getVariableTypes(), dualTopB.getVariableTypes());
521       } else {
522         logger.fine(
523             " Variable types array remaining null due to null "
524                 + "variable types in either A or B dual topology");
525       }
526     }
527     return types;
528   }
529 
530   
531   @Override
532   public double[] getVelocity(double[] velocity) {
533     doublesFrom(velocity, dualTopA.getVelocity(tempA), dualTopB.getVelocity(tempB));
534     return velocity;
535   }
536 
537   
538   @Override
539   public double getd2EdL2() {
540     return d2EdL2;
541   }
542 
543   
544   @Override
545   public double getdEdL() {
546     return dEdL;
547   }
548 
549   
550   @Override
551   public void getdEdXdL(double[] g) {
552     dualTopA.getdEdXdL(tempA);
553     dualTopB.getdEdXdL(tempB);
554     addDoublesFrom(g, tempA, tempB);
555   }
556 
557   
558   @Override
559   public void setAcceleration(double[] acceleration) {
560     doublesTo(acceleration, tempA, tempB);
561     dualTopA.setVelocity(tempA);
562     dualTopB.setVelocity(tempB);
563   }
564 
565   
566 
567 
568 
569 
570   public void setParallel(boolean parallel) {
571     this.inParallel = parallel;
572     if (team != null) {
573       try {
574         team.shutdown();
575       } catch (Exception e) {
576         logger.severe(format(" Exception in shutting down old ParallelTeam for DualTopologyEnergy: %s", e));
577       }
578     }
579     team = parallel ? new ParallelTeam(2) : new ParallelTeam(1);
580   }
581 
582   
583   @Override
584   public void setPreviousAcceleration(double[] previousAcceleration) {
585     doublesTo(previousAcceleration, tempA, tempB);
586     dualTopA.setPreviousAcceleration(tempA);
587     dualTopB.setPreviousAcceleration(tempB);
588   }
589 
590   
591 
592 
593 
594 
595 
596 
597 
598 
599 
600   public void setPrintOnFailure(boolean onFail, boolean override) {
601     dualTopA.setPrintOnFailure(onFail, override);
602     dualTopB.setPrintOnFailure(onFail, override);
603   }
604 
605   
606   @Override
607   public void setCoordinates(double[] coordinates) {
608     doublesTo(coordinates, tempA, tempB);
609     dualTopA.setCoordinates(tempA);
610     dualTopB.setCoordinates(tempB);
611   }
612 
613   
614   @Override
615   public void setVelocity(double[] velocity) {
616     doublesTo(velocity, tempA, tempB);
617     dualTopA.setVelocity(tempA);
618     dualTopB.setVelocity(tempB);
619   }
620 
621   
622 
623 
624 
625 
626 
627 
628 
629 
630 
631   private <T> void copyFrom(T[] to, T[] fromA, T[] fromB) {
632     if (to == null) {
633       to = Arrays.copyOf(fromA, nVarTot);
634     }
635     for (int i = 0; i < nVarA; i++) {
636       int index = indexAToGlobal[i];
637       to[index] = fromA[i];
638     }
639     for (int i = 0; i < nVarB; i++) {
640       int index = indexBToGlobal[i];
641       
642       assert (index >= nShared || to[index].equals(fromB[i]));
643       to[index] = fromB[i];
644     }
645   }
646 
647   
648 
649 
650 
651 
652 
653 
654   private void doublesTo(double[] from, double[] toA, double[] toB) {
655     toA = (toA == null) ? new double[nVarA] : toA;
656     toB = (toB == null) ? new double[nVarB] : toB;
657     for (int i = 0; i < nVarTot; i++) {
658       int index = indexGlobalToA[i];
659       if (index >= 0) {
660         toA[index] = from[i];
661       }
662       index = indexGlobalToB[i];
663       if (index >= 0) {
664         toB[index] = from[i];
665       }
666     }
667   }
668 
669   
670 
671 
672 
673 
674 
675 
676 
677 
678   private void doublesFrom(double[] to, double[] fromA, double[] fromB) {
679     to = (to == null) ? new double[nVarTot] : to;
680     for (int i = 0; i < nVarA; i++) {
681       to[indexAToGlobal[i]] = fromA[i];
682     }
683     for (int i = 0; i < nVarB; i++) {
684       int index = indexBToGlobal[i];
685       
686       
687       assert (index >= nShared || to[index] == fromB[i]);
688       to[index] = fromB[i];
689     }
690   }
691 
692   
693 
694 
695 
696 
697 
698 
699 
700 
701   private void doublesFromFunction(
702       double[] to, double[] fromA, double[] fromB, DoubleBinaryOperator funct) {
703     to = (to == null) ? new double[nVarTot] : to;
704     for (int i = 0; i < nVarA; i++) {
705       to[indexAToGlobal[i]] = fromA[i];
706     }
707     for (int i = 0; i < nVarB; i++) {
708       int index = indexBToGlobal[i];
709       if (index < nShared) {
710         double current = to[index];
711         logger.finer(format(" Current %g, i %d, index %d", current, i, index));
712         to[index] = funct.applyAsDouble(current, fromB[i]);
713         logger.finer(format(" New: %g", to[index]));
714       } else {
715         logger.finer(
716             format(
717                 " Applying %g to i %d index %d, current %g", fromB[i], i, index, to[index]));
718         to[index] = fromB[i];
719       }
720     }
721   }
722 
723   
724 
725 
726 
727 
728 
729 
730 
731   private void addDoublesFrom(double[] to, double[] fromA, double[] fromB) {
732     to = (to == null) ? new double[nVarTot] : to;
733     fill(to, 0.0);
734     for (int i = 0; i < nVarA; i++) {
735       to[indexAToGlobal[i]] = fromA[i];
736     }
737     for (int i = 0; i < nVarB; i++) {
738       to[indexBToGlobal[i]] += fromB[i];
739     }
740   }
741 
742   private class EnergyRegion extends ParallelRegion {
743 
744     private final EnergyASection sectA;
745     private final EnergyBSection sectB;
746     private double[] x;
747     private double[] g;
748     private boolean gradient = false;
749 
750     EnergyRegion() {
751       sectA = new EnergyASection();
752       sectB = new EnergyBSection();
753     }
754 
755     @Override
756     public void finish() {
757       totalEnergy = energyA + energyB;
758       if (gradient) {
759         addDoublesFrom(g, gA, gB);
760         dEdL = dEdL_A + dEdL_B;
761         d2EdL2 = d2EdL2_A + d2EdL2_B;
762       }
763       gradient = false;
764     }
765 
766     @Override
767     public void run() throws Exception {
768       execute(sectA, sectB);
769     }
770 
771     public void setG(double[] g) {
772       this.g = g;
773       setGradient(true);
774     }
775 
776     public void setGradient(boolean gradient) {
777       this.gradient = gradient;
778       sectA.setGradient(gradient);
779       sectB.setGradient(gradient);
780     }
781 
782     public void setVerbose(boolean verbose) {
783       sectA.setVerbose(verbose);
784       sectB.setVerbose(verbose);
785     }
786 
787     public void setX(double[] x) {
788       this.x = x;
789     }
790 
791     @Override
792     public void start() {
793       doublesTo(x, xA, xB);
794     }
795   }
796 
797   private class EnergyASection extends ParallelSection {
798 
799     private boolean verbose = false;
800     private boolean gradient = false;
801 
802     @Override
803     public void run() throws Exception {
804       if (gradient) {
805         energyA = dualTopA.energyAndGradient(xA, gA, verbose);
806         dEdL_A = linterA.getdEdL();
807         d2EdL2_A = linterA.getd2EdL2();
808       } else {
809         energyA = dualTopA.energy(xA, verbose);
810       }
811       this.verbose = false;
812       this.gradient = false;
813     }
814 
815     public void setGradient(boolean gradient) {
816       this.gradient = gradient;
817     }
818 
819     public void setVerbose(boolean verbose) {
820       this.verbose = verbose;
821     }
822   }
823 
824   private class EnergyBSection extends ParallelSection {
825 
826     private boolean verbose = false;
827     private boolean gradient = false;
828 
829     @Override
830     public void run() throws Exception {
831       if (gradient) {
832         energyB = dualTopB.energyAndGradient(xB, gB, verbose);
833         dEdL_B = linterB.getdEdL();
834         d2EdL2_B = linterB.getd2EdL2();
835       } else {
836         energyB = dualTopB.energy(xB, verbose);
837       }
838       this.verbose = false;
839       this.gradient = false;
840     }
841 
842     public void setGradient(boolean gradient) {
843       this.gradient = gradient;
844     }
845 
846     public void setVerbose(boolean verbose) {
847       this.verbose = verbose;
848     }
849   }
850 }