feat(tier4): Integrate patch generation into GUI workflow

- Add patch_callback parameter throughout the tool execution chain
- Add _render_patch_modal() to gui_2.py with colored diff display
- Add patch modal state variables to App.__init__
- Add request_patch_from_tier4() to trigger patch generation
- Add run_tier4_patch_callback() to ai_client.py
- Update shell_runner to accept and execute patch_callback
- Diff colors: green for additions, red for deletions, cyan for headers
- 36 tests passing
This commit is contained in:
2026-03-07 00:26:34 -05:00
parent 72a71706e3
commit 90670b9671
6 changed files with 143 additions and 33 deletions

View File

@@ -85,7 +85,7 @@ _send_lock: threading.Lock = threading.Lock()
_gemini_cli_adapter: Optional[GeminiCliAdapter] = None _gemini_cli_adapter: Optional[GeminiCliAdapter] = None
# Injected by gui.py - called when AI wants to run a command. # Injected by gui.py - called when AI wants to run a command.
confirm_and_run_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None confirm_and_run_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]], Optional[Callable[[str, str], Optional[str]]]], Optional[str]]] = None
# Injected by gui.py - called whenever a comms entry is appended. # Injected by gui.py - called whenever a comms entry is appended.
comms_log_callback: Optional[Callable[[dict[str, Any]], None]] = None comms_log_callback: Optional[Callable[[dict[str, Any]], None]] = None
@@ -558,7 +558,8 @@ async def _execute_tool_calls_concurrently(
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]], pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
qa_callback: Optional[Callable[[str], str]], qa_callback: Optional[Callable[[str], str]],
r_idx: int, r_idx: int,
provider: str provider: str,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None
) -> list[tuple[str, str, str, str]]: # tool_name, call_id, output, original_name ) -> list[tuple[str, str, str, str]]: # tool_name, call_id, output, original_name
""" """
Executes multiple tool calls concurrently using asyncio.gather. Executes multiple tool calls concurrently using asyncio.gather.
@@ -589,7 +590,7 @@ async def _execute_tool_calls_concurrently(
else: else:
continue continue
tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx)) tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx, patch_callback))
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
return results return results
@@ -601,7 +602,8 @@ async def _execute_single_tool_call_async(
base_dir: str, base_dir: str,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]], pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
qa_callback: Optional[Callable[[str], str]], qa_callback: Optional[Callable[[str], str]],
r_idx: int r_idx: int,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None
) -> tuple[str, str, str, str]: ) -> tuple[str, str, str, str]:
out = "" out = ""
tool_executed = False tool_executed = False
@@ -631,16 +633,16 @@ async def _execute_single_tool_call_async(
elif name == TOOL_NAME: elif name == TOOL_NAME:
scr = cast(str, args.get("script", "")) scr = cast(str, args.get("script", ""))
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr}) _append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback) out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback, patch_callback)
else: else:
out = f"ERROR: unknown tool '{name}'" out = f"ERROR: unknown tool '{name}'"
return (name, call_id, out, name) return (name, call_id, out, name)
def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str: def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
if confirm_and_run_callback is None: if confirm_and_run_callback is None:
return "ERROR: no confirmation handler registered" return "ERROR: no confirmation handler registered"
result = confirm_and_run_callback(script, base_dir, qa_callback) result = confirm_and_run_callback(script, base_dir, qa_callback, patch_callback)
if result is None: if result is None:
output = "USER REJECTED: command was not executed" output = "USER REJECTED: command was not executed"
else: else:
@@ -799,7 +801,8 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None, pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
qa_callback: Optional[Callable[[str], str]] = None, qa_callback: Optional[Callable[[str], str]] = None,
enable_tools: bool = True, enable_tools: bool = True,
stream_callback: Optional[Callable[[str], None]] = None) -> str: stream_callback: Optional[Callable[[str], None]] = None,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at
try: try:
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir]) _ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
@@ -979,11 +982,11 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe( results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"), _execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini", patch_callback),
loop loop
).result() ).result()
except RuntimeError: except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini")) results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini", patch_callback))
for i, (name, call_id, out, _) in enumerate(results): for i, (name, call_id, out, _) in enumerate(results):
# Check if this is the last tool to trigger file refresh # Check if this is the last tool to trigger file refresh
@@ -1079,11 +1082,11 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe( results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"), _execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback),
loop loop
).result() ).result()
except RuntimeError: except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli")) results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback))
tool_results_for_cli: list[dict[str, Any]] = [] tool_results_for_cli: list[dict[str, Any]] = []
for i, (name, call_id, out, _) in enumerate(results): for i, (name, call_id, out, _) in enumerate(results):
@@ -1404,7 +1407,7 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
loop loop
).result() ).result()
except RuntimeError: except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic")) results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic", patch_callback))
tool_results: list[dict[str, Any]] = [] tool_results: list[dict[str, Any]] = []
for i, (name, call_id, out, _) in enumerate(results): for i, (name, call_id, out, _) in enumerate(results):
@@ -1675,11 +1678,11 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe( results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"), _execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek", patch_callback),
loop loop
).result() ).result()
except RuntimeError: except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek")) results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek", patch_callback))
tool_results_for_history: list[dict[str, Any]] = [] tool_results_for_history: list[dict[str, Any]] = []
for i, (name, call_id, out, _) in enumerate(results): for i, (name, call_id, out, _) in enumerate(results):
@@ -1891,11 +1894,11 @@ def _send_minimax(md_content: str, user_message: str, base_dir: str,
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
results = asyncio.run_coroutine_threadsafe( results = asyncio.run_coroutine_threadsafe(
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax"), _execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax", patch_callback),
loop loop
).result() ).result()
except RuntimeError: except RuntimeError:
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax")) results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax", patch_callback))
tool_results_for_history: list[dict[str, Any]] = [] tool_results_for_history: list[dict[str, Any]] = []
for i, (name, call_id, out, _) in enumerate(results): for i, (name, call_id, out, _) in enumerate(results):
@@ -1962,6 +1965,23 @@ def run_tier4_analysis(stderr: str) -> str:
return f"[QA ANALYSIS FAILED] {e}" return f"[QA ANALYSIS FAILED] {e}"
def run_tier4_patch_callback(stderr: str, base_dir: str) -> Optional[str]:
try:
from src import project_manager
file_items = project_manager.get_current_file_items()
file_context = ""
for item in file_items[:5]:
path = item.get("path", "")
content = item.get("content", "")[:2000]
file_context += f"\n\nFile: {path}\n```\n{content}\n```\n"
patch = run_tier4_patch_generation(stderr, file_context)
if patch and "---" in patch and "+++" in patch:
return patch
return None
except Exception as e:
return None
def run_tier4_patch_generation(error: str, file_context: str) -> str: def run_tier4_patch_generation(error: str, file_context: str) -> str:
if not error or not error.strip(): if not error or not error.strip():
return "" return ""
@@ -2032,32 +2052,33 @@ def send(
qa_callback: Optional[Callable[[str], str]] = None, qa_callback: Optional[Callable[[str], str]] = None,
enable_tools: bool = True, enable_tools: bool = True,
stream_callback: Optional[Callable[[str], None]] = None, stream_callback: Optional[Callable[[str], None]] = None,
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None,
) -> str: ) -> str:
with _send_lock: with _send_lock:
if _provider == "gemini": if _provider == "gemini":
return _send_gemini( return _send_gemini(
md_content, user_message, base_dir, file_items, discussion_history, md_content, user_message, base_dir, file_items, discussion_history,
pre_tool_callback, qa_callback, enable_tools, stream_callback pre_tool_callback, qa_callback, enable_tools, stream_callback, patch_callback
) )
elif _provider == "gemini_cli": elif _provider == "gemini_cli":
return _send_gemini_cli( return _send_gemini_cli(
md_content, user_message, base_dir, file_items, discussion_history, md_content, user_message, base_dir, file_items, discussion_history,
pre_tool_callback, qa_callback, stream_callback pre_tool_callback, qa_callback, stream_callback, patch_callback
) )
elif _provider == "anthropic": elif _provider == "anthropic":
return _send_anthropic( return _send_anthropic(
md_content, user_message, base_dir, file_items, discussion_history, md_content, user_message, base_dir, file_items, discussion_history,
pre_tool_callback, qa_callback, stream_callback=stream_callback pre_tool_callback, qa_callback, stream_callback=stream_callback, patch_callback=patch_callback
) )
elif _provider == "deepseek": elif _provider == "deepseek":
return _send_deepseek( return _send_deepseek(
md_content, user_message, base_dir, file_items, discussion_history, md_content, user_message, base_dir, file_items, discussion_history,
stream, pre_tool_callback, qa_callback, stream_callback stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
) )
elif _provider == "minimax": elif _provider == "minimax":
return _send_minimax( return _send_minimax(
md_content, user_message, base_dir, file_items, discussion_history, md_content, user_message, base_dir, file_items, discussion_history,
stream, pre_tool_callback, qa_callback, stream_callback stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
) )
else: else:
raise ValueError(f"Unknown provider: {_provider}") raise ValueError(f"Unknown provider: {_provider}")
@@ -2079,16 +2100,16 @@ def _is_mutating_tool(name: str) -> bool:
"""Returns True if the tool name is considered a mutating tool.""" """Returns True if the tool name is considered a mutating tool."""
return name in mcp_client.MUTATING_TOOLS or name == TOOL_NAME return name in mcp_client.MUTATING_TOOLS or name == TOOL_NAME
def _confirm_and_run(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> Optional[str]: def _confirm_and_run(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Optional[str]:
""" """
Wrapper for the confirm_and_run_callback. Wrapper for the confirm_and_run_callback.
This is what the providers call to trigger HITL approval. This is what the providers call to trigger HITL approval.
""" """
if confirm_and_run_callback: if confirm_and_run_callback:
return confirm_and_run_callback(script, base_dir, qa_callback) return confirm_and_run_callback(script, base_dir, qa_callback, patch_callback)
# Fallback to direct execution if no callback registered (headless default) # Fallback to direct execution if no callback registered (headless default)
from src import shell_runner from src import shell_runner
return shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback) return shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback, patch_callback=patch_callback)
def get_history_bleed_stats(md_content: Optional[str] = None) -> dict[str, Any]: def get_history_bleed_stats(md_content: Optional[str] = None) -> dict[str, Any]:
if _provider == "anthropic": if _provider == "anthropic":

