import asyncio
import json

import websockets
from websockets.legacy.protocol import broadcast
from websockets.protocol import State

from celilo.actions import *
from celilo.test_suite import devices

CONNECTIONS = {}
DEVICES = {}


class Wss:

    def __init__(self, ip):
        self.ip = ip

    async def handle_fw_update(self, message):
        j = json.loads(message)
        sender_id = j["sender"]
        assert sender_id

        rec_socket = CONNECTIONS[sender_id]

        await rec_socket.send(message)

    async def handle_realm_connect(self, message):
        j = json.loads(message)

        devices_json = json.dumps(DEVICES)

        print("devices json: {}".format(devices_json))
        try:
            # send python client the list of connected devices.
            # client will then choose which device to connect to
            # python client is the device who sent the realm request (local machine)
            python_client = CONNECTIONS['python']
        except KeyError:
            await python_client.send(devices_json)
            return

        await python_client.send(devices_json)

    async def handle_request_broadcast(self, message):
        j = json.loads(message)
        print("received broadcast request: {}".format(j))
        for key, value in CONNECTIONS.items():
            await value.send(json.dumps(j))



    async def handle_realm_request(self, message):
        j = json.loads(message)
        sender_id = j["sender"]
        assert sender_id

        print("requesting realm capabilities: device {}".format(sender_id))
        try:
            rec_socket = CONNECTIONS[sender_id]
        except KeyError:
            msg = "device {} not connected".format(sender_id)
            print(msg)
            await rec_socket.send(msg)
            return

        await rec_socket.send(message)

    async def handle_realm_response(self, message):
        j = json.loads(message)
        print("received realm response: {}".format(j))
        python_client = CONNECTIONS['python']
        await python_client.send(json.dumps(message))

    async def handle_connect(self, message, websocket):
        j = json.loads(message)
        sender = j["sender"]
        assert sender

        print("device connected: {}{}".format(sender, j))
        DEVICES[sender] = j
        CONNECTIONS[sender] = websocket
        request = {
            "action": ACTION_CONNECTED,
            "data": "connected"
        }
        broadcast({websocket}, message=json.dumps(request))
        pass

    async def handle_send(self, message):
        j = json.loads(message)
        sender_id = j["sender"]
        receiver_id = j["receiver"]
        assert sender_id
        assert receiver_id

        if receiver_id not in CONNECTIONS:
            print(f"Receiver {receiver_id} not connected make sure the id is correct in the test")
        else:
            sender_socket = CONNECTIONS[sender_id]
            print("Connected devices: {}".format(CONNECTIONS.keys()))
            message_text = j["extras"]
            print(f"Sending message from {devices[receiver_id]} to {devices[sender_id]}: {message_text}")

            request = {
                "action": ACTION_SEND_MESSAGE,
                "data": j
            }
            await sender_socket.send(json.dumps(request))
            # broadcast({rec_socket}, message=json.dumps(request))

    async def handle_test_response(self, message):
        print(f"received test response: {message}")
        j = json.loads(message)
        print("received test response: {}".format(j))
        python_client = CONNECTIONS['python']

        # for key, value in CONNECTIONS.items():
        #     print(f"{key} is connected: {value.state == State.OPEN}")

        # pass the test results back to the python client for validation
        print("python client connected: {}".format(python_client.state == State.OPEN))
        # if python_client.state != State.OPEN:
        await python_client.send(json.dumps(message))

    async def handler(self, websocket):
        async for message in websocket:

            print("received on WSS: {}".format(message))
            j = json.loads(message)

            type = j["action"]
            if type == "connect" or type == ACTION_CONNECT:
                await self.handle_connect(message, websocket)
            elif type == ACTION_SEND_MESSAGE:
                # the only client capable of sending messages is the python client.
                # The python client creates a new connection for each test suite run.
                # so lets store a reference to it, other clients need to send connect messages further in advance
                CONNECTIONS['python'] = websocket
                await self.handle_send(message)
            elif type == RESPONSE_TYPE_MESSAGE_TEST:
                await self.handle_test_response(message)
            elif type == ACTION_REALM_CONNECT:
                CONNECTIONS['python'] = websocket
                await self.handle_realm_connect(message)
            elif type == ACTION_REALM_CAPABILITIES:
                CONNECTIONS['python'] = websocket
                await self.handle_realm_request(message)
            elif type == ACTION_SQL_EXPORT:
                await self.handle_request_broadcast(message)
            elif type == RESPONSE_TYPE_REALM:
                await self.handle_realm_response(message)
            elif type == RESPONSE_TYPE_REALM_RESULTS:
                await self.handle_realm_response(message)
            elif type == ACTION_REALM_QUERY:
                await self.handle_realm_request(message)
            elif type == ACTION_FIRMWARE_UPDATE_START \
                    or type == ACTION_FIRMWARE_UPDATE_PROGRESS \
                    or type == ACTION_FIRMWARE_UPDATE_END:
                await self.handle_fw_update(message)

    async def main(self, ip):
        port = 5000
        async with websockets.serve(self.handler, ip, port):
            print(f"WSS listening {ip} port {port}")
            await asyncio.Future()  # run foreve

    def start(self):
        asyncio.run(self.main(self.ip))
