# -*- coding: utf-8 -*-
"""
Created on Tue Feb 21 12:31:13 2023

@author: liuzh
"""

import numpy as np
import cv2
import random
import matplotlib.pylab as plt
from scipy.ndimage import gaussian_filter1d
from skimage.metrics import peak_signal_noise_ratio
from scipy.signal import medfilt





sample=30
frm=5
length = 1000
def sigSim(std_dev ,sample=sample,frm=frm,length=length):
    
    print(sample,frm)

    
    def lorentzian(x, x0, gamma, h):
        """Lorentzian function"""
        return (h*gamma/np.pi) / ((x-x0)**2 + gamma**2)

    x = np.linspace(0, length, length)
    Y=[]

    mean = 0 
    for i in range(sample):

        num_peaks = random.randint(7, 15)

        positions = np.sort(np.array([random.uniform(0, length) for j in range(num_peaks)]))

        heights = np.array([random.uniform(0.2, 1) for j in range(num_peaks)])

        half_widths = np.array([random.uniform(1, 7) for j in range(num_peaks)])

        x = np.linspace(0, length, length)
        y = np.zeros_like(x)
        for j in range(num_peaks):
            y += heights[j] * lorentzian(x,  positions[j],half_widths[j],heights[j])

        Y.append(y/y.max())
   
        
    return np.array(Y),np.random.normal(mean, std_dev, (frm,length))

# %%

def validation(originalSignal, processedSignal):
   
    noise=processedSignal-originalSignal
    
    snr = 20 * np.log10(np.linalg.norm(originalSignal) / np.linalg.norm(noise))

    psnr = peak_signal_noise_ratio(originalSignal, processedSignal)
    return snr, psnr,np.linalg.norm(noise)**2


def nlm_dogKer(rmMatPad,
        h_tmp = -1,
        w_tmp=9,
        w_tmp_L=5,
        h_serch = -1,
        sigma1 = 1, 
        # sigma2 = 45,
        **karg):
    try :
        w_serch = karg['w_serch']
    except:
        w_serch =30    
        
    try :
        smoother = karg['smoother']
    except:
        smoother = 0.01  
    print(karg,smoother,w_serch)
    # assert(rmMatPad.shape[0]==length)
    padW_len = (w_serch+w_tmp)
    subMat = cv2.copyMakeBorder(rmMatPad, padW_len, padW_len,0, 0, cv2.BORDER_REFLECT)

    
    
    
    
    mat1_ = subMat[padW_len-w_tmp:-padW_len+w_tmp]
    sweight = mat1_[w_tmp:-w_tmp]*0
    average=mat1_[w_tmp:-w_tmp]*0
    
    
    
    # for j in np.arange(-w_serch,w_serch+1):
    for j in np.arange(frm)-1-int(frm/2):
        for i in np.arange(-w_serch,w_serch+1):
    
            subMat_ = np.roll(subMat,(i,j),(0,1))[padW_len-w_tmp:-padW_len+w_tmp]
            
            diffMat = ((mat1_-subMat_)[w_tmp:-w_tmp])
            
            gf_Diff=gaussian_filter1d(diffMat,sigma1, axis=0,order=0,mode='reflect')
                # -gaussian_filter1d(diffMat,sigma2, axis=0,order=0,mode='reflect')*0

            
            w = np.exp(-(gf_Diff**2/smoother**2))

            
            
            w=gaussian_filter1d(w,5, axis=1,order=0,mode='reflect')
            
            sweight += w
            average += (w*subMat_[w_tmp:-w_tmp])
    return average/sweight






