import ast import json import os import pathlib import sys import re from typing import List, Dict, Any, Optional, Tuple def find_closing_quotes_pos(line: str) -> Tuple[int, str]: pos_double = line.rfind('"""') pos_single = line.rfind("'''") if pos_double != -1 and pos_single != -1: if pos_double > pos_single: return pos_double, '"""' else: return pos_single, "'''" elif pos_double != -1: return pos_double, '"""' elif pos_single != -1: return pos_single, "'''" return -1, "" class SdmDocstringInjectorVisitor(ast.NodeVisitor): def __init__(self, file_path: str, sdm_tags_map: Dict[str, Any], lines: List[str]): self.file_path = file_path self.sdm_tags_map = sdm_tags_map self.lines = lines self.targets_to_modify = [] self.current_class_name = None self.project_root = pathlib.Path.cwd().resolve() def get_rel_path(self, path): p = pathlib.Path(path).resolve() try: return str(p.relative_to(self.project_root)).replace("\\", "/") except (ValueError, RuntimeError): return str(p).replace("\\", "/") def _get_sdm_tags(self, name: str, node_type: str, parent_class_name: Optional[str] = None) -> List[str]: relative_file_path = self.get_rel_path(self.file_path) file_data = self.sdm_tags_map.get(relative_file_path) if not file_data: return [] tags = [] if node_type == 'ClassDef': class_data = file_data.get('classes', {}).get(name, {}) class_tag = class_data.get('class_tag') if class_tag: tags.append(class_tag) elif node_type in ('FunctionDef', 'AsyncFunctionDef'): if parent_class_name: class_data = file_data.get('classes', {}).get(parent_class_name, {}) tag = class_data.get('methods', {}).get(name) if tag: tags.append(tag) else: tag = file_data.get('functions', {}).get(name) if tag: tags.append(tag) return tags def _process_node(self, node, node_type: str): if not node.body: return sdm_tags = self._get_sdm_tags(node.name, node_type, self.current_class_name) first_body_node = node.body[0] if (node.lineno == first_body_node.lineno): return docstring_node = None if isinstance(node.body[0], ast.Expr) and \ isinstance(node.body[0].value, ast.Constant) and isinstance(node.body[0].value.value, str): docstring_node = node.body[0].value # Use col_offset of the first body node for exact matching body_indent_count = first_body_node.col_offset if docstring_node: self.targets_to_modify.append({ 'type': 'append', 'node': node, 'name': node.name, 'sdm_tags': sdm_tags, 'start_lineno': docstring_node.lineno, 'end_lineno': docstring_node.end_lineno, 'indent_count': body_indent_count, 'existing_doc': docstring_node.value }) elif sdm_tags: self.targets_to_modify.append({ 'type': 'new', 'node': node, 'name': node.name, 'sdm_tags': sdm_tags, 'insert_lineno': first_body_node.lineno, 'indent_count': body_indent_count }) def visit_ClassDef(self, node): self._process_node(node, 'ClassDef') old_class = self.current_class_name self.current_class_name = node.name self.generic_visit(node) self.current_class_name = old_class def visit_FunctionDef(self, node): self._process_node(node, 'FunctionDef') self.generic_visit(node) def visit_AsyncFunctionDef(self, node): self._process_node(node, 'AsyncFunctionDef') self.generic_visit(node) def strip_tags(docstring: str) -> str: lines = docstring.splitlines() new_lines = [] for line in lines: if re.search(r'\[C:.*\]|\[M:.*\]|\[U:.*\]|\[VARS:.*\]', line): continue new_lines.append(line) while new_lines and not new_lines[-1].strip(): new_lines.pop() return "\n".join(new_lines) def process_file(py_file_path: pathlib.Path, sdm_tags_map): try: with open(py_file_path, 'r', encoding='utf-8') as f: content = f.read() lines = content.splitlines() if not lines: return try: tree = ast.parse(content) except SyntaxError: return visitor = SdmDocstringInjectorVisitor(str(py_file_path.resolve()), sdm_tags_map, lines) visitor.visit(tree) if not visitor.targets_to_modify: return visitor.targets_to_modify.sort(key=lambda t: t['node'].lineno, reverse=True) modified_lines = lines[:] file_modified = False for target in visitor.targets_to_modify: sdm_tags = target['sdm_tags'] indent = " " * target['indent_count'] if target['type'] == 'append': clean_doc = strip_tags(target['existing_doc']) if sdm_tags: prepared_tags = [f"{indent}{line}" for t in sdm_tags for line in t.splitlines()] new_content = (clean_doc + "\n" + "\n".join(prepared_tags)) if clean_doc.strip() else "\n".join(prepared_tags) else: new_content = clean_doc start_idx = target['start_lineno'] - 1 end_idx = target['end_lineno'] - 1 first_line, last_line = modified_lines[start_idx], modified_lines[end_idx] q_start_pos = first_line.find('"""') if q_start_pos == -1: q_start_pos = first_line.find("'''") q_end_pos, q_type = find_closing_quotes_pos(last_line) if q_start_pos != -1 and q_end_pos != -1: q_prefix, q_suffix = first_line[:q_start_pos + 3], last_line[q_end_pos:] if "\n" in new_content or (start_idx != end_idx): replacement = [q_prefix] + [f"{indent}{l}" for l in new_content.splitlines()] + [f"{indent}{q_suffix}"] else: replacement = [f"{q_prefix}{new_content}{q_suffix}"] modified_lines[start_idx:end_idx+1] = replacement file_modified = True elif sdm_tags: prepared_tags = [f"{indent}{line}" for t in sdm_tags for line in t.splitlines()] new_doc = [f'{indent}"""', "\n".join(prepared_tags), f'{indent}"""'] insert_idx = target['insert_lineno'] - 1 while insert_idx > 0 and not modified_lines[insert_idx-1].strip(): insert_idx -= 1 modified_lines[insert_idx:insert_idx] = new_doc file_modified = True if file_modified: with open(py_file_path, 'w', encoding='utf-8') as f: f.write("\n".join(modified_lines)) except Exception as e: print(f"Error processing {py_file_path}: {e}", file=sys.stderr) def main(): sdm_report_path = "sdm_report_refined.json" if not pathlib.Path(sdm_report_path).exists(): print(f"Error: {sdm_report_path} not found.", file=sys.stderr); sys.exit(1) with open(sdm_report_path, 'r', encoding='utf-8') as f: sdm_tags_map = json.load(f) targets = sys.argv[1:] if not targets: for d in ["src", "simulation", "tests"]: sd = pathlib.Path(d) if sd.exists(): for f in sd.rglob("*.py"): process_file(f, sdm_tags_map) else: for t in targets: tp = pathlib.Path(t) if tp.is_file(): process_file(tp, sdm_tags_map) elif tp.is_dir(): for f in tp.rglob("*.py"): process_file(f, sdm_tags_map) if __name__ == "__main__": main()