/*
 * Decompiled with CFR 0.152.
 */
package ffx.algorithms.thermodynamics;

import edu.rit.mp.Buf;
import edu.rit.mp.LongBuf;
import edu.rit.mp.buf.LongItemBuf;
import edu.rit.pj.Comm;
import ffx.algorithms.cli.DynamicsOptions;
import ffx.algorithms.cli.OSTOptions;
import ffx.algorithms.dynamics.MDWriteAction;
import ffx.algorithms.dynamics.MolecularDynamics;
import ffx.algorithms.mc.BoltzmannMC;
import ffx.algorithms.thermodynamics.MonteCarloOST;
import ffx.algorithms.thermodynamics.OrthogonalSpaceTempering;
import ffx.algorithms.thermodynamics.SendSynchronous;
import ffx.potential.MolecularAssembly;
import ffx.potential.cli.WriteoutOptions;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.LongConsumer;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.IntStream;
import org.apache.commons.configuration2.CompositeConfiguration;
import org.apache.commons.io.FilenameUtils;

public class RepExOST {
    private static final Logger logger = Logger.getLogger(RepExOST.class.getName());
    private static final int mainLoopTag = 2020;
    private final OrthogonalSpaceTempering orthogonalSpaceTempering;
    private final OrthogonalSpaceTempering.Histogram[] allHistograms;
    private final SendSynchronous[] sendSynchronous;
    private final LongConsumer algoRun;
    private final MolecularDynamics molecularDynamics;
    private final DynamicsOptions dynamicsOptions;
    private final String fileType;
    private final MonteCarloOST monteCarloOST;
    private final long stepsBetweenExchanges;
    private final Comm world;
    private final int rank;
    private final int numPairs;
    private final int[] rankToHisto;
    private final int[] histoToRank;
    private final boolean isMC;
    private final Random random;
    private final double invKT;
    private final String basePath;
    private final String[] allFilenames;
    private final File dynFile;
    private final String extension;
    private final long[] totalSwaps;
    private final long[] acceptedSwaps;
    private boolean reinitVelocities = true;
    private int currentHistoIndex;
    private double currentLambda;

    private RepExOST(OrthogonalSpaceTempering orthogonalSpaceTempering, MonteCarloOST monteCarloOST, MolecularDynamics molecularDynamics, OstType ostType, DynamicsOptions dynamicsOptions, OSTOptions ostOptions, CompositeConfiguration properties, String fileType, double repexInterval) throws IOException {
        long seed;
        this.orthogonalSpaceTempering = orthogonalSpaceTempering;
        switch (ostType.ordinal()) {
            case 0: {
                this.algoRun = this::runMD;
                this.isMC = false;
                break;
            }
            case 1: {
                this.algoRun = this::runMCOneStep;
                this.isMC = true;
                break;
            }
            case 2: {
                this.algoRun = this::runMCTwoStep;
                this.isMC = true;
                break;
            }
            default: {
                throw new IllegalArgumentException(" Could not recognize whether this is supposed to be MD, MC 1-step, or MC 2-step!");
            }
        }
        this.molecularDynamics = molecularDynamics;
        this.molecularDynamics.setAutomaticWriteouts(false);
        this.dynamicsOptions = dynamicsOptions;
        this.fileType = fileType;
        this.monteCarloOST = monteCarloOST;
        if (monteCarloOST != null) {
            monteCarloOST.setAutomaticWriteouts(false);
        }
        this.extension = WriteoutOptions.toArchiveExtension((String)fileType);
        this.world = Comm.world();
        this.rank = this.world.rank();
        int size = this.world.size();
        MolecularAssembly[] allAssemblies = this.molecularDynamics.getAssemblyArray();
        this.allFilenames = (String[])Arrays.stream(allAssemblies).map(MolecularAssembly::getFile).map(File::getName).map(FilenameUtils::getBaseName).toArray(String[]::new);
        File firstFile = allAssemblies[0].getFile();
        this.basePath = FilenameUtils.getFullPath((String)firstFile.getAbsolutePath()) + File.separator;
        String baseFileName = FilenameUtils.getBaseName((String)firstFile.getAbsolutePath());
        this.dynFile = new File(String.format("%s%d%s%s.dyn", this.basePath, this.rank, File.separator, baseFileName));
        this.molecularDynamics.setFallbackDynFile(this.dynFile);
        this.currentHistoIndex = orthogonalSpaceTempering.getHistogram().ld.histogramIndex;
        this.allHistograms = orthogonalSpaceTempering.getAllHistograms();
        this.numPairs = size - 1;
        this.invKT = -1.0 / (0.0019872042586408316 * dynamicsOptions.getTemperature());
        LongItemBuf seedBuf = LongBuf.buffer((long)0L);
        if (this.rank == 0) {
            seed = properties.getLong("randomseed", ThreadLocalRandom.current().nextLong());
            seedBuf.put(0, seed);
            this.world.broadcast(0, (Buf)seedBuf);
        } else {
            this.world.broadcast(0, (Buf)seedBuf);
            seed = seedBuf.get(0);
        }
        this.random = new Random(seed);
        double timestep = dynamicsOptions.getDt() * 0.001;
        this.stepsBetweenExchanges = Math.max(1, (int)(repexInterval / timestep));
        this.sendSynchronous = (SendSynchronous[])Arrays.stream(this.allHistograms).map(OrthogonalSpaceTempering.Histogram::getSynchronousSend).map(Optional::get).toArray(SendSynchronous[]::new);
        if (this.sendSynchronous.length < 1) {
            throw new IllegalArgumentException(" No SynchronousSend objects were found!");
        }
        this.rankToHisto = IntStream.range(0, size).toArray();
        this.histoToRank = Arrays.copyOf(this.rankToHisto, size);
        Arrays.stream(this.sendSynchronous).forEach(ss -> ss.setHistograms(this.allHistograms, this.rankToHisto));
        this.totalSwaps = new long[this.numPairs];
        this.acceptedSwaps = new long[this.numPairs];
        Arrays.fill(this.totalSwaps, 0L);
        Arrays.fill(this.acceptedSwaps, 0L);
        this.setFiles();
        this.setHistogram(this.rank);
    }