def proc(std_dev,sample=sample,frm=frm,**karg):
    
    
    Y,Noise = sigSim(std_dev,sample=sample,frm=frm)
    gfQ = []
    mdQ =[]
    dnlm_yQ=[]
    nQ=[]
    Yn=[]
    for i in range(Y.shape[0]):
        
        gfDn = []
        mdDn=[]
    
        y=Y[i]
        yn = y+Noise
        nQ .append(validation(y.T,yn.T.mean(1)))
        
        
        dnlm_y =nlm_dogKer(yn.T,**karg)
        gfDn.append(gaussian_filter1d(yn.T,0.1, axis=0,order=0,mode='reflect').mean(1))
        gfQ.append(validation(y.T,gfDn[0]))
        gfDn.append(gaussian_filter1d(yn.T,0.5, axis=0,order=0,mode='reflect').mean(1))
        gfQ.append(validation(y.T,gfDn[1]))
        gfDn.append(gaussian_filter1d(yn.T,1, axis=0,order=0,mode='reflect').mean(1))
        gfQ.append(validation(y.T,gfDn[2]))
        gfDn.append(gaussian_filter1d(yn.T,2, axis=0,order=0,mode='reflect').mean(1))
        gfQ.append(validation(y.T,gfDn[3]))
        
        
        mdDn.append(medfilt(yn.T,3).mean(1))
        mdQ.append(validation(y.T,mdDn[0]))
        mdDn.append(medfilt(yn.T,5).mean(1))
        mdQ.append(validation(y.T,mdDn[1]))
        mdDn.append(medfilt(yn.T,7).mean(1))
        mdQ.append(validation(y.T,mdDn[2]))
        mdDn.append(medfilt(yn.T,9).mean(1))
        mdQ.append(validation(y.T,mdDn[3]))
        
        
        Yn.append(yn.T)
    
        dnlm_yQ.append(validation(y.T,dnlm_y.mean(1)))
        i+=1
        # print(i)
    
        
    gfQ = np.array(gfQ)
    mdQ =np.array(mdQ)
    dnlm_yQ=np.array(dnlm_yQ)
    nQ=np.array(nQ)
    print('gf'+str(gfQ.reshape(-1,4,3).mean(0).max(0)))
    print('df'+str(mdQ.reshape(-1,4,3).mean(0).max(0)))
    print('nlm'+str(dnlm_yQ.mean(0)))
    return gfQ,mdQ,dnlm_yQ,nQ,Y,Yn,gfDn[2],mdDn[0],dnlm_y.mean(1)
   


def sampleProc():   
    
    # assert(frm==1)
    global frm
    ofrm=frm
    frm=1
    gfQ,mdQ,dnlm_yQ,nQ,Origin,Degraded,Guassian,Median,NLM= proc(0.05)
    
    plt.close('all') 
    sig=['Degraded','Guassian','Median','NLM','Origin']
    for i in np.arange(5).astype(int):
        plt.figure('sep')
        spc=0.5
        if i>0:
            plt.plot(eval(sig[i])+i*spc,label=sig[i])
            np.savetxt(r'./valid\sig'+sig[i]+".txt",eval(sig[i]))
        else:

            plt.plot(eval(sig[i])[0]+i*spc,label=sig[i])
            np.savetxt(r'./valid\sig'+sig[i]+".txt",eval(sig[i])[0])
    
    plt.legend()
    frm=ofrm





def saveData(std_dev):
    gfQ,mdQ,dnlm_yQ,nQ,_,_,_,_,_= proc(std_dev)
    np.savetxt(r'./valid\gfQ_std_dev_'+str(std_dev)+".txt",gfQ,header='sigma=0.1,0.5,1,2 with avg\n\n'
               +str(gfQ.reshape(-1,4,3).mean(0))+'\n\n\n')
    np.savetxt(r'./valid\mdQ_std_dev_'+str(std_dev)+".txt",mdQ,header='kerlen=3 5 7 9 with avg\n\n'+
               str(mdQ.reshape(-1,4,3).mean(0))+'\n\n\n')
    np.savetxt(r'.\valid\nlmQ_std_dev_'+str(std_dev)+".txt",dnlm_yQ,header='nlm avg\n\n'+
               str(dnlm_yQ.mean(0))+'\n\n\n')
    
    np.savetxt(r'.\valid\nQ_std_dev'+str(std_dev)+".txt",nQ,header='kerlen=3 5 7 9 with avg\n\n'+
               str(nQ.mean(0))+'\n\n\n')
    
    alQ=np.vstack([gfQ.reshape(-1,4,3).mean(0),
               mdQ.reshape(-1,4,3).mean(0),
               dnlm_yQ.mean(0),
               nQ.mean(0)])
    np.savetxt(r'.\valid\all_stddv'+str(std_dev)+".txt",alQ,header='sigma=0.1,0.5,1,2 with avg\n\n'+
               'kerlen=3 5 7 9 with avg\n\n'+
               'nlm avg\n\n'+
               'nQ')
