feat(tier4): Integrate patch generation into GUI workflow

- Add patch_callback parameter throughout the tool execution chain
- Add _render_patch_modal() to gui_2.py with colored diff display
- Add patch modal state variables to App.__init__
- Add request_patch_from_tier4() to trigger patch generation
- Add run_tier4_patch_callback() to ai_client.py
- Update shell_runner to accept and execute patch_callback
- Diff colors: green for additions, red for deletions, cyan for headers
- 36 tests passing
This commit is contained in:
2026-03-07 00:26:34 -05:00
parent 72a71706e3
commit 90670b9671
6 changed files with 143 additions and 33 deletions

View File

@@ -85,7 +85,7 @@ _send_lock: threading.Lock = threading.Lock()
_gemini_cli_adapter: Optional[GeminiCliAdapter] = None
# Injected by gui.py - called when AI wants to run a command.
confirm_and_run_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None
confirm_and_run_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]], Optional[Callable[[str, str], Optional[str]]]], Optional[str]]] = None
# Injected by gui.py - called whenever a comms entry is appended.
comms_log_callback: Optional[Callable[[dict[str, Any]], None]] = None
@@ -558,7 +558,8 @@ async def _execute_tool_calls_concurrently(
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
qa_callback: Optional[Callable[[str], str]],
r_idx: int,
provider: str
provider: str,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None
) -> list[tuple[str, str, str, str]]: # tool_name, call_id, output, original_name
"""
Executes multiple tool calls concurrently using asyncio.gather.
@@ -589,7 +590,7 @@ async def _execute_tool_calls_concurrently(
else:
continue
tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx))
tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx, patch_callback))
results = await asyncio.gather(*tasks)
return results
@@ -601,7 +602,8 @@ async def _execute_single_tool_call_async(
base_dir: str,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
qa_callback: Optional[Callable[[str], str]],
r_idx: int
r_idx: int,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None
) -> tuple[str, str, str, str]:
out = ""
tool_executed = False
@@ -631,16 +633,16 @@ async def _execute_single_tool_call_async(
elif name == TOOL_NAME:
scr = cast(str, args.get("script", ""))
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback)
out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback, patch_callback)
else:
out = f"ERROR: unknown tool '{name}'"
return (name, call_id, out, name)
def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str:
def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
if confirm_and_run_callback is None:
return "ERROR: no confirmation handler registered"
result = confirm_and_run_callback(script, base_dir, qa_callback)
result = confirm_and_run_callback(script, base_dir, qa_callback, patch_callback)
if result is None:
output = "USER REJECTED: command was not executed"
else:
@@ -799,7 +801,8 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
qa_callback: Optional[Callable[[str], str]] = None,
enable_tools: bool = True,
stream_callback: Optional[Callable[[str], None]] = None) -> str:
stream_callback: Optional[Callable[[str], None]] = None,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at
try:
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
@@ -979,11 +982,11 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
try:
loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"),
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini", patch_callback),
loop
).result()
except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"))
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini", patch_callback))
for i, (name, call_id, out, _) in enumerate(results):
# Check if this is the last tool to trigger file refresh
@@ -1079,11 +1082,11 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
try:
loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"),
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback),
loop
).result()
except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"))
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback))
tool_results_for_cli: list[dict[str, Any]] = []
for i, (name, call_id, out, _) in enumerate(results):
@@ -1404,7 +1407,7 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
loop
).result()
except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic"))
results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic", patch_callback))
tool_results: list[dict[str, Any]] = []
for i, (name, call_id, out, _) in enumerate(results):
@@ -1675,11 +1678,11 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
try:
loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"),
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek", patch_callback),
loop
).result()
except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"))
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek", patch_callback))
tool_results_for_history: list[dict[str, Any]] = []
for i, (name, call_id, out, _) in enumerate(results):
@@ -1891,11 +1894,11 @@ def _send_minimax(md_content: str, user_message: str, base_dir: str,
try:
loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax"),
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax", patch_callback),
loop
).result()
except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax"))
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax", patch_callback))
tool_results_for_history: list[dict[str, Any]] = []
for i, (name, call_id, out, _) in enumerate(results):
@@ -1962,6 +1965,23 @@ def run_tier4_analysis(stderr: str) -> str:
return f"[QA ANALYSIS FAILED] {e}"
def run_tier4_patch_callback(stderr: str, base_dir: str) -> Optional[str]:
try:
from src import project_manager
file_items = project_manager.get_current_file_items()
file_context = ""
for item in file_items[:5]:
path = item.get("path", "")
content = item.get("content", "")[:2000]
file_context += f"\n\nFile: {path}\n```\n{content}\n```\n"
patch = run_tier4_patch_generation(stderr, file_context)
if patch and "---" in patch and "+++" in patch:
return patch
return None
except Exception as e:
return None
def run_tier4_patch_generation(error: str, file_context: str) -> str:
if not error or not error.strip():
return ""
@@ -2032,32 +2052,33 @@ def send(
qa_callback: Optional[Callable[[str], str]] = None,
enable_tools: bool = True,
stream_callback: Optional[Callable[[str], None]] = None,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None,
) -> str:
with _send_lock:
if _provider == "gemini":
return _send_gemini(
md_content, user_message, base_dir, file_items, discussion_history,
pre_tool_callback, qa_callback, enable_tools, stream_callback
pre_tool_callback, qa_callback, enable_tools, stream_callback, patch_callback
)
elif _provider == "gemini_cli":
return _send_gemini_cli(
md_content, user_message, base_dir, file_items, discussion_history,
pre_tool_callback, qa_callback, stream_callback
pre_tool_callback, qa_callback, stream_callback, patch_callback
)
elif _provider == "anthropic":
return _send_anthropic(
md_content, user_message, base_dir, file_items, discussion_history,
pre_tool_callback, qa_callback, stream_callback=stream_callback
pre_tool_callback, qa_callback, stream_callback=stream_callback, patch_callback=patch_callback
)
elif _provider == "deepseek":
return _send_deepseek(
md_content, user_message, base_dir, file_items, discussion_history,
stream, pre_tool_callback, qa_callback, stream_callback
stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
)
elif _provider == "minimax":
return _send_minimax(
md_content, user_message, base_dir, file_items, discussion_history,
stream, pre_tool_callback, qa_callback, stream_callback
stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
)
else:
raise ValueError(f"Unknown provider: {_provider}")
@@ -2079,16 +2100,16 @@ def _is_mutating_tool(name: str) -> bool:
"""Returns True if the tool name is considered a mutating tool."""
return name in mcp_client.MUTATING_TOOLS or name == TOOL_NAME
def _confirm_and_run(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> Optional[str]:
def _confirm_and_run(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Optional[str]:
"""
Wrapper for the confirm_and_run_callback.
This is what the providers call to trigger HITL approval.
"""
if confirm_and_run_callback:
return confirm_and_run_callback(script, base_dir, qa_callback)
return confirm_and_run_callback(script, base_dir, qa_callback, patch_callback)
# Fallback to direct execution if no callback registered (headless default)
from src import shell_runner
return shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback)
return shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback, patch_callback=patch_callback)
def get_history_bleed_stats(md_content: Optional[str] = None) -> dict[str, Any]:
if _provider == "anthropic":

View File

@@ -907,7 +907,8 @@ class AppController:
stream=True,
stream_callback=lambda text: self._on_ai_stream(text),
pre_tool_callback=self._confirm_and_run,
qa_callback=ai_client.run_tier4_analysis
qa_callback=ai_client.run_tier4_analysis,
patch_callback=ai_client.run_tier4_patch_callback
)
self.event_queue.put("response", {"text": resp, "status": "done", "role": "AI"})
except ai_client.ProviderError as e:
@@ -988,14 +989,14 @@ class AppController:
"ts": project_manager.now_ts()
})
def _confirm_and_run(self, script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> Optional[str]:
def _confirm_and_run(self, script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Optional[str]:
sys.stderr.write(f"[DEBUG] _confirm_and_run called. test_hooks={self.test_hooks_enabled}, manual_approve={getattr(self, 'ui_manual_approve', False)}\n")
sys.stderr.flush()
if self.test_hooks_enabled and not getattr(self, "ui_manual_approve", False):
sys.stderr.write("[DEBUG] Auto-approving script.\n")
sys.stderr.flush()
self._set_status("running powershell...")
output = shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback)
output = shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback, patch_callback=patch_callback)
self._append_tool_log(script, output)
self._set_status("powershell done, awaiting AI...")
return output
@@ -1033,7 +1034,7 @@ class AppController:
self._append_tool_log(final_script, "REJECTED by user")
return None
self._set_status("running powershell...")
output = shell_runner.run_powershell(final_script, base_dir, qa_callback=qa_callback)
output = shell_runner.run_powershell(final_script, base_dir, qa_callback=qa_callback, patch_callback=patch_callback)
self._append_tool_log(final_script, output)
self._set_status("powershell done, awaiting AI...")
return output

View File

@@ -1,4 +1,4 @@
# gui_2.py
# gui_2.py
from __future__ import annotations
import tomli_w
import time
@@ -114,6 +114,11 @@ class App:
self._tool_log_dirty: bool = True
self._last_ui_focus_agent: Optional[str] = None
self._log_registry: Optional[log_registry.LogRegistry] = None
# Patch viewer state for Tier 4 auto-patching
self._pending_patch_text: Optional[str] = None
self._pending_patch_files: list[str] = []
self._show_patch_modal: bool = False
self._patch_error_message: Optional[str] = None
def _handle_approve_tool(self, user_data=None) -> None:
"""UI-level wrapper for approving a pending tool execution ask."""
@@ -254,6 +259,7 @@ class App:
self._process_pending_gui_tasks()
self._process_pending_history_adds()
self._render_track_proposal_modal()
self._render_patch_modal()
# Auto-save (every 60s)
now = time.time()
if now - self._last_autosave >= self._autosave_interval:
@@ -873,6 +879,82 @@ class App:
imgui.close_current_popup()
imgui.end_popup()
def _render_patch_modal(self) -> None:
if not self._show_patch_modal:
return
imgui.open_popup("Apply Patch?")
if imgui.begin_popup_modal("Apply Patch?", True, imgui.ImVec2(600, 500))[0]:
imgui.text_colored(imgui.ImVec4(1, 0.9, 0.3, 1), "Tier 4 QA Generated a Patch")
imgui.separator()
if self._pending_patch_files:
imgui.text("Files to modify:")
for f in self._pending_patch_files:
imgui.text(f" - {f}")
imgui.separator()
if self._patch_error_message:
imgui.text_colored(imgui.ImVec4(1, 0.3, 0.3, 1), f"Error: {self._patch_error_message}")
imgui.separator()
imgui.text("Diff Preview:")
imgui.begin_child("patch_diff_scroll", imgui.ImVec2(-1, 280), True)
if self._pending_patch_text:
diff_lines = self._pending_patch_text.split("\n")
for line in diff_lines:
if line.startswith("+++") or line.startswith("---") or line.startswith("@@"):
imgui.text_colored(imgui.ImVec4(0.3, 0.7, 1, 1), line)
elif line.startswith("+"):
imgui.text_colored(imgui.ImVec4(0.2, 0.9, 0.2, 1), line)
elif line.startswith("-"):
imgui.text_colored(imgui.ImVec4(0.9, 0.2, 0.2, 1), line)
else:
imgui.text(line)
imgui.end_child()
imgui.separator()
if imgui.button("Apply Patch", imgui.ImVec2(150, 0)):
self._apply_pending_patch()
imgui.same_line()
if imgui.button("Reject", imgui.ImVec2(150, 0)):
self._show_patch_modal = False
self._pending_patch_text = None
self._pending_patch_files = []
self._patch_error_message = None
imgui.close_current_popup()
imgui.end_popup()
def _apply_pending_patch(self) -> None:
if not self._pending_patch_text:
self._patch_error_message = "No patch to apply"
return
try:
from src.diff_viewer import apply_patch_to_file
base_dir = str(self.controller.current_project_dir) if hasattr(self.controller, 'current_project_dir') else "."
success, msg = apply_patch_to_file(self._pending_patch_text, base_dir)
if success:
self._show_patch_modal = False
self._pending_patch_text = None
self._pending_patch_files = []
self._patch_error_message = None
imgui.close_current_popup()
else:
self._patch_error_message = msg
except Exception as e:
self._patch_error_message = str(e)
def request_patch_from_tier4(self, error: str, file_context: str) -> None:
try:
from src import ai_client
from src.diff_viewer import parse_diff
patch_text = ai_client.run_tier4_patch_generation(error, file_context)
if patch_text and "---" in patch_text and "+++" in patch_text:
diff_files = parse_diff(patch_text)
file_paths = [df.old_path for df in diff_files]
self._pending_patch_text = patch_text
self._pending_patch_files = file_paths
self._show_patch_modal = True
else:
self._patch_error_message = patch_text or "No patch generated"
except Exception as e:
self._patch_error_message = str(e)
def _render_log_management(self) -> None:
exp, opened = imgui.begin("Log Management", self.show_windows["Log Management"])
self.show_windows["Log Management"] = bool(opened)

View File

@@ -410,6 +410,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
base_dir=".",
pre_tool_callback=clutch_callback if ticket.step_mode else None,
qa_callback=ai_client.run_tier4_analysis,
patch_callback=ai_client.run_tier4_patch_callback,
stream_callback=stream_callback
)
finally:

View File

@@ -44,13 +44,14 @@ def _build_subprocess_env() -> dict[str, str]:
env[key] = os.path.expandvars(str(val))
return env
def run_powershell(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str:
def run_powershell(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
"""
Run a PowerShell script with working directory set to base_dir.
Returns a string combining stdout, stderr, and exit code.
Environment is configured via mcp_env.toml (project root).
If qa_callback is provided and the command fails or has stderr,
the callback is called with the stderr content and its result is appended.
If patch_callback is provided, it receives (error, file_context) and returns patch text.
"""
safe_dir: str = str(base_dir).replace("'", "''")
full_script: str = f"Set-Location -LiteralPath '{safe_dir}'\n{script}"
@@ -72,6 +73,10 @@ def run_powershell(script: str, base_dir: str, qa_callback: Optional[Callable[[s
qa_analysis: Optional[str] = qa_callback(stderr.strip())
if qa_analysis:
parts.append(f"\nQA ANALYSIS:\n{qa_analysis}")
if patch_callback and (process.returncode != 0 or stderr.strip()):
patch_text = patch_callback(stderr.strip(), base_dir)
if patch_text:
parts.append(f"\nAUTO_PATCH:\n{patch_text}")
return "\n".join(parts)
except subprocess.TimeoutExpired:
if 'process' in locals() and process:

View File

@@ -130,5 +130,5 @@ def test_gemini_provider_passes_qa_callback_to_run_script() -> None:
base_dir=".",
qa_callback=qa_callback
)
# Verify _run_script received the qa_callback
mock_run_script.assert_called_with("dir", ".", qa_callback)
# Verify _run_script received the qa_callback and patch_callback
mock_run_script.assert_called_with("dir", ".", qa_callback, None)