"""
Author: Henry (hw1n21@soton.ac.uk)
Untitled-1 (c) 2022
Desc: description
Created:  2022-06-10T10:05:15.121Z
"""

#NN training on Netrace data

import imp
import os

import torch
from torch import detach, nn
from torch.utils.data import DataLoader, Dataset
from torch.nn.functional import normalize

from torchvision import datasets
from torchvision.transforms import ToTensor

from sklearn.preprocessing import StandardScaler 
from sklearn.model_selection import train_test_split

import numpy as np
import pandas as pd
import csv

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")



#TODO NetraceDataset Implementation (for abnormal detection)
feature_num = 7

class NetraceDataset(Dataset):
    def __init__(self, file_name):
     file_out = pd.read_csv(file_name)
     x = file_out.iloc[0:len(file_out), 2:feature_num].values
     y = file_out.iloc[0:len(file_out), feature_num].values

     #Feature Scaling
     sc = StandardScaler()
     x = sc.fit_transform(x)
     y = y

     #converting to torch tensors
     self.X = torch.tensor(x, dtype=torch.float32)
     #self.X = normalize(self.X, p=2.0, dim=0) * 10000
     self.y = torch.tensor(y)
     #self.y = normalize(self.y, p=2.0, dim=0) 
     

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


#TODO NetraceDataset Implementation (for dst prediction)

# class NetraceDataset(Dataset):
#     def __init__(self, file_name):
#      file_out = pd.read_csv(file_name)
#      x = file_out.iloc[0:len(file_out), 0:4].values
#      y = file_out.iloc[0:len(file_out), 4].values

#      #Feature Scaling
#      sc = StandardScaler()
#      x = x #sc.fit_transform(x)
#      y = y

#      #converting to torch tensors
#      self.X = torch.tensor(x, dtype=torch.float32)
#      #self.X = normalize(self.X, p=2.0, dim=0) 
#      self.X[:, 0] = normalize(self.X[:, 0], p=2.0, dim=0) *100000
#      self.X[:, 1] = normalize(self.X[:, 1], p=2.0, dim=0) * 10000
#      self.X[:, 2] = normalize(self.X[:, 2], p=2.0, dim=0) * 10000
#      self.X[:, 3] = normalize(self.X[:, 3], p=2.0, dim=0) * 10000 

#      self.y = torch.tensor(y, dtype=torch.float32)
#      self.y = normalize(self.y, p=2.0, dim=0) * 1000

#     def __len__(self):
#         return len(self.y)

#     def __getitem__(self, idx):
#         return self.X[idx], self.y[idx]



class NewData(Dataset):
    def __init__(self, X_train, y_train):
        self.X = X_train
        self.y = y_train
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


#create Netrace dataset and split
netrace_tensor = NetraceDataset(r"C:\Users\owner\Nutstore\1\Academic\Experiment\Data\Golden data\blackscholes_64c_simsmall_packets_ROI2_golden_label_all.csv")
X_train, X_test, y_train, y_test = train_test_split(netrace_tensor.X, netrace_tensor.y, test_size=0.3, train_size=0.4, random_state=None, shuffle=None)

NetraceTrain = NewData(X_train, y_train)
NetraceTest = NewData(X_test, y_test)

batch_size_train = len(NetraceTrain.y)
batch_size_test = len(NetraceTest.y)

train_dataloader = DataLoader(NetraceTrain, batch_size = batch_size_train, shuffle=False)
test_dataloader = DataLoader(NetraceTest, batch_size=batch_size_test, shuffle=False)


for X, y in train_dataloader:
    print(f"Shape of X: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

#Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(5, 20), 
            nn.ReLU(),
            nn.Linear(20, 20), 
            nn.ReLU(),
        #    nn.Linear(30, 30), 
        #    nn.ReLU(),
            nn.Linear(20, 1)  #
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits



model = NeuralNetwork().to(device)
print(model)

#loss_fn = nn.CrossEntropyLoss()
loss_fn=nn.BCEWithLogitsLoss() 
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def train_opt(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        #loss = loss_fn(pred, y) #CrossEntropy
        loss = loss_fn(pred[:, 0], y.float()) #BCE

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            loss, current = loss.item(), batch* len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def train(dataloader, model):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)

        # Backpropagation


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred[:, 0], y).item()
            #correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    #correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

# train_opt(train_dataloader, model, loss_fn, optimizer)

# pred = model(test_dataloader.X[100])
######TODO traing and testing#######
epochs = 50
max_range = []
min_range = []
diff_range = []
with open(r'C:\Users\owner\Nutstore\1\Academic\Experiment\Data\pretraining_verfication.csv', 'w', newline='') as out_file_pre:
    writer_pre = csv.writer(out_file_pre)
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_opt(train_dataloader, model, loss_fn, optimizer)
        for X, y in train_dataloader:
        #print(X[0:10])
            pred = model(X[0:10000].to('cuda'))                     
        #print( pred[:, torch.argmax(pred, dim=1)] )
        #print(y[0:10]) 
            max_range.append(max(pred.cpu()).detach().numpy())
            min_range.append(min(pred.cpu()).detach().numpy())
            diff_range.append(max_range[t] - min_range[t])
            print('range diff is : ', diff_range[t])
            break 
    for i in range(len(max_range)):
       row = [max_range[i], min_range[i], diff_range[i]]
       writer_pre.writerow(row)






print("Done!")

# prediction_dir = r"C:\Users\owner\OneDrive - University of Southampton\Documents\Experiment_Materials\Code\dst prediction\NN_model_2HL_15NU_normalized_in_30e.pth"
# detection_dir = r"C:\Users\owner\OneDrive - University of Southampton\Documents\Experiment_Materials\Code\abnormal detection\NN_model_5IN_2HL_20NU_scaler_in_50e.pth"
# torch.save(model.state_dict(), detection_dir)
# print("Saved PyTorch Model State to" + detection_dir)

# model = NeuralNetwork()
# model.load_state_dict(torch.load(r"C:\Users\owner\OneDrive - University of Southampton\Documents\Experiment_Materials\Code\NN_model_3HL_30NU.pth"))

# pred = model(netrace_tensor.X)
# print(pred)
# print(netrace_tensor.y)





