1 package ffx.numerics.estimator;
2
3 import ffx.numerics.estimator.MultistateBennettAcceptanceRatio.SeedType;
4
5 import java.io.BufferedReader;
6 import java.io.File;
7 import java.io.FileReader;
8 import java.io.IOException;
9 import java.util.ArrayList;
10 import java.util.Arrays;
11 import java.util.logging.Logger;
12
13 import static org.apache.commons.lang3.math.NumberUtils.max;
14 import static org.apache.commons.lang3.math.NumberUtils.min;
15
16
17
18
19
20
21
22
23
24
25
26
27
28 public class MBARFilter {
29 private static final Logger logger = Logger.getLogger(MBARFilter.class.getName());
30 private File[] barFiles;
31 private final File fileLocation;
32 private double[][][] eAll;
33 private double[][] eAllFlat;
34 private double[] temperatures;
35 private int[] snaps;
36 private final int[] numLambdas;
37 private int windowsRead;
38 private int windows;
39 private MultistateBennettAcceptanceRatio mbar;
40 private int startIndex = -1;
41 private int endIndex = -1;
42 private int numLambda;
43 private boolean oneFile = false;
44
45
46
47
48
49
50
51 public MBARFilter(File fileLocation, boolean continuousLambda) {
52 this.fileLocation = fileLocation;
53 barFiles = fileLocation.listFiles((dir, name) -> name.matches("energy_\\d+.mbar") || name.matches("energy_\\d+.bar"));
54 assert barFiles != null;
55 if (barFiles.length == 0) {
56 logger.severe(" No files matching 'energy_\\d+.mbar' or 'energy_\\d+.bar' found in " +
57 fileLocation.getAbsolutePath());
58 }
59
60 Arrays.sort(barFiles, (f1, f2) -> {
61 int state1 = Integer.parseInt(f1.getName().split("\\.")[0].split("_")[1]);
62 int state2 = Integer.parseInt(f2.getName().split("\\.")[0].split("_")[1]);
63 return Integer.compare(state1, state2);
64 });
65 windows = barFiles.length;
66 temperatures = new double[windows];
67 snaps = new int[windows];
68 numLambdas = new int[windows];
69 if (continuousLambda) {
70 this.parseFile();
71 oneFile = true;
72 } else {
73 this.parseFiles();
74 }
75 }
76
77
78
79
80
81
82
83 public MultistateBennettAcceptanceRatio getMBAR(SeedType seedType) {
84 return getMBAR(seedType, 1e-7);
85 }
86
87
88
89
90
91
92
93
94 public MultistateBennettAcceptanceRatio getMBAR(SeedType seedType, double tolerance) {
95 double[] lambda = new double[windows];
96 for (int i = 0; i < windows; i++) {
97 lambda[i] = i / (windows - 1.0);
98 }
99 if (eAll != null) {
100 this.mbar = new MultistateBennettAcceptanceRatio(lambda, eAll, temperatures, tolerance, seedType);
101 } else {
102 this.mbar = new MultistateBennettAcceptanceRatio(lambda, snaps, eAllFlat, temperatures, tolerance, seedType);
103 }
104 return this.mbar;
105 }
106
107
108
109
110
111
112
113
114 public MultistateBennettAcceptanceRatio[] getPeriodComparisonMBAR(SeedType seedType, double tolerance) {
115 double[] lambda = new double[windows];
116 for (int i = 0; i < windows; i++) {
117 lambda[i] = i / (windows - 1.0);
118 }
119 MultistateBennettAcceptanceRatio[] mbar = new MultistateBennettAcceptanceRatio[10];
120 for (int i = 0; i < 10; i++) {
121 double[][][] e = new double[windows][][];
122 int maxSamples = max(snaps);
123 int timePeriod = maxSamples / 10;
124 for (int j = 0; j < windows; j++) {
125 e[j] = new double[windows][];
126 for (int k = 0; k < windows; k++) {
127 e[j][k] = new double[timePeriod];
128 if (timePeriod * (i + 1) > maxSamples) {
129 System.arraycopy(eAll[j][k], timePeriod * i, e[j][k], 0, maxSamples - timePeriod * i);
130 } else {
131 System.arraycopy(eAll[j][k], timePeriod * i, e[j][k], 0, timePeriod);
132 }
133 }
134 }
135 logger.info(" Period: " + (timePeriod * i) + " - " + (timePeriod * (i + 1)) + " samples calculation.");
136 mbar[i] = new MultistateBennettAcceptanceRatio(lambda, e, temperatures, tolerance, seedType);
137 }
138 return mbar;
139 }
140
141
142
143
144 private void parseFiles() {
145 eAll = new double[windows][][];
146 for (int i = 0; i < windows; i++) {
147 eAll[i] = readFile(barFiles[i].getName(), i);
148 }
149 if (windowsRead != windows) {
150 logger.severe("Failed to read all files in " + fileLocation.getAbsolutePath());
151 }
152 int minSnaps = min(snaps);
153 int maxSnaps = max(snaps);
154
155
156 boolean warn = minSnaps != maxSnaps;
157 if (warn) {
158 logger.warning("NOT ALL FILES CONTAINED THE SAME NUMBER OF SNAPSHOTS. ");
159 logger.warning("SAMPLES PER WINDOW: " + Arrays.toString(snaps));
160 double[][][] temp = new double[eAll.length][eAll[0].length][maxSnaps];
161 for (int j = 0; j < windows; j++) {
162 for (int k = 0; k < windows; k++) {
163 System.arraycopy(eAll[j][k], 0, temp[j][k], 0, snaps[j]);
164 for (int l = snaps[j]; l < maxSnaps; l++) {
165 temp[j][k][l] = Double.NaN;
166 }
167 }
168 }
169 eAll = temp;
170 }
171
172 int maxLambdas = max(numLambdas);
173 for (int i = 0; i < windows; i++) {
174 if (numLambdas[i] != maxLambdas) {
175 logger.severe(" Number of lambda evaluations in file " + barFiles[i].getName() +
176 " does not match the number of lambda evaluations in the other files. This is unrecoverable.");
177 }
178 }
179
180 warn = maxLambdas != windows;
181 if (warn) {
182 String symbol = maxLambdas > windows ? "MORE" : "LESS";
183 logger.warning("FILES CONTAIN " + symbol + " LAMBDA EVALUATIONS THAN ACTUAL TRAJECTORIES.");
184 if (windows == 1) {
185 logger.warning(" USE --continuousLambda FLAG IF USING A SINGLE FILE.");
186 }
187 symbol = maxLambdas > windows ? "Add" : "Remove";
188 logger.severe(symbol + " completely empty files (zero lines) to fill in the gaps.");
189 }
190 }
191
192
193
194
195 private void parseFile() {
196 eAllFlat = readFile(barFiles[0].getName(), 0);
197
198 snaps = new int[numLambda];
199 for (int i = 0; i < numLambda; i++) {
200 snaps[i] = eAllFlat[i].length;
201 }
202 windows = numLambda;
203 double temp = temperatures[0];
204 temperatures = new double[windows];
205 for (int i = 0; i < windows; i++) {
206 temperatures[i] = temp;
207 }
208 }
209
210
211
212
213
214
215
216
217 public void writeFiles(File mbarFileLoc, double[][][] energies, double[] temperatures) {
218 if (temperatures.length != windows) {
219 double temp = temperatures[0];
220 temperatures = new double[windows];
221 for (int i = 0; i < windows; i++) {
222 temperatures[i] = temp;
223 }
224 }
225 for (int i = 0; i < windows; i++) {
226 File file = new File(mbarFileLoc, "energy_" + i + ".mbar");
227 writeFile(energies[i], file, temperatures[i]);
228 }
229 }
230
231
232
233
234
235
236
237 private double[][] readFile(String fileName, int state) {
238 File tempBarFile = new File(fileLocation, fileName);
239 ArrayList<ArrayList<Double>> tempFileEnergies = new ArrayList<>();
240 for (int i = 0; i < windows; i++) {
241 tempFileEnergies.add(new ArrayList<>());
242 }
243 double[][] fileEnergies;
244 try (FileReader fr1 = new FileReader(tempBarFile);
245 BufferedReader br1 = new BufferedReader(fr1);) {
246
247 String line = br1.readLine();
248 if (line == null) {
249 for (int i = 0; i < windows; i++) {
250 tempFileEnergies.get(i).add(Double.NaN);
251 }
252 snaps[state] = 0;
253 temperatures[state] = 298;
254 MultistateBennettAcceptanceRatio.FORCE_ZEROS_SEED = true;
255 if (state != 0) {
256 numLambdas[state] = numLambdas[state - 1];
257 }
258 windowsRead++;
259 fileEnergies = new double[windows][];
260 for (int i = 0; i < windows; i++) {
261 fileEnergies[i] = new double[tempFileEnergies.get(i).size()];
262 for (int j = 0; j < tempFileEnergies.get(i).size(); j++) {
263 fileEnergies[i][j] = tempFileEnergies.get(i).get(j);
264 }
265 }
266 return fileEnergies;
267 }
268 String[] tokens = line.trim().split("\\t *| +");
269 temperatures[state] = Double.parseDouble(tokens[1]);
270
271 int count = 0;
272 numLambda = 0;
273 line = br1.readLine();
274 while (line != null) {
275 tokens = line.trim().split("\\t *| +");
276 numLambda = tokens.length - 1;
277 for (int i = 1; i < tokens.length; i++) {
278 if (tempFileEnergies.size() < i) {
279 tempFileEnergies.add(new ArrayList<>());
280 }
281 tempFileEnergies.get(i - 1).add(Double.parseDouble(tokens[i]));
282 }
283 count++;
284 line = br1.readLine();
285 }
286 numLambdas[state] = numLambda;
287 if (state != 0 && numLambdas[0] == 0) {
288 numLambdas[0] = numLambda;
289 }
290 snaps[state] = count;
291 } catch (IOException e) {
292 logger.info("Failed to read MBAR file: " + tempBarFile.getAbsolutePath());
293 throw new RuntimeException(e);
294 }
295
296 fileEnergies = new double[tempFileEnergies.size()][];
297 for (int i = 0; i < tempFileEnergies.size(); i++) {
298 fileEnergies[i] = new double[tempFileEnergies.get(i).size()];
299 for (int j = 0; j < tempFileEnergies.get(i).size(); j++) {
300 fileEnergies[i][j] = tempFileEnergies.get(i).get(j);
301 }
302 }
303 windowsRead++;
304 return fileEnergies;
305 }
306
307
308
309
310
311
312
313
314 public void writeFile(double[][] energies, File file, double temperature) {
315 MultistateBennettAcceptanceRatio.writeFile(energies, file, temperature);
316 }
317
318
319
320
321
322
323 public void setStartSnapshot(int startIndex) {
324 this.startIndex = startIndex;
325 if (oneFile) {
326 for (int i = 0; i < eAllFlat.length; i++) {
327 try {
328 eAllFlat[i] = Arrays.copyOfRange(eAllFlat[i], startIndex, eAllFlat[i].length);
329 } catch (ArrayIndexOutOfBoundsException e) {
330 logger.severe("Start index " + startIndex + " is out of bounds for file " + barFiles[i].getName());
331 }
332 }
333 } else {
334 for (int i = 0; i < eAll.length; i++) {
335 for (int j = 0; j < eAll[0].length; j++) {
336 try {
337 eAll[i][j] = Arrays.copyOfRange(eAll[i][j], startIndex, eAll[i][j].length);
338 } catch (ArrayIndexOutOfBoundsException e) {
339 logger.severe("Start index " + startIndex + " is out of bounds for file " + barFiles[i].getName());
340 }
341 }
342 }
343 }
344 }
345
346
347
348
349
350
351 public void setEndSnapshot(int endIndex) {
352 this.endIndex = endIndex;
353 if (oneFile) {
354 for (int i = 0; i < eAllFlat.length; i++) {
355 try {
356 eAllFlat[i] = Arrays.copyOfRange(eAllFlat[i], 0, endIndex);
357 } catch (ArrayIndexOutOfBoundsException e) {
358 logger.severe("End index " + endIndex + " is out of bounds for file " + barFiles[i].getName());
359 }
360 }
361 } else {
362 for (int i = 0; i < eAll.length; i++) {
363 for (int j = 0; j < eAll[0].length; j++) {
364 try {
365 eAll[i][j] = Arrays.copyOfRange(eAll[i][j], 0, endIndex);
366 } catch (ArrayIndexOutOfBoundsException e) {
367 logger.severe("End index " + endIndex + " is out of bounds for file " + barFiles[i].getName());
368 }
369 }
370 }
371 }
372 }
373
374
375
376
377
378
379
380
381
382 public boolean readObservableData(boolean multiDataObservable, boolean isBiasData, boolean isDerivativeData) {
383 if (isDerivativeData && !isBiasData) {
384 barFiles = fileLocation.listFiles((dir, name) -> name.matches("derivative_\\d+.mbar") ||
385 name.matches("derivative_\\d+.bar") ||
386 name.matches("derivatives_\\d+.mbar") ||
387 name.matches("derivatives_\\d+.bar") ||
388 name.matches("observable_\\d+.mbar") ||
389 name.matches("observable_\\d+.bar"));
390 } else if (isBiasData) {
391 barFiles = fileLocation.listFiles((dir, name) -> name.matches("bias_\\d+.mbar") ||
392 name.matches("bias_\\d+.bar"));
393 }
394 if (barFiles == null || barFiles.length == 0) {
395 return false;
396 }
397
398 Arrays.sort(barFiles, (f1, f2) -> {
399 int state1 = Integer.parseInt(f1.getName().split("\\.")[0].split("_")[1]);
400 int state2 = Integer.parseInt(f2.getName().split("\\.")[0].split("_")[1]);
401 return Integer.compare(state1, state2);
402 });
403 if (oneFile) {
404 eAllFlat = readFile(barFiles[0].getName(), 0);
405 multiDataObservable = eAllFlat.length != numLambda;
406 } else {
407 eAll = new double[windows][][];
408 for (int i = 0; i < windows; i++) {
409 eAll[i] = readFile(barFiles[i].getName(), i);
410 }
411 int minSnaps = min(snaps);
412 int maxSnaps = max(snaps);
413
414
415 boolean warn = minSnaps != maxSnaps;
416 if (warn) {
417 double[][][] temp = new double[eAll.length][eAll[0].length][maxSnaps];
418 for (int j = 0; j < windows; j++) {
419 for (int k = 0; k < windows; k++) {
420 System.arraycopy(eAll[j][k], 0, temp[j][k], 0, snaps[j]);
421 for (int l = snaps[j]; l < maxSnaps; l++) {
422 temp[j][k][l] = Double.NaN;
423 }
424 }
425 }
426 eAll = temp;
427 }
428 }
429
430 if (this.startIndex != -1) {
431 this.setStartSnapshot(this.startIndex);
432 }
433 if (this.endIndex != -1) {
434 this.setEndSnapshot(this.endIndex);
435 }
436
437
438 if (isBiasData) {
439 if (!oneFile) {
440 mbar.setBiasData(eAll, multiDataObservable);
441 } else {
442 mbar.setBiasData(eAllFlat);
443 }
444 } else {
445 if (!oneFile) {
446 mbar.setObservableData(eAll, multiDataObservable, false);
447 } else {
448 mbar.setObservableData(eAllFlat, false);
449 }
450 }
451
452 return true;
453 }
454 }