/*---------------------------------------------------------------------------*\

    DAFoam  : Discrete Adjoint with OpenFOAM
    Version : v3

\*---------------------------------------------------------------------------*/

#include "DABoxAvgObjFunc.H"
#include "fvCFD.H"
#include "fvMesh.H"
#include "runTimeSelectionTables.H"

// * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * //

using namespace std;
namespace Foam

{
defineTypeNameAndDebug(DABoxAvgObjFunc, 0);
addToRunTimeSelectionTable(DAObjFunc, DABoxAvgObjFunc, dictionary);
// * * * * * * * * * * * * * * * * Constructors  * * * * * * * * * * * * * * //

DABoxAvgObjFunc::DABoxAvgObjFunc(
    const fvMesh& mesh,
    const DAOption& daOption,
    const DAModel& daModel,
    const DAIndex& daIndex,
    const DAResidual& daResidual,
    const word objFuncName,
    const word objFuncPart,
    const dictionary& objFuncDict)
    : DAObjFunc(
        mesh,
        daOption,
        daModel,
        daIndex,
        daResidual,
        objFuncName,
        objFuncPart,
        objFuncDict),
      daTurb_(daModel.getDATurbulenceModel())

{
    /*
    Description:
        Calculate the stateErrorNorm
        f = scale * L2Norm( state-stateRef )*V/N
    Input:
        objFuncFaceSources: List of face source (index) for this objective
    
        objFuncCellSources: List of cell source (index) for this objective
    Output:
        objFuncFaceValues: the discrete value of objective for each face source (index). 
        This  will be used for computing df/dw in the adjoint.
    
        objFuncCellValues: the discrete value of objective on each cell source (index). 
        This will be used for computing df/dw in the adjoint.
    
        objFuncValue: the sum of objective, reduced across all processors and scaled by "scale"
    */

    objFuncDict_.readEntry<word>("type", objFuncType_);
    data_ = objFuncDict_.getWord("data");
    scale_ = objFuncDict_.getScalar("scale");
    objFuncDict_.readEntry<bool>("weightedSum", weightedSum_);
    if (weightedSum_ == true)
    {
        objFuncDict_.readEntry<scalar>("weight", weight_);
    }
    
    // setup the connectivity, this is needed in Foam::DAJacCondFdW
    // this objFunc only depends on the state variable at the zero level cell
    if (DAUtility::isInList<word>(stateName_, daIndex.adjStateNames))
    {
        objFuncConInfo_ = {{stateName_}}; // level 0
    }
    else
    {
        objFuncConInfo_ = {{}}; // level 0
    }
}

/// calculate the value of objective function
void DABoxAvgObjFunc::calcObjFunc(
    const labelList& objFuncFaceSources,
    const labelList& objFuncCellSources,
    scalarList& objFuncFaceValues,
    scalarList& objFuncCellValues,
    scalar& objFuncValue)
{	
    /*
    Description:
        Calculate the stateErrorNorm
        f = scale * L2Norm( state-stateRef )

    Input:
        objFuncFaceSources: List of face source (index) for this objective
    
        objFuncCellSources: List of cell source (index) for this objective

    Output:
        objFuncFaceValues: the discrete value of objective for each face source (index). 
        This  will be used for computing df/dw in the adjoint.
    
        objFuncCellValues: the discrete value of objective on each cell source (index). 
        This will be used for computing df/dw in the adjoint.
    
        objFuncValue: the sum of objective, reduced across all processors and scaled by "scale"
    */
   

    forAll(objFuncCellValues, idxI)
        {
          objFuncCellValues[idxI] = 0.0;
        }
    // initialize objFunValue
    objFuncValue = 0.0; 

    const objectRegistry& db = mesh_.thisDb();

    if (data_ == "beta")
    {
        stateName_ = "betaFieldInversion";
        const volScalarField betaFieldInversion_ = db.lookupObject<volScalarField>(stateName_);
        
        forAll(objFuncCellSources, idxI)
        {
            const label& cellI = objFuncCellSources[idxI];
            objFuncCellValues[idxI] = scale_ * (sqr(betaFieldInversion_[cellI] - 1));
            objFuncValue += objFuncCellValues[idxI];
        }
        // need to reduce the sum of all objectives across all processors
        reduce(objFuncValue, sumOp<scalar>());

        if (weightedSum_ == true)
        {
            objFuncValue = weight_ * objFuncValue;
        }
    }
    else if (data_ == "UData")
    {
      
    
    // user input in DAOptions read here
    scalarList UAv = daOption_.getOption<scalarList>("UAv");
    scalarList VAv = daOption_.getOption<scalarList>("VAv");
    
    const volVectorField state = db.lookupObject<volVectorField>("U");
    
    scalarList ux = state.component(0);
    scalarList uy = state.component(1);
    scalarList uz = state.component(2);
    
    // set the sizes of the box averaged lists to ensure contiguous memory
    List<scalar> ubavg(indices.size());
    List<scalar> vbavg(indices.size());
    List<scalar> wbavg(indices.size());

    // this block does box averaging
    forAll(indices, idx_exp)
    {
        scalar tempVelU = 0.0;
        scalar tempVelV = 0.0;
        scalar tempVelW = 0.0;
        scalar tempVol = 0.0;

        List<label> tempIndExp = indices[idx_exp];
        List<scalar> intersectVols_t = intersectVols[idx_exp];


        forAll(tempIndExp, idx_cfd)
        {
            label tempIndCFD = tempIndExp[idx_cfd];
            label listSize = intersectVols[idx_exp].size();

            if (listSize > 0)
            {
                tempVelU = tempVelU + ux[tempIndCFD]*intersectVols_t[idx_cfd];
                tempVelV = tempVelV + uy[tempIndCFD]*intersectVols_t[idx_cfd];
                tempVelW = tempVelW + uz[tempIndCFD]*intersectVols_t[idx_cfd];
                tempVol = tempVol + intersectVols_t[idx_cfd];    
            }
                
        }
        if (tempVol > 0)
        {
            ubavg[idx_exp] = tempVelU/tempVol;
            vbavg[idx_exp] = tempVelV/tempVol;
            wbavg[idx_exp] = tempVelW/tempVol;
         
        }
    }

 
    // temporary objective function cell values list
    scalarList objFuncValues_t(indices.size()); 

    forAll(objFuncValues_t, idx)
    {
        objFuncValues_t[idx] = 0.0;
    }
    
    // calculate the objective function cell values on the experimental grid
    forAll(objFuncValues_t, index)
    {
        if(UAv[index]!=0)
            objFuncValues_t[index] = 0.5*(sqr(scale_ * ubavg[index] - UAv[index])  + sqr(scale_ * vbavg[index] - VAv[index]) + sqr(scale_ * wbavg[index]));
    }

    List<List<scalar>> error_tmp(mesh_.nCells());
    List<scalar> error_OG(mesh_.nCells());

    // do the broadcasting
    forAll(indices, idx)
    {
        labelList tmpInd = indices[idx];
        scalarList tmpweights = volWeights[idx];
        forAll(tmpInd, idx2)
        {
            label tmpIndtmp = tmpInd[idx2];
            scalar tmpweightstmp = tmpweights[idx2];

            error_tmp[tmpIndtmp].append(tmpweightstmp*objFuncValues_t[idx]);
        }
    }

    forAll(error_tmp, idx)
    {
        scalarList tmp = error_tmp[idx];
        forAll(tmp, idx2)
        {
            error_OG[idx] = error_OG[idx] + tmp[idx2];
        }
    }

    // calculate the objective function value using the broadcasted cell values
    forAll(error_OG, idx)
    {
        objFuncCellValues[idx] = error_OG[idx];
        objFuncValue = objFuncValue + objFuncCellValues[idx];
    }
    reduce(objFuncValue, sumOp<scalar>());

    if (weightedSum_ == true)
        {
            objFuncValue = weight_ * objFuncValue;
        }
    }

    
    else
    {
        FatalErrorIn("") << "dataType: " << data_
                            << " not supported for field inversion! "
                            << "Available options are: UData, pData, surfacePressureData, surfaceFrictionData, aeroCoeffData, and surfaceFrictionDataPeriodicHill."
                            << abort(FatalError);
    }

    return;
}
// * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * //

} // End namespace Foam

// ************************************************************************* //
