import onnxruntime as ort
import numpy as np

import time
from PIL import Image
import os
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

class ControlAlgorithmONNX:
    def __init__(self, onnx_model_path: str):
        self.model_path = onnx_model_path
        self.session = ort.InferenceSession(self.model_path)
        self.input_name = self.session.get_inputs()[0].name
        self.output_name = self.session.get_outputs()[0].name

    def get_next_output(self, image) -> float:
        """ returns phase output not wrapped """
        # Run the ONNX session
        output = self.session.run([self.output_name], {self.input_name: image})[0]
        return output[0]


    def prepare_image_GRAYSCALE(self, img):
        img_np = img.convert('L')
        image = np.expand_dims(np.array(img_np), axis=0)
        image = image.astype(np.float32) / image.max()  
        image = np.expand_dims(image, axis=1)  # Shape (1, 1, height, width)
        return image

def normalise_angle(angle):
    return np.arctan2(np.sin(angle), np.cos(angle))

# Example usage
base_dir = ('C:/Users/ONNX')
control_algo_onnx = ControlAlgorithmONNX(base_dir + '/mobilnet.onnx')



times = []
def visualise__predictions(base_dir):

    label_phase = []
    label_power = []
    output_phase = []
    output_power = []


    base_dir_2 = base_dir + '/ampl lower limit 0.75'
    filenames_fresnel = os.listdir(base_dir_2 + '/pattern')

    for i in range(len(filenames_fresnel)):
        upload_image = Image.open(os.path.join(base_dir_2, 'pattern', filenames_fresnel[i]))
        npy_filename = os.path.splitext(filenames_fresnel[i])[0] + '.npy'
        
        labels_phase = np.load(os.path.join(base_dir_2, 'phase', npy_filename))
        labels_power = np.load(os.path.join(base_dir_2, 'power', npy_filename))


        image = control_algo_onnx.prepare_image_GRAYSCALE(upload_image)
        start_time = time.time()
        output = control_algo_onnx.get_next_output(image)
        end_time = time.time()
        times.append(end_time - start_time)
        outputs_phase = normalise_angle(np.array(output[7:]))
        outputs_power = normalise_angle(np.array(output[:7]))
        
        label_phase.append(labels_phase)
        label_power.append(labels_power)
        output_phase.append(outputs_phase)
        output_power.append(outputs_power)

    return label_phase, label_power, output_phase, output_power


label_phase, label_power, output_phase, output_power = visualise__predictions(base_dir)




def plot_with_kde(x, y, xlabel, ylabel, title, label):
   
    xy = np.vstack([x, y])
    z = gaussian_kde(xy)(xy)
    
    idx = z.argsort()
    x, y, z = x[idx], y[idx], z[idx]


    plt.figure(figsize=(10, 5))
    scatter = plt.scatter(x, y, c=z, s=10, cmap='viridis', label=label)
    plt.colorbar(scatter, label='Density')

    # Add labels, title, and grid
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.show()

label_phase = np.concatenate(label_phase)
output_phase = np.concatenate(output_phase)
label_power = np.concatenate(label_power)
output_power = np.concatenate(output_power)


label_phase = label_phase.flatten()
output_phase = output_phase.flatten()
label_power = label_power.flatten()
output_power = output_power.flatten()

plot_with_kde(
    label_phase, output_phase,
    xlabel='Phase Ground Truth',
    ylabel='Phase Prediction',
    title='NN Phase Prediction with Density',
    label='Phase Mean Error'
)

# Plot Power Prediction with gaussian_kde
plot_with_kde(
    label_power, output_power,
    xlabel='Power Ground Truth',
    ylabel='Power Prediction',
    title='NN Power Prediction with Density',
    label='Power Mean Error'
)


print('Mean Prediction Time', f'{np.mean(times):.5f}', '(s)')
print('Mean Phase Prediction Error:', np.mean(np.abs(normalise_angle(label_phase - output_phase))), ' radians')
print('Mean Power prediction Error:', np.mean(np.abs(normalise_angle(label_power - output_power))), 'normalised')
