View Javadoc
1   package ffx.potential.bonded;
2   
3   import ffx.numerics.atomic.AtomicDoubleArray3D;
4   import ffx.potential.parameters.TorsionType;
5   
6   import java.io.Serial;
7   import java.util.Arrays;
8   import java.util.function.DoubleUnaryOperator;
9   
10  import static org.apache.commons.math3.util.FastMath.*;
11  
12  public class RestraintTorsion extends BondedTerm implements LambdaInterface {
13  
14    @Serial
15    private static final long serialVersionUID = 1L;
16  
17    private final Atom[] atoms;
18    public final TorsionType torsionType;
19    private final boolean lambdaTerm;
20    private final DoubleUnaryOperator lamMapper;
21    public final double units;
22  
23    private double lambda = 1.0;
24    private double dEdL = 0.0;
25    private double d2EdL2 = 0.0;
26  
27    public RestraintTorsion(Atom a1, Atom a2, Atom a3, Atom a4, TorsionType tType, boolean lamEnabled,
28        boolean revLam, double units) {
29      atoms = new Atom[] {a1, a2, a3, a4};
30      this.torsionType = tType;
31      lambdaTerm = lamEnabled;
32      if (this.lambdaTerm) {
33        if (revLam) {
34          lamMapper = (double l) -> 1.0 - l;
35        } else {
36          lamMapper = (double l) -> l;
37        }
38      } else {
39        lamMapper = (double l) -> 1.0;
40      }
41      this.units = units;
42    }
43  
44    @Override
45    public double energy(boolean gradient, int threadID, AtomicDoubleArray3D grad, AtomicDoubleArray3D lambdaGrad) {
46      energy = 0.0;
47      value = 0.0;
48      dEdL = 0.0;
49      // Only compute this term if at least one atom is being used.
50      if (!getUse()) {
51        return energy;
52      }
53      var atomA = atoms[0];
54      var atomB = atoms[1];
55      var atomC = atoms[2];
56      var atomD = atoms[3];
57      var va = atomA.getXYZ();
58      var vb = atomB.getXYZ();
59      var vc = atomC.getXYZ();
60      var vd = atomD.getXYZ();
61      var vba = vb.sub(va);
62      var vcb = vc.sub(vb);
63      var vdc = vd.sub(vc);
64      var vt = vba.X(vcb);
65      var vu = vcb.X(vdc);
66      var rt2 = vt.length2();
67      var ru2 = vu.length2();
68      var rtru2 = rt2 * ru2;
69      if (rtru2 != 0.0) {
70        var rr = sqrt(rtru2);
71        var rcb = vcb.length();
72        var cosine = vt.dot(vu) / rr;
73        var sine = vcb.dot(vt.X(vu)) / (rcb * rr);
74        value = toDegrees(acos(cosine));
75        if (sine < 0.0) {
76          value = -value;
77        }
78        var amp = torsionType.amplitude;
79        var tsin = torsionType.sine;
80        var tcos = torsionType.cosine;
81        energy = amp[0] * (1.0 + cosine * tcos[0] + sine * tsin[0]);
82        var dedphi = amp[0] * (cosine * tsin[0] - sine * tcos[0]);
83        var cosprev = cosine;
84        var sinprev = sine;
85        var n = torsionType.terms;
86        for (int i = 1; i < n; i++) {
87          var cosn = cosine * cosprev - sine * sinprev;
88          var sinn = sine * cosprev + cosine * sinprev;
89          var phi = 1.0 + cosn * tcos[i] + sinn * tsin[i];
90          var dphi = (1.0 + i) * (cosn * tsin[i] - sinn * tcos[i]);
91          energy = energy + amp[i] * phi;
92          dedphi = dedphi + amp[i] * dphi;
93          cosprev = cosn;
94          sinprev = sinn;
95        }
96        energy = units * energy * lambda;
97        dEdL = units * energy;
98        if (gradient || lambdaTerm) {
99          dedphi = units * dedphi;
100         var vca = vc.sub(va);
101         var vdb = vd.sub(vb);
102         var dedt = vt.X(vcb).scaleI(dedphi / (rt2 * rcb));
103         var dedu = vu.X(vcb).scaleI(-dedphi / (ru2 * rcb));
104         var ga = dedt.X(vcb);
105         var gb = vca.X(dedt).addI(dedu.X(vdc));
106         var gc = dedt.X(vba).addI(vdb.X(dedu));
107         var gd = dedu.X(vcb);
108         int ia = atomA.getIndex() - 1;
109         int ib = atomB.getIndex() - 1;
110         int ic = atomC.getIndex() - 1;
111         int id = atomD.getIndex() - 1;
112         if (lambdaTerm) {
113           lambdaGrad.add(threadID, ia, ga);
114           lambdaGrad.add(threadID, ib, gb);
115           lambdaGrad.add(threadID, ic, gc);
116           lambdaGrad.add(threadID, id, gd);
117         }
118         if (gradient) {
119           grad.add(threadID, ia, ga.scaleI(lambda));
120           grad.add(threadID, ib, gb.scaleI(lambda));
121           grad.add(threadID, ic, gc.scaleI(lambda));
122           grad.add(threadID, id, gd.scaleI(lambda));
123         }
124       }
125     }
126 
127     return energy;
128   }
129 
130   @Override
131   public double getLambda() {
132     return lambda;
133   }
134 
135   public Atom[] getAtoms() {
136     return Arrays.copyOf(atoms, atoms.length);
137   }
138 
139   @Override
140   public Atom getAtom(int index) {
141     return atoms[index];
142   }
143 
144   @Override
145   public boolean applyLambda() {
146     return lambdaTerm;
147   }
148 
149   @Override
150   public void setLambda(double lambda) {
151     this.lambda = lamMapper.applyAsDouble(lambda);
152   }
153 
154   @Override
155   public double getd2EdL2() {
156     return d2EdL2;
157   }
158 
159   @Override
160   public double getdEdL() {
161     return dEdL;
162   }
163 
164   public double mapLambda(double lambda) {
165     return lamMapper.applyAsDouble(lambda);
166   }
167 
168   @Override
169   public void getdEdXdL(double[] gradient) {
170     // The chain rule term is at least supposedly zero.
171   }
172 
173   @Override
174   public String toString() {
175     return String.format(" t-type %s, val %.3f, e %.3f", torsionType, value, energy);
176   }
177 }