import cv2
import socket
import sys

import mediapipe as mp
import numpy as np

from mediapipe.framework.formats.landmark_pb2 import NormalizedLandmarkListCollection

mp_hands = mp.solutions.hands


def main(should_connect):
    print("Started mediapipe python hands", file=sys.stderr)
    input = sys.stdin.buffer
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    if should_connect:
        # TODO add to argparse
        sock.connect(("127.0.0.1", 40404))

    with mp_hands.Hands(
        static_image_mode=False,
        max_num_hands=2,
        model_complexity=1,
        min_detection_confidence=0.5,
        min_tracking_confidence=0.5,
    ) as hands:
        while True:
            frame = read_int_or_exit(input)
            length = read_int_or_exit(input)
            stride = read_int_or_exit(input)
            width = read_int_or_exit(input)
            top = read_int_or_exit(input)
            left = read_int_or_exit(input)
            height = int(length / stride)

            buf = input.read(length)

            if debug_file:
                debug_file.write(buf)
                debug_file.flush()

            if len(buf) != length:
                print(f"Read {len(buf)}/{length} bytes, exiting", file=sys.stderr)
                sys.exit(1)

            raw_img = np.frombuffer(buf, dtype=np.uint8).reshape(
                (height, int(stride / 4), 4)
            )
            img = cv2.cvtColor(raw_img, cv2.COLOR_RGBA2BGR)
            results = hands.process(img)

            if should_connect:
                if results.multi_hand_landmarks:
                    header = b"".join(
                        [
                            frame.to_bytes(4, "little"),
                            top.to_bytes(4, "little"),
                            left.to_bytes(4, "little"),
                            width.to_bytes(4, "little"),
                            height.to_bytes(4, "little"),
                        ]
                    )
                    llc = NormalizedLandmarkListCollection()
                    for landmarks in results.multi_hand_landmarks:
                        normalized_landmark_list = llc.landmark_list.add()
                        normalized_landmark_list.landmark.extend(landmarks.landmark)

                    data = b"".join([header, llc.SerializeToString()])
                    sock.sendall(data)
                else:
                    sock.sendall(b"\x00")


def read_int_or_exit(input):
    buf = input.read(4)
    if len(buf) != 4:
        sys.exit(1)
    if debug_file:
        debug_file.write(buf)
    return int.from_bytes(buf, byteorder="little")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--connect", action="store_true")
    args = parser.parse_args()
    if args.debug:
        debug_file = open("/tmp/hands_debug.bin", "wb")
    else:
        debug_file = None

    main(args.connect)
