Source code for eopod._eopod_cli

# Copyright 2025 The EasyDeL/eopod Author @erfanzar (Erfan Zare Chavoshi).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import json
import logging
import pathlib
import re
import shlex
import shutil
import subprocess
import time
from datetime import datetime

import click
import yaml
from rich.console import Console
from rich.logging import RichHandler
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.prompt import Prompt
from rich.table import Table
from rich.theme import Theme

from ._utils import EOPOD_PATH, PYTHON_PATH, EOConfig, TPUManager, async_command, run_command

console = Console(theme=Theme({"info": "cyan", "warning": "yellow", "error": "white", "success": "green"}))

logging.basicConfig(
    level=logging.INFO, format="%(message)s", handlers=[RichHandler(console=console, rich_tracebacks=True)]
)


@click.group()
def cli():
    """eopod - Enhanced TPU Command Runner"""
    pass


def _read_gcloud_config_property(property_name: str) -> str:
    """Read gcloud config property and fail clearly when unset."""
    result = subprocess.run(
        ["gcloud", "config", "get", property_name],
        capture_output=True,
        text=True,
        check=True,
    )
    value = result.stdout.strip()
    if not value or value == "(unset)":
        raise ValueError(f"gcloud config property '{property_name}' is unset")
    return value


def _network_name(network_value: str | None) -> str:
    """Normalize full network URI to simple network name."""
    if not network_value:
        return "default"
    return network_value.rsplit("/", 1)[-1]


def _basename(resource_name: str | None) -> str | None:
    if not resource_name:
        return None
    return resource_name.rsplit("/", 1)[-1]


def _metadata_value(path: str) -> str | None:
    url = f"http://metadata.google.internal/computeMetadata/v1/{path}"
    result = subprocess.run(
        ["curl", "-fsS", "-H", "Metadata-Flavor: Google", url],
        capture_output=True,
        text=True,
    )
    if result.returncode != 0:
        return None
    value = result.stdout.strip()
    return value or None


def _detect_project_id_from_metadata() -> str | None:
    return _metadata_value("project/project-id")


def _detect_zone_from_metadata() -> str | None:
    zone_output = _metadata_value("instance/zone")
    if not zone_output:
        return None
    zone_match = re.search(r"/zones/([^/]+)", zone_output)
    if zone_match:
        return zone_match.group(1)
    return None


def _detect_self_internal_ip_from_metadata() -> str | None:
    return _metadata_value("instance/network-interfaces/0/ip")


