main
1# fork of https://github.com/asottile/blacken-docs adapted for ruff
2from __future__ import annotations
3
4import re
5import sys
6import argparse
7import textwrap
8import contextlib
9import subprocess
10from typing import Match, Optional, Sequence, Generator, NamedTuple, cast
11
12MD_RE = re.compile(
13 r"(?P<before>^(?P<indent> *)```\s*python\n)" r"(?P<code>.*?)" r"(?P<after>^(?P=indent)```\s*$)",
14 re.DOTALL | re.MULTILINE,
15)
16MD_PYCON_RE = re.compile(
17 r"(?P<before>^(?P<indent> *)```\s*pycon\n)" r"(?P<code>.*?)" r"(?P<after>^(?P=indent)```.*$)",
18 re.DOTALL | re.MULTILINE,
19)
20PYCON_PREFIX = ">>> "
21PYCON_CONTINUATION_PREFIX = "..."
22PYCON_CONTINUATION_RE = re.compile(
23 rf"^{re.escape(PYCON_CONTINUATION_PREFIX)}( |$)",
24)
25DEFAULT_LINE_LENGTH = 100
26
27
28class CodeBlockError(NamedTuple):
29 offset: int
30 exc: Exception
31
32
33def format_str(
34 src: str,
35) -> tuple[str, Sequence[CodeBlockError]]:
36 errors: list[CodeBlockError] = []
37
38 @contextlib.contextmanager
39 def _collect_error(match: Match[str]) -> Generator[None, None, None]:
40 try:
41 yield
42 except Exception as e:
43 errors.append(CodeBlockError(match.start(), e))
44
45 def _md_match(match: Match[str]) -> str:
46 code = textwrap.dedent(match["code"])
47 with _collect_error(match):
48 code = format_code_block(code)
49 code = textwrap.indent(code, match["indent"])
50 return f"{match['before']}{code}{match['after']}"
51
52 def _pycon_match(match: Match[str]) -> str:
53 code = ""
54 fragment = cast(Optional[str], None)
55
56 def finish_fragment() -> None:
57 nonlocal code
58 nonlocal fragment
59
60 if fragment is not None:
61 with _collect_error(match):
62 fragment = format_code_block(fragment)
63 fragment_lines = fragment.splitlines()
64 code += f"{PYCON_PREFIX}{fragment_lines[0]}\n"
65 for line in fragment_lines[1:]:
66 # Skip blank lines to handle Black adding a blank above
67 # functions within blocks. A blank line would end the REPL
68 # continuation prompt.
69 #
70 # >>> if True:
71 # ... def f():
72 # ... pass
73 # ...
74 if line:
75 code += f"{PYCON_CONTINUATION_PREFIX} {line}\n"
76 if fragment_lines[-1].startswith(" "):
77 code += f"{PYCON_CONTINUATION_PREFIX}\n"
78 fragment = None
79
80 indentation = None
81 for line in match["code"].splitlines():
82 orig_line, line = line, line.lstrip()
83 if indentation is None and line:
84 indentation = len(orig_line) - len(line)
85 continuation_match = PYCON_CONTINUATION_RE.match(line)
86 if continuation_match and fragment is not None:
87 fragment += line[continuation_match.end() :] + "\n"
88 else:
89 finish_fragment()
90 if line.startswith(PYCON_PREFIX):
91 fragment = line[len(PYCON_PREFIX) :] + "\n"
92 else:
93 code += orig_line[indentation:] + "\n"
94 finish_fragment()
95 return code
96
97 def _md_pycon_match(match: Match[str]) -> str:
98 code = _pycon_match(match)
99 code = textwrap.indent(code, match["indent"])
100 return f"{match['before']}{code}{match['after']}"
101
102 src = MD_RE.sub(_md_match, src)
103 src = MD_PYCON_RE.sub(_md_pycon_match, src)
104 return src, errors
105
106
107def format_code_block(code: str) -> str:
108 return subprocess.check_output(
109 [
110 sys.executable,
111 "-m",
112 "ruff",
113 "format",
114 "--stdin-filename=script.py",
115 f"--line-length={DEFAULT_LINE_LENGTH}",
116 ],
117 encoding="utf-8",
118 input=code,
119 )
120
121
122def format_file(
123 filename: str,
124 skip_errors: bool,
125) -> int:
126 with open(filename, encoding="UTF-8") as f:
127 contents = f.read()
128 new_contents, errors = format_str(contents)
129 for error in errors:
130 lineno = contents[: error.offset].count("\n") + 1
131 print(f"{filename}:{lineno}: code block parse error {error.exc}")
132 if errors and not skip_errors:
133 return 1
134 if contents != new_contents:
135 print(f"{filename}: Rewriting...")
136 with open(filename, "w", encoding="UTF-8") as f:
137 f.write(new_contents)
138 return 0
139 else:
140 return 0
141
142
143def main(argv: Sequence[str] | None = None) -> int:
144 parser = argparse.ArgumentParser()
145 parser.add_argument(
146 "-l",
147 "--line-length",
148 type=int,
149 default=DEFAULT_LINE_LENGTH,
150 )
151 parser.add_argument(
152 "-S",
153 "--skip-string-normalization",
154 action="store_true",
155 )
156 parser.add_argument("-E", "--skip-errors", action="store_true")
157 parser.add_argument("filenames", nargs="*")
158 args = parser.parse_args(argv)
159
160 retv = 0
161 for filename in args.filenames:
162 retv |= format_file(filename, skip_errors=args.skip_errors)
163 return retv
164
165
166if __name__ == "__main__":
167 raise SystemExit(main())