import cv2
import socket
import sys

import mediapipe as mp
import numpy as np


mp_face_mesh = mp.solutions.face_mesh


def main(should_connect, port):
    print("Started mediapipe python face_mesh", file=sys.stderr)
    input = sys.stdin.buffer
    output = sys.stdout.buffer

    with mp_face_mesh.FaceMesh(
        max_num_faces=1,
        refine_landmarks=False,
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5,
    ) as face_mesh:
        while True:
            frame = read_int_or_exit(input)
            length = read_int_or_exit(input)
            stride = read_int_or_exit(input)
            width = read_int_or_exit(input)
            top = read_int_or_exit(input)
            left = read_int_or_exit(input)
            height = int(length / stride)

            buf = input.read(length)

            if debug_file:
                debug_file.write(buf)
                debug_file.flush()

            if len(buf) != length:
                print(f"Read {len(buf)}/{length} bytes, exiting", file=sys.stderr)
                sys.exit(1)

            raw_img = np.frombuffer(buf, dtype=np.uint8).reshape(
                (height, int(stride / 4), 4)
            )
            img = cv2.cvtColor(raw_img, cv2.COLOR_RGBA2BGR)

            results = face_mesh.process(img)

            if results.multi_face_landmarks:
                landmarks = results.multi_face_landmarks[0].SerializeToString()
                landmarks_size = len(landmarks)
                header = b"".join(
                    [
                        frame.to_bytes(4, "little"),
                        top.to_bytes(4, "little"),
                        left.to_bytes(4, "little"),
                        width.to_bytes(4, "little"),
                        height.to_bytes(4, "little"),
                        landmarks_size.to_bytes(4, "little"),
                    ]
                )
                total_size = len(header) + landmarks_size
                total_size = total_size.to_bytes(4, "little")
                data = b"".join([total_size, header, landmarks])
                output.write(data)
                output.flush()
            else:
                zero = 0
                output.write(zero.to_bytes(4, "little"))
                output.flush()


def read_int_or_exit(input):
    buf = input.read(4)
    if len(buf) != 4:
        sys.exit(1)
    if debug_file:
        debug_file.write(buf)
    return int.from_bytes(buf, byteorder="little")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--connect", action="store_true")
    parser.add_argument("-p", "--port", type=int)
    args = parser.parse_args()
    if args.debug:
        debug_file = open("/tmp/expressions_debug.bin", "wb")
    else:
        debug_file = None

    main(args.connect, args.port)
