package ch.ethz.bsse.quasirecomb.model.hmm;

import ch.ethz.bsse.quasirecomb.distance.KullbackLeibler;
import ch.ethz.bsse.quasirecomb.informationholder.Globals;
import ch.ethz.bsse.quasirecomb.informationholder.ParallelJHMMStorage;
import ch.ethz.bsse.quasirecomb.informationholder.Read;
import ch.ethz.bsse.quasirecomb.informationholder.Threading;
import ch.ethz.bsse.quasirecomb.model.hmm.parallel.CallableReadHMMList;
import ch.ethz.bsse.quasirecomb.utils.Random;
import ch.ethz.bsse.quasirecomb.utils.StatusUpdate;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.javatuples.Pair;
import org.javatuples.Triplet;

/* loaded from: input_file:main/QuasiRecomb-1.0.jar:ch/ethz/bsse/quasirecomb/model/hmm/JHMM.class */
public class JHMM extends Garage {
    protected int Kmin;
    protected int N;
    protected int L;
    protected int K;
    protected int n;
    protected double[][] snv;
    protected double[][][] rho;
    protected double[][] pi;
    protected double[][][] mu;
    protected double[] eps;
    protected double[] antieps;
    protected double loglikelihood;
    protected double[][][] nJKL;
    protected double[][][] nJKV;
    protected double[] nneqPos;
    protected double[] muPrior;
    protected Read[] allReads;
    protected int restart;
    protected int[] coverage;
    protected int muChanged;
    protected int rhoChanged;
    protected boolean paired;
    protected double[][][] rho_old;
    protected double[][][] mu_old;
    private int oldFlatMu;
    private boolean biasMu;
    private int biasCounter;
    private int unBiasCounter;
    private int s;
    private double beta;
    List<Callable<Double>> callables;

    public JHMM(Read[] readArr, int i, int i2, int i3, int i4, double d, int i5) {
        this(readArr, i, i2, i3, i4, d, Random.generateInitRho(i2 - 1, i3), Random.generateInitPi(i2, i3), Random.generateMuInit(i2, i3, i4), i5);
    }

    public JHMM(Read[] readArr, int i, int i2, int i3, int i4, double d, double[][][] dArr, double[][] dArr2, double[][][] dArr3, int i5) {
        this.restart = 0;
        this.muChanged = 0;
        this.rhoChanged = 0;
        this.oldFlatMu = -1;
        this.biasMu = false;
        this.biasCounter = 0;
        this.unBiasCounter = 0;
        this.s = 0;
        this.beta = 1.0E-4d;
        this.callables = new LinkedList();
        this.eps = new double[i2];
        this.antieps = new double[i2];
        for (int i6 = 0; i6 < i2; i6++) {
            this.eps[i6] = d;
            this.antieps[i6] = 1.0d - ((i4 - 1) * d);
        }
        this.Kmin = i5;
        prepare(readArr, i, i2, i3, i4, dArr, dArr2, dArr3);
        compute();
    }

    public JHMM(Read[] readArr, int i, int i2, int i3, int i4, double[] dArr, double[][][] dArr2, double[][] dArr3, double[][][] dArr4, int i5) {
        this.restart = 0;
        this.muChanged = 0;
        this.rhoChanged = 0;
        this.oldFlatMu = -1;
        this.biasMu = false;
        this.biasCounter = 0;
        this.unBiasCounter = 0;
        this.s = 0;
        this.beta = 1.0E-4d;
        this.callables = new LinkedList();
        this.eps = dArr;
        this.antieps = new double[i2];
        for (int i6 = 0; i6 < i2; i6++) {
            this.antieps[i6] = 1.0d - ((i4 - 1) * dArr[i6]);
        }
        this.Kmin = i5;
        prepare(readArr, i, i2, i3, i4, dArr2, dArr3, dArr4);
        compute();
    }

