#include "cvpbas.hpp"

#define N 50
#define R_lower 18
#define Raute_min 2
#define T_lower 2
#define T_upper 200
#define R_scale 5
#define R_incdec 0.05
#define T_dec 0.05
#define T_inc 1

cvPBAS::cvPBAS(void)
{
    alpha = 10.0;
    beta = 1.0;
    formerMeanNorm = 0;
    width = 0;
    foregroundValue = 255;
    backgroundValue = 0;
    countOfRandomNumb = 1000;
    T_init = R_lower;
    runs = 0;
    newInitialization();
}

void cvPBAS::newInitialization()
{
    if(!randomN.empty())
        randomN.clear();

    if(!randomX.empty())
        randomX.clear();

    if(!randomY.empty())
        randomY.clear();

    if(!randomMinDist.empty())
        randomMinDist.clear();

    if(!randomT.empty())
        randomT.clear();

    if(!randomTN.empty())
        randomTN.clear();

    for(int l = 0; l < countOfRandomNumb; l++)
    {
        randomN.push_back((int)randomGenerator.uniform((int)0,(int)N));

        randomX.push_back((int)randomGenerator.uniform(-1, +2));

        randomY.push_back((int)randomGenerator.uniform(-1, +2));

        randomMinDist.push_back((int)randomGenerator.uniform((int)0, (int)N));

        randomT.push_back((int)randomGenerator.uniform((int)0, (int)T_upper));

        randomTN.push_back((int)randomGenerator.uniform((int)0, (int)T_upper));
    }
}
cvPBAS::~cvPBAS(void)
{
    randomN.clear();
    randomX.clear();
    randomY.clear();
    randomMinDist.clear();
    randomT.clear();
    randomTN.clear();


    for(int k = 0; k < backgroundModel.size(); ++k)
    {
        if(chans == 1)
        {
            backgroundModel.at(k).at(0).release();
            backgroundModel.at(k).at(1).release();
        }
        else
        {
            backgroundModel.at(k).at(0).release();
            backgroundModel.at(k).at(1).release();
            backgroundModel.at(k).at(2).release();

            backgroundModel.at(k).at(3).release();
            backgroundModel.at(k).at(4).release();
            backgroundModel.at(k).at(5).release();
        }

    }

    backgroundModel.clear();
    meanMinDist.release();

    actualR.release();
    actualT.release();

    sobelX.release();
    sobelY.release();
}


