import json
import os
from time import sleep

from celilo.reporting.MessageSqliteSource import MessageSqliteSource
from celilo.reporting.report_output import generate_report

db_suffix = '_atak'
state_file_name = 'sw_device_state'


def get_android_devices():
    cmd = 'adb devices | tail -n +2 | cut -sf 1'
    return os.popen(cmd).read().split()


class MessageTransmissionReport:
    def __init__(self, app_package, workspace_id):
        self.db_suffix = '_atak'
        self.state_file_name = 'sw_device_state'
        self.device_state = {}
        self.message_sources = {}
        self.app_package = app_package
        self.workspace_id = workspace_id

    def trigger_sqlite_export(self):
        devices = get_android_devices()
        if self.app_package != "com.atakmap.app.civ":
            for device in devices:
                cmd = f'adb -s {device} shell am start -a android.intent.action.VIEW -d somewear://workspace/{self.workspace_id} --es "cmd" "export_sql"'
                print(f"Running: {cmd}")
                os.system(cmd)
        else:
            print("Testing ATAK")
            for device in devices:
                cmd = f'adb -s {device} shell am start -a android.intent.action.VIEW -d somewear://workspace/{self.workspace_id} --es "cmd" "export_sql"'
                print(f"Running: {cmd}")
                os.system(cmd)

    def delete_messages(self):
        devices = get_android_devices()
        for device in devices:
            cmd = f'adb -s {device} shell am broadcast -a com.somewearlabs.ataklibs.TestHarnessBroadcastReceiver.TEST ' \
                  f'--es data "DELETE_MESSAGES"'
            print(f"Running: {cmd}")
            os.system(cmd)

    def pull_sqlite_file(self):
        cmd = f'sh ../cli/pullAtakSqliteDatabase.sh {self.app_package}'
        os.system(cmd)

    def kill_app(self):
        if self.app_package != "com.atakmap.app.civ":
            print("Testing ATAK TestApp")
            atak_kill_script = "../cli/atak_test_kill.sh"
        else:
            print("Testing ATAK")
            atak_kill_script = "../cli/atak_kill.sh"
        os.system(atak_kill_script)

    def restart_app(self):
        if self.app_package != "com.atakmap.app.civ":
            print("Testing ATAK TestApp")
            atak_restart_script = f"../cli/atakTestRestart.sh {self.workspace_id}"
        else:
            print("Testing ATAK")
            atak_restart_script = f"../cli/atak_restart.sh {self.workspace_id}"
        os.system(atak_restart_script)

    def compare_dbs(self, start_time, message_source, message_source2):
        print(f"Comparing databases: start_time: {start_time} device_id: {message_source.device_id} ")
        midnight = start_time.split(" ")[0] + " 23:59:59"
        messages = message_source.get_messages_for_day(start_time, midnight)

        results = {
            "successes": {},
            "failures": {},
            "cell_backhaul_success": [],
            "sat_backhaul_success": []
        }

        results["successes"].setdefault("messages", {})

        for message in messages:
            message_key = message["key"]
            parcel_id = message["parcelId"]
            message_found = message_source2.get_message_by_id(message_key)

            device_id = message_source2.device_id

            if message_found:
                isCellBackhaul = message_source2.is_backhaul(parcel_id, "Cellular")
                if isCellBackhaul:
                    results["cell_backhaul_success"].append(message_key)

                isSatBackhaul = message_source2.is_backhaul(parcel_id, "Satellite")
                if isSatBackhaul:
                    results["sat_backhaul_success"].append(message_key)

                results["successes"].setdefault(message_key, set()).add(device_id)
                results["successes"]["messages"][message_key] = message_found
            else:
                results["failures"].setdefault(message_key, set()).add(device_id)

        message_source.close()
        message_source2.close()
        return results

    def merge_maps(self, dict1, dict2):
        for key, value in dict2.items():
            if key in dict1:
                dict1[key].update(value)
            else:
                dict1[key] = value
        return dict1

    def load_json_state_obj(self, device_id, device_state_path):
        if not os.path.exists(device_state_path):
            print(f"Warning: Device state file not found for {device_id} at {device_state_path}, skipping state load")
            return

        try:
            with open(device_state_path) as json_file:
                data = json.load(json_file)
                self.device_state[device_id] = data
        except json.JSONDecodeError as e:
            print(f"Warning: Failed to parse JSON for device {device_id} at {device_state_path}: {e}")
        except Exception as e:
            print(f"Warning: Failed to load device state for {device_id} at {device_state_path}: {e}")

    def analyze_sqlite_file(self, start_time):
        devices = get_android_devices()
        total_failures = {}
        total_success = {}
        cell_backhaul_success = []
        sat_backhaul_success = []
        print("Devices found: ", devices)
        for index in range(len(devices)):

            for index2 in range(len(devices)):
                device_id = devices[index]
                device_id2 = devices[index2]

                device1_root = f"reporting/device/{device_id}{self.db_suffix}"
                device2_root = f"reporting/device/{device_id2}{self.db_suffix}"

                # Check if device directories exist
                if not os.path.exists(device1_root):
                    print(f"Device {device_id} does not exist at path {device1_root}, skipping")
                    continue

                if not os.path.exists(device2_root):
                    print(f"Device {device_id2} does not exist at path {device2_root}, skipping")
                    continue

                path = f'{device1_root}/atak.sqlite'
                path2 = f'{device2_root}/atak.sqlite'

                # Check if database files exist
                if not os.path.exists(path):
                    print(f"Database file not found for device {device_id} at {path}, skipping")
                    continue

                if not os.path.exists(path2):
                    print(f"Database file not found for device {device_id2} at {path2}, skipping")
                    continue

                device_state_path = f'{device1_root}/device/{self.state_file_name}.json'
                device_state_path2 = f'{device2_root}/device/{self.state_file_name}.json'

                self.load_json_state_obj(device_id, device_state_path)
                self.load_json_state_obj(device_id2, device_state_path2)

                message_source = MessageSqliteSource(path, device_id)
                message_source2 = MessageSqliteSource(path2, device_id2)

                results = self.compare_dbs(start_time, message_source, message_source2)

                total_failures = self.merge_maps(total_failures, results["failures"])
                total_success = self.merge_maps(total_success, results["successes"])
                cell_backhaul_success = results["cell_backhaul_success"]
                sat_backhaul_success = results["sat_backhaul_success"]

        # print("device state: ", self.device_state)
        # print("success: ", total_success)
        # print("failures: ", total_failures)


        report = generate_report(
            total_success,
            total_failures,
            self.device_state,
            cell_backhaul_success,
            sat_backhaul_success
        )
        print(report)
        return report

    def run(self, start_time):
        # self.kill_app()
        self.trigger_sqlite_export()
        sleep(10)
        self.pull_sqlite_file()
        report = self.analyze_sqlite_file(start_time=start_time)
        # print current directory
        current_directory = os.getcwd()
        print(f"Report directory: {current_directory}/reporting/device")
        print('done')
        return report


if __name__ == "__main__":
    # report = MessageTransmissionReport("com.somewearlabs.atak.debug", "49353")
    report = MessageTransmissionReport("com.atakmap.app.civ", "49353")
    time = "2025-12-10 15:20:45"
    report.run(time)
