refactor(scripts): Add strict type hints to utility scripts

This commit is contained in:
2026-02-28 18:58:53 -05:00
parent c368caf43a
commit 53c2bbfa81
6 changed files with 45 additions and 42 deletions

View File

@@ -9,9 +9,10 @@ import ast
import re import re
import sys import sys
import os import os
from typing import Any, Callable
BASE = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) BASE: str = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
stats = {"auto_none": 0, "manual_sig": 0, "vars": 0, "errors": []} stats: dict[str, Any] = {"auto_none": 0, "manual_sig": 0, "vars": 0, "errors": []}
def abs_path(filename: str) -> str: def abs_path(filename: str) -> str:
return os.path.join(BASE, filename) return os.path.join(BASE, filename)
@@ -167,7 +168,7 @@ def verify_syntax(filepath: str) -> str:
# ============================================================ # ============================================================
# gui_2.py manual signatures (Tier 3 items) # gui_2.py manual signatures (Tier 3 items)
# ============================================================ # ============================================================
GUI2_MANUAL_SIGS = [ GUI2_MANUAL_SIGS: list[tuple[str, str]] = [
(r'def resolve_pending_action\(self, action_id: str, approved: bool\):', (r'def resolve_pending_action\(self, action_id: str, approved: bool\):',
r'def resolve_pending_action(self, action_id: str, approved: bool) -> bool:'), r'def resolve_pending_action(self, action_id: str, approved: bool) -> bool:'),
(r'def _cb_start_track\(self, user_data=None\):', (r'def _cb_start_track\(self, user_data=None\):',
@@ -185,7 +186,7 @@ GUI2_MANUAL_SIGS = [
# ============================================================ # ============================================================
# gui_legacy.py manual signatures (Tier 3 items) # gui_legacy.py manual signatures (Tier 3 items)
# ============================================================ # ============================================================
LEGACY_MANUAL_SIGS = [ LEGACY_MANUAL_SIGS: list[tuple[str, str]] = [
(r'def _add_kv_row\(parent: str, key: str, val, val_color=None\):', (r'def _add_kv_row\(parent: str, key: str, val, val_color=None\):',
r'def _add_kv_row(parent: str, key: str, val: Any, val_color: tuple[int, int, int] | None = None) -> None:'), r'def _add_kv_row(parent: str, key: str, val: Any, val_color: tuple[int, int, int] | None = None) -> None:'),
(r'def _make_remove_file_cb\(self, idx: int\):', (r'def _make_remove_file_cb\(self, idx: int\):',
@@ -229,7 +230,7 @@ LEGACY_MANUAL_SIGS = [
# ============================================================ # ============================================================
# gui_2.py variable type annotations # gui_2.py variable type annotations
# ============================================================ # ============================================================
GUI2_VAR_REPLACEMENTS = [ GUI2_VAR_REPLACEMENTS: list[tuple[str, str]] = [
(r'^CONFIG_PATH = ', 'CONFIG_PATH: Path = '), (r'^CONFIG_PATH = ', 'CONFIG_PATH: Path = '),
(r'^PROVIDERS = ', 'PROVIDERS: list[str] = '), (r'^PROVIDERS = ', 'PROVIDERS: list[str] = '),
(r'^COMMS_CLAMP_CHARS = ', 'COMMS_CLAMP_CHARS: int = '), (r'^COMMS_CLAMP_CHARS = ', 'COMMS_CLAMP_CHARS: int = '),
@@ -255,7 +256,7 @@ GUI2_VAR_REPLACEMENTS = [
# ============================================================ # ============================================================
# gui_legacy.py variable type annotations # gui_legacy.py variable type annotations
# ============================================================ # ============================================================
LEGACY_VAR_REPLACEMENTS = [ LEGACY_VAR_REPLACEMENTS: list[tuple[str, str]] = [
(r'^CONFIG_PATH = ', 'CONFIG_PATH: Path = '), (r'^CONFIG_PATH = ', 'CONFIG_PATH: Path = '),
(r'^PROVIDERS = ', 'PROVIDERS: list[str] = '), (r'^PROVIDERS = ', 'PROVIDERS: list[str] = '),
(r'^COMMS_CLAMP_CHARS = ', 'COMMS_CLAMP_CHARS: int = '), (r'^COMMS_CLAMP_CHARS = ', 'COMMS_CLAMP_CHARS: int = '),

View File

@@ -8,9 +8,9 @@ import tomllib
import tree_sitter import tree_sitter
import tree_sitter_python import tree_sitter_python
LOG_FILE = 'logs/claude_mma_delegation.log' LOG_FILE: str = 'logs/claude_mma_delegation.log'
MODEL_MAP = { MODEL_MAP: dict[str, str] = {
'tier1-orchestrator': 'claude-opus-4-6', 'tier1-orchestrator': 'claude-opus-4-6',
'tier1': 'claude-opus-4-6', 'tier1': 'claude-opus-4-6',
'tier2-tech-lead': 'claude-sonnet-4-6', 'tier2-tech-lead': 'claude-sonnet-4-6',
@@ -86,7 +86,7 @@ def get_role_documents(role: str) -> list[str]:
return [] return []
def log_delegation(role, full_prompt, result=None, summary_prompt=None): def log_delegation(role: str, full_prompt: str, result: str | None = None, summary_prompt: str | None = None) -> str:
os.makedirs('logs/claude_agents', exist_ok=True) os.makedirs('logs/claude_agents', exist_ok=True)
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = f'logs/claude_agents/claude_{role}_task_{timestamp}.log' log_file = f'logs/claude_agents/claude_{role}_task_{timestamp}.log'
@@ -137,7 +137,7 @@ def execute_agent(role: str, prompt: str, docs: list[str]) -> str:
# Advanced Context: Dependency skeletons for Tier 3 # Advanced Context: Dependency skeletons for Tier 3
injected_context = "" injected_context = ""
UNFETTERED_MODULES = ['mcp_client', 'project_manager', 'events', 'aggregate'] UNFETTERED_MODULES: list[str] = ['mcp_client', 'project_manager', 'events', 'aggregate']
if role in ['tier3', 'tier3-worker']: if role in ['tier3', 'tier3-worker']:
for doc in docs: for doc in docs:
@@ -231,7 +231,7 @@ def execute_agent(role: str, prompt: str, docs: list[str]) -> str:
return err_msg return err_msg
def create_parser(): def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Claude MMA Execution Script") parser = argparse.ArgumentParser(description="Claude MMA Execution Script")
parser.add_argument( parser.add_argument(
"--role", "--role",
@@ -275,7 +275,7 @@ def main() -> None:
docs = get_role_documents(role) docs = get_role_documents(role)
# Extract @file references from the prompt # Extract @file references from the prompt
file_refs = re.findall(r"@([\w./\\]+)", prompt) file_refs: list[str] = re.findall(r"@([\w./\\]+)", prompt)
for ref in file_refs: for ref in file_refs:
if os.path.exists(ref) and ref not in docs: if os.path.exists(ref) and ref not in docs:
docs.append(ref) docs.append(ref)

View File

@@ -2,14 +2,14 @@ import os
import re import re
with open('mcp_client.py', 'r', encoding='utf-8') as f: with open('mcp_client.py', 'r', encoding='utf-8') as f:
content = f.read() content: str = f.read()
# 1. Add import os if not there # 1. Add import os if not there
if 'import os' not in content: if 'import os' not in content:
content = content.replace('import summarize', 'import os\nimport summarize') content: str = content.replace('import summarize', 'import os\nimport summarize')
# 2. Add the functions before "# ------------------------------------------------------------------ web tools" # 2. Add the functions before "# ------------------------------------------------------------------ web tools"
functions_code = r''' functions_code: str = r'''
def py_find_usages(path: str, name: str) -> str: def py_find_usages(path: str, name: str) -> str:
"""Finds exact string matches of a symbol in a given file or directory.""" """Finds exact string matches of a symbol in a given file or directory."""
p, err = _resolve_and_check(path) p, err = _resolve_and_check(path)
@@ -179,17 +179,17 @@ def get_tree(path: str, max_depth: int = 2) -> str:
# ------------------------------------------------------------------ web tools''' # ------------------------------------------------------------------ web tools'''
content = content.replace('# ------------------------------------------------------------------ web tools', functions_code) content: str = content.replace('# ------------------------------------------------------------------ web tools', functions_code)
# 3. Update TOOL_NAMES # 3. Update TOOL_NAMES
old_tool_names_match = re.search(r'TOOL_NAMES\s*=\s*\{([^}]*)\}', content) old_tool_names_match: re.Match | None = re.search(r'TOOL_NAMES\s*=\s*\{([^}]*)\}', content)
if old_tool_names_match: if old_tool_names_match:
old_names = old_tool_names_match.group(1) old_names: str = old_tool_names_match.group(1)
new_names = old_names + ', "py_find_usages", "py_get_imports", "py_check_syntax", "py_get_hierarchy", "py_get_docstring", "get_tree"' new_names: str = old_names + ', "py_find_usages", "py_get_imports", "py_check_syntax", "py_get_hierarchy", "py_get_docstring", "get_tree"'
content = content.replace(old_tool_names_match.group(0), f'TOOL_NAMES = {{{new_names}}}') content: str = content.replace(old_tool_names_match.group(0), f'TOOL_NAMES = {{{new_names}}}')
# 4. Update dispatch # 4. Update dispatch
dispatch_additions = r''' dispatch_additions: str = r'''
if tool_name == "py_find_usages": if tool_name == "py_find_usages":
return py_find_usages(tool_input.get("path", ""), tool_input.get("name", "")) return py_find_usages(tool_input.get("path", ""), tool_input.get("name", ""))
if tool_name == "py_get_imports": if tool_name == "py_get_imports":
@@ -204,10 +204,11 @@ dispatch_additions = r'''
return get_tree(tool_input.get("path", ""), tool_input.get("max_depth", 2)) return get_tree(tool_input.get("path", ""), tool_input.get("max_depth", 2))
return f"ERROR: unknown MCP tool '{tool_name}'" return f"ERROR: unknown MCP tool '{tool_name}'"
''' '''
content = re.sub(r' return f"ERROR: unknown MCP tool \'{tool_name}\'"', dispatch_additions.strip(), content) content: str = re.sub(
r' return f"ERROR: unknown MCP tool \'{tool_name}\'"', dispatch_additions.strip(), content)
# 5. Update MCP_TOOL_SPECS # 5. Update MCP_TOOL_SPECS
mcp_tool_specs_addition = r''' mcp_tool_specs_addition: str = r'''
{ {
"name": "py_find_usages", "name": "py_find_usages",
"description": "Finds exact string matches of a symbol in a given file or directory.", "description": "Finds exact string matches of a symbol in a given file or directory.",
@@ -281,7 +282,8 @@ mcp_tool_specs_addition = r'''
] ]
''' '''
content = re.sub(r'\]\s*$', mcp_tool_specs_addition.strip(), content) content: str = re.sub(
r'\]\s*$', mcp_tool_specs_addition.strip(), content)
with open('mcp_client.py', 'w', encoding='utf-8') as f: with open('mcp_client.py', 'w', encoding='utf-8') as f:
f.write(content) f.write(content)

View File

@@ -8,7 +8,7 @@ import tree_sitter_python
import ast import ast
import datetime import datetime
LOG_FILE = 'logs/mma_delegation.log' LOG_FILE: str = 'logs/mma_delegation.log'
def generate_skeleton(code: str) -> str: def generate_skeleton(code: str) -> str:
""" """
@@ -79,7 +79,7 @@ def get_role_documents(role: str) -> list[str]:
return ['conductor/workflow.md'] return ['conductor/workflow.md']
return [] return []
def log_delegation(role, full_prompt, result=None, summary_prompt=None): def log_delegation(role: str, full_prompt: str, result: str | None = None, summary_prompt: str | None = None) -> str:
os.makedirs('logs/agents', exist_ok=True) os.makedirs('logs/agents', exist_ok=True)
timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = f'logs/agents/mma_{role}_task_{timestamp}.log' log_file = f'logs/agents/mma_{role}_task_{timestamp}.log'
@@ -130,7 +130,7 @@ def execute_agent(role: str, prompt: str, docs: list[str]) -> str:
injected_context = "" injected_context = ""
# Whitelist of modules that sub-agents have "unfettered" (full) access to. # Whitelist of modules that sub-agents have "unfettered" (full) access to.
# These will be provided in full if imported, instead of just skeletons. # These will be provided in full if imported, instead of just skeletons.
UNFETTERED_MODULES = ['mcp_client', 'project_manager', 'events', 'aggregate'] UNFETTERED_MODULES: list[str] = ['mcp_client', 'project_manager', 'events', 'aggregate']
if role in ['tier3', 'tier3-worker']: if role in ['tier3', 'tier3-worker']:
for doc in docs: for doc in docs:
if doc.endswith('.py') and os.path.exists(doc): if doc.endswith('.py') and os.path.exists(doc):
@@ -219,7 +219,7 @@ def execute_agent(role: str, prompt: str, docs: list[str]) -> str:
log_delegation(role, command_text, err_msg) log_delegation(role, command_text, err_msg)
return err_msg return err_msg
def create_parser(): def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="MMA Execution Script") parser = argparse.ArgumentParser(description="MMA Execution Script")
parser.add_argument( parser.add_argument(
"--role", "--role",

View File

@@ -1,24 +1,24 @@
"""Scan all .py files for missing type hints. Writes scan_report.txt.""" """Scan all .py files for missing type hints. Writes scan_report.txt."""
import ast, os import ast, os
SKIP = {'.git', '__pycache__', '.venv', 'venv', 'node_modules', '.claude', '.gemini'} SKIP: set[str] = {'.git', '__pycache__', '.venv', 'venv', 'node_modules', '.claude', '.gemini'}
BASE = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) BASE: str = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
os.chdir(BASE) os.chdir(BASE)
results = {} results: dict[str, tuple[int, int, int, int]] = {}
for root, dirs, files in os.walk('.'): for root, dirs, files in os.walk('.'):
dirs[:] = [d for d in dirs if d not in SKIP] dirs[:] = [d for d in dirs if d not in SKIP]
for f in files: for f in files:
if not f.endswith('.py'): if not f.endswith('.py'):
continue continue
path = os.path.join(root, f).replace('\\', '/') path: str = os.path.join(root, f).replace('\\', '/')
try: try:
with open(path, 'r', encoding='utf-8-sig') as fh: with open(path, 'r', encoding='utf-8-sig') as fh:
tree = ast.parse(fh.read()) tree = ast.parse(fh.read())
except Exception: except Exception:
continue continue
counts = [0, 0, 0] # nr, up, uv counts: list[int] = [0, 0, 0] # nr, up, uv
def scan(scope, prefix=''): def scan(scope: ast.AST, prefix: str = '') -> None:
for node in ast.iter_child_nodes(scope): for node in ast.iter_child_nodes(scope):
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
if node.returns is None: if node.returns is None:
@@ -34,16 +34,16 @@ for root, dirs, files in os.walk('.'):
scan(node, prefix=f'{node.name}.') scan(node, prefix=f'{node.name}.')
scan(tree) scan(tree)
nr, up, uv = counts nr, up, uv = counts
total = nr + up + uv total: int = nr + up + uv
if total > 0: if total > 0:
results[path] = (nr, up, uv, total) results[path] = (nr, up, uv, total)
lines = [] lines: list[str] = []
lines.append(f'Files with untyped items: {len(results)}') lines.append(f'Files with untyped items: {len(results)}')
lines.append('') lines.append('')
lines.append(f'{"File":<58} {"NoRet":>6} {"Params":>7} {"Vars":>5} {"Total":>6}') lines.append(f'{"File":<58} {"NoRet":>6} {"Params":>7} {"Vars":>5} {"Total":>6}')
lines.append('-' * 85) lines.append('-' * 85)
gt = 0 gt: int = 0
for path in sorted(results, key=lambda x: results[x][3], reverse=True): for path in sorted(results, key=lambda x: results[x][3], reverse=True):
nr, up, uv, t = results[path] nr, up, uv, t = results[path]
lines.append(f'{path:<58} {nr:>6} {up:>7} {uv:>5} {t:>6}') lines.append(f'{path:<58} {nr:>6} {up:>7} {uv:>5} {t:>6}')
@@ -51,6 +51,6 @@ for path in sorted(results, key=lambda x: results[x][3], reverse=True):
lines.append('-' * 85) lines.append('-' * 85)
lines.append(f'{"TOTAL":<58} {"":>6} {"":>7} {"":>5} {gt:>6}') lines.append(f'{"TOTAL":<58} {"":>6} {"":>7} {"":>5} {gt:>6}')
report = '\n'.join(lines) report: str = '\n'.join(lines)
with open('scan_report.txt', 'w', encoding='utf-8') as f: with open('scan_report.txt', 'w', encoding='utf-8') as f:
f.write(report) f.write(report)

View File

@@ -1,14 +1,14 @@
import sys import sys
import ast import ast
def get_slice(filepath, start_line, end_line): def get_slice(filepath: str, start_line: int | str, end_line: int | str) -> str:
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, 'r', encoding='utf-8') as f:
lines = f.readlines() lines = f.readlines()
start_idx = int(start_line) - 1 start_idx = int(start_line) - 1
end_idx = int(end_line) end_idx = int(end_line)
return "".join(lines[start_idx:end_idx]) return "".join(lines[start_idx:end_idx])
def set_slice(filepath, start_line, end_line, new_content): def set_slice(filepath: str, start_line: int | str, end_line: int | str, new_content: str) -> None:
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, 'r', encoding='utf-8') as f:
lines = f.readlines() lines = f.readlines()
start_idx = int(start_line) - 1 start_idx = int(start_line) - 1
@@ -20,7 +20,7 @@ def set_slice(filepath, start_line, end_line, new_content):
with open(filepath, 'w', encoding='utf-8', newline='') as f: with open(filepath, 'w', encoding='utf-8', newline='') as f:
f.writelines(lines) f.writelines(lines)
def get_def(filepath, symbol_name): def get_def(filepath: str, symbol_name: str) -> str:
with open(filepath, 'r', encoding='utf-8') as f: with open(filepath, 'r', encoding='utf-8') as f:
content = f.read() content = f.read()
tree = ast.parse(content) tree = ast.parse(content)
@@ -35,7 +35,7 @@ def get_def(filepath, symbol_name):
return f"{start},{end}{chr(10)}{slice_content}" return f"{start},{end}{chr(10)}{slice_content}"
return "NOT_FOUND" return "NOT_FOUND"
def set_def(filepath, symbol_name, new_content): def set_def(filepath: str, symbol_name: str, new_content: str) -> None:
res = get_def(filepath, symbol_name) res = get_def(filepath, symbol_name)
if res == "NOT_FOUND": if res == "NOT_FOUND":
print(f"Error: Symbol '{symbol_name}' not found in {filepath}") print(f"Error: Symbol '{symbol_name}' not found in {filepath}")