View File

@@ -907,7 +907,8 @@ class AppController:
stream=True, stream=True,
stream_callback=lambda text: self._on_ai_stream(text), stream_callback=lambda text: self._on_ai_stream(text),
pre_tool_callback=self._confirm_and_run, pre_tool_callback=self._confirm_and_run,
qa_callback=ai_client.run_tier4_analysis qa_callback=ai_client.run_tier4_analysis,
patch_callback=ai_client.run_tier4_patch_callback
) )
self.event_queue.put("response", {"text": resp, "status": "done", "role": "AI"}) self.event_queue.put("response", {"text": resp, "status": "done", "role": "AI"})
except ai_client.ProviderError as e: except ai_client.ProviderError as e:
@@ -988,14 +989,14 @@ class AppController:
"ts": project_manager.now_ts() "ts": project_manager.now_ts()
}) })
def _confirm_and_run(self, script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> Optional[str]: def _confirm_and_run(self, script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Optional[str]:
sys.stderr.write(f"[DEBUG] _confirm_and_run called. test_hooks={self.test_hooks_enabled}, manual_approve={getattr(self, 'ui_manual_approve', False)}\n") sys.stderr.write(f"[DEBUG] _confirm_and_run called. test_hooks={self.test_hooks_enabled}, manual_approve={getattr(self, 'ui_manual_approve', False)}\n")
sys.stderr.flush() sys.stderr.flush()
if self.test_hooks_enabled and not getattr(self, "ui_manual_approve", False): if self.test_hooks_enabled and not getattr(self, "ui_manual_approve", False):
sys.stderr.write("[DEBUG] Auto-approving script.\n") sys.stderr.write("[DEBUG] Auto-approving script.\n")
sys.stderr.flush() sys.stderr.flush()
self._set_status("running powershell...") self._set_status("running powershell...")
output = shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback) output = shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback, patch_callback=patch_callback)
self._append_tool_log(script, output) self._append_tool_log(script, output)
self._set_status("powershell done, awaiting AI...") self._set_status("powershell done, awaiting AI...")
return output return output
@@ -1033,7 +1034,7 @@ class AppController:
self._append_tool_log(final_script, "REJECTED by user") self._append_tool_log(final_script, "REJECTED by user")
return None return None
self._set_status("running powershell...") self._set_status("running powershell...")
output = shell_runner.run_powershell(final_script, base_dir, qa_callback=qa_callback) output = shell_runner.run_powershell(final_script, base_dir, qa_callback=qa_callback, patch_callback=patch_callback)
self._append_tool_log(final_script, output) self._append_tool_log(final_script, output)
self._set_status("powershell done, awaiting AI...") self._set_status("powershell done, awaiting AI...")
return output return output

