import cv2
import socket
import sys
import datetime

import mediapipe as mp
import numpy as np

from mediapipe.framework.formats.detection_pb2 import DetectionList

mp_face_detection = mp.solutions.face_detection


def main(should_connect, port):
    print("Started mediapipe python face_detection", file=sys.stderr)
    input = sys.stdin.buffer
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    if should_connect:
        # TODO add to argparse
        sock.connect(("127.0.0.1", port))

    # model_selection 1 is full range model best for faces within 5 meters
    with mp_face_detection.FaceDetection(
        model_selection=1, min_detection_confidence=0.5
    ) as face_detection:
        cnt = 0
        while True:
            t = str()

            lg = cnt % 1000 == 0
            cnt += 1

            if lg:
                print(f"{datetime.datetime.now()}: Processing frame {cnt-1}")

            frame = read_int_or_exit(input)
            length = read_int_or_exit(input)
            stride = read_int_or_exit(input)
            width = read_int_or_exit(input)
            top = read_int_or_exit(input)
            left = read_int_or_exit(input)
            height = int(length / stride)

            buf = input.read(length)

            if debug_file:
                debug_file.write(buf)
                debug_file.flush()

            if len(buf) != length:
                print(f"Read {len(buf)}/{length} bytes, exiting", file=sys.stderr)
                sys.exit(1)

            raw_img = np.frombuffer(buf, dtype=np.uint8).reshape(
                (height, int(stride / 4), 4)
            )
            img = cv2.cvtColor(raw_img, cv2.COLOR_RGBA2RGB)

            results = face_detection.process(img)

            if should_connect:
                if results.detections:
                    header = b"".join(
                        [
                            frame.to_bytes(4, "little"),
                            top.to_bytes(4, "little"),
                            left.to_bytes(4, "little"),
                            width.to_bytes(4, "little"),
                            height.to_bytes(4, "little"),
                        ]
                    )
                    dl = DetectionList()
                    dl.detection.extend(results.detections)
                    landmarks = dl.SerializeToString()
                    data = b"".join([header, landmarks])
                    sock.sendall(data)
                else:
                    sock.sendall(b"\x00")


def read_int_or_exit(input):
    buf = input.read(4)
    if len(buf) != 4:
        sys.exit(1)
    if debug_file:
        debug_file.write(buf)
    return int.from_bytes(buf, byteorder="little")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--connect", action="store_true")
    parser.add_argument("-p", "--port", type=int)
    args = parser.parse_args()
    if args.debug:
        debug_file = open("/tmp/face_detection_debug.bin", "wb")
    else:
        debug_file = None

    main(args.connect, args.port)
