"""
Author: Henry (hw1n21@soton.ac.uk)
Untitled-1 (c) 2022
Desc: description
Created:  2022-06-08T12:38:04.193Z
"""
import csv
from collections import Counter
from math import sqrt
from turtle import up

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sympy import elliptic_f

workload_name = "blackscholes"  ## blackscholes   fluidanimate(no 50k)   x264
ht_data_level = "50k"
pkt_data_num = 50000
GP_num = 7070
calib_rate = 0.05    #### if calib rate is too large, the algorithm will be over-calibrated  ## biggest is 0.056
ht_node = 61            ## 61   38  
ht_node_second = 25
ht_detect_time = 5
dsct_and_caging = ht_detect_time


# data for blackscholes workload
if workload_name == "blackscholes":
    if ht_node_second >= 64: 
      ht_file = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\HT data\blackscholes_64c_simsmall_packets_ROI2_HTin"+str(ht_node)+"_"+ht_data_level+"_htinjected.csv"
      ht_pred_file = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_25NU_HTin"+str(ht_node)+"_"+ht_data_level+".csv"
    else: 
      ht_file = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\HT data\second_ht\blackscholes_64c_simsmall_packets_ROI2_HTin"+str(ht_node)+"_"+str(ht_node_second)+"_"+ht_data_level+"_htinjected.csv"
      ht_pred_file = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\second_ht\detection_results_500_5IN_2HL_25NU_HTin"+str(ht_node)+"_"+str(ht_node_second)+"_"+ht_data_level+".csv"

    golden_pred_file  = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_25NU_golden_10k.csv"

elif workload_name == "fluidanimate":
    # data for fluidanimate workload
    ht_file = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\HT data\fluidanimate_64c_simsmall_packets_ROI2_HTin"+str(ht_node)+"_"+ht_data_level+"_htinjected.csv"
    ht_pred_file = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_25NU_HTin"+str(ht_node)+"_"+ht_data_level+"_fa.csv"
    golden_pred_file  = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_25NU_golden_10k_fa.csv"
    if ht_data_level == "100k":
      pkt_data_num = 100000
      GP_num = 6516
    elif ht_data_level == "10k":
      pkt_data_num = 10000  
      GP_num = 1404 #902
    elif ht_data_level == "500k":
      pkt_data_num = 500000  
      GP_num = 26976 #30503 
    elif ht_data_level == "1000k":
      pkt_data_num = 1000000  
      GP_num = 57191 
    elif ht_data_level == "20989k":
      pkt_data_num = 20989000
      GP_num = 1675629
elif workload_name == "x264":
    # data for fluidanimate workload
    ht_file = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\HT data\x264_64c_simsmall_packets_ROI2_HTin"+str(ht_node)+"_"+ht_data_level+"_htinjected.csv"
    ht_pred_file = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_25NU_HTin"+str(ht_node)+"_"+ht_data_level+"_x264.csv"
    golden_pred_file  = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_25NU_golden_10k_x264.csv"
    if ht_data_level == "100k":
      pkt_data_num = 100000
      GP_num = 6133
    elif ht_data_level == "50k":
      pkt_data_num = 50000 
      GP_num = 2491
    elif ht_data_level == "10k":
      pkt_data_num = 10000  
      GP_num = 285
    elif ht_data_level == "500k":
      pkt_data_num = 500000  
      GP_num = 31271   
    elif ht_data_level == "1000k":
      pkt_data_num = 1000000  
      GP_num = 63153 
    elif ht_data_level == "9999k":
      pkt_data_num = 9999000
      GP_num = 609383

########One file, normal data and HT data##############
########read normal detection results data############
# with open(r'C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_20NU.csv', 'rt') as out_file_pre:
#   read_file = csv.reader(out_file_pre)
#   pac_result_in = np.array([float(row[1]) for row in read_file])

# ########read HT detection results data############
# with open(r'C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_20NU.csv', 'rt') as out_file_pre:
#   read_file = csv.reader(out_file_pre)
#   pac_result_in_ht = np.array([float(row[0]) for row in read_file])
########################################################

##############One file, HT pred data###########################
with open(ht_pred_file, 'rt') as out_file_pre:
  read_file = csv.reader(out_file_pre)
  pac_result_ht = np.array([float(row[0]) for row in read_file])
###############################################################

##############One file, golden pred data###########################
with open(golden_pred_file, 'rt') as out_file_pre:
  read_file = csv.reader(out_file_pre)
  pac_result_golden = np.array([float(row[0]) for row in read_file])