bool cvPBAS::process(cv::Mat* input, cv::Mat* output)
{
    if(width != input->cols)
    {
        width = input->cols;
        chans = input->channels();
        height = input->rows;

        if(input->rows < 1  || input->cols < 1)
        {
            std::cout << "Error: Occurrence of to small (or empty?) image size in cvPBAS. STOPPING " << std::endl;
            return false;
        }

    }

    init(input);

    resultMap = new cv::Mat(input->rows,input->cols,CV_8UC1);

    calculateFeatures(&currentFeatures, input);

    sumMagnitude = 0;
    long glCounterFore = 0;

    for (int j=0; j< resultMap->rows; ++j)
    {

        resultMap_Pt = resultMap->ptr<uchar>(j);
        currentFeaturesM_Pt.clear();
        currentFeaturesC_Pt.clear();
        std::vector<float*> fT;
        std::vector<uchar*> uT;
        B_Mag_Pts.clear();
        B_Col_Pts.clear();
        for(int z = 0; z < chans; ++z)
        {
            currentFeaturesM_Pt.push_back(currentFeatures.at(z).ptr<float>(j));
            currentFeaturesC_Pt.push_back(currentFeatures.at(z + chans).ptr<uchar>(j));

            B_Mag_Pts.push_back(fT);

            B_Col_Pts.push_back(uT);
        }
        meanMinDist_Pt = meanMinDist.ptr<float>(j);
        actualR_Pt = actualR.ptr<float>(j);
        actualT_Pt = actualT.ptr<float>(j);

        for(int k = 0; k < runs; ++k)
        {
            for(int z = 0; z < chans; ++z)
            {
                    B_Mag_Pts.at(z).push_back(backgroundModel.at(k).at(z).ptr<float>(j));
                    B_Col_Pts.at(z).push_back(backgroundModel.at(k).at(z+chans).ptr<uchar>(j));
            }
        }


        for(int i = 0; i < resultMap->cols; ++i)
        {
            int count = 0;
            int index = 0;

            double norm = 0.0;
            double dist = 0.0;
            double minDist = 1000.0;
            int entry = randomGenerator.uniform(3, countOfRandomNumb-4);

            do
            {
                if(chans == 3)
                {
                    norm = sqrt(
                        (((double)B_Mag_Pts.at(0).at(index)[i] - ((double)*currentFeaturesM_Pt.at(0)))*((double)B_Mag_Pts.at(0).at(index)[i] - ((double)*currentFeaturesM_Pt.at(0))))+
                        (((double)B_Mag_Pts.at(1).at(index)[i] - ((double)*currentFeaturesM_Pt.at(1)))*((double)B_Mag_Pts.at(1).at(index)[i] - ((double)*currentFeaturesM_Pt.at(1))))+
                        (((double)B_Mag_Pts.at(2).at(index)[i] - ((double)*currentFeaturesM_Pt.at(2)))*((double)B_Mag_Pts.at(2).at(index)[i] - ((double)*currentFeaturesM_Pt.at(2))))

                    );

                    dist = sqrt(
                    (((double)B_Col_Pts.at(0).at(index)[i] - ((double)*currentFeaturesC_Pt.at(0)))*((double)B_Col_Pts.at(0).at(index)[i] - ((double)*currentFeaturesC_Pt.at(0))))+
                    (((double)B_Col_Pts.at(1).at(index)[i] - ((double)*currentFeaturesC_Pt.at(1)))*((double)B_Col_Pts.at(1).at(index)[i] - ((double)*currentFeaturesC_Pt.at(1))))+
                    (((double)B_Col_Pts.at(2).at(index)[i] - ((double)*currentFeaturesC_Pt.at(2)))*((double)B_Col_Pts.at(2).at(index)[i] - ((double)*currentFeaturesC_Pt.at(2))))
                    );

                }
                else
                {
                        norm = abs((((double)B_Mag_Pts.at(0).at(index)[i] - ((double)*currentFeaturesM_Pt.at(0)))*((double)B_Mag_Pts.at(0).at(index)[i] - ((double)*currentFeaturesM_Pt.at(0)))));

                    dist = abs((((double)B_Col_Pts.at(0).at(index)[i] - ((double)*currentFeaturesC_Pt.at(0)))*((double)B_Col_Pts.at(0).at(index)[i] - ((double)*currentFeaturesC_Pt.at(0))))
                    );


                }
                dist = ((double)alpha*(norm/formerMeanMag) + beta*dist);

                if((dist < *actualR_Pt))
                {
                    ++count;
                    if(minDist > dist)
                        minDist = dist;
                }
                else
                {
                    sumMagnitude += (double)(norm);
                    ++glCounterFore;
                }
                ++index;
            }
            while((count< Raute_min) && (index < runs));

            if(count >= Raute_min)
            {
                *resultMap_Pt = 0;
                double ratio = std::ceil((double)T_upper/(double)(*actualT_Pt));
                if(runs < N && runs > 2)
                {
                    *meanMinDist_Pt = ((((float)(runs-1)) * (*meanMinDist_Pt)) + (float)minDist) / ((float)runs);
                }
                else if(runs < N && runs == 2)
                {
                    *meanMinDist_Pt = (float)minDist;
                }

                if(runs == N)
                {
                    if(randomT.at(entry) < ratio)
                    {
                        int rand = randomN.at(entry+1);
                        for(int z = 0; z < chans; ++z)
                        {
                            B_Mag_Pts.at(z).at(rand)[i] = (float)*currentFeaturesM_Pt.at(z);
                            B_Col_Pts.at(z).at(rand)[i] = (uchar)*currentFeaturesC_Pt.at(z);

                        }

                        *meanMinDist_Pt = ((((float)(N-1)) * (*meanMinDist_Pt)) + (float)minDist) / ((float)N);
                    }

                    if(randomTN.at(entry) < ratio)
                    {
                        int xNeigh = randomX.at(entry)+i;
                        int yNeigh = randomY.at(entry)+j;
                        checkValid(&xNeigh, &yNeigh);

                        int rand = randomN.at(entry-1);
                        for(int z = 0; z < chans; ++z)
                        {
                            (backgroundModel.at(rand)).at(z).at<float>(yNeigh,xNeigh) = currentFeatures.at(z).at<float>(yNeigh,xNeigh);
                            (backgroundModel.at(rand)).at(z + chans).at<uchar>(yNeigh,xNeigh) = currentFeatures.at(z+chans).at<uchar>(yNeigh,xNeigh);
                        }

                    }
                }
            }
            else
            {
                *resultMap_Pt = 255;
            }

            decisionThresholdRegulator(actualR_Pt,meanMinDist_Pt);

            learningRateRegulator(actualT_Pt, meanMinDist_Pt,resultMap_Pt);

            ++resultMap_Pt;
            for(int z = 0; z < chans; ++z)
            {
                ++currentFeaturesM_Pt.at(z);
                ++currentFeaturesC_Pt.at(z);
            }
            ++meanMinDist_Pt;
            ++actualR_Pt;
            ++actualT_Pt;
        }
    }

    resultMap->copyTo(*output);
    double meanMag = sumMagnitude/(double)(glCounterFore + 1);
    if(meanMag > 20)
        formerMeanMag = meanMag;
    else
        formerMeanMag = 20;



    delete resultMap;
    for(int z = 0; z < chans; ++z)
    {
        currentFeatures.at(z+chans).release();
        currentFeatures.at(z).release();
    }

    return true;
}


