from bisect import insort
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Iterable, Iterator, Union

from nextcloud_tasks_api import TaskList
from nextcloud_tasks_api.ical import Task

from rich.color import Color
from rich.style import Style
from rich.text import Text
from textual.binding import Binding
from textual.message import Message
from textual.reactive import Reactive, reactive
from textual.widgets import Tree, Select
from textual.widgets.tree import TreeNode


@dataclass
class TaskTreeNode:
    uid: str
    parent: Optional[str]

    completed: bool
    summary: str
    description: bool


@dataclass
class TaskTreeNodeTree(TaskTreeNode):
    children: List['TaskTreeNodeTree'] = field(default_factory=list)

    def __lt__(self, other: 'TaskTreeNodeTree') -> bool:
        """As bisect.insort key"""
        return self.summary.lower() < other.summary.lower()

    def _find_r(self, uid: str) -> Optional['TaskTreeNodeTree']:
        if self.uid == uid:
            return self
        return self._find(self.children, uid)

    @classmethod
    def _find(cls, nodes: List['TaskTreeNodeTree'], uid: str) -> Optional['TaskTreeNodeTree']:
        for child in nodes:
            found: Optional[TaskTreeNodeTree] = child._find_r(uid)
            if found is not None:
                return found
        return None

    @classmethod
    def from_task(cls, ical: Task) -> 'TaskTreeNodeTree':
        description: Optional[str] = ical.description
        uid: Optional[str] = ical.uid
        if uid is None:
            raise ValueError("Empty task UID")
        return cls(uid=uid, parent=ical.related_to,
                   completed=ical.completed is not None,
                   summary=ical.summary or "",
                   description=True if description and description.strip() else False)

    @classmethod
    def _build_flat(cls, flat_nodes: List['TaskTreeNodeTree']) -> None:
        i: int = 0
        while i < len(flat_nodes):
            node: TaskTreeNodeTree = flat_nodes[i]
            related_to: Optional[str] = node.parent
            if related_to is not None:
                parent: Optional[TaskTreeNodeTree] = cls._find(flat_nodes, related_to)
                if parent is not None:
                    parent.children.append(node)
                    flat_nodes.pop(i)
                    continue
            i += 1

    @classmethod
    def build(cls, tasks: Iterable[Task]) -> List['TaskTreeNodeTree']:
        flat_nodes: List[TaskTreeNodeTree] = []
        for task in tasks:
            try:
                insort(flat_nodes, cls.from_task(task))
            except ValueError:
                pass
        TaskTreeNodeTree._build_flat(flat_nodes)
        return flat_nodes