View File

@@ -1,4 +1,4 @@
# gui_2.py # gui_2.py
from __future__ import annotations from __future__ import annotations
import tomli_w import tomli_w
import time import time
@@ -114,6 +114,11 @@ class App:
self._tool_log_dirty: bool = True self._tool_log_dirty: bool = True
self._last_ui_focus_agent: Optional[str] = None self._last_ui_focus_agent: Optional[str] = None
self._log_registry: Optional[log_registry.LogRegistry] = None self._log_registry: Optional[log_registry.LogRegistry] = None
# Patch viewer state for Tier 4 auto-patching
self._pending_patch_text: Optional[str] = None
self._pending_patch_files: list[str] = []
self._show_patch_modal: bool = False
self._patch_error_message: Optional[str] = None
def _handle_approve_tool(self, user_data=None) -> None: def _handle_approve_tool(self, user_data=None) -> None:
"""UI-level wrapper for approving a pending tool execution ask.""" """UI-level wrapper for approving a pending tool execution ask."""
@@ -254,6 +259,7 @@ class App:
self._process_pending_gui_tasks() self._process_pending_gui_tasks()
self._process_pending_history_adds() self._process_pending_history_adds()
self._render_track_proposal_modal() self._render_track_proposal_modal()
self._render_patch_modal()
# Auto-save (every 60s) # Auto-save (every 60s)
now = time.time() now = time.time()
if now - self._last_autosave >= self._autosave_interval: if now - self._last_autosave >= self._autosave_interval:
@@ -873,6 +879,82 @@ class App:
imgui.close_current_popup() imgui.close_current_popup()
imgui.end_popup() imgui.end_popup()
def _render_patch_modal(self) -> None:
if not self._show_patch_modal:
return
imgui.open_popup("Apply Patch?")
if imgui.begin_popup_modal("Apply Patch?", True, imgui.ImVec2(600, 500))[0]:
imgui.text_colored(imgui.ImVec4(1, 0.9, 0.3, 1), "Tier 4 QA Generated a Patch")
imgui.separator()
if self._pending_patch_files:
imgui.text("Files to modify:")
for f in self._pending_patch_files:
imgui.text(f" - {f}")
imgui.separator()
if self._patch_error_message:
imgui.text_colored(imgui.ImVec4(1, 0.3, 0.3, 1), f"Error: {self._patch_error_message}")
imgui.separator()
imgui.text("Diff Preview:")
imgui.begin_child("patch_diff_scroll", imgui.ImVec2(-1, 280), True)
if self._pending_patch_text:
diff_lines = self._pending_patch_text.split("\n")
for line in diff_lines:
if line.startswith("+++") or line.startswith("---") or line.startswith("@@"):
imgui.text_colored(imgui.ImVec4(0.3, 0.7, 1, 1), line)
elif line.startswith("+"):
imgui.text_colored(imgui.ImVec4(0.2, 0.9, 0.2, 1), line)
elif line.startswith("-"):
imgui.text_colored(imgui.ImVec4(0.9, 0.2, 0.2, 1), line)
else:
imgui.text(line)
imgui.end_child()
imgui.separator()
if imgui.button("Apply Patch", imgui.ImVec2(150, 0)):
self._apply_pending_patch()
imgui.same_line()
if imgui.button("Reject", imgui.ImVec2(150, 0)):
self._show_patch_modal = False
self._pending_patch_text = None
self._pending_patch_files = []
self._patch_error_message = None
imgui.close_current_popup()
imgui.end_popup()
def _apply_pending_patch(self) -> None:
if not self._pending_patch_text:
self._patch_error_message = "No patch to apply"
return
try:
from src.diff_viewer import apply_patch_to_file
base_dir = str(self.controller.current_project_dir) if hasattr(self.controller, 'current_project_dir') else "."
success, msg = apply_patch_to_file(self._pending_patch_text, base_dir)
if success:
self._show_patch_modal = False
self._pending_patch_text = None
self._pending_patch_files = []
self._patch_error_message = None
imgui.close_current_popup()
else:
self._patch_error_message = msg
except Exception as e:
self._patch_error_message = str(e)
def request_patch_from_tier4(self, error: str, file_context: str) -> None:
try:
from src import ai_client
from src.diff_viewer import parse_diff
patch_text = ai_client.run_tier4_patch_generation(error, file_context)
if patch_text and "---" in patch_text and "+++" in patch_text:
diff_files = parse_diff(patch_text)
file_paths = [df.old_path for df in diff_files]
self._pending_patch_text = patch_text
self._pending_patch_files = file_paths
self._show_patch_modal = True
else:
self._patch_error_message = patch_text or "No patch generated"
except Exception as e:
self._patch_error_message = str(e)
def _render_log_management(self) -> None: def _render_log_management(self) -> None:
exp, opened = imgui.begin("Log Management", self.show_windows["Log Management"]) exp, opened = imgui.begin("Log Management", self.show_windows["Log Management"])
self.show_windows["Log Management"] = bool(opened) self.show_windows["Log Management"] = bool(opened)

