# This script aims to extract explanations using SHAP for the best CAPP model identified
# The balanced random forest algorithm used during RFE was used to extract explanations/ feature importance measures for the predictors selected from RFE feature selection 
# The SVM algorithm that gave the best CAPP model was used to extract both global (whole model) and local (individual instances) explanations of the model predictions. 
# Python version 3.6.8 was used 

# Imports
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn import metrics
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
from sklearn.metrics import roc_auc_score
from sklearn.metrics import brier_score_loss
import imblearn
from imblearn.ensemble import BalancedRandomForestClassifier
from sklearn.metrics import balanced_accuracy_score, average_precision_score, f1_score
from sklearn.utils import shuffle
from collections import Counter
from sklearn.metrics import roc_auc_score
from sklearn.svm import SVC
import shap
import matplotlib.pyplot as plt
shap.initjs()

# Set working directory
os.chdir("/../../")

#############################################
### SHAP FOR EVALUATING FEATURE SELECTION ###
#############################################
# Import cleaned, unstandardised preschool dataset - data found in IOWBC_data.xlsx, sheet: "Preschool data"
data_4YR = pd.read_csv("Preschool_QC_1368IDs.csv", index_col=False)
del data_4YR['Unnamed: 0']
# 1368 Ids, 59 columns

# Remove those with NA to identify individuals with complete data. Only individuals with complete data for all candidate features were included in the feature selection
complete_data_4YR = data_4YR.dropna()

# Separate features and outcome for feature selection
X,Y=complete_data_4YR.iloc[:,1:complete_data_4YR.shape[1]-1],complete_data_4YR.iloc[:,complete_data_4YR.shape[1]-1]
X = X[['Mat_age', 'Birthweight', 'Solid_food', 'SDS_BMI_1','SDS_BMI_4','Total.Bf.duration',
       'Wheeze_4YR', 'Cough_4YR', 'Noct_Symp_4YR', 'Atopy_4YR',
       'Polysensitisation_4YR','SES']]
	   
# Define parameters for RFE - used default settings 
best_param1= {'bootstrap': True,'criterion': 'gini', 'max_depth': None, 'max_features': 'sqrt', 'min_samples_split': 2, 'n_estimators': 100}

# Define RFE model
bclf = BalancedRandomForestClassifier(n_estimators=best_param1["n_estimators"],max_depth=best_param1["max_depth"],
                              min_samples_split =best_param1["min_samples_split"],max_features=best_param1["max_features"],random_state=123)


# Standardise data
scaler = StandardScaler()
cont = pd.DataFrame(scaler.fit_transform(X.iloc[:,0:5]), columns=('Mat_age', 'Birthweight', 'Solid_food', 'SDS_BMI_1', 'SDS_BMI_4'))
cat = X.iloc[:,5:]
SX = pd.concat([cont, cat.reset_index(drop=True)], axis=1)

# Extract SHAP importance values for each feature
SX.columns = ['Maternal age', 'Birthweight', 'Age of solid food introduction', 'BMI at 1 year', 'BMI at 4 years', 'Total breastfeeding duration', 'Preschool wheeze', 'Preschool cough', 'Preschool nocturnal symptoms', 'Preschool atopy', 'Preschool polysensitisation', 'Maternal socioeconomic status']
bclf.fit(SX,Y)
importances = bclf.feature_importances_
explainer = shap.TreeExplainer(bclf)
shap_values = explainer.shap_values(SX)

# genetate summary plot
plt.clf()
shap.summary_plot(shap_values[1], SX, show=False)
plt.savefig("CAPP_shap_value_summary_plot.pdf",bbox_inches='tight')


#########################################
### SHAP TO OBTAIN MODEL EXPLANATIONS ###
#########################################
# Load training data which developed the best performing model- complete data, oversampled 300%, undersampled
data_300_O = pd.read_csv("Oversampled_preschool_dataset_300%.csv", index_col=False)
data_300_O = data_300_O.iloc[0:518,:]
print('Original dataset shape %s' % Counter(data_300_O.Asthma_10YR))
# Original dataset shape Counter({0: 314, 1: 204})