class TaskTree(Tree[TaskTreeNode]):
    class QuitAction(Message):
        pass

    class _Action(Message):
        def __init__(self, node: Optional[TreeNode[TaskTreeNode]]) -> None:
            self.node: Optional[TreeNode[TaskTreeNode]] = node
            self.uid: Optional[str] = node.data.uid if node is not None and node.data is not None else None
            super().__init__()

    class HighlightAction(_Action):
        pass

    class SelectAction(_Action):
        pass

    class ToggleAction(_Action):
        pass

    class CreateAction(_Action):
        pass

    class DeleteAction(_Action):
        pass

    BINDINGS = [
        Binding("escape", "quit", "Back", show=True),
        Binding("space", "invoke_toggle", "Toggle", show=True),
        Binding("enter", "invoke_select", "Edit", show=True),
        Binding("insert", "invoke_insert", "Create", show=True),
        Binding("delete", "invoke_delete", "Delete", show=True),
    ]

    task_file: Reactive[Optional[TreeNode[TaskTreeNode]]] = reactive(None, always_update=True)

    def __init__(self) -> None:
        super().__init__("Tasks")
        self.root.expand()
        self.show_root = False
        self.show_guides = False
        self.guide_depth = 2

        self._node_cache: Dict[str, TreeNode[TaskTreeNode]] = {}

    @property
    def cursor_uid(self) -> Optional[str]:
        if self.cursor_node is not None:
            if self.cursor_node.data is not None:
                if self.cursor_node.data.uid is not None:
                    return self.cursor_node.data.uid
        return None

    def action_quit(self) -> None:
        self.post_message(self.QuitAction())

    def action_invoke_toggle(self) -> None:
        self.post_message(self.ToggleAction(self.cursor_node))

    def action_invoke_select(self) -> None:
        self.post_message(self.SelectAction(self.cursor_node))

    def action_invoke_insert(self) -> None:
        self.post_message(self.CreateAction(self.cursor_node))

    def action_invoke_delete(self) -> None:
        self.post_message(self.DeleteAction(self.cursor_node))

    def watch_task_file(self, task: Optional[TreeNode[TaskTreeNode]]) -> None:
        self.post_message(self.HighlightAction(task))

    def on_tree_node_highlighted(self, evt: Tree.NodeHighlighted) -> None:
        self.task_file = evt.node

    def _find_node(self, uid: str) -> TreeNode[TaskTreeNode]:
        try:
            return self._node_cache[uid]
        except KeyError:
            raise KeyError(uid) from None

    def _get_node_data(self, node: TreeNode[TaskTreeNode]) -> TaskTreeNode:
        if node.data is None:
            raise KeyError(node.label)
        return node.data

    def _find_node_data(self, uid: str) -> TaskTreeNode:
        return self._get_node_data(self._find_node(uid))

    def _is_completed_r(self, task: TaskTreeNode) -> bool:
        if task.completed:
            return True
        while task.parent is not None:  # parents added first
            try:
                task = self._find_node_data(task.parent)
            except KeyError:
                pass
            else:
                if task.completed:
                    return True
        return False

    def _get_label(self, task: TaskTreeNode) -> Text:
        prefix: str = " ✓ " if task.completed else " • "
        label: str = task.summary.strip()
        suffix: str = " […] " if task.description else " "
        return Text("".join((prefix, label, suffix)), style="dim" if self._is_completed_r(task) else "")

    def _insert_r(self, root: TreeNode[TaskTreeNode], node: TaskTreeNodeTree) -> None:
        tree_node: TreeNode[TaskTreeNode] = root.add_leaf(self._get_label(node), node)  # add(expand=False)
        tree_node.expand()
        self._node_cache[node.uid] = tree_node
        for child in node.children:
            self._insert_r(tree_node, child)

    def update(self, tree: List[TaskTreeNodeTree]) -> None:
        self.clear()
        self._node_cache.clear()

        for node in tree:
            self._insert_r(self.root, node)

        if len(self.root.children):
            self.cursor_line = 0
            self.task_file = self.root.children[0]
        else:
            self.cursor_line = -1
            self.task_file = None

    def _find_parent(self, node: TaskTreeNode) -> TreeNode[TaskTreeNode]:
        return self._find_node(node.parent) if node.parent is not None else self.root

    def _update_node(self, tree_node: TreeNode[TaskTreeNode]) -> None:
        tree_node.label = self._get_label(self._get_node_data(tree_node))
        for child in tree_node.children:
            self._update_node(child)  # due to recursive completed label

    def get_node(self, uid: str) -> TaskTreeNode:
        return self._find_node_data(uid)

    def update_node(self, node: TaskTreeNode) -> None:
        tree_node: TreeNode[TaskTreeNode] = self._find_node(node.uid)
        tree_node.data = node
        self._update_node(tree_node)

    def add_node(self, node: TaskTreeNode) -> None:
        # TODO: can somehow insert at first or proper position?
        if node.uid in self._node_cache:
            raise KeyError
        parent: TreeNode[TaskTreeNode] = self._find_parent(node)
        tree_node: TreeNode[TaskTreeNode] = parent.add_leaf(self._get_label(node), node)
        tree_node.expand()
        self._node_cache[node.uid] = tree_node
        if parent is self.root and len(self.root.children) == 1:  # no highlight message when previously none
            self.cursor_line = 0
            self.task_file = tree_node

    def _dfs_tree_r(self, tree_node: TreeNode[TaskTreeNode]) -> Iterator[List[TaskTreeNode]]:
        for child in tree_node.children:
            yield from self._dfs_tree_r(child)
        if len(tree_node.children):
            yield [self._get_node_data(child) for child in tree_node.children]

    def dfs_tree(self, node: TaskTreeNode) -> Iterator[List[TaskTreeNode]]:
        """Depth-first list for orderly deletion."""
        tree_node: TreeNode[TaskTreeNode] = self._find_node(node.uid)
        yield from self._dfs_tree_r(tree_node)
        yield [self._get_node_data(tree_node)]

    def _remove_node(self, node: TaskTreeNode) -> None:
        tree_node: TreeNode[TaskTreeNode] = self._find_node(node.uid)
        tree_node.remove()  # NB: is recursive
        del self._node_cache[node.uid]

    def remove_nodes(self, nodes: List[TaskTreeNode]) -> None:
        line: int = self.cursor_line  # watch_cursor_line seems to not update cursor_node for deletions
        for node in nodes:
            self._remove_node(node)
        if self.cursor_line == 0 and not len(self.root.children):
            self.cursor_line = -1
            self.task_file = None
        else:
            self.cursor_line = -1
            self.cursor_line = line


class TaskListSelect(Select[TaskList]):
    class Updated(Message):
        def __init__(self, task_list: Optional[TaskList]) -> None:
            self.task_list: Optional[TaskList] = task_list
            super().__init__()

    task_list: Reactive[Optional[TaskList]] = reactive(None, always_update=True)

    @classmethod
    def _get_prompt(cls, option: TaskList) -> Union[str, Text]:
        bullet: str = "⚫"
        return Text(bullet, style=Style(color=Color.parse(option.color))) if option.color is not None else bullet

    def update(self, options: List[TaskList], preselect: Optional[str]) -> None:
        self.set_options([(Text.assemble(self._get_prompt(_), _.name or ""), _) for _ in options])
        if preselect is not None:
            for option in options:
                if option.name is not None and option.name == preselect:
                    self.value = option
                    self.task_list = option
                    return
        self.value = Select.BLANK
        self.task_list = None

    def on_select_changed(self, evt: Select.Changed) -> None:
        self.task_list = evt.value if isinstance(evt.value, TaskList) else None

    def watch_task_list(self, task_list: Optional[TaskList]) -> None:
        self.post_message(self.Updated(task_list))