# This script aims to extract explanations using SHAP for the best CAPE model identified
# The balanced random forest algorithm used during RFE was used to extract explanations/ feature importance measures for the predictors selected from RFE feature selection 
# The SVM algorithm that gave the best CAPE model was used to extract both global (whole model) and local (individual instances) explanations of the model predictions. 
# Python version 3.6.8 was used 

# Imports
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score
from sklearn.metrics import brier_score_loss
import imblearn
from imblearn.ensemble import BalancedRandomForestClassifier
from sklearn.metrics import balanced_accuracy_score, average_precision_score, f1_score
from sklearn.utils import shuffle
from collections import Counter
from sklearn.metrics import roc_auc_score
from sklearn.svm import SVC
import shap
import matplotlib.pyplot as plt
shap.initjs()

# Set working directory
os.chdir("/../../")

#############################################
### SHAP FOR EVALUATING FEATURE SELECTION ###
#############################################
# Import cleaned, unstandardised earlylife dataset - data found in IOWBC_data.xlsx, sheet: "Earlylife data"
data_2YR = pd.read_csv("/scratch/dk2e18/Asthma_Prediction_Model/Perinatal_2YR_42F_QC_1368IDs_edit.csv", index_col=False)
del data_2YR['Unnamed: 0']  

# Remove those with NA to identify individuals with complete data. Only individuals with complete data for all candidate features were included in the feature selection
complete_data_2YR = data_2YR.dropna()

# Separate features and outcome for feature selection
X,Y=complete_data_2YR.iloc[:,1:complete_data_2YR.shape[1]-1],complete_data_2YR.iloc[:,complete_data_2YR.shape[1]-1]
X = X[['Mat_age', 'Birthweight', 'Solid_food', 'SDS_BMI_1','Total.Bf.duration',
       'Wheeze_2YR', 'Cough_2YR', 'SES']]
	   


# Define parameters for RFE - used default settings 
best_param1= {'bootstrap': True,'criterion': 'gini', 'max_depth': None, 'max_features': 'sqrt', 'min_samples_split': 2, 'n_estimators': 100}

# Define RFE model
bclf = BalancedRandomForestClassifier(n_estimators=best_param1["n_estimators"],max_depth=best_param1["max_depth"],
                              min_samples_split =best_param1["min_samples_split"],max_features=best_param1["max_features"],random_state=123)


# Standardise data
scaler = StandardScaler()
cont = pd.DataFrame(scaler.fit_transform(X.iloc[:,0:4]), columns=('Mat_age', 'Birthweight', 'Solid_food', 'SDS_BMI_1'))
cat = X.iloc[:,4:]
SX = pd.concat([cont, cat.reset_index(drop=True)], axis=1)

# Extract SHAP importance values for each feature
SX.columns = ['Maternal age', 'Birthweight', 'Age of solid food introduction', 'BMI at 1 year', 'Total breastfeeding duration', 'Early life wheeze', 'Early life cough', 'Maternal socioeconomic status']
bclf.fit(SX,Y)
importances = bclf.feature_importances_
explainer = shap.TreeExplainer(bclf)
shap_values = explainer.shap_values(SX)

# genetate summary plot
plt.clf()
shap.summary_plot(shap_values[1], SX, show=False)
plt.savefig("CAPE_shap_value_summary_plot.pdf",bbox_inches='tight')

#########################################
### SHAP TO OBTAIN MODEL EXPLANATIONS ###
#########################################
# Load training data which developed the best performing model - complete data, oversampled 0%, undersampled - data found in IOWBC_training_test_data.xlsx, sheet: "Standardised earlylife training"
data_0 = pd.read_csv("Earlylife_standardised_training_dataset_510IDs.csv", index_col=False)
print('Original dataset shape %s' % Counter(data_0.Asthma_10YR))
# Original dataset shape Counter({0: 442, 1: 68})

# Undersample the controls 
s1 = data_0.loc[data_0['Asthma_10YR'] == 1]
s0 = data_0.loc[data_0['Asthma_10YR'] == 0]
s0 = shuffle(s0, random_state=123)
s0 = s0.iloc[:68,]
data_0_U = s1.append(pd.DataFrame(data = s0), ignore_index=True)
data_0_U = shuffle(data_0_U, random_state=123)
print('Original dataset shape %s' % Counter(data_0_U.Asthma_10YR))
# Original dataset shape Counter({0: 68, 1: 68})

X_train = data_0_U.iloc[:,1:-1]
y_train = data_0_U.iloc[:,-1]
X_train.columns=['Maternal age', 'Birthweight', 'Age of solid food introduction', 'BMI at 1 year', 'Total breastfeeding duration', 'Early life wheeze', 'Early life cough', 'Maternal socioeconomic status']

# Import test data
test = pd.read_csv("Earlylife_standardised_test_dataset_255IDs.csv", index_col=False) - data found in IOWBC_training_test_data.xlsx, sheet: "Standardised earlylife test set"
# Split test data into features and outcome
X_test = test.drop(['Study_ID','Asthma_10YR'], axis=1)
X_test.columns=['Maternal age', 'Birthweight', 'Age of solid food introduction', 'BMI at 1 year', 'Total breastfeeding duration', 'Early life wheeze', 'Early life cough', 'Maternal socioeconomic status']
y_test = test['Asthma_10YR']

# Specify the algorithm and hyperparameters which gave the best performance
svc_rbf = SVC(C=45.1, gamma=0.0054, kernel='rbf', probability=True, random_state=123)

# Fit optimised model
svc_rbf.fit(X_train, y_train)

# explain all the predictions in the test set
explainer = shap.KernelExplainer(svc_rbf.predict_proba, X_train)
shap_values = explainer.shap_values(X_test)

### Global model explanations ###
# Plot a summary horizontal bar graph for each variable and the |mean| SHAP value for each class
plt.clf()
shap.summary_plot(shap_values, X_test, show=False)
plt.savefig("CAPE_model_shap_values_summary_bar_plot.pdf",bbox_inches='tight')

### Local model explanations ###
# plot SHAP values for all instances with the asthma output
# expected value = mean probability of the corresponding class - so mean value for asthmatics
# shap value - the feature attributions for the instance to be explained - want to explain asthmatics
f=shap.force_plot(explainer.expected_value[0], shap_values[1], X_test, show=False)
shap.save_html("Earlylife_model_shap_values_all_class01.htm", f)

# plot SHAP values for a single instances with the asthma output - sample 48
f1 = shap.force_plot(explainer.expected_value[1], shap_values[1][48,:], X_test.iloc[48,:], link="logit")
shap.save_html("Earlylife_model_shap_values_instance_48-11.htm", f1)

