#!/usr/bin/env python3
import argparse
import ipaddress
import json
import os
import re
import shlex
import shutil
import select
import subprocess
import sys
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable

APP_NAME = "Bahari Network Manager"
DEFAULT_GUARD_TIMEOUT = 45
DANGEROUS_GUARD_TIMEOUT = 60
SCRIPT_PATH = os.path.realpath(__file__)
APP_SUBTITLE = "Adaptive Network Control Surface"
ANSI_RE = re.compile(r"\x1b\[[0-9;]*m")
RESET = "\033[0m"
BOLD = "\033[1m"
DIM = "\033[2m"
FG = {
    "frame": "\033[38;5;39m",
    "primary": "\033[38;5;51m",
    "secondary": "\033[38;5;117m",
    "muted": "\033[38;5;109m",
    "success": "\033[38;5;120m",
    "warning": "\033[38;5;214m",
    "danger": "\033[38;5;203m",
}
BG = {
    "primary": "\033[48;5;24m",
    "secondary": "\033[48;5;31m",
    "muted": "\033[48;5;236m",
    "success": "\033[48;5;22m",
    "warning": "\033[48;5;94m",
    "danger": "\033[48;5;52m",
}
ACTION_HINTS = {
    "Add IP": "attach the first live address",
    "Change IP": "safe replace with rollback shield",
    "Remove IP": "destructive path with SSH protection",
    "Set Gateway": "update default route targeting",
    "Set DNS": "push resolver servers",
    "Interface Manage": "switch, enable, or disable link",
    "Restart Network": "reconfigure live network state",
    "Refresh": "scan the current node again",
}


class CommandError(RuntimeError):
    def __init__(self, cmd: list[str], stderr: str):
        super().__init__(stderr or f"Command failed: {shlex.join(cmd)}")
        self.cmd = cmd
        self.stderr = stderr or ""


class PromptCancelled(RuntimeError):
    pass


class InputClosed(RuntimeError):
    pass


@dataclass
class Operation:
    label: str
    preview: str
    executor: Callable[[], None]


@dataclass
class IPWizardResult:
    cidr: str
    gateway: str | None
    dns_servers: list[str] | None


@dataclass
class Snapshot:
    interface: str
    link_up: bool
    addresses: list[str]
    default_routes: list[dict[str, Any]]
    resolver_mode: str
    dns_servers: list[str]
    resolv_conf: str | None = None

    def to_dict(self) -> dict[str, Any]:
        return {
            "interface": self.interface,
            "link_up": self.link_up,
            "addresses": self.addresses,
            "default_routes": self.default_routes,
            "resolver_mode": self.resolver_mode,
            "dns_servers": self.dns_servers,
            "resolv_conf": self.resolv_conf,
        }

    @classmethod
    def from_dict(cls, payload: dict[str, Any]) -> "Snapshot":
        return cls(
            interface=payload["interface"],
            link_up=payload["link_up"],
            addresses=list(payload.get("addresses", [])),
            default_routes=list(payload.get("default_routes", [])),
            resolver_mode=payload.get("resolver_mode", "resolv.conf"),
            dns_servers=list(payload.get("dns_servers", [])),
            resolv_conf=payload.get("resolv_conf"),
        )


@dataclass
class InterfaceSummary:
    name: str
    state: str
    addresses: list[str]
    gateways: dict[str, str]
    dns_servers: list[str]
    mac: str | None
    mtu: int | None

    @property
    def primary_gateway(self) -> str | None:
        return self.gateways.get("inet") or self.gateways.get("inet6")


@dataclass
class MenuAction:
    label: str
    handler: Callable[[], None]


def command_exists(name: str) -> bool:
    return shutil.which(name) is not None


def run_command(
    cmd: list[str],
    *,
    check: bool = True,
    capture: bool = True,
) -> str:
    completed = subprocess.run(
        cmd,
        check=False,
        text=True,
        capture_output=capture,
    )
    stdout = completed.stdout.strip() if completed.stdout else ""
    stderr = completed.stderr.strip() if completed.stderr else ""
    if check and completed.returncode != 0:
        raise CommandError(cmd, stderr or stdout)
    return stdout


def best_effort(cmd: list[str]) -> None:
    run_command(cmd, check=False, capture=True)


def run_json(cmd: list[str]) -> list[dict[str, Any]]:
    raw = run_command(cmd)
    if not raw:
        return []
    payload = json.loads(raw)
    if isinstance(payload, list):
        return payload
    return []


def clear_screen() -> None:
    if sys.stdout.isatty():
        sys.stdout.write("\033[2J\033[H")
        sys.stdout.flush()


def color_enabled(stream: Any | None = None) -> bool:
    target = stream or sys.stdout
    is_tty = getattr(target, "isatty", lambda: False)()
    term = os.environ.get("TERM", "")
    return bool(is_tty and term and term.lower() != "dumb" and "NO_COLOR" not in os.environ)


def paint(text: str, *codes: str) -> str:
    if not color_enabled():
        return text
    return f"{''.join(codes)}{text}{RESET}"