def _list_tpu_vms(project_id: str, zone: str) -> list[dict]:
    cmd = [
        "gcloud",
        "compute",
        "tpus",
        "tpu-vm",
        "list",
        f"--project={project_id}",
        f"--zone={zone}",
        "--format=json(name,queuedResource,state,networkEndpoints.ipAddress)",
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(result.stderr.strip() or "Failed to list TPU VMs")
    try:
        data = json.loads(result.stdout.strip() or "[]")
    except json.JSONDecodeError as e:
        raise RuntimeError(f"Failed to parse TPU VM list output: {e}") from e
    return data if isinstance(data, list) else []


def _detect_tpu_identity_from_current_vm(project_id: str, zone: str) -> tuple[str | None, str | None]:
    self_ip = _detect_self_internal_ip_from_metadata()
    if not self_ip:
        return None, None

    # Preferred path: let gcloud do endpoint flatten/filter directly.
    try:
        cmd = [
            "gcloud",
            "compute",
            "tpus",
            "tpu-vm",
            "list",
            f"--project={project_id}",
            f"--zone={zone}",
            "--flatten=networkEndpoints[]",
            f"--filter=networkEndpoints.ipAddress={self_ip}",
            "--format=value(name.basename(),queuedResource.basename())",
        ]
        result = subprocess.run(cmd, capture_output=True, text=True)
        if result.returncode == 0:
            line = next((ln.strip() for ln in result.stdout.splitlines() if ln.strip()), "")
            if line:
                parts = line.split()
                detected_tpu = parts[0] if len(parts) >= 1 else None
                detected_queued = parts[1] if len(parts) >= 2 else None
                return detected_tpu, detected_queued
    except Exception:
        pass

    # Fallback path: parse structured JSON from list output.
    try:
        for node in _list_tpu_vms(project_id, zone):
            endpoints = node.get("networkEndpoints", []) or []
            endpoint_ips = [ep.get("ipAddress") for ep in endpoints if ep.get("ipAddress")]
            if self_ip in endpoint_ips:
                return _basename(node.get("name")), _basename(node.get("queuedResource"))
    except Exception:
        return None, None
    return None, None


def _resolve_tpu_from_queued_resource(project_id: str, zone: str, queued_resource: str) -> str | None:
    queued_resource = _basename(queued_resource) or queued_resource
    try:
        matches = []
        for node in _list_tpu_vms(project_id, zone):
            queued_basename = _basename(node.get("queuedResource"))
            if queued_basename == queued_resource:
                matches.append(node)

        if not matches:
            return None

        active = next((m for m in matches if str(m.get("state", "")).upper() in {"READY", "ACTIVE"}), None)
        selected = active or matches[0]
        return _basename(selected.get("name"))
    except Exception:
        return None


def _describe_tpu_vm(project_id: str, zone: str, tpu_name: str) -> dict | None:
    cmd = [
        "gcloud",
        "compute",
        "tpus",
        "tpu-vm",
        "describe",
        tpu_name,
        f"--project={project_id}",
        f"--zone={zone}",
        "--format=json",
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        return None
    try:
        return json.loads(result.stdout.strip() or "{}")
    except json.JSONDecodeError:
        return None


def _normalize_target_tags(tags: list[str] | tuple[str, ...] | None) -> list[str]:
    normalized = []
    seen = set()
    for tag in tags or []:
        value = str(tag).strip()
        if not value or value in seen:
            continue
        seen.add(value)
        normalized.append(value)
    return normalized


def _preferred_tpu_target_tags(tags: list[str] | tuple[str, ...] | None) -> list[str]:
    normalized = _normalize_target_tags(tags)
    if not normalized:
        return []

    # TPU-specific per-node tags are the safest match for firewall rules.
    tpu_specific = [tag for tag in normalized if tag.startswith("tpu-")]
    if tpu_specific:
        return tpu_specific

    non_google = [tag for tag in normalized if not tag.startswith("x-google-")]
    return non_google or normalized


def _detect_instance_tags_from_metadata() -> list[str]:
    raw_tags = _metadata_value("instance/tags")
    if not raw_tags:
        return []

    try:
        parsed = json.loads(raw_tags)
    except json.JSONDecodeError:
        parsed = None

    if isinstance(parsed, list):
        return _preferred_tpu_target_tags(parsed)

    return _preferred_tpu_target_tags([line.strip() for line in raw_tags.splitlines() if line.strip()])


def _project_number_from_resource_name(resource_name: str | None) -> str | None:
    if not resource_name:
        return None

    match = re.search(r"(?:^|/)projects/(\d+)(?:/|$)", resource_name)
    if match:
        return match.group(1)
    return None


def _read_project_number(project_id: str) -> str | None:
    result = subprocess.run(
        ["gcloud", "projects", "describe", project_id, "--format=value(projectNumber)"],
        capture_output=True,
        text=True,
    )
    if result.returncode != 0:
        return None
    value = result.stdout.strip()
    return value or None


def _describe_firewall_rule(project_id: str, rule_name: str) -> dict | None:
    result = subprocess.run(
        [
            "gcloud",
            "compute",
            "firewall-rules",
            "describe",
            rule_name,
            f"--project={project_id}",
            "--format=json",
        ],
        capture_output=True,
        text=True,
    )
    if result.returncode != 0:
        return None
    try:
        return json.loads(result.stdout.strip() or "{}")
    except json.JSONDecodeError:
        return None


def _resolve_tpu_vm_target_tags(
    project_id: str,
    zone: str,
    tpu_name: str,
    tpu_info: dict | None = None,
) -> list[str]:
    tpu_info = tpu_info or _describe_tpu_vm(project_id, zone, tpu_name)
    if not tpu_info:
        return []

    direct_tags = _preferred_tpu_target_tags(
        tpu_info.get("tags") or tpu_info.get("networkConfig", {}).get("networkTags", [])
    )
    if direct_tags:
        return direct_tags

    project_number = (
        _project_number_from_resource_name(tpu_info.get("queuedResource"))
        or _project_number_from_resource_name(tpu_info.get("name"))
        or _read_project_number(project_id)
    )
    tpu_id = str(tpu_info.get("id") or "").strip()
    if project_number and tpu_id:
        healthcheck_rule = _describe_firewall_rule(
            project_id,
            f"tpu-{project_number}-{tpu_id}-healthcheck-fw",
        )
        healthcheck_tags = _preferred_tpu_target_tags((healthcheck_rule or {}).get("targetTags", []))
        if healthcheck_tags:
            return healthcheck_tags

    detected_tpu, _ = _detect_tpu_identity_from_current_vm(project_id, zone)
    if detected_tpu == tpu_name:
        metadata_tags = _detect_instance_tags_from_metadata()
        if metadata_tags:
            return metadata_tags

    return []


def _resolve_runtime_credentials(require_tpu: bool, verbose: bool = False):
    config = EOConfig()
    project_id, zone, tpu_name = config.get_credentials()
    queued_resource = config.get_queued_resource()
    original = (project_id, zone, tpu_name, queued_resource)

    if not project_id:
        project_id = _detect_project_id_from_metadata()
        if project_id and verbose:
            console.print(f"[yellow]Auto-detected project ID from metadata: {project_id}[/yellow]")
    if not zone:
        zone = _detect_zone_from_metadata()
        if zone and verbose:
            console.print(f"[yellow]Auto-detected zone from metadata: {zone}[/yellow]")

    if not project_id:
        try:
            project_id = _read_gcloud_config_property("project")
            if verbose:
                console.print(f"[yellow]Using gcloud config project: {project_id}[/yellow]")
        except Exception:
            pass
    if not zone:
        try:
            zone = _read_gcloud_config_property("compute/zone")
            if verbose:
                console.print(f"[yellow]Using gcloud config zone: {zone}[/yellow]")
        except Exception:
            pass

    if project_id and zone:
        if not tpu_name:
            detected_tpu, detected_queued = _detect_tpu_identity_from_current_vm(project_id, zone)
            if detected_tpu:
                tpu_name = detected_tpu
                queued_resource = queued_resource or detected_queued
                if verbose:
                    console.print(f"[yellow]Auto-detected TPU name: {tpu_name}[/yellow]")
            elif queued_resource:
                resolved_tpu = _resolve_tpu_from_queued_resource(project_id, zone, queued_resource)
                if resolved_tpu:
                    tpu_name = resolved_tpu
                    if verbose:
                        console.print(f"[yellow]Resolved TPU name from queued resource: {tpu_name}[/yellow]")
        else:
            tpu_info = _describe_tpu_vm(project_id, zone, tpu_name)
            if tpu_info:
                current_queued = _basename(tpu_info.get("queuedResource"))
                if current_queued:
                    queued_resource = current_queued
            else:
                resolved_tpu = None
                if queued_resource:
                    resolved_tpu = _resolve_tpu_from_queued_resource(project_id, zone, queued_resource)
                if not resolved_tpu:
                    detected_tpu, detected_queued = _detect_tpu_identity_from_current_vm(project_id, zone)
                    resolved_tpu = detected_tpu
                    queued_resource = queued_resource or detected_queued
                if resolved_tpu:
                    if verbose:
                        console.print(
                            f"[yellow]Configured TPU '{tpu_name}' is stale; using '{resolved_tpu}' instead.[/yellow]"
                        )
                    tpu_name = resolved_tpu

    has_required = all([project_id, zone, tpu_name]) if require_tpu else all([project_id, zone])
    if not has_required:
        return None

    if (project_id, zone, tpu_name, queued_resource) != original and project_id and zone and tpu_name:
        config.set_credentials(project_id, zone, tpu_name, queued_resource=queued_resource)
        config.save_config()

    return config, project_id, zone, tpu_name, queued_resource


def _get_config_and_manager():
    resolved = _resolve_runtime_credentials(require_tpu=True, verbose=True)
    if not resolved:
        raise click.ClickException(
            "Could not resolve project/zone/tpu automatically. Run 'eopod configure' or pass values explicitly."
        )
    config, project_id, zone, tpu_name, _queued_resource = resolved
    tpu_manager = TPUManager(project_id, zone, tpu_name)
    return config, tpu_manager


@cli.command()
@click.option("--project-id", help="Google Cloud Project ID (optional if running on GCP)")
@click.option("--zone", help="Google Cloud Zone (optional if running on GCP)")
@click.option("--tpu-name", required=False, help="TPU Name (auto-detected on TPU VM if omitted)")
def configure(project_id, zone, tpu_name):
    """Configure eopod with your Google Cloud details"""
    config = EOConfig()
    stored_project_id, stored_zone, stored_tpu_name = config.get_credentials()
    queued_resource = config.get_queued_resource()

    if not project_id:
        project_id = _detect_project_id_from_metadata()
        if project_id:
            console.print(f"[yellow]Auto-detected project ID: {project_id}[/yellow]")
    if not project_id:
        try:
            project_id = _read_gcloud_config_property("project")
            console.print(f"[yellow]Using gcloud config project: {project_id}[/yellow]")
        except Exception:
            pass

    if not project_id and stored_project_id:
        project_id = stored_project_id
        console.print(f"[yellow]Using saved project ID: {project_id}[/yellow]")

    if not project_id:
        console.print("[red]Failed to auto-detect project ID. Please provide it manually.[/red]")
        return

    if not zone:
        zone = _detect_zone_from_metadata()
        if zone:
            console.print(f"[yellow]Auto-detected zone: {zone}[/yellow]")
    if not zone:
        try:
            zone = _read_gcloud_config_property("compute/zone")
            console.print(f"[yellow]Using gcloud config zone: {zone}[/yellow]")
        except Exception:
            pass

    if not zone and stored_zone:
        zone = stored_zone
        console.print(f"[yellow]Using saved zone: {zone}[/yellow]")

    if not zone:
        console.print("[red]Failed to auto-detect zone. Please provide it manually.[/red]")
        return

    if not tpu_name:
        detected_tpu, detected_queued = _detect_tpu_identity_from_current_vm(project_id, zone)
        if detected_tpu:
            tpu_name = detected_tpu
            queued_resource = detected_queued or queued_resource
            console.print(f"[yellow]Auto-detected TPU name: {tpu_name}[/yellow]")
        elif queued_resource:
            resolved_tpu = _resolve_tpu_from_queued_resource(project_id, zone, queued_resource)
            if resolved_tpu:
                tpu_name = resolved_tpu
                console.print(f"[yellow]Resolved TPU name from queued resource: {tpu_name}[/yellow]")
        elif stored_tpu_name:
            tpu_name = stored_tpu_name
            console.print(f"[yellow]Using saved TPU name: {tpu_name}[/yellow]")

    if not tpu_name:
        console.print("[red]Failed to auto-detect TPU name. Please provide --tpu-name manually.[/red]")
        return

    tpu_info = _describe_tpu_vm(project_id, zone, tpu_name)
    if not tpu_info and queued_resource:
        resolved_tpu = _resolve_tpu_from_queued_resource(project_id, zone, queued_resource)
        if resolved_tpu and resolved_tpu != tpu_name:
            console.print(f"[yellow]Saved TPU name '{tpu_name}' is stale; using '{resolved_tpu}'.[/yellow]")
            tpu_name = resolved_tpu
            tpu_info = _describe_tpu_vm(project_id, zone, tpu_name)

    if tpu_info:
        current_queued = _basename(tpu_info.get("queuedResource"))
        if current_queued:
            queued_resource = current_queued
    else:
        console.print(f"[yellow]Could not verify TPU '{tpu_name}' in {zone}. Saving configuration as provided.[/yellow]")

    config.set_credentials(project_id, zone, tpu_name, queued_resource=queued_resource)
    config.save_config()
    console.print("[green]Configuration saved successfully![/green]")
    if queued_resource:
        console.print(f"[green]Queued resource: {queued_resource}[/green]")


async def _install_package_uv(packages, uv_location):
    """
    Install one or more Python packages via uv on TPU workers.

    Example:
        install-package-uv torch numpy
    """
    _, tpu_manager = _get_config_and_manager()

    if uv_location is None:
        uv_location = str(pathlib.Path().home() / ".local" / "bin" / "uv")

    packages_str = " ".join(packages)

    cmd = f"{uv_location} pip install --python {PYTHON_PATH} {packages_str}"
    await tpu_manager.execute_command(cmd)


@cli.command()
@click.argument("packages", nargs=-1, required=True)
@click.option(
    "--uv-location",
    default=None,
    help="Path to uv executable (default: ~/.local/bin/uv)",
)
@async_command
async def install_package_uv(packages, uv_location):
    """
    Install one or more Python packages via uv on TPU workers.

    Example:
        install-package-uv torch numpy
    """
    await _install_package_uv(packages, uv_location)


@cli.command()
@async_command
async def get_internal_ips():
    """Get internal IP addresses of TPU workers."""
    _config, tpu_manager = _get_config_and_manager()
    try:
        internal_ips = await tpu_manager.get_internal_ips()
        tpu_manager.display_ips(internal_ips, "internal", output_format="comma")
    except Exception as e:
        console.print(f"[red]Failed to get internal IPs: {e!s}[/red]")


@cli.command()
@async_command
async def get_external_ips():
    """Get external IP addresses of TPU workers."""

    _config, tpu_manager = _get_config_and_manager()
    try:
        external_ips = await tpu_manager.get_external_ips()
        console.print(external_ips)
    except Exception as e:
        console.print(f"[red]Failed to get external IPs: {e!s}[/red]")


@cli.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.argument("cmd_args", nargs=-1, type=click.UNPROCESSED)
@click.option("--worker", default="all", help='Specific worker or "all"')
@click.option("--retry", default=1, help="Number of retries for failed commands")
@click.option("--delay", default=5, help="Delay between retries in seconds")
@click.option("--timeout", default=-1, help="Command timeout in seconds")
@click.option("--no-stream", is_flag=True, help="Disable output streaming")
@click.option("--background", is_flag=True, help="Run command in background")
@async_command
async def run(cmd_args, worker, retry, delay, timeout, no_stream, background):
    """Run a command on TPU VM with advanced features"""
    if not cmd_args:
        console.print("[red]No command provided[/red]")
        return

    command = " ".join(cmd_args)
    stream = not no_stream
    if timeout == -1:
        timeout = None

    config, tpu_manager = _get_config_and_manager()
    start_time = datetime.now()
    console.print(f"[cyan]Started at: {start_time.strftime('%Y-%m-%d %H:%M:%S')}[/cyan]")
    console.print(f"[cyan]Executing: {command}[/cyan]")

    with Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
        disable=stream,
    ) as progress:
        task = progress.add_task(description=f"Executing command: {command}", total=None)

        for attempt in range(1, retry + 1):
            try:
                returncode, stdout, stderr = await asyncio.wait_for(
                    tpu_manager.execute_command(command, worker, stream=stream, background=background),
                    timeout=timeout,
                )

                if returncode == 0:
                    if not stream and not background:
                        progress.update(task, description="[green]Command completed successfully![/green]")
                        console.print("\nOutput:")
                        console.print(stdout)

                    end_time = datetime.now()
                    duration = end_time - start_time
                    console.print(f"[cyan]Duration: {duration}[/cyan]")

                    config.save_command_history(command, "success", stdout if not stream else "Streamed output")
                    break
                else:
                    progress.update(task, description=f"[red]Attempt {attempt} failed[/red]")
                    console.print(f"[red]Error: {stderr}[/red]")
                    config.save_error_log(command, stderr)

            except TimeoutError:
                console.print(f"[red]Command timed out after {timeout} seconds[/red]")
                config.save_error_log(command, "Command timed out")
            except Exception as e:
                console.print(f"[red]Error: {e!s}[/red]")
                config.save_error_log(command, str(e))
                break

            if attempt < retry:
                await asyncio.sleep(delay)


@cli.command()
@click.argument("pid_args", nargs=-1)
@click.option("--worker", default="all", help='Specific worker or "all"')
@async_command
async def check_background(pid_args, worker):
    """Check status of background processes"""

    _config, tpu = _get_config_and_manager()

    if pid_args:
        pids = " ".join(pid_args)
        command = f"ps -p {pids} -f"
    else:
        command = "ps aux | grep nohup | grep -v grep"

    returncode, stdout, stderr = await tpu.execute_command(command, worker)

    if returncode == 0:
        console.print("[green]Background Processes:[/green]")
        console.print(stdout)
    else:
        console.print(f"[red]Error checking background processes:[/red] {stderr}")


@cli.command()
@async_command
async def setup_path():
    """Add ~/.local/bin to PATH on all TPU workers if not already present"""
    config, tpu = _get_config_and_manager()

    path_command = """
    if [[ ":$PATH:" != *":$HOME/.local/bin:"* ]]; then
        echo 'export PATH=$PATH:$HOME/.local/bin' >> ~/.bashrc
        source ~/.bashrc
        echo "[success] Added ~/.local/bin to PATH"
    else
        echo "[info] ~/.local/bin is already in PATH"
    fi
    """

    with Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
    ) as progress:
        task = progress.add_task(description="Adding ~/.local/bin to PATH on all workers...", total=None)

        try:
            returncode, stdout, stderr = await tpu.execute_command(path_command, worker="all", stream=False)

            if returncode == 0:
                progress.update(task, description="[green]Successfully updated PATH on all workers[/green]")
                console.print("\nDetailed results:")
                console.print(stdout)
            else:
                progress.update(task, description=f"[red]Failed to update PATH: {stderr}[/red]")

        except Exception as e:
            progress.update(task, description=f"[red]Error: {e!s}[/red]")
            config.save_error_log("add_local_bin_to_path", str(e))


