104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
import ast
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT_DIR = Path(__file__).parent.parent
|
|
|
|
class PythonIndentationFixer(ast.NodeVisitor):
|
|
def __init__(self, source_lines: list[str]):
|
|
self.source_lines = source_lines
|
|
self.result: list[str] = []
|
|
self._depth = 0
|
|
self._pending: list[tuple[int, str]] = []
|
|
|
|
def fix(self) -> str:
|
|
tree = ast.parse("".join(self.source_lines))
|
|
self._walk_module(tree)
|
|
return "\n".join(self.result)
|
|
|
|
def _get_line(self, lineno: int) -> str:
|
|
if 0 < lineno <= len(self.source_lines):
|
|
return self.source_lines[lineno - 1]
|
|
return ""
|
|
|
|
def _walk_module(self, node: ast.Module):
|
|
for item in node.body:
|
|
self._process_item(item, 0)
|
|
self.generic_visit(node)
|
|
|
|
def _process_item(self, node: ast.AST, base_depth: int):
|
|
lineno = node.lineno
|
|
line = self._get_line(lineno)
|
|
stripped = line.lstrip()
|
|
leading = len(line) - len(stripped)
|
|
expected = base_depth
|
|
|
|
if leading != expected:
|
|
self.result.append(" " * expected + stripped)
|
|
else:
|
|
self.result.append(line.rstrip("\n"))
|
|
|
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
|
|
body_depth = base_depth + 1
|
|
for child in node.body:
|
|
self._process_item(child, body_depth)
|
|
elif isinstance(node, (ast.If, ast.For, ast.While, ast.With, ast.Try)):
|
|
body_depth = base_depth + 1
|
|
for child in node.body:
|
|
self._process_item(child, body_depth)
|
|
if isinstance(node, ast.If) and node.orelse:
|
|
self._process_item(node.orelse, base_depth + 1)
|
|
if isinstance(node, ast.For) and node.orelse:
|
|
for child in node.orelse:
|
|
self._process_item(child, body_depth)
|
|
elif isinstance(node, ast.Try):
|
|
for handler in node.handlers:
|
|
for child in handler.body:
|
|
self._process_item(child, body_depth)
|
|
if node.orelse:
|
|
for child in node.orelse:
|
|
self._process_item(child, body_depth)
|
|
if node.finalbody:
|
|
for child in node.finalbody:
|
|
self._process_item(child, body_depth)
|
|
else:
|
|
self.generic_visit(node)
|
|
|
|
def generic_visit(self, node: ast.AST):
|
|
pass
|
|
|
|
def fix_file_ast(filepath: Path) -> tuple[bool, str]:
|
|
try:
|
|
with open(filepath, "r", encoding="utf-8", newline="") as f:
|
|
source = f.read()
|
|
|
|
lines = source.splitlines()
|
|
fixer = PythonIndentationFixer(lines)
|
|
new_source = fixer.fix()
|
|
|
|
ast.parse(new_source)
|
|
|
|
if new_source == source:
|
|
return False, "No changes needed"
|
|
|
|
with open(filepath, "w", encoding="utf-8", newline="") as f:
|
|
f.write(new_source)
|
|
|
|
return True, "Fixed"
|
|
except SyntaxError as e:
|
|
return False, f"SyntaxError: {e}"
|
|
except Exception as e:
|
|
return False, str(e)
|
|
|
|
def main():
|
|
if len(sys.argv) > 1:
|
|
filepath = Path(sys.argv[1])
|
|
changed, msg = fix_file_ast(filepath)
|
|
print(f"{filepath}: {msg}")
|
|
return
|
|
|
|
print("AST-based Python indentation fixer")
|
|
print("Usage: python fix_indent_ast.py <filepath>")
|
|
|
|
if __name__ == "__main__":
|
|
main() |