From 90670b9671d1829053d35285e3d723d5b3b097ea Mon Sep 17 00:00:00 2001 From: Ed_ Date: Sat, 7 Mar 2026 00:26:34 -0500 Subject: [PATCH] feat(tier4): Integrate patch generation into GUI workflow - Add patch_callback parameter throughout the tool execution chain - Add _render_patch_modal() to gui_2.py with colored diff display - Add patch modal state variables to App.__init__ - Add request_patch_from_tier4() to trigger patch generation - Add run_tier4_patch_callback() to ai_client.py - Update shell_runner to accept and execute patch_callback - Diff colors: green for additions, red for deletions, cyan for headers - 36 tests passing --- src/ai_client.py | 71 ++++++++++++++++++---------- src/app_controller.py | 9 ++-- src/gui_2.py | 84 ++++++++++++++++++++++++++++++++- src/multi_agent_conductor.py | 1 + src/shell_runner.py | 7 ++- tests/test_tier4_interceptor.py | 4 +- 6 files changed, 143 insertions(+), 33 deletions(-) diff --git a/src/ai_client.py b/src/ai_client.py index 778e6a5..e3320db 100644 --- a/src/ai_client.py +++ b/src/ai_client.py @@ -85,7 +85,7 @@ _send_lock: threading.Lock = threading.Lock() _gemini_cli_adapter: Optional[GeminiCliAdapter] = None # Injected by gui.py - called when AI wants to run a command. -confirm_and_run_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None +confirm_and_run_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]], Optional[Callable[[str, str], Optional[str]]]], Optional[str]]] = None # Injected by gui.py - called whenever a comms entry is appended. comms_log_callback: Optional[Callable[[dict[str, Any]], None]] = None @@ -558,7 +558,8 @@ async def _execute_tool_calls_concurrently( pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]], qa_callback: Optional[Callable[[str], str]], r_idx: int, - provider: str + provider: str, + patch_callback: Optional[Callable[[str, str], Optional[str]]] = None ) -> list[tuple[str, str, str, str]]: # tool_name, call_id, output, original_name """ Executes multiple tool calls concurrently using asyncio.gather. @@ -589,7 +590,7 @@ async def _execute_tool_calls_concurrently( else: continue - tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx)) + tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx, patch_callback)) results = await asyncio.gather(*tasks) return results @@ -601,7 +602,8 @@ async def _execute_single_tool_call_async( base_dir: str, pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]], qa_callback: Optional[Callable[[str], str]], - r_idx: int + r_idx: int, + patch_callback: Optional[Callable[[str, str], Optional[str]]] = None ) -> tuple[str, str, str, str]: out = "" tool_executed = False @@ -631,16 +633,16 @@ async def _execute_single_tool_call_async( elif name == TOOL_NAME: scr = cast(str, args.get("script", "")) _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr}) - out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback) + out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback, patch_callback) else: out = f"ERROR: unknown tool '{name}'" return (name, call_id, out, name) -def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str: +def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str: if confirm_and_run_callback is None: return "ERROR: no confirmation handler registered" - result = confirm_and_run_callback(script, base_dir, qa_callback) + result = confirm_and_run_callback(script, base_dir, qa_callback, patch_callback) if result is None: output = "USER REJECTED: command was not executed" else: @@ -799,7 +801,8 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None, qa_callback: Optional[Callable[[str], str]] = None, enable_tools: bool = True, - stream_callback: Optional[Callable[[str], None]] = None) -> str: + stream_callback: Optional[Callable[[str], None]] = None, + patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str: global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at try: _ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir]) @@ -979,11 +982,11 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str, try: loop = asyncio.get_running_loop() results = asyncio.run_coroutine_threadsafe( - _execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"), + _execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini", patch_callback), loop ).result() except RuntimeError: - results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini")) + results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini", patch_callback)) for i, (name, call_id, out, _) in enumerate(results): # Check if this is the last tool to trigger file refresh @@ -1079,11 +1082,11 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str, try: loop = asyncio.get_running_loop() results = asyncio.run_coroutine_threadsafe( - _execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"), + _execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback), loop ).result() except RuntimeError: - results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli")) + results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback)) tool_results_for_cli: list[dict[str, Any]] = [] for i, (name, call_id, out, _) in enumerate(results): @@ -1404,7 +1407,7 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item loop ).result() except RuntimeError: - results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic")) + results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic", patch_callback)) tool_results: list[dict[str, Any]] = [] for i, (name, call_id, out, _) in enumerate(results): @@ -1675,11 +1678,11 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str, try: loop = asyncio.get_running_loop() results = asyncio.run_coroutine_threadsafe( - _execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"), + _execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek", patch_callback), loop ).result() except RuntimeError: - results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek")) + results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek", patch_callback)) tool_results_for_history: list[dict[str, Any]] = [] for i, (name, call_id, out, _) in enumerate(results): @@ -1891,11 +1894,11 @@ def _send_minimax(md_content: str, user_message: str, base_dir: str, try: loop = asyncio.get_running_loop() results = asyncio.run_coroutine_threadsafe( - _execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax"), + _execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax", patch_callback), loop ).result() except RuntimeError: - results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax")) + results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax", patch_callback)) tool_results_for_history: list[dict[str, Any]] = [] for i, (name, call_id, out, _) in enumerate(results): @@ -1962,6 +1965,23 @@ def run_tier4_analysis(stderr: str) -> str: return f"[QA ANALYSIS FAILED] {e}" +def run_tier4_patch_callback(stderr: str, base_dir: str) -> Optional[str]: + try: + from src import project_manager + file_items = project_manager.get_current_file_items() + file_context = "" + for item in file_items[:5]: + path = item.get("path", "") + content = item.get("content", "")[:2000] + file_context += f"\n\nFile: {path}\n```\n{content}\n```\n" + patch = run_tier4_patch_generation(stderr, file_context) + if patch and "---" in patch and "+++" in patch: + return patch + return None + except Exception as e: + return None + + def run_tier4_patch_generation(error: str, file_context: str) -> str: if not error or not error.strip(): return "" @@ -2032,32 +2052,33 @@ def send( qa_callback: Optional[Callable[[str], str]] = None, enable_tools: bool = True, stream_callback: Optional[Callable[[str], None]] = None, + patch_callback: Optional[Callable[[str, str], Optional[str]]] = None, ) -> str: with _send_lock: if _provider == "gemini": return _send_gemini( md_content, user_message, base_dir, file_items, discussion_history, - pre_tool_callback, qa_callback, enable_tools, stream_callback + pre_tool_callback, qa_callback, enable_tools, stream_callback, patch_callback ) elif _provider == "gemini_cli": return _send_gemini_cli( md_content, user_message, base_dir, file_items, discussion_history, - pre_tool_callback, qa_callback, stream_callback + pre_tool_callback, qa_callback, stream_callback, patch_callback ) elif _provider == "anthropic": return _send_anthropic( md_content, user_message, base_dir, file_items, discussion_history, - pre_tool_callback, qa_callback, stream_callback=stream_callback + pre_tool_callback, qa_callback, stream_callback=stream_callback, patch_callback=patch_callback ) elif _provider == "deepseek": return _send_deepseek( md_content, user_message, base_dir, file_items, discussion_history, - stream, pre_tool_callback, qa_callback, stream_callback + stream, pre_tool_callback, qa_callback, stream_callback, patch_callback ) elif _provider == "minimax": return _send_minimax( md_content, user_message, base_dir, file_items, discussion_history, - stream, pre_tool_callback, qa_callback, stream_callback + stream, pre_tool_callback, qa_callback, stream_callback, patch_callback ) else: raise ValueError(f"Unknown provider: {_provider}") @@ -2079,16 +2100,16 @@ def _is_mutating_tool(name: str) -> bool: """Returns True if the tool name is considered a mutating tool.""" return name in mcp_client.MUTATING_TOOLS or name == TOOL_NAME -def _confirm_and_run(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> Optional[str]: +def _confirm_and_run(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Optional[str]: """ Wrapper for the confirm_and_run_callback. This is what the providers call to trigger HITL approval. """ if confirm_and_run_callback: - return confirm_and_run_callback(script, base_dir, qa_callback) + return confirm_and_run_callback(script, base_dir, qa_callback, patch_callback) # Fallback to direct execution if no callback registered (headless default) from src import shell_runner - return shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback) + return shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback, patch_callback=patch_callback) def get_history_bleed_stats(md_content: Optional[str] = None) -> dict[str, Any]: if _provider == "anthropic": diff --git a/src/app_controller.py b/src/app_controller.py index 9919706..a01c35b 100644 --- a/src/app_controller.py +++ b/src/app_controller.py @@ -907,7 +907,8 @@ class AppController: stream=True, stream_callback=lambda text: self._on_ai_stream(text), pre_tool_callback=self._confirm_and_run, - qa_callback=ai_client.run_tier4_analysis + qa_callback=ai_client.run_tier4_analysis, + patch_callback=ai_client.run_tier4_patch_callback ) self.event_queue.put("response", {"text": resp, "status": "done", "role": "AI"}) except ai_client.ProviderError as e: @@ -988,14 +989,14 @@ class AppController: "ts": project_manager.now_ts() }) - def _confirm_and_run(self, script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> Optional[str]: + def _confirm_and_run(self, script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Optional[str]: sys.stderr.write(f"[DEBUG] _confirm_and_run called. test_hooks={self.test_hooks_enabled}, manual_approve={getattr(self, 'ui_manual_approve', False)}\n") sys.stderr.flush() if self.test_hooks_enabled and not getattr(self, "ui_manual_approve", False): sys.stderr.write("[DEBUG] Auto-approving script.\n") sys.stderr.flush() self._set_status("running powershell...") - output = shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback) + output = shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback, patch_callback=patch_callback) self._append_tool_log(script, output) self._set_status("powershell done, awaiting AI...") return output @@ -1033,7 +1034,7 @@ class AppController: self._append_tool_log(final_script, "REJECTED by user") return None self._set_status("running powershell...") - output = shell_runner.run_powershell(final_script, base_dir, qa_callback=qa_callback) + output = shell_runner.run_powershell(final_script, base_dir, qa_callback=qa_callback, patch_callback=patch_callback) self._append_tool_log(final_script, output) self._set_status("powershell done, awaiting AI...") return output diff --git a/src/gui_2.py b/src/gui_2.py index 72d1593..af63e7f 100644 --- a/src/gui_2.py +++ b/src/gui_2.py @@ -1,4 +1,4 @@ -# gui_2.py +# gui_2.py from __future__ import annotations import tomli_w import time @@ -114,6 +114,11 @@ class App: self._tool_log_dirty: bool = True self._last_ui_focus_agent: Optional[str] = None self._log_registry: Optional[log_registry.LogRegistry] = None + # Patch viewer state for Tier 4 auto-patching + self._pending_patch_text: Optional[str] = None + self._pending_patch_files: list[str] = [] + self._show_patch_modal: bool = False + self._patch_error_message: Optional[str] = None def _handle_approve_tool(self, user_data=None) -> None: """UI-level wrapper for approving a pending tool execution ask.""" @@ -254,6 +259,7 @@ class App: self._process_pending_gui_tasks() self._process_pending_history_adds() self._render_track_proposal_modal() + self._render_patch_modal() # Auto-save (every 60s) now = time.time() if now - self._last_autosave >= self._autosave_interval: @@ -873,6 +879,82 @@ class App: imgui.close_current_popup() imgui.end_popup() + def _render_patch_modal(self) -> None: + if not self._show_patch_modal: + return + imgui.open_popup("Apply Patch?") + if imgui.begin_popup_modal("Apply Patch?", True, imgui.ImVec2(600, 500))[0]: + imgui.text_colored(imgui.ImVec4(1, 0.9, 0.3, 1), "Tier 4 QA Generated a Patch") + imgui.separator() + if self._pending_patch_files: + imgui.text("Files to modify:") + for f in self._pending_patch_files: + imgui.text(f" - {f}") + imgui.separator() + if self._patch_error_message: + imgui.text_colored(imgui.ImVec4(1, 0.3, 0.3, 1), f"Error: {self._patch_error_message}") + imgui.separator() + imgui.text("Diff Preview:") + imgui.begin_child("patch_diff_scroll", imgui.ImVec2(-1, 280), True) + if self._pending_patch_text: + diff_lines = self._pending_patch_text.split("\n") + for line in diff_lines: + if line.startswith("+++") or line.startswith("---") or line.startswith("@@"): + imgui.text_colored(imgui.ImVec4(0.3, 0.7, 1, 1), line) + elif line.startswith("+"): + imgui.text_colored(imgui.ImVec4(0.2, 0.9, 0.2, 1), line) + elif line.startswith("-"): + imgui.text_colored(imgui.ImVec4(0.9, 0.2, 0.2, 1), line) + else: + imgui.text(line) + imgui.end_child() + imgui.separator() + if imgui.button("Apply Patch", imgui.ImVec2(150, 0)): + self._apply_pending_patch() + imgui.same_line() + if imgui.button("Reject", imgui.ImVec2(150, 0)): + self._show_patch_modal = False + self._pending_patch_text = None + self._pending_patch_files = [] + self._patch_error_message = None + imgui.close_current_popup() + imgui.end_popup() + + def _apply_pending_patch(self) -> None: + if not self._pending_patch_text: + self._patch_error_message = "No patch to apply" + return + try: + from src.diff_viewer import apply_patch_to_file + base_dir = str(self.controller.current_project_dir) if hasattr(self.controller, 'current_project_dir') else "." + success, msg = apply_patch_to_file(self._pending_patch_text, base_dir) + if success: + self._show_patch_modal = False + self._pending_patch_text = None + self._pending_patch_files = [] + self._patch_error_message = None + imgui.close_current_popup() + else: + self._patch_error_message = msg + except Exception as e: + self._patch_error_message = str(e) + + def request_patch_from_tier4(self, error: str, file_context: str) -> None: + try: + from src import ai_client + from src.diff_viewer import parse_diff + patch_text = ai_client.run_tier4_patch_generation(error, file_context) + if patch_text and "---" in patch_text and "+++" in patch_text: + diff_files = parse_diff(patch_text) + file_paths = [df.old_path for df in diff_files] + self._pending_patch_text = patch_text + self._pending_patch_files = file_paths + self._show_patch_modal = True + else: + self._patch_error_message = patch_text or "No patch generated" + except Exception as e: + self._patch_error_message = str(e) + def _render_log_management(self) -> None: exp, opened = imgui.begin("Log Management", self.show_windows["Log Management"]) self.show_windows["Log Management"] = bool(opened) diff --git a/src/multi_agent_conductor.py b/src/multi_agent_conductor.py index d7e83d4..cae896c 100644 --- a/src/multi_agent_conductor.py +++ b/src/multi_agent_conductor.py @@ -410,6 +410,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files: base_dir=".", pre_tool_callback=clutch_callback if ticket.step_mode else None, qa_callback=ai_client.run_tier4_analysis, + patch_callback=ai_client.run_tier4_patch_callback, stream_callback=stream_callback ) finally: diff --git a/src/shell_runner.py b/src/shell_runner.py index cf2f4b5..24884f3 100644 --- a/src/shell_runner.py +++ b/src/shell_runner.py @@ -44,13 +44,14 @@ def _build_subprocess_env() -> dict[str, str]: env[key] = os.path.expandvars(str(val)) return env -def run_powershell(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str: +def run_powershell(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str: """ Run a PowerShell script with working directory set to base_dir. Returns a string combining stdout, stderr, and exit code. Environment is configured via mcp_env.toml (project root). If qa_callback is provided and the command fails or has stderr, the callback is called with the stderr content and its result is appended. + If patch_callback is provided, it receives (error, file_context) and returns patch text. """ safe_dir: str = str(base_dir).replace("'", "''") full_script: str = f"Set-Location -LiteralPath '{safe_dir}'\n{script}" @@ -72,6 +73,10 @@ def run_powershell(script: str, base_dir: str, qa_callback: Optional[Callable[[s qa_analysis: Optional[str] = qa_callback(stderr.strip()) if qa_analysis: parts.append(f"\nQA ANALYSIS:\n{qa_analysis}") + if patch_callback and (process.returncode != 0 or stderr.strip()): + patch_text = patch_callback(stderr.strip(), base_dir) + if patch_text: + parts.append(f"\nAUTO_PATCH:\n{patch_text}") return "\n".join(parts) except subprocess.TimeoutExpired: if 'process' in locals() and process: diff --git a/tests/test_tier4_interceptor.py b/tests/test_tier4_interceptor.py index e81ebec..09f991e 100644 --- a/tests/test_tier4_interceptor.py +++ b/tests/test_tier4_interceptor.py @@ -130,5 +130,5 @@ def test_gemini_provider_passes_qa_callback_to_run_script() -> None: base_dir=".", qa_callback=qa_callback ) - # Verify _run_script received the qa_callback - mock_run_script.assert_called_with("dir", ".", qa_callback) + # Verify _run_script received the qa_callback and patch_callback + mock_run_script.assert_called_with("dir", ".", qa_callback, None)