import datetime
import os
import sys
import cv2
import mediapipe as mp
import numpy as np

mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_pose = mp.solutions.pose

BG_COLOR = (192, 192, 192, 0)  # TRANSPARENT


def main(should_connect, port):
    print("Started mediapipe python pose segmentation", file=sys.stderr)
    input = sys.stdin.buffer
    output = sys.stdout.buffer

    with mp_pose.Pose(
        model_complexity=2, enable_segmentation=True, min_detection_confidence=0.5
    ) as pose:
        while True:
            frame = read_int_or_exit(input)
            length = read_int_or_exit(input)
            stride = read_int_or_exit(input)
            width = read_int_or_exit(input)
            top = read_int_or_exit(input)
            left = read_int_or_exit(input)
            height = int(length / stride)
            buf = input.read(length)

            if debug_file:
                debug_file.write(buf)
                debug_file.flush()

            if len(buf) != length:
                print(f"Read {len(buf)}/{length} bytes, exiting", file=sys.stderr)
                sys.exit(1)

            raw_img = np.frombuffer(buf, dtype=np.uint8).reshape(
                (height, int(stride / 4), 4)
            )
            img = cv2.cvtColor(raw_img, cv2.COLOR_RGBA2BGR)
            results = pose.process(img)

            if should_connect:
                if results.segmentation_mask is not None and results.pose_landmarks:
                    # IMAGE BLEND IN PYTHON
                    # condition = (
                    #     np.stack((results.segmentation_mask,) * 4, axis=-1) > 0.1
                    # )
                    # bg_image = np.zeros(raw_img.shape, dtype=np.uint8)
                    # bg_image[:] = BG_COLOR
                    # new_image = np.where(condition, raw_img, bg_image)
                    # raw_bytes = new_image.tobytes()

                    # SEND JUST SEGMENTATION MASK
                    raw_bytes = results.segmentation_mask.tobytes()
                    landmarks = results.pose_landmarks.SerializeToString()
                    landmarks_size = len(landmarks)

                    to_le_bytes = lambda x: x.to_bytes(4, "little")
                    header = b"".join(
                        list(
                            map(
                                to_le_bytes,
                                [
                                    frame,
                                    width,
                                    height,
                                    stride,
                                    top,
                                    left,
                                    landmarks_size,
                                ],
                            )
                        )
                    )
                    size = len(header) + len(raw_bytes) + len(landmarks)
                    data = b"".join(
                        [size.to_bytes(4, "little"), header, raw_bytes, landmarks]
                    )
                    output.write(data)
                    output.flush()
                else:
                    zero = 0
                    output.write(zero.to_bytes(4, "little"))
                    output.flush()


def read_int_or_exit(input):
    buf = input.read(4)
    if len(buf) != 4:
        sys.exit(1)
    if debug_file:
        debug_file.write(buf)
    return int.from_bytes(buf, byteorder="little")


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--connect", action="store_true")
    parser.add_argument("-p", "--port", type=int)
    args = parser.parse_args()
    if args.debug:
        debug_file = open("/tmp/hands_debug.bin", "wb")
    else:
        debug_file = None

    print("Start mediapipe")
    main(args.connect, args.port)
