From 173ea96fb4bc097a6e110e6b282dcdbe59ee8d7a Mon Sep 17 00:00:00 2001 From: Ed_ Date: Sat, 28 Feb 2026 19:36:38 -0500 Subject: [PATCH] refactor(indentation): Apply codebase-wide 1-space ultra-compact refactor. Formatted 21 core modules and tests. --- api_hooks.py | 6 +- gui_2.py | 19 ++ gui_legacy.py | 2 - mcp_client.py | 308 ++++++++++--------- project_manager.py | 34 ++- scripts/apply_type_hints.py | 35 +-- scripts/claude_mma_exec.py | 481 ++++++++++++++---------------- scripts/claude_tool_bridge.py | 135 ++++----- scripts/inject_tools.py | 25 +- scripts/mcp_server.py | 102 +++---- scripts/scan_all_hints.py | 3 +- session_logger.py | 7 + shell_runner.py | 2 +- tests/mock_alias_tool.py | 33 +- tests/test_api_events.py | 170 +++++------ tests/test_conductor_tech_lead.py | 186 ++++++------ tests/test_gemini_cli_adapter.py | 170 +++++------ tests/test_orchestrator_pm.py | 120 ++++---- tests/test_session_logging.py | 2 +- tests/test_sync_hooks.py | 5 +- tests/verify_mma_gui_robust.py | 2 +- 21 files changed, 917 insertions(+), 930 deletions(-) diff --git a/api_hooks.py b/api_hooks.py index 15a03a1..d5fd907 100644 --- a/api_hooks.py +++ b/api_hooks.py @@ -11,8 +11,8 @@ class HookServerInstance(ThreadingHTTPServer): """Custom HTTPServer that carries a reference to the main App instance.""" def __init__(self, server_address: tuple[str, int], RequestHandlerClass: type, app: Any) -> None: - super().__init__(server_address, RequestHandlerClass) - self.app = app + super().__init__(server_address, RequestHandlerClass) + self.app = app class HookHandler(BaseHTTPRequestHandler): """Handles incoming HTTP requests for the API hooks.""" @@ -276,7 +276,7 @@ class HookHandler(BaseHTTPRequestHandler): self.wfile.write(json.dumps({'error': str(e)}).encode('utf-8')) def log_message(self, format: str, *args: Any) -> None: - logging.info("Hook API: " + format % args) + logging.info("Hook API: " + format % args) class HookServer: def __init__(self, app: Any, port: int = 8999) -> None: diff --git a/gui_2.py b/gui_2.py index ee44a1f..96abdc1 100644 --- a/gui_2.py +++ b/gui_2.py @@ -102,6 +102,7 @@ class ConfirmDialog: self._condition = threading.Condition() self._done = False self._approved = False + def wait(self) -> tuple[bool, str]: with self._condition: while not self._done: @@ -115,6 +116,7 @@ class MMAApprovalDialog: self._condition = threading.Condition() self._done = False self._approved = False + def wait(self) -> tuple[bool, str]: with self._condition: while not self._done: @@ -131,6 +133,7 @@ class MMASpawnApprovalDialog: self._done = False self._approved = False self._abort = False + def wait(self) -> dict[str, Any]: with self._condition: while not self._done: @@ -293,6 +296,7 @@ class App: def _prune_old_logs(self) -> None: """Asynchronously prunes old insignificant logs on startup.""" + def run_prune() -> None: try: registry = LogRegistry("logs/log_registry.toml") @@ -306,6 +310,7 @@ class App: @property def current_provider(self) -> str: return self._current_provider + @current_provider.setter def current_provider(self, value: str) -> None: if value != self._current_provider: @@ -325,6 +330,7 @@ class App: @property def current_model(self) -> str: return self._current_model + @current_model.setter def current_model(self, value: str) -> None: if value != self._current_model: @@ -390,15 +396,18 @@ class App: def create_api(self) -> FastAPI: """Creates and configures the FastAPI application for headless mode.""" api = FastAPI(title="Manual Slop Headless API") + class GenerateRequest(BaseModel): prompt: str auto_add_history: bool = True temperature: float | None = None max_tokens: int | None = None + class ConfirmRequest(BaseModel): approved: bool API_KEY_NAME = "X-API-KEY" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + async def get_api_key(header_key: str = Depends(api_key_header)) -> str: """Validates the API key from the request header against configuration.""" headless_cfg = self.config.get("headless", {}) @@ -410,10 +419,12 @@ class App: if header_key == target_key: return header_key raise HTTPException(status_code=403, detail="Could not validate API Key") + @api.get("/health") def health() -> dict[str, str]: """Basic health check endpoint.""" return {"status": "ok"} + @api.get("/status", dependencies=[Depends(get_api_key)]) def status() -> dict[str, Any]: """Returns the current status of the AI provider and active project.""" @@ -424,6 +435,7 @@ class App: "ai_status": self.ai_status, "session_usage": self.session_usage } + @api.get("/api/v1/pending_actions", dependencies=[Depends(get_api_key)]) def pending_actions() -> list[dict[str, Any]]: """Lists all PowerShell scripts awaiting manual confirmation.""" @@ -442,6 +454,7 @@ class App: "base_dir": self._pending_dialog._base_dir }) return actions + @api.post("/api/v1/confirm/{action_id}", dependencies=[Depends(get_api_key)]) def confirm_action(action_id: str, req: ConfirmRequest) -> dict[str, Any]: """Approves or denies a pending PowerShell script execution.""" @@ -449,6 +462,7 @@ class App: if not success: raise HTTPException(status_code=404, detail=f"Action ID {action_id} not found") return {"status": "success", "action_id": action_id, "approved": req.approved} + @api.get("/api/v1/sessions", dependencies=[Depends(get_api_key)]) def list_sessions() -> list[str]: """Lists all available session log files.""" @@ -456,6 +470,7 @@ class App: if not log_dir.exists(): return [] return sorted([f.name for f in log_dir.glob("*.log")], reverse=True) + @api.get("/api/v1/sessions/{filename}", dependencies=[Depends(get_api_key)]) def get_session(filename: str) -> dict[str, str]: """Retrieves the content of a specific session log file.""" @@ -469,6 +484,7 @@ class App: return {"filename": filename, "content": content} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + @api.delete("/api/v1/sessions/{filename}", dependencies=[Depends(get_api_key)]) def delete_session(filename: str) -> dict[str, str]: """Deletes a specific session log file.""" @@ -482,6 +498,7 @@ class App: return {"status": "success", "message": f"Deleted {filename}"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + @api.get("/api/v1/context", dependencies=[Depends(get_api_key)]) def get_context() -> dict[str, Any]: """Returns the current file and screenshot context configuration.""" @@ -491,6 +508,7 @@ class App: "files_base_dir": self.ui_files_base_dir, "screenshots_base_dir": self.ui_shots_base_dir } + @api.post("/api/v1/generate", dependencies=[Depends(get_api_key)]) def generate(req: GenerateRequest) -> dict[str, Any]: """Triggers an AI generation request using the current project context.""" @@ -547,6 +565,7 @@ class App: raise HTTPException(status_code=502, detail=f"AI Provider Error: {e.ui_message()}") except Exception as e: raise HTTPException(status_code=500, detail=f"In-flight AI request failure: {e}") + @api.post("/api/v1/stream", dependencies=[Depends(get_api_key)]) async def stream(req: GenerateRequest) -> Any: """Placeholder for streaming AI generation responses (Not yet implemented).""" diff --git a/gui_legacy.py b/gui_legacy.py index 79c1477..d52b88d 100644 --- a/gui_legacy.py +++ b/gui_legacy.py @@ -2400,5 +2400,3 @@ def main() -> None: if __name__ == "__main__": main() - - diff --git a/mcp_client.py b/mcp_client.py index 4f59605..47bc19a 100644 --- a/mcp_client.py +++ b/mcp_client.py @@ -520,175 +520,167 @@ def get_git_diff(path: str, base_rev: str = "HEAD", head_rev: str = "") -> str: return f"ERROR running git diff: {e.stderr}" except Exception as e: return f"ERROR: {e}" - + def py_find_usages(path: str, name: str) -> str: - """Finds exact string matches of a symbol in a given file or directory.""" - p, err = _resolve_and_check(path) - if err: return err - try: - import re - pattern = re.compile(r"\b" + re.escape(name) + r"\b") - results = [] - def _search_file(fp): - if fp.name == "history.toml" or fp.name.endswith("_history.toml"): return - if not _is_allowed(fp): return - try: - text = fp.read_text(encoding="utf-8") - lines = text.splitlines() - for i, line in enumerate(lines, 1): - if pattern.search(line): - rel = fp.relative_to(_primary_base_dir if _primary_base_dir else Path.cwd()) - results.append(f"{rel}:{i}: {line.strip()[:100]}") - except Exception: - pass - - if p.is_file(): - _search_file(p) - else: - for root, dirs, files in os.walk(p): - dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('__pycache__', 'venv', 'env')] - for file in files: - if file.endswith(('.py', '.md', '.toml', '.txt', '.json')): - _search_file(Path(root) / file) - - if not results: - return f"No usages found for '{name}' in {p}" - if len(results) > 100: - return "\n".join(results[:100]) + f"\n... (and {len(results)-100} more)" - return "\n".join(results) - except Exception as e: - return f"ERROR finding usages for '{name}': {e}" + """Finds exact string matches of a symbol in a given file or directory.""" + p, err = _resolve_and_check(path) + if err: return err + try: + import re + pattern = re.compile(r"\b" + re.escape(name) + r"\b") + results = [] + + def _search_file(fp): + if fp.name == "history.toml" or fp.name.endswith("_history.toml"): return + if not _is_allowed(fp): return + try: + text = fp.read_text(encoding="utf-8") + lines = text.splitlines() + for i, line in enumerate(lines, 1): + if pattern.search(line): + rel = fp.relative_to(_primary_base_dir if _primary_base_dir else Path.cwd()) + results.append(f"{rel}:{i}: {line.strip()[:100]}") + except Exception: + pass + if p.is_file(): + _search_file(p) + else: + for root, dirs, files in os.walk(p): + dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('__pycache__', 'venv', 'env')] + for file in files: + if file.endswith(('.py', '.md', '.toml', '.txt', '.json')): + _search_file(Path(root) / file) + if not results: + return f"No usages found for '{name}' in {p}" + if len(results) > 100: + return "\n".join(results[:100]) + f"\n... (and {len(results)-100} more)" + return "\n".join(results) + except Exception as e: + return f"ERROR finding usages for '{name}': {e}" def py_get_imports(path: str) -> str: - """Parses a file's AST and returns a strict list of its dependencies.""" - p, err = _resolve_and_check(path) - if err: return err - if not p.is_file() or p.suffix != ".py": return f"ERROR: not a python file: {path}" - try: - import ast - code = p.read_text(encoding="utf-8") - tree = ast.parse(code) - imports = [] - for node in tree.body: - if isinstance(node, ast.Import): - for alias in node.names: - imports.append(alias.name) - elif isinstance(node, ast.ImportFrom): - module = node.module or "" - for alias in node.names: - imports.append(f"{module}.{alias.name}" if module else alias.name) - if not imports: return "No imports found." - return "Imports:\n" + "\n".join(f" - {i}" for i in imports) - except Exception as e: - return f"ERROR getting imports for '{path}': {e}" + """Parses a file's AST and returns a strict list of its dependencies.""" + p, err = _resolve_and_check(path) + if err: return err + if not p.is_file() or p.suffix != ".py": return f"ERROR: not a python file: {path}" + try: + import ast + code = p.read_text(encoding="utf-8") + tree = ast.parse(code) + imports = [] + for node in tree.body: + if isinstance(node, ast.Import): + for alias in node.names: + imports.append(alias.name) + elif isinstance(node, ast.ImportFrom): + module = node.module or "" + for alias in node.names: + imports.append(f"{module}.{alias.name}" if module else alias.name) + if not imports: return "No imports found." + return "Imports:\n" + "\n".join(f" - {i}" for i in imports) + except Exception as e: + return f"ERROR getting imports for '{path}': {e}" def py_check_syntax(path: str) -> str: - """Runs a quick syntax check on a Python file.""" - p, err = _resolve_and_check(path) - if err: return err - if not p.is_file() or p.suffix != ".py": return f"ERROR: not a python file: {path}" - try: - import ast - code = p.read_text(encoding="utf-8") - ast.parse(code) - return f"Syntax OK: {path}" - except SyntaxError as e: - return f"SyntaxError in {path} at line {e.lineno}, offset {e.offset}: {e.msg}\n{e.text}" - except Exception as e: - return f"ERROR checking syntax for '{path}': {e}" + """Runs a quick syntax check on a Python file.""" + p, err = _resolve_and_check(path) + if err: return err + if not p.is_file() or p.suffix != ".py": return f"ERROR: not a python file: {path}" + try: + import ast + code = p.read_text(encoding="utf-8") + ast.parse(code) + return f"Syntax OK: {path}" + except SyntaxError as e: + return f"SyntaxError in {path} at line {e.lineno}, offset {e.offset}: {e.msg}\n{e.text}" + except Exception as e: + return f"ERROR checking syntax for '{path}': {e}" def py_get_hierarchy(path: str, class_name: str) -> str: - """Scans the project to find subclasses of a given class.""" - p, err = _resolve_and_check(path) - if err: return err - import ast - subclasses = [] - - def _search_file(fp): - if not _is_allowed(fp): return - try: - code = fp.read_text(encoding="utf-8") - tree = ast.parse(code) - for node in ast.walk(tree): - if isinstance(node, ast.ClassDef): - for base in node.bases: - if isinstance(base, ast.Name) and base.id == class_name: - subclasses.append(f"{fp.name}: class {node.name}({class_name})") - elif isinstance(base, ast.Attribute) and base.attr == class_name: - subclasses.append(f"{fp.name}: class {node.name}({base.value.id}.{class_name})") - except Exception: - pass - - try: - if p.is_file(): - _search_file(p) - else: - for root, dirs, files in os.walk(p): - dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('__pycache__', 'venv', 'env')] - for file in files: - if file.endswith('.py'): - _search_file(Path(root) / file) - - if not subclasses: - return f"No subclasses of '{class_name}' found in {p}" - return f"Subclasses of '{class_name}':\n" + "\n".join(f" - {s}" for s in subclasses) - except Exception as e: - return f"ERROR finding subclasses of '{class_name}': {e}" + """Scans the project to find subclasses of a given class.""" + p, err = _resolve_and_check(path) + if err: return err + import ast + subclasses = [] + + def _search_file(fp): + if not _is_allowed(fp): return + try: + code = fp.read_text(encoding="utf-8") + tree = ast.parse(code) + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + for base in node.bases: + if isinstance(base, ast.Name) and base.id == class_name: + subclasses.append(f"{fp.name}: class {node.name}({class_name})") + elif isinstance(base, ast.Attribute) and base.attr == class_name: + subclasses.append(f"{fp.name}: class {node.name}({base.value.id}.{class_name})") + except Exception: + pass + try: + if p.is_file(): + _search_file(p) + else: + for root, dirs, files in os.walk(p): + dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('__pycache__', 'venv', 'env')] + for file in files: + if file.endswith('.py'): + _search_file(Path(root) / file) + if not subclasses: + return f"No subclasses of '{class_name}' found in {p}" + return f"Subclasses of '{class_name}':\n" + "\n".join(f" - {s}" for s in subclasses) + except Exception as e: + return f"ERROR finding subclasses of '{class_name}': {e}" def py_get_docstring(path: str, name: str) -> str: - """Extracts the docstring for a specific module, class, or function.""" - p, err = _resolve_and_check(path) - if err: return err - if not p.is_file() or p.suffix != ".py": return f"ERROR: not a python file: {path}" - try: - import ast - code = p.read_text(encoding="utf-8") - tree = ast.parse(code) - if not name or name == "module": - doc = ast.get_docstring(tree) - return doc if doc else "No module docstring found." - - node = _get_symbol_node(tree, name) - if not node: return f"ERROR: could not find symbol '{name}' in {path}" - doc = ast.get_docstring(node) - return doc if doc else f"No docstring found for '{name}'." - except Exception as e: - return f"ERROR getting docstring for '{name}': {e}" + """Extracts the docstring for a specific module, class, or function.""" + p, err = _resolve_and_check(path) + if err: return err + if not p.is_file() or p.suffix != ".py": return f"ERROR: not a python file: {path}" + try: + import ast + code = p.read_text(encoding="utf-8") + tree = ast.parse(code) + if not name or name == "module": + doc = ast.get_docstring(tree) + return doc if doc else "No module docstring found." + node = _get_symbol_node(tree, name) + if not node: return f"ERROR: could not find symbol '{name}' in {path}" + doc = ast.get_docstring(node) + return doc if doc else f"No docstring found for '{name}'." + except Exception as e: + return f"ERROR getting docstring for '{name}': {e}" def get_tree(path: str, max_depth: int = 2) -> str: - """Returns a directory structure up to a max depth.""" - p, err = _resolve_and_check(path) - if err: return err - if not p.is_dir(): return f"ERROR: not a directory: {path}" - - try: - max_depth = int(max_depth) - def _build_tree(dir_path, current_depth, prefix=""): - if current_depth > max_depth: return [] - lines = [] - try: - entries = sorted(dir_path.iterdir(), key=lambda e: (e.is_file(), e.name.lower())) - except PermissionError: - return [] - - # Filter - entries = [e for e in entries if not e.name.startswith('.') and e.name not in ('__pycache__', 'venv', 'env') and e.name != "history.toml" and not e.name.endswith("_history.toml")] - - for i, entry in enumerate(entries): - is_last = (i == len(entries) - 1) - connector = "└── " if is_last else "├── " - lines.append(f"{prefix}{connector}{entry.name}") - if entry.is_dir(): - extension = " " if is_last else "│ " - lines.extend(_build_tree(entry, current_depth + 1, prefix + extension)) - return lines + """Returns a directory structure up to a max depth.""" + p, err = _resolve_and_check(path) + if err: return err + if not p.is_dir(): return f"ERROR: not a directory: {path}" + try: + max_depth = int(max_depth) - tree_lines = [f"{p.name}/"] + _build_tree(p, 1) - return "\n".join(tree_lines) - except Exception as e: - return f"ERROR generating tree for '{path}': {e}" - -# ------------------------------------------------------------------ web tools + def _build_tree(dir_path, current_depth, prefix=""): + if current_depth > max_depth: return [] + lines = [] + try: + entries = sorted(dir_path.iterdir(), key=lambda e: (e.is_file(), e.name.lower())) + except PermissionError: + return [] + # Filter + entries = [e for e in entries if not e.name.startswith('.') and e.name not in ('__pycache__', 'venv', 'env') and e.name != "history.toml" and not e.name.endswith("_history.toml")] + for i, entry in enumerate(entries): + is_last = (i == len(entries) - 1) + connector = "└── " if is_last else "├── " + lines.append(f"{prefix}{connector}{entry.name}") + if entry.is_dir(): + extension = " " if is_last else "│ " + lines.extend(_build_tree(entry, current_depth + 1, prefix + extension)) + return lines + tree_lines = [f"{p.name}/"] + _build_tree(p, 1) + return "\n".join(tree_lines) + except Exception as e: + return f"ERROR generating tree for '{path}': {e}" + # ------------------------------------------------------------------ web tools class _DDGParser(HTMLParser): def __init__(self) -> None: @@ -1306,4 +1298,4 @@ MCP_TOOL_SPECS: list[dict[str, Any]] = [ "required": ["path"] } } -] \ No newline at end of file +] diff --git a/project_manager.py b/project_manager.py index d7d58ae..d6b844b 100644 --- a/project_manager.py +++ b/project_manager.py @@ -16,14 +16,17 @@ from pathlib import Path if TYPE_CHECKING: from models import TrackState TS_FMT: str = "%Y-%m-%dT%H:%M:%S" + def now_ts() -> str: return datetime.datetime.now().strftime(TS_FMT) + def parse_ts(s: str) -> Optional[datetime.datetime]: try: return datetime.datetime.strptime(s, TS_FMT) except Exception: return None -# ── entry serialisation ────────────────────────────────────────────────────── + # ── entry serialisation ────────────────────────────────────────────────────── + def entry_to_str(entry: dict[str, Any]) -> str: """Serialise a disc entry dict -> stored string.""" ts = entry.get("ts", "") @@ -32,6 +35,7 @@ def entry_to_str(entry: dict[str, Any]) -> str: if ts: return f"@{ts}\n{role}:\n{content}" return f"{role}:\n{content}" + def str_to_entry(raw: str, roles: list[str]) -> dict[str, Any]: """Parse a stored string back to a disc entry dict.""" ts = "" @@ -56,7 +60,8 @@ def str_to_entry(raw: str, roles: list[str]) -> dict[str, Any]: matched_role = next((r for r in known if r.lower() == raw_role.lower()), raw_role) content = parts[1].strip() if len(parts) > 1 else "" return {"role": matched_role, "content": content, "collapsed": False, "ts": ts} -# ── git helpers ────────────────────────────────────────────────────────────── + # ── git helpers ────────────────────────────────────────────────────────────── + def get_git_commit(git_dir: str) -> str: try: r = subprocess.run( @@ -66,6 +71,7 @@ def get_git_commit(git_dir: str) -> str: return r.stdout.strip() if r.returncode == 0 else "" except Exception: return "" + def get_git_log(git_dir: str, n: int = 5) -> str: try: r = subprocess.run( @@ -75,9 +81,11 @@ def get_git_log(git_dir: str, n: int = 5) -> str: return r.stdout.strip() if r.returncode == 0 else "" except Exception: return "" -# ── default structures ─────────────────────────────────────────────────────── + # ── default structures ─────────────────────────────────────────────────────── + def default_discussion() -> dict[str, Any]: return {"git_commit": "", "last_updated": now_ts(), "history": []} + def default_project(name: str = "unnamed") -> dict[str, Any]: return { "project": {"name": name, "git_dir": "", "system_prompt": "", "main_context": ""}, @@ -108,11 +116,13 @@ def default_project(name: str = "unnamed") -> dict[str, Any]: "tracks": [] } } -# ── load / save ────────────────────────────────────────────────────────────── + # ── load / save ────────────────────────────────────────────────────────────── + def get_history_path(project_path: Union[str, Path]) -> Path: """Return the Path to the sibling history TOML file for a given project.""" p = Path(project_path) return p.parent / f"{p.stem}_history.toml" + def load_project(path: Union[str, Path]) -> dict[str, Any]: """ Load a project TOML file. @@ -131,6 +141,7 @@ def load_project(path: Union[str, Path]) -> dict[str, Any]: if hist_path.exists(): proj["discussion"] = load_history(path) return proj + def load_history(project_path: Union[str, Path]) -> dict[str, Any]: """Load the segregated discussion history from its dedicated TOML file.""" hist_path = get_history_path(project_path) @@ -138,6 +149,7 @@ def load_history(project_path: Union[str, Path]) -> dict[str, Any]: with open(hist_path, "rb") as f: return tomllib.load(f) return {} + def clean_nones(data: Any) -> Any: """Recursively remove None values from a dictionary/list.""" if isinstance(data, dict): @@ -145,6 +157,7 @@ def clean_nones(data: Any) -> Any: elif isinstance(data, list): return [clean_nones(v) for v in data if v is not None] return data + def save_project(proj: dict[str, Any], path: Union[str, Path], disc_data: Optional[dict[str, Any]] = None) -> None: """ Save the project TOML. @@ -163,7 +176,8 @@ def save_project(proj: dict[str, Any], path: Union[str, Path], disc_data: Option hist_path = get_history_path(path) with open(hist_path, "wb") as f: tomli_w.dump(disc_data, f) -# ── migration helper ───────────────────────────────────────────────────────── + # ── migration helper ───────────────────────────────────────────────────────── + def migrate_from_legacy_config(cfg: dict[str, Any]) -> dict[str, Any]: """Build a fresh project dict from a legacy flat config.toml. Does NOT save.""" name = cfg.get("output", {}).get("namespace", "project") @@ -177,7 +191,8 @@ def migrate_from_legacy_config(cfg: dict[str, Any]) -> dict[str, Any]: main_disc["history"] = disc.get("history", []) main_disc["last_updated"] = now_ts() return proj -# ── flat config for aggregate.run() ───────────────────────────────────────── + # ── flat config for aggregate.run() ───────────────────────────────────────── + def flat_config(proj: dict[str, Any], disc_name: Optional[str] = None, track_id: Optional[str] = None) -> dict[str, Any]: """Return a flat config dict compatible with aggregate.run().""" disc_sec = proj.get("discussion", {}) @@ -197,7 +212,8 @@ def flat_config(proj: dict[str, Any], disc_name: Optional[str] = None, track_id: "history": history, }, } -# ── track state persistence ───────────────────────────────────────────────── + # ── track state persistence ───────────────────────────────────────────────── + def save_track_state(track_id: str, state: 'TrackState', base_dir: Union[str, Path] = ".") -> None: """ Saves a TrackState object to conductor/tracks//state.toml. @@ -208,6 +224,7 @@ def save_track_state(track_id: str, state: 'TrackState', base_dir: Union[str, Pa data = clean_nones(state.to_dict()) with open(state_file, "wb") as f: tomli_w.dump(data, f) + def load_track_state(track_id: str, base_dir: Union[str, Path] = ".") -> Optional['TrackState']: """ Loads a TrackState object from conductor/tracks//state.toml. @@ -219,6 +236,7 @@ def load_track_state(track_id: str, base_dir: Union[str, Path] = ".") -> Optiona with open(state_file, "rb") as f: data = tomllib.load(f) return TrackState.from_dict(data) + def load_track_history(track_id: str, base_dir: Union[str, Path] = ".") -> list[str]: """ Loads the discussion history for a specific track from its state.toml. @@ -236,6 +254,7 @@ def load_track_history(track_id: str, base_dir: Union[str, Path] = ".") -> list[ e["ts"] = ts.strftime(TS_FMT) history.append(entry_to_str(e)) return history + def save_track_history(track_id: str, history: list[str], base_dir: Union[str, Path] = ".") -> None: """ Saves the discussion history for a specific track to its state.toml. @@ -249,6 +268,7 @@ def save_track_history(track_id: str, history: list[str], base_dir: Union[str, P entries = [str_to_entry(h, roles) for h in history] state.discussion = entries save_track_state(track_id, state, base_dir) + def get_all_tracks(base_dir: Union[str, Path] = ".") -> list[dict[str, Any]]: """ Scans the conductor/tracks/ directory and returns a list of dictionaries diff --git a/scripts/apply_type_hints.py b/scripts/apply_type_hints.py index 9ff1bf2..7b9e8c9 100644 --- a/scripts/apply_type_hints.py +++ b/scripts/apply_type_hints.py @@ -29,6 +29,7 @@ def has_value_return(node: ast.AST) -> bool: def collect_auto_none(tree: ast.Module) -> list[tuple[str, ast.AST]]: """Collect functions that can safely get -> None annotation.""" results = [] + def scan(scope, prefix=""): for node in ast.iter_child_nodes(scope): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): @@ -61,9 +62,9 @@ def apply_return_none_single_pass(filepath: str) -> int: for name, node in candidates: if not node.body: continue - # The colon is on the last line of the signature - # For single-line defs: `def foo(self):` -> colon at end - # For multi-line defs: last line ends with `):` or similar + # The colon is on the last line of the signature + # For single-line defs: `def foo(self):` -> colon at end + # For multi-line defs: last line ends with `):` or similar body_start = node.body[0].lineno # 1-indexed sig_last_line_idx = body_start - 2 # 0-indexed, the line before body # But for single-line signatures, sig_last_line_idx == node.lineno - 1 @@ -96,11 +97,11 @@ def apply_return_none_single_pass(filepath: str) -> int: if colon_idx < 0: stats["errors"].append(f"no colon found: {filepath}:{name} L{sig_last_line_idx+1}") continue - # Check not already annotated + # Check not already annotated if '->' in code_part: continue edits.append((sig_last_line_idx, colon_idx)) - # Apply edits in reverse order to preserve line indices + # Apply edits in reverse order to preserve line indices edits.sort(key=lambda x: x[0], reverse=True) count = 0 for line_idx, colon_col in edits: @@ -111,11 +112,10 @@ def apply_return_none_single_pass(filepath: str) -> int: with open(fp, 'w', encoding='utf-8', newline='') as f: f.writelines(lines) return count - -# --- Manual signature replacements --- -# These use regex on the def line to do a targeted replacement. -# Each entry: (dotted_name, old_params_pattern, new_full_sig_line) -# We match by finding the exact def line and replacing it. + # --- Manual signature replacements --- + # These use regex on the def line to do a targeted replacement. + # Each entry: (dotted_name, old_params_pattern, new_full_sig_line) + # We match by finding the exact def line and replacing it. def apply_manual_sigs(filepath: str, sig_replacements: list[tuple[str, str]]) -> int: """Apply manual signature replacements. @@ -164,10 +164,9 @@ def verify_syntax(filepath: str) -> str: return f"Syntax OK: {filepath}" except SyntaxError as e: return f"SyntaxError in {filepath} at line {e.lineno}: {e.msg}" - -# ============================================================ -# gui_2.py manual signatures (Tier 3 items) -# ============================================================ + # ============================================================ + # gui_2.py manual signatures (Tier 3 items) + # ============================================================ GUI2_MANUAL_SIGS: list[tuple[str, str]] = [ (r'def resolve_pending_action\(self, action_id: str, approved: bool\):', r'def resolve_pending_action(self, action_id: str, approved: bool) -> bool:'), @@ -281,7 +280,6 @@ if __name__ == "__main__": n = apply_return_none_single_pass("gui_legacy.py") stats["auto_none"] += n print(f" gui_legacy.py: {n} applied") - # Verify syntax after Phase A for f in ["gui_2.py", "gui_legacy.py"]: r = verify_syntax(f) @@ -289,7 +287,6 @@ if __name__ == "__main__": print(f" ABORT: {r}") sys.exit(1) print(" Syntax OK after Phase A") - print("\n=== Phase B: Manual signatures (regex) ===") n = apply_manual_sigs("gui_2.py", GUI2_MANUAL_SIGS) stats["manual_sig"] += n @@ -297,7 +294,6 @@ if __name__ == "__main__": n = apply_manual_sigs("gui_legacy.py", LEGACY_MANUAL_SIGS) stats["manual_sig"] += n print(f" gui_legacy.py: {n} applied") - # Verify syntax after Phase B for f in ["gui_2.py", "gui_legacy.py"]: r = verify_syntax(f) @@ -305,9 +301,9 @@ if __name__ == "__main__": print(f" ABORT: {r}") sys.exit(1) print(" Syntax OK after Phase B") - print("\n=== Phase C: Variable annotations (regex) ===") # Use re.MULTILINE so ^ matches line starts + def apply_var_replacements_m(filepath, replacements): fp = abs_path(filepath) with open(fp, 'r', encoding='utf-8') as f: @@ -323,14 +319,12 @@ if __name__ == "__main__": with open(fp, 'w', encoding='utf-8', newline='') as f: f.write(code) return count - n = apply_var_replacements_m("gui_2.py", GUI2_VAR_REPLACEMENTS) stats["vars"] += n print(f" gui_2.py: {n} applied") n = apply_var_replacements_m("gui_legacy.py", LEGACY_VAR_REPLACEMENTS) stats["vars"] += n print(f" gui_legacy.py: {n} applied") - print("\n=== Final Syntax Verification ===") all_ok = True for f in ["gui_2.py", "gui_legacy.py"]: @@ -338,7 +332,6 @@ if __name__ == "__main__": print(f" {f}: {r}") if "Error" in r: all_ok = False - print(f"\n=== Summary ===") print(f" Auto -> None: {stats['auto_none']}") print(f" Manual sigs: {stats['manual_sig']}") diff --git a/scripts/claude_mma_exec.py b/scripts/claude_mma_exec.py index 028bd96..18fe8cf 100644 --- a/scripts/claude_mma_exec.py +++ b/scripts/claude_mma_exec.py @@ -11,279 +11,256 @@ import tree_sitter_python LOG_FILE: str = 'logs/claude_mma_delegation.log' MODEL_MAP: dict[str, str] = { - 'tier1-orchestrator': 'claude-opus-4-6', - 'tier1': 'claude-opus-4-6', - 'tier2-tech-lead': 'claude-sonnet-4-6', - 'tier2': 'claude-sonnet-4-6', - 'tier3-worker': 'claude-sonnet-4-6', - 'tier3': 'claude-sonnet-4-6', - 'tier4-qa': 'claude-haiku-4-5', - 'tier4': 'claude-haiku-4-5', + 'tier1-orchestrator': 'claude-opus-4-6', + 'tier1': 'claude-opus-4-6', + 'tier2-tech-lead': 'claude-sonnet-4-6', + 'tier2': 'claude-sonnet-4-6', + 'tier3-worker': 'claude-sonnet-4-6', + 'tier3': 'claude-sonnet-4-6', + 'tier4-qa': 'claude-haiku-4-5', + 'tier4': 'claude-haiku-4-5', } - def generate_skeleton(code: str) -> str: - """ + """ Parses Python code and replaces function/method bodies with '...', preserving docstrings if present. """ - try: - PY_LANGUAGE = tree_sitter.Language(tree_sitter_python.language()) - parser = tree_sitter.Parser(PY_LANGUAGE) - tree = parser.parse(bytes(code, "utf8")) - edits = [] + try: + PY_LANGUAGE = tree_sitter.Language(tree_sitter_python.language()) + parser = tree_sitter.Parser(PY_LANGUAGE) + tree = parser.parse(bytes(code, "utf8")) + edits = [] - def is_docstring(node): - if node.type == "expression_statement" and node.child_count > 0: - if node.children[0].type == "string": - return True - return False - - def walk(node): - if node.type == "function_definition": - body = node.child_by_field_name("body") - if body and body.type == "block": - indent = " " * body.start_point.column - first_stmt = None - for child in body.children: - if child.type != "comment": - first_stmt = child - break - if first_stmt and is_docstring(first_stmt): - start_byte = first_stmt.end_byte - end_byte = body.end_byte - if end_byte > start_byte: - edits.append((start_byte, end_byte, f"\n{indent}...")) - else: - start_byte = body.start_byte - end_byte = body.end_byte - edits.append((start_byte, end_byte, "...")) - for child in node.children: - walk(child) - - walk(tree.root_node) - edits.sort(key=lambda x: x[0], reverse=True) - code_bytes = bytearray(code, "utf8") - for start, end, replacement in edits: - code_bytes[start:end] = bytes(replacement, "utf8") - return code_bytes.decode("utf8") - except Exception as e: - return f"# Error generating skeleton: {e}\n{code}" + def is_docstring(node): + if node.type == "expression_statement" and node.child_count > 0: + if node.children[0].type == "string": + return True + return False + def walk(node): + if node.type == "function_definition": + body = node.child_by_field_name("body") + if body and body.type == "block": + indent = " " * body.start_point.column + first_stmt = None + for child in body.children: + if child.type != "comment": + first_stmt = child + break + if first_stmt and is_docstring(first_stmt): + start_byte = first_stmt.end_byte + end_byte = body.end_byte + if end_byte > start_byte: + edits.append((start_byte, end_byte, f"\n{indent}...")) + else: + start_byte = body.start_byte + end_byte = body.end_byte + edits.append((start_byte, end_byte, "...")) + for child in node.children: + walk(child) + walk(tree.root_node) + edits.sort(key=lambda x: x[0], reverse=True) + code_bytes = bytearray(code, "utf8") + for start, end, replacement in edits: + code_bytes[start:end] = bytes(replacement, "utf8") + return code_bytes.decode("utf8") + except Exception as e: + return f"# Error generating skeleton: {e}\n{code}" def get_model_for_role(role: str) -> str: - """Returns the Claude model to use for a given tier role.""" - return MODEL_MAP.get(role, 'claude-haiku-4-5') - + """Returns the Claude model to use for a given tier role.""" + return MODEL_MAP.get(role, 'claude-haiku-4-5') def get_role_documents(role: str) -> list[str]: - if role in ('tier1-orchestrator', 'tier1'): - return ['conductor/product.md', 'conductor/product-guidelines.md'] - elif role in ('tier2-tech-lead', 'tier2'): - return ['conductor/tech-stack.md', 'conductor/workflow.md'] - elif role in ('tier3-worker', 'tier3'): - return ['conductor/workflow.md'] - return [] - + if role in ('tier1-orchestrator', 'tier1'): + return ['conductor/product.md', 'conductor/product-guidelines.md'] + elif role in ('tier2-tech-lead', 'tier2'): + return ['conductor/tech-stack.md', 'conductor/workflow.md'] + elif role in ('tier3-worker', 'tier3'): + return ['conductor/workflow.md'] + return [] def log_delegation(role: str, full_prompt: str, result: str | None = None, summary_prompt: str | None = None) -> str: - os.makedirs('logs/claude_agents', exist_ok=True) - timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') - log_file = f'logs/claude_agents/claude_{role}_task_{timestamp}.log' - with open(log_file, 'w', encoding='utf-8') as f: - f.write("==================================================\n") - f.write(f"ROLE: {role}\n") - f.write(f"TIMESTAMP: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") - f.write("--------------------------------------------------\n") - f.write(f"FULL PROMPT:\n{full_prompt}\n") - f.write("--------------------------------------------------\n") - if result: - f.write(f"RESULT:\n{result}\n") - f.write("==================================================\n") - os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) - display_prompt = summary_prompt if summary_prompt else full_prompt - with open(LOG_FILE, 'a', encoding='utf-8') as f: - f.write(f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {role}: {display_prompt[:100]}... (Log: {log_file})\n") - return log_file - + os.makedirs('logs/claude_agents', exist_ok=True) + timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + log_file = f'logs/claude_agents/claude_{role}_task_{timestamp}.log' + with open(log_file, 'w', encoding='utf-8') as f: + f.write("==================================================\n") + f.write(f"ROLE: {role}\n") + f.write(f"TIMESTAMP: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write("--------------------------------------------------\n") + f.write(f"FULL PROMPT:\n{full_prompt}\n") + f.write("--------------------------------------------------\n") + if result: + f.write(f"RESULT:\n{result}\n") + f.write("==================================================\n") + os.makedirs(os.path.dirname(LOG_FILE), exist_ok=True) + display_prompt = summary_prompt if summary_prompt else full_prompt + with open(LOG_FILE, 'a', encoding='utf-8') as f: + f.write(f"[{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] {role}: {display_prompt[:100]}... (Log: {log_file})\n") + return log_file def get_dependencies(filepath: str) -> list[str]: - """Identify top-level module imports from a Python file.""" - try: - with open(filepath, 'r', encoding='utf-8') as f: - tree = ast.parse(f.read()) - dependencies = [] - for node in tree.body: - if isinstance(node, ast.Import): - for alias in node.names: - dependencies.append(alias.name.split('.')[0]) - elif isinstance(node, ast.ImportFrom): - if node.module: - dependencies.append(node.module.split('.')[0]) - seen = set() - result = [] - for d in dependencies: - if d not in seen: - result.append(d) - seen.add(d) - return result - except Exception as e: - print(f"Error getting dependencies for {filepath}: {e}") - return [] - + """Identify top-level module imports from a Python file.""" + try: + with open(filepath, 'r', encoding='utf-8') as f: + tree = ast.parse(f.read()) + dependencies = [] + for node in tree.body: + if isinstance(node, ast.Import): + for alias in node.names: + dependencies.append(alias.name.split('.')[0]) + elif isinstance(node, ast.ImportFrom): + if node.module: + dependencies.append(node.module.split('.')[0]) + seen = set() + result = [] + for d in dependencies: + if d not in seen: + result.append(d) + seen.add(d) + return result + except Exception as e: + print(f"Error getting dependencies for {filepath}: {e}") + return [] def execute_agent(role: str, prompt: str, docs: list[str]) -> str: - model = get_model_for_role(role) - - # Advanced Context: Dependency skeletons for Tier 3 - injected_context = "" - UNFETTERED_MODULES: list[str] = ['mcp_client', 'project_manager', 'events', 'aggregate'] - - if role in ['tier3', 'tier3-worker']: - for doc in docs: - if doc.endswith('.py') and os.path.exists(doc): - deps = get_dependencies(doc) - for dep in deps: - dep_file = f"{dep}.py" - if dep_file in docs: - continue - if os.path.exists(dep_file) and dep_file != doc: - try: - if dep in UNFETTERED_MODULES: - with open(dep_file, 'r', encoding='utf-8') as f: - full_content = f.read() - injected_context += f"\n\nFULL MODULE CONTEXT: {dep_file}\n{full_content}\n" - else: - with open(dep_file, 'r', encoding='utf-8') as f: - skeleton = generate_skeleton(f.read()) - injected_context += f"\n\nDEPENDENCY SKELETON: {dep_file}\n{skeleton}\n" - except Exception as e: - print(f"Error gathering context for {dep_file}: {e}") - if len(injected_context) > 15000: - injected_context = injected_context[:15000] + "... [TRUNCATED FOR COMMAND LINE LIMITS]" - - # MMA Protocol: Tier 3 and 4 are stateless. Build system directive. - if role in ['tier3', 'tier3-worker']: - system_directive = ( - "STRICT SYSTEM DIRECTIVE: You are a stateless Tier 3 Worker (Contributor). " - "Your goal is to implement specific code changes or tests based on the provided task. " - "You have access to tools for reading and writing files (Read, Write, Edit), " - "codebase investigation (Glob, Grep), " - "version control (Bash git commands), and web tools (WebFetch, WebSearch). " - "You CAN execute PowerShell scripts via Bash for verification and testing. " - "Follow TDD and return success status or code changes. No pleasantries, no conversational filler." - ) - elif role in ['tier4', 'tier4-qa']: - system_directive = ( - "STRICT SYSTEM DIRECTIVE: You are a stateless Tier 4 QA Agent. " - "Your goal is to analyze errors, summarize logs, or verify tests. " - "You have access to tools for reading files and exploring the codebase (Read, Glob, Grep). " - "You CAN execute PowerShell scripts via Bash (read-only) for diagnostics. " - "ONLY output the requested analysis. No pleasantries." - ) - else: - system_directive = ( - f"STRICT SYSTEM DIRECTIVE: You are a stateless {role}. " - "ONLY output the requested text. No pleasantries." - ) - - command_text = f"{system_directive}\n\n{injected_context}\n\n" - - # Inline documents to ensure sub-agent has context in headless mode - for doc in docs: - if os.path.exists(doc): - try: - with open(doc, 'r', encoding='utf-8') as f: - content = f.read() - command_text += f"\n\nFILE CONTENT: {doc}\n{content}\n" - except Exception as e: - print(f"Error inlining {doc}: {e}") - - command_text += f"\n\nTASK: {prompt}\n\n" - - # Spawn claude CLI non-interactively via PowerShell - ps_command = ( - "if (Test-Path 'C:\\projects\\misc\\setup_claude.ps1') " - "{ . 'C:\\projects\\misc\\setup_claude.ps1' }; " - f"claude --model {model} --print" - ) - cmd = ['powershell.exe', '-NoProfile', '-Command', ps_command] - - try: - env = os.environ.copy() - env['CLAUDE_CLI_HOOK_CONTEXT'] = 'mma_headless' - process = subprocess.run( - cmd, - input=command_text, - capture_output=True, - text=True, - encoding='utf-8', - env=env - ) - # claude --print outputs plain text — no JSON parsing needed - result = process.stdout if process.stdout else f"Error: {process.stderr}" - log_file = log_delegation(role, command_text, result, summary_prompt=prompt) - print(f"Sub-agent log created: {log_file}") - return result - except Exception as e: - err_msg = f"Execution failed: {str(e)}" - log_delegation(role, command_text, err_msg) - return err_msg - + model = get_model_for_role(role) + # Advanced Context: Dependency skeletons for Tier 3 + injected_context = "" + UNFETTERED_MODULES: list[str] = ['mcp_client', 'project_manager', 'events', 'aggregate'] + if role in ['tier3', 'tier3-worker']: + for doc in docs: + if doc.endswith('.py') and os.path.exists(doc): + deps = get_dependencies(doc) + for dep in deps: + dep_file = f"{dep}.py" + if dep_file in docs: + continue + if os.path.exists(dep_file) and dep_file != doc: + try: + if dep in UNFETTERED_MODULES: + with open(dep_file, 'r', encoding='utf-8') as f: + full_content = f.read() + injected_context += f"\n\nFULL MODULE CONTEXT: {dep_file}\n{full_content}\n" + else: + with open(dep_file, 'r', encoding='utf-8') as f: + skeleton = generate_skeleton(f.read()) + injected_context += f"\n\nDEPENDENCY SKELETON: {dep_file}\n{skeleton}\n" + except Exception as e: + print(f"Error gathering context for {dep_file}: {e}") + if len(injected_context) > 15000: + injected_context = injected_context[:15000] + "... [TRUNCATED FOR COMMAND LINE LIMITS]" + # MMA Protocol: Tier 3 and 4 are stateless. Build system directive. + if role in ['tier3', 'tier3-worker']: + system_directive = ( + "STRICT SYSTEM DIRECTIVE: You are a stateless Tier 3 Worker (Contributor). " + "Your goal is to implement specific code changes or tests based on the provided task. " + "You have access to tools for reading and writing files (Read, Write, Edit), " + "codebase investigation (Glob, Grep), " + "version control (Bash git commands), and web tools (WebFetch, WebSearch). " + "You CAN execute PowerShell scripts via Bash for verification and testing. " + "Follow TDD and return success status or code changes. No pleasantries, no conversational filler." + ) + elif role in ['tier4', 'tier4-qa']: + system_directive = ( + "STRICT SYSTEM DIRECTIVE: You are a stateless Tier 4 QA Agent. " + "Your goal is to analyze errors, summarize logs, or verify tests. " + "You have access to tools for reading files and exploring the codebase (Read, Glob, Grep). " + "You CAN execute PowerShell scripts via Bash (read-only) for diagnostics. " + "ONLY output the requested analysis. No pleasantries." + ) + else: + system_directive = ( + f"STRICT SYSTEM DIRECTIVE: You are a stateless {role}. " + "ONLY output the requested text. No pleasantries." + ) + command_text = f"{system_directive}\n\n{injected_context}\n\n" + # Inline documents to ensure sub-agent has context in headless mode + for doc in docs: + if os.path.exists(doc): + try: + with open(doc, 'r', encoding='utf-8') as f: + content = f.read() + command_text += f"\n\nFILE CONTENT: {doc}\n{content}\n" + except Exception as e: + print(f"Error inlining {doc}: {e}") + command_text += f"\n\nTASK: {prompt}\n\n" + # Spawn claude CLI non-interactively via PowerShell + ps_command = ( + "if (Test-Path 'C:\\projects\\misc\\setup_claude.ps1') " + "{ . 'C:\\projects\\misc\\setup_claude.ps1' }; " + f"claude --model {model} --print" + ) + cmd = ['powershell.exe', '-NoProfile', '-Command', ps_command] + try: + env = os.environ.copy() + env['CLAUDE_CLI_HOOK_CONTEXT'] = 'mma_headless' + process = subprocess.run( + cmd, + input=command_text, + capture_output=True, + text=True, + encoding='utf-8', + env=env + ) + # claude --print outputs plain text — no JSON parsing needed + result = process.stdout if process.stdout else f"Error: {process.stderr}" + log_file = log_delegation(role, command_text, result, summary_prompt=prompt) + print(f"Sub-agent log created: {log_file}") + return result + except Exception as e: + err_msg = f"Execution failed: {str(e)}" + log_delegation(role, command_text, err_msg) + return err_msg def create_parser() -> argparse.ArgumentParser: - parser = argparse.ArgumentParser(description="Claude MMA Execution Script") - parser.add_argument( - "--role", - choices=['tier1', 'tier2', 'tier3', 'tier4', - 'tier1-orchestrator', 'tier2-tech-lead', 'tier3-worker', 'tier4-qa'], - help="The tier role to execute" - ) - parser.add_argument( - "--task-file", - type=str, - help="TOML file defining the task" - ) - parser.add_argument( - "prompt", - type=str, - nargs='?', - help="The prompt for the tier (optional if --task-file is used)" - ) - return parser - + parser = argparse.ArgumentParser(description="Claude MMA Execution Script") + parser.add_argument( + "--role", + choices=['tier1', 'tier2', 'tier3', 'tier4', + 'tier1-orchestrator', 'tier2-tech-lead', 'tier3-worker', 'tier4-qa'], + help="The tier role to execute" + ) + parser.add_argument( + "--task-file", + type=str, + help="TOML file defining the task" + ) + parser.add_argument( + "prompt", + type=str, + nargs='?', + help="The prompt for the tier (optional if --task-file is used)" + ) + return parser def main() -> None: - parser = create_parser() - args = parser.parse_args() - role = args.role - prompt = args.prompt - docs = [] - - if args.task_file and os.path.exists(args.task_file): - with open(args.task_file, "rb") as f: - task_data = tomllib.load(f) - role = task_data.get("role", role) - prompt = task_data.get("prompt", prompt) - docs = task_data.get("docs", []) - - if not role or not prompt: - parser.print_help() - return - - if not docs: - docs = get_role_documents(role) - - # Extract @file references from the prompt - file_refs: list[str] = re.findall(r"@([\w./\\]+)", prompt) - for ref in file_refs: - if os.path.exists(ref) and ref not in docs: - docs.append(ref) - - print(f"Executing role: {role} with docs: {docs}") - result = execute_agent(role, prompt, docs) - print(result) - + parser = create_parser() + args = parser.parse_args() + role = args.role + prompt = args.prompt + docs = [] + if args.task_file and os.path.exists(args.task_file): + with open(args.task_file, "rb") as f: + task_data = tomllib.load(f) + role = task_data.get("role", role) + prompt = task_data.get("prompt", prompt) + docs = task_data.get("docs", []) + if not role or not prompt: + parser.print_help() + return + if not docs: + docs = get_role_documents(role) + # Extract @file references from the prompt + file_refs: list[str] = re.findall(r"@([\w./\\]+)", prompt) + for ref in file_refs: + if os.path.exists(ref) and ref not in docs: + docs.append(ref) + print(f"Executing role: {role} with docs: {docs}") + result = execute_agent(role, prompt, docs) + print(result) if __name__ == "__main__": - main() + main() diff --git a/scripts/claude_tool_bridge.py b/scripts/claude_tool_bridge.py index 95f189e..ae70b89 100644 --- a/scripts/claude_tool_bridge.py +++ b/scripts/claude_tool_bridge.py @@ -6,83 +6,72 @@ import os # Add project root to sys.path so we can import api_hook_client project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) if project_root not in sys.path: - sys.path.append(project_root) + sys.path.append(project_root) try: - from api_hook_client import ApiHookClient + from api_hook_client import ApiHookClient except ImportError: - print("FATAL: Failed to import ApiHookClient. Ensure it's in the Python path.", file=sys.stderr) - sys.exit(1) - + print("FATAL: Failed to import ApiHookClient. Ensure it's in the Python path.", file=sys.stderr) + sys.exit(1) def main() -> None: - logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', stream=sys.stderr) - logging.debug("Claude Tool Bridge script started.") - try: - input_data = sys.stdin.read() - if not input_data: - logging.debug("No input received from stdin. Exiting gracefully.") - return - logging.debug(f"Received raw input data: {input_data}") - try: - hook_input = json.loads(input_data) - except json.JSONDecodeError: - logging.error("Failed to decode JSON from stdin.") - print(json.dumps({"decision": "deny", "reason": "Invalid JSON received from stdin."})) - return - - # Claude Code PreToolUse hook format: tool_name + tool_input - tool_name = hook_input.get('tool_name') - tool_input = hook_input.get('tool_input', {}) - - if tool_name is None: - logging.error("Could not determine tool name from input. Expected 'tool_name'.") - print(json.dumps({"decision": "deny", "reason": "Missing 'tool_name' in hook input."})) - return - - if not isinstance(tool_input, dict): - logging.warning(f"tool_input is not a dict: {tool_input}. Treating as empty.") - tool_input = {} - - logging.debug(f"Resolved tool_name: '{tool_name}', tool_input: {tool_input}") - - # Check context — if not running via Manual Slop, pass through - hook_context = os.environ.get("CLAUDE_CLI_HOOK_CONTEXT") - logging.debug(f"Checking CLAUDE_CLI_HOOK_CONTEXT: '{hook_context}'") - - if hook_context == 'mma_headless': - # Sub-agents in headless MMA mode: auto-allow all tools - logging.debug("CLAUDE_CLI_HOOK_CONTEXT is 'mma_headless'. Allowing for sub-agent.") - print(json.dumps({"decision": "allow", "reason": "Sub-agent headless mode (MMA)."})) - return - - if hook_context != 'manual_slop': - # Not a programmatic Manual Slop session — allow through silently - logging.debug(f"CLAUDE_CLI_HOOK_CONTEXT is '{hook_context}', not 'manual_slop'. Allowing.") - print(json.dumps({"decision": "allow", "reason": f"Non-programmatic usage (CLAUDE_CLI_HOOK_CONTEXT={hook_context})."})) - return - - # manual_slop context: route to GUI for approval - logging.debug("CLAUDE_CLI_HOOK_CONTEXT is 'manual_slop'. Routing to API Hook Client.") - client = ApiHookClient(base_url="http://127.0.0.1:8999") - try: - logging.debug(f"Requesting confirmation for tool '{tool_name}' with args: {tool_input}") - response = client.request_confirmation(tool_name, tool_input) - if response and response.get('approved') is True: - logging.debug("User approved tool execution.") - print(json.dumps({"decision": "allow"})) - else: - reason = response.get('reason', 'User rejected tool execution in GUI.') if response else 'No response from GUI.' - logging.debug(f"User denied tool execution. Reason: {reason}") - print(json.dumps({"decision": "deny", "reason": reason})) - except Exception as e: - logging.error(f"API Hook Client error: {str(e)}", exc_info=True) - print(json.dumps({"decision": "deny", "reason": f"Manual Slop hook server unreachable: {str(e)}"})) - - except Exception as e: - logging.error(f"Unexpected error in bridge: {str(e)}", exc_info=True) - print(json.dumps({"decision": "deny", "reason": f"Internal bridge error: {str(e)}"})) - + logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', stream=sys.stderr) + logging.debug("Claude Tool Bridge script started.") + try: + input_data = sys.stdin.read() + if not input_data: + logging.debug("No input received from stdin. Exiting gracefully.") + return + logging.debug(f"Received raw input data: {input_data}") + try: + hook_input = json.loads(input_data) + except json.JSONDecodeError: + logging.error("Failed to decode JSON from stdin.") + print(json.dumps({"decision": "deny", "reason": "Invalid JSON received from stdin."})) + return + # Claude Code PreToolUse hook format: tool_name + tool_input + tool_name = hook_input.get('tool_name') + tool_input = hook_input.get('tool_input', {}) + if tool_name is None: + logging.error("Could not determine tool name from input. Expected 'tool_name'.") + print(json.dumps({"decision": "deny", "reason": "Missing 'tool_name' in hook input."})) + return + if not isinstance(tool_input, dict): + logging.warning(f"tool_input is not a dict: {tool_input}. Treating as empty.") + tool_input = {} + logging.debug(f"Resolved tool_name: '{tool_name}', tool_input: {tool_input}") + # Check context — if not running via Manual Slop, pass through + hook_context = os.environ.get("CLAUDE_CLI_HOOK_CONTEXT") + logging.debug(f"Checking CLAUDE_CLI_HOOK_CONTEXT: '{hook_context}'") + if hook_context == 'mma_headless': + # Sub-agents in headless MMA mode: auto-allow all tools + logging.debug("CLAUDE_CLI_HOOK_CONTEXT is 'mma_headless'. Allowing for sub-agent.") + print(json.dumps({"decision": "allow", "reason": "Sub-agent headless mode (MMA)."})) + return + if hook_context != 'manual_slop': + # Not a programmatic Manual Slop session — allow through silently + logging.debug(f"CLAUDE_CLI_HOOK_CONTEXT is '{hook_context}', not 'manual_slop'. Allowing.") + print(json.dumps({"decision": "allow", "reason": f"Non-programmatic usage (CLAUDE_CLI_HOOK_CONTEXT={hook_context})."})) + return + # manual_slop context: route to GUI for approval + logging.debug("CLAUDE_CLI_HOOK_CONTEXT is 'manual_slop'. Routing to API Hook Client.") + client = ApiHookClient(base_url="http://127.0.0.1:8999") + try: + logging.debug(f"Requesting confirmation for tool '{tool_name}' with args: {tool_input}") + response = client.request_confirmation(tool_name, tool_input) + if response and response.get('approved') is True: + logging.debug("User approved tool execution.") + print(json.dumps({"decision": "allow"})) + else: + reason = response.get('reason', 'User rejected tool execution in GUI.') if response else 'No response from GUI.' + logging.debug(f"User denied tool execution. Reason: {reason}") + print(json.dumps({"decision": "deny", "reason": reason})) + except Exception as e: + logging.error(f"API Hook Client error: {str(e)}", exc_info=True) + print(json.dumps({"decision": "deny", "reason": f"Manual Slop hook server unreachable: {str(e)}"})) + except Exception as e: + logging.error(f"Unexpected error in bridge: {str(e)}", exc_info=True) + print(json.dumps({"decision": "deny", "reason": f"Internal bridge error: {str(e)}"})) if __name__ == "__main__": - main() + main() diff --git a/scripts/inject_tools.py b/scripts/inject_tools.py index b769541..dddd28d 100644 --- a/scripts/inject_tools.py +++ b/scripts/inject_tools.py @@ -2,13 +2,11 @@ import os import re with open('mcp_client.py', 'r', encoding='utf-8') as f: - content: str = f.read() - -# 1. Add import os if not there + content: str = f.read() + # 1. Add import os if not there if 'import os' not in content: - content: str = content.replace('import summarize', 'import os\nimport summarize') - -# 2. Add the functions before "# ------------------------------------------------------------------ web tools" + content: str = content.replace('import summarize', 'import os\nimport summarize') + # 2. Add the functions before "# ------------------------------------------------------------------ web tools" functions_code: str = r''' def py_find_usages(path: str, name: str) -> str: """Finds exact string matches of a symbol in a given file or directory.""" @@ -184,11 +182,10 @@ content: str = content.replace('# ---------------------------------------------- # 3. Update TOOL_NAMES old_tool_names_match: re.Match | None = re.search(r'TOOL_NAMES\s*=\s*\{([^}]*)\}', content) if old_tool_names_match: - old_names: str = old_tool_names_match.group(1) - new_names: str = old_names + ', "py_find_usages", "py_get_imports", "py_check_syntax", "py_get_hierarchy", "py_get_docstring", "get_tree"' - content: str = content.replace(old_tool_names_match.group(0), f'TOOL_NAMES = {{{new_names}}}') - -# 4. Update dispatch + old_names: str = old_tool_names_match.group(1) + new_names: str = old_names + ', "py_find_usages", "py_get_imports", "py_check_syntax", "py_get_hierarchy", "py_get_docstring", "get_tree"' + content: str = content.replace(old_tool_names_match.group(0), f'TOOL_NAMES = {{{new_names}}}') + # 4. Update dispatch dispatch_additions: str = r''' if tool_name == "py_find_usages": return py_find_usages(tool_input.get("path", ""), tool_input.get("name", "")) @@ -205,7 +202,7 @@ dispatch_additions: str = r''' return f"ERROR: unknown MCP tool '{tool_name}'" ''' content: str = re.sub( -r' return f"ERROR: unknown MCP tool \'{tool_name}\'"', dispatch_additions.strip(), content) + r' return f"ERROR: unknown MCP tool \'{tool_name}\'"', dispatch_additions.strip(), content) # 5. Update MCP_TOOL_SPECS mcp_tool_specs_addition: str = r''' @@ -283,9 +280,9 @@ mcp_tool_specs_addition: str = r''' ''' content: str = re.sub( -r'\]\s*$', mcp_tool_specs_addition.strip(), content) + r'\]\s*$', mcp_tool_specs_addition.strip(), content) with open('mcp_client.py', 'w', encoding='utf-8') as f: - f.write(content) + f.write(content) print("Injected new tools.") diff --git a/scripts/mcp_server.py b/scripts/mcp_server.py index 674ad42..906db76 100644 --- a/scripts/mcp_server.py +++ b/scripts/mcp_server.py @@ -26,69 +26,65 @@ from mcp.types import Tool, TextContent # run_powershell is handled by shell_runner, not mcp_client.dispatch() # Define its spec here since it's not in MCP_TOOL_SPECS RUN_POWERSHELL_SPEC = { - "name": "run_powershell", - "description": ( - "Run a PowerShell script within the project base directory. " - "Returns combined stdout, stderr, and exit code. " - "60-second timeout. Use for builds, tests, and system commands." - ), - "parameters": { - "type": "object", - "properties": { - "script": { - "type": "string", - "description": "PowerShell script content to execute." - } - }, - "required": ["script"] - } + "name": "run_powershell", + "description": ( + "Run a PowerShell script within the project base directory. " + "Returns combined stdout, stderr, and exit code. " + "60-second timeout. Use for builds, tests, and system commands." + ), + "parameters": { + "type": "object", + "properties": { + "script": { + "type": "string", + "description": "PowerShell script content to execute." + } + }, + "required": ["script"] + } } server = Server("manual-slop-tools") - @server.list_tools() async def list_tools() -> list[Tool]: - tools = [] - for spec in mcp_client.MCP_TOOL_SPECS: - tools.append(Tool( - name=spec["name"], - description=spec["description"], - inputSchema=spec["parameters"], - )) - # Add run_powershell - tools.append(Tool( - name=RUN_POWERSHELL_SPEC["name"], - description=RUN_POWERSHELL_SPEC["description"], - inputSchema=RUN_POWERSHELL_SPEC["parameters"], - )) - return tools - + tools = [] + for spec in mcp_client.MCP_TOOL_SPECS: + tools.append(Tool( + name=spec["name"], + description=spec["description"], + inputSchema=spec["parameters"], + )) + # Add run_powershell + tools.append(Tool( + name=RUN_POWERSHELL_SPEC["name"], + description=RUN_POWERSHELL_SPEC["description"], + inputSchema=RUN_POWERSHELL_SPEC["parameters"], + )) + return tools @server.call_tool() async def call_tool(name: str, arguments: dict) -> list[TextContent]: - try: - if name == "run_powershell": - script = arguments.get("script", "") - result = shell_runner.run_powershell(script, os.getcwd()) - else: - result = mcp_client.dispatch(name, arguments) - return [TextContent(type="text", text=str(result))] - except Exception as e: - return [TextContent(type="text", text=f"ERROR: {e}")] - + try: + if name == "run_powershell": + script = arguments.get("script", "") + result = shell_runner.run_powershell(script, os.getcwd()) + else: + result = mcp_client.dispatch(name, arguments) + return [TextContent(type="text", text=str(result))] + except Exception as e: + return [TextContent(type="text", text=f"ERROR: {e}")] async def main() -> None: - # Configure mcp_client with the project root so py_* tools are not ACCESS DENIED - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - mcp_client.configure([], extra_base_dirs=[project_root]) - async with stdio_server() as (read_stream, write_stream): - await server.run( - read_stream, - write_stream, - server.create_initialization_options(), - ) - +# Configure mcp_client with the project root so py_* tools are not ACCESS DENIED + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + mcp_client.configure([], extra_base_dirs=[project_root]) + async with stdio_server() as (read_stream, write_stream): + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(main()) diff --git a/scripts/scan_all_hints.py b/scripts/scan_all_hints.py index 9e045e5..d2a54af 100644 --- a/scripts/scan_all_hints.py +++ b/scripts/scan_all_hints.py @@ -18,8 +18,9 @@ for root, dirs, files in os.walk('.'): except Exception: continue counts: list[int] = [0, 0, 0] # nr, up, uv + def scan(scope: ast.AST, prefix: str = '') -> None: - # Iterate top-level nodes in this scope + # Iterate top-level nodes in this scope for node in ast.iter_child_nodes(scope): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): if node.returns is None: diff --git a/session_logger.py b/session_logger.py index 1ee7ede..02bfab2 100644 --- a/session_logger.py +++ b/session_logger.py @@ -30,8 +30,10 @@ _comms_fh: Optional[TextIO] = None # file handle: logs//comms.log _tool_fh: Optional[TextIO] = None # file handle: logs//toolcalls.log _api_fh: Optional[TextIO] = None # file handle: logs//apihooks.log _cli_fh: Optional[TextIO] = None # file handle: logs//clicalls.log + def _now_ts() -> str: return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + def open_session(label: Optional[str] = None) -> None: """ Called once at GUI startup. Creates the log directories if needed and @@ -64,6 +66,7 @@ def open_session(label: Optional[str] = None) -> None: except Exception as e: print(f"Warning: Could not register session in LogRegistry: {e}") atexit.register(close_session) + def close_session() -> None: """Flush and close all log files. Called on clean exit.""" global _comms_fh, _tool_fh, _api_fh, _cli_fh, _session_id, _LOG_DIR @@ -87,6 +90,7 @@ def close_session() -> None: registry.update_auto_whitelist_status(_session_id) except Exception as e: print(f"Warning: Could not update auto-whitelist on close: {e}") + def log_api_hook(method: str, path: str, payload: str) -> None: """Log an API hook invocation.""" if _api_fh is None: @@ -97,6 +101,7 @@ def log_api_hook(method: str, path: str, payload: str) -> None: _api_fh.flush() except Exception: pass + def log_comms(entry: dict[str, Any]) -> None: """ Append one comms entry to the comms log file as a JSON-L line. @@ -108,6 +113,7 @@ def log_comms(entry: dict[str, Any]) -> None: _comms_fh.write(json.dumps(entry, ensure_ascii=False, default=str) + "\n") except Exception: pass + def log_tool_call(script: str, result: str, script_path: Optional[str]) -> Optional[str]: """ Append a tool-call record to the toolcalls log and write the PS1 script to @@ -139,6 +145,7 @@ def log_tool_call(script: str, result: str, script_path: Optional[str]) -> Optio except Exception: pass return str(ps1_path) if ps1_path else None + def log_cli_call(command: str, stdin_content: Optional[str], stdout_content: Optional[str], stderr_content: Optional[str], latency: float) -> None: """Log details of a CLI subprocess execution.""" if _cli_fh is None: diff --git a/shell_runner.py b/shell_runner.py index b6c5a63..2acd743 100644 --- a/shell_runner.py +++ b/shell_runner.py @@ -33,7 +33,7 @@ def _build_subprocess_env() -> dict[str, str]: prepend_dirs = _ENV_CONFIG.get("path", {}).get("prepend", []) if prepend_dirs: env["PATH"] = os.pathsep.join(prepend_dirs) + os.pathsep + env.get("PATH", "") - # Apply [env] key-value pairs, expanding ${VAR} references + # Apply [env] key-value pairs, expanding ${VAR} references for key, val in _ENV_CONFIG.get("env", {}).items(): env[key] = os.path.expandvars(str(val)) return env diff --git a/tests/mock_alias_tool.py b/tests/mock_alias_tool.py index ce94548..29262fa 100644 --- a/tests/mock_alias_tool.py +++ b/tests/mock_alias_tool.py @@ -1,21 +1,20 @@ import sys, json, os, subprocess prompt = sys.stdin.read() if '"role": "tool"' in prompt: - print(json.dumps({"type": "message", "role": "assistant", "content": "Tool worked!"}), flush=True) - print(json.dumps({"type": "result", "stats": {"total_tokens": 20}}), flush=True) + print(json.dumps({"type": "message", "role": "assistant", "content": "Tool worked!"}), flush=True) + print(json.dumps({"type": "result", "stats": {"total_tokens": 20}}), flush=True) else: - # We must call the bridge to trigger the GUI approval! - tool_call = {"name": "list_directory", "input": {"dir_path": "."}} - bridge_cmd = [sys.executable, "C:/projects/manual_slop/scripts/cli_tool_bridge.py"] - proc = subprocess.Popen(bridge_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, text=True) - stdout, _ = proc.communicate(input=json.dumps(tool_call)) - - # Even if bridge says allow, we emit the tool_use to the adapter - print(json.dumps({"type": "message", "role": "assistant", "content": "I will list the directory."}), flush=True) - print(json.dumps({ - "type": "tool_use", - "name": "list_directory", - "id": "alias_call", - "args": {"dir_path": "."} - }), flush=True) - print(json.dumps({"type": "result", "stats": {"total_tokens": 10}}), flush=True) +# We must call the bridge to trigger the GUI approval! + tool_call = {"name": "list_directory", "input": {"dir_path": "."}} + bridge_cmd = [sys.executable, "C:/projects/manual_slop/scripts/cli_tool_bridge.py"] + proc = subprocess.Popen(bridge_cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, text=True) + stdout, _ = proc.communicate(input=json.dumps(tool_call)) + # Even if bridge says allow, we emit the tool_use to the adapter + print(json.dumps({"type": "message", "role": "assistant", "content": "I will list the directory."}), flush=True) + print(json.dumps({ + "type": "tool_use", + "name": "list_directory", + "id": "alias_call", + "args": {"dir_path": "."} + }), flush=True) + print(json.dumps({"type": "result", "stats": {"total_tokens": 10}}), flush=True) diff --git a/tests/test_api_events.py b/tests/test_api_events.py index ce54913..dd0bd2f 100644 --- a/tests/test_api_events.py +++ b/tests/test_api_events.py @@ -4,104 +4,104 @@ from unittest.mock import MagicMock, patch import ai_client class MockUsage: - def __init__(self) -> None: - self.prompt_token_count = 10 - self.candidates_token_count = 5 - self.total_token_count = 15 - self.cached_content_token_count = 0 + def __init__(self) -> None: + self.prompt_token_count = 10 + self.candidates_token_count = 5 + self.total_token_count = 15 + self.cached_content_token_count = 0 class MockPart: - def __init__(self, text: Any, function_call: Any) -> None: - self.text = text - self.function_call = function_call + def __init__(self, text: Any, function_call: Any) -> None: + self.text = text + self.function_call = function_call class MockContent: - def __init__(self, parts: Any) -> None: - self.parts = parts + def __init__(self, parts: Any) -> None: + self.parts = parts class MockCandidate: - def __init__(self, parts: Any) -> None: - self.content = MockContent(parts) - self.finish_reason = MagicMock() - self.finish_reason.name = "STOP" + def __init__(self, parts: Any) -> None: + self.content = MockContent(parts) + self.finish_reason = MagicMock() + self.finish_reason.name = "STOP" def test_ai_client_event_emitter_exists() -> None: - # This should fail initially because 'events' won't exist on ai_client - assert hasattr(ai_client, 'events') +# This should fail initially because 'events' won't exist on ai_client + assert hasattr(ai_client, 'events') def test_event_emission() -> None: - callback = MagicMock() - ai_client.events.on("test_event", callback) - ai_client.events.emit("test_event", payload={"data": 123}) - callback.assert_called_once_with(payload={"data": 123}) + callback = MagicMock() + ai_client.events.on("test_event", callback) + ai_client.events.emit("test_event", payload={"data": 123}) + callback.assert_called_once_with(payload={"data": 123}) def test_send_emits_events() -> None: - with patch("ai_client._send_gemini") as mock_send_gemini, \ - patch("ai_client._send_anthropic") as mock_send_anthropic: - mock_send_gemini.return_value = "gemini response" - start_callback = MagicMock() - response_callback = MagicMock() - ai_client.events.on("request_start", start_callback) - ai_client.events.on("response_received", response_callback) - ai_client.set_provider("gemini", "gemini-2.5-flash-lite") - ai_client.send("context", "message") - # We mocked _send_gemini so it doesn't emit events inside. - # But wait, ai_client.send itself emits request_start and response_received? - # Actually, ai_client.send delegates to _send_gemini. - # Let's mock _gemini_client instead to let _send_gemini run and emit events. - pass + with patch("ai_client._send_gemini") as mock_send_gemini, \ + patch("ai_client._send_anthropic") as mock_send_anthropic: + mock_send_gemini.return_value = "gemini response" + start_callback = MagicMock() + response_callback = MagicMock() + ai_client.events.on("request_start", start_callback) + ai_client.events.on("response_received", response_callback) + ai_client.set_provider("gemini", "gemini-2.5-flash-lite") + ai_client.send("context", "message") + # We mocked _send_gemini so it doesn't emit events inside. + # But wait, ai_client.send itself emits request_start and response_received? + # Actually, ai_client.send delegates to _send_gemini. + # Let's mock _gemini_client instead to let _send_gemini run and emit events. + pass def test_send_emits_events_proper() -> None: - with patch("ai_client._ensure_gemini_client"), \ - patch("ai_client._gemini_client") as mock_client: - mock_chat = MagicMock() - mock_client.chats.create.return_value = mock_chat - mock_response = MagicMock() - mock_response.candidates = [MockCandidate([MockPart("gemini response", None)])] - mock_response.usage_metadata = MockUsage() - mock_chat.send_message.return_value = mock_response - start_callback = MagicMock() - response_callback = MagicMock() - ai_client.events.on("request_start", start_callback) - ai_client.events.on("response_received", response_callback) - ai_client.set_provider("gemini", "gemini-2.5-flash-lite") - ai_client.send("context", "message") - assert start_callback.called - assert response_callback.called - args, kwargs = start_callback.call_args - assert kwargs['payload']['provider'] == 'gemini' + with patch("ai_client._ensure_gemini_client"), \ + patch("ai_client._gemini_client") as mock_client: + mock_chat = MagicMock() + mock_client.chats.create.return_value = mock_chat + mock_response = MagicMock() + mock_response.candidates = [MockCandidate([MockPart("gemini response", None)])] + mock_response.usage_metadata = MockUsage() + mock_chat.send_message.return_value = mock_response + start_callback = MagicMock() + response_callback = MagicMock() + ai_client.events.on("request_start", start_callback) + ai_client.events.on("response_received", response_callback) + ai_client.set_provider("gemini", "gemini-2.5-flash-lite") + ai_client.send("context", "message") + assert start_callback.called + assert response_callback.called + args, kwargs = start_callback.call_args + assert kwargs['payload']['provider'] == 'gemini' def test_send_emits_tool_events() -> None: - import mcp_client - with patch("ai_client._ensure_gemini_client"), \ - patch("ai_client._gemini_client") as mock_client, \ - patch("mcp_client.dispatch") as mock_dispatch: - mock_chat = MagicMock() - mock_client.chats.create.return_value = mock_chat - # 1. Setup mock response with a tool call - mock_fc = MagicMock() - mock_fc.name = "read_file" - mock_fc.args = {"path": "test.txt"} - mock_response_with_tool = MagicMock() - mock_response_with_tool.candidates = [MockCandidate([MockPart("tool call text", mock_fc)])] - mock_response_with_tool.usage_metadata = MockUsage() - # 2. Setup second mock response (final answer) - mock_response_final = MagicMock() - mock_response_final.candidates = [MockCandidate([MockPart("final answer", None)])] - mock_response_final.usage_metadata = MockUsage() - mock_chat.send_message.side_effect = [mock_response_with_tool, mock_response_final] - mock_dispatch.return_value = "file content" - ai_client.set_provider("gemini", "gemini-2.5-flash-lite") - tool_callback = MagicMock() - ai_client.events.on("tool_execution", tool_callback) - ai_client.send("context", "message") - # Should be called twice: once for 'started', once for 'completed' - assert tool_callback.call_count == 2 - # Check 'started' call - args, kwargs = tool_callback.call_args_list[0] - assert kwargs['payload']['status'] == 'started' - assert kwargs['payload']['tool'] == 'read_file' - # Check 'completed' call - args, kwargs = tool_callback.call_args_list[1] - assert kwargs['payload']['status'] == 'completed' - assert kwargs['payload']['result'] == 'file content' + import mcp_client + with patch("ai_client._ensure_gemini_client"), \ + patch("ai_client._gemini_client") as mock_client, \ + patch("mcp_client.dispatch") as mock_dispatch: + mock_chat = MagicMock() + mock_client.chats.create.return_value = mock_chat + # 1. Setup mock response with a tool call + mock_fc = MagicMock() + mock_fc.name = "read_file" + mock_fc.args = {"path": "test.txt"} + mock_response_with_tool = MagicMock() + mock_response_with_tool.candidates = [MockCandidate([MockPart("tool call text", mock_fc)])] + mock_response_with_tool.usage_metadata = MockUsage() + # 2. Setup second mock response (final answer) + mock_response_final = MagicMock() + mock_response_final.candidates = [MockCandidate([MockPart("final answer", None)])] + mock_response_final.usage_metadata = MockUsage() + mock_chat.send_message.side_effect = [mock_response_with_tool, mock_response_final] + mock_dispatch.return_value = "file content" + ai_client.set_provider("gemini", "gemini-2.5-flash-lite") + tool_callback = MagicMock() + ai_client.events.on("tool_execution", tool_callback) + ai_client.send("context", "message") + # Should be called twice: once for 'started', once for 'completed' + assert tool_callback.call_count == 2 + # Check 'started' call + args, kwargs = tool_callback.call_args_list[0] + assert kwargs['payload']['status'] == 'started' + assert kwargs['payload']['tool'] == 'read_file' + # Check 'completed' call + args, kwargs = tool_callback.call_args_list[1] + assert kwargs['payload']['status'] == 'completed' + assert kwargs['payload']['result'] == 'file content' diff --git a/tests/test_conductor_tech_lead.py b/tests/test_conductor_tech_lead.py index ba6c0c5..73a33d8 100644 --- a/tests/test_conductor_tech_lead.py +++ b/tests/test_conductor_tech_lead.py @@ -5,105 +5,105 @@ import json import conductor_tech_lead class TestConductorTechLead(unittest.TestCase): - @patch('ai_client.send') - @patch('ai_client.set_provider') - @patch('ai_client.reset_session') - def test_generate_tickets_success(self, mock_reset_session: Any, mock_set_provider: Any, mock_send: Any) -> None: - mock_tickets = [ - { - "id": "ticket_1", - "type": "Ticket", - "goal": "Test goal", - "target_file": "test.py", - "depends_on": [], - "context_requirements": [] - } - ] - mock_send.return_value = "```json\n" + json.dumps(mock_tickets) + "\n```" - track_brief = "Test track brief" - module_skeletons = "Test skeletons" - # Call the function - tickets = conductor_tech_lead.generate_tickets(track_brief, module_skeletons) - # Verify set_provider was called - mock_set_provider.assert_called_with('gemini', 'gemini-2.5-flash-lite') - mock_reset_session.assert_called_once() - # Verify send was called - mock_send.assert_called_once() - args, kwargs = mock_send.call_args - self.assertEqual(kwargs['md_content'], "") - self.assertIn(track_brief, kwargs['user_message']) - self.assertIn(module_skeletons, kwargs['user_message']) - # Verify tickets were parsed correctly - self.assertEqual(tickets, mock_tickets) + @patch('ai_client.send') + @patch('ai_client.set_provider') + @patch('ai_client.reset_session') + def test_generate_tickets_success(self, mock_reset_session: Any, mock_set_provider: Any, mock_send: Any) -> None: + mock_tickets = [ + { + "id": "ticket_1", + "type": "Ticket", + "goal": "Test goal", + "target_file": "test.py", + "depends_on": [], + "context_requirements": [] + } + ] + mock_send.return_value = "```json\n" + json.dumps(mock_tickets) + "\n```" + track_brief = "Test track brief" + module_skeletons = "Test skeletons" + # Call the function + tickets = conductor_tech_lead.generate_tickets(track_brief, module_skeletons) + # Verify set_provider was called + mock_set_provider.assert_called_with('gemini', 'gemini-2.5-flash-lite') + mock_reset_session.assert_called_once() + # Verify send was called + mock_send.assert_called_once() + args, kwargs = mock_send.call_args + self.assertEqual(kwargs['md_content'], "") + self.assertIn(track_brief, kwargs['user_message']) + self.assertIn(module_skeletons, kwargs['user_message']) + # Verify tickets were parsed correctly + self.assertEqual(tickets, mock_tickets) - @patch('ai_client.send') - @patch('ai_client.set_provider') - @patch('ai_client.reset_session') - def test_generate_tickets_parse_error(self, mock_reset_session: Any, mock_set_provider: Any, mock_send: Any) -> None: - # Setup mock invalid response - mock_send.return_value = "Invalid JSON" - # Call the function - tickets = conductor_tech_lead.generate_tickets("brief", "skeletons") - # Verify it returns an empty list on parse error - self.assertEqual(tickets, []) + @patch('ai_client.send') + @patch('ai_client.set_provider') + @patch('ai_client.reset_session') + def test_generate_tickets_parse_error(self, mock_reset_session: Any, mock_set_provider: Any, mock_send: Any) -> None: + # Setup mock invalid response + mock_send.return_value = "Invalid JSON" + # Call the function + tickets = conductor_tech_lead.generate_tickets("brief", "skeletons") + # Verify it returns an empty list on parse error + self.assertEqual(tickets, []) class TestTopologicalSort(unittest.TestCase): - def test_topological_sort_empty(self) -> None: - tickets = [] - sorted_tickets = conductor_tech_lead.topological_sort(tickets) - self.assertEqual(sorted_tickets, []) + def test_topological_sort_empty(self) -> None: + tickets = [] + sorted_tickets = conductor_tech_lead.topological_sort(tickets) + self.assertEqual(sorted_tickets, []) - def test_topological_sort_linear(self) -> None: - tickets = [ - {"id": "t2", "depends_on": ["t1"]}, - {"id": "t1", "depends_on": []}, - {"id": "t3", "depends_on": ["t2"]}, - ] - sorted_tickets = conductor_tech_lead.topological_sort(tickets) - ids = [t["id"] for t in sorted_tickets] - self.assertEqual(ids, ["t1", "t2", "t3"]) + def test_topological_sort_linear(self) -> None: + tickets = [ + {"id": "t2", "depends_on": ["t1"]}, + {"id": "t1", "depends_on": []}, + {"id": "t3", "depends_on": ["t2"]}, + ] + sorted_tickets = conductor_tech_lead.topological_sort(tickets) + ids = [t["id"] for t in sorted_tickets] + self.assertEqual(ids, ["t1", "t2", "t3"]) - def test_topological_sort_complex(self) -> None: - # t1 - # | \ - # t2 t3 - # | / - # t4 - tickets = [ - {"id": "t4", "depends_on": ["t2", "t3"]}, - {"id": "t3", "depends_on": ["t1"]}, - {"id": "t2", "depends_on": ["t1"]}, - {"id": "t1", "depends_on": []}, - ] - sorted_tickets = conductor_tech_lead.topological_sort(tickets) - ids = [t["id"] for t in sorted_tickets] - # Possible valid orders: [t1, t2, t3, t4] or [t1, t3, t2, t4] - self.assertEqual(ids[0], "t1") - self.assertEqual(ids[-1], "t4") - self.assertSetEqual(set(ids[1:3]), {"t2", "t3"}) + def test_topological_sort_complex(self) -> None: + # t1 + # | \ + # t2 t3 + # | / + # t4 + tickets = [ + {"id": "t4", "depends_on": ["t2", "t3"]}, + {"id": "t3", "depends_on": ["t1"]}, + {"id": "t2", "depends_on": ["t1"]}, + {"id": "t1", "depends_on": []}, + ] + sorted_tickets = conductor_tech_lead.topological_sort(tickets) + ids = [t["id"] for t in sorted_tickets] + # Possible valid orders: [t1, t2, t3, t4] or [t1, t3, t2, t4] + self.assertEqual(ids[0], "t1") + self.assertEqual(ids[-1], "t4") + self.assertSetEqual(set(ids[1:3]), {"t2", "t3"}) - def test_topological_sort_cycle(self) -> None: - tickets = [ - {"id": "t1", "depends_on": ["t2"]}, - {"id": "t2", "depends_on": ["t1"]}, - ] - with self.assertRaises(ValueError) as cm: - conductor_tech_lead.topological_sort(tickets) - self.assertIn("Circular dependency detected", str(cm.exception)) + def test_topological_sort_cycle(self) -> None: + tickets = [ + {"id": "t1", "depends_on": ["t2"]}, + {"id": "t2", "depends_on": ["t1"]}, + ] + with self.assertRaises(ValueError) as cm: + conductor_tech_lead.topological_sort(tickets) + self.assertIn("Circular dependency detected", str(cm.exception)) - def test_topological_sort_missing_dependency(self) -> None: - # If a ticket depends on something not in the list, we should probably handle it or let it fail. - # Usually in our context, we only care about dependencies within the same track. - tickets = [ - {"id": "t1", "depends_on": ["missing"]}, - ] - # For now, let's assume it should raise an error if a dependency is missing within the set we are sorting, - # OR it should just treat it as "ready" if it's external? - # Actually, let's just test that it doesn't crash if it's not a cycle. - # But if 'missing' is not in tickets, it will never be satisfied. - # Let's say it raises ValueError for missing internal dependencies. - with self.assertRaises(ValueError): - conductor_tech_lead.topological_sort(tickets) + def test_topological_sort_missing_dependency(self) -> None: + # If a ticket depends on something not in the list, we should probably handle it or let it fail. + # Usually in our context, we only care about dependencies within the same track. + tickets = [ + {"id": "t1", "depends_on": ["missing"]}, + ] + # For now, let's assume it should raise an error if a dependency is missing within the set we are sorting, + # OR it should just treat it as "ready" if it's external? + # Actually, let's just test that it doesn't crash if it's not a cycle. + # But if 'missing' is not in tickets, it will never be satisfied. + # Let's say it raises ValueError for missing internal dependencies. + with self.assertRaises(ValueError): + conductor_tech_lead.topological_sort(tickets) if __name__ == '__main__': - unittest.main() + unittest.main() diff --git a/tests/test_gemini_cli_adapter.py b/tests/test_gemini_cli_adapter.py index 363bc4b..46ec0e8 100644 --- a/tests/test_gemini_cli_adapter.py +++ b/tests/test_gemini_cli_adapter.py @@ -13,105 +13,105 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from gemini_cli_adapter import GeminiCliAdapter class TestGeminiCliAdapter(unittest.TestCase): - def setUp(self) -> None: - self.adapter = GeminiCliAdapter(binary_path="gemini") + def setUp(self) -> None: + self.adapter = GeminiCliAdapter(binary_path="gemini") - @patch('subprocess.Popen') - def test_send_starts_subprocess_with_correct_args(self, mock_popen: Any) -> None: - """ + @patch('subprocess.Popen') + def test_send_starts_subprocess_with_correct_args(self, mock_popen: Any) -> None: + """ Verify that send(message) correctly starts the subprocess with --output-format stream-json and the provided message via stdin using communicate. """ - # Setup mock process with a minimal valid JSONL termination - process_mock = MagicMock() - stdout_content = json.dumps({"type": "result", "usage": {}}) + "\n" - process_mock.communicate.return_value = (stdout_content, "") - process_mock.poll.return_value = 0 - process_mock.wait.return_value = 0 - mock_popen.return_value = process_mock - message = "Hello Gemini CLI" - self.adapter.send(message) - # Verify subprocess.Popen call - mock_popen.assert_called_once() - args, kwargs = mock_popen.call_args - cmd = args[0] - # Check mandatory CLI components - self.assertIn("gemini", cmd) - self.assertIn("--output-format", cmd) - self.assertIn("stream-json", cmd) - # Message should NOT be in cmd now - self.assertNotIn(message, cmd) - # Verify message was sent via communicate - process_mock.communicate.assert_called_once_with(input=message) - # Check process configuration - self.assertEqual(kwargs.get('stdout'), subprocess.PIPE) - self.assertEqual(kwargs.get('stdin'), subprocess.PIPE) - self.assertEqual(kwargs.get('text'), True) + # Setup mock process with a minimal valid JSONL termination + process_mock = MagicMock() + stdout_content = json.dumps({"type": "result", "usage": {}}) + "\n" + process_mock.communicate.return_value = (stdout_content, "") + process_mock.poll.return_value = 0 + process_mock.wait.return_value = 0 + mock_popen.return_value = process_mock + message = "Hello Gemini CLI" + self.adapter.send(message) + # Verify subprocess.Popen call + mock_popen.assert_called_once() + args, kwargs = mock_popen.call_args + cmd = args[0] + # Check mandatory CLI components + self.assertIn("gemini", cmd) + self.assertIn("--output-format", cmd) + self.assertIn("stream-json", cmd) + # Message should NOT be in cmd now + self.assertNotIn(message, cmd) + # Verify message was sent via communicate + process_mock.communicate.assert_called_once_with(input=message) + # Check process configuration + self.assertEqual(kwargs.get('stdout'), subprocess.PIPE) + self.assertEqual(kwargs.get('stdin'), subprocess.PIPE) + self.assertEqual(kwargs.get('text'), True) - @patch('subprocess.Popen') - def test_send_parses_jsonl_output(self, mock_popen: Any) -> None: - """ + @patch('subprocess.Popen') + def test_send_parses_jsonl_output(self, mock_popen: Any) -> None: + """ Verify that it correctly parses multiple JSONL 'message' events and returns the combined text. """ - jsonl_output = [ - json.dumps({"type": "message", "role": "model", "text": "The quick brown "}), - json.dumps({"type": "message", "role": "model", "text": "fox jumps."}), - json.dumps({"type": "result", "usage": {"prompt_tokens": 5, "candidates_tokens": 5}}) - ] - stdout_content = "\n".join(jsonl_output) + "\n" - process_mock = MagicMock() - process_mock.communicate.return_value = (stdout_content, "") - process_mock.poll.return_value = 0 - process_mock.wait.return_value = 0 - mock_popen.return_value = process_mock - result = self.adapter.send("test message") - self.assertEqual(result["text"], "The quick brown fox jumps.") - self.assertEqual(result["tool_calls"], []) + jsonl_output = [ + json.dumps({"type": "message", "role": "model", "text": "The quick brown "}), + json.dumps({"type": "message", "role": "model", "text": "fox jumps."}), + json.dumps({"type": "result", "usage": {"prompt_tokens": 5, "candidates_tokens": 5}}) + ] + stdout_content = "\n".join(jsonl_output) + "\n" + process_mock = MagicMock() + process_mock.communicate.return_value = (stdout_content, "") + process_mock.poll.return_value = 0 + process_mock.wait.return_value = 0 + mock_popen.return_value = process_mock + result = self.adapter.send("test message") + self.assertEqual(result["text"], "The quick brown fox jumps.") + self.assertEqual(result["tool_calls"], []) - @patch('subprocess.Popen') - def test_send_handles_tool_use_events(self, mock_popen: Any) -> None: - """ + @patch('subprocess.Popen') + def test_send_handles_tool_use_events(self, mock_popen: Any) -> None: + """ Verify that it correctly handles 'tool_use' events in the stream by continuing to read until the final 'result' event. """ - jsonl_output = [ - json.dumps({"type": "message", "role": "assistant", "text": "Calling tool..."}), - json.dumps({"type": "tool_use", "name": "read_file", "args": {"path": "test.txt"}}), - json.dumps({"type": "message", "role": "assistant", "text": "\nFile read successfully."}), - json.dumps({"type": "result", "usage": {}}) - ] - stdout_content = "\n".join(jsonl_output) + "\n" - process_mock = MagicMock() - process_mock.communicate.return_value = (stdout_content, "") - process_mock.poll.return_value = 0 - process_mock.wait.return_value = 0 - mock_popen.return_value = process_mock - result = self.adapter.send("read test.txt") - # Result should contain the combined text from all 'message' events - self.assertEqual(result["text"], "Calling tool...\nFile read successfully.") - self.assertEqual(len(result["tool_calls"]), 1) - self.assertEqual(result["tool_calls"][0]["name"], "read_file") + jsonl_output = [ + json.dumps({"type": "message", "role": "assistant", "text": "Calling tool..."}), + json.dumps({"type": "tool_use", "name": "read_file", "args": {"path": "test.txt"}}), + json.dumps({"type": "message", "role": "assistant", "text": "\nFile read successfully."}), + json.dumps({"type": "result", "usage": {}}) + ] + stdout_content = "\n".join(jsonl_output) + "\n" + process_mock = MagicMock() + process_mock.communicate.return_value = (stdout_content, "") + process_mock.poll.return_value = 0 + process_mock.wait.return_value = 0 + mock_popen.return_value = process_mock + result = self.adapter.send("read test.txt") + # Result should contain the combined text from all 'message' events + self.assertEqual(result["text"], "Calling tool...\nFile read successfully.") + self.assertEqual(len(result["tool_calls"]), 1) + self.assertEqual(result["tool_calls"][0]["name"], "read_file") - @patch('subprocess.Popen') - def test_send_captures_usage_metadata(self, mock_popen: Any) -> None: - """ + @patch('subprocess.Popen') + def test_send_captures_usage_metadata(self, mock_popen: Any) -> None: + """ Verify that usage data is extracted from the 'result' event. """ - usage_data = {"total_tokens": 42} - jsonl_output = [ - json.dumps({"type": "message", "text": "Finalizing"}), - json.dumps({"type": "result", "usage": usage_data}) - ] - stdout_content = "\n".join(jsonl_output) + "\n" - process_mock = MagicMock() - process_mock.communicate.return_value = (stdout_content, "") - process_mock.poll.return_value = 0 - process_mock.wait.return_value = 0 - mock_popen.return_value = process_mock - self.adapter.send("usage test") - # Verify the usage was captured in the adapter instance - self.assertEqual(self.adapter.last_usage, usage_data) + usage_data = {"total_tokens": 42} + jsonl_output = [ + json.dumps({"type": "message", "text": "Finalizing"}), + json.dumps({"type": "result", "usage": usage_data}) + ] + stdout_content = "\n".join(jsonl_output) + "\n" + process_mock = MagicMock() + process_mock.communicate.return_value = (stdout_content, "") + process_mock.poll.return_value = 0 + process_mock.wait.return_value = 0 + mock_popen.return_value = process_mock + self.adapter.send("usage test") + # Verify the usage was captured in the adapter instance + self.assertEqual(self.adapter.last_usage, usage_data) if __name__ == '__main__': - unittest.main() + unittest.main() diff --git a/tests/test_orchestrator_pm.py b/tests/test_orchestrator_pm.py index 511860a..9e4960d 100644 --- a/tests/test_orchestrator_pm.py +++ b/tests/test_orchestrator_pm.py @@ -7,67 +7,67 @@ import mma_prompts class TestOrchestratorPM(unittest.TestCase): - @patch('summarize.build_summary_markdown') - @patch('ai_client.send') - def test_generate_tracks_success(self, mock_send: Any, mock_summarize: Any) -> None: - # Setup mocks - mock_summarize.return_value = "REPO_MAP_CONTENT" - mock_response_data = [ - { - "id": "track_1", - "type": "Track", - "module": "test_module", - "persona": "Tech Lead", - "severity": "Medium", - "goal": "Test goal", - "acceptance_criteria": ["criteria 1"] - } - ] - mock_send.return_value = json.dumps(mock_response_data) - user_request = "Implement unit tests" - project_config = {"files": {"paths": ["src"]}} - file_items = [{"path": "src/main.py", "content": "print('hello')"}] - # Execute - result = orchestrator_pm.generate_tracks(user_request, project_config, file_items) - # Verify summarize call - mock_summarize.assert_called_once_with(file_items) - # Verify ai_client.send call - expected_system_prompt = mma_prompts.PROMPTS['tier1_epic_init'] - mock_send.assert_called_once() - args, kwargs = mock_send.call_args - self.assertEqual(kwargs['md_content'], "") - # Cannot check system_prompt via mock_send kwargs anymore as it's set globally - # But we can verify user_message was passed - self.assertIn(user_request, kwargs['user_message']) - self.assertIn("REPO_MAP_CONTENT", kwargs['user_message']) - # Verify result - self.assertEqual(result[0]['id'], mock_response_data[0]['id']) + @patch('summarize.build_summary_markdown') + @patch('ai_client.send') + def test_generate_tracks_success(self, mock_send: Any, mock_summarize: Any) -> None: + # Setup mocks + mock_summarize.return_value = "REPO_MAP_CONTENT" + mock_response_data = [ + { + "id": "track_1", + "type": "Track", + "module": "test_module", + "persona": "Tech Lead", + "severity": "Medium", + "goal": "Test goal", + "acceptance_criteria": ["criteria 1"] + } + ] + mock_send.return_value = json.dumps(mock_response_data) + user_request = "Implement unit tests" + project_config = {"files": {"paths": ["src"]}} + file_items = [{"path": "src/main.py", "content": "print('hello')"}] + # Execute + result = orchestrator_pm.generate_tracks(user_request, project_config, file_items) + # Verify summarize call + mock_summarize.assert_called_once_with(file_items) + # Verify ai_client.send call + expected_system_prompt = mma_prompts.PROMPTS['tier1_epic_init'] + mock_send.assert_called_once() + args, kwargs = mock_send.call_args + self.assertEqual(kwargs['md_content'], "") + # Cannot check system_prompt via mock_send kwargs anymore as it's set globally + # But we can verify user_message was passed + self.assertIn(user_request, kwargs['user_message']) + self.assertIn("REPO_MAP_CONTENT", kwargs['user_message']) + # Verify result + self.assertEqual(result[0]['id'], mock_response_data[0]['id']) - @patch('summarize.build_summary_markdown') - @patch('ai_client.send') - def test_generate_tracks_markdown_wrapped(self, mock_send: Any, mock_summarize: Any) -> None: - mock_summarize.return_value = "REPO_MAP" - mock_response_data = [{"id": "track_1"}] - expected_result = [{"id": "track_1", "title": "Untitled Track"}] - # Wrapped in ```json ... ``` - mock_send.return_value = f"Here is the plan:\n```json\n{json.dumps(mock_response_data)}\n```\nHope this helps." - result = orchestrator_pm.generate_tracks("req", {}, []) - self.assertEqual(result, expected_result) - # Wrapped in ``` ... ``` - mock_send.return_value = f"```\n{json.dumps(mock_response_data)}\n```" - result = orchestrator_pm.generate_tracks("req", {}, []) - self.assertEqual(result, expected_result) + @patch('summarize.build_summary_markdown') + @patch('ai_client.send') + def test_generate_tracks_markdown_wrapped(self, mock_send: Any, mock_summarize: Any) -> None: + mock_summarize.return_value = "REPO_MAP" + mock_response_data = [{"id": "track_1"}] + expected_result = [{"id": "track_1", "title": "Untitled Track"}] + # Wrapped in ```json ... ``` + mock_send.return_value = f"Here is the plan:\n```json\n{json.dumps(mock_response_data)}\n```\nHope this helps." + result = orchestrator_pm.generate_tracks("req", {}, []) + self.assertEqual(result, expected_result) + # Wrapped in ``` ... ``` + mock_send.return_value = f"```\n{json.dumps(mock_response_data)}\n```" + result = orchestrator_pm.generate_tracks("req", {}, []) + self.assertEqual(result, expected_result) - @patch('summarize.build_summary_markdown') - @patch('ai_client.send') - def test_generate_tracks_malformed_json(self, mock_send: Any, mock_summarize: Any) -> None: - mock_summarize.return_value = "REPO_MAP" - mock_send.return_value = "NOT A JSON" - # Should return empty list and print error (we can mock print if we want to be thorough) - with patch('builtins.print') as mock_print: - result = orchestrator_pm.generate_tracks("req", {}, []) - self.assertEqual(result, []) - mock_print.assert_any_call("Error parsing Tier 1 response: Expecting value: line 1 column 1 (char 0)") + @patch('summarize.build_summary_markdown') + @patch('ai_client.send') + def test_generate_tracks_malformed_json(self, mock_send: Any, mock_summarize: Any) -> None: + mock_summarize.return_value = "REPO_MAP" + mock_send.return_value = "NOT A JSON" + # Should return empty list and print error (we can mock print if we want to be thorough) + with patch('builtins.print') as mock_print: + result = orchestrator_pm.generate_tracks("req", {}, []) + self.assertEqual(result, []) + mock_print.assert_any_call("Error parsing Tier 1 response: Expecting value: line 1 column 1 (char 0)") if __name__ == '__main__': - unittest.main() + unittest.main() diff --git a/tests/test_session_logging.py b/tests/test_session_logging.py index f23bff9..bd6bfcb 100644 --- a/tests/test_session_logging.py +++ b/tests/test_session_logging.py @@ -10,7 +10,7 @@ import session_logger @pytest.fixture def temp_logs(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Generator[Path, None, None]: - # Ensure closed before starting +# Ensure closed before starting session_logger.close_session() monkeypatch.setattr(session_logger, "_comms_fh", None) # Mock _LOG_DIR in session_logger diff --git a/tests/test_sync_hooks.py b/tests/test_sync_hooks.py index 27a4168..0aab231 100644 --- a/tests/test_sync_hooks.py +++ b/tests/test_sync_hooks.py @@ -15,15 +15,14 @@ def test_api_ask_client_method(live_gui) -> None: def make_blocking_request() -> None: try: - # This call should block until we respond + # This call should block until we respond results["response"] = client.request_confirmation( tool_name="powershell", args={"command": "echo hello"} ) except Exception as e: results["error"] = str(e) - - # Start the request in a background thread + # Start the request in a background thread t = threading.Thread(target=make_blocking_request) t.start() # Poll for the 'ask_received' event diff --git a/tests/verify_mma_gui_robust.py b/tests/verify_mma_gui_robust.py index 821f7ed..4a5f95d 100644 --- a/tests/verify_mma_gui_robust.py +++ b/tests/verify_mma_gui_robust.py @@ -14,7 +14,7 @@ from api_hook_client import ApiHookClient class TestMMAGUIRobust(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - # 1. Launch gui_2.py with --enable-test-hooks + # 1. Launch gui_2.py with --enable-test-hooks cls.gui_command = [sys.executable, "gui_2.py", "--enable-test-hooks"] print(f"Launching GUI: {' '.join(cls.gui_command)}") cls.gui_process = subprocess.Popen(