@cli.command()
@click.argument("pid_args", nargs=-1, required=True)
@click.option("--worker", default="all", help='Specific worker or "all"')
@click.option("--force", is_flag=True, help="Force kill the process")
@async_command
async def kill(pid_args, worker, force):
    """Kill a background process"""
    pids = " ".join(pid_args)
    _config, tpu = _get_config_and_manager()

    signal = "-9" if force else "-15"
    command = f"kill {signal} {pids}"

    returncode, _stdout, stderr = await tpu.execute_command(command, worker)

    if returncode == 0:
        console.print(f"[green]Successfully {'force ' if force else ''}killed process(es) {pids}[/green]")
    else:
        console.print(f"[red]Error killing process(es):[/red] {stderr}")


@cli.command()
@async_command
async def status():
    """Show TPU status and information"""
    _config, tpu = _get_config_and_manager()
    try:
        status = await tpu.get_status()
        network = _network_name(status.get("networkConfig", {}).get("network") or status.get("network"))

        table = Table(title="TPU Status")
        table.add_column("Property")
        table.add_column("Value")

        table.add_row("Name", status.get("name", ""))
        table.add_row("State", status.get("state", ""))
        table.add_row("Type", status.get("acceleratorType", ""))
        table.add_row("Network", network)
        table.add_row("API Version", status.get("apiVersion", ""))

        console.print(table)

    except RuntimeError as e:
        console.print(f"[red]{e}[/red]")


