import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from cycler import cycler
from scipy.stats import pearsonr
from sklearn.metrics import roc_curve, roc_auc_score

plt.rcParams.update({
    'font.size': 14,  # Default font size for all text
    'font.weight': 'bold',  # Make all text bold
    'font.style': 'normal',  # Set font style to italic
    'font.family': 'Arial',  # Set default font family (you can change to 'sans-serif', 'monospace', etc.)
    'axes.titlesize': 18,  # Title font size
    'axes.titleweight': 'bold',  # Title font weight
    'axes.labelsize': 16,  # X and Y labels font size
    'axes.labelweight': 'bold',  # X and Y labels font weight
    # Added the Origin Pro colors from their built-in palette -
    # gray, red, blue, green,
    # purple, ochre, cyan, brown,
    # olive, orange, sky-blue
    'axes.prop_cycle': cycler('color', ['#515151', '#f14040', '#1a6fdf', '#37ad6b',
                                        '#b177de', '#cc9900', '#00cbcc', '#7d4e4e',
                                        '#8e8e00', '#fb6501', '#6699cc'])
})

def fit_line(x, y):
    coefficients = np.polyfit(x, y, 1)  # Linear fit (degree 1)
    polynomial = np.poly1d(coefficients)
    x_line = np.linspace(min(x), max(x), 100)
    y_line = polynomial(x_line)
    return x_line, y_line

# Function to calculate mean
def calculate_means(data):
    return {patient: np.mean(values) for patient, values in data.items()}

# Function to calculate medians
def calculate_medians(data):
    return {patient: np.median(values) for patient, values in data.items()}

# Load up sensor data 
lesional_sensor = pd.read_csv('lesional_sensor.csv').iloc[:, 1:].to_dict(orient='list')
nonlesional_sensor = pd.read_csv('nonlesional_sensor.csv').iloc[:, 1:].to_dict(orient='list')
# Clean NaN values
lesional_sensor = {k: [v for v in vals if pd.notna(v)] for k, vals in lesional_sensor.items()}
nonlesional_sensor = {k: [v for v in vals if pd.notna(v)] for k, vals in nonlesional_sensor.items()}

# Load up corneo data 
lesional_corneo = pd.read_csv('lesional_corneo.csv').iloc[:, 1:].to_dict(orient='list')
nonlesional_corneo = pd.read_csv('nonlesional_corneo.csv').iloc[:, 1:].to_dict(orient='list')
# Clean NaN values
lesional_corneo = {k: [v for v in vals if pd.notna(v)] for k, vals in lesional_corneo.items()}
nonlesional_corneo = {k: [v for v in vals if pd.notna(v)] for k, vals in nonlesional_corneo.items()}

# Compute means for capacitance data
lesional_means = calculate_means(lesional_sensor)
nonlesional_means = calculate_means(nonlesional_sensor)

# Compute medians for corneometer data
lesional_medians = calculate_medians(lesional_corneo)
nonlesional_medians = calculate_medians(nonlesional_corneo)

# Create scatter plot
plt.figure(figsize=(10, 6))

# Lesional series
lesional_x = [lesional_means[patient] for patient in lesional_means]
lesional_y = [lesional_medians[patient] for patient in lesional_means]
plt.scatter(lesional_x, lesional_y, label='Lesional')

# Non-lesional series
nonlesional_x = [nonlesional_means[patient] for patient in nonlesional_means]
nonlesional_y = [nonlesional_medians[patient] for patient in nonlesional_means]
plt.scatter(nonlesional_x, nonlesional_y, label='Non-Lesional')

combined_x = lesional_x + nonlesional_x
combined_y = lesional_y + nonlesional_y
combined_x_line, combined_y_line = fit_line(combined_x, combined_y)
plt.plot(combined_x_line, combined_y_line, label='Combined Fit', color='green', linestyle='--')



# Plot customization
# plt.xlabel('Mean Normalized Capacitance (pF/cm$^2$)')
plt.xlabel('Mean Sensor Capacitance (pF)')
plt.ylabel('Median Corneometer Reading (A.U.)')
plt.title('Capacitance vs Corneometer Data')
plt.legend()
plt.grid(True)
plt.tight_layout()

plt.savefig('SensorCapVsCorneo.pdf', format='pdf')

plt.show()

lesional_corr, lesional_p_value = pearsonr(lesional_x, lesional_y)
nonlesional_corr, nonlesional_p_value = pearsonr(nonlesional_x, nonlesional_y)
combined_corr, combined_p_value = pearsonr(combined_x, combined_y)


print(f"Lesional correlation: {lesional_corr:.3f}, p-value: {lesional_p_value:.3f}")
print(f"Non-lesional correlation: {nonlesional_corr:.3f}, p-value: {nonlesional_p_value:.3f}")
print(f"Combined correlation: {combined_corr:.3f}, p-value: {combined_p_value:.3f}")