# Undersample the controls 
s1 = data_300_O.loc[data_300_O['Asthma_10YR'] == 1]
s0 = data_300_O.loc[data_300_O['Asthma_10YR'] == 0]
s0 = shuffle(s0, random_state=123)
s0 = s0.iloc[:204,]
data_300_OU = s1.append(pd.DataFrame(data = s0), ignore_index=True)
data_300_OU = shuffle(data_300_OU, random_state=123)
print('Original dataset shape %s' % Counter(data_300_OU.Asthma_10YR))
# Original dataset shape Counter({0: 204, 1: 204}

X_train = data_300_OU.iloc[:,1:-1]
y_train = data_300_OU.iloc[:,-1]
X_train.columns=['Maternal age', 'Birthweight', 'Age of solid food introduction', 'BMI at 1 year', 'BMI at 4 years', 'Total breastfeeding duration', 'Preschool wheeze', 'Preschool cough', 'Preschool nocturnal symptoms', 'Preschool atopy', 'Preschool polysensitisation', 'Maternal socioeconomic status']

# Import the standardised preschool test data - data found in IOWBC_training_test_data.xlsx, sheet: "Standardised preschool test set"
test = pd.read_csv("Preschool_standardised_test_dataset_183IDs.csv", index_col=False)
# Split test data into features and outcome
X_test = test.drop(['Study_ID','Asthma_10YR'], axis=1)
X_test.columns=['Maternal age', 'Birthweight', 'Age of solid food introduction', 'BMI at 1 year', 'BMI at 4 years', 'Total breastfeeding duration', 'Preschool wheeze', 'Preschool cough', 'Preschool nocturnal symptoms', 'Preschool atopy', 'Preschool polysensitisation', 'Maternal socioeconomic status']
y_test = test['Asthma_10YR']

# Specify the algorithm and hyperparameters which gave the best performance
svc_linear = SVC(C=0.33, kernel='linear', probability=True, random_state=123)

# Fit optimised model
svc_linear.fit(X_train, y_train)

# explain all the predictions in the test set
explainer = shap.KernelExplainer(svc_linear.predict_proba, X_train)
shap_values = explainer.shap_values(X_test)

### Global model explanations ###
# Plot a summary horizontal bar graph for each variable and the |mean| SHAP value for each class
plt.clf()
shap.summary_plot(shap_values, X_test, show=False)
plt.savefig("CAPP_model_shap_values_summary_bar_plot.pdf",bbox_inches='tight')

### Local model explanations ###
# plot SHAP values for  all instances with the asthma output
# expected value = mean probability of the corresponding class - so mean value for asthmatics
# shap value - the feature attributions for the instance to be explained - want to explain asthmatics
f=shap.force_plot(explainer.expected_value[0], shap_values[1], X_test, show=False)
shap.save_html("CAPP_model_shap_values_all_class01.htm", f)

# plot SHAP values for a single instances with the asthma output - e.g. sample 36 
f1 = shap.force_plot(explainer.expected_value[1], shap_values[1][36,:], X_test.iloc[36,:], link="logit")
shap.save_html("/scratch/dk2e18/Asthma_Prediction_Model/Oversampling/Final_models/Sensitivity_analyses/Interpretable_ML/Preschool_model_shap_values_instance_36-11.htm", f1)

# plot SHAP values for a single instances with the asthma output - e.g. sample 12 
f1 = shap.force_plot(explainer.expected_value[1], shap_values[1][12,:], X_test.iloc[12,:], link="logit")
shap.save_html("/scratch/dk2e18/Asthma_Prediction_Model/Oversampling/Final_models/Sensitivity_analyses/Interpretable_ML/Preschool_model_shap_values_instance_12-11.htm", f1)

# plot SHAP values for a single instances with the asthma output - e.g. sample 50
f1 = shap.force_plot(explainer.expected_value[1], shap_values[1][50,:], X_test.iloc[50,:], link="logit")
shap.save_html("/scratch/dk2e18/Asthma_Prediction_Model/Oversampling/Final_models/Sensitivity_analyses/Interpretable_ML/Preschool_model_shap_values_instance_50-11.htm", f1)

