1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 package ffx.potential.nonbonded;
39
40 import ffx.potential.MolecularAssembly;
41 import ffx.potential.bonded.Atom;
42 import ffx.potential.bonded.LambdaInterface;
43 import ffx.potential.bonded.MSNode;
44 import ffx.potential.bonded.Polymer;
45 import ffx.potential.parameters.ForceField;
46
47 import java.util.List;
48 import java.util.logging.Logger;
49
50 import static ffx.numerics.math.DoubleMath.length2;
51 import static java.util.Arrays.fill;
52 import static org.apache.commons.math3.util.FastMath.pow;
53
54
55
56
57
58
59
60 public class COMRestraint implements LambdaInterface {
61
62 private static final Logger logger = Logger.getLogger(COMRestraint.class.getName());
63 private final Atom[] atoms;
64 private final int nAtoms;
65 private final Polymer[] polymers;
66 private final List<MSNode> molecules;
67 private final List<MSNode> water;
68 private final List<MSNode> ions;
69 private final int nMolecules;
70
71
72
73 private final double forceConstant;
74
75 private final double[][] initialCOM;
76 private final double[][] currentCOM;
77 private final double[] dx = new double[3];
78 private final double[] dcomdx;
79 private final double lambdaExp = 1.0;
80 private final double[] lambdaGradient;
81 private double lambda = 1.0;
82 private double lambdaPow = pow(lambda, lambdaExp);
83 private double dLambdaPow = lambdaExp * pow(lambda, lambdaExp - 1.0);
84 private double d2LambdaPow = 0;
85 private double dEdL = 0.0;
86 private double d2EdL2 = 0.0;
87 private boolean lambdaTerm;
88
89
90
91
92
93
94
95 public COMRestraint(MolecularAssembly molecularAssembly) {
96 this.atoms = molecularAssembly.getAtomArray();
97 nAtoms = atoms.length;
98 this.polymers = molecularAssembly.getChains();
99 this.molecules = molecularAssembly.getMolecules();
100 this.water = molecularAssembly.getWater();
101 this.ions = molecularAssembly.getIons();
102 ForceField forceField = molecularAssembly.getForceField();
103
104 nMolecules = countMolecules();
105 initialCOM = new double[3][nMolecules];
106 currentCOM = new double[3][nMolecules];
107
108
109 lambdaTerm = false;
110
111 if (lambdaTerm) {
112 lambdaGradient = new double[nAtoms * 3];
113 } else {
114 lambdaGradient = null;
115 lambda = 1.0;
116 lambdaPow = 1.0;
117 dLambdaPow = 0.0;
118 d2LambdaPow = 0.0;
119 }
120 dcomdx = new double[nAtoms];
121 forceConstant = forceField.getDouble("COMRESTRAINT_K", 10.0);
122
123 computeCOM(initialCOM, nMolecules);
124
125 logger.info("\n COM restraint initialized");
126 }
127
128
129
130
131 @Override
132 public double getLambda() {
133 return lambda;
134 }
135
136
137
138
139 @Override
140 public void setLambda(double lambda) {
141 if (lambdaTerm) {
142 this.lambda = lambda;
143
144 double lambdaWindow = 1.0;
145
146 if (this.lambda <= lambdaWindow) {
147 double dldgl = 1.0 / lambdaWindow;
148 double l = dldgl * this.lambda;
149 double l2 = l * l;
150 double l3 = l2 * l;
151 double l4 = l2 * l2;
152 double l5 = l4 * l;
153 double c3 = 10.0;
154 double c4 = -15.0;
155 double c5 = 6.0;
156 double threeC3 = 3.0 * c3;
157 double sixC3 = 6.0 * c3;
158 double fourC4 = 4.0 * c4;
159 double twelveC4 = 12.0 * c4;
160 double fiveC5 = 5.0 * c5;
161 double twentyC5 = 20.0 * c5;
162 lambdaPow = c3 * l3 + c4 * l4 + c5 * l5;
163 dLambdaPow = (threeC3 * l2 + fourC4 * l3 + fiveC5 * l4) * dldgl;
164 d2LambdaPow = (sixC3 * l + twelveC4 * l2 + twentyC5 * l3) * dldgl * dldgl;
165 } else {
166 lambdaPow = 1.0;
167 dLambdaPow = 0.0;
168 d2LambdaPow = 0.0;
169 }
170 } else {
171 this.lambda = 1.0;
172 lambdaPow = 1.0;
173 dLambdaPow = 0.0;
174 d2LambdaPow = 0.0;
175 }
176 }
177
178
179
180
181 @Override
182 public double getd2EdL2() {
183 if (lambdaTerm) {
184 return d2EdL2;
185 } else {
186 return 0.0;
187 }
188 }
189
190
191
192
193 @Override
194 public double getdEdL() {
195 if (lambdaTerm) {
196 return dEdL;
197 } else {
198 return 0.0;
199 }
200 }
201
202
203
204
205 @Override
206 public void getdEdXdL(double[] gradient) {
207 if (lambdaTerm) {
208 int n3 = nAtoms * 3;
209 for (int i = 0; i < n3; i++) {
210 gradient[i] += lambdaGradient[i];
211 }
212 }
213 }
214
215
216
217
218
219
220
221
222 public double residual(boolean gradient, boolean print) {
223 if (lambdaTerm) {
224 dEdL = 0.0;
225 d2EdL2 = 0.0;
226 fill(lambdaGradient, 0.0);
227 }
228 double residual = 0.0;
229 double fx2 = forceConstant * 2.0;
230 computeCOM(currentCOM, nMolecules);
231 computedcomdx();
232 for (int i = 0; i < nMolecules; i++) {
233 dx[0] = currentCOM[0][i] - initialCOM[0][i];
234 dx[1] = currentCOM[1][i] - initialCOM[1][i];
235 dx[2] = currentCOM[2][i] - initialCOM[2][i];
236
237 double r2 = length2(dx);
238 residual += r2;
239 for (int j = 0; j < nAtoms; j++) {
240 if (gradient || lambdaTerm) {
241 final double dedx = dx[0] * fx2 * dcomdx[j];
242 final double dedy = dx[1] * fx2 * dcomdx[j];
243 final double dedz = dx[2] * fx2 * dcomdx[j];
244
245 Atom atom = atoms[j];
246 if (gradient) {
247 atom.addToXYZGradient(lambdaPow * dedx, lambdaPow * dedy, lambdaPow * dedz);
248 }
249 if (lambdaTerm) {
250 int j3 = i * 3;
251 lambdaGradient[j3] = dLambdaPow * dedx;
252 lambdaGradient[j3 + 1] = dLambdaPow * dedy;
253 lambdaGradient[j3 + 2] = dLambdaPow * dedz;
254 }
255 }
256 }
257 }
258 if (lambdaTerm) {
259 dEdL = dLambdaPow * forceConstant * residual;
260 d2EdL2 = d2LambdaPow * forceConstant * residual;
261 }
262 return forceConstant * residual * lambdaPow;
263 }
264
265
266
267
268
269
270 public void setLambdaTerm(boolean lambdaTerm) {
271 this.lambdaTerm = lambdaTerm;
272 setLambda(lambda);
273 }
274
275 private void computeCOM(double[][] com, int nMolecules) {
276 int i = 0;
277 while (i < nMolecules) {
278 if (polymers != null && polymers.length > 0) {
279
280 for (Polymer polymer : polymers) {
281 List<Atom> list = polymer.getAtomList();
282 com[0][i] = 0.0;
283 com[1][i] = 0.0;
284 com[2][i] = 0.0;
285 double totalMass = 0.0;
286 for (Atom atom : list) {
287 double m = atom.getMass();
288 com[0][i] += atom.getX() * m;
289 com[1][i] += atom.getY() * m;
290 com[2][i] += atom.getZ() * m;
291 totalMass += m;
292 }
293 com[0][i] /= totalMass;
294 com[1][i] /= totalMass;
295 com[2][i] /= totalMass;
296 i++;
297 }
298 }
299
300
301 for (MSNode molecule : molecules) {
302 List<Atom> list = molecule.getAtomList();
303
304 com[0][i] = 0.0;
305 com[1][i] = 0.0;
306 com[2][i] = 0.0;
307 double totalMass = 0.0;
308 for (Atom atom : list) {
309 double m = atom.getMass();
310 com[0][i] += atom.getX() * m;
311 com[1][i] += atom.getY() * m;
312 com[2][i] += atom.getZ() * m;
313 totalMass += m;
314 }
315 com[0][i] /= totalMass;
316 com[1][i] /= totalMass;
317 com[2][i] /= totalMass;
318 i++;
319 }
320
321
322 for (MSNode water : water) {
323 List<Atom> list = water.getAtomList();
324
325 com[0][i] = 0.0;
326 com[1][i] = 0.0;
327 com[2][i] = 0.0;
328 double totalMass = 0.0;
329 for (Atom atom : list) {
330 double m = atom.getMass();
331 com[0][i] += atom.getX() * m;
332 com[1][i] += atom.getY() * m;
333 com[2][i] += atom.getZ() * m;
334 totalMass += m;
335 }
336 com[0][i] /= totalMass;
337 com[1][i] /= totalMass;
338 com[2][i] /= totalMass;
339 i++;
340 }
341
342
343 for (MSNode ion : ions) {
344 List<Atom> list = ion.getAtomList();
345
346 com[0][i] = 0.0;
347 com[1][i] = 0.0;
348 com[2][i] = 0.0;
349 double totalMass = 0.0;
350 for (Atom atom : list) {
351 double m = atom.getMass();
352 com[0][i] += atom.getX() * m;
353 com[1][i] += atom.getY() * m;
354 com[2][i] += atom.getZ() * m;
355 totalMass += m;
356 }
357 com[0][i] /= totalMass;
358 com[1][i] /= totalMass;
359 com[2][i] /= totalMass;
360 i++;
361 }
362 }
363 }
364
365 private void computedcomdx() {
366
367 int i = 0;
368 while (i < nAtoms) {
369 if (polymers != null && polymers.length > 0) {
370 for (Polymer polymer : polymers) {
371 List<Atom> list = polymer.getAtomList();
372 double totalMass = 0.0;
373 for (Atom atom : list) {
374 double m = atom.getMass();
375 totalMass += m;
376 }
377 for (Atom atom : list) {
378 dcomdx[i] = atom.getMass();
379 dcomdx[i] /= totalMass;
380 i++;
381 }
382 }
383 }
384
385
386 for (MSNode molecule : molecules) {
387 List<Atom> list = molecule.getAtomList();
388 double totalMass = 0.0;
389 for (Atom atom : list) {
390 double m = atom.getMass();
391 totalMass += m;
392 }
393 for (Atom atom : list) {
394 dcomdx[i] = atom.getMass();
395 dcomdx[i] /= totalMass;
396 i++;
397 }
398 }
399
400
401 for (MSNode water : water) {
402 List<Atom> list = water.getAtomList();
403 double totalMass = 0.0;
404 for (Atom atom : list) {
405 double m = atom.getMass();
406 totalMass += m;
407 }
408 for (Atom atom : list) {
409 dcomdx[i] = atom.getMass();
410 dcomdx[i] /= totalMass;
411 i++;
412 }
413 }
414
415
416 for (MSNode ion : ions) {
417 List<Atom> list = ion.getAtomList();
418 double totalMass = 0.0;
419 for (Atom atom : list) {
420 double m = atom.getMass();
421 totalMass += m;
422 }
423 for (Atom atom : list) {
424 dcomdx[i] = atom.getMass();
425 dcomdx[i] /= totalMass;
426 i++;
427 }
428 }
429 }
430
431
432
433
434
435 }
436
437 private int countMolecules() {
438 int count = 0;
439
440 if (polymers != null && polymers.length > 0) {
441 count += polymers.length;
442 }
443 if (molecules != null) {
444 count += molecules.size();
445 }
446 if (water != null) {
447 count += water.size();
448 }
449 if (ions != null) {
450 count += ions.size();
451 }
452 return count;
453 }
454 }