#include "bcjr_decoder.h"

BCJR_decoder::BCJR_decoder()
{

}

void BCJR_decoder::block_decode(vec apriori_uncoded, vec apriori_encoded, vec &extrinsic_uncoded)
{
    if (apriori_uncoded.length() != apriori_encoded.length())
        cout << "LLR sequences must have the same length !!!" << endl;

    // Set Log-MAP approximation method
    max_star.initialize(1);

    //% Matrix to describe the trellis
    //% Each row describes one transition in the trellis
    //% Each state is allocated an index 1,2,3,... Note that this list starts
    //% from 1 rather than 0.
    //%               FromState,        ToState,        UncodedBit,     EncodedBit
    /*mat transitions = "1                1               0               0;
                       2                5               0               0";
                       3                6               0               1;
                       4                2               0               1;
                       5                3               0               1;
                       6                7               0               1;
                       7                8               0               0;
                       8                4               0               0;
                       1                5               1               1;
                       2                1               1               1;
                       3                2               1               0;
                       4                6               1               0;
                       5                7               1               0;
                       6                3               1               0;
                       7                4               1               1;
                       8                8               1               1";
                       */
    mat transitions =  "1 1 0 0; 2 5 0 0; 3 6 0 1; 4 2 0 1; 5 3 0 1; 6 7 0 1; 7 8 0 0; 8 4 0 0; 1 5 1 1; 2 1 1 1; 3 2 1 0; 4 6 1 0; 5 7 1 0; 6 3 1 0; 7 4 1 1; 8 8 1 1";

    // Find the largest state index in the transitions matrix
    state_count = std::max(itpp::max(transitions.get_col(0)),itpp::max(transitions.get_col(1)));

    //% Calculate the a priori transition log-probabilities
    mat gammas = zeros(transitions.rows(), apriori_encoded.length());
    for (int i=0; i<apriori_encoded.length(); i++)
        for (int j=0; j<gammas.rows(); j++)
            gammas(j,i) = transitions(j,2)*apriori_uncoded(i) + transitions(j,3)*apriori_encoded(i);
    //cout << gammas << endl;

    //% Recursion to calculate forward state log-probabilities
    mat alphas = zeros(state_count, apriori_encoded.length());
    vec first_alphas = -INFINITY*ones(state_count);
    alphas.set_col(0, first_alphas);
    alphas(0,0) = 0.0;
    for (int bit_index=1; bit_index<apriori_encoded.length(); bit_index++)
    {
        vec temp = zeros(transitions.rows());
        for (int i=0; i<temp.length(); i++)
            temp(i) = alphas(transitions(i,0)-1,bit_index-1) + gammas(i,bit_index-1);

        for (int state_index=0; state_index<state_count; state_index++)
        {
            vec temp1(2);
            int temp1_index=0;
            for (int i=0; i<transitions.rows(); i++)
                if (transitions(i,1) == state_index+1)
                {
                    temp1(temp1_index) = temp(i);
                    temp1_index++;
                }
            alphas(state_index, bit_index) = max_star.approximation(temp1.left(temp1_index));
        }
    }

    //% Recursion to calculate backward state log-probabilities
    mat betas = zeros(state_count, apriori_encoded.length());
    vec last_betas = -INFINITY*ones(state_count);
    betas.set_col(betas.cols()-1, last_betas);
    betas(0,betas.cols()-1) = 0.0;
    for (int bit_index=apriori_encoded.length()-2; bit_index>=0; bit_index--)
    {
        vec temp = zeros(transitions.rows());
        for (int i=0; i<temp.length(); i++)
            temp(i) = betas(transitions(i,1)-1,bit_index+1) + gammas(i,bit_index+1);

        for (int state_index=0; state_index<state_count; state_index++)
        {
            vec temp2(2);
            int temp2_index=0;
            for (int i=0; i<transitions.rows(); i++)
                if (transitions(i,0) == state_index+1)
                {
                    temp2(temp2_index) = temp(i);
                    temp2_index++;
                }
            betas(state_index, bit_index) = max_star.approximation(temp2.left(temp2_index));
        }
    }

    //% Calculate a posteriori transition log-probabilities
    mat deltas = zeros(transitions.rows(), apriori_encoded.length());
    for (int i=0; i<deltas.cols(); i++)
        for (int j=0; j<deltas.rows(); j++)
            deltas(j,i) = alphas(transitions(j,0)-1,i) + betas(transitions(j,1)-1,i) + gammas(j,i);

    //% Calculate the uncoded extrinsic LLRs
    for (int bit_index=0; bit_index<extrinsic_uncoded.length(); bit_index++)
    {
        double log_p0 = max_star.approximation(deltas.get_col(bit_index).left(state_count));
        double log_p1 = max_star.approximation(deltas.get_col(bit_index).right(state_count));
        extrinsic_uncoded(bit_index) = log_p1 - log_p0 - apriori_uncoded(bit_index);
    }
}