    public static RepExOST repexMC(OrthogonalSpaceTempering orthogonalSpaceTempering, MonteCarloOST monteCarloOST, DynamicsOptions dynamicsOptions, OSTOptions ostOptions, CompositeConfiguration compositeConfiguration, String fileType, boolean twoStep, double repexInterval) throws IOException {
        MolecularDynamics md = monteCarloOST.getMD();
        OstType type = twoStep ? OstType.MC_TWOSTEP : OstType.MC_ONESTEP;
        return new RepExOST(orthogonalSpaceTempering, monteCarloOST, md, type, dynamicsOptions, ostOptions, compositeConfiguration, fileType, repexInterval);
    }

    public static RepExOST repexMD(OrthogonalSpaceTempering orthogonalSpaceTempering, MolecularDynamics molecularDynamics, DynamicsOptions dynamicsOptions, OSTOptions ostOptions, CompositeConfiguration compositeConfiguration, String fileType, double repexInterval) throws IOException {
        return new RepExOST(orthogonalSpaceTempering, null, molecularDynamics, OstType.MD, dynamicsOptions, ostOptions, compositeConfiguration, fileType, repexInterval);
    }

    public OrthogonalSpaceTempering getOST() {
        return this.orthogonalSpaceTempering;
    }

    public void mainLoop(long numTimesteps, boolean equilibrate) throws IOException {
        if (this.isMC) {
            this.monteCarloOST.setEquilibration(equilibrate);
        }
        this.currentLambda = this.orthogonalSpaceTempering.getLambda();
        Arrays.fill(this.totalSwaps, 0L);
        Arrays.fill(this.acceptedSwaps, 0L);
        if (equilibrate) {
            logger.info(String.format(" Equilibrating RepEx OST without exchanges on histogram %d.", this.currentHistoIndex));
            this.algoRun.accept(numTimesteps);
            this.reinitVelocities = false;
        } else {
            long numExchanges = numTimesteps / this.stepsBetweenExchanges;
            int i = 0;
            while ((long)i < numExchanges) {
                logger.info(String.format(" Beginning of RepEx loop %d of %d, operating on histogram %d", i + 1, numExchanges, this.currentHistoIndex));
                this.world.barrier(2020);
                this.algoRun.accept(this.stepsBetweenExchanges);
                this.orthogonalSpaceTempering.logOutputFiles(this.currentHistoIndex);
                this.world.barrier(2020);
                this.proposeSwaps(i % 2, 2);
                this.setFiles();
                long mdMoveNum = (long)i * this.stepsBetweenExchanges;
                this.currentLambda = this.orthogonalSpaceTempering.getLambda();
                boolean trySnapshot = this.currentLambda >= this.orthogonalSpaceTempering.getLambdaWriteOut();
                EnumSet<MDWriteAction> written = this.molecularDynamics.writeFilesForStep(mdMoveNum, trySnapshot, true);
                if (written.contains((Object)MDWriteAction.RESTART)) {
                    this.orthogonalSpaceTempering.writeAdditionalRestartInfo(false);
                }
                this.reinitVelocities = false;
                ++i;
            }
        }
        logger.info(" Final rank-to-histogram mapping: " + Arrays.toString(this.rankToHisto));
    }