def strip_ansi(text: str) -> str:
    return ANSI_RE.sub("", text)


def visible_len(text: str) -> int:
    return len(strip_ansi(text))


def terminal_width() -> int:
    columns = shutil.get_terminal_size((96, 24)).columns
    return max(72, min(110, columns - 2))


def fit_text(text: str, width: int) -> str:
    plain = strip_ansi(text)
    if len(plain) > width:
        plain = plain[: max(0, width - 1)] + ("…" if width > 0 else "")
    return plain.ljust(width)


def render_badge(label: str, tone: str = "primary") -> str:
    if not color_enabled():
        return f"[{label}]"
    fg = FG.get(tone, FG["primary"])
    bg = BG.get(tone, BG["primary"])
    return f"{bg}{fg}{BOLD} {label} {RESET}"


def render_panel(title: str, lines: list[str], *, tone: str = "primary") -> str:
    body = lines or [""]
    content_width = max(
        42,
        min(
            max([len(title)] + [visible_len(line) for line in body]),
            terminal_width() - 6,
        ),
    )
    top = paint("┏" + "━" * (content_width + 2) + "┓", FG["frame"])
    title_line = paint(f"┃ {fit_text(title.upper(), content_width)} ┃", BOLD, FG.get(tone, FG["primary"]))
    divider = paint("┣" + "━" * (content_width + 2) + "┫", FG["frame"])
    content = [paint(f"┃ {fit_text(line, content_width)} ┃", FG["muted"]) for line in body]
    bottom = paint("┗" + "━" * (content_width + 2) + "┛", FG["frame"])
    return "\n".join([top, title_line, divider, *content, bottom])


def message_tone(message: str) -> str:
    lowered = message.lower()
    if any(token in lowered for token in ("failed", "invalid", "warning", "error")):
        return "danger"
    if "rolled back" in lowered or "cancelled" in lowered:
        return "warning"
    if any(token in lowered for token in ("success", "confirmed", "refreshed", "switched")):
        return "success"
    return "primary"


def prompt_text(text: str) -> str:
    if not color_enabled():
        return text
    return f"{paint('› ', BOLD, FG['primary'])}{paint(text, BOLD, FG['secondary'])}"


def prompt_input(text: str) -> str:
    try:
        return input(prompt_text(text)).strip()
    except EOFError as exc:
        print()
        raise InputClosed from exc
    except KeyboardInterrupt as exc:
        print()
        raise PromptCancelled from exc


def wait_for_enter(prompt: str = "Press Enter to continue...") -> None:
    try:
        prompt_input(prompt)
    except (EOFError, InputClosed, PromptCancelled):
        pass


def prompt_yes_no(text: str, *, default: bool = False) -> bool:
    suffix = "[Y/n]" if default else "[y/N]"
    while True:
        raw = prompt_input(f"{text} {suffix}: ").lower()
        if not raw:
            return default
        if raw in {"y", "yes"}:
            return True
        if raw in {"n", "no"}:
            return False
        print("Please answer yes or no.")


def prompt_timeout(text: str, timeout: int) -> tuple[str, bool]:
    print(prompt_text(text), end="", flush=True)
    try:
        ready, _, _ = select.select([sys.stdin], [], [], timeout)
    except KeyboardInterrupt as exc:
        print()
        raise PromptCancelled from exc
    if not ready:
        print()
        return "", True
    try:
        value = sys.stdin.readline()
    except KeyboardInterrupt as exc:
        print()
        raise PromptCancelled from exc
    if not value:
        raise InputClosed
    return value.strip(), False


def header_box(title: str) -> str:
    width = max(56, min(84, terminal_width() - 4))
    top = paint("╔" + ("═" * width) + "╗", FG["frame"])
    middle = paint("║" + title.center(width) + "║", BOLD, FG["primary"])
    subtitle = paint("║" + APP_SUBTITLE.center(width) + "║", FG["secondary"])
    bottom = paint("╚" + ("═" * width) + "╝", FG["frame"])
    return "\n".join([top, middle, subtitle, bottom])


def normalize_cidr(value: str) -> str:
    iface = ipaddress.ip_interface(value.strip())
    return f"{iface.ip.compressed}/{iface.network.prefixlen}"


def normalize_ip(value: str) -> str:
    return ipaddress.ip_address(value.strip()).compressed


def parse_dns_servers(value: str) -> list[str]:
    tokens = [item.strip() for item in value.replace(",", " ").split() if item.strip()]
    if not tokens:
        return []
    return [normalize_ip(item) for item in tokens]


def family_flag_for_ip(value: str) -> str:
    return "-6" if ipaddress.ip_address(value).version == 6 else "-4"


def family_name_for_ip(value: str) -> str:
    return "inet6" if ipaddress.ip_address(value).version == 6 else "inet"


def render_resolv_conf(existing: str, servers: list[str]) -> str:
    preserved = [
        line.rstrip()
        for line in existing.splitlines()
        if not line.lstrip().startswith("nameserver ")
    ]
    lines = [f"nameserver {server}" for server in servers]
    if preserved:
        if lines:
            lines.append("")
        lines.extend(preserved)
    return ("\n".join(lines).rstrip() + "\n") if lines else ""


