import cv2 as cv
import time
import numpy as np
import tensorflow as tf
import tensorflow.contrib.tensorrt as trt
from tensorflow.python.platform import gfile
from PIL import Image
import sys
sys.path.insert(0, './TensorFlow-YOLOv3/')
from core import utils
SIZE = [416, 416]
classes = utils.read_coco_names('./TensorFlow-YOLOv3/data/coco.names')
num_classes = len(classes)
TENSORRT_YOLOv3_MODEL = "./TensorFlow-YOLOv3/checkpoint/TensorRT_YOLOv3.pb"

# get input-output tensor
input_tensor, output_tensors = \
utils.read_pb_return_tensors(tf.get_default_graph(), TENSORRT_YOLOv3_MODEL, ["Placeholder:0", "concat_10:0", "concat_11:0", "concat_12:0"])

# perform inference
with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.5))) as sess:
    # vid = cv2.VideoCapture(video_path) # must use opencv >= 3.3.1 (install it by 'pip install opencv-python')
    for i in range(1):
        # return_value, frame = vid.read()
        frame = cv.imread("../images/image1.jpg")
        image = Image.fromarray(frame)
        img_resized = np.array(image.resize(size=tuple(SIZE)), dtype=np.float32)
        img_resized = img_resized / 255.
        prev_time = time.time()

        boxes, scores, labels = sess.run(output_tensors, feed_dict={input_tensor: np.expand_dims(img_resized, axis=0)})
        # print(my_output)
        # print(scores)
        # print(labels)
        # boxes, scores, labels = utils.cpu_nms(boxes, scores, num_classes, score_thresh=0.4, iou_thresh=0.5)
        image = utils.draw_boxes(image, boxes, scores, labels, classes, SIZE, show=False)
        # image.save("result.jpg")

        curr_time = time.time()
        exec_time = curr_time - prev_time
        result = np.asarray(image)
        info = "time:" + str(round(1000*exec_time, 2)) + " ms, FPS: " + str(round((1000/(1000*exec_time)),1))
        print(info)
        image.save("result.jpg")
        # cv.putText(result, text=info, org=(50, 70), 
        #             fontFace=cv.FONT_HERSHEY_SIMPLEX,
        #             fontScale=1, color=(255, 0, 0), thickness=2)
        #cv2.namedWindow("result", cv2.WINDOW_AUTOSIZE)
        # cv2.imshow("result", result)
        # if cv2.waitKey(10) & 0xFF == ord('q'): break