main
  1from __future__ import annotations
  2
  3import os
  4import sys
  5import shutil
  6import tarfile
  7import platform
  8import subprocess
  9from typing import TYPE_CHECKING, List
 10from pathlib import Path
 11from argparse import ArgumentParser
 12
 13import httpx
 14
 15from .._errors import CLIError, SilentCLIError
 16from .._models import BaseModel
 17
 18if TYPE_CHECKING:
 19    from argparse import _SubParsersAction
 20
 21
 22def register(subparser: _SubParsersAction[ArgumentParser]) -> None:
 23    sub = subparser.add_parser("migrate")
 24    sub.set_defaults(func=migrate, args_model=MigrateArgs, allow_unknown_args=True)
 25
 26    sub = subparser.add_parser("grit")
 27    sub.set_defaults(func=grit, args_model=GritArgs, allow_unknown_args=True)
 28
 29
 30class GritArgs(BaseModel):
 31    # internal
 32    unknown_args: List[str] = []
 33
 34
 35def grit(args: GritArgs) -> None:
 36    grit_path = install()
 37
 38    try:
 39        subprocess.check_call([grit_path, *args.unknown_args])
 40    except subprocess.CalledProcessError:
 41        # stdout and stderr are forwarded by subprocess so an error will already
 42        # have been displayed
 43        raise SilentCLIError() from None
 44
 45
 46class MigrateArgs(BaseModel):
 47    # internal
 48    unknown_args: List[str] = []
 49
 50
 51def migrate(args: MigrateArgs) -> None:
 52    grit_path = install()
 53
 54    try:
 55        subprocess.check_call([grit_path, "apply", "openai", *args.unknown_args])
 56    except subprocess.CalledProcessError:
 57        # stdout and stderr are forwarded by subprocess so an error will already
 58        # have been displayed
 59        raise SilentCLIError() from None
 60
 61
 62# handles downloading the Grit CLI until they provide their own PyPi package
 63
 64KEYGEN_ACCOUNT = "custodian-dev"
 65
 66
 67def _cache_dir() -> Path:
 68    xdg = os.environ.get("XDG_CACHE_HOME")
 69    if xdg is not None:
 70        return Path(xdg)
 71
 72    return Path.home() / ".cache"
 73
 74
 75def _debug(message: str) -> None:
 76    if not os.environ.get("DEBUG"):
 77        return
 78
 79    sys.stdout.write(f"[DEBUG]: {message}\n")
 80
 81
 82def install() -> Path:
 83    """Installs the Grit CLI and returns the location of the binary"""
 84    if sys.platform == "win32":
 85        raise CLIError("Windows is not supported yet in the migration CLI")
 86
 87    _debug("Using Grit installer from GitHub")
 88
 89    platform = "apple-darwin" if sys.platform == "darwin" else "unknown-linux-gnu"
 90
 91    dir_name = _cache_dir() / "openai-python"
 92    install_dir = dir_name / ".install"
 93    target_dir = install_dir / "bin"
 94
 95    target_path = target_dir / "grit"
 96    temp_file = target_dir / "grit.tmp"
 97
 98    if target_path.exists():
 99        _debug(f"{target_path} already exists")
100        sys.stdout.flush()
101        return target_path
102
103    _debug(f"Using Grit CLI path: {target_path}")
104
105    target_dir.mkdir(parents=True, exist_ok=True)
106
107    if temp_file.exists():
108        temp_file.unlink()
109
110    arch = _get_arch()
111    _debug(f"Using architecture {arch}")
112
113    file_name = f"grit-{arch}-{platform}"
114    download_url = f"https://github.com/getgrit/gritql/releases/latest/download/{file_name}.tar.gz"
115
116    sys.stdout.write(f"Downloading Grit CLI from {download_url}\n")
117    with httpx.Client() as client:
118        download_response = client.get(download_url, follow_redirects=True)
119        if download_response.status_code != 200:
120            raise CLIError(f"Failed to download Grit CLI from {download_url}")
121        with open(temp_file, "wb") as file:
122            for chunk in download_response.iter_bytes():
123                file.write(chunk)
124
125    unpacked_dir = target_dir / "cli-bin"
126    unpacked_dir.mkdir(parents=True, exist_ok=True)
127
128    with tarfile.open(temp_file, "r:gz") as archive:
129        if sys.version_info >= (3, 12):
130            archive.extractall(unpacked_dir, filter="data")
131        else:
132            archive.extractall(unpacked_dir)
133
134    _move_files_recursively(unpacked_dir, target_dir)
135
136    shutil.rmtree(unpacked_dir)
137    os.remove(temp_file)
138    os.chmod(target_path, 0o755)
139
140    sys.stdout.flush()
141
142    return target_path
143
144
145def _move_files_recursively(source_dir: Path, target_dir: Path) -> None:
146    for item in source_dir.iterdir():
147        if item.is_file():
148            item.rename(target_dir / item.name)
149        elif item.is_dir():
150            _move_files_recursively(item, target_dir)
151
152
153def _get_arch() -> str:
154    architecture = platform.machine().lower()
155
156    # Map the architecture names to Grit equivalents
157    arch_map = {
158        "x86_64": "x86_64",
159        "amd64": "x86_64",
160        "armv7l": "aarch64",
161        "arm64": "aarch64",
162    }
163
164    return arch_map.get(architecture, architecture)