add indentation scripts
This commit is contained in:
@@ -0,0 +1,104 @@
|
||||
import ast
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent
|
||||
|
||||
class PythonIndentationFixer(ast.NodeVisitor):
|
||||
def __init__(self, source_lines: list[str]):
|
||||
self.source_lines = source_lines
|
||||
self.result: list[str] = []
|
||||
self._depth = 0
|
||||
self._pending: list[tuple[int, str]] = []
|
||||
|
||||
def fix(self) -> str:
|
||||
tree = ast.parse("".join(self.source_lines))
|
||||
self._walk_module(tree)
|
||||
return "\n".join(self.result)
|
||||
|
||||
def _get_line(self, lineno: int) -> str:
|
||||
if 0 < lineno <= len(self.source_lines):
|
||||
return self.source_lines[lineno - 1]
|
||||
return ""
|
||||
|
||||
def _walk_module(self, node: ast.Module):
|
||||
for item in node.body:
|
||||
self._process_item(item, 0)
|
||||
self.generic_visit(node)
|
||||
|
||||
def _process_item(self, node: ast.AST, base_depth: int):
|
||||
lineno = node.lineno
|
||||
line = self._get_line(lineno)
|
||||
stripped = line.lstrip()
|
||||
leading = len(line) - len(stripped)
|
||||
expected = base_depth
|
||||
|
||||
if leading != expected:
|
||||
self.result.append(" " * expected + stripped)
|
||||
else:
|
||||
self.result.append(line.rstrip("\n"))
|
||||
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
||||
body_depth = base_depth + 1
|
||||
for child in node.body:
|
||||
self._process_item(child, body_depth)
|
||||
elif isinstance(node, (ast.If, ast.For, ast.While, ast.With, ast.Try)):
|
||||
body_depth = base_depth + 1
|
||||
for child in node.body:
|
||||
self._process_item(child, body_depth)
|
||||
if isinstance(node, ast.If) and node.orelse:
|
||||
self._process_item(node.orelse, base_depth + 1)
|
||||
if isinstance(node, ast.For) and node.orelse:
|
||||
for child in node.orelse:
|
||||
self._process_item(child, body_depth)
|
||||
elif isinstance(node, ast.Try):
|
||||
for handler in node.handlers:
|
||||
for child in handler.body:
|
||||
self._process_item(child, body_depth)
|
||||
if node.orelse:
|
||||
for child in node.orelse:
|
||||
self._process_item(child, body_depth)
|
||||
if node.finalbody:
|
||||
for child in node.finalbody:
|
||||
self._process_item(child, body_depth)
|
||||
else:
|
||||
self.generic_visit(node)
|
||||
|
||||
def generic_visit(self, node: ast.AST):
|
||||
pass
|
||||
|
||||
def fix_file_ast(filepath: Path) -> tuple[bool, str]:
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8", newline="") as f:
|
||||
source = f.read()
|
||||
|
||||
lines = source.splitlines()
|
||||
fixer = PythonIndentationFixer(lines)
|
||||
new_source = fixer.fix()
|
||||
|
||||
ast.parse(new_source)
|
||||
|
||||
if new_source == source:
|
||||
return False, "No changes needed"
|
||||
|
||||
with open(filepath, "w", encoding="utf-8", newline="") as f:
|
||||
f.write(new_source)
|
||||
|
||||
return True, "Fixed"
|
||||
except SyntaxError as e:
|
||||
return False, f"SyntaxError: {e}"
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
def main():
|
||||
if len(sys.argv) > 1:
|
||||
filepath = Path(sys.argv[1])
|
||||
changed, msg = fix_file_ast(filepath)
|
||||
print(f"{filepath}: {msg}")
|
||||
return
|
||||
|
||||
print("AST-based Python indentation fixer")
|
||||
print("Usage: python fix_indent_ast.py <filepath>")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user