#include "hmmlike.h"

// [[Rcpp::export]]

NumericMatrix cforwardP(const IntegerVector poslen, const int datalen, const int reglen,
        const IntegerVector csum, List logTs, const IntegerVector X, const IntegerVector M,
        const LogicalVector tn, const int c, const NumericMatrix kab, const NumericVector p,
        const NumericVector pi) {
    const NumericMatrix logTt = logTs["t"];
    const NumericMatrix logTnt = logTs["nt"];
    const NumericVector lphit = logTs["phit"];
    const NumericVector lphint = logTs["phint"];

    const int d = c + 1;

    NumericMatrix lP(4, 3 * datalen);
    int getpos;
    double logl[3];
    double colm[3];
    double totm;
    double mlog = 0;
    int nni = 0;
    double pbnbr[3];
    double lomp[3];
    double ninf = log(0.0);
    bool nicol[3];

    for (int r = 0; r < 3; r++) {
        lomp[r] = log(1 - pi[r]);
        pbnbr[r] = lnormcbnb(kab(r, 0), kab(r, 1), kab(r, 2), r);
    }

    for (int j = 0; j < reglen; j++) {
        for (int i = 0; i < poslen[j]; i++) {
            getpos = i + csum[j];
            for (int s = 0; s < 3; s++) {
                colm[s] = 0;
                nicol[s] = 0;
            }
            totm = 0;
            if (i == 0) {
                if (tn[getpos]) {
                    for (int r = 0; r < 4; r++) {
                        for (int s = 0; s < 3; s++) {
                            lP(r, s + 3 * getpos) = lphit[s];
                        }
                    }
                } else {
                    for (int r = 0; r < 4; r++) {
                        for (int s = 0; s < 3; s++) {
                            lP(r, s + 3 * getpos) = lphint[s];
                        }
                    }
                }
            } else {
                if (!tn[getpos]) {
                    for (int s = 0; s < 2; s++) {
                        logl[s] = lbnbr(X[getpos], d, kab(s, 0), kab(s, 1), kab(s, 2), s) - pbnbr[s];
                    }
                    logl[2] = ninf;
                } else {
                    for (int s = 0; s < 3; s++) {
                        logl[s] = fzib(X[getpos], M[getpos], p[s], pi[s], lomp[s]) +
                                lbnbr(X[getpos], d, kab(s, 0), kab(s, 1), kab(s, 2), s) - pbnbr[s];
                    }
                }
                if (tn[getpos]) {
                    for (int s = 0; s < 3; s++) {
                        for (int r = 0; r < 3; r++) {
                            lP(r, s + 3 * getpos) = lP(3, r + 3 * (getpos - 1)) + logl[s] + logTt(r, s);
                            colm[s] += exp(lP(r, s + 3 * getpos));
                        }
                    }
                } else {
                    for (int s = 0; s < 3; s++) {
                        for (int r = 0; r < 3; r++) {
                            lP(r, s + 3 * getpos) = lP(3, r + 3 * (getpos - 1)) + logl[s] + logTnt(r, s);
                            colm[s] += exp(lP(r, s + 3 * getpos));
                        }
                    }
                }
                for (int s = 0; s < 3; s++) {
                    totm += colm[s];
                    if (log(colm[s]) <= -700) {
                        nicol[s] = 1;
                    }
                }
                if (nicol[0] && nicol[1]) {
                    totm = 0;
                    mlog = 0;
                    nni = 0;
                    for (int s = 0; s < 3; s++) {
                        colm[s] = 0;
                        for (int r = 0; r < 3; r++) {
                            if (lP(r, s + 3 * getpos) > -99999999) {
                                mlog += lP(r, s + 3 * getpos);
                                nni += 1;
                            }
                        }
                    }
                    mlog /= nni;
                    for (int s = 0; s < 3; s++) {
                        for (int r = 0; r < 3; r++) {
                            colm[s] += exp(lP(r, s + 3 * getpos) - mlog);
                        }
                    }
                    for (int s = 0; s < 3; s++) {
                        totm += colm[s];
                    }
                }
                for (int s = 0; s < 3; s++) {
                    lP(3, s + 3 * getpos) = log(colm[s]) - log(totm);
                }
            }
        }
    }
    return lP;
}