def is_ssh_session() -> bool:
    return any(os.environ.get(key) for key in ("SSH_CONNECTION", "SSH_CLIENT", "SSH_TTY"))


def make_command_operation(cmd: list[str], label: str | None = None) -> Operation:
    preview = shlex.join(cmd)
    return Operation(
        label=label or preview,
        preview=preview,
        executor=lambda: run_command(cmd, capture=False),
    )


class RollbackGuard:
    def __init__(self, snapshot_dir: str, active_file: str, timeout: int):
        self.snapshot_dir = snapshot_dir
        self.active_file = active_file
        self.timeout = timeout

    @classmethod
    def arm(cls, snapshot: Snapshot, timeout: int) -> "RollbackGuard":
        snapshot_dir = tempfile.mkdtemp(prefix="bhipconfig-rollback-")
        snapshot_file = os.path.join(snapshot_dir, "snapshot.json")
        active_file = os.path.join(snapshot_dir, "active")
        Path(snapshot_file).write_text(json.dumps(snapshot.to_dict()), encoding="utf-8")
        Path(active_file).write_text(str(int(time.time())), encoding="utf-8")
        subprocess.Popen(
            [sys.executable, SCRIPT_PATH, "--rollback-daemon", snapshot_file, str(timeout)],
            stdin=subprocess.DEVNULL,
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
            close_fds=True,
            start_new_session=True,
        )
        return cls(snapshot_dir, active_file, timeout)

    def cancel(self) -> None:
        if os.path.exists(self.active_file):
            os.remove(self.active_file)


