View Javadoc
1   package ffx.potential.bonded;
2   
3   import ffx.numerics.atomic.AtomicDoubleArray3D;
4   import ffx.potential.parameters.TorsionType;
5   
6   import java.util.Arrays;
7   import java.util.function.DoubleUnaryOperator;
8   
9   import static org.apache.commons.math3.util.FastMath.*;
10  
11  public class RestraintTorsion extends BondedTerm implements LambdaInterface {
12  
13      private final Atom[] atoms;
14      public final TorsionType torsionType;
15      private final boolean lambdaTerm;
16      private final DoubleUnaryOperator lamMapper;
17      public final double units;
18  
19      private double lambda = 1.0;
20      private double dEdL = 0.0;
21      private double d2EdL2 = 0.0;
22  
23      public RestraintTorsion(Atom a1, Atom a2, Atom a3, Atom a4, TorsionType tType, boolean lamEnabled, boolean revLam, double units) {
24          atoms = new Atom[]{a1, a2, a3, a4};
25          this.torsionType = tType;
26          lambdaTerm = lamEnabled;
27          if (this.lambdaTerm) {
28              if (revLam) {
29                  lamMapper = (double l) -> 1.0 - l;
30              } else {
31                  lamMapper = (double l) -> l;
32              }
33          } else {
34              lamMapper = (double l) -> 1.0;
35          }
36          this.units = units;
37      }
38  
39      @Override
40      public double energy(boolean gradient, int threadID, AtomicDoubleArray3D grad, AtomicDoubleArray3D lambdaGrad) {
41          energy = 0.0;
42          value = 0.0;
43          dEdL = 0.0;
44          var atomA = atoms[0];
45          var atomB = atoms[1];
46          var atomC = atoms[2];
47          var atomD = atoms[3];
48          var va = atomA.getXYZ();
49          var vb = atomB.getXYZ();
50          var vc = atomC.getXYZ();
51          var vd = atomD.getXYZ();
52          var vba = vb.sub(va);
53          var vcb = vc.sub(vb);
54          var vdc = vd.sub(vc);
55          var vt = vba.X(vcb);
56          var vu = vcb.X(vdc);
57          var rt2 = vt.length2();
58          var ru2 = vu.length2();
59          var rtru2 = rt2 * ru2;
60          if (rtru2 != 0.0) {
61              var rr = sqrt(rtru2);
62              var rcb = vcb.length();
63              var cosine = vt.dot(vu) / rr;
64              var sine = vcb.dot(vt.X(vu)) / (rcb * rr);
65              value = toDegrees(acos(cosine));
66              if (sine < 0.0) {
67                  value = -value;
68              }
69              var amp = torsionType.amplitude;
70              var tsin = torsionType.sine;
71              var tcos = torsionType.cosine;
72              energy = amp[0] * (1.0 + cosine * tcos[0] + sine * tsin[0]);
73              var dedphi = amp[0] * (cosine * tsin[0] - sine * tcos[0]);
74              var cosprev = cosine;
75              var sinprev = sine;
76              var n = torsionType.terms;
77              for (int i = 1; i < n; i++) {
78                  var cosn = cosine * cosprev - sine * sinprev;
79                  var sinn = sine * cosprev + cosine * sinprev;
80                  var phi = 1.0 + cosn * tcos[i] + sinn * tsin[i];
81                  var dphi = (1.0 + i) * (cosn * tsin[i] - sinn * tcos[i]);
82                  energy = energy + amp[i] * phi;
83                  dedphi = dedphi + amp[i] * dphi;
84                  cosprev = cosn;
85                  sinprev = sinn;
86              }
87              energy = units * energy * lambda;
88              dEdL = units * energy;
89              if (gradient || lambdaTerm) {
90                  dedphi = units * dedphi;
91                  var vca = vc.sub(va);
92                  var vdb = vd.sub(vb);
93                  var dedt = vt.X(vcb).scaleI(dedphi / (rt2 * rcb));
94                  var dedu = vu.X(vcb).scaleI(-dedphi / (ru2 * rcb));
95                  var ga = dedt.X(vcb);
96                  var gb = vca.X(dedt).addI(dedu.X(vdc));
97                  var gc = dedt.X(vba).addI(vdb.X(dedu));
98                  var gd = dedu.X(vcb);
99                  int ia = atomA.getIndex() - 1;
100                 int ib = atomB.getIndex() - 1;
101                 int ic = atomC.getIndex() - 1;
102                 int id = atomD.getIndex() - 1;
103                 if (lambdaTerm) {
104                     lambdaGrad.add(threadID, ia, ga);
105                     lambdaGrad.add(threadID, ib, gb);
106                     lambdaGrad.add(threadID, ic, gc);
107                     lambdaGrad.add(threadID, id, gd);
108                 }
109                 if (gradient) {
110                     grad.add(threadID, ia, ga.scaleI(lambda));
111                     grad.add(threadID, ib, gb.scaleI(lambda));
112                     grad.add(threadID, ic, gc.scaleI(lambda));
113                     grad.add(threadID, id, gd.scaleI(lambda));
114                 }
115             }
116         }
117 
118         return energy;
119     }
120 
121     @Override
122     public double getLambda() {
123         return lambda;
124     }
125 
126     public Atom[] getAtoms() {
127         return Arrays.copyOf(atoms, atoms.length);
128     }
129 
130     @Override
131     public Atom getAtom(int index) {
132         return atoms[index];
133     }
134 
135     @Override
136     public boolean applyLambda() {
137         return lambdaTerm;
138     }
139 
140     @Override
141     public void setLambda(double lambda) {
142         this.lambda = lamMapper.applyAsDouble(lambda);
143     }
144 
145     @Override
146     public double getd2EdL2() {
147         return d2EdL2;
148     }
149 
150     @Override
151     public double getdEdL() {
152         return dEdL;
153     }
154 
155     public double mapLambda(double lambda) {
156         return lamMapper.applyAsDouble(lambda);
157     }
158 
159     @Override
160     public void getdEdXdL(double[] gradient) {
161         // The chain rule term is at least supposedly zero.
162     }
163 
164     @Override
165     public String toString() {
166         return String.format(" t-type %s, val %.3f, e %.3f", torsionType, value, energy);
167     }
168 }