@cli.command()
@click.option("--worker", default="all", help='Specific worker or "all"')
@click.option("--force", is_flag=True, help="Force kill all processes")
@click.option("--pid", multiple=True, type=int, help="Specific PIDs to kill")
@async_command
async def kill_tpu(worker, force, pid):
    """Kill processes using TPU resources"""
    config, tpu = _get_config_and_manager()

    with Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
    ) as progress:
        task = progress.add_task(description="Scanning for TPU processes...", total=None)

        try:
            status = await tpu.get_status()

            worker_count = 1
            if "networkEndpoints" in status:
                worker_count = len(status["networkEndpoints"])

            workers = range(worker_count) if worker == "all" else [int(worker)]

            check_process_cmd = (
                "ps aux | grep -E 'python|jax|tensorflow' | "
                "grep -v grep | awk '{print $2}' | "
                "while read pid; do "
                "  if [ -d /proc/$pid ] && grep -q 'accel' /proc/$pid/maps 2>/dev/null; then "
                "    echo $pid;"
                "  fi; "
                "done"
            )

            async def scan_worker(w):
                returncode, stdout, _stderr = await tpu.execute_command(
                    check_process_cmd,
                    worker=str(w),
                    stream=False,
                )
                if returncode == 0 and stdout.strip():
                    pids = [int(p.strip()) for p in stdout.splitlines() if p.strip()]
                    return w, pids
                return w, []

            tasks = [scan_worker(w) for w in workers]
            results = await asyncio.gather(*tasks)

            worker_processes = {w: pids for w, pids in results if pids}

            if not worker_processes:
                console.print("[green]No TPU processes found.[/green]")
                return

            console.print("\n[yellow]Found TPU processes:[/yellow]")
            for w, pids in worker_processes.items():
                console.print(f"Worker {w}: PIDs {', '.join(map(str, pids))}")

            if pid:
                filtered_processes = {}
                for w, pids in worker_processes.items():
                    matching_pids = [p for p in pids if p in pid]
                    if matching_pids:
                        filtered_processes[w] = matching_pids
                worker_processes = filtered_processes

            if not force:
                if not click.confirm("[yellow]Do you want to kill these processes?[/yellow]"):
                    return

            async def kill_worker_processes(w, pids):
                results = []
                for pid in pids:
                    kill_cmd = f"kill {'-9' if force else ''} {pid}"
                    returncode, _stdout, stderr = await tpu.execute_command(kill_cmd, worker=str(w), stream=False)
                    results.append((pid, returncode == 0, stderr))
                return w, results

            kill_tasks = [kill_worker_processes(w, pids) for w, pids in worker_processes.items()]
            kill_results = await asyncio.gather(*kill_tasks)

            for w, results in kill_results:
                for pid, success, error in results:
                    if success:
                        console.print(f"[green]Successfully killed process {pid} on worker {w}[/green]")
                    else:
                        console.print(f"[red]Failed to kill process {pid} on worker {w}: {error}[/red]")
            cleanup_commands = [
                "sudo rm -f /tmp/libtpu_lockfile",
                "sudo rmmod tpu || true",
                "sudo modprobe tpu || true",
            ]

            async def cleanup_worker(w):
                results = []
                for cmd in cleanup_commands:
                    returncode, _stdout, stderr = await tpu.execute_command(cmd, worker=str(w), stream=False)
                    results.append((cmd, returncode == 0, stderr))
                return w, results

            cleanup_tasks = [cleanup_worker(w) for w in worker_processes.keys()]
            cleanup_results = await asyncio.gather(*cleanup_tasks)

            for w, results in cleanup_results:  # noqa
                progress.update(task, description=f"Cleaned up TPU resources on worker {w}")

            progress.update(task, description="Verifying TPU status...")
            final_status = await tpu.get_status()
            console.print(f"[blue]Current TPU Status: {final_status.get('state', 'Unknown')}[/blue]")

        except Exception as e:
            console.print(f"[red]Error during TPU process cleanup: {e!s}[/red]")
            config.save_error_log("kill_tpu", str(e))


async def _execute_terminal_command(project_id, zone, tpu_name, command, worker):
    """Helper function to execute commands in terminal mode"""
    try:
        tpu = TPUManager(project_id, zone, tpu_name)

        # Show a simple spinner while executing
        with Progress(SpinnerColumn(), TextColumn("Executing..."), console=console) as progress:
            task = progress.add_task("exec", total=None)
            returncode, stdout, stderr = await tpu.execute_command(command, worker, stream=False)
            progress.remove_task(task)

        if returncode == 0:
            if stdout.strip():
                console.print(stdout)
        else:
            console.print(f"[red]Command failed (exit code {returncode})[/red]")
            if stderr.strip():
                console.print(f"[red]{stderr}[/red]")

    except Exception as e:
        console.print(f"[red]Error executing command: {e}[/red]")


async def _show_status_async(project_id, zone, tpu_name):
    """Helper function to show TPU status in terminal"""
    try:
        tpu = TPUManager(project_id, zone, tpu_name)
        status = await tpu.get_status()
        network = _network_name(status.get("networkConfig", {}).get("network") or status.get("network"))

        table = Table(title="TPU Status")
        table.add_column("Property", style="cyan")
        table.add_column("Value", style="white")

        table.add_row("Name", status.get("name", ""))
        table.add_row("State", status.get("state", ""))
        table.add_row("Type", status.get("acceleratorType", ""))
        table.add_row("Network", network)

        console.print(table)
    except Exception as e:
        console.print(f"[red]Error fetching status: {e}[/red]")


@cli.command()
def history():
    """Show command execution history"""
    config = EOConfig()

    if not config.history_file.exists():
        console.print("No command history found.")
        return

    with open(config.history_file, "r") as f:
        history = yaml.safe_load(f) or []

    table = Table(title="Command History")
    table.add_column("Timestamp")
    table.add_column("Command")
    table.add_column("Status")
    table.add_column("Output (truncated)")

    for entry in history[-15:]:
        table.add_row(entry["timestamp"], entry["command"], entry["status"], entry["output"])

    console.print(table)


@cli.command()
def show_config():
    """Show current configuration"""
    resolved = _resolve_runtime_credentials(require_tpu=False, verbose=False)
    if resolved:
        _config, project_id, zone, tpu_name, queued_resource = resolved
    else:
        config = EOConfig()
        project_id, zone, tpu_name = config.get_credentials()
        queued_resource = config.get_queued_resource()

    if project_id and zone:
        table = Table(title="Current Configuration")
        table.add_column("Setting")
        table.add_column("Value")

        table.add_row("Project ID", project_id)
        table.add_row("Zone", zone)
        table.add_row("TPU Name", tpu_name or "(unset)")
        table.add_row("Queued Resource", queued_resource or "(unset)")

        console.print(table)
    else:
        console.print("[red]No configuration found. Please run 'eopod configure' first.[/red]")