    private void compute() {
        this.nJKL = new double[this.L][this.K][this.K];
        this.nJKV = new double[this.L][this.K][this.n];
        this.nneqPos = new double[this.L];
        eStep();
        mStep();
        this.s++;
    }

    private void eStep() {
        clearGarage(this.L, this.K, this.n);
        this.loglikelihood = 0.0d;
        if (this.callables.isEmpty()) {
            int length = this.allReads.length;
            int availableProcessors = Globals.getINSTANCE().getSTEPS() == 2 ? length / (Runtime.getRuntime().availableProcessors() - 1) : Globals.getINSTANCE().getSTEPS();
            int i = 0;
            while (true) {
                int i2 = i;
                if (i2 >= length) {
                    break;
                }
                int i3 = i2 + availableProcessors;
                if (i3 >= length) {
                    i3 = length;
                }
                this.callables.add(new CallableReadHMMList(this, (Read[]) Arrays.copyOfRange(this.allReads, i2, i3)));
                StatusUpdate.getINSTANCE().printPercentage(this.K, i2 / length, this.Kmin);
                i = i2 + availableProcessors;
            }
        }
        List list = null;
        try {
            list = Threading.getINSTANCE().getExecutor().invokeAll(this.callables);
        } catch (InterruptedException e) {
            Logger.getLogger(JHMM.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        for (int i4 = 0; i4 < list.size(); i4++) {
            try {
                this.loglikelihood += ((Double) ((Future) list.get(i4)).get()).doubleValue();
            } catch (InterruptedException | ExecutionException e2) {
                Logger.getLogger(JHMM.class.getName()).log(Level.SEVERE, (String) null, e2);
            }
        }
        updateExpectedCounts();
    }

    private void updateExpectedCounts() {
        if (Globals.getINSTANCE().isSTORAGE()) {
            ParallelJHMMStorage mergeGarage = mergeGarage();
            this.nJKL = new double[this.L][this.K][this.K];
            this.nJKV = new double[this.L][this.K][this.n];
            this.nneqPos = new double[this.L];
            for (int i = 0; i < this.L; i++) {
                for (int i2 = 0; i2 < this.K; i2++) {
                    for (int i3 = 0; i3 < this.K; i3++) {
                        double[] dArr = this.nJKL[i][i2];
                        int i4 = i3;
                        dArr[i4] = dArr[i4] + mergeGarage.getnJKL()[i][i2][i3];
                    }
                    for (int i5 = 0; i5 < this.n; i5++) {
                        double[] dArr2 = this.nJKV[i][i2];
                        int i6 = i5;
                        dArr2[i6] = dArr2[i6] + mergeGarage.getnJKV()[i][i2][i5];
                    }
                }
                double[] dArr3 = this.nneqPos;
                int i7 = i;
                dArr3[i7] = dArr3[i7] + mergeGarage.getNneqPos()[i];
            }
        }
    }

    public void computeSNVPosterior() {
        double[][][] dArr = new double[this.L][this.K][this.n];
        for (int i = 0; i < this.L; i++) {
            for (int i2 = 0; i2 < this.K; i2++) {
                double d = Double.MAX_VALUE;
                for (int i3 = 0; i3 < this.n; i3++) {
                    d = Math.min(this.nJKV[i][i2][i3], d);
                }
                for (int i4 = 0; i4 < this.n; i4++) {
                    dArr[i][i2][i4] = this.nJKV[i][i2][i4] - d;
                }
            }
        }
        for (int i5 = 0; i5 < this.L; i5++) {
            for (int i6 = 0; i6 < this.n; i6++) {
                this.snv[i5][i6] = 0.0d;
                for (int i7 = 0; i7 < this.K; i7++) {
                    double[] dArr2 = this.snv[i5];
                    int i8 = i6;
                    dArr2[i8] = dArr2[i8] + (dArr[i5][i7][i6] / this.coverage[i5]);
                }
            }
        }
    }

    private void maximizeMu() {
        for (int i = 0; i < this.L; i++) {
            for (int i2 = 0; i2 < this.K; i2++) {
                double[] deterministicAnnealing = Globals.getINSTANCE().isANNEALING() ? Regularizations.deterministicAnnealing(this.nJKV[i][i2], this.mu[i][i2], this.beta) : Globals.getINSTANCE().getINTERPOLATE_MU() > 0.0d ? Regularizations.step(this.nJKV[i][i2], this.mu[i][i2], Math.pow(Math.pow(this.s, 2.0d) + 2.0d, -Globals.getINSTANCE().getINTERPOLATE_MU()), Globals.getINSTANCE().isPAIRED()) : Regularizations.ml(this.nJKV[i][i2]);
                double mult_mu = Globals.getINSTANCE().getMULT_MU();
                double[] dArr = new double[this.n];
                System.arraycopy(this.muPrior, 0, dArr, 0, this.n);
                boolean z = true;
                do {
                    deterministicAnnealing = Regularizations.regularizeOnce(deterministicAnnealing, this.restart, dArr, mult_mu);
                    double d = deterministicAnnealing[0];
                    for (int i3 = 1; i3 < this.n; i3++) {
                        if (d != deterministicAnnealing[i3] && d > 0.0d) {
                            z = false;
                        }
                    }
                    if (z) {
                        mult_mu *= 2.0d;
                        for (int i4 = 0; i4 < this.n; i4++) {
                            int i5 = i4;
                            dArr[i5] = dArr[i5] * 10.0d;
                        }
                    }
                    if (mult_mu > 1000.0d) {
                        z = false;
                    }
                } while (z);
                if (Globals.getINSTANCE().isMAX()) {
                    double d2 = 0.0d;
                    int i6 = 0;
                    for (int i7 = 0; i7 < this.n; i7++) {
                        if (deterministicAnnealing[i7] > d2) {
                            d2 = deterministicAnnealing[i7];
                            i6 = i7;
                        }
                    }
                    int i8 = 0;
                    while (i8 < this.n) {
                        deterministicAnnealing[i8] = i8 == i6 ? 1.0d : 0.0d;
                        i8++;
                    }
                }
                for (int i9 = 0; i9 < this.n; i9++) {
                    changedMu(this.mu[i][i2][i9], deterministicAnnealing[i9]);
                    this.mu[i][i2][i9] = deterministicAnnealing[i9];
                    if (Double.isNaN(deterministicAnnealing[i9])) {
                        System.out.println("R nan, j " + i + ", k " + i2);
                        for (int i10 = 0; i10 < this.n; i10++) {
                            System.out.println(this.nJKV[i][i2][i9] + "\t" + deterministicAnnealing[i10]);
                        }
                    }
                }
            }
        }
    }

    private boolean maximizeRho() {
        boolean z = false;
        for (int i = 1; i < this.L; i++) {
            int i2 = 0;
            while (i2 < this.K) {
                double[] step = Globals.getINSTANCE().getINTERPOLATE_RHO() > 0.0d ? Regularizations.step(this.nJKL[i][i2], this.rho[i - 1][i2], Math.pow(Math.pow(this.s, 2.0d) + 2.0d, -Globals.getINSTANCE().getINTERPOLATE_RHO()), false) : Regularizations.ml(this.nJKL[i][i2]);
                double d = 0.0d;
                int i3 = -1;
                double d2 = 0.0d;
                double[] dArr = new double[this.K];
                for (int i4 = 0; i4 < this.K; i4++) {
                    dArr[i4] = this.nJKL[i][i2][i4];
                    d2 += dArr[i4];
                }
                for (int i5 = 0; i5 < this.K; i5++) {
                    int i6 = i5;
                    dArr[i6] = dArr[i6] / d2;
                    if (dArr[i5] > d) {
                        d = dArr[i5];
                        i3 = i5;
                    }
                }
                double mult_rho = Globals.getINSTANCE().getMULT_RHO();
                double[] dArr2 = new double[this.K];
                boolean z2 = false;
                int i7 = 0;
                while (true) {
                    if (i7 >= this.K) {
                        break;
                    }
                    if (i2 != i7) {
                        dArr2[i7] = Globals.getINSTANCE().getALPHA_Z();
                    } else {
                        if (d > 0.5d && i3 != i2 && Globals.getINSTANCE().isSPIKERHO()) {
                            dArr2[i7] = 100.0d;
                            z2 = true;
                            z = true;
                            break;
                        }
                        dArr2[i7] = Globals.getINSTANCE().getALPHA_Z() * 10.0d;
                    }
                    i7++;
                }
                if (!z2) {
                    step = Regularizations.regularizeOnceRho(i2, step, this.restart, dArr2, mult_rho);
                    int i8 = -1;
                    double d3 = -1.0d;
                    for (int i9 = 0; i9 < this.K; i9++) {
                        if (step[i9] > d3) {
                            d3 = step[i9];
                            i8 = i9;
                        }
                    }
                    if (d3 > 0.5d && i8 != i2 && Globals.getINSTANCE().isSPIKERHO()) {
                        z2 = true;
                        z = true;
                    }
                }
                if (z2) {
                    step = new double[this.K];
                    int i10 = 0;
                    while (i10 < this.K) {
                        step[i10] = i10 == i2 ? 1.0d : 0.0d;
                        i10++;
                    }
                }
                for (int i11 = 0; i11 < this.K; i11++) {
                    changedRho(this.rho[i - 1][i2][i11], step[i11]);
                    this.rho[i - 1][i2][i11] = step[i11];
                }
                i2++;
            }
        }
        return z;
    }

    private void maximizePi() {
        StringBuilder sb = new StringBuilder();
        double[][] dArr = new double[this.L][this.K];
        double pow = Math.pow(this.s + 2, -1.0d);
        for (int i = 0; i < this.L; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.K; i2++) {
                double d2 = 0.0d;
                for (int i3 = 0; i3 < this.n; i3++) {
                    double[] dArr2 = dArr[i];
                    int i4 = i2;
                    dArr2[i4] = dArr2[i4] + this.nJKV[i][i2][i3];
                    d2 += this.nJKV[i][i2][i3];
                    d += this.nJKV[i][i2][i3];
                }
                sb.append(d2).append("\t");
            }
            for (int i5 = 0; i5 < this.K; i5++) {
                double[] dArr3 = dArr[i];
                int i6 = i5;
                dArr3[i6] = dArr3[i6] / d;
            }
            this.pi[i] = Regularizations.step(dArr[i], this.pi[i], pow, false);
            sb.setLength(sb.length() - 1);
            sb.append("\n");
        }
    }

    private void mStep() {
        computeSNVPosterior();
        boolean maximizeRho = Globals.getINSTANCE().isNO_RECOMB() ? false : maximizeRho();
        maximizePi();
        maximizeMu();
        if (maximizeRho) {
            for (int i = 0; i < this.L; i++) {
                for (int i2 = 0; i2 < this.K; i2++) {
                    double d = 0.0d;
                    for (int i3 = 0; i3 < this.n; i3++) {
                        double[] dArr = this.mu[i][i2];
                        int i4 = i3;
                        dArr[i4] = dArr[i4] + (Math.random() / 100.0d);
                        d += this.mu[i][i2][i3];
                    }
                    for (int i5 = 0; i5 < this.n; i5++) {
                        double[] dArr2 = this.mu[i][i2];
                        int i6 = i5;
                        dArr2[i6] = dArr2[i6] / d;
                    }
                }
            }
        }
        if (Globals.getINSTANCE().isBIAS_MU()) {
            int muFlats = getMuFlats();
            if (this.biasCounter < 1) {
                this.biasMu = true;
            } else {
                this.unBiasCounter++;
                this.biasMu = false;
                if (this.unBiasCounter > 200) {
                    this.biasCounter = 0;
                    this.unBiasCounter = 0;
                }
            }
            if (this.oldFlatMu != muFlats) {
                System.out.print("DIFF\t");
            } else if (this.biasMu) {
                biasMu();
                System.out.print("BIAS\t");
                this.biasCounter++;
            } else {
                System.out.print("UNB\t");
            }
            this.oldFlatMu = muFlats;
        }
        if (Globals.getINSTANCE().isFLAT_EPSILON_PRIOR()) {
            return;
        }
        maximizeEps();
    }

    public void biasMu() {
        for (int i = 0; i < this.L; i++) {
            boolean z = false;
            int i2 = 0;
            while (true) {
                if (i2 >= this.K) {
                    break;
                }
                double d = 0.0d;
                double d2 = 0.0d;
                for (int i3 = 0; i3 < this.n; i3++) {
                    d = Math.max(this.mu[i][i2][i3], d);
                    d2 += this.mu[i][i2][i3];
                }
                if (d < d2) {
                    z = true;
                    break;
                }
                i2++;
            }
            if (z) {
                for (int i4 = 0; i4 < this.K; i4++) {
                    double d3 = 0.0d;
                    for (int i5 = 0; i5 < this.n; i5++) {
                        double[] dArr = this.mu[i][i4];
                        int i6 = i5;
                        dArr[i6] = dArr[i6] + (Math.random() / 10.0d);
                        d3 += this.mu[i][i4][i5];
                    }
                    for (int i7 = 0; i7 < this.n; i7++) {
                        double[] dArr2 = this.mu[i][i4];
                        int i8 = i7;
                        dArr2[i8] = dArr2[i8] / d3;
                    }
                }
            }
        }
    }

    private void maximizeEps() {
        for (int i = 0; i < this.L; i++) {
            this.eps[i] = Regularizations.f(this.nneqPos[i] + 20.0d) / Regularizations.f(((this.coverage[i] * (this.n - 1)) + 20.0d) + 2357.0d);
            if (this.eps[i] > 1.0d / this.n) {
                this.eps[i] = 0.05d;
            }
            this.antieps[i] = 1.0d - ((this.n - 1) * this.eps[i]);
        }
    }

    public void restart() {
        this.restart++;
        this.muChanged = 0;
        this.rhoChanged = 0;
        compute();
    }

    public Triplet<Integer, Integer, Double> minKL() {
        HashSet hashSet = new HashSet();
        double d = Double.MAX_VALUE;
        Pair pair = null;
        for (int i = 0; i < this.K; i++) {
            for (int i2 = 0; i2 < this.K; i2++) {
                if (i != i2 && !hashSet.contains(Pair.with(Integer.valueOf(i), Integer.valueOf(i2))) && !hashSet.contains(Pair.with(Integer.valueOf(i2), Integer.valueOf(i)))) {
                    hashSet.add(Pair.with(Integer.valueOf(i), Integer.valueOf(i2)));
                    double symmetric = KullbackLeibler.symmetric(this.mu, i, i2);
                    if (symmetric < d) {
                        d = symmetric;
                        pair = Pair.with(Integer.valueOf(i), Integer.valueOf(i2));
                    }
                }
            }
        }
        return Triplet.with(pair.getValue0(), pair.getValue1(), Double.valueOf(d));
    }

    protected void changedMu(double d, double d2) {
        if (Math.abs(d - d2) > Globals.getINSTANCE().getPCHANGE()) {
            this.muChanged++;
        }
    }

    protected void changedRho(double d, double d2) {
        if (Math.abs(d - d2) > Globals.getINSTANCE().getPCHANGE()) {
            this.rhoChanged++;
        }
    }

    protected final void prepare(Read[] readArr, int i, int i2, int i3, int i4, double[][][] dArr, double[][] dArr2, double[][][] dArr3) {
        this.N = i;
        this.L = i2;
        this.K = i3;
        this.n = i4;
        this.allReads = readArr;
        this.rho = dArr;
        this.rho_old = dArr;
        this.mu = dArr3;
        this.mu_old = dArr3;
        this.pi = dArr2;
        this.snv = new double[i2][i4];
        this.muPrior = new double[i4];
        for (int i5 = 0; i5 < i4; i5++) {
            this.muPrior[i5] = Globals.getINSTANCE().getALPHA_H();
        }
        this.coverage = Globals.getINSTANCE().getTAU_OMEGA().getCoverage();
        this.paired = Globals.getINSTANCE().isPAIRED();
    }

    public int getMuFlats() {
        int i = 0;
        for (int i2 = 0; i2 < this.L; i2++) {
            for (int i3 = 0; i3 < this.K; i3++) {
                double d = 0.0d;
                double d2 = 0.0d;
                for (int i4 = 0; i4 < this.n; i4++) {
                    d = Math.max(this.mu[i2][i3][i4], d);
                    d2 += this.mu[i2][i3][i4];
                }
                if (d < d2) {
                    i++;
                }
            }
        }
        return i;
    }

    public int getNjkvFlats() {
        int i = 0;
        for (int i2 = 0; i2 < this.L; i2++) {
            for (int i3 = 0; i3 < this.K; i3++) {
                double d = 0.0d;
                double d2 = 0.0d;
                for (int i4 = 0; i4 < this.n; i4++) {
                    d = Math.max(this.nJKV[i2][i3][i4], d);
                    d2 += this.nJKV[i2][i3][i4];
                }
                if (d < d2) {
                    i++;
                }
            }
        }
        return i;
    }

    public int getRhoFlats() {
        int i = 0;
        for (int i2 = 0; i2 < this.L - 1; i2++) {
            for (int i3 = 0; i3 < this.K; i3++) {
                double d = 0.0d;
                double d2 = 0.0d;
                for (int i4 = 0; i4 < this.K; i4++) {
                    d = Math.max(this.rho[i2][i3][i4], d);
                    d2 += this.rho[i2][i3][i4];
                }
                if (d < d2) {
                    i++;
                }
            }
        }
        return i;
    }

    public int getNjklFlats() {
        int i = 0;
        for (int i2 = 0; i2 < this.L - 1; i2++) {
            for (int i3 = 0; i3 < this.K; i3++) {
                double d = 0.0d;
                double d2 = 0.0d;
                for (int i4 = 0; i4 < this.K; i4++) {
                    d = Math.max(this.nJKL[i2][i3][i4], d);
                    d2 += this.nJKL[i2][i3][i4];
                }
                if (d < d2) {
                    i++;
                }
            }
        }
        return i;
    }

    public int getK() {
        return this.K;
    }

    public int getL() {
        return this.L;
    }

    public int getN() {
        return this.N;
    }

    public int getn() {
        return this.n;
    }

    public double[] getEps() {
        return this.eps;
    }

    public double[] getAntieps() {
        return this.antieps;
    }

    public double getLoglikelihood() {
        return this.loglikelihood;
    }

    public double[][][] getMu() {
        return this.mu;
    }

    public double[][] getPi() {
        return this.pi;
    }

    public double[][][] getRho() {
        return this.rho;
    }

    public int getRestart() {
        return this.restart;
    }

    public int getMuChanged() {
        return this.muChanged;
    }

    public int getRhoChanged() {
        return this.rhoChanged;
    }

    public double[][] getSnv() {
        return this.snv;
    }

    public double getBeta() {
        return this.beta;
    }

    public void incBeta(double d) {
        this.beta *= d;
    }

    public void setBeta(double d) {
        this.beta = d;
    }
}