class NetworkController:
    def __init__(self) -> None:
        self.resolver_mode = self.detect_resolver_mode()

    def detect_resolver_mode(self) -> str:
        if command_exists("resolvectl"):
            try:
                run_command(["resolvectl", "status"])
                return "systemd-resolved"
            except CommandError:
                pass
        return "resolv.conf"

    def interface_summaries(self) -> list[InterfaceSummary]:
        addr_payload = run_json(["ip", "-j", "address", "show"])
        route_payload = run_json(["ip", "-j", "route", "show", "default"])
        gateways_by_dev: dict[str, dict[str, str]] = {}
        for route in route_payload:
            dev = route.get("dev")
            gateway = route.get("gateway")
            if not dev or not gateway:
                continue
            family = route.get("family", family_name_for_ip(gateway))
            gateways_by_dev.setdefault(dev, {})[family] = gateway

        summaries: list[InterfaceSummary] = []
        for entry in addr_payload:
            name = entry.get("ifname")
            if not name:
                continue
            addresses: list[str] = []
            for addr_info in entry.get("addr_info", []):
                scope = addr_info.get("scope")
                if scope not in {"global", "host"}:
                    continue
                local = addr_info.get("local")
                prefix = addr_info.get("prefixlen")
                if local and prefix is not None:
                    addresses.append(normalize_cidr(f"{local}/{prefix}"))
            dns_servers = self.get_dns_servers(name)
            summaries.append(
                InterfaceSummary(
                    name=name,
                    state=entry.get("operstate", "UNKNOWN"),
                    addresses=addresses,
                    gateways=gateways_by_dev.get(name, {}),
                    dns_servers=dns_servers,
                    mac=entry.get("address"),
                    mtu=entry.get("mtu"),
                )
            )
        return summaries

    def displayable_interfaces(self) -> list[InterfaceSummary]:
        summaries = self.interface_summaries()
        visible = [item for item in summaries if item.name != "lo"]
        return visible or summaries

    def default_interface_name(self) -> str | None:
        routes = run_json(["ip", "-j", "route", "show", "default"])
        for route in routes:
            dev = route.get("dev")
            if dev and dev != "lo":
                return dev
        interfaces = self.displayable_interfaces()
        return interfaces[0].name if interfaces else None

    def summary_for(self, interface: str) -> InterfaceSummary | None:
        for summary in self.interface_summaries():
            if summary.name == interface:
                return summary
        return None

    def current_addresses(self, interface: str) -> list[str]:
        summary = self.summary_for(interface)
        return summary.addresses[:] if summary else []

    def get_dns_servers(self, interface: str) -> list[str]:
        if self.resolver_mode == "systemd-resolved":
            try:
                output = run_command(["resolvectl", "dns", interface])
                if ":" not in output:
                    return []
                _, values = output.split(":", 1)
                tokens = [item for item in values.split() if item]
                servers: list[str] = []
                for token in tokens:
                    try:
                        servers.append(normalize_ip(token))
                    except ValueError:
                        continue
                return servers
            except CommandError:
                return []
        resolv_conf = Path("/etc/resolv.conf")
        if not resolv_conf.exists():
            return []
        servers: list[str] = []
        for line in resolv_conf.read_text(encoding="utf-8").splitlines():
            stripped = line.strip()
            if stripped.startswith("nameserver "):
                value = stripped.split(None, 1)[1]
                try:
                    servers.append(normalize_ip(value))
                except ValueError:
                    continue
        return servers

    def capture_snapshot(self, interface: str) -> Snapshot:
        summary = self.summary_for(interface)
        link_up = bool(summary and summary.state != "DOWN")
        resolv_conf = None
        dns_servers = self.get_dns_servers(interface)
        if self.resolver_mode == "resolv.conf":
            resolv_path = Path("/etc/resolv.conf")
            if resolv_path.exists():
                resolv_conf = resolv_path.read_text(encoding="utf-8")
        return Snapshot(
            interface=interface,
            link_up=link_up,
            addresses=self.current_addresses(interface),
            default_routes=run_json(["ip", "-j", "route", "show", "default", "dev", interface]),
            resolver_mode=self.resolver_mode,
            dns_servers=dns_servers,
            resolv_conf=resolv_conf,
        )

    def apply_snapshot(self, snapshot: Snapshot) -> None:
        interface = snapshot.interface
        if snapshot.link_up:
            best_effort(["ip", "link", "set", "dev", interface, "up"])

        current_addrs = set(self.current_addresses(interface))
        desired_addrs = set(snapshot.addresses)
        for address in sorted(current_addrs - desired_addrs):
            best_effort(["ip", "addr", "del", address, "dev", interface])
        for address in sorted(desired_addrs - current_addrs):
            best_effort(["ip", "addr", "add", address, "dev", interface])

        current_routes = run_json(["ip", "-j", "route", "show", "default", "dev", interface])
        for route in current_routes:
            best_effort(self.build_route_delete_command(route))
        for route in snapshot.default_routes:
            best_effort(self.build_route_replace_from_snapshot(route))

        self.restore_dns(snapshot)

        if not snapshot.link_up:
            best_effort(["ip", "link", "set", "dev", interface, "down"])

    def restore_dns(self, snapshot: Snapshot) -> None:
        if snapshot.resolver_mode == "systemd-resolved" and command_exists("resolvectl"):
            if snapshot.dns_servers:
                best_effort(["resolvectl", "dns", snapshot.interface, *snapshot.dns_servers])
            else:
                best_effort(["resolvectl", "revert", snapshot.interface])
            return
        if snapshot.resolv_conf is not None:
            Path("/etc/resolv.conf").write_text(snapshot.resolv_conf, encoding="utf-8")

    def route_needs_onlink(self, interface: str, gateway: str, extra_cidrs: list[str] | None = None) -> bool:
        networks = []
        for cidr in self.current_addresses(interface) + (extra_cidrs or []):
            try:
                networks.append(ipaddress.ip_interface(cidr).network)
            except ValueError:
                continue
        gateway_ip = ipaddress.ip_address(gateway)
        same_family_networks = [network for network in networks if network.version == gateway_ip.version]
        if not same_family_networks:
            return False
        return not any(gateway_ip in network for network in same_family_networks)

    def build_gateway_command(
        self,
        interface: str,
        gateway: str,
        *,
        extra_cidrs: list[str] | None = None,
    ) -> list[str]:
        flag = family_flag_for_ip(gateway)
        command = ["ip", flag, "route", "replace", "default", "via", gateway, "dev", interface]
        if ipaddress.ip_address(gateway).version == 4 and self.route_needs_onlink(interface, gateway, extra_cidrs):
            command.append("onlink")
        return command

    def build_route_replace_from_snapshot(self, route: dict[str, Any]) -> list[str]:
        gateway = route.get("gateway")
        family = route.get("family", family_name_for_ip(gateway)) if gateway else route.get("family", "inet")
        flag = "-6" if family == "inet6" else "-4"
        command = ["ip", flag, "route", "replace", "default"]
        if gateway:
            command.extend(["via", gateway])
        dev = route.get("dev")
        if dev:
            command.extend(["dev", dev])
        metric = route.get("metric")
        if metric is not None:
            command.extend(["metric", str(metric)])
        flags = route.get("flags") or []
        if "onlink" in flags:
            command.append("onlink")
        return command

    def build_route_delete_command(self, route: dict[str, Any]) -> list[str]:
        gateway = route.get("gateway")
        family = route.get("family", family_name_for_ip(gateway)) if gateway else route.get("family", "inet")
        flag = "-6" if family == "inet6" else "-4"
        command = ["ip", flag, "route", "del", "default"]
        if gateway:
            command.extend(["via", gateway])
        dev = route.get("dev")
        if dev:
            command.extend(["dev", dev])
        metric = route.get("metric")
        if metric is not None:
            command.extend(["metric", str(metric)])
        return command

    def plan_add_ip(
        self,
        interface: str,
        cidr: str,
        *,
        gateway: str | None = None,
        dns_servers: list[str] | None = None,
    ) -> list[Operation]:
        operations = [make_command_operation(["ip", "addr", "add", cidr, "dev", interface], f"Add {cidr}")]
        if gateway:
            gateway_cmd = self.build_gateway_command(interface, gateway, extra_cidrs=[cidr])
            operations.append(make_command_operation(gateway_cmd, f"Set gateway {gateway}"))
        if dns_servers:
            operations.extend(self.plan_set_dns(interface, dns_servers))
        return operations

    def plan_remove_ip(self, interface: str, cidr: str) -> list[Operation]:
        return [make_command_operation(["ip", "addr", "del", cidr, "dev", interface], f"Remove {cidr}")]

    def plan_change_ip(
        self,
        interface: str,
        old_cidr: str,
        new_cidr: str,
        *,
        gateway: str | None = None,
        dns_servers: list[str] | None = None,
    ) -> list[Operation]:
        operations = [
            make_command_operation(["ip", "addr", "add", new_cidr, "dev", interface], f"Add {new_cidr}"),
        ]
        if gateway:
            gateway_cmd = self.build_gateway_command(interface, gateway, extra_cidrs=[new_cidr])
            operations.append(make_command_operation(gateway_cmd, f"Set gateway {gateway}"))
        if dns_servers:
            operations.extend(self.plan_set_dns(interface, dns_servers))
        operations.append(
            make_command_operation(["ip", "addr", "del", old_cidr, "dev", interface], f"Remove {old_cidr}")
        )
        return operations

    def plan_set_gateway(self, interface: str, gateway: str) -> list[Operation]:
        return [make_command_operation(self.build_gateway_command(interface, gateway), f"Set gateway {gateway}")]

    def plan_set_dns(self, interface: str, servers: list[str]) -> list[Operation]:
        if self.resolver_mode == "systemd-resolved":
            return [
                make_command_operation(
                    ["resolvectl", "dns", interface, *servers],
                    f"Set DNS to {', '.join(servers)}",
                )
            ]

        def write_resolv_conf() -> None:
            path = Path("/etc/resolv.conf")
            existing = path.read_text(encoding="utf-8") if path.exists() else ""
            path.write_text(render_resolv_conf(existing, servers), encoding="utf-8")

        preview = f"write /etc/resolv.conf with nameservers: {', '.join(servers)}"
        return [Operation(label="Update resolv.conf", preview=preview, executor=write_resolv_conf)]

    def plan_link_state(self, interface: str, enable: bool) -> list[Operation]:
        state = "up" if enable else "down"
        label = "Enable interface" if enable else "Disable interface"
        return [make_command_operation(["ip", "link", "set", "dev", interface, state], label)]

    def plan_restart(self, interface: str) -> list[Operation]:
        if command_exists("networkctl"):
            def reconfigure() -> None:
                try:
                    run_command(["networkctl", "reconfigure", interface], capture=False)
                except CommandError:
                    run_command(["ip", "link", "set", "dev", interface, "down"], capture=False)
                    run_command(["ip", "link", "set", "dev", interface, "up"], capture=False)

            preview = f"networkctl reconfigure {interface} (fallback: ip link set down/up)"
            return [Operation(label="Reconfigure interface", preview=preview, executor=reconfigure)]

        return [
            make_command_operation(["ip", "link", "set", "dev", interface, "down"], "Bring interface down"),
            make_command_operation(["ip", "link", "set", "dev", interface, "up"], "Bring interface up"),
        ]