void BCJR_decoder::block_decode(vec apriori_uncoded, vec apriori_encoded, vec &extrinsic_uncoded, vec &extrinsic_encoded)
{
    if (apriori_uncoded.length() != apriori_encoded.length())
        cout << "LLR sequences must have the same length !!!" << endl;

    // Set Log-MAP approximation method
    max_star.initialize(1);

    //% Matrix to describe the trellis
    //% Each row describes one transition in the trellis
    //% Each state is allocated an index 1,2,3,... Note that this list starts
    //% from 1 rather than 0.
    //%               FromState,        ToState,        UncodedBit,     EncodedBit
    /*mat transitions = "1                1               0               0;
                       2                5               0               0";
                       3                6               0               1;
                       4                2               0               1;
                       5                3               0               1;
                       6                7               0               1;
                       7                8               0               0;
                       8                4               0               0;
                       1                5               1               1;
                       2                1               1               1;
                       3                2               1               0;
                       4                6               1               0;
                       5                7               1               0;
                       6                3               1               0;
                       7                4               1               1;
                       8                8               1               1";
                       */
    mat transitions =  "1 1 0 0; 2 5 0 0; 3 6 0 1; 4 2 0 1; 5 3 0 1; 6 7 0 1; 7 8 0 0; 8 4 0 0; 1 5 1 1; 2 1 1 1; 3 2 1 0; 4 6 1 0; 5 7 1 0; 6 3 1 0; 7 4 1 1; 8 8 1 1";

    // Find the largest state index in the transitions matrix
    state_count = std::max(itpp::max(transitions.get_col(0)),itpp::max(transitions.get_col(1)));

    //% Calculate the a priori transition log-probabilities
    mat gammas = zeros(transitions.rows(), apriori_encoded.length());
    for (int i=0; i<apriori_encoded.length(); i++)
        for (int j=0; j<gammas.rows(); j++)
            gammas(j,i) = transitions(j,2)*apriori_uncoded(i) + transitions(j,3)*apriori_encoded(i);
    //cout << gammas << endl;

    //% Recursion to calculate forward state log-probabilities
    mat alphas = zeros(state_count, apriori_encoded.length());
    vec first_alphas = -INFINITY*ones(state_count);
    alphas.set_col(0, first_alphas);
    alphas(0,0) = 0.0;
    for (int bit_index=1; bit_index<apriori_encoded.length(); bit_index++)
    {
        vec temp = zeros(transitions.rows());
        for (int i=0; i<temp.length(); i++)
            temp(i) = alphas(transitions(i,0)-1,bit_index-1) + gammas(i,bit_index-1);

        for (int state_index=0; state_index<state_count; state_index++)
        {
            vec temp1(2);
            int temp1_index=0;
            for (int i=0; i<transitions.rows(); i++)
                if (transitions(i,1) == state_index+1)
                {
                    temp1(temp1_index) = temp(i);
                    temp1_index++;
                }
            alphas(state_index, bit_index) = max_star.approximation(temp1.left(temp1_index));
        }
    }

    //% Recursion to calculate backward state log-probabilities
    mat betas = zeros(state_count, apriori_encoded.length());
    vec last_betas = -INFINITY*ones(state_count);
    betas.set_col(betas.cols()-1, last_betas);
    betas(0,betas.cols()-1) = 0.0;
    for (int bit_index=apriori_encoded.length()-2; bit_index>=0; bit_index--)
    {
        vec temp = zeros(transitions.rows());
        for (int i=0; i<temp.length(); i++)
            temp(i) = betas(transitions(i,1)-1,bit_index+1) + gammas(i,bit_index+1);

        for (int state_index=0; state_index<state_count; state_index++)
        {
            vec temp2(2);
            int temp2_index=0;
            for (int i=0; i<transitions.rows(); i++)
                if (transitions(i,0) == state_index+1)
                {
                    temp2(temp2_index) = temp(i);
                    temp2_index++;
                }
            betas(state_index, bit_index) = max_star.approximation(temp2.left(temp2_index));
        }
    }

    //% Calculate a posteriori transition log-probabilities
    mat deltas = zeros(transitions.rows(), apriori_encoded.length());
    for (int i=0; i<deltas.cols(); i++)
        for (int j=0; j<deltas.rows(); j++)
            deltas(j,i) = alphas(transitions(j,0)-1,i) + betas(transitions(j,1)-1,i) + gammas(j,i);

    //% Calculate the uncoded extrinsic LLRs
    for (int bit_index=0; bit_index<extrinsic_uncoded.length(); bit_index++)
    {
        double log_p0 = max_star.approximation(deltas.get_col(bit_index).left(state_count));
        double log_p1 = max_star.approximation(deltas.get_col(bit_index).right(state_count));
        extrinsic_uncoded(bit_index) = log_p1 - log_p0 - apriori_uncoded(bit_index);
    }

    //% Calculate the uncoded extrinsic LLRs
    for (int bit_index=0; bit_index<extrinsic_uncoded.length(); bit_index++)
    {
        vec temp = deltas.get_col(bit_index);

        vec temp_p0 = temp;
        int temp_index0 = 0;
        for (int i=0; i<temp.length(); i++)
            if (transitions(i,3)==0)
            {

                temp_p0(temp_index0) = temp(i);
                temp_index0++;
            }

        vec temp_p1 = temp;
        int temp_index1 = 0;
        for (int i=0; i<temp.length(); i++)
            if (transitions(i,3)==1)
            {

                temp_p1(temp_index1) = temp(i);
                temp_index1++;
            }

        double log_p0 = max_star.approximation(temp_p0.left(state_count));
        double log_p1 = max_star.approximation(temp_p1.left(state_count));
        extrinsic_encoded(bit_index) = log_p1 - log_p0 - apriori_encoded(bit_index);
    }
}