@cli.command()
def doctor():
    """Run environment diagnostics for eopod and TPU access."""
    checks = []

    def add_check(name: str, status: str, details: str, fix: str = ""):
        checks.append({"name": name, "status": status, "details": details, "fix": fix})

    gcloud_path = shutil.which("gcloud")
    if gcloud_path:
        add_check("gcloud", "PASS", f"Found at {gcloud_path}")
    else:
        add_check("gcloud", "FAIL", "gcloud CLI not found on PATH", "Install Google Cloud CLI and re-run doctor")

    active_account = None
    if gcloud_path:
        auth_result = subprocess.run(
            ["gcloud", "auth", "list", "--filter=status:ACTIVE", "--format=value(account)"],
            capture_output=True,
            text=True,
        )
        if auth_result.returncode == 0 and auth_result.stdout.strip():
            active_account = auth_result.stdout.strip().splitlines()[0]
            add_check("Auth", "PASS", f"Active account: {active_account}")
        else:
            add_check("Auth", "FAIL", "No active gcloud account", "Run: gcloud auth login")

    resolved = _resolve_runtime_credentials(require_tpu=False, verbose=False)
    if resolved:
        _config, project_id, zone, tpu_name, queued_resource = resolved
    else:
        cfg = EOConfig()
        project_id, zone, tpu_name = cfg.get_credentials()
        queued_resource = cfg.get_queued_resource()

    if project_id:
        add_check("Project", "PASS", f"Project ID: {project_id}")
    else:
        add_check(
            "Project",
            "FAIL",
            "Project ID is unset",
            "Run: eopod configure --project-id <PROJECT_ID> (or set gcloud project)",
        )

    if zone:
        add_check("Zone", "PASS", f"Zone: {zone}")
    else:
        add_check("Zone", "FAIL", "Zone is unset", "Run: eopod configure --zone <ZONE> (or set gcloud compute/zone)")

    if project_id and gcloud_path:
        api_result = subprocess.run(
            [
                "gcloud",
                "services",
                "list",
                "--enabled",
                f"--project={project_id}",
                "--filter=config.name:tpu.googleapis.com",
                "--format=value(config.name)",
            ],
            capture_output=True,
            text=True,
        )
        if api_result.returncode == 0 and "tpu.googleapis.com" in api_result.stdout:
            add_check("TPU API", "PASS", "tpu.googleapis.com is enabled")
        else:
            add_check(
                "TPU API",
                "FAIL",
                "tpu.googleapis.com not enabled (or not accessible)",
                f"Run: gcloud services enable tpu.googleapis.com --project={project_id}",
            )

    if tpu_name:
        detail = f"TPU name: {tpu_name}"
        if queued_resource:
            detail += f" (queued: {queued_resource})"
        add_check("TPU Name", "PASS", detail)
    else:
        add_check(
            "TPU Name",
            "FAIL",
            "TPU name is unset",
            "Run: eopod configure --tpu-name <TPU_NAME> (or run from within TPU VM for auto-detect)",
        )

    tpu_reachable = False
    if project_id and zone and tpu_name and gcloud_path:
        tpu_info = _describe_tpu_vm(project_id, zone, tpu_name)
        if tpu_info:
            state = tpu_info.get("state", "UNKNOWN")
            add_check("TPU Describe", "PASS", f"Reachable via API (state: {state})")
            tpu_reachable = True
        else:
            add_check(
                "TPU Describe",
                "FAIL",
                "Failed to describe TPU VM (missing permissions, wrong zone/name, or stale config)",
                "Re-run: eopod configure (will auto-refresh stale TPU names when possible)",
            )

    if tpu_reachable and project_id and zone and tpu_name:
        ssh_result = subprocess.run(
            [
                "gcloud",
                "compute",
                "tpus",
                "tpu-vm",
                "ssh",
                tpu_name,
                f"--project={project_id}",
                f"--zone={zone}",
                "--worker=0",
                "--command=true",
            ],
            capture_output=True,
            text=True,
            timeout=30,
        )
        if ssh_result.returncode == 0:
            add_check("TPU SSH", "PASS", "Non-interactive SSH command succeeded")
        else:
            ssh_error = (
                (ssh_result.stderr or ssh_result.stdout).strip().splitlines()[-1]
                if (ssh_result.stderr or ssh_result.stdout)
                else "SSH command failed"
            )
            add_check(
                "TPU SSH",
                "WARN",
                ssh_error,
                "Check VM networking/firewall, IAM, and SSH key setup. Re-run: gcloud compute tpus tpu-vm ssh ...",
            )

    table = Table(title="eopod Doctor")
    table.add_column("Check", style="cyan")
    table.add_column("Status")
    table.add_column("Details")
    table.add_column("Suggested Fix")

    status_color = {"PASS": "green", "WARN": "yellow", "FAIL": "red"}
    for item in checks:
        color = status_color.get(item["status"], "white")
        table.add_row(
            item["name"],
            f"[{color}]{item['status']}[/{color}]",
            item["details"],
            item["fix"],
        )

    pass_count = sum(1 for item in checks if item["status"] == "PASS")
    warn_count = sum(1 for item in checks if item["status"] == "WARN")
    fail_count = sum(1 for item in checks if item["status"] == "FAIL")

    console.print(table)
    console.print(
        f"[cyan]Summary:[/cyan] [green]{pass_count} pass[/green], "
        f"[yellow]{warn_count} warn[/yellow], [red]{fail_count} fail[/red]"
    )


@cli.command()
@click.option("--worker", default="all", help='Specific worker or "all"')
@click.option("--shell", default="/bin/bash", help="Shell to use (default: /bin/bash)")
def terminal(worker, shell):
    """Open an interactive terminal session with TPU workers"""
    _ = shell
    resolved = _resolve_runtime_credentials(require_tpu=True, verbose=True)
    if not resolved:
        console.print("[red]Could not resolve project/zone/tpu automatically. Run 'eopod configure' first.[/red]")
        return
    _config, project_id, zone, tpu_name, _queued_resource = resolved

    # Show welcome message
    welcome_panel = Panel.fit(
        f"[bold green]TPU Interactive Terminal[/bold green]\n"
        f"[cyan]TPU:[/cyan] {tpu_name}\n"
        f"[cyan]Worker:[/cyan] {worker}\n"
        f"[cyan]Zone:[/cyan] {zone}\n\n"
        f"[yellow]Commands:[/yellow]\n"
        f"  [bold]exit[/bold] or [bold]quit[/bold] - Exit terminal\n"
        f"  [bold]:help[/bold] - Show help\n"
        f"  [bold]:status[/bold] - Show TPU status\n"
        f"  [bold]:worker <num>[/bold] - Switch to specific worker\n"
        f"  [bold]:background <cmd>[/bold] - Run command in background\n",
        title="Welcome",
        border_style="blue",
    )
    console.print(welcome_panel)

    current_worker = worker

    while True:
        try:
            # Create a rich prompt
            prompt_text = (
                f"[bold green]eopod[/bold green]:[bold blue]{tpu_name}[/bold blue]:"
                f"[bold yellow]worker-{current_worker}[/bold yellow]$ "
            )
            command = Prompt.ask(prompt_text, console=console)

            if not command.strip():
                continue

            # Handle special commands
            if command.lower() in ["exit", "quit"]:
                console.print("[yellow]Goodbye![/yellow]")
                break
            elif command.startswith(":help"):
                help_panel = Panel.fit(
                    "[bold yellow]Available Commands:[/bold yellow]\n\n"
                    "[bold]Regular commands[/bold] - Execute directly on TPU\n"
                    "[bold]:help[/bold] - Show this help\n"
                    "[bold]:status[/bold] - Show current TPU status\n"
                    "[bold]:worker <num>[/bold] - Switch to specific worker (or 'all')\n"
                    "[bold]:background <cmd>[/bold] - Run command in background\n"
                    "[bold]:history[/bold] - Show recent command history\n"
                    "[bold]:clear[/bold] - Clear screen\n"
                    "[bold]exit/quit[/bold] - Exit terminal\n",
                    title="Help",
                    border_style="yellow",
                )
                console.print(help_panel)
            elif command.startswith(":status"):
                # Run async status command
                asyncio.run(_show_status_async(project_id, zone, tpu_name))
            elif command.startswith(":worker"):
                parts = command.split()
                if len(parts) == 2:
                    new_worker = parts[1]
                    current_worker = new_worker
                    console.print(f"[green]Switched to worker: {current_worker}[/green]")
                else:
                    console.print("[red]Usage: :worker <worker_num|all>[/red]")
            elif command.startswith(":background"):
                bg_command = command[11:].strip()  # Remove ':background '
                if bg_command:
                    console.print(f"[yellow]Running in background: {bg_command}[/yellow]")
                    asyncio.run(_execute_background_command(project_id, zone, tpu_name, bg_command, current_worker))
                else:
                    console.print("[red]Usage: :background <command>[/red]")
            elif command.startswith(":history"):
                _show_history()
            elif command.startswith(":clear"):
                console.clear()
            else:
                # Execute regular command on TPU
                console.print(f"[cyan]Executing on worker {current_worker}: {command}[/cyan]")
                asyncio.run(_execute_terminal_command(project_id, zone, tpu_name, command, current_worker))

        except KeyboardInterrupt:
            console.print("\n[yellow]Use 'exit' or 'quit' to leave the terminal[/yellow]")
        except EOFError:
            console.print("\n[yellow]Goodbye![/yellow]")
            break
        except Exception as e:
            console.print(f"[red]Error: {e}[/red]")


async def _execute_background_command(project_id, zone, tpu_name, command, worker):
    """Helper function to execute background commands"""
    try:
        tpu = TPUManager(project_id, zone, tpu_name)
        returncode, pid, stderr = await tpu.execute_command(command, worker, background=True)

        if returncode == 0:
            console.print(f"[green]Background process started with PID: {pid}[/green]")
        else:
            console.print(f"[red]Failed to start background process: {stderr}[/red]")

    except Exception as e:
        console.print(f"[red]Error starting background process: {e}[/red]")


def _show_history():
    """Helper function to show command history in terminal"""
    config = EOConfig()

    if not config.history_file.exists():
        console.print("[yellow]No command history found.[/yellow]")
        return

    with open(config.history_file, "r") as f:
        history = yaml.safe_load(f) or []

    if not history:
        console.print("[yellow]No command history found.[/yellow]")
        return

    table = Table(title="Recent Command History")
    table.add_column("Time", style="cyan")
    table.add_column("Command", style="white")
    table.add_column("Status", style="green")

    # Show last 10 commands
    for entry in history[-10:]:
        timestamp = entry["timestamp"].split("T")[1][:8]  # Show only time part
        table.add_row(timestamp, entry["command"][:50], entry["status"])

    console.print(table)


