/*
 * Decompiled with CFR 0.152.
 */
package ffx.algorithms.mc;

import ffx.algorithms.dynamics.thermostats.Thermostat;
import ffx.algorithms.mc.MCMove;
import ffx.algorithms.mc.MonteCarloListener;
import ffx.algorithms.mc.RosenbluthChi0Move;
import ffx.potential.ForceFieldEnergy;
import ffx.potential.MolecularAssembly;
import ffx.potential.bonded.AminoAcidUtils;
import ffx.potential.bonded.Atom;
import ffx.potential.bonded.Residue;
import ffx.potential.bonded.ResidueState;
import ffx.potential.bonded.Torsion;
import ffx.potential.parsers.PDBFilter;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.logging.Logger;
import org.apache.commons.io.FilenameUtils;
import org.apache.commons.math3.util.FastMath;

public class RosenbluthOBMC
implements MonteCarloListener {
    private static final Logger logger = Logger.getLogger(RosenbluthOBMC.class.getName());
    private final MolecularAssembly molecularAssembly;
    private final ForceFieldEnergy forceFieldEnergy;
    private final Thermostat thermostat;
    private final List<Residue> targets;
    private final int mcFrequency;
    private final int trialSetSize;
    private int steps = 0;
    private double Wn;
    private double Wo;
    private int numMovesProposed = 0;
    private StringBuilder report = new StringBuilder();
    private boolean writeSnapshots = false;

    public RosenbluthOBMC(MolecularAssembly molecularAssembly, ForceFieldEnergy forceFieldEnergy, Thermostat thermostat, List<Residue> targets, int mcFrequency, int trialSetSize) {
        this.targets = targets;
        this.mcFrequency = mcFrequency;
        this.trialSetSize = trialSetSize;
        this.molecularAssembly = molecularAssembly;
        this.forceFieldEnergy = forceFieldEnergy;
        this.thermostat = thermostat;
    }

    public RosenbluthOBMC(MolecularAssembly molecularAssembly, ForceFieldEnergy forceFieldEnergy, Thermostat thermostat, List<Residue> targets, int mcFrequency, int trialSetSize, boolean writeSnapshots) {
        this(molecularAssembly, forceFieldEnergy, thermostat, targets, mcFrequency, trialSetSize);
        this.writeSnapshots = writeSnapshots;
    }

    @Override
    public boolean mcUpdate(double temperature) {
        ++this.steps;
        if (this.steps % this.mcFrequency == 0) {
            return this.mcStep();
        }
        return false;
    }

    private boolean mcStep() {
        boolean accepted;
        ++this.numMovesProposed;
        int index = ThreadLocalRandom.current().nextInt(this.targets.size());
        Residue target = this.targets.get(index);
        ResidueState origState = target.storeState();
        Torsion chi0 = this.getChiZeroTorsion(target);
        this.writeSnapshot("orig");
        List<MCMove> oldTrialSet = this.createTrialSet(target, origState, this.trialSetSize - 1);
        List<MCMove> newTrialSet = this.createTrialSet(target, origState, this.trialSetSize);
        this.report = new StringBuilder();
        this.report.append(String.format(" Rosenbluth Rotamer MC Move: %4d\n", this.numMovesProposed));
        this.report.append(String.format("    residue:   %s\n", target));
        this.report.append(String.format("    chi0:      %s\n", chi0.toString()));
        MCMove proposal = this.calculateRosenbluthFactors(target, chi0, origState, oldTrialSet, origState, newTrialSet);
        this.setState(target, origState);
        this.writeSnapshot("uIndO");
        double uIndO = this.getTotalEnergy() - this.getTorsionEnergy(chi0);
        proposal.move();
        this.writeSnapshot("uIndN");
        double uIndN = this.getTotalEnergy() - this.getTorsionEnergy(chi0);
        double temperature = this.thermostat.getCurrentTemperature();
        double beta = 1.0 / (0.0019872042586408316 * temperature);
        double dInd = uIndN - uIndO;
        double dIndE = FastMath.exp((double)(-beta * dInd));
        double criterion = this.Wn / this.Wo * FastMath.exp((double)(-beta * (uIndN - uIndO)));
        double metropolis = FastMath.min((double)1.0, (double)criterion);
        double rng = ThreadLocalRandom.current().nextDouble();
        this.report.append(String.format("    theta:     %3.2f\n", ((RosenbluthChi0Move)proposal).theta));
        this.report.append(String.format("    criterion: %1.4f\n", criterion));
        this.report.append(String.format("       Wn/Wo:     %.2f\n", this.Wn / this.Wo));
        this.report.append(String.format("       uIndN,O:  %7.2f\t%7.2f\n", uIndN, uIndO));
        this.report.append(String.format("       dInd(E):  %7.2f\t%7.2f\n", dInd, dIndE));
        this.report.append(String.format("    rng:       %1.4f\n", rng));
        if (rng < metropolis) {
            this.report.append(" Accepted.\n");
            accepted = true;
        } else {
            proposal.revertMove();
            this.report.append(" Denied.\n");
            accepted = false;
        }
        logger.info(this.report.toString());
        this.Wn = 0.0;
        this.Wo = 0.0;
        return accepted;
    }

    private List<MCMove> createTrialSet(Residue target, ResidueState state, int setSize) {
        ArrayList<MCMove> moves = new ArrayList<MCMove>();
        this.setState(target, state);
        for (int i = 0; i < setSize; ++i) {
            moves.add(new RosenbluthChi0Move(target));
        }
        return moves;
    }

    private MCMove calculateRosenbluthFactors(Residue target, Torsion chi0, ResidueState oldConf, List<MCMove> oldTrialSet, ResidueState newConf, List<MCMove> newTrialSet) {
        double temperature = this.thermostat.getCurrentTemperature();
        double beta = 1.0 / (0.0019872042586408316 * temperature);
        this.Wo = FastMath.exp((double)(-beta * this.getTorsionEnergy(chi0)));
        this.report.append(String.format("    TestSet (Old): %5s\t%7s\t\t%7s\n", "uDepO", "uDepOe", "Sum(Wo)"));
        this.report.append(String.format("       Orig %d:   %7.4f\t%7.4f\t\t%7.4f\n", 0, this.getTorsionEnergy(chi0), FastMath.exp((double)(-beta * this.getTorsionEnergy(chi0))), this.Wo));
        for (int i = 0; i < oldTrialSet.size(); ++i) {
            this.setState(target, oldConf);
            MCMove move = oldTrialSet.get(i);
            move.move();
            double uDepO = this.getTorsionEnergy(chi0);
            double uDepOe = FastMath.exp((double)(-beta * uDepO));
            this.Wo += uDepOe;
            if (i < 5 || i >= oldTrialSet.size() - 5) {
                this.report.append(String.format("       Prop %d:   %7.4f\t%7.4f\t\t%7.4f\n", i + 1, uDepO, uDepOe, this.Wo));
                this.writeSnapshot("ots");
                continue;
            }
            if (i != 5) continue;
            this.report.append("        ... \n");
        }
        this.Wn = 0.0;
        double[] uDepN = new double[newTrialSet.size()];
        double[] uDepNe = new double[newTrialSet.size()];
        this.report.append(String.format("    TestSet (New): %5s\t%7s\t\t%7s\n", "uDepN", "uDepNe", "Sum(Wn)"));
        for (int i = 0; i < newTrialSet.size(); ++i) {
            this.setState(target, newConf);
            MCMove move = newTrialSet.get(i);
            move.move();
            uDepN[i] = this.getTorsionEnergy(chi0);
            uDepNe[i] = FastMath.exp((double)(-beta * uDepN[i]));
            this.Wn += uDepNe[i];
            if (i < 5 || i >= newTrialSet.size() - 5) {
                this.report.append(String.format("       Prop %d:   %7.4f\t%7.4f\t\t%7.4f\n", i, uDepN[i], uDepNe[i], this.Wn));
                this.writeSnapshot("nts");
                continue;
            }
            if (i != 5) continue;
            this.report.append("        ... \n");
        }
        this.setState(target, oldConf);
        MCMove proposal = null;
        double rng = ThreadLocalRandom.current().nextDouble(this.Wn);
        double running = 0.0;
        for (int i = 0; i < newTrialSet.size(); ++i) {
            if (!(rng < (running += uDepNe[i]))) continue;
            proposal = newTrialSet.get(i);
            double prob = uDepNe[i] / this.Wn * 100.0;
            this.report.append(String.format("       Chose %d   %7.4f\t%7.4f\t  %4.1f%%\n", i, uDepN[i], uDepNe[i], prob));
            break;
        }
        if (proposal == null) {
            logger.severe("Programming error.");
        }
        return proposal;
    }

    private double getTotalEnergy() {
        double[] x = new double[this.forceFieldEnergy.getNumberOfVariables() * 3];
        this.forceFieldEnergy.getCoordinates(x);
        return this.forceFieldEnergy.energy(x);
    }

    private double getTorsionEnergy(Torsion torsion) {
        return torsion.energy(false);
    }

    private Torsion getChiZeroTorsion(Residue residue) {
        AminoAcidUtils.AminoAcid3 name = AminoAcidUtils.AminoAcid3.valueOf((String)residue.getName());
        List torsions = residue.getTorsionList();
        switch (name) {
            case VAL: {
                Atom N = (Atom)residue.getAtomNode("N");
                Atom CA = (Atom)residue.getAtomNode("CA");
                Atom CB = (Atom)residue.getAtomNode("CB");
                Atom CG1 = (Atom)residue.getAtomNode("CG1");
                for (Torsion torsion : torsions) {
                    if (!torsion.compare(N, CA, CB, CG1)) continue;
                    return torsion;
                }
                break;
            }
            case ILE: {
                Atom N = (Atom)residue.getAtomNode("N");
                Atom CA = (Atom)residue.getAtomNode("CA");
                Atom CB = (Atom)residue.getAtomNode("CB");
                Atom CG1 = (Atom)residue.getAtomNode("CG1");
                for (Torsion torsion : torsions) {
                    if (!torsion.compare(N, CA, CB, CG1)) continue;
                    return torsion;
                }
                break;
            }
            case SER: {
                Atom N = (Atom)residue.getAtomNode("N");
                Atom CA = (Atom)residue.getAtomNode("CA");
                Atom CB = (Atom)residue.getAtomNode("CB");
                Atom OG = (Atom)residue.getAtomNode("OG");
                for (Torsion torsion : torsions) {
                    if (!torsion.compare(N, CA, CB, OG)) continue;
                    return torsion;
                }
                break;
            }
            case THR: {
                Atom N = (Atom)residue.getAtomNode("N");
                Atom CA = (Atom)residue.getAtomNode("CA");
                Atom CB = (Atom)residue.getAtomNode("CB");
                Atom OG1 = (Atom)residue.getAtomNode("OG1");
                for (Torsion torsion : torsions) {
                    if (!torsion.compare(N, CA, CB, OG1)) continue;
                    return torsion;
                }
                break;
            }
            case CYX: {
                Atom N = (Atom)residue.getAtomNode("N");
                Atom CA = (Atom)residue.getAtomNode("CA");
                Atom CB = (Atom)residue.getAtomNode("CB");
                Atom SG = (Atom)residue.getAtomNode("SG");
                for (Torsion torsion : torsions) {
                    if (!torsion.compare(N, CA, CB, SG)) continue;
                    return torsion;
                }
                break;
            }
            case CYD: {
                Atom N = (Atom)residue.getAtomNode("N");
                Atom CA = (Atom)residue.getAtomNode("CA");
                Atom CB = (Atom)residue.getAtomNode("CB");
                Atom SG = (Atom)residue.getAtomNode("SG");
                for (Torsion torsion : torsions) {
                    if (!torsion.compare(N, CA, CB, SG)) continue;
                    return torsion;
                }
                break;
            }
            default: {
                Atom N = (Atom)residue.getAtomNode("N");
                Atom CA = (Atom)residue.getAtomNode("CA");
                Atom CB = (Atom)residue.getAtomNode("CB");
                Atom CG = (Atom)residue.getAtomNode("CG");
                for (Torsion torsion : torsions) {
                    if (!torsion.compare(N, CA, CB, CG)) continue;
                    return torsion;
                }
                logger.info("Couldn't find chi[0] for residue " + String.valueOf(residue));
                return null;
            }
        }
        logger.info("Couldn't find chi[0] for residue " + String.valueOf(residue));
        return null;
    }

    private void setState(Residue target, ResidueState state) {
        target.revertState(state);
        for (Torsion torsion : target.getTorsionList()) {
            torsion.update();
        }
    }

    private void writeSnapshot(String suffix) {
        if (!this.writeSnapshots) {
            return;
        }
        String filename = FilenameUtils.removeExtension((String)this.molecularAssembly.getFile().toString()) + "." + suffix + "-" + this.numMovesProposed;
        File file = new File(filename);
        PDBFilter writer = new PDBFilter(file, this.molecularAssembly, null, null);
        writer.writeFile(file, false);
    }
}