#########find actual HT, if two data array (normal and HT)##########
# total_ht=0
# actual_ht_index = []
# for a in range(0, len(pac_result_in)):
#     if abs(pac_result_in_ht[a] - pac_result_in[a]) >0.01:
#         total_ht +=1
#         actual_ht_index.append(a)

#########find src id############
with open(ht_file, 'rt') as packet_data:
  read_file_pkt = csv.reader(packet_data)
  pkt_src= np.array([float(row[0]) for row in read_file_pkt])

with open(ht_file, 'rt') as packet_data:
  read_file_pkt = csv.reader(packet_data)
  pkt_dst= np.array([float(row[1]) for row in read_file_pkt])

if len(pkt_src) != len(pac_result_ht):
  print("ERROR: ht file is different with pred file!",len(pac_result_ht), len(pkt_src))
  exit()



#######     X-Y routing algorithm   ########
def xy_routing(src, dst):
  node_id = [src]
  #####convert router id into x-y#######
  base_x = [0,1,2,3,4,5,6,7]
  if src in base_x:
    src_x = src
    src_y = 0
  else: 
    src_x = src % 8
    src_y = src // 8
  if dst in base_x:
    dst_x = dst
    dst_y = 0
  else: 
    dst_x = dst % 8
    dst_y = dst // 8
  ##### find pathing nodes #######
  if dst_x > src_x:
    ##East dst
    for est_x in range(dst_x - src_x - 1):
      node_id.append(src + est_x + 1)
    if dst_y > src_y:
      ##North dst
      last_x_node = node_id[-1]
      for nor_x in range(dst_y - src_y - 1):
        node_id.append(last_x_node + (nor_x+1)  * 8)
    elif dst_y < src_y:
      ##South dst
      last_x_node = node_id[-1]
      for nor_x in range(src_y - dst_y - 1):
        node_id.append(last_x_node - (nor_x+1)  * 8)     

  elif dst_x < src_x:
     ##West dst
    for est_x in range(src_x - dst_x - 1):
      node_id.append(src - est_x - 1)
    if dst_y > src_y:
      ##North dst
      last_x_node = node_id[-1]
      for nor_x in range(dst_y - src_y - 1):
        node_id.append(last_x_node + (nor_x+1)  * 8)
    elif dst_y < src_y:
      ##South dst
      last_x_node = node_id[-1]
      for nor_x in range(src_y - dst_y - 1):
        node_id.append(last_x_node - (nor_x+1)  * 8)  

  elif dst_x == src_x:
    ##same col
    if dst_y > src_y:
      ##North dst
      for nor_x in range(dst_y - src_y - 1):
        node_id.append(src + (nor_x+1)  * 8)
    elif dst_y < src_y:
      ##South dst
      for nor_x in range(src_y - dst_y - 1):
        node_id.append(src - (nor_x+1)  * 8)  
         
  return node_id

def takeSecond(elem):
  return elem[1]

