import ast import sys import pathlib class ImportCollector(ast.NodeTransformer): def __init__(self): self.collected_imports = [] def visit_Import(self, node): self.collected_imports.append(node) return None def visit_ImportFrom(self, node): self.collected_imports.append(node) return None class PassFiller(ast.NodeTransformer): """Ensures that blocks that became empty after import removal have a 'pass' statement.""" def generic_visit(self, node): super().generic_visit(node) if hasattr(node, 'body') and isinstance(node.body, list) and not node.body: if not isinstance(node, ast.Module): node.body.append(ast.Pass()) return node def fix_imports(file_path): path = pathlib.Path(file_path) if not path.exists(): print(f"File not found: {file_path}") return try: content = path.read_text(encoding='utf-8') tree = ast.parse(content) except Exception as e: print(f"Error parsing {file_path}: {e}") return collector = ImportCollector() tree = collector.visit(tree) # Fill empty bodies with pass tree = PassFiller().visit(tree) if not collector.collected_imports: print(f"No imports to move in {file_path}") return # De-duplicate while preserving order unique_imports = {} for node in collector.collected_imports: try: # We use unparse to identify identical imports code = ast.unparse(node).strip() if code not in unique_imports: unique_imports[code] = node except: continue if not unique_imports: return # Sort: __future__ first, then others future_imports = [] other_imports = [] for code, node in unique_imports.items(): if isinstance(node, ast.ImportFrom) and node.module == '__future__': future_imports.append(node) else: other_imports.append(node) all_to_insert = future_imports + other_imports # Find insertion point (after initial docstring) insertion_idx = 0 if tree.body: first = tree.body[0] if (isinstance(first, ast.Expr) and isinstance(first.value, ast.Constant) and isinstance(first.value.value, str)): insertion_idx = 1 # Insert imports for i, node in enumerate(all_to_insert): tree.body.insert(insertion_idx + i, node) try: new_code = ast.unparse(tree) # Basic check to avoid unnecessary writes if new_code.strip() != content.strip(): path.write_text(new_code, encoding='utf-8') print(f"Updated {file_path}") else: print(f"No changes for {file_path}") except Exception as e: print(f"Error unparsing or writing {file_path}: {e}") if __name__ == "__main__": if len(sys.argv) < 2: print("Usage: python scripts/fix_imports.py ...") else: for arg in sys.argv[1:]: fix_imports(arg)