feat(tier4): Integrate patch generation into GUI workflow
- Add patch_callback parameter throughout the tool execution chain - Add _render_patch_modal() to gui_2.py with colored diff display - Add patch modal state variables to App.__init__ - Add request_patch_from_tier4() to trigger patch generation - Add run_tier4_patch_callback() to ai_client.py - Update shell_runner to accept and execute patch_callback - Diff colors: green for additions, red for deletions, cyan for headers - 36 tests passing
This commit is contained in:
@@ -85,7 +85,7 @@ _send_lock: threading.Lock = threading.Lock()
|
||||
_gemini_cli_adapter: Optional[GeminiCliAdapter] = None
|
||||
|
||||
# Injected by gui.py - called when AI wants to run a command.
|
||||
confirm_and_run_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None
|
||||
confirm_and_run_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]], Optional[Callable[[str, str], Optional[str]]]], Optional[str]]] = None
|
||||
|
||||
# Injected by gui.py - called whenever a comms entry is appended.
|
||||
comms_log_callback: Optional[Callable[[dict[str, Any]], None]] = None
|
||||
@@ -558,7 +558,8 @@ async def _execute_tool_calls_concurrently(
|
||||
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
|
||||
qa_callback: Optional[Callable[[str], str]],
|
||||
r_idx: int,
|
||||
provider: str
|
||||
provider: str,
|
||||
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None
|
||||
) -> list[tuple[str, str, str, str]]: # tool_name, call_id, output, original_name
|
||||
"""
|
||||
Executes multiple tool calls concurrently using asyncio.gather.
|
||||
@@ -589,7 +590,7 @@ async def _execute_tool_calls_concurrently(
|
||||
else:
|
||||
continue
|
||||
|
||||
tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx))
|
||||
tasks.append(_execute_single_tool_call_async(name, args, call_id, base_dir, pre_tool_callback, qa_callback, r_idx, patch_callback))
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
@@ -601,7 +602,8 @@ async def _execute_single_tool_call_async(
|
||||
base_dir: str,
|
||||
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]],
|
||||
qa_callback: Optional[Callable[[str], str]],
|
||||
r_idx: int
|
||||
r_idx: int,
|
||||
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None
|
||||
) -> tuple[str, str, str, str]:
|
||||
out = ""
|
||||
tool_executed = False
|
||||
@@ -631,16 +633,16 @@ async def _execute_single_tool_call_async(
|
||||
elif name == TOOL_NAME:
|
||||
scr = cast(str, args.get("script", ""))
|
||||
_append_comms("OUT", "tool_call", {"name": TOOL_NAME, "id": call_id, "script": scr})
|
||||
out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback)
|
||||
out = await asyncio.to_thread(_run_script, scr, base_dir, qa_callback, patch_callback)
|
||||
else:
|
||||
out = f"ERROR: unknown tool '{name}'"
|
||||
|
||||
return (name, call_id, out, name)
|
||||
|
||||
def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> str:
|
||||
def _run_script(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
|
||||
if confirm_and_run_callback is None:
|
||||
return "ERROR: no confirmation handler registered"
|
||||
result = confirm_and_run_callback(script, base_dir, qa_callback)
|
||||
result = confirm_and_run_callback(script, base_dir, qa_callback, patch_callback)
|
||||
if result is None:
|
||||
output = "USER REJECTED: command was not executed"
|
||||
else:
|
||||
@@ -799,7 +801,8 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
||||
pre_tool_callback: Optional[Callable[[str, str, Optional[Callable[[str], str]]], Optional[str]]] = None,
|
||||
qa_callback: Optional[Callable[[str], str]] = None,
|
||||
enable_tools: bool = True,
|
||||
stream_callback: Optional[Callable[[str], None]] = None) -> str:
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> str:
|
||||
global _gemini_chat, _gemini_cache, _gemini_cache_md_hash, _gemini_cache_created_at
|
||||
try:
|
||||
_ensure_gemini_client(); mcp_client.configure(file_items or [], [base_dir])
|
||||
@@ -979,11 +982,11 @@ def _send_gemini(md_content: str, user_message: str, base_dir: str,
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
results = asyncio.run_coroutine_threadsafe(
|
||||
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"),
|
||||
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini", patch_callback),
|
||||
loop
|
||||
).result()
|
||||
except RuntimeError:
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini"))
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini", patch_callback))
|
||||
|
||||
for i, (name, call_id, out, _) in enumerate(results):
|
||||
# Check if this is the last tool to trigger file refresh
|
||||
@@ -1079,11 +1082,11 @@ def _send_gemini_cli(md_content: str, user_message: str, base_dir: str,
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
results = asyncio.run_coroutine_threadsafe(
|
||||
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"),
|
||||
_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback),
|
||||
loop
|
||||
).result()
|
||||
except RuntimeError:
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli"))
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(calls, base_dir, pre_tool_callback, qa_callback, r_idx, "gemini_cli", patch_callback))
|
||||
|
||||
tool_results_for_cli: list[dict[str, Any]] = []
|
||||
for i, (name, call_id, out, _) in enumerate(results):
|
||||
@@ -1404,7 +1407,7 @@ def _send_anthropic(md_content: str, user_message: str, base_dir: str, file_item
|
||||
loop
|
||||
).result()
|
||||
except RuntimeError:
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic"))
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(response.content, base_dir, pre_tool_callback, qa_callback, round_idx, "anthropic", patch_callback))
|
||||
|
||||
tool_results: list[dict[str, Any]] = []
|
||||
for i, (name, call_id, out, _) in enumerate(results):
|
||||
@@ -1675,11 +1678,11 @@ def _send_deepseek(md_content: str, user_message: str, base_dir: str,
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
results = asyncio.run_coroutine_threadsafe(
|
||||
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"),
|
||||
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek", patch_callback),
|
||||
loop
|
||||
).result()
|
||||
except RuntimeError:
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek"))
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "deepseek", patch_callback))
|
||||
|
||||
tool_results_for_history: list[dict[str, Any]] = []
|
||||
for i, (name, call_id, out, _) in enumerate(results):
|
||||
@@ -1891,11 +1894,11 @@ def _send_minimax(md_content: str, user_message: str, base_dir: str,
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
results = asyncio.run_coroutine_threadsafe(
|
||||
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax"),
|
||||
_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax", patch_callback),
|
||||
loop
|
||||
).result()
|
||||
except RuntimeError:
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax"))
|
||||
results = asyncio.run(_execute_tool_calls_concurrently(tool_calls_raw, base_dir, pre_tool_callback, qa_callback, round_idx, "minimax", patch_callback))
|
||||
|
||||
tool_results_for_history: list[dict[str, Any]] = []
|
||||
for i, (name, call_id, out, _) in enumerate(results):
|
||||
@@ -1962,6 +1965,23 @@ def run_tier4_analysis(stderr: str) -> str:
|
||||
return f"[QA ANALYSIS FAILED] {e}"
|
||||
|
||||
|
||||
def run_tier4_patch_callback(stderr: str, base_dir: str) -> Optional[str]:
|
||||
try:
|
||||
from src import project_manager
|
||||
file_items = project_manager.get_current_file_items()
|
||||
file_context = ""
|
||||
for item in file_items[:5]:
|
||||
path = item.get("path", "")
|
||||
content = item.get("content", "")[:2000]
|
||||
file_context += f"\n\nFile: {path}\n```\n{content}\n```\n"
|
||||
patch = run_tier4_patch_generation(stderr, file_context)
|
||||
if patch and "---" in patch and "+++" in patch:
|
||||
return patch
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
def run_tier4_patch_generation(error: str, file_context: str) -> str:
|
||||
if not error or not error.strip():
|
||||
return ""
|
||||
@@ -2032,32 +2052,33 @@ def send(
|
||||
qa_callback: Optional[Callable[[str], str]] = None,
|
||||
enable_tools: bool = True,
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
patch_callback: Optional[Callable[[str, str], Optional[str]]] = None,
|
||||
) -> str:
|
||||
with _send_lock:
|
||||
if _provider == "gemini":
|
||||
return _send_gemini(
|
||||
md_content, user_message, base_dir, file_items, discussion_history,
|
||||
pre_tool_callback, qa_callback, enable_tools, stream_callback
|
||||
pre_tool_callback, qa_callback, enable_tools, stream_callback, patch_callback
|
||||
)
|
||||
elif _provider == "gemini_cli":
|
||||
return _send_gemini_cli(
|
||||
md_content, user_message, base_dir, file_items, discussion_history,
|
||||
pre_tool_callback, qa_callback, stream_callback
|
||||
pre_tool_callback, qa_callback, stream_callback, patch_callback
|
||||
)
|
||||
elif _provider == "anthropic":
|
||||
return _send_anthropic(
|
||||
md_content, user_message, base_dir, file_items, discussion_history,
|
||||
pre_tool_callback, qa_callback, stream_callback=stream_callback
|
||||
pre_tool_callback, qa_callback, stream_callback=stream_callback, patch_callback=patch_callback
|
||||
)
|
||||
elif _provider == "deepseek":
|
||||
return _send_deepseek(
|
||||
md_content, user_message, base_dir, file_items, discussion_history,
|
||||
stream, pre_tool_callback, qa_callback, stream_callback
|
||||
stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
|
||||
)
|
||||
elif _provider == "minimax":
|
||||
return _send_minimax(
|
||||
md_content, user_message, base_dir, file_items, discussion_history,
|
||||
stream, pre_tool_callback, qa_callback, stream_callback
|
||||
stream, pre_tool_callback, qa_callback, stream_callback, patch_callback
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {_provider}")
|
||||
@@ -2079,16 +2100,16 @@ def _is_mutating_tool(name: str) -> bool:
|
||||
"""Returns True if the tool name is considered a mutating tool."""
|
||||
return name in mcp_client.MUTATING_TOOLS or name == TOOL_NAME
|
||||
|
||||
def _confirm_and_run(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None) -> Optional[str]:
|
||||
def _confirm_and_run(script: str, base_dir: str, qa_callback: Optional[Callable[[str], str]] = None, patch_callback: Optional[Callable[[str, str], Optional[str]]] = None) -> Optional[str]:
|
||||
"""
|
||||
Wrapper for the confirm_and_run_callback.
|
||||
This is what the providers call to trigger HITL approval.
|
||||
"""
|
||||
if confirm_and_run_callback:
|
||||
return confirm_and_run_callback(script, base_dir, qa_callback)
|
||||
return confirm_and_run_callback(script, base_dir, qa_callback, patch_callback)
|
||||
# Fallback to direct execution if no callback registered (headless default)
|
||||
from src import shell_runner
|
||||
return shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback)
|
||||
return shell_runner.run_powershell(script, base_dir, qa_callback=qa_callback, patch_callback=patch_callback)
|
||||
|
||||
def get_history_bleed_stats(md_content: Optional[str] = None) -> dict[str, Any]:
|
||||
if _provider == "anthropic":
|
||||
|
||||
Reference in New Issue
Block a user