@cli.command()
@async_command
async def smi():
    """Show TPU utilization (like nvidia-smi)"""
    _config, tpu = _get_config_and_manager()

    try:
        _, text, _ = await tpu.execute_command(
            f'{PYTHON_PATH} -c "from tpu_info import cli; cli.print_chip_info()"',
            stream=False,
        )

        pattern = r"│\s+(\d+)\s+│\s+([\d.]+ GiB / [\d.]+ GiB)\s+│\s+([\d.]+%)\s+│"
        matches = re.findall(pattern, text)

        if matches:
            table = Table(title="[bold magenta]TPU Utilization[/bold magenta]")
            table.add_column("📟 Device", justify="center", style="bold blue")
            table.add_column("💾 Memory Usage", justify="left", style="white")
            table.add_column("âš¡ Duty Cycle", justify="right", style="white")

            for device_index, memory_usage, duty_cycle in matches:
                table.add_row(device_index, memory_usage, duty_cycle)

            console.print(table)
        else:
            console.print("[yellow]Could not parse TPU utilization data[/yellow]")
            console.print(text)  # Show raw output

    except Exception as e:
        console.print(f"[red]Error getting TPU utilization: {e}[/red]")


@cli.command()
@async_command
async def clean_logs():
    """Clean up logs and temporary files on the TPU VM"""
    _config, tpu = _get_config_and_manager()
    command = r"""
    sudo bash -c 'echo "[*] Vacuuming journal logs (keeping 1 second)..." && journalctl --vacuum-time=1s && echo "[*] Deleting rotated/compressed logs..." && find /var/log -type f \( -name "*.gz" -o -name "*.1" -o -name "*.old" -o -name "*.bak" -o -name "*-????????" -o -name "*.log.[0-9]*" \) -print -delete && echo "[*] Truncating active log files..." && find /var/log -type f -name "*.log" -exec truncate -s 0 {} \; && echo "[*] Vacuuming journal logs to 50MB cap..." && journalctl --vacuum-size=5M && docker system prune -af --volumes && echo "[✔] Cleanup complete."'
    """  # noqa
    await tpu.execute_command(command.strip(), stream=False)