class BhipConfigApp:
    def __init__(self) -> None:
        self.controller = NetworkController()
        self.current_interface = self.controller.default_interface_name()
        self.message = ""

    def run(self) -> int:
        if not self.current_interface:
            print("No network interface found.")
            return 1

        while True:
            actions = self.build_main_menu_actions()
            self.render(actions)
            try:
                choice = prompt_input("Select an option: ")
            except PromptCancelled:
                print("Exiting bhipconfig.")
                return 130
            except InputClosed:
                print("Input closed. Exiting bhipconfig.")
                return 0

            try:
                if choice == "0":
                    return 0
                action_map = {str(index): action.handler for index, action in enumerate(actions, start=1)}
                handler = action_map.get(choice)
                if handler is None:
                    self.message = f"Unknown option: {choice or 'blank'}"
                else:
                    handler()
            except PromptCancelled:
                self.message = "Current action cancelled."
            except InputClosed:
                print("Input closed. Exiting bhipconfig.")
                return 0

    def current_summary(self) -> InterfaceSummary | None:
        return self.controller.summary_for(self.current_interface) if self.current_interface else None

    def managed_addresses(self) -> list[str]:
        summary = self.current_summary()
        return summary.addresses[:] if summary else []

    def build_main_menu_actions(self) -> list[MenuAction]:
        actions: list[MenuAction] = []
        if self.managed_addresses():
            actions.extend(
                [
                    MenuAction("Change IP", self.action_change_ip),
                    MenuAction("Remove IP", self.action_remove_ip),
                ]
            )
        else:
            actions.append(MenuAction("Add IP", self.action_add_ip))

        actions.extend(
            [
                MenuAction("Set Gateway", self.action_set_gateway),
                MenuAction("Set DNS", self.action_set_dns),
                MenuAction("Interface Manage", self.menu_interface_manage),
                MenuAction("Restart Network", self.action_restart_network),
                MenuAction("Refresh", self.action_refresh),
            ]
        )
        return actions

    def render_status_badges(self, summary: InterfaceSummary | None, default_iface: str | None) -> str:
        if not summary:
            return "  ".join(
                [
                    render_badge("NO INTERFACE", "danger"),
                    render_badge(self.controller.resolver_mode.upper(), "secondary"),
                ]
            )
        state_tone = "success" if summary.state in {"UP", "UNKNOWN"} else "warning"
        badges = [
            render_badge(summary.state, state_tone),
            render_badge("SSH GUARD", "success" if is_ssh_session() else "secondary"),
            render_badge(self.controller.resolver_mode.upper(), "secondary"),
        ]
        if summary.name == default_iface:
            badges.insert(1, render_badge("DEFAULT", "primary"))
        return "  ".join(badges)

    def render_summary_panel(self, summary: InterfaceSummary | None, default_iface: str | None) -> str:
        if not summary:
            lines = [
                "Interface   unavailable",
                "Telemetry   no live interface data is available",
            ]
            return render_panel("Node Snapshot", lines, tone="danger")

        current_marker = " default route target" if summary.name == default_iface else " standby path"
        lines = [
            f"Interface   {summary.name}{current_marker}",
            f"IP Stack    {', '.join(summary.addresses) if summary.addresses else 'None'}",
            f"Gateway     {summary.primary_gateway or 'None'}",
            f"Resolvers   {', '.join(summary.dns_servers) if summary.dns_servers else 'None'}",
            f"Link Meta   MTU {summary.mtu or 'unknown'} | MAC {summary.mac or 'unknown'}",
        ]
        return render_panel("Node Snapshot", lines, tone="primary")

    def render_action_panel(self, actions: list[MenuAction]) -> str:
        lines = []
        for index, action in enumerate(actions, start=1):
            hint = ACTION_HINTS.get(action.label, "network control operation")
            lines.append(f"[{index}] {action.label:<17} // {hint}")
        lines.append("")
        lines.append("Use the action number to open a workflow. Ctrl+C cancels the current stage.")
        return render_panel("Control Matrix", lines, tone="secondary")

    def render_message_panel(self) -> str | None:
        if not self.message:
            return None
        return render_panel("Event Log", [self.message], tone=message_tone(self.message))

    def render_screen(self, title: str, body_lines: list[str], *, tone: str = "secondary") -> None:
        clear_screen()
        print(header_box(APP_NAME))
        print()
        print(render_panel(title, body_lines, tone=tone))

    def render(self, actions: list[MenuAction]) -> None:
        clear_screen()
        summary = self.current_summary()
        default_iface = self.controller.default_interface_name()
        print(header_box(APP_NAME))
        print()
        print(self.render_status_badges(summary, default_iface))
        print()
        print(self.render_summary_panel(summary, default_iface))
        print()
        print(self.render_action_panel(actions))
        print()
        print(paint("[0] Exit", BOLD, FG["muted"]))
        message_panel = self.render_message_panel()
        if message_panel:
            print()
            print(message_panel)

    def execute_plan(
        self,
        title: str,
        operations: list[Operation],
        *,
        action: str,
        confirm_mode: str = "yesno",
        confirm_value: str | None = None,
        warning_lines: list[str] | None = None,
    ) -> None:
        if not operations:
            self.message = f"{title}: nothing to do."
            return

        body_lines = [
            f"Interface   {self.current_interface}",
            "",
            "Execution preview:",
            *[f"  - {operation.preview}" for operation in operations],
        ]
        self.render_screen(title, body_lines, tone="primary")
        if warning_lines:
            print()
            print(render_panel("High Risk Notice", warning_lines, tone="danger"))

        if confirm_mode == "typed":
            token = confirm_value or "CONFIRM"
            confirm = prompt_input(f"\nType {token} to apply these changes: ")
            if confirm != token:
                self.message = f"{title} cancelled."
                return
        else:
            confirm = prompt_input("\nApply these changes? [y/N]: ").lower()
            if confirm not in {"y", "yes"}:
                self.message = f"{title} cancelled."
                return

        snapshot = self.controller.capture_snapshot(self.current_interface)
        guard_needed = self.guard_required(action)
        timeout = self.guard_timeout(action)
        guard = RollbackGuard.arm(snapshot, timeout) if guard_needed else None

        try:
            for operation in operations:
                operation.executor()
        except Exception as exc:
            if guard:
                guard.cancel()
            self.controller.apply_snapshot(snapshot)
            self.message = f"{title} failed and was rolled back: {exc}"
            return

        if not guard:
            self.message = f"{title} applied successfully."
            return

        try:
            answer, timed_out = prompt_timeout(
                f"Type 'yes' within {timeout}s to keep the change; anything else will roll it back: ",
                timeout,
            )
        except (PromptCancelled, InputClosed):
            self.controller.apply_snapshot(snapshot)
            guard.cancel()
            self.message = f"{title} rolled back."
            return
        if answer.lower() in {"y", "yes", "keep"}:
            guard.cancel()
            self.message = f"{title} applied and confirmed."
            return

        self.controller.apply_snapshot(snapshot)
        guard.cancel()
        if timed_out:
            self.message = f"{title} rolled back because no confirmation arrived."
        else:
            self.message = f"{title} rolled back."

    def guard_required(self, action: str) -> bool:
        if self.current_interface == "lo":
            return False
        return action in {
            "add_ip_with_gateway",
            "remove_ip",
            "change_ip",
            "set_gateway",
            "disable_interface",
            "restart_network",
        }

    def guard_timeout(self, action: str) -> int:
        if action in {"add_ip_with_gateway", "disable_interface", "restart_network", "set_gateway"}:
            return DANGEROUS_GUARD_TIMEOUT
        return DEFAULT_GUARD_TIMEOUT

    def prompt_ip_wizard(
        self,
        *,
        title: str,
        existing_addresses: list[str],
        replacing_cidr: str | None = None,
    ) -> IPWizardResult | None:
        body_lines = [f"Interface   {self.current_interface}"]
        if replacing_cidr:
            body_lines.append(f"Current IP  {replacing_cidr}")
        if existing_addresses:
            body_lines.append(f"Existing IPs {' | '.join(existing_addresses)}")
        body_lines.extend(
            [
                "",
                "Wizard collects: new IP/CIDR, optional gateway, optional DNS",
                "Blank gateway or DNS keeps the current value unchanged.",
            ]
        )
        self.render_screen(title, body_lines, tone="secondary")

        raw_cidr = prompt_input("New IP/CIDR: ")
        if not raw_cidr:
            self.message = f"{title} cancelled."
            return None
        try:
            cidr = normalize_cidr(raw_cidr)
        except ValueError as exc:
            self.message = f"Invalid CIDR: {exc}"
            return None
        if cidr in existing_addresses:
            self.message = f"{cidr} is already assigned on {self.current_interface}."
            return None

        gateway = None
        raw_gateway = prompt_input("Optional gateway (blank to keep current): ")
        if raw_gateway:
            try:
                gateway = normalize_ip(raw_gateway)
            except ValueError as exc:
                self.message = f"Invalid gateway: {exc}"
                return None
            if ipaddress.ip_interface(cidr).ip.version != ipaddress.ip_address(gateway).version:
                self.message = "Gateway family must match the new IP address family."
                return None

        dns_servers = None
        raw_dns = prompt_input("Optional DNS servers (space or comma separated, blank to keep current): ")
        if raw_dns:
            try:
                dns_servers = parse_dns_servers(raw_dns)
            except ValueError as exc:
                self.message = f"Invalid DNS entry: {exc}"
                return None

        return IPWizardResult(cidr=cidr, gateway=gateway, dns_servers=dns_servers)

    def choose_existing_address(self) -> str | None:
        addresses = self.managed_addresses()
        if not addresses:
            self.message = "No IP address is assigned to the current interface."
            return None
        if len(addresses) == 1:
            return addresses[0]
        lines = ["Select the IP target for this workflow.", ""]
        for index, address in enumerate(addresses, start=1):
            lines.append(f"[{index}] {address}")
        self.render_screen("Address Selector", lines, tone="secondary")
        raw = prompt_input("Choose address number or type the CIDR: ")
        if not raw:
            return None
        if raw.isdigit():
            idx = int(raw) - 1
            if 0 <= idx < len(addresses):
                return addresses[idx]
            self.message = "Invalid address selection."
            return None
        try:
            cidr = normalize_cidr(raw)
        except ValueError as exc:
            self.message = f"Invalid CIDR: {exc}"
            return None
        if cidr not in addresses:
            self.message = f"{cidr} is not assigned on {self.current_interface}."
            return None
        return cidr

    def run_replace_flow(self, *, title: str, old_cidr: str) -> None:
        settings = self.prompt_ip_wizard(
            title=title,
            existing_addresses=self.managed_addresses(),
            replacing_cidr=old_cidr,
        )
        if not settings:
            return
        operations = self.controller.plan_change_ip(
            self.current_interface,
            old_cidr,
            settings.cidr,
            gateway=settings.gateway,
            dns_servers=settings.dns_servers,
        )
        self.execute_plan(title, operations, action="change_ip")

    def action_add_ip(self) -> None:
        if self.managed_addresses():
            self.message = "Add IP is only available when the selected interface has no IP address."
            return
        settings = self.prompt_ip_wizard(title="Add IP", existing_addresses=self.managed_addresses())
        if not settings:
            return
        action = "add_ip_with_gateway" if settings.gateway else "add_ip"
        operations = self.controller.plan_add_ip(
            self.current_interface,
            settings.cidr,
            gateway=settings.gateway,
            dns_servers=settings.dns_servers,
        )
        self.execute_plan("Add IP", operations, action=action)

    def action_remove_ip(self) -> None:
        cidr = self.choose_existing_address()
        if not cidr:
            return
        if not prompt_yes_no(f"Remove {cidr} from {self.current_interface}?", default=False):
            self.message = "Remove IP cancelled."
            return
        if prompt_yes_no("Set up a new IP before removing this one?", default=True):
            self.run_replace_flow(title="Change IP", old_cidr=cidr)
            return
        warning_lines = [
            f"Warning: removing {cidr} without a replacement may close your SSH connection immediately.",
            "You can lose network access until rollback restores the previous settings.",
        ]
        self.execute_plan(
            "Remove IP",
            self.controller.plan_remove_ip(self.current_interface, cidr),
            action="remove_ip",
            confirm_mode="typed",
            confirm_value="REMOVE",
            warning_lines=warning_lines,
        )

    def action_change_ip(self) -> None:
        old_cidr = self.choose_existing_address()
        if not old_cidr:
            return
        self.run_replace_flow(title="Change IP", old_cidr=old_cidr)

    def action_set_gateway(self) -> None:
        raw = prompt_input("Gateway IP: ")
        if not raw:
            self.message = "Set gateway cancelled."
            return
        try:
            gateway = normalize_ip(raw)
        except ValueError as exc:
            self.message = f"Invalid gateway: {exc}"
            return
        self.execute_plan(
            "Set Gateway",
            self.controller.plan_set_gateway(self.current_interface, gateway),
            action="set_gateway",
        )

    def action_set_dns(self) -> None:
        raw = prompt_input("DNS servers (space or comma separated): ")
        if not raw:
            self.message = "Set DNS cancelled."
            return
        try:
            servers = parse_dns_servers(raw)
        except ValueError as exc:
            self.message = f"Invalid DNS entry: {exc}"
            return
        self.execute_plan(
            "Set DNS",
            self.controller.plan_set_dns(self.current_interface, servers),
            action="set_dns",
        )

    def action_refresh(self) -> None:
        self.current_interface = self.controller.default_interface_name() or self.current_interface
        self.message = "State refreshed."

    def menu_interface_manage(self) -> None:
        while True:
            lines = [
                f"Interface   {self.current_interface}",
                "",
                "[1] Switch Interface",
                "[2] Enable Interface",
                "[3] Disable Interface",
            ]
            self.render_screen("Interface Matrix", lines, tone="secondary")
            print()
            print(paint("[0] Back", BOLD, FG["muted"]))
            choice = prompt_input("Select an option: ")
            if choice == "1":
                self.action_switch_interface()
            elif choice == "2":
                self.execute_plan(
                    "Enable Interface",
                    self.controller.plan_link_state(self.current_interface, True),
                    action="enable_interface",
                )
            elif choice == "3":
                self.execute_plan(
                    "Disable Interface",
                    self.controller.plan_link_state(self.current_interface, False),
                    action="disable_interface",
                )
            elif choice == "0":
                return
            else:
                self.message = f"Unknown option: {choice or 'blank'}"
                return

    def action_switch_interface(self) -> None:
        interfaces = self.controller.displayable_interfaces()
        lines = []
        for index, summary in enumerate(interfaces, start=1):
            marker = " <==" if summary.name == self.current_interface else ""
            lines.append(
                f"[{index}] {summary.name} ({summary.state}) "
                f"IPs: {', '.join(summary.addresses) if summary.addresses else 'None'}{marker}"
            )
        self.render_screen("Interface Selector", lines, tone="secondary")
        raw = prompt_input("Choose interface: ")
        if not raw:
            self.message = "Switch interface cancelled."
            return
        if not raw.isdigit():
            self.message = "Interface selection must be a number."
            return
        idx = int(raw) - 1
        if not (0 <= idx < len(interfaces)):
            self.message = "Invalid interface selection."
            return
        self.current_interface = interfaces[idx].name
        self.message = f"Switched to {self.current_interface}."

    def action_restart_network(self) -> None:
        self.execute_plan(
            "Restart Network",
            self.controller.plan_restart(self.current_interface),
            action="restart_network",
        )