void cvPBAS::decisionThresholdRegulator(float* pt, float* meanDist)
{
    double tempR = *pt;
    double newThresh = (*meanDist)*R_scale;

    if( tempR < newThresh)
    {
        tempR += tempR * R_incdec;
    }
    else
    {
        tempR -= tempR * R_incdec;
    }

    if(tempR >= R_lower)
        *pt = (float)tempR;
    else
        *pt = (float)R_lower;
}
void cvPBAS::learningRateRegulator(float* pt, float* meanDist,uchar* isFore)
{
    double tempT = *pt;
    if((int)*isFore < 128)
    {
        tempT -= T_inc/(*meanDist+1.0);
    }
    else
    {
        tempT += T_dec/(*meanDist+1.0);
    }
    if(tempT > T_lower && tempT < T_upper)
        *pt = (float)tempT;

}

void cvPBAS::checkValid(int *x, int *y)
{
    if(*x < 0)
    {
        *x = 0;
    }
    else if(*x >= width)
    {
        *x = width -1;
    }

    if(*y < 0)
    {
        *y = 0;
    }
    else if(*y >= height)
    {
        *y = height - 1;
    }
}

void cvPBAS::init(cv::Mat* input)
{
    if(runs < N)
    {

        std::vector<cv::Mat> init;
        calculateFeatures(&init, input);
        backgroundModel.push_back(init);

        if(chans == 1)
        {
            init.at(0).release();
            init.at(1).release();
        }
        else
        {
            init.at(0).release();
            init.at(1).release();
            init.at(2).release();
            init.at(3).release();
            init.at(4).release();
            init.at(5).release();
        }
        init.clear();

        if(runs == 0)
        {

                meanMinDist.create(input->size(), CV_32FC1);
                meanMinDist.zeros(input->rows, input->cols, CV_32FC1);

                actualR.create(input->rows, input->cols, CV_32FC1);
                actualT.create(input->rows, input->cols, CV_32FC1);

            float* ptRs, *ptTs; //, *ptM;
            for(int rows = 0; rows < actualR.rows; ++rows)
            {

                    ptRs = actualR.ptr<float>(rows);
                    ptTs = actualT.ptr<float>(rows);

                for(int cols = 0; cols < actualR.cols; ++cols)
                {
                        ptRs[cols] = (float)R_lower;
                        ptTs[cols] = (float)T_init;
                }
            }
        }

        ++runs;
    }
}

void cvPBAS::calculateFeatures(std::vector<cv::Mat>* feature, cv::Mat* inputImage)
{
    if(!feature->empty())
        feature->clear();


    cv::Mat mag[3], dir;

    if(inputImage->channels() == 3)
    {
        std::vector<cv::Mat> rgbChannels(3);
        cv::split(*inputImage, rgbChannels);

        for(int l = 0; l < 3; ++l)
        {
            cv::Sobel(rgbChannels.at(l),sobelX,CV_32F,1,0, 3, 1, 0.0);
            cv::Sobel(rgbChannels.at(l),sobelY,CV_32F,0,1, 3, 1, 0.0);

            cv::cartToPolar(sobelX,sobelY,mag[l],dir, true);
            feature->push_back(mag[l]);
            sobelX.release();
            sobelY.release();
        }

        feature->push_back(rgbChannels.at(0));
        feature->push_back(rgbChannels.at(1));
        feature->push_back(rgbChannels.at(2));
        rgbChannels.at(0).release();
        rgbChannels.at(1).release();
        rgbChannels.at(2).release();

    }
    else
    {
        cv::Sobel(*inputImage,sobelX,CV_32F,1,0, 3, 1, 0.0);
        cv::Sobel(*inputImage,sobelY,CV_32F,0,1, 3, 1, 0.0);
        cv::cartToPolar(sobelX,sobelY,mag[0],dir, true);
        feature->push_back(mag[0]);

        cv::Mat temp;
        inputImage->copyTo(temp);
        feature->push_back(temp);
        temp.release();
    }


    mag[0].release();
    mag[1].release();
    mag[2].release();
    dir.release();
}

