View Javadoc
1   //******************************************************************************
2   //
3   // Title:       Force Field X.
4   // Description: Force Field X - Software for Molecular Biophysics.
5   // Copyright:   Copyright (c) Michael J. Schnieders 2001-2025.
6   //
7   // This file is part of Force Field X.
8   //
9   // Force Field X is free software; you can redistribute it and/or modify it
10  // under the terms of the GNU General Public License version 3 as published by
11  // the Free Software Foundation.
12  //
13  // Force Field X is distributed in the hope that it will be useful, but WITHOUT
14  // ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
15  // FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
16  // details.
17  //
18  // You should have received a copy of the GNU General Public License along with
19  // Force Field X; if not, write to the Free Software Foundation, Inc., 59 Temple
20  // Place, Suite 330, Boston, MA 02111-1307 USA
21  //
22  // Linking this library statically or dynamically with other modules is making a
23  // combined work based on this library. Thus, the terms and conditions of the
24  // GNU General Public License cover the whole combination.
25  //
26  // As a special exception, the copyright holders of this library give you
27  // permission to link this library with independent modules to produce an
28  // executable, regardless of the license terms of these independent modules, and
29  // to copy and distribute the resulting executable under terms of your choice,
30  // provided that you also meet, for each linked independent module, the terms
31  // and conditions of the license of that module. An independent module is a
32  // module which is not derived from or based on this library. If you modify this
33  // library, you may extend this exception to your version of the library, but
34  // you are not obligated to do so. If you do not wish to do so, delete this
35  // exception statement from your version.
36  //
37  //******************************************************************************
38  package ffx.algorithms;
39  
40  import edu.rit.mp.DoubleBuf;
41  import edu.rit.mp.IntegerBuf;
42  import edu.rit.pj.Comm;
43  import ffx.crystal.Crystal;
44  import ffx.crystal.CrystalPotential;
45  import ffx.numerics.Potential;
46  import ffx.potential.MolecularAssembly;
47  import ffx.potential.Utilities;
48  import ffx.potential.bonded.LambdaInterface;
49  import ffx.potential.parsers.SystemFilter;
50  import org.apache.commons.configuration2.CompositeConfiguration;
51  
52  import java.io.File;
53  import java.util.logging.Logger;
54  
55  import static java.lang.String.format;
56  import static java.lang.System.arraycopy;
57  
58  /**
59   * The ParallelStateEnergy class evaluates the energy of a system at different lambda values.
60   *
61   * @since 1.0
62   * @author Michael J. Schnieders
63   */
64  public class ParallelStateEnergy {
65  
66    private static final Logger logger = Logger.getLogger(ParallelStateEnergy.class.getName());
67  
68    /**
69     * Parallel Java world communicator.
70     */
71    private Comm world;
72    /**
73     * If false, do not use MPI communication.
74     */
75    private final boolean useMPI;
76    /**
77     * Number of processes.
78     */
79    private int numProc;
80    /**
81     * Rank of this process.
82     */
83    private int rank;
84    /**
85     * Number of states.
86     */
87    private int nStates;
88    /**
89     * Lambda value for each state.
90     */
91    private double[] lambdaValues;
92    /**
93     * The amount of work based on windows for each process.
94     */
95    private int statesPerProcess;
96  
97    /**
98     * Number of samples for each state. The array is of size [statesPerProcess].
99     */
100   private final int[] nSamples;
101   /**
102    * The energy from evaluating at L - dL. The array is of size [statesPerProcess][snapshots].
103    */
104   private final double[][] energiesLowPJ;
105   /**
106    * The energy from evaluating at L. The array is of size [statesPerProcess][snapshots].
107    */
108   private final double[][] energiesAtPJ;
109   /**
110    * The energy from evaluating at L + dL. The array is of size [statesPerProcess][snapshots].
111    */
112   private final double[][] energiesHighPJ;
113   /**
114    * The volume of each snapshot. The array is of size [statesPerProcess][snapshots].
115    */
116   private final double[][] volumePJ;
117   /**
118    * Number of samples for each state. The array is of size [statesPerProcess].
119    */
120   private final IntegerBuf bufferNSamples;
121   /**
122    * Convenience reference for the DoubleBuf of this process.
123    */
124   private DoubleBuf bufferLow;
125   /**
126    * Convenience reference for the DoubleBuf of this process.
127    */
128   private DoubleBuf bufferAt;
129   /**
130    * Convenience reference for the DoubleBuf of this process.
131    */
132   private DoubleBuf bufferHigh;
133   /**
134    * Convenience reference for the DoubleBuf of this process.
135    */
136   private DoubleBuf bufferVolume;
137 
138   /**
139    * The MolecularAssembly for each topology.
140    */
141   private final MolecularAssembly[] molecularAssemblies;
142   /**
143    * The SystemFilter for each topology.
144    */
145   private final SystemFilter[] openers;
146   /**
147    * The potential to evaluate.
148    */
149   private final Potential potential;
150   /**
151    * The full file paths for each state.
152    */
153   private final String[][] fullFilePaths;
154 
155   /**
156    * The ParallelEnergy constructor.
157    *
158    * @param nStates             The number of states.
159    * @param lambdaValues        The lambda values.
160    * @param molecularAssemblies The molecular assemblies.
161    * @param potential           The potential to evaluate.
162    * @param fullFilePaths       The full file paths for each state.
163    * @param openers             The system filters.
164    */
165   public ParallelStateEnergy(int nStates, double[] lambdaValues,
166                              MolecularAssembly[] molecularAssemblies, Potential potential,
167                              String[][] fullFilePaths, SystemFilter[] openers) {
168 
169     this.nStates = nStates;
170     this.lambdaValues = lambdaValues;
171     this.molecularAssemblies = molecularAssemblies;
172     this.potential = potential;
173     this.fullFilePaths = fullFilePaths;
174     this.openers = openers;
175 
176     // Default to a single process that processes all states.
177     world = Comm.world();
178     numProc = 1;
179     rank = 0;
180     statesPerProcess = nStates;
181 
182     // Determine if use of PJ is specified.
183     CompositeConfiguration properties = molecularAssemblies[0].getProperties();
184     useMPI = properties.getBoolean("pj.use.mpi", true);
185     if (useMPI) {
186       // Number of processes.
187       numProc = world.size();
188       // Each processor gets its own rank.
189       rank = world.rank();
190 
191       // Padding of the target array size (inner loop limit) is for parallelization.
192       // Target states are parallelized over available nodes.
193       // For example, if numProc = 8 and nStates = 12, then paddednWindows = 16.
194       int extra = nStates % numProc;
195       int paddednWindows = nStates;
196       if (extra != 0) {
197         paddednWindows = nStates - extra + numProc;
198       }
199       statesPerProcess = paddednWindows / numProc;
200 
201       if (numProc > 1) {
202         logger.fine(format(" Number of MPI Processes:  %d", numProc));
203         logger.fine(format(" Rank of this MPI Process: %d", rank));
204         logger.fine(format(" States per process per row: %d", statesPerProcess));
205       }
206     }
207 
208     // Initialize arrays for storing energy values.
209     nSamples = new int[statesPerProcess];
210     energiesLowPJ = new double[statesPerProcess][];
211     energiesAtPJ = new double[statesPerProcess][];
212     energiesHighPJ = new double[statesPerProcess][];
213     volumePJ = new double[statesPerProcess][];
214     bufferNSamples = IntegerBuf.buffer(nSamples);
215     bufferLow = DoubleBuf.buffer(energiesLowPJ);
216     bufferAt = DoubleBuf.buffer(energiesAtPJ);
217     bufferHigh = DoubleBuf.buffer(energiesHighPJ);
218     bufferVolume = DoubleBuf.buffer(volumePJ);
219   }
220 
221   /**
222    * Get the rank of this process.
223    *
224    * @return The rank.
225    */
226   public int getRank() {
227     return rank;
228   }
229 
230   /**
231    * Evaluate the energies for each state.
232    *
233    * @param energiesLow  The energy from evaluating at L - dL.
234    * @param energiesAt   The energy from evaluating at L.
235    * @param energiesHigh The energy from evaluating at L + dL.
236    * @param volume       The volume of each snapshot.
237    */
238   public void evaluateStates(double[][] energiesLow,
239                              double[][] energiesAt,
240                              double[][] energiesHigh,
241                              double[][] volume) {
242     double[] currentLambdas;
243     double[][] energy;
244     int nCurrLambdas;
245     for (int state = 0; state < nStates; state++) {
246       if (state == 0) {
247         currentLambdas = new double[2];
248         currentLambdas[0] = lambdaValues[state];
249         currentLambdas[1] = lambdaValues[state + 1];
250       } else if (state == nStates - 1) {
251         currentLambdas = new double[2];
252         currentLambdas[0] = lambdaValues[state - 1];
253         currentLambdas[1] = lambdaValues[state];
254       } else {
255         currentLambdas = new double[3];
256         currentLambdas[0] = lambdaValues[state - 1];
257         currentLambdas[1] = lambdaValues[state];
258         currentLambdas[2] = lambdaValues[state + 1];
259       }
260       nCurrLambdas = currentLambdas.length;
261       energy = new double[nCurrLambdas][];
262       evaluateEnergies(state, currentLambdas, energy, fullFilePaths);
263     }
264 
265     gatherAllValues(energiesLow, energiesAt, energiesHigh, volume);
266   }
267 
268 
269   /**
270    * Evaluate the energies for a given state.
271    *
272    * @param state         The state.
273    * @param lambdaValues  The current lambda values.
274    * @param energy        The energy values.
275    * @param fullFilePaths The full file paths.
276    */
277   private void evaluateEnergies(int state, double[] lambdaValues,
278                                 double[][] energy, String[][] fullFilePaths) {
279     if (state % numProc == rank) {
280       int workItem = state / numProc;
281       double[] vol = getEnergyForLambdas(lambdaValues, fullFilePaths[state], energy);
282       int len = energy[0].length;
283       nSamples[workItem] = len;
284       energiesLowPJ[workItem] = new double[len];
285       energiesAtPJ[workItem] = new double[len];
286       energiesHighPJ[workItem] = new double[len];
287       volumePJ[workItem] = new double[len];
288       if (state == 0) {
289         arraycopy(energy[0], 0, energiesAtPJ[workItem], 0, len);
290         arraycopy(energy[1], 0, energiesHighPJ[workItem], 0, len);
291       } else if (state < nStates - 1) {
292         arraycopy(energy[0], 0, energiesLowPJ[workItem], 0, len);
293         arraycopy(energy[1], 0, energiesAtPJ[workItem], 0, len);
294         arraycopy(energy[2], 0, energiesHighPJ[workItem], 0, len);
295       } else if (state == nStates - 1) {
296         arraycopy(energy[0], 0, energiesLowPJ[workItem], 0, len);
297         arraycopy(energy[1], 0, energiesAtPJ[workItem], 0, len);
298       }
299       if (vol != null) {
300         arraycopy(vol, 0, volumePJ[workItem], 0, len);
301       }
302     }
303   }
304 
305   /**
306    * Gather all energy values from all nodes.
307    * This method calls <code>world.gather</code> to collect numProc values.
308    *
309    * @param energiesLow  The energy from evaluating at L - dL.
310    * @param energiesAt   The energy from evaluating at L.
311    * @param energiesHigh The energy from evaluating at L + dL.
312    * @param volume       The volume of each snapshot.
313    */
314   private void gatherAllValues(double[][] energiesLow,
315                                double[][] energiesAt,
316                                double[][] energiesHigh,
317                                double[][] volume) {
318     if (useMPI) {
319       try {
320         if (rank != 0) {
321           // Send all results to node 0.
322           world.send(0, bufferNSamples);
323           for (int workItem = 0; workItem < statesPerProcess; workItem++) {
324             bufferLow = DoubleBuf.buffer(energiesLowPJ[workItem]);
325             bufferAt = DoubleBuf.buffer(energiesAtPJ[workItem]);
326             bufferHigh = DoubleBuf.buffer(energiesHighPJ[workItem]);
327             bufferVolume = DoubleBuf.buffer(volumePJ[workItem]);
328             world.send(0, bufferLow);
329             world.send(0, bufferAt);
330             world.send(0, bufferHigh);
331             world.send(0, bufferVolume);
332           }
333         } else {
334           for (int proc = 0; proc < numProc; proc++) {
335             // Receive all results from another node.
336             if (proc > 0) {
337               // Receive the number of samples for each state.
338               world.receive(proc, bufferNSamples);
339             }
340             // Store results in the appropriate arrays.
341             for (int workItem = 0; workItem < statesPerProcess; workItem++) {
342               final int state = numProc * workItem + proc;
343               final int nSnapshots = nSamples[workItem];
344               // Do not include padded results.
345               if (state < nStates) {
346                 if (proc > 0) {
347                   // Ensure array size.
348                   updateMemory(workItem, nSnapshots);
349                   world.receive(proc, bufferLow);
350                   world.receive(proc, bufferAt);
351                   world.receive(proc, bufferHigh);
352                   world.receive(proc, bufferVolume);
353                 }
354                 energiesLow[state] = new double[nSnapshots];
355                 energiesAt[state] = new double[nSnapshots];
356                 energiesHigh[state] = new double[nSnapshots];
357                 volume[state] = new double[nSnapshots];
358                 arraycopy(energiesLowPJ[workItem], 0, energiesLow[state], 0, nSnapshots);
359                 arraycopy(energiesAtPJ[workItem], 0, energiesAt[state], 0, nSnapshots);
360                 arraycopy(energiesHighPJ[workItem], 0, energiesHigh[state], 0, nSnapshots);
361                 arraycopy(volumePJ[workItem], 0, volume[state], 0, nSnapshots);
362               }
363             }
364           }
365         }
366         // Wait for all processes to finish.
367         world.barrier();
368       } catch (Exception ex) {
369         logger.severe(" Exception collecting energy values." + ex + Utilities.stackTraceToString(ex));
370       }
371     } else {
372       for (int i = 0; i < nStates; i++) {
373         int len = energiesAtPJ[rank].length;
374         energiesLow[i] = new double[len];
375         energiesAt[i] = new double[len];
376         energiesHigh[i] = new double[len];
377         volume[i] = new double[len];
378         arraycopy(energiesLowPJ[i], 0, energiesLow[i], 0, len);
379         arraycopy(energiesAtPJ[i], 0, energiesAt[i], 0, len);
380         arraycopy(energiesHighPJ[i], 0, energiesHigh[i], 0, len);
381         arraycopy(volumePJ[i], 0, volume[i], 0, len);
382       }
383     }
384   }
385 
386   /**
387    * Compute energy values for each lambda value.
388    *
389    * @param lambdaValues The lambda values.
390    * @param arcFileName  The archive file names.
391    * @param energy       The energy values.
392    * @return The volume of each snapshot.
393    */
394   private double[] getEnergyForLambdas(double[] lambdaValues,
395                                        String[] arcFileName, double[][] energy) {
396 
397     int numTopologies = molecularAssemblies.length;
398 
399     // Initialize the potential to use the correct archive files.
400     StringBuilder sb = new StringBuilder("\n");
401     for (int j = 0; j < numTopologies; j++) {
402       File archiveFile = new File(arcFileName[j]);
403       openers[j].setFile(archiveFile);
404       molecularAssemblies[j].setFile(archiveFile);
405       sb.append(format(" Evaluating energies for file: %s\n", arcFileName[j]));
406     }
407     sb.append("\n");
408     logger.info(sb.toString());
409 
410     int nSnapshots = openers[0].countNumModels();
411     double[] x = new double[potential.getNumberOfVariables()];
412     double[] vol = new double[nSnapshots];
413     int nLambdas = lambdaValues.length;
414     for (int k = 0; k < nLambdas; k++) {
415       energy[k] = new double[nSnapshots];
416     }
417 
418     LambdaInterface linter1 = (LambdaInterface) potential;
419 
420     int endWindow = nStates - 1;
421     String endWindows = endWindow + File.separator;
422 
423     if (arcFileName[0].contains(endWindows)) {
424       logger.info(format(" %s     %s   %s     %s   %s ", "Snapshot", "Lambda Low",
425           "Energy Low", "Lambda At", "Energy At"));
426     } else if (arcFileName[0].contains("0/")) {
427       logger.info(format(" %s     %s   %s     %s   %s ", "Snapshot", "Lambda At",
428           "Energy At", "Lambda High", "Energy High"));
429     } else {
430       logger.info(format(" %s     %s   %s     %s   %s     %s   %s ", "Snapshot", "Lambda Low",
431           "Energy Low", "Lambda At", "Energy At", "Lambda High", "Energy High"));
432     }
433 
434     for (int i = 0; i < nSnapshots; i++) {
435       boolean resetPosition = (i == 0);
436       int nOpeners = openers.length;
437       for (int n = 0; n < nOpeners; n++) {
438         openers[n].readNext(resetPosition, false);
439       }
440 
441       x = potential.getCoordinates(x);
442       nLambdas = lambdaValues.length;
443       for (int k = 0; k < nLambdas; k++) {
444         double lambda = lambdaValues[k];
445         linter1.setLambda(lambda);
446         energy[k][i] = potential.energy(x, false);
447       }
448 
449       if (nLambdas == 2) {
450         logger.info(format(" %8d     %6.3f   %14.4f     %6.3f   %14.4f ", i + 1,
451             lambdaValues[0], energy[0][i], lambdaValues[1], energy[1][i]));
452       } else {
453         logger.info(format(" %8d     %6.3f   %14.4f     %6.3f   %14.4f     %6.3f   %14.4f ", i + 1,
454             lambdaValues[0], energy[0][i], lambdaValues[1], energy[1][i], lambdaValues[2], energy[2][i]));
455       }
456 
457       Crystal unitCell;
458       if (potential instanceof CrystalPotential) {
459         unitCell = ((CrystalPotential) potential).getCrystal().getUnitCell();
460       } else {
461         unitCell = molecularAssemblies[0].getCrystal().getUnitCell();
462       }
463 
464       if (!unitCell.aperiodic()) {
465         int nSymm = unitCell.getNumSymOps();
466         vol[i] = unitCell.volume / nSymm;
467       }
468     }
469 
470     return vol;
471   }
472 
473   /**
474    * Update the memory to receive energy values given the number of snapshots.
475    *
476    * @param workItem   The work item to update.
477    * @param nSnapshots The number of snapshots.
478    */
479   private void updateMemory(int workItem, int nSnapshots) {
480     if (energiesAtPJ[workItem].length < nSnapshots) {
481       energiesLowPJ[workItem] = new double[nSnapshots];
482       energiesAtPJ[workItem] = new double[nSnapshots];
483       energiesHighPJ[workItem] = new double[nSnapshots];
484       volumePJ[workItem] = new double[nSnapshots];
485     }
486     bufferLow = DoubleBuf.buffer(energiesLowPJ[workItem]);
487     bufferAt = DoubleBuf.buffer(energiesAtPJ[workItem]);
488     bufferHigh = DoubleBuf.buffer(energiesHighPJ[workItem]);
489     bufferVolume = DoubleBuf.buffer(volumePJ[workItem]);
490   }
491 }