def run_rollback_daemon(snapshot_file: str, timeout: int) -> int:
    active_file = os.path.join(os.path.dirname(snapshot_file), "active")
    time.sleep(timeout)
    if not os.path.exists(active_file):
        return 0
    try:
        snapshot = Snapshot.from_dict(json.loads(Path(snapshot_file).read_text(encoding="utf-8")))
        NetworkController().apply_snapshot(snapshot)
    finally:
        if os.path.exists(active_file):
            os.remove(active_file)
    return 0


def ensure_root() -> None:
    if os.geteuid() == 0:
        return
    sudo = shutil.which("sudo")
    if not sudo:
        print("bhipconfig needs root privileges. Run it with sudo or as root.")
        raise SystemExit(1)
    os.execvp(sudo, [sudo, "-E", SCRIPT_PATH, *sys.argv[1:]])


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(prog="bhipconfig", add_help=True)
    parser.add_argument("--rollback-daemon", nargs=2, metavar=("SNAPSHOT", "TIMEOUT"))
    return parser.parse_args()


def main() -> int:
    args = parse_args()
    if args.rollback_daemon:
        snapshot_file, timeout_raw = args.rollback_daemon
        return run_rollback_daemon(snapshot_file, int(timeout_raw))

    ensure_root()
    app = BhipConfigApp()
    return app.run()


if __name__ == "__main__":
    raise SystemExit(main())