def test(case):

    if case ==1:
        for std in [0.01,0.05, 0.1, 0.2]:
        
            saveData(std)
    elif case==2:
        rSNR = []
        rPSNR =[]
        rMSE = []
        std = []
        for frm in np.arange(1,10).astype(int):
            gfQ,mdQ,dnlm_yQ,nQ,_,_,_,_,_= proc(0.05,30,frm)
            print (gfQ.reshape(-1,4,3).mean(0)[2],
                   dnlm_yQ.mean(0),
                   nQ.mean(0))
            rSNR.append([gfQ.reshape(-1,4,3).mean(0)[2,0],
                         dnlm_yQ.mean(0)[0],nQ.mean(0)[0]])
            rPSNR.append([gfQ.reshape(-1,4,3).mean(0)[2,1],
                         dnlm_yQ.mean(0)[1],nQ.mean(0)[1]])
            rMSE.append([gfQ.reshape(-1,4,3).mean(0)[2,2],
                         dnlm_yQ.mean(0)[2],nQ.mean(0)[2]])
            
            std.append([gfQ.reshape(-1,4,3).std(0),dnlm_yQ.std(0),nQ.std(0)])   
        return rSNR,rPSNR,rMSE,std
    elif case==3:
        MSE = []
        for smoother in np.linspace(0.1,1,20)*0.05:
            rSNR = []
            rPSNR =[]
            rMSE = []
            std = []
            
            for w_serch in np.arange(3,30):
                gfQ,mdQ,dnlm_yQ,nQ,_,_,_,_,_= proc(0.05,30,w_serch=w_serch,smoother=smoother)
                print (gfQ.reshape(-1,4,3).mean(0)[2],
                       dnlm_yQ.mean(0),
                       nQ.mean(0))
                
                rSNR.append([gfQ.reshape(-1,4,3).mean(0)[2,0],
                             dnlm_yQ.mean(0)[0],nQ.mean(0)[0]])
                rPSNR.append([gfQ.reshape(-1,4,3).mean(0)[2,1],
                             dnlm_yQ.mean(0)[1],nQ.mean(0)[1]])
                rMSE.append([gfQ.reshape(-1,4,3).mean(0)[2,2],
                             dnlm_yQ.mean(0)[2],nQ.mean(0)[2]])
                
                std.append([gfQ.reshape(-1,4,3).std(0),dnlm_yQ.std(0),nQ.std(0)])  
            MSE.append(rMSE)
# def fmt(x, pos):
#     a, b = '{:.2e}'.format(x).split('e')
#     b = int(b)
#     return r'${} \times 10^{{{}}}$'.format(a, b)

from matplotlib import ticker            
MSE=np.loadtxt('./valid/sweep2.txt')            
plt.contourf(np.arange(3,30),np.linspace(0.1,1,20),MSE[:,:]*1e-3,50)
plt.ylabel(r'$k$')
plt.xlabel(r'$w$  (pixels)')
fmt = ticker.ScalarFormatter(useMathText=True)
fmt.set_powerlimits((0, 0))

cb=plt.colorbar(format=fmt)
cb.update_ticks()
cb.ax.text(-0.25, 1, r'$\times$10$^{-1}$', va='bottom', ha='left')
cb.set_label('MSE')
# if __name__ == '__main__':          
    # rSNR,rPSNR,rMSE,std=test(2)
    # sampleProc()
    # pass
# np.savetxt(r".\valid\1_10frm_MSE.txt",rMSE,header='g_f sigma=0.1,0.5,1,2 with avg\n nlm\n')