#######    DSCT & malicious nodes detection     ########
def DSCT(sus_pac, pkt_src, pkt_dst, ht_node, ht_node_second, ht_detect_time, dsct_and_caging):
  cor_ht_src_num = 0
  cor_ht_pkt = 0
  cor_ht_src = []
  all_path_nodes = []
  all_src_nodes = []
  top10_path_nodes = []
  top10_src_nodes = []
  dsct_node = []
  dsct_credit = []
  dsct = []
  dsct_and = []
  dsct_src_node = []
  dsct_src_credit = []
  #ht_flag = 0
  current_sus_pac = 0
  HT_nodes = []
  cor_ht_posi = 0
  print('==============================DSCT Phase=============================')
  for cor in sus_pac:
      current_sus_pac += 1
      node_src = pkt_src[cor]
      if pkt_src[cor] == ht_node: 
          cor_ht_src_num +=1
          cor_ht_src.append(cor+1)
      path_nodes = xy_routing(int(node_src), int(pkt_dst[cor]))
      if (ht_node in path_nodes) or (ht_node_second <= 64 and ht_node_second in path_nodes):
        cor_ht_pkt+=1
      all_src_nodes.append(node_src)
      all_path_nodes.extend(path_nodes)
      # top10_path_nodes = Counter(all_path_nodes).most_common(ht_detect_time)
      # top10_src_nodes = Counter(all_src_nodes).most_common(ht_detect_time)

      ## routing-path nodes credit assignment
      for node_sus in path_nodes:
        if node_sus not in dsct_node:
          dsct_node.append(node_sus)
          dsct_credit.append(3)
        else:
          dsct_credit[dsct_node.index(node_sus)] +=3
      for node_dsct in dsct_node:
        if node_dsct not in path_nodes:
          dsct_credit[dsct_node.index(node_dsct)] -= 1 
        if dsct_credit[dsct_node.index(node_dsct)] <= 0:
          del dsct_credit[dsct_node.index(node_dsct)]
          del dsct_node[dsct_node.index(node_dsct)]
       
      ## source nodes credit assignment
      if node_src not in dsct_src_node:
        dsct_src_node.append(int(node_src))
        dsct_src_credit.append(5)
      else:
        dsct_src_credit[dsct_src_node.index(node_src)]  += 5
      for n in dsct_src_node:
        if n != node_src:
          dsct_src_credit[dsct_src_node.index(n)] -=1
        if dsct_src_credit[dsct_src_node.index(n)] <= 0:
          del dsct_src_credit[dsct_src_node.index(n)]
          del dsct_src_node[dsct_src_node.index(n)] 

      dsct = list(zip(dsct_node,dsct_credit))
      dscta = np.array(dsct)
      dscta = dscta[np.lexsort(-dscta.T)]

      dsct_src = list(zip(dsct_src_node,dsct_src_credit))
      dscta_src = np.array(dsct_src)
      dscta_src = dscta_src[np.lexsort(-dscta_src.T)]

      temp_top_src_nodes = Counter(all_src_nodes).most_common(ht_detect_time)
      #if len(top10_path_nodes) == ht_detect_time and len(top10_src_nodes) == ht_detect_time and top10_path_nodes[0][0] == top10_src_nodes[0][0] and top10_path_nodes[0][0] not in HT_nodes:
      caging_0 = len(dsct)>=ht_detect_time and len(dsct_src)>=ht_detect_time        ## caging time checking
      if caging_0:
        caging_1 = dscta[0][1] != dscta[1][1] #or dscta_src[0][1] != dscta_src[1][1]  ## credits of first and second nodes in dsct are not same 
      else:
        caging_1=0
      #caging_1 = 1
      caging_2 = dscta[0][0] not in HT_nodes                                        ## new HT node has not been detected
      caging_3 = dscta[0][0] == dscta_src[0][0]                                     ## Top1 nodes of both dscts are same
      
      dsct_and = []
      dsct_and_src=[k1.tolist() for k1 in dscta_src if k1 in dscta]
      dsct_and_a = [k2.tolist() for k2 in dscta if k2 in  dscta_src]
      
      # print('And test: ', dsct_and) 
      # print(dscta_src)
      # print(dscta)
     
      if len(dsct_and_a) > dsct_and_caging:
        for kk in range(len(dsct_and_src)): 
          for kkk in range(len(dsct_and_a)):
            if dsct_and_src[kk][0] == dsct_and_a[kkk][0]:
               dsct_and.append([dsct_and_src[kk][0], dsct_and_src[kk][1] + dsct_and_a[kkk][1]])
        dsct_and.sort(key=takeSecond,reverse=True)
      if AND_top_open:
          caging_4 = (len(dsct_and) > dsct_and_caging) and (dsct_and[0][0] not in HT_nodes)    
      else:
          caging_4 = 0
      caging_same_top = caging_1 and caging_2 and caging_3
      caging_and_top = caging_4
      if caging_same_top or caging_and_top: 
        if caging_same_top: 
          ht_inj_node = dscta[0][0]
          if ht_inj_node == ht_node or ht_inj_node == ht_node_second:
              print('\033[4;33m!!!!!!HARDWARE TROJAN detected in NODE (from SAME-TOP protocol)', ht_inj_node, '\033[0m')
              cor_ht_posi +=1
          else:
              print('\033[31m!!!!!!HARDWARE TROJAN detected in NODE (from SAME-TOP protocol)', ht_inj_node, '\033[0m') 
        elif caging_and_top:
           ht_inj_node = dsct_and[0][0]
           if ht_inj_node == ht_node or ht_inj_node == ht_node_second:
              print('\033[4;33m!!!!!!HARDWARE TROJAN detected in NODE (from AND-TOP protocol)', ht_inj_node, '\033[0m')
              cor_ht_posi +=1
           else:
              print('\033[31m!!!!!!HARDWARE TROJAN detected in NODE (from AND-TOP protocol)', ht_inj_node, '\033[0m')   
        #print('!!!!!!HARDWARE TROJAN detected in NODE', top10_path_nodes[0][0],'Now is packet',cor, '/',round(cor/len(pkt_src),4)*100,'%')
       
        # print('dsct:',dscta[0:ht_detect_time,:].tolist())
        # print('dsct_src:',dscta_src[0:ht_detect_time,:].tolist())        
        print('dsct:',dscta[0:10,:].tolist())
        print('dsct_src:',dscta_src[0:10,:].tolist())     
        print('dsct_and_src', dsct_and_src)
        print('dsct_and_a', dsct_and_a)
        print('dsct_and',dsct_and)
        HT_nodes.append(ht_inj_node)
        print('NOW is pacekt:',cor,'/', cor/pkt_data_num*100, '%')
        print('NOW The total detected suspicious packets is',current_sus_pac)
        print('NOW The correct ht infected packet number is', cor_ht_pkt)
        print('NOW Precision: TP/(TP + FP) =', round(cor_ht_pkt/current_sus_pac, 4)*100,'%,Higher is better')
        # dsct_node.clear()
        # dsct_src_node.clear()     ## initialise DSCT\
        # dsct_credit.clear()
        # dsct_src_credit.clear()
      # if cor % 1010 == 0:
      #   print('Current Suspicious Packet ID:', cor )
      #   print('Top 10 ranking path nodes is: ', top10_path_nodes)
      #   print('Top 10 ranking src nodes is: ', top10_src_nodes)
  print('==============================DSCT Done===============================')
  print('Final DSCTs: ')
  print('dsct:',dscta.tolist())
  print('dsct_src:',dscta_src.tolist()) 
  if len(HT_nodes) != 0:
    local_preci = cor_ht_posi/len(HT_nodes)
    if ht_node_second < 64:
      local_recall = cor_ht_posi / 1
    else:
      local_recall = cor_ht_posi / 2    
  else:
    local_preci = 0
    local_recall = 0
  return cor_ht_src_num, cor_ht_src,cor_ht_pkt, top10_path_nodes, top10_src_nodes, local_preci, local_recall


low_b, up_b, scale_rate, calib_up_array, calib_down_array, calib_input_array = calibration_phase(calib_rate, pac_result_ht[0:10000]) #pac_result_golden

sus_pac_dci, sus_pac_num, low_b_array, up_b_array, input_new_array = DCI(pac_result_ht, low_b, up_b, scale_rate)

cor_ht_src_num, cor_ht_src_dsct,cor_ht_pkt, top10_path_nodes,top10_src_nodes,local_preci, local_recall  = DSCT(sus_pac_dci,pkt_src, pkt_dst, ht_node, ht_node_second,ht_detect_time, dsct_and_caging)

GN = pkt_data_num - GP_num
TN = pkt_data_num - sus_pac_num - GP_num
TP = cor_ht_pkt
FP = sus_pac_num - cor_ht_pkt
FN = GP_num - sus_pac_num
Specificity = TN / GN

#print('The total detected malicious packets is %d, actual malicious packets is %d'%(bad_pack, total_ht))
print('====================Detection Sumarry of ',workload_name,'=====================')
print('===============================DCI performance=================================')
print('Based on:',ht_data_level, 'The total detected suspicious packets is', '\033[31m' ,sus_pac_num, '\033[0m')
print('Based on:',ht_data_level, 'The correct ht infected packet number is', '\033[31m' , TP, '\033[0m')
print('Based on:',ht_data_level, 'Precision: TP/(TP + FP) =', round(TP/sus_pac_num, 3)*100,'%,Higher is better')
print('Based on:',ht_data_level, 'Sensitivity/Recall/TPR: TP/GP =', round(TP/GP_num, 3)*100,'%,Higher is better')
print('Based on:',ht_data_level, 'Specificity/TNR: TN/GN =', round(Specificity, 3)*100,'%,Higher is better')
print('Based on:',ht_data_level, 'FPR: FP/GN =', round(FP/GN, 3)*100,'%, Lower is better')
print('Based on:',ht_data_level, 'FNR: FN/GN =', round(FN/GN, 3)*100,'%, Lower is better')
print('===============================DSCT performance================================')
print('Based on:',ht_data_level,'Precision: True node/(True node + False node) = ', round(local_preci,3)*100,'%, Higher is better')
print('Based on:',ht_data_level,'Recall: True node/(GP) = ', round(local_recall,3)*100,'%, Higher is better')
#print('The detected suspicious indexs are ',sus_pac_dci)
#print('The correct HT src number is %d, they are '%(cor_ht_src_num), cor_ht_src_dsct )
# print('The actual malicious index are', actual_ht_index)
# print('The accuracy is %d%%'%(cor_num/len(actual_ht_index)*100))

