import sqlite3
import sqlite3
from typing import List, Optional


class MessageSqliteSource:
    def __init__(self, db_path: str, device_id):
        """Initialize the MessageSource with the given SQLite database path."""
        self.db_path = db_path
        self.device_id = device_id
        self.connection = sqlite3.connect(db_path)
        self.connection.row_factory = sqlite3.Row  # Allows accessing rows as dictionaries
        self.cursor = self.connection.cursor()

    def get_message_by_id(self, key: str) -> Optional[dict]:
        """Retrieve a message by its ID."""
        select_sql = "SELECT * FROM Message WHERE key = ?"
        self.cursor.execute(select_sql, (key,))
        row = self.cursor.fetchone()
        if row:
            return dict(row)
        return None

    def is_backhaul(self, parcelId: str, channel: str) -> bool:
        """Check if a message is backhaul based on its key."""
        query = f"""
        SELECT parcelId, json_extract(parcelStatuses, '$.hotspotId.userId')
        FROM Message
        WHERE parcelId = ?
          AND json_extract(parcelStatuses, '$.{channel}.hotspotId.userId') != 0 and 
          json_extract(parcelStatuses, '$.{channel}.channelStatus') == 'Delivered';
        """

        self.cursor.execute(query, (parcelId,))
        row = self.cursor.fetchone()
        if row:
            return True
        return False

    def get_messages_by_conversation(self, conversation_id: str) -> List[dict]:
        """Retrieve all messages associated with a conversation."""
        select_sql = "SELECT * FROM Message WHERE conversation = ?"
        self.cursor.execute(select_sql, (conversation_id,))
        rows = self.cursor.fetchall()
        return [dict(row) for row in rows]

    def get_unread_messages(self) -> List[dict]:
        """Retrieve all unread messages (i.e., readDate is NULL)."""
        select_sql = "SELECT * FROM Message WHERE readDate IS NULL"
        self.cursor.execute(select_sql)
        rows = self.cursor.fetchall()
        return [dict(row) for row in rows]

    def get_all_messages(self):
        """Retrieve all messages."""
        select_sql = "SELECT * FROM Message"
        self.cursor.execute(select_sql)
        rows = self.cursor.fetchall()
        return [dict(row) for row in rows]

    def get_messages_for_day(self, start_date: str, end_date: str) -> List[dict]:

        """Retrieve all messages for a given day."""
        select_sql = "SELECT * FROM Message WHERE dateSent >= ? AND dateSent <= ?"
        self.cursor.execute(select_sql, (start_date, end_date))
        rows = self.cursor.fetchall()
        return [dict(row) for row in rows]

    def close(self):
        """Close the database connection."""
        self.connection.close()


# main
if __name__ == "__main__":
    path = "path/to/atak.sqlite"
    device_id = "device123"
    message_source = MessageSqliteSource(path, device_id)
    messages = message_source.get_all_messages()
    for message in messages:
        print(message)
    message_source.close()
