import threading
#import analyse
import cv2
import EasyPySpin
import datetime
import time
from PIL import Image
import os
import PySpin
import config

#from collections import deque
#from multiprocessing.pool import ThreadPool


class cameraclass:
    '''Class to implement camera functionality'''
    #TODO: Reorder functions more logically
    cameraconnection = None     # Stores the camera connection wrapper object
    imageformat = '.jpg'        # Image save format
    savedir = 'captures/test/'       # Image save location
    bluelevel = 1.5             # Blue channel balance level
    redlevel = 1.9              # Red channel balance level
    gain = 19                   # Gain in dB
    exposure = 305              # Exposure in us
    pixelformat = 'RGB8'        # Pixel format (bayer, RGB, etc)
    abort = False               # Set true to kill threaded capture
    def openconnection(self, index=0):
        '''Initiates a connection to the camera accessed at index specified
        (usually 0 unless multiple camera are connected)'''
        self.cameraconnection = EasyPySpin.VideoCapture(index)
        print('Opened camera connection: ' +
              str(self.cameraconnection.cam.DeviceModelName()))
    def closeconnection(self):
        '''Closes camera connection safely'''
        self.cameraconnection.release()
        self.cameraconnection = None
    def setgain(self, gain=13):
        '''Used for setting the gain (in dB) of the camera'''
        # Store new gain value
        self.gain = gain
        # Retrieve Gain node
        node_iGain_float = PySpin.CFloatPtr(
            self.cameraconnection.nodemap.GetNode("Gain"))
        # Set gain
        node_iGain_float.SetValue(self.gain)
    def setredbluebalance(self, red=1.9, blue=1.5):
        '''Used for setting the red/blue balance of the camera'''
        # Retrieve BalanceRatioSelector and BalanceRatio nodes
        node_balanceratioselector = PySpin.CEnumerationPtr(
            self.cameraconnection.nodemap.GetNode("BalanceRatioSelector"))
        node_balanceratio = PySpin.CFloatPtr(
            self.cameraconnection.nodemap.GetNode("BalanceRatio"))
        # Set BalanceRatioSelector to Red channel
        balance_select_red = node_balanceratioselector.GetEntryByName("Red")
        node_balanceratioselector.SetIntValue(balance_select_red.GetValue())
        # Set red channel
        node_balanceratio.SetValue(red)
        # Set BalanceRatioSelector to Blue channel
        balance_select_blue = node_balanceratioselector.GetEntryByName("Blue")
        node_balanceratioselector.SetIntValue(balance_select_blue.GetValue())
        # Set blue channel
        node_balanceratio.SetValue(blue)
        # Store new channel values
        self.bluelevel = blue
        self.redlevel = red
    def setpixelformat(self, pformat='RGB8'):
        '''Used for setting the pixel format (e.g. RGB8, BayerXXYY etc)'''
        # Retrieve Pixel Format node
        node_imageformat = PySpin.CEnumerationPtr(
            self.cameraconnection.nodemap.GetNode("PixelFormat"))
        # EnumEntry node for RGB8 format
        imageformatobject = node_imageformat.GetEntryByName(pformat)
        # Set pixel format to RGB8
        node_imageformat.SetIntValue(imageformatobject.GetValue())
        # Store new image format
        self.pixelformat = pformat
    def setexposure(self, exposure=305):
        '''Used for setting the exposure time of the camera (in microseconds)'''
        # TODO: Set exposure time
        # Store new exposure time
        self.exposure = exposure
        pass
    def setsavedir(self, directory):
        '''Changes the default save directory, creating if nonexistant'''
        if not os.path.exists(directory):
            os.makedirs(directory)
        self.savedir = directory
    def init(self):
        '''Connects to camera and loads default camera settings'''
        self.openconnection()
        self.setgain()
        self.setredbluebalance()
        self.setpixelformat()
        # self.setexposure()
    def liveview(self, FPS=None):
        '''Displays a window with a live video feed from the camera (at max FPS
        unless a display FPS is specified'''
        # TODO: threaded approach?
        cv2.namedWindow('Live View', cv2.WINDOW_NORMAL)
        if FPS is not None:
            period = 1./FPS
            starttime = time.time()
        while not self.abort:
            if cv2.waitKey(1) == 27:
                cv2.destroyAllWindows()
                break  # esc to quit
            img = self.grabimage()
            if not FPS:
                cv2.imshow('Live View', img)
            else:
                cv2.imshow('Live View', img)
                time.sleep(
                    period - ((time.time() - starttime) % period))
    def grabimage(self):
        '''Returns a single image frame from the camera'''
        return self.cameraconnection.read()[1]
    def saveimage(self, frame, directory=None):
        '''Saves the passed frame to the directory given or the default directory
        given in the class if none specified'''
        if directory is None:
            directory = self.savedir
        filename = str(directory + datetime.datetime.now().strftime(
            '%d-%m-%Y-%H_%M_%S_%f')[:-3] + self.imageformat)
        cv2.imwrite(filename, frame)
    def abortcapture(self):
        '''Aborts any current image capture threads'''
        # Toggle self.abort to kill executing capture thread(s) and reset
        self.abort = True
        time.sleep(0.1)
        self.abort = False
    def cap_n_blocking(self, n, FPS=None):
        '''Capture function is defined here to allow threaded execution'''
        initiated = time.time()
        if FPS is not None:
            period = 1./FPS
            starttime = time.time()
        n = int(n)
        count = 0
        while count < n and not self.abort:
            if not FPS:
                image = self.grabimage()
                self.saveimage(image, self.savedir)
                count += 1
            else:
                image = self.grabimage()
                self.saveimage(image, self.savedir)
                count += 1
                time.sleep(
                    period - ((time.time() - starttime) % period))
        endtime = time.time()
        print(f'Captured {count} frames in {str(endtime-initiated)} s.')
    def capture_n(self, n, FPS=None):
        '''Captures n frames at max FPS or the specified FPS'''
        def cap_n(n, FPS=None):
            '''Capture function is defined here to allow threaded execution'''
            initiated = time.time()
            if FPS is not None:
                period = 1./FPS
                starttime = time.time()
            n = int(n)
            count = 0
            while count < n and not self.abort:
                if not FPS:
                    image = self.grabimage()
                    self.saveimage(image, self.savedir)
                    count += 1
                else:
                    image = self.grabimage()
                    self.saveimage(image, self.savedir)
                    count += 1
                    time.sleep(
                        period - ((time.time() - starttime) % period))
            endtime = time.time()
            print(f'Captured {count} frames in {str(endtime-initiated)} s.')
        capthread = threading.Thread(target=cap_n, args=[n, FPS])
        capthread.start()
    def capture_backgroundimage(self):
        img = self.grabimage()
        self.saveimage(frame=img, directory='captures/background/')
        print('Saved background image')
    def capture_secs(self, s:float):
        '''Captures images continuously from the camera for s seconds''' 
        def cap_s(s):
            '''Capture function is defined here to allow threaded execution'''
            initiated = time.time()
            endtime = initiated + s
            count = 0
            while time.time() < endtime and not self.abort:
                image = self.grabimage()
                self.saveimage(image, self.savedir)
                count+=1
            print(f'Captured {count} frames in {s} s.')
        assert s > 0
        capthread = threading.Thread(target=cap_s, args=[s])
        capthread.start()



###Depreciated but not yet replaced
def capture_ml_analyse_nosave(ml: float = 1.0):
    images_tocapture = round(config.images_per_ml * ml)
    print('Preparing to capture and analyse ' +
          str(images_tocapture) + ' images')
    framecount = 0
    detected_phytoplankton = 0
    starttime = time.time()
    while framecount < images_tocapture:
        framecount += 1
        frame = config.cameraconnection.read()[1]
        detected_phytoplankton += analyse.count_phyto(cv2.resize(
            frame, (0, 0), fx=config.resizeratio, fy=config.resizeratio))
    endtime = time.time()
    print('Captured: ' + str(framecount) + ' images (' + str(framecount /
                                                             config.images_per_ml) + ' ml) in ' + str(endtime-starttime) + ' s')
    print('Counted ' + str(detected_phytoplankton) + ' cells')
    return detected_phytoplankton
