main
  1# fork of https://github.com/asottile/blacken-docs adapted for ruff
  2from __future__ import annotations
  3
  4import re
  5import sys
  6import argparse
  7import textwrap
  8import contextlib
  9import subprocess
 10from typing import Match, Optional, Sequence, Generator, NamedTuple, cast
 11
 12MD_RE = re.compile(
 13    r"(?P<before>^(?P<indent> *)```\s*python\n)" r"(?P<code>.*?)" r"(?P<after>^(?P=indent)```\s*$)",
 14    re.DOTALL | re.MULTILINE,
 15)
 16MD_PYCON_RE = re.compile(
 17    r"(?P<before>^(?P<indent> *)```\s*pycon\n)" r"(?P<code>.*?)" r"(?P<after>^(?P=indent)```.*$)",
 18    re.DOTALL | re.MULTILINE,
 19)
 20PYCON_PREFIX = ">>> "
 21PYCON_CONTINUATION_PREFIX = "..."
 22PYCON_CONTINUATION_RE = re.compile(
 23    rf"^{re.escape(PYCON_CONTINUATION_PREFIX)}( |$)",
 24)
 25DEFAULT_LINE_LENGTH = 100
 26
 27
 28class CodeBlockError(NamedTuple):
 29    offset: int
 30    exc: Exception
 31
 32
 33def format_str(
 34    src: str,
 35) -> tuple[str, Sequence[CodeBlockError]]:
 36    errors: list[CodeBlockError] = []
 37
 38    @contextlib.contextmanager
 39    def _collect_error(match: Match[str]) -> Generator[None, None, None]:
 40        try:
 41            yield
 42        except Exception as e:
 43            errors.append(CodeBlockError(match.start(), e))
 44
 45    def _md_match(match: Match[str]) -> str:
 46        code = textwrap.dedent(match["code"])
 47        with _collect_error(match):
 48            code = format_code_block(code)
 49        code = textwrap.indent(code, match["indent"])
 50        return f"{match['before']}{code}{match['after']}"
 51
 52    def _pycon_match(match: Match[str]) -> str:
 53        code = ""
 54        fragment = cast(Optional[str], None)
 55
 56        def finish_fragment() -> None:
 57            nonlocal code
 58            nonlocal fragment
 59
 60            if fragment is not None:
 61                with _collect_error(match):
 62                    fragment = format_code_block(fragment)
 63                fragment_lines = fragment.splitlines()
 64                code += f"{PYCON_PREFIX}{fragment_lines[0]}\n"
 65                for line in fragment_lines[1:]:
 66                    # Skip blank lines to handle Black adding a blank above
 67                    # functions within blocks. A blank line would end the REPL
 68                    # continuation prompt.
 69                    #
 70                    # >>> if True:
 71                    # ...     def f():
 72                    # ...         pass
 73                    # ...
 74                    if line:
 75                        code += f"{PYCON_CONTINUATION_PREFIX} {line}\n"
 76                if fragment_lines[-1].startswith(" "):
 77                    code += f"{PYCON_CONTINUATION_PREFIX}\n"
 78                fragment = None
 79
 80        indentation = None
 81        for line in match["code"].splitlines():
 82            orig_line, line = line, line.lstrip()
 83            if indentation is None and line:
 84                indentation = len(orig_line) - len(line)
 85            continuation_match = PYCON_CONTINUATION_RE.match(line)
 86            if continuation_match and fragment is not None:
 87                fragment += line[continuation_match.end() :] + "\n"
 88            else:
 89                finish_fragment()
 90                if line.startswith(PYCON_PREFIX):
 91                    fragment = line[len(PYCON_PREFIX) :] + "\n"
 92                else:
 93                    code += orig_line[indentation:] + "\n"
 94        finish_fragment()
 95        return code
 96
 97    def _md_pycon_match(match: Match[str]) -> str:
 98        code = _pycon_match(match)
 99        code = textwrap.indent(code, match["indent"])
100        return f"{match['before']}{code}{match['after']}"
101
102    src = MD_RE.sub(_md_match, src)
103    src = MD_PYCON_RE.sub(_md_pycon_match, src)
104    return src, errors
105
106
107def format_code_block(code: str) -> str:
108    return subprocess.check_output(
109        [
110            sys.executable,
111            "-m",
112            "ruff",
113            "format",
114            "--stdin-filename=script.py",
115            f"--line-length={DEFAULT_LINE_LENGTH}",
116        ],
117        encoding="utf-8",
118        input=code,
119    )
120
121
122def format_file(
123    filename: str,
124    skip_errors: bool,
125) -> int:
126    with open(filename, encoding="UTF-8") as f:
127        contents = f.read()
128    new_contents, errors = format_str(contents)
129    for error in errors:
130        lineno = contents[: error.offset].count("\n") + 1
131        print(f"{filename}:{lineno}: code block parse error {error.exc}")
132    if errors and not skip_errors:
133        return 1
134    if contents != new_contents:
135        print(f"{filename}: Rewriting...")
136        with open(filename, "w", encoding="UTF-8") as f:
137            f.write(new_contents)
138        return 0
139    else:
140        return 0
141
142
143def main(argv: Sequence[str] | None = None) -> int:
144    parser = argparse.ArgumentParser()
145    parser.add_argument(
146        "-l",
147        "--line-length",
148        type=int,
149        default=DEFAULT_LINE_LENGTH,
150    )
151    parser.add_argument(
152        "-S",
153        "--skip-string-normalization",
154        action="store_true",
155    )
156    parser.add_argument("-E", "--skip-errors", action="store_true")
157    parser.add_argument("filenames", nargs="*")
158    args = parser.parse_args(argv)
159
160    retv = 0
161    for filename in args.filenames:
162        retv |= format_file(filename, skip_errors=args.skip_errors)
163    return retv
164
165
166if __name__ == "__main__":
167    raise SystemExit(main())