View Javadoc
1   package ffx.numerics.estimator;
2   
3   import ffx.numerics.estimator.MultistateBennettAcceptanceRatio.SeedType;
4   
5   import java.io.BufferedReader;
6   import java.io.File;
7   import java.io.FileReader;
8   import java.io.IOException;
9   import java.util.ArrayList;
10  import java.util.Arrays;
11  import java.util.logging.Logger;
12  
13  import static org.apache.commons.lang3.math.NumberUtils.max;
14  import static org.apache.commons.lang3.math.NumberUtils.min;
15  
16  /**
17   * The MBARFilter class parses mbar (*.mbar or *.bar) files. Expected file format is a header
18   * line including the number of snapshots contained, a name, and a temperature. Following the header
19   * is a list of energies for each snapshot at each lambda value being considered with an index to start
20   * the line. Then the energies go from least (0) to greatest (1) lambda value.
21   * <p>
22   * Files less than numLambda states are handled. Users should simply generate MBAR files with desired number of lambda
23   * windows, this filter should handle the rest and warn about potential issues.
24   *
25   * @author Matthew J. Speranza
26   * @since 1.0
27   */
28  public class MBARFilter {
29    private static final Logger logger = Logger.getLogger(MBARFilter.class.getName());
30    private File[] barFiles;
31    private final File fileLocation;
32    private double[][][] eAll;
33    private double[][] eAllFlat;
34    private double[] temperatures;
35    private int[] snaps;
36    private final int[] numLambdas;
37    private int windowsRead;
38    private int windows;
39    private MultistateBennettAcceptanceRatio mbar;
40    private int startIndex = -1;
41    private int endIndex = -1;
42    private int numLambda;
43    private boolean oneFile = false;
44  
45    /**
46     * Constructor for MBARFilter.
47     *
48     * @param fileLocation     the directory containing the mbar files
49     * @param continuousLambda whether a single file is being used.
50     */
51    public MBARFilter(File fileLocation, boolean continuousLambda) {
52      this.fileLocation = fileLocation;
53      barFiles = fileLocation.listFiles((dir, name) -> name.matches("energy_\\d+.mbar") || name.matches("energy_\\d+.bar"));
54      assert barFiles != null;
55      if (barFiles.length == 0) {
56        logger.severe(" No files matching 'energy_\\d+.mbar' or 'energy_\\d+.bar' found in " +
57            fileLocation.getAbsolutePath());
58      }
59      // Sort files by state number
60      Arrays.sort(barFiles, (f1, f2) -> {
61        int state1 = Integer.parseInt(f1.getName().split("\\.")[0].split("_")[1]);
62        int state2 = Integer.parseInt(f2.getName().split("\\.")[0].split("_")[1]);
63        return Integer.compare(state1, state2);
64      });
65      windows = barFiles.length;
66      temperatures = new double[windows];
67      snaps = new int[windows];
68      numLambdas = new int[windows];
69      if (continuousLambda) {
70        this.parseFile();
71        oneFile = true;
72      } else {
73        this.parseFiles();
74      }
75    }
76  
77    /**
78     * Create an MBAR instance with the given seed type.
79     *
80     * @param seedType the seed type to use.
81     * @return an MBAR instance.
82     */
83    public MultistateBennettAcceptanceRatio getMBAR(SeedType seedType) {
84      return getMBAR(seedType, 1e-7);
85    }
86  
87    /**
88     * Create an MBAR instance with the given seed type and tolerance.
89     *
90     * @param seedType  the seed type to use.
91     * @param tolerance the tolerance to use.
92     * @return an MBAR instance.
93     */
94    public MultistateBennettAcceptanceRatio getMBAR(SeedType seedType, double tolerance) {
95      double[] lambda = new double[windows];
96      for (int i = 0; i < windows; i++) {
97        lambda[i] = i / (windows - 1.0);
98      }
99      if (eAll != null) {
100       this.mbar = new MultistateBennettAcceptanceRatio(lambda, eAll, temperatures, tolerance, seedType);
101     } else {
102       this.mbar = new MultistateBennettAcceptanceRatio(lambda, snaps, eAllFlat, temperatures, tolerance, seedType);
103     }
104     return this.mbar;
105   }
106 
107   /**
108    * 10% of the total samples at different time points.
109    *
110    * @param seedType  the seed type to use.
111    * @param tolerance the tolerance to use.
112    * @return an array of MBAR objects
113    */
114   public MultistateBennettAcceptanceRatio[] getPeriodComparisonMBAR(SeedType seedType, double tolerance) {
115     double[] lambda = new double[windows];
116     for (int i = 0; i < windows; i++) {
117       lambda[i] = i / (windows - 1.0);
118     }
119     MultistateBennettAcceptanceRatio[] mbar = new MultistateBennettAcceptanceRatio[10];
120     for (int i = 0; i < 10; i++) {
121       double[][][] e = new double[windows][][];
122       int maxSamples = max(snaps);
123       int timePeriod = maxSamples / 10;
124       for (int j = 0; j < windows; j++) {
125         e[j] = new double[windows][];
126         for (int k = 0; k < windows; k++) {
127           e[j][k] = new double[timePeriod];
128           if (timePeriod * (i + 1) > maxSamples) {
129             System.arraycopy(eAll[j][k], timePeriod * i, e[j][k], 0, maxSamples - timePeriod * i);
130           } else {
131             System.arraycopy(eAll[j][k], timePeriod * i, e[j][k], 0, timePeriod);
132           }
133         }
134       }
135       logger.info(" Period: " + (timePeriod * i) + " - " + (timePeriod * (i + 1)) + " samples calculation.");
136       mbar[i] = new MultistateBennettAcceptanceRatio(lambda, e, temperatures, tolerance, seedType);
137     }
138     return mbar;
139   }
140 
141   /**
142    * Parse all files in the directory.
143    */
144   private void parseFiles() {
145     eAll = new double[windows][][];
146     for (int i = 0; i < windows; i++) {
147       eAll[i] = readFile(barFiles[i].getName(), i);
148     }
149     if (windowsRead != windows) {
150       logger.severe("Failed to read all files in " + fileLocation.getAbsolutePath());
151     }
152     int minSnaps = min(snaps);
153     int maxSnaps = max(snaps);
154 
155     // Basically just make sure eAll isn't jagged
156     boolean warn = minSnaps != maxSnaps;
157     if (warn) {
158       logger.warning("NOT ALL FILES CONTAINED THE SAME NUMBER OF SNAPSHOTS. ");
159       logger.warning("SAMPLES PER WINDOW: " + Arrays.toString(snaps));
160       double[][][] temp = new double[eAll.length][eAll[0].length][maxSnaps];
161       for (int j = 0; j < windows; j++) {
162         for (int k = 0; k < windows; k++) {
163           System.arraycopy(eAll[j][k], 0, temp[j][k], 0, snaps[j]);
164           for (int l = snaps[j]; l < maxSnaps; l++) { // Fill in the rest with NaNs
165             temp[j][k][l] = Double.NaN;
166           }
167         }
168       }
169       eAll = temp;
170     }
171     // Fail if not all files have the same number of energy evaluations across lambda
172     int maxLambdas = max(numLambdas);
173     for (int i = 0; i < windows; i++) {
174       if (numLambdas[i] != maxLambdas) {
175         logger.severe(" Number of lambda evaluations in file " + barFiles[i].getName() +
176             " does not match the number of lambda evaluations in the other files. This is unrecoverable.");
177       }
178     }
179     // Handle files with more lambda windows than actual trajectories
180     warn = maxLambdas != windows;
181     if (warn) {
182       String symbol = maxLambdas > windows ? "MORE" : "LESS";
183       logger.warning("FILES CONTAIN " + symbol + " LAMBDA EVALUATIONS THAN ACTUAL TRAJECTORIES.");
184       if (windows == 1) {
185         logger.warning(" USE --continuousLambda FLAG IF USING A SINGLE FILE.");
186       }
187       symbol = maxLambdas > windows ? "Add" : "Remove";
188       logger.severe(symbol + " completely empty files (zero lines) to fill in the gaps.");
189     }
190   }
191 
192   /**
193    * Parse a single file.
194    */
195   private void parseFile() {
196     eAllFlat = readFile(barFiles[0].getName(), 0);
197     // Reset variables
198     snaps = new int[numLambda];
199     for (int i = 0; i < numLambda; i++) {
200       snaps[i] = eAllFlat[i].length;
201     }
202     windows = numLambda;
203     double temp = temperatures[0];
204     temperatures = new double[windows];
205     for (int i = 0; i < windows; i++) {
206       temperatures[i] = temp;
207     }
208   }
209 
210   /**
211    * Write the energies to files.
212    *
213    * @param mbarFileLoc  the directory to write the files to.
214    * @param energies     the energies to write.
215    * @param temperatures the temperatures to write.
216    */
217   public void writeFiles(File mbarFileLoc, double[][][] energies, double[] temperatures) {
218     if (temperatures.length != windows) {
219       double temp = temperatures[0];
220       temperatures = new double[windows];
221       for (int i = 0; i < windows; i++) {
222         temperatures[i] = temp;
223       }
224     }
225     for (int i = 0; i < windows; i++) {
226       File file = new File(mbarFileLoc, "energy_" + i + ".mbar");
227       writeFile(energies[i], file, temperatures[i]);
228     }
229   }
230 
231   /**
232    * Parses the file matching the name given in the directory of 'fileLocation'.
233    *
234    * @param fileName the name of the file to be parsed matching 'energy_\d+.mbar' or 'energy_\d+.bar'.
235    * @return a double[][] of the energies for each snapshot at each lambda value
236    */
237   private double[][] readFile(String fileName, int state) {
238     File tempBarFile = new File(fileLocation, fileName);
239     ArrayList<ArrayList<Double>> tempFileEnergies = new ArrayList<>();
240     for (int i = 0; i < windows; i++) {
241       tempFileEnergies.add(new ArrayList<>());
242     }
243     double[][] fileEnergies;
244     try (FileReader fr1 = new FileReader(tempBarFile);
245          BufferedReader br1 = new BufferedReader(fr1);) {
246       // Read header
247       String line = br1.readLine();
248       if (line == null) { // Empty file
249         for (int i = 0; i < windows; i++) {
250           tempFileEnergies.get(i).add(Double.NaN);
251         }
252         snaps[state] = 0;
253         temperatures[state] = 298; // Assumed default temp since 0 leads to division by zero
254         MultistateBennettAcceptanceRatio.FORCE_ZEROS_SEED = true;
255         if (state != 0) {
256           numLambdas[state] = numLambdas[state - 1];
257         }
258         windowsRead++;
259         fileEnergies = new double[windows][];
260         for (int i = 0; i < windows; i++) {
261           fileEnergies[i] = new double[tempFileEnergies.get(i).size()];
262           for (int j = 0; j < tempFileEnergies.get(i).size(); j++) {
263             fileEnergies[i][j] = tempFileEnergies.get(i).get(j);
264           }
265         }
266         return fileEnergies;
267       }
268       String[] tokens = line.trim().split("\\t *| +");
269       temperatures[state] = Double.parseDouble(tokens[1]);
270       // Read energies (however many there are)
271       int count = 0;
272       numLambda = 0;
273       line = br1.readLine();
274       while (line != null) {
275         tokens = line.trim().split("\\t *| +");
276         numLambda = tokens.length - 1;
277         for (int i = 1; i < tokens.length; i++) {
278           if (tempFileEnergies.size() < i) {
279             tempFileEnergies.add(new ArrayList<>());
280           }
281           tempFileEnergies.get(i - 1).add(Double.parseDouble(tokens[i]));
282         }
283         count++;
284         line = br1.readLine();
285       }
286       numLambdas[state] = numLambda;
287       if (state != 0 && numLambdas[0] == 0) { // If the zeroth window is missing this wasn't set yet
288         numLambdas[0] = numLambda;
289       }
290       snaps[state] = count;
291     } catch (IOException e) {
292       logger.info("Failed to read MBAR file: " + tempBarFile.getAbsolutePath());
293       throw new RuntimeException(e);
294     }
295     // Convert to double[][]
296     fileEnergies = new double[tempFileEnergies.size()][];
297     for (int i = 0; i < tempFileEnergies.size(); i++) {
298       fileEnergies[i] = new double[tempFileEnergies.get(i).size()];
299       for (int j = 0; j < tempFileEnergies.get(i).size(); j++) {
300         fileEnergies[i][j] = tempFileEnergies.get(i).get(j);
301       }
302     }
303     windowsRead++;
304     return fileEnergies;
305   }
306 
307   /**
308    * Write the energies to a file.
309    *
310    * @param energies    the energies to write
311    * @param file        the file to write to
312    * @param temperature the temperature to write
313    */
314   public void writeFile(double[][] energies, File file, double temperature) {
315     MultistateBennettAcceptanceRatio.writeFile(energies, file, temperature);
316   }
317 
318   /**
319    * Set the start snapshot.
320    *
321    * @param startIndex the start index.
322    */
323   public void setStartSnapshot(int startIndex) {
324     this.startIndex = startIndex;
325     if (oneFile) {
326       for (int i = 0; i < eAllFlat.length; i++) {
327         try {
328           eAllFlat[i] = Arrays.copyOfRange(eAllFlat[i], startIndex, eAllFlat[i].length);
329         } catch (ArrayIndexOutOfBoundsException e) {
330           logger.severe("Start index " + startIndex + " is out of bounds for file " + barFiles[i].getName());
331         }
332       }
333     } else {
334       for (int i = 0; i < eAll.length; i++) {
335         for (int j = 0; j < eAll[0].length; j++) {
336           try {
337             eAll[i][j] = Arrays.copyOfRange(eAll[i][j], startIndex, eAll[i][j].length);
338           } catch (ArrayIndexOutOfBoundsException e) {
339             logger.severe("Start index " + startIndex + " is out of bounds for file " + barFiles[i].getName());
340           }
341         }
342       }
343     }
344   }
345 
346   /**
347    * Set the end snapshot.
348    *
349    * @param endIndex the end index.
350    */
351   public void setEndSnapshot(int endIndex) {
352     this.endIndex = endIndex;
353     if (oneFile) {
354       for (int i = 0; i < eAllFlat.length; i++) {
355         try {
356           eAllFlat[i] = Arrays.copyOfRange(eAllFlat[i], 0, endIndex);
357         } catch (ArrayIndexOutOfBoundsException e) {
358           logger.severe("End index " + endIndex + " is out of bounds for file " + barFiles[i].getName());
359         }
360       }
361     } else {
362       for (int i = 0; i < eAll.length; i++) {
363         for (int j = 0; j < eAll[0].length; j++) {
364           try {
365             eAll[i][j] = Arrays.copyOfRange(eAll[i][j], 0, endIndex);
366           } catch (ArrayIndexOutOfBoundsException e) {
367             logger.severe("End index " + endIndex + " is out of bounds for file " + barFiles[i].getName());
368           }
369         }
370       }
371     }
372   }
373 
374   /**
375    * Read in observable data, try to leave as many fields in-tact as possible.
376    *
377    * @param multiDataObservable whether the observable data is multi-data.
378    * @param isBiasData          whether the data is bias data.
379    * @param isDerivativeData    whether the data is derivative data.
380    * @return whether the data was read successfully.
381    */
382   public boolean readObservableData(boolean multiDataObservable, boolean isBiasData, boolean isDerivativeData) {
383     if (isDerivativeData && !isBiasData) {
384       barFiles = fileLocation.listFiles((dir, name) -> name.matches("derivative_\\d+.mbar") ||
385           name.matches("derivative_\\d+.bar") ||
386           name.matches("derivatives_\\d+.mbar") ||
387           name.matches("derivatives_\\d+.bar") ||
388           name.matches("observable_\\d+.mbar") ||
389           name.matches("observable_\\d+.bar"));
390     } else if (isBiasData) {
391       barFiles = fileLocation.listFiles((dir, name) -> name.matches("bias_\\d+.mbar") ||
392           name.matches("bias_\\d+.bar"));
393     }
394     if (barFiles == null || barFiles.length == 0) {
395       return false;
396     }
397     // Sort files by state number
398     Arrays.sort(barFiles, (f1, f2) -> {
399       int state1 = Integer.parseInt(f1.getName().split("\\.")[0].split("_")[1]);
400       int state2 = Integer.parseInt(f2.getName().split("\\.")[0].split("_")[1]);
401       return Integer.compare(state1, state2);
402     });
403     if (oneFile) {
404       eAllFlat = readFile(barFiles[0].getName(), 0);
405       multiDataObservable = eAllFlat.length != numLambda;
406     } else {
407       eAll = new double[windows][][];
408       for (int i = 0; i < windows; i++) {
409         eAll[i] = readFile(barFiles[i].getName(), i);
410       }
411       int minSnaps = min(snaps);
412       int maxSnaps = max(snaps);
413 
414       // Basically just make sure eAll isn't jagged
415       boolean warn = minSnaps != maxSnaps;
416       if (warn) {
417         double[][][] temp = new double[eAll.length][eAll[0].length][maxSnaps];
418         for (int j = 0; j < windows; j++) {
419           for (int k = 0; k < windows; k++) {
420             System.arraycopy(eAll[j][k], 0, temp[j][k], 0, snaps[j]);
421             for (int l = snaps[j]; l < maxSnaps; l++) { // Fill in the rest with NaNs
422               temp[j][k][l] = Double.NaN;
423             }
424           }
425         }
426         eAll = temp;
427       }
428     }
429     // Apply cutoffs set by user
430     if (this.startIndex != -1) {
431       this.setStartSnapshot(this.startIndex);
432     }
433     if (this.endIndex != -1) {
434       this.setEndSnapshot(this.endIndex);
435     }
436 
437     // Set observable data and compute observable averages
438     if (isBiasData) {
439       if (!oneFile) {
440         mbar.setBiasData(eAll, multiDataObservable);
441       } else {
442         mbar.setBiasData(eAllFlat);
443       }
444     } else {
445       if (!oneFile) {
446         mbar.setObservableData(eAll, multiDataObservable, false);
447       } else {
448         mbar.setObservableData(eAllFlat, false);
449       }
450     }
451 
452     return true;
453   }
454 }