"""
Git Graph visualization panel for Paplix Version Control Plugin.

Provides a visual representation of git commit history with branch lines,
merge points, and commit details similar to VS Code's Git Graph extension.
"""

import wx
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass

from .git_utils import GitUtils, GraphCommit, BranchInfo


# Color palette for branch rails (cycling through these)
RAIL_COLORS = [
    wx.Colour(34, 197, 94),    # Green
    wx.Colour(59, 130, 246),   # Blue
    wx.Colour(249, 115, 22),   # Orange
    wx.Colour(168, 85, 247),   # Purple
    wx.Colour(6, 182, 212),    # Cyan
    wx.Colour(236, 72, 153),   # Pink
    wx.Colour(234, 179, 8),    # Yellow
    wx.Colour(239, 68, 68),    # Red
]


@dataclass
class RenderCommit:
    """Commit with rendering information."""
    commit: GraphCommit
    rail: int  # Which column/rail this commit is on
    color: wx.Colour
    merge_lines: List[Tuple[int, int]]  # List of (from_rail, to_rail) for merge lines


class GitGraphPanel(wx.ScrolledWindow):
    """
    Scrollable panel that displays git commit history as a graph.

    Features:
    - Colored branch rails (vertical lines)
    - Commit nodes as circles with borders
    - Smooth bezier merge lines
    - Row hover highlighting
    - Commit info (hash, message, author, time)
    - Lazy loading on scroll
    """

    # Layout constants
    ROW_HEIGHT = 32
    NODE_RADIUS = 6
    RAIL_SPACING = 18
    LEFT_MARGIN = 16
    MIN_GRAPH_WIDTH = 60

    # Loading
    INITIAL_LOAD = 30
    LOAD_MORE = 20

    # Colors
    BG_COLOR = wx.Colour(30, 30, 32)
    HOVER_COLOR = wx.Colour(45, 45, 50)
    SEPARATOR_COLOR = wx.Colour(50, 50, 55)
    HASH_COLOR = wx.Colour(120, 120, 120)
    MESSAGE_COLOR = wx.Colour(220, 220, 220)
    META_COLOR = wx.Colour(100, 100, 100)

    def __init__(self, parent, git: GitUtils):
        super().__init__(parent, style=wx.VSCROLL | wx.ALWAYS_SHOW_SB)

        self.git = git
        self.commits: List[RenderCommit] = []
        self.all_loaded = False
        self._loading = False
        self._hovered_row = -1

        # Rail tracking
        self._reserved_rails: Dict[str, int] = {}  # Parent hash -> reserved rail
        self._commit_rails: Dict[str, int] = {}    # Commit hash -> assigned rail
        self._commit_row_index: Dict[str, int] = {}  # Commit hash -> row index
        self._rail_occupancy: List[bool] = []
        self._max_rail = 0

        # Set up scrolling
        self.SetScrollRate(0, self.ROW_HEIGHT)
        self.SetBackgroundColour(self.BG_COLOR)

        # Bind events
        self.Bind(wx.EVT_PAINT, self._on_paint)
        self.Bind(wx.EVT_SIZE, self._on_size)
        self.Bind(wx.EVT_SCROLLWIN, self._on_scroll)
        self.Bind(wx.EVT_MOUSEWHEEL, self._on_mousewheel)
        self.Bind(wx.EVT_MOTION, self._on_motion)
        self.Bind(wx.EVT_LEAVE_WINDOW, self._on_leave)

        # Initial load
        self.refresh()

    def refresh(self):
        """Reload commit data from git."""
        self.commits = []
        self._reserved_rails = {}
        self._commit_rails = {}
        self._commit_row_index = {}
        self._rail_occupancy = []
        self._max_rail = 0
        self._last_head_hash = None
        self.all_loaded = False
        self._load_commits(self.INITIAL_LOAD)
        self._update_virtual_size()
        self.Scroll(0, 0)
        self.Refresh()
        # Cache HEAD hash for incremental updates
        if self.commits:
            self._last_head_hash = self.commits[0].commit.short_hash

    def check_for_new_commits(self) -> bool:
        """
        Quick check if there are new commits since last refresh.

        Returns:
            True if refresh is needed
        """
        if not self.commits:
            return True

        # Get current HEAD hash quickly
        current_head = self.git.get_head_hash_quick()
        if not current_head:
            return True

        # Compare with cached HEAD
        last_head = getattr(self, '_last_head_hash', None)
        if not last_head:
            return True

        return current_head != last_head

    def refresh_incremental(self) -> bool:
        """
        Incremental refresh - only reload if new commits detected.

        Returns:
            True if refresh was performed
        """
        if self.check_for_new_commits():
            self.refresh()
            return True
        return False

    def _get_graph_width(self) -> int:
        """Calculate dynamic graph width based on number of rails."""
        return max(self.MIN_GRAPH_WIDTH,
                   self.LEFT_MARGIN + (self._max_rail + 1) * self.RAIL_SPACING + 20)

    def _load_commits(self, count: int):
        """Load more commits from git."""
        if self.all_loaded or self._loading:
            return

        self._loading = True
        try:
            skip = len(self.commits)
            new_commits = self.git.get_graph_log(limit=count, skip=skip)

            if not new_commits:
                self.all_loaded = True
                return

            if len(new_commits) < count:
                self.all_loaded = True

            for commit in new_commits:
                row_index = len(self.commits)
                render_commit = self._process_commit(commit)
                self.commits.append(render_commit)
                # Track row index for merge line drawing
                self._commit_row_index[commit.hash] = row_index

            self._update_virtual_size()
        finally:
            self._loading = False

    def _process_commit(self, commit: GraphCommit) -> RenderCommit:
        """Process a commit and assign it to a rail."""
        rail = self._find_rail_for_commit(commit)
        color = RAIL_COLORS[rail % len(RAIL_COLORS)]

        # Track which rail this commit is on for merge line drawing
        self._commit_rails[commit.hash] = rail

        # For merge commits, draw lines to secondary parents
        # Store (parent_hash, parent_rail, commit_rail) so we can look up parent row later
        merge_lines = []
        if len(commit.parents) > 1:
            for parent_hash in commit.parents[1:]:
                # Check if we know where the secondary parent is
                if parent_hash in self._commit_rails:
                    parent_rail = self._commit_rails[parent_hash]
                elif parent_hash in self._reserved_rails:
                    parent_rail = self._reserved_rails[parent_hash]
                else:
                    # Parent not yet processed, assign a new rail for it
                    parent_rail = self._get_next_rail()
                    self._reserved_rails[parent_hash] = parent_rail

                if parent_rail != rail:
                    merge_lines.append((parent_hash, parent_rail, rail))

        return RenderCommit(
            commit=commit,
            rail=rail,
            color=color,
            merge_lines=merge_lines
        )

    def _find_rail_for_commit(self, commit: GraphCommit) -> int:
        """Find the appropriate rail for a commit using forward-looking reservation."""
        if commit.hash in self._reserved_rails:
            # Pop to remove the reservation (rail is now being used)
            rail = self._reserved_rails.pop(commit.hash)
            # Mark rail as available for reuse once this branch is consumed
            if rail < len(self._rail_occupancy):
                self._rail_occupancy[rail] = False
        else:
            rail = self._get_first_available_rail()

        # Only reserve rail for first parent (main line of history)
        # Secondary parents from merges don't need permanent rail reservation
        if commit.parents:
            first_parent = commit.parents[0]
            if first_parent not in self._reserved_rails:
                self._reserved_rails[first_parent] = rail
                # Mark rail as occupied for the parent
                while len(self._rail_occupancy) <= rail:
                    self._rail_occupancy.append(False)
                self._rail_occupancy[rail] = True

        return rail

    def _get_first_available_rail(self) -> int:
        """Get the first available rail."""
        if not self._rail_occupancy:
            self._rail_occupancy.append(True)
            return 0

        for i, occupied in enumerate(self._rail_occupancy):
            if not occupied:
                self._rail_occupancy[i] = True
                return i

        self._rail_occupancy.append(True)
        return len(self._rail_occupancy) - 1

    def _get_next_rail(self) -> int:
        """Get the next available rail for a new branch."""
        for i, occupied in enumerate(self._rail_occupancy):
            if not occupied:
                self._rail_occupancy[i] = True
                self._max_rail = max(self._max_rail, i)
                return i

        self._rail_occupancy.append(True)
        self._max_rail = len(self._rail_occupancy) - 1
        return self._max_rail

    def _update_virtual_size(self):
        """Update the virtual size for scrolling."""
        height = len(self.commits) * self.ROW_HEIGHT + 20
        width = self.GetClientSize().width
        self.SetVirtualSize((width, height))

    def _on_size(self, event):
        self._update_virtual_size()
        self.Refresh()
        event.Skip()

    def _on_scroll(self, event):
        self._check_load_more()
        event.Skip()

    def _on_mousewheel(self, event):
        rotation = event.GetWheelRotation()
        delta = event.GetWheelDelta()
        lines = rotation // delta

        current_y = self.GetViewStart()[1]
        new_y = max(0, current_y - lines)
        self.Scroll(0, new_y)

        self._check_load_more()

    def _on_motion(self, event):
        """Handle mouse motion for hover effect."""
        pos = event.GetPosition()
        view_start = self.GetViewStart()[1]
        # Convert to logical coordinates
        logical_y = pos.y + view_start * self.ROW_HEIGHT
        new_hovered = logical_y // self.ROW_HEIGHT

        if new_hovered != self._hovered_row:
            self._hovered_row = new_hovered
            self.Refresh()

    def _on_leave(self, event):
        """Handle mouse leaving the panel."""
        if self._hovered_row != -1:
            self._hovered_row = -1
            self.Refresh()

    def _check_load_more(self):
        """Check if we need to load more commits."""
        if self.all_loaded:
            return

        view_start = self.GetViewStart()[1]
        client_height = self.GetClientSize().height
        total_rows = len(self.commits)

        visible_rows = client_height // self.ROW_HEIGHT
        last_visible_row = view_start + visible_rows

        if last_visible_row >= total_rows - 5:
            self._load_commits(self.LOAD_MORE)
            self.Refresh()

    def _on_paint(self, event):
        """Paint the git graph."""
        dc = wx.AutoBufferedPaintDC(self)
        self.PrepareDC(dc)

        # Clear background
        dc.SetBackground(wx.Brush(self.BG_COLOR))
        dc.Clear()

        if not self.commits:
            dc.SetTextForeground(wx.Colour(150, 150, 150))
            dc.SetFont(wx.Font(10, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_ITALIC, wx.FONTWEIGHT_NORMAL))
            dc.DrawText("No commits yet", 20, 20)
            return

        # Get visible area
        view_start = self.GetViewStart()[1]
        client_height = self.GetClientSize().height

        first_visible = max(0, view_start - 2)
        last_visible = min(len(self.commits), view_start + (client_height // self.ROW_HEIGHT) + 3)

        graph_width = self._get_graph_width()

        # Draw hover background first
        if 0 <= self._hovered_row < len(self.commits):
            y = self._hovered_row * self.ROW_HEIGHT
            dc.SetBrush(wx.Brush(self.HOVER_COLOR))
            dc.SetPen(wx.TRANSPARENT_PEN)
            dc.DrawRectangle(0, y, self.GetClientSize().width + 500, self.ROW_HEIGHT)

        # Draw row separators
        dc.SetPen(wx.Pen(self.SEPARATOR_COLOR, 1))
        for i in range(first_visible, last_visible + 1):
            y = i * self.ROW_HEIGHT
            dc.DrawLine(0, y, self.GetClientSize().width + 500, y)

        # Draw rail lines
        self._draw_rail_lines(dc, first_visible, last_visible)

        # Draw merge lines with curves
        for i in range(first_visible, last_visible):
            self._draw_merge_lines(dc, i)

        # Draw commits
        for i in range(first_visible, last_visible):
            self._draw_commit(dc, i, graph_width)

    def _get_rail_x(self, rail: int) -> int:
        """Get the X coordinate for a rail."""
        return self.LEFT_MARGIN + rail * self.RAIL_SPACING + self.NODE_RADIUS

    def _get_row_y(self, row: int) -> int:
        """Get the Y coordinate for a row (center of the row)."""
        return row * self.ROW_HEIGHT + self.ROW_HEIGHT // 2

    def _draw_rail_lines(self, dc: wx.DC, first_row: int, last_row: int):
        """Draw vertical rail lines."""
        active_rails: Dict[int, wx.Colour] = {}

        for i in range(min(last_row + 1, len(self.commits))):
            commit = self.commits[i]
            active_rails[commit.rail] = commit.color

        for rail, color in active_rails.items():
            x = self._get_rail_x(rail)
            y1 = self._get_row_y(first_row)
            y2 = self._get_row_y(min(last_row, len(self.commits) - 1))

            dc.SetPen(wx.Pen(color, 2))
            dc.DrawLine(x, y1, x, y2)

    def _draw_merge_lines(self, dc: wx.DC, row: int):
        """Draw smooth bezier merge lines for a commit."""
        if row >= len(self.commits):
            return

        commit = self.commits[row]
        merge_y = self._get_row_y(row)

        for merge_line in commit.merge_lines:
            # Unpack: (parent_hash, parent_rail, commit_rail)
            parent_hash, from_rail, to_rail = merge_line

            from_x = self._get_rail_x(from_rail)
            to_x = self._get_rail_x(to_rail)

            color = RAIL_COLORS[from_rail % len(RAIL_COLORS)]
            dc.SetPen(wx.Pen(color, 2))

            # Find the parent's actual row position
            if parent_hash in self._commit_row_index:
                parent_row = self._commit_row_index[parent_hash]
                parent_y = self._get_row_y(parent_row)
            else:
                # Parent not loaded yet, draw to one row below as fallback
                parent_y = merge_y + self.ROW_HEIGHT

            # Draw smooth S-curve from parent (below) to merge commit (above)
            # Parents are in later rows (higher index = lower Y visually)
            steps = 8
            for step in range(steps + 1):
                t = step / steps
                # Ease in-out interpolation
                ease_t = t * t * (3 - 2 * t)
                curr_x = from_x + (to_x - from_x) * ease_t
                curr_y = parent_y + (merge_y - parent_y) * t

                if step > 0:
                    dc.DrawLine(int(last_x), int(last_y), int(curr_x), int(curr_y))

                last_x, last_y = curr_x, curr_y

    def _draw_commit(self, dc: wx.DC, row: int, graph_width: int):
        """Draw a single commit node and info."""
        if row >= len(self.commits):
            return

        render_commit = self.commits[row]
        commit = render_commit.commit

        x = self._get_rail_x(render_commit.rail)
        y = self._get_row_y(row)

        # Draw commit node with border
        dc.SetBrush(wx.Brush(render_commit.color))
        dc.SetPen(wx.Pen(wx.Colour(20, 20, 22), 2))
        dc.DrawCircle(x, y, self.NODE_RADIUS)

        # Text layout
        text_x = graph_width

        # Hash (monospace, dimmed)
        dc.SetTextForeground(self.HASH_COLOR)
        dc.SetFont(wx.Font(9, wx.FONTFAMILY_MODERN, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL))
        dc.DrawText(commit.short_hash, text_x, y - 7)
        text_x += 65

        # Message
        dc.SetTextForeground(self.MESSAGE_COLOR)
        dc.SetFont(wx.Font(9, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL))

        # Calculate available space for message
        client_width = self.GetClientSize().width
        author_text = f"{commit.author} • {commit.relative_date}"
        dc.SetFont(wx.Font(8, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL))
        author_width = dc.GetTextExtent(author_text)[0]

        # Space for badges
        badge_width = 0
        if commit.refs:
            dc.SetFont(wx.Font(8, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_BOLD))
            for ref in commit.refs[:2]:
                badge_width += dc.GetTextExtent(ref)[0] + 16

        max_msg_width = max(80, client_width - text_x - badge_width - author_width - 40)

        # Draw message
        dc.SetFont(wx.Font(9, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL))
        message = commit.message
        extent = dc.GetTextExtent(message)
        if extent[0] > max_msg_width:
            while extent[0] > max_msg_width and len(message) > 10:
                message = message[:-4] + "…"
                extent = dc.GetTextExtent(message)
        dc.DrawText(message, text_x, y - 7)
        text_x += extent[0] + 10

        # Draw ref badges inline
        if commit.refs:
            text_x = self._draw_ref_badges_inline(dc, commit.refs, text_x, y, render_commit.color)

        # Author and time (right-aligned feel, but after badges)
        text_x += 8
        dc.SetTextForeground(self.META_COLOR)
        dc.SetFont(wx.Font(8, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_NORMAL))
        dc.DrawText(author_text, text_x, y - 6)

    def _draw_ref_badges_inline(self, dc: wx.DC, refs: List[str], start_x: int, y: int,
                                 base_color: wx.Colour) -> int:
        """Draw branch/tag reference badges inline. Returns the new x position."""
        badge_x = start_x
        badge_y = y - 8

        dc.SetFont(wx.Font(8, wx.FONTFAMILY_DEFAULT, wx.FONTSTYLE_NORMAL, wx.FONTWEIGHT_BOLD))

        for ref in refs[:2]:  # Limit to 2 badges inline
            # Determine badge color
            if ref.startswith('origin/') or ref.startswith('remotes/'):
                bg_color = wx.Colour(60, 60, 65)
                text_color = wx.Colour(180, 180, 180)
            else:
                bg_color = base_color
                text_color = wx.Colour(255, 255, 255)

            # Measure and draw
            text_extent = dc.GetTextExtent(ref)
            badge_width = text_extent[0] + 10
            badge_height = 16

            dc.SetBrush(wx.Brush(bg_color))
            dc.SetPen(wx.TRANSPARENT_PEN)
            dc.DrawRoundedRectangle(badge_x, badge_y, badge_width, badge_height, 3)

            dc.SetTextForeground(text_color)
            dc.DrawText(ref, badge_x + 5, badge_y + 1)

            badge_x += badge_width + 4

        return badge_x
