#model_test

import imp
import os

import torch
from torch import detach, nn
from torch.utils.data import DataLoader, Dataset
from torch.nn.functional import normalize


from sklearn.preprocessing import StandardScaler 
from sklearn.model_selection import train_test_split
from scipy.interpolate import make_interp_spline

import numpy as np
import pandas as pd
import csv

import matplotlib.pyplot as plt
import seaborn as sns
                 # 
dataset_name = "blackscholes" #   fluidanimate    blackscholes  x264
NU_node =  25
epochs =  50     #    fluidanimate: 40 , x264 55, blackscholes 50
mode = "ht" # ht, golden
dataset_num = "500k"
ht_in_r = "HTin38"
ht_second = 99

if dataset_name == "blackscholes":
# data for blackscholes workload
    if ht_second >=64:
        golden_dataset = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Golden data\blackscholes_64c_simsmall_packets_ROI2_golden_"+str(dataset_num)+".csv"
        det_result_golden = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_"+str(NU_node)+"NU_"+str(mode)+"_"+str(dataset_num)+".csv"

        ht_dataset = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\HT data\blackscholes_64c_simsmall_packets_ROI2_"+ht_in_r+"_"+dataset_num+"_htinjected.csv"
        det_result = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_"+str(NU_node)+"NU_"+ht_in_r+"_"+str(dataset_num)+".csv"
    else:
        golden_dataset = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Golden data\blackscholes_64c_simsmall_packets_ROI2_golden_"+str(dataset_num)+".csv"
        det_result_golden = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_"+str(NU_node)+"NU_"+str(mode)+"_"+str(dataset_num)+".csv"

        ht_dataset = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\HT data\second_ht\blackscholes_64c_simsmall_packets_ROI2_"+ht_in_r+"_"+str(ht_second)+"_"+dataset_num+"_htinjected.csv"
        det_result = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\second_ht\detection_results_500_5IN_2HL_"+str(NU_node)+"NU_"+ht_in_r+"_"+str(ht_second)+"_"+str(dataset_num)+".csv"
elif dataset_name == "fluidanimate":
# data for fluidanimate workload
    golden_dataset = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Golden data\fluidanimate_64c_simsmall_packets_ROI2_golden_"+str(dataset_num)+".csv"
    det_result_golden = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_"+str(NU_node)+"NU_"+str(mode)+"_"+str(dataset_num)+"_fa.csv"

    ht_dataset = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\HT data\fluidanimate_64c_simsmall_packets_ROI2_"+ht_in_r+"_"+dataset_num+"_htinjected.csv"
    det_result = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_"+str(NU_node)+"NU_"+ht_in_r+"_"+str(dataset_num)+"_fa.csv"
elif dataset_name == "x264":
# data for x264 workload
    golden_dataset = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Golden data\x264_64c_simsmall_packets_ROI2_golden_"+str(dataset_num)+".csv"
    det_result_golden = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_"+str(NU_node)+"NU_"+str(mode)+"_"+str(dataset_num)+"_x264.csv"

    ht_dataset = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\HT data\x264_64c_simsmall_packets_ROI2_"+ht_in_r+"_"+dataset_num+"_htinjected.csv"
    det_result = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_"+str(NU_node)+"NU_"+ht_in_r+"_"+str(dataset_num)+"_x264.csv"


# ht_dataset = r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\HT data\blackscholes_64c_simsmall_packets_ROI2_HTin61_10k.csv"
# det_result = r'C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_25NU_golden_10k.csv'



#TODO NetraceDataset Implementation (for abnormal detection)

class NetraceDataset(Dataset):
    def __init__(self, file_name):
     file_out = pd.read_csv(file_name, header=None)
     x = file_out.iloc[0:len(file_out)+1, 0:5].values
     y = file_out.iloc[0:len(file_out)+1, 4].values

     #Feature Scaling
     sc = StandardScaler()
     x = sc.fit_transform(x)
     y = y

     #converting to torch tensors
     self.X = torch.tensor(x, dtype=torch.float32)
     #self.X = normalize(self.X, p=2.0, dim=0) * 10000
     self.y = torch.tensor(y)
     #self.y = normalize(self.y, p=2.0, dim=0) 
     

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

#TODO Define model

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(5, NU_node), #512
            nn.ReLU(),
            nn.Linear(NU_node, NU_node), #512, 512
            nn.ReLU(),
            # nn.Linear(30, 30), #512, 512
            # nn.ReLU(),
            nn.Linear(NU_node, 2) # 512, 10
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

pred_show = 20

#netrace_tensor = NetraceDataset(r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Golden data\blackscholes_64c_simsmall_packets_ROI2_golden_label_all.csv")
if dataset_name == "blackscholes":
    #model_name = "NN_model_5IN_2HL_"+str(NU_node)+"NU_scaler_in_"+str(epochs) + "e_2out" 
    model_name = "NN_model_5IN_2HL_"+str(NU_node)+"NU_scaler_in_"+str(epochs) + "e_black_htlabel" 
elif dataset_name == "fluidanimate":
    model_name = "NN_model_5IN_2HL_"+str(NU_node)+"NU_scaler_in_"+str(epochs) +"e_fa" 
elif dataset_name == "x264":
    model_name = "NN_model_5IN_2HL_"+str(NU_node)+"NU_scaler_in_"+str(epochs) +"e_x264" 
     
model = NeuralNetwork()
model.load_state_dict(torch.load(r"C:\Users\owner\OneDrive - University of Southampton\Documents\Experiment_Materials\Code\abnormal detection\The latest\{}.pth".format(model_name)))

if mode == "golden": 
    netrace_tensor  = NetraceDataset(golden_dataset)
    pred = model(netrace_tensor.X) 
elif mode == "ht":
    netrace_tensor_ht = NetraceDataset(ht_dataset)
    pred = model(netrace_tensor_ht.X) 

print("===============")
print(pred[0:pred_show])


###TODO store results
# with open(r'C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Prediction result\detection_results_500_5IN_2HL_20NU.csv', 'w', newline='') as out_file_pre:
#     writer_pre = csv.writer(out_file_pre)
#      #   writer.writerow(('title', 'intro'))
#     for i in range(len(pred)):
#        pred[i, 0] = max(pred[i])
#        pred_ht[i, 0] = max(pred_ht[i])
#        row = [round(float(pred[i].detach().numpy()), 4), round(float(pred_ht[i].detach().numpy()), 4)]
#        writer_pre.writerow(row)

if mode == "golden": 
    with open(det_result_golden, 'w', newline='') as out_file_pre_ht: 
        writer_pre_ht = csv.writer(out_file_pre_ht)
        #   writer.writerow(('title', 'intro'))
        #writer_pre_ht.writerow([])
        for i in range(len(pred)):
            pred[i, 0] = max(pred[i])
            writer_pre_ht.writerow(pred[i].detach().numpy())
elif mode == "ht":
    with open(det_result, 'w', newline='') as out_file_pre_ht: 
        writer_pre_ht = csv.writer(out_file_pre_ht)
        #   writer.writerow(('title', 'intro'))
        #writer_pre_ht.writerow([])
        for i in range(len(pred)):
            pred[i, 0] = max(pred[i])
            writer_pre_ht.writerow(pred[i].detach().numpy())