@cli.command()
@click.option("--external", is_flag=True, help="Use external IPs instead of internal IPs")
@click.option("--stop", is_flag=True, help="Stop the Ray cluster")
@click.option("--verify", is_flag=True, help="Verify the Ray cluster setup")
@click.option("--tpu-version", help="Set TPU version (auto-detected if not provided)")
@click.option("--tpu-slice", type=int, help="Set TPU slice size (auto-detected if not provided)")
@click.option("--num-slices", type=int, default=1, help="Number of TPU slices to combine (default: 1)")
@click.option("--ssh-user", help="SSH username to use")
@click.option("--config", help="Path to YAML config file with IP addresses")
@click.option("--test-ssh", is_flag=True, help="Test SSH connectivity to all nodes")
@click.option("--external-ips", help="Comma-separated list of external IPs")
@click.option("--self-job", is_flag=True, help="Run only on the current machine (no SSH)")
@click.option("--slice-config", help="Path to YAML config file with slice configurations")
@click.option("--python-path", help="Path to venv or python interpreter")
@click.option("--head-node-ip", help="IP address of external head node (if not using first IP in list)")
@click.option("--head-only", is_flag=True, help="Run this node as head only (no TPU resources)")
@click.option("--spot-tpu-name", help="Name of spot TPU to configure as workers")
@click.option("--spot-tpu-project-id", help="Project ID for spot TPU (defaults to current)")
@click.option("--spot-tpu-zone", help="Zone for spot TPU (defaults to current)")
@click.option("--head-ip", help="IP address to use as head (defaults to current machine)")
def auto_config_ray(
    external,
    stop,
    verify,
    tpu_version,
    tpu_slice,
    num_slices,
    ssh_user,
    config,
    test_ssh,
    external_ips,
    self_job,
    slice_config,
    python_path,
    head_node_ip,
    head_only,
    spot_tpu_name,
    spot_tpu_project_id,
    spot_tpu_zone,
    head_ip,
):
    """
    Auto-configure Ray on TPU cluster using internal IPs from current setup.
    Automatically detects TPU version and slice size if not specified.

    Examples:
        # Setup v4-64 with external v2-8 head:
        eformer auto-config-ray --self-job --head-node-ip 10.x.x.x

        # Setup v2-8 as head-only:
        eformer auto-config-ray --self-job --head-only --head-node-ip 10.x.x.x
    """
    import asyncio
    import re
    import subprocess

    try:
        current_internal_ip = subprocess.check_output("hostname -I", shell=True, text=True).strip().split()[0]
        current_external_ip = subprocess.check_output("curl -s https://api.ipify.org", shell=True, text=True).strip()
    except subprocess.CalledProcessError:
        console.print("[red]Could not determine current machine's IP[/red]")
        return

    if not head_ip:
        head_ip = current_external_ip if external else current_internal_ip
        console.print(f"[green]Using current machine as head: {head_ip}[/green]")

    is_head_machine = head_ip in [current_internal_ip, current_external_ip]
    if spot_tpu_name and (not spot_tpu_project_id or not spot_tpu_zone):
        try:
            if not spot_tpu_project_id:
                spot_tpu_project_id = _read_gcloud_config_property("project")
                console.print(f"[green]Using current project for spot TPU: {spot_tpu_project_id}[/green]")

            if not spot_tpu_zone:
                spot_tpu_zone = _read_gcloud_config_property("compute/zone")
                console.print(f"[green]Using current zone for spot TPU: {spot_tpu_zone}[/green]")
        except Exception as e:
            console.print(f"[red]Failed to get default project/zone: {e!s}[/red]")
            return

    if spot_tpu_name:
        console.print(f"[cyan]Configuring spot TPU '{spot_tpu_name}' to connect to head at {head_ip}[/cyan]")

        async def fetch_spot_tpu_ips():
            console.print(f"[yellow]Fetching IPs for spot TPU: {spot_tpu_name}[/yellow]")
            manager = TPUManager(spot_tpu_project_id, spot_tpu_zone, spot_tpu_name)

            try:
                internal_ips_dict = await manager.get_internal_ips()
                internal_ips_list = list(internal_ips_dict.values())

                external_ips_list = []
                if external:
                    external_ips_str = await manager.get_external_ips()
                    external_ips_list = [ip.strip() for ip in external_ips_str.split(",") if ip.strip()]

                return internal_ips_list, external_ips_list
            except Exception as e:
                console.print(f"[red]Failed to fetch IPs for {spot_tpu_name}: {e!s}[/red]")
                return [], []

        spot_internal_ips, spot_external_ips = asyncio.run(fetch_spot_tpu_ips())

        if not spot_internal_ips:
            console.print("[red]No IPs found for spot TPU[/red]")
            return

        if is_head_machine and not stop:
            console.print("[cyan]Setting up current machine as Ray head...[/cyan]")

            try:
                accelerator_type = subprocess.check_output(
                    "curl -s 'http://metadata.google.internal/computeMetadata/v1/instance/attributes/accelerator-type'"
                    " -H 'Metadata-Flavor: Google'",
                    shell=True,
                    text=True,
                ).strip()

                match = re.match(r"v(\d+[a-zA-Z]?)-(\d+)", accelerator_type)
                if match:
                    head_tpu_version = f"v{match.group(1)}"
                    head_tpu_slice = int(match.group(2))
                    head_has_tpu = True
                    console.print(f"[green]Head has TPU: {head_tpu_version}-{head_tpu_slice}[/green]")
                else:
                    head_has_tpu = False
            except Exception:
                head_has_tpu = False
                console.print("[yellow]Head machine has no TPU (CPU-only head)[/yellow]")

            cmd = [
                python_path or PYTHON_PATH,
                "-m",
                "eformer.executor.patch_tpus_ray",
                "--self-job",
                "--head-only" if not head_has_tpu else None,
                "--head-node-ip",
                head_ip,
            ]
            cmd = [part for part in cmd if part]

            subprocess.run(" ".join(cmd), shell=True, check=True)
            console.print("[green]Ray head started successfully[/green]")
            time.sleep(5)

        console.print(f"[cyan]Configuring {len(spot_internal_ips)} spot TPU workers...[/cyan]")

        cmd_parts = [
            python_path or PYTHON_PATH,
            "-m",
            "eformer.executor.patch_tpus_ray",
            "--tpu-version",
            tpu_version or "v4",
            "--tpu-slice",
            str(tpu_slice or len(spot_internal_ips) * 8),
            "--internal-ips",
            ",".join(spot_internal_ips),
            "--self-job",
            "--head-node-ip",
            head_ip,
        ]

        if external and spot_external_ips:
            cmd_parts.extend(["--external-ips", ",".join(spot_external_ips)])

        if stop:
            cmd_parts.append("--stop")

        gcloud_cmd = [
            "gcloud",
            "compute",
            "tpus",
            "tpu-vm",
            "ssh",
            spot_tpu_name,
            f"--zone={spot_tpu_zone}",
            f"--project={spot_tpu_project_id}",
            "--worker=all",
            "--command",
            shlex.quote(" ".join(shlex.quote(str(part)) for part in cmd_parts)),
        ]

        final_cmd = " ".join(gcloud_cmd)
        console.print(f"[yellow]Executing on spot TPU: {final_cmd}[/yellow]")

        try:
            subprocess.run(final_cmd, shell=True, check=True)
            console.print("[green]Spot TPU workers configured successfully![/green]")

            if not stop:
                console.print("\n[cyan]Ray cluster ready![/cyan]")
                console.print(f"[cyan]Head node: {head_ip}[/cyan]")
                console.print(f"[cyan]Workers: {len(spot_internal_ips)} nodes from {spot_tpu_name}[/cyan]")
                console.print(f"[cyan]Dashboard: http://{head_ip}:8265[/cyan]")
                console.print(f"[cyan]Connect with: ray.init(address='{head_ip}:6379')[/cyan]")

        except subprocess.CalledProcessError as e:
            console.print(f"[red]Failed to configure spot TPU workers: {e!s}[/red]")
    else:
        if head_only:
            console.print("[cyan]Running in head-only mode (no TPU resources)[/cyan]")
            cmd_parts = [
                EOPOD_PATH,
                "run",
                python_path or PYTHON_PATH,
                "-m",
                "eformer.executor.patch_tpus_ray",
                "--head-only",
                "--self-job",
            ]

            if head_node_ip:
                cmd_parts.extend(["--head-node-ip", head_node_ip])
            else:
                # Get current machine's IP for head-only mode
                try:
                    local_ip = subprocess.check_output("hostname -I", shell=True, text=True).strip().split()[0]
                    cmd_parts.extend(["--head-node-ip", local_ip])
                    console.print(f"[green]Using local IP as head node: {local_ip}[/green]")
                except subprocess.CalledProcessError:
                    console.print("[red]Could not determine local IP for head node[/red]")
                    return

            if stop:
                cmd_parts.append("--stop")
            if verify:
                cmd_parts.append("--verify")

            final_cmd = " ".join(shlex.quote(str(part)) for part in cmd_parts)
            console.print(f"[yellow]Executing: {final_cmd}[/yellow]")

            try:
                subprocess.run(final_cmd, shell=True, check=True, text=True)
                console.print("[green]Head-only node configuration completed successfully![/green]")
                if not stop:
                    console.print("[cyan]Now run the worker setup on your TPU nodes with --head-node-ip[/cyan]")
            except subprocess.CalledProcessError as e:
                console.print(f"[red]Failed to configure head-only node: {e!s}[/red]")
            return

        # Regular TPU node setup (existing logic)
        try:
            console.print("[yellow]Fetching internal IPs from eopod...[/yellow]")
            internal_ips_output = subprocess.check_output(
                f"{EOPOD_PATH} get-internal-ips", shell=True, text=True
            ).strip()

            sanitized_ips_output = internal_ips_output.replace("\n", "").replace("\r", "")
            internal_ips = [ip.strip() for ip in sanitized_ips_output.split(",") if ip.strip()]
            if not internal_ips:
                console.print("[red]No internal IPs found. Make sure eopod is configured correctly.[/red]")
                return

            internal_ips_str = ",".join(internal_ips)
            console.print(f"[green]Found internal IPs: {internal_ips_str}[/green]")

        except subprocess.CalledProcessError as e:
            console.print(f"[red]Failed to get internal IPs: {e!s}[/red]")
            return

        if not tpu_version or not tpu_slice:
            try:
                console.print("[yellow]Auto-detecting TPU configuration...[/yellow]")
                accelerator_type = subprocess.check_output(
                    "curl -s 'http://metadata.google.internal/computeMetadata/v1/instance/attributes/accelerator-type' -H 'Metadata-Flavor: Google'",  # noqa
                    shell=True,
                    text=True,
                ).strip()

                match = re.match(r"v(\d+[a-zA-Z]?)-(\d+)", accelerator_type)
                if match:
                    detected_version = match.group(1)
                    detected_slice = int(match.group(2))

                    if not tpu_version:
                        tpu_version = f"v{detected_version}"
                        console.print(f"[green]Auto-detected TPU version: {tpu_version}[/green]")

                    if not tpu_slice:
                        tpu_slice = detected_slice
                        console.print(f"[green]Auto-detected TPU slice size: {tpu_slice}[/green]")
                else:
                    console.print(
                        f"[yellow]Could not parse accelerator type: {accelerator_type}. Please provide --tpu-version and --tpu-slice manually.[/yellow]"  # noqa
                    )
                    if not tpu_version or not tpu_slice:
                        console.print("[red]TPU version and slice size are required. Exiting.[/red]")
                        return
            except subprocess.CalledProcessError:
                console.print(
                    "[yellow]Failed to auto-detect TPU configuration. Please provide --tpu-version and --tpu-slice manually.[/yellow]"  # noqa
                )
                if not tpu_version or not tpu_slice:
                    console.print("[red]TPU version and slice size are required. Exiting.[/red]")
                    return

        # Show configuration summary
        if head_node_ip:
            console.print(f"[cyan]Using external head node at: {head_node_ip}[/cyan]")
            console.print("[cyan]This TPU cluster will connect as workers to the external head[/cyan]")

        cmd_parts = [
            EOPOD_PATH,
            "run",
            python_path or PYTHON_PATH,
            "-m",
            "eformer.executor.patch_tpus_ray",
        ]

        cmd_parts.extend(["--tpu-version", str(tpu_version)])
        cmd_parts.extend(["--tpu-slice", str(tpu_slice)])
        cmd_parts.extend(["--num-slices", str(num_slices)])
        cmd_parts.extend(["--internal-ips", internal_ips_str])

        if external:
            cmd_parts.append("--external")
        if stop:
            cmd_parts.append("--stop")
        if verify:
            cmd_parts.append("--verify")
        if self_job:
            cmd_parts.append("--self-job")
        if test_ssh:
            cmd_parts.append("--test-ssh")

        if ssh_user:
            cmd_parts.extend(["--ssh-user", ssh_user])
        if config:
            cmd_parts.extend(["--config", config])
        if external_ips:
            cmd_parts.extend(["--external-ips", external_ips])
        if slice_config:
            cmd_parts.extend(["--slice-config", slice_config])
        if head_node_ip:
            cmd_parts.extend(["--head-node-ip", head_node_ip])

        final_cmd = " ".join(shlex.quote(str(part)) for part in cmd_parts)
        console.print(f"[yellow]Executing: {final_cmd}[/yellow]")

        try:
            subprocess.run(final_cmd, shell=True, check=True, text=True)
            console.print("[green]Ray cluster configuration completed successfully![/green]")

            if head_node_ip and not stop:
                console.print(f"\n[cyan]Workers connected to head node at: {head_node_ip}[/cyan]")
                console.print(f"[cyan]Ray dashboard available at: http://{head_node_ip}:8265[/cyan]")
                console.print(f"[cyan]Connect to cluster with: ray.init(address='{head_node_ip}:6379')[/cyan]")

        except subprocess.CalledProcessError as e:
            console.print(f"[red]Failed to configure Ray cluster: {e!s}[/red]")