View File

@@ -410,6 +410,7 @@ def run_worker_lifecycle(ticket: Ticket, context: WorkerContext, context_files:
base_dir=".", base_dir=".",
pre_tool_callback=clutch_callback if ticket.step_mode else None, pre_tool_callback=clutch_callback if ticket.step_mode else None,
qa_callback=ai_client.run_tier4_analysis, qa_callback=ai_client.run_tier4_analysis,
patch_callback=ai_client.run_tier4_patch_callback,
stream_callback=stream_callback stream_callback=stream_callback
) )
finally: finally:

View File

@@ -44,13 +44,14 @@ def _build_subprocess_env() -> dict[str, str]:
env[key] = os.path.expandvars(str(val)) env[key] = os.path.expandvars(str(val))
return env return env
def run_powershell(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str: def run_powershell(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
""" """
Run a PowerShell script with working directory set to base_dir. Run a PowerShell script with working directory set to base_dir.
Returns a string combining stdout, stderr, and exit code. Returns a string combining stdout, stderr, and exit code.
Environment is configured via mcp_env.toml (project root). Environment is configured via mcp_env.toml (project root).
If qa_callback is provided and the command fails or has stderr, If qa_callback is provided and the command fails or has stderr,
the callback is called with the stderr content and its result is appended. the callback is called with the stderr content and its result is appended.
If patch_callback is provided, it receives (error, file_context) and returns patch text.
""" """
safe_dir: str = str(base_dir).replace("'", "''") safe_dir: str = str(base_dir).replace("'", "''")
full_script: str = f"Set-Location -LiteralPath '{safe_dir}'\n{script}" full_script: str = f"Set-Location -LiteralPath '{safe_dir}'\n{script}"
@@ -72,6 +73,10 @@ def run_powershell(script: str, base_dir: str, qa_callback: Optional[Callable[[s
qa_analysis: Optional[str] = qa_callback(stderr.strip()) qa_analysis: Optional[str] = qa_callback(stderr.strip())
if qa_analysis: if qa_analysis:
parts.append(f"\nQA ANALYSIS:\n{qa_analysis}") parts.append(f"\nQA ANALYSIS:\n{qa_analysis}")
if patch_callback and (process.returncode != 0 or stderr.strip()):
patch_text = patch_callback(stderr.strip(), base_dir)
if patch_text:
parts.append(f"\nAUTO_PATCH:\n{patch_text}")
return "\n".join(parts) return "\n".join(parts)
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
if 'process' in locals() and process: if 'process' in locals() and process:

View File

@@ -130,5 +130,5 @@ def test_gemini_provider_passes_qa_callback_to_run_script() -> None:
base_dir=".", base_dir=".",
qa_callback=qa_callback qa_callback=qa_callback
) )
# Verify _run_script received the qa_callback # Verify _run_script received the qa_callback and patch_callback
mock_run_script.assert_called_with("dir", ".", qa_callback) mock_run_script.assert_called_with("dir", ".", qa_callback, None)