from datetime import datetime

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.dates as mdates
import pytz

# Load CSV file
# file_path = "connection_life_cycle.csv"


def get_latency(self, row):
    parcelId, latency = row.split("|")
    return parcelId, latency


class StreamReport:
    def __init__(self, path, device):
        self.settings_payload = None
        self.info_command_payload = None
        self.device = device
        self.file_path = path
        self.df = pd.read_csv(self.file_path)
        self.df["Date"] = pd.to_datetime(self.df["timestamp"], unit="ms")
        self.df["Date"] = self.df["Date"].dt.tz_localize("UTC").dt.tz_convert("America/Los_Angeles")
        self.startTime = self.df["timestamp"].min()
        self.rssi_data = []
        self.connect_events = []
        self.commands_per_second_ary = []
        self.outbound_pli_creates = []
        self.pli_updates_per_second = []
        self.message_latency_events = []
        self.pli_latency_events = []
        self.disconnect_events = []
        self.commndCount = 0
        self.pli_updates = 0
        self.outboundCreates = 0
        self.commands_per_second = 0
        self.last_command_timestamp = None
        self.last_pli_update_timestamp = None

    def compute_outbound_pli_creates_per_second(self, timestamp):
        if self.last_pli_update_timestamp is None:
            self.last_pli_update_timestamp = timestamp
            return 0
        else:
            self.last_pli_update_timestamp = timestamp
            time_diff = timestamp - self.startTime
            seconds = time_diff / 1000
            if seconds == 0:
                return 0
            self.outboundCreates += 1
            return self.outboundCreates / seconds

    def compute_outbound_pli_updates_per_second(self, timestamp):
        if self.last_pli_update_timestamp is None:
            self.last_pli_update_timestamp = timestamp
            return 0
        else:
            self.last_pli_update_timestamp = timestamp
            time_diff = timestamp - self.startTime
            seconds = time_diff / 1000
            if seconds == 0:
                return 0
            self.pli_updates += 1
            return self.pli_updates / seconds

    def compute_commands_per_second(self, timestamp):
        if self.last_command_timestamp is None:
            self.last_command_timestamp = timestamp
            return 0
        else:
            self.last_command_timestamp = timestamp
            time_diff = timestamp - self.startTime
            seconds = time_diff / 1000
            if seconds == 0:
                return 0
            self.commndCount += 1
            return self.commndCount / seconds

    def process_event(self, row):
        event = row["event"]
        if "BLE_RSSI" in event:
            message = row["message"]
            rssi_value = message
            try:
                val = abs(int(rssi_value))
                self.rssi_data.append((row["Date"], val))
            except:
                pass
        elif "DISCONNECT" in event:
            self.disconnect_events.append((row["Date"], True))
        elif "CONNECTED" in event:
            date = row["Date"]
            # print(f"appending date to connect {date}")
            self.connect_events.append((row["Date"], True))
        elif "INBOUND_COMMAND" in event:
            timestamp = row["timestamp"]
            commands_per_second = self.compute_commands_per_second(timestamp)
            self.commands_per_second_ary.append((row["Date"], commands_per_second))
        elif "OUTBOUND_PLI_CREATE" in event:
            timestamp = row["timestamp"]
            updates_per_second = self.compute_outbound_pli_creates_per_second(timestamp)
            self.outbound_pli_creates.append((row["Date"], updates_per_second))
        elif "OUTBOUND_PLI_UPDATE" in event:
            timestamp = row["timestamp"]
            updates_per_second = self.compute_outbound_pli_updates_per_second(timestamp)
            self.pli_updates_per_second.append((row["Date"], updates_per_second))
        elif "MESSAGE_LATENCY" in event:
            self.message_latency_events.append((row["Date"], float(row["message"]) / 1000))
        elif "SETTINGS" in event:
            self.settings_payload = row["message"]
        elif "INFO_COMMAND" in event:
            self.info_command_payload = row["message"]
        elif "PLI_LATENCY" in event:
            obj = get_latency(self, row["message"])
            pli_latency = float(obj[1])
            # print(f"PLI Latency: {pli_latency} ms")
            if pli_latency > 0:
                self.pli_latency_events.append((row["Date"], float(pli_latency) / 1000))

    def plot(self):
        print("rows in df:", len(self.df))
        for index, row in self.df.iterrows():
            self.process_event(row)

        # rssi_df = pd.DataFrame(self.rssi_data, columns=["Date", "RSSI"])
        # rssi_df = pd.DataFrame(self.rssi_data, columns=["Date", "RSSI"])

        def find_nearest_rssi(time, rssi_df):
            idx = np.abs(rssi_df["Date"] - time).idxmin()
            return rssi_df.loc[idx, "RSSI"]

        def find_nearest_PLI_Updates(time, pli_updates_df):
            # idx = np.abs(pli_updates_df["Date"] - time).idxmin()
            return 10

        pli_updates_df = pd.DataFrame(self.pli_updates_per_second, columns=["Date", "PLI_Updates"])

        # connect_points = [(t, find_nearest_rssi(t, rssi_df)) for t in self.connect_events]
        # disconnect_points = [(t, find_nearest_rssi(t, rssi_df)) for t in self.disconnect_events]

        # connect_df = pd.DataFrame(connect_points, columns=["Date", "RSSI"])
        # disconnect_df = pd.DataFrame(disconnect_points, columns=["Date", "RSSI"])

        fig, ax2 = plt.subplots(figsize=(16, 8))

        ax2.set_xlabel("Time")
        # ax1.set_ylabel("RSSI (dBm)", color="blue")
        # # ax1.plot(rssi_df["Date"], rssi_df["RSSI"], label="RSSI", linestyle="-", color="blue")
        # ax1.scatter(connect_df["Date"], connect_df["RSSI"], color="green", label="Connected", marker="o", s=100)
        # ax1.scatter(disconnect_df["Date"], disconnect_df["RSSI"], color="red", label="Disconnected", marker="x", s=100)
        # ax1.tick_params(axis='y', labelcolor="blue")

        # ax2 = ax1.twinx()
        # print(self.df["Date"])
        # print(self.df["Date"].dtype)

        print("pli_latency_events:", len(self.pli_latency_events))

        connect_event_dates = [x[0].tz_convert("America/Los_Angeles") for x in self.connect_events]
        disconnect_event_dates = [x[0].tz_convert("America/Los_Angeles") for x in self.disconnect_events]
        outbound_pli_dates = [x[0].tz_convert("America/Los_Angeles") for x in self.outbound_pli_creates]
        pli_updates_dates = [x[0].tz_convert("America/Los_Angeles") for x in self.pli_updates_per_second]
        message_latency_dates = [x[0].tz_convert("America/Los_Angeles") for x in self.message_latency_events]
        pli_latency_dates = [x[0].tz_convert("America/Los_Angeles") for x in self.pli_latency_events]

        # ax2.set_ylabel("Commands Per Second", color="orange")
        # ax2.plot(commands_per_second_dates, [x[1] for x in self.commands_per_second_ary],
        #          label="Commands Per Second", linestyle="-", color="orange")
        # ax2.tick_params(axis='y', labelcolor="orange")
        ax2.set_ylim(0, 150)

        # scatter plot for connect_events
        # non clammpe
        # ax2.scatter([x[0] for x in self.connect_events], [0 for x in self.connect_events], color="green",
        #             label="Connected", marker="o", s=100)
        #
        # ax2.scatter([x[0] for x in self.disconnect_events], [0 for x in self.disconnect_events], color="red",
        #             label="Disconnected", marker="x", s=100)

        print(f"connect events {len(self.connect_events)}")
        print(f"disconnect events {len(self.disconnect_events)}")

        # connected_event_ret = []
        # if pli_updates_df.empty:
        #     for d in self.connect_events:
        #         connected_event_ret.append(10)
        # else:
        #     connected_event_ret = [find_nearest_PLI_Updates(x[0], pli_updates_df) for x in self.connect_events]

        ax2.scatter(connect_event_dates,
                    [find_nearest_PLI_Updates(x[0], pli_updates_df) for x in self.connect_events],
                    color="green", label="Connected", marker="o", s=100)

        ax2.scatter(disconnect_event_dates,
                    [find_nearest_PLI_Updates(x[0], pli_updates_df) for x in self.disconnect_events],
                    color="red", label="Disconnected", marker="x", s=100)

        ax3 = ax2.twinx()
        ax3.spines['right'].set_position(('outward', 60))
        ax3.set_ylabel("Outbound PLI", color="purple")
        ax3.plot(outbound_pli_dates,
                 [x[1] for x in self.outbound_pli_creates], label="Outbound PLI",
                 linestyle="-", color="purple")
        ax3.tick_params(axis='y', labelcolor="purple")
        ax3.set_ylim(0, 50)

        ax4 = ax2.twinx()
        ax4.spines['right'].set_position(('outward', 120))
        ax4.set_ylabel("PLI Updates Per Second", color="blue")
        ax4.plot(pli_updates_dates,
                 [x[1] for x in self.pli_updates_per_second], label="PLI Updates Per Second",
                 linestyle="-", color="blue")
        ax4.tick_params(axis='y', labelcolor="blue")
        # set max y
        ax4.set_ylim(0, 50)

        # self.message_latency_events.sort(key=lambda x: x[1])

        # message latency
        ax5 = ax2.twinx()
        ax5.spines['right'].set_position(('outward', 180))
        ax5.set_ylabel("Inbound Messages", color="red")
        ax5.scatter(message_latency_dates, [x[1] for x in self.message_latency_events],
                    color="red",
                    label="Inbound Messages (seconds)", marker="o", s=10)

        # self.pli_latency_events.sort(key=lambda x: x[1])

        # pli latency
        ax6 = ax2.twinx()
        ax6.spines['right'].set_position(('outward', 240))
        ax6.set_ylabel("Inbound PLI", color="black")
        ax6.scatter(pli_latency_dates, [x[1] for x in self.pli_latency_events], color="black",
                    label="Inbound PLI (seconds)", marker="o", s=10)

        # Calculate averages before plotting
        average_message_latency = np.mean([x[1] for x in self.message_latency_events]) if self.message_latency_events else 0
        average_pli_latency = np.mean([x[1] for x in self.pli_latency_events]) if self.pli_latency_events else 0

        # Add horizontal lines for average latencies with distinct colors
        if average_message_latency > 0:
            ax5.axhline(y=average_message_latency, color='orange', linestyle='--', linewidth=2.5,
                       label=f'Avg Msg Latency: {average_message_latency:.2f}s')

        if average_pli_latency > 0:
            ax6.axhline(y=average_pli_latency, color='cyan', linestyle='--', linewidth=2.5,
                       label=f'Avg PLI Latency: {average_pli_latency:.2f}s')

        pst_tz = pytz.timezone("America/Los_Angeles")
        ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M:%S', tz=pst_tz))
        ax2.xaxis.set_major_locator(mdates.AutoDateLocator())  # Automatically adjust tick intervals
        plt.gcf().autofmt_xdate()

        fig.tight_layout()
        ax2.legend(loc="upper left")
        # ax2.legend(loc="upper right")
        plt.title("ATAK Stream Report")
        plt.xticks(rotation=45)
        plt.grid()
        # plt.show()
        # save file
        file_name = f"{self.device}_stream_report_.png"
        plt.savefig(file_name)
        event_stats = {
            "connect_events": len(self.connect_events),
            "disconnect_events": len(self.disconnect_events),
            "outbound_pli_create_events": len(self.outbound_pli_creates),
            "pli_db_update_events": self.pli_updates,
            "message_latency_events": len(self.message_latency_events),
            "pli_latency_events": len(self.pli_latency_events),
            "average_message_latency": average_message_latency,
            "average_pli_latency": average_pli_latency,
        }
        return file_name, self.settings_payload, self.info_command_payload, event_stats

# Create a StreamReport object
# report = StreamReport("connection_life_cycle_1763701765399.csv", "device_123")
# data = report.plot()
# print(data)