    private void setHistogram(int index) {
        this.currentHistoIndex = index;
        this.orthogonalSpaceTempering.switchHistogram(index);
    }

    private void logIfMaster(String message) {
        this.logIfMaster(Level.INFO, message);
    }

    private void logIfMaster(Level level, String message) {
        if (this.rank == 0) {
            logger.log(level, message);
        }
    }

    private void logIfSwapping(String message) {
        this.logIfMaster(message);
    }

    private void setFiles() {
        File[] trajFiles = (File[])Arrays.stream(this.allFilenames).map(fn -> String.format("%s%d%s%s.%s", this.basePath, this.currentHistoIndex, File.separator, fn, this.extension)).map(File::new).toArray(File[]::new);
        this.molecularDynamics.setArchiveFiles(trajFiles);
    }

    private void proposeSwaps(int offset, int stride) {
        for (int i = offset; i < this.numPairs; i += stride) {
            int rankLow = this.histoToRank[i];
            int rankHigh = this.histoToRank[i + 1];
            OrthogonalSpaceTempering.Histogram histoLow = this.allHistograms[i];
            OrthogonalSpaceTempering.Histogram histoHigh = this.allHistograms[i + 1];
            double lamLow = histoLow.getLastReceivedLambda();
            double dUdLLow = histoLow.getLastReceivedDUDL();
            double lamHigh = histoHigh.getLastReceivedLambda();
            double dUdLHigh = histoHigh.getLastReceivedDUDL();
            double eii = histoLow.computeBiasEnergy(lamLow, dUdLLow);
            double eij = histoLow.computeBiasEnergy(lamHigh, dUdLHigh);
            double eji = histoHigh.computeBiasEnergy(lamLow, dUdLLow);
            double ejj = histoHigh.computeBiasEnergy(lamHigh, dUdLHigh);
            this.logIfSwapping(String.format("\n Proposing exchange between histograms %d (rank %d) and %d (rank %d).\n Li: %.6f dU/dLi: %.6f Lj: %.6f dU/dLj: %.6f", i, rankLow, i + 1, rankHigh, lamLow, dUdLLow, lamHigh, dUdLHigh));
            double e1 = eii + ejj;
            double e2 = eji + eij;
            boolean accept = BoltzmannMC.evaluateMove(this.random, this.invKT, e1, e2);
            double acceptChance = BoltzmannMC.acceptChance(this.invKT, e1, e2);
            String desc = accept ? "Accepted" : "Rejected";
            this.logIfSwapping(String.format(" %s exchange with probability %.5f based on Eii %.6f, Ejj %.6f, Eij %.6f, Eji %.6f kcal/mol", desc, acceptChance, eii, ejj, eij, eji));
            int n = i;
            this.totalSwaps[n] = this.totalSwaps[n] + 1L;
            if (accept) {
                int n2 = i;
                this.acceptedSwaps[n2] = this.acceptedSwaps[n2] + 1L;
                this.switchHistos(rankLow, rankHigh, i);
            }
            double acceptRate = (double)this.acceptedSwaps[i] / (double)this.totalSwaps[i];
            this.logIfSwapping(String.format(" Replica exchange acceptance rate for pair %d-%d is %.3f%%", i, i + 1, acceptRate * 100.0));
        }
    }

    private void switchHistos(int rankLow, int rankHigh, int histoLow) {
        int histoHigh;
        this.rankToHisto[rankLow] = histoHigh = histoLow + 1;
        this.rankToHisto[rankHigh] = histoLow;
        this.histoToRank[histoLow] = rankHigh;
        this.histoToRank[histoHigh] = rankLow;
        this.setHistogram(this.rankToHisto[this.rank]);
        this.orthogonalSpaceTempering.setLambda(this.currentLambda);
        for (SendSynchronous send : this.sendSynchronous) {
            send.updateRanks(this.rankToHisto);
        }
    }

    private void runMCOneStep(long numSteps) {
        this.monteCarloOST.setTotalSteps(numSteps);
        this.monteCarloOST.sampleOneStep();
    }

    private void runMCTwoStep(long numSteps) {
        this.monteCarloOST.setTotalSteps(numSteps);
        this.monteCarloOST.sampleTwoStep();
    }

    private void runMD(long numSteps) {
        this.molecularDynamics.dynamic(numSteps, this.dynamicsOptions.getDt(), this.dynamicsOptions.getReport(), this.dynamicsOptions.getSnapshotInterval(), this.dynamicsOptions.getTemperature(), this.reinitVelocities, this.fileType, this.dynamicsOptions.getCheckpoint(), this.dynFile);
    }

    private static enum OstType {
        MD,
        MC_ONESTEP,
        MC_TWOSTEP;

    }
}