# with open(r"C:\Users\owner\Nutstore\1\Academic\Experiment\FIgures\DATE_ANN\hot_map.csv", 'w', newline='') as hot_map_file: 
#     writer_hot_map = csv.writer(hot_map_file)
#     writer_hot_map.writerow(top10_path_nodes)



####################################### DCI & DSCT visualization ######################################################
# store_num_down = 0
# store_num_up = 400
# for i in range(store_num_down, store_num_up):
#   calib_up_array[i] = round(calib_up_array[i],3)
#   calib_down_array[i] = round(calib_down_array[i],3)
#   calib_input_array[i] = round(calib_input_array[i],3)
# with open(r'C:\Users\owner\Nutstore\1\Academic\publication\ASP-DAC2023\Figures\bounds_array.csv', 'w', newline='') as bounds_array:
#     writer_bounds_array = csv.writer(bounds_array)
#     #for i in range(0, 500):
#     writer_bounds_array.writerow(map(lambda x: [x], calib_down_array[store_num_down:store_num_up]))

# with open(r'C:\Users\owner\Nutstore\1\Academic\publication\ASP-DAC2023\Figures\bounds_array_up.csv', 'w', newline='') as bounds_array:
#     writer_bounds_array = csv.writer(bounds_array)
#     #for i in range(0, 500):
#     writer_bounds_array.writerow(map(lambda x: [x], calib_up_array[store_num_down:store_num_up]))

# with open(r'C:\Users\owner\Nutstore\1\Academic\publication\ASP-DAC2023\Figures\prediction_value.csv', 'w', newline='') as bounds_array:
#     writer_bounds_array = csv.writer(bounds_array)
#     #for i in range(0, 500):
#     writer_bounds_array.writerow(map(lambda x: [x], calib_input_array[store_num_down:store_num_up]))


left_num = 0
right_num = 1000
data_num = right_num - left_num


fig, (ax1) = plt.subplots(1,1)

x_plt = np.linspace(left_num, right_num, data_num)
y_plt = pac_result_ht

ax1.plot(x_plt, y_plt[left_num:right_num], linestyle = '-', color = 'black',linewidth=0.5)
# ax1.plot(x_plt, low_b_array[left_num:right_num], linestyle = '--', color = 'blue',linewidth=0.4)
# ax1.plot(x_plt, up_b_array[left_num:right_num], linestyle = '--', color = 'green',linewidth=0.4)
#ax1.set_title('500-packet results 5 scaling inputs 2NL 20NU, normal and malicious packets', fontsize=15)# 标题
ax1.set_xlabel('packet', fontsize=10)# x轴标签
ax1.set_ylabel('prediction result', fontsize=10)# y轴标签
#ax1.legend(loc='best')# 图例

# for x,y in zip(x_plt,y_plt):

#     label = "{:.2f}".format(y)

#     plt.annotate(label, # this is the text
#                  (x,y), # these are the coordinates to position the label
#                  textcoords="offset points", # how to position the text
#                  xytext=(0,10), # distance from text to points (x,y)
#                  ha='center')

# for x,y in zip(x_plt,low_b_array):

#     label = "{:.2f}".format(y)

#     plt.annotate(label, # this is the text
#                  (x,y), # these are the coordinates to position the label
#                  textcoords="offset points", # how to position the text
#                  xytext=(0,10), # distance from text to points (x,y)
#                  ha='center')


#ax2.plot(x_plt, y_plt[left_num:right_num], linestyle = '-', color = 'black',linewidth=0.5)# 设置颜色、标记符号、线型、图例标签

#ax2.set_title('500-packet results 5 scaling inputs 2NL 20NU, normal packets', fontsize=15)# 标题
#ax2.set_xlabel('packet', fontsize=10)# x轴标签
#ax2.set_ylabel('prediction result', fontsize=10)# y轴标签

#ax3.plot(x_plt, y_plt_ht[left_num:right_num], linestyle = '-', color = 'black',linewidth=0.5)# 设置颜色、标记符号、线型、图例标签

           
#ax3.set_title('500-packet results 5 scaling inputs 2NL 20NU, malicious packets', fontsize=15)# 标题
#ax3.set_xlabel('packet', fontsize=10)# x轴标签
#ax3.set_ylabel('prediction result', fontsize=10)# y轴标签

#plt.show()
########################################################################################