@cli.command()
@click.option(
    "--port",
    "-p",
    type=int,
    required=True,
    multiple=True,
    help="Port number(s) to open. Can be specified multiple times (e.g., -p 80 -p 443).",
)
@click.option(
    "--direction",
    type=click.Choice(["ingress", "egress", "both"], case_sensitive=False),
    default="both",
    show_default=True,
    help="Direction of traffic to allow.",
)
@click.option(
    "--protocol",
    default="tcp",
    show_default=True,
    type=click.Choice(["tcp", "udp", "icmp", "all"], case_sensitive=False),
    help="Protocol to allow.",
)
@click.option(
    "--target-tag",
    default=None,
    help="Network tag for VMs. If omitted, eopod auto-detects the TPU VM tag from GCP.",
)
@click.option(
    "--source-ranges",
    default="0.0.0.0/0",
    show_default=True,
    help="Source IP CIDR range for ingress rules.",
)
@click.option(
    "--destination-ranges",
    default="0.0.0.0/0",
    show_default=True,
    help="Destination IP CIDR range for egress rules.",
)
@click.option(
    "--priority",
    type=int,
    default=1000,
    show_default=True,
    help="Firewall rule priority (lower number = higher priority).",
)
@click.option(
    "--description",
    default="Rule created by eopod",
    show_default=True,
    help="Description for the firewall rule.",
)
@click.option(
    "--network",
    default=None,
    help="Network name to use. If omitted, will attempt to detect from TPU configuration.",
)
@click.option(
    "--update-existing/--skip-existing",
    default=False,
    show_default=True,
    help="Whether to update existing rules or skip them.",
)
@click.option(
    "--verify-tag",
    is_flag=True,
    default=False,
    help="Verify that the target tag is applied to the TPU VM before creating rules.",
)
@async_command
async def open_port(
    port,
    direction,
    protocol,
    target_tag,
    source_ranges,
    destination_ranges,
    priority,
    description,
    network,
    update_existing,
    verify_tag,
):
    """Creates GCP firewall rules to open ports for TPU VMs."""
    resolved = _resolve_runtime_credentials(require_tpu=True, verbose=True)
    if not resolved:
        console.print("[red]Could not resolve project/zone/tpu automatically. Run 'eopod configure' first.[/red]")
        return
    _config, project_id, zone, tpu_name, _queued_resource = resolved

    safe_tpu_name = tpu_name.lower().replace("_", "-")

    tpu_manager = TPUManager(project_id, zone, tpu_name)
    tpu_info = None

    try:
        tpu_info = await tpu_manager.get_tpu_info()
    except Exception as e:
        console.print(f"[yellow]Could not fetch TPU details automatically: {e}[/yellow]")

    if network is None:
        try:
            raw_network = (tpu_info or {}).get("networkConfig", {}).get("network", "default")
            network = _network_name(raw_network)
            console.print(f"Using network: {network}")
        except Exception as e:
            console.print(f"[yellow]Could not determine network from TPU config: {e}[/yellow]")
            console.print("[yellow]Using 'default' network instead[/yellow]")
            network = "default"

    detected_target_tags = []
    if target_tag is None:
        detected_target_tags = _resolve_tpu_vm_target_tags(project_id, zone, tpu_name, tpu_info=tpu_info)
        if detected_target_tags:
            console.print(f"[green]Auto-detected TPU VM target tag(s): {', '.join(detected_target_tags)}[/green]")
        else:
            detected_target_tags = [safe_tpu_name]
            console.print(
                "[yellow]Could not auto-detect the TPU VM target tag. "
                f"Falling back to '{safe_tpu_name}'. If the port still does not open, "
                "pass --target-tag explicitly.[/yellow]"
            )
    else:
        detected_target_tags = _normalize_target_tags(target_tag.split(","))
        console.print(f"[green]Using user-provided target tag(s): {', '.join(detected_target_tags)}[/green]")

    if verify_tag:
        try:
            vm_tags = _resolve_tpu_vm_target_tags(project_id, zone, tpu_name, tpu_info=tpu_info)
            missing_tags = [tag for tag in detected_target_tags if tag not in vm_tags]
            if missing_tags:
                console.print(f"[red]Target tag(s) '{', '.join(missing_tags)}' are not applied to the TPU VM![/red]")
                console.print(f"[yellow]Available tags: {', '.join(vm_tags) if vm_tags else 'None'}[/yellow]")
                if click.confirm("Do you want to add this tag to the TPU VM?", default=False):
                    console.print("[yellow]Adding tag functionality not implemented yet[/yellow]")
                else:
                    return
        except Exception as e:
            console.print(f"[yellow]Could not verify tags: {e}[/yellow]")
            if not click.confirm("Continue without verifying tags?", default=False):
                return

    directions_to_process = ["ingress", "egress"] if direction.lower() == "both" else [direction.lower()]

    for p in port:
        for current_direction in directions_to_process:
            rule_name = f"a-allow-{safe_tpu_name}-{p}-{current_direction}".lower()[:63]

            try:
                cmd = f"gcloud compute firewall-rules describe {rule_name} --project={project_id} --format=json"
                existing_rule_raw = await run_command(cmd, capture_output=True)
                existing_rule = json.loads(existing_rule_raw)
                rule_exists = True
                console.print(f"Rule '{rule_name}' already exists.")

                existing_target_tags = _normalize_target_tags(existing_rule.get("targetTags", []))
                if sorted(existing_target_tags) != sorted(detected_target_tags):
                    console.print(
                        "[yellow]Existing rule uses stale target tags. "
                        "Updating it to match the TPU VM automatically.[/yellow]"
                    )
                elif not update_existing:
                    console.print("[yellow]Skipping (use --update-existing to update)[/yellow]")
                    continue

            except Exception:
                rule_exists = False

            cmd_parts = [
                "gcloud",
                "compute",
                "firewall-rules",
                "update" if rule_exists else "create",
                rule_name,
                f"--project={project_id}",
                f"--priority={priority}",
            ]

            if not rule_exists:
                cmd_parts.extend(
                    [
                        f"--direction={current_direction.upper()}",
                        f"--network={network}",
                        "--action=ALLOW",
                    ]
                )

            if protocol == "all":
                cmd_parts.append("--rules=all")
            elif protocol == "icmp":
                cmd_parts.append("--rules=icmp")
            else:
                cmd_parts.append(f"--rules={protocol}:{p}")

            if current_direction.lower() == "ingress":
                cmd_parts.append(f"--source-ranges={source_ranges}")

            if current_direction.lower() == "egress":
                cmd_parts.append(f"--destination-ranges={destination_ranges}")

            cmd_parts.append(f"--target-tags={','.join(detected_target_tags)}")

            cmd_parts.append(f"--description={shlex.quote(description)}")

            cmd = " ".join(cmd_parts)
            console.print(f"[green]Executing:[/green]\n{cmd}")

            try:
                await run_command(cmd)
                console.print(
                    f"[green]Successfully {'updated' if rule_exists else 'created'} firewall rule '{rule_name}'[/green]"
                )
            except Exception as e:
                console.print(f"[red]Failed to {'update' if rule_exists else 'create'} firewall rule: {e}[/red]")


@cli.command()
def errors():
    """Show recent command execution errors"""
    config = EOConfig()

    if not config.error_log_file.exists():
        console.print("[yellow]No error log found.[/yellow]")
        return

    with open(config.error_log_file, "r") as f:
        try:
            error_log = yaml.safe_load(f) or []
        except yaml.YAMLError as e:
            console.print(f"[red]Error loading error log: {e}[/red]")
            return

    if not error_log:
        console.print("[green]No errors found![/green]")
        return

    table = Table(title="Error Log", style="red")
    table.add_column("Timestamp")
    table.add_column("Command")
    table.add_column("Error")

    for entry in error_log[-10:]:  # Show last 10 errors
        table.add_row(entry["timestamp"], entry["command"][:30], entry["error"][:100])

    console.print(table)


[docs]def main(): """ Main entry point for the eopod CLI. """ try: asyncio.run(cli()) except click.exceptions.Exit as e: if e.exit_code != 0: console.print(f"[red]Error:[/red] Command failed with exit code {e.exit_code}") logging.exception("Click command failed") except Exception as e: console.print(f"[red]Unexpected Error:[/red] {e!s}")