from concurrent import futures
import time
import logging
import sys
import grpc
import numpy as np 
import cv2 as cv
# import pickle
sys.path.insert(0, '../protos/')
sys.path.insert(0, './TensorFlow-YOLOv3/')
from core import utils
import deepcam_pb2
import deepcam_pb2_grpc
import tensorflow as tf
from PIL import Image
import io
import tensorflow.contrib.tensorrt as trt
from tensorflow.python.platform import gfile
# Define image size for yolo
SIZE = [416, 416]
IMAGE_H, IMAGE_W = 416, 416
classes = utils.read_coco_names('./TensorFlow-YOLOv3/data/coco.names')
num_classes = len(classes)
TENSORRT_YOLOv3_MODEL = "./TensorFlow-YOLOv3/checkpoint/TensorRT_YOLOv3.pb"

# gpu_nms_graph = tf.Graph()

# Load tensorflow model
# input_tensor, output_tensors = utils.read_pb_return_tensors(gpu_nms_graph, "./TensorFlow-YOLOv3/checkpoint/yolov3_gpu_nms.pb", ["Placeholder:0", "concat_10:0", "concat_11:0", "concat_12:0"])
# input_tensor, output_tensors = utils.read_pb_return_tensors(gpu_nms_graph, "./TensorFlow-YOLOv3/checkpoint/TensorRT_model.pb", ["Placeholder:0", "concat_9:0", "mul_9:0"])
# input_tensor, output_tensors = utils.read_pb_return_tensors(tf.get_default_graph(), TENSORRT_YOLOv3_MODEL, ["Placeholder:0", "concat_9:0", "mul_9:0"])
# get input-output tensor
input_tensor, output_tensors = utils.read_pb_return_tensors(tf.get_default_graph(), TENSORRT_YOLOv3_MODEL, ["Placeholder:0", "concat_10:0", "concat_11:0", "concat_12:0"])

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

class DeepCam(deepcam_pb2_grpc.DeepCamServicer):  
    def PleaseReply(self, request, context):
        print("Received :", request.msg_number)
        return deepcam_pb2.ReplyMessage(msg_recv='Hello, %s!' % request.msg_send)

    def DetectionRequest(self, request, context):
        
        # start = time.time()
        bbox_array = np.ones((5), dtype=int)

        bbox_array_bytes = bbox_array.tobytes()
        rgb_img = np.frombuffer(request.img_data, dtype="uint8")
        
        rgb_img = cv.imdecode(rgb_img, 1)
        pil_img = Image.fromarray(rgb_img)
        
        # Convert ke ukuran target
        # img_resized = np.array(pil_img.resize(size=(IMAGE_H, IMAGE_W)), dtype=np.float32)
        img_resized = np.array(image.resize(size=tuple(SIZE)), dtype=np.float32)
        img_resized = img_resized / 255.
        start = time.time()
        boxes, scores, labels = sess.run(output_tensors, feed_dict={input_tensor: np.expand_dims(img_resized, axis=0)})
        print("=> nms on gpu the number of boxes= %d  time=%.2f ms" %(len(boxes), 1000*(time.time()-start)))
        return_string = str("=> nms on gpu the number of boxes= %d  time=%.2f ms" %(len(boxes), 1000*(time.time()-start)))
        result_pil_image = utils.draw_boxes(pil_img, boxes, scores, labels, classes, [IMAGE_H, IMAGE_W], show=False)
        result_rgb_image = np.array(result_pil_image) 
        _, rgb_img_bytes = cv.imencode('.jpg', result_rgb_image)
        rgb_img_bytes = rgb_img_bytes.tobytes()
        img_h, img_w, img_ch = result_rgb_image.shape
        return deepcam_pb2.ReplyImgMessage(img_h=img_h, img_w=img_w, img_ch=img_ch, img_data=rgb_img_bytes, bbox_result=bbox_array_bytes)

    def YOLORequest(self, request, context):
        rgb_img = np.frombuffer(request.img_data, dtype="uint8")
        rgb_img = cv.imdecode(rgb_img, 1)
        pil_img = Image.fromarray(rgb_img)
        img_resized = np.array(pil_img.resize(size=tuple(SIZE)), dtype=np.float32)
        img_resized = img_resized / 255.
        start = time.time()

        # boxes, scores = sess.run(output_tensors, feed_dict={input_tensor: np.expand_dims(img_resized, axis=0)})
        # boxes, scores, labels = utils.cpu_nms(boxes, scores, num_classes, score_thresh=0.4, iou_thresh=0.5)
        # print("=> nms on gpu the number of boxes= %d  time=%.2f ms" %(len(boxes), 1000*(time.time()-start)))
        # frame = cv.imread("../images/image1.jpg")
        # image = Image.fromarray(frame)
        # img_resized = np.array(image.resize(size=tuple(SIZE)), dtype=np.float32)
        # img_resized = img_resized / 255.
        prev_time = time.time()

        boxes, scores, labels = sess.run(output_tensors, feed_dict={input_tensor: np.expand_dims(img_resized, axis=0)})
        # print(type(boxes))
        print(boxes)
        print(scores)
        print(labels)
        # boxes, scores, labels = utils.cpu_nms(boxes, scores, num_classes, score_thresh=0.4, iou_thresh=0.5)
        # image = utils.draw_boxes(image, boxes, scores, labels, classes, SIZE, show=False)

        curr_time = time.time()
        exec_time = curr_time - prev_time
        info = "time:" + str(round(1000*exec_time, 2)) + " ms, FPS: " + str(round((1000/(1000*exec_time)),1))
        print(info)
        return deepcam_pb2.YOLOMessage(boxes= boxes.tobytes(), scores=scores.tobytes(), labels=labels.tobytes())

def serve():
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
    deepcam_pb2_grpc.add_DeepCamServicer_to_server(DeepCam(), server)
    server.add_insecure_port('[::]:2000')
    server.start()
    print("Server Running")
    try:
        while True:
            time.sleep(0.01)
    except KeyboardInterrupt:
        server.stop(0)

if __name__ == '__main__':
    logging.basicConfig()
    # with tf.Session(graph=gpu_nms_graph) as sess:
    with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.5))) as sess:
        serve()
        