import cv2
import imutils
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import camera
import config
import time


def preprocess(image):
    '''
    Converts images to grayscale, performs background subtraction (using backgroundimage)
    provided, thresholds the resulting image to only keep dark pixels (indicitive of 
    the presence of a cell, and removes detections smaller than 150px (which are exclusively
    false positives). Returns resulting thresholded image
    '''
    def bg_sub(image, bgblur):
        '''
        Carries out background subtraction by subtracting a blurred grayscale
        background image (provided in argument bg) from the image. Carries out
        some morphological operations to reduce noise
        '''
        # Perform median blurs on image and bg
        image_blur = cv2.medianBlur(image, 29)
        # Convert both images to grayscale
        image_blur_gray = cv2.cvtColor(image_blur, cv2.COLOR_BGR2GRAY)
        # Subtract background from image grayscales
        difference = cv2.absdiff(image_blur_gray, bgblur)
        # Threshold resulting difference image with value = 27
        thresh = cv2.threshold(difference, 27, 255, cv2.THRESH_BINARY)[1]
        # Dilate the threshold image to remove small holes and smooth
        thresh = cv2.dilate(thresh, None, iterations=2)
        return thresh

    def remove_small_objects(img, min_size=150):
        '''
        Removes blobs from image which have areas below min_size. Required to eliminate
        false positives from detritus, beads, etc. Returns a copy of the input img with
        these small detections removed.
        '''
        # Find all  connected components (white blobs in threshold image)
        nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(
            img, connectivity=8)
        # connectedComponentswithStats also selects the background as a 'component'
        # but this is not wanted so the next 2 lines removes that component
        sizes = stats[1:, -1]
        nb_components = nb_components - 1
        # The image to return
        img2 = img
        # For every component in the image, keep it only if it's above min_size
        for i in range(0, nb_components):
            if sizes[i] < min_size:
                img2[output == i + 1] = 0
        # Return the resulting image
        return img2
    # Return the image after background subtraction, thresholding and small
    # object removal
    if config.bgblur is None:
        print("Didn't detect a bg image - taking one now")
        camera.capture_backgroundimage()
    #print('got image shape:' + str(image.shape) + ', bg shape:' + str(backgroundimage.shape))
    return remove_small_objects(bg_sub(image, config.bgblur))


def count_phyto(image):
    '''
    Preprocesses the image using preprocess() with the supplied background image for
    background subtraction, then counts connected components in the resulting
    threshold image. Connected components remaining after preprocessing are cells
    (or false positives). Returns just the number of detections counted.
    '''
    # Carry out thresholding, background subtraction and small object removal
    image = preprocess(image)
    # Carry out morphological opening
    kernel = np.ones((3, 3), np.uint8)
    opening = cv2.morphologyEx(image, cv2.MORPH_OPEN, kernel)
    # Generate distance transform and use it for thresholding
    dist_transform = cv2.distanceTransform(opening, cv2.DIST_L2, 5)
    ret, last_image = cv2.threshold(
        dist_transform, 0.3*dist_transform.max(), 255, 0)
    # Image in which to count objects
    last_image = np.uint8(last_image)
    # Count
    count, _ = cv2.connectedComponents(last_image)
    return count-1


#im = cv2.imread('captures/CULTURE2_1ML_17NOV2020/17-11-2020-15_08_27_632.JPG')
#config.bgblur = cv2.cvtColor(cv2.medianBlur(im, 29), cv2.COLOR_BGR2GRAY)

#config.bgblur = cv2.cvtColor(cv2.medianBlur(cv2.resize(im, (0,0), fx=config.resizeratio, fy=config.resizeratio) , 29), cv2.COLOR_BGR2GRAY)

def count_in_folder(folder: str):
    count = 0
    imagecount = 0
    starttime = time.time()
    for filename in os.listdir(folder):
        if '.JPG' in filename:
            img = cv2.imread(os.path.join(folder, filename))
            img = cv2.resize(img, (0, 0), fx=config.resizeratio,
                             fy=config.resizeratio)
            count += count_phyto(img)
            imagecount += 1
    endtime = time.time()
    print('Counted ' + str(count) + ' cells in ' + str(imagecount) + ' images, took ' +
          str(endtime-starttime) + ' s (' + str(imagecount/(endtime-starttime)) + 'FPS)